Build uploaded using `kernels` (batch 1/10).
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- build/torch210-cxx11-cu126-aarch64-linux/__init__.py +684 -0
- build/torch210-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so +3 -0
- build/torch210-cxx11-cu126-aarch64-linux/_ops.py +9 -0
- build/torch210-cxx11-cu126-aarch64-linux/deep_gemm/__init__.py +26 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/common/cute_tie.cuh +48 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/common/epilogue_utils.cuh +27 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/common/reduction.cuh +44 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/common/scheduler.cuh +288 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/common/sm100_utils.cuh +266 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/common/sm90_utils.cuh +332 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/common/tma_utils.cuh +116 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/common/types.hpp +41 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/common/utils.cuh +183 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh +482 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh +265 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +563 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh +404 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh +398 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh +345 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh +381 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh +174 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh +349 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +440 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh +329 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh +413 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh +287 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh +67 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/smxx_layout.cuh +176 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/options.h +121 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/register_layout.h +59 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/visualize_layout.h +383 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h +719 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_gemm_run.h +763 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h +450 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h +749 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h +798 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_gemm.h +352 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h +300 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h +811 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h +157 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h +521 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h +94 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h +749 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h +740 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h +817 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h +804 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h +503 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h +384 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/grouped.h +168 -0
.gitattributes
CHANGED
|
@@ -11,3 +11,4 @@ build/torch29-cxx11-cu126-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lf
|
|
| 11 |
build/torch29-cxx11-cu128-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 12 |
build/torch29-cxx11-cu129-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 13 |
build/torch29-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 11 |
build/torch29-cxx11-cu128-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 12 |
build/torch29-cxx11-cu129-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 13 |
build/torch29-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
build/torch210-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
build/torch210-cxx11-cu126-aarch64-linux/__init__.py
ADDED
|
@@ -0,0 +1,684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
# Import the compiled extension
|
| 6 |
+
from ._ops import ops
|
| 7 |
+
from . import utils
|
| 8 |
+
|
| 9 |
+
__version__ = "2.3.0"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Runtime
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def set_num_sms(num_sms: int):
|
| 16 |
+
ops.set_num_sms(num_sms)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_num_sms() -> int:
|
| 20 |
+
return ops.get_num_sms()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def set_tc_util(tc_util: int):
|
| 24 |
+
ops.set_tc_util(tc_util)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_tc_util() -> int:
|
| 28 |
+
return ops.get_tc_util()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_mk_alignment_for_contiguous_layout() -> int:
|
| 32 |
+
return ops.get_mk_alignment_for_contiguous_layout()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# Layout utilities
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_tma_aligned_size(mn: int, element_size: int) -> int:
|
| 39 |
+
return ops.get_tma_aligned_size(mn, element_size).item()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_mn_major_tma_aligned_tensor(sf):
|
| 43 |
+
return ops.get_mn_major_tma_aligned_tensor(sf)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf):
|
| 47 |
+
return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks):
|
| 51 |
+
ks_int = torch.tensor(ks, dtype=torch.int32, device="cpu")
|
| 52 |
+
return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
|
| 53 |
+
sf, ks_tensor, ks_int
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def transform_sf_into_required_layout(
|
| 58 |
+
sf,
|
| 59 |
+
mn,
|
| 60 |
+
k,
|
| 61 |
+
recipe=None,
|
| 62 |
+
recipe_ab=None,
|
| 63 |
+
num_groups=None,
|
| 64 |
+
is_sfa=False,
|
| 65 |
+
disable_ue8m0_cast=False,
|
| 66 |
+
):
|
| 67 |
+
has_recipe = recipe is not None
|
| 68 |
+
r0, r1, r2 = recipe if has_recipe else (0, 0, 0)
|
| 69 |
+
has_recipe_ab = recipe_ab is not None
|
| 70 |
+
rab0, rab1 = recipe_ab if has_recipe_ab else (0, 0)
|
| 71 |
+
has_ng = num_groups is not None
|
| 72 |
+
ng = num_groups if has_ng else 0
|
| 73 |
+
return ops.transform_sf_into_required_layout(
|
| 74 |
+
sf,
|
| 75 |
+
mn,
|
| 76 |
+
k,
|
| 77 |
+
r0,
|
| 78 |
+
r1,
|
| 79 |
+
r2,
|
| 80 |
+
has_recipe,
|
| 81 |
+
rab0,
|
| 82 |
+
rab1,
|
| 83 |
+
has_recipe_ab,
|
| 84 |
+
ng,
|
| 85 |
+
has_ng,
|
| 86 |
+
is_sfa,
|
| 87 |
+
disable_ue8m0_cast,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# Aliases for contiguous layout alignment
|
| 92 |
+
get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
|
| 93 |
+
get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# Helper to flatten recipe args
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _flatten_recipe(recipe, recipe_a=None, recipe_b=None):
|
| 100 |
+
has_recipe = recipe is not None
|
| 101 |
+
r0, r1, r2 = recipe if has_recipe else (0, 0, 0)
|
| 102 |
+
has_ra = recipe_a is not None
|
| 103 |
+
ra0, ra1 = recipe_a if has_ra else (0, 0)
|
| 104 |
+
has_rb = recipe_b is not None
|
| 105 |
+
rb0, rb1 = recipe_b if has_rb else (0, 0)
|
| 106 |
+
return r0, r1, r2, has_recipe, ra0, ra1, has_ra, rb0, rb1, has_rb
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# FP8/FP4 GEMM ops
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def fp8_fp4_gemm_nt(
|
| 113 |
+
a,
|
| 114 |
+
b,
|
| 115 |
+
d,
|
| 116 |
+
c=None,
|
| 117 |
+
recipe=None,
|
| 118 |
+
recipe_a=None,
|
| 119 |
+
recipe_b=None,
|
| 120 |
+
compiled_dims="nk",
|
| 121 |
+
disable_ue8m0_cast=False,
|
| 122 |
+
):
|
| 123 |
+
a_data, a_sf = a
|
| 124 |
+
b_data, b_sf = b
|
| 125 |
+
r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe(
|
| 126 |
+
recipe, recipe_a, recipe_b
|
| 127 |
+
)
|
| 128 |
+
ops.fp8_fp4_gemm_nt(
|
| 129 |
+
a_data,
|
| 130 |
+
a_sf,
|
| 131 |
+
b_data,
|
| 132 |
+
b_sf,
|
| 133 |
+
d,
|
| 134 |
+
c,
|
| 135 |
+
r0,
|
| 136 |
+
r1,
|
| 137 |
+
r2,
|
| 138 |
+
hr,
|
| 139 |
+
ra0,
|
| 140 |
+
ra1,
|
| 141 |
+
hra,
|
| 142 |
+
rb0,
|
| 143 |
+
rb1,
|
| 144 |
+
hrb,
|
| 145 |
+
compiled_dims,
|
| 146 |
+
disable_ue8m0_cast,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def fp8_fp4_gemm_nn(
|
| 151 |
+
a,
|
| 152 |
+
b,
|
| 153 |
+
d,
|
| 154 |
+
c=None,
|
| 155 |
+
recipe=None,
|
| 156 |
+
recipe_a=None,
|
| 157 |
+
recipe_b=None,
|
| 158 |
+
compiled_dims="nk",
|
| 159 |
+
disable_ue8m0_cast=False,
|
| 160 |
+
):
|
| 161 |
+
a_data, a_sf = a
|
| 162 |
+
b_data, b_sf = b
|
| 163 |
+
r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe(
|
| 164 |
+
recipe, recipe_a, recipe_b
|
| 165 |
+
)
|
| 166 |
+
ops.fp8_fp4_gemm_nn(
|
| 167 |
+
a_data,
|
| 168 |
+
a_sf,
|
| 169 |
+
b_data,
|
| 170 |
+
b_sf,
|
| 171 |
+
d,
|
| 172 |
+
c,
|
| 173 |
+
r0,
|
| 174 |
+
r1,
|
| 175 |
+
r2,
|
| 176 |
+
hr,
|
| 177 |
+
ra0,
|
| 178 |
+
ra1,
|
| 179 |
+
hra,
|
| 180 |
+
rb0,
|
| 181 |
+
rb1,
|
| 182 |
+
hrb,
|
| 183 |
+
compiled_dims,
|
| 184 |
+
disable_ue8m0_cast,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def fp8_fp4_gemm_tn(
|
| 189 |
+
a,
|
| 190 |
+
b,
|
| 191 |
+
d,
|
| 192 |
+
c=None,
|
| 193 |
+
recipe=None,
|
| 194 |
+
recipe_a=None,
|
| 195 |
+
recipe_b=None,
|
| 196 |
+
compiled_dims="mn",
|
| 197 |
+
disable_ue8m0_cast=False,
|
| 198 |
+
):
|
| 199 |
+
a_data, a_sf = a
|
| 200 |
+
b_data, b_sf = b
|
| 201 |
+
r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe(
|
| 202 |
+
recipe, recipe_a, recipe_b
|
| 203 |
+
)
|
| 204 |
+
ops.fp8_fp4_gemm_tn(
|
| 205 |
+
a_data,
|
| 206 |
+
a_sf,
|
| 207 |
+
b_data,
|
| 208 |
+
b_sf,
|
| 209 |
+
d,
|
| 210 |
+
c,
|
| 211 |
+
r0,
|
| 212 |
+
r1,
|
| 213 |
+
r2,
|
| 214 |
+
hr,
|
| 215 |
+
ra0,
|
| 216 |
+
ra1,
|
| 217 |
+
hra,
|
| 218 |
+
rb0,
|
| 219 |
+
rb1,
|
| 220 |
+
hrb,
|
| 221 |
+
compiled_dims,
|
| 222 |
+
disable_ue8m0_cast,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def fp8_fp4_gemm_tt(
|
| 227 |
+
a,
|
| 228 |
+
b,
|
| 229 |
+
d,
|
| 230 |
+
c=None,
|
| 231 |
+
recipe=None,
|
| 232 |
+
recipe_a=None,
|
| 233 |
+
recipe_b=None,
|
| 234 |
+
compiled_dims="mn",
|
| 235 |
+
disable_ue8m0_cast=False,
|
| 236 |
+
):
|
| 237 |
+
a_data, a_sf = a
|
| 238 |
+
b_data, b_sf = b
|
| 239 |
+
r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe(
|
| 240 |
+
recipe, recipe_a, recipe_b
|
| 241 |
+
)
|
| 242 |
+
ops.fp8_fp4_gemm_tt(
|
| 243 |
+
a_data,
|
| 244 |
+
a_sf,
|
| 245 |
+
b_data,
|
| 246 |
+
b_sf,
|
| 247 |
+
d,
|
| 248 |
+
c,
|
| 249 |
+
r0,
|
| 250 |
+
r1,
|
| 251 |
+
r2,
|
| 252 |
+
hr,
|
| 253 |
+
ra0,
|
| 254 |
+
ra1,
|
| 255 |
+
hra,
|
| 256 |
+
rb0,
|
| 257 |
+
rb1,
|
| 258 |
+
hrb,
|
| 259 |
+
compiled_dims,
|
| 260 |
+
disable_ue8m0_cast,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# FP8 aliases (same as FP8/FP4)
|
| 265 |
+
fp8_gemm_nt = fp8_fp4_gemm_nt
|
| 266 |
+
fp8_gemm_nn = fp8_fp4_gemm_nn
|
| 267 |
+
fp8_gemm_tn = fp8_fp4_gemm_tn
|
| 268 |
+
fp8_gemm_tt = fp8_fp4_gemm_tt
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# M-grouped FP8/FP4 GEMM ops
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def m_grouped_fp8_fp4_gemm_nt_contiguous(
|
| 275 |
+
a,
|
| 276 |
+
b,
|
| 277 |
+
d,
|
| 278 |
+
grouped_layout,
|
| 279 |
+
recipe=None,
|
| 280 |
+
recipe_a=None,
|
| 281 |
+
recipe_b=None,
|
| 282 |
+
compiled_dims="nk",
|
| 283 |
+
disable_ue8m0_cast=False,
|
| 284 |
+
use_psum_layout=False,
|
| 285 |
+
expected_m_for_psum_layout=None,
|
| 286 |
+
):
|
| 287 |
+
a_data, a_sf = a
|
| 288 |
+
b_data, b_sf = b
|
| 289 |
+
r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe(
|
| 290 |
+
recipe, recipe_a, recipe_b
|
| 291 |
+
)
|
| 292 |
+
has_em = expected_m_for_psum_layout is not None
|
| 293 |
+
em = expected_m_for_psum_layout if has_em else 0
|
| 294 |
+
ops.m_grouped_fp8_fp4_gemm_nt_contiguous(
|
| 295 |
+
a_data,
|
| 296 |
+
a_sf,
|
| 297 |
+
b_data,
|
| 298 |
+
b_sf,
|
| 299 |
+
d,
|
| 300 |
+
grouped_layout,
|
| 301 |
+
r0,
|
| 302 |
+
r1,
|
| 303 |
+
r2,
|
| 304 |
+
hr,
|
| 305 |
+
ra0,
|
| 306 |
+
ra1,
|
| 307 |
+
hra,
|
| 308 |
+
rb0,
|
| 309 |
+
rb1,
|
| 310 |
+
hrb,
|
| 311 |
+
compiled_dims,
|
| 312 |
+
disable_ue8m0_cast,
|
| 313 |
+
use_psum_layout,
|
| 314 |
+
em,
|
| 315 |
+
has_em,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def m_grouped_fp8_fp4_gemm_nn_contiguous(
|
| 320 |
+
a,
|
| 321 |
+
b,
|
| 322 |
+
d,
|
| 323 |
+
grouped_layout,
|
| 324 |
+
recipe=None,
|
| 325 |
+
recipe_a=None,
|
| 326 |
+
recipe_b=None,
|
| 327 |
+
compiled_dims="nk",
|
| 328 |
+
disable_ue8m0_cast=False,
|
| 329 |
+
use_psum_layout=False,
|
| 330 |
+
):
|
| 331 |
+
a_data, a_sf = a
|
| 332 |
+
b_data, b_sf = b
|
| 333 |
+
r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe(
|
| 334 |
+
recipe, recipe_a, recipe_b
|
| 335 |
+
)
|
| 336 |
+
ops.m_grouped_fp8_fp4_gemm_nn_contiguous(
|
| 337 |
+
a_data,
|
| 338 |
+
a_sf,
|
| 339 |
+
b_data,
|
| 340 |
+
b_sf,
|
| 341 |
+
d,
|
| 342 |
+
grouped_layout,
|
| 343 |
+
r0,
|
| 344 |
+
r1,
|
| 345 |
+
r2,
|
| 346 |
+
hr,
|
| 347 |
+
ra0,
|
| 348 |
+
ra1,
|
| 349 |
+
hra,
|
| 350 |
+
rb0,
|
| 351 |
+
rb1,
|
| 352 |
+
hrb,
|
| 353 |
+
compiled_dims,
|
| 354 |
+
disable_ue8m0_cast,
|
| 355 |
+
use_psum_layout,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def m_grouped_fp8_fp4_gemm_nt_masked(
|
| 360 |
+
a,
|
| 361 |
+
b,
|
| 362 |
+
d,
|
| 363 |
+
masked_m,
|
| 364 |
+
expected_m,
|
| 365 |
+
recipe=None,
|
| 366 |
+
recipe_a=None,
|
| 367 |
+
recipe_b=None,
|
| 368 |
+
compiled_dims="nk",
|
| 369 |
+
disable_ue8m0_cast=False,
|
| 370 |
+
):
|
| 371 |
+
a_data, a_sf = a
|
| 372 |
+
b_data, b_sf = b
|
| 373 |
+
r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe(
|
| 374 |
+
recipe, recipe_a, recipe_b
|
| 375 |
+
)
|
| 376 |
+
ops.m_grouped_fp8_fp4_gemm_nt_masked(
|
| 377 |
+
a_data,
|
| 378 |
+
a_sf,
|
| 379 |
+
b_data,
|
| 380 |
+
b_sf,
|
| 381 |
+
d,
|
| 382 |
+
masked_m,
|
| 383 |
+
expected_m,
|
| 384 |
+
r0,
|
| 385 |
+
r1,
|
| 386 |
+
r2,
|
| 387 |
+
hr,
|
| 388 |
+
ra0,
|
| 389 |
+
ra1,
|
| 390 |
+
hra,
|
| 391 |
+
rb0,
|
| 392 |
+
rb1,
|
| 393 |
+
hrb,
|
| 394 |
+
compiled_dims,
|
| 395 |
+
disable_ue8m0_cast,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
# M-grouped FP8 aliases
|
| 400 |
+
m_grouped_fp8_gemm_nt_contiguous = m_grouped_fp8_fp4_gemm_nt_contiguous
|
| 401 |
+
m_grouped_fp8_gemm_nn_contiguous = m_grouped_fp8_fp4_gemm_nn_contiguous
|
| 402 |
+
m_grouped_fp8_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
|
| 403 |
+
|
| 404 |
+
# Legacy aliases
|
| 405 |
+
fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
# K-grouped FP8 GEMM ops
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def k_grouped_fp8_gemm_tn_contiguous(
|
| 412 |
+
a, b, d, ks, ks_tensor, c=None, recipe=(1, 1, 128), compiled_dims="mn"
|
| 413 |
+
):
|
| 414 |
+
a_data, a_sf = a
|
| 415 |
+
b_data, b_sf = b
|
| 416 |
+
r0, r1, r2 = recipe
|
| 417 |
+
ops.k_grouped_fp8_gemm_tn_contiguous(
|
| 418 |
+
a_data, a_sf, b_data, b_sf, d, ks_tensor, c, r0, r1, r2, compiled_dims
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def k_grouped_fp8_gemm_nt_contiguous(
|
| 423 |
+
a, b, d, ks, ks_tensor, c=None, recipe=(1, 1, 128), compiled_dims="mn"
|
| 424 |
+
):
|
| 425 |
+
a_data, a_sf = a
|
| 426 |
+
b_data, b_sf = b
|
| 427 |
+
r0, r1, r2 = recipe
|
| 428 |
+
ops.k_grouped_fp8_gemm_nt_contiguous(
|
| 429 |
+
a_data, a_sf, b_data, b_sf, d, ks_tensor, c, r0, r1, r2, compiled_dims
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
# BF16 GEMM ops
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def bf16_gemm_nt(a, b, d, c=None, compiled_dims="nk"):
|
| 437 |
+
ops.bf16_gemm_nt(a, b, d, c, compiled_dims)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def bf16_gemm_nn(a, b, d, c=None, compiled_dims="nk"):
|
| 441 |
+
ops.bf16_gemm_nn(a, b, d, c, compiled_dims)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def bf16_gemm_tn(a, b, d, c=None, compiled_dims="mn"):
|
| 445 |
+
ops.bf16_gemm_tn(a, b, d, c, compiled_dims)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def bf16_gemm_tt(a, b, d, c=None, compiled_dims="mn"):
|
| 449 |
+
ops.bf16_gemm_tt(a, b, d, c, compiled_dims)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# M-grouped BF16 GEMM ops
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def m_grouped_bf16_gemm_nt_contiguous(
|
| 456 |
+
a,
|
| 457 |
+
b,
|
| 458 |
+
d,
|
| 459 |
+
grouped_layout,
|
| 460 |
+
compiled_dims="nk",
|
| 461 |
+
use_psum_layout=False,
|
| 462 |
+
expected_m_for_psum_layout=None,
|
| 463 |
+
):
|
| 464 |
+
has_em = expected_m_for_psum_layout is not None
|
| 465 |
+
em = expected_m_for_psum_layout if has_em else 0
|
| 466 |
+
ops.m_grouped_bf16_gemm_nt_contiguous(
|
| 467 |
+
a, b, d, grouped_layout, compiled_dims, use_psum_layout, em, has_em
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def m_grouped_bf16_gemm_nn_contiguous(
|
| 472 |
+
a, b, d, grouped_layout, compiled_dims="nk", use_psum_layout=False
|
| 473 |
+
):
|
| 474 |
+
ops.m_grouped_bf16_gemm_nn_contiguous(
|
| 475 |
+
a, b, d, grouped_layout, compiled_dims, use_psum_layout
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m, compiled_dims="nk"):
|
| 480 |
+
ops.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m, compiled_dims)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
# Legacy alias
|
| 484 |
+
bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
# K-grouped BF16 GEMM ops
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def k_grouped_bf16_gemm_tn_contiguous(
|
| 491 |
+
a, b, d, ks, ks_tensor, c=None, compiled_dims="mn"
|
| 492 |
+
):
|
| 493 |
+
ops.k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks_tensor, c, compiled_dims)
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
# cuBLASLt GEMM ops
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def cublaslt_gemm_nt(a, b, d, c=None):
|
| 500 |
+
ops.cublaslt_gemm_nt(a, b, d, c)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def cublaslt_gemm_nn(a, b, d, c=None):
|
| 504 |
+
ops.cublaslt_gemm_nn(a, b, d, c)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def cublaslt_gemm_tn(a, b, d, c=None):
|
| 508 |
+
ops.cublaslt_gemm_tn(a, b, d, c)
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def cublaslt_gemm_tt(a, b, d, c=None):
|
| 512 |
+
ops.cublaslt_gemm_tt(a, b, d, c)
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
# Attention ops
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def fp8_gemm_nt_skip_head_mid(
|
| 519 |
+
a, b, d, head_splits, recipe=None, compiled_dims="nk", disable_ue8m0_cast=False
|
| 520 |
+
):
|
| 521 |
+
a_data, a_sf = a
|
| 522 |
+
b_data, b_sf = b
|
| 523 |
+
left, mid, right = head_splits
|
| 524 |
+
has_recipe = recipe is not None
|
| 525 |
+
r0, r1, r2 = recipe if has_recipe else (0, 0, 0)
|
| 526 |
+
ops.fp8_gemm_nt_skip_head_mid(
|
| 527 |
+
a_data,
|
| 528 |
+
a_sf,
|
| 529 |
+
b_data,
|
| 530 |
+
b_sf,
|
| 531 |
+
d,
|
| 532 |
+
left,
|
| 533 |
+
mid,
|
| 534 |
+
right,
|
| 535 |
+
r0,
|
| 536 |
+
r1,
|
| 537 |
+
r2,
|
| 538 |
+
has_recipe,
|
| 539 |
+
compiled_dims,
|
| 540 |
+
disable_ue8m0_cast,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def fp8_mqa_logits(
|
| 545 |
+
q,
|
| 546 |
+
kv,
|
| 547 |
+
weights,
|
| 548 |
+
cu_seq_len_k_start,
|
| 549 |
+
cu_seq_len_k_end,
|
| 550 |
+
clean_logits=True,
|
| 551 |
+
max_seqlen_k=0,
|
| 552 |
+
):
|
| 553 |
+
kv_data, kv_sf = kv
|
| 554 |
+
return ops.fp8_mqa_logits(
|
| 555 |
+
q,
|
| 556 |
+
kv_data,
|
| 557 |
+
kv_sf,
|
| 558 |
+
weights,
|
| 559 |
+
cu_seq_len_k_start,
|
| 560 |
+
cu_seq_len_k_end,
|
| 561 |
+
clean_logits,
|
| 562 |
+
max_seqlen_k,
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms):
|
| 567 |
+
return ops.get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms)
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def fp8_paged_mqa_logits(
|
| 571 |
+
q,
|
| 572 |
+
kv_cache,
|
| 573 |
+
weights,
|
| 574 |
+
context_lens,
|
| 575 |
+
block_table,
|
| 576 |
+
schedule_meta,
|
| 577 |
+
max_context_len,
|
| 578 |
+
clean_logits=False,
|
| 579 |
+
):
|
| 580 |
+
return ops.fp8_paged_mqa_logits(
|
| 581 |
+
q,
|
| 582 |
+
kv_cache,
|
| 583 |
+
weights,
|
| 584 |
+
context_lens,
|
| 585 |
+
block_table,
|
| 586 |
+
schedule_meta,
|
| 587 |
+
max_context_len,
|
| 588 |
+
clean_logits,
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
# Einsum ops
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def einsum(expr, a, b, d, c=None, use_cublaslt=False):
|
| 596 |
+
ops.einsum(expr, a, b, d, c, use_cublaslt)
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def fp8_einsum(expr, a, b, d, c=None, recipe=(1, 128, 128)):
|
| 600 |
+
a_data, a_sf = a
|
| 601 |
+
b_data, b_sf = b
|
| 602 |
+
r0, r1, r2 = recipe
|
| 603 |
+
ops.fp8_einsum(expr, a_data, a_sf, b_data, b_sf, d, c, r0, r1, r2)
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
# Hyperconnection ops
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None):
|
| 610 |
+
has_ns = num_splits is not None
|
| 611 |
+
ns = num_splits if has_ns else 0
|
| 612 |
+
ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, ns, has_ns)
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
# Initialize the C++ runtime
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
def _find_cuda_home() -> str:
|
| 619 |
+
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
|
| 620 |
+
if cuda_home is None:
|
| 621 |
+
try:
|
| 622 |
+
with open(os.devnull, "w") as devnull:
|
| 623 |
+
nvcc = (
|
| 624 |
+
subprocess.check_output(["which", "nvcc"], stderr=devnull)
|
| 625 |
+
.decode()
|
| 626 |
+
.rstrip("\r\n")
|
| 627 |
+
)
|
| 628 |
+
cuda_home = os.path.dirname(os.path.dirname(nvcc))
|
| 629 |
+
except Exception:
|
| 630 |
+
cuda_home = "/usr/local/cuda"
|
| 631 |
+
if not os.path.exists(cuda_home):
|
| 632 |
+
cuda_home = None
|
| 633 |
+
assert cuda_home is not None, "Could not find CUDA installation"
|
| 634 |
+
return cuda_home
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
# Find the library root for JIT headers
|
| 638 |
+
# In development: use the repo's deep_gemm/ directory
|
| 639 |
+
# In installed wheel: use this package's directory
|
| 640 |
+
_lib_root = os.path.join(
|
| 641 |
+
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "deep_gemm"
|
| 642 |
+
)
|
| 643 |
+
if not os.path.isdir(os.path.join(_lib_root, "include")):
|
| 644 |
+
# Fallback: try the parent package
|
| 645 |
+
_lib_root = os.path.dirname(os.path.abspath(__file__))
|
| 646 |
+
|
| 647 |
+
_initialized = False
|
| 648 |
+
|
| 649 |
+
# Set DG_CUTLASS_INCLUDE for JIT kernel compilation (if not already set by user)
|
| 650 |
+
if "DG_CUTLASS_INCLUDE" not in os.environ:
|
| 651 |
+
_include = os.path.join(_lib_root, "include")
|
| 652 |
+
_cutlass_include_candidates = [
|
| 653 |
+
_include, # legacy layout: include/cutlass
|
| 654 |
+
os.path.join(_include, "third-party", "cutlass", "include"), # submodule layout
|
| 655 |
+
]
|
| 656 |
+
for _cutlass_include in _cutlass_include_candidates:
|
| 657 |
+
if os.path.isdir(os.path.join(_cutlass_include, "cutlass")):
|
| 658 |
+
os.environ["DG_CUTLASS_INCLUDE"] = _cutlass_include
|
| 659 |
+
break
|
| 660 |
+
else:
|
| 661 |
+
# Fall back to nvidia-cutlass pip package
|
| 662 |
+
try:
|
| 663 |
+
import nvidia.cutlass as _nc
|
| 664 |
+
os.environ["DG_CUTLASS_INCLUDE"] = os.path.join(
|
| 665 |
+
os.path.dirname(_nc.__file__), "include"
|
| 666 |
+
)
|
| 667 |
+
except ImportError:
|
| 668 |
+
pass
|
| 669 |
+
|
| 670 |
+
def _ensure_initialized():
|
| 671 |
+
global _initialized
|
| 672 |
+
if _initialized:
|
| 673 |
+
return
|
| 674 |
+
_initialized = True
|
| 675 |
+
ops.init(_lib_root, _find_cuda_home())
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
# Try to initialize eagerly, but don't fail if CUDA is not found
|
| 679 |
+
# (e.g., during build-time import checks). init() will be called
|
| 680 |
+
# lazily on first actual kernel use.
|
| 681 |
+
try:
|
| 682 |
+
_ensure_initialized()
|
| 683 |
+
except (AssertionError, RuntimeError):
|
| 684 |
+
pass
|
build/torch210-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:59c03e7c2a78c9cc545723380a52b88ac95ffa5803211a7ab5cecc5358524720
|
| 3 |
+
size 2828112
|
build/torch210-cxx11-cu126-aarch64-linux/_ops.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from . import _deep_gemm_cuda_a68a39f
|
| 3 |
+
ops = torch.ops._deep_gemm_cuda_a68a39f
|
| 4 |
+
|
| 5 |
+
def add_op_namespace_prefix(op_name: str):
|
| 6 |
+
"""
|
| 7 |
+
Prefix op by namespace.
|
| 8 |
+
"""
|
| 9 |
+
return f"_deep_gemm_cuda_a68a39f::{op_name}"
|
build/torch210-cxx11-cu126-aarch64-linux/deep_gemm/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import importlib
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from types import ModuleType
|
| 7 |
+
|
| 8 |
+
def _import_from_path(file_path: Path) -> ModuleType:
|
| 9 |
+
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
| 10 |
+
# it would also be used for other imports. So, we make a module name that
|
| 11 |
+
# depends on the path for it to be unique using the hex-encoded hash of
|
| 12 |
+
# the path.
|
| 13 |
+
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
|
| 14 |
+
module_name = path_hash
|
| 15 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 16 |
+
if spec is None:
|
| 17 |
+
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
| 18 |
+
module = importlib.util.module_from_spec(spec)
|
| 19 |
+
if module is None:
|
| 20 |
+
raise ImportError(f"Cannot load module {module_name} from spec")
|
| 21 |
+
sys.modules[module_name] = module
|
| 22 |
+
spec.loader.exec_module(module) # type: ignore
|
| 23 |
+
return module
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/common/cute_tie.cuh
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
namespace cute {
|
| 4 |
+
|
| 5 |
+
struct ignore_t {
|
| 6 |
+
template <typename T>
|
| 7 |
+
constexpr const ignore_t& operator=(T&&) const noexcept {
|
| 8 |
+
return *this;
|
| 9 |
+
}
|
| 10 |
+
};
|
| 11 |
+
|
| 12 |
+
inline constexpr ignore_t ignore{};
|
| 13 |
+
|
| 14 |
+
} // namespace cute
|
| 15 |
+
|
| 16 |
+
#define CUTE_TIE_CONCAT_IMPL(A, B) A##B
|
| 17 |
+
#define CUTE_TIE_CONCAT(A, B) CUTE_TIE_CONCAT_IMPL(A, B)
|
| 18 |
+
|
| 19 |
+
#define CUTE_TIE_GET_NTH_ARG(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) N
|
| 20 |
+
#define CUTE_TIE_COUNT_ARGS(...) \
|
| 21 |
+
CUTE_TIE_GET_NTH_ARG(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
|
| 22 |
+
|
| 23 |
+
#define CUTE_TIE_OP_DECL(I, TUPLE, VAR) auto VAR = ::cute::get<I>(TUPLE)
|
| 24 |
+
#define CUTE_TIE_OP_ASSIGN(I, TUPLE, VAR) VAR = ::cute::get<I>(TUPLE)
|
| 25 |
+
|
| 26 |
+
#define CUTE_TIE_APPLY_OP_1(OP, T, V1) OP(0, T, V1);
|
| 27 |
+
#define CUTE_TIE_APPLY_OP_2(OP, T, V1, V2) OP(0, T, V1); OP(1, T, V2);
|
| 28 |
+
#define CUTE_TIE_APPLY_OP_3(OP, T, V1, V2, V3) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3);
|
| 29 |
+
#define CUTE_TIE_APPLY_OP_4(OP, T, V1, V2, V3, V4) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4);
|
| 30 |
+
#define CUTE_TIE_APPLY_OP_5(OP, T, V1, V2, V3, V4, V5) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); OP(4, T, V5);
|
| 31 |
+
|
| 32 |
+
#define CUTE_TIE_DECL(TUPLE_EXPR, ...) \
|
| 33 |
+
auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \
|
| 34 |
+
CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \
|
| 35 |
+
CUTE_TIE_OP_DECL, \
|
| 36 |
+
CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \
|
| 37 |
+
__VA_ARGS__ \
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
#define CUTE_TIE(TUPLE_EXPR, ...) \
|
| 41 |
+
do { \
|
| 42 |
+
auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \
|
| 43 |
+
CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \
|
| 44 |
+
CUTE_TIE_OP_ASSIGN, \
|
| 45 |
+
CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \
|
| 46 |
+
__VA_ARGS__ \
|
| 47 |
+
); \
|
| 48 |
+
} while (0)
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/common/epilogue_utils.cuh
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <deep_gemm/common/types.hpp>
|
| 4 |
+
#include <deep_gemm/common/utils.cuh>
|
| 5 |
+
|
| 6 |
+
namespace deep_gemm {
|
| 7 |
+
|
| 8 |
+
struct EpilogueIdentity {
|
| 9 |
+
template <uint32_t STORE_BLOCK_N>
|
| 10 |
+
__device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) {
|
| 11 |
+
return n_idx;
|
| 12 |
+
}
|
| 13 |
+
};
|
| 14 |
+
|
| 15 |
+
template <uint32_t kLeft, uint32_t kMid, uint32_t kRight>
|
| 16 |
+
struct EpilogueHeadSplits: EpilogueIdentity {
|
| 17 |
+
template <uint32_t STORE_BLOCK_N>
|
| 18 |
+
__device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) {
|
| 19 |
+
DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0
|
| 20 |
+
and kRight % STORE_BLOCK_N == 0, "Invalid head splits config");
|
| 21 |
+
return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid;
|
| 22 |
+
}
|
| 23 |
+
};
|
| 24 |
+
|
| 25 |
+
#pragma clang diagnostic pop
|
| 26 |
+
|
| 27 |
+
} // namespace deep_gemm
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/common/reduction.cuh
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cuda_bf16.h>
|
| 4 |
+
#include <cuda_fp8.h>
|
| 5 |
+
#include <cuda/std/cstdint>
|
| 6 |
+
#include <cuda/std/utility>
|
| 7 |
+
|
| 8 |
+
#include <deep_gemm/common/utils.cuh>
|
| 9 |
+
|
| 10 |
+
// Operation functors
|
| 11 |
+
template <typename T> struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } };
|
| 12 |
+
template <typename T> struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } };
|
| 13 |
+
template <typename T> struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } };
|
| 14 |
+
template <typename T> struct ReduceAnd { __device__ T operator()(T a, T b) const { return a & b; } };
|
| 15 |
+
template <typename T> struct ReduceOr { __device__ T operator()(T a, T b) const { return a | b; } };
|
| 16 |
+
|
| 17 |
+
// Unified reduction function
|
| 18 |
+
template <int kNumLanesPerGroup, bool kIntergroupReduce, typename T, typename Op>
|
| 19 |
+
__forceinline__ __device__ T warp_reduce(T value, Op op) {
|
| 20 |
+
DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or
|
| 21 |
+
kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1,
|
| 22 |
+
"Invalid number of lanes");
|
| 23 |
+
constexpr uint32_t mask = 0xffffffff;
|
| 24 |
+
if constexpr (kIntergroupReduce) {
|
| 25 |
+
if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1));
|
| 26 |
+
if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2));
|
| 27 |
+
if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4));
|
| 28 |
+
if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8));
|
| 29 |
+
if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16));
|
| 30 |
+
} else {
|
| 31 |
+
if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16));
|
| 32 |
+
if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8));
|
| 33 |
+
if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4));
|
| 34 |
+
if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2));
|
| 35 |
+
if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1));
|
| 36 |
+
}
|
| 37 |
+
return value;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
// Convenience aliases
|
| 41 |
+
template <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
|
| 42 |
+
__forceinline__ __device__ T warp_reduce_sum(T value) {
|
| 43 |
+
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceSum<T>{});
|
| 44 |
+
}
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/common/scheduler.cuh
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <deep_gemm/common/types.hpp>
|
| 4 |
+
#include <deep_gemm/common/utils.cuh>
|
| 5 |
+
|
| 6 |
+
namespace deep_gemm {
|
| 7 |
+
|
| 8 |
+
enum class IndexType {
|
| 9 |
+
MN,
|
| 10 |
+
K,
|
| 11 |
+
SF_K,
|
| 12 |
+
};
|
| 13 |
+
|
| 14 |
+
template <GemmType kGemmType, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumSMs, bool kIsMulticastOnA>
|
| 15 |
+
static constexpr uint32_t get_num_1d_blocks_per_group() {
|
| 16 |
+
// Select the best from candidates
|
| 17 |
+
uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits<uint32_t>::max();
|
| 18 |
+
for (const auto& candidate: {8u, 16u}) {
|
| 19 |
+
const auto& usage = kIsMulticastOnA ?
|
| 20 |
+
candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N
|
| 21 |
+
candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M
|
| 22 |
+
if (usage < min_usage)
|
| 23 |
+
min_usage = usage, num_best_blocks = candidate;
|
| 24 |
+
}
|
| 25 |
+
return num_best_blocks;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
#pragma clang diagnostic push
|
| 29 |
+
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
|
| 30 |
+
template <GemmType kGemmType,
|
| 31 |
+
uint32_t BLOCK_M, uint32_t BLOCK_N,
|
| 32 |
+
uint32_t kNumGroups,
|
| 33 |
+
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
| 34 |
+
uint32_t kNumSMs,
|
| 35 |
+
uint32_t SF_K_ALIGNMENT = 512u, // for k-grouped GEMM only: 128 (SM90 float SF) or 512 (SM100 UE8M0 SF)
|
| 36 |
+
uint32_t kNum1DBlocksPerGroup = get_num_1d_blocks_per_group<kGemmType, BLOCK_M, BLOCK_N, kNumSMs, kIsMulticastOnA>()>
|
| 37 |
+
struct Scheduler {
|
| 38 |
+
int current_iter = -1;
|
| 39 |
+
|
| 40 |
+
// Block configs
|
| 41 |
+
uint32_t num_blocks;
|
| 42 |
+
uint32_t num_m_blocks;
|
| 43 |
+
uint32_t num_n_blocks;
|
| 44 |
+
|
| 45 |
+
// For SM90 multicast checks
|
| 46 |
+
uint32_t num_blocks_in_group;
|
| 47 |
+
bool is_peer_cta_alive = true;
|
| 48 |
+
|
| 49 |
+
// For grouped GEMM
|
| 50 |
+
int* grouped_layout;
|
| 51 |
+
uint32_t current_group_idx = 0;
|
| 52 |
+
// Only used for masked layout
|
| 53 |
+
uint32_t current_m_cumsum = 0;
|
| 54 |
+
// Only used for countiguous psum layout
|
| 55 |
+
uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0;
|
| 56 |
+
// Only used for k-grouped layout
|
| 57 |
+
uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0;
|
| 58 |
+
uint32_t next_group_idx, next_shape_k;
|
| 59 |
+
|
| 60 |
+
// Only used for k-grouped gemm
|
| 61 |
+
__device__ __forceinline__ void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const {
|
| 62 |
+
for (; group_idx < kNumGroups; ++ group_idx) {
|
| 63 |
+
shape_k = __ldg(grouped_layout + group_idx);
|
| 64 |
+
if (shape_k > 0)
|
| 65 |
+
break;
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
// ReSharper disable once CppPossiblyUninitializedMember
|
| 70 |
+
__device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, const uint32_t& shape_k,
|
| 71 |
+
int* grouped_layout = nullptr) {
|
| 72 |
+
num_m_blocks = ceil_div(shape_m, BLOCK_M);
|
| 73 |
+
num_n_blocks = ceil_div(shape_n, BLOCK_N);
|
| 74 |
+
current_shape_k = shape_k;
|
| 75 |
+
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) {
|
| 76 |
+
num_blocks = num_m_blocks * num_n_blocks;
|
| 77 |
+
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
|
| 78 |
+
num_blocks = num_m_blocks * num_n_blocks;
|
| 79 |
+
this->grouped_layout = grouped_layout;
|
| 80 |
+
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
|
| 81 |
+
this->grouped_layout = grouped_layout;
|
| 82 |
+
} else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
|
| 83 |
+
this->grouped_layout = grouped_layout;
|
| 84 |
+
current_psum_m = __ldg(grouped_layout);
|
| 85 |
+
num_m_blocks = ceil_div(current_psum_m, BLOCK_M);
|
| 86 |
+
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
|
| 87 |
+
this->grouped_layout = grouped_layout;
|
| 88 |
+
get_next_k_group(current_group_idx, current_shape_k);
|
| 89 |
+
next_group_idx = current_group_idx + 1;
|
| 90 |
+
get_next_k_group(next_group_idx, next_shape_k);
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
| 95 |
+
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size");
|
| 96 |
+
|
| 97 |
+
// Swizzle for better L2 usages
|
| 98 |
+
const auto& primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks;
|
| 99 |
+
const auto& secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks;
|
| 100 |
+
const auto& num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup;
|
| 101 |
+
const auto& group_idx = block_idx / num_blocks_per_group;
|
| 102 |
+
auto first_block_idx = group_idx * kNum1DBlocksPerGroup;
|
| 103 |
+
auto in_group_idx = block_idx % num_blocks_per_group;
|
| 104 |
+
num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx);
|
| 105 |
+
|
| 106 |
+
// Fix unaligned TMA multicast
|
| 107 |
+
// NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast
|
| 108 |
+
// while SM100 uses 2-CTA, which can not be dynamically disabled
|
| 109 |
+
#if __CUDA_ARCH__ < 1000
|
| 110 |
+
if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) {
|
| 111 |
+
if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) {
|
| 112 |
+
num_blocks_in_group = num_blocks_in_group ^ 1;
|
| 113 |
+
} else {
|
| 114 |
+
in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks;
|
| 115 |
+
first_block_idx += num_blocks_in_group ^ 1;
|
| 116 |
+
num_blocks_in_group = 1;
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
#endif
|
| 120 |
+
|
| 121 |
+
// Convert to final M/N block indices
|
| 122 |
+
// `kIsMulticastOnA == true` leads to groups on N
|
| 123 |
+
if constexpr (kIsMulticastOnA) {
|
| 124 |
+
m_block_idx = in_group_idx / num_blocks_in_group;
|
| 125 |
+
n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
|
| 126 |
+
} else {
|
| 127 |
+
m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
|
| 128 |
+
n_block_idx = in_group_idx / num_blocks_in_group;
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
template <bool kWithGroupOffset, IndexType kIndexType = IndexType::MN>
|
| 133 |
+
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
|
| 134 |
+
const uint32_t& block_idx, const uint32_t& m_block_idx = 0) {
|
| 135 |
+
if constexpr (kGemmType == GemmType::Normal) {
|
| 136 |
+
return block_idx * block_size;
|
| 137 |
+
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
|
| 138 |
+
const auto offset = kWithGroupOffset ? cute::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0;
|
| 139 |
+
return offset * shape_dim + block_idx * block_size;
|
| 140 |
+
} else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
|
| 141 |
+
const auto offset = kWithGroupOffset ? current_group_idx : 0;
|
| 142 |
+
return offset * shape_dim + block_idx * block_size;
|
| 143 |
+
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
|
| 144 |
+
auto offset = 0;
|
| 145 |
+
if constexpr (kWithGroupOffset) {
|
| 146 |
+
if constexpr (kIndexType == IndexType::MN)
|
| 147 |
+
offset = current_group_idx * shape_dim;
|
| 148 |
+
else if constexpr (kIndexType == IndexType::K)
|
| 149 |
+
offset = current_k_cumsum;
|
| 150 |
+
else if constexpr (kIndexType == IndexType::SF_K)
|
| 151 |
+
offset = current_sf_k_cumsum;
|
| 152 |
+
}
|
| 153 |
+
return offset + block_idx * block_size;
|
| 154 |
+
} else if constexpr (kGemmType == GemmType::Batched) {
|
| 155 |
+
// Ignore kWithGroupOffset, and apply offset for IndexType::SF_K
|
| 156 |
+
const auto offset = kIndexType == IndexType::SF_K ? current_group_idx : 0;
|
| 157 |
+
return offset * shape_dim + block_idx * block_size;
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
| 162 |
+
const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x;
|
| 163 |
+
|
| 164 |
+
if constexpr (kGemmType == GemmType::MGroupedMasked) {
|
| 165 |
+
while (true) {
|
| 166 |
+
// End of the task
|
| 167 |
+
if (current_group_idx == kNumGroups)
|
| 168 |
+
return false;
|
| 169 |
+
|
| 170 |
+
// Within current group
|
| 171 |
+
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + current_group_idx)), BLOCK_M);
|
| 172 |
+
const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks;
|
| 173 |
+
if (next_block_idx < current_m_block_cumsum * num_n_blocks)
|
| 174 |
+
break;
|
| 175 |
+
|
| 176 |
+
// Move to check the next group
|
| 177 |
+
current_group_idx ++, current_m_cumsum = current_m_block_cumsum;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx);
|
| 181 |
+
} else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
|
| 182 |
+
while (true) {
|
| 183 |
+
// Within current group
|
| 184 |
+
if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks)
|
| 185 |
+
break;
|
| 186 |
+
|
| 187 |
+
// Move to check the next group
|
| 188 |
+
if (++ current_group_idx == kNumGroups)
|
| 189 |
+
return false;
|
| 190 |
+
|
| 191 |
+
// NOTES: `num_m_blocks` varies with the increase of the group index
|
| 192 |
+
last_psum_m = align(current_psum_m, 128u);
|
| 193 |
+
current_psum_m = __ldg(grouped_layout + current_group_idx);
|
| 194 |
+
current_m_block_cumsum += num_m_blocks;
|
| 195 |
+
num_m_blocks = ceil_div(current_psum_m - last_psum_m, BLOCK_M);
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx);
|
| 199 |
+
|
| 200 |
+
// NOTES: `last_psum_m` is aligned with 128
|
| 201 |
+
m_block_idx += last_psum_m / BLOCK_M;
|
| 202 |
+
DG_STATIC_ASSERT(128 % BLOCK_M == 0, "Invalid BLOCK_M");
|
| 203 |
+
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
|
| 204 |
+
while (true) {
|
| 205 |
+
// End of the task
|
| 206 |
+
if (current_group_idx == kNumGroups)
|
| 207 |
+
return false;
|
| 208 |
+
|
| 209 |
+
// Within current group
|
| 210 |
+
if (next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks)
|
| 211 |
+
break;
|
| 212 |
+
|
| 213 |
+
// Move to check the next group
|
| 214 |
+
current_k_cumsum += current_shape_k;
|
| 215 |
+
current_sf_k_cumsum += ceil_div(current_shape_k, SF_K_ALIGNMENT);
|
| 216 |
+
current_num_valid_groups ++;
|
| 217 |
+
|
| 218 |
+
current_group_idx = next_group_idx ++;
|
| 219 |
+
current_shape_k = next_shape_k;
|
| 220 |
+
get_next_k_group(next_group_idx, next_shape_k);
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_m_blocks * num_n_blocks, m_block_idx, n_block_idx);
|
| 224 |
+
} else if constexpr (kGemmType == GemmType::Batched) {
|
| 225 |
+
if (next_block_idx >= num_blocks * kNumGroups)
|
| 226 |
+
return false;
|
| 227 |
+
|
| 228 |
+
current_group_idx = next_block_idx / num_blocks;
|
| 229 |
+
const auto& block_idx = next_block_idx - current_group_idx * num_blocks;
|
| 230 |
+
if constexpr (kIsMulticastOnA) {
|
| 231 |
+
m_block_idx = block_idx / num_n_blocks;
|
| 232 |
+
n_block_idx = block_idx % num_n_blocks;
|
| 233 |
+
} else {
|
| 234 |
+
m_block_idx = block_idx % num_m_blocks;
|
| 235 |
+
n_block_idx = block_idx / num_m_blocks;
|
| 236 |
+
}
|
| 237 |
+
} else {
|
| 238 |
+
if (next_block_idx >= num_blocks)
|
| 239 |
+
return false;
|
| 240 |
+
|
| 241 |
+
// For SM90 only
|
| 242 |
+
// NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned
|
| 243 |
+
is_peer_cta_alive = num_n_blocks % kNumMulticast == 0 or // Always aligned on N (constant bypass)
|
| 244 |
+
num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass)
|
| 245 |
+
(next_block_idx ^ 1) < num_blocks; // Peer CTA in bound
|
| 246 |
+
get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx);
|
| 247 |
+
}
|
| 248 |
+
return true;
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
// For SM90 only
|
| 252 |
+
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
|
| 253 |
+
if (num_blocks_in_group == 1)
|
| 254 |
+
return false;
|
| 255 |
+
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or
|
| 256 |
+
kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched) {
|
| 257 |
+
return true;
|
| 258 |
+
} else {
|
| 259 |
+
DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type");
|
| 260 |
+
if constexpr (kIsMulticastOnA) {
|
| 261 |
+
return true;
|
| 262 |
+
} else {
|
| 263 |
+
const auto& group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
| 264 |
+
const auto& peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
|
| 265 |
+
return group_idx == peer_group_idx;
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
// For SM90 only
|
| 271 |
+
// ReSharper disable once CppNotAllPathsReturnValue
|
| 272 |
+
__device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const {
|
| 273 |
+
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) {
|
| 274 |
+
return true;
|
| 275 |
+
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
|
| 276 |
+
return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0;
|
| 277 |
+
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
|
| 278 |
+
return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + current_group_idx);
|
| 279 |
+
} else {
|
| 280 |
+
// Unreachable
|
| 281 |
+
DG_TRAP_ONLY_DEVICE_ASSERT(false);
|
| 282 |
+
}
|
| 283 |
+
}
|
| 284 |
+
};
|
| 285 |
+
|
| 286 |
+
#pragma clang diagnostic pop
|
| 287 |
+
|
| 288 |
+
} // namespace deep_gemm
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/common/sm100_utils.cuh
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cute/atom/mma_traits_sm100.hpp>
|
| 4 |
+
#include <cute/arch/mma_sm100_umma.hpp>
|
| 5 |
+
#include <cute/arch/tmem_allocator_sm100.hpp>
|
| 6 |
+
#include <cutlass/arch/barrier.h>
|
| 7 |
+
|
| 8 |
+
#include <deep_gemm/common/utils.cuh>
|
| 9 |
+
#include <deep_gemm/common/tma_utils.cuh>
|
| 10 |
+
|
| 11 |
+
namespace deep_gemm::sm100 {
|
| 12 |
+
|
| 13 |
+
__device__ __forceinline__
|
| 14 |
+
cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr,
|
| 15 |
+
uint32_t stride_byte_offset, uint32_t leading_byte_offset) {
|
| 16 |
+
cute::UMMA::SmemDescriptor desc;
|
| 17 |
+
|
| 18 |
+
// Set the version for SM100
|
| 19 |
+
desc.version_ = 1;
|
| 20 |
+
|
| 21 |
+
// Legacy mode
|
| 22 |
+
desc.lbo_mode_ = 0;
|
| 23 |
+
|
| 24 |
+
// Layout
|
| 25 |
+
desc.layout_type_ = static_cast<uint8_t>(layout);
|
| 26 |
+
|
| 27 |
+
// Start address
|
| 28 |
+
const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
| 29 |
+
desc.start_address_ = static_cast<uint16_t>(uint_ptr >> 4);
|
| 30 |
+
|
| 31 |
+
// Base offset
|
| 32 |
+
desc.base_offset_ = 0;
|
| 33 |
+
|
| 34 |
+
// SBO and LBO
|
| 35 |
+
desc.stride_byte_offset_ = stride_byte_offset >> 4;
|
| 36 |
+
desc.leading_byte_offset_ = leading_byte_offset >> 4;
|
| 37 |
+
|
| 38 |
+
return desc;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
__device__ __forceinline__
|
| 42 |
+
cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) {
|
| 43 |
+
// NOTES: the UTCCP layout is K-major by default
|
| 44 |
+
// Atom size: 8 x 128 bits
|
| 45 |
+
// {SBO, LBO} means the byte stride between atoms on {MN, K}
|
| 46 |
+
// Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero
|
| 47 |
+
return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
__device__ __forceinline__
|
| 51 |
+
void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) {
|
| 52 |
+
const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
| 53 |
+
desc.start_address_ = static_cast<uint16_t>(uint_ptr >> 4);
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
__device__ __forceinline__
|
| 57 |
+
static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) {
|
| 58 |
+
return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
// ReSharper disable once CppNotAllPathsReturnValue
|
| 62 |
+
template <cute::UMMA::Major kMajorMode, uint32_t kSwizzleMode, bool kUseBase32, typename dtype_t>
|
| 63 |
+
constexpr static cute::UMMA::LayoutType to_umma_layout_type() {
|
| 64 |
+
DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or
|
| 65 |
+
kSwizzleMode == 32 or kSwizzleMode == 64 or
|
| 66 |
+
kSwizzleMode == 128, "Invalid swizzling mode");
|
| 67 |
+
// A special case
|
| 68 |
+
if constexpr ((cute::is_same_v<dtype_t, float> and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) {
|
| 69 |
+
DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base");
|
| 70 |
+
return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
// Normal cases
|
| 74 |
+
if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE;
|
| 75 |
+
if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE;
|
| 76 |
+
if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B;
|
| 77 |
+
if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B;
|
| 78 |
+
if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
|
| 82 |
+
__device__ __forceinline__
|
| 83 |
+
constexpr uint32_t get_umma_desc_stride_k() {
|
| 84 |
+
return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
|
| 88 |
+
__device__ __forceinline__
|
| 89 |
+
uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) {
|
| 90 |
+
return base + (((offset + k_idx * get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>()) * static_cast<uint32_t>(sizeof(dtype_t))) >> 4u);
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, bool kUseBase32 = false, typename dtype_t>
|
| 94 |
+
__device__ __forceinline__
|
| 95 |
+
cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) {
|
| 96 |
+
const uint32_t stride_k = get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>();
|
| 97 |
+
const auto& layout_type = to_umma_layout_type<kMajorMode, kSwizzleMode, kUseBase32, dtype_t>();
|
| 98 |
+
const auto& num_non_contiguous = 128 / get_atom_base(layout_type);
|
| 99 |
+
if constexpr (kMajorMode == cute::UMMA::Major::K) {
|
| 100 |
+
// NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)`
|
| 101 |
+
// also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis
|
| 102 |
+
DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value");
|
| 103 |
+
|
| 104 |
+
// Atom size: 8 x `kSwizzleMode` (in bytes, on K)
|
| 105 |
+
// {SBO, LBO} means the byte stride between atoms on {MN, K}
|
| 106 |
+
// NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0
|
| 107 |
+
const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t);
|
| 108 |
+
const uint32_t leading_byte_offset = 0;
|
| 109 |
+
return make_smem_desc(layout_type,
|
| 110 |
+
base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
|
| 111 |
+
stride_byte_offset, leading_byte_offset);
|
| 112 |
+
} else {
|
| 113 |
+
constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
|
| 114 |
+
|
| 115 |
+
// Must have no in-atom MN-idx
|
| 116 |
+
// NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time
|
| 117 |
+
DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0);
|
| 118 |
+
DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling");
|
| 119 |
+
|
| 120 |
+
// Atom size: `kSwizzleMode` (in bytes, on MN) x 8
|
| 121 |
+
// NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving
|
| 122 |
+
// {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling
|
| 123 |
+
// {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling
|
| 124 |
+
uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t);
|
| 125 |
+
uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t);
|
| 126 |
+
if constexpr (kSwizzleMode == 16)
|
| 127 |
+
swap(stride_byte_offset, leading_byte_offset);
|
| 128 |
+
return make_smem_desc(layout_type,
|
| 129 |
+
base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
|
| 130 |
+
stride_byte_offset, leading_byte_offset);
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
__device__ __forceinline__
|
| 135 |
+
uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) {
|
| 136 |
+
desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id;
|
| 137 |
+
return static_cast<uint64_t>(static_cast<uint32_t>(desc)) << 32;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
template <uint32_t kNumCols>
|
| 141 |
+
__device__ constexpr uint32_t get_num_aligned_tmem_cols() {
|
| 142 |
+
DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns");
|
| 143 |
+
if (kNumCols <= 32) return 32;
|
| 144 |
+
if (kNumCols <= 64) return 64;
|
| 145 |
+
if (kNumCols <= 128) return 128;
|
| 146 |
+
if (kNumCols <= 256) return 256;
|
| 147 |
+
return 512;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
__device__ __forceinline__ void tcgen05_before_thread_sync() {
|
| 151 |
+
asm volatile("tcgen05.fence::before_thread_sync;");
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
__device__ __forceinline__ void tcgen05_after_thread_sync() {
|
| 155 |
+
asm volatile("tcgen05.fence::after_thread_sync;");
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
__device__ __forceinline__
|
| 159 |
+
void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier &mbarrier, void* smem_ptr, int col_idx, int4 row_idxs, uint64_t cache_hint) {
|
| 160 |
+
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
| 161 |
+
uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier);
|
| 162 |
+
asm volatile(
|
| 163 |
+
"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n"
|
| 164 |
+
:
|
| 165 |
+
: "r"(smem_addr), "l"(desc_ptr), "r"(col_idx),
|
| 166 |
+
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
|
| 167 |
+
"r"(mbarrier_addr), "l"(cache_hint)
|
| 168 |
+
: "memory"
|
| 169 |
+
);
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
// UMMA versions with relaxed assertions
|
| 173 |
+
struct SM100_MMA_F16BF16_SS {
|
| 174 |
+
__device__ static void
|
| 175 |
+
fma(uint64_t const& desc_a,
|
| 176 |
+
uint64_t const& desc_b,
|
| 177 |
+
uint32_t const& tmem_c,
|
| 178 |
+
uint32_t const& scale_c,
|
| 179 |
+
uint64_t const& desc) {
|
| 180 |
+
asm volatile(
|
| 181 |
+
"{\n\t"
|
| 182 |
+
".reg .pred p;\n\t"
|
| 183 |
+
"setp.ne.b32 p, %4, 0;\n\t"
|
| 184 |
+
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t"
|
| 185 |
+
"}\n"
|
| 186 |
+
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
|
| 187 |
+
}
|
| 188 |
+
};
|
| 189 |
+
|
| 190 |
+
struct SM100_MMA_F16BF16_2x1SM_SS {
|
| 191 |
+
__device__ static void
|
| 192 |
+
fma(uint64_t const& desc_a,
|
| 193 |
+
uint64_t const& desc_b,
|
| 194 |
+
uint32_t const& tmem_c,
|
| 195 |
+
uint32_t const& scale_c,
|
| 196 |
+
uint64_t const& desc) {
|
| 197 |
+
asm volatile(
|
| 198 |
+
"{\n\t"
|
| 199 |
+
".reg .pred p;\n\t"
|
| 200 |
+
"setp.ne.b32 p, %4, 0;\n\t"
|
| 201 |
+
"tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t"
|
| 202 |
+
"}\n"
|
| 203 |
+
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
|
| 204 |
+
}
|
| 205 |
+
};
|
| 206 |
+
|
| 207 |
+
struct SM100_MMA_MXF8F6F4_SS {
|
| 208 |
+
__device__ static void
|
| 209 |
+
fma(uint64_t const& desc_a,
|
| 210 |
+
uint64_t const& desc_b,
|
| 211 |
+
uint32_t const& tmem_c,
|
| 212 |
+
uint32_t const& scale_c,
|
| 213 |
+
uint64_t const& desc,
|
| 214 |
+
uint32_t const& tmem_sfa,
|
| 215 |
+
uint32_t const& tmem_sfb) {
|
| 216 |
+
asm volatile(
|
| 217 |
+
"{\n\t"
|
| 218 |
+
".reg .pred p;\n\t"
|
| 219 |
+
"setp.ne.b32 p, %4, 0;\n\t"
|
| 220 |
+
"tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t"
|
| 221 |
+
"}\n"
|
| 222 |
+
:
|
| 223 |
+
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
|
| 224 |
+
"r"(tmem_sfa), "r"(tmem_sfb));
|
| 225 |
+
}
|
| 226 |
+
};
|
| 227 |
+
|
| 228 |
+
struct SM100_MMA_MXF8F6F4_2x1SM_SS {
|
| 229 |
+
__device__ static void
|
| 230 |
+
fma(uint64_t const& desc_a,
|
| 231 |
+
uint64_t const& desc_b,
|
| 232 |
+
uint32_t const& tmem_c,
|
| 233 |
+
uint32_t const& scale_c,
|
| 234 |
+
uint64_t const& desc,
|
| 235 |
+
uint32_t const& tmem_sfa,
|
| 236 |
+
uint32_t const& tmem_sfb) {
|
| 237 |
+
asm volatile(
|
| 238 |
+
"{\n\t"
|
| 239 |
+
".reg .pred p;\n\t"
|
| 240 |
+
"setp.ne.b32 p, %4, 0;\n\t"
|
| 241 |
+
"tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t"
|
| 242 |
+
"}\n"
|
| 243 |
+
:
|
| 244 |
+
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
|
| 245 |
+
"r"(tmem_sfa), "r"(tmem_sfb));
|
| 246 |
+
}
|
| 247 |
+
};
|
| 248 |
+
|
| 249 |
+
struct SM100_MMA_F16BF16_WS_SS {
|
| 250 |
+
__device__ static void
|
| 251 |
+
fma(uint64_t const& desc_a,
|
| 252 |
+
uint64_t const& desc_b,
|
| 253 |
+
uint32_t const& tmem_c,
|
| 254 |
+
uint32_t const& scale_c,
|
| 255 |
+
uint64_t const& desc) {
|
| 256 |
+
asm volatile(
|
| 257 |
+
"{\n\t"
|
| 258 |
+
".reg .pred p;\n\t"
|
| 259 |
+
"setp.ne.b32 p, %4, 0;\n\t"
|
| 260 |
+
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t"
|
| 261 |
+
"}\n"
|
| 262 |
+
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
|
| 263 |
+
}
|
| 264 |
+
};
|
| 265 |
+
|
| 266 |
+
} // namespace `deep_gemm::sm100`
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/common/sm90_utils.cuh
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cute/arch/cluster_sm90.hpp>
|
| 4 |
+
#include <cute/arch/mma_sm90_desc.hpp>
|
| 5 |
+
#include <cute/arch/mma_sm90_gmma.hpp>
|
| 6 |
+
#include <cute/arch/mma_sm90_gmma_ext.hpp>
|
| 7 |
+
#include <cute/arch/mma_sm100_desc.hpp>
|
| 8 |
+
|
| 9 |
+
#include <deep_gemm/common/utils.cuh>
|
| 10 |
+
#include <deep_gemm/common/sm100_utils.cuh>
|
| 11 |
+
#include <deep_gemm/common/tma_utils.cuh>
|
| 12 |
+
|
| 13 |
+
namespace deep_gemm::sm90 {
|
| 14 |
+
|
| 15 |
+
template <int N_, typename MMA>
|
| 16 |
+
struct FP8MMA {
|
| 17 |
+
|
| 18 |
+
template <size_t ...Idx>
|
| 19 |
+
__forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
|
| 20 |
+
using namespace cute::SM90::GMMA;
|
| 21 |
+
MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
__forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 25 |
+
call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
static constexpr int M = 64;
|
| 29 |
+
static constexpr int N = N_;
|
| 30 |
+
static constexpr int K = 32;
|
| 31 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 32 |
+
};
|
| 33 |
+
|
| 34 |
+
template <int N>
|
| 35 |
+
struct FP8MMASelector {
|
| 36 |
+
|
| 37 |
+
static constexpr auto select_mma() {
|
| 38 |
+
using namespace cute::SM90::GMMA;
|
| 39 |
+
if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN();
|
| 40 |
+
if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN();
|
| 41 |
+
if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN();
|
| 42 |
+
if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN();
|
| 43 |
+
if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN();
|
| 44 |
+
if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN();
|
| 45 |
+
if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN();
|
| 46 |
+
if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN();
|
| 47 |
+
if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN();
|
| 48 |
+
if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN();
|
| 49 |
+
if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN();
|
| 50 |
+
if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN();
|
| 51 |
+
if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN();
|
| 52 |
+
if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN();
|
| 53 |
+
if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN();
|
| 54 |
+
if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN();
|
| 55 |
+
if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN();
|
| 56 |
+
if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN();
|
| 57 |
+
if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN();
|
| 58 |
+
if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN();
|
| 59 |
+
if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN();
|
| 60 |
+
if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN();
|
| 61 |
+
if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN();
|
| 62 |
+
if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN();
|
| 63 |
+
if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN();
|
| 64 |
+
if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN();
|
| 65 |
+
if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN();
|
| 66 |
+
if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN();
|
| 67 |
+
if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN();
|
| 68 |
+
if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN();
|
| 69 |
+
if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN();
|
| 70 |
+
if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN();
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
static constexpr auto select_type() {
|
| 74 |
+
return FP8MMA<N, decltype(select_mma())>();
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
using type = decltype(select_type());
|
| 78 |
+
};
|
| 79 |
+
|
| 80 |
+
template <int N_, typename MMA>
|
| 81 |
+
struct BF16MMA {
|
| 82 |
+
|
| 83 |
+
template <size_t ...Idx>
|
| 84 |
+
__forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
|
| 85 |
+
using namespace cute::SM90::GMMA;
|
| 86 |
+
MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
__forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 90 |
+
call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
static constexpr int M = 64;
|
| 94 |
+
static constexpr int N = N_;
|
| 95 |
+
static constexpr int K = 16;
|
| 96 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 97 |
+
};
|
| 98 |
+
|
| 99 |
+
template <cute::UMMA::Major kMajor>
|
| 100 |
+
constexpr cute::SM90::GMMA::Major to_sm90_major() {
|
| 101 |
+
DG_STATIC_ASSERT(kMajor == cute::UMMA::Major::K or kMajor == cute::UMMA::Major::MN, "Invalid major-ness");
|
| 102 |
+
return kMajor == cute::UMMA::Major::K ? cute::SM90::GMMA::Major::K : cute::SM90::GMMA::Major::MN;
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
template <int N,
|
| 106 |
+
cute::UMMA::Major kMajorA = cute::UMMA::Major::K,
|
| 107 |
+
cute::UMMA::Major kMajorB = cute::UMMA::Major::K>
|
| 108 |
+
struct BF16MMASelector {
|
| 109 |
+
|
| 110 |
+
static constexpr auto select_mma() {
|
| 111 |
+
using namespace cute::SM90::GMMA;
|
| 112 |
+
constexpr auto kGMMAMajorA = to_sm90_major<kMajorA>();
|
| 113 |
+
constexpr auto kGMMAMajorB = to_sm90_major<kMajorB>();
|
| 114 |
+
if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 115 |
+
if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 116 |
+
if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 117 |
+
if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 118 |
+
if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 119 |
+
if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 120 |
+
if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 121 |
+
if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 122 |
+
if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 123 |
+
if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 124 |
+
if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 125 |
+
if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 126 |
+
if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 127 |
+
if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 128 |
+
if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 129 |
+
if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 130 |
+
if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 131 |
+
if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 132 |
+
if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 133 |
+
if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 134 |
+
if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 135 |
+
if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 136 |
+
if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 137 |
+
if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 138 |
+
if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 139 |
+
if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 140 |
+
if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 141 |
+
if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 142 |
+
if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 143 |
+
if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 144 |
+
if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 145 |
+
if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
static constexpr auto select_type() {
|
| 149 |
+
return BF16MMA<N, decltype(select_mma())>();
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
using type = decltype(select_type());
|
| 153 |
+
};
|
| 154 |
+
|
| 155 |
+
template <int N_, typename MMA>
|
| 156 |
+
struct TF32MMARS {
|
| 157 |
+
|
| 158 |
+
template <size_t ...Idx>
|
| 159 |
+
__forceinline__ __device__ static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
|
| 160 |
+
using namespace cute::SM90::GMMA;
|
| 161 |
+
MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
__forceinline__ __device__ static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 165 |
+
call_fma_impl(reinterpret_cast<uint32_t*>(a), desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
static constexpr int M = 64;
|
| 169 |
+
static constexpr int N = N_;
|
| 170 |
+
static constexpr int K = 8;
|
| 171 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 172 |
+
};
|
| 173 |
+
|
| 174 |
+
template <int N, bool kUseRS = true>
|
| 175 |
+
struct TF32MMASelector {
|
| 176 |
+
|
| 177 |
+
static constexpr auto select_mma() {
|
| 178 |
+
using namespace cute::SM90::GMMA;
|
| 179 |
+
if constexpr (kUseRS) {
|
| 180 |
+
if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN();
|
| 181 |
+
if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN();
|
| 182 |
+
if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN();
|
| 183 |
+
if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN();
|
| 184 |
+
if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN();
|
| 185 |
+
if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN();
|
| 186 |
+
DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N");
|
| 187 |
+
}
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
static constexpr auto select_type() {
|
| 191 |
+
if constexpr (kUseRS) {
|
| 192 |
+
return TF32MMARS<N, decltype(select_mma())>();
|
| 193 |
+
} else {
|
| 194 |
+
DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now");
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
using type = decltype(select_type());
|
| 199 |
+
};
|
| 200 |
+
|
| 201 |
+
template <typename dtype_t>
|
| 202 |
+
struct SM90_U32x2_STSM_N {
|
| 203 |
+
__device__ __forceinline__ static void
|
| 204 |
+
copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
|
| 205 |
+
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
|
| 206 |
+
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
|
| 207 |
+
:: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1]));
|
| 208 |
+
}
|
| 209 |
+
};
|
| 210 |
+
|
| 211 |
+
struct SM90_U32x2_LDSM_N {
|
| 212 |
+
__device__ __forceinline__ static void
|
| 213 |
+
copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) {
|
| 214 |
+
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
|
| 215 |
+
: "=r"(dst_0), "=r"(dst_1)
|
| 216 |
+
: "l"(__cvta_generic_to_shared(smem_src)));
|
| 217 |
+
}
|
| 218 |
+
};
|
| 219 |
+
|
| 220 |
+
struct SM90_U32x4_LDSM_N {
|
| 221 |
+
__device__ __forceinline__ static void
|
| 222 |
+
copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) {
|
| 223 |
+
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
| 224 |
+
: "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3)
|
| 225 |
+
: "l"(__cvta_generic_to_shared(smem_src)));
|
| 226 |
+
}
|
| 227 |
+
};
|
| 228 |
+
|
| 229 |
+
__forceinline__ __device__ void warpgroup_arrive() {
|
| 230 |
+
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
__forceinline__ __device__ void warpgroup_commit_batch() {
|
| 234 |
+
asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
__forceinline__ __device__ void warpgroup_fence_operand(float& reg) {
|
| 238 |
+
asm volatile("" : "+f"(reg) :: "memory");
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
template <int N>
|
| 242 |
+
__forceinline__ __device__ void warpgroup_wait() {
|
| 243 |
+
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
|
| 244 |
+
asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
template <class PointerType>
|
| 248 |
+
__device__ cute::GmmaDescriptor make_smem_desc(PointerType smem_ptr, const int& layout_type,
|
| 249 |
+
const int& leading_byte_offset = 0,
|
| 250 |
+
const int& stride_byte_offset = 1024) {
|
| 251 |
+
// NOTES: the default LBO and SBO are for K-major types
|
| 252 |
+
cute::GmmaDescriptor desc;
|
| 253 |
+
const auto& uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
| 254 |
+
desc.bitfield.start_address_ = uint_ptr >> 4;
|
| 255 |
+
desc.bitfield.layout_type_ = layout_type;
|
| 256 |
+
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
|
| 257 |
+
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
|
| 258 |
+
desc.bitfield.base_offset_ = 0;
|
| 259 |
+
return desc;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
template <uint32_t BLOCK_INNER, uint32_t kSwizzleMode, typename dtype_t>
|
| 263 |
+
constexpr uint32_t get_inner_block_atom_size() {
|
| 264 |
+
return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t);
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
|
| 268 |
+
__device__ __forceinline__
|
| 269 |
+
constexpr uint32_t get_gmma_desc_stride_k() {
|
| 270 |
+
return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
// ReSharper disable once CppNotAllPathsReturnValue
|
| 274 |
+
template <cute::UMMA::Major kMajorMode, uint32_t kSwizzleMode, typename dtype_t>
|
| 275 |
+
constexpr static cute::SM90::GMMA::LayoutType to_gmma_layout_type() {
|
| 276 |
+
DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or
|
| 277 |
+
kSwizzleMode == 32 or kSwizzleMode == 64 or
|
| 278 |
+
kSwizzleMode == 128, "Invalid swizzling mode");
|
| 279 |
+
|
| 280 |
+
// Normal cases
|
| 281 |
+
if constexpr (kSwizzleMode == 0) return cute::SM90::GMMA::LayoutType::INTERLEAVE;
|
| 282 |
+
if constexpr (kSwizzleMode == 16) return cute::SM90::GMMA::LayoutType::INTERLEAVE;
|
| 283 |
+
if constexpr (kSwizzleMode == 32) return cute::SM90::GMMA::LayoutType::B32;
|
| 284 |
+
if constexpr (kSwizzleMode == 64) return cute::SM90::GMMA::LayoutType::B64;
|
| 285 |
+
if constexpr (kSwizzleMode == 128) return cute::SM90::GMMA::LayoutType::B128;
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, typename dtype_t>
|
| 289 |
+
__device__ __forceinline__
|
| 290 |
+
uint32_t advance_gmma_desc_lo(const uint32_t& base, const uint32_t& mn_idx, const uint32_t& k_idx, const uint32_t& offset = 0) {
|
| 291 |
+
return base + (((offset + mn_idx * BLOCK_K + k_idx * get_gmma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>()) * static_cast<uint32_t>(sizeof(dtype_t))) >> 4u);
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, typename dtype_t>
|
| 295 |
+
__device__ __forceinline__
|
| 296 |
+
cute::GmmaDescriptor make_gmma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) {
|
| 297 |
+
const uint32_t stride_k = get_gmma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>();
|
| 298 |
+
const auto& layout_type = to_gmma_layout_type<kMajorMode, kSwizzleMode, dtype_t>();
|
| 299 |
+
constexpr uint32_t num_non_contiguous = 128 / 16;
|
| 300 |
+
if constexpr (kMajorMode == cute::UMMA::Major::K) {
|
| 301 |
+
// NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128
|
| 302 |
+
DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value");
|
| 303 |
+
|
| 304 |
+
// Atom size: 8 x `kSwizzleMode` (in bytes, on K)
|
| 305 |
+
// {SBO, LBO} means the byte stride between atoms on {MN, K}
|
| 306 |
+
// NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0
|
| 307 |
+
const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t);
|
| 308 |
+
const uint32_t leading_byte_offset = 0;
|
| 309 |
+
return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast<uint32_t>(layout_type),
|
| 310 |
+
leading_byte_offset, stride_byte_offset);
|
| 311 |
+
} else {
|
| 312 |
+
constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
|
| 313 |
+
|
| 314 |
+
// Must have no in-atom MN-idx
|
| 315 |
+
// NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time
|
| 316 |
+
DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0);
|
| 317 |
+
DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling");
|
| 318 |
+
|
| 319 |
+
// Atom size: `kSwizzleMode` (in bytes, on MN) x 8
|
| 320 |
+
// NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving
|
| 321 |
+
// {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling
|
| 322 |
+
// {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling
|
| 323 |
+
uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t);
|
| 324 |
+
uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t);
|
| 325 |
+
if constexpr (kSwizzleMode == 16)
|
| 326 |
+
swap(stride_byte_offset, leading_byte_offset);
|
| 327 |
+
return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast<uint32_t>(layout_type),
|
| 328 |
+
leading_byte_offset, stride_byte_offset);
|
| 329 |
+
}
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
} // namespace `deep_gemm::sm90`
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/common/tma_utils.cuh
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cute/arch/copy_sm90_tma.hpp>
|
| 4 |
+
#include <cute/arch/copy_sm100_tma.hpp>
|
| 5 |
+
#include <cutlass/arch/barrier.h>
|
| 6 |
+
|
| 7 |
+
namespace deep_gemm {
|
| 8 |
+
|
| 9 |
+
template <uint32_t BLOCK_INNER, uint32_t kSwizzleMode, typename dtype_t>
|
| 10 |
+
constexpr uint32_t get_inner_block_atom_size() {
|
| 11 |
+
return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t);
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
template <uint32_t BLOCK_INNER, uint32_t BLOCK_OUTER,
|
| 15 |
+
uint32_t kSwizzleMode,
|
| 16 |
+
typename dtype_t, bool kIs3DTMA = false>
|
| 17 |
+
__device__ __forceinline__ void
|
| 18 |
+
tma_copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr,
|
| 19 |
+
dtype_t* smem_ptr, const uint32_t& inner_idx, const uint32_t& outer_idx,
|
| 20 |
+
const uint32_t& num_tma_multicast = 1, const uint32_t& batch_idx = 0) {
|
| 21 |
+
DG_STATIC_ASSERT(static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL) ==
|
| 22 |
+
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint");
|
| 23 |
+
constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size<BLOCK_INNER, kSwizzleMode, dtype_t>();
|
| 24 |
+
|
| 25 |
+
if constexpr (not kIs3DTMA) {
|
| 26 |
+
if (num_tma_multicast == 1) {
|
| 27 |
+
#pragma unroll
|
| 28 |
+
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
|
| 29 |
+
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
|
| 30 |
+
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
|
| 31 |
+
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
|
| 32 |
+
inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
|
| 33 |
+
}
|
| 34 |
+
} else {
|
| 35 |
+
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000))
|
| 36 |
+
// 2-CTA function will send signals to the leader CTA only
|
| 37 |
+
#pragma unroll
|
| 38 |
+
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
|
| 39 |
+
cute::SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
|
| 40 |
+
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
|
| 41 |
+
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
|
| 42 |
+
inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
|
| 43 |
+
}
|
| 44 |
+
#elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900))
|
| 45 |
+
if (cute::block_rank_in_cluster() == 0) {
|
| 46 |
+
#pragma unroll
|
| 47 |
+
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
|
| 48 |
+
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
|
| 49 |
+
(1 << num_tma_multicast) - 1, static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL),
|
| 50 |
+
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
|
| 51 |
+
inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
#endif
|
| 55 |
+
}
|
| 56 |
+
} else {
|
| 57 |
+
if (num_tma_multicast == 1) {
|
| 58 |
+
#pragma unroll
|
| 59 |
+
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
|
| 60 |
+
cute::SM90_TMA_LOAD_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
|
| 61 |
+
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
|
| 62 |
+
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
|
| 63 |
+
inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
|
| 64 |
+
}
|
| 65 |
+
} else {
|
| 66 |
+
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000))
|
| 67 |
+
// 2-CTA function will send signals to the leader CTA only
|
| 68 |
+
#pragma unroll
|
| 69 |
+
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
|
| 70 |
+
cute::SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
|
| 71 |
+
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
|
| 72 |
+
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
|
| 73 |
+
inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
|
| 74 |
+
}
|
| 75 |
+
#elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900))
|
| 76 |
+
if (cute::block_rank_in_cluster() == 0) {
|
| 77 |
+
#pragma unroll
|
| 78 |
+
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
|
| 79 |
+
cute::SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
|
| 80 |
+
(1 << num_tma_multicast) - 1, static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL),
|
| 81 |
+
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
|
| 82 |
+
inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
#endif
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
// Tensormap related
|
| 91 |
+
__device__ __forceinline__ void tensor_map_release_cta() {
|
| 92 |
+
asm volatile ("fence.proxy.tensormap::generic.release.cta;");
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
__device__ __forceinline__ void tensor_map_acquire_cta(const cute::TmaDescriptor* gmem_desc_ptr) {
|
| 96 |
+
auto gmem_int_desc = reinterpret_cast<uint64_t>(gmem_desc_ptr);
|
| 97 |
+
asm volatile ("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" :: "l"(gmem_int_desc) : "memory");
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
__device__ __forceinline__ void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) {
|
| 101 |
+
auto smem_int_desc = static_cast<uint32_t>(__cvta_generic_to_shared(smem_desc));
|
| 102 |
+
const auto new_int64_addr = reinterpret_cast<uint64_t>(new_addr);
|
| 103 |
+
asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr));
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
__device__ __forceinline__ void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) {
|
| 107 |
+
auto smem_int_desc = __cvta_generic_to_shared(smem_desc);
|
| 108 |
+
asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim));
|
| 109 |
+
#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 3)))
|
| 110 |
+
asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride));
|
| 111 |
+
#else
|
| 112 |
+
DG_STATIC_ASSERT(false, "Invalid CUDA version");
|
| 113 |
+
#endif
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
} // namespace `deep_gemm`
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/common/types.hpp
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
namespace deep_gemm {
|
| 4 |
+
|
| 5 |
+
enum class MmaKind {
|
| 6 |
+
BF16 = 0,
|
| 7 |
+
MXFP8FP4 = 1,
|
| 8 |
+
};
|
| 9 |
+
|
| 10 |
+
constexpr __host__ __device__ int get_element_size(const MmaKind& mma_kind) {
|
| 11 |
+
switch (mma_kind) {
|
| 12 |
+
case MmaKind::BF16: return 2;
|
| 13 |
+
case MmaKind::MXFP8FP4: return 1;
|
| 14 |
+
default: return 0;
|
| 15 |
+
}
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
enum class GemmType {
|
| 19 |
+
Normal = 0,
|
| 20 |
+
MGroupedContiguous = 1,
|
| 21 |
+
MGroupedMasked = 2,
|
| 22 |
+
KGroupedContiguous = 3,
|
| 23 |
+
Batched = 4,
|
| 24 |
+
MGroupedContiguousWithPsumLayout = 5,
|
| 25 |
+
};
|
| 26 |
+
|
| 27 |
+
constexpr __host__ __device__ bool is_m_grouped_contiguous(const GemmType& gemm_type) {
|
| 28 |
+
switch (gemm_type) {
|
| 29 |
+
case GemmType::MGroupedContiguous: return true;
|
| 30 |
+
case GemmType::MGroupedContiguousWithPsumLayout: return true;
|
| 31 |
+
default: return false;
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
enum class KernelType {
|
| 36 |
+
Kernel1D1D = 0,
|
| 37 |
+
Kernel1D2D = 1,
|
| 38 |
+
KernelNoSF = 2
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
} // namespace deep_gemm
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/common/utils.cuh
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cuda_bf16.h>
|
| 4 |
+
#include <cuda_fp8.h>
|
| 5 |
+
#include <cuda/std/cstdint>
|
| 6 |
+
#include <cuda/std/utility>
|
| 7 |
+
#include <cute/container/tuple.hpp>
|
| 8 |
+
|
| 9 |
+
#include "cute_tie.cuh"
|
| 10 |
+
|
| 11 |
+
#ifdef __CLION_IDE__
|
| 12 |
+
|
| 13 |
+
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) {
|
| 14 |
+
asm volatile("trap;");
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
#define printf host_device_printf
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
#ifndef DG_DEVICE_ASSERT
|
| 21 |
+
#define DG_DEVICE_ASSERT(cond) \
|
| 22 |
+
do { \
|
| 23 |
+
if (not (cond)) { \
|
| 24 |
+
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
|
| 25 |
+
asm("trap;"); \
|
| 26 |
+
} \
|
| 27 |
+
} while (0)
|
| 28 |
+
#endif
|
| 29 |
+
|
| 30 |
+
#ifndef DG_TRAP_ONLY_DEVICE_ASSERT
|
| 31 |
+
#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \
|
| 32 |
+
do { \
|
| 33 |
+
if (not (cond)) \
|
| 34 |
+
asm("trap;"); \
|
| 35 |
+
} while (0)
|
| 36 |
+
#endif
|
| 37 |
+
|
| 38 |
+
#ifndef DG_STATIC_ASSERT
|
| 39 |
+
#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__)
|
| 40 |
+
#endif
|
| 41 |
+
|
| 42 |
+
namespace deep_gemm {
|
| 43 |
+
|
| 44 |
+
template <typename FuncT>
|
| 45 |
+
struct PatternVisitor {
|
| 46 |
+
FuncT func;
|
| 47 |
+
|
| 48 |
+
__device__ __host__
|
| 49 |
+
explicit PatternVisitor(FuncT&& func): func(std::forward<FuncT>(func)) {}
|
| 50 |
+
|
| 51 |
+
__device__ __host__
|
| 52 |
+
auto operator [](const uint32_t& i) {
|
| 53 |
+
return func(i);
|
| 54 |
+
}
|
| 55 |
+
};
|
| 56 |
+
|
| 57 |
+
template <typename T>
|
| 58 |
+
__device__ __host__ T ceil_div(T a, T b) {
|
| 59 |
+
return (a + b - 1) / b;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
template <typename T>
|
| 63 |
+
__device__ __host__ constexpr T constexpr_ceil_div(T a, T b) {
|
| 64 |
+
return (a + b - 1) / b;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
template <typename T>
|
| 68 |
+
__device__ __host__ T align(T a, T b) {
|
| 69 |
+
return ceil_div(a, b) * b;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
template <typename T>
|
| 73 |
+
__device__ __host__ constexpr T constexpr_align(T a, T b) {
|
| 74 |
+
return constexpr_ceil_div(a, b) * b;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
template <typename T>
|
| 78 |
+
__device__ __host__ constexpr T constexpr_gcd(T a, T b) {
|
| 79 |
+
return b == 0 ? a : constexpr_gcd(b, a % b);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
template<typename T>
|
| 83 |
+
__forceinline__ __device__ void swap(T& a, T& b) {
|
| 84 |
+
T temp = a;
|
| 85 |
+
a = b;
|
| 86 |
+
b = temp;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
__forceinline__ __device__ uint32_t get_sm_idx() {
|
| 90 |
+
uint32_t sm_idx;
|
| 91 |
+
asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx));
|
| 92 |
+
return sm_idx;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
__forceinline__ __device__ uint32_t get_lane_idx() {
|
| 96 |
+
uint32_t lane_id;
|
| 97 |
+
asm ("mov.u32 %0, %laneid;" : "=r"(lane_id));
|
| 98 |
+
return lane_id;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
__device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) {
|
| 102 |
+
uint32_t ret;
|
| 103 |
+
asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr)));
|
| 104 |
+
return ret;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
__device__ __forceinline__ float2 ld_shared(const float2* ptr) {
|
| 108 |
+
float2 ret;
|
| 109 |
+
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr)));
|
| 110 |
+
return ret;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
__device__ __forceinline__ float4 ld_shared(const float4* ptr) {
|
| 114 |
+
float4 ret;
|
| 115 |
+
asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr)));
|
| 116 |
+
return ret;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
__device__ __forceinline__ uint4 ld_shared(const uint4* ptr) {
|
| 120 |
+
uint4 ret;
|
| 121 |
+
asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr)));
|
| 122 |
+
return ret;
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
__device__ __forceinline__ float ld_shared(const float* ptr) {
|
| 126 |
+
float ret;
|
| 127 |
+
asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr)));
|
| 128 |
+
return ret;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
|
| 132 |
+
asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val));
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) {
|
| 136 |
+
asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y));
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
|
| 140 |
+
asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val));
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y) {
|
| 144 |
+
asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y));
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) {
|
| 148 |
+
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w));
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
__device__ __forceinline__ void st_shared(const __int128_t* ptr, __int128_t val) {
|
| 152 |
+
asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val));
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
template <typename old_t>
|
| 156 |
+
__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) {
|
| 157 |
+
auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast<float*>(&x), *reinterpret_cast<float*>(&y)});
|
| 158 |
+
return *reinterpret_cast<int*>(&bf16x2);
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
__device__ __forceinline__ void prefetch_l1(void *ptr) {
|
| 162 |
+
asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr));
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
template <uint32_t kNumBytes>
|
| 166 |
+
struct Vectorized {
|
| 167 |
+
static auto zeros() {
|
| 168 |
+
// TODO: add `ulonglong4` for SM100 once `__ldg` support this
|
| 169 |
+
if constexpr (kNumBytes > 0 and kNumBytes % 16 == 0) {
|
| 170 |
+
return make_uint4(0, 0, 0, 0);
|
| 171 |
+
} else if constexpr (kNumBytes > 0 and kNumBytes % 8 == 0) {
|
| 172 |
+
return make_uint2(0, 0);
|
| 173 |
+
} else if constexpr (kNumBytes > 0 and kNumBytes % 4 == 0) {
|
| 174 |
+
return 0;
|
| 175 |
+
} else {
|
| 176 |
+
DG_STATIC_ASSERT(kNumBytes > 0 and kNumBytes % 4 == 0, "Invalid vectorization");
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
using vec_t = decltype(zeros());
|
| 181 |
+
};
|
| 182 |
+
|
| 183 |
+
} // namespace `deep_gemm`
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#pragma clang diagnostic push
|
| 3 |
+
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
| 4 |
+
|
| 5 |
+
#include <cutlass/arch/barrier.h>
|
| 6 |
+
|
| 7 |
+
#include <deep_gemm/common/scheduler.cuh>
|
| 8 |
+
#include <deep_gemm/common/utils.cuh>
|
| 9 |
+
#include <deep_gemm/common/sm100_utils.cuh>
|
| 10 |
+
|
| 11 |
+
namespace deep_gemm {
|
| 12 |
+
|
| 13 |
+
using namespace deep_gemm::sm100;
|
| 14 |
+
|
| 15 |
+
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
| 16 |
+
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 17 |
+
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K_,
|
| 18 |
+
uint32_t kNumGroups,
|
| 19 |
+
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
|
| 20 |
+
uint32_t kNumStages_,
|
| 21 |
+
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
|
| 22 |
+
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
| 23 |
+
uint32_t kNumSMs,
|
| 24 |
+
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t,
|
| 25 |
+
uint64_t kTensorCoreUtilControl>
|
| 26 |
+
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
|
| 27 |
+
sm100_bf16_gemm_impl(int* grouped_layout,
|
| 28 |
+
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
| 29 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
| 30 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
| 31 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_cd) {
|
| 32 |
+
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
|
| 33 |
+
// Enlarge `BLOCK_K` for some cases
|
| 34 |
+
// NOTES: this is for reducing the `umma_arrive()` overhead
|
| 35 |
+
constexpr bool kDoMergeStages =
|
| 36 |
+
kNumStages_ >= 8 and kGemmType == GemmType::Normal and
|
| 37 |
+
kMajorA == cute::UMMA::Major::K and kMajorB == cute::UMMA::Major::K;
|
| 38 |
+
// Ensure there are at least `kNumMinStages` stages after merge
|
| 39 |
+
constexpr uint32_t kNumMinStages = 8;
|
| 40 |
+
constexpr uint32_t kNumStagesPerMerge = kDoMergeStages ? kNumStages_ / kNumMinStages : 1;
|
| 41 |
+
constexpr uint32_t BLOCK_K = BLOCK_K_ * kNumStagesPerMerge;
|
| 42 |
+
constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge;
|
| 43 |
+
|
| 44 |
+
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 45 |
+
using Allocator = cute::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
|
| 46 |
+
|
| 47 |
+
// GEMM with accumulation must have FP32 output
|
| 48 |
+
if constexpr (kWithAccumulation)
|
| 49 |
+
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
|
| 50 |
+
|
| 51 |
+
// Configs
|
| 52 |
+
constexpr uint32_t LAYOUT_AD_M = 128;
|
| 53 |
+
constexpr uint32_t WAVE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
|
| 54 |
+
constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M;
|
| 55 |
+
constexpr uint32_t kNumTMAStoreStages = 2;
|
| 56 |
+
DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K");
|
| 57 |
+
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M");
|
| 58 |
+
DG_STATIC_ASSERT(sizeof(cutlass::bfloat16_t) * LAYOUT_AD_M % kSwizzleAMode == 0, "Invalid swizzle A mode");
|
| 59 |
+
|
| 60 |
+
// Overwrite shape constants if the compiler gives
|
| 61 |
+
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
|
| 62 |
+
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
|
| 63 |
+
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
| 64 |
+
|
| 65 |
+
// Utils
|
| 66 |
+
bool is_leader_cta = cute::block_rank_in_cluster() == 0;
|
| 67 |
+
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
| 68 |
+
const auto lane_idx = get_lane_idx();
|
| 69 |
+
|
| 70 |
+
// Align to 1024 bytes for swizzle-128B
|
| 71 |
+
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
| 72 |
+
|
| 73 |
+
// 2-CTA MMA
|
| 74 |
+
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
|
| 75 |
+
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
|
| 76 |
+
constexpr uint32_t STORE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
|
| 77 |
+
constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t);
|
| 78 |
+
constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M;
|
| 79 |
+
DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast");
|
| 80 |
+
DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D");
|
| 81 |
+
DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
|
| 82 |
+
DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M");
|
| 83 |
+
|
| 84 |
+
// Share memory sizes
|
| 85 |
+
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode;
|
| 86 |
+
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
|
| 87 |
+
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t);
|
| 88 |
+
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t);
|
| 89 |
+
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0,
|
| 90 |
+
"Shared memory of A/B must be aligned to 1024 bytes");
|
| 91 |
+
DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
|
| 92 |
+
|
| 93 |
+
// NOTES: Make sure we have enough shared memory for UMMA padding
|
| 94 |
+
static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16);
|
| 95 |
+
DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA");
|
| 96 |
+
|
| 97 |
+
// Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size
|
| 98 |
+
// TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2`
|
| 99 |
+
constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N) > 512 ? 1 : 2;
|
| 100 |
+
|
| 101 |
+
// Real tensor memory size and offsets
|
| 102 |
+
constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N;
|
| 103 |
+
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols>();
|
| 104 |
+
|
| 105 |
+
// Prefetch TMA descriptors at the very beginning
|
| 106 |
+
if (warp_idx == 0 and cute::elect_one_sync()) {
|
| 107 |
+
cute::prefetch_tma_descriptor(&tensor_map_a);
|
| 108 |
+
cute::prefetch_tma_descriptor(&tensor_map_b);
|
| 109 |
+
cute::prefetch_tma_descriptor(&tensor_map_cd);
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
// D/A/B shared memory
|
| 113 |
+
auto smem_cd = PatternVisitor([&](const uint32_t& i) {
|
| 114 |
+
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
|
| 115 |
+
});
|
| 116 |
+
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
| 117 |
+
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
| 118 |
+
});
|
| 119 |
+
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
| 120 |
+
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
| 121 |
+
});
|
| 122 |
+
|
| 123 |
+
// Fill barriers
|
| 124 |
+
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 125 |
+
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 126 |
+
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 127 |
+
auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
|
| 128 |
+
auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); });
|
| 129 |
+
auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2;
|
| 130 |
+
|
| 131 |
+
// Fill the tensor memory pointer
|
| 132 |
+
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2 + 1);
|
| 133 |
+
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
| 134 |
+
|
| 135 |
+
// Initialize barriers
|
| 136 |
+
if (warp_idx == 1 and cute::elect_one_sync()) {
|
| 137 |
+
#pragma unroll
|
| 138 |
+
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
| 139 |
+
// Arrive only at the leader CTA
|
| 140 |
+
full_barriers[i]->init(kNumMulticast);
|
| 141 |
+
// Arrive at all CTAs
|
| 142 |
+
empty_barriers[i]->init(1);
|
| 143 |
+
}
|
| 144 |
+
#pragma unroll
|
| 145 |
+
for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) {
|
| 146 |
+
// Arrive at all CTAs
|
| 147 |
+
tmem_full_barriers[i]->init(1);
|
| 148 |
+
// Arrive only at the leader CTA
|
| 149 |
+
tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads);
|
| 150 |
+
}
|
| 151 |
+
if constexpr (kTensorCoreUtilControl < 100)
|
| 152 |
+
tensor_core_full_barrier->init(1);
|
| 153 |
+
|
| 154 |
+
// Make initialized barrier visible in async proxy
|
| 155 |
+
cutlass::arch::fence_barrier_init();
|
| 156 |
+
} else if (warp_idx == 2) {
|
| 157 |
+
// Allocate tensor memory
|
| 158 |
+
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
| 159 |
+
}
|
| 160 |
+
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
|
| 161 |
+
|
| 162 |
+
// Block scheduler
|
| 163 |
+
uint32_t m_block_idx, n_block_idx;
|
| 164 |
+
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
|
| 165 |
+
|
| 166 |
+
// Pipeline and TMA phases
|
| 167 |
+
uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0;
|
| 168 |
+
auto advance_pipeline = [&](uint32_t& k_block_idx) {
|
| 169 |
+
++ k_block_idx;
|
| 170 |
+
|
| 171 |
+
// Flip phases only if reach the next first stage
|
| 172 |
+
stage_idx = (stage_idx + 1) % kNumStages;
|
| 173 |
+
phase ^= stage_idx == 0;
|
| 174 |
+
};
|
| 175 |
+
|
| 176 |
+
// Dispatch warps into different roles
|
| 177 |
+
if (warp_idx == 0 and cute::elect_one_sync()) {
|
| 178 |
+
// TMA load warp
|
| 179 |
+
// Persistently schedule over blocks
|
| 180 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 181 |
+
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
|
| 182 |
+
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 183 |
+
// Wait consumer release
|
| 184 |
+
empty_barriers[stage_idx]->wait(phase ^ 1);
|
| 185 |
+
|
| 186 |
+
// Compute offsets
|
| 187 |
+
// NOTES: the group is always concatenated with the outer dimension
|
| 188 |
+
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> (
|
| 189 |
+
shape_m, BLOCK_M, m_block_idx);
|
| 190 |
+
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> (
|
| 191 |
+
shape_n, BLOCK_N, n_block_idx, m_block_idx);
|
| 192 |
+
|
| 193 |
+
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
|
| 194 |
+
// And for all m-grouped GEMMs, A must be K-majored
|
| 195 |
+
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
|
| 196 |
+
kMajorA == cute::UMMA::Major::K, "Invalid major");
|
| 197 |
+
uint32_t k_idx = k_block_idx * BLOCK_K;
|
| 198 |
+
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> (
|
| 199 |
+
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
| 200 |
+
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> (
|
| 201 |
+
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
| 202 |
+
|
| 203 |
+
// Add 2 CTA offsets
|
| 204 |
+
if constexpr (kNumMulticast > 1) {
|
| 205 |
+
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
|
| 206 |
+
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
// Issue TMAs
|
| 210 |
+
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
|
| 211 |
+
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
|
| 212 |
+
if constexpr (kMajorA == cute::UMMA::Major::K)
|
| 213 |
+
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
|
| 214 |
+
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, kNumMulticast, batch_idx);
|
| 215 |
+
if constexpr (kMajorA == cute::UMMA::Major::MN)
|
| 216 |
+
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
|
| 217 |
+
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, kNumMulticast, batch_idx);
|
| 218 |
+
if constexpr (kMajorB == cute::UMMA::Major::K)
|
| 219 |
+
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
|
| 220 |
+
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, kNumMulticast, batch_idx);
|
| 221 |
+
if constexpr (kMajorB == cute::UMMA::Major::MN)
|
| 222 |
+
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
|
| 223 |
+
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, kNumMulticast, batch_idx);
|
| 224 |
+
|
| 225 |
+
// Arrive at full barriers
|
| 226 |
+
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
| 227 |
+
if (is_leader_cta) {
|
| 228 |
+
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast);
|
| 229 |
+
} else {
|
| 230 |
+
full_barriers[stage_idx]->arrive(0u);
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
}
|
| 234 |
+
} else if (warp_idx == 1 and is_leader_cta) {
|
| 235 |
+
// MMA issue warp
|
| 236 |
+
// NOTES: only the leader CTA will do this
|
| 237 |
+
// Make instruction descriptor
|
| 238 |
+
// TODO: refactor `UMMA_M` calculation
|
| 239 |
+
constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast);
|
| 240 |
+
constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1);
|
| 241 |
+
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t);
|
| 242 |
+
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, kMajorA, kMajorB>();
|
| 243 |
+
|
| 244 |
+
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
| 245 |
+
// Merged stages only happens in NT normal GEMM cases
|
| 246 |
+
constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge;
|
| 247 |
+
auto a_desc = make_umma_desc<kMajorA, LOAD_BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode>(smem_a[0], 0, 0);
|
| 248 |
+
auto b_desc = make_umma_desc<kMajorB, LOAD_BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode>(smem_b[0], 0, 0);
|
| 249 |
+
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
|
| 250 |
+
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
| 251 |
+
|
| 252 |
+
// Checks for MMA instructions
|
| 253 |
+
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
|
| 254 |
+
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
|
| 255 |
+
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
|
| 256 |
+
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
|
| 257 |
+
"Invalid MMA instruction shape");
|
| 258 |
+
|
| 259 |
+
// Persistently schedule over blocks
|
| 260 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 261 |
+
// Wait tensor memory empty barrier arrival
|
| 262 |
+
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
|
| 263 |
+
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
| 264 |
+
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
|
| 265 |
+
tcgen05_after_thread_sync();
|
| 266 |
+
|
| 267 |
+
// UMMA and empty barrier arrival alias
|
| 268 |
+
auto umma_arrive = [](const uint64_t* barrier) {
|
| 269 |
+
if constexpr (kNumMulticast == 1) {
|
| 270 |
+
cutlass::arch::umma_arrive(barrier);
|
| 271 |
+
} else {
|
| 272 |
+
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
|
| 273 |
+
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
|
| 274 |
+
}
|
| 275 |
+
};
|
| 276 |
+
auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) {
|
| 277 |
+
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
|
| 278 |
+
|
| 279 |
+
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
|
| 280 |
+
if (do_tmem_full_arrive)
|
| 281 |
+
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
|
| 282 |
+
};
|
| 283 |
+
|
| 284 |
+
// Launch MMAs
|
| 285 |
+
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
|
| 286 |
+
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 287 |
+
// Wait TMA arrival
|
| 288 |
+
full_barriers[stage_idx]->wait(phase);
|
| 289 |
+
tcgen05_after_thread_sync();
|
| 290 |
+
|
| 291 |
+
// Issue UMMA in the leader CTA
|
| 292 |
+
using mma_t = cute::conditional_t<kNumMulticast == 1, SM100_MMA_F16BF16_SS, SM100_MMA_F16BF16_2x1SM_SS>;
|
| 293 |
+
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
| 294 |
+
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
|
| 295 |
+
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
|
| 296 |
+
if (cute::elect_one_sync()) {
|
| 297 |
+
#pragma unroll
|
| 298 |
+
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
| 299 |
+
uint32_t atom_k_idx = k * UMMA_K / BLOCK_ATOM_K;
|
| 300 |
+
b_desc.lo = advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t>(b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K);
|
| 301 |
+
#pragma unroll
|
| 302 |
+
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
| 303 |
+
DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset");
|
| 304 |
+
a_desc.lo = advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t>(a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K + w * WAVE_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K);
|
| 305 |
+
mma_t::fma(a_desc, b_desc,
|
| 306 |
+
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
|
| 307 |
+
k_block_idx > 0 or k > 0,
|
| 308 |
+
runtime_instr_desc);
|
| 309 |
+
}
|
| 310 |
+
}
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
// Commit to the mbarrier object
|
| 314 |
+
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
|
| 315 |
+
empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1);
|
| 316 |
+
|
| 317 |
+
// Let tensor cores relax for lower possibility of frequency drop
|
| 318 |
+
DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control");
|
| 319 |
+
if constexpr (kTensorCoreUtilControl < 100) {
|
| 320 |
+
// For utilization control
|
| 321 |
+
umma_arrive(reinterpret_cast<uint64_t*>(tensor_core_full_barrier));
|
| 322 |
+
|
| 323 |
+
// Wait for last UMMA to be done
|
| 324 |
+
tensor_core_full_barrier->wait(tensor_core_phase);
|
| 325 |
+
tensor_core_phase ^= 1;
|
| 326 |
+
|
| 327 |
+
// Sleep for certain cycles
|
| 328 |
+
constexpr static uint64_t kNumUMMACycles = (2ull * LAYOUT_AD_M * kNumMWaves * BLOCK_N * BLOCK_K) / 8192ull;
|
| 329 |
+
constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl;
|
| 330 |
+
const auto& start_clock = clock64();
|
| 331 |
+
if (cute::elect_one_sync())
|
| 332 |
+
while (clock64() - start_clock < kNumDummyCycles) {}
|
| 333 |
+
__syncwarp();
|
| 334 |
+
}
|
| 335 |
+
}
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
// To safely deconstruct barriers, we need another round of waits
|
| 339 |
+
const auto& iter_idx = scheduler.current_iter - 1;
|
| 340 |
+
if (kNumMulticast > 1 and iter_idx >= 0) {
|
| 341 |
+
const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
|
| 342 |
+
tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
|
| 343 |
+
}
|
| 344 |
+
} else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) {
|
| 345 |
+
// Epilogue warp groups
|
| 346 |
+
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
|
| 347 |
+
|
| 348 |
+
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
|
| 349 |
+
// i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
|
| 350 |
+
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
|
| 351 |
+
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
|
| 352 |
+
|
| 353 |
+
// TMA checks
|
| 354 |
+
constexpr uint32_t kNumBankGroupBytes = 16;
|
| 355 |
+
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t);
|
| 356 |
+
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
|
| 357 |
+
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
|
| 358 |
+
|
| 359 |
+
// Share store pipeline between blocks
|
| 360 |
+
uint32_t tma_stage_idx = 0;
|
| 361 |
+
auto advance_store_pipeline = [&]() {
|
| 362 |
+
tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
|
| 363 |
+
};
|
| 364 |
+
|
| 365 |
+
// Persistently schedule over blocks
|
| 366 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 367 |
+
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
|
| 368 |
+
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
| 369 |
+
|
| 370 |
+
// Wait UMMA arrival
|
| 371 |
+
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
|
| 372 |
+
tcgen05_after_thread_sync();
|
| 373 |
+
|
| 374 |
+
// Load from tensor memory into registers, and write shared memory with STSM
|
| 375 |
+
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
|
| 376 |
+
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
|
| 377 |
+
|
| 378 |
+
// Iterate over M waves
|
| 379 |
+
#pragma unroll
|
| 380 |
+
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
| 381 |
+
// Issue every swizzled atom and pipeline STSM and TMA store
|
| 382 |
+
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
|
| 383 |
+
#pragma unroll
|
| 384 |
+
for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) {
|
| 385 |
+
// Wait shared memory to be released
|
| 386 |
+
if (epilogue_warp_idx == 0)
|
| 387 |
+
cute::tma_store_wait<kNumTMAStoreStages - 1>();
|
| 388 |
+
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
|
| 389 |
+
|
| 390 |
+
// The pipeline stage
|
| 391 |
+
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
|
| 392 |
+
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
|
| 393 |
+
|
| 394 |
+
// Store into shared memory
|
| 395 |
+
#pragma unroll
|
| 396 |
+
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
|
| 397 |
+
// Calculate the index of the bank group to be written in the atom
|
| 398 |
+
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
|
| 399 |
+
|
| 400 |
+
// Reshape the atom in another view and swizzle
|
| 401 |
+
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
|
| 402 |
+
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
|
| 403 |
+
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
|
| 404 |
+
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
|
| 405 |
+
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
|
| 406 |
+
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
|
| 407 |
+
col ^= row % (kSwizzleCDMode / 16);
|
| 408 |
+
|
| 409 |
+
// Source and destination memory address
|
| 410 |
+
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
|
| 411 |
+
w * BLOCK_N + // Wave offset
|
| 412 |
+
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
|
| 413 |
+
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
|
| 414 |
+
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
|
| 415 |
+
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
| 416 |
+
|
| 417 |
+
// Load from tensor memory, store into shared memory
|
| 418 |
+
uint32_t values[kNumElemsPerBankGroup];
|
| 419 |
+
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
|
| 420 |
+
// For FP32 output, read and store
|
| 421 |
+
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
|
| 422 |
+
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
|
| 423 |
+
values[0], values[1], values[2], values[3]);
|
| 424 |
+
cutlass::arch::fence_view_async_tmem_load();
|
| 425 |
+
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
| 426 |
+
} else {
|
| 427 |
+
// For BF16 output, read, cast and store
|
| 428 |
+
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
|
| 429 |
+
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
|
| 430 |
+
values[0], values[1], values[2], values[3],
|
| 431 |
+
values[4], values[5], values[6], values[7]);
|
| 432 |
+
cutlass::arch::fence_view_async_tmem_load();
|
| 433 |
+
st_shared(smem_ptr,
|
| 434 |
+
cast_into_bf16_and_pack(values[0], values[1]),
|
| 435 |
+
cast_into_bf16_and_pack(values[2], values[3]),
|
| 436 |
+
cast_into_bf16_and_pack(values[4], values[5]),
|
| 437 |
+
cast_into_bf16_and_pack(values[6], values[7]));
|
| 438 |
+
}
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
|
| 442 |
+
// NOTES: only the last stage needs to do this
|
| 443 |
+
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
|
| 444 |
+
tcgen05_before_thread_sync();
|
| 445 |
+
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
|
| 446 |
+
}
|
| 447 |
+
__syncwarp();
|
| 448 |
+
|
| 449 |
+
// Synchronize all threads and issue TMA
|
| 450 |
+
cute::tma_store_fence();
|
| 451 |
+
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
|
| 452 |
+
if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
|
| 453 |
+
if constexpr (kGemmType == GemmType::Batched) {
|
| 454 |
+
using cute_tma_t = cute::conditional_t<kWithAccumulation,
|
| 455 |
+
cute::SM90_TMA_REDUCE_ADD_3D, cute::SM90_TMA_STORE_3D>;
|
| 456 |
+
cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx],
|
| 457 |
+
n_idx, m_idx, scheduler.current_group_idx);
|
| 458 |
+
} else {
|
| 459 |
+
using cute_tma_t = cute::conditional_t<kWithAccumulation,
|
| 460 |
+
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
|
| 461 |
+
cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx);
|
| 462 |
+
}
|
| 463 |
+
cute::tma_store_arrive();
|
| 464 |
+
}
|
| 465 |
+
}
|
| 466 |
+
}
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
// Deallocate tensor memory by the last UMMA store warp
|
| 470 |
+
// NOTES: warp 0 is waiting TMA store
|
| 471 |
+
if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1)
|
| 472 |
+
Allocator().free(0, kNumTmemCols);
|
| 473 |
+
}
|
| 474 |
+
#else
|
| 475 |
+
if (blockIdx.x == 0 and threadIdx.x == 0)
|
| 476 |
+
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
|
| 477 |
+
#endif
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
}; // namespace deep_gemm
|
| 481 |
+
|
| 482 |
+
#pragma clang diagnostic pop
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cute/arch/cluster_sm90.hpp>
|
| 4 |
+
#include <cute/util/type_traits.hpp>
|
| 5 |
+
#include <cutlass/arch/barrier.h>
|
| 6 |
+
|
| 7 |
+
#include <deep_gemm/common/utils.cuh>
|
| 8 |
+
#include <deep_gemm/common/sm100_utils.cuh>
|
| 9 |
+
|
| 10 |
+
namespace deep_gemm {
|
| 11 |
+
|
| 12 |
+
using namespace deep_gemm::sm100;
|
| 13 |
+
|
| 14 |
+
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 15 |
+
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
| 16 |
+
uint32_t kSplitFactor,
|
| 17 |
+
uint32_t kSwizzleABMode, uint32_t kSwizzleCDMode,
|
| 18 |
+
uint32_t kNumStages, uint32_t kNumThreads>
|
| 19 |
+
__global__ void __launch_bounds__(kNumThreads, 1)
|
| 20 |
+
sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
|
| 21 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
| 22 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
| 23 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
|
| 24 |
+
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
|
| 25 |
+
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 26 |
+
|
| 27 |
+
// Configs
|
| 28 |
+
constexpr uint32_t LAYOUT_AD_M = 128;
|
| 29 |
+
constexpr uint32_t kNumTMAStoreStages = 2;
|
| 30 |
+
|
| 31 |
+
// Utils
|
| 32 |
+
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
| 33 |
+
const auto lane_idx = get_lane_idx();
|
| 34 |
+
DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size");
|
| 35 |
+
DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode");
|
| 36 |
+
|
| 37 |
+
// Align to 1024 bytes for swizzle-128B
|
| 38 |
+
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
| 39 |
+
|
| 40 |
+
// Shared memory sizes
|
| 41 |
+
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = BLOCK_M * kSwizzleCDMode;
|
| 42 |
+
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
|
| 43 |
+
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t);
|
| 44 |
+
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t);
|
| 45 |
+
|
| 46 |
+
// Prefetch TMA descriptors at the very beginning
|
| 47 |
+
if (warp_idx == 0 and cute::elect_one_sync()) {
|
| 48 |
+
cute::prefetch_tma_descriptor(&tensor_map_a);
|
| 49 |
+
cute::prefetch_tma_descriptor(&tensor_map_b);
|
| 50 |
+
cute::prefetch_tma_descriptor(&tensor_map_d);
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
// Real tensor memory size and offsets
|
| 54 |
+
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<BLOCK_N>();
|
| 55 |
+
|
| 56 |
+
// Fill D/A/B
|
| 57 |
+
auto smem_cd = PatternVisitor([&](const uint32_t& i) {
|
| 58 |
+
return reinterpret_cast<float*>(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE));
|
| 59 |
+
});
|
| 60 |
+
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
| 61 |
+
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
|
| 62 |
+
});
|
| 63 |
+
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
| 64 |
+
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
| 65 |
+
});
|
| 66 |
+
|
| 67 |
+
// Fill barriers
|
| 68 |
+
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
|
| 69 |
+
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 70 |
+
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 71 |
+
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 72 |
+
auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2);
|
| 73 |
+
|
| 74 |
+
// Fill the tensor memory pointer
|
| 75 |
+
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 2 + 1);
|
| 76 |
+
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
| 77 |
+
|
| 78 |
+
// Initialize barriers
|
| 79 |
+
if (warp_idx == 1 and cute::elect_one_sync()) {
|
| 80 |
+
#pragma unroll
|
| 81 |
+
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
| 82 |
+
full_barriers[i]->init(1);
|
| 83 |
+
empty_barriers[i]->init(1);
|
| 84 |
+
}
|
| 85 |
+
tmem_full_barrier->init(1);
|
| 86 |
+
|
| 87 |
+
// Make initialized barrier visible in async proxy
|
| 88 |
+
cutlass::arch::fence_barrier_init();
|
| 89 |
+
} else if (warp_idx == 2) {
|
| 90 |
+
// Allocate tensor memory
|
| 91 |
+
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
| 92 |
+
}
|
| 93 |
+
__syncthreads();
|
| 94 |
+
|
| 95 |
+
// Block indices
|
| 96 |
+
const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N);
|
| 97 |
+
const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M);
|
| 98 |
+
const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
|
| 99 |
+
const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
|
| 100 |
+
const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
|
| 101 |
+
const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
|
| 102 |
+
const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
|
| 103 |
+
|
| 104 |
+
if (warp_idx == 0) {
|
| 105 |
+
// TMA load warp
|
| 106 |
+
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
| 107 |
+
const auto& stage_idx = s % kNumStages;
|
| 108 |
+
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
|
| 109 |
+
|
| 110 |
+
uint32_t m_idx = BLOCK_M * m_block_idx;
|
| 111 |
+
uint32_t n_idx = BLOCK_N * n_block_idx;
|
| 112 |
+
uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K;
|
| 113 |
+
uint32_t k_idx = sk_idx % SHAPE_K;
|
| 114 |
+
uint32_t s_idx = sk_idx / SHAPE_K;
|
| 115 |
+
|
| 116 |
+
// Issue TMAs
|
| 117 |
+
if (cute::elect_one_sync()) {
|
| 118 |
+
tma_copy<BLOCK_K, BLOCK_M, kSwizzleABMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M);
|
| 119 |
+
tma_copy<BLOCK_K, BLOCK_N, kSwizzleABMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N);
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
// Arrive at full barriers
|
| 123 |
+
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
| 124 |
+
if (cute::elect_one_sync())
|
| 125 |
+
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes);
|
| 126 |
+
}
|
| 127 |
+
} else if (warp_idx == 1) {
|
| 128 |
+
// MMA issue warp
|
| 129 |
+
// NOTES: only the leader CTA will do this
|
| 130 |
+
// Make instruction descriptor
|
| 131 |
+
constexpr uint32_t UMMA_M = LAYOUT_AD_M;
|
| 132 |
+
constexpr uint32_t UMMA_N = BLOCK_N;
|
| 133 |
+
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t);
|
| 134 |
+
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
|
| 135 |
+
|
| 136 |
+
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
| 137 |
+
auto a_desc = make_umma_desc<cute::UMMA::Major::K, BLOCK_M, BLOCK_K, kSwizzleABMode>(smem_a[0], 0, 0);
|
| 138 |
+
auto b_desc = make_umma_desc<cute::UMMA::Major::K, BLOCK_N, BLOCK_K, kSwizzleABMode>(smem_b[0], 0, 0);
|
| 139 |
+
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
|
| 140 |
+
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
| 141 |
+
|
| 142 |
+
// Checks for MMA instructions
|
| 143 |
+
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
|
| 144 |
+
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
|
| 145 |
+
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
|
| 146 |
+
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
|
| 147 |
+
"Invalid MMA instruction shape");
|
| 148 |
+
|
| 149 |
+
// Wait tensor memory empty barrier arrival
|
| 150 |
+
tcgen05_after_thread_sync();
|
| 151 |
+
|
| 152 |
+
// Launch MMAs
|
| 153 |
+
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
| 154 |
+
// Wait TMA arrival
|
| 155 |
+
const auto& stage_idx = s % kNumStages;
|
| 156 |
+
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
|
| 157 |
+
tcgen05_after_thread_sync();
|
| 158 |
+
|
| 159 |
+
// Issue UMMA in the leader CTA
|
| 160 |
+
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
| 161 |
+
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, stage_idx);
|
| 162 |
+
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, stage_idx);
|
| 163 |
+
if (cute::elect_one_sync()) {
|
| 164 |
+
#pragma unroll
|
| 165 |
+
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
| 166 |
+
a_desc.lo = advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_M, kSwizzleABMode, cutlass::bfloat16_t>(a_desc_base_lo, 0, k * UMMA_K);
|
| 167 |
+
b_desc.lo = advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_N, kSwizzleABMode, cutlass::bfloat16_t>(b_desc_base_lo, 0, k * UMMA_K);
|
| 168 |
+
SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc);
|
| 169 |
+
}
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
// Commit to the mbarrier object
|
| 173 |
+
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
|
| 174 |
+
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
|
| 175 |
+
}
|
| 176 |
+
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barrier));
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
|
| 180 |
+
// i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`.
|
| 181 |
+
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
|
| 182 |
+
if (warp_idx == 2)
|
| 183 |
+
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
|
| 184 |
+
|
| 185 |
+
// TMA checks
|
| 186 |
+
constexpr uint32_t kNumBankGroupBytes = 16;
|
| 187 |
+
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float);
|
| 188 |
+
constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(float);
|
| 189 |
+
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
|
| 190 |
+
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
|
| 191 |
+
|
| 192 |
+
// Wait UMMA arrival
|
| 193 |
+
tmem_full_barrier->wait(0);
|
| 194 |
+
tcgen05_after_thread_sync();
|
| 195 |
+
|
| 196 |
+
// Load from tensor memory into registers, and write shared memory with STSM
|
| 197 |
+
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
|
| 198 |
+
|
| 199 |
+
// Issue every swizzled atom and pipeline STSM and TMA store
|
| 200 |
+
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
|
| 201 |
+
#pragma unroll
|
| 202 |
+
for (uint32_t s = 0; s < kNumStores; ++ s) {
|
| 203 |
+
// Wait shared memory to be released
|
| 204 |
+
if (s >= kNumTMAStoreStages) {
|
| 205 |
+
if (warp_idx == 0 and cute::elect_one_sync())
|
| 206 |
+
cute::tma_store_wait<kNumTMAStoreStages - 1>();
|
| 207 |
+
cutlass::arch::NamedBarrier(kNumThreads).sync();
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
// The pipeline stage
|
| 211 |
+
const auto tma_stage_idx = s % kNumTMAStoreStages;
|
| 212 |
+
const auto m_idx = m_block_idx * BLOCK_M;
|
| 213 |
+
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
|
| 214 |
+
|
| 215 |
+
// Store into shared memory
|
| 216 |
+
#pragma unroll
|
| 217 |
+
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
|
| 218 |
+
// Calculate the index of the bank group to be written in the atom
|
| 219 |
+
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
|
| 220 |
+
|
| 221 |
+
// Reshape the atom in another view and swizzle
|
| 222 |
+
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
|
| 223 |
+
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
|
| 224 |
+
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
|
| 225 |
+
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
|
| 226 |
+
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
|
| 227 |
+
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
|
| 228 |
+
col ^= row % (kSwizzleCDMode / 16);
|
| 229 |
+
|
| 230 |
+
// Source and destination memory address
|
| 231 |
+
uint32_t tmem_addr = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
|
| 232 |
+
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
|
| 233 |
+
warp_idx * 32 * kSwizzleCDMode + // Warp offset
|
| 234 |
+
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
| 235 |
+
|
| 236 |
+
// Load from tensor memory, store into shared memory
|
| 237 |
+
uint32_t values[kNumElemsPerBankGroup];
|
| 238 |
+
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
|
| 239 |
+
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
|
| 240 |
+
values[0], values[1], values[2], values[3]);
|
| 241 |
+
cutlass::arch::fence_view_async_tmem_load();
|
| 242 |
+
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
// Synchronize all threads and issue TMA
|
| 246 |
+
cute::tma_store_fence();
|
| 247 |
+
cutlass::arch::NamedBarrier(kNumThreads).sync();
|
| 248 |
+
if (warp_idx == 0 and cute::elect_one_sync()) {
|
| 249 |
+
cute::SM90_TMA_REDUCE_ADD_2D::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
|
| 250 |
+
cute::tma_store_arrive();
|
| 251 |
+
}
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
// Deallocate tensor memory by warp 1
|
| 255 |
+
// NOTES: warp 0 is doing TMA stores
|
| 256 |
+
if (warp_idx == 1)
|
| 257 |
+
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
|
| 258 |
+
|
| 259 |
+
#else
|
| 260 |
+
if (blockIdx.x == 0 and threadIdx.x == 0)
|
| 261 |
+
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
|
| 262 |
+
#endif
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
}
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#pragma clang diagnostic push
|
| 3 |
+
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
| 4 |
+
|
| 5 |
+
#include <cutlass/arch/barrier.h>
|
| 6 |
+
|
| 7 |
+
#include <deep_gemm/common/epilogue_utils.cuh>
|
| 8 |
+
#include <deep_gemm/common/scheduler.cuh>
|
| 9 |
+
#include <deep_gemm/common/utils.cuh>
|
| 10 |
+
#include <deep_gemm/common/sm100_utils.cuh>
|
| 11 |
+
|
| 12 |
+
namespace deep_gemm {
|
| 13 |
+
|
| 14 |
+
using namespace deep_gemm::sm100;
|
| 15 |
+
|
| 16 |
+
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
| 17 |
+
uint32_t kGranKA, uint32_t kGranKB,
|
| 18 |
+
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 19 |
+
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
| 20 |
+
uint32_t kNumGroups,
|
| 21 |
+
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
|
| 22 |
+
uint32_t kNumStages,
|
| 23 |
+
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
|
| 24 |
+
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
| 25 |
+
uint32_t kNumSMs,
|
| 26 |
+
GemmType kGemmType, bool kWithAccumulation,
|
| 27 |
+
typename a_dtype_t, typename b_dtype_t, typename cd_dtype_t,
|
| 28 |
+
typename epilogue_type_t>
|
| 29 |
+
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
|
| 30 |
+
sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
| 31 |
+
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
| 32 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
| 33 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
| 34 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa,
|
| 35 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_sfb,
|
| 36 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_cd) {
|
| 37 |
+
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
|
| 38 |
+
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 39 |
+
using Allocator = cute::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
|
| 40 |
+
|
| 41 |
+
// GEMM with accumulation must have FP32 output
|
| 42 |
+
if constexpr (kWithAccumulation)
|
| 43 |
+
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
|
| 44 |
+
|
| 45 |
+
// Configs
|
| 46 |
+
constexpr uint32_t LAYOUT_AD_M = 128;
|
| 47 |
+
constexpr uint32_t WAVE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
|
| 48 |
+
constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M;
|
| 49 |
+
constexpr uint32_t kNumTMAStoreStages = 2;
|
| 50 |
+
constexpr uint32_t kNumUTCCPAlignedElems = 128;
|
| 51 |
+
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
|
| 52 |
+
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M");
|
| 53 |
+
|
| 54 |
+
constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4;
|
| 55 |
+
constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4;
|
| 56 |
+
DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A");
|
| 57 |
+
DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B");
|
| 58 |
+
|
| 59 |
+
// Overwrite shape constants if the compiler gives
|
| 60 |
+
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
|
| 61 |
+
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
|
| 62 |
+
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
| 63 |
+
const uint32_t shape_sfa_k = ceil_div(shape_k, kGranKA * 4);
|
| 64 |
+
const uint32_t shape_sfb_k = ceil_div(shape_k, kGranKB * 4);
|
| 65 |
+
|
| 66 |
+
// Utils
|
| 67 |
+
bool is_leader_cta = cute::block_rank_in_cluster() == 0;
|
| 68 |
+
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
| 69 |
+
const auto lane_idx = get_lane_idx();
|
| 70 |
+
|
| 71 |
+
// Align to 1024 bytes for swizzle-128B
|
| 72 |
+
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
| 73 |
+
|
| 74 |
+
// 2-CTA MMA
|
| 75 |
+
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
|
| 76 |
+
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
|
| 77 |
+
constexpr uint32_t STORE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
|
| 78 |
+
constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t);
|
| 79 |
+
constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M;
|
| 80 |
+
DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast");
|
| 81 |
+
DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D");
|
| 82 |
+
DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
|
| 83 |
+
DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M");
|
| 84 |
+
|
| 85 |
+
// Share memory sizes
|
| 86 |
+
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode;
|
| 87 |
+
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
|
| 88 |
+
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t);
|
| 89 |
+
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t);
|
| 90 |
+
constexpr uint32_t SF_BLOCK_M = constexpr_align(BLOCK_M, kNumUTCCPAlignedElems);
|
| 91 |
+
constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems);
|
| 92 |
+
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
|
| 93 |
+
constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t);
|
| 94 |
+
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0,
|
| 95 |
+
"Shared memory of A/B must be aligned to 1024 bytes");
|
| 96 |
+
DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
|
| 97 |
+
|
| 98 |
+
// NOTES: Make sure we have enough shared memory for UMMA padding
|
| 99 |
+
static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t);
|
| 100 |
+
DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA");
|
| 101 |
+
|
| 102 |
+
// Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size
|
| 103 |
+
// TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2`
|
| 104 |
+
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;
|
| 105 |
+
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;
|
| 106 |
+
constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N + kNumSFATmemCols + kNumSFBTmemCols) > 512 ? 1 : 2;
|
| 107 |
+
|
| 108 |
+
// Real tensor memory size and offsets
|
| 109 |
+
constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N;
|
| 110 |
+
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>();
|
| 111 |
+
constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols;
|
| 112 |
+
constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
|
| 113 |
+
|
| 114 |
+
// Prefetch TMA descriptors at the very beginning
|
| 115 |
+
if (warp_idx == 0 and cute::elect_one_sync()) {
|
| 116 |
+
cute::prefetch_tma_descriptor(&tensor_map_a);
|
| 117 |
+
cute::prefetch_tma_descriptor(&tensor_map_b);
|
| 118 |
+
cute::prefetch_tma_descriptor(&tensor_map_sfa);
|
| 119 |
+
cute::prefetch_tma_descriptor(&tensor_map_sfb);
|
| 120 |
+
cute::prefetch_tma_descriptor(&tensor_map_cd);
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
// D/A/B shared memory
|
| 124 |
+
auto smem_cd = PatternVisitor([&](const uint32_t& i) {
|
| 125 |
+
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
|
| 126 |
+
});
|
| 127 |
+
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
| 128 |
+
return reinterpret_cast<a_dtype_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
| 129 |
+
});
|
| 130 |
+
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
| 131 |
+
return reinterpret_cast<b_dtype_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
| 132 |
+
});
|
| 133 |
+
|
| 134 |
+
// SFA/SFB shared memory
|
| 135 |
+
auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
| 136 |
+
auto smem_sfa = PatternVisitor([=](const uint32_t& i) {
|
| 137 |
+
return reinterpret_cast<uint32_t*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
|
| 138 |
+
});
|
| 139 |
+
auto smem_sfb = PatternVisitor([=](const uint32_t& i) {
|
| 140 |
+
return reinterpret_cast<uint32_t*>(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE);
|
| 141 |
+
});
|
| 142 |
+
|
| 143 |
+
// Fill barriers
|
| 144 |
+
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer +
|
| 145 |
+
SMEM_CD_SIZE +
|
| 146 |
+
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) +
|
| 147 |
+
kNumStages * (SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE));
|
| 148 |
+
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 149 |
+
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 150 |
+
auto with_sf_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
|
| 151 |
+
auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
|
| 152 |
+
auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); });
|
| 153 |
+
|
| 154 |
+
// Fill the tensor memory pointer
|
| 155 |
+
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
|
| 156 |
+
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
| 157 |
+
|
| 158 |
+
// Initialize barriers
|
| 159 |
+
if (warp_idx == 1 and cute::elect_one_sync()) {
|
| 160 |
+
#pragma unroll
|
| 161 |
+
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
| 162 |
+
// Arrive at all CTAs
|
| 163 |
+
full_barriers[i]->init(1);
|
| 164 |
+
empty_barriers[i]->init(1);
|
| 165 |
+
// Arrive only at the leader CTA
|
| 166 |
+
with_sf_full_barriers[i]->init(kNumMulticast * 32);
|
| 167 |
+
}
|
| 168 |
+
#pragma unroll
|
| 169 |
+
for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) {
|
| 170 |
+
// Arrive at all CTAs
|
| 171 |
+
tmem_full_barriers[i]->init(1);
|
| 172 |
+
// Arrive only at the leader CTA
|
| 173 |
+
tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads);
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
// Make initialized barrier visible in async proxy
|
| 177 |
+
cutlass::arch::fence_barrier_init();
|
| 178 |
+
} else if (warp_idx == 2) {
|
| 179 |
+
// Allocate tensor memory
|
| 180 |
+
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
| 181 |
+
}
|
| 182 |
+
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
|
| 183 |
+
|
| 184 |
+
// Block scheduler
|
| 185 |
+
uint32_t m_block_idx, n_block_idx;
|
| 186 |
+
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
|
| 187 |
+
|
| 188 |
+
// Pipeline and TMA phases
|
| 189 |
+
uint32_t stage_idx = 0, phase = 0;
|
| 190 |
+
auto advance_pipeline = [&](uint32_t& k_block_idx) {
|
| 191 |
+
++ k_block_idx;
|
| 192 |
+
|
| 193 |
+
// Flip phases only if reach the next first stage
|
| 194 |
+
stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
|
| 195 |
+
phase ^= stage_idx == 0;
|
| 196 |
+
};
|
| 197 |
+
|
| 198 |
+
// Dispatch warps into different roles
|
| 199 |
+
if (warp_idx == 0 and cute::elect_one_sync()) {
|
| 200 |
+
// TMA load warp
|
| 201 |
+
// Persistently schedule over blocks
|
| 202 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 203 |
+
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
|
| 204 |
+
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 205 |
+
// Wait consumer release
|
| 206 |
+
empty_barriers[stage_idx]->wait(phase ^ 1);
|
| 207 |
+
|
| 208 |
+
// Compute offsets
|
| 209 |
+
// NOTES: the group is always concatenated with the outer dimension
|
| 210 |
+
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> (
|
| 211 |
+
shape_m, BLOCK_M, m_block_idx);
|
| 212 |
+
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> (
|
| 213 |
+
shape_n, BLOCK_N, n_block_idx, m_block_idx);
|
| 214 |
+
|
| 215 |
+
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
|
| 216 |
+
// And for all m-grouped GEMMs, A must be K-majored
|
| 217 |
+
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
|
| 218 |
+
kMajorA == cute::UMMA::Major::K, "Invalid major");
|
| 219 |
+
uint32_t k_idx = k_block_idx * BLOCK_K;
|
| 220 |
+
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> (
|
| 221 |
+
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
| 222 |
+
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> (
|
| 223 |
+
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
| 224 |
+
|
| 225 |
+
// Add 2 CTA offsets
|
| 226 |
+
if constexpr (kNumMulticast > 1) {
|
| 227 |
+
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
|
| 228 |
+
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
// Issue TMAs
|
| 232 |
+
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
|
| 233 |
+
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
|
| 234 |
+
if constexpr (kMajorA == cute::UMMA::Major::K)
|
| 235 |
+
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
|
| 236 |
+
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx);
|
| 237 |
+
if constexpr (kMajorA == cute::UMMA::Major::MN)
|
| 238 |
+
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
|
| 239 |
+
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx);
|
| 240 |
+
if constexpr (kMajorB == cute::UMMA::Major::K)
|
| 241 |
+
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
|
| 242 |
+
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx);
|
| 243 |
+
if constexpr (kMajorB == cute::UMMA::Major::MN)
|
| 244 |
+
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
|
| 245 |
+
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx);
|
| 246 |
+
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v<a_dtype_t, cutlass::float_e4m3_t> ? 1 : 2) +
|
| 247 |
+
SMEM_B_SIZE_PER_STAGE / (std::is_same_v<b_dtype_t, cutlass::float_e4m3_t> ? 1 : 2);
|
| 248 |
+
|
| 249 |
+
// Issue SFA and SFB TMAs at certain stages
|
| 250 |
+
// No swizzling, so one TMA for one SF is enough
|
| 251 |
+
if (k_block_idx % kNumSFAStagesPerLoad == 0) {
|
| 252 |
+
tma_copy<BLOCK_M, 1, 0>(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], m_block_idx * BLOCK_M,
|
| 253 |
+
scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::SF_K>(shape_sfa_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad)));
|
| 254 |
+
num_arrival_bytes += BLOCK_M * sizeof(uint32_t);
|
| 255 |
+
}
|
| 256 |
+
if (k_block_idx % kNumSFBStagesPerLoad == 0) {
|
| 257 |
+
tma_copy<BLOCK_N, 1, 0>(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], n_block_idx * BLOCK_N,
|
| 258 |
+
scheduler.template get_global_idx<true, IndexType::SF_K>(shape_sfb_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx));
|
| 259 |
+
num_arrival_bytes += BLOCK_N * sizeof(uint32_t);
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
// Arrive at full barriers
|
| 263 |
+
full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes);
|
| 264 |
+
}
|
| 265 |
+
}
|
| 266 |
+
} else if (warp_idx == 1 and is_leader_cta) {
|
| 267 |
+
// MMA issue warp
|
| 268 |
+
// NOTES: only the leader CTA will do this
|
| 269 |
+
// Make instruction descriptor
|
| 270 |
+
// TODO: refactor `UMMA_M` calculation
|
| 271 |
+
constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast);
|
| 272 |
+
constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1);
|
| 273 |
+
constexpr uint32_t UMMA_K = 32;
|
| 274 |
+
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<a_dtype_t, b_dtype_t, float, cutlass::float_ue8m0_t,
|
| 275 |
+
UMMA_M, UMMA_N, kMajorA, kMajorB>();
|
| 276 |
+
auto sf_desc = make_sf_desc(nullptr);
|
| 277 |
+
|
| 278 |
+
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
| 279 |
+
auto a_desc = make_umma_desc<kMajorA, LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode>(smem_a[0], 0, 0);
|
| 280 |
+
auto b_desc = make_umma_desc<kMajorB, LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode>(smem_b[0], 0, 0);
|
| 281 |
+
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
|
| 282 |
+
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
| 283 |
+
|
| 284 |
+
// Checks for MMA instructions
|
| 285 |
+
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
|
| 286 |
+
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
|
| 287 |
+
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
|
| 288 |
+
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
|
| 289 |
+
"Invalid MMA instruction shape");
|
| 290 |
+
|
| 291 |
+
// Persistently schedule over blocks
|
| 292 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 293 |
+
// Wait tensor memory empty barrier arrival
|
| 294 |
+
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
|
| 295 |
+
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
| 296 |
+
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
|
| 297 |
+
tcgen05_after_thread_sync();
|
| 298 |
+
|
| 299 |
+
// Empty barrier arrival
|
| 300 |
+
auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) {
|
| 301 |
+
auto umma_arrive = [](const uint64_t* barrier) {
|
| 302 |
+
if constexpr (kNumMulticast == 1) {
|
| 303 |
+
cutlass::arch::umma_arrive(barrier);
|
| 304 |
+
} else {
|
| 305 |
+
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
|
| 306 |
+
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
|
| 307 |
+
}
|
| 308 |
+
};
|
| 309 |
+
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
|
| 310 |
+
|
| 311 |
+
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
|
| 312 |
+
if (do_tmem_full_arrive)
|
| 313 |
+
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
|
| 314 |
+
};
|
| 315 |
+
|
| 316 |
+
// Launch MMAs
|
| 317 |
+
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
|
| 318 |
+
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 319 |
+
// Wait TMA and SF-transpose arrival
|
| 320 |
+
with_sf_full_barriers[stage_idx]->wait(phase);
|
| 321 |
+
tcgen05_after_thread_sync();
|
| 322 |
+
|
| 323 |
+
// Do SF copy at certain stages
|
| 324 |
+
// NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves
|
| 325 |
+
// TODO: process shared memory descriptor by addition
|
| 326 |
+
using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
|
| 327 |
+
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
|
| 328 |
+
const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad;
|
| 329 |
+
if (sfa_stage_in_group_idx == 0 and cute::elect_one_sync()) {
|
| 330 |
+
#pragma unroll
|
| 331 |
+
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
|
| 332 |
+
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
|
| 333 |
+
replace_smem_desc_addr(sf_desc, smem_ptr);
|
| 334 |
+
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
|
| 335 |
+
}
|
| 336 |
+
}
|
| 337 |
+
const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad;
|
| 338 |
+
if (sfb_stage_in_group_idx == 0 and cute::elect_one_sync()) {
|
| 339 |
+
#pragma unroll
|
| 340 |
+
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
|
| 341 |
+
auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
|
| 342 |
+
replace_smem_desc_addr(sf_desc, smem_ptr);
|
| 343 |
+
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
|
| 344 |
+
}
|
| 345 |
+
}
|
| 346 |
+
__syncwarp();
|
| 347 |
+
|
| 348 |
+
// Issue UMMA in the leader CTA
|
| 349 |
+
using mma_t = cute::conditional_t<kNumMulticast == 1, SM100_MMA_MXF8F6F4_SS, SM100_MMA_MXF8F6F4_2x1SM_SS>;
|
| 350 |
+
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
|
| 351 |
+
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
|
| 352 |
+
if (cute::elect_one_sync()) {
|
| 353 |
+
#pragma unroll
|
| 354 |
+
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
| 355 |
+
const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx);
|
| 356 |
+
const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx);
|
| 357 |
+
const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id);
|
| 358 |
+
|
| 359 |
+
b_desc.lo = advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K);
|
| 360 |
+
#pragma unroll
|
| 361 |
+
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
| 362 |
+
DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset");
|
| 363 |
+
a_desc.lo = advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, w * WAVE_BLOCK_M * BLOCK_K, k * UMMA_K);
|
| 364 |
+
mma_t::fma(a_desc, b_desc,
|
| 365 |
+
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
|
| 366 |
+
k_block_idx > 0 or k > 0,
|
| 367 |
+
runtime_instr_desc,
|
| 368 |
+
kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32),
|
| 369 |
+
kTmemStartColOfSFB);
|
| 370 |
+
}
|
| 371 |
+
}
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
// Commit to the mbarrier object
|
| 375 |
+
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
|
| 376 |
+
empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1);
|
| 377 |
+
}
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
// To safely deconstruct barriers, we need another round of waits
|
| 381 |
+
const auto& iter_idx = scheduler.current_iter - 1;
|
| 382 |
+
if (kNumMulticast > 1 and iter_idx >= 0) {
|
| 383 |
+
const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
|
| 384 |
+
tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
|
| 385 |
+
}
|
| 386 |
+
} else if (warp_idx == 2) {
|
| 387 |
+
// UTCCP transposer
|
| 388 |
+
auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) {
|
| 389 |
+
DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements");
|
| 390 |
+
uint32_t values[4];
|
| 391 |
+
#pragma unroll
|
| 392 |
+
for (uint32_t i = 0; i < 4; ++ i)
|
| 393 |
+
values[i] = ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
|
| 394 |
+
__syncwarp();
|
| 395 |
+
#pragma unroll
|
| 396 |
+
for (uint32_t i = 0; i < 4; ++ i)
|
| 397 |
+
st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
|
| 398 |
+
};
|
| 399 |
+
|
| 400 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 401 |
+
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
|
| 402 |
+
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 403 |
+
// Wait TMA arrival
|
| 404 |
+
full_barriers[stage_idx]->wait(phase);
|
| 405 |
+
|
| 406 |
+
// Transpose for UTCCP at certain stages
|
| 407 |
+
if (k_block_idx % kNumSFAStagesPerLoad == 0) {
|
| 408 |
+
#pragma unroll
|
| 409 |
+
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i)
|
| 410 |
+
utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems);
|
| 411 |
+
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
|
| 412 |
+
cutlass::arch::fence_view_async_shared();
|
| 413 |
+
}
|
| 414 |
+
if (k_block_idx % kNumSFBStagesPerLoad == 0) {
|
| 415 |
+
#pragma unroll
|
| 416 |
+
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i)
|
| 417 |
+
utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems);
|
| 418 |
+
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
|
| 419 |
+
cutlass::arch::fence_view_async_shared();
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
// Arrive
|
| 423 |
+
with_sf_full_barriers[stage_idx]->arrive(0u);
|
| 424 |
+
}
|
| 425 |
+
}
|
| 426 |
+
} else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) {
|
| 427 |
+
// Epilogue warp groups
|
| 428 |
+
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
|
| 429 |
+
|
| 430 |
+
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
|
| 431 |
+
// i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
|
| 432 |
+
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
|
| 433 |
+
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
|
| 434 |
+
|
| 435 |
+
// TMA checks
|
| 436 |
+
constexpr uint32_t kNumBankGroupBytes = 16;
|
| 437 |
+
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t);
|
| 438 |
+
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
|
| 439 |
+
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
|
| 440 |
+
|
| 441 |
+
// Share store pipeline between blocks
|
| 442 |
+
uint32_t tma_stage_idx = 0;
|
| 443 |
+
auto advance_store_pipeline = [&]() {
|
| 444 |
+
tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
|
| 445 |
+
};
|
| 446 |
+
|
| 447 |
+
// Persistently schedule over blocks
|
| 448 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 449 |
+
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
|
| 450 |
+
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
| 451 |
+
|
| 452 |
+
// Wait UMMA arrival
|
| 453 |
+
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
|
| 454 |
+
tcgen05_after_thread_sync();
|
| 455 |
+
|
| 456 |
+
// Load from tensor memory into registers, and write shared memory with STSM
|
| 457 |
+
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
|
| 458 |
+
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
|
| 459 |
+
|
| 460 |
+
// Iterate over M waves
|
| 461 |
+
#pragma unroll
|
| 462 |
+
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
| 463 |
+
// Issue every swizzled atom and pipeline STSM and TMA store
|
| 464 |
+
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
|
| 465 |
+
#pragma unroll
|
| 466 |
+
for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) {
|
| 467 |
+
// Wait shared memory to be released
|
| 468 |
+
if (epilogue_warp_idx == 0)
|
| 469 |
+
cute::tma_store_wait<kNumTMAStoreStages - 1>();
|
| 470 |
+
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
|
| 471 |
+
|
| 472 |
+
// The pipeline stage
|
| 473 |
+
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
|
| 474 |
+
const auto n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N>(n_block_idx * BLOCK_N + s * STORE_BLOCK_N);
|
| 475 |
+
|
| 476 |
+
// Store into shared memory
|
| 477 |
+
#pragma unroll
|
| 478 |
+
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
|
| 479 |
+
// Calculate the index of the bank group to be written in the atom
|
| 480 |
+
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
|
| 481 |
+
|
| 482 |
+
// Reshape the atom in another view and swizzle
|
| 483 |
+
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
|
| 484 |
+
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
|
| 485 |
+
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
|
| 486 |
+
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
|
| 487 |
+
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
|
| 488 |
+
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
|
| 489 |
+
col ^= row % (kSwizzleCDMode / 16);
|
| 490 |
+
|
| 491 |
+
// Source and destination memory address
|
| 492 |
+
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
|
| 493 |
+
w * BLOCK_N + // Wave offset
|
| 494 |
+
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
|
| 495 |
+
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
|
| 496 |
+
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
|
| 497 |
+
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
| 498 |
+
|
| 499 |
+
// Load from tensor memory, store into shared memory
|
| 500 |
+
uint32_t values[kNumElemsPerBankGroup];
|
| 501 |
+
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
|
| 502 |
+
// For FP32 output, read and store
|
| 503 |
+
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
|
| 504 |
+
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
|
| 505 |
+
values[0], values[1], values[2], values[3]);
|
| 506 |
+
cutlass::arch::fence_view_async_tmem_load();
|
| 507 |
+
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
| 508 |
+
} else {
|
| 509 |
+
// For BF16 output, read, cast and store
|
| 510 |
+
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
|
| 511 |
+
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
|
| 512 |
+
values[0], values[1], values[2], values[3],
|
| 513 |
+
values[4], values[5], values[6], values[7]);
|
| 514 |
+
cutlass::arch::fence_view_async_tmem_load();
|
| 515 |
+
st_shared(smem_ptr,
|
| 516 |
+
cast_into_bf16_and_pack(values[0], values[1]),
|
| 517 |
+
cast_into_bf16_and_pack(values[2], values[3]),
|
| 518 |
+
cast_into_bf16_and_pack(values[4], values[5]),
|
| 519 |
+
cast_into_bf16_and_pack(values[6], values[7]));
|
| 520 |
+
}
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
|
| 524 |
+
// NOTES: only the last stage needs to do this
|
| 525 |
+
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
|
| 526 |
+
tcgen05_before_thread_sync();
|
| 527 |
+
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
// Synchronize all threads and issue TMA
|
| 531 |
+
cute::tma_store_fence();
|
| 532 |
+
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
|
| 533 |
+
if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
|
| 534 |
+
if constexpr (kGemmType == GemmType::Batched) {
|
| 535 |
+
using cute_tma_t = cute::conditional_t<kWithAccumulation,
|
| 536 |
+
cute::SM90_TMA_REDUCE_ADD_3D, cute::SM90_TMA_STORE_3D>;
|
| 537 |
+
cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx],
|
| 538 |
+
n_idx, m_idx, scheduler.current_group_idx);
|
| 539 |
+
} else {
|
| 540 |
+
using cute_tma_t = cute::conditional_t<kWithAccumulation,
|
| 541 |
+
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
|
| 542 |
+
cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx);
|
| 543 |
+
}
|
| 544 |
+
cute::tma_store_arrive();
|
| 545 |
+
}
|
| 546 |
+
}
|
| 547 |
+
}
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
// Deallocate tensor memory by the last UMMA store warp
|
| 551 |
+
// NOTES: warp 0 is waiting TMA store
|
| 552 |
+
if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1)
|
| 553 |
+
Allocator().free(0, kNumTmemCols);
|
| 554 |
+
}
|
| 555 |
+
#else
|
| 556 |
+
if (blockIdx.x == 0 and threadIdx.x == 0)
|
| 557 |
+
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
|
| 558 |
+
#endif
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
}; // namespace deep_gemm
|
| 562 |
+
|
| 563 |
+
#pragma clang diagnostic pop
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cutlass/arch/barrier.h>
|
| 4 |
+
#include <cutlass/arch/reg_reconfig.h>
|
| 5 |
+
|
| 6 |
+
#include <cute/arch/cluster_sm90.hpp>
|
| 7 |
+
#include <cute/arch/copy_sm90_desc.hpp>
|
| 8 |
+
|
| 9 |
+
#include <deep_gemm/common/utils.cuh>
|
| 10 |
+
#include <deep_gemm/common/sm90_utils.cuh>
|
| 11 |
+
#include <deep_gemm/common/sm100_utils.cuh>
|
| 12 |
+
|
| 13 |
+
namespace deep_gemm {
|
| 14 |
+
|
| 15 |
+
using namespace deep_gemm::sm90;
|
| 16 |
+
using namespace deep_gemm::sm100;
|
| 17 |
+
|
| 18 |
+
template <uint32_t kNumHeads, uint32_t kHeadDim,
|
| 19 |
+
bool kIsCompressedLogits,
|
| 20 |
+
uint32_t BLOCK_Q, uint32_t BLOCK_KV,
|
| 21 |
+
uint32_t kNumQStages, uint32_t kNumKVStages,
|
| 22 |
+
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
|
| 23 |
+
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
|
| 24 |
+
__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
|
| 25 |
+
void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
| 26 |
+
const uint32_t max_seqlen_k, const uint64_t stride_logits,
|
| 27 |
+
uint32_t* cu_seq_len_k_start,
|
| 28 |
+
uint32_t* cu_seq_len_k_end,
|
| 29 |
+
float* logits,
|
| 30 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
| 31 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
| 32 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
| 33 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
|
| 34 |
+
// TODO: consider TMA multicast
|
| 35 |
+
// Normally, `h (kNumHeads) == 32` and `d (kHeadDim) == 64`
|
| 36 |
+
// For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]`
|
| 37 |
+
// Q should be load only at once for a block
|
| 38 |
+
const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q);
|
| 39 |
+
|
| 40 |
+
// Types
|
| 41 |
+
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 42 |
+
|
| 43 |
+
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
| 44 |
+
const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 45 |
+
const auto& warp_in_group_idx = warp_idx % 4;
|
| 46 |
+
const auto& warpgroup_idx = warp_idx / 4;
|
| 47 |
+
const auto& lane_idx = get_lane_idx();
|
| 48 |
+
|
| 49 |
+
// Prefetch TMA descriptors
|
| 50 |
+
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
|
| 51 |
+
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
| 52 |
+
cute::prefetch_tma_descriptor(&tensor_map_q);
|
| 53 |
+
cute::prefetch_tma_descriptor(&tensor_map_kv);
|
| 54 |
+
cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
|
| 55 |
+
cute::prefetch_tma_descriptor(&tensor_map_weights);
|
| 56 |
+
}
|
| 57 |
+
__syncwarp();
|
| 58 |
+
|
| 59 |
+
// Shared memory configs
|
| 60 |
+
// NOTES: weight may be unaligned
|
| 61 |
+
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
|
| 62 |
+
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float);
|
| 63 |
+
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
|
| 64 |
+
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
|
| 65 |
+
static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u);
|
| 66 |
+
|
| 67 |
+
// Align to 512 bytes for swizzle-64B
|
| 68 |
+
extern __shared__ __align__(512) uint8_t smem_buffer[];
|
| 69 |
+
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling");
|
| 70 |
+
DG_STATIC_ASSERT(SMEM_WEIGHT_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling");
|
| 71 |
+
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling");
|
| 72 |
+
|
| 73 |
+
// TMA configs
|
| 74 |
+
constexpr uint32_t kNumTmemCols = BLOCK_Q * kNumHeads * kNumMathWarpGroups;
|
| 75 |
+
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
|
| 76 |
+
|
| 77 |
+
// Data on shared memory
|
| 78 |
+
auto smem_q = PatternVisitor([&](const uint32_t& i) {
|
| 79 |
+
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer +
|
| 80 |
+
SMEM_Q_SIZE_PER_STAGE * i);
|
| 81 |
+
});
|
| 82 |
+
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
|
| 83 |
+
return reinterpret_cast<float*>(smem_buffer +
|
| 84 |
+
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
| 85 |
+
});
|
| 86 |
+
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
|
| 87 |
+
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (
|
| 88 |
+
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i));
|
| 89 |
+
});
|
| 90 |
+
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
|
| 91 |
+
return reinterpret_cast<float*>(smem_buffer +
|
| 92 |
+
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages +
|
| 93 |
+
SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
|
| 94 |
+
});
|
| 95 |
+
|
| 96 |
+
// TMA barriers
|
| 97 |
+
auto barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
|
| 98 |
+
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
|
| 99 |
+
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
|
| 100 |
+
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
|
| 101 |
+
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
|
| 102 |
+
auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); });
|
| 103 |
+
auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); });
|
| 104 |
+
|
| 105 |
+
// Tensor memory allocation
|
| 106 |
+
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_ptr + kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups * 2);
|
| 107 |
+
|
| 108 |
+
// Initialize barriers
|
| 109 |
+
DG_STATIC_ASSERT(kNumSpecializedThreads % 128 == 0 and kNumSpecializedThreads >= 64, "Invalid threads");
|
| 110 |
+
const bool& is_tma_load_warp = (warp_idx == (kNumMathThreads / 32));
|
| 111 |
+
const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 1));
|
| 112 |
+
if (is_tma_load_warp and cute::elect_one_sync()) {
|
| 113 |
+
#pragma unroll
|
| 114 |
+
for (uint32_t i = 0; i < kNumQStages; ++ i) {
|
| 115 |
+
full_q_barriers[i]->init(1);
|
| 116 |
+
empty_q_barriers[i]->init(kNumMathThreads);
|
| 117 |
+
}
|
| 118 |
+
#pragma unroll
|
| 119 |
+
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
|
| 120 |
+
full_kv_barriers[i]->init(1);
|
| 121 |
+
empty_kv_barriers[i]->init(kNumMathThreads);
|
| 122 |
+
}
|
| 123 |
+
#pragma unroll
|
| 124 |
+
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
|
| 125 |
+
full_umma_barriers[i]->init(1);
|
| 126 |
+
empty_umma_barriers[i]->init(128);
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
// Make initialized barrier visible in async proxy
|
| 130 |
+
cutlass::arch::fence_barrier_init();
|
| 131 |
+
} else if (is_umma_warp) {
|
| 132 |
+
// Allocate tensor memory
|
| 133 |
+
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
| 134 |
+
}
|
| 135 |
+
__syncthreads();
|
| 136 |
+
|
| 137 |
+
// Register reconfigurations
|
| 138 |
+
constexpr uint32_t kNumSpecializedRegisters = 24;
|
| 139 |
+
constexpr uint32_t kNumMathRegisters = 240;
|
| 140 |
+
|
| 141 |
+
// Block scheduler
|
| 142 |
+
uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0;
|
| 143 |
+
const auto& get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
|
| 144 |
+
return {block_q_idx + gridDim.x, q_iter_idx + 1};
|
| 145 |
+
};
|
| 146 |
+
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
|
| 147 |
+
const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
|
| 148 |
+
uint32_t start = cute::numeric_limits<uint32_t>::max();
|
| 149 |
+
uint32_t end = cute::numeric_limits<uint32_t>::min();
|
| 150 |
+
|
| 151 |
+
#pragma unroll
|
| 152 |
+
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 153 |
+
const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
|
| 154 |
+
seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx);
|
| 155 |
+
seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx);
|
| 156 |
+
start = min(start, min(seq_k_start[i], seq_len_kv));
|
| 157 |
+
end = max(end, min(seq_k_end[i], seq_len_kv));
|
| 158 |
+
}
|
| 159 |
+
start = start / 4 * 4;
|
| 160 |
+
return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
|
| 161 |
+
((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase
|
| 162 |
+
start, ceil_div(end - start, BLOCK_KV)}; // Task info
|
| 163 |
+
};
|
| 164 |
+
|
| 165 |
+
// KV pipeline
|
| 166 |
+
uint32_t num_total_kv_blocks = 0;
|
| 167 |
+
const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple<uint32_t, uint32_t> {
|
| 168 |
+
return {
|
| 169 |
+
(num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage
|
| 170 |
+
((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase
|
| 171 |
+
};
|
| 172 |
+
};
|
| 173 |
+
|
| 174 |
+
// UMMA settings
|
| 175 |
+
// Construct instruction with layout D
|
| 176 |
+
constexpr uint32_t UMMA_M = 128;
|
| 177 |
+
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
|
| 178 |
+
constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads;
|
| 179 |
+
|
| 180 |
+
if (is_tma_load_warp) {
|
| 181 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 182 |
+
|
| 183 |
+
// Prefetch
|
| 184 |
+
const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) {
|
| 185 |
+
tma_copy<kHeadDim, BLOCK_Q * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads);
|
| 186 |
+
tma_copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q);
|
| 187 |
+
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
| 188 |
+
};
|
| 189 |
+
if (cute::elect_one_sync() and block_q_idx < num_q_blocks)
|
| 190 |
+
issue_tma_q(0, block_q_idx);
|
| 191 |
+
|
| 192 |
+
// Only the first lane persistently schedules over blocks
|
| 193 |
+
if (cute::elect_one_sync()) {
|
| 194 |
+
while (block_q_idx < num_q_blocks) {
|
| 195 |
+
CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks);
|
| 196 |
+
|
| 197 |
+
// Wait Q consumer release
|
| 198 |
+
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
|
| 199 |
+
|
| 200 |
+
// Issue TMA Q
|
| 201 |
+
if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks)
|
| 202 |
+
issue_tma_q(q_stage_idx, next_block_q_idx);
|
| 203 |
+
|
| 204 |
+
// Issue TMA KV
|
| 205 |
+
#pragma unroll
|
| 206 |
+
for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
|
| 207 |
+
// Wait consumer release
|
| 208 |
+
CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
|
| 209 |
+
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
|
| 210 |
+
|
| 211 |
+
// Issue TMA KV
|
| 212 |
+
tma_copy<kHeadDim, BLOCK_KV, kHeadDim>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
|
| 213 |
+
smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV);
|
| 214 |
+
tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
|
| 215 |
+
smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0);
|
| 216 |
+
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
|
| 217 |
+
}
|
| 218 |
+
num_total_kv_blocks += num_kv_blocks;
|
| 219 |
+
|
| 220 |
+
// Jump to the next block
|
| 221 |
+
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
|
| 222 |
+
}
|
| 223 |
+
}
|
| 224 |
+
} else if (is_umma_warp) {
|
| 225 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 226 |
+
|
| 227 |
+
// Require full allocation
|
| 228 |
+
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
|
| 229 |
+
|
| 230 |
+
// Make UMMA desc
|
| 231 |
+
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
|
| 232 |
+
UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
|
| 233 |
+
auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
| 234 |
+
|
| 235 |
+
while (block_q_idx < num_q_blocks) {
|
| 236 |
+
CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
|
| 237 |
+
|
| 238 |
+
// Wait TMA Q arrival
|
| 239 |
+
full_q_barriers[q_stage_idx]->wait(q_phase);
|
| 240 |
+
|
| 241 |
+
// Compute over KV blocks
|
| 242 |
+
#pragma unroll
|
| 243 |
+
for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
|
| 244 |
+
// Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]`
|
| 245 |
+
// Wait TMA KV arrival
|
| 246 |
+
CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
|
| 247 |
+
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
| 248 |
+
|
| 249 |
+
// Issue UMMA
|
| 250 |
+
DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads, "Invalid block size");
|
| 251 |
+
DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim");
|
| 252 |
+
#pragma unroll
|
| 253 |
+
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
|
| 254 |
+
empty_umma_barriers[i]->wait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1);
|
| 255 |
+
tcgen05_after_thread_sync();
|
| 256 |
+
#pragma unroll
|
| 257 |
+
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
|
| 258 |
+
auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
|
| 259 |
+
smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
|
| 260 |
+
auto b_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
|
| 261 |
+
smem_q[q_stage_idx], 0, k * UMMA_K);
|
| 262 |
+
cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
|
| 263 |
+
}
|
| 264 |
+
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(full_umma_barriers[i]));
|
| 265 |
+
}
|
| 266 |
+
}
|
| 267 |
+
num_total_kv_blocks += num_kv_blocks;
|
| 268 |
+
|
| 269 |
+
// Jump to the next block
|
| 270 |
+
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
|
| 271 |
+
}
|
| 272 |
+
} else if (warp_idx >= kNumMathThreads / 32) {
|
| 273 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 274 |
+
} else if (warp_idx < kNumMathThreads / 32) {
|
| 275 |
+
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
| 276 |
+
|
| 277 |
+
// Offsets
|
| 278 |
+
const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0);
|
| 279 |
+
const auto& warp_offset = warp_idx * 32;
|
| 280 |
+
const auto& v_offset = lane_idx;
|
| 281 |
+
|
| 282 |
+
// Preload weights
|
| 283 |
+
constexpr uint32_t kNumWeightsInReg = cute::min(52, kNumHeads);
|
| 284 |
+
float weights[BLOCK_Q][kNumWeightsInReg];
|
| 285 |
+
DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers");
|
| 286 |
+
|
| 287 |
+
while (block_q_idx < num_q_blocks) {
|
| 288 |
+
CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
|
| 289 |
+
|
| 290 |
+
// Wait TMA Q arrival
|
| 291 |
+
full_q_barriers[q_stage_idx]->wait(q_phase);
|
| 292 |
+
|
| 293 |
+
// Read weights
|
| 294 |
+
#pragma unroll
|
| 295 |
+
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 296 |
+
for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) {
|
| 297 |
+
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
// Compute over KV blocks
|
| 302 |
+
#pragma unroll
|
| 303 |
+
for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
|
| 304 |
+
// Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]`
|
| 305 |
+
// Wait TMA KV arrival
|
| 306 |
+
CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
|
| 307 |
+
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
| 308 |
+
|
| 309 |
+
// Read per-KV scales
|
| 310 |
+
float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_offset);
|
| 311 |
+
|
| 312 |
+
// Wait UMMA arrival
|
| 313 |
+
full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1);
|
| 314 |
+
tcgen05_after_thread_sync();
|
| 315 |
+
|
| 316 |
+
// Release KV empty
|
| 317 |
+
empty_kv_barriers[kv_stage_idx]->arrive();
|
| 318 |
+
|
| 319 |
+
// Reduce over the head dim and store
|
| 320 |
+
const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset;
|
| 321 |
+
static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
|
| 322 |
+
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
|
| 323 |
+
|
| 324 |
+
constexpr uint32_t kNumLDTMElems = kNumHeads * BLOCK_Q;
|
| 325 |
+
DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid kNumLDTMElems");
|
| 326 |
+
uint32_t shifted_accum[kNumLDTMElems];
|
| 327 |
+
auto tmem_load = [&](auto... Is) {
|
| 328 |
+
if constexpr (kNumLDTMElems == 32) {
|
| 329 |
+
cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...);
|
| 330 |
+
} else if constexpr (kNumLDTMElems == 64) {
|
| 331 |
+
cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...);
|
| 332 |
+
} else if constexpr (kNumLDTMElems == 128) {
|
| 333 |
+
cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...);
|
| 334 |
+
}
|
| 335 |
+
};
|
| 336 |
+
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
|
| 337 |
+
cutlass::arch::fence_view_async_tmem_load();
|
| 338 |
+
|
| 339 |
+
tcgen05_before_thread_sync();
|
| 340 |
+
empty_umma_barriers[warpgroup_idx]->arrive();
|
| 341 |
+
|
| 342 |
+
#pragma unroll
|
| 343 |
+
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 344 |
+
auto accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
|
| 345 |
+
|
| 346 |
+
auto sum_0 = make_float2(0, 0);
|
| 347 |
+
auto sum_1 = make_float2(0, 0);
|
| 348 |
+
|
| 349 |
+
const auto& transform_reg = [&](const uint32_t& j, const float2& sum) {
|
| 350 |
+
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
| 351 |
+
auto b = make_float2(weights[i][j], weights[i][j + 1]);
|
| 352 |
+
return __ffma2_rn(a, b, sum);
|
| 353 |
+
};
|
| 354 |
+
|
| 355 |
+
#pragma unroll
|
| 356 |
+
for (int j = 0; j < kNumWeightsInReg; j += 4) {
|
| 357 |
+
sum_0 = transform_reg(j, sum_0);
|
| 358 |
+
sum_1 = transform_reg(j + 2, sum_1);
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
const auto& transform_smem = [&](const uint32_t& j, const float2& sum) {
|
| 362 |
+
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
| 363 |
+
auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j),
|
| 364 |
+
ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1));
|
| 365 |
+
return __ffma2_rn(a, b, sum);
|
| 366 |
+
};
|
| 367 |
+
|
| 368 |
+
#pragma unroll
|
| 369 |
+
for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) {
|
| 370 |
+
sum_0 = transform_smem(j, sum_0);
|
| 371 |
+
sum_1 = transform_smem(j + 2, sum_1);
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
auto sum = __fadd2_rn(sum_0, sum_1);
|
| 375 |
+
float result = scale_kv * (sum.x + sum.y);
|
| 376 |
+
|
| 377 |
+
// Store into the global memory
|
| 378 |
+
// NOTES: we have redundant writes here, consider more carefully
|
| 379 |
+
const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
|
| 380 |
+
if constexpr (kIsCompressedLogits) {
|
| 381 |
+
if (seq_k_start[i] <= kv_offset + v_offset and kv_offset + v_offset < seq_k_end[i])
|
| 382 |
+
logits[q_idx * stride_logits + kv_offset + v_offset - seq_k_start[i]] = result;
|
| 383 |
+
} else {
|
| 384 |
+
logits[q_idx * stride_logits + kv_offset + v_offset] = result;
|
| 385 |
+
}
|
| 386 |
+
}
|
| 387 |
+
}
|
| 388 |
+
num_total_kv_blocks += num_kv_blocks;
|
| 389 |
+
|
| 390 |
+
// Release Q empty
|
| 391 |
+
empty_q_barriers[q_stage_idx]->arrive();
|
| 392 |
+
|
| 393 |
+
// Jump to the next block
|
| 394 |
+
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
|
| 395 |
+
}
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
// Free tensor memory
|
| 399 |
+
__syncthreads();
|
| 400 |
+
if (is_tma_load_warp)
|
| 401 |
+
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
} // namespace deep_gemm
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cutlass/arch/barrier.h>
|
| 4 |
+
#include <cutlass/arch/reg_reconfig.h>
|
| 5 |
+
|
| 6 |
+
#include <cute/arch/cluster_sm90.hpp>
|
| 7 |
+
#include <cute/arch/copy_sm90_desc.hpp>
|
| 8 |
+
|
| 9 |
+
#include <deep_gemm/common/utils.cuh>
|
| 10 |
+
#include <deep_gemm/common/sm90_utils.cuh>
|
| 11 |
+
#include <deep_gemm/common/sm100_utils.cuh>
|
| 12 |
+
|
| 13 |
+
#include <deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh>
|
| 14 |
+
|
| 15 |
+
namespace deep_gemm {
|
| 16 |
+
|
| 17 |
+
using namespace deep_gemm::sm90;
|
| 18 |
+
using namespace deep_gemm::sm100;
|
| 19 |
+
|
| 20 |
+
template <uint32_t kNextN, uint32_t kNumHeads,
|
| 21 |
+
uint32_t kHeadDim, uint32_t BLOCK_KV,
|
| 22 |
+
bool kIsContextLens2D,
|
| 23 |
+
uint32_t kNumQStages, uint32_t kNumKVStages,
|
| 24 |
+
uint32_t SPLIT_KV,
|
| 25 |
+
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
|
| 26 |
+
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
|
| 27 |
+
__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
|
| 28 |
+
void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
| 29 |
+
const uint64_t logits_stride, const uint64_t block_table_stride,
|
| 30 |
+
const uint32_t* context_lens, float* logits,
|
| 31 |
+
const uint32_t* block_table, const uint32_t* schedule_meta,
|
| 32 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
| 33 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
| 34 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
| 35 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
|
| 36 |
+
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 37 |
+
|
| 38 |
+
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
| 39 |
+
const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 40 |
+
const auto& warpgroup_idx = warp_idx / 4;
|
| 41 |
+
const auto& lane_idx = get_lane_idx();
|
| 42 |
+
|
| 43 |
+
// Prefetch TMA descriptors
|
| 44 |
+
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
|
| 45 |
+
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
| 46 |
+
cute::prefetch_tma_descriptor(&tensor_map_q);
|
| 47 |
+
cute::prefetch_tma_descriptor(&tensor_map_kv);
|
| 48 |
+
cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
|
| 49 |
+
cute::prefetch_tma_descriptor(&tensor_map_weights);
|
| 50 |
+
}
|
| 51 |
+
__syncwarp();
|
| 52 |
+
|
| 53 |
+
// Shared memory configs
|
| 54 |
+
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
|
| 55 |
+
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
|
| 56 |
+
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
|
| 57 |
+
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float);
|
| 58 |
+
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float);
|
| 59 |
+
|
| 60 |
+
// Align to swizzling alignment bytes
|
| 61 |
+
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
|
| 62 |
+
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
| 63 |
+
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
| 64 |
+
|
| 65 |
+
// Q and KV data on shared memory
|
| 66 |
+
auto smem_q = PatternVisitor([&](const uint32_t& i) {
|
| 67 |
+
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
|
| 68 |
+
});
|
| 69 |
+
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
|
| 70 |
+
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i);
|
| 71 |
+
});
|
| 72 |
+
constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages;
|
| 73 |
+
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
|
| 74 |
+
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
|
| 75 |
+
});
|
| 76 |
+
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
|
| 77 |
+
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
| 78 |
+
});
|
| 79 |
+
|
| 80 |
+
// Barriers and TMEM pointer on shared memory
|
| 81 |
+
const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
|
| 82 |
+
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
|
| 83 |
+
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
|
| 84 |
+
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
|
| 85 |
+
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
|
| 86 |
+
const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
|
| 87 |
+
auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; });
|
| 88 |
+
auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; });
|
| 89 |
+
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(umma_barrier_ptr + kNumMathWarpGroups * 2);
|
| 90 |
+
|
| 91 |
+
constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups;
|
| 92 |
+
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
|
| 93 |
+
const bool& is_math_warp = (warp_idx < kNumMathWarpGroups * 4);
|
| 94 |
+
const bool& is_tma_load_warp = (warp_idx == kNumMathWarpGroups * 4);
|
| 95 |
+
const bool& is_umma_warp = (warp_idx == kNumMathWarpGroups * 4 + 1);
|
| 96 |
+
|
| 97 |
+
// Initialize barriers
|
| 98 |
+
if (is_tma_load_warp and cute::elect_one_sync()) {
|
| 99 |
+
#pragma unroll
|
| 100 |
+
for (uint32_t i = 0; i < kNumQStages; ++ i) {
|
| 101 |
+
full_q_barriers[i]->init(1);
|
| 102 |
+
empty_q_barriers[i]->init(kNumMathThreads);
|
| 103 |
+
}
|
| 104 |
+
#pragma unroll
|
| 105 |
+
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
|
| 106 |
+
full_kv_barriers[i]->init(1);
|
| 107 |
+
empty_kv_barriers[i]->init(kNumMathThreads);
|
| 108 |
+
}
|
| 109 |
+
cutlass::arch::fence_barrier_init();
|
| 110 |
+
}
|
| 111 |
+
if (is_umma_warp) {
|
| 112 |
+
if (cute::elect_one_sync()) {
|
| 113 |
+
#pragma unroll
|
| 114 |
+
for (uint32_t i = 0; i < kNumMathWarpGroups; ++i) {
|
| 115 |
+
full_umma_barriers[i]->init(1);
|
| 116 |
+
empty_umma_barriers[i]->init(128);
|
| 117 |
+
}
|
| 118 |
+
cutlass::arch::fence_barrier_init();
|
| 119 |
+
}
|
| 120 |
+
// Allocate tensor memory
|
| 121 |
+
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
| 122 |
+
}
|
| 123 |
+
__syncthreads();
|
| 124 |
+
|
| 125 |
+
// Register reconfigurations
|
| 126 |
+
constexpr uint32_t kNumSpecializedRegisters = 40;
|
| 127 |
+
constexpr uint32_t kNumMathRegisters = 232;
|
| 128 |
+
|
| 129 |
+
// Scheduler
|
| 130 |
+
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
|
| 131 |
+
auto scheduler = PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit>(batch_size, blockIdx.x, context_lens, schedule_meta);
|
| 132 |
+
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
|
| 133 |
+
|
| 134 |
+
// Q and KV pipeline
|
| 135 |
+
const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
|
| 136 |
+
return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase
|
| 137 |
+
};
|
| 138 |
+
const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
|
| 139 |
+
return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase
|
| 140 |
+
};
|
| 141 |
+
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
|
| 142 |
+
|
| 143 |
+
// UMMA settings
|
| 144 |
+
// Construct instruction with layout D
|
| 145 |
+
constexpr uint32_t UMMA_M = 128;
|
| 146 |
+
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
|
| 147 |
+
constexpr uint32_t UMMA_N = kNextN * kNumHeads;
|
| 148 |
+
DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
|
| 149 |
+
|
| 150 |
+
if (is_tma_load_warp) {
|
| 151 |
+
// TMA warp-group for loading data
|
| 152 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 153 |
+
|
| 154 |
+
const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) {
|
| 155 |
+
if (cute::elect_one_sync()) {
|
| 156 |
+
tma_copy<kHeadDim, kNextN * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads);
|
| 157 |
+
tma_copy<kNextN * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx);
|
| 158 |
+
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
| 159 |
+
}
|
| 160 |
+
};
|
| 161 |
+
|
| 162 |
+
// Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
|
| 163 |
+
uint32_t q_idx = batch_size, kv_idx, num_kv;
|
| 164 |
+
uint32_t next_q_idx, next_kv_idx, next_num_kv;
|
| 165 |
+
bool fetched_next_task;
|
| 166 |
+
|
| 167 |
+
// Prefetch the first Q
|
| 168 |
+
if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)))
|
| 169 |
+
issue_tma_q(0, next_q_idx), q_iter_idx = 1;
|
| 170 |
+
|
| 171 |
+
int kv_block_idx_ptr = 32;
|
| 172 |
+
uint32_t kv_block_idx_storage;
|
| 173 |
+
|
| 174 |
+
while (fetched_next_task) {
|
| 175 |
+
// Prefetch next Q when current Q changes
|
| 176 |
+
bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1));
|
| 177 |
+
q_idx = next_q_idx;
|
| 178 |
+
kv_idx = next_kv_idx;
|
| 179 |
+
num_kv = next_num_kv;
|
| 180 |
+
|
| 181 |
+
// Read KV block index
|
| 182 |
+
// TODO: deal with `-1`?
|
| 183 |
+
if (kv_idx == 0 or kv_block_idx_ptr == 32) {
|
| 184 |
+
kv_block_idx_ptr = 0;
|
| 185 |
+
kv_block_idx_storage = (kv_idx + lane_idx < num_kv ? __ldg(block_table + q_idx * block_table_stride + (kv_idx + lane_idx)) : 0);
|
| 186 |
+
}
|
| 187 |
+
DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`");
|
| 188 |
+
|
| 189 |
+
// Wait Q consumer release and issue TMA Q
|
| 190 |
+
if (prefetch_q) {
|
| 191 |
+
CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
|
| 192 |
+
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
|
| 193 |
+
issue_tma_q(q_stage_idx, q_idx + 1);
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
int kv_block_idx[kNumBlocksPerSplit];
|
| 197 |
+
#pragma unroll
|
| 198 |
+
for (int i = 0; i < kNumBlocksPerSplit; ++ i)
|
| 199 |
+
kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i);
|
| 200 |
+
kv_block_idx_ptr += kNumBlocksPerSplit;
|
| 201 |
+
|
| 202 |
+
// Wait KV consumer release
|
| 203 |
+
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
| 204 |
+
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
|
| 205 |
+
|
| 206 |
+
if (cute::elect_one_sync()) {
|
| 207 |
+
#pragma unroll
|
| 208 |
+
for (int i = 0; i < kNumBlocksPerSplit; ++ i) {
|
| 209 |
+
tma_copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
|
| 210 |
+
smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i,
|
| 211 |
+
0, 0, 1, kv_block_idx[i]);
|
| 212 |
+
tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
|
| 213 |
+
smem_kv_scales[kv_stage_idx] + BLOCK_KV * i,
|
| 214 |
+
0, kv_block_idx[i]);
|
| 215 |
+
}
|
| 216 |
+
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
// Fetch next task
|
| 220 |
+
fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv);
|
| 221 |
+
}
|
| 222 |
+
} else if (is_umma_warp) {
|
| 223 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 224 |
+
|
| 225 |
+
// Require full allocation
|
| 226 |
+
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
|
| 227 |
+
|
| 228 |
+
// Make UMMA desc
|
| 229 |
+
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
|
| 230 |
+
UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
|
| 231 |
+
auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
| 232 |
+
|
| 233 |
+
uint32_t q_idx = batch_size, kv_idx;
|
| 234 |
+
uint32_t next_q_idx, next_kv_idx, next_num_kv;
|
| 235 |
+
uint32_t q_stage_idx, q_phase;
|
| 236 |
+
uint32_t umma_phase = 1;
|
| 237 |
+
|
| 238 |
+
while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
|
| 239 |
+
if (q_idx != next_q_idx) {
|
| 240 |
+
CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
|
| 241 |
+
full_q_barriers[q_stage_idx]->wait(q_phase);
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
q_idx = next_q_idx;
|
| 245 |
+
kv_idx = next_kv_idx;
|
| 246 |
+
|
| 247 |
+
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
| 248 |
+
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
| 249 |
+
|
| 250 |
+
DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim");
|
| 251 |
+
#pragma unroll
|
| 252 |
+
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
|
| 253 |
+
empty_umma_barriers[i]->wait(umma_phase);
|
| 254 |
+
tcgen05_after_thread_sync();
|
| 255 |
+
#pragma unroll
|
| 256 |
+
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
|
| 257 |
+
auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
|
| 258 |
+
smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
|
| 259 |
+
auto b_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
|
| 260 |
+
smem_q[q_stage_idx], 0, k * UMMA_K);
|
| 261 |
+
cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
|
| 262 |
+
}
|
| 263 |
+
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(full_umma_barriers[i]));
|
| 264 |
+
}
|
| 265 |
+
umma_phase ^= 1;
|
| 266 |
+
}
|
| 267 |
+
} else if (is_math_warp) {
|
| 268 |
+
// Math warp-groups for WGMMA
|
| 269 |
+
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
| 270 |
+
|
| 271 |
+
// Offsets
|
| 272 |
+
const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0);
|
| 273 |
+
const uint32_t thread_idx = threadIdx.x;
|
| 274 |
+
|
| 275 |
+
// Weights
|
| 276 |
+
constexpr uint32_t kNumWeightsInReg = (kNextN == 1 ? kNumHeads : cute::min(48, kNumHeads));
|
| 277 |
+
float weights[kNextN][kNumWeightsInReg];
|
| 278 |
+
DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers");
|
| 279 |
+
|
| 280 |
+
// Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
|
| 281 |
+
uint32_t q_idx = batch_size, kv_idx;
|
| 282 |
+
uint32_t next_q_idx, next_kv_idx, next_num_kv;
|
| 283 |
+
uint32_t q_stage_idx, q_phase;
|
| 284 |
+
uint32_t umma_phase = 0;
|
| 285 |
+
|
| 286 |
+
while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
|
| 287 |
+
// Current Q changes
|
| 288 |
+
if (q_idx != next_q_idx) {
|
| 289 |
+
// Release Last Q empty
|
| 290 |
+
if (q_iter_idx > 0)
|
| 291 |
+
empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive();
|
| 292 |
+
|
| 293 |
+
// Wait TMA Q arrival
|
| 294 |
+
CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
|
| 295 |
+
full_q_barriers[q_stage_idx]->wait(q_phase);
|
| 296 |
+
|
| 297 |
+
// Read weights
|
| 298 |
+
#pragma unroll
|
| 299 |
+
for (uint32_t i = 0; i < kNextN; ++ i) {
|
| 300 |
+
for (uint32_t j = 0; j < kNumWeightsInReg; ++ j)
|
| 301 |
+
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
|
| 302 |
+
}
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
// Get current Q and KV index
|
| 306 |
+
q_idx = next_q_idx;
|
| 307 |
+
kv_idx = next_kv_idx;
|
| 308 |
+
|
| 309 |
+
// Calculate KV offset in advance
|
| 310 |
+
auto kv_offset = q_idx * kNextN * logits_stride + kv_idx * BLOCK_KV;
|
| 311 |
+
|
| 312 |
+
// Compute `[kNextN * kNumHeads, kHeadDim] @ [SPLIT_KV, kHeadDim] -> [kNextN, SPLIT_KV]`
|
| 313 |
+
// Wait TMA KV arrival
|
| 314 |
+
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
| 315 |
+
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
| 316 |
+
|
| 317 |
+
// Read per-KV scales
|
| 318 |
+
float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + thread_idx);
|
| 319 |
+
|
| 320 |
+
// Wait UMMA arrival
|
| 321 |
+
full_umma_barriers[warpgroup_idx]->wait(umma_phase);
|
| 322 |
+
tcgen05_after_thread_sync();
|
| 323 |
+
umma_phase ^= 1;
|
| 324 |
+
|
| 325 |
+
// Release KV empty
|
| 326 |
+
empty_kv_barriers[kv_stage_idx]->arrive();
|
| 327 |
+
|
| 328 |
+
// Reduce over the head dim and store
|
| 329 |
+
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
|
| 330 |
+
constexpr uint32_t kNumLDTMElems = kNumHeads * kNextN;
|
| 331 |
+
uint32_t shifted_accum[kNumLDTMElems];
|
| 332 |
+
DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid LDTM");
|
| 333 |
+
auto tmem_load = [&](auto... Is) {
|
| 334 |
+
if constexpr (kNumLDTMElems == 32) {
|
| 335 |
+
cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...);
|
| 336 |
+
} else if constexpr (kNumLDTMElems == 64) {
|
| 337 |
+
cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...);
|
| 338 |
+
} else if constexpr (kNumLDTMElems == 128) {
|
| 339 |
+
cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...);
|
| 340 |
+
}
|
| 341 |
+
};
|
| 342 |
+
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
|
| 343 |
+
cutlass::arch::fence_view_async_tmem_load();
|
| 344 |
+
|
| 345 |
+
tcgen05_before_thread_sync();
|
| 346 |
+
empty_umma_barriers[warpgroup_idx]->arrive();
|
| 347 |
+
|
| 348 |
+
#pragma unroll
|
| 349 |
+
for (uint32_t i = 0; i < kNextN; ++ i) {
|
| 350 |
+
auto accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
|
| 351 |
+
|
| 352 |
+
auto sum_0 = make_float2(0, 0);
|
| 353 |
+
auto sum_1 = make_float2(0, 0);
|
| 354 |
+
|
| 355 |
+
const auto& transform_reg = [&](const uint32_t& j, const float2& sum) {
|
| 356 |
+
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
| 357 |
+
auto b = make_float2(weights[i][j], weights[i][j + 1]);
|
| 358 |
+
return __ffma2_rn(a, b, sum);
|
| 359 |
+
};
|
| 360 |
+
|
| 361 |
+
#pragma unroll
|
| 362 |
+
for (int j = 0; j < kNumWeightsInReg; j += 4) {
|
| 363 |
+
sum_0 = transform_reg(j, sum_0);
|
| 364 |
+
sum_1 = transform_reg(j + 2, sum_1);
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
const auto& transform_smem = [&](const uint32_t& j, const float2& sum) {
|
| 368 |
+
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
| 369 |
+
auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j),
|
| 370 |
+
ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1));
|
| 371 |
+
return __ffma2_rn(a, b, sum);
|
| 372 |
+
};
|
| 373 |
+
|
| 374 |
+
#pragma unroll
|
| 375 |
+
for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) {
|
| 376 |
+
sum_0 = transform_smem(j, sum_0);
|
| 377 |
+
sum_1 = transform_smem(j + 2, sum_1);
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
auto sum = __fadd2_rn(sum_0, sum_1);
|
| 381 |
+
float result = scale_kv * (sum.x + sum.y);
|
| 382 |
+
|
| 383 |
+
// Store into the global memory
|
| 384 |
+
// NOTES: we have redundant writes here, consider more carefully
|
| 385 |
+
logits[kv_offset + i * logits_stride + thread_idx] = result;
|
| 386 |
+
}
|
| 387 |
+
}
|
| 388 |
+
} else {
|
| 389 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
// Free tensor memory
|
| 393 |
+
__syncthreads();
|
| 394 |
+
if (is_umma_warp)
|
| 395 |
+
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
} // namespace deep_gemm
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#pragma clang diagnostic push
|
| 3 |
+
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
| 4 |
+
|
| 5 |
+
#include <cutlass/arch/barrier.h>
|
| 6 |
+
|
| 7 |
+
#include <deep_gemm/common/reduction.cuh>
|
| 8 |
+
#include <deep_gemm/common/utils.cuh>
|
| 9 |
+
#include <deep_gemm/common/sm90_utils.cuh>
|
| 10 |
+
#include <deep_gemm/common/sm100_utils.cuh>
|
| 11 |
+
|
| 12 |
+
namespace deep_gemm {
|
| 13 |
+
|
| 14 |
+
using namespace deep_gemm::sm100;
|
| 15 |
+
|
| 16 |
+
template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
|
| 17 |
+
__device__ __forceinline__
|
| 18 |
+
uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) {
|
| 19 |
+
// Calculate the index of the bank group to be written in the atom
|
| 20 |
+
const auto& bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase);
|
| 21 |
+
|
| 22 |
+
// Reshape the atom in another view and swizzle
|
| 23 |
+
// - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)`
|
| 24 |
+
// - new: `(BLOCK_N * kSwizzleMode / kSwizzleBase / kNumBankGroups, kNumBankGroups)`
|
| 25 |
+
constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase;
|
| 26 |
+
constexpr bool kHasShortcut = (kSwizzleMode / kSwizzleBase) == kNumBankGroups;
|
| 27 |
+
auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups);
|
| 28 |
+
auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups);
|
| 29 |
+
col ^= row % (kSwizzleMode / kSwizzleBase);
|
| 30 |
+
|
| 31 |
+
return row * 128 + col * kSwizzleBase;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 35 |
+
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
| 36 |
+
uint32_t kNumSplits,
|
| 37 |
+
uint32_t kSwizzleCDMode,
|
| 38 |
+
uint32_t kNumStages,
|
| 39 |
+
uint32_t kNumMMAThreads, uint32_t kNumCastAndReduceThreads>
|
| 40 |
+
__global__ void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1)
|
| 41 |
+
sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
| 42 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
| 43 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
| 44 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_d,
|
| 45 |
+
float* sqr_sum) {
|
| 46 |
+
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
|
| 47 |
+
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 48 |
+
|
| 49 |
+
// Configs
|
| 50 |
+
constexpr uint32_t kNumCastStages = 2;
|
| 51 |
+
constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128);
|
| 52 |
+
constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128);
|
| 53 |
+
constexpr auto kMajorA = cute::UMMA::Major::K;
|
| 54 |
+
constexpr auto kMajorB = cute::UMMA::Major::K;
|
| 55 |
+
DG_STATIC_ASSERT(kNumCastStages <= kNumStages, "Invalid cast stages");
|
| 56 |
+
DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N");
|
| 57 |
+
DG_STATIC_ASSERT(kNumMMAThreads == 128, "Invalid MMA threads");
|
| 58 |
+
|
| 59 |
+
// Utils
|
| 60 |
+
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
| 61 |
+
const auto lane_idx = get_lane_idx();
|
| 62 |
+
|
| 63 |
+
// Align to 1024 bytes for swizzle-128B
|
| 64 |
+
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
| 65 |
+
|
| 66 |
+
// Share memory sizes
|
| 67 |
+
constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode;
|
| 68 |
+
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16);
|
| 69 |
+
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float);
|
| 70 |
+
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
| 71 |
+
|
| 72 |
+
// Real tensor memory size and offsets
|
| 73 |
+
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<BLOCK_K * kNumCastStages + BLOCK_N>();
|
| 74 |
+
|
| 75 |
+
// Prefetch TMA descriptors at the very beginning
|
| 76 |
+
if (warp_idx == 0 and cute::elect_one_sync()) {
|
| 77 |
+
cute::prefetch_tma_descriptor(&tensor_map_a);
|
| 78 |
+
cute::prefetch_tma_descriptor(&tensor_map_b);
|
| 79 |
+
cute::prefetch_tma_descriptor(&tensor_map_d);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
// Data on shared memory (layout as ordered below)
|
| 83 |
+
// Fill D/A/B pointers
|
| 84 |
+
auto smem_cd = reinterpret_cast<float*>(smem_buffer);
|
| 85 |
+
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
| 86 |
+
return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
|
| 87 |
+
});
|
| 88 |
+
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
| 89 |
+
return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
| 90 |
+
});
|
| 91 |
+
|
| 92 |
+
// Fill barriers
|
| 93 |
+
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
|
| 94 |
+
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 95 |
+
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 96 |
+
auto full_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 97 |
+
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
|
| 98 |
+
auto empty_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
|
| 99 |
+
auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4;
|
| 100 |
+
|
| 101 |
+
// Fill the tensor memory pointer
|
| 102 |
+
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 4 + 1);
|
| 103 |
+
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
| 104 |
+
|
| 105 |
+
// Initialize barriers
|
| 106 |
+
if (warp_idx == 1 and cute::elect_one_sync()) {
|
| 107 |
+
#pragma unroll
|
| 108 |
+
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
| 109 |
+
full_barriers[i]->init(1);
|
| 110 |
+
full_cast_barriers[i]->init(kNumCastAndReduceThreads);
|
| 111 |
+
empty_barriers[i]->init(1);
|
| 112 |
+
empty_cast_barriers[i]->init(1);
|
| 113 |
+
}
|
| 114 |
+
tmem_full_barrier->init(1);
|
| 115 |
+
|
| 116 |
+
// Make initialized barrier visible in async proxy
|
| 117 |
+
cutlass::arch::fence_barrier_init();
|
| 118 |
+
} else if (warp_idx == 2) {
|
| 119 |
+
// Allocate tensor memory
|
| 120 |
+
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
| 121 |
+
}
|
| 122 |
+
__syncthreads();
|
| 123 |
+
|
| 124 |
+
constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K);
|
| 125 |
+
constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
|
| 126 |
+
constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
|
| 127 |
+
const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
|
| 128 |
+
const uint32_t m_block_idx = block_idx / kNumSplits;
|
| 129 |
+
const uint32_t k_split_idx = block_idx % kNumSplits;
|
| 130 |
+
const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K;
|
| 131 |
+
const uint32_t m_offset = shape_m * k_split_idx;
|
| 132 |
+
const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks);
|
| 133 |
+
|
| 134 |
+
// Dispatch warps into different roles
|
| 135 |
+
if (warp_idx < kNumMMAThreads / 32) {
|
| 136 |
+
// TMA load warp
|
| 137 |
+
if (warp_idx == 0 and cute::elect_one_sync()) {
|
| 138 |
+
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
| 139 |
+
// Wait consumer release
|
| 140 |
+
const auto& stage_idx = s % kNumStages;
|
| 141 |
+
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
|
| 142 |
+
|
| 143 |
+
// Compute offsets
|
| 144 |
+
uint32_t m_idx = m_block_idx * BLOCK_M;
|
| 145 |
+
uint32_t k_idx = k_offset + s * BLOCK_K;
|
| 146 |
+
|
| 147 |
+
// Issue TMAs
|
| 148 |
+
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
|
| 149 |
+
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
|
| 150 |
+
|
| 151 |
+
// Arrive at full barriers
|
| 152 |
+
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
| 153 |
+
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes);
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
// MMA issue warp
|
| 158 |
+
if (warp_idx == 1) {
|
| 159 |
+
// Make instruction descriptor
|
| 160 |
+
constexpr uint32_t UMMA_M = BLOCK_M;
|
| 161 |
+
constexpr uint32_t UMMA_N = BLOCK_N;
|
| 162 |
+
constexpr uint32_t UMMA_K = 32 / sizeof(float);
|
| 163 |
+
constexpr uint32_t BLOCK_SWIZZLED_BK = kSwizzleBMode / sizeof(float);
|
| 164 |
+
using umma_t = cute::SM100_MMA_TF32_TS<cutlass::tfloat32_t, cutlass::tfloat32_t, float,
|
| 165 |
+
BLOCK_M, BLOCK_N, kMajorA, kMajorB>;
|
| 166 |
+
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::tfloat32_t, cutlass::tfloat32_t, float,
|
| 167 |
+
UMMA_M, UMMA_N, kMajorA, kMajorB>();
|
| 168 |
+
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
| 169 |
+
|
| 170 |
+
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
| 171 |
+
auto b_desc = make_umma_desc<kMajorB, BLOCK_N, BLOCK_SWIZZLED_BK, kSwizzleBMode>(smem_b[0], 0, 0);
|
| 172 |
+
const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
| 173 |
+
|
| 174 |
+
// Checks for MMA instructions
|
| 175 |
+
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
|
| 176 |
+
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
|
| 177 |
+
(UMMA_M == 128 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
|
| 178 |
+
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
|
| 179 |
+
"Invalid MMA instruction shape");
|
| 180 |
+
|
| 181 |
+
// Launch MMAs
|
| 182 |
+
// We can not unroll this part
|
| 183 |
+
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
| 184 |
+
// Wait TMA arrival
|
| 185 |
+
const auto& stage_idx = s % kNumStages;
|
| 186 |
+
const auto& cast_stage_idx = s % kNumCastStages;
|
| 187 |
+
full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1);
|
| 188 |
+
tcgen05_after_thread_sync();
|
| 189 |
+
|
| 190 |
+
// Issue UMMA
|
| 191 |
+
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
|
| 192 |
+
#pragma unroll
|
| 193 |
+
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
| 194 |
+
const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK;
|
| 195 |
+
const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK;
|
| 196 |
+
const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK;
|
| 197 |
+
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, float>(b_desc_base_lo, offset, in_atom_idx);
|
| 198 |
+
umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc);
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
// Commit
|
| 202 |
+
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_cast_barriers[cast_stage_idx]));
|
| 203 |
+
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
// Commit to epilogue threads
|
| 207 |
+
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barrier));
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
// TMA checks
|
| 211 |
+
constexpr uint32_t kNumBankGroupBytes = 16;
|
| 212 |
+
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float);
|
| 213 |
+
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
|
| 214 |
+
DG_STATIC_ASSERT(BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
|
| 215 |
+
|
| 216 |
+
// Only support layout F (M = 64) and D (M = 128)
|
| 217 |
+
DG_STATIC_ASSERT(BLOCK_M == 64 or BLOCK_M == 128, "Invalid block M");
|
| 218 |
+
|
| 219 |
+
// Wait UMMA arrival
|
| 220 |
+
tmem_full_barrier->wait(0);
|
| 221 |
+
tcgen05_after_thread_sync();
|
| 222 |
+
|
| 223 |
+
// Load from tensor memory into registers, and write shared memory with STSM
|
| 224 |
+
DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough");
|
| 225 |
+
|
| 226 |
+
// Store into shared memory
|
| 227 |
+
#pragma unroll
|
| 228 |
+
for (uint32_t i = 0; i < BLOCK_N / kNumElemsPerBankGroup; ++ i) {
|
| 229 |
+
// Source and destination memory address
|
| 230 |
+
uint32_t tmem_addr = BLOCK_K * kNumCastStages + i * kNumElemsPerBankGroup;
|
| 231 |
+
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd) + // Base pointer
|
| 232 |
+
warp_idx * BLOCK_M / 4 * kSwizzleCDMode + // Warp offset
|
| 233 |
+
get_swizzled_smem_offset<kSwizzleCDMode>(i, lane_idx); // In-atom offset
|
| 234 |
+
|
| 235 |
+
// Load from tensor memory, store into shared memory
|
| 236 |
+
uint32_t values[kNumElemsPerBankGroup];
|
| 237 |
+
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
|
| 238 |
+
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
|
| 239 |
+
values[0], values[1], values[2], values[3]);
|
| 240 |
+
cutlass::arch::fence_view_async_tmem_load();
|
| 241 |
+
if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16))
|
| 242 |
+
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
| 243 |
+
if constexpr (BLOCK_M == 64)
|
| 244 |
+
__syncwarp();
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
// Synchronize all threads and issue TMA
|
| 248 |
+
cute::tma_store_fence();
|
| 249 |
+
cutlass::arch::NamedBarrier::sync(kNumMMAThreads, 0);
|
| 250 |
+
if (warp_idx == 0 and cute::elect_one_sync()) {
|
| 251 |
+
if constexpr (kNumSplits == 1) {
|
| 252 |
+
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M);
|
| 253 |
+
} else {
|
| 254 |
+
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx);
|
| 255 |
+
}
|
| 256 |
+
cute::tma_store_arrive();
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
// Deallocate tensor memory by warp 1
|
| 260 |
+
// NOTES: warp 0 is waiting TMA store
|
| 261 |
+
if (warp_idx == 1)
|
| 262 |
+
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
|
| 263 |
+
} else {
|
| 264 |
+
DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M");
|
| 265 |
+
DG_STATIC_ASSERT(kNumCastAndReduceThreads == 128, "Invalid cast-and-reduce threads");
|
| 266 |
+
constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4;
|
| 267 |
+
const uint32_t sub_warp_idx = warp_idx - kNumMMAThreads / 32;
|
| 268 |
+
|
| 269 |
+
// TODO: make even larger block K
|
| 270 |
+
DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K");
|
| 271 |
+
|
| 272 |
+
// Launch reductions
|
| 273 |
+
float2 sum[2] = {float2{0, 0}, float2{0, 0}};
|
| 274 |
+
#pragma unroll kNumStages
|
| 275 |
+
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
| 276 |
+
// Wait TMA arrival
|
| 277 |
+
const auto& stage_idx = s % kNumStages;
|
| 278 |
+
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
|
| 279 |
+
|
| 280 |
+
// Load from shared memory into tensor memory using movement shape `.16x256b` (shared memory part is 128b)
|
| 281 |
+
constexpr uint32_t kNumBankGroupBytes = 16;
|
| 282 |
+
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16);
|
| 283 |
+
constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup;
|
| 284 |
+
const auto& smem_base_ptr = reinterpret_cast<uint8_t*>(smem_a[stage_idx]) + // Base pointer
|
| 285 |
+
sub_warp_idx * BLOCK_M_PER_WARP * kSwizzleAMode; // Warp offset
|
| 286 |
+
|
| 287 |
+
// 4 lanes shared a bank group
|
| 288 |
+
uint32_t uint32_values[2][kNumLoads];
|
| 289 |
+
DG_STATIC_ASSERT(kNumLoads % 2 == 0, "Invalid number of loads");
|
| 290 |
+
#pragma unroll
|
| 291 |
+
for (uint32_t i = 0; i < kNumLoads; i += 2) {
|
| 292 |
+
auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset<kSwizzleAMode>(i + lane_idx / 16, lane_idx % 16);
|
| 293 |
+
sm90::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0],
|
| 294 |
+
uint32_values[0][i + 1], uint32_values[1][i + 1],
|
| 295 |
+
smem_ptr);
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
// Wait tensor memory empty
|
| 299 |
+
const auto& cast_stage_idx = s % kNumCastStages;
|
| 300 |
+
empty_cast_barriers[cast_stage_idx]->wait(((s / kNumCastStages) & 1) ^ 1);
|
| 301 |
+
|
| 302 |
+
// Cast, reduce and store into tensor memory
|
| 303 |
+
float2 fp32x2_values[2][kNumLoads];
|
| 304 |
+
const auto& upper_view = reinterpret_cast<uint32_t*>(&fp32x2_values[0]);
|
| 305 |
+
const auto& lower_view = reinterpret_cast<uint32_t*>(&fp32x2_values[1]);
|
| 306 |
+
#pragma unroll
|
| 307 |
+
for (uint32_t i = 0; i < kNumLoads; ++ i) {
|
| 308 |
+
#pragma unroll
|
| 309 |
+
for (uint32_t u = 0; u < 2; ++ u) {
|
| 310 |
+
fp32x2_values[u][i] = __bfloat1622float2(*reinterpret_cast<nv_bfloat162*>(&uint32_values[u][i]));
|
| 311 |
+
sum[u] = __ffma2_rn(fp32x2_values[u][i], fp32x2_values[u][i], sum[u]);
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
// Store upper and lower part at the same time
|
| 315 |
+
const auto idx_0 = i * 2, idx_1 = i * 2 + 1;
|
| 316 |
+
cute::SM100_TMEM_STORE_16dp256b1x::copy(
|
| 317 |
+
upper_view[idx_0], upper_view[idx_1],
|
| 318 |
+
lower_view[idx_0], lower_view[idx_1],
|
| 319 |
+
cast_stage_idx * BLOCK_K + i * 8);
|
| 320 |
+
}
|
| 321 |
+
cutlass::arch::fence_view_async_tmem_store();
|
| 322 |
+
|
| 323 |
+
// Arrive for issuing MMAs
|
| 324 |
+
tcgen05_before_thread_sync();
|
| 325 |
+
full_cast_barriers[cast_stage_idx]->arrive();
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
// Intra-warp reduction and write back
|
| 329 |
+
#pragma unroll
|
| 330 |
+
for (uint32_t u = 0; u < 2; ++ u) {
|
| 331 |
+
const auto& reduced_sum = warp_reduce_sum<4>(sum[u].x + sum[u].y);
|
| 332 |
+
const auto& m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8;
|
| 333 |
+
if (lane_idx % 4 == 0 and m_idx < shape_m)
|
| 334 |
+
sqr_sum[m_offset + m_idx] = reduced_sum;
|
| 335 |
+
}
|
| 336 |
+
}
|
| 337 |
+
#else
|
| 338 |
+
if (blockIdx.x == 0 and threadIdx.x == 0)
|
| 339 |
+
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
|
| 340 |
+
#endif
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
} // namespace deep_gemm
|
| 344 |
+
|
| 345 |
+
#pragma clang diagnostic pop
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#pragma clang diagnostic push
|
| 4 |
+
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
| 5 |
+
|
| 6 |
+
#include <cutlass/arch/barrier.h>
|
| 7 |
+
#include <cutlass/arch/reg_reconfig.h>
|
| 8 |
+
|
| 9 |
+
#include <cute/arch/cluster_sm90.hpp>
|
| 10 |
+
#include <cute/arch/copy_sm90_desc.hpp>
|
| 11 |
+
#include <cute/arch/copy_sm90_tma.hpp>
|
| 12 |
+
#include <cute/arch/mma_sm100_desc.hpp>
|
| 13 |
+
|
| 14 |
+
#include <deep_gemm/common/utils.cuh>
|
| 15 |
+
#include <deep_gemm/common/scheduler.cuh>
|
| 16 |
+
#include <deep_gemm/common/sm90_utils.cuh>
|
| 17 |
+
|
| 18 |
+
namespace deep_gemm {
|
| 19 |
+
|
| 20 |
+
using namespace deep_gemm::sm90;
|
| 21 |
+
|
| 22 |
+
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
| 23 |
+
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 24 |
+
uint32_t kNumGroups,
|
| 25 |
+
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K_,
|
| 26 |
+
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleDMode,
|
| 27 |
+
uint32_t kNumStages_,
|
| 28 |
+
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
|
| 29 |
+
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
| 30 |
+
uint32_t kNumSMs,
|
| 31 |
+
GemmType kGemmType, bool kWithAccumulation,
|
| 32 |
+
typename cd_dtype_t>
|
| 33 |
+
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
|
| 34 |
+
sm90_bf16_gemm_impl(int* grouped_layout,
|
| 35 |
+
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
| 36 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
| 37 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
| 38 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_cd) {
|
| 39 |
+
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
| 40 |
+
// Enlarge `BLOCK_K` for some cases
|
| 41 |
+
// NOTES: this is for reducing the `warpgroup_wait<0>()` overhead
|
| 42 |
+
constexpr uint32_t kDoMergeStages =
|
| 43 |
+
kNumStages_ >= 10 and
|
| 44 |
+
kGemmType == GemmType::Normal and
|
| 45 |
+
kMajorA == cute::UMMA::Major::K and kMajorB == cute::UMMA::Major::K and
|
| 46 |
+
kNumMathThreads == 128;
|
| 47 |
+
// Ensure there are at least `kNumMinStages` stages after merge
|
| 48 |
+
constexpr uint32_t kNumMinStages = 5;
|
| 49 |
+
constexpr uint32_t kNumStagesPerMerge = kDoMergeStages ? kNumStages_ / kNumMinStages : 1;
|
| 50 |
+
constexpr uint32_t BLOCK_K = BLOCK_K_ * kNumStagesPerMerge;
|
| 51 |
+
constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge;
|
| 52 |
+
|
| 53 |
+
// Types
|
| 54 |
+
using WGMMA = typename BF16MMASelector<BLOCK_N, kMajorA, kMajorB>::type;
|
| 55 |
+
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 56 |
+
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size");
|
| 57 |
+
|
| 58 |
+
// Overwrite shape constants if the compiler gives
|
| 59 |
+
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
|
| 60 |
+
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
|
| 61 |
+
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
| 62 |
+
|
| 63 |
+
// Shared memory
|
| 64 |
+
static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(cd_dtype_t)), 1024u);
|
| 65 |
+
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
|
| 66 |
+
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
|
| 67 |
+
|
| 68 |
+
// NOTES: Make sure we have enough shared memory for WGMMA padding
|
| 69 |
+
static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
| 70 |
+
DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA");
|
| 71 |
+
|
| 72 |
+
// Configs
|
| 73 |
+
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 74 |
+
const uint32_t lane_idx = get_lane_idx();
|
| 75 |
+
|
| 76 |
+
// Prefetch TMA descriptors at the very beginning
|
| 77 |
+
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
| 78 |
+
cute::prefetch_tma_descriptor(&tensor_map_a);
|
| 79 |
+
cute::prefetch_tma_descriptor(&tensor_map_b);
|
| 80 |
+
cute::prefetch_tma_descriptor(&tensor_map_cd);
|
| 81 |
+
}
|
| 82 |
+
__syncwarp();
|
| 83 |
+
|
| 84 |
+
// Align to 1024 bytes for swizzle-128B
|
| 85 |
+
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
| 86 |
+
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0,
|
| 87 |
+
"Shared memory of A/B/D must be aligned to 1024 bytes");
|
| 88 |
+
|
| 89 |
+
// D/A/B shared memory
|
| 90 |
+
auto smem_d = reinterpret_cast<cd_dtype_t*>(smem_buffer);
|
| 91 |
+
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
| 92 |
+
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
| 93 |
+
});
|
| 94 |
+
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
| 95 |
+
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
| 96 |
+
});
|
| 97 |
+
|
| 98 |
+
// Fill barriers
|
| 99 |
+
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 100 |
+
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 101 |
+
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 102 |
+
|
| 103 |
+
// Initialize barriers
|
| 104 |
+
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
|
| 105 |
+
#pragma unroll
|
| 106 |
+
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
| 107 |
+
full_barriers[i]->init(1);
|
| 108 |
+
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
// Make initialized barrier visible in async proxy
|
| 112 |
+
cutlass::arch::fence_barrier_init();
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
// Synchronize all threads to make barrier visible in normal memory model
|
| 116 |
+
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
|
| 117 |
+
|
| 118 |
+
// Register reconfigurations
|
| 119 |
+
constexpr uint32_t kNumTMARegisters = 48;
|
| 120 |
+
constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 224;
|
| 121 |
+
|
| 122 |
+
// Block scheduler
|
| 123 |
+
uint32_t m_block_idx, n_block_idx;
|
| 124 |
+
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
|
| 125 |
+
|
| 126 |
+
// Pipeline and TMA phases
|
| 127 |
+
uint32_t stage_idx = 0, phase = 0;
|
| 128 |
+
auto advance_pipeline = [&](uint32_t& k_block_idx) {
|
| 129 |
+
++ k_block_idx;
|
| 130 |
+
|
| 131 |
+
// Flip phases only if reach the next first stage
|
| 132 |
+
stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
|
| 133 |
+
phase ^= stage_idx == 0;
|
| 134 |
+
};
|
| 135 |
+
|
| 136 |
+
if (warp_idx >= kNumMathThreads / 32) {
|
| 137 |
+
// TMA warp-group for loading data
|
| 138 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
| 139 |
+
|
| 140 |
+
// NOTES: only one thread (or warp) will be used
|
| 141 |
+
// We use the third warp, as warp 0/1 may be doing WGMMA with `BLOCK_M == 32`
|
| 142 |
+
if (warp_idx == kNumMathThreads / 32 + 2 and cute::elect_one_sync()) {
|
| 143 |
+
DG_STATIC_ASSERT(kNumTMAThreads >= 128, "Need at least 128 threads for TMA warp-group");
|
| 144 |
+
|
| 145 |
+
// Persistently schedule over blocks
|
| 146 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 147 |
+
// Assign TMA multicast number into A and B
|
| 148 |
+
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
|
| 149 |
+
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
|
| 150 |
+
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
| 151 |
+
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
| 152 |
+
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
| 153 |
+
|
| 154 |
+
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
|
| 155 |
+
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 156 |
+
// Wait consumer release
|
| 157 |
+
empty_barriers[stage_idx]->wait(phase ^ 1);
|
| 158 |
+
|
| 159 |
+
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
|
| 160 |
+
auto& full_barrier = *full_barriers[stage_idx];
|
| 161 |
+
|
| 162 |
+
const auto m_idx = scheduler.template get_global_idx<kWithGroupOffsetA, IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
|
| 163 |
+
const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx);
|
| 164 |
+
|
| 165 |
+
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
|
| 166 |
+
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> (
|
| 167 |
+
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
| 168 |
+
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> (
|
| 169 |
+
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
| 170 |
+
|
| 171 |
+
// Issue TMAs
|
| 172 |
+
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
|
| 173 |
+
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
|
| 174 |
+
if constexpr (kMajorA == cute::UMMA::Major::K)
|
| 175 |
+
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
|
| 176 |
+
&tensor_map_a, &full_barrier, smem_a[stage_idx], k_a_idx, m_idx, num_tma_multicast_a, batch_idx);
|
| 177 |
+
if constexpr (kMajorA == cute::UMMA::Major::MN)
|
| 178 |
+
tma_copy<BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
|
| 179 |
+
&tensor_map_a, &full_barrier, smem_a[stage_idx], m_idx, k_a_idx, num_tma_multicast_a, batch_idx);
|
| 180 |
+
if constexpr (kMajorB == cute::UMMA::Major::K)
|
| 181 |
+
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
|
| 182 |
+
&tensor_map_b, &full_barrier, smem_b[stage_idx], k_b_idx, n_idx, num_tma_multicast_b, batch_idx);
|
| 183 |
+
if constexpr (kMajorB == cute::UMMA::Major::MN)
|
| 184 |
+
tma_copy<BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
|
| 185 |
+
&tensor_map_b, &full_barrier, smem_b[stage_idx], n_idx, k_b_idx, num_tma_multicast_b, batch_idx);
|
| 186 |
+
|
| 187 |
+
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
| 192 |
+
if constexpr (kNumTMAMulticast > 1) {
|
| 193 |
+
for (uint32_t i = 0; i < kNumStages; advance_pipeline(i))
|
| 194 |
+
empty_barriers[stage_idx]->wait(phase ^ 1);
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
} else {
|
| 198 |
+
// Math warp-groups for WGMMA
|
| 199 |
+
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
| 200 |
+
|
| 201 |
+
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
| 202 |
+
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
|
| 203 |
+
|
| 204 |
+
// Merged stages only happens in NT normal GEMM cases
|
| 205 |
+
constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge;
|
| 206 |
+
auto a_desc = make_gmma_desc<kMajorA, BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode>(smem_a[0], math_wg_idx * WGMMA::M, 0);
|
| 207 |
+
auto b_desc = make_gmma_desc<kMajorB, BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode>(smem_b[0], 0, 0);
|
| 208 |
+
const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
|
| 209 |
+
const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
|
| 210 |
+
|
| 211 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 212 |
+
constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2;
|
| 213 |
+
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes");
|
| 214 |
+
float accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
|
| 215 |
+
|
| 216 |
+
// Pick threads whose WGMMA results are to be stored in shared memory
|
| 217 |
+
DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for `BLOCK_M < 64`");
|
| 218 |
+
constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M);
|
| 219 |
+
const bool do_wgmma_store = BLOCK_M >= 64 or warp_idx < kNumWGMMAStoreThreads / 32;
|
| 220 |
+
|
| 221 |
+
// Empty barrier arrival
|
| 222 |
+
auto empty_barrier_arrive = [&](uint32_t s) {
|
| 223 |
+
if constexpr (kNumTMAMulticast == 1) {
|
| 224 |
+
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
| 225 |
+
} else {
|
| 226 |
+
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
|
| 227 |
+
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
|
| 228 |
+
}
|
| 229 |
+
};
|
| 230 |
+
|
| 231 |
+
// TODO: remove some useless computation for unaligned Ms
|
| 232 |
+
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
|
| 233 |
+
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 234 |
+
const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
|
| 235 |
+
const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
|
| 236 |
+
|
| 237 |
+
// Wait TMA arrivals
|
| 238 |
+
full_barriers[stage_idx]->wait(phase);
|
| 239 |
+
|
| 240 |
+
// Commit WGMMA instructions
|
| 241 |
+
#pragma unroll
|
| 242 |
+
for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i)
|
| 243 |
+
warpgroup_fence_operand(accum[i]);
|
| 244 |
+
warpgroup_arrive();
|
| 245 |
+
#pragma unroll
|
| 246 |
+
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
| 247 |
+
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
|
| 248 |
+
#pragma unroll
|
| 249 |
+
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
| 250 |
+
const uint32_t& atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K;
|
| 251 |
+
a_desc.reg32_[0] = advance_gmma_desc_lo<kMajorA, BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode, nv_bfloat16>(
|
| 252 |
+
a_desc_base_lo, local_idx * WAVE_BLOCK_M, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_M * BLOCK_ATOM_K);
|
| 253 |
+
b_desc.reg32_[0] = advance_gmma_desc_lo<kMajorB, BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode, nv_bfloat16>(
|
| 254 |
+
b_desc_base_lo, 0, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_N * BLOCK_ATOM_K);
|
| 255 |
+
WGMMA::wgmma(a_desc, b_desc, shifted_accum, 1);
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
warpgroup_commit_batch();
|
| 259 |
+
#pragma unroll
|
| 260 |
+
for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i)
|
| 261 |
+
warpgroup_fence_operand(accum[i]);
|
| 262 |
+
warpgroup_wait<0>();
|
| 263 |
+
|
| 264 |
+
// Notify barrier arrival
|
| 265 |
+
empty_barrier_arrive(stage_idx);
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
// TMA checks
|
| 269 |
+
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
|
| 270 |
+
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes);
|
| 271 |
+
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
|
| 272 |
+
DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom");
|
| 273 |
+
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
|
| 274 |
+
"Unaligned TMA store or too many TMA store instructions");
|
| 275 |
+
DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N");
|
| 276 |
+
|
| 277 |
+
// Skip WGMMA store for the unfilled parts
|
| 278 |
+
if (not do_wgmma_store)
|
| 279 |
+
continue;
|
| 280 |
+
|
| 281 |
+
// Wait last TMA store to be finished
|
| 282 |
+
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N)
|
| 283 |
+
cute::tma_store_wait<0>();
|
| 284 |
+
cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0);
|
| 285 |
+
|
| 286 |
+
if constexpr (cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>) {
|
| 287 |
+
// Write back to shared memory using STSM and issue TMA stores
|
| 288 |
+
DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type");
|
| 289 |
+
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
| 290 |
+
#pragma unroll
|
| 291 |
+
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
| 292 |
+
auto m_offset = local_idx * WAVE_BLOCK_M;
|
| 293 |
+
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
|
| 294 |
+
#pragma unroll
|
| 295 |
+
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
| 296 |
+
// Swizzle or padding into the correct address
|
| 297 |
+
uint8_t* smem_ptr = nullptr;
|
| 298 |
+
if constexpr (kSwizzleDMode > 0) {
|
| 299 |
+
// Calculate the swizzling atom offset and in-atom offset
|
| 300 |
+
constexpr uint32_t kNumBankGroupBytes = 16;
|
| 301 |
+
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
|
| 302 |
+
|
| 303 |
+
// Calculate the index of the bank group to be written in the atom
|
| 304 |
+
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
|
| 305 |
+
|
| 306 |
+
// Reshape the atom in another view and swizzle
|
| 307 |
+
// - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)`
|
| 308 |
+
// - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
|
| 309 |
+
constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8;
|
| 310 |
+
auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8);
|
| 311 |
+
auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8);
|
| 312 |
+
col ^= row % (kSwizzleDMode / 16);
|
| 313 |
+
|
| 314 |
+
// Add back into the base pointer
|
| 315 |
+
// NOTES: think twice before modifying this, as changes may affect the number of instructions
|
| 316 |
+
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
|
| 317 |
+
warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset
|
| 318 |
+
m_offset * kSwizzleDMode + // Wave offset
|
| 319 |
+
atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants)
|
| 320 |
+
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
| 321 |
+
} else {
|
| 322 |
+
// No swizzling
|
| 323 |
+
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8);
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
// NOTES: only 16 lanes' addresses are used
|
| 327 |
+
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
| 328 |
+
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
|
| 329 |
+
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
|
| 330 |
+
smem_ptr
|
| 331 |
+
);
|
| 332 |
+
}
|
| 333 |
+
}
|
| 334 |
+
} else {
|
| 335 |
+
// Use `st.shared` if STSM is not available
|
| 336 |
+
#pragma unroll
|
| 337 |
+
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
| 338 |
+
auto m_offset = local_idx * WAVE_BLOCK_M;
|
| 339 |
+
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
|
| 340 |
+
auto smem_d_0 = reinterpret_cast<float2*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 0) * BLOCK_N + (lane_idx % 4) * 2);
|
| 341 |
+
auto smem_d_1 = reinterpret_cast<float2*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2);
|
| 342 |
+
#pragma unroll
|
| 343 |
+
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
| 344 |
+
st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]));
|
| 345 |
+
st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]));
|
| 346 |
+
}
|
| 347 |
+
}
|
| 348 |
+
}
|
| 349 |
+
cute::tma_store_fence();
|
| 350 |
+
cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0);
|
| 351 |
+
|
| 352 |
+
// Use TMA store to write back to global memory
|
| 353 |
+
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
|
| 354 |
+
DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
|
| 355 |
+
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
|
| 356 |
+
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
|
| 357 |
+
auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
|
| 358 |
+
if constexpr (kGemmType == GemmType::Batched) {
|
| 359 |
+
cute::SM90_TMA_STORE_3D::copy(&tensor_map_cd, smem_ptr,
|
| 360 |
+
n_block_idx * BLOCK_N + in_block_n_offset,
|
| 361 |
+
m_idx, scheduler.current_group_idx);
|
| 362 |
+
} else {
|
| 363 |
+
using cute_tma_t = cute::conditional_t<kWithAccumulation,
|
| 364 |
+
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
|
| 365 |
+
cute_tma_t::copy(&tensor_map_cd, smem_ptr,
|
| 366 |
+
n_block_idx * BLOCK_N + in_block_n_offset, m_idx);
|
| 367 |
+
}
|
| 368 |
+
cute::tma_store_arrive();
|
| 369 |
+
}
|
| 370 |
+
__syncwarp();
|
| 371 |
+
}
|
| 372 |
+
}
|
| 373 |
+
#else
|
| 374 |
+
if (blockIdx.x == 0 and threadIdx.x == 0)
|
| 375 |
+
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
| 376 |
+
#endif
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
}; // namespace deep_gemm
|
| 380 |
+
|
| 381 |
+
#pragma clang diagnostic pop
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cute/arch/cluster_sm90.hpp>
|
| 4 |
+
#include <cutlass/arch/barrier.h>
|
| 5 |
+
#include <cutlass/arch/reg_reconfig.h>
|
| 6 |
+
|
| 7 |
+
#include <deep_gemm/common/utils.cuh>
|
| 8 |
+
#include <deep_gemm/common/sm90_utils.cuh>
|
| 9 |
+
|
| 10 |
+
namespace deep_gemm {
|
| 11 |
+
|
| 12 |
+
using namespace deep_gemm::sm90;
|
| 13 |
+
|
| 14 |
+
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 15 |
+
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
| 16 |
+
uint32_t kSplitFactor,
|
| 17 |
+
uint32_t kNumStages,
|
| 18 |
+
uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
|
| 19 |
+
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
|
| 20 |
+
sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
|
| 21 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
| 22 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
| 23 |
+
float *d) {
|
| 24 |
+
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
| 25 |
+
// Types
|
| 26 |
+
using WGMMA = typename BF16MMASelector<BLOCK_N>::type;
|
| 27 |
+
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 28 |
+
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
|
| 29 |
+
|
| 30 |
+
// Shared memory
|
| 31 |
+
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
|
| 32 |
+
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
|
| 33 |
+
|
| 34 |
+
// Configs
|
| 35 |
+
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 36 |
+
const uint32_t lane_idx = get_lane_idx();
|
| 37 |
+
DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M");
|
| 38 |
+
DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads");
|
| 39 |
+
DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads");
|
| 40 |
+
|
| 41 |
+
// Prefetch TMA descriptors at the very beginning
|
| 42 |
+
if (warp_idx == 0 and cute::elect_one_sync()) {
|
| 43 |
+
cute::prefetch_tma_descriptor(&tensor_map_a);
|
| 44 |
+
cute::prefetch_tma_descriptor(&tensor_map_b);
|
| 45 |
+
}
|
| 46 |
+
__syncwarp();
|
| 47 |
+
|
| 48 |
+
// Align to 1024 bytes for swizzle-128B
|
| 49 |
+
// Fill shared memory pointers
|
| 50 |
+
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
| 51 |
+
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
| 52 |
+
return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE));
|
| 53 |
+
});
|
| 54 |
+
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
| 55 |
+
return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
| 56 |
+
});
|
| 57 |
+
|
| 58 |
+
// Fill barriers
|
| 59 |
+
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 60 |
+
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 61 |
+
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 62 |
+
|
| 63 |
+
// Initialize barriers
|
| 64 |
+
if (warp_idx == 1 and cute::elect_one_sync()) {
|
| 65 |
+
#pragma unroll
|
| 66 |
+
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
| 67 |
+
full_barriers[i]->init(1);
|
| 68 |
+
empty_barriers[i]->init(kNumMathThreads);
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
// Make initialized barrier visible in async proxy
|
| 72 |
+
cutlass::arch::fence_barrier_init();
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
// Synchronize all threads to make barrier visible in normal memory model
|
| 76 |
+
__syncthreads();
|
| 77 |
+
|
| 78 |
+
// Register reconfigurations
|
| 79 |
+
constexpr uint32_t kNumTMARegisters = 40;
|
| 80 |
+
constexpr uint32_t kNumMathRegisters = 232;
|
| 81 |
+
|
| 82 |
+
// Block indices
|
| 83 |
+
const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N);
|
| 84 |
+
const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M);
|
| 85 |
+
const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
|
| 86 |
+
const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
|
| 87 |
+
const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
|
| 88 |
+
const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
|
| 89 |
+
const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
|
| 90 |
+
|
| 91 |
+
if (warp_idx >= kNumMathThreads / 32) {
|
| 92 |
+
// TMA warp-group for loading data
|
| 93 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
| 94 |
+
|
| 95 |
+
// NOTES: only one thread (or warp) will be used
|
| 96 |
+
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
| 97 |
+
// Persistently schedule over blocks
|
| 98 |
+
#pragma unroll
|
| 99 |
+
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
| 100 |
+
// Wait consumer release
|
| 101 |
+
const auto& stage_idx = s % kNumStages;
|
| 102 |
+
empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1);
|
| 103 |
+
|
| 104 |
+
auto& full_barrier = *full_barriers[stage_idx];
|
| 105 |
+
const uint32_t& sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K;
|
| 106 |
+
const uint32_t& k_idx = sk_idx % SHAPE_K;
|
| 107 |
+
const uint32_t& s_idx = sk_idx / SHAPE_K;
|
| 108 |
+
|
| 109 |
+
constexpr uint32_t kSwizzle = BLOCK_K * sizeof(nv_bfloat16);
|
| 110 |
+
tma_copy<BLOCK_K, BLOCK_M, kSwizzle>(
|
| 111 |
+
&tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1);
|
| 112 |
+
tma_copy<BLOCK_K, BLOCK_N, kSwizzle>(
|
| 113 |
+
&tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1);
|
| 114 |
+
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
} else {
|
| 118 |
+
// Math warp-groups for WGMMA
|
| 119 |
+
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
| 120 |
+
|
| 121 |
+
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
| 122 |
+
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
|
| 123 |
+
float accum[WGMMA::kNumAccum] = {0};
|
| 124 |
+
|
| 125 |
+
// Launch MMAs
|
| 126 |
+
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
| 127 |
+
// Wait TMA arrivals
|
| 128 |
+
const auto& stage_idx = s % kNumStages;
|
| 129 |
+
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
|
| 130 |
+
|
| 131 |
+
// Commit WGMMA instructions
|
| 132 |
+
#pragma unroll
|
| 133 |
+
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 134 |
+
warpgroup_fence_operand(accum[i]);
|
| 135 |
+
warpgroup_arrive();
|
| 136 |
+
#pragma unroll
|
| 137 |
+
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
| 138 |
+
auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1);
|
| 139 |
+
auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
|
| 140 |
+
WGMMA::wgmma(desc_a, desc_b, accum, 1);
|
| 141 |
+
}
|
| 142 |
+
warpgroup_commit_batch();
|
| 143 |
+
#pragma unroll
|
| 144 |
+
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 145 |
+
warpgroup_fence_operand(accum[i]);
|
| 146 |
+
warpgroup_wait<0>();
|
| 147 |
+
|
| 148 |
+
// Notify barrier arrival at the last warpgroup wave
|
| 149 |
+
empty_barriers[stage_idx]->arrive();
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
const auto& row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4;
|
| 153 |
+
const auto& col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2;
|
| 154 |
+
#pragma unroll
|
| 155 |
+
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
| 156 |
+
if (col + i * 8 >= SHAPE_N)
|
| 157 |
+
break;
|
| 158 |
+
if (row < SHAPE_M) {
|
| 159 |
+
atomicAdd(reinterpret_cast<float2*>(d + (row + 0) * SHAPE_N + col + i * 8),
|
| 160 |
+
make_float2(accum[i * 4 + 0], accum[i * 4 + 1]));
|
| 161 |
+
}
|
| 162 |
+
if (row + 8 < SHAPE_M) {
|
| 163 |
+
atomicAdd(reinterpret_cast<float2*>(d + (row + 8) * SHAPE_N + col + i * 8),
|
| 164 |
+
make_float2(accum[i * 4 + 2], accum[i * 4 + 3]));
|
| 165 |
+
}
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
#else
|
| 169 |
+
if (blockIdx.x == 0 and threadIdx.x == 0)
|
| 170 |
+
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
| 171 |
+
#endif
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
}; // namespace deep_gemm
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#pragma clang diagnostic push
|
| 4 |
+
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
| 5 |
+
|
| 6 |
+
#include <cutlass/arch/barrier.h>
|
| 7 |
+
#include <cutlass/arch/reg_reconfig.h>
|
| 8 |
+
|
| 9 |
+
#include <cute/arch/cluster_sm90.hpp>
|
| 10 |
+
#include <cute/arch/copy_sm90_desc.hpp>
|
| 11 |
+
#include <cute/arch/copy_sm90_tma.hpp>
|
| 12 |
+
|
| 13 |
+
#include <deep_gemm/common/utils.cuh>
|
| 14 |
+
#include <deep_gemm/common/scheduler.cuh>
|
| 15 |
+
#include <deep_gemm/common/sm90_utils.cuh>
|
| 16 |
+
|
| 17 |
+
namespace deep_gemm {
|
| 18 |
+
|
| 19 |
+
using namespace deep_gemm::sm90;
|
| 20 |
+
|
| 21 |
+
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 22 |
+
uint32_t kNumGroups,
|
| 23 |
+
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
| 24 |
+
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode,
|
| 25 |
+
uint32_t kNumStages,
|
| 26 |
+
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
|
| 27 |
+
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
| 28 |
+
uint32_t kNumSMs,
|
| 29 |
+
GemmType kGemmType, typename cd_dtype_t>
|
| 30 |
+
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
|
| 31 |
+
sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
|
| 32 |
+
int* grouped_layout,
|
| 33 |
+
cute::TmaDescriptor* tensor_map_buffer,
|
| 34 |
+
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
| 35 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_a_base,
|
| 36 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_b_base,
|
| 37 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa,
|
| 38 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_sfb,
|
| 39 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_cd) {
|
| 40 |
+
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
| 41 |
+
// Scaling checks
|
| 42 |
+
DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid Threads");
|
| 43 |
+
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
| 44 |
+
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
|
| 45 |
+
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type");
|
| 46 |
+
|
| 47 |
+
// Types
|
| 48 |
+
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
| 49 |
+
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 50 |
+
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
|
| 51 |
+
|
| 52 |
+
// Overwrite shape constants if the compiler gives
|
| 53 |
+
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
|
| 54 |
+
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
|
| 55 |
+
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
| 56 |
+
|
| 57 |
+
// Shared memory
|
| 58 |
+
static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0);
|
| 59 |
+
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float);
|
| 60 |
+
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
| 61 |
+
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
| 62 |
+
static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
| 63 |
+
static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float);
|
| 64 |
+
static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u);
|
| 65 |
+
DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment");
|
| 66 |
+
|
| 67 |
+
// Configs
|
| 68 |
+
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 69 |
+
const uint32_t lane_idx = threadIdx.x % 32;
|
| 70 |
+
|
| 71 |
+
// Prefetch TMA descriptors at the very beginning
|
| 72 |
+
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
| 73 |
+
cute::prefetch_tma_descriptor(&tensor_map_a_base);
|
| 74 |
+
cute::prefetch_tma_descriptor(&tensor_map_b_base);
|
| 75 |
+
cute::prefetch_tma_descriptor(&tensor_map_sfa);
|
| 76 |
+
cute::prefetch_tma_descriptor(&tensor_map_sfb);
|
| 77 |
+
cute::prefetch_tma_descriptor(&tensor_map_cd);
|
| 78 |
+
}
|
| 79 |
+
__syncwarp();
|
| 80 |
+
|
| 81 |
+
// Align to 1024 bytes for swizzle-128B
|
| 82 |
+
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
| 83 |
+
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
| 84 |
+
|
| 85 |
+
// Tensor maps on shared and global memory
|
| 86 |
+
auto smem_tensor_map_a = PatternVisitor([&](const uint32_t& i) {
|
| 87 |
+
return reinterpret_cast<cute::TmaDescriptor*>(smem_buffer + static_cast<uint32_t>(sizeof(cute::TmaDescriptor)) * i);
|
| 88 |
+
});
|
| 89 |
+
auto smem_tensor_map_b = PatternVisitor([&](const uint32_t& i) {
|
| 90 |
+
return reinterpret_cast<cute::TmaDescriptor*>(smem_buffer + static_cast<uint32_t>(sizeof(cute::TmaDescriptor)) * (2 + i));
|
| 91 |
+
});
|
| 92 |
+
auto gmem_tensor_map_a = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + i; });
|
| 93 |
+
auto gmem_tensor_map_b = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + 2 + i; });
|
| 94 |
+
|
| 95 |
+
// Data on shared memory
|
| 96 |
+
auto smem_d = reinterpret_cast<float*>(smem_buffer + SMEM_TENSOR_MAP_SIZE);
|
| 97 |
+
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
| 98 |
+
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE));
|
| 99 |
+
});
|
| 100 |
+
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
| 101 |
+
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
| 102 |
+
});
|
| 103 |
+
constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
| 104 |
+
auto smem_sfa = PatternVisitor([&](const uint32_t& i) {
|
| 105 |
+
return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE));
|
| 106 |
+
});
|
| 107 |
+
auto smem_sfb = PatternVisitor([&](const uint32_t& i) {
|
| 108 |
+
return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE));
|
| 109 |
+
});
|
| 110 |
+
|
| 111 |
+
// Barriers on shared memory
|
| 112 |
+
constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE);
|
| 113 |
+
auto full_barriers = PatternVisitor([&](const uint32_t& i) {
|
| 114 |
+
return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast<uint32_t>(sizeof(Barrier))));
|
| 115 |
+
});
|
| 116 |
+
auto empty_barriers = PatternVisitor([&](const uint32_t& i) {
|
| 117 |
+
return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast<uint32_t>(sizeof(Barrier))));
|
| 118 |
+
});
|
| 119 |
+
|
| 120 |
+
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
|
| 121 |
+
// Load tensormap A/B to shared memory
|
| 122 |
+
if constexpr (kGemmType == GemmType::KGroupedContiguous) {
|
| 123 |
+
*smem_tensor_map_a[0] = tensor_map_a_base;
|
| 124 |
+
*smem_tensor_map_a[1] = tensor_map_a_base;
|
| 125 |
+
*smem_tensor_map_b[0] = tensor_map_b_base;
|
| 126 |
+
*smem_tensor_map_b[1] = tensor_map_b_base;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
// Initialize barriers
|
| 130 |
+
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
|
| 131 |
+
// even with TMA multicast disabled, we want to make the behavior aligned
|
| 132 |
+
#pragma unroll
|
| 133 |
+
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
| 134 |
+
full_barriers[i]->init(1);
|
| 135 |
+
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
// Make initialized barrier visible in async proxy
|
| 139 |
+
cutlass::arch::fence_barrier_init();
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
// Synchronize all threads to make barrier visible in normal memory model
|
| 143 |
+
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
|
| 144 |
+
|
| 145 |
+
// Pipeline unroll control
|
| 146 |
+
constexpr uint32_t kNumPipelineUnrolls = (kGemmType == GemmType::KGroupedContiguous ? 0 : kNumStages);
|
| 147 |
+
|
| 148 |
+
// Register reconfigurations (more math registers are needed with unrolling)
|
| 149 |
+
constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24);
|
| 150 |
+
constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240);
|
| 151 |
+
|
| 152 |
+
// Block scheduler
|
| 153 |
+
uint32_t m_block_idx, n_block_idx;
|
| 154 |
+
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs, 128u>(shape_m, shape_n, shape_k, grouped_layout);
|
| 155 |
+
|
| 156 |
+
// TMA and MMA pipeline
|
| 157 |
+
const auto& get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple<uint32_t, uint32_t> {
|
| 158 |
+
return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase
|
| 159 |
+
};
|
| 160 |
+
uint32_t iter_idx = 0;
|
| 161 |
+
|
| 162 |
+
if (warp_idx >= kNumMathThreads / 32) {
|
| 163 |
+
// TMA warp-group for loading data
|
| 164 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
| 165 |
+
|
| 166 |
+
// NOTES: only one thread (or warp) will be used
|
| 167 |
+
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
| 168 |
+
const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base;
|
| 169 |
+
const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base;
|
| 170 |
+
uint32_t last_group_idx = kNumGroups, sum_k = 0;
|
| 171 |
+
|
| 172 |
+
// Persistently schedule over blocks
|
| 173 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 174 |
+
// Assign TMA multicast number into A and B
|
| 175 |
+
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
|
| 176 |
+
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
|
| 177 |
+
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
| 178 |
+
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
| 179 |
+
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
| 180 |
+
|
| 181 |
+
const uint32_t& num_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
|
| 182 |
+
const uint32_t& m_idx = m_block_idx * BLOCK_M;
|
| 183 |
+
const uint32_t& n_idx = n_block_idx * BLOCK_N;
|
| 184 |
+
|
| 185 |
+
if (kGemmType == GemmType::KGroupedContiguous and last_group_idx != scheduler.current_group_idx) {
|
| 186 |
+
const uint32_t& stage_idx = scheduler.current_num_valid_groups & 1;
|
| 187 |
+
const uint32_t& next_stage_idx = stage_idx ^ 1;
|
| 188 |
+
last_group_idx = scheduler.current_group_idx;
|
| 189 |
+
|
| 190 |
+
// Prepare next tensor map
|
| 191 |
+
sum_k += scheduler.current_shape_k;
|
| 192 |
+
if (scheduler.next_group_idx < kNumGroups) {
|
| 193 |
+
tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + static_cast<uint64_t>(sum_k) * shape_m);
|
| 194 |
+
tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + static_cast<uint64_t>(sum_k) * shape_n);
|
| 195 |
+
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k);
|
| 196 |
+
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k);
|
| 197 |
+
*(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]);
|
| 198 |
+
*(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]);
|
| 199 |
+
tensor_map_release_cta();
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
// Get current tensor map
|
| 203 |
+
if (scheduler.current_num_valid_groups > 0) {
|
| 204 |
+
tensor_map_acquire_cta(gmem_tensor_map_a[stage_idx]);
|
| 205 |
+
tensor_map_acquire_cta(gmem_tensor_map_b[stage_idx]);
|
| 206 |
+
current_tensor_map_a = gmem_tensor_map_a[stage_idx];
|
| 207 |
+
current_tensor_map_b = gmem_tensor_map_b[stage_idx];
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
#pragma unroll kNumPipelineUnrolls
|
| 212 |
+
for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) {
|
| 213 |
+
// Wait consumer release
|
| 214 |
+
CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase);
|
| 215 |
+
empty_barriers[stage_idx]->wait(phase ^ 1);
|
| 216 |
+
|
| 217 |
+
// Issue TMA
|
| 218 |
+
auto& full_barrier = *full_barriers[stage_idx];
|
| 219 |
+
const uint32_t& k_idx = k_block_idx * BLOCK_K;
|
| 220 |
+
const uint32_t& sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx;
|
| 221 |
+
tma_copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a);
|
| 222 |
+
tma_copy<BLOCK_N, BLOCK_K, 0>(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b);
|
| 223 |
+
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(current_tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a);
|
| 224 |
+
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(current_tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b);
|
| 225 |
+
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE);
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
| 230 |
+
if constexpr (kNumTMAMulticast > 1) {
|
| 231 |
+
#pragma unroll
|
| 232 |
+
for (uint32_t s = 0; s < kNumStages; ++ s) {
|
| 233 |
+
CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase);
|
| 234 |
+
empty_barriers[stage_idx]->wait(phase ^ 1);
|
| 235 |
+
}
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
} else {
|
| 239 |
+
// Math warp-groups for WGMMA
|
| 240 |
+
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
| 241 |
+
|
| 242 |
+
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
| 243 |
+
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
|
| 244 |
+
const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4;
|
| 245 |
+
const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8;
|
| 246 |
+
|
| 247 |
+
// Persistently schedule over blocks
|
| 248 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 249 |
+
// Accumulation for WGMMA or CUDA promotion
|
| 250 |
+
DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes");
|
| 251 |
+
const uint32_t& current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
|
| 252 |
+
const uint32_t& current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0);
|
| 253 |
+
const uint32_t& num_k_blocks = ceil_div(current_shape_k, BLOCK_K);
|
| 254 |
+
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
|
| 255 |
+
float2 scales_b[WGMMA::kNumAccum / 4];
|
| 256 |
+
|
| 257 |
+
// Empty barrier arrival
|
| 258 |
+
auto empty_barrier_arrive = [&](uint32_t s) {
|
| 259 |
+
if constexpr (kNumTMAMulticast == 1) {
|
| 260 |
+
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
| 261 |
+
} else {
|
| 262 |
+
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
|
| 263 |
+
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
|
| 264 |
+
}
|
| 265 |
+
};
|
| 266 |
+
|
| 267 |
+
#pragma unroll kNumPipelineUnrolls
|
| 268 |
+
for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) {
|
| 269 |
+
// Wait TMA arrivals
|
| 270 |
+
CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase);
|
| 271 |
+
full_barriers[stage_idx]->wait(phase);
|
| 272 |
+
|
| 273 |
+
// Read A scales
|
| 274 |
+
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
| 275 |
+
auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0);
|
| 276 |
+
auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1);
|
| 277 |
+
|
| 278 |
+
// Read B scales
|
| 279 |
+
#pragma unroll
|
| 280 |
+
for (int i = 0; i < WGMMA::kNumAccum / 4; ++i)
|
| 281 |
+
scales_b[i] = ld_shared(reinterpret_cast<float2*>(smem_sfb[stage_idx] + i * 8 + col_idx * 2));
|
| 282 |
+
|
| 283 |
+
// Commit WGMMA instructions
|
| 284 |
+
#pragma unroll
|
| 285 |
+
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 286 |
+
warpgroup_fence_operand(accum[i]);
|
| 287 |
+
warpgroup_arrive();
|
| 288 |
+
#pragma unroll
|
| 289 |
+
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
| 290 |
+
auto desc_a = make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
|
| 291 |
+
auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
|
| 292 |
+
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
| 293 |
+
}
|
| 294 |
+
warpgroup_commit_batch();
|
| 295 |
+
#pragma unroll
|
| 296 |
+
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 297 |
+
warpgroup_fence_operand(accum[i]);
|
| 298 |
+
warpgroup_wait<0>();
|
| 299 |
+
|
| 300 |
+
// Notify barrier arrival
|
| 301 |
+
empty_barrier_arrive(stage_idx);
|
| 302 |
+
|
| 303 |
+
// Promote with scales
|
| 304 |
+
#pragma unroll
|
| 305 |
+
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
| 306 |
+
const float &scale_b_0 = scales_b[i].x;
|
| 307 |
+
const float &scale_b_1 = scales_b[i].y;
|
| 308 |
+
final_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0];
|
| 309 |
+
final_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1];
|
| 310 |
+
final_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2];
|
| 311 |
+
final_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3];
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
// Flush previous stores
|
| 316 |
+
if (warp_idx % 4 == 0 and cute::elect_one_sync())
|
| 317 |
+
cute::tma_store_wait<0>();
|
| 318 |
+
cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
|
| 319 |
+
|
| 320 |
+
// Store to D shared memory
|
| 321 |
+
const auto& smem_d_0 = reinterpret_cast<float2*>(smem_d + r_0 * BLOCK_N + col_idx * 2);
|
| 322 |
+
const auto& smem_d_1 = reinterpret_cast<float2*>(smem_d + r_1 * BLOCK_N + col_idx * 2);
|
| 323 |
+
#pragma unroll
|
| 324 |
+
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
| 325 |
+
st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]});
|
| 326 |
+
st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]});
|
| 327 |
+
}
|
| 328 |
+
cute::tma_store_fence();
|
| 329 |
+
cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
|
| 330 |
+
|
| 331 |
+
// Use TMA store to write back to global memory
|
| 332 |
+
if (warp_idx % 4 == 0 and cute::elect_one_sync()) {
|
| 333 |
+
cute::SM90_TMA_REDUCE_ADD_2D::copy(
|
| 334 |
+
&tensor_map_cd, smem_d_0, n_block_idx * BLOCK_N,
|
| 335 |
+
current_group_idx * shape_m + m_block_idx * BLOCK_M + r_0);
|
| 336 |
+
cute::tma_store_arrive();
|
| 337 |
+
}
|
| 338 |
+
__syncwarp();
|
| 339 |
+
}
|
| 340 |
+
}
|
| 341 |
+
#else
|
| 342 |
+
if (blockIdx.x == 0 and threadIdx.x == 0)
|
| 343 |
+
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
| 344 |
+
#endif
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
}; // namespace deep_gemm
|
| 348 |
+
|
| 349 |
+
#pragma clang diagnostic pop
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#pragma clang diagnostic push
|
| 4 |
+
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
| 5 |
+
|
| 6 |
+
#include <cutlass/arch/barrier.h>
|
| 7 |
+
#include <cutlass/arch/reg_reconfig.h>
|
| 8 |
+
|
| 9 |
+
#include <cute/arch/cluster_sm90.hpp>
|
| 10 |
+
#include <cute/arch/copy_sm90_desc.hpp>
|
| 11 |
+
#include <cute/arch/copy_sm90_tma.hpp>
|
| 12 |
+
|
| 13 |
+
#include <deep_gemm/common/epilogue_utils.cuh>
|
| 14 |
+
#include <deep_gemm/common/utils.cuh>
|
| 15 |
+
#include <deep_gemm/common/scheduler.cuh>
|
| 16 |
+
#include <deep_gemm/common/sm90_utils.cuh>
|
| 17 |
+
|
| 18 |
+
namespace deep_gemm {
|
| 19 |
+
|
| 20 |
+
using namespace deep_gemm::sm90;
|
| 21 |
+
|
| 22 |
+
template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd, typename func_t>
|
| 23 |
+
__device__ void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) {
|
| 24 |
+
if (num_former_iters == kNumFormerIters) {
|
| 25 |
+
func(cute::Int<kNumFormerIters>{});
|
| 26 |
+
return;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
if constexpr (kNumFormerIters + kGap <= kEnd)
|
| 30 |
+
dispatch_num_former_iters<kNumFormerIters + kGap, kGap, kEnd>(num_former_iters, func);
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
template <cute::UMMA::Major kMajorSFB,
|
| 34 |
+
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 35 |
+
uint32_t kNumGroups,
|
| 36 |
+
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
| 37 |
+
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleDMode,
|
| 38 |
+
uint32_t kNumStages, uint32_t kNumLastStages,
|
| 39 |
+
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
|
| 40 |
+
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
| 41 |
+
uint32_t kNumSMs, GemmType kGemmType,
|
| 42 |
+
typename epilogue_type_t>
|
| 43 |
+
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
|
| 44 |
+
sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
| 45 |
+
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
| 46 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
| 47 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
| 48 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_d,
|
| 49 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) {
|
| 50 |
+
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
| 51 |
+
// Scaling checks
|
| 52 |
+
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
| 53 |
+
DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
|
| 54 |
+
|
| 55 |
+
// Types
|
| 56 |
+
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
| 57 |
+
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 58 |
+
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size");
|
| 59 |
+
|
| 60 |
+
// Overwrite shape constants if the compiler gives
|
| 61 |
+
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
|
| 62 |
+
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
|
| 63 |
+
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
| 64 |
+
|
| 65 |
+
// Shared memory
|
| 66 |
+
static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
|
| 67 |
+
static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(__nv_bfloat16)), 1024u);
|
| 68 |
+
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
| 69 |
+
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
| 70 |
+
static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
| 71 |
+
static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u);
|
| 72 |
+
const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K);
|
| 73 |
+
const uint32_t& shape_n_sfb = ceil_div(shape_n, BLOCK_K);
|
| 74 |
+
const uint32_t& smem_sfb_size = align<uint32_t>(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier));
|
| 75 |
+
|
| 76 |
+
// NOTES: Make sure we have enough shared memory for WGMMA padding
|
| 77 |
+
static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
| 78 |
+
DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA");
|
| 79 |
+
|
| 80 |
+
// Configs
|
| 81 |
+
const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K);
|
| 82 |
+
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 83 |
+
const uint32_t lane_idx = get_lane_idx();
|
| 84 |
+
|
| 85 |
+
// Prefetch TMA descriptors at the very beginning
|
| 86 |
+
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
| 87 |
+
cute::prefetch_tma_descriptor(&tensor_map_a);
|
| 88 |
+
cute::prefetch_tma_descriptor(&tensor_map_b);
|
| 89 |
+
cute::prefetch_tma_descriptor(&tensor_map_sfa);
|
| 90 |
+
cute::prefetch_tma_descriptor(&tensor_map_d);
|
| 91 |
+
}
|
| 92 |
+
__syncwarp();
|
| 93 |
+
|
| 94 |
+
// Align to 1024 bytes for swizzle-128B
|
| 95 |
+
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
| 96 |
+
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
| 97 |
+
|
| 98 |
+
// Data on shared memory
|
| 99 |
+
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
|
| 100 |
+
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
| 101 |
+
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
| 102 |
+
});
|
| 103 |
+
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
| 104 |
+
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
| 105 |
+
});
|
| 106 |
+
constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
| 107 |
+
auto smem_sfa = PatternVisitor([&](const uint32_t& i) {
|
| 108 |
+
return reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE);
|
| 109 |
+
});
|
| 110 |
+
auto smem_sfb = reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE);
|
| 111 |
+
|
| 112 |
+
// Fill barriers
|
| 113 |
+
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_sfb) + smem_sfb_size);
|
| 114 |
+
auto full_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; });
|
| 115 |
+
auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; });
|
| 116 |
+
|
| 117 |
+
// Initialize barriers
|
| 118 |
+
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
|
| 119 |
+
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
|
| 120 |
+
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
|
| 121 |
+
// even with TMA multicast disabled, we want to make the behavior aligned
|
| 122 |
+
#pragma unroll
|
| 123 |
+
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
| 124 |
+
full_barriers[i]->init(1);
|
| 125 |
+
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
// Make initialized barrier visible in async proxy
|
| 129 |
+
cutlass::arch::fence_barrier_init();
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
// Synchronize all threads to make barrier visible in normal memory model
|
| 133 |
+
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
|
| 134 |
+
|
| 135 |
+
// Register reconfigurations
|
| 136 |
+
constexpr uint32_t kNumTMARegisters = 40;
|
| 137 |
+
constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232;
|
| 138 |
+
|
| 139 |
+
// Block scheduler
|
| 140 |
+
uint32_t m_block_idx, n_block_idx;
|
| 141 |
+
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
|
| 142 |
+
|
| 143 |
+
// Pipeline and TMA phases
|
| 144 |
+
uint32_t stage_idx = 0, phase = 0;
|
| 145 |
+
auto advance_pipeline = [&](uint32_t& k_block_idx) {
|
| 146 |
+
++ k_block_idx;
|
| 147 |
+
|
| 148 |
+
// Flip phases only if reach the next first stage
|
| 149 |
+
stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
|
| 150 |
+
phase ^= stage_idx == 0;
|
| 151 |
+
};
|
| 152 |
+
|
| 153 |
+
if (warp_idx >= kNumMathThreads / 32) {
|
| 154 |
+
// TMA warp-group for loading data
|
| 155 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
| 156 |
+
|
| 157 |
+
// NOTES: only one thread (or warp) will be used
|
| 158 |
+
// We use the third warp, as warp 0/1 may be doing WGMMA with `BLOCK_M == 32`
|
| 159 |
+
if (warp_idx == kNumMathThreads / 32 + 2 and cute::elect_one_sync()) {
|
| 160 |
+
// Persistently schedule over blocks
|
| 161 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 162 |
+
// Assign TMA multicast number into A and B
|
| 163 |
+
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
|
| 164 |
+
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
|
| 165 |
+
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
| 166 |
+
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
| 167 |
+
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
| 168 |
+
|
| 169 |
+
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 170 |
+
// Wait consumer release
|
| 171 |
+
empty_barriers[stage_idx]->wait(phase ^ 1);
|
| 172 |
+
|
| 173 |
+
// Issue TMA A
|
| 174 |
+
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
|
| 175 |
+
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
|
| 176 |
+
|
| 177 |
+
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
|
| 178 |
+
auto& full_barrier = *full_barriers[stage_idx];
|
| 179 |
+
const uint32_t k_idx = k_block_idx * BLOCK_K;
|
| 180 |
+
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_a, &full_barrier,
|
| 181 |
+
smem_a[stage_idx], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
|
| 182 |
+
num_tma_multicast_a, batch_idx);
|
| 183 |
+
tma_copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier,
|
| 184 |
+
smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx<kWithGroupOffsetA, IndexType::SF_K>(shape_k_scales, 1, k_block_idx),
|
| 185 |
+
num_tma_multicast_a);
|
| 186 |
+
|
| 187 |
+
// Issue TMA B
|
| 188 |
+
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_b, &full_barrier,
|
| 189 |
+
smem_b[stage_idx], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
|
| 190 |
+
num_tma_multicast_b, batch_idx);
|
| 191 |
+
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
|
| 192 |
+
}
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
| 196 |
+
if constexpr (kNumTMAMulticast > 1) {
|
| 197 |
+
for (uint32_t i = 0; i < kNumStages; advance_pipeline(i))
|
| 198 |
+
empty_barriers[stage_idx]->wait(phase ^ 1);
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
} else {
|
| 202 |
+
// Math warp-groups for WGMMA
|
| 203 |
+
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
| 204 |
+
|
| 205 |
+
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
| 206 |
+
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
|
| 207 |
+
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
|
| 208 |
+
|
| 209 |
+
auto a_desc = make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1);
|
| 210 |
+
auto b_desc = make_smem_desc(smem_b[0], 1);
|
| 211 |
+
const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
|
| 212 |
+
const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
|
| 213 |
+
|
| 214 |
+
// Persistently schedule over blocks
|
| 215 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 216 |
+
// Decide the number of scales B to load
|
| 217 |
+
DG_TRAP_ONLY_DEVICE_ASSERT(shape_n % 8 == 0);
|
| 218 |
+
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
|
| 219 |
+
if constexpr (not kMustUseUniformedScaleB) {
|
| 220 |
+
num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8;
|
| 221 |
+
num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N) / 8;
|
| 222 |
+
}
|
| 223 |
+
uint32_t num_sfb = shape_k_scales * (num_former_iters >= num_full_iters ? 1 : 2);
|
| 224 |
+
|
| 225 |
+
// Load B scales with math warp-groups
|
| 226 |
+
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
|
| 227 |
+
if (threadIdx.x >= 32) {
|
| 228 |
+
auto previous_group_offset = scheduler.template get_global_idx<true, IndexType::SF_K>(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx);
|
| 229 |
+
const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales;
|
| 230 |
+
const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1;
|
| 231 |
+
auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb;
|
| 232 |
+
|
| 233 |
+
#pragma unroll
|
| 234 |
+
for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32)
|
| 235 |
+
st_shared(smem_sfb + i, __ldg(i < shape_k_scales ? local_sfb + i * stride_k_sfb : local_sfb + (i - shape_k_scales) * stride_k_sfb + stride_n_sfb));
|
| 236 |
+
}
|
| 237 |
+
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
|
| 238 |
+
|
| 239 |
+
// Accumulation for WGMMA or CUDA promotion
|
| 240 |
+
constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2;
|
| 241 |
+
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes");
|
| 242 |
+
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
|
| 243 |
+
|
| 244 |
+
// Pick threads whose WGMMA results are to be stored in shared memory
|
| 245 |
+
DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for `BLOCK_M < 64`");
|
| 246 |
+
constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M);
|
| 247 |
+
const bool do_wgmma_store = BLOCK_M >= WGMMA::M or warp_idx < kNumWGMMAStoreThreads / 32;
|
| 248 |
+
|
| 249 |
+
// Empty barrier arrival
|
| 250 |
+
auto empty_barrier_arrive = [&]() {
|
| 251 |
+
if constexpr (kNumTMAMulticast == 1) {
|
| 252 |
+
lane_idx == 0 ? empty_barriers[stage_idx]->arrive() : void();
|
| 253 |
+
} else {
|
| 254 |
+
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
|
| 255 |
+
lane_idx < kNumTMAMulticast ? empty_barriers[stage_idx]->arrive(target_cta) : void();
|
| 256 |
+
}
|
| 257 |
+
};
|
| 258 |
+
|
| 259 |
+
// Skip useless computations
|
| 260 |
+
if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) {
|
| 261 |
+
// The compiler must know the dynamic variable `num_former_iters`'s real value
|
| 262 |
+
constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
|
| 263 |
+
constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
|
| 264 |
+
constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
|
| 265 |
+
|
| 266 |
+
// Dispatch `num_former_iters` and launch MMAs
|
| 267 |
+
dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) {
|
| 268 |
+
#pragma unroll 8
|
| 269 |
+
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 270 |
+
const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
|
| 271 |
+
const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
|
| 272 |
+
|
| 273 |
+
// Read B scales
|
| 274 |
+
float scale_b_0 = ld_shared(smem_sfb + k_block_idx), scale_b_1;
|
| 275 |
+
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
|
| 276 |
+
if constexpr (not kMustUseUniformedScaleB)
|
| 277 |
+
scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales);
|
| 278 |
+
|
| 279 |
+
// Wait TMA arrivals
|
| 280 |
+
full_barriers[stage_idx]->wait(phase);
|
| 281 |
+
|
| 282 |
+
// TODO: remove some useless computation for unaligned Ms
|
| 283 |
+
#pragma unroll
|
| 284 |
+
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
| 285 |
+
auto m_offset = local_idx * WAVE_BLOCK_M;
|
| 286 |
+
|
| 287 |
+
// Read A scales
|
| 288 |
+
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
| 289 |
+
auto scale_a_0 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0;
|
| 290 |
+
auto scale_a_1 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0;
|
| 291 |
+
|
| 292 |
+
// Commit WGMMA instructions
|
| 293 |
+
#pragma unroll
|
| 294 |
+
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 295 |
+
warpgroup_fence_operand(accum[i]);
|
| 296 |
+
warpgroup_arrive();
|
| 297 |
+
#pragma unroll
|
| 298 |
+
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
| 299 |
+
a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16;
|
| 300 |
+
b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16;
|
| 301 |
+
WGMMA::wgmma(a_desc, b_desc, accum, k);
|
| 302 |
+
}
|
| 303 |
+
warpgroup_commit_batch();
|
| 304 |
+
#pragma unroll
|
| 305 |
+
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 306 |
+
warpgroup_fence_operand(accum[i]);
|
| 307 |
+
warpgroup_wait<0>();
|
| 308 |
+
|
| 309 |
+
// Notify barrier arrival at the last warpgroup wave
|
| 310 |
+
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
|
| 311 |
+
empty_barrier_arrive();
|
| 312 |
+
|
| 313 |
+
// Skip promotion for the unfilled parts
|
| 314 |
+
if (not do_wgmma_store)
|
| 315 |
+
continue;
|
| 316 |
+
|
| 317 |
+
// Promote with scales
|
| 318 |
+
// NOTES: making it as predicates is very important for performance, comparing to two loops
|
| 319 |
+
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
|
| 320 |
+
float scale_0_1, scale_1_1;
|
| 321 |
+
if constexpr (not kMustUseUniformedScaleB)
|
| 322 |
+
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
| 323 |
+
|
| 324 |
+
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
|
| 325 |
+
#pragma unroll
|
| 326 |
+
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
| 327 |
+
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
|
| 328 |
+
const bool& predicate = kMustUseUniformedScaleB or i < num_former_iters;
|
| 329 |
+
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
| 330 |
+
shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
| 331 |
+
shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
| 332 |
+
shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
|
| 333 |
+
}
|
| 334 |
+
}
|
| 335 |
+
}
|
| 336 |
+
});
|
| 337 |
+
} else {
|
| 338 |
+
#pragma unroll
|
| 339 |
+
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
| 340 |
+
full_barriers[stage_idx]->wait(phase);
|
| 341 |
+
empty_barrier_arrive();
|
| 342 |
+
}
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
// TMA checks
|
| 346 |
+
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
|
| 347 |
+
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes);
|
| 348 |
+
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
|
| 349 |
+
DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom");
|
| 350 |
+
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
|
| 351 |
+
"Unaligned TMA store or too many TMA store instructions");
|
| 352 |
+
DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N");
|
| 353 |
+
|
| 354 |
+
// Skip WGMMA store for the unfilled parts
|
| 355 |
+
if (not do_wgmma_store)
|
| 356 |
+
continue;
|
| 357 |
+
|
| 358 |
+
// Wait last TMA store to be finished
|
| 359 |
+
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N)
|
| 360 |
+
cute::tma_store_wait<0>();
|
| 361 |
+
cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1);
|
| 362 |
+
|
| 363 |
+
// Write back to shared memory using STSM and issue TMA stores
|
| 364 |
+
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
| 365 |
+
#pragma unroll
|
| 366 |
+
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
| 367 |
+
auto m_offset = local_idx * WAVE_BLOCK_M;
|
| 368 |
+
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
|
| 369 |
+
#pragma unroll
|
| 370 |
+
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
| 371 |
+
// Swizzle or padding into the correct address
|
| 372 |
+
uint8_t* smem_ptr = nullptr;
|
| 373 |
+
if constexpr (kSwizzleDMode > 0) {
|
| 374 |
+
// Calculate the swizzling atom offset and in-atom offset
|
| 375 |
+
constexpr uint32_t kNumBankGroupBytes = 16;
|
| 376 |
+
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
|
| 377 |
+
|
| 378 |
+
// Calculate the index of the bank group to be written in the atom
|
| 379 |
+
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
|
| 380 |
+
|
| 381 |
+
// Reshape the atom in another view and swizzle
|
| 382 |
+
// - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)`
|
| 383 |
+
// - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
|
| 384 |
+
constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8;
|
| 385 |
+
auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8);
|
| 386 |
+
auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8);
|
| 387 |
+
col ^= row % (kSwizzleDMode / 16);
|
| 388 |
+
|
| 389 |
+
// Add back into the base pointer
|
| 390 |
+
// NOTES: think twice before modifying this, as changes may affect the number of instructions
|
| 391 |
+
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
|
| 392 |
+
warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset
|
| 393 |
+
m_offset * kSwizzleDMode + // Wave offset
|
| 394 |
+
atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants)
|
| 395 |
+
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
| 396 |
+
} else {
|
| 397 |
+
// No swizzling, just padding
|
| 398 |
+
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8);
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
// NOTES: only 16 lanes' addresses are used
|
| 402 |
+
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
| 403 |
+
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
|
| 404 |
+
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
|
| 405 |
+
smem_ptr
|
| 406 |
+
);
|
| 407 |
+
}
|
| 408 |
+
}
|
| 409 |
+
cute::tma_store_fence();
|
| 410 |
+
cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1);
|
| 411 |
+
|
| 412 |
+
// Use TMA store to write back to global memory
|
| 413 |
+
// TODO: compatible with FP32 output
|
| 414 |
+
constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked;
|
| 415 |
+
DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
|
| 416 |
+
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
|
| 417 |
+
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
|
| 418 |
+
auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
|
| 419 |
+
auto n_idx = epilogue_type_t::apply_index_n<TMA_D_BLOCK_N>(n_block_idx * BLOCK_N + in_block_n_offset);
|
| 420 |
+
auto m_idx = scheduler.get_global_idx<kWithGroupOffsetD>(shape_m, BLOCK_M, m_block_idx);
|
| 421 |
+
if constexpr (kGemmType == GemmType::Batched) {
|
| 422 |
+
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_ptr,
|
| 423 |
+
n_idx, m_idx, scheduler.current_group_idx);
|
| 424 |
+
} else {
|
| 425 |
+
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, n_idx, m_idx);
|
| 426 |
+
}
|
| 427 |
+
cute::tma_store_arrive();
|
| 428 |
+
}
|
| 429 |
+
__syncwarp();
|
| 430 |
+
}
|
| 431 |
+
}
|
| 432 |
+
#else
|
| 433 |
+
if (blockIdx.x == 0 and threadIdx.x == 0)
|
| 434 |
+
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
| 435 |
+
#endif
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
}; // namespace deep_gemm
|
| 439 |
+
|
| 440 |
+
#pragma clang diagnostic pop
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cutlass/arch/barrier.h>
|
| 4 |
+
#include <cutlass/arch/reg_reconfig.h>
|
| 5 |
+
|
| 6 |
+
#include <cute/arch/cluster_sm90.hpp>
|
| 7 |
+
#include <cute/arch/copy_sm90_desc.hpp>
|
| 8 |
+
#include <cute/arch/mma_sm90_desc.hpp>
|
| 9 |
+
|
| 10 |
+
#include <deep_gemm/common/utils.cuh>
|
| 11 |
+
#include <deep_gemm/common/sm90_utils.cuh>
|
| 12 |
+
|
| 13 |
+
namespace deep_gemm {
|
| 14 |
+
|
| 15 |
+
using namespace deep_gemm::sm90;
|
| 16 |
+
|
| 17 |
+
// ReSharper disable once CppNotAllPathsReturnValue
|
| 18 |
+
template <uint32_t kHeadDim>
|
| 19 |
+
static constexpr int to_swizzle_cute_type() {
|
| 20 |
+
DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling");
|
| 21 |
+
if constexpr (kHeadDim == 32)
|
| 22 |
+
return static_cast<int>(cute::SM90::GMMA::LayoutType::B32);
|
| 23 |
+
if constexpr (kHeadDim == 64)
|
| 24 |
+
return static_cast<int>(cute::SM90::GMMA::LayoutType::B64);
|
| 25 |
+
if constexpr (kHeadDim == 128)
|
| 26 |
+
return static_cast<int>(cute::SM90::GMMA::LayoutType::B128);
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
template <uint32_t kNumHeads, uint32_t kHeadDim,
|
| 30 |
+
bool kIsCompressedLogits,
|
| 31 |
+
uint32_t BLOCK_Q, uint32_t BLOCK_KV,
|
| 32 |
+
uint32_t kNumQStages, uint32_t kNumKVStages,
|
| 33 |
+
uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
|
| 34 |
+
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
|
| 35 |
+
void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
| 36 |
+
const uint32_t max_seqlen_k, const uint64_t stride_logits,
|
| 37 |
+
uint32_t* cu_seq_len_k_start,
|
| 38 |
+
uint32_t* cu_seq_len_k_end,
|
| 39 |
+
float* logits,
|
| 40 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
| 41 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
| 42 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
| 43 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
|
| 44 |
+
// TODO: consider TMA multicast
|
| 45 |
+
// For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]`
|
| 46 |
+
// Q should be load only at once for a block
|
| 47 |
+
const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q);
|
| 48 |
+
|
| 49 |
+
// Types
|
| 50 |
+
using WGMMA = typename FP8MMASelector<BLOCK_Q * kNumHeads>::type;
|
| 51 |
+
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 52 |
+
|
| 53 |
+
// Prefetch TMA descriptors
|
| 54 |
+
DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
|
| 55 |
+
if (threadIdx.x / 32 == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
| 56 |
+
cute::prefetch_tma_descriptor(&tensor_map_q);
|
| 57 |
+
cute::prefetch_tma_descriptor(&tensor_map_kv);
|
| 58 |
+
cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
|
| 59 |
+
cute::prefetch_tma_descriptor(&tensor_map_weights);
|
| 60 |
+
}
|
| 61 |
+
__syncwarp();
|
| 62 |
+
|
| 63 |
+
// Shared memory configs
|
| 64 |
+
// NOTES: weight may be unaligned
|
| 65 |
+
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
|
| 66 |
+
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
|
| 67 |
+
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float);
|
| 68 |
+
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
|
| 69 |
+
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
|
| 70 |
+
|
| 71 |
+
// Align to swizzling alignment bytes
|
| 72 |
+
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
|
| 73 |
+
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
| 74 |
+
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
| 75 |
+
|
| 76 |
+
// Data on shared memory
|
| 77 |
+
auto smem_q = PatternVisitor([&](const uint32_t& i) {
|
| 78 |
+
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer +
|
| 79 |
+
SMEM_Q_SIZE_PER_STAGE * i);
|
| 80 |
+
});
|
| 81 |
+
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
|
| 82 |
+
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (
|
| 83 |
+
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i));
|
| 84 |
+
});
|
| 85 |
+
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
|
| 86 |
+
return reinterpret_cast<float*>(smem_buffer +
|
| 87 |
+
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
| 88 |
+
});
|
| 89 |
+
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
|
| 90 |
+
return reinterpret_cast<float*>(smem_buffer +
|
| 91 |
+
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages +
|
| 92 |
+
SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
|
| 93 |
+
});
|
| 94 |
+
|
| 95 |
+
// TMA barriers
|
| 96 |
+
auto barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
|
| 97 |
+
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
|
| 98 |
+
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
|
| 99 |
+
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
|
| 100 |
+
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
|
| 101 |
+
|
| 102 |
+
// Initialize barriers
|
| 103 |
+
const bool& is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32;
|
| 104 |
+
if (is_tma_load_warp and cute::elect_one_sync()) {
|
| 105 |
+
#pragma unroll
|
| 106 |
+
for (uint32_t i = 0; i < kNumQStages; ++ i) {
|
| 107 |
+
full_q_barriers[i]->init(1);
|
| 108 |
+
empty_q_barriers[i]->init(kNumMathThreads);
|
| 109 |
+
}
|
| 110 |
+
#pragma unroll
|
| 111 |
+
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
|
| 112 |
+
full_kv_barriers[i]->init(1);
|
| 113 |
+
empty_kv_barriers[i]->init(kNumMathThreads);
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
// Make initialized barrier visible in async proxy
|
| 117 |
+
cutlass::arch::fence_barrier_init();
|
| 118 |
+
}
|
| 119 |
+
__syncthreads();
|
| 120 |
+
|
| 121 |
+
// Register reconfigurations
|
| 122 |
+
constexpr uint32_t kNumTMARegisters = 32;
|
| 123 |
+
constexpr uint32_t kNumMathRegisters = 112;
|
| 124 |
+
|
| 125 |
+
// Block scheduler
|
| 126 |
+
uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0;
|
| 127 |
+
const auto& get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
|
| 128 |
+
return {block_q_idx + gridDim.x, q_iter_idx + 1};
|
| 129 |
+
};
|
| 130 |
+
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
|
| 131 |
+
const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
|
| 132 |
+
uint32_t start = cute::numeric_limits<uint32_t>::max();
|
| 133 |
+
uint32_t end = cute::numeric_limits<uint32_t>::min();
|
| 134 |
+
|
| 135 |
+
#pragma unroll
|
| 136 |
+
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 137 |
+
const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
|
| 138 |
+
seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx);
|
| 139 |
+
seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx);
|
| 140 |
+
start = min(start, min(seq_k_start[i], seq_len_kv));
|
| 141 |
+
end = max(end, min(seq_k_end[i], seq_len_kv));
|
| 142 |
+
}
|
| 143 |
+
start = start / 4 * 4;
|
| 144 |
+
return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
|
| 145 |
+
((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase
|
| 146 |
+
start, ceil_div(end - start, BLOCK_KV)}; // Task info
|
| 147 |
+
};
|
| 148 |
+
|
| 149 |
+
// KV pipeline
|
| 150 |
+
uint32_t num_total_kv_blocks = 0;
|
| 151 |
+
const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple<uint32_t, uint32_t> {
|
| 152 |
+
return {
|
| 153 |
+
(num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage
|
| 154 |
+
((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase
|
| 155 |
+
};
|
| 156 |
+
};
|
| 157 |
+
|
| 158 |
+
if (threadIdx.x >= kNumMathThreads) {
|
| 159 |
+
// TMA warp-group for loading data
|
| 160 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
| 161 |
+
|
| 162 |
+
// Only the first warp remains
|
| 163 |
+
if (not is_tma_load_warp)
|
| 164 |
+
return;
|
| 165 |
+
|
| 166 |
+
// Prefetch
|
| 167 |
+
const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) {
|
| 168 |
+
tma_copy<kHeadDim, BLOCK_Q * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads);
|
| 169 |
+
tma_copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q);
|
| 170 |
+
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
| 171 |
+
};
|
| 172 |
+
if (cute::elect_one_sync() and block_q_idx < num_q_blocks)
|
| 173 |
+
issue_tma_q(0, block_q_idx);
|
| 174 |
+
|
| 175 |
+
// Only the first lane persistently schedules over blocks
|
| 176 |
+
if (cute::elect_one_sync()) {
|
| 177 |
+
while (block_q_idx < num_q_blocks) {
|
| 178 |
+
CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks);
|
| 179 |
+
|
| 180 |
+
// Wait Q consumer release
|
| 181 |
+
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
|
| 182 |
+
|
| 183 |
+
// Issue TMA Q
|
| 184 |
+
if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks)
|
| 185 |
+
issue_tma_q(q_stage_idx, next_block_q_idx);
|
| 186 |
+
|
| 187 |
+
// Issue TMA KV
|
| 188 |
+
#pragma unroll
|
| 189 |
+
for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
|
| 190 |
+
// Wait consumer release
|
| 191 |
+
CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
|
| 192 |
+
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
|
| 193 |
+
|
| 194 |
+
// Issue TMA KV
|
| 195 |
+
tma_copy<kHeadDim, BLOCK_KV, kHeadDim>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
|
| 196 |
+
smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV);
|
| 197 |
+
tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
|
| 198 |
+
smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0);
|
| 199 |
+
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
|
| 200 |
+
}
|
| 201 |
+
num_total_kv_blocks += num_kv_blocks;
|
| 202 |
+
|
| 203 |
+
// Jump to the next block
|
| 204 |
+
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
|
| 205 |
+
}
|
| 206 |
+
}
|
| 207 |
+
} else {
|
| 208 |
+
// Math warp-groups for WGMMA
|
| 209 |
+
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
| 210 |
+
|
| 211 |
+
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
| 212 |
+
const auto& thread_idx = threadIdx.x % kNumMathThreads;
|
| 213 |
+
const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0);
|
| 214 |
+
const auto& warpgroup_idx = warp_idx / 4;
|
| 215 |
+
const auto& lane_idx = get_lane_idx();
|
| 216 |
+
float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4];
|
| 217 |
+
|
| 218 |
+
const auto& warp_offset = warp_idx * 16;
|
| 219 |
+
const auto& v_0_offset = lane_idx / 4 + 0;
|
| 220 |
+
const auto& v_1_offset = lane_idx / 4 + 8;
|
| 221 |
+
|
| 222 |
+
while (block_q_idx < num_q_blocks) {
|
| 223 |
+
CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
|
| 224 |
+
|
| 225 |
+
// Wait TMA Q arrival
|
| 226 |
+
full_q_barriers[q_stage_idx]->wait(q_phase);
|
| 227 |
+
|
| 228 |
+
// Read weights
|
| 229 |
+
#pragma unroll
|
| 230 |
+
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 231 |
+
#pragma unroll
|
| 232 |
+
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
|
| 233 |
+
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
// Compute over KV blocks
|
| 237 |
+
#pragma unroll
|
| 238 |
+
for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
|
| 239 |
+
// Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]`
|
| 240 |
+
// Wait TMA KV arrival
|
| 241 |
+
CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
|
| 242 |
+
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
| 243 |
+
|
| 244 |
+
// Read per-KV scales
|
| 245 |
+
float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset);
|
| 246 |
+
float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset);
|
| 247 |
+
|
| 248 |
+
// Issue WGMMA
|
| 249 |
+
DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size");
|
| 250 |
+
DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
|
| 251 |
+
#pragma unroll
|
| 252 |
+
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 253 |
+
warpgroup_fence_operand(accum[i]);
|
| 254 |
+
warpgroup_arrive();
|
| 255 |
+
#pragma unroll
|
| 256 |
+
for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
|
| 257 |
+
auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K,
|
| 258 |
+
to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
|
| 259 |
+
auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K,
|
| 260 |
+
to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
|
| 261 |
+
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
| 262 |
+
}
|
| 263 |
+
warpgroup_commit_batch();
|
| 264 |
+
#pragma unroll
|
| 265 |
+
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 266 |
+
warpgroup_fence_operand(accum[i]);
|
| 267 |
+
warpgroup_wait<0>();
|
| 268 |
+
|
| 269 |
+
// Release KV empty
|
| 270 |
+
empty_kv_barriers[kv_stage_idx]->arrive();
|
| 271 |
+
|
| 272 |
+
// Reduce over the head dim and store
|
| 273 |
+
const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset;
|
| 274 |
+
static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
|
| 275 |
+
DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation");
|
| 276 |
+
DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == BLOCK_Q, "Invalid accumulation");
|
| 277 |
+
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
|
| 278 |
+
#pragma unroll
|
| 279 |
+
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
| 280 |
+
auto shifted_accum = accum + i * kNumAccumPerReduce;
|
| 281 |
+
const auto& transform = [&](const uint32_t& j) {
|
| 282 |
+
return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
|
| 283 |
+
};
|
| 284 |
+
|
| 285 |
+
// Intra-thread reduction
|
| 286 |
+
float sum[4] = {transform(0), transform(1), transform(2), transform(3)};
|
| 287 |
+
#pragma unroll
|
| 288 |
+
for (uint32_t j = 1; j < kNumHeads / 8; ++ j) {
|
| 289 |
+
#pragma unroll
|
| 290 |
+
for (uint32_t k = 0; k < 4; k ++)
|
| 291 |
+
sum[k] += transform(j * 4 + k);
|
| 292 |
+
}
|
| 293 |
+
float v_0 = (sum[0] + sum[1]) * scale_kv_0;
|
| 294 |
+
float v_1 = (sum[2] + sum[3]) * scale_kv_1;
|
| 295 |
+
|
| 296 |
+
// Inter-thread reduction
|
| 297 |
+
#pragma unroll
|
| 298 |
+
for (uint32_t j = 0; j < 2; ++ j) {
|
| 299 |
+
const auto& offset = static_cast<int>(1u << j);
|
| 300 |
+
v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset);
|
| 301 |
+
v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset);
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
// Store into the global memory
|
| 305 |
+
// NOTES: we have redundant writes here, consider more carefully
|
| 306 |
+
const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
|
| 307 |
+
if constexpr (kIsCompressedLogits) {
|
| 308 |
+
if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i])
|
| 309 |
+
logits[q_idx * stride_logits + kv_offset + v_0_offset - seq_k_start[i]] = v_0;
|
| 310 |
+
if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i])
|
| 311 |
+
logits[q_idx * stride_logits + kv_offset + v_1_offset - seq_k_start[i]] = v_1;
|
| 312 |
+
} else {
|
| 313 |
+
logits[q_idx * stride_logits + kv_offset + v_0_offset] = v_0;
|
| 314 |
+
logits[q_idx * stride_logits + kv_offset + v_1_offset] = v_1;
|
| 315 |
+
}
|
| 316 |
+
}
|
| 317 |
+
}
|
| 318 |
+
num_total_kv_blocks += num_kv_blocks;
|
| 319 |
+
|
| 320 |
+
// Release Q empty
|
| 321 |
+
empty_q_barriers[q_stage_idx]->arrive();
|
| 322 |
+
|
| 323 |
+
// Jump to the next block
|
| 324 |
+
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
|
| 325 |
+
}
|
| 326 |
+
}
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
} // namespace deep_gemm
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cutlass/arch/barrier.h>
|
| 4 |
+
#include <cutlass/arch/reg_reconfig.h>
|
| 5 |
+
|
| 6 |
+
#include <cute/arch/cluster_sm90.hpp>
|
| 7 |
+
#include <cute/arch/copy_sm90_desc.hpp>
|
| 8 |
+
|
| 9 |
+
#include <deep_gemm/common/utils.cuh>
|
| 10 |
+
#include <deep_gemm/common/sm90_utils.cuh>
|
| 11 |
+
#include <deep_gemm/impls/sm90_fp8_mqa_logits.cuh>
|
| 12 |
+
|
| 13 |
+
namespace deep_gemm {
|
| 14 |
+
|
| 15 |
+
template <uint32_t kAlignedBatchSize, uint32_t SPLIT_KV, uint32_t kNumSMs>
|
| 16 |
+
__global__ __launch_bounds__(32, 1)
|
| 17 |
+
void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d,
|
| 18 |
+
const uint32_t* context_lens, uint32_t* schedule_metadata) {
|
| 19 |
+
DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size");
|
| 20 |
+
const uint32_t lane_idx = get_lane_idx();
|
| 21 |
+
|
| 22 |
+
uint32_t num_segs[kAlignedBatchSize / 32];
|
| 23 |
+
#pragma unroll
|
| 24 |
+
for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
|
| 25 |
+
const uint32_t q_idx = k * 32 + lane_idx;
|
| 26 |
+
const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx);
|
| 27 |
+
const uint32_t& context_len = (q_idx < batch_size ? __ldg(context_lens + lens_idx) : 0);
|
| 28 |
+
num_segs[k] = ceil_div(context_len, SPLIT_KV);
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
__shared__ uint32_t prefix_sum[kAlignedBatchSize];
|
| 32 |
+
uint32_t sum = 0;
|
| 33 |
+
#pragma unroll
|
| 34 |
+
for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
|
| 35 |
+
uint32_t x = num_segs[k];
|
| 36 |
+
#pragma unroll
|
| 37 |
+
for (uint32_t offset = 1; offset < 32; offset <<= 1) {
|
| 38 |
+
const uint32_t& y = __shfl_up_sync(0xffffffff, x, offset);
|
| 39 |
+
x += (lane_idx >= offset ? y : 0);
|
| 40 |
+
}
|
| 41 |
+
x += sum;
|
| 42 |
+
prefix_sum[k * 32 + lane_idx] = x;
|
| 43 |
+
sum = __shfl_sync(0xffffffff, x, 31);
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
const uint32_t& q = sum / kNumSMs, r = sum % kNumSMs;
|
| 47 |
+
for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
|
| 48 |
+
uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
|
| 49 |
+
uint32_t q_idx = 0;
|
| 50 |
+
while (q_idx < batch_size and prefix_sum[q_idx] <= seg_starts)
|
| 51 |
+
++ q_idx;
|
| 52 |
+
const uint32_t& kv_split_idx = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1]);
|
| 53 |
+
__syncwarp();
|
| 54 |
+
|
| 55 |
+
schedule_metadata[sm_idx * 2] = q_idx;
|
| 56 |
+
schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
template <uint32_t kNextN, bool kIsContextLens2D,
|
| 61 |
+
uint32_t BLOCK_KV, uint32_t kNumBlocksPerSplit>
|
| 62 |
+
struct PagedMQALogitsScheduler {
|
| 63 |
+
uint32_t batch_size;
|
| 64 |
+
const uint32_t* context_lens;
|
| 65 |
+
|
| 66 |
+
uint32_t current_q_idx, current_kv_idx;
|
| 67 |
+
uint32_t end_q_idx, end_kv_idx;
|
| 68 |
+
uint32_t current_num_kv;
|
| 69 |
+
|
| 70 |
+
__device__ __forceinline__ uint32_t get_num_kv(const uint32_t& q_idx) {
|
| 71 |
+
const auto& lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx);
|
| 72 |
+
return q_idx < batch_size ? ceil_div(__ldg(context_lens + lens_idx), BLOCK_KV) : 0;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
__device__ __forceinline__ explicit PagedMQALogitsScheduler(const uint32_t& batch_size, const uint32_t& sm_idx,
|
| 76 |
+
const uint32_t* context_lens, const uint32_t* schedule_meta) {
|
| 77 |
+
this->batch_size = batch_size;
|
| 78 |
+
this->context_lens = context_lens;
|
| 79 |
+
|
| 80 |
+
const auto& current_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx);
|
| 81 |
+
const auto& end_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx + 1);
|
| 82 |
+
current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit;
|
| 83 |
+
end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit;
|
| 84 |
+
|
| 85 |
+
current_num_kv = get_num_kv(current_q_idx);
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
__device__ __forceinline__ bool fetch_next_task(uint32_t &q_idx, uint32_t &kv_idx, uint32_t &num_kv) {
|
| 89 |
+
q_idx = current_q_idx;
|
| 90 |
+
kv_idx = current_kv_idx;
|
| 91 |
+
num_kv = current_num_kv;
|
| 92 |
+
|
| 93 |
+
if (q_idx == end_q_idx and kv_idx == end_kv_idx)
|
| 94 |
+
return false;
|
| 95 |
+
|
| 96 |
+
current_kv_idx += kNumBlocksPerSplit;
|
| 97 |
+
if (current_kv_idx >= current_num_kv) {
|
| 98 |
+
++ current_q_idx;
|
| 99 |
+
current_kv_idx = 0;
|
| 100 |
+
current_num_kv = get_num_kv(current_q_idx);
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
return true;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
__device__ __forceinline__ bool exist_q_idx(const uint32_t& q_idx) const {
|
| 107 |
+
return q_idx < end_q_idx or q_idx == end_q_idx and 0 < end_kv_idx;
|
| 108 |
+
}
|
| 109 |
+
};
|
| 110 |
+
|
| 111 |
+
using namespace deep_gemm::sm90;
|
| 112 |
+
|
| 113 |
+
template <uint32_t kNextN, uint32_t kNumHeads,
|
| 114 |
+
uint32_t kHeadDim, uint32_t BLOCK_KV,
|
| 115 |
+
bool kIsContextLens2D,
|
| 116 |
+
uint32_t kNumQStages, uint32_t kNumKVStages,
|
| 117 |
+
uint32_t SPLIT_KV,
|
| 118 |
+
uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
|
| 119 |
+
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
|
| 120 |
+
void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
| 121 |
+
const uint64_t logits_stride, const uint64_t block_table_stride,
|
| 122 |
+
const uint32_t* context_lens, float* logits,
|
| 123 |
+
const uint32_t* block_table, const uint32_t* schedule_meta,
|
| 124 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
| 125 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
| 126 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
| 127 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
|
| 128 |
+
// Types
|
| 129 |
+
using WGMMA = typename FP8MMASelector<kNextN * kNumHeads>::type;
|
| 130 |
+
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 131 |
+
|
| 132 |
+
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
| 133 |
+
const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 134 |
+
const auto& warpgroup_idx = warp_idx / 4;
|
| 135 |
+
const auto& lane_idx = get_lane_idx();
|
| 136 |
+
|
| 137 |
+
// Prefetch TMA descriptors
|
| 138 |
+
static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128;
|
| 139 |
+
DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
|
| 140 |
+
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
|
| 141 |
+
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
| 142 |
+
cute::prefetch_tma_descriptor(&tensor_map_q);
|
| 143 |
+
cute::prefetch_tma_descriptor(&tensor_map_kv);
|
| 144 |
+
cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
|
| 145 |
+
cute::prefetch_tma_descriptor(&tensor_map_weights);
|
| 146 |
+
}
|
| 147 |
+
__syncwarp();
|
| 148 |
+
|
| 149 |
+
// Shared memory configs
|
| 150 |
+
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
|
| 151 |
+
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
|
| 152 |
+
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float);
|
| 153 |
+
static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment);
|
| 154 |
+
static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) +
|
| 155 |
+
constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment);
|
| 156 |
+
|
| 157 |
+
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
|
| 158 |
+
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
|
| 159 |
+
static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment);
|
| 160 |
+
static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) +
|
| 161 |
+
constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment);
|
| 162 |
+
|
| 163 |
+
// Align to swizzling alignment bytes
|
| 164 |
+
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
|
| 165 |
+
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
| 166 |
+
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
| 167 |
+
|
| 168 |
+
// Q data and barriers on shared memory
|
| 169 |
+
auto smem_q = PatternVisitor([&](const uint32_t& i) {
|
| 170 |
+
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
|
| 171 |
+
});
|
| 172 |
+
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
|
| 173 |
+
return reinterpret_cast<float*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
| 174 |
+
});
|
| 175 |
+
auto q_barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
|
| 176 |
+
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; });
|
| 177 |
+
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); });
|
| 178 |
+
|
| 179 |
+
// Separate math warpgroups and tma load warps into KV groups
|
| 180 |
+
// Each math warpgroup corresponds to a tma load warp
|
| 181 |
+
const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0);
|
| 182 |
+
|
| 183 |
+
// Per group KV data and barriers on shared memory
|
| 184 |
+
const auto& smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx;
|
| 185 |
+
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
|
| 186 |
+
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i);
|
| 187 |
+
});
|
| 188 |
+
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
|
| 189 |
+
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
|
| 190 |
+
});
|
| 191 |
+
auto kv_barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
|
| 192 |
+
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; });
|
| 193 |
+
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; });
|
| 194 |
+
|
| 195 |
+
// Initialize barriers
|
| 196 |
+
if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) {
|
| 197 |
+
if (kv_group_idx == 0) {
|
| 198 |
+
#pragma unroll
|
| 199 |
+
for (uint32_t i = 0; i < kNumQStages; ++ i) {
|
| 200 |
+
full_q_barriers[i]->init(1);
|
| 201 |
+
empty_q_barriers[i]->init(kNumMathThreads);
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
if (kv_group_idx < kNumMathWarpGroups) {
|
| 205 |
+
#pragma unroll
|
| 206 |
+
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
|
| 207 |
+
full_kv_barriers[i]->init(1);
|
| 208 |
+
empty_kv_barriers[i]->init(128);
|
| 209 |
+
}
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
// Make initialized barrier visible in async proxy
|
| 213 |
+
cutlass::arch::fence_barrier_init();
|
| 214 |
+
}
|
| 215 |
+
__syncthreads();
|
| 216 |
+
|
| 217 |
+
// Register reconfigurations
|
| 218 |
+
constexpr uint32_t kNumTMARegisters = 64;
|
| 219 |
+
constexpr uint32_t kNumMathRegisters = 104;
|
| 220 |
+
|
| 221 |
+
// Scheduler
|
| 222 |
+
auto scheduler = PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumMathWarpGroups>(batch_size, blockIdx.x, context_lens, schedule_meta);
|
| 223 |
+
DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV");
|
| 224 |
+
|
| 225 |
+
// Q and KV pipeline
|
| 226 |
+
const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
|
| 227 |
+
return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase
|
| 228 |
+
};
|
| 229 |
+
const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
|
| 230 |
+
return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase
|
| 231 |
+
};
|
| 232 |
+
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
|
| 233 |
+
|
| 234 |
+
if (warp_idx >= kNumMathThreads / 32) {
|
| 235 |
+
// TMA warp-group for loading data
|
| 236 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
| 237 |
+
if (kv_group_idx >= kNumMathWarpGroups)
|
| 238 |
+
return;
|
| 239 |
+
|
| 240 |
+
const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) {
|
| 241 |
+
if (kv_group_idx == 0 and cute::elect_one_sync()) {
|
| 242 |
+
tma_copy<kHeadDim, kNextN * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads);
|
| 243 |
+
tma_copy<kNextN * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx);
|
| 244 |
+
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
| 245 |
+
}
|
| 246 |
+
};
|
| 247 |
+
|
| 248 |
+
// Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
|
| 249 |
+
uint32_t q_idx = batch_size, kv_idx, num_kv;
|
| 250 |
+
uint32_t next_q_idx, next_kv_idx, next_num_kv;
|
| 251 |
+
bool fetched_next_task;
|
| 252 |
+
|
| 253 |
+
// Prefetch the first Q
|
| 254 |
+
if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)))
|
| 255 |
+
issue_tma_q(0, next_q_idx), q_iter_idx = 1;
|
| 256 |
+
|
| 257 |
+
int kv_block_idx_ptr = 32;
|
| 258 |
+
uint32_t kv_block_idx_storage;
|
| 259 |
+
|
| 260 |
+
while (fetched_next_task) {
|
| 261 |
+
// Prefetch next Q when current Q changes
|
| 262 |
+
bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1));
|
| 263 |
+
q_idx = next_q_idx;
|
| 264 |
+
kv_idx = next_kv_idx;
|
| 265 |
+
num_kv = next_num_kv;
|
| 266 |
+
|
| 267 |
+
// Wait Q consumer release and issue TMA Q
|
| 268 |
+
if (prefetch_q) {
|
| 269 |
+
CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
|
| 270 |
+
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
|
| 271 |
+
issue_tma_q(q_stage_idx, q_idx + 1);
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
// Read KV block index
|
| 275 |
+
// TODO: deal with `-1`?
|
| 276 |
+
if (kv_idx == 0 or kv_block_idx_ptr == 32) {
|
| 277 |
+
kv_block_idx_ptr = 0;
|
| 278 |
+
kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ?
|
| 279 |
+
__ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0);
|
| 280 |
+
}
|
| 281 |
+
const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++);
|
| 282 |
+
|
| 283 |
+
// Wait KV consumer release
|
| 284 |
+
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
| 285 |
+
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
|
| 286 |
+
|
| 287 |
+
// Issue TMA KV
|
| 288 |
+
if (cute::elect_one_sync()) {
|
| 289 |
+
tma_copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
|
| 290 |
+
smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx);
|
| 291 |
+
tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
|
| 292 |
+
smem_kv_scales[kv_stage_idx], 0, kv_block_idx);
|
| 293 |
+
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
// Fetch next task
|
| 297 |
+
fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv);
|
| 298 |
+
}
|
| 299 |
+
} else {
|
| 300 |
+
// Math warp-groups for WGMMA
|
| 301 |
+
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
| 302 |
+
|
| 303 |
+
float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4];
|
| 304 |
+
const auto& sub_warp_offset = (warp_idx % 4) * 16;
|
| 305 |
+
const auto& v_0_offset = lane_idx / 4 + 0;
|
| 306 |
+
const auto& v_1_offset = lane_idx / 4 + 8;
|
| 307 |
+
|
| 308 |
+
// Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
|
| 309 |
+
uint32_t q_idx = batch_size, kv_idx;
|
| 310 |
+
uint32_t next_q_idx, next_kv_idx, next_num_kv;
|
| 311 |
+
uint32_t q_stage_idx, q_phase;
|
| 312 |
+
|
| 313 |
+
while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
|
| 314 |
+
// Current Q changes
|
| 315 |
+
if (q_idx != next_q_idx) {
|
| 316 |
+
// Release Last Q empty
|
| 317 |
+
if (q_iter_idx > 0)
|
| 318 |
+
empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive();
|
| 319 |
+
|
| 320 |
+
// Wait TMA Q arrival
|
| 321 |
+
CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
|
| 322 |
+
full_q_barriers[q_stage_idx]->wait(q_phase);
|
| 323 |
+
|
| 324 |
+
// Read weights
|
| 325 |
+
#pragma unroll
|
| 326 |
+
for (uint32_t i = 0; i < kNextN; ++ i) {
|
| 327 |
+
#pragma unroll
|
| 328 |
+
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
|
| 329 |
+
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
|
| 330 |
+
}
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
// Get current Q and KV index
|
| 334 |
+
q_idx = next_q_idx;
|
| 335 |
+
kv_idx = next_kv_idx;
|
| 336 |
+
|
| 337 |
+
// Calculate KV offset in advance
|
| 338 |
+
auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset);
|
| 339 |
+
|
| 340 |
+
// Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]`
|
| 341 |
+
// Wait TMA KV arrival
|
| 342 |
+
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
| 343 |
+
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
| 344 |
+
|
| 345 |
+
// Issue WGMMA
|
| 346 |
+
DG_STATIC_ASSERT(BLOCK_KV == 64, "Invalid block size");
|
| 347 |
+
DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
|
| 348 |
+
#pragma unroll
|
| 349 |
+
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 350 |
+
warpgroup_fence_operand(accum[i]);
|
| 351 |
+
warpgroup_arrive();
|
| 352 |
+
#pragma unroll
|
| 353 |
+
for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
|
| 354 |
+
auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + k * WGMMA::K, to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
|
| 355 |
+
auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
|
| 356 |
+
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
| 357 |
+
}
|
| 358 |
+
warpgroup_commit_batch();
|
| 359 |
+
#pragma unroll
|
| 360 |
+
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 361 |
+
warpgroup_fence_operand(accum[i]);
|
| 362 |
+
|
| 363 |
+
// Read per-KV scales
|
| 364 |
+
float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset);
|
| 365 |
+
float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset);
|
| 366 |
+
|
| 367 |
+
// Wait WGMMA
|
| 368 |
+
warpgroup_wait<0>();
|
| 369 |
+
|
| 370 |
+
// Release KV empty
|
| 371 |
+
empty_kv_barriers[kv_stage_idx]->arrive();
|
| 372 |
+
|
| 373 |
+
// Reduce over the head dim and store
|
| 374 |
+
static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
|
| 375 |
+
DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation");
|
| 376 |
+
DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == kNextN, "Invalid accumulation");
|
| 377 |
+
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
|
| 378 |
+
#pragma unroll
|
| 379 |
+
for (uint32_t i = 0; i < kNextN; ++ i) {
|
| 380 |
+
auto shifted_accum = accum + i * kNumAccumPerReduce;
|
| 381 |
+
const auto& transform = [&](const uint32_t& j) {
|
| 382 |
+
return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
|
| 383 |
+
};
|
| 384 |
+
|
| 385 |
+
// Intra-thread reduction
|
| 386 |
+
float sum[4] = {transform(0), transform(1), transform(2), transform(3)};
|
| 387 |
+
#pragma unroll
|
| 388 |
+
for (uint32_t j = 1; j < kNumHeads / 8; ++ j) {
|
| 389 |
+
#pragma unroll
|
| 390 |
+
for (uint32_t k = 0; k < 4; k ++)
|
| 391 |
+
sum[k] += transform(j * 4 + k);
|
| 392 |
+
}
|
| 393 |
+
float v_0 = (sum[0] + sum[1]) * scale_kv_0;
|
| 394 |
+
float v_1 = (sum[2] + sum[3]) * scale_kv_1;
|
| 395 |
+
|
| 396 |
+
// Inter-thread reduction
|
| 397 |
+
#pragma unroll
|
| 398 |
+
for (uint32_t j = 0; j < 2; ++ j) {
|
| 399 |
+
const auto& offset = static_cast<int>(1u << j);
|
| 400 |
+
v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset);
|
| 401 |
+
v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset);
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
// Store into the global memory
|
| 405 |
+
// NOTES: we have redundant writes here, consider more carefully
|
| 406 |
+
logits[kv_offset + i * logits_stride + v_0_offset] = v_0;
|
| 407 |
+
logits[kv_offset + i * logits_stride + v_1_offset] = v_1;
|
| 408 |
+
}
|
| 409 |
+
}
|
| 410 |
+
}
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
} // namespace deep_gemm
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#pragma clang diagnostic push
|
| 3 |
+
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
| 4 |
+
|
| 5 |
+
#include <cutlass/arch/barrier.h>
|
| 6 |
+
#include <cutlass/arch/reg_reconfig.h>
|
| 7 |
+
|
| 8 |
+
#include <deep_gemm/common/reduction.cuh>
|
| 9 |
+
#include <deep_gemm/common/utils.cuh>
|
| 10 |
+
#include <deep_gemm/common/sm90_utils.cuh>
|
| 11 |
+
|
| 12 |
+
namespace deep_gemm {
|
| 13 |
+
|
| 14 |
+
using namespace deep_gemm::sm90;
|
| 15 |
+
|
| 16 |
+
template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
|
| 17 |
+
__device__ __forceinline__
|
| 18 |
+
uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) {
|
| 19 |
+
constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase;
|
| 20 |
+
|
| 21 |
+
const auto& bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange;
|
| 22 |
+
|
| 23 |
+
constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase;
|
| 24 |
+
constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups;
|
| 25 |
+
auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups);
|
| 26 |
+
auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups);
|
| 27 |
+
col ^= row % kGroupsInSwizzleRange;
|
| 28 |
+
|
| 29 |
+
return (row * kNumBankGroups + col) % kGroupsInSwizzleRange;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 33 |
+
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
| 34 |
+
uint32_t kNumSplits,
|
| 35 |
+
uint32_t kSwizzleCDMode,
|
| 36 |
+
uint32_t kNumStages,
|
| 37 |
+
uint32_t kNumMathThreads, uint32_t kNumTMAThreads>
|
| 38 |
+
__global__ void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1)
|
| 39 |
+
sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
| 40 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
| 41 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
| 42 |
+
const __grid_constant__ cute::TmaDescriptor tensor_map_d,
|
| 43 |
+
float* sqr_sum) {
|
| 44 |
+
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
| 45 |
+
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 46 |
+
|
| 47 |
+
// kSwizzleAMode and kSwizzleBMode must be 128 for now
|
| 48 |
+
constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128);
|
| 49 |
+
constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128);
|
| 50 |
+
DG_STATIC_ASSERT(BLOCK_K == 64, "Invalid block K");
|
| 51 |
+
DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode");
|
| 52 |
+
DG_STATIC_ASSERT(kSwizzleBMode == 128, "Invalid swizzle B mode");
|
| 53 |
+
|
| 54 |
+
DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N");
|
| 55 |
+
DG_STATIC_ASSERT(kNumMathThreads == 128, "Invalid MMA threads");
|
| 56 |
+
|
| 57 |
+
// Utils
|
| 58 |
+
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
| 59 |
+
const auto lane_idx = get_lane_idx();
|
| 60 |
+
|
| 61 |
+
// Align to 1024 bytes for swizzle-128B
|
| 62 |
+
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
| 63 |
+
|
| 64 |
+
// Share memory sizes
|
| 65 |
+
constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode;
|
| 66 |
+
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16);
|
| 67 |
+
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float);
|
| 68 |
+
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
| 69 |
+
|
| 70 |
+
if (warp_idx == 0 and cute::elect_one_sync()) {
|
| 71 |
+
cute::prefetch_tma_descriptor(&tensor_map_a);
|
| 72 |
+
cute::prefetch_tma_descriptor(&tensor_map_b);
|
| 73 |
+
cute::prefetch_tma_descriptor(&tensor_map_d);
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// Data on shared memory (layout as ordered below)
|
| 77 |
+
// Fill D/A/B pointers
|
| 78 |
+
auto smem_cd = reinterpret_cast<float*>(smem_buffer);
|
| 79 |
+
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
| 80 |
+
return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
|
| 81 |
+
});
|
| 82 |
+
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
| 83 |
+
return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
| 84 |
+
});
|
| 85 |
+
|
| 86 |
+
// Fill barriers
|
| 87 |
+
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
| 88 |
+
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
| 89 |
+
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
| 90 |
+
|
| 91 |
+
// Initialize barriers
|
| 92 |
+
if (warp_idx == 1 and cute::elect_one_sync()) {
|
| 93 |
+
#pragma unroll
|
| 94 |
+
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
| 95 |
+
full_barriers[i]->init(1);
|
| 96 |
+
empty_barriers[i]->init(128);
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
// Make initialized barrier visible in async proxy
|
| 100 |
+
cutlass::arch::fence_barrier_init();
|
| 101 |
+
}
|
| 102 |
+
__syncthreads();
|
| 103 |
+
|
| 104 |
+
constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K);
|
| 105 |
+
constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
|
| 106 |
+
constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
|
| 107 |
+
const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
|
| 108 |
+
const uint32_t m_block_idx = block_idx / kNumSplits;
|
| 109 |
+
const uint32_t k_split_idx = block_idx % kNumSplits;
|
| 110 |
+
const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K;
|
| 111 |
+
const uint32_t m_offset = shape_m * k_split_idx;
|
| 112 |
+
const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks);
|
| 113 |
+
constexpr uint32_t kNumTMARegisters = 40;
|
| 114 |
+
constexpr uint32_t kNumMathRegisters = 256;
|
| 115 |
+
|
| 116 |
+
// TMA load warp
|
| 117 |
+
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
| 118 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
| 119 |
+
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
| 120 |
+
// Wait consumer release
|
| 121 |
+
const auto& stage_idx = s % kNumStages;
|
| 122 |
+
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
|
| 123 |
+
|
| 124 |
+
// Compute offsets
|
| 125 |
+
uint32_t m_idx = m_block_idx * BLOCK_M;
|
| 126 |
+
uint32_t k_idx = k_offset + s * BLOCK_K;
|
| 127 |
+
|
| 128 |
+
// Issue TMAs
|
| 129 |
+
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
|
| 130 |
+
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
|
| 131 |
+
|
| 132 |
+
// Arrive at full barriers
|
| 133 |
+
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
| 134 |
+
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes);
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) {
|
| 138 |
+
const auto& stage_idx = s % kNumStages;
|
| 139 |
+
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
|
| 140 |
+
}
|
| 141 |
+
} else if (warp_idx < kNumMathThreads / 32) {
|
| 142 |
+
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
| 143 |
+
|
| 144 |
+
DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M");
|
| 145 |
+
DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K");
|
| 146 |
+
constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4;
|
| 147 |
+
constexpr uint32_t WGMMA_M = 64;
|
| 148 |
+
constexpr uint32_t WGMMA_N = BLOCK_N;
|
| 149 |
+
constexpr uint32_t WGMMA_K = 8;
|
| 150 |
+
|
| 151 |
+
using WGMMA = typename TF32MMASelector<WGMMA_N, true>::type;
|
| 152 |
+
float accum[WGMMA::kNumAccum] = {0};
|
| 153 |
+
|
| 154 |
+
constexpr uint32_t kNumBankGroupBytes = 16;
|
| 155 |
+
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16);
|
| 156 |
+
constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup;
|
| 157 |
+
float sqr_sum_acc_0 = 0;
|
| 158 |
+
float sqr_sum_acc_1 = 0;
|
| 159 |
+
|
| 160 |
+
#pragma unroll kNumStages < 8 ? kNumStages : kNumStages / 2
|
| 161 |
+
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
| 162 |
+
// Wait TMA arrival
|
| 163 |
+
const auto& stage_idx = s % kNumStages;
|
| 164 |
+
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
|
| 165 |
+
|
| 166 |
+
constexpr uint32_t kNumRegPerWgmma = WGMMA::M * WGMMA::K / 128;
|
| 167 |
+
constexpr uint32_t kNumWgmmaPerBlockK = BLOCK_K / WGMMA::K;
|
| 168 |
+
|
| 169 |
+
float a[kNumRegPerWgmma * kNumWgmmaPerBlockK];
|
| 170 |
+
// Assume swizzle A mode is 128
|
| 171 |
+
DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode");
|
| 172 |
+
|
| 173 |
+
// Load BF16 A fragment from shared memory into registers, and transpose to FP32
|
| 174 |
+
uint32_t row = warp_idx * 16 + lane_idx / 4;
|
| 175 |
+
#pragma unroll
|
| 176 |
+
for (uint32_t i = 0; i < kNumLoads; ++ i) {
|
| 177 |
+
// Refer to the A layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-a
|
| 178 |
+
uint32_t bank_group_idx = (row ^ i) % 8;
|
| 179 |
+
nv_bfloat16* a_bf16_smem_ptr_upper = smem_a[stage_idx] + row * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup;
|
| 180 |
+
nv_bfloat16* a_bf16_smem_ptr_lower = smem_a[stage_idx] + (row + 8) * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup;
|
| 181 |
+
|
| 182 |
+
uint32_t elem_offset = lane_idx % 4;
|
| 183 |
+
nv_bfloat16 a_bf16[kNumRegPerWgmma];
|
| 184 |
+
a_bf16[0] = a_bf16_smem_ptr_upper[elem_offset];
|
| 185 |
+
a_bf16[2] = a_bf16_smem_ptr_upper[elem_offset + 4];
|
| 186 |
+
a_bf16[1] = a_bf16_smem_ptr_lower[elem_offset];
|
| 187 |
+
a_bf16[3] = a_bf16_smem_ptr_lower[elem_offset + 4];
|
| 188 |
+
|
| 189 |
+
auto a_bf16x2_ptr = reinterpret_cast<nv_bfloat162*>(a_bf16);
|
| 190 |
+
auto a_float2_ptr = reinterpret_cast<float2*>(a);
|
| 191 |
+
float2 a_float2_0 = __bfloat1622float2(a_bf16x2_ptr[0]);
|
| 192 |
+
float2 a_float2_1 = __bfloat1622float2(a_bf16x2_ptr[1]);
|
| 193 |
+
a_float2_ptr[i * 2 + 0] = a_float2_0;
|
| 194 |
+
a_float2_ptr[i * 2 + 1] = a_float2_1;
|
| 195 |
+
sqr_sum_acc_0 += a_float2_0.x * a_float2_0.x + a_float2_1.x * a_float2_1.x;
|
| 196 |
+
sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y;
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
warpgroup_wait<0>();
|
| 200 |
+
if (s > 0)
|
| 201 |
+
empty_barriers[(s - 1) % kNumStages]->arrive();
|
| 202 |
+
|
| 203 |
+
#pragma unroll
|
| 204 |
+
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 205 |
+
warpgroup_fence_operand(accum[i]);
|
| 206 |
+
warpgroup_arrive();
|
| 207 |
+
|
| 208 |
+
constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float);
|
| 209 |
+
constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K;
|
| 210 |
+
DG_STATIC_ASSERT(BLOCK_K % kNumElemsInSwizzleRange == 0, "Invalid block K");
|
| 211 |
+
|
| 212 |
+
#pragma unroll
|
| 213 |
+
for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) {
|
| 214 |
+
#pragma unroll
|
| 215 |
+
for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) {
|
| 216 |
+
auto b_desc = make_smem_desc(smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1);
|
| 217 |
+
WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1);
|
| 218 |
+
}
|
| 219 |
+
}
|
| 220 |
+
warpgroup_commit_batch();
|
| 221 |
+
#pragma unroll
|
| 222 |
+
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 223 |
+
warpgroup_fence_operand(accum[i]);
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
const auto& reduced_sum_0 = warp_reduce_sum<4>(sqr_sum_acc_0);
|
| 227 |
+
const auto& reduced_sum_1 = warp_reduce_sum<4>(sqr_sum_acc_1);
|
| 228 |
+
|
| 229 |
+
const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4);
|
| 230 |
+
if (lane_idx % 4 == 0) {
|
| 231 |
+
if (m_idx < shape_m)
|
| 232 |
+
sqr_sum[m_offset + m_idx] = reduced_sum_0;
|
| 233 |
+
if (m_idx + 8 < shape_m)
|
| 234 |
+
sqr_sum[m_offset + m_idx + 8] = reduced_sum_1;
|
| 235 |
+
}
|
| 236 |
+
warpgroup_wait<0>();
|
| 237 |
+
empty_barriers[(num_total_stages-1) % kNumStages]->arrive();
|
| 238 |
+
|
| 239 |
+
// Write accum to shared memory
|
| 240 |
+
// Every 2 threads (one pair) will write to the same bank group (16 bytes).
|
| 241 |
+
// Refer to the D layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-d
|
| 242 |
+
uint32_t is_odd_pair = lane_idx / 2 % 2;
|
| 243 |
+
|
| 244 |
+
// Four threads per group; write the data to the same row.
|
| 245 |
+
uint32_t row_idx = lane_idx / 4;
|
| 246 |
+
|
| 247 |
+
// Even/odd index pairs write to the same column, we need to reorder idx:
|
| 248 |
+
// group even pair indices consecutively, and likewise for odd ones.
|
| 249 |
+
uint32_t reordered_pair_idx = is_odd_pair * 8 + row_idx;
|
| 250 |
+
|
| 251 |
+
auto shifted_smem_ptr = reinterpret_cast<uint8_t*>(smem_cd) +
|
| 252 |
+
(warp_idx * BLOCK_M_PER_WARP + row_idx) * kSwizzleCDMode + // Row offset, each warp has 16 rows
|
| 253 |
+
lane_idx % 2 * 8; // One thread of a pair writes 8 bytes
|
| 254 |
+
|
| 255 |
+
#pragma unroll
|
| 256 |
+
for (uint32_t i = 0; i < (kSwizzleCDMode / sizeof(float)) / 4; i += 2) {
|
| 257 |
+
// Get the swizzled bank group index (16 bytes per group)
|
| 258 |
+
uint32_t bank_group_idx = get_swizzled_bank_group_idx<kSwizzleCDMode>(i + is_odd_pair, reordered_pair_idx);
|
| 259 |
+
auto smem_ptr = shifted_smem_ptr + bank_group_idx * kNumBankGroupBytes; // Col offset, 16 bytes per group
|
| 260 |
+
|
| 261 |
+
// 0/1 write to the same row, 2/3 write to another row
|
| 262 |
+
auto values = reinterpret_cast<uint32_t*>(accum + i * 2);
|
| 263 |
+
st_shared(smem_ptr, values[0], values[1]);
|
| 264 |
+
st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]);
|
| 265 |
+
}
|
| 266 |
+
cute::tma_store_fence();
|
| 267 |
+
cutlass::arch::NamedBarrier::sync(128, 1);
|
| 268 |
+
|
| 269 |
+
// Issue TMA stores
|
| 270 |
+
if (warp_idx == 0 and cute::elect_one_sync()) {
|
| 271 |
+
if constexpr (kNumSplits == 1) {
|
| 272 |
+
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M);
|
| 273 |
+
} else {
|
| 274 |
+
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx);
|
| 275 |
+
}
|
| 276 |
+
cute::tma_store_arrive();
|
| 277 |
+
}
|
| 278 |
+
}
|
| 279 |
+
#else
|
| 280 |
+
if (blockIdx.x == 0 and threadIdx.x == 0)
|
| 281 |
+
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
| 282 |
+
#endif
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
} // namespace deep_gemm
|
| 286 |
+
|
| 287 |
+
#pragma clang diagnostic pop
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cutlass/arch/barrier.h>
|
| 4 |
+
#include <cute/arch/cluster_sm90.hpp>
|
| 5 |
+
|
| 6 |
+
#include <deep_gemm/common/utils.cuh>
|
| 7 |
+
|
| 8 |
+
namespace deep_gemm {
|
| 9 |
+
|
| 10 |
+
template <uint32_t kNextN, uint32_t BLOCK_KV, uint32_t kNumWarps>
|
| 11 |
+
__global__ __launch_bounds__(kNumWarps * 32, 1)
|
| 12 |
+
void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_logits,
|
| 13 |
+
const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, float* logits) {
|
| 14 |
+
const uint32_t& num_sms = gridDim.x;
|
| 15 |
+
const uint32_t& sm_idx = blockIdx.x;
|
| 16 |
+
const uint32_t& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 17 |
+
constexpr float neg_inf = -cute::numeric_limits<float>::infinity();
|
| 18 |
+
|
| 19 |
+
// Allocate filled `-inf` shared memory
|
| 20 |
+
extern __shared__ __align__(1024) float smem_buffer[];
|
| 21 |
+
#pragma unroll
|
| 22 |
+
for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32)
|
| 23 |
+
smem_buffer[i] = neg_inf;
|
| 24 |
+
cute::tma_store_fence();
|
| 25 |
+
__syncthreads();
|
| 26 |
+
|
| 27 |
+
// Assign sequence to each warp
|
| 28 |
+
const auto& assign_task = [&](const uint32_t& num, const uint32_t& idx,
|
| 29 |
+
const uint32_t& start, const uint32_t& total) -> cute::tuple<uint32_t, uint32_t> {
|
| 30 |
+
const auto& per = total / num, rem = total % num;
|
| 31 |
+
return {start + idx * per + min(idx, rem), per + (idx < rem)};
|
| 32 |
+
};
|
| 33 |
+
CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len);
|
| 34 |
+
CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len);
|
| 35 |
+
|
| 36 |
+
if (cute::elect_one_sync()) {
|
| 37 |
+
for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
|
| 38 |
+
const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN);
|
| 39 |
+
const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1;
|
| 40 |
+
const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4;
|
| 41 |
+
|
| 42 |
+
for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) {
|
| 43 |
+
const auto& right = min(left + BLOCK_KV, static_cast<uint32_t>(stride_logits));
|
| 44 |
+
if (right <= ks or ke <= left) {
|
| 45 |
+
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(float));
|
| 46 |
+
} else {
|
| 47 |
+
if (left < aligned_ks)
|
| 48 |
+
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(float));
|
| 49 |
+
if (aligned_ke < right)
|
| 50 |
+
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(float));
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
|
| 57 |
+
const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN);
|
| 58 |
+
const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1;
|
| 59 |
+
const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4;
|
| 60 |
+
for (uint32_t j = aligned_ks; j < ks; ++ j)
|
| 61 |
+
logits[i * stride_logits + j] = neg_inf;
|
| 62 |
+
for (uint32_t j = ke; j < aligned_ke; ++ j)
|
| 63 |
+
logits[i * stride_logits + j] = neg_inf;
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
}
|
build/torch210-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/smxx_layout.cuh
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <deep_gemm/common/utils.cuh>
|
| 4 |
+
|
| 5 |
+
namespace deep_gemm {
|
| 6 |
+
|
| 7 |
+
template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K,
|
| 8 |
+
uint32_t PADDED_SF_K = SF_K + (1 - (SF_K % 2))>
|
| 9 |
+
__global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) {
|
| 10 |
+
typedef typename Vectorized<sizeof(float) * SF_K>::vec_t in_vec_t;
|
| 11 |
+
constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float);
|
| 12 |
+
constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec;
|
| 13 |
+
|
| 14 |
+
// Shapes and strides
|
| 15 |
+
extern __shared__ float smem_buffer[];
|
| 16 |
+
constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(float));
|
| 17 |
+
const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
|
| 18 |
+
const auto tma_aligned_mn = align<uint32_t>(mn, kNumTMAAlignedElems);
|
| 19 |
+
|
| 20 |
+
// Shift into the block
|
| 21 |
+
sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
|
| 22 |
+
out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * SF_K;
|
| 23 |
+
const auto& local_sf = reinterpret_cast<const in_vec_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
|
| 24 |
+
|
| 25 |
+
// Load
|
| 26 |
+
for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) {
|
| 27 |
+
auto in_vec = __ldg(local_sf + i);
|
| 28 |
+
const auto& in_values = reinterpret_cast<float*>(&in_vec);
|
| 29 |
+
|
| 30 |
+
const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec;
|
| 31 |
+
#pragma unroll
|
| 32 |
+
for (uint32_t j = 0; j < kNumElemsPerVec; ++ j)
|
| 33 |
+
smem_buffer[row * PADDED_SF_K + col + j] = in_values[j];
|
| 34 |
+
}
|
| 35 |
+
__syncthreads();
|
| 36 |
+
|
| 37 |
+
// Store
|
| 38 |
+
#pragma unroll
|
| 39 |
+
for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) {
|
| 40 |
+
const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn;
|
| 41 |
+
const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx;
|
| 42 |
+
out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx);
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
// NOTES: the two kernels below always pack the K dimension
|
| 47 |
+
|
| 48 |
+
template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K>
|
| 49 |
+
__global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) {
|
| 50 |
+
extern __shared__ uint32_t smem_buffer[];
|
| 51 |
+
|
| 52 |
+
// Shapes and strides
|
| 53 |
+
constexpr auto kNumPackedSFK = constexpr_ceil_div(SF_K, 4u);
|
| 54 |
+
constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(int));
|
| 55 |
+
const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
|
| 56 |
+
const auto tma_aligned_mn = align<uint64_t>(mn, kNumTMAAlignedElems);
|
| 57 |
+
|
| 58 |
+
// Shift into the group
|
| 59 |
+
sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
|
| 60 |
+
out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * kNumPackedSFK;
|
| 61 |
+
|
| 62 |
+
// Load FP32 SFs
|
| 63 |
+
DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size");
|
| 64 |
+
const auto local_sf = reinterpret_cast<uint32_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
|
| 65 |
+
const auto num_values = in_block_mn * SF_K;
|
| 66 |
+
const auto num_uint4 = num_values / 4;
|
| 67 |
+
#pragma unroll
|
| 68 |
+
for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) {
|
| 69 |
+
const auto& [x, y, z, w] = __ldg(reinterpret_cast<uint4*>(local_sf) + i);
|
| 70 |
+
st_shared(reinterpret_cast<uint4*>(smem_buffer) + i, x, y, z, w);
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
// Fill unaligned values as well
|
| 74 |
+
if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values)
|
| 75 |
+
st_shared(smem_buffer + unaligned_idx, __ldg(local_sf + unaligned_idx));
|
| 76 |
+
__syncthreads();
|
| 77 |
+
|
| 78 |
+
// Pack into UE8M0 and store
|
| 79 |
+
#pragma unroll
|
| 80 |
+
for (uint32_t i = threadIdx.x; i < (kNumPackedSFK * BLOCK_MN); i += kNumThreads) {
|
| 81 |
+
const auto sf_k_pack_idx = i / BLOCK_MN, mn_idx = i % BLOCK_MN;
|
| 82 |
+
|
| 83 |
+
// Load shared memory
|
| 84 |
+
uint32_t values[4];
|
| 85 |
+
#pragma unroll
|
| 86 |
+
for (uint32_t j = 0; j < 4; ++ j) {
|
| 87 |
+
const auto sf_k_idx = sf_k_pack_idx * 4 + j;
|
| 88 |
+
values[j] = sf_k_idx < SF_K ? ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
// Pack and store
|
| 92 |
+
uint32_t packed = 0;
|
| 93 |
+
packed |= (values[0] >> 23u);
|
| 94 |
+
packed |= (values[1] >> 15u);
|
| 95 |
+
packed |= (values[2] >> 7u);
|
| 96 |
+
packed |= (values[3] << 1u);
|
| 97 |
+
if (const auto global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; global_mn_idx < mn)
|
| 98 |
+
out[sf_k_pack_idx * tma_aligned_mn + global_mn_idx] = packed;
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
template <uint32_t kNumGroups, uint32_t kNumThreads,
|
| 103 |
+
uint32_t BLOCK_MN, uint32_t BLOCK_PACKED_SF_K, bool kTransposed = true>
|
| 104 |
+
__global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
|
| 105 |
+
const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k) {
|
| 106 |
+
// Always packing the K dimension
|
| 107 |
+
// NOTES: should also assert `mn % 4 == 0` at launch
|
| 108 |
+
DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)");
|
| 109 |
+
DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block sizes");
|
| 110 |
+
DG_STATIC_ASSERT(BLOCK_PACKED_SF_K == kNumThreads / 32, "Invalid block sizes");
|
| 111 |
+
|
| 112 |
+
// Shapes and strides
|
| 113 |
+
const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
|
| 114 |
+
const auto in_block_mn_uint4 = in_block_mn / 4;
|
| 115 |
+
const auto in_block_packed_sf_k = min(BLOCK_PACKED_SF_K, packed_sf_k - blockIdx.y * BLOCK_PACKED_SF_K);
|
| 116 |
+
|
| 117 |
+
// Shift into the right block along MN
|
| 118 |
+
sf += blockIdx.x * BLOCK_MN;
|
| 119 |
+
out += blockIdx.x * BLOCK_MN;
|
| 120 |
+
|
| 121 |
+
// Each warp is responsible for a packed row
|
| 122 |
+
const auto warp_idx = threadIdx.x / 32;
|
| 123 |
+
const auto lane_idx = get_lane_idx();
|
| 124 |
+
const auto packed_sf_k_idx = static_cast<uint64_t>(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx;
|
| 125 |
+
if (warp_idx >= in_block_packed_sf_k)
|
| 126 |
+
return;
|
| 127 |
+
|
| 128 |
+
// Make an offset on the input
|
| 129 |
+
uint32_t input_offset = 0;
|
| 130 |
+
if constexpr (kNumGroups > 1) {
|
| 131 |
+
// Load each group's size
|
| 132 |
+
DG_STATIC_ASSERT(kNumGroups <= 128, "Too many groups");
|
| 133 |
+
uint32_t group_ks[4];
|
| 134 |
+
#pragma unroll
|
| 135 |
+
for (uint32_t i = 0; i < 4; ++ i) {
|
| 136 |
+
const auto group_idx = lane_idx * 4 + i;
|
| 137 |
+
group_ks[i] = group_idx < kNumGroups ? __ldg(ks + group_idx) : 0;
|
| 138 |
+
}
|
| 139 |
+
__syncwarp();
|
| 140 |
+
|
| 141 |
+
// Make the offset
|
| 142 |
+
sf_k = 0;
|
| 143 |
+
auto sum_packed_sf_k = 0;
|
| 144 |
+
#pragma unroll
|
| 145 |
+
for (uint32_t i = 0; i < kNumGroups; ++ i) {
|
| 146 |
+
const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / 128, i / 4);
|
| 147 |
+
sf_k += sf_k_in_group;
|
| 148 |
+
sum_packed_sf_k += ceil_div(sf_k_in_group, 4u);
|
| 149 |
+
if (packed_sf_k_idx < sum_packed_sf_k)
|
| 150 |
+
break;
|
| 151 |
+
if (const auto remainder = sf_k_in_group % 4; remainder > 0)
|
| 152 |
+
input_offset += 4 - remainder;
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
for (uint32_t mn_idx = get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) {
|
| 157 |
+
// Load
|
| 158 |
+
uint4 values[4];
|
| 159 |
+
#pragma unroll
|
| 160 |
+
for (uint32_t j = 0; j < 4; ++ j) {
|
| 161 |
+
values[j] = make_uint4(0, 0, 0, 0);
|
| 162 |
+
if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k)
|
| 163 |
+
values[j] = __ldg(reinterpret_cast<uint4*>(sf + sf_k_idx * mn) + mn_idx);
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
// Pack and store
|
| 167 |
+
uint4 packed;
|
| 168 |
+
packed.x = (values[0].x >> 23u) | (values[1].x >> 15u) | (values[2].x >> 7u) | (values[3].x << 1u);
|
| 169 |
+
packed.y = (values[0].y >> 23u) | (values[1].y >> 15u) | (values[2].y >> 7u) | (values[3].y << 1u);
|
| 170 |
+
packed.z = (values[0].z >> 23u) | (values[1].z >> 15u) | (values[2].z >> 7u) | (values[3].z << 1u);
|
| 171 |
+
packed.w = (values[0].w >> 23u) | (values[1].w >> 15u) | (values[2].w >> 7u) | (values[3].w << 1u);
|
| 172 |
+
reinterpret_cast<uint4*>(out + packed_sf_k_idx * mn)[mn_idx] = packed;
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
} // namespace deep_gemm
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/options.h
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
#pragma once
|
| 33 |
+
|
| 34 |
+
#include <vector>
|
| 35 |
+
#include <iostream>
|
| 36 |
+
|
| 37 |
+
// Cutlass command line parser
|
| 38 |
+
#include "cutlass/util/command_line.h"
|
| 39 |
+
|
| 40 |
+
class Options {
|
| 41 |
+
public:
|
| 42 |
+
|
| 43 |
+
bool help;
|
| 44 |
+
bool good;
|
| 45 |
+
std::vector<int> extent; ///< extent of tile to fill
|
| 46 |
+
std::vector<int> stride; ///< stride vector for layout function
|
| 47 |
+
std::vector<int> output_shape; ///< output shape
|
| 48 |
+
int vectorize; ///< sequences of consecutive output elements are concatenated into a vector
|
| 49 |
+
/// if, and only if, they were consecutive in source memory
|
| 50 |
+
|
| 51 |
+
public:
|
| 52 |
+
|
| 53 |
+
/// Options
|
| 54 |
+
Options():
|
| 55 |
+
help(false),
|
| 56 |
+
good(true),
|
| 57 |
+
extent({32, 8}),
|
| 58 |
+
stride({32}),
|
| 59 |
+
output_shape({16, 8}),
|
| 60 |
+
vectorize(1) {
|
| 61 |
+
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
/// Constructs from command line parser
|
| 65 |
+
Options(cutlass::CommandLine const & cmd_line): help(false), good(true) {
|
| 66 |
+
|
| 67 |
+
if (cmd_line.check_cmd_line_flag("help") ||
|
| 68 |
+
cmd_line.check_cmd_line_flag("h")) {
|
| 69 |
+
|
| 70 |
+
help = true;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
if (cmd_line.check_cmd_line_flag("extent")) {
|
| 74 |
+
cmd_line.get_cmd_line_arguments("extent", extent);
|
| 75 |
+
}
|
| 76 |
+
else {
|
| 77 |
+
extent = {32, 8};
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
if (cmd_line.check_cmd_line_flag("stride")) {
|
| 81 |
+
cmd_line.get_cmd_line_arguments("stride", stride);
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
int default_output_shape[] = {16, 8};
|
| 85 |
+
|
| 86 |
+
if (cmd_line.check_cmd_line_flag("output-shape")) {
|
| 87 |
+
cmd_line.get_cmd_line_arguments("output-shape", output_shape);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
for (int i = int(output_shape.size()); i < 2; ++i) {
|
| 91 |
+
output_shape.push_back(default_output_shape[i]);
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
if (cmd_line.check_cmd_line_flag("vectorize")) {
|
| 95 |
+
cmd_line.get_cmd_line_argument("vectorize", vectorize);
|
| 96 |
+
}
|
| 97 |
+
else {
|
| 98 |
+
vectorize = 1;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
if (output_shape.front() % vectorize) {
|
| 102 |
+
|
| 103 |
+
std::cerr << "Error: --vectorize=" << vectorize
|
| 104 |
+
<< " must divide contiguous elements in --output-shape="
|
| 105 |
+
<< output_shape.at(0) << "," << output_shape.at(1) << std::endl;
|
| 106 |
+
|
| 107 |
+
good = false;
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
/// Prints usage statement
|
| 112 |
+
static void print_usage(std::ostream &out) {
|
| 113 |
+
out
|
| 114 |
+
<< " Options:\n"
|
| 115 |
+
<< " --help Displays this help message.\n"
|
| 116 |
+
<< " --extent=<extent> Specifies the layout-specific extent (as comma-delimited array).\n"
|
| 117 |
+
<< " --stride=<stride> Specifies the layout-specific stride vector (comma-delimited array)\n"
|
| 118 |
+
<< " --output-shape=<extent> Specifies the dimensions of a row-major output matrix. \n"
|
| 119 |
+
<< " --vectorize=<vector length> If possible, vectorizes the output into vectors of consecutive elements\n";
|
| 120 |
+
}
|
| 121 |
+
};
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/register_layout.h
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief CUTLASS layout visualization example
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include <map>
|
| 39 |
+
#include <memory>
|
| 40 |
+
|
| 41 |
+
#include "options.h"
|
| 42 |
+
|
| 43 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
struct VisualizeLayoutBase {
|
| 46 |
+
virtual bool visualize(Options const &) = 0;
|
| 47 |
+
virtual bool verify(bool verbose, std::ostream &out) = 0;
|
| 48 |
+
virtual void print_csv(std::ostream &out, char delim = '|', char new_line = '\n') = 0;
|
| 49 |
+
virtual std::ostream &print_help(std::ostream &out) {
|
| 50 |
+
return out;
|
| 51 |
+
}
|
| 52 |
+
virtual ~VisualizeLayoutBase() { }
|
| 53 |
+
};
|
| 54 |
+
|
| 55 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 56 |
+
|
| 57 |
+
void RegisterLayouts(std::map<std::string, std::unique_ptr<VisualizeLayoutBase> > &layouts);
|
| 58 |
+
|
| 59 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/visualize_layout.h
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief CUTLASS layout visualization example
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include <algorithm>
|
| 39 |
+
#include <stdexcept>
|
| 40 |
+
#include <vector>
|
| 41 |
+
|
| 42 |
+
#include "cutlass/coord.h"
|
| 43 |
+
#include "cutlass/util/reference/host/tensor_foreach.h"
|
| 44 |
+
|
| 45 |
+
#include "register_layout.h"
|
| 46 |
+
|
| 47 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
/// Permits copying dynamic vectors into static-length vectors
|
| 50 |
+
template <typename TensorCoord, int Rank>
|
| 51 |
+
struct vector_to_coord {
|
| 52 |
+
|
| 53 |
+
vector_to_coord(TensorCoord &coord, std::vector<int> const &vec) {
|
| 54 |
+
|
| 55 |
+
coord[Rank - 1] = vec.at(Rank - 1);
|
| 56 |
+
|
| 57 |
+
if (Rank > 1) {
|
| 58 |
+
vector_to_coord<TensorCoord, Rank - 1>(coord, vec);
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
/// Permits copying dynamic vectors into static-length vectors
|
| 64 |
+
template <typename TensorCoord>
|
| 65 |
+
struct vector_to_coord<TensorCoord, 1> {
|
| 66 |
+
|
| 67 |
+
vector_to_coord(TensorCoord &coord, std::vector<int> const &vec) {
|
| 68 |
+
|
| 69 |
+
coord[0] = vec.at(0);
|
| 70 |
+
}
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
/// Permits copying dynamic vectors into static-length vectors
|
| 74 |
+
template <typename TensorCoord>
|
| 75 |
+
struct vector_to_coord<TensorCoord, 0> {
|
| 76 |
+
|
| 77 |
+
vector_to_coord(TensorCoord &coord, std::vector<int> const &vec) {
|
| 78 |
+
|
| 79 |
+
}
|
| 80 |
+
};
|
| 81 |
+
|
| 82 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 83 |
+
|
| 84 |
+
template <typename T>
|
| 85 |
+
std::ostream &operator<<(std::ostream &out, std::vector<T> const &vec) {
|
| 86 |
+
auto it = vec.begin();
|
| 87 |
+
if (it != vec.end()) {
|
| 88 |
+
out << *it;
|
| 89 |
+
for (++it; it != vec.end(); ++it) {
|
| 90 |
+
out << ", " << *it;
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
return out;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 97 |
+
|
| 98 |
+
/// Permits copying static-length vectors into dynamic vectors
|
| 99 |
+
template <typename TensorCoord, int Rank>
|
| 100 |
+
struct coord_to_vector {
|
| 101 |
+
|
| 102 |
+
coord_to_vector(std::vector<int> &vec, TensorCoord const &coord) {
|
| 103 |
+
|
| 104 |
+
vec.at(Rank - 1) = coord[Rank - 1];
|
| 105 |
+
coord_to_vector<TensorCoord, Rank - 1>(vec, coord);
|
| 106 |
+
}
|
| 107 |
+
};
|
| 108 |
+
|
| 109 |
+
/// Permits copying static-length vectors into dynamic vectors
|
| 110 |
+
template <typename TensorCoord>
|
| 111 |
+
struct coord_to_vector<TensorCoord, 1> {
|
| 112 |
+
|
| 113 |
+
coord_to_vector(std::vector<int> &vec, TensorCoord const &coord) {
|
| 114 |
+
|
| 115 |
+
vec.at(0) = coord[0];
|
| 116 |
+
}
|
| 117 |
+
};
|
| 118 |
+
|
| 119 |
+
/// Permits copying static-length vectors into dynamic vectors
|
| 120 |
+
template <typename TensorCoord>
|
| 121 |
+
struct coord_to_vector<TensorCoord, 0> {
|
| 122 |
+
|
| 123 |
+
coord_to_vector(std::vector<int> &vec, TensorCoord const &coord) {
|
| 124 |
+
}
|
| 125 |
+
};
|
| 126 |
+
|
| 127 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 128 |
+
|
| 129 |
+
/// Structure representing an element in source memory
|
| 130 |
+
struct Element {
|
| 131 |
+
|
| 132 |
+
std::vector<int> coord; ///< logical coordinate of element (as vector)
|
| 133 |
+
int offset; ///< linear offset from source memory
|
| 134 |
+
int color; ///< enables coloring each element to indicate
|
| 135 |
+
|
| 136 |
+
/// Default ctor
|
| 137 |
+
inline Element(): offset(-1), color(0) { }
|
| 138 |
+
|
| 139 |
+
/// Construct from logical coordinate and initial offset
|
| 140 |
+
inline Element(
|
| 141 |
+
std::vector<int> const &coord_,
|
| 142 |
+
int offset_,
|
| 143 |
+
int color_ = 0
|
| 144 |
+
):
|
| 145 |
+
coord(coord_), offset(offset_), color(color_) { }
|
| 146 |
+
|
| 147 |
+
/// Returns true if element is in a defined state
|
| 148 |
+
inline bool valid() const {
|
| 149 |
+
return offset >= 0;
|
| 150 |
+
}
|
| 151 |
+
};
|
| 152 |
+
|
| 153 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 154 |
+
|
| 155 |
+
/// Visualizes memory layouts by constructing a 'shape'
|
| 156 |
+
template <typename Layout_>
|
| 157 |
+
class VisualizeLayout : public VisualizeLayoutBase {
|
| 158 |
+
public:
|
| 159 |
+
|
| 160 |
+
using Layout = Layout_;
|
| 161 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 162 |
+
using Stride = typename Layout::Stride;
|
| 163 |
+
|
| 164 |
+
public:
|
| 165 |
+
|
| 166 |
+
Options options;
|
| 167 |
+
Layout layout;
|
| 168 |
+
TensorCoord extent;
|
| 169 |
+
std::vector<Element> elements;
|
| 170 |
+
|
| 171 |
+
public:
|
| 172 |
+
|
| 173 |
+
/// Initializes the problem space
|
| 174 |
+
VisualizeLayout() {
|
| 175 |
+
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
/// visualization method
|
| 179 |
+
bool visualize(Options const &options_) {
|
| 180 |
+
|
| 181 |
+
options = options_;
|
| 182 |
+
|
| 183 |
+
if (options.extent.size() != TensorCoord::kRank) {
|
| 184 |
+
|
| 185 |
+
std::cerr
|
| 186 |
+
<< "--extent must have rank " << TensorCoord::kRank
|
| 187 |
+
<< " (given: " << options.extent.size() << ")" << std::endl;
|
| 188 |
+
|
| 189 |
+
return false;
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
vector_to_coord<TensorCoord, TensorCoord::kRank>(extent, options.extent);
|
| 193 |
+
|
| 194 |
+
// Construct the layout for a packed tensor
|
| 195 |
+
if (options.stride.empty()) {
|
| 196 |
+
|
| 197 |
+
layout = Layout::packed(extent);
|
| 198 |
+
}
|
| 199 |
+
else if (options.stride.size() != Stride::kRank) {
|
| 200 |
+
|
| 201 |
+
std::cerr
|
| 202 |
+
<< "--stride must have rank " << Stride::kRank
|
| 203 |
+
<< " (given: " << options.stride.size() << ")" << std::endl;
|
| 204 |
+
|
| 205 |
+
return false;
|
| 206 |
+
}
|
| 207 |
+
else {
|
| 208 |
+
// Stride from
|
| 209 |
+
Stride stride;
|
| 210 |
+
vector_to_coord<Stride, Stride::kRank>(stride, options.stride);
|
| 211 |
+
|
| 212 |
+
layout = Layout(stride);
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
// Resize elements, setting elements to 'undefined' state
|
| 216 |
+
elements.resize(layout.capacity(extent));
|
| 217 |
+
|
| 218 |
+
// enumerate points in tensor space and assign
|
| 219 |
+
cutlass::reference::host::TensorForEachLambda(
|
| 220 |
+
extent,
|
| 221 |
+
[&](TensorCoord coord) {
|
| 222 |
+
|
| 223 |
+
std::vector<int> coord_vec(TensorCoord::kRank, 0);
|
| 224 |
+
coord_to_vector<TensorCoord, TensorCoord::kRank>(coord_vec, coord);
|
| 225 |
+
|
| 226 |
+
int offset = int(layout(coord));
|
| 227 |
+
|
| 228 |
+
if (offset >= int(elements.size())) {
|
| 229 |
+
std::cerr
|
| 230 |
+
<< "Layout error - " << coord_vec
|
| 231 |
+
<< " is out of range (computed offset: " << offset
|
| 232 |
+
<< ", capacity: " << elements.size() << std::endl;
|
| 233 |
+
|
| 234 |
+
throw std::out_of_range("(TensorForEach) layout error - coordinate out of range");
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
elements.at(offset) = Element(coord_vec, offset);
|
| 238 |
+
});
|
| 239 |
+
|
| 240 |
+
return true;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
/// Verifies the layout satisfies vectorization requirements
|
| 244 |
+
bool verify(bool verbose, std::ostream &out) {
|
| 245 |
+
return true;
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
private:
|
| 249 |
+
|
| 250 |
+
/// returns a pair (is_vectorizable, one_changing_rank) to determine if a
|
| 251 |
+
/// vector exists (consecutive logical coordinates or uniformly invalid)
|
| 252 |
+
/// at the given location.
|
| 253 |
+
std::pair< bool, int > _is_vectorizable(int i) const {
|
| 254 |
+
// (all elements are invalid) or
|
| 255 |
+
// (all elements are valid AND
|
| 256 |
+
// exactly one rank is changing AND
|
| 257 |
+
// elements are consecutive)
|
| 258 |
+
|
| 259 |
+
// Don't need vectorization.
|
| 260 |
+
if (options.vectorize <= 2) return std::make_pair(false, -1);
|
| 261 |
+
|
| 262 |
+
// Boundary check.
|
| 263 |
+
if (i > int(elements.size()) || (i + options.vectorize - 1) > int(elements.size()))
|
| 264 |
+
return std::make_pair(false, -1);
|
| 265 |
+
|
| 266 |
+
// Check if either all elements are valid or invalid.
|
| 267 |
+
bool all_elements_invalid = std::all_of(
|
| 268 |
+
elements.begin() + i, elements.begin() + i + options.vectorize,
|
| 269 |
+
[](Element const &e) { return !e.valid(); });
|
| 270 |
+
|
| 271 |
+
bool all_elements_valid = std::all_of(
|
| 272 |
+
elements.begin() + i, elements.begin() + i + options.vectorize,
|
| 273 |
+
[](Element const &e) { return e.valid(); });
|
| 274 |
+
|
| 275 |
+
if (!all_elements_invalid && !all_elements_valid)
|
| 276 |
+
return std::make_pair(false, -1);
|
| 277 |
+
|
| 278 |
+
// From here, it is vectorizable.
|
| 279 |
+
if (all_elements_invalid) return std::make_pair(true, -1);
|
| 280 |
+
|
| 281 |
+
// Check if only exactly one rank is changing.
|
| 282 |
+
int one_changing_rank = -1;
|
| 283 |
+
for (int j = 0; j < options.vectorize; ++j) {
|
| 284 |
+
for (int r = 0; r < TensorCoord::kRank; ++r) {
|
| 285 |
+
if (elements.at(i + j).coord.at(r) != elements.at(i).coord.at(r)) {
|
| 286 |
+
if (one_changing_rank == -1) {
|
| 287 |
+
one_changing_rank = r;
|
| 288 |
+
} else if (one_changing_rank != r) {
|
| 289 |
+
return std::make_pair(false, -1);
|
| 290 |
+
}
|
| 291 |
+
}
|
| 292 |
+
}
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
return std::make_pair(true, one_changing_rank);
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
/// Prints a vector of elements
|
| 299 |
+
void _print_vector(std::ostream &out, int i, int one_changing_rank) {
|
| 300 |
+
Element const &base_element = elements.at(i);
|
| 301 |
+
if (base_element.valid()) {
|
| 302 |
+
out << "(";
|
| 303 |
+
for (int r = 0; r < TensorCoord::kRank; ++r) {
|
| 304 |
+
if (r) {
|
| 305 |
+
out << ", ";
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
if (r == one_changing_rank) {
|
| 309 |
+
out
|
| 310 |
+
<< base_element.coord.at(r)
|
| 311 |
+
<< ".."
|
| 312 |
+
<< (base_element.coord.at(r) + options.vectorize - 1);
|
| 313 |
+
}
|
| 314 |
+
else {
|
| 315 |
+
out << base_element.coord.at(r);
|
| 316 |
+
}
|
| 317 |
+
}
|
| 318 |
+
out << ")";
|
| 319 |
+
}
|
| 320 |
+
else {
|
| 321 |
+
out << " ";
|
| 322 |
+
}
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
/// Prints a single element
|
| 326 |
+
void _print_element(std::ostream &out, int k) {
|
| 327 |
+
Element const &element = elements.at(k);
|
| 328 |
+
if (element.valid()) {
|
| 329 |
+
out << "(";
|
| 330 |
+
for (int v = 0; v < TensorCoord::kRank; ++v) {
|
| 331 |
+
out << (v ? ", " : "") << element.coord.at(v);
|
| 332 |
+
}
|
| 333 |
+
out << ")";
|
| 334 |
+
}
|
| 335 |
+
else {
|
| 336 |
+
out << " ";
|
| 337 |
+
}
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
public:
|
| 341 |
+
|
| 342 |
+
/// Pretty-prints the layout to the console
|
| 343 |
+
void print_csv(std::ostream &out, char delim = '|', char new_line = '\n') {
|
| 344 |
+
int row = -1;
|
| 345 |
+
|
| 346 |
+
for (int i = 0; i < int(elements.size()); i += options.vectorize) {
|
| 347 |
+
if (i % options.output_shape.at(0)) {
|
| 348 |
+
out << delim;
|
| 349 |
+
}
|
| 350 |
+
else {
|
| 351 |
+
if (row >= 0) {
|
| 352 |
+
out << new_line;
|
| 353 |
+
}
|
| 354 |
+
++row;
|
| 355 |
+
if (row == options.output_shape.at(1)) {
|
| 356 |
+
out << new_line;
|
| 357 |
+
row = 0;
|
| 358 |
+
}
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
auto is_vector = _is_vectorizable(i);
|
| 362 |
+
|
| 363 |
+
if (is_vector.first) {
|
| 364 |
+
_print_vector(out, i, is_vector.second); // print a vector starting at element i
|
| 365 |
+
}
|
| 366 |
+
else {
|
| 367 |
+
for (int j = 0; j < options.vectorize; ++j) { // print individual elements [i..i+j)
|
| 368 |
+
_print_element(out, i + j);
|
| 369 |
+
}
|
| 370 |
+
}
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
out << new_line << std::flush;
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
/// Help message
|
| 377 |
+
virtual std::ostream &print_help(std::ostream &out) {
|
| 378 |
+
out << "TensorCoord rank " << TensorCoord::kRank << ", Stride rank: " << Stride::kRank;
|
| 379 |
+
return out;
|
| 380 |
+
}
|
| 381 |
+
};
|
| 382 |
+
|
| 383 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h
ADDED
|
@@ -0,0 +1,719 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
#pragma once
|
| 33 |
+
|
| 34 |
+
#include <iostream>
|
| 35 |
+
#include <fstream>
|
| 36 |
+
#include <sstream>
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
|
| 40 |
+
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
| 41 |
+
#include "cutlass/reduction/device/reduce_split_k.h"
|
| 42 |
+
#include "cutlass/reduction/thread/reduction_operators.h"
|
| 43 |
+
|
| 44 |
+
#include "cutlass/util/host_tensor.h"
|
| 45 |
+
#include "cutlass/util/reference/host/tensor_fill.h"
|
| 46 |
+
#include "cutlass/util/reference/device/tensor_compare.h"
|
| 47 |
+
#include "cutlass/util/reference/host/tensor_compare.h"
|
| 48 |
+
#include "cutlass/util/reference/host/tensor_norm.h"
|
| 49 |
+
|
| 50 |
+
#include "cutlass/util/reference/host/convolution.h"
|
| 51 |
+
#include "cutlass/util/reference/device/convolution.h"
|
| 52 |
+
#include "cutlass/util/reference/device/tensor_relu.h"
|
| 53 |
+
|
| 54 |
+
#include "cutlass/core_io.h"
|
| 55 |
+
#include "cutlass/util/tensor_view_io.h"
|
| 56 |
+
|
| 57 |
+
#include "reference/device/tensor_scale_bias.h"
|
| 58 |
+
#include "helper.h"
|
| 59 |
+
|
| 60 |
+
#define CHECK_GT(val1, val2) \
|
| 61 |
+
if((val1) <= (val2)) \
|
| 62 |
+
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
|
| 63 |
+
#define CHECK_TRUE(val) \
|
| 64 |
+
if(!(val)) \
|
| 65 |
+
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
template <typename Conv2d0_, typename Conv2d1_>
|
| 69 |
+
class B2bNonFusedConv2dRun {
|
| 70 |
+
public:
|
| 71 |
+
|
| 72 |
+
using Conv2d0 = Conv2d0_;
|
| 73 |
+
using Conv2d1 = Conv2d1_;
|
| 74 |
+
using ElementAccumulator = typename Conv2d0::ElementAccumulator;
|
| 75 |
+
using ElementCompute = typename Conv2d0::ElementCompute;
|
| 76 |
+
|
| 77 |
+
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d0::kConvolutionalOperator;
|
| 78 |
+
static_assert(kConvolutionalOperator == Conv2d1::kConvolutionalOperator,
|
| 79 |
+
"Fused convolution operators must be the same");
|
| 80 |
+
|
| 81 |
+
public:
|
| 82 |
+
|
| 83 |
+
/// Initialization
|
| 84 |
+
cutlass::Distribution::Kind init_A;
|
| 85 |
+
cutlass::Distribution::Kind init_B;
|
| 86 |
+
cutlass::Distribution::Kind init_C;
|
| 87 |
+
cutlass::Distribution::Kind init_Bias;
|
| 88 |
+
uint64_t seed;
|
| 89 |
+
|
| 90 |
+
cutlass::HostTensor<typename Conv2d0::ElementA, typename Conv2d0::LayoutA> tensor_A0;
|
| 91 |
+
cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0;
|
| 92 |
+
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_C0;
|
| 93 |
+
cutlass::HostTensor<typename Conv2d0::ElementCompute, typename Conv2d0::LayoutC> tensor_Bias0;
|
| 94 |
+
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_computed;
|
| 95 |
+
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_reference;
|
| 96 |
+
|
| 97 |
+
cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1;
|
| 98 |
+
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_C1;
|
| 99 |
+
cutlass::HostTensor<typename Conv2d1::ElementCompute, typename Conv2d0::LayoutC> tensor_Bias1;
|
| 100 |
+
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_computed;
|
| 101 |
+
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_reference;
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
public:
|
| 105 |
+
|
| 106 |
+
B2bNonFusedConv2dRun(
|
| 107 |
+
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 108 |
+
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 109 |
+
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 110 |
+
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
| 111 |
+
uint64_t seed_ = 2080
|
| 112 |
+
):
|
| 113 |
+
init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) {
|
| 114 |
+
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
/// Helper to initialize a tensor view
|
| 118 |
+
template <typename Element, typename Layout>
|
| 119 |
+
void initialize_tensor(
|
| 120 |
+
cutlass::TensorView<Element, Layout> view,
|
| 121 |
+
cutlass::Distribution::Kind dist_kind,
|
| 122 |
+
uint64_t seed) {
|
| 123 |
+
|
| 124 |
+
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 125 |
+
|
| 126 |
+
int scope;
|
| 127 |
+
int bits = cutlass::sizeof_bits<Element>::value;
|
| 128 |
+
|
| 129 |
+
if (bits <= 16) {
|
| 130 |
+
scope = 2;
|
| 131 |
+
}
|
| 132 |
+
else {
|
| 133 |
+
scope = 8;
|
| 134 |
+
}
|
| 135 |
+
cutlass::reference::host::TensorFillRandomUniform(
|
| 136 |
+
view, seed, scope, -scope, 0);
|
| 137 |
+
}
|
| 138 |
+
else if (dist_kind == cutlass::Distribution::Identity) {
|
| 139 |
+
|
| 140 |
+
cutlass::reference::host::TensorFillIdentity(view);
|
| 141 |
+
}
|
| 142 |
+
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 143 |
+
|
| 144 |
+
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 145 |
+
}
|
| 146 |
+
else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 147 |
+
|
| 148 |
+
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
| 149 |
+
}
|
| 150 |
+
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
| 151 |
+
cutlass::reference::host::TensorFill(view, Element(0));
|
| 152 |
+
}
|
| 153 |
+
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
| 154 |
+
cutlass::reference::host::TensorFill(view, Element(1));
|
| 155 |
+
}
|
| 156 |
+
else {
|
| 157 |
+
std::cerr << "Not implemented\n";
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
void initialize(
|
| 162 |
+
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
| 163 |
+
cutlass::conv::Conv2dProblemSize const &problem_size_1,
|
| 164 |
+
uint64_t seed = 2019) {
|
| 165 |
+
|
| 166 |
+
tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0));
|
| 167 |
+
tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
|
| 168 |
+
tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
| 169 |
+
tensor_Bias0.resize({1, 1, 1, problem_size_0.K});
|
| 170 |
+
tensor_D0_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
| 171 |
+
tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
| 172 |
+
tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
| 173 |
+
tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
| 174 |
+
tensor_Bias1.resize({1, 1, 1, problem_size_1.K});
|
| 175 |
+
tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
| 176 |
+
tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
| 177 |
+
|
| 178 |
+
initialize_tensor(tensor_A0.host_view(), init_A, seed);
|
| 179 |
+
initialize_tensor(tensor_B0.host_view(), init_B, seed * 17);
|
| 180 |
+
initialize_tensor(tensor_C0.host_view(), init_C, seed * 39);
|
| 181 |
+
initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83);
|
| 182 |
+
initialize_tensor(tensor_B1.host_view(), init_B, seed * 18);
|
| 183 |
+
initialize_tensor(tensor_C1.host_view(), init_C, seed * 40);
|
| 184 |
+
initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84);
|
| 185 |
+
|
| 186 |
+
tensor_A0.sync_device();
|
| 187 |
+
tensor_B0.sync_device();
|
| 188 |
+
tensor_C0.sync_device();
|
| 189 |
+
tensor_Bias0.sync_device();
|
| 190 |
+
tensor_D0_computed.sync_device();
|
| 191 |
+
tensor_D0_reference.sync_device();
|
| 192 |
+
tensor_B1.sync_device();
|
| 193 |
+
tensor_C1.sync_device();
|
| 194 |
+
tensor_Bias1.sync_device();
|
| 195 |
+
tensor_D1_computed.sync_device();
|
| 196 |
+
tensor_D1_reference.sync_device();
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
/// Executes one test
|
| 200 |
+
bool run(
|
| 201 |
+
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
| 202 |
+
cutlass::conv::Conv2dProblemSize const &problem_size_1,
|
| 203 |
+
cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
|
| 204 |
+
ElementCompute alpha0 = ElementCompute(1),
|
| 205 |
+
ElementCompute beta0 = ElementCompute(0),
|
| 206 |
+
ElementCompute alpha1 = ElementCompute(1),
|
| 207 |
+
ElementCompute beta1 = ElementCompute(0),
|
| 208 |
+
bool relu = true,
|
| 209 |
+
int warm_ups = 1,
|
| 210 |
+
int runs = 100) {
|
| 211 |
+
|
| 212 |
+
initialize(problem_size_0, problem_size_1);
|
| 213 |
+
|
| 214 |
+
// configure the operator
|
| 215 |
+
Conv2d0 conv2d_op_0;
|
| 216 |
+
Conv2d1 conv2d_op_1;
|
| 217 |
+
|
| 218 |
+
typename Conv2d0::Arguments conv2d_args_0(
|
| 219 |
+
problem_size_0,
|
| 220 |
+
tensor_A0.device_ref(),
|
| 221 |
+
tensor_B0.device_ref(),
|
| 222 |
+
{tensor_Bias0.device_data(), typename Conv2d0::LayoutC::Stride(0)},
|
| 223 |
+
tensor_D0_computed.device_ref(),
|
| 224 |
+
{alpha0, beta0},
|
| 225 |
+
split_k_mode
|
| 226 |
+
);
|
| 227 |
+
typename Conv2d1::Arguments conv2d_args_1(
|
| 228 |
+
problem_size_1,
|
| 229 |
+
tensor_D0_computed.device_ref(),
|
| 230 |
+
tensor_B1.device_ref(),
|
| 231 |
+
{tensor_Bias1.device_data(), typename Conv2d1::LayoutC::Stride(0)},
|
| 232 |
+
tensor_D1_computed.device_ref(),
|
| 233 |
+
{alpha1, beta1},
|
| 234 |
+
split_k_mode
|
| 235 |
+
);
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
cutlass::Status status = conv2d_op_0.initialize(conv2d_args_0);
|
| 239 |
+
|
| 240 |
+
CUTLASS_CHECK(status);
|
| 241 |
+
|
| 242 |
+
status = conv2d_op_1.initialize(conv2d_args_1);
|
| 243 |
+
|
| 244 |
+
CUTLASS_CHECK(status);
|
| 245 |
+
|
| 246 |
+
for(int i = 0; i < warm_ups; i++) {
|
| 247 |
+
status = conv2d_op_0();
|
| 248 |
+
CUTLASS_CHECK(status);
|
| 249 |
+
status = conv2d_op_1();
|
| 250 |
+
CUTLASS_CHECK(status);
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
//
|
| 254 |
+
// Run Conv2d
|
| 255 |
+
//
|
| 256 |
+
cudaEvent_t start, stop1, stop2;
|
| 257 |
+
cudaEventCreate(&start);
|
| 258 |
+
cudaEventCreate(&stop1);
|
| 259 |
+
cudaEventCreate(&stop2);
|
| 260 |
+
|
| 261 |
+
cudaEventRecord(start);
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
for(int i = 0; i < runs; i++) {
|
| 265 |
+
// run conv2d operator
|
| 266 |
+
status = conv2d_op_0();
|
| 267 |
+
CUTLASS_CHECK(status);
|
| 268 |
+
}
|
| 269 |
+
cudaEventRecord(stop1);
|
| 270 |
+
|
| 271 |
+
for(int i = 0; i < runs; i++) {
|
| 272 |
+
// run conv2d operator
|
| 273 |
+
status = conv2d_op_1();
|
| 274 |
+
CUTLASS_CHECK(status);
|
| 275 |
+
}
|
| 276 |
+
cudaEventRecord(stop2);
|
| 277 |
+
cudaDeviceSynchronize();
|
| 278 |
+
float conv2d0Time, conv2d1Time, totalTime;
|
| 279 |
+
cudaEventElapsedTime(&conv2d0Time, start, stop1);
|
| 280 |
+
cudaEventElapsedTime(&conv2d1Time, stop1, stop2);
|
| 281 |
+
cudaEventElapsedTime(&totalTime, start, stop2);
|
| 282 |
+
std::cout << "conv2d 0 time " << conv2d0Time / (float)runs << " ms\n";
|
| 283 |
+
std::cout << "conv2d 1 time " << conv2d1Time / (float)runs << " ms\n";
|
| 284 |
+
std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
|
| 285 |
+
|
| 286 |
+
tensor_D0_computed.sync_host();
|
| 287 |
+
tensor_D1_computed.sync_host();
|
| 288 |
+
|
| 289 |
+
bool passed = false;
|
| 290 |
+
|
| 291 |
+
cutlass::reference::device::Conv2d<
|
| 292 |
+
typename Conv2d0::ElementA,
|
| 293 |
+
typename Conv2d0::LayoutA,
|
| 294 |
+
typename Conv2d0::ElementB,
|
| 295 |
+
typename Conv2d0::LayoutB,
|
| 296 |
+
typename Conv2d0::ElementC,
|
| 297 |
+
typename Conv2d0::LayoutC,
|
| 298 |
+
ElementCompute,
|
| 299 |
+
ElementAccumulator
|
| 300 |
+
>(
|
| 301 |
+
kConvolutionalOperator,
|
| 302 |
+
problem_size_0,
|
| 303 |
+
tensor_A0.device_ref(),
|
| 304 |
+
tensor_B0.device_ref(),
|
| 305 |
+
{tensor_Bias0.device_data(), typename Conv2d0::LayoutC::Stride(0)},
|
| 306 |
+
tensor_D0_reference.device_ref(),
|
| 307 |
+
alpha0,
|
| 308 |
+
beta0);
|
| 309 |
+
|
| 310 |
+
if(relu) {
|
| 311 |
+
cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
cutlass::reference::device::Conv2d<
|
| 315 |
+
typename Conv2d1::ElementA,
|
| 316 |
+
typename Conv2d1::LayoutA,
|
| 317 |
+
typename Conv2d1::ElementB,
|
| 318 |
+
typename Conv2d1::LayoutB,
|
| 319 |
+
typename Conv2d1::ElementC,
|
| 320 |
+
typename Conv2d1::LayoutC,
|
| 321 |
+
ElementCompute,
|
| 322 |
+
ElementAccumulator
|
| 323 |
+
>(
|
| 324 |
+
kConvolutionalOperator,
|
| 325 |
+
problem_size_1,
|
| 326 |
+
tensor_D0_reference.device_ref(),
|
| 327 |
+
tensor_B1.device_ref(),
|
| 328 |
+
{tensor_Bias1.device_data(), typename Conv2d1::LayoutC::Stride(0)},
|
| 329 |
+
tensor_D1_reference.device_ref(),
|
| 330 |
+
alpha1,
|
| 331 |
+
beta1);
|
| 332 |
+
|
| 333 |
+
if(relu) {
|
| 334 |
+
cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view());
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
cudaError_t result = cudaDeviceSynchronize();
|
| 338 |
+
CHECK_TRUE(result == cudaSuccess);
|
| 339 |
+
|
| 340 |
+
// sync host (copy device data to host) for dumping error output in case of mismatches
|
| 341 |
+
tensor_D0_reference.sync_host();
|
| 342 |
+
tensor_D1_reference.sync_host();
|
| 343 |
+
|
| 344 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_computed.host_view()), 0);
|
| 345 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0);
|
| 346 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0);
|
| 347 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0);
|
| 348 |
+
|
| 349 |
+
passed = cutlass::reference::host::TensorEquals(
|
| 350 |
+
tensor_D1_computed.host_view(),
|
| 351 |
+
tensor_D1_reference.host_view());
|
| 352 |
+
|
| 353 |
+
CHECK_TRUE(passed);
|
| 354 |
+
|
| 355 |
+
if (!passed) {
|
| 356 |
+
std::stringstream fname;
|
| 357 |
+
|
| 358 |
+
fname << "error_B2bImplicitGemm_device_nonfused.txt";
|
| 359 |
+
std::cerr << "Dumping results in " << fname.str() << "\n";
|
| 360 |
+
|
| 361 |
+
std::ofstream results(fname.str());
|
| 362 |
+
|
| 363 |
+
results << problem_size_0 << std::endl;
|
| 364 |
+
results << problem_size_1 << std::endl;
|
| 365 |
+
|
| 366 |
+
results
|
| 367 |
+
<< "\nA0:\n" << tensor_A0.host_view() << "\n"
|
| 368 |
+
<< "\nB0:\n" << tensor_B0.host_view() << "\n"
|
| 369 |
+
<< "\nC0:\n" << tensor_C0.host_view() << "\n"
|
| 370 |
+
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
| 371 |
+
<< "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n"
|
| 372 |
+
<< "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n"
|
| 373 |
+
<< "\nB1:\n" << tensor_B1.host_view() << "\n"
|
| 374 |
+
<< "\nC1:\n" << tensor_C1.host_view() << "\n"
|
| 375 |
+
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
| 376 |
+
<< "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
|
| 377 |
+
<< "\nD1 computed:\n" << tensor_D1_computed.host_view();
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
return passed;
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
};
|
| 386 |
+
|
| 387 |
+
template <typename B2bConv2d_>
|
| 388 |
+
class B2bFusedConv2dRun {
|
| 389 |
+
public:
|
| 390 |
+
|
| 391 |
+
using B2bConv2d = B2bConv2d_;
|
| 392 |
+
using ElementAccumulator = typename B2bConv2d::ElementAccumulator;
|
| 393 |
+
using ElementCompute = typename B2bConv2d::ElementCompute;
|
| 394 |
+
|
| 395 |
+
static cutlass::conv::Operator const kConvolutionalOperator = B2bConv2d::kConvolutionalOperator;
|
| 396 |
+
|
| 397 |
+
public:
|
| 398 |
+
|
| 399 |
+
/// Initialization
|
| 400 |
+
cutlass::Distribution::Kind init_A;
|
| 401 |
+
cutlass::Distribution::Kind init_B;
|
| 402 |
+
cutlass::Distribution::Kind init_C;
|
| 403 |
+
cutlass::Distribution::Kind init_Scale;
|
| 404 |
+
cutlass::Distribution::Kind init_Bias;
|
| 405 |
+
uint64_t seed;
|
| 406 |
+
|
| 407 |
+
cutlass::HostTensor<typename B2bConv2d::ElementA, typename B2bConv2d::LayoutA> tensor_A0;
|
| 408 |
+
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B0;
|
| 409 |
+
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C0;
|
| 410 |
+
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Scale0;
|
| 411 |
+
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Bias0;
|
| 412 |
+
cutlass::HostTensor<ElementAccumulator, typename B2bConv2d::LayoutC> tensor_Z0_reference;
|
| 413 |
+
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D0_reference;
|
| 414 |
+
|
| 415 |
+
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1;
|
| 416 |
+
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C1;
|
| 417 |
+
cutlass::HostTensor<typename B2bConv2d::ElementCompute, typename B2bConv2d::LayoutC> tensor_Bias1;
|
| 418 |
+
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_computed;
|
| 419 |
+
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_reference;
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
public:
|
| 423 |
+
|
| 424 |
+
B2bFusedConv2dRun(
|
| 425 |
+
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 426 |
+
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 427 |
+
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 428 |
+
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
| 429 |
+
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
| 430 |
+
uint64_t seed_ = 2080
|
| 431 |
+
):
|
| 432 |
+
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
| 433 |
+
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) {
|
| 434 |
+
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
/// Helper to initialize a tensor view
|
| 438 |
+
template <typename Element, typename Layout>
|
| 439 |
+
void initialize_tensor(
|
| 440 |
+
cutlass::TensorView<Element, Layout> view,
|
| 441 |
+
cutlass::Distribution::Kind dist_kind,
|
| 442 |
+
uint64_t seed) {
|
| 443 |
+
|
| 444 |
+
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 445 |
+
|
| 446 |
+
int scope;
|
| 447 |
+
int bits = cutlass::sizeof_bits<Element>::value;
|
| 448 |
+
|
| 449 |
+
if (bits <= 16) {
|
| 450 |
+
scope = 2;
|
| 451 |
+
}
|
| 452 |
+
else {
|
| 453 |
+
scope = 8;
|
| 454 |
+
}
|
| 455 |
+
cutlass::reference::host::TensorFillRandomUniform(
|
| 456 |
+
view, seed, scope, -scope, 0);
|
| 457 |
+
}
|
| 458 |
+
else if (dist_kind == cutlass::Distribution::Identity) {
|
| 459 |
+
|
| 460 |
+
cutlass::reference::host::TensorFillIdentity(view);
|
| 461 |
+
}
|
| 462 |
+
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 463 |
+
|
| 464 |
+
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 465 |
+
}
|
| 466 |
+
else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 467 |
+
|
| 468 |
+
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
| 469 |
+
}
|
| 470 |
+
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
| 471 |
+
cutlass::reference::host::TensorFill(view, Element(0));
|
| 472 |
+
}
|
| 473 |
+
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
| 474 |
+
cutlass::reference::host::TensorFill(view, Element(1));
|
| 475 |
+
}
|
| 476 |
+
else {
|
| 477 |
+
}
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
void initialize(
|
| 481 |
+
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
| 482 |
+
cutlass::conv::Conv2dProblemSize const &problem_size_1,
|
| 483 |
+
ElementCompute alpha0,
|
| 484 |
+
ElementCompute alpha1,
|
| 485 |
+
uint64_t seed = 2019) {
|
| 486 |
+
|
| 487 |
+
tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0));
|
| 488 |
+
tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
|
| 489 |
+
tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
| 490 |
+
if(alpha0 == ElementCompute(0)) //per-channel scale
|
| 491 |
+
tensor_Scale0.resize({1, problem_size_0.K});
|
| 492 |
+
tensor_Bias0.resize({1, problem_size_0.K});
|
| 493 |
+
tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
| 494 |
+
tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
| 495 |
+
tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
| 496 |
+
tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
| 497 |
+
tensor_Bias1.resize({1, 1, 1, problem_size_1.K});
|
| 498 |
+
tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
| 499 |
+
tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
| 500 |
+
|
| 501 |
+
initialize_tensor(tensor_A0.host_view(), init_A, seed);
|
| 502 |
+
initialize_tensor(tensor_B0.host_view(), init_B, seed * 17);
|
| 503 |
+
initialize_tensor(tensor_C0.host_view(), init_C, seed * 39);
|
| 504 |
+
if(alpha0 == ElementCompute(0)) //per-channel scale
|
| 505 |
+
initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed * 61);
|
| 506 |
+
initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83);
|
| 507 |
+
initialize_tensor(tensor_B1.host_view(), init_B, seed * 18);
|
| 508 |
+
initialize_tensor(tensor_C1.host_view(), init_C, seed * 40);
|
| 509 |
+
initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84);
|
| 510 |
+
|
| 511 |
+
tensor_A0.sync_device();
|
| 512 |
+
tensor_B0.sync_device();
|
| 513 |
+
tensor_C0.sync_device();
|
| 514 |
+
if(alpha0 == ElementCompute(0)) //per-channel scale
|
| 515 |
+
tensor_Scale0.sync_device();
|
| 516 |
+
tensor_Bias0.sync_device();
|
| 517 |
+
tensor_D0_reference.sync_device();
|
| 518 |
+
tensor_B1.sync_device();
|
| 519 |
+
tensor_C1.sync_device();
|
| 520 |
+
tensor_Bias1.sync_device();
|
| 521 |
+
tensor_D1_computed.sync_device();
|
| 522 |
+
tensor_D1_reference.sync_device();
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
/// Executes one test
|
| 526 |
+
bool run(
|
| 527 |
+
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
| 528 |
+
cutlass::conv::Conv2dProblemSize const &problem_size_1,
|
| 529 |
+
cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
|
| 530 |
+
ElementCompute alpha0 = ElementCompute(1),
|
| 531 |
+
ElementCompute beta0 = ElementCompute(0),
|
| 532 |
+
ElementCompute alpha1 = ElementCompute(1),
|
| 533 |
+
ElementCompute beta1 = ElementCompute(0),
|
| 534 |
+
bool relu = true,
|
| 535 |
+
int warm_ups = 1,
|
| 536 |
+
int runs = 100) {
|
| 537 |
+
|
| 538 |
+
initialize(problem_size_0, problem_size_1, alpha0, alpha1);
|
| 539 |
+
|
| 540 |
+
// configure the operator
|
| 541 |
+
B2bConv2d b2b_conv2d_op;
|
| 542 |
+
|
| 543 |
+
typename B2bConv2d::Arguments b2b_conv2d_args(
|
| 544 |
+
problem_size_0,
|
| 545 |
+
problem_size_1,
|
| 546 |
+
tensor_A0.device_ref(),
|
| 547 |
+
tensor_B0.device_ref(),
|
| 548 |
+
tensor_C0.device_ref(),
|
| 549 |
+
tensor_Scale0.device_ref(),
|
| 550 |
+
tensor_Bias0.device_ref(),
|
| 551 |
+
tensor_B1.device_ref(),
|
| 552 |
+
{tensor_Bias1.device_data(), typename B2bConv2d::LayoutC::Stride(0)},
|
| 553 |
+
tensor_D1_computed.device_ref(),
|
| 554 |
+
{alpha0, beta0},
|
| 555 |
+
{alpha1, beta1},
|
| 556 |
+
split_k_mode
|
| 557 |
+
);
|
| 558 |
+
|
| 559 |
+
cutlass::Status status = b2b_conv2d_op.can_implement(b2b_conv2d_args);
|
| 560 |
+
|
| 561 |
+
if(status != cutlass::Status::kSuccess) {
|
| 562 |
+
std::cout << "Problem sizes not supported.\n"
|
| 563 |
+
<< "Requirments:\n"
|
| 564 |
+
<< " problem_size_0.N*P*Q = problem_size_1.N*P*Q\n"
|
| 565 |
+
<< " problem_size_0.K = problem_size_1.C\n"
|
| 566 |
+
<< " problem_size_1.R = problem_size_1.S = 1\n"
|
| 567 |
+
<< " ThreadblockShape0::kN = problem_size_0.K\n"
|
| 568 |
+
<< " ThreadblockShape1::kN = problem_size_1.K" << std::endl;
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
CUTLASS_CHECK(status);
|
| 572 |
+
|
| 573 |
+
status = b2b_conv2d_op.initialize(b2b_conv2d_args);
|
| 574 |
+
|
| 575 |
+
CUTLASS_CHECK(status);
|
| 576 |
+
|
| 577 |
+
for(int i = 0; i < warm_ups; i++) {
|
| 578 |
+
status = b2b_conv2d_op();
|
| 579 |
+
CUTLASS_CHECK(status);
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
//
|
| 583 |
+
// Run the Conv2d
|
| 584 |
+
//
|
| 585 |
+
|
| 586 |
+
cudaEvent_t start, stop;
|
| 587 |
+
cudaEventCreate(&start);
|
| 588 |
+
cudaEventCreate(&stop);
|
| 589 |
+
|
| 590 |
+
cudaEventRecord(start);
|
| 591 |
+
|
| 592 |
+
for(int i = 0; i < runs; i++) {
|
| 593 |
+
|
| 594 |
+
// run conv2d operator
|
| 595 |
+
status = b2b_conv2d_op();
|
| 596 |
+
CUTLASS_CHECK(status);
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
cudaEventRecord(stop);
|
| 600 |
+
cudaDeviceSynchronize();
|
| 601 |
+
float conv2dTime;
|
| 602 |
+
cudaEventElapsedTime(&conv2dTime, start, stop);
|
| 603 |
+
std::cout << "Fusion time " << conv2dTime / (float)runs << " ms\n";
|
| 604 |
+
|
| 605 |
+
tensor_D1_computed.sync_host();
|
| 606 |
+
|
| 607 |
+
bool passed = false;
|
| 608 |
+
|
| 609 |
+
cutlass::reference::device::Conv2d<
|
| 610 |
+
typename B2bConv2d::ElementA,
|
| 611 |
+
typename B2bConv2d::LayoutA,
|
| 612 |
+
typename B2bConv2d::ElementB,
|
| 613 |
+
typename B2bConv2d::LayoutB,
|
| 614 |
+
ElementAccumulator,
|
| 615 |
+
typename B2bConv2d::LayoutC,
|
| 616 |
+
ElementAccumulator,
|
| 617 |
+
ElementAccumulator
|
| 618 |
+
>(
|
| 619 |
+
kConvolutionalOperator,
|
| 620 |
+
problem_size_0,
|
| 621 |
+
tensor_A0.device_ref(),
|
| 622 |
+
tensor_B0.device_ref(),
|
| 623 |
+
tensor_Z0_reference.device_ref(),
|
| 624 |
+
tensor_Z0_reference.device_ref(),
|
| 625 |
+
ElementAccumulator(1), // intermediate alpha = 1
|
| 626 |
+
ElementAccumulator(0) // beta = 0
|
| 627 |
+
);
|
| 628 |
+
|
| 629 |
+
cutlass::reference::device::TensorScaleBiasConv2d<
|
| 630 |
+
ElementAccumulator,
|
| 631 |
+
typename B2bConv2d::ElementC,
|
| 632 |
+
typename B2bConv2d::LayoutC,
|
| 633 |
+
ElementCompute,
|
| 634 |
+
typename B2bConv2d::LayoutScaleBias
|
| 635 |
+
>(
|
| 636 |
+
problem_size_0,
|
| 637 |
+
tensor_Z0_reference.device_ref(),
|
| 638 |
+
tensor_D0_reference.device_ref(),
|
| 639 |
+
alpha0,
|
| 640 |
+
tensor_Scale0.device_ref(),
|
| 641 |
+
tensor_Bias0.device_ref()
|
| 642 |
+
);
|
| 643 |
+
|
| 644 |
+
if(relu) {
|
| 645 |
+
cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
cutlass::reference::device::Conv2d<
|
| 649 |
+
typename B2bConv2d::ElementA,
|
| 650 |
+
typename B2bConv2d::LayoutA,
|
| 651 |
+
typename B2bConv2d::ElementB,
|
| 652 |
+
typename B2bConv2d::LayoutB,
|
| 653 |
+
typename B2bConv2d::ElementC,
|
| 654 |
+
typename B2bConv2d::LayoutC,
|
| 655 |
+
ElementCompute,
|
| 656 |
+
ElementAccumulator
|
| 657 |
+
>(
|
| 658 |
+
kConvolutionalOperator,
|
| 659 |
+
problem_size_1,
|
| 660 |
+
tensor_D0_reference.device_ref(),
|
| 661 |
+
tensor_B1.device_ref(),
|
| 662 |
+
{tensor_Bias1.device_data(), typename B2bConv2d::LayoutC::Stride(0)},
|
| 663 |
+
tensor_D1_reference.device_ref(),
|
| 664 |
+
alpha1,
|
| 665 |
+
beta1);
|
| 666 |
+
|
| 667 |
+
if(relu) {
|
| 668 |
+
cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view());
|
| 669 |
+
}
|
| 670 |
+
|
| 671 |
+
cudaError_t result = cudaDeviceSynchronize();
|
| 672 |
+
CHECK_TRUE(result == cudaSuccess);
|
| 673 |
+
|
| 674 |
+
// sync host (copy device data to host) for dumping error output in case of mismatches
|
| 675 |
+
tensor_D0_reference.sync_host();
|
| 676 |
+
tensor_D1_reference.sync_host();
|
| 677 |
+
|
| 678 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0);
|
| 679 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0);
|
| 680 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0);
|
| 681 |
+
|
| 682 |
+
passed = cutlass::reference::host::TensorEquals(
|
| 683 |
+
tensor_D1_computed.host_view(),
|
| 684 |
+
tensor_D1_reference.host_view());
|
| 685 |
+
|
| 686 |
+
CHECK_TRUE(passed);
|
| 687 |
+
|
| 688 |
+
if (!passed) {
|
| 689 |
+
std::stringstream fname;
|
| 690 |
+
|
| 691 |
+
fname << "error_B2bImplicitGemm_device_fused.txt";
|
| 692 |
+
std::cerr << "Dumping results in " << fname.str() << "\n";
|
| 693 |
+
|
| 694 |
+
std::ofstream results(fname.str());
|
| 695 |
+
|
| 696 |
+
results << problem_size_0 << std::endl;
|
| 697 |
+
results << problem_size_1 << std::endl;
|
| 698 |
+
|
| 699 |
+
results
|
| 700 |
+
<< "\nA0:\n" << tensor_A0.host_view() << "\n"
|
| 701 |
+
<< "\nB0:\n" << tensor_B0.host_view() << "\n"
|
| 702 |
+
<< "\nC0:\n" << tensor_C0.host_view() << "\n"
|
| 703 |
+
<< "\nScale0:\n" << tensor_Scale0.host_view() << "\n"
|
| 704 |
+
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
| 705 |
+
<< "\nB1:\n" << tensor_B1.host_view() << "\n"
|
| 706 |
+
<< "\nC1:\n" << tensor_C1.host_view() << "\n"
|
| 707 |
+
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
| 708 |
+
<< "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
|
| 709 |
+
<< "\nD1 computed:\n" << tensor_D1_computed.host_view();
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
}
|
| 713 |
+
|
| 714 |
+
return passed;
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
};
|
| 718 |
+
|
| 719 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_gemm_run.h
ADDED
|
@@ -0,0 +1,763 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
#pragma once
|
| 32 |
+
|
| 33 |
+
#include <iostream>
|
| 34 |
+
#include <fstream>
|
| 35 |
+
#include <sstream>
|
| 36 |
+
|
| 37 |
+
#include "cutlass/util/host_tensor.h"
|
| 38 |
+
#include "cutlass/util/tensor_view_io.h"
|
| 39 |
+
#include "cutlass/util/distribution.h"
|
| 40 |
+
#include "cutlass/util/reference/host/tensor_fill.h"
|
| 41 |
+
#include "cutlass/util/reference/host/tensor_copy.h"
|
| 42 |
+
#include "cutlass/util/reference/host/tensor_compare.h"
|
| 43 |
+
#include "cutlass/util/reference/host/tensor_norm.h"
|
| 44 |
+
#include "cutlass/util/reference/device/gemm.h"
|
| 45 |
+
#include "cutlass/util/reference/device/gemm_complex.h"
|
| 46 |
+
#include "cutlass/util/reference/device/tensor_relu.h"
|
| 47 |
+
|
| 48 |
+
#include "reference/device/tensor_scale_bias.h"
|
| 49 |
+
#include "helper.h"
|
| 50 |
+
|
| 51 |
+
#define CHECK_GT(val1, val2) \
|
| 52 |
+
if((val1) <= (val2)) \
|
| 53 |
+
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
|
| 54 |
+
#define CHECK_TRUE(val) \
|
| 55 |
+
if(!(val)) \
|
| 56 |
+
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
|
| 57 |
+
|
| 58 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
|
| 60 |
+
template <typename Gemm0_, typename Gemm1_>
|
| 61 |
+
struct B2bNonFusedGemmRun
|
| 62 |
+
{
|
| 63 |
+
|
| 64 |
+
using Gemm0 = Gemm0_;
|
| 65 |
+
using Gemm1 = Gemm1_;
|
| 66 |
+
using ElementAccumulator = typename Gemm0::ElementAccumulator;
|
| 67 |
+
using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute;
|
| 68 |
+
|
| 69 |
+
/// Initialization
|
| 70 |
+
cutlass::Distribution::Kind init_A;
|
| 71 |
+
cutlass::Distribution::Kind init_B;
|
| 72 |
+
cutlass::Distribution::Kind init_C;
|
| 73 |
+
cutlass::Distribution::Kind init_Bias;
|
| 74 |
+
uint64_t seed;
|
| 75 |
+
|
| 76 |
+
//
|
| 77 |
+
// Methods
|
| 78 |
+
//
|
| 79 |
+
|
| 80 |
+
B2bNonFusedGemmRun(
|
| 81 |
+
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 82 |
+
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 83 |
+
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 84 |
+
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
| 85 |
+
uint64_t seed_ = 2080
|
| 86 |
+
):
|
| 87 |
+
init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { }
|
| 88 |
+
|
| 89 |
+
/// Helper to initialize a tensor view
|
| 90 |
+
template <typename Element, typename Layout>
|
| 91 |
+
bool initialize_tensor(
|
| 92 |
+
cutlass::TensorView<Element, Layout> view,
|
| 93 |
+
cutlass::Distribution::Kind dist_kind,
|
| 94 |
+
uint64_t seed) {
|
| 95 |
+
|
| 96 |
+
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 97 |
+
|
| 98 |
+
cutlass::reference::host::TensorFillRandomUniform(
|
| 99 |
+
view, seed, 2, -2, 0);
|
| 100 |
+
}
|
| 101 |
+
else if (dist_kind == cutlass::Distribution::Identity) {
|
| 102 |
+
|
| 103 |
+
cutlass::reference::host::TensorFillIdentity(view);
|
| 104 |
+
}
|
| 105 |
+
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 106 |
+
|
| 107 |
+
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 108 |
+
}
|
| 109 |
+
else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 110 |
+
|
| 111 |
+
cutlass::reference::host::BlockFillSequential(
|
| 112 |
+
view.data(), view.capacity());
|
| 113 |
+
}
|
| 114 |
+
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
| 115 |
+
cutlass::reference::host::TensorFill(view, Element(0));
|
| 116 |
+
}
|
| 117 |
+
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
| 118 |
+
cutlass::reference::host::TensorFill(view, Element(1));
|
| 119 |
+
}
|
| 120 |
+
else {
|
| 121 |
+
std::cerr << "Not implemented\n";
|
| 122 |
+
return false;
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
return true;
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
/// Executes one test
|
| 132 |
+
bool run(
|
| 133 |
+
cutlass::gemm::GemmCoord problem_size_0,
|
| 134 |
+
cutlass::gemm::GemmCoord problem_size_1,
|
| 135 |
+
ElementCompute alpha0 = ElementCompute(1),
|
| 136 |
+
ElementCompute beta0 = ElementCompute(0),
|
| 137 |
+
ElementCompute alpha1 = ElementCompute(1),
|
| 138 |
+
ElementCompute beta1 = ElementCompute(0),
|
| 139 |
+
bool relu = true,
|
| 140 |
+
int warm_ups = 1,
|
| 141 |
+
int runs = 100) {
|
| 142 |
+
|
| 143 |
+
//
|
| 144 |
+
// Allocate the GEMM workspace
|
| 145 |
+
//
|
| 146 |
+
|
| 147 |
+
cutlass::HostTensor<
|
| 148 |
+
typename Gemm0::ElementA,
|
| 149 |
+
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
|
| 150 |
+
|
| 151 |
+
cutlass::HostTensor<
|
| 152 |
+
typename Gemm0::ElementB,
|
| 153 |
+
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
|
| 154 |
+
|
| 155 |
+
cutlass::HostTensor<
|
| 156 |
+
typename Gemm0::ElementC,
|
| 157 |
+
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
|
| 158 |
+
|
| 159 |
+
cutlass::HostTensor<
|
| 160 |
+
ElementCompute,
|
| 161 |
+
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
|
| 162 |
+
|
| 163 |
+
cutlass::HostTensor<
|
| 164 |
+
typename Gemm0::ElementC,
|
| 165 |
+
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
|
| 166 |
+
|
| 167 |
+
cutlass::HostTensor<
|
| 168 |
+
typename Gemm0::ElementC,
|
| 169 |
+
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
|
| 170 |
+
|
| 171 |
+
cutlass::HostTensor<
|
| 172 |
+
typename Gemm1::ElementB,
|
| 173 |
+
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
|
| 174 |
+
|
| 175 |
+
cutlass::HostTensor<
|
| 176 |
+
typename Gemm1::ElementC,
|
| 177 |
+
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
|
| 178 |
+
|
| 179 |
+
cutlass::HostTensor<
|
| 180 |
+
ElementCompute,
|
| 181 |
+
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
|
| 182 |
+
|
| 183 |
+
cutlass::HostTensor<
|
| 184 |
+
typename Gemm1::ElementC,
|
| 185 |
+
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
|
| 186 |
+
|
| 187 |
+
cutlass::HostTensor<
|
| 188 |
+
typename Gemm1::ElementC,
|
| 189 |
+
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
| 193 |
+
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
| 194 |
+
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
| 195 |
+
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014));
|
| 196 |
+
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
| 197 |
+
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
| 198 |
+
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013));
|
| 199 |
+
|
| 200 |
+
cutlass::reference::host::TensorFill(
|
| 201 |
+
tensor_D0.host_view());
|
| 202 |
+
cutlass::reference::host::TensorFill(
|
| 203 |
+
tensor_D1.host_view());
|
| 204 |
+
cutlass::reference::host::TensorFill(
|
| 205 |
+
reference_D0.host_view());
|
| 206 |
+
cutlass::reference::host::TensorFill(
|
| 207 |
+
reference_D1.host_view());
|
| 208 |
+
|
| 209 |
+
tensor_A0.sync_device();
|
| 210 |
+
tensor_B0.sync_device();
|
| 211 |
+
tensor_C0.sync_device();
|
| 212 |
+
tensor_Bias0.sync_device();
|
| 213 |
+
tensor_D0.sync_device();
|
| 214 |
+
tensor_B1.sync_device();
|
| 215 |
+
tensor_C1.sync_device();
|
| 216 |
+
tensor_Bias1.sync_device();
|
| 217 |
+
tensor_D1.sync_device();
|
| 218 |
+
reference_D0.sync_device();
|
| 219 |
+
reference_D1.sync_device();
|
| 220 |
+
|
| 221 |
+
//
|
| 222 |
+
// Initialize the GEMM operator
|
| 223 |
+
//
|
| 224 |
+
|
| 225 |
+
typename Gemm0::Arguments arguments_0{
|
| 226 |
+
problem_size_0,
|
| 227 |
+
tensor_A0.device_ref(),
|
| 228 |
+
tensor_B0.device_ref(),
|
| 229 |
+
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
| 230 |
+
tensor_D0.device_ref(),
|
| 231 |
+
{alpha0, beta0}
|
| 232 |
+
};
|
| 233 |
+
|
| 234 |
+
typename Gemm1::Arguments arguments_1{
|
| 235 |
+
problem_size_1,
|
| 236 |
+
tensor_D0.device_ref(),
|
| 237 |
+
tensor_B1.device_ref(),
|
| 238 |
+
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
| 239 |
+
tensor_D1.device_ref(),
|
| 240 |
+
{alpha1, beta1}
|
| 241 |
+
};
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
Gemm0 gemm_op_0;
|
| 245 |
+
Gemm1 gemm_op_1;
|
| 246 |
+
|
| 247 |
+
cutlass::Status status = gemm_op_0.initialize(arguments_0);
|
| 248 |
+
|
| 249 |
+
CUTLASS_CHECK(status);
|
| 250 |
+
|
| 251 |
+
status = gemm_op_1.initialize(arguments_1);
|
| 252 |
+
|
| 253 |
+
CUTLASS_CHECK(status);
|
| 254 |
+
|
| 255 |
+
for(int i = 0; i < warm_ups; i++) {
|
| 256 |
+
status = gemm_op_0();
|
| 257 |
+
CUTLASS_CHECK(status);
|
| 258 |
+
status = gemm_op_1();
|
| 259 |
+
CUTLASS_CHECK(status);
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
//
|
| 263 |
+
// Run the GEMM
|
| 264 |
+
//
|
| 265 |
+
cudaEvent_t start, stop1, stop2;
|
| 266 |
+
cudaEventCreate(&start);
|
| 267 |
+
cudaEventCreate(&stop1);
|
| 268 |
+
cudaEventCreate(&stop2);
|
| 269 |
+
|
| 270 |
+
cudaEventRecord(start);
|
| 271 |
+
|
| 272 |
+
for(int i = 0; i < runs; i++) {
|
| 273 |
+
status = gemm_op_0();
|
| 274 |
+
|
| 275 |
+
CUTLASS_CHECK(status);
|
| 276 |
+
}
|
| 277 |
+
cudaEventRecord(stop1);
|
| 278 |
+
for(int i = 0; i < runs; i++) {
|
| 279 |
+
status = gemm_op_1();
|
| 280 |
+
|
| 281 |
+
CUTLASS_CHECK(status);
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
cudaEventRecord(stop2);
|
| 285 |
+
cudaDeviceSynchronize();
|
| 286 |
+
float gemm0Time, gemm1Time, totalTime;
|
| 287 |
+
cudaEventElapsedTime(&gemm0Time, start, stop1);
|
| 288 |
+
cudaEventElapsedTime(&gemm1Time, stop1, stop2);
|
| 289 |
+
cudaEventElapsedTime(&totalTime, start, stop2);
|
| 290 |
+
std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
|
| 291 |
+
std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
|
| 292 |
+
std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
|
| 293 |
+
|
| 294 |
+
tensor_D0.sync_host();
|
| 295 |
+
tensor_D1.sync_host();
|
| 296 |
+
|
| 297 |
+
//
|
| 298 |
+
// Verify
|
| 299 |
+
//
|
| 300 |
+
cutlass::reference::device::Gemm<
|
| 301 |
+
typename Gemm0::ElementA, typename Gemm0::LayoutA,
|
| 302 |
+
typename Gemm0::ElementB, typename Gemm0::LayoutB,
|
| 303 |
+
typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute,
|
| 304 |
+
ElementAccumulator, typename Gemm0::Operator>
|
| 305 |
+
reference_gemm_0;
|
| 306 |
+
|
| 307 |
+
cutlass::reference::device::Gemm<
|
| 308 |
+
typename Gemm1::ElementA, typename Gemm1::LayoutA,
|
| 309 |
+
typename Gemm1::ElementB, typename Gemm1::LayoutB,
|
| 310 |
+
typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute,
|
| 311 |
+
ElementAccumulator, typename Gemm1::Operator>
|
| 312 |
+
reference_gemm_1;
|
| 313 |
+
|
| 314 |
+
reference_gemm_0(
|
| 315 |
+
problem_size_0,
|
| 316 |
+
alpha0,
|
| 317 |
+
tensor_A0.device_ref(),
|
| 318 |
+
tensor_B0.device_ref(),
|
| 319 |
+
beta0,
|
| 320 |
+
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
| 321 |
+
reference_D0.device_ref()
|
| 322 |
+
);
|
| 323 |
+
|
| 324 |
+
if(relu) {
|
| 325 |
+
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
reference_gemm_1(
|
| 329 |
+
problem_size_1,
|
| 330 |
+
alpha1,
|
| 331 |
+
reference_D0.device_ref(),
|
| 332 |
+
tensor_B1.device_ref(),
|
| 333 |
+
beta1,
|
| 334 |
+
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
| 335 |
+
reference_D1.device_ref()
|
| 336 |
+
);
|
| 337 |
+
|
| 338 |
+
if(relu) {
|
| 339 |
+
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
// Wait for kernels to finish
|
| 343 |
+
cudaDeviceSynchronize();
|
| 344 |
+
reference_D0.sync_host();
|
| 345 |
+
reference_D1.sync_host();
|
| 346 |
+
|
| 347 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
|
| 348 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
| 349 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
| 350 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
| 351 |
+
|
| 352 |
+
bool passed = cutlass::reference::host::TensorEquals(
|
| 353 |
+
reference_D1.host_view(),
|
| 354 |
+
tensor_D1.host_view());
|
| 355 |
+
|
| 356 |
+
CHECK_TRUE(passed);
|
| 357 |
+
if (!passed) {
|
| 358 |
+
|
| 359 |
+
std::stringstream fname;
|
| 360 |
+
|
| 361 |
+
fname << "error_B2bGemm_device_nonfused.txt";
|
| 362 |
+
std::cerr << "Dumping results in " << fname.str() << "\n";
|
| 363 |
+
|
| 364 |
+
std::ofstream file(fname.str());
|
| 365 |
+
|
| 366 |
+
file
|
| 367 |
+
<< "A0 =\n" << tensor_A0.host_view()
|
| 368 |
+
<< "\nB0 =\n" << tensor_B0.host_view()
|
| 369 |
+
<< "\nC0 =\n" << tensor_C0.host_view()
|
| 370 |
+
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
| 371 |
+
<< "\nD0 =\n" << tensor_D0.host_view()
|
| 372 |
+
<< "\nB1 =\n" << tensor_B1.host_view()
|
| 373 |
+
<< "\nC1 =\n" << tensor_C1.host_view()
|
| 374 |
+
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
| 375 |
+
<< "\n\nReference =\n" << reference_D1.host_view()
|
| 376 |
+
<< "\nComputed =\n" << tensor_D1.host_view();
|
| 377 |
+
}
|
| 378 |
+
return passed;
|
| 379 |
+
}
|
| 380 |
+
};
|
| 381 |
+
|
| 382 |
+
template <typename B2bGemm_>
|
| 383 |
+
struct B2bFusedGemmRun
|
| 384 |
+
{
|
| 385 |
+
|
| 386 |
+
using B2bGemm = B2bGemm_;
|
| 387 |
+
using ElementAccumulator = typename B2bGemm::ElementAccumulator;
|
| 388 |
+
using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute;
|
| 389 |
+
|
| 390 |
+
/// Initialization
|
| 391 |
+
cutlass::Distribution::Kind init_A;
|
| 392 |
+
cutlass::Distribution::Kind init_B;
|
| 393 |
+
cutlass::Distribution::Kind init_C;
|
| 394 |
+
cutlass::Distribution::Kind init_Scale;
|
| 395 |
+
cutlass::Distribution::Kind init_Bias;
|
| 396 |
+
uint64_t seed;
|
| 397 |
+
|
| 398 |
+
//
|
| 399 |
+
// Methods
|
| 400 |
+
//
|
| 401 |
+
|
| 402 |
+
B2bFusedGemmRun(
|
| 403 |
+
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 404 |
+
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 405 |
+
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 406 |
+
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
| 407 |
+
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
| 408 |
+
uint64_t seed_ = 2080
|
| 409 |
+
):
|
| 410 |
+
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
| 411 |
+
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
|
| 412 |
+
|
| 413 |
+
/// Helper to initialize a tensor view
|
| 414 |
+
template <typename Element, typename Layout>
|
| 415 |
+
bool initialize_tensor(
|
| 416 |
+
cutlass::TensorView<Element, Layout> view,
|
| 417 |
+
cutlass::Distribution::Kind dist_kind,
|
| 418 |
+
uint64_t seed) {
|
| 419 |
+
|
| 420 |
+
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 421 |
+
|
| 422 |
+
cutlass::reference::host::TensorFillRandomUniform(
|
| 423 |
+
view, seed, 2, -2, 0);
|
| 424 |
+
}
|
| 425 |
+
else if (dist_kind == cutlass::Distribution::Identity) {
|
| 426 |
+
|
| 427 |
+
cutlass::reference::host::TensorFillIdentity(view);
|
| 428 |
+
}
|
| 429 |
+
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 430 |
+
|
| 431 |
+
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 432 |
+
}
|
| 433 |
+
else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 434 |
+
|
| 435 |
+
cutlass::reference::host::BlockFillSequential(
|
| 436 |
+
view.data(), view.capacity());
|
| 437 |
+
}
|
| 438 |
+
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
| 439 |
+
cutlass::reference::host::TensorFill(view, Element(0));
|
| 440 |
+
}
|
| 441 |
+
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
| 442 |
+
cutlass::reference::host::TensorFill(view, Element(1));
|
| 443 |
+
}
|
| 444 |
+
else {
|
| 445 |
+
std::cerr << "Not implemented\n";
|
| 446 |
+
return false;
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
return true;
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
/// Executes one test
|
| 456 |
+
bool run(
|
| 457 |
+
cutlass::gemm::GemmCoord problem_size_0,
|
| 458 |
+
cutlass::gemm::GemmCoord problem_size_1,
|
| 459 |
+
ElementCompute alpha0 = ElementCompute(1),
|
| 460 |
+
ElementCompute beta0 = ElementCompute(0),
|
| 461 |
+
ElementCompute alpha1 = ElementCompute(1),
|
| 462 |
+
ElementCompute beta1 = ElementCompute(0),
|
| 463 |
+
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
|
| 464 |
+
|
| 465 |
+
// batch_count is used as split-k when mode is kGemm according
|
| 466 |
+
// to the GemmUniversal interface
|
| 467 |
+
|
| 468 |
+
int batch_count = 1,
|
| 469 |
+
int64_t batch_stride_A0 = 0,
|
| 470 |
+
int64_t batch_stride_B0 = 0,
|
| 471 |
+
int64_t batch_stride_C0 = 0,
|
| 472 |
+
int64_t batch_stride_B1 = 0,
|
| 473 |
+
int64_t batch_stride_C1 = 0,
|
| 474 |
+
int64_t batch_stride_D1 = 0,
|
| 475 |
+
int64_t batch_stride_Bias0 = 0,
|
| 476 |
+
int64_t batch_stride_Scale0 = 0,
|
| 477 |
+
bool relu = true,
|
| 478 |
+
int warm_ups = 1,
|
| 479 |
+
int runs = 100) {
|
| 480 |
+
|
| 481 |
+
//
|
| 482 |
+
// Allocate the GEMM workspace
|
| 483 |
+
//
|
| 484 |
+
|
| 485 |
+
cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
| 486 |
+
cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
| 487 |
+
cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
|
| 488 |
+
cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
|
| 489 |
+
cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
|
| 490 |
+
|
| 491 |
+
cutlass::HostTensor<
|
| 492 |
+
typename B2bGemm::ElementA,
|
| 493 |
+
typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
|
| 494 |
+
|
| 495 |
+
cutlass::HostTensor<
|
| 496 |
+
typename B2bGemm::ElementB,
|
| 497 |
+
typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
|
| 498 |
+
|
| 499 |
+
cutlass::HostTensor<
|
| 500 |
+
typename B2bGemm::ElementC,
|
| 501 |
+
typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
|
| 502 |
+
|
| 503 |
+
cutlass::HostTensor<
|
| 504 |
+
typename B2bGemm::ElementScaleBias,
|
| 505 |
+
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
|
| 506 |
+
|
| 507 |
+
if(alpha0 == ElementCompute(0)) //per-channel scale
|
| 508 |
+
tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
|
| 509 |
+
|
| 510 |
+
cutlass::HostTensor<
|
| 511 |
+
typename B2bGemm::ElementScaleBias,
|
| 512 |
+
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
|
| 513 |
+
|
| 514 |
+
cutlass::HostTensor<
|
| 515 |
+
ElementAccumulator,
|
| 516 |
+
typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
|
| 517 |
+
|
| 518 |
+
cutlass::HostTensor<
|
| 519 |
+
typename B2bGemm::ElementC,
|
| 520 |
+
typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
|
| 521 |
+
|
| 522 |
+
cutlass::HostTensor<
|
| 523 |
+
typename B2bGemm::ElementB,
|
| 524 |
+
typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
|
| 525 |
+
|
| 526 |
+
cutlass::HostTensor<
|
| 527 |
+
typename B2bGemm::ElementC,
|
| 528 |
+
typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
|
| 529 |
+
|
| 530 |
+
cutlass::HostTensor<
|
| 531 |
+
typename B2bGemm::ElementC,
|
| 532 |
+
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
|
| 533 |
+
|
| 534 |
+
cutlass::HostTensor<
|
| 535 |
+
typename B2bGemm::ElementC,
|
| 536 |
+
typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
|
| 537 |
+
|
| 538 |
+
cutlass::HostTensor<
|
| 539 |
+
typename B2bGemm::ElementC,
|
| 540 |
+
typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
| 544 |
+
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
| 545 |
+
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
| 546 |
+
if(alpha0 == ElementCompute(0)) //per-channel scale
|
| 547 |
+
CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014));
|
| 548 |
+
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013));
|
| 549 |
+
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
| 550 |
+
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
| 551 |
+
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012));
|
| 552 |
+
|
| 553 |
+
cutlass::reference::host::TensorFill(
|
| 554 |
+
tensor_D1.host_view());
|
| 555 |
+
cutlass::reference::host::TensorFill(
|
| 556 |
+
reference_D0.host_view());
|
| 557 |
+
cutlass::reference::host::TensorFill(
|
| 558 |
+
reference_D1.host_view());
|
| 559 |
+
|
| 560 |
+
tensor_A0.sync_device();
|
| 561 |
+
tensor_B0.sync_device();
|
| 562 |
+
tensor_C0.sync_device();
|
| 563 |
+
if(alpha0 == ElementCompute(0)) //per-channel scale
|
| 564 |
+
tensor_Scale0.sync_device();
|
| 565 |
+
tensor_Bias0.sync_device();
|
| 566 |
+
tensor_B1.sync_device();
|
| 567 |
+
tensor_C1.sync_device();
|
| 568 |
+
tensor_Bias1.sync_device();
|
| 569 |
+
tensor_D1.sync_device();
|
| 570 |
+
reference_D0.sync_device();
|
| 571 |
+
reference_D1.sync_device();
|
| 572 |
+
|
| 573 |
+
//
|
| 574 |
+
// Initialize the GEMM operator
|
| 575 |
+
//
|
| 576 |
+
|
| 577 |
+
typename B2bGemm::Arguments arguments{
|
| 578 |
+
mode,
|
| 579 |
+
problem_size_0,
|
| 580 |
+
problem_size_1,
|
| 581 |
+
tensor_A0.device_ref(),
|
| 582 |
+
tensor_B0.device_ref(),
|
| 583 |
+
tensor_C0.device_ref(),
|
| 584 |
+
tensor_Scale0.device_ref(),
|
| 585 |
+
tensor_Bias0.device_ref(),
|
| 586 |
+
tensor_B1.device_ref(),
|
| 587 |
+
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
| 588 |
+
tensor_D1.device_ref(),
|
| 589 |
+
batch_stride_A0,
|
| 590 |
+
batch_stride_B0,
|
| 591 |
+
batch_stride_B1,
|
| 592 |
+
batch_stride_C1,
|
| 593 |
+
batch_stride_D1,
|
| 594 |
+
batch_stride_Bias0,
|
| 595 |
+
batch_stride_Scale0,
|
| 596 |
+
{alpha0, beta0},
|
| 597 |
+
{alpha1, beta1},
|
| 598 |
+
batch_count,
|
| 599 |
+
};
|
| 600 |
+
|
| 601 |
+
B2bGemm b2b_gemm_op;
|
| 602 |
+
|
| 603 |
+
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
|
| 604 |
+
|
| 605 |
+
if(status != cutlass::Status::kSuccess) {
|
| 606 |
+
std::cout << "Problem sizes not supported.\n"
|
| 607 |
+
<< "Requirments:\n"
|
| 608 |
+
<< " problem_size_0.M = problem_size_1.M\n"
|
| 609 |
+
<< " problem_size_0.N = problem_size_1.K\n"
|
| 610 |
+
<< " ThreadblockShape0::kN = problem_size_0.N\n"
|
| 611 |
+
<< " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
status = b2b_gemm_op.initialize(arguments);
|
| 615 |
+
|
| 616 |
+
CUTLASS_CHECK(status);
|
| 617 |
+
|
| 618 |
+
for(int i = 0; i < warm_ups; i++) {
|
| 619 |
+
status = b2b_gemm_op();
|
| 620 |
+
CUTLASS_CHECK(status);
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
//
|
| 624 |
+
// Run the GEMM
|
| 625 |
+
//
|
| 626 |
+
|
| 627 |
+
cudaEvent_t start, stop;
|
| 628 |
+
cudaEventCreate(&start);
|
| 629 |
+
cudaEventCreate(&stop);
|
| 630 |
+
|
| 631 |
+
cudaEventRecord(start);
|
| 632 |
+
|
| 633 |
+
for(int i = 0; i < runs; i++) {
|
| 634 |
+
status = b2b_gemm_op();
|
| 635 |
+
|
| 636 |
+
CUTLASS_CHECK(status);
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
cudaEventRecord(stop);
|
| 640 |
+
cudaDeviceSynchronize();
|
| 641 |
+
float gemmTime;
|
| 642 |
+
cudaEventElapsedTime(&gemmTime, start, stop);
|
| 643 |
+
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
| 644 |
+
|
| 645 |
+
tensor_D1.sync_host();
|
| 646 |
+
|
| 647 |
+
//
|
| 648 |
+
// Verify
|
| 649 |
+
//
|
| 650 |
+
|
| 651 |
+
cutlass::reference::device::GemmComplex<
|
| 652 |
+
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
| 653 |
+
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
| 654 |
+
ElementAccumulator, typename B2bGemm::LayoutC,
|
| 655 |
+
ElementAccumulator, ElementAccumulator
|
| 656 |
+
>(
|
| 657 |
+
|
| 658 |
+
problem_size_0,
|
| 659 |
+
ElementAccumulator(1), //intermediate alpha=1
|
| 660 |
+
tensor_A0.device_ref(),
|
| 661 |
+
cutlass::ComplexTransform::kNone,
|
| 662 |
+
tensor_B0.device_ref(),
|
| 663 |
+
cutlass::ComplexTransform::kNone,
|
| 664 |
+
ElementAccumulator(0), //beta = 0
|
| 665 |
+
reference_Z0.device_ref(),
|
| 666 |
+
reference_Z0.device_ref(),
|
| 667 |
+
ElementAccumulator(0),
|
| 668 |
+
int(batch_count),
|
| 669 |
+
batch_stride_A0,
|
| 670 |
+
batch_stride_B0,
|
| 671 |
+
batch_stride_C0,
|
| 672 |
+
batch_stride_C0
|
| 673 |
+
);
|
| 674 |
+
|
| 675 |
+
cutlass::reference::device::TensorScaleBiasGemmBatched<
|
| 676 |
+
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
| 677 |
+
ElementCompute, typename B2bGemm::LayoutScaleBias
|
| 678 |
+
> (
|
| 679 |
+
problem_size_0,
|
| 680 |
+
reference_Z0.device_ref(),
|
| 681 |
+
reference_D0.device_ref(),
|
| 682 |
+
alpha0,
|
| 683 |
+
tensor_Scale0.device_ref(),
|
| 684 |
+
tensor_Bias0.device_ref(),
|
| 685 |
+
int(batch_count),
|
| 686 |
+
batch_stride_C0,
|
| 687 |
+
batch_stride_C0,
|
| 688 |
+
batch_stride_Scale0,
|
| 689 |
+
batch_stride_Bias0
|
| 690 |
+
);
|
| 691 |
+
|
| 692 |
+
if(relu) {
|
| 693 |
+
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
| 694 |
+
}
|
| 695 |
+
|
| 696 |
+
cutlass::reference::device::GemmComplex<
|
| 697 |
+
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
| 698 |
+
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
| 699 |
+
typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
| 700 |
+
ElementCompute, ElementAccumulator
|
| 701 |
+
>(
|
| 702 |
+
problem_size_1,
|
| 703 |
+
alpha1, //intermediate alpha=1
|
| 704 |
+
reference_D0.device_ref(),
|
| 705 |
+
cutlass::ComplexTransform::kNone,
|
| 706 |
+
tensor_B1.device_ref(),
|
| 707 |
+
cutlass::ComplexTransform::kNone,
|
| 708 |
+
beta1, //beta = 0
|
| 709 |
+
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
| 710 |
+
reference_D1.device_ref(),
|
| 711 |
+
ElementAccumulator(0),
|
| 712 |
+
int(batch_count),
|
| 713 |
+
batch_stride_C0,
|
| 714 |
+
batch_stride_B1,
|
| 715 |
+
batch_stride_C1,
|
| 716 |
+
batch_stride_D1
|
| 717 |
+
);
|
| 718 |
+
|
| 719 |
+
if(relu) {
|
| 720 |
+
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
| 721 |
+
}
|
| 722 |
+
|
| 723 |
+
cudaDeviceSynchronize();
|
| 724 |
+
reference_D0.sync_host();
|
| 725 |
+
reference_D1.sync_host();
|
| 726 |
+
|
| 727 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
| 728 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
| 729 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
| 730 |
+
|
| 731 |
+
bool passed = cutlass::reference::host::TensorEquals(
|
| 732 |
+
reference_D1.host_view(),
|
| 733 |
+
tensor_D1.host_view());
|
| 734 |
+
|
| 735 |
+
CHECK_TRUE(passed);
|
| 736 |
+
if (!passed)
|
| 737 |
+
{
|
| 738 |
+
|
| 739 |
+
std::stringstream fname;
|
| 740 |
+
|
| 741 |
+
fname << "error_B2bGemm_device_fused.txt";
|
| 742 |
+
std::cerr << "Dumping results in " << fname.str() << "\n";
|
| 743 |
+
|
| 744 |
+
std::ofstream file(fname.str());
|
| 745 |
+
|
| 746 |
+
file
|
| 747 |
+
<< "A0 =\n" << tensor_A0.host_view()
|
| 748 |
+
<< "\nB0 =\n" << tensor_B0.host_view()
|
| 749 |
+
<< "\nC0 =\n" << tensor_C0.host_view()
|
| 750 |
+
<< "\nScale0:\n" << tensor_Scale0.host_view() << "\n"
|
| 751 |
+
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
| 752 |
+
<< "\nB1 =\n" << tensor_B1.host_view()
|
| 753 |
+
<< "\nC1 =\n" << tensor_C1.host_view()
|
| 754 |
+
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
| 755 |
+
<< "\n\nReference =\n" << reference_D1.host_view()
|
| 756 |
+
<< "\nComputed =\n" << tensor_D1.host_view();
|
| 757 |
+
}
|
| 758 |
+
return passed;
|
| 759 |
+
}
|
| 760 |
+
|
| 761 |
+
};
|
| 762 |
+
|
| 763 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Containers for running grouped back-to-back GEMMs
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include <iostream>
|
| 38 |
+
#include <fstream>
|
| 39 |
+
#include <sstream>
|
| 40 |
+
|
| 41 |
+
#include "cutlass/util/device_memory.h"
|
| 42 |
+
#include "cutlass/util/host_tensor.h"
|
| 43 |
+
#include "cutlass/util/tensor_view_io.h"
|
| 44 |
+
#include "cutlass/util/distribution.h"
|
| 45 |
+
#include "cutlass/util/reference/host/tensor_fill.h"
|
| 46 |
+
#include "cutlass/util/reference/host/tensor_copy.h"
|
| 47 |
+
#include "cutlass/util/reference/host/tensor_compare.h"
|
| 48 |
+
#include "cutlass/util/reference/host/tensor_norm.h"
|
| 49 |
+
#include "cutlass/util/reference/device/gemm.h"
|
| 50 |
+
#include "cutlass/util/reference/device/tensor_relu.h"
|
| 51 |
+
|
| 52 |
+
#include "reference/device/tensor_scale_bias.h"
|
| 53 |
+
#include "helper.h"
|
| 54 |
+
|
| 55 |
+
#define CHECK_GT(val1, val2) \
|
| 56 |
+
if((val1) <= (val2)) \
|
| 57 |
+
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
|
| 58 |
+
#define CHECK_TRUE(val) \
|
| 59 |
+
if(!(val)) \
|
| 60 |
+
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
|
| 61 |
+
|
| 62 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 63 |
+
|
| 64 |
+
template <typename B2bGemm_>
|
| 65 |
+
struct B2bFusedGroupedGemmRun
|
| 66 |
+
{
|
| 67 |
+
|
| 68 |
+
using B2bGemm = B2bGemm_;
|
| 69 |
+
using ElementAccumulator = typename B2bGemm::ElementAccumulator;
|
| 70 |
+
using ElementCompute = typename B2bGemm::BaseKernel::Epilogue::OutputOp::ElementCompute;
|
| 71 |
+
|
| 72 |
+
/// Initialization
|
| 73 |
+
cutlass::Distribution::Kind init_A;
|
| 74 |
+
cutlass::Distribution::Kind init_B;
|
| 75 |
+
cutlass::Distribution::Kind init_C;
|
| 76 |
+
cutlass::Distribution::Kind init_Scale;
|
| 77 |
+
cutlass::Distribution::Kind init_Bias;
|
| 78 |
+
uint64_t seed;
|
| 79 |
+
|
| 80 |
+
//
|
| 81 |
+
// Methods
|
| 82 |
+
//
|
| 83 |
+
|
| 84 |
+
B2bFusedGroupedGemmRun(
|
| 85 |
+
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 86 |
+
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 87 |
+
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 88 |
+
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
| 89 |
+
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
| 90 |
+
uint64_t seed_ = 2080
|
| 91 |
+
):
|
| 92 |
+
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
| 93 |
+
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
|
| 94 |
+
|
| 95 |
+
/// Helper to initialize a tensor view
|
| 96 |
+
template <typename Element, typename Layout>
|
| 97 |
+
bool initialize_tensor(
|
| 98 |
+
cutlass::TensorView<Element, Layout> view,
|
| 99 |
+
cutlass::Distribution::Kind dist_kind,
|
| 100 |
+
uint64_t seed) {
|
| 101 |
+
|
| 102 |
+
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 103 |
+
|
| 104 |
+
cutlass::reference::host::TensorFillRandomUniform(
|
| 105 |
+
view, seed, 1, -1, 0);
|
| 106 |
+
}
|
| 107 |
+
else if (dist_kind == cutlass::Distribution::Identity) {
|
| 108 |
+
|
| 109 |
+
cutlass::reference::host::TensorFillIdentity(view);
|
| 110 |
+
}
|
| 111 |
+
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 112 |
+
|
| 113 |
+
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 114 |
+
}
|
| 115 |
+
else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 116 |
+
|
| 117 |
+
cutlass::reference::host::BlockFillSequential(
|
| 118 |
+
view.data(), view.capacity());
|
| 119 |
+
}
|
| 120 |
+
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
| 121 |
+
cutlass::reference::host::TensorFill(view, Element(0));
|
| 122 |
+
}
|
| 123 |
+
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
| 124 |
+
cutlass::reference::host::TensorFill(view, Element(1));
|
| 125 |
+
}
|
| 126 |
+
else {
|
| 127 |
+
std::cerr << "Not implemented\n";
|
| 128 |
+
return false;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
return true;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
/// Executes one test
|
| 135 |
+
bool run(
|
| 136 |
+
std::vector<cutlass::gemm::GemmCoord> problem_sizes_0,
|
| 137 |
+
std::vector<cutlass::gemm::GemmCoord> problem_sizes_1,
|
| 138 |
+
ElementCompute alpha0 = ElementCompute(1),
|
| 139 |
+
ElementCompute beta0 = ElementCompute(0),
|
| 140 |
+
ElementCompute alpha1 = ElementCompute(1),
|
| 141 |
+
ElementCompute beta1 = ElementCompute(0),
|
| 142 |
+
bool relu = true,
|
| 143 |
+
int warm_ups = 1,
|
| 144 |
+
int runs = 100) {
|
| 145 |
+
|
| 146 |
+
using HostTensorA = cutlass::HostTensor<typename B2bGemm::ElementA, typename B2bGemm::LayoutA>;
|
| 147 |
+
using HostTensorB = cutlass::HostTensor<typename B2bGemm::ElementB, typename B2bGemm::LayoutB>;
|
| 148 |
+
using HostTensorC = cutlass::HostTensor<typename B2bGemm::ElementC, typename B2bGemm::LayoutC>;
|
| 149 |
+
using HostTensorScale = cutlass::HostTensor<ElementCompute, typename B2bGemm::LayoutC>;
|
| 150 |
+
using HostTensorZ = cutlass::HostTensor<ElementAccumulator, typename B2bGemm::LayoutC>;
|
| 151 |
+
using HostTensorBias = cutlass::HostTensor<ElementCompute, typename B2bGemm::LayoutC>;
|
| 152 |
+
|
| 153 |
+
int problem_count = (int)problem_sizes_0.size();
|
| 154 |
+
|
| 155 |
+
std::vector<HostTensorA> host_tensor_A0(problem_count);
|
| 156 |
+
std::vector<HostTensorB> host_tensor_B0(problem_count);
|
| 157 |
+
std::vector<HostTensorC> host_tensor_C0(problem_count);
|
| 158 |
+
std::vector<HostTensorScale> host_tensor_Scale0(problem_count);
|
| 159 |
+
std::vector<HostTensorScale> host_tensor_Bias0(problem_count);
|
| 160 |
+
std::vector<HostTensorB> host_tensor_B1(problem_count);
|
| 161 |
+
std::vector<HostTensorC> host_tensor_C1(problem_count);
|
| 162 |
+
std::vector<HostTensorBias> host_tensor_Bias1(problem_count);
|
| 163 |
+
std::vector<HostTensorC> host_tensor_D1(problem_count);
|
| 164 |
+
std::vector<HostTensorZ> host_tensor_Z(problem_count);
|
| 165 |
+
std::vector<HostTensorC> host_tensor_ref_D0(problem_count);
|
| 166 |
+
std::vector<HostTensorC> host_tensor_ref_D1(problem_count);
|
| 167 |
+
|
| 168 |
+
std::vector<typename HostTensorA::TensorRef> ref_A0(problem_count);
|
| 169 |
+
std::vector<typename HostTensorB::TensorRef> ref_B0(problem_count);
|
| 170 |
+
std::vector<typename HostTensorC::TensorRef> ref_C0(problem_count);
|
| 171 |
+
std::vector<typename HostTensorScale::TensorRef> ref_Scale0(problem_count);
|
| 172 |
+
std::vector<typename HostTensorScale::TensorRef> ref_Bias0(problem_count);
|
| 173 |
+
std::vector<typename HostTensorB::TensorRef> ref_B1(problem_count);
|
| 174 |
+
std::vector<typename HostTensorC::TensorRef> ref_C1(problem_count);
|
| 175 |
+
std::vector<typename HostTensorBias::TensorRef> ref_Bias1(problem_count);
|
| 176 |
+
std::vector<typename HostTensorC::TensorRef> ref_D1(problem_count);
|
| 177 |
+
std::vector<typename HostTensorZ::TensorRef> ref_Z(problem_count);
|
| 178 |
+
std::vector<typename HostTensorC::TensorRef> ref_ref_D0(problem_count);
|
| 179 |
+
std::vector<typename HostTensorC::TensorRef> ref_ref_D1(problem_count);
|
| 180 |
+
|
| 181 |
+
for (int i = 0; i < problem_count; ++i) {
|
| 182 |
+
//
|
| 183 |
+
// Allocate the GEMM workspace
|
| 184 |
+
//
|
| 185 |
+
|
| 186 |
+
auto problem_size_0 = problem_sizes_0[i];
|
| 187 |
+
auto problem_size_1 = problem_sizes_1[i];
|
| 188 |
+
|
| 189 |
+
host_tensor_A0.at(i) = HostTensorA(problem_size_0.mk());
|
| 190 |
+
host_tensor_B0.at(i) = HostTensorB(problem_size_0.kn());
|
| 191 |
+
host_tensor_C0.at(i) = HostTensorC(problem_size_0.mn());
|
| 192 |
+
if (alpha0 == ElementCompute(0)) //per-channel scale
|
| 193 |
+
host_tensor_Scale0.at(i) = HostTensorScale(typename HostTensorZ::Layout::TensorCoord{1, problem_size_0.n()});
|
| 194 |
+
host_tensor_Bias0.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_0.n()});
|
| 195 |
+
host_tensor_Z.at(i) = HostTensorZ(problem_size_0.mn());
|
| 196 |
+
host_tensor_ref_D0.at(i) = HostTensorC(problem_size_0.mn());
|
| 197 |
+
host_tensor_B1.at(i) = HostTensorB(problem_size_1.kn());
|
| 198 |
+
host_tensor_C1.at(i) = HostTensorC(problem_size_1.mn());
|
| 199 |
+
host_tensor_Bias1.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_1.n()});
|
| 200 |
+
host_tensor_D1.at(i) = HostTensorC(problem_size_1.mn());
|
| 201 |
+
host_tensor_ref_D1.at(i) = HostTensorC(problem_size_1.mn());
|
| 202 |
+
|
| 203 |
+
CHECK_TRUE(initialize_tensor(host_tensor_A0.at(i).host_view(), init_A, seed + 2019));
|
| 204 |
+
CHECK_TRUE(initialize_tensor(host_tensor_B0.at(i).host_view(), init_B, seed + 2018));
|
| 205 |
+
CHECK_TRUE(initialize_tensor(host_tensor_C0.at(i).host_view(), init_C, seed + 2017));
|
| 206 |
+
if (alpha0 == ElementCompute(0)) //per-channel scale
|
| 207 |
+
CHECK_TRUE(initialize_tensor(host_tensor_Scale0.at(i).host_view(), init_Scale, seed + 2014));
|
| 208 |
+
CHECK_TRUE(initialize_tensor(host_tensor_Bias0.at(i).host_view(), init_Bias, seed + 2013));
|
| 209 |
+
CHECK_TRUE(initialize_tensor(host_tensor_B1.at(i).host_view(), init_B, seed + 2016));
|
| 210 |
+
CHECK_TRUE(initialize_tensor(host_tensor_C1.at(i).host_view(), init_C, seed + 2015));
|
| 211 |
+
CHECK_TRUE(initialize_tensor(host_tensor_Bias1.at(i).host_view(), init_Bias, seed + 2012));
|
| 212 |
+
|
| 213 |
+
cutlass::reference::host::TensorFill(
|
| 214 |
+
host_tensor_D1.at(i).host_view());
|
| 215 |
+
cutlass::reference::host::TensorFill(
|
| 216 |
+
host_tensor_ref_D0.at(i).host_view());
|
| 217 |
+
cutlass::reference::host::TensorFill(
|
| 218 |
+
host_tensor_ref_D1.at(i).host_view());
|
| 219 |
+
|
| 220 |
+
host_tensor_A0.at(i).sync_device();
|
| 221 |
+
host_tensor_B0.at(i).sync_device();
|
| 222 |
+
host_tensor_C0.at(i).sync_device();
|
| 223 |
+
if (alpha0 == ElementCompute(0)) //per-channel scale
|
| 224 |
+
host_tensor_Scale0.at(i).sync_device();
|
| 225 |
+
host_tensor_Bias0.at(i).sync_device();
|
| 226 |
+
host_tensor_B1.at(i).sync_device();
|
| 227 |
+
host_tensor_C1.at(i).sync_device();
|
| 228 |
+
host_tensor_Bias1.at(i).sync_device();
|
| 229 |
+
host_tensor_D1.at(i).sync_device();
|
| 230 |
+
host_tensor_ref_D0.at(i).sync_device();
|
| 231 |
+
host_tensor_ref_D1.at(i).sync_device();
|
| 232 |
+
|
| 233 |
+
ref_A0.at(i) = (host_tensor_A0.at(i).device_ref());
|
| 234 |
+
ref_B0.at(i) = (host_tensor_B0.at(i).device_ref());
|
| 235 |
+
ref_C0.at(i) = (host_tensor_C0.at(i).device_ref());
|
| 236 |
+
if (alpha0 == ElementCompute(0)) //per-channel scale
|
| 237 |
+
ref_Scale0.at(i) = (host_tensor_Scale0.at(i).device_ref());
|
| 238 |
+
ref_Bias0.at(i) = (host_tensor_Bias0.at(i).device_ref());
|
| 239 |
+
ref_B1.at(i) = (host_tensor_B1.at(i).device_ref());
|
| 240 |
+
ref_C1.at(i) = {host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)};
|
| 241 |
+
ref_Bias1.at(i) = (host_tensor_Bias1.at(i).device_ref());
|
| 242 |
+
ref_D1.at(i) = (host_tensor_D1.at(i).device_ref());
|
| 243 |
+
ref_Z.at(i) = (host_tensor_Z.at(i).device_ref());
|
| 244 |
+
ref_ref_D0.at(i) = (host_tensor_ref_D0.at(i).device_ref());
|
| 245 |
+
ref_ref_D1.at(i) = (host_tensor_ref_D1.at(i).device_ref());
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
//
|
| 249 |
+
// Initialize the GEMM operator
|
| 250 |
+
//
|
| 251 |
+
|
| 252 |
+
cutlass::DeviceAllocation<typename HostTensorA::TensorRef> device_ref_A0(problem_count);
|
| 253 |
+
device_ref_A0.copy_from_host(ref_A0.data());
|
| 254 |
+
cutlass::DeviceAllocation<typename HostTensorB::TensorRef> device_ref_B0(problem_count);
|
| 255 |
+
device_ref_B0.copy_from_host(ref_B0.data());
|
| 256 |
+
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_C0(problem_count);
|
| 257 |
+
device_ref_C0.copy_from_host(ref_C0.data());
|
| 258 |
+
cutlass::DeviceAllocation<typename HostTensorScale::TensorRef> device_ref_Scale0(problem_count);
|
| 259 |
+
device_ref_Scale0.copy_from_host(ref_Scale0.data());
|
| 260 |
+
cutlass::DeviceAllocation<typename HostTensorScale::TensorRef> device_ref_Bias0(problem_count);
|
| 261 |
+
device_ref_Bias0.copy_from_host(ref_Bias0.data());
|
| 262 |
+
cutlass::DeviceAllocation<typename HostTensorB::TensorRef> device_ref_B1(problem_count);
|
| 263 |
+
device_ref_B1.copy_from_host(ref_B1.data());
|
| 264 |
+
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_C1(problem_count);
|
| 265 |
+
device_ref_C1.copy_from_host(ref_C1.data());
|
| 266 |
+
cutlass::DeviceAllocation<typename HostTensorBias::TensorRef> device_ref_Bias1(problem_count);
|
| 267 |
+
device_ref_Bias1.copy_from_host(ref_Bias1.data());
|
| 268 |
+
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_D1(problem_count);
|
| 269 |
+
device_ref_D1.copy_from_host(ref_D1.data());
|
| 270 |
+
|
| 271 |
+
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes_0(problem_count);
|
| 272 |
+
device_problem_sizes_0.copy_from_host(problem_sizes_0.data());
|
| 273 |
+
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes_1(problem_count);
|
| 274 |
+
device_problem_sizes_1.copy_from_host(problem_sizes_1.data());
|
| 275 |
+
|
| 276 |
+
B2bGemm b2b_gemm_op;
|
| 277 |
+
|
| 278 |
+
int threadblock_count = B2bGemm::sufficient(problem_sizes_1.data(), problem_count);
|
| 279 |
+
if (!threadblock_count) {
|
| 280 |
+
std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl;
|
| 281 |
+
return false;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
typename B2bGemm::Arguments arguments{
|
| 285 |
+
problem_count,
|
| 286 |
+
device_problem_sizes_0.get(),
|
| 287 |
+
device_problem_sizes_1.get(),
|
| 288 |
+
device_ref_A0.get(),
|
| 289 |
+
device_ref_B0.get(),
|
| 290 |
+
device_ref_C0.get(),
|
| 291 |
+
device_ref_Scale0.get(),
|
| 292 |
+
device_ref_Bias0.get(),
|
| 293 |
+
device_ref_B1.get(),
|
| 294 |
+
device_ref_C1.get(),
|
| 295 |
+
device_ref_D1.get(),
|
| 296 |
+
{alpha0, beta0},
|
| 297 |
+
{alpha1, beta1},
|
| 298 |
+
threadblock_count
|
| 299 |
+
};
|
| 300 |
+
|
| 301 |
+
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
|
| 302 |
+
|
| 303 |
+
if(status != cutlass::Status::kSuccess) {
|
| 304 |
+
std::cout << "Problem sizes not supported.\n"
|
| 305 |
+
<< "Requirments:\n"
|
| 306 |
+
<< " problem_size_0.M = problem_size_1.M\n"
|
| 307 |
+
<< " problem_size_0.N = problem_size_1.K\n"
|
| 308 |
+
<< " ThreadblockShape0::kN = problem_size_0.N\n"
|
| 309 |
+
<< " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
status = b2b_gemm_op.initialize(arguments);
|
| 313 |
+
|
| 314 |
+
CUTLASS_CHECK(status);
|
| 315 |
+
|
| 316 |
+
for(int i = 0; i < warm_ups; i++) {
|
| 317 |
+
status = b2b_gemm_op();
|
| 318 |
+
CUTLASS_CHECK(status);
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
//
|
| 322 |
+
// Run the GEMM
|
| 323 |
+
//
|
| 324 |
+
|
| 325 |
+
cudaEvent_t start, stop;
|
| 326 |
+
cudaEventCreate(&start);
|
| 327 |
+
cudaEventCreate(&stop);
|
| 328 |
+
|
| 329 |
+
cudaEventRecord(start);
|
| 330 |
+
|
| 331 |
+
for(int i = 0; i < runs; i++) {
|
| 332 |
+
status = b2b_gemm_op();
|
| 333 |
+
CUTLASS_CHECK(status);
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
cudaEventRecord(stop);
|
| 337 |
+
cudaDeviceSynchronize();
|
| 338 |
+
float gemmTime;
|
| 339 |
+
cudaEventElapsedTime(&gemmTime, start, stop);
|
| 340 |
+
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
| 341 |
+
|
| 342 |
+
for (int i = 0; i < problem_count; ++i) {
|
| 343 |
+
host_tensor_D1.at(i).sync_host();
|
| 344 |
+
|
| 345 |
+
//
|
| 346 |
+
// Verify
|
| 347 |
+
//
|
| 348 |
+
|
| 349 |
+
cutlass::reference::device::Gemm<
|
| 350 |
+
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
| 351 |
+
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
| 352 |
+
ElementAccumulator, typename B2bGemm::LayoutC,
|
| 353 |
+
ElementAccumulator, ElementAccumulator>
|
| 354 |
+
reference_gemm_0;
|
| 355 |
+
|
| 356 |
+
cutlass::reference::device::Gemm<
|
| 357 |
+
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
| 358 |
+
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
| 359 |
+
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
| 360 |
+
ElementAccumulator>
|
| 361 |
+
reference_gemm_1;
|
| 362 |
+
|
| 363 |
+
auto problem_size_0 = problem_sizes_0[i];
|
| 364 |
+
auto problem_size_1 = problem_sizes_1[i];
|
| 365 |
+
|
| 366 |
+
reference_gemm_0(
|
| 367 |
+
problem_size_0,
|
| 368 |
+
ElementAccumulator(1), //intermediate alpha=1
|
| 369 |
+
ref_A0.at(i),
|
| 370 |
+
ref_B0.at(i),
|
| 371 |
+
ElementAccumulator(0), //beta = 0
|
| 372 |
+
ref_Z.at(i),
|
| 373 |
+
ref_Z.at(i),
|
| 374 |
+
ElementAccumulator(0)
|
| 375 |
+
);
|
| 376 |
+
|
| 377 |
+
cutlass::reference::device::TensorScaleBiasGemm<
|
| 378 |
+
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
| 379 |
+
ElementCompute, typename B2bGemm::LayoutC
|
| 380 |
+
> (
|
| 381 |
+
problem_size_0,
|
| 382 |
+
ref_Z.at(i),
|
| 383 |
+
ref_ref_D0.at(i),
|
| 384 |
+
alpha0,
|
| 385 |
+
ref_Scale0.at(i),
|
| 386 |
+
ref_Bias0.at(i)
|
| 387 |
+
);
|
| 388 |
+
|
| 389 |
+
if(relu) {
|
| 390 |
+
cutlass::reference::device::TensorReLu(host_tensor_ref_D0.at(i).device_view());
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
reference_gemm_1(
|
| 394 |
+
problem_size_1,
|
| 395 |
+
alpha1,
|
| 396 |
+
ref_ref_D0.at(i),
|
| 397 |
+
ref_B1.at(i),
|
| 398 |
+
beta1,
|
| 399 |
+
{host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
| 400 |
+
ref_ref_D1.at(i)
|
| 401 |
+
);
|
| 402 |
+
if(relu) {
|
| 403 |
+
cutlass::reference::device::TensorReLu(host_tensor_ref_D1.at(i).device_view());
|
| 404 |
+
}
|
| 405 |
+
cudaDeviceSynchronize();
|
| 406 |
+
host_tensor_ref_D0.at(i).sync_host();
|
| 407 |
+
host_tensor_ref_D1.at(i).sync_host();
|
| 408 |
+
|
| 409 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D0.at(i).host_view()), 0);
|
| 410 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_D1.at(i).host_view()), 0);
|
| 411 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D1.at(i).host_view()), 0);
|
| 412 |
+
|
| 413 |
+
bool passed = cutlass::reference::host::TensorEquals(
|
| 414 |
+
host_tensor_ref_D1.at(i).host_view(),
|
| 415 |
+
host_tensor_D1.at(i).host_view());
|
| 416 |
+
|
| 417 |
+
CHECK_TRUE(passed);
|
| 418 |
+
if (!passed)
|
| 419 |
+
{
|
| 420 |
+
|
| 421 |
+
std::stringstream fname;
|
| 422 |
+
|
| 423 |
+
fname << "error_B2bGemm_device_fused.txt";
|
| 424 |
+
std::cerr << "Check failed for GEMM " << i << " in the group." << std::endl;
|
| 425 |
+
std::cerr << "Dumping results in " << fname.str() << "\n";
|
| 426 |
+
|
| 427 |
+
std::ofstream file(fname.str());
|
| 428 |
+
|
| 429 |
+
file
|
| 430 |
+
<< "GEMM " << i << " in group\n"
|
| 431 |
+
<< "A0 =\n" << host_tensor_A0.at(i).host_view()
|
| 432 |
+
<< "\nB0 =\n" << host_tensor_B0.at(i).host_view()
|
| 433 |
+
<< "\nC0 =\n" << host_tensor_C0.at(i).host_view()
|
| 434 |
+
<< "\nScale0:\n" << host_tensor_Scale0.at(i).host_view() << "\n"
|
| 435 |
+
<< "\nBias0:\n" << host_tensor_Bias0.at(i).host_view() << "\n"
|
| 436 |
+
<< "\nB1 =\n" << host_tensor_B1.at(i).host_view()
|
| 437 |
+
<< "\nC1 =\n" << host_tensor_C1.at(i).host_view()
|
| 438 |
+
<< "\nBias1:\n" << host_tensor_Bias1.at(i).host_view() << "\n"
|
| 439 |
+
<< "\n\nReference =\n" << host_tensor_ref_D1.at(i).host_view()
|
| 440 |
+
<< "\nComputed =\n" << host_tensor_D1.at(i).host_view();
|
| 441 |
+
|
| 442 |
+
return false;
|
| 443 |
+
}
|
| 444 |
+
}
|
| 445 |
+
return true;
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
};
|
| 449 |
+
|
| 450 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h
ADDED
|
@@ -0,0 +1,749 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
#pragma once
|
| 33 |
+
|
| 34 |
+
#include <iostream>
|
| 35 |
+
#include <fstream>
|
| 36 |
+
#include <sstream>
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
|
| 40 |
+
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
| 41 |
+
#include "cutlass/reduction/device/reduce_split_k.h"
|
| 42 |
+
#include "cutlass/reduction/thread/reduction_operators.h"
|
| 43 |
+
|
| 44 |
+
#include "cutlass/util/host_tensor.h"
|
| 45 |
+
#include "cutlass/util/reference/host/tensor_fill.h"
|
| 46 |
+
#include "cutlass/util/reference/device/tensor_compare.h"
|
| 47 |
+
#include "cutlass/util/reference/host/tensor_compare.h"
|
| 48 |
+
#include "cutlass/util/reference/host/tensor_norm.h"
|
| 49 |
+
#include "cutlass/util/host_reorder.h"
|
| 50 |
+
|
| 51 |
+
#include "cutlass/util/reference/host/convolution.h"
|
| 52 |
+
#include "cutlass/util/reference/device/convolution.h"
|
| 53 |
+
#include "cutlass/util/reference/device/tensor_relu.h"
|
| 54 |
+
|
| 55 |
+
#include "cutlass/core_io.h"
|
| 56 |
+
#include "cutlass/util/tensor_view_io.h"
|
| 57 |
+
|
| 58 |
+
#include "reference/device/tensor_scale_bias.h"
|
| 59 |
+
#include "helper.h"
|
| 60 |
+
|
| 61 |
+
#define CHECK_GT(val1, val2) \
|
| 62 |
+
if((val1) <= (val2)) \
|
| 63 |
+
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
|
| 64 |
+
#define CHECK_TRUE(val) \
|
| 65 |
+
if(!(val)) \
|
| 66 |
+
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
template <typename Conv2d0_, typename Conv2d1_, int InterleavedK>
|
| 70 |
+
class B2bInterleavedNonFusedConv2dRun {
|
| 71 |
+
public:
|
| 72 |
+
|
| 73 |
+
using Conv2d0 = Conv2d0_;
|
| 74 |
+
using Conv2d1 = Conv2d1_;
|
| 75 |
+
using ElementAccumulator = typename Conv2d0::ElementAccumulator;
|
| 76 |
+
using ElementCompute = typename Conv2d0::ElementCompute;
|
| 77 |
+
|
| 78 |
+
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d0::kConvolutionalOperator;
|
| 79 |
+
static_assert(kConvolutionalOperator == Conv2d1::kConvolutionalOperator,
|
| 80 |
+
"Fused convolution operators must be the same");
|
| 81 |
+
|
| 82 |
+
public:
|
| 83 |
+
|
| 84 |
+
/// Initialization
|
| 85 |
+
cutlass::Distribution::Kind init_A;
|
| 86 |
+
cutlass::Distribution::Kind init_B;
|
| 87 |
+
cutlass::Distribution::Kind init_C;
|
| 88 |
+
cutlass::Distribution::Kind init_Bias;
|
| 89 |
+
uint64_t seed;
|
| 90 |
+
|
| 91 |
+
cutlass::HostTensor<typename Conv2d0::ElementA, typename Conv2d0::LayoutA> tensor_A0;
|
| 92 |
+
cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0;
|
| 93 |
+
cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0_reordered;
|
| 94 |
+
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_C0;
|
| 95 |
+
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_Bias0;
|
| 96 |
+
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_computed;
|
| 97 |
+
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_reference;
|
| 98 |
+
|
| 99 |
+
cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1;
|
| 100 |
+
cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1_reordered;
|
| 101 |
+
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_C1;
|
| 102 |
+
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d0::LayoutC> tensor_Bias1;
|
| 103 |
+
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_computed;
|
| 104 |
+
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_reference;
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
public:
|
| 108 |
+
|
| 109 |
+
B2bInterleavedNonFusedConv2dRun(
|
| 110 |
+
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 111 |
+
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 112 |
+
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 113 |
+
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
| 114 |
+
uint64_t seed_ = 2080
|
| 115 |
+
):
|
| 116 |
+
init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) {
|
| 117 |
+
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
/// Helper to initialize a tensor view
|
| 121 |
+
template <typename Element, typename Layout>
|
| 122 |
+
void initialize_tensor(
|
| 123 |
+
cutlass::TensorView<Element, Layout> view,
|
| 124 |
+
cutlass::Distribution::Kind dist_kind,
|
| 125 |
+
uint64_t seed) {
|
| 126 |
+
|
| 127 |
+
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 128 |
+
|
| 129 |
+
int scope;
|
| 130 |
+
int bits = cutlass::sizeof_bits<Element>::value;
|
| 131 |
+
|
| 132 |
+
if (bits <= 16) {
|
| 133 |
+
scope = 2;
|
| 134 |
+
}
|
| 135 |
+
else {
|
| 136 |
+
scope = 8;
|
| 137 |
+
}
|
| 138 |
+
cutlass::reference::host::TensorFillRandomUniform(
|
| 139 |
+
view, seed, scope, -scope, 0);
|
| 140 |
+
}
|
| 141 |
+
else if (dist_kind == cutlass::Distribution::Identity) {
|
| 142 |
+
|
| 143 |
+
cutlass::reference::host::TensorFillIdentity(view);
|
| 144 |
+
}
|
| 145 |
+
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 146 |
+
|
| 147 |
+
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 148 |
+
}
|
| 149 |
+
else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 150 |
+
|
| 151 |
+
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
| 152 |
+
}
|
| 153 |
+
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
| 154 |
+
cutlass::reference::host::TensorFill(view, Element(0));
|
| 155 |
+
}
|
| 156 |
+
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
| 157 |
+
cutlass::reference::host::TensorFill(view, Element(1));
|
| 158 |
+
}
|
| 159 |
+
else {
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
void initialize(
|
| 164 |
+
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
| 165 |
+
cutlass::conv::Conv2dProblemSize const &problem_size_1, uint64_t seed = 2019) {
|
| 166 |
+
|
| 167 |
+
tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0));
|
| 168 |
+
tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
|
| 169 |
+
tensor_B0_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
|
| 170 |
+
tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
| 171 |
+
tensor_Bias0.resize({1, 1, 1, problem_size_0.K});
|
| 172 |
+
tensor_D0_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
| 173 |
+
tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
| 174 |
+
tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
| 175 |
+
tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
| 176 |
+
tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
| 177 |
+
tensor_Bias1.resize({1, 1, 1, problem_size_1.K});
|
| 178 |
+
tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
| 179 |
+
tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
| 180 |
+
|
| 181 |
+
initialize_tensor(tensor_A0.host_view(), init_A, seed);
|
| 182 |
+
initialize_tensor(tensor_B0.host_view(), init_B, seed * 17);
|
| 183 |
+
initialize_tensor(tensor_C0.host_view(), init_C, seed * 39);
|
| 184 |
+
initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83);
|
| 185 |
+
initialize_tensor(tensor_B1.host_view(), init_B, seed * 18);
|
| 186 |
+
initialize_tensor(tensor_C1.host_view(), init_C, seed * 40);
|
| 187 |
+
|
| 188 |
+
//Reorder B0 and B1
|
| 189 |
+
cutlass::reorder_convK<InterleavedK, InterleavedK>(
|
| 190 |
+
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_0));
|
| 191 |
+
cutlass::reorder_convK<InterleavedK, InterleavedK>(
|
| 192 |
+
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_1));
|
| 193 |
+
|
| 194 |
+
tensor_A0.sync_device();
|
| 195 |
+
tensor_B0.sync_device();
|
| 196 |
+
tensor_B0_reordered.sync_device();
|
| 197 |
+
tensor_C0.sync_device();
|
| 198 |
+
tensor_Bias0.sync_device();
|
| 199 |
+
tensor_D0_computed.sync_device();
|
| 200 |
+
tensor_D0_reference.sync_device();
|
| 201 |
+
tensor_B1.sync_device();
|
| 202 |
+
tensor_B1_reordered.sync_device();
|
| 203 |
+
tensor_C1.sync_device();
|
| 204 |
+
tensor_Bias1.sync_device();
|
| 205 |
+
tensor_D1_computed.sync_device();
|
| 206 |
+
tensor_D1_reference.sync_device();
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
/// Executes one test
|
| 210 |
+
bool run(
|
| 211 |
+
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
| 212 |
+
cutlass::conv::Conv2dProblemSize const &problem_size_1,
|
| 213 |
+
cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
|
| 214 |
+
ElementCompute alpha0 = ElementCompute(1),
|
| 215 |
+
ElementCompute beta0 = ElementCompute(0),
|
| 216 |
+
ElementCompute alpha1 = ElementCompute(1),
|
| 217 |
+
ElementCompute beta1 = ElementCompute(0),
|
| 218 |
+
bool relu = true,
|
| 219 |
+
int warm_ups = 1,
|
| 220 |
+
int runs = 100) {
|
| 221 |
+
|
| 222 |
+
initialize(problem_size_0, problem_size_1);
|
| 223 |
+
|
| 224 |
+
// configure the operator
|
| 225 |
+
Conv2d0 conv2d_op_0;
|
| 226 |
+
Conv2d1 conv2d_op_1;
|
| 227 |
+
|
| 228 |
+
typename Conv2d0::Arguments conv2d_args_0(
|
| 229 |
+
problem_size_0,
|
| 230 |
+
tensor_A0.device_ref(),
|
| 231 |
+
tensor_B0_reordered.device_ref(),
|
| 232 |
+
tensor_C0.device_ref(),
|
| 233 |
+
tensor_D0_computed.device_ref(),
|
| 234 |
+
{alpha0, beta0},
|
| 235 |
+
split_k_mode
|
| 236 |
+
);
|
| 237 |
+
typename Conv2d1::Arguments conv2d_args_1(
|
| 238 |
+
problem_size_1,
|
| 239 |
+
tensor_D0_computed.device_ref(),
|
| 240 |
+
tensor_B1_reordered.device_ref(),
|
| 241 |
+
tensor_C1.device_ref(),
|
| 242 |
+
tensor_D1_computed.device_ref(),
|
| 243 |
+
{alpha1, beta1},
|
| 244 |
+
split_k_mode
|
| 245 |
+
);
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
cutlass::Status status = conv2d_op_0.initialize(conv2d_args_0);
|
| 249 |
+
|
| 250 |
+
CUTLASS_CHECK(status);
|
| 251 |
+
|
| 252 |
+
status = conv2d_op_1.initialize(conv2d_args_1);
|
| 253 |
+
|
| 254 |
+
CUTLASS_CHECK(status);
|
| 255 |
+
|
| 256 |
+
for(int i = 0; i < warm_ups; i++) {
|
| 257 |
+
status = conv2d_op_0();
|
| 258 |
+
CUTLASS_CHECK(status);
|
| 259 |
+
status = conv2d_op_1();
|
| 260 |
+
CUTLASS_CHECK(status);
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
//
|
| 264 |
+
// Run Conv2d
|
| 265 |
+
//
|
| 266 |
+
cudaEvent_t start, stop1, stop2;
|
| 267 |
+
cudaEventCreate(&start);
|
| 268 |
+
cudaEventCreate(&stop1);
|
| 269 |
+
cudaEventCreate(&stop2);
|
| 270 |
+
|
| 271 |
+
cudaEventRecord(start);
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
for(int i = 0; i < runs; i++) {
|
| 275 |
+
// run conv2d operator
|
| 276 |
+
status = conv2d_op_0();
|
| 277 |
+
CUTLASS_CHECK(status);
|
| 278 |
+
}
|
| 279 |
+
cudaEventRecord(stop1);
|
| 280 |
+
|
| 281 |
+
for(int i = 0; i < runs; i++) {
|
| 282 |
+
// run conv2d operator
|
| 283 |
+
status = conv2d_op_1();
|
| 284 |
+
CUTLASS_CHECK(status);
|
| 285 |
+
}
|
| 286 |
+
cudaEventRecord(stop2);
|
| 287 |
+
cudaDeviceSynchronize();
|
| 288 |
+
float conv2d0Time, conv2d1Time, totalTime;
|
| 289 |
+
cudaEventElapsedTime(&conv2d0Time, start, stop1);
|
| 290 |
+
cudaEventElapsedTime(&conv2d1Time, stop1, stop2);
|
| 291 |
+
cudaEventElapsedTime(&totalTime, start, stop2);
|
| 292 |
+
std::cout << "conv2d 0 time " << conv2d0Time / (float)runs << " ms\n";
|
| 293 |
+
std::cout << "conv2d 1 time " << conv2d1Time / (float)runs << " ms\n";
|
| 294 |
+
std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
|
| 295 |
+
|
| 296 |
+
tensor_D0_computed.sync_host();
|
| 297 |
+
tensor_D1_computed.sync_host();
|
| 298 |
+
|
| 299 |
+
bool passed = false;
|
| 300 |
+
|
| 301 |
+
cutlass::reference::device::Conv2d<
|
| 302 |
+
typename Conv2d0::ElementA,
|
| 303 |
+
typename Conv2d0::LayoutA,
|
| 304 |
+
typename Conv2d0::ElementB,
|
| 305 |
+
typename Conv2d0::LayoutB,
|
| 306 |
+
typename Conv2d0::ElementC,
|
| 307 |
+
typename Conv2d0::LayoutC,
|
| 308 |
+
ElementCompute,
|
| 309 |
+
ElementAccumulator,
|
| 310 |
+
cutlass::NumericConverterClamp<typename Conv2d0::ElementC, ElementCompute>
|
| 311 |
+
>(
|
| 312 |
+
kConvolutionalOperator,
|
| 313 |
+
problem_size_0,
|
| 314 |
+
tensor_A0.device_ref(),
|
| 315 |
+
tensor_B0.device_ref(),
|
| 316 |
+
tensor_C0.device_ref(),
|
| 317 |
+
tensor_D0_reference.device_ref(),
|
| 318 |
+
alpha0,
|
| 319 |
+
beta0);
|
| 320 |
+
|
| 321 |
+
if(relu) {
|
| 322 |
+
cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
cutlass::reference::device::Conv2d<
|
| 326 |
+
typename Conv2d1::ElementA,
|
| 327 |
+
typename Conv2d1::LayoutA,
|
| 328 |
+
typename Conv2d1::ElementB,
|
| 329 |
+
typename Conv2d1::LayoutB,
|
| 330 |
+
typename Conv2d1::ElementC,
|
| 331 |
+
typename Conv2d1::LayoutC,
|
| 332 |
+
ElementCompute,
|
| 333 |
+
ElementAccumulator,
|
| 334 |
+
cutlass::NumericConverterClamp<typename Conv2d1::ElementC, ElementCompute>
|
| 335 |
+
>(
|
| 336 |
+
kConvolutionalOperator,
|
| 337 |
+
problem_size_1,
|
| 338 |
+
tensor_D0_reference.device_ref(),
|
| 339 |
+
tensor_B1.device_ref(),
|
| 340 |
+
tensor_C1.device_ref(),
|
| 341 |
+
tensor_D1_reference.device_ref(),
|
| 342 |
+
alpha1,
|
| 343 |
+
beta1);
|
| 344 |
+
|
| 345 |
+
if(relu) {
|
| 346 |
+
cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view());
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
cudaError_t result = cudaDeviceSynchronize();
|
| 350 |
+
CHECK_TRUE(result == cudaSuccess);
|
| 351 |
+
|
| 352 |
+
// sync host (copy device data to host) for dumping error output in case of mismatches
|
| 353 |
+
tensor_D0_reference.sync_host();
|
| 354 |
+
tensor_D1_reference.sync_host();
|
| 355 |
+
|
| 356 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_computed.host_view()), 0);
|
| 357 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0);
|
| 358 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0);
|
| 359 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0);
|
| 360 |
+
|
| 361 |
+
passed = cutlass::reference::host::TensorEquals(
|
| 362 |
+
tensor_D1_computed.host_view(),
|
| 363 |
+
tensor_D1_reference.host_view());
|
| 364 |
+
|
| 365 |
+
CHECK_TRUE(passed);
|
| 366 |
+
|
| 367 |
+
if (!passed) {
|
| 368 |
+
std::stringstream fname;
|
| 369 |
+
|
| 370 |
+
fname << "error_B2bImplicitGemm_device_interleaved_nonfused.txt";
|
| 371 |
+
std::cerr << "Dumping results in " << fname.str() << "\n";
|
| 372 |
+
|
| 373 |
+
std::ofstream results(fname.str());
|
| 374 |
+
|
| 375 |
+
results << problem_size_0 << std::endl;
|
| 376 |
+
results << problem_size_1 << std::endl;
|
| 377 |
+
|
| 378 |
+
results
|
| 379 |
+
<< "\nA0:\n" << tensor_A0.host_view() << "\n"
|
| 380 |
+
<< "\nB0:\n" << tensor_B0.host_view() << "\n"
|
| 381 |
+
<< "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n"
|
| 382 |
+
<< "\nC0:\n" << tensor_C0.host_view() << "\n"
|
| 383 |
+
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
| 384 |
+
<< "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n"
|
| 385 |
+
<< "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n"
|
| 386 |
+
<< "\nB1:\n" << tensor_B1.host_view() << "\n"
|
| 387 |
+
<< "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n"
|
| 388 |
+
<< "\nC1:\n" << tensor_C1.host_view() << "\n"
|
| 389 |
+
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
| 390 |
+
<< "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
|
| 391 |
+
<< "\nD1 computed:\n" << tensor_D1_computed.host_view();
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
return passed;
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
};
|
| 400 |
+
|
| 401 |
+
template <typename B2bConv2d_, int InterleavedK>
|
| 402 |
+
class B2bInterleavedFusedConv2dRun {
|
| 403 |
+
public:
|
| 404 |
+
|
| 405 |
+
using B2bConv2d = B2bConv2d_;
|
| 406 |
+
using ElementAccumulator = typename B2bConv2d::ElementAccumulator;
|
| 407 |
+
using ElementCompute = typename B2bConv2d::ElementCompute;
|
| 408 |
+
|
| 409 |
+
static cutlass::conv::Operator const kConvolutionalOperator = B2bConv2d::kConvolutionalOperator;
|
| 410 |
+
|
| 411 |
+
public:
|
| 412 |
+
|
| 413 |
+
/// Initialization
|
| 414 |
+
cutlass::Distribution::Kind init_A;
|
| 415 |
+
cutlass::Distribution::Kind init_B;
|
| 416 |
+
cutlass::Distribution::Kind init_C;
|
| 417 |
+
cutlass::Distribution::Kind init_Scale;
|
| 418 |
+
cutlass::Distribution::Kind init_Bias;
|
| 419 |
+
uint64_t seed;
|
| 420 |
+
|
| 421 |
+
cutlass::HostTensor<typename B2bConv2d::ElementA, typename B2bConv2d::LayoutA> tensor_A0;
|
| 422 |
+
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B0;
|
| 423 |
+
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B0_reordered;
|
| 424 |
+
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C0;
|
| 425 |
+
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Scale0;
|
| 426 |
+
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Bias0;
|
| 427 |
+
cutlass::HostTensor<ElementAccumulator, typename B2bConv2d::LayoutC> tensor_Z0_reference;
|
| 428 |
+
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D0_reference;
|
| 429 |
+
|
| 430 |
+
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1;
|
| 431 |
+
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1_reordered;
|
| 432 |
+
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C1;
|
| 433 |
+
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_Bias1;
|
| 434 |
+
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_computed;
|
| 435 |
+
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_reference;
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
public:
|
| 439 |
+
|
| 440 |
+
B2bInterleavedFusedConv2dRun(
|
| 441 |
+
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 442 |
+
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 443 |
+
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 444 |
+
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
| 445 |
+
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
| 446 |
+
uint64_t seed_ = 2080
|
| 447 |
+
):
|
| 448 |
+
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
| 449 |
+
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) {
|
| 450 |
+
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
/// Helper to initialize a tensor view
|
| 454 |
+
template <typename Element, typename Layout>
|
| 455 |
+
void initialize_tensor(
|
| 456 |
+
cutlass::TensorView<Element, Layout> view,
|
| 457 |
+
cutlass::Distribution::Kind dist_kind,
|
| 458 |
+
uint64_t seed) {
|
| 459 |
+
|
| 460 |
+
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 461 |
+
|
| 462 |
+
int scope;
|
| 463 |
+
int bits = cutlass::sizeof_bits<Element>::value;
|
| 464 |
+
|
| 465 |
+
if (bits <= 16) {
|
| 466 |
+
scope = 2;
|
| 467 |
+
}
|
| 468 |
+
else {
|
| 469 |
+
scope = 8;
|
| 470 |
+
}
|
| 471 |
+
cutlass::reference::host::TensorFillRandomUniform(
|
| 472 |
+
view, seed, scope, -scope, 0);
|
| 473 |
+
}
|
| 474 |
+
else if (dist_kind == cutlass::Distribution::Identity) {
|
| 475 |
+
|
| 476 |
+
cutlass::reference::host::TensorFillIdentity(view);
|
| 477 |
+
}
|
| 478 |
+
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 479 |
+
|
| 480 |
+
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 481 |
+
}
|
| 482 |
+
else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 483 |
+
|
| 484 |
+
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
| 485 |
+
}
|
| 486 |
+
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
| 487 |
+
cutlass::reference::host::TensorFill(view, Element(0));
|
| 488 |
+
}
|
| 489 |
+
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
| 490 |
+
cutlass::reference::host::TensorFill(view, Element(1));
|
| 491 |
+
}
|
| 492 |
+
else {
|
| 493 |
+
}
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
void initialize(
|
| 497 |
+
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
| 498 |
+
cutlass::conv::Conv2dProblemSize const &problem_size_1,
|
| 499 |
+
ElementCompute alpha0,
|
| 500 |
+
ElementCompute alpha1,
|
| 501 |
+
uint64_t seed = 2019) {
|
| 502 |
+
|
| 503 |
+
tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0));
|
| 504 |
+
tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
|
| 505 |
+
tensor_B0_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
|
| 506 |
+
tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
| 507 |
+
if(alpha0 == ElementCompute(0)) //per-channel scale
|
| 508 |
+
tensor_Scale0.resize({1, problem_size_0.K});
|
| 509 |
+
tensor_Bias0.resize({1, problem_size_0.K});
|
| 510 |
+
tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
| 511 |
+
tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
| 512 |
+
tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
| 513 |
+
tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
| 514 |
+
tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
| 515 |
+
tensor_Bias1.resize({1, 1, 1, problem_size_1.K});
|
| 516 |
+
tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
| 517 |
+
tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
| 518 |
+
|
| 519 |
+
initialize_tensor(tensor_A0.host_view(), init_A, seed);
|
| 520 |
+
initialize_tensor(tensor_B0.host_view(), init_B, seed * 17);
|
| 521 |
+
initialize_tensor(tensor_C0.host_view(), init_C, seed * 39);
|
| 522 |
+
if(alpha0 == ElementCompute(0)) //per-channel scale
|
| 523 |
+
initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed * 61);
|
| 524 |
+
initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83);
|
| 525 |
+
initialize_tensor(tensor_B1.host_view(), init_B, seed * 18);
|
| 526 |
+
initialize_tensor(tensor_C1.host_view(), init_C, seed * 40);
|
| 527 |
+
initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84);
|
| 528 |
+
|
| 529 |
+
//Reorder B0 and B1
|
| 530 |
+
cutlass::reorder_convK<16, InterleavedK>(
|
| 531 |
+
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_0));
|
| 532 |
+
cutlass::reorder_convK<InterleavedK, InterleavedK>(
|
| 533 |
+
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_1));
|
| 534 |
+
|
| 535 |
+
tensor_A0.sync_device();
|
| 536 |
+
tensor_B0.sync_device();
|
| 537 |
+
tensor_B0_reordered.sync_device();
|
| 538 |
+
tensor_C0.sync_device();
|
| 539 |
+
if(alpha0 == ElementCompute(0)) //per-channel scale
|
| 540 |
+
tensor_Scale0.sync_device();
|
| 541 |
+
tensor_Bias0.sync_device();
|
| 542 |
+
tensor_D0_reference.sync_device();
|
| 543 |
+
tensor_B1.sync_device();
|
| 544 |
+
tensor_B1_reordered.sync_device();
|
| 545 |
+
tensor_C1.sync_device();
|
| 546 |
+
tensor_Bias1.sync_device();
|
| 547 |
+
tensor_D1_computed.sync_device();
|
| 548 |
+
tensor_D1_reference.sync_device();
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
/// Executes one test
|
| 552 |
+
bool run(
|
| 553 |
+
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
| 554 |
+
cutlass::conv::Conv2dProblemSize const &problem_size_1,
|
| 555 |
+
cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
|
| 556 |
+
ElementCompute alpha0 = ElementCompute(1),
|
| 557 |
+
ElementCompute beta0 = ElementCompute(0),
|
| 558 |
+
ElementCompute alpha1 = ElementCompute(1),
|
| 559 |
+
ElementCompute beta1 = ElementCompute(0),
|
| 560 |
+
bool relu = true,
|
| 561 |
+
int warm_ups = 1,
|
| 562 |
+
int runs = 100) {
|
| 563 |
+
|
| 564 |
+
initialize(problem_size_0, problem_size_1, alpha0, alpha1);
|
| 565 |
+
|
| 566 |
+
// configure the operator
|
| 567 |
+
B2bConv2d b2b_conv2d_op;
|
| 568 |
+
|
| 569 |
+
typename B2bConv2d::Arguments b2b_conv2d_args(
|
| 570 |
+
problem_size_0,
|
| 571 |
+
problem_size_1,
|
| 572 |
+
tensor_A0.device_ref(),
|
| 573 |
+
tensor_B0_reordered.device_ref(),
|
| 574 |
+
tensor_C0.device_ref(),
|
| 575 |
+
tensor_Scale0.device_ref(),
|
| 576 |
+
tensor_Bias0.device_ref(),
|
| 577 |
+
tensor_B1_reordered.device_ref(),
|
| 578 |
+
tensor_C1.device_ref(),
|
| 579 |
+
tensor_D1_computed.device_ref(),
|
| 580 |
+
{alpha0, beta0},
|
| 581 |
+
{alpha1, beta1},
|
| 582 |
+
split_k_mode
|
| 583 |
+
);
|
| 584 |
+
|
| 585 |
+
cutlass::Status status = b2b_conv2d_op.can_implement(b2b_conv2d_args);
|
| 586 |
+
|
| 587 |
+
if(status != cutlass::Status::kSuccess) {
|
| 588 |
+
std::cout << "Problem sizes not supported.\n"
|
| 589 |
+
<< "Requirments:\n"
|
| 590 |
+
<< " problem_size_0.N*P*Q = problem_size_1.N*P*Q\n"
|
| 591 |
+
<< " problem_size_0.K = problem_size_1.C\n"
|
| 592 |
+
<< " problem_size_1.R = problem_size_1.S = 1\n"
|
| 593 |
+
<< " ThreadblockShape0::kN = problem_size_0.K\n"
|
| 594 |
+
<< " ThreadblockShape1::kN = problem_size_1.K" << std::endl;
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
CUTLASS_CHECK(status);
|
| 598 |
+
|
| 599 |
+
status = b2b_conv2d_op.initialize(b2b_conv2d_args);
|
| 600 |
+
|
| 601 |
+
CUTLASS_CHECK(status);
|
| 602 |
+
|
| 603 |
+
for(int i = 0; i < warm_ups; i++) {
|
| 604 |
+
status = b2b_conv2d_op();
|
| 605 |
+
CUTLASS_CHECK(status);
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
//
|
| 609 |
+
// Run the Conv2d
|
| 610 |
+
//
|
| 611 |
+
|
| 612 |
+
cudaEvent_t start, stop;
|
| 613 |
+
cudaEventCreate(&start);
|
| 614 |
+
cudaEventCreate(&stop);
|
| 615 |
+
|
| 616 |
+
cudaEventRecord(start);
|
| 617 |
+
|
| 618 |
+
for(int i = 0; i < runs; i++) {
|
| 619 |
+
|
| 620 |
+
// run conv2d operator
|
| 621 |
+
status = b2b_conv2d_op();
|
| 622 |
+
CUTLASS_CHECK(status);
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
+
cudaEventRecord(stop);
|
| 626 |
+
cudaDeviceSynchronize();
|
| 627 |
+
float conv2dTime;
|
| 628 |
+
cudaEventElapsedTime(&conv2dTime, start, stop);
|
| 629 |
+
std::cout << "Fusion time " << conv2dTime / (float)runs << " ms\n";
|
| 630 |
+
|
| 631 |
+
tensor_D1_computed.sync_host();
|
| 632 |
+
|
| 633 |
+
bool passed = false;
|
| 634 |
+
|
| 635 |
+
cutlass::reference::device::Conv2d<
|
| 636 |
+
typename B2bConv2d::ElementA,
|
| 637 |
+
typename B2bConv2d::LayoutA,
|
| 638 |
+
typename B2bConv2d::ElementB,
|
| 639 |
+
typename B2bConv2d::LayoutB,
|
| 640 |
+
ElementAccumulator,
|
| 641 |
+
typename B2bConv2d::LayoutC,
|
| 642 |
+
ElementAccumulator,
|
| 643 |
+
ElementAccumulator
|
| 644 |
+
>(
|
| 645 |
+
kConvolutionalOperator,
|
| 646 |
+
problem_size_0,
|
| 647 |
+
tensor_A0.device_ref(),
|
| 648 |
+
tensor_B0.device_ref(),
|
| 649 |
+
tensor_Z0_reference.device_ref(),
|
| 650 |
+
tensor_Z0_reference.device_ref(),
|
| 651 |
+
ElementAccumulator(1), // intermediate alpha = 1
|
| 652 |
+
ElementAccumulator(0) // beta = 0
|
| 653 |
+
);
|
| 654 |
+
|
| 655 |
+
cutlass::reference::device::TensorScaleBiasConv2d<
|
| 656 |
+
ElementAccumulator,
|
| 657 |
+
typename B2bConv2d::ElementC,
|
| 658 |
+
typename B2bConv2d::LayoutC,
|
| 659 |
+
ElementCompute,
|
| 660 |
+
typename B2bConv2d::LayoutScaleBias,
|
| 661 |
+
cutlass::NumericConverterClamp<typename B2bConv2d::ElementC, ElementCompute>
|
| 662 |
+
>(
|
| 663 |
+
problem_size_0,
|
| 664 |
+
tensor_Z0_reference.device_ref(),
|
| 665 |
+
tensor_D0_reference.device_ref(),
|
| 666 |
+
alpha0,
|
| 667 |
+
tensor_Scale0.device_ref(),
|
| 668 |
+
tensor_Bias0.device_ref()
|
| 669 |
+
);
|
| 670 |
+
|
| 671 |
+
if(relu) {
|
| 672 |
+
cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
cutlass::reference::device::Conv2d<
|
| 676 |
+
typename B2bConv2d::ElementA,
|
| 677 |
+
typename B2bConv2d::LayoutA,
|
| 678 |
+
typename B2bConv2d::ElementB,
|
| 679 |
+
typename B2bConv2d::LayoutB,
|
| 680 |
+
typename B2bConv2d::ElementC,
|
| 681 |
+
typename B2bConv2d::LayoutC,
|
| 682 |
+
ElementCompute,
|
| 683 |
+
ElementAccumulator,
|
| 684 |
+
cutlass::NumericConverterClamp<typename B2bConv2d::ElementC, ElementCompute>
|
| 685 |
+
>(
|
| 686 |
+
kConvolutionalOperator,
|
| 687 |
+
problem_size_1,
|
| 688 |
+
tensor_D0_reference.device_ref(),
|
| 689 |
+
tensor_B1.device_ref(),
|
| 690 |
+
tensor_C1.device_ref(),
|
| 691 |
+
tensor_D1_reference.device_ref(),
|
| 692 |
+
alpha1,
|
| 693 |
+
beta1);
|
| 694 |
+
|
| 695 |
+
if(relu) {
|
| 696 |
+
cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view());
|
| 697 |
+
}
|
| 698 |
+
|
| 699 |
+
cudaError_t result = cudaDeviceSynchronize();
|
| 700 |
+
CHECK_TRUE(result == cudaSuccess);
|
| 701 |
+
|
| 702 |
+
// sync host (copy device data to host) for dumping error output in case of mismatches
|
| 703 |
+
tensor_D0_reference.sync_host();
|
| 704 |
+
tensor_D1_reference.sync_host();
|
| 705 |
+
|
| 706 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0);
|
| 707 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0);
|
| 708 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0);
|
| 709 |
+
|
| 710 |
+
passed = cutlass::reference::host::TensorEquals(
|
| 711 |
+
tensor_D1_computed.host_view(),
|
| 712 |
+
tensor_D1_reference.host_view());
|
| 713 |
+
|
| 714 |
+
CHECK_TRUE(passed);
|
| 715 |
+
|
| 716 |
+
if (!passed) {
|
| 717 |
+
std::stringstream fname;
|
| 718 |
+
|
| 719 |
+
fname << "error_B2bImplicitGemm_device_interleaved_fused.txt";
|
| 720 |
+
std::cerr << "Dumping results in " << fname.str() << "\n";
|
| 721 |
+
|
| 722 |
+
std::ofstream results(fname.str());
|
| 723 |
+
|
| 724 |
+
results << problem_size_0 << std::endl;
|
| 725 |
+
results << problem_size_1 << std::endl;
|
| 726 |
+
|
| 727 |
+
results
|
| 728 |
+
<< "\nA0:\n" << tensor_A0.host_view() << "\n"
|
| 729 |
+
<< "\nB0:\n" << tensor_B0.host_view() << "\n"
|
| 730 |
+
<< "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n"
|
| 731 |
+
<< "\nC0:\n" << tensor_C0.host_view() << "\n"
|
| 732 |
+
<< "\nScale0:\n" << tensor_Scale0.host_view() << "\n"
|
| 733 |
+
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
| 734 |
+
<< "\nB1:\n" << tensor_B1.host_view() << "\n"
|
| 735 |
+
<< "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n"
|
| 736 |
+
<< "\nC1:\n" << tensor_C1.host_view() << "\n"
|
| 737 |
+
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
| 738 |
+
<< "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
|
| 739 |
+
<< "\nD1 computed:\n" << tensor_D1_computed.host_view();
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
}
|
| 743 |
+
|
| 744 |
+
return passed;
|
| 745 |
+
}
|
| 746 |
+
|
| 747 |
+
};
|
| 748 |
+
|
| 749 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h
ADDED
|
@@ -0,0 +1,798 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
#pragma once
|
| 32 |
+
|
| 33 |
+
#include <iostream>
|
| 34 |
+
#include <fstream>
|
| 35 |
+
#include <sstream>
|
| 36 |
+
|
| 37 |
+
#include "cutlass/util/host_tensor.h"
|
| 38 |
+
#include "cutlass/util/tensor_view_io.h"
|
| 39 |
+
#include "cutlass/util/distribution.h"
|
| 40 |
+
#include "cutlass/util/reference/host/tensor_fill.h"
|
| 41 |
+
#include "cutlass/util/reference/host/tensor_copy.h"
|
| 42 |
+
#include "cutlass/util/reference/host/tensor_compare.h"
|
| 43 |
+
#include "cutlass/util/reference/host/tensor_norm.h"
|
| 44 |
+
#include "cutlass/util/host_reorder.h"
|
| 45 |
+
#include "cutlass/util/reference/device/gemm.h"
|
| 46 |
+
#include "cutlass/util/reference/device/gemm_complex.h"
|
| 47 |
+
#include "cutlass/util/reference/device/tensor_relu.h"
|
| 48 |
+
|
| 49 |
+
#include "reference/device/tensor_scale_bias.h"
|
| 50 |
+
#include "helper.h"
|
| 51 |
+
|
| 52 |
+
#define CHECK_GT(val1, val2) \
|
| 53 |
+
if((val1) <= (val2)) \
|
| 54 |
+
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
|
| 55 |
+
#define CHECK_TRUE(val) \
|
| 56 |
+
if(!(val)) \
|
| 57 |
+
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
|
| 58 |
+
|
| 59 |
+
template <typename Gemm0_, typename Gemm1_, int InterleavedK_>
|
| 60 |
+
struct B2bInterleavedNonFusedGemmRun
|
| 61 |
+
{
|
| 62 |
+
|
| 63 |
+
using Gemm0 = Gemm0_;
|
| 64 |
+
using Gemm1 = Gemm1_;
|
| 65 |
+
using ElementAccumulator = typename Gemm0::ElementAccumulator;
|
| 66 |
+
using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute;
|
| 67 |
+
|
| 68 |
+
/// Initialization
|
| 69 |
+
cutlass::Distribution::Kind init_A;
|
| 70 |
+
cutlass::Distribution::Kind init_B;
|
| 71 |
+
cutlass::Distribution::Kind init_C;
|
| 72 |
+
cutlass::Distribution::Kind init_Bias;
|
| 73 |
+
uint64_t seed;
|
| 74 |
+
|
| 75 |
+
//
|
| 76 |
+
// Methods
|
| 77 |
+
//
|
| 78 |
+
|
| 79 |
+
B2bInterleavedNonFusedGemmRun(
|
| 80 |
+
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 81 |
+
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 82 |
+
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 83 |
+
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
| 84 |
+
uint64_t seed_ = 2080
|
| 85 |
+
):
|
| 86 |
+
init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { }
|
| 87 |
+
|
| 88 |
+
/// Helper to initialize a tensor view
|
| 89 |
+
template <typename Element, typename Layout>
|
| 90 |
+
bool initialize_tensor(
|
| 91 |
+
cutlass::TensorView<Element, Layout> view,
|
| 92 |
+
cutlass::Distribution::Kind dist_kind,
|
| 93 |
+
uint64_t seed) {
|
| 94 |
+
|
| 95 |
+
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 96 |
+
|
| 97 |
+
cutlass::reference::host::TensorFillRandomUniform(
|
| 98 |
+
view, seed, 2, -2, 0);
|
| 99 |
+
}
|
| 100 |
+
else if (dist_kind == cutlass::Distribution::Identity) {
|
| 101 |
+
|
| 102 |
+
cutlass::reference::host::TensorFillIdentity(view);
|
| 103 |
+
}
|
| 104 |
+
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 105 |
+
|
| 106 |
+
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 107 |
+
}
|
| 108 |
+
else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 109 |
+
|
| 110 |
+
cutlass::reference::host::BlockFillSequential(
|
| 111 |
+
view.data(), view.capacity());
|
| 112 |
+
}
|
| 113 |
+
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
| 114 |
+
cutlass::reference::host::TensorFill(view, Element(0));
|
| 115 |
+
}
|
| 116 |
+
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
| 117 |
+
cutlass::reference::host::TensorFill(view, Element(1));
|
| 118 |
+
}
|
| 119 |
+
else {
|
| 120 |
+
std::cerr << "Not implemented\n";
|
| 121 |
+
return false;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
return true;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
/// Executes one test
|
| 131 |
+
bool run(
|
| 132 |
+
cutlass::gemm::GemmCoord problem_size_0,
|
| 133 |
+
cutlass::gemm::GemmCoord problem_size_1,
|
| 134 |
+
ElementCompute alpha0 = ElementCompute(1),
|
| 135 |
+
ElementCompute beta0 = ElementCompute(0),
|
| 136 |
+
ElementCompute alpha1 = ElementCompute(1),
|
| 137 |
+
ElementCompute beta1 = ElementCompute(0),
|
| 138 |
+
bool relu = true,
|
| 139 |
+
int warm_ups = 1,
|
| 140 |
+
int runs = 100) {
|
| 141 |
+
|
| 142 |
+
//
|
| 143 |
+
// Allocate the GEMM workspace
|
| 144 |
+
//
|
| 145 |
+
|
| 146 |
+
cutlass::HostTensor<
|
| 147 |
+
typename Gemm0::ElementA,
|
| 148 |
+
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
|
| 149 |
+
|
| 150 |
+
cutlass::HostTensor<
|
| 151 |
+
typename Gemm0::ElementB,
|
| 152 |
+
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
|
| 153 |
+
|
| 154 |
+
cutlass::HostTensor<
|
| 155 |
+
typename Gemm0::ElementB,
|
| 156 |
+
typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn());
|
| 157 |
+
|
| 158 |
+
cutlass::HostTensor<
|
| 159 |
+
typename Gemm0::ElementC,
|
| 160 |
+
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
|
| 161 |
+
|
| 162 |
+
cutlass::HostTensor<
|
| 163 |
+
typename Gemm0::ElementC,
|
| 164 |
+
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
|
| 165 |
+
|
| 166 |
+
cutlass::HostTensor<
|
| 167 |
+
typename Gemm0::ElementC,
|
| 168 |
+
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
|
| 169 |
+
|
| 170 |
+
cutlass::HostTensor<
|
| 171 |
+
typename Gemm0::ElementC,
|
| 172 |
+
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
|
| 173 |
+
|
| 174 |
+
cutlass::HostTensor<
|
| 175 |
+
typename Gemm1::ElementB,
|
| 176 |
+
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
|
| 177 |
+
|
| 178 |
+
cutlass::HostTensor<
|
| 179 |
+
typename Gemm1::ElementB,
|
| 180 |
+
typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn());
|
| 181 |
+
|
| 182 |
+
cutlass::HostTensor<
|
| 183 |
+
typename Gemm1::ElementC,
|
| 184 |
+
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
|
| 185 |
+
|
| 186 |
+
cutlass::HostTensor<
|
| 187 |
+
typename Gemm0::ElementC,
|
| 188 |
+
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
|
| 189 |
+
|
| 190 |
+
cutlass::HostTensor<
|
| 191 |
+
typename Gemm1::ElementC,
|
| 192 |
+
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
|
| 193 |
+
|
| 194 |
+
cutlass::HostTensor<
|
| 195 |
+
typename Gemm1::ElementC,
|
| 196 |
+
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
|
| 197 |
+
|
| 198 |
+
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
| 199 |
+
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
| 200 |
+
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
| 201 |
+
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014));
|
| 202 |
+
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
| 203 |
+
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
| 204 |
+
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013));
|
| 205 |
+
|
| 206 |
+
//Reorder B0 and B1
|
| 207 |
+
cutlass::reorder_column<InterleavedK_>(
|
| 208 |
+
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
|
| 209 |
+
cutlass::reorder_column<InterleavedK_>(
|
| 210 |
+
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
|
| 211 |
+
|
| 212 |
+
cutlass::reference::host::TensorFill(
|
| 213 |
+
tensor_D0.host_view());
|
| 214 |
+
cutlass::reference::host::TensorFill(
|
| 215 |
+
tensor_D1.host_view());
|
| 216 |
+
cutlass::reference::host::TensorFill(
|
| 217 |
+
reference_D0.host_view());
|
| 218 |
+
cutlass::reference::host::TensorFill(
|
| 219 |
+
reference_D1.host_view());
|
| 220 |
+
|
| 221 |
+
tensor_A0.sync_device();
|
| 222 |
+
tensor_B0.sync_device();
|
| 223 |
+
tensor_B0_reordered.sync_device();
|
| 224 |
+
tensor_C0.sync_device();
|
| 225 |
+
tensor_Bias0.sync_device();
|
| 226 |
+
tensor_D0.sync_device();
|
| 227 |
+
tensor_B1.sync_device();
|
| 228 |
+
tensor_B1_reordered.sync_device();
|
| 229 |
+
tensor_C1.sync_device();
|
| 230 |
+
tensor_Bias1.sync_device();
|
| 231 |
+
tensor_D1.sync_device();
|
| 232 |
+
reference_D0.sync_device();
|
| 233 |
+
reference_D1.sync_device();
|
| 234 |
+
|
| 235 |
+
//
|
| 236 |
+
// Initialize the GEMM operator
|
| 237 |
+
//
|
| 238 |
+
|
| 239 |
+
typename Gemm0::Arguments arguments_0{
|
| 240 |
+
problem_size_0,
|
| 241 |
+
tensor_A0.device_ref(),
|
| 242 |
+
tensor_B0_reordered.device_ref(),
|
| 243 |
+
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
| 244 |
+
tensor_D0.device_ref(),
|
| 245 |
+
{alpha0, beta0}
|
| 246 |
+
};
|
| 247 |
+
|
| 248 |
+
typename Gemm1::Arguments arguments_1{
|
| 249 |
+
problem_size_1,
|
| 250 |
+
tensor_D0.device_ref(),
|
| 251 |
+
tensor_B1_reordered.device_ref(),
|
| 252 |
+
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
| 253 |
+
tensor_D1.device_ref(),
|
| 254 |
+
{alpha1, beta1}
|
| 255 |
+
};
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
Gemm0 gemm_op_0;
|
| 259 |
+
Gemm1 gemm_op_1;
|
| 260 |
+
|
| 261 |
+
cutlass::Status status = gemm_op_0.initialize(arguments_0);
|
| 262 |
+
|
| 263 |
+
CUTLASS_CHECK(status);
|
| 264 |
+
|
| 265 |
+
status = gemm_op_1.initialize(arguments_1);
|
| 266 |
+
|
| 267 |
+
CUTLASS_CHECK(status);
|
| 268 |
+
|
| 269 |
+
for(int i = 0; i < warm_ups; i++) {
|
| 270 |
+
status = gemm_op_0();
|
| 271 |
+
CUTLASS_CHECK(status);
|
| 272 |
+
status = gemm_op_1();
|
| 273 |
+
CUTLASS_CHECK(status);
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
//
|
| 277 |
+
// Run the GEMM
|
| 278 |
+
//
|
| 279 |
+
cudaEvent_t start, stop1, stop2;
|
| 280 |
+
cudaEventCreate(&start);
|
| 281 |
+
cudaEventCreate(&stop1);
|
| 282 |
+
cudaEventCreate(&stop2);
|
| 283 |
+
|
| 284 |
+
cudaEventRecord(start);
|
| 285 |
+
|
| 286 |
+
for(int i = 0; i < runs; i++) {
|
| 287 |
+
status = gemm_op_0();
|
| 288 |
+
|
| 289 |
+
CUTLASS_CHECK(status);
|
| 290 |
+
}
|
| 291 |
+
cudaEventRecord(stop1);
|
| 292 |
+
for(int i = 0; i < runs; i++) {
|
| 293 |
+
status = gemm_op_1();
|
| 294 |
+
|
| 295 |
+
CUTLASS_CHECK(status);
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
cudaEventRecord(stop2);
|
| 299 |
+
cudaDeviceSynchronize();
|
| 300 |
+
float gemm0Time, gemm1Time, totalTime;
|
| 301 |
+
cudaEventElapsedTime(&gemm0Time, start, stop1);
|
| 302 |
+
cudaEventElapsedTime(&gemm1Time, stop1, stop2);
|
| 303 |
+
cudaEventElapsedTime(&totalTime, start, stop2);
|
| 304 |
+
std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
|
| 305 |
+
std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
|
| 306 |
+
std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
|
| 307 |
+
|
| 308 |
+
tensor_D0.sync_host();
|
| 309 |
+
tensor_D1.sync_host();
|
| 310 |
+
|
| 311 |
+
//
|
| 312 |
+
// Verify
|
| 313 |
+
//
|
| 314 |
+
cutlass::reference::device::Gemm<
|
| 315 |
+
typename Gemm0::ElementA, typename Gemm0::LayoutA,
|
| 316 |
+
typename Gemm0::ElementB, typename Gemm0::LayoutB,
|
| 317 |
+
typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute,
|
| 318 |
+
ElementAccumulator, typename Gemm0::Operator>
|
| 319 |
+
reference_gemm_0;
|
| 320 |
+
|
| 321 |
+
cutlass::reference::device::Gemm<
|
| 322 |
+
typename Gemm1::ElementA, typename Gemm1::LayoutA,
|
| 323 |
+
typename Gemm1::ElementB, typename Gemm1::LayoutB,
|
| 324 |
+
typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute,
|
| 325 |
+
ElementAccumulator, typename Gemm1::Operator>
|
| 326 |
+
reference_gemm_1;
|
| 327 |
+
|
| 328 |
+
reference_gemm_0(
|
| 329 |
+
problem_size_0,
|
| 330 |
+
alpha0,
|
| 331 |
+
tensor_A0.device_ref(),
|
| 332 |
+
tensor_B0.device_ref(),
|
| 333 |
+
beta0,
|
| 334 |
+
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
| 335 |
+
reference_D0.device_ref()
|
| 336 |
+
);
|
| 337 |
+
|
| 338 |
+
if(relu) {
|
| 339 |
+
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
reference_gemm_1(
|
| 343 |
+
problem_size_1,
|
| 344 |
+
alpha1,
|
| 345 |
+
reference_D0.device_ref(),
|
| 346 |
+
tensor_B1.device_ref(),
|
| 347 |
+
beta1,
|
| 348 |
+
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
| 349 |
+
reference_D1.device_ref()
|
| 350 |
+
);
|
| 351 |
+
|
| 352 |
+
if(relu) {
|
| 353 |
+
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
// Wait for kernels to finish
|
| 357 |
+
cudaDeviceSynchronize();
|
| 358 |
+
reference_D0.sync_host();
|
| 359 |
+
reference_D1.sync_host();
|
| 360 |
+
|
| 361 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
|
| 362 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
| 363 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
| 364 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
| 365 |
+
|
| 366 |
+
bool passed = cutlass::reference::host::TensorEquals(
|
| 367 |
+
reference_D1.host_view(),
|
| 368 |
+
tensor_D1.host_view());
|
| 369 |
+
|
| 370 |
+
CHECK_TRUE(passed);
|
| 371 |
+
if (!passed) {
|
| 372 |
+
|
| 373 |
+
std::stringstream fname;
|
| 374 |
+
|
| 375 |
+
fname << "error_B2bGemm_device_interleaved_nonfused.txt";
|
| 376 |
+
std::cerr << "Dumping results in " << fname.str() << "\n";
|
| 377 |
+
|
| 378 |
+
std::ofstream file(fname.str());
|
| 379 |
+
|
| 380 |
+
file
|
| 381 |
+
<< "A0 =\n" << tensor_A0.host_view()
|
| 382 |
+
<< "\nB0 =\n" << tensor_B0.host_view()
|
| 383 |
+
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
|
| 384 |
+
<< "\nC0 =\n" << tensor_C0.host_view()
|
| 385 |
+
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
| 386 |
+
<< "\nD0 =\n" << tensor_D0.host_view()
|
| 387 |
+
<< "\nB1 =\n" << tensor_B1.host_view()
|
| 388 |
+
<< "\nB1_reordered =\n" << tensor_B1_reordered.host_view()
|
| 389 |
+
<< "\nC1 =\n" << tensor_C1.host_view()
|
| 390 |
+
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
| 391 |
+
<< "\n\nReference =\n" << reference_D1.host_view()
|
| 392 |
+
<< "\nComputed =\n" << tensor_D1.host_view();
|
| 393 |
+
}
|
| 394 |
+
return passed;
|
| 395 |
+
}
|
| 396 |
+
};
|
| 397 |
+
|
| 398 |
+
template <typename B2bGemm_, int InterleavedK_>
|
| 399 |
+
struct B2bInterleavedFusedGemmRun
|
| 400 |
+
{
|
| 401 |
+
|
| 402 |
+
using B2bGemm = B2bGemm_;
|
| 403 |
+
using ElementAccumulator = typename B2bGemm::ElementAccumulator;
|
| 404 |
+
using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute;
|
| 405 |
+
|
| 406 |
+
/// Initialization
|
| 407 |
+
cutlass::Distribution::Kind init_A;
|
| 408 |
+
cutlass::Distribution::Kind init_B;
|
| 409 |
+
cutlass::Distribution::Kind init_C;
|
| 410 |
+
cutlass::Distribution::Kind init_Scale;
|
| 411 |
+
cutlass::Distribution::Kind init_Bias;
|
| 412 |
+
uint64_t seed;
|
| 413 |
+
|
| 414 |
+
//
|
| 415 |
+
// Methods
|
| 416 |
+
//
|
| 417 |
+
|
| 418 |
+
B2bInterleavedFusedGemmRun(
|
| 419 |
+
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 420 |
+
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 421 |
+
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 422 |
+
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
| 423 |
+
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
| 424 |
+
uint64_t seed_ = 2080
|
| 425 |
+
):
|
| 426 |
+
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
| 427 |
+
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
|
| 428 |
+
|
| 429 |
+
/// Helper to initialize a tensor view
|
| 430 |
+
template <typename Element, typename Layout>
|
| 431 |
+
bool initialize_tensor(
|
| 432 |
+
cutlass::TensorView<Element, Layout> view,
|
| 433 |
+
cutlass::Distribution::Kind dist_kind,
|
| 434 |
+
uint64_t seed) {
|
| 435 |
+
|
| 436 |
+
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 437 |
+
|
| 438 |
+
cutlass::reference::host::TensorFillRandomUniform(
|
| 439 |
+
view, seed, 2, -2, 0);
|
| 440 |
+
}
|
| 441 |
+
else if (dist_kind == cutlass::Distribution::Identity) {
|
| 442 |
+
|
| 443 |
+
cutlass::reference::host::TensorFillIdentity(view);
|
| 444 |
+
}
|
| 445 |
+
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 446 |
+
|
| 447 |
+
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 448 |
+
}
|
| 449 |
+
else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 450 |
+
|
| 451 |
+
cutlass::reference::host::BlockFillSequential(
|
| 452 |
+
view.data(), view.capacity());
|
| 453 |
+
}
|
| 454 |
+
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
| 455 |
+
cutlass::reference::host::TensorFill(view, Element(0));
|
| 456 |
+
}
|
| 457 |
+
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
| 458 |
+
cutlass::reference::host::TensorFill(view, Element(1));
|
| 459 |
+
}
|
| 460 |
+
else {
|
| 461 |
+
std::cerr << "Not implemented\n";
|
| 462 |
+
return false;
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
return true;
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
/// Executes one test
|
| 472 |
+
bool run(
|
| 473 |
+
cutlass::gemm::GemmCoord problem_size_0,
|
| 474 |
+
cutlass::gemm::GemmCoord problem_size_1,
|
| 475 |
+
ElementCompute alpha0 = ElementCompute(1),
|
| 476 |
+
ElementCompute beta0 = ElementCompute(0),
|
| 477 |
+
ElementCompute alpha1 = ElementCompute(1),
|
| 478 |
+
ElementCompute beta1 = ElementCompute(0),
|
| 479 |
+
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
|
| 480 |
+
|
| 481 |
+
// batch_count is used as split-k when mode is kGemm according
|
| 482 |
+
// to the GemmUniversal interface
|
| 483 |
+
|
| 484 |
+
int batch_count = 1,
|
| 485 |
+
|
| 486 |
+
int64_t batch_stride_A0 = 0,
|
| 487 |
+
int64_t batch_stride_B0 = 0,
|
| 488 |
+
int64_t batch_stride_C0 = 0,
|
| 489 |
+
int64_t batch_stride_B1 = 0,
|
| 490 |
+
int64_t batch_stride_C1 = 0,
|
| 491 |
+
int64_t batch_stride_D1 = 0,
|
| 492 |
+
int64_t batch_stride_Bias0 = 0,
|
| 493 |
+
int64_t batch_stride_Scale0 = 0,
|
| 494 |
+
bool relu = true,
|
| 495 |
+
int warm_ups = 1,
|
| 496 |
+
int runs = 100) {
|
| 497 |
+
|
| 498 |
+
//
|
| 499 |
+
// Allocate the GEMM workspace
|
| 500 |
+
//
|
| 501 |
+
|
| 502 |
+
cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
| 503 |
+
cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
| 504 |
+
cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
|
| 505 |
+
cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
|
| 506 |
+
cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
|
| 507 |
+
|
| 508 |
+
cutlass::HostTensor<
|
| 509 |
+
typename B2bGemm::ElementA,
|
| 510 |
+
typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
|
| 511 |
+
|
| 512 |
+
cutlass::HostTensor<
|
| 513 |
+
typename B2bGemm::ElementB,
|
| 514 |
+
typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
|
| 515 |
+
|
| 516 |
+
cutlass::HostTensor<
|
| 517 |
+
typename B2bGemm::ElementB,
|
| 518 |
+
typename B2bGemm::LayoutB> tensor_B0_reordered(CoordB0.kn());
|
| 519 |
+
|
| 520 |
+
cutlass::HostTensor<
|
| 521 |
+
typename B2bGemm::ElementC,
|
| 522 |
+
typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
|
| 523 |
+
|
| 524 |
+
cutlass::HostTensor<
|
| 525 |
+
typename B2bGemm::ElementScaleBias,
|
| 526 |
+
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
|
| 527 |
+
|
| 528 |
+
if(alpha0 == ElementCompute(0)) //per-channel scale
|
| 529 |
+
tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
|
| 530 |
+
|
| 531 |
+
cutlass::HostTensor<
|
| 532 |
+
typename B2bGemm::ElementScaleBias,
|
| 533 |
+
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
|
| 534 |
+
|
| 535 |
+
cutlass::HostTensor<
|
| 536 |
+
ElementAccumulator,
|
| 537 |
+
typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
|
| 538 |
+
|
| 539 |
+
cutlass::HostTensor<
|
| 540 |
+
typename B2bGemm::ElementC,
|
| 541 |
+
typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
|
| 542 |
+
|
| 543 |
+
cutlass::HostTensor<
|
| 544 |
+
typename B2bGemm::ElementB,
|
| 545 |
+
typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
|
| 546 |
+
|
| 547 |
+
cutlass::HostTensor<
|
| 548 |
+
typename B2bGemm::ElementB,
|
| 549 |
+
typename B2bGemm::LayoutB> tensor_B1_reordered(CoordB1.kn());
|
| 550 |
+
|
| 551 |
+
cutlass::HostTensor<
|
| 552 |
+
typename B2bGemm::ElementC,
|
| 553 |
+
typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
|
| 554 |
+
|
| 555 |
+
cutlass::HostTensor<
|
| 556 |
+
typename B2bGemm::ElementC,
|
| 557 |
+
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
|
| 558 |
+
|
| 559 |
+
cutlass::HostTensor<
|
| 560 |
+
typename B2bGemm::ElementC,
|
| 561 |
+
typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
|
| 562 |
+
|
| 563 |
+
cutlass::HostTensor<
|
| 564 |
+
typename B2bGemm::ElementC,
|
| 565 |
+
typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
| 569 |
+
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
| 570 |
+
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
| 571 |
+
if(alpha0 == ElementCompute(0)) //per-channel scale
|
| 572 |
+
CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014));
|
| 573 |
+
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013));
|
| 574 |
+
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
| 575 |
+
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
| 576 |
+
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012));
|
| 577 |
+
|
| 578 |
+
//Reorder B0
|
| 579 |
+
cutlass::reorder_column<16>(
|
| 580 |
+
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), CoordB0);
|
| 581 |
+
cutlass::reorder_column<InterleavedK_>(
|
| 582 |
+
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), CoordB1);
|
| 583 |
+
|
| 584 |
+
cutlass::reference::host::TensorFill(
|
| 585 |
+
tensor_D1.host_view());
|
| 586 |
+
cutlass::reference::host::TensorFill(
|
| 587 |
+
reference_D0.host_view());
|
| 588 |
+
cutlass::reference::host::TensorFill(
|
| 589 |
+
reference_D1.host_view());
|
| 590 |
+
|
| 591 |
+
tensor_A0.sync_device();
|
| 592 |
+
tensor_B0.sync_device();
|
| 593 |
+
tensor_B0_reordered.sync_device();
|
| 594 |
+
tensor_C0.sync_device();
|
| 595 |
+
if(alpha0 == ElementCompute(0)) //per-channel scale
|
| 596 |
+
tensor_Scale0.sync_device();
|
| 597 |
+
tensor_Bias0.sync_device();
|
| 598 |
+
tensor_B1.sync_device();
|
| 599 |
+
tensor_B1_reordered.sync_device();
|
| 600 |
+
tensor_C1.sync_device();
|
| 601 |
+
tensor_Bias1.sync_device();
|
| 602 |
+
tensor_D1.sync_device();
|
| 603 |
+
reference_D0.sync_device();
|
| 604 |
+
reference_D1.sync_device();
|
| 605 |
+
// tensor_Bias0_batched.sync_device();
|
| 606 |
+
|
| 607 |
+
//
|
| 608 |
+
// Initialize the GEMM operator
|
| 609 |
+
//
|
| 610 |
+
|
| 611 |
+
typename B2bGemm::Arguments arguments{
|
| 612 |
+
mode,
|
| 613 |
+
problem_size_0,
|
| 614 |
+
problem_size_1,
|
| 615 |
+
tensor_A0.device_ref(),
|
| 616 |
+
tensor_B0_reordered.device_ref(),
|
| 617 |
+
tensor_C0.device_ref(),
|
| 618 |
+
tensor_Scale0.device_ref(),
|
| 619 |
+
tensor_Bias0.device_ref(),
|
| 620 |
+
tensor_B1_reordered.device_ref(),
|
| 621 |
+
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
| 622 |
+
tensor_D1.device_ref(),
|
| 623 |
+
batch_stride_A0,
|
| 624 |
+
batch_stride_B0,
|
| 625 |
+
batch_stride_B1,
|
| 626 |
+
batch_stride_C1,
|
| 627 |
+
batch_stride_D1,
|
| 628 |
+
batch_stride_Bias0,
|
| 629 |
+
batch_stride_Scale0,
|
| 630 |
+
{alpha0, beta0},
|
| 631 |
+
{alpha1, beta1},
|
| 632 |
+
batch_count,
|
| 633 |
+
};
|
| 634 |
+
|
| 635 |
+
B2bGemm b2b_gemm_op;
|
| 636 |
+
|
| 637 |
+
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
|
| 638 |
+
|
| 639 |
+
if(status != cutlass::Status::kSuccess) {
|
| 640 |
+
std::cout << "Problem sizes not supported.\n"
|
| 641 |
+
<< "Requirments:\n"
|
| 642 |
+
<< " problem_size_0.M = problem_size_1.M\n"
|
| 643 |
+
<< " problem_size_0.N = problem_size_1.K\n"
|
| 644 |
+
<< " ThreadblockShape0::kN = problem_size_0.N\n"
|
| 645 |
+
<< " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
status = b2b_gemm_op.initialize(arguments);
|
| 649 |
+
|
| 650 |
+
CUTLASS_CHECK(status);
|
| 651 |
+
|
| 652 |
+
for(int i = 0; i < warm_ups; i++) {
|
| 653 |
+
status = b2b_gemm_op();
|
| 654 |
+
CUTLASS_CHECK(status);
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
//
|
| 658 |
+
// Run the GEMM
|
| 659 |
+
//
|
| 660 |
+
|
| 661 |
+
cudaEvent_t start, stop;
|
| 662 |
+
cudaEventCreate(&start);
|
| 663 |
+
cudaEventCreate(&stop);
|
| 664 |
+
|
| 665 |
+
cudaEventRecord(start);
|
| 666 |
+
|
| 667 |
+
for(int i = 0; i < runs; i++) {
|
| 668 |
+
status = b2b_gemm_op();
|
| 669 |
+
|
| 670 |
+
CUTLASS_CHECK(status);
|
| 671 |
+
}
|
| 672 |
+
|
| 673 |
+
cudaEventRecord(stop);
|
| 674 |
+
cudaDeviceSynchronize();
|
| 675 |
+
float gemmTime;
|
| 676 |
+
cudaEventElapsedTime(&gemmTime, start, stop);
|
| 677 |
+
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
| 678 |
+
|
| 679 |
+
tensor_D1.sync_host();
|
| 680 |
+
|
| 681 |
+
//
|
| 682 |
+
// Verify
|
| 683 |
+
//
|
| 684 |
+
|
| 685 |
+
cutlass::reference::device::GemmComplex<
|
| 686 |
+
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
| 687 |
+
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
| 688 |
+
ElementAccumulator, typename B2bGemm::LayoutC,
|
| 689 |
+
ElementAccumulator, ElementAccumulator
|
| 690 |
+
>(
|
| 691 |
+
problem_size_0,
|
| 692 |
+
ElementAccumulator(1), //intermediate alpha=1
|
| 693 |
+
tensor_A0.device_ref(),
|
| 694 |
+
cutlass::ComplexTransform::kNone,
|
| 695 |
+
tensor_B0.device_ref(),
|
| 696 |
+
cutlass::ComplexTransform::kNone,
|
| 697 |
+
ElementAccumulator(0), //beta = 0
|
| 698 |
+
reference_Z0.device_ref(),
|
| 699 |
+
reference_Z0.device_ref(),
|
| 700 |
+
ElementAccumulator(0),
|
| 701 |
+
int(batch_count),
|
| 702 |
+
batch_stride_A0,
|
| 703 |
+
batch_stride_B0,
|
| 704 |
+
batch_stride_C0,
|
| 705 |
+
batch_stride_C0
|
| 706 |
+
);
|
| 707 |
+
|
| 708 |
+
cutlass::reference::device::TensorScaleBiasGemmBatched<
|
| 709 |
+
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
| 710 |
+
ElementCompute, typename B2bGemm::LayoutScaleBias
|
| 711 |
+
> (
|
| 712 |
+
problem_size_0,
|
| 713 |
+
reference_Z0.device_ref(),
|
| 714 |
+
reference_D0.device_ref(),
|
| 715 |
+
alpha0,
|
| 716 |
+
tensor_Scale0.device_ref(),
|
| 717 |
+
tensor_Bias0.device_ref(),
|
| 718 |
+
int(batch_count),
|
| 719 |
+
batch_stride_C0,
|
| 720 |
+
batch_stride_C0,
|
| 721 |
+
batch_stride_Scale0,
|
| 722 |
+
batch_stride_Bias0
|
| 723 |
+
);
|
| 724 |
+
|
| 725 |
+
if(relu) {
|
| 726 |
+
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
| 727 |
+
}
|
| 728 |
+
|
| 729 |
+
cutlass::reference::device::GemmComplex<
|
| 730 |
+
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
| 731 |
+
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
| 732 |
+
typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
| 733 |
+
ElementCompute, ElementAccumulator
|
| 734 |
+
>(
|
| 735 |
+
problem_size_1,
|
| 736 |
+
alpha1, //intermediate alpha=1
|
| 737 |
+
reference_D0.device_ref(),
|
| 738 |
+
cutlass::ComplexTransform::kNone,
|
| 739 |
+
tensor_B1.device_ref(),
|
| 740 |
+
cutlass::ComplexTransform::kNone,
|
| 741 |
+
beta1, //beta = 0
|
| 742 |
+
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
| 743 |
+
reference_D1.device_ref(),
|
| 744 |
+
ElementAccumulator(0),
|
| 745 |
+
int(batch_count),
|
| 746 |
+
batch_stride_C0,
|
| 747 |
+
batch_stride_B1,
|
| 748 |
+
batch_stride_C1,
|
| 749 |
+
batch_stride_D1
|
| 750 |
+
);
|
| 751 |
+
|
| 752 |
+
if(relu) {
|
| 753 |
+
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
| 754 |
+
}
|
| 755 |
+
|
| 756 |
+
cudaDeviceSynchronize();
|
| 757 |
+
reference_D0.sync_host();
|
| 758 |
+
reference_D1.sync_host();
|
| 759 |
+
|
| 760 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
| 761 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
| 762 |
+
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
| 763 |
+
|
| 764 |
+
bool passed = cutlass::reference::host::TensorEquals(
|
| 765 |
+
reference_D1.host_view(),
|
| 766 |
+
tensor_D1.host_view());
|
| 767 |
+
|
| 768 |
+
CHECK_TRUE(passed);
|
| 769 |
+
if (!passed)
|
| 770 |
+
{
|
| 771 |
+
|
| 772 |
+
std::stringstream fname;
|
| 773 |
+
|
| 774 |
+
fname << "error_B2bGemm_device_interleaved_fused.txt";
|
| 775 |
+
std::cerr << "Dumping results in " << fname.str() << "\n";
|
| 776 |
+
|
| 777 |
+
std::ofstream file(fname.str());
|
| 778 |
+
|
| 779 |
+
file
|
| 780 |
+
<< "A0 =\n" << tensor_A0.host_view()
|
| 781 |
+
<< "\nB0 =\n" << tensor_B0.host_view()
|
| 782 |
+
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
|
| 783 |
+
<< "\nC0 =\n" << tensor_C0.host_view()
|
| 784 |
+
<< "\nScale0:\n" << tensor_Scale0.host_view() << "\n"
|
| 785 |
+
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
| 786 |
+
<< "\nB1 =\n" << tensor_B1.host_view()
|
| 787 |
+
<< "\nB1_reordered =\n" << tensor_B1_reordered.host_view()
|
| 788 |
+
<< "\nC1 =\n" << tensor_C1.host_view()
|
| 789 |
+
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
| 790 |
+
<< "\n\nReference =\n" << reference_D1.host_view()
|
| 791 |
+
<< "\nComputed =\n" << tensor_D1.host_view();
|
| 792 |
+
}
|
| 793 |
+
return passed;
|
| 794 |
+
}
|
| 795 |
+
|
| 796 |
+
};
|
| 797 |
+
|
| 798 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_gemm.h
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/numeric_types.h"
|
| 39 |
+
#include "cutlass/arch/arch.h"
|
| 40 |
+
#include "cutlass/device_kernel.h"
|
| 41 |
+
|
| 42 |
+
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
| 43 |
+
|
| 44 |
+
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
| 45 |
+
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
| 46 |
+
|
| 47 |
+
#include "kernel/b2b_gemm.h"
|
| 48 |
+
#include "kernel/default_b2b_gemm.h"
|
| 49 |
+
#include "kernel/default_b2b_gemm_smem_accumulator.h"
|
| 50 |
+
|
| 51 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 52 |
+
|
| 53 |
+
namespace cutlass {
|
| 54 |
+
namespace gemm {
|
| 55 |
+
namespace device {
|
| 56 |
+
|
| 57 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 58 |
+
|
| 59 |
+
template <
|
| 60 |
+
/// Element type for A matrix operand
|
| 61 |
+
typename ElementA_,
|
| 62 |
+
/// Layout type for A matrix operand
|
| 63 |
+
typename LayoutA_,
|
| 64 |
+
/// Element type for B matrix operand
|
| 65 |
+
typename ElementB_,
|
| 66 |
+
/// Layout type for B matrix operand
|
| 67 |
+
typename LayoutB_,
|
| 68 |
+
/// Element type for C and D matrix operands
|
| 69 |
+
typename ElementC_,
|
| 70 |
+
/// Layout type for C and D matrix operands
|
| 71 |
+
typename LayoutC_,
|
| 72 |
+
/// Element type for internal accumulation
|
| 73 |
+
typename ElementAccumulator_ = ElementC_,
|
| 74 |
+
/// Operator class tag
|
| 75 |
+
typename OperatorClass_ = arch::OpClassSimt,
|
| 76 |
+
/// Tag indicating architecture to tune for
|
| 77 |
+
typename ArchTag_ = arch::Sm70,
|
| 78 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 79 |
+
typename ThreadblockShape0_ = typename DefaultGemmConfiguration<
|
| 80 |
+
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
| 81 |
+
ElementAccumulator_>::ThreadblockShape,
|
| 82 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 83 |
+
typename ThreadblockShape1_ = typename DefaultGemmConfiguration<
|
| 84 |
+
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
| 85 |
+
ElementAccumulator_>::ThreadblockShape,
|
| 86 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 87 |
+
typename WarpShape0_ = typename DefaultGemmConfiguration<
|
| 88 |
+
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
| 89 |
+
ElementAccumulator_>::WarpShape,
|
| 90 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 91 |
+
typename WarpShape1_ = typename DefaultGemmConfiguration<
|
| 92 |
+
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
| 93 |
+
ElementAccumulator_>::WarpShape,
|
| 94 |
+
/// Instruction-level tile size (concept: GemmShape)
|
| 95 |
+
typename InstructionShape_ = typename DefaultGemmConfiguration<
|
| 96 |
+
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
| 97 |
+
ElementAccumulator_>::InstructionShape,
|
| 98 |
+
/// Epilogue output operator
|
| 99 |
+
typename EpilogueOutputOp0_ = typename DefaultGemmConfiguration<
|
| 100 |
+
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
| 101 |
+
ElementAccumulator_>::EpilogueOutputOp,
|
| 102 |
+
/// Epilogue output operator
|
| 103 |
+
typename EpilogueOutputOp1_ = typename DefaultGemmConfiguration<
|
| 104 |
+
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
| 105 |
+
ElementAccumulator_>::EpilogueOutputOp,
|
| 106 |
+
/// Threadblock-level swizzling operator
|
| 107 |
+
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
|
| 108 |
+
/// Number of stages used in the pipelined mainloop
|
| 109 |
+
int Stages =
|
| 110 |
+
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
| 111 |
+
ElementC_, ElementAccumulator_>::kStages,
|
| 112 |
+
/// Stage accumulator in shared memory
|
| 113 |
+
bool SmemAccumulator = false,
|
| 114 |
+
/// Access granularity of A matrix in units of elements
|
| 115 |
+
int AlignmentA =
|
| 116 |
+
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
| 117 |
+
ElementC_, ElementAccumulator_>::kAlignmentA,
|
| 118 |
+
/// Access granularity of B matrix in units of elements
|
| 119 |
+
int AlignmentB =
|
| 120 |
+
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
| 121 |
+
ElementC_, ElementAccumulator_>::kAlignmentB,
|
| 122 |
+
/// Operation performed by GEMM
|
| 123 |
+
typename Operator_ = typename DefaultGemmConfiguration<
|
| 124 |
+
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
| 125 |
+
ElementAccumulator_>::Operator>
|
| 126 |
+
class B2bGemm {
|
| 127 |
+
public:
|
| 128 |
+
|
| 129 |
+
using ElementA = ElementA_;
|
| 130 |
+
using LayoutA = LayoutA_;
|
| 131 |
+
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
| 132 |
+
using ElementB = ElementB_;
|
| 133 |
+
using LayoutB = LayoutB_;
|
| 134 |
+
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
| 135 |
+
using ElementC = ElementC_;
|
| 136 |
+
using LayoutC = LayoutC_;
|
| 137 |
+
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
| 138 |
+
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
| 139 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 140 |
+
using OperatorClass = OperatorClass_;
|
| 141 |
+
using ArchTag = ArchTag_;
|
| 142 |
+
using ThreadblockShape0 = ThreadblockShape0_;
|
| 143 |
+
using ThreadblockShape1 = ThreadblockShape1_;
|
| 144 |
+
using WarpShape0 = WarpShape0_;
|
| 145 |
+
using WarpShape1 = WarpShape1_;
|
| 146 |
+
using InstructionShape = InstructionShape_;
|
| 147 |
+
using EpilogueOutputOp0 = EpilogueOutputOp0_;
|
| 148 |
+
using EpilogueOutputOp1 = EpilogueOutputOp1_;
|
| 149 |
+
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
| 150 |
+
using Operator = Operator_;
|
| 151 |
+
static int const kStages = Stages;
|
| 152 |
+
static int const kAlignmentA = AlignmentA;
|
| 153 |
+
static int const kAlignmentB = AlignmentB;
|
| 154 |
+
static int const kAlignmentC = EpilogueOutputOp1::kCount;
|
| 155 |
+
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
| 156 |
+
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
| 157 |
+
|
| 158 |
+
/// Derived types
|
| 159 |
+
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
| 160 |
+
using LayoutScaleBias = layout::RowMajor;
|
| 161 |
+
|
| 162 |
+
/// Define the kernel
|
| 163 |
+
using B2bGemmKernel = typename kernel::DefaultB2bGemm<
|
| 164 |
+
ElementA,
|
| 165 |
+
LayoutA,
|
| 166 |
+
kAlignmentA,
|
| 167 |
+
ElementB,
|
| 168 |
+
LayoutB,
|
| 169 |
+
kAlignmentB,
|
| 170 |
+
ElementC,
|
| 171 |
+
LayoutC,
|
| 172 |
+
ElementAccumulator,
|
| 173 |
+
OperatorClass,
|
| 174 |
+
ArchTag,
|
| 175 |
+
ThreadblockShape0,
|
| 176 |
+
ThreadblockShape1,
|
| 177 |
+
WarpShape0,
|
| 178 |
+
WarpShape1,
|
| 179 |
+
InstructionShape,
|
| 180 |
+
EpilogueOutputOp0,
|
| 181 |
+
EpilogueOutputOp1,
|
| 182 |
+
ThreadblockSwizzle,
|
| 183 |
+
kStages,
|
| 184 |
+
Operator,
|
| 185 |
+
SmemAccumulator
|
| 186 |
+
>::B2bGemmKernel;
|
| 187 |
+
|
| 188 |
+
using Arguments = typename B2bGemmKernel::Arguments;
|
| 189 |
+
|
| 190 |
+
private:
|
| 191 |
+
|
| 192 |
+
/// Kernel parameters object
|
| 193 |
+
typename B2bGemmKernel::Params params_;
|
| 194 |
+
|
| 195 |
+
public:
|
| 196 |
+
|
| 197 |
+
/// Constructs the GEMM.
|
| 198 |
+
B2bGemm() { }
|
| 199 |
+
|
| 200 |
+
/// Determines whether the GEMM can execute the given problem.
|
| 201 |
+
static Status can_implement(Arguments const &args) {
|
| 202 |
+
|
| 203 |
+
Status status = B2bGemmKernel::can_implement(
|
| 204 |
+
args.problem_size_0,
|
| 205 |
+
args.problem_size_1,
|
| 206 |
+
args.ref_A0.non_const_ref(),
|
| 207 |
+
args.ref_B0.non_const_ref(),
|
| 208 |
+
args.ref_C0.non_const_ref(),
|
| 209 |
+
args.ref_B1.non_const_ref(),
|
| 210 |
+
args.ref_C1.non_const_ref(),
|
| 211 |
+
args.ref_D1
|
| 212 |
+
);
|
| 213 |
+
|
| 214 |
+
if (status != Status::kSuccess) {
|
| 215 |
+
return status;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
return Status::kSuccess;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
/// Gets the workspace size
|
| 222 |
+
static size_t get_workspace_size(Arguments const &args) {
|
| 223 |
+
|
| 224 |
+
size_t bytes = 0;
|
| 225 |
+
|
| 226 |
+
// Determine grid shape
|
| 227 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 228 |
+
|
| 229 |
+
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
|
| 230 |
+
args.problem_size_0,
|
| 231 |
+
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
| 232 |
+
args.batch_count);
|
| 233 |
+
|
| 234 |
+
return bytes;
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
/// Initializes GEMM state from arguments.
|
| 238 |
+
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
| 239 |
+
|
| 240 |
+
// Determine grid shape
|
| 241 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 242 |
+
|
| 243 |
+
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
| 244 |
+
args.problem_size_0,
|
| 245 |
+
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
| 246 |
+
args.batch_count);
|
| 247 |
+
// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape(
|
| 248 |
+
// args.problem_size_1,
|
| 249 |
+
// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK},
|
| 250 |
+
// args.batch_count);
|
| 251 |
+
|
| 252 |
+
// Initialize the Params structure
|
| 253 |
+
params_ = typename B2bGemmKernel::Params{
|
| 254 |
+
args.mode,
|
| 255 |
+
args.problem_size_0,
|
| 256 |
+
args.problem_size_1,
|
| 257 |
+
grid_shape,
|
| 258 |
+
args.ref_A0.non_const_ref(),
|
| 259 |
+
args.ref_B0.non_const_ref(),
|
| 260 |
+
args.ref_C0.non_const_ref(),
|
| 261 |
+
args.ref_Scale0.non_const_ref(),
|
| 262 |
+
args.ref_Bias0.non_const_ref(),
|
| 263 |
+
args.ref_B1.non_const_ref(),
|
| 264 |
+
args.ref_C1.non_const_ref(),
|
| 265 |
+
args.ref_D1,
|
| 266 |
+
args.batch_stride_A0,
|
| 267 |
+
args.batch_stride_B0,
|
| 268 |
+
args.batch_stride_B1,
|
| 269 |
+
args.batch_stride_C1,
|
| 270 |
+
args.batch_stride_D1,
|
| 271 |
+
args.batch_stride_Bias0,
|
| 272 |
+
args.batch_stride_Scale0,
|
| 273 |
+
args.epilogue0,
|
| 274 |
+
args.epilogue1,
|
| 275 |
+
static_cast<int *>(workspace),
|
| 276 |
+
};
|
| 277 |
+
|
| 278 |
+
return Status::kSuccess;
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
/// Lightweight update given a subset of arguments
|
| 282 |
+
Status update(Arguments const &args, void *workspace = nullptr) {
|
| 283 |
+
|
| 284 |
+
params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
|
| 285 |
+
params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
|
| 286 |
+
params_.ref_C0.reset(args.ref_C0.non_const_ref().data());
|
| 287 |
+
params_.ref_Scale0.reset(args.ref_Scale0.non_const_ref().data());
|
| 288 |
+
params_.ref_Bias0.reset(args.ref_Bias0.non_const_ref().data());
|
| 289 |
+
params_.ref_B1.reset(args.ref_B1.non_const_ref().data());
|
| 290 |
+
params_.ref_C1.reset(args.ref_C1.non_const_ref().data());
|
| 291 |
+
params_.ref_D1.reset(args.ref_D1.data());
|
| 292 |
+
params_.output_op_0 = args.epilogue0;
|
| 293 |
+
params_.output_op_1 = args.epilogue1;
|
| 294 |
+
params_.semaphore = static_cast<int *>(workspace);
|
| 295 |
+
|
| 296 |
+
return Status::kSuccess;
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
/// Runs the kernel using initialized state.
|
| 300 |
+
Status run(cudaStream_t stream = nullptr) {
|
| 301 |
+
|
| 302 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 303 |
+
|
| 304 |
+
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
| 305 |
+
dim3 block(B2bGemmKernel::kThreadCount, 1, 1);
|
| 306 |
+
|
| 307 |
+
cudaError_t result;
|
| 308 |
+
|
| 309 |
+
int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage));
|
| 310 |
+
if (smem_size >= (48 << 10)) {
|
| 311 |
+
result = cudaFuncSetAttribute(Kernel<B2bGemmKernel>,
|
| 312 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
| 313 |
+
smem_size);
|
| 314 |
+
|
| 315 |
+
if (result != cudaSuccess) {
|
| 316 |
+
return Status::kErrorInternal;
|
| 317 |
+
}
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
cutlass::Kernel<B2bGemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
| 321 |
+
|
| 322 |
+
result = cudaGetLastError();
|
| 323 |
+
|
| 324 |
+
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
/// Runs the kernel using initialized state.
|
| 328 |
+
Status operator()(cudaStream_t stream = nullptr) {
|
| 329 |
+
return run(stream);
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
/// Runs the kernel using initialized state.
|
| 333 |
+
Status operator()(
|
| 334 |
+
Arguments const &args,
|
| 335 |
+
void *workspace = nullptr,
|
| 336 |
+
cudaStream_t stream = nullptr) {
|
| 337 |
+
|
| 338 |
+
Status status = initialize(args, workspace, stream);
|
| 339 |
+
|
| 340 |
+
if (status == Status::kSuccess) {
|
| 341 |
+
status = run(stream);
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
return status;
|
| 345 |
+
}
|
| 346 |
+
};
|
| 347 |
+
|
| 348 |
+
} // namespace device
|
| 349 |
+
} // namespace gemm
|
| 350 |
+
} // namespace cutlass
|
| 351 |
+
|
| 352 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/* \file
|
| 32 |
+
\brief Template for device-level Implicit GEMM
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include <limits>
|
| 38 |
+
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
#include "cutlass/device_kernel.h"
|
| 41 |
+
#include "cutlass/conv/convolution.h"
|
| 42 |
+
|
| 43 |
+
#include "kernel/b2b_implicit_gemm_convolution.h"
|
| 44 |
+
#include "kernel/default_b2b_conv2d_fprop.h"
|
| 45 |
+
#include "kernel/default_b2b_conv2d_fprop_sm75.h"
|
| 46 |
+
#include "kernel/default_b2b_conv2d_fprop_sm80.h"
|
| 47 |
+
#include "kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h"
|
| 48 |
+
#include "kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h"
|
| 49 |
+
|
| 50 |
+
namespace cutlass {
|
| 51 |
+
namespace conv {
|
| 52 |
+
namespace device {
|
| 53 |
+
|
| 54 |
+
template<typename B2bImplicitGemmKernel_>
|
| 55 |
+
class B2bImplicitGemmConvolution {
|
| 56 |
+
public:
|
| 57 |
+
|
| 58 |
+
using B2bImplicitGemmKernel = B2bImplicitGemmKernel_;
|
| 59 |
+
|
| 60 |
+
using ElementA = typename B2bImplicitGemmKernel::ElementA;
|
| 61 |
+
using LayoutA = typename B2bImplicitGemmKernel::LayoutA;
|
| 62 |
+
using ElementB = typename B2bImplicitGemmKernel::ElementB;
|
| 63 |
+
using LayoutB = typename B2bImplicitGemmKernel::LayoutB;
|
| 64 |
+
using ElementC = typename B2bImplicitGemmKernel::ElementC;
|
| 65 |
+
using LayoutC = typename B2bImplicitGemmKernel::LayoutC;
|
| 66 |
+
using ElementAccumulator = typename B2bImplicitGemmKernel::ElementAccumulator;
|
| 67 |
+
using ElementCompute = typename B2bImplicitGemmKernel::ElementCompute;
|
| 68 |
+
using ElementScaleBias = typename B2bImplicitGemmKernel::ElementScaleBias;
|
| 69 |
+
using LayoutScaleBias = typename B2bImplicitGemmKernel::LayoutScaleBias;
|
| 70 |
+
using OperatorClass = typename B2bImplicitGemmKernel::OperatorClass;
|
| 71 |
+
using ArchTag = typename B2bImplicitGemmKernel::ArchTag;
|
| 72 |
+
using ThreadblockShape0 = typename B2bImplicitGemmKernel::ThreadblockShape0;
|
| 73 |
+
using ThreadblockShape1 = typename B2bImplicitGemmKernel::ThreadblockShape1;
|
| 74 |
+
using WarpShape0 = typename B2bImplicitGemmKernel::WarpShape0;
|
| 75 |
+
using WarpShape1 = typename B2bImplicitGemmKernel::WarpShape1;
|
| 76 |
+
using InstructionShape = typename B2bImplicitGemmKernel::InstructionShape;
|
| 77 |
+
using ThreadblockSwizzle = typename B2bImplicitGemmKernel::ThreadblockSwizzle;
|
| 78 |
+
using EpilogueOutputOp0 = typename B2bImplicitGemmKernel::EpilogueOutputOp0;
|
| 79 |
+
using EpilogueOutputOp1 = typename B2bImplicitGemmKernel::EpilogueOutputOp1;
|
| 80 |
+
static int const kStages = B2bImplicitGemmKernel::kStages;
|
| 81 |
+
static int const kConvDim = B2bImplicitGemmKernel::kConvDim;
|
| 82 |
+
using WarpMmaOperator0 = typename B2bImplicitGemmKernel::WarpMmaOperator0;
|
| 83 |
+
using WarpMmaOperator1 = typename B2bImplicitGemmKernel::WarpMmaOperator1;
|
| 84 |
+
using ArchMmaOperator = typename B2bImplicitGemmKernel::ArchMmaOperator;
|
| 85 |
+
using MathOperator = typename B2bImplicitGemmKernel::MathOperator;
|
| 86 |
+
|
| 87 |
+
static cutlass::conv::Operator const kConvolutionalOperator = B2bImplicitGemmKernel::kConvolutionalOperator;
|
| 88 |
+
static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = B2bImplicitGemmKernel::kIteratorAlgorithm;
|
| 89 |
+
|
| 90 |
+
static int const kWarpCount =
|
| 91 |
+
(ThreadblockShape0::kM / WarpShape0::kM) *
|
| 92 |
+
(ThreadblockShape0::kN / WarpShape0::kN);
|
| 93 |
+
|
| 94 |
+
/// Argument structure
|
| 95 |
+
using Arguments = typename B2bImplicitGemmKernel::Arguments;
|
| 96 |
+
|
| 97 |
+
private:
|
| 98 |
+
|
| 99 |
+
/// Kernel parameters object
|
| 100 |
+
typename B2bImplicitGemmKernel::Params params_;
|
| 101 |
+
|
| 102 |
+
public:
|
| 103 |
+
|
| 104 |
+
/// Constructs Implicit GEMM
|
| 105 |
+
B2bImplicitGemmConvolution() { }
|
| 106 |
+
|
| 107 |
+
/// Determines whether the Implicit GEMM can execute the given problem.
|
| 108 |
+
static Status can_implement(Arguments const &args) {
|
| 109 |
+
|
| 110 |
+
// dispatch to iterators
|
| 111 |
+
Status status = B2bImplicitGemmKernel::B2bMma::IteratorA0::can_implement(args.problem_size_0);
|
| 112 |
+
if (Status::kSuccess != status) {
|
| 113 |
+
return status;
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
status = B2bImplicitGemmKernel::B2bMma::IteratorB0::can_implement(args.problem_size_0);
|
| 117 |
+
if (Status::kSuccess != status) {
|
| 118 |
+
return status;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
status = B2bImplicitGemmKernel::B2bMma::IteratorB1::can_implement(args.problem_size_1);
|
| 122 |
+
if (Status::kSuccess != status) {
|
| 123 |
+
return status;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
// Determine grid shape
|
| 127 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 128 |
+
|
| 129 |
+
dim3 grid = threadblock_swizzle.get_grid_shape(
|
| 130 |
+
threadblock_swizzle.get_tiled_shape(
|
| 131 |
+
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0),
|
| 132 |
+
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
| 133 |
+
args.problem_size_0.split_k_slices));
|
| 134 |
+
|
| 135 |
+
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
|
| 136 |
+
grid.z <= std::numeric_limits<uint16_t>::max())) {
|
| 137 |
+
|
| 138 |
+
return Status::kErrorInvalidProblem;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
// Determine if fusion sizes are valid
|
| 142 |
+
|
| 143 |
+
cutlass::gemm::GemmCoord problem_size_0 = implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0);
|
| 144 |
+
cutlass::gemm::GemmCoord problem_size_1 = implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_1);
|
| 145 |
+
|
| 146 |
+
if(problem_size_0.m() != problem_size_1.m())
|
| 147 |
+
return Status::kErrorInvalidProblem;
|
| 148 |
+
|
| 149 |
+
if(problem_size_0.n() != problem_size_1.k())
|
| 150 |
+
return Status::kErrorInvalidProblem;
|
| 151 |
+
|
| 152 |
+
if(args.problem_size_1.R != 1 || args.problem_size_1.S != 1)
|
| 153 |
+
return Status::kErrorInvalidProblem;
|
| 154 |
+
|
| 155 |
+
if(problem_size_0.n() > ThreadblockShape0::kN)
|
| 156 |
+
return Status::kErrorInvalidProblem;
|
| 157 |
+
|
| 158 |
+
if(problem_size_1.n() > ThreadblockShape1::kN)
|
| 159 |
+
return Status::kErrorInvalidProblem;
|
| 160 |
+
|
| 161 |
+
return Status::kSuccess;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
/// Gets the workspace size
|
| 165 |
+
static size_t get_workspace_size(Arguments const &args) {
|
| 166 |
+
|
| 167 |
+
size_t workspace_bytes = 0;
|
| 168 |
+
|
| 169 |
+
// Determine grid shape
|
| 170 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 171 |
+
|
| 172 |
+
cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
| 173 |
+
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0),
|
| 174 |
+
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
| 175 |
+
args.problem_size_0.split_k_slices);
|
| 176 |
+
|
| 177 |
+
if(args.split_k_mode == SplitKMode::kParallel) {
|
| 178 |
+
|
| 179 |
+
// Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace.
|
| 180 |
+
// The user needs to call a reduction operator to obtain the final output tensor
|
| 181 |
+
workspace_bytes =
|
| 182 |
+
sizeof(ElementAccumulator) *
|
| 183 |
+
size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size_0)) *
|
| 184 |
+
size_t(grid_tiled_shape.k());
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size_0.split_k_slices > 1) {
|
| 188 |
+
|
| 189 |
+
// Split-K serial: The user workspace is used to store semaphore and serialize writing the
|
| 190 |
+
// final reduced output to user's output tensor
|
| 191 |
+
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
return workspace_bytes;
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
/// Initializes GEMM state from arguments.
|
| 198 |
+
Status initialize(
|
| 199 |
+
Arguments const &args,
|
| 200 |
+
void *workspace = nullptr,
|
| 201 |
+
cudaStream_t stream = nullptr) {
|
| 202 |
+
|
| 203 |
+
if (args.problem_size_0.split_k_slices > 1) {
|
| 204 |
+
|
| 205 |
+
if (!workspace) {
|
| 206 |
+
return Status::kErrorWorkspaceNull;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream);
|
| 210 |
+
|
| 211 |
+
if (status != cudaSuccess) {
|
| 212 |
+
return Status::kErrorInternal;
|
| 213 |
+
}
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
// initialize the params structure from the arguments
|
| 217 |
+
params_ = typename B2bImplicitGemmKernel::Params(
|
| 218 |
+
args,
|
| 219 |
+
static_cast<int *>(workspace)
|
| 220 |
+
);
|
| 221 |
+
|
| 222 |
+
int smem_size = int(sizeof(typename B2bImplicitGemmKernel::SharedStorage));
|
| 223 |
+
|
| 224 |
+
if (smem_size >= (48 << 10)) {
|
| 225 |
+
cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel<B2bImplicitGemmKernel>,
|
| 226 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
| 227 |
+
smem_size);
|
| 228 |
+
|
| 229 |
+
if (result != cudaSuccess) {
|
| 230 |
+
return Status::kErrorInternal;
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
return Status::kSuccess;
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
/// Initializes GEMM state from arguments.
|
| 238 |
+
Status update(Arguments const &args, void *workspace = nullptr) {
|
| 239 |
+
|
| 240 |
+
// update the params structure from the arguments
|
| 241 |
+
params_.ptr_A0 = args.ref_A0.data();
|
| 242 |
+
params_.ptr_B0 = args.ref_B0.data();
|
| 243 |
+
params_.ptr_C0 = args.ref_C0.data();
|
| 244 |
+
params_.ptr_Scale0 = args.ref_Scale0.data();
|
| 245 |
+
params_.ptr_Bias0 = args.ref_Bias0.data();
|
| 246 |
+
params_.ptr_B1 = args.ref_B1.data();
|
| 247 |
+
params_.ptr_C1 = args.ref_C1.data();
|
| 248 |
+
params_.ptr_D1 = args.ref_D1.data();
|
| 249 |
+
params_.output_op_0 = args.output_op_0;
|
| 250 |
+
params_.output_op_1 = args.output_op_1;
|
| 251 |
+
params_.semaphore = static_cast<int *>(workspace);
|
| 252 |
+
|
| 253 |
+
return Status::kSuccess;
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
/// Runs the kernel using initialized state.
|
| 257 |
+
Status run(cudaStream_t stream = nullptr) {
|
| 258 |
+
|
| 259 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 260 |
+
|
| 261 |
+
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
| 262 |
+
dim3 block(32 * kWarpCount, 1, 1);
|
| 263 |
+
|
| 264 |
+
int smem_size = int(sizeof(typename B2bImplicitGemmKernel::SharedStorage));
|
| 265 |
+
|
| 266 |
+
cutlass::Kernel<B2bImplicitGemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
| 267 |
+
|
| 268 |
+
cudaError_t result = cudaGetLastError();
|
| 269 |
+
|
| 270 |
+
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
/// Runs the kernel using initialized state.
|
| 274 |
+
Status operator()(cudaStream_t stream = nullptr) {
|
| 275 |
+
return run(stream);
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
/// Runs the kernel using initialized state.
|
| 279 |
+
Status operator()(
|
| 280 |
+
Arguments const &args,
|
| 281 |
+
void *workspace = nullptr,
|
| 282 |
+
cudaStream_t stream = nullptr) {
|
| 283 |
+
|
| 284 |
+
Status status = initialize(args, workspace, stream);
|
| 285 |
+
|
| 286 |
+
if (status == Status::kSuccess) {
|
| 287 |
+
status = run(stream);
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
return status;
|
| 291 |
+
}
|
| 292 |
+
};
|
| 293 |
+
|
| 294 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 295 |
+
|
| 296 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 297 |
+
} // namespace device
|
| 298 |
+
} // namespace conv
|
| 299 |
+
} // namespace cutlass
|
| 300 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h
ADDED
|
@@ -0,0 +1,811 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
|
| 39 |
+
#include "cutlass/gemm/gemm.h"
|
| 40 |
+
#include "cutlass/matrix_coord.h"
|
| 41 |
+
#include "cutlass/semaphore.h"
|
| 42 |
+
|
| 43 |
+
#include "kernel/b2b_gemm_grouped_problem_visitor.h"
|
| 44 |
+
#include "threadblock/grouped_threadblock_swizzle.h"
|
| 45 |
+
|
| 46 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace gemm {
|
| 50 |
+
namespace kernel {
|
| 51 |
+
|
| 52 |
+
namespace detail {
|
| 53 |
+
|
| 54 |
+
/// Utility struct for returning the type of the problem visitor used by the swizzling function,
|
| 55 |
+
/// if it is a grouped swizzling function, or a default visitor. This is used only for defining
|
| 56 |
+
/// the parameters of the problem visitor used in GroupedParams.
|
| 57 |
+
template <
|
| 58 |
+
typename B2bMma_,
|
| 59 |
+
typename ThreadblockSwizzle_,
|
| 60 |
+
typename Enable = void
|
| 61 |
+
>
|
| 62 |
+
struct ProblemVisitorOrDefault;
|
| 63 |
+
|
| 64 |
+
/// Return a generic problem visitor for GEMM problems
|
| 65 |
+
template <
|
| 66 |
+
typename B2bMma_,
|
| 67 |
+
typename ThreadblockSwizzle_
|
| 68 |
+
>
|
| 69 |
+
struct ProblemVisitorOrDefault<B2bMma_,
|
| 70 |
+
ThreadblockSwizzle_,
|
| 71 |
+
typename platform::enable_if<
|
| 72 |
+
! cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle_>::value
|
| 73 |
+
>::type> {
|
| 74 |
+
using value = B2bGemmGroupedProblemVisitor<typename B2bMma_::Shape,
|
| 75 |
+
GroupScheduleMode::kDeviceOnly,
|
| 76 |
+
128,
|
| 77 |
+
128,
|
| 78 |
+
platform::is_same<typename B2bMma_::LayoutC,
|
| 79 |
+
cutlass::layout::ColumnMajor>::value>;
|
| 80 |
+
};
|
| 81 |
+
|
| 82 |
+
/// Return the problem visitor specified by the swizzling function
|
| 83 |
+
template <
|
| 84 |
+
typename B2bMma_,
|
| 85 |
+
typename ThreadblockSwizzle_
|
| 86 |
+
>
|
| 87 |
+
struct ProblemVisitorOrDefault<B2bMma_,
|
| 88 |
+
ThreadblockSwizzle_,
|
| 89 |
+
typename platform::enable_if<
|
| 90 |
+
cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle_>::value
|
| 91 |
+
>::type> {
|
| 92 |
+
using value = typename ThreadblockSwizzle_::ProblemVisitor;
|
| 93 |
+
};
|
| 94 |
+
|
| 95 |
+
} // namespace detail
|
| 96 |
+
|
| 97 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 98 |
+
|
| 99 |
+
template <
|
| 100 |
+
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
|
| 101 |
+
typename Epilogue_, ///! Epilogue
|
| 102 |
+
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
| 103 |
+
>
|
| 104 |
+
struct B2bGemm {
|
| 105 |
+
|
| 106 |
+
using B2bMma = B2bMma_;
|
| 107 |
+
using Epilogue = Epilogue_;
|
| 108 |
+
using OutputOp0 = typename B2bMma::OutputOp;
|
| 109 |
+
using OutputOp1 = typename Epilogue::OutputOp;
|
| 110 |
+
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
| 111 |
+
|
| 112 |
+
using ElementA0 = typename B2bMma::IteratorA0::Element;
|
| 113 |
+
using LayoutA0 = typename B2bMma::IteratorA0::Layout;
|
| 114 |
+
using ElementB0 = typename B2bMma::IteratorB0::Element;
|
| 115 |
+
using LayoutB0 = typename B2bMma::IteratorB0::Layout;
|
| 116 |
+
using ElementB1 = typename B2bMma::IteratorB1::Element;
|
| 117 |
+
using LayoutB1 = typename B2bMma::IteratorB1::Layout;
|
| 118 |
+
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
| 119 |
+
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
| 120 |
+
|
| 121 |
+
using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element;
|
| 122 |
+
|
| 123 |
+
/// Data types needed for higher-level containers. In some cases, a single type must be exposed
|
| 124 |
+
/// despite the B2b GEMM using two GEMMs under the hood. In such cases, we select the values from
|
| 125 |
+
/// the second GEMM (other than for ElementA/ElementB)
|
| 126 |
+
using ElementA = typename B2bMma::IteratorA0::Element;
|
| 127 |
+
using LayoutA = typename B2bMma::IteratorA0::Layout;
|
| 128 |
+
using ElementB = typename B2bMma::IteratorB0::Element;
|
| 129 |
+
using LayoutB = typename B2bMma::IteratorB0::Layout;
|
| 130 |
+
|
| 131 |
+
static ComplexTransform const kTransformA = B2bMma::kTransformA;
|
| 132 |
+
static ComplexTransform const kTransformB = B2bMma::kTransformB;
|
| 133 |
+
using Operator = typename B2bMma::Operator0;
|
| 134 |
+
|
| 135 |
+
using OperatorClass = typename Operator::OperatorClass;
|
| 136 |
+
using ThreadblockShape = typename B2bMma::Shape0;
|
| 137 |
+
using WarpShape = typename Operator::Shape;
|
| 138 |
+
using InstructionShape = typename Operator::InstructionShape;
|
| 139 |
+
using ArchTag = typename B2bMma::ArchTag;
|
| 140 |
+
|
| 141 |
+
static int const kStages = B2bMma::kStages;
|
| 142 |
+
static int const kAlignmentA = B2bMma::IteratorA::AccessType::kElements;
|
| 143 |
+
static int const kAlignmentB = B2bMma::IteratorB::AccessType::kElements;
|
| 144 |
+
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
| 145 |
+
|
| 146 |
+
using Mma = B2bMma;
|
| 147 |
+
using EpilogueOutputOp = OutputOp1;
|
| 148 |
+
|
| 149 |
+
/// Warp count (concept: GemmShape)
|
| 150 |
+
using WarpCount0 = typename B2bMma::WarpCount0;
|
| 151 |
+
static int const kThreadCount = 32 * WarpCount0::kCount;
|
| 152 |
+
|
| 153 |
+
/// Argument structure
|
| 154 |
+
struct Arguments {
|
| 155 |
+
|
| 156 |
+
//
|
| 157 |
+
// Data members
|
| 158 |
+
//
|
| 159 |
+
|
| 160 |
+
GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
|
| 161 |
+
GemmCoord problem_size_0{0,0,0};
|
| 162 |
+
GemmCoord problem_size_1{0,0,0};
|
| 163 |
+
typename B2bMma::IteratorA0::TensorRef ref_A0{};
|
| 164 |
+
typename B2bMma::IteratorB0::TensorRef ref_B0{};
|
| 165 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_C0{};
|
| 166 |
+
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0{};
|
| 167 |
+
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0{};
|
| 168 |
+
typename B2bMma::IteratorB1::TensorRef ref_B1{};
|
| 169 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_C1{};
|
| 170 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_D1{};
|
| 171 |
+
int64_t batch_stride_A0{0};
|
| 172 |
+
int64_t batch_stride_B0{0};
|
| 173 |
+
int64_t batch_stride_B1{0};
|
| 174 |
+
int64_t batch_stride_C1{0};
|
| 175 |
+
int64_t batch_stride_D1{0};
|
| 176 |
+
int64_t batch_stride_Bias0{0};
|
| 177 |
+
int64_t batch_stride_Scale0{0};
|
| 178 |
+
typename OutputOp0::Params epilogue0 {};
|
| 179 |
+
typename OutputOp1::Params epilogue1 {};
|
| 180 |
+
int batch_count{1};
|
| 181 |
+
|
| 182 |
+
//
|
| 183 |
+
// Methods
|
| 184 |
+
//
|
| 185 |
+
|
| 186 |
+
/// Default ctor
|
| 187 |
+
Arguments() = default;
|
| 188 |
+
|
| 189 |
+
/// Constructs an Arguments structure
|
| 190 |
+
CUTLASS_HOST_DEVICE
|
| 191 |
+
Arguments(
|
| 192 |
+
GemmUniversalMode mode_,
|
| 193 |
+
GemmCoord problem_size_0_,
|
| 194 |
+
GemmCoord problem_size_1_,
|
| 195 |
+
typename B2bMma::IteratorA0::TensorRef ref_A0_,
|
| 196 |
+
typename B2bMma::IteratorB0::TensorRef ref_B0_,
|
| 197 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_C0_,
|
| 198 |
+
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0_,
|
| 199 |
+
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0_,
|
| 200 |
+
typename B2bMma::IteratorB1::TensorRef ref_B1_,
|
| 201 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_C1_,
|
| 202 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_D1_,
|
| 203 |
+
int64_t batch_stride_A0_,
|
| 204 |
+
int64_t batch_stride_B0_,
|
| 205 |
+
int64_t batch_stride_B1_,
|
| 206 |
+
int64_t batch_stride_C1_,
|
| 207 |
+
int64_t batch_stride_D1_,
|
| 208 |
+
int64_t batch_stride_Bias0_,
|
| 209 |
+
int64_t batch_stride_Scale0_,
|
| 210 |
+
typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(),
|
| 211 |
+
typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(),
|
| 212 |
+
int batch_count_ = 1
|
| 213 |
+
):
|
| 214 |
+
mode(mode_),
|
| 215 |
+
problem_size_0(problem_size_0_),
|
| 216 |
+
problem_size_1(problem_size_1_),
|
| 217 |
+
ref_A0(ref_A0_),
|
| 218 |
+
ref_B0(ref_B0_),
|
| 219 |
+
ref_C0(ref_C0_),
|
| 220 |
+
ref_Scale0(ref_Scale0_),
|
| 221 |
+
ref_Bias0(ref_Bias0_),
|
| 222 |
+
ref_B1(ref_B1_),
|
| 223 |
+
ref_C1(ref_C1_),
|
| 224 |
+
ref_D1(ref_D1_),
|
| 225 |
+
batch_stride_A0(batch_stride_A0_),
|
| 226 |
+
batch_stride_B0(batch_stride_B0_),
|
| 227 |
+
batch_stride_B1(batch_stride_B1_),
|
| 228 |
+
batch_stride_C1(batch_stride_C1_),
|
| 229 |
+
batch_stride_D1(batch_stride_D1_),
|
| 230 |
+
batch_stride_Bias0(batch_stride_Bias0_),
|
| 231 |
+
batch_stride_Scale0(batch_stride_Scale0_),
|
| 232 |
+
epilogue0(epilogue0_),
|
| 233 |
+
epilogue1(epilogue1_),
|
| 234 |
+
batch_count(batch_count_) {
|
| 235 |
+
}
|
| 236 |
+
};
|
| 237 |
+
|
| 238 |
+
// Arguments structure for grouped B2B problems
|
| 239 |
+
struct GroupedArguments {
|
| 240 |
+
GemmCoord* problem_size_0;
|
| 241 |
+
GemmCoord* problem_size_1;
|
| 242 |
+
typename B2bMma::IteratorA0::TensorRef* ref_A0;
|
| 243 |
+
typename B2bMma::IteratorB0::TensorRef* ref_B0;
|
| 244 |
+
typename Epilogue::OutputTileIterator::TensorRef* ref_C0;
|
| 245 |
+
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0;
|
| 246 |
+
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0;
|
| 247 |
+
typename B2bMma::IteratorB1::TensorRef* ref_B1;
|
| 248 |
+
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
|
| 249 |
+
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
|
| 250 |
+
|
| 251 |
+
// Epilogue params remain constant across all problems in the group. Thus,
|
| 252 |
+
// the parameter here is not a pointer.
|
| 253 |
+
typename OutputOp0::Params epilogue0;
|
| 254 |
+
typename OutputOp1::Params epilogue1;
|
| 255 |
+
|
| 256 |
+
int problem_count;
|
| 257 |
+
int threadblock_count;
|
| 258 |
+
GemmCoord* host_problem_sizes;
|
| 259 |
+
|
| 260 |
+
CUTLASS_HOST_DEVICE
|
| 261 |
+
GroupedArguments(
|
| 262 |
+
int problem_count,
|
| 263 |
+
GemmCoord* problem_size_0_,
|
| 264 |
+
GemmCoord* problem_size_1_,
|
| 265 |
+
typename B2bMma::IteratorA0::TensorRef* ref_A0_,
|
| 266 |
+
typename B2bMma::IteratorB0::TensorRef* ref_B0_,
|
| 267 |
+
typename Epilogue::OutputTileIterator::TensorRef* ref_C0_,
|
| 268 |
+
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0_,
|
| 269 |
+
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0_,
|
| 270 |
+
typename B2bMma::IteratorB1::TensorRef* ref_B1_,
|
| 271 |
+
typename Epilogue::OutputTileIterator::TensorRef* ref_C1_,
|
| 272 |
+
typename Epilogue::OutputTileIterator::TensorRef* ref_D1_,
|
| 273 |
+
typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(),
|
| 274 |
+
typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(),
|
| 275 |
+
int threadblock_count = 0
|
| 276 |
+
) : problem_size_0(problem_size_0_), problem_size_1(problem_size_1_),
|
| 277 |
+
ref_A0(ref_A0_), ref_B0(ref_B0_), ref_C0(ref_C0_),
|
| 278 |
+
ref_Scale0(ref_Scale0_), ref_Bias0(ref_Bias0_), ref_B1(ref_B1_),
|
| 279 |
+
ref_C1(ref_C1_), ref_D1(ref_D1_), epilogue0(epilogue0_), epilogue1(epilogue1_),
|
| 280 |
+
problem_count(problem_count),
|
| 281 |
+
threadblock_count(threadblock_count)
|
| 282 |
+
{}
|
| 283 |
+
};
|
| 284 |
+
|
| 285 |
+
/// Parameters structure
|
| 286 |
+
struct Params {
|
| 287 |
+
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
|
| 288 |
+
cutlass::gemm::GemmCoord problem_size_0{};
|
| 289 |
+
cutlass::gemm::GemmCoord problem_size_1{};
|
| 290 |
+
cutlass::gemm::GemmCoord grid_tiled_shape{};
|
| 291 |
+
int swizzle_log_tile{0};
|
| 292 |
+
typename B2bMma::IteratorA0::Params params_A0{};
|
| 293 |
+
typename B2bMma::IteratorA0::TensorRef ref_A0{};
|
| 294 |
+
typename B2bMma::IteratorB0::Params params_B0{};
|
| 295 |
+
typename B2bMma::IteratorB0::TensorRef ref_B0{};
|
| 296 |
+
typename Epilogue::OutputTileIterator::Params params_C0{};
|
| 297 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_C0{};
|
| 298 |
+
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0{};
|
| 299 |
+
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0{};
|
| 300 |
+
typename B2bMma::IteratorB1::Params params_B1{};
|
| 301 |
+
typename B2bMma::IteratorB1::TensorRef ref_B1{};
|
| 302 |
+
typename Epilogue::OutputTileIterator::Params params_C1{};
|
| 303 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_C1{};
|
| 304 |
+
typename Epilogue::OutputTileIterator::Params params_D1{};
|
| 305 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_D1{};
|
| 306 |
+
typename OutputOp0::Params output_op_0{};
|
| 307 |
+
typename OutputOp1::Params output_op_1{};
|
| 308 |
+
int64_t batch_stride_A0{0};
|
| 309 |
+
int64_t batch_stride_B0{0};
|
| 310 |
+
int64_t batch_stride_B1{0};
|
| 311 |
+
int64_t batch_stride_C1{0};
|
| 312 |
+
int64_t batch_stride_D1{0};
|
| 313 |
+
int64_t batch_stride_Bias0{0};
|
| 314 |
+
int64_t batch_stride_Scale0{0};
|
| 315 |
+
int *semaphore = nullptr;
|
| 316 |
+
int gemm_k_iterations_0{0};
|
| 317 |
+
int gemm_k_size_0{0};
|
| 318 |
+
int gemm_k_iterations_1{0};
|
| 319 |
+
int gemm_k_size_1{0};
|
| 320 |
+
|
| 321 |
+
//
|
| 322 |
+
// Methods
|
| 323 |
+
//
|
| 324 |
+
|
| 325 |
+
Params() = default;
|
| 326 |
+
|
| 327 |
+
CUTLASS_HOST_DEVICE
|
| 328 |
+
Params(
|
| 329 |
+
cutlass::gemm::GemmUniversalMode mode,
|
| 330 |
+
cutlass::gemm::GemmCoord const & problem_size_0,
|
| 331 |
+
cutlass::gemm::GemmCoord const & problem_size_1,
|
| 332 |
+
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
| 333 |
+
typename B2bMma::IteratorA0::TensorRef ref_A0,
|
| 334 |
+
typename B2bMma::IteratorB0::TensorRef ref_B0,
|
| 335 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_C0,
|
| 336 |
+
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0,
|
| 337 |
+
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0,
|
| 338 |
+
typename B2bMma::IteratorB1::TensorRef ref_B1,
|
| 339 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
|
| 340 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_D1,
|
| 341 |
+
int64_t batch_stride_A0,
|
| 342 |
+
int64_t batch_stride_B0,
|
| 343 |
+
int64_t batch_stride_B1,
|
| 344 |
+
int64_t batch_stride_C1,
|
| 345 |
+
int64_t batch_stride_D1,
|
| 346 |
+
int64_t batch_stride_Bias0,
|
| 347 |
+
int64_t batch_stride_Scale0,
|
| 348 |
+
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
|
| 349 |
+
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
|
| 350 |
+
int *workspace = nullptr
|
| 351 |
+
):
|
| 352 |
+
mode(mode),
|
| 353 |
+
problem_size_0(problem_size_0),
|
| 354 |
+
problem_size_1(problem_size_1),
|
| 355 |
+
grid_tiled_shape(grid_tiled_shape),
|
| 356 |
+
swizzle_log_tile(ThreadblockSwizzle::get_log_tile(grid_tiled_shape)),
|
| 357 |
+
params_A0(ref_A0.layout()),
|
| 358 |
+
ref_A0(ref_A0),
|
| 359 |
+
params_B0(ref_B0.layout()),
|
| 360 |
+
ref_B0(ref_B0),
|
| 361 |
+
params_C0(ref_C0.layout()),
|
| 362 |
+
ref_C0(ref_C0),
|
| 363 |
+
ref_Scale0(ref_Scale0),
|
| 364 |
+
ref_Bias0(ref_Bias0),
|
| 365 |
+
params_B1(ref_B1.layout()),
|
| 366 |
+
ref_B1(ref_B1),
|
| 367 |
+
params_C1(ref_C1.layout()),
|
| 368 |
+
ref_C1(ref_C1),
|
| 369 |
+
params_D1(ref_D1.layout()),
|
| 370 |
+
ref_D1(ref_D1),
|
| 371 |
+
batch_stride_A0(batch_stride_A0),
|
| 372 |
+
batch_stride_B0(batch_stride_B0),
|
| 373 |
+
batch_stride_B1(batch_stride_B1),
|
| 374 |
+
batch_stride_C1(batch_stride_C1),
|
| 375 |
+
batch_stride_D1(batch_stride_D1),
|
| 376 |
+
batch_stride_Bias0(batch_stride_Bias0),
|
| 377 |
+
batch_stride_Scale0(batch_stride_Scale0),
|
| 378 |
+
output_op_0(output_op_0),
|
| 379 |
+
output_op_1(output_op_1) {
|
| 380 |
+
|
| 381 |
+
int total_gemm_k_iterations_0 = (problem_size_0.k() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
|
| 382 |
+
int gemm_k_iterations_0 = (total_gemm_k_iterations_0 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
|
| 383 |
+
gemm_k_size_0 = gemm_k_iterations_0 * B2bMma::Shape0::kK;
|
| 384 |
+
int total_gemm_k_iterations_1 = (problem_size_1.k() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
|
| 385 |
+
int gemm_k_iterations_1 = (total_gemm_k_iterations_1 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
|
| 386 |
+
gemm_k_size_1 = gemm_k_iterations_1 * B2bMma::Shape1::kK;
|
| 387 |
+
|
| 388 |
+
semaphore = workspace;
|
| 389 |
+
}
|
| 390 |
+
};
|
| 391 |
+
|
| 392 |
+
struct GroupedParams {
|
| 393 |
+
cutlass::gemm::GemmCoord* problem_size_0;
|
| 394 |
+
cutlass::gemm::GemmCoord* problem_size_1;
|
| 395 |
+
cutlass::gemm::GemmCoord* grid_tiled_shape;
|
| 396 |
+
typename B2bMma::IteratorA0::TensorRef* ref_A0;
|
| 397 |
+
typename B2bMma::IteratorB0::TensorRef* ref_B0;
|
| 398 |
+
typename Epilogue::OutputTileIterator::TensorRef* ref_C0;
|
| 399 |
+
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0;
|
| 400 |
+
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0;
|
| 401 |
+
typename B2bMma::IteratorB1::TensorRef* ref_B1;
|
| 402 |
+
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
|
| 403 |
+
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
|
| 404 |
+
|
| 405 |
+
// Epilogue params remain constant across all problems in the group. Thus,
|
| 406 |
+
// the parameter here is not a pointer.
|
| 407 |
+
typename OutputOp0::Params output_op_0;
|
| 408 |
+
typename OutputOp1::Params output_op_1;
|
| 409 |
+
|
| 410 |
+
using ProblemVisitor = typename detail::ProblemVisitorOrDefault<B2bMma, ThreadblockSwizzle>::value;
|
| 411 |
+
typename ProblemVisitor::Params problem_visitor;
|
| 412 |
+
int threadblock_count;
|
| 413 |
+
int* workspace;
|
| 414 |
+
|
| 415 |
+
CUTLASS_HOST_DEVICE
|
| 416 |
+
GroupedParams() {}
|
| 417 |
+
|
| 418 |
+
CUTLASS_HOST_DEVICE
|
| 419 |
+
GroupedParams(
|
| 420 |
+
GroupedArguments const &args,
|
| 421 |
+
void *workspace = nullptr,
|
| 422 |
+
int tile_count = 0
|
| 423 |
+
) :
|
| 424 |
+
problem_size_0(args.problem_size_0), problem_size_1(args.problem_size_1),
|
| 425 |
+
ref_A0(args.ref_A0), ref_B0(args.ref_B0), ref_C0(args.ref_C0),
|
| 426 |
+
ref_Scale0(args.ref_Scale0), ref_Bias0(args.ref_Bias0), ref_B1(args.ref_B1), ref_C1(args.ref_C1), ref_D1(args.ref_D1),
|
| 427 |
+
output_op_0(args.epilogue0), output_op_1(args.epilogue1),
|
| 428 |
+
problem_visitor(args.problem_size_0, args.problem_size_1, args.problem_count, workspace, tile_count),
|
| 429 |
+
threadblock_count(args.threadblock_count),
|
| 430 |
+
workspace(reinterpret_cast<int*>(workspace)) {}
|
| 431 |
+
|
| 432 |
+
CUTLASS_HOST_DEVICE
|
| 433 |
+
void transpose() {
|
| 434 |
+
// Only row-major outputs are currently supported, so no transpose is performed
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
/// Returns non-grouped parameters to be used as input to the kernel-level
|
| 438 |
+
/// operator for the problem indicated by problem_visitor.
|
| 439 |
+
CUTLASS_HOST_DEVICE
|
| 440 |
+
Params to_single_params(const ProblemVisitor& problem_visitor) const {
|
| 441 |
+
GemmCoord problem_size0 = problem_visitor.problem_size0();
|
| 442 |
+
GemmCoord problem_size1 = problem_visitor.problem_size1();
|
| 443 |
+
int32_t idx = problem_visitor.problem_index();
|
| 444 |
+
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size1);
|
| 445 |
+
|
| 446 |
+
return Params(
|
| 447 |
+
cutlass::gemm::GemmUniversalMode::kGemm,
|
| 448 |
+
problem_size0,
|
| 449 |
+
problem_size1,
|
| 450 |
+
grid_shape,
|
| 451 |
+
ref_A0[idx],
|
| 452 |
+
ref_B0[idx],
|
| 453 |
+
ref_C0[idx],
|
| 454 |
+
ref_Scale0[idx],
|
| 455 |
+
ref_Bias0[idx],
|
| 456 |
+
ref_B1[idx],
|
| 457 |
+
ref_C1[idx],
|
| 458 |
+
ref_D1[idx],
|
| 459 |
+
0, 0, 0, 0, 0, 0, 0, // Batched B2B GEMMs within the grouped kernel are currently unsupported
|
| 460 |
+
output_op_0,
|
| 461 |
+
output_op_1,
|
| 462 |
+
workspace
|
| 463 |
+
);
|
| 464 |
+
}
|
| 465 |
+
};
|
| 466 |
+
|
| 467 |
+
/// Shared memory storage structure
|
| 468 |
+
union SharedStorage {
|
| 469 |
+
typename B2bMma::B2bMmaSharedStorage main_loop;
|
| 470 |
+
typename Epilogue::SharedStorage epilogue;
|
| 471 |
+
};
|
| 472 |
+
|
| 473 |
+
//
|
| 474 |
+
// Methods
|
| 475 |
+
//
|
| 476 |
+
|
| 477 |
+
CUTLASS_HOST_DEVICE
|
| 478 |
+
B2bGemm() { }
|
| 479 |
+
|
| 480 |
+
/// Determines whether kernel satisfies alignment
|
| 481 |
+
static Status can_implement(
|
| 482 |
+
cutlass::gemm::GemmCoord const & problem_size_0,
|
| 483 |
+
cutlass::gemm::GemmCoord const & problem_size_1,
|
| 484 |
+
typename B2bMma::IteratorA0::TensorRef ref_A0,
|
| 485 |
+
typename B2bMma::IteratorB0::TensorRef ref_B0,
|
| 486 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_C0,
|
| 487 |
+
typename B2bMma::IteratorB1::TensorRef ref_B1,
|
| 488 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
|
| 489 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_D1) {
|
| 490 |
+
|
| 491 |
+
static int const kAlignmentA = B2bMma::IteratorA0::AccessType::kElements;
|
| 492 |
+
static int const kAlignmentB = B2bMma::IteratorB0::AccessType::kElements;
|
| 493 |
+
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
| 494 |
+
|
| 495 |
+
if (!TensorRef_aligned(ref_A0, kAlignmentA)) {
|
| 496 |
+
return Status::kErrorMisalignedOperand;
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
if (!TensorRef_aligned(ref_B0, kAlignmentB)) {
|
| 500 |
+
return Status::kErrorMisalignedOperand;
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
+
if (!TensorRef_aligned(ref_C0, kAlignmentC)) {
|
| 504 |
+
return Status::kErrorMisalignedOperand;
|
| 505 |
+
}
|
| 506 |
+
|
| 507 |
+
if (!TensorRef_aligned(ref_B1, kAlignmentB)) {
|
| 508 |
+
return Status::kErrorMisalignedOperand;
|
| 509 |
+
}
|
| 510 |
+
|
| 511 |
+
if (!TensorRef_aligned(ref_C1, kAlignmentC)) {
|
| 512 |
+
return Status::kErrorMisalignedOperand;
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
if (!TensorRef_aligned(ref_D1, kAlignmentC)) {
|
| 516 |
+
return Status::kErrorMisalignedOperand;
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
if ((problem_size_0.m() % kAlignmentA) || (problem_size_0.k() % kAlignmentA) ||
|
| 520 |
+
(problem_size_0.n() % kAlignmentB) || (problem_size_0.k() % kAlignmentB) ||
|
| 521 |
+
(problem_size_0.m() % kAlignmentC) || (problem_size_0.n() % kAlignmentC) ||
|
| 522 |
+
(problem_size_1.m() % kAlignmentA) || (problem_size_1.k() % kAlignmentA) ||
|
| 523 |
+
(problem_size_1.n() % kAlignmentB) || (problem_size_1.k() % kAlignmentB) ||
|
| 524 |
+
(problem_size_1.m() % kAlignmentC) || (problem_size_1.n() % kAlignmentC)) {
|
| 525 |
+
|
| 526 |
+
return Status::kErrorMisalignedOperand;
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
// Determine if fusion sizes are valid
|
| 530 |
+
if(problem_size_0.m() != problem_size_1.m())
|
| 531 |
+
return Status::kErrorInvalidProblem;
|
| 532 |
+
|
| 533 |
+
if(problem_size_0.n() != problem_size_1.k())
|
| 534 |
+
return Status::kErrorInvalidProblem;
|
| 535 |
+
|
| 536 |
+
if(problem_size_0.n() > B2bMma::Shape0::kN)
|
| 537 |
+
return Status::kErrorInvalidProblem;
|
| 538 |
+
|
| 539 |
+
if(problem_size_1.n() > B2bMma::Shape1::kN)
|
| 540 |
+
return Status::kErrorInvalidProblem;
|
| 541 |
+
|
| 542 |
+
return Status::kSuccess;
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
/// Executes one GEMM
|
| 546 |
+
CUTLASS_DEVICE
|
| 547 |
+
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
| 548 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 549 |
+
run_with_swizzle(params, shared_storage, threadblock_swizzle);
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
/// Executes one GEMM with an externally-provided swizzling function
|
| 553 |
+
CUTLASS_DEVICE
|
| 554 |
+
void run_with_swizzle(Params const ¶ms, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) {
|
| 555 |
+
|
| 556 |
+
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
| 557 |
+
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
| 558 |
+
|
| 559 |
+
// Early exit if CTA is out of range
|
| 560 |
+
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
| 561 |
+
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
| 562 |
+
|
| 563 |
+
return;
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
ElementA0 *ptr_A0 = static_cast<ElementA0 *>(params.ref_A0.data());
|
| 567 |
+
ElementB0 *ptr_B0 = static_cast<ElementB0 *>(params.ref_B0.data());
|
| 568 |
+
ElementB1 *ptr_B1 = static_cast<ElementB1 *>(params.ref_B1.data());
|
| 569 |
+
|
| 570 |
+
ScaleBiasData *ptr_Bias0 = static_cast<ScaleBiasData *>(params.ref_Bias0.data());
|
| 571 |
+
ScaleBiasData *ptr_Scale0 = static_cast<ScaleBiasData *>(params.ref_Scale0.data());
|
| 572 |
+
|
| 573 |
+
int offset_k_0 = 0;
|
| 574 |
+
int offset_k_1 = 0;
|
| 575 |
+
|
| 576 |
+
int problem_size_k_0 = params.problem_size_0.k();
|
| 577 |
+
int problem_size_k_1 = params.problem_size_1.k();
|
| 578 |
+
|
| 579 |
+
if (params.mode == GemmUniversalMode::kGemm) {
|
| 580 |
+
|
| 581 |
+
// Problem size is a function of threadblock index in the K dimension
|
| 582 |
+
problem_size_k_0 = min(
|
| 583 |
+
problem_size_k_0,
|
| 584 |
+
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
|
| 585 |
+
|
| 586 |
+
// Problem size is a function of threadblock index in the K dimension
|
| 587 |
+
problem_size_k_1 = min(
|
| 588 |
+
problem_size_k_1,
|
| 589 |
+
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
|
| 590 |
+
|
| 591 |
+
offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0;
|
| 592 |
+
offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1;
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
else if (params.mode == GemmUniversalMode::kBatched) {
|
| 596 |
+
ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0;
|
| 597 |
+
ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0;
|
| 598 |
+
ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1;
|
| 599 |
+
ptr_Bias0 += threadblock_tile_offset.k() * params.batch_stride_Bias0;
|
| 600 |
+
ptr_Scale0 += threadblock_tile_offset.k() * params.batch_stride_Scale0;
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
// Compute initial location in logical coordinates
|
| 604 |
+
cutlass::MatrixCoord tb_offset_A0{
|
| 605 |
+
threadblock_tile_offset.m() * B2bMma::Shape0::kM,
|
| 606 |
+
offset_k_0,
|
| 607 |
+
};
|
| 608 |
+
|
| 609 |
+
cutlass::MatrixCoord tb_offset_B0{
|
| 610 |
+
offset_k_0,
|
| 611 |
+
threadblock_tile_offset.n() * B2bMma::Shape0::kN
|
| 612 |
+
};
|
| 613 |
+
|
| 614 |
+
cutlass::MatrixCoord tb_offset_B1{
|
| 615 |
+
offset_k_1,
|
| 616 |
+
threadblock_tile_offset.n() * B2bMma::Shape1::kN
|
| 617 |
+
};
|
| 618 |
+
|
| 619 |
+
// Compute threadblock-scoped matrix multiply-add
|
| 620 |
+
int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
|
| 621 |
+
|
| 622 |
+
// Compute threadblock-scoped matrix multiply-add
|
| 623 |
+
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
// Compute position within threadblock
|
| 627 |
+
int thread_idx = threadIdx.x;
|
| 628 |
+
|
| 629 |
+
// Construct iterators to A and B operands
|
| 630 |
+
typename B2bMma::IteratorA0 iterator_A0(
|
| 631 |
+
params.params_A0,
|
| 632 |
+
ptr_A0,
|
| 633 |
+
{params.problem_size_0.m(), problem_size_k_0},
|
| 634 |
+
thread_idx,
|
| 635 |
+
tb_offset_A0);
|
| 636 |
+
|
| 637 |
+
typename B2bMma::IteratorB0 iterator_B0(
|
| 638 |
+
params.params_B0,
|
| 639 |
+
ptr_B0,
|
| 640 |
+
{problem_size_k_0, params.problem_size_0.n()},
|
| 641 |
+
thread_idx,
|
| 642 |
+
tb_offset_B0);
|
| 643 |
+
|
| 644 |
+
typename B2bMma::IteratorB1 iterator_B1(
|
| 645 |
+
params.params_B1,
|
| 646 |
+
ptr_B1,
|
| 647 |
+
{problem_size_k_1, params.problem_size_1.n()},
|
| 648 |
+
thread_idx,
|
| 649 |
+
tb_offset_B1);
|
| 650 |
+
|
| 651 |
+
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
| 652 |
+
// is compiled as warp-uniform.
|
| 653 |
+
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 654 |
+
int lane_idx = threadIdx.x % 32;
|
| 655 |
+
|
| 656 |
+
// Construct iterators to accumulator scale/bias vector
|
| 657 |
+
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
|
| 658 |
+
ptr_Scale0,
|
| 659 |
+
{1, params.problem_size_0.n()},
|
| 660 |
+
thread_idx,
|
| 661 |
+
warp_idx,
|
| 662 |
+
MatrixCoord(
|
| 663 |
+
0, threadblock_tile_offset.n() * B2bMma::Shape0::kN
|
| 664 |
+
)
|
| 665 |
+
);
|
| 666 |
+
|
| 667 |
+
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
|
| 668 |
+
ptr_Bias0,
|
| 669 |
+
{1, params.problem_size_0.n()},
|
| 670 |
+
thread_idx,
|
| 671 |
+
warp_idx,
|
| 672 |
+
MatrixCoord(
|
| 673 |
+
0, threadblock_tile_offset.n() * B2bMma::Shape0::kN
|
| 674 |
+
)
|
| 675 |
+
);
|
| 676 |
+
|
| 677 |
+
//
|
| 678 |
+
// Main loop
|
| 679 |
+
//
|
| 680 |
+
|
| 681 |
+
OutputOp0 output_op_0(params.output_op_0);
|
| 682 |
+
|
| 683 |
+
if (cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle>::value) {
|
| 684 |
+
// Wait for all threads to finish their epilogue phases from the previous tile.
|
| 685 |
+
__syncthreads();
|
| 686 |
+
}
|
| 687 |
+
|
| 688 |
+
// Construct thread-scoped matrix multiply
|
| 689 |
+
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n());
|
| 690 |
+
|
| 691 |
+
typename B2bMma::FragmentC0 src_accum;
|
| 692 |
+
typename B2bMma::FragmentC1 accumulators;
|
| 693 |
+
|
| 694 |
+
src_accum.clear();
|
| 695 |
+
accumulators.clear();
|
| 696 |
+
|
| 697 |
+
// Compute threadblock-scoped matrix multiply-add
|
| 698 |
+
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
|
| 699 |
+
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
|
| 700 |
+
|
| 701 |
+
//
|
| 702 |
+
// Epilogue
|
| 703 |
+
//
|
| 704 |
+
|
| 705 |
+
OutputOp1 output_op_1(params.output_op_1);
|
| 706 |
+
|
| 707 |
+
//
|
| 708 |
+
// Masked tile iterators constructed from members
|
| 709 |
+
//
|
| 710 |
+
|
| 711 |
+
threadblock_tile_offset =
|
| 712 |
+
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
| 713 |
+
|
| 714 |
+
//assume identity swizzle
|
| 715 |
+
MatrixCoord threadblock_offset(
|
| 716 |
+
threadblock_tile_offset.m() * B2bMma::Shape1::kM,
|
| 717 |
+
threadblock_tile_offset.n() * B2bMma::Shape1::kN
|
| 718 |
+
);
|
| 719 |
+
|
| 720 |
+
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
| 721 |
+
|
| 722 |
+
ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
|
| 723 |
+
ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
|
| 724 |
+
|
| 725 |
+
// Construct the semaphore.
|
| 726 |
+
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
| 727 |
+
|
| 728 |
+
if (params.mode == GemmUniversalMode::kGemm) {
|
| 729 |
+
// If performing a reduction via split-K, fetch the initial synchronization
|
| 730 |
+
|
| 731 |
+
if (params.grid_tiled_shape.k() > 1) {
|
| 732 |
+
// Fetch the synchronization lock initially but do not block.
|
| 733 |
+
semaphore.fetch();
|
| 734 |
+
|
| 735 |
+
// Indicate which position in a serial reduction the output operator is currently updating
|
| 736 |
+
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
| 737 |
+
}
|
| 738 |
+
}
|
| 739 |
+
else if (params.mode == GemmUniversalMode::kBatched) {
|
| 740 |
+
ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1;
|
| 741 |
+
ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1;
|
| 742 |
+
}
|
| 743 |
+
|
| 744 |
+
// Tile iterator loading from source tensor.
|
| 745 |
+
typename Epilogue::OutputTileIterator iterator_C1(
|
| 746 |
+
params.params_C1,
|
| 747 |
+
ptr_C1,
|
| 748 |
+
params.problem_size_1.mn(),
|
| 749 |
+
thread_idx,
|
| 750 |
+
threadblock_offset
|
| 751 |
+
);
|
| 752 |
+
|
| 753 |
+
// Tile iterator writing to destination tensor.
|
| 754 |
+
typename Epilogue::OutputTileIterator iterator_D1(
|
| 755 |
+
params.params_D1,
|
| 756 |
+
ptr_D1,
|
| 757 |
+
params.problem_size_1.mn(),
|
| 758 |
+
thread_idx,
|
| 759 |
+
threadblock_offset
|
| 760 |
+
);
|
| 761 |
+
|
| 762 |
+
Epilogue epilogue(
|
| 763 |
+
shared_storage.epilogue,
|
| 764 |
+
thread_idx,
|
| 765 |
+
warp_idx,
|
| 766 |
+
lane_idx);
|
| 767 |
+
|
| 768 |
+
// Wait on the semaphore - this latency may have been covered by iterator construction
|
| 769 |
+
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
| 770 |
+
|
| 771 |
+
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
| 772 |
+
if (threadblock_tile_offset.k()) {
|
| 773 |
+
iterator_C1 = iterator_D1;
|
| 774 |
+
}
|
| 775 |
+
|
| 776 |
+
semaphore.wait(threadblock_tile_offset.k());
|
| 777 |
+
|
| 778 |
+
__threadfence();
|
| 779 |
+
}
|
| 780 |
+
|
| 781 |
+
// Execute the epilogue operator to update the destination tensor.
|
| 782 |
+
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
|
| 783 |
+
|
| 784 |
+
//
|
| 785 |
+
// Release the semaphore
|
| 786 |
+
//
|
| 787 |
+
|
| 788 |
+
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
| 789 |
+
|
| 790 |
+
int lock = 0;
|
| 791 |
+
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
| 792 |
+
|
| 793 |
+
// The final threadblock resets the semaphore for subsequent grids.
|
| 794 |
+
lock = 0;
|
| 795 |
+
}
|
| 796 |
+
else {
|
| 797 |
+
// Otherwise, the semaphore is incremented
|
| 798 |
+
lock = threadblock_tile_offset.k() + 1;
|
| 799 |
+
}
|
| 800 |
+
|
| 801 |
+
__threadfence();
|
| 802 |
+
semaphore.release(lock);
|
| 803 |
+
}
|
| 804 |
+
}
|
| 805 |
+
};
|
| 806 |
+
|
| 807 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 808 |
+
|
| 809 |
+
} // namespace kernel
|
| 810 |
+
} // namespace gemm
|
| 811 |
+
} // namespace cutlass
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Scheduler for grouped B2b GEMMs
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#include "cutlass/gemm/gemm.h"
|
| 40 |
+
#include "cutlass/matrix_coord.h"
|
| 41 |
+
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
|
| 42 |
+
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
| 43 |
+
|
| 44 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
|
| 46 |
+
namespace cutlass {
|
| 47 |
+
namespace gemm {
|
| 48 |
+
namespace kernel {
|
| 49 |
+
|
| 50 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 51 |
+
|
| 52 |
+
/// Visitor class to abstract away the algorithm for iterating over tiles
|
| 53 |
+
template <typename ThreadblockShape,
|
| 54 |
+
GroupScheduleMode GroupScheduleMode_,
|
| 55 |
+
int PrefetchTileCount,
|
| 56 |
+
int ThreadCount,
|
| 57 |
+
bool Transposed = false>
|
| 58 |
+
struct B2bGemmGroupedProblemVisitor : public GroupedProblemVisitor<
|
| 59 |
+
detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>,
|
| 60 |
+
ThreadblockShape,
|
| 61 |
+
GroupScheduleMode_,
|
| 62 |
+
PrefetchTileCount,
|
| 63 |
+
ThreadCount> {
|
| 64 |
+
|
| 65 |
+
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
|
| 66 |
+
using Base = GroupedProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
|
| 67 |
+
using BaseParams = typename Base::Params;
|
| 68 |
+
using SharedStorage = typename Base::SharedStorage;
|
| 69 |
+
static bool const kTransposed = Transposed;
|
| 70 |
+
|
| 71 |
+
cutlass::gemm::GemmCoord const *problem_sizes0;
|
| 72 |
+
cutlass::gemm::GemmCoord const *problem_sizes1;
|
| 73 |
+
|
| 74 |
+
struct Params {
|
| 75 |
+
cutlass::gemm::GemmCoord const *problem_sizes0;
|
| 76 |
+
cutlass::gemm::GemmCoord const *problem_sizes1;
|
| 77 |
+
int32_t problem_count;
|
| 78 |
+
void const *workspace;
|
| 79 |
+
int32_t tile_count;
|
| 80 |
+
|
| 81 |
+
//
|
| 82 |
+
// Methods
|
| 83 |
+
//
|
| 84 |
+
|
| 85 |
+
/// Ctor
|
| 86 |
+
CUTLASS_HOST_DEVICE
|
| 87 |
+
Params(): problem_sizes0(nullptr), problem_sizes1(nullptr),
|
| 88 |
+
problem_count(0), workspace(nullptr), tile_count(0) { }
|
| 89 |
+
|
| 90 |
+
/// Ctor
|
| 91 |
+
CUTLASS_HOST_DEVICE
|
| 92 |
+
Params(
|
| 93 |
+
cutlass::gemm::GemmCoord const *problem_sizes0,
|
| 94 |
+
cutlass::gemm::GemmCoord const *problem_sizes1,
|
| 95 |
+
int32_t problem_count,
|
| 96 |
+
void const *workspace = nullptr,
|
| 97 |
+
int32_t tile_count = 0
|
| 98 |
+
):
|
| 99 |
+
problem_sizes0(problem_sizes0),
|
| 100 |
+
problem_sizes1(problem_sizes1),
|
| 101 |
+
problem_count(problem_count),
|
| 102 |
+
workspace(workspace),
|
| 103 |
+
tile_count(tile_count)
|
| 104 |
+
{}
|
| 105 |
+
|
| 106 |
+
/// Convert the B2b-GEMM-specific parameters to those used by the base class
|
| 107 |
+
CUTLASS_HOST_DEVICE
|
| 108 |
+
BaseParams to_base() const {
|
| 109 |
+
return BaseParams(// Set problem_sizes as problem_sizes0 because these determine
|
| 110 |
+
// shape of the grid used in the non-grouped B2b GEMM
|
| 111 |
+
problem_sizes0,
|
| 112 |
+
problem_count,
|
| 113 |
+
workspace,
|
| 114 |
+
tile_count);
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
};
|
| 118 |
+
|
| 119 |
+
//
|
| 120 |
+
// Methods
|
| 121 |
+
//
|
| 122 |
+
CUTLASS_DEVICE
|
| 123 |
+
B2bGemmGroupedProblemVisitor(
|
| 124 |
+
Params const ¶ms_,
|
| 125 |
+
SharedStorage &shared_storage_,
|
| 126 |
+
int32_t block_idx
|
| 127 |
+
): Base (
|
| 128 |
+
params_.to_base(),
|
| 129 |
+
shared_storage_, block_idx),
|
| 130 |
+
problem_sizes0(params_.problem_sizes0),
|
| 131 |
+
problem_sizes1(params_.problem_sizes1)
|
| 132 |
+
{}
|
| 133 |
+
|
| 134 |
+
/// Returns the problem size 0 for the current problem
|
| 135 |
+
CUTLASS_HOST_DEVICE
|
| 136 |
+
cutlass::gemm::GemmCoord problem_size0() const {
|
| 137 |
+
GemmCoord problem = problem_sizes0[this->problem_idx];
|
| 138 |
+
ProblemSizeHelper::possibly_transpose_problem(problem);
|
| 139 |
+
return problem;
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
/// Returns the problem size 1 for the current problem
|
| 143 |
+
CUTLASS_HOST_DEVICE
|
| 144 |
+
cutlass::gemm::GemmCoord problem_size1() const {
|
| 145 |
+
GemmCoord problem = problem_sizes1[this->problem_idx];
|
| 146 |
+
ProblemSizeHelper::possibly_transpose_problem(problem);
|
| 147 |
+
return problem;
|
| 148 |
+
}
|
| 149 |
+
};
|
| 150 |
+
|
| 151 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 152 |
+
|
| 153 |
+
} // namespace kernel
|
| 154 |
+
} // namespace gemm
|
| 155 |
+
} // namespace cutlass
|
| 156 |
+
|
| 157 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h
ADDED
|
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Template for a pipelined Implicit GEMM kernel.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
|
| 39 |
+
#include "cutlass/aligned_buffer.h"
|
| 40 |
+
#include "cutlass/array.h"
|
| 41 |
+
#include "cutlass/numeric_types.h"
|
| 42 |
+
#include "cutlass/matrix_shape.h"
|
| 43 |
+
#include "cutlass/semaphore.h"
|
| 44 |
+
#include "cutlass/tensor_ref.h"
|
| 45 |
+
#include "cutlass/layout/tensor.h"
|
| 46 |
+
#include "cutlass/gemm/gemm.h"
|
| 47 |
+
#include "cutlass/conv/convolution.h"
|
| 48 |
+
#include "cutlass/conv/conv2d_problem_size.h"
|
| 49 |
+
#include "cutlass/conv/conv3d_problem_size.h"
|
| 50 |
+
#include "cutlass/epilogue/threadblock/output_iterator_parameter.h"
|
| 51 |
+
|
| 52 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
namespace cutlass {
|
| 55 |
+
namespace conv {
|
| 56 |
+
namespace kernel {
|
| 57 |
+
|
| 58 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
|
| 60 |
+
template <
|
| 61 |
+
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
|
| 62 |
+
typename Epilogue_, ///! Epilogue
|
| 63 |
+
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
| 64 |
+
conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad)
|
| 65 |
+
typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem
|
| 66 |
+
>
|
| 67 |
+
struct B2bImplicitGemmConvolution {
|
| 68 |
+
|
| 69 |
+
using B2bMma = B2bMma_;
|
| 70 |
+
using Epilogue = Epilogue_;
|
| 71 |
+
using EpilogueOutputOp0 = typename B2bMma::OutputOp;
|
| 72 |
+
using EpilogueOutputOp1 = typename Epilogue::OutputOp;
|
| 73 |
+
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
| 74 |
+
static Operator const kConvolutionalOperator = ConvOperator;
|
| 75 |
+
|
| 76 |
+
using ElementA = typename B2bMma::IteratorA0::Element;
|
| 77 |
+
using LayoutA = typename B2bMma::IteratorA0::Layout;
|
| 78 |
+
using ElementB = typename B2bMma::IteratorB0::Element;
|
| 79 |
+
using LayoutB = typename B2bMma::IteratorB0::Layout;
|
| 80 |
+
using ElementC = typename EpilogueOutputOp1::ElementOutput;
|
| 81 |
+
|
| 82 |
+
/// Set output tensor C layout
|
| 83 |
+
using LayoutC = LayoutA;
|
| 84 |
+
|
| 85 |
+
using ElementAccumulator = typename EpilogueOutputOp0::ElementAccumulator;
|
| 86 |
+
using ElementCompute = typename EpilogueOutputOp0::ElementCompute;
|
| 87 |
+
|
| 88 |
+
/// Scale and Bias
|
| 89 |
+
using ElementScaleBias = typename B2bMma::IteratorAccumulatorScaleBias::Element;
|
| 90 |
+
using LayoutScaleBias = typename B2bMma::IteratorAccumulatorScaleBias::Layout;
|
| 91 |
+
|
| 92 |
+
using WarpMmaOperator0 = typename B2bMma::Policy0::Operator;
|
| 93 |
+
using WarpMmaOperator1 = typename B2bMma::Policy1::Operator;
|
| 94 |
+
|
| 95 |
+
using ArchMmaOperator = typename WarpMmaOperator0::ArchMmaOperator;
|
| 96 |
+
using MathOperator = typename ArchMmaOperator::Operator;
|
| 97 |
+
|
| 98 |
+
using OperatorClass = typename WarpMmaOperator0::OperatorClass;
|
| 99 |
+
using ArchTag = typename WarpMmaOperator0::ArchTag;
|
| 100 |
+
|
| 101 |
+
using ThreadblockShape0 = typename B2bMma::Shape0;
|
| 102 |
+
using ThreadblockShape1 = typename B2bMma::Shape1;
|
| 103 |
+
using WarpShape0 = typename WarpMmaOperator0::Shape;
|
| 104 |
+
using WarpShape1 = typename WarpMmaOperator1::Shape;
|
| 105 |
+
using InstructionShape = typename ArchMmaOperator::Shape;
|
| 106 |
+
|
| 107 |
+
static int const kStages = B2bMma::kStages;
|
| 108 |
+
static IteratorAlgorithm const kIteratorAlgorithm = B2bMma::IteratorA0::kIteratorAlgorithm;
|
| 109 |
+
|
| 110 |
+
/// Warp count (concept: GemmShape)
|
| 111 |
+
using WarpCount0 = typename B2bMma::WarpCount0;
|
| 112 |
+
static int const kThreadCount = 32 * WarpCount0::kCount;
|
| 113 |
+
|
| 114 |
+
using TensorRefA0 = typename B2bMma::IteratorA0::TensorRef;
|
| 115 |
+
using TensorRefB0 = typename B2bMma::IteratorB0::TensorRef;
|
| 116 |
+
using TensorRefScaleBias0 = typename B2bMma::IteratorAccumulatorScaleBias::TensorRef;
|
| 117 |
+
using TensorRefB1 = typename B2bMma::IteratorB1::TensorRef;
|
| 118 |
+
using TensorRefC = cutlass::TensorRef<ElementC, LayoutC>;
|
| 119 |
+
|
| 120 |
+
/// Check iterator A and B convolution dimension are the same and
|
| 121 |
+
// set device::B2bImplicitGemmConvolution::kConvDim
|
| 122 |
+
static_assert(B2bMma::IteratorA0::kConvDim == B2bMma::IteratorB0::kConvDim,
|
| 123 |
+
"Convolution on different dimensions is not supported");
|
| 124 |
+
static int const kConvDim = B2bMma::IteratorA0::kConvDim;
|
| 125 |
+
|
| 126 |
+
/// Conv dimension and problem size structure (Conv2d or Conv3d)
|
| 127 |
+
using ConvProblemSize = ConvProblemSize_;
|
| 128 |
+
|
| 129 |
+
/// Wgrad C stride idx for implicit gemm algorithm
|
| 130 |
+
// Conv2d row-major matrix C (KxRSC)
|
| 131 |
+
// Conv3d row-major matrix C (KxTRSC)
|
| 132 |
+
static int const kWgradCStrideIdx =
|
| 133 |
+
cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
| 134 |
+
|
| 135 |
+
/// This chooses the appropriate stride element of the C tensor.
|
| 136 |
+
static int const kTensorCStrideIdx =
|
| 137 |
+
(kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0);
|
| 138 |
+
|
| 139 |
+
//
|
| 140 |
+
//
|
| 141 |
+
//
|
| 142 |
+
using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter<
|
| 143 |
+
LayoutC,
|
| 144 |
+
typename Epilogue::OutputTileIterator::Layout,
|
| 145 |
+
TensorRefC,
|
| 146 |
+
ConvOperator,
|
| 147 |
+
ConvProblemSize
|
| 148 |
+
>;
|
| 149 |
+
|
| 150 |
+
/// Argument structure
|
| 151 |
+
struct Arguments {
|
| 152 |
+
|
| 153 |
+
//
|
| 154 |
+
// Data members
|
| 155 |
+
//
|
| 156 |
+
|
| 157 |
+
ConvProblemSize problem_size_0;
|
| 158 |
+
ConvProblemSize problem_size_1;
|
| 159 |
+
TensorRefA0 ref_A0;
|
| 160 |
+
TensorRefB0 ref_B0;
|
| 161 |
+
TensorRefC ref_C0;
|
| 162 |
+
TensorRefScaleBias0 ref_Scale0;
|
| 163 |
+
TensorRefScaleBias0 ref_Bias0;
|
| 164 |
+
TensorRefB1 ref_B1;
|
| 165 |
+
TensorRefC ref_C1;
|
| 166 |
+
TensorRefC ref_D1;
|
| 167 |
+
typename EpilogueOutputOp0::Params output_op_0;
|
| 168 |
+
typename EpilogueOutputOp1::Params output_op_1;
|
| 169 |
+
SplitKMode split_k_mode;
|
| 170 |
+
|
| 171 |
+
//
|
| 172 |
+
// Methods
|
| 173 |
+
//
|
| 174 |
+
|
| 175 |
+
/// Default ctor
|
| 176 |
+
CUTLASS_HOST_DEVICE
|
| 177 |
+
Arguments() { }
|
| 178 |
+
|
| 179 |
+
CUTLASS_HOST_DEVICE
|
| 180 |
+
Arguments(
|
| 181 |
+
ConvProblemSize const & problem_size_0,
|
| 182 |
+
ConvProblemSize const & problem_size_1
|
| 183 |
+
):
|
| 184 |
+
problem_size_0(problem_size_0),
|
| 185 |
+
problem_size_1(problem_size_1) { }
|
| 186 |
+
|
| 187 |
+
CUTLASS_HOST_DEVICE
|
| 188 |
+
Arguments(
|
| 189 |
+
ConvProblemSize const & problem_size_0,
|
| 190 |
+
ConvProblemSize const & problem_size_1,
|
| 191 |
+
TensorRefA0 const & ref_A0,
|
| 192 |
+
TensorRefB0 const & ref_B0,
|
| 193 |
+
TensorRefC const & ref_C0,
|
| 194 |
+
TensorRefScaleBias0 const & ref_Scale0,
|
| 195 |
+
TensorRefScaleBias0 const & ref_Bias0,
|
| 196 |
+
TensorRefB1 const & ref_B1,
|
| 197 |
+
TensorRefC const & ref_C1,
|
| 198 |
+
TensorRefC const & ref_D1,
|
| 199 |
+
typename EpilogueOutputOp0::Params const & output_op_0,
|
| 200 |
+
typename EpilogueOutputOp1::Params const & output_op_1,
|
| 201 |
+
SplitKMode const & split_k_mode = SplitKMode::kSerial
|
| 202 |
+
):
|
| 203 |
+
problem_size_0(problem_size_0),
|
| 204 |
+
problem_size_1(problem_size_1),
|
| 205 |
+
ref_A0(ref_A0),
|
| 206 |
+
ref_B0(ref_B0),
|
| 207 |
+
ref_C0(ref_C0),
|
| 208 |
+
ref_Scale0(ref_Scale0),
|
| 209 |
+
ref_Bias0(ref_Bias0),
|
| 210 |
+
ref_B1(ref_B1),
|
| 211 |
+
ref_C1(ref_C1),
|
| 212 |
+
ref_D1(ref_D1),
|
| 213 |
+
output_op_0(output_op_0),
|
| 214 |
+
output_op_1(output_op_1),
|
| 215 |
+
split_k_mode(split_k_mode)
|
| 216 |
+
{
|
| 217 |
+
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
};
|
| 221 |
+
|
| 222 |
+
/// Parameters structure
|
| 223 |
+
struct Params {
|
| 224 |
+
ConvProblemSize problem_size_0;
|
| 225 |
+
ConvProblemSize problem_size_1;
|
| 226 |
+
cutlass::gemm::GemmCoord grid_tiled_shape;
|
| 227 |
+
gemm::GemmCoord implicit_gemm_problem_size_0;
|
| 228 |
+
gemm::GemmCoord implicit_gemm_problem_size_1;
|
| 229 |
+
int swizzle_log_tile;
|
| 230 |
+
int gemm_k_iterations_0;
|
| 231 |
+
int gemm_k_iterations_1;
|
| 232 |
+
typename B2bMma::IteratorA0::Params iterator_A0;
|
| 233 |
+
typename B2bMma::IteratorA0::Element const *ptr_A0;
|
| 234 |
+
typename B2bMma::IteratorB0::Params iterator_B0;
|
| 235 |
+
typename B2bMma::IteratorB0::Element const *ptr_B0;
|
| 236 |
+
typename Epilogue::OutputTileIterator::Params iterator_C0;
|
| 237 |
+
typename Epilogue::OutputTileIterator::Element *ptr_C0;
|
| 238 |
+
typename B2bMma::IteratorAccumulatorScaleBias::Element *ptr_Scale0;
|
| 239 |
+
typename B2bMma::IteratorAccumulatorScaleBias::Element *ptr_Bias0;
|
| 240 |
+
typename B2bMma::IteratorB1::Params iterator_B1;
|
| 241 |
+
typename B2bMma::IteratorB1::Element const *ptr_B1;
|
| 242 |
+
typename Epilogue::OutputTileIterator::Params iterator_C1;
|
| 243 |
+
typename Epilogue::OutputTileIterator::Element *ptr_C1;
|
| 244 |
+
typename Epilogue::OutputTileIterator::Params iterator_D1;
|
| 245 |
+
typename Epilogue::OutputTileIterator::Element *ptr_D1;
|
| 246 |
+
typename EpilogueOutputOp0::Params output_op_0;
|
| 247 |
+
typename EpilogueOutputOp1::Params output_op_1;
|
| 248 |
+
int *semaphore;
|
| 249 |
+
SplitKMode split_k_mode;
|
| 250 |
+
|
| 251 |
+
//
|
| 252 |
+
// Methods
|
| 253 |
+
//
|
| 254 |
+
|
| 255 |
+
CUTLASS_HOST_DEVICE
|
| 256 |
+
Params(): swizzle_log_tile(0), gemm_k_iterations_0(0), gemm_k_iterations_1(0) { }
|
| 257 |
+
|
| 258 |
+
///
|
| 259 |
+
CUTLASS_HOST_DEVICE
|
| 260 |
+
Params(
|
| 261 |
+
Arguments const &args,
|
| 262 |
+
int *semaphore = nullptr
|
| 263 |
+
):
|
| 264 |
+
problem_size_0(args.problem_size_0),
|
| 265 |
+
problem_size_1(args.problem_size_1),
|
| 266 |
+
implicit_gemm_problem_size_0(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0)),
|
| 267 |
+
implicit_gemm_problem_size_1(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_1)),
|
| 268 |
+
iterator_A0(B2bMma::IteratorA0::getParams(args.problem_size_0, args.ref_A0.layout())),
|
| 269 |
+
ptr_A0(args.ref_A0.data()),
|
| 270 |
+
iterator_B0(args.problem_size_0, args.ref_B0.layout()),
|
| 271 |
+
ptr_B0(args.ref_B0.data()),
|
| 272 |
+
iterator_C0(ConvOutputIteratorParameter::layout(args.ref_C0)),
|
| 273 |
+
ptr_C0(args.ref_C0.data()),
|
| 274 |
+
ptr_Scale0(args.ref_Scale0.data()),
|
| 275 |
+
ptr_Bias0(args.ref_Bias0.data()),
|
| 276 |
+
iterator_B1(args.problem_size_1, args.ref_B1.layout()),
|
| 277 |
+
ptr_B1(args.ref_B1.data()),
|
| 278 |
+
iterator_C1(ConvOutputIteratorParameter::layout(args.ref_C1)),
|
| 279 |
+
ptr_C1(args.ref_C1.data()),
|
| 280 |
+
iterator_D1(ConvOutputIteratorParameter::layout(args.ref_D1)),
|
| 281 |
+
ptr_D1(args.ref_D1.data()),
|
| 282 |
+
output_op_0(args.output_op_0),
|
| 283 |
+
output_op_1(args.output_op_1),
|
| 284 |
+
semaphore(semaphore),
|
| 285 |
+
split_k_mode(args.split_k_mode)
|
| 286 |
+
{
|
| 287 |
+
gemm_k_iterations_0 = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape0::kK, args.problem_size_0);
|
| 288 |
+
gemm_k_iterations_1 = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape1::kK, args.problem_size_1);
|
| 289 |
+
|
| 290 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 291 |
+
|
| 292 |
+
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
| 293 |
+
implicit_gemm_problem_size_0,
|
| 294 |
+
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
| 295 |
+
args.problem_size_0.split_k_slices);
|
| 296 |
+
|
| 297 |
+
swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape);
|
| 298 |
+
}
|
| 299 |
+
};
|
| 300 |
+
|
| 301 |
+
/// Shared memory storage structure
|
| 302 |
+
union SharedStorage {
|
| 303 |
+
typename B2bMma::B2bMmaSharedStorage main_loop;
|
| 304 |
+
typename Epilogue::SharedStorage epilogue;
|
| 305 |
+
};
|
| 306 |
+
|
| 307 |
+
//
|
| 308 |
+
// Methods
|
| 309 |
+
//
|
| 310 |
+
|
| 311 |
+
CUTLASS_HOST_DEVICE
|
| 312 |
+
B2bImplicitGemmConvolution() { }
|
| 313 |
+
|
| 314 |
+
/// Executes one ImplicitGEMM
|
| 315 |
+
CUTLASS_DEVICE
|
| 316 |
+
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
| 317 |
+
|
| 318 |
+
// Compute threadblock location
|
| 319 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 320 |
+
|
| 321 |
+
cutlass::gemm::GemmCoord threadblock_tile_idx =
|
| 322 |
+
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
| 323 |
+
|
| 324 |
+
// Early exit if CTA is out of range
|
| 325 |
+
if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() ||
|
| 326 |
+
params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) {
|
| 327 |
+
|
| 328 |
+
return;
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
// Compute position within threadblock
|
| 332 |
+
int thread_idx = threadIdx.x;
|
| 333 |
+
|
| 334 |
+
// Construct iterators to A and B operands
|
| 335 |
+
typename B2bMma::IteratorA0 iterator_A0(
|
| 336 |
+
params.iterator_A0,
|
| 337 |
+
params.problem_size_0,
|
| 338 |
+
params.ptr_A0,
|
| 339 |
+
thread_idx,
|
| 340 |
+
MatrixCoord(
|
| 341 |
+
threadblock_tile_idx.m() * B2bMma::Shape0::kM,
|
| 342 |
+
threadblock_tile_idx.k() * B2bMma::Shape0::kK
|
| 343 |
+
)
|
| 344 |
+
);
|
| 345 |
+
|
| 346 |
+
typename B2bMma::IteratorB0 iterator_B0(
|
| 347 |
+
params.iterator_B0,
|
| 348 |
+
params.problem_size_0,
|
| 349 |
+
params.ptr_B0,
|
| 350 |
+
thread_idx,
|
| 351 |
+
MatrixCoord(
|
| 352 |
+
threadblock_tile_idx.k() * B2bMma::Shape0::kK,
|
| 353 |
+
threadblock_tile_idx.n() * B2bMma::Shape0::kN
|
| 354 |
+
)
|
| 355 |
+
);
|
| 356 |
+
|
| 357 |
+
typename B2bMma::IteratorB1 iterator_B1(
|
| 358 |
+
params.iterator_B1,
|
| 359 |
+
params.problem_size_1,
|
| 360 |
+
params.ptr_B1,
|
| 361 |
+
thread_idx,
|
| 362 |
+
MatrixCoord(
|
| 363 |
+
threadblock_tile_idx.k() * B2bMma::Shape1::kK,
|
| 364 |
+
threadblock_tile_idx.n() * B2bMma::Shape1::kN
|
| 365 |
+
)
|
| 366 |
+
);
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
| 370 |
+
// is compiled as warp-uniform.
|
| 371 |
+
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 372 |
+
int lane_idx = threadIdx.x % 32;
|
| 373 |
+
|
| 374 |
+
// Construct iterators to accumulator scale/bias vector
|
| 375 |
+
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
|
| 376 |
+
params.ptr_Scale0,
|
| 377 |
+
{1, params.problem_size_0.K},
|
| 378 |
+
thread_idx,
|
| 379 |
+
warp_idx,
|
| 380 |
+
MatrixCoord(
|
| 381 |
+
0, threadblock_tile_idx.n() * B2bMma::Shape0::kN
|
| 382 |
+
)
|
| 383 |
+
);
|
| 384 |
+
|
| 385 |
+
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
|
| 386 |
+
params.ptr_Bias0,
|
| 387 |
+
{1, params.problem_size_0.K},
|
| 388 |
+
thread_idx,
|
| 389 |
+
warp_idx,
|
| 390 |
+
MatrixCoord(
|
| 391 |
+
0, threadblock_tile_idx.n() * B2bMma::Shape0::kN
|
| 392 |
+
)
|
| 393 |
+
);
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
//
|
| 397 |
+
// Main loop
|
| 398 |
+
//
|
| 399 |
+
|
| 400 |
+
EpilogueOutputOp0 output_op_0(params.output_op_0);
|
| 401 |
+
|
| 402 |
+
// Construct thread-scoped matrix multiply
|
| 403 |
+
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
| 404 |
+
|
| 405 |
+
typename B2bMma::FragmentC0 src_accum;
|
| 406 |
+
typename B2bMma::FragmentC1 accumulators;
|
| 407 |
+
|
| 408 |
+
src_accum.clear();
|
| 409 |
+
accumulators.clear();
|
| 410 |
+
|
| 411 |
+
// Compute threadblock-scoped matrix multiply-add
|
| 412 |
+
b2bMma(params.gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
|
| 413 |
+
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
|
| 414 |
+
|
| 415 |
+
//
|
| 416 |
+
// Epilogue
|
| 417 |
+
//
|
| 418 |
+
|
| 419 |
+
EpilogueOutputOp1 output_op_1(params.output_op_1);
|
| 420 |
+
|
| 421 |
+
// Construct the semaphore.
|
| 422 |
+
int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m();
|
| 423 |
+
|
| 424 |
+
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
| 425 |
+
|
| 426 |
+
// Compute logical position within grid
|
| 427 |
+
threadblock_tile_idx =
|
| 428 |
+
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
| 429 |
+
|
| 430 |
+
// If performing a reduction via split-K, fetch the initial synchronization
|
| 431 |
+
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
| 432 |
+
|
| 433 |
+
// Fetch the synchronization lock initially but do not block.
|
| 434 |
+
semaphore.fetch();
|
| 435 |
+
|
| 436 |
+
// Indicate which position in a serial reduction the output operator is currently updating
|
| 437 |
+
output_op_1.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k());
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
MatrixCoord threadblock_offset(
|
| 441 |
+
threadblock_tile_idx.m() * B2bMma::Shape1::kM,
|
| 442 |
+
threadblock_tile_idx.n() * B2bMma::Shape1::kN
|
| 443 |
+
);
|
| 444 |
+
|
| 445 |
+
// Tile iterator writing to destination tensor
|
| 446 |
+
typename Epilogue::OutputTileIterator iterator_D1(
|
| 447 |
+
params.iterator_D1,
|
| 448 |
+
params.ptr_D1,
|
| 449 |
+
ConvOutputIteratorParameter::extent(params.problem_size_1),
|
| 450 |
+
thread_idx,
|
| 451 |
+
threadblock_offset
|
| 452 |
+
);
|
| 453 |
+
|
| 454 |
+
// Tile iterator reading from source accumulator tensor
|
| 455 |
+
typename Epilogue::OutputTileIterator iterator_C1(
|
| 456 |
+
params.iterator_C1,
|
| 457 |
+
params.ptr_C1,
|
| 458 |
+
ConvOutputIteratorParameter::extent(params.problem_size_1),
|
| 459 |
+
thread_idx,
|
| 460 |
+
threadblock_offset
|
| 461 |
+
);
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
// Construct the epilogue
|
| 465 |
+
Epilogue epilogue(
|
| 466 |
+
shared_storage.epilogue,
|
| 467 |
+
thread_idx,
|
| 468 |
+
warp_idx,
|
| 469 |
+
lane_idx);
|
| 470 |
+
|
| 471 |
+
// Wait on the semaphore - this latency may have been covered by iterator construction
|
| 472 |
+
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
| 473 |
+
|
| 474 |
+
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
| 475 |
+
if (threadblock_tile_idx.k()) {
|
| 476 |
+
iterator_C1 = iterator_D1;
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
semaphore.wait(threadblock_tile_idx.k());
|
| 480 |
+
|
| 481 |
+
__threadfence();
|
| 482 |
+
}
|
| 483 |
+
// Each split-k-slice writes to a unique tensor location
|
| 484 |
+
else if (params.split_k_mode == SplitKMode::kParallel) {
|
| 485 |
+
iterator_D1.add_pointer_offset(threadblock_tile_idx.k() *
|
| 486 |
+
cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size_1));
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
// Run efficient epilogue
|
| 490 |
+
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
|
| 491 |
+
|
| 492 |
+
//
|
| 493 |
+
// Release the semaphore
|
| 494 |
+
//
|
| 495 |
+
|
| 496 |
+
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
| 497 |
+
|
| 498 |
+
int lock = 0;
|
| 499 |
+
if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) {
|
| 500 |
+
|
| 501 |
+
// The final threadblock resets the semaphore for subsequent grids.
|
| 502 |
+
lock = 0;
|
| 503 |
+
}
|
| 504 |
+
else {
|
| 505 |
+
// Otherwise, the semaphore is incremented
|
| 506 |
+
lock = threadblock_tile_idx.k() + 1;
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
semaphore.release(lock);
|
| 510 |
+
}
|
| 511 |
+
}
|
| 512 |
+
};
|
| 513 |
+
|
| 514 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 515 |
+
|
| 516 |
+
} // namespace kernel
|
| 517 |
+
} // namespace conv
|
| 518 |
+
} // namespace cutlass
|
| 519 |
+
|
| 520 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 521 |
+
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief
|
| 34 |
+
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
| 35 |
+
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
| 36 |
+
*/
|
| 37 |
+
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
#include "cutlass/cutlass.h"
|
| 41 |
+
#include "cutlass/conv/kernel/default_conv2d.h"
|
| 42 |
+
|
| 43 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
| 44 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
| 45 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
| 46 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
| 47 |
+
|
| 48 |
+
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
| 49 |
+
#include "cutlass/transform/threadblock/vector_iterator.h"
|
| 50 |
+
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
| 51 |
+
|
| 52 |
+
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
| 53 |
+
|
| 54 |
+
#include "kernel/b2b_implicit_gemm_convolution.h"
|
| 55 |
+
#include "threadblock/b2b_implicit_gemm_pipelined.h"
|
| 56 |
+
#include "threadblock/b2b_implicit_gemm_multistage.h"
|
| 57 |
+
|
| 58 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
|
| 60 |
+
namespace cutlass {
|
| 61 |
+
namespace conv {
|
| 62 |
+
namespace kernel {
|
| 63 |
+
|
| 64 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 65 |
+
/// Defines a kernel for Conv2dFprop
|
| 66 |
+
template <
|
| 67 |
+
typename ElementA,
|
| 68 |
+
typename LayoutA,
|
| 69 |
+
typename ElementB,
|
| 70 |
+
typename LayoutB,
|
| 71 |
+
typename ElementC,
|
| 72 |
+
typename LayoutC,
|
| 73 |
+
typename ElementAccumulator,
|
| 74 |
+
typename OperatorClass,
|
| 75 |
+
typename ArchTag,
|
| 76 |
+
typename ThreadblockShape0,
|
| 77 |
+
typename ThreadblockShape1,
|
| 78 |
+
typename WarpShape0,
|
| 79 |
+
typename WarpShape1,
|
| 80 |
+
typename InstructionShape,
|
| 81 |
+
typename EpilogueOutputOp0,
|
| 82 |
+
typename EpilogueOutputOp1,
|
| 83 |
+
typename ThreadblockSwizzle,
|
| 84 |
+
int Stages,
|
| 85 |
+
typename MathOperatorTag,
|
| 86 |
+
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
| 87 |
+
bool SmemAccumulator = false
|
| 88 |
+
> struct DefaultB2bConv2dFprop;
|
| 89 |
+
|
| 90 |
+
} // namespace kernel
|
| 91 |
+
} // namespace conv
|
| 92 |
+
} // namespace cutlass
|
| 93 |
+
|
| 94 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h
ADDED
|
@@ -0,0 +1,749 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief
|
| 34 |
+
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
| 35 |
+
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
| 36 |
+
*/
|
| 37 |
+
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
#include "cutlass/cutlass.h"
|
| 41 |
+
#include "cutlass/conv/kernel/default_conv2d.h"
|
| 42 |
+
|
| 43 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
| 44 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
| 45 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
| 46 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
| 47 |
+
|
| 48 |
+
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
| 49 |
+
#include "cutlass/transform/threadblock/vector_iterator.h"
|
| 50 |
+
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
| 51 |
+
|
| 52 |
+
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
| 53 |
+
|
| 54 |
+
#include "kernel/default_b2b_conv2d_fprop.h"
|
| 55 |
+
#include "kernel/b2b_implicit_gemm_convolution.h"
|
| 56 |
+
#include "threadblock/b2b_implicit_gemm_pipelined.h"
|
| 57 |
+
|
| 58 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
|
| 60 |
+
namespace cutlass {
|
| 61 |
+
namespace conv {
|
| 62 |
+
namespace kernel {
|
| 63 |
+
|
| 64 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 65 |
+
// OpClassTensorOp convolutions
|
| 66 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 67 |
+
|
| 68 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm
|
| 69 |
+
/// and 2 stage pipeline.
|
| 70 |
+
template <
|
| 71 |
+
typename ElementA,
|
| 72 |
+
typename LayoutA,
|
| 73 |
+
typename ElementB,
|
| 74 |
+
typename LayoutB,
|
| 75 |
+
typename ElementC,
|
| 76 |
+
typename LayoutC,
|
| 77 |
+
typename ElementAccumulator,
|
| 78 |
+
typename ArchTag,
|
| 79 |
+
typename ThreadblockShape0,
|
| 80 |
+
typename ThreadblockShape1,
|
| 81 |
+
typename WarpShape0,
|
| 82 |
+
typename WarpShape1,
|
| 83 |
+
typename InstructionShape,
|
| 84 |
+
typename EpilogueOutputOp0,
|
| 85 |
+
typename EpilogueOutputOp1,
|
| 86 |
+
typename ThreadblockSwizzle,
|
| 87 |
+
typename MathOperatorTag
|
| 88 |
+
>
|
| 89 |
+
struct DefaultB2bConv2dFprop <
|
| 90 |
+
ElementA,
|
| 91 |
+
LayoutA,
|
| 92 |
+
ElementB,
|
| 93 |
+
LayoutB,
|
| 94 |
+
ElementC,
|
| 95 |
+
LayoutC,
|
| 96 |
+
ElementAccumulator,
|
| 97 |
+
arch::OpClassTensorOp,
|
| 98 |
+
ArchTag,
|
| 99 |
+
ThreadblockShape0,
|
| 100 |
+
ThreadblockShape1,
|
| 101 |
+
WarpShape0,
|
| 102 |
+
WarpShape1,
|
| 103 |
+
InstructionShape,
|
| 104 |
+
EpilogueOutputOp0,
|
| 105 |
+
EpilogueOutputOp1,
|
| 106 |
+
ThreadblockSwizzle,
|
| 107 |
+
2,
|
| 108 |
+
MathOperatorTag,
|
| 109 |
+
IteratorAlgorithm::kAnalytic
|
| 110 |
+
> {
|
| 111 |
+
|
| 112 |
+
// Define the core components from GEMM
|
| 113 |
+
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 114 |
+
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
| 115 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 116 |
+
2, MathOperatorTag>;
|
| 117 |
+
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 118 |
+
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
| 119 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 120 |
+
2, MathOperatorTag>;
|
| 121 |
+
|
| 122 |
+
// Define iterators over tiles from the A operand
|
| 123 |
+
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
| 124 |
+
using IteratorA0 =
|
| 125 |
+
cutlass::conv::threadblock::TileIterator<
|
| 126 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
| 127 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
| 128 |
+
ElementA, LayoutA,
|
| 129 |
+
ThreadMapA0
|
| 130 |
+
>
|
| 131 |
+
>;
|
| 132 |
+
|
| 133 |
+
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
| 134 |
+
|
| 135 |
+
// Define iterators over tiles from the B operand
|
| 136 |
+
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
| 137 |
+
using IteratorB0 =
|
| 138 |
+
cutlass::conv::threadblock::TileIterator<
|
| 139 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 140 |
+
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
| 141 |
+
ElementB, LayoutB,
|
| 142 |
+
ThreadMapB0
|
| 143 |
+
>
|
| 144 |
+
>;
|
| 145 |
+
|
| 146 |
+
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
| 147 |
+
|
| 148 |
+
// Use fragment iterator for A operand
|
| 149 |
+
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
| 150 |
+
using FragmentIteratorA1 =
|
| 151 |
+
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
| 152 |
+
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
| 153 |
+
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
| 154 |
+
MmaCore1::Shape::kK, //kBlocksColumn
|
| 155 |
+
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
| 156 |
+
|
| 157 |
+
/// Define iterators over tiles from scale/bias vectors
|
| 158 |
+
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
| 159 |
+
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
| 160 |
+
static int const kElementsPerAccess = 2;
|
| 161 |
+
using IteratorAccumulatorScaleBias =
|
| 162 |
+
cutlass::transform::threadblock::VectorIterator<
|
| 163 |
+
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
| 164 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
| 165 |
+
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
| 166 |
+
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
| 167 |
+
>;
|
| 168 |
+
|
| 169 |
+
// Warp-level iterators to load scale and bias vectors
|
| 170 |
+
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
| 171 |
+
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
| 172 |
+
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
| 173 |
+
|
| 174 |
+
// Define iterators over tiles from the B operand
|
| 175 |
+
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
| 176 |
+
using IteratorB1 =
|
| 177 |
+
cutlass::conv::threadblock::TileIterator<
|
| 178 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 179 |
+
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
| 180 |
+
ElementB, LayoutB,
|
| 181 |
+
ThreadMapB1
|
| 182 |
+
>
|
| 183 |
+
>;
|
| 184 |
+
|
| 185 |
+
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
| 186 |
+
|
| 187 |
+
// Warp-level GEMM components
|
| 188 |
+
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
| 189 |
+
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
| 190 |
+
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
| 191 |
+
|
| 192 |
+
// Define the Mma
|
| 193 |
+
using B2bMma = threadblock::B2bImplicitGemmPipelined<
|
| 194 |
+
ThreadblockShape0,
|
| 195 |
+
IteratorA0,
|
| 196 |
+
SmemIteratorA0,
|
| 197 |
+
IteratorB0,
|
| 198 |
+
SmemIteratorB0,
|
| 199 |
+
ThreadblockShape1,
|
| 200 |
+
FragmentIteratorA1,
|
| 201 |
+
IteratorAccumulatorScaleBias,
|
| 202 |
+
FragmentIteratorA1ScaleBias,
|
| 203 |
+
IteratorB1,
|
| 204 |
+
SmemIteratorB1,
|
| 205 |
+
ElementC,
|
| 206 |
+
LayoutC,
|
| 207 |
+
EpilogueOutputOp0,
|
| 208 |
+
MmaPolicy0,
|
| 209 |
+
MmaPolicy1
|
| 210 |
+
>;
|
| 211 |
+
|
| 212 |
+
// Define the epilogue
|
| 213 |
+
using Epilogue = typename detail::DefaultConvEpilogue<
|
| 214 |
+
ArchTag,
|
| 215 |
+
ThreadblockShape1,
|
| 216 |
+
WarpMmaTensorOp1,
|
| 217 |
+
1,
|
| 218 |
+
EpilogueOutputOp1
|
| 219 |
+
>::Epilogue;
|
| 220 |
+
|
| 221 |
+
// Define the kernel
|
| 222 |
+
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
| 223 |
+
B2bMma,
|
| 224 |
+
Epilogue,
|
| 225 |
+
ThreadblockSwizzle,
|
| 226 |
+
conv::Operator::kFprop
|
| 227 |
+
>;
|
| 228 |
+
};
|
| 229 |
+
|
| 230 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 231 |
+
|
| 232 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage
|
| 233 |
+
/// pipeline with interleaved layout.
|
| 234 |
+
template <
|
| 235 |
+
typename ElementA,
|
| 236 |
+
typename ElementB,
|
| 237 |
+
typename ElementC,
|
| 238 |
+
typename LayoutC,
|
| 239 |
+
typename ElementAccumulator,
|
| 240 |
+
typename ArchTag,
|
| 241 |
+
typename ThreadblockShape0,
|
| 242 |
+
typename ThreadblockShape1,
|
| 243 |
+
typename WarpShape0,
|
| 244 |
+
typename WarpShape1,
|
| 245 |
+
typename InstructionShape,
|
| 246 |
+
typename EpilogueOutputOp0,
|
| 247 |
+
typename EpilogueOutputOp1,
|
| 248 |
+
typename ThreadblockSwizzle,
|
| 249 |
+
typename MathOperatorTag,
|
| 250 |
+
int InterleavedK
|
| 251 |
+
>
|
| 252 |
+
struct DefaultB2bConv2dFprop <
|
| 253 |
+
ElementA,
|
| 254 |
+
layout::TensorNCxHWx<InterleavedK>,
|
| 255 |
+
ElementB,
|
| 256 |
+
layout::TensorCxRSKx<InterleavedK>,
|
| 257 |
+
ElementC,
|
| 258 |
+
LayoutC,
|
| 259 |
+
ElementAccumulator,
|
| 260 |
+
arch::OpClassTensorOp,
|
| 261 |
+
ArchTag,
|
| 262 |
+
ThreadblockShape0,
|
| 263 |
+
ThreadblockShape1,
|
| 264 |
+
WarpShape0,
|
| 265 |
+
WarpShape1,
|
| 266 |
+
InstructionShape,
|
| 267 |
+
EpilogueOutputOp0,
|
| 268 |
+
EpilogueOutputOp1,
|
| 269 |
+
ThreadblockSwizzle,
|
| 270 |
+
2,
|
| 271 |
+
MathOperatorTag,
|
| 272 |
+
IteratorAlgorithm::kAnalytic,
|
| 273 |
+
false
|
| 274 |
+
> {
|
| 275 |
+
|
| 276 |
+
// Define the core components from GEMM
|
| 277 |
+
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 278 |
+
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 279 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 280 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 281 |
+
2, MathOperatorTag, true>;
|
| 282 |
+
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 283 |
+
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 284 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 285 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 286 |
+
2, MathOperatorTag, true>;
|
| 287 |
+
|
| 288 |
+
// Define iterators over tiles from the A operand
|
| 289 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 290 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 291 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 292 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 293 |
+
// layout.
|
| 294 |
+
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
| 295 |
+
using IteratorA0 =
|
| 296 |
+
cutlass::conv::threadblock::TileIterator<
|
| 297 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
| 298 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
| 299 |
+
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
| 300 |
+
ThreadMapA0
|
| 301 |
+
>
|
| 302 |
+
>;
|
| 303 |
+
|
| 304 |
+
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
| 305 |
+
|
| 306 |
+
// Define iterators over tiles from the B operand
|
| 307 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 308 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 309 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 310 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 311 |
+
// layout.
|
| 312 |
+
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
| 313 |
+
using IteratorB0 =
|
| 314 |
+
cutlass::conv::threadblock::TileIterator<
|
| 315 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 316 |
+
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
| 317 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 318 |
+
ThreadMapB0
|
| 319 |
+
>
|
| 320 |
+
>;
|
| 321 |
+
|
| 322 |
+
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
| 323 |
+
|
| 324 |
+
// Use fragment iterator for A operand
|
| 325 |
+
using AccumulatorLayout = cutlass::layout::RowMajor;
|
| 326 |
+
using FragmentIteratorA1 =
|
| 327 |
+
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
| 328 |
+
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
| 329 |
+
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
| 330 |
+
MmaCore1::Shape::kK, //kBlocksColumn
|
| 331 |
+
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
| 332 |
+
|
| 333 |
+
/// Define iterators over tiles from scale/bias vectors
|
| 334 |
+
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
| 335 |
+
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
| 336 |
+
static int const kElementsPerAccess = 4;
|
| 337 |
+
using IteratorAccumulatorScaleBias =
|
| 338 |
+
cutlass::transform::threadblock::VectorIterator<
|
| 339 |
+
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
| 340 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
| 341 |
+
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
| 342 |
+
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
| 343 |
+
>;
|
| 344 |
+
|
| 345 |
+
// Warp-level iterators to load scale and bias vectors
|
| 346 |
+
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
| 347 |
+
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
| 348 |
+
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
| 349 |
+
|
| 350 |
+
// Define iterators over tiles from the B operand
|
| 351 |
+
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
| 352 |
+
using IteratorB1 =
|
| 353 |
+
cutlass::conv::threadblock::TileIterator<
|
| 354 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 355 |
+
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
| 356 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 357 |
+
ThreadMapB1
|
| 358 |
+
>
|
| 359 |
+
>;
|
| 360 |
+
|
| 361 |
+
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
| 362 |
+
|
| 363 |
+
// Warp-level GEMM components
|
| 364 |
+
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
| 365 |
+
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
| 366 |
+
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
| 367 |
+
|
| 368 |
+
// Define the Mma
|
| 369 |
+
using B2bMma = threadblock::B2bImplicitGemmPipelined<
|
| 370 |
+
ThreadblockShape0,
|
| 371 |
+
IteratorA0,
|
| 372 |
+
SmemIteratorA0,
|
| 373 |
+
IteratorB0,
|
| 374 |
+
SmemIteratorB0,
|
| 375 |
+
ThreadblockShape1,
|
| 376 |
+
FragmentIteratorA1,
|
| 377 |
+
IteratorAccumulatorScaleBias,
|
| 378 |
+
FragmentIteratorA1ScaleBias,
|
| 379 |
+
IteratorB1,
|
| 380 |
+
SmemIteratorB1,
|
| 381 |
+
ElementC,
|
| 382 |
+
LayoutC,
|
| 383 |
+
EpilogueOutputOp0,
|
| 384 |
+
MmaPolicy0,
|
| 385 |
+
MmaPolicy1
|
| 386 |
+
>;
|
| 387 |
+
|
| 388 |
+
// Define the epilogue
|
| 389 |
+
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
| 390 |
+
ThreadblockShape1,
|
| 391 |
+
WarpMmaTensorOp1,
|
| 392 |
+
1,
|
| 393 |
+
EpilogueOutputOp1,
|
| 394 |
+
EpilogueOutputOp1::kCount,
|
| 395 |
+
InterleavedK
|
| 396 |
+
>::Epilogue;
|
| 397 |
+
|
| 398 |
+
// Define the kernel
|
| 399 |
+
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
| 400 |
+
B2bMma,
|
| 401 |
+
Epilogue,
|
| 402 |
+
ThreadblockSwizzle,
|
| 403 |
+
conv::Operator::kFprop
|
| 404 |
+
>;
|
| 405 |
+
};
|
| 406 |
+
|
| 407 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 408 |
+
|
| 409 |
+
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm
|
| 410 |
+
/// and 2 stage pipeline.
|
| 411 |
+
template <
|
| 412 |
+
typename ElementA,
|
| 413 |
+
typename LayoutA,
|
| 414 |
+
typename ElementB,
|
| 415 |
+
typename LayoutB,
|
| 416 |
+
typename ElementC,
|
| 417 |
+
typename LayoutC,
|
| 418 |
+
typename ElementAccumulator,
|
| 419 |
+
typename ArchTag,
|
| 420 |
+
typename ThreadblockShape0,
|
| 421 |
+
typename ThreadblockShape1,
|
| 422 |
+
typename WarpShape0,
|
| 423 |
+
typename WarpShape1,
|
| 424 |
+
typename InstructionShape,
|
| 425 |
+
typename EpilogueOutputOp0,
|
| 426 |
+
typename EpilogueOutputOp1,
|
| 427 |
+
typename ThreadblockSwizzle,
|
| 428 |
+
typename MathOperatorTag
|
| 429 |
+
>
|
| 430 |
+
struct DefaultB2bConv2dFprop <
|
| 431 |
+
ElementA,
|
| 432 |
+
LayoutA,
|
| 433 |
+
ElementB,
|
| 434 |
+
LayoutB,
|
| 435 |
+
ElementC,
|
| 436 |
+
LayoutC,
|
| 437 |
+
ElementAccumulator,
|
| 438 |
+
arch::OpClassTensorOp,
|
| 439 |
+
ArchTag,
|
| 440 |
+
ThreadblockShape0,
|
| 441 |
+
ThreadblockShape1,
|
| 442 |
+
WarpShape0,
|
| 443 |
+
WarpShape1,
|
| 444 |
+
InstructionShape,
|
| 445 |
+
EpilogueOutputOp0,
|
| 446 |
+
EpilogueOutputOp1,
|
| 447 |
+
ThreadblockSwizzle,
|
| 448 |
+
2,
|
| 449 |
+
MathOperatorTag,
|
| 450 |
+
IteratorAlgorithm::kOptimized
|
| 451 |
+
> {
|
| 452 |
+
|
| 453 |
+
// Define the core components from GEMM
|
| 454 |
+
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 455 |
+
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
| 456 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 457 |
+
2, MathOperatorTag>;
|
| 458 |
+
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 459 |
+
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
| 460 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 461 |
+
2, MathOperatorTag>;
|
| 462 |
+
|
| 463 |
+
// Define iterators over tiles from the A operand
|
| 464 |
+
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
| 465 |
+
using IteratorA0 =
|
| 466 |
+
cutlass::conv::threadblock::TileIterator<
|
| 467 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
| 468 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
| 469 |
+
ElementA, LayoutA,
|
| 470 |
+
ThreadMapA0
|
| 471 |
+
>
|
| 472 |
+
>;
|
| 473 |
+
|
| 474 |
+
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
| 475 |
+
|
| 476 |
+
// Define iterators over tiles from the B operand
|
| 477 |
+
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
| 478 |
+
using IteratorB0 =
|
| 479 |
+
cutlass::conv::threadblock::TileIterator<
|
| 480 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 481 |
+
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
| 482 |
+
ElementB, LayoutB,
|
| 483 |
+
ThreadMapB0
|
| 484 |
+
>
|
| 485 |
+
>;
|
| 486 |
+
|
| 487 |
+
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
| 488 |
+
|
| 489 |
+
// Use fragment iterator for A operand
|
| 490 |
+
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
| 491 |
+
using FragmentIteratorA1 =
|
| 492 |
+
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
| 493 |
+
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
| 494 |
+
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
| 495 |
+
MmaCore1::Shape::kK, //kBlocksColumn
|
| 496 |
+
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
| 497 |
+
|
| 498 |
+
/// Define iterators over tiles from scale/bias vectors
|
| 499 |
+
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
| 500 |
+
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
| 501 |
+
static int const kElementsPerAccess = 2;
|
| 502 |
+
using IteratorAccumulatorScaleBias =
|
| 503 |
+
cutlass::transform::threadblock::VectorIterator<
|
| 504 |
+
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
| 505 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
| 506 |
+
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
| 507 |
+
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
| 508 |
+
>;
|
| 509 |
+
|
| 510 |
+
// Warp-level iterators to load scale and bias vectors
|
| 511 |
+
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
| 512 |
+
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
| 513 |
+
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
| 514 |
+
|
| 515 |
+
// Define iterators over tiles from the B operand
|
| 516 |
+
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
| 517 |
+
using IteratorB1 =
|
| 518 |
+
cutlass::conv::threadblock::TileIterator<
|
| 519 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 520 |
+
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
| 521 |
+
ElementB, LayoutB,
|
| 522 |
+
ThreadMapB1
|
| 523 |
+
>
|
| 524 |
+
>;
|
| 525 |
+
|
| 526 |
+
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
| 527 |
+
|
| 528 |
+
// Warp-level GEMM components
|
| 529 |
+
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
| 530 |
+
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
| 531 |
+
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
| 532 |
+
|
| 533 |
+
// Define the Mma
|
| 534 |
+
using B2bMma = threadblock::B2bImplicitGemmPipelined<
|
| 535 |
+
ThreadblockShape0,
|
| 536 |
+
IteratorA0,
|
| 537 |
+
SmemIteratorA0,
|
| 538 |
+
IteratorB0,
|
| 539 |
+
SmemIteratorB0,
|
| 540 |
+
ThreadblockShape1,
|
| 541 |
+
FragmentIteratorA1,
|
| 542 |
+
IteratorAccumulatorScaleBias,
|
| 543 |
+
FragmentIteratorA1ScaleBias,
|
| 544 |
+
IteratorB1,
|
| 545 |
+
SmemIteratorB1,
|
| 546 |
+
ElementC,
|
| 547 |
+
LayoutC,
|
| 548 |
+
EpilogueOutputOp0,
|
| 549 |
+
MmaPolicy0,
|
| 550 |
+
MmaPolicy1
|
| 551 |
+
>;
|
| 552 |
+
|
| 553 |
+
// Define the epilogue
|
| 554 |
+
using Epilogue = typename detail::DefaultConvEpilogue<
|
| 555 |
+
ArchTag,
|
| 556 |
+
ThreadblockShape1,
|
| 557 |
+
WarpMmaTensorOp1,
|
| 558 |
+
1,
|
| 559 |
+
EpilogueOutputOp1
|
| 560 |
+
>::Epilogue;
|
| 561 |
+
|
| 562 |
+
// Define the kernel
|
| 563 |
+
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
| 564 |
+
B2bMma,
|
| 565 |
+
Epilogue,
|
| 566 |
+
ThreadblockSwizzle,
|
| 567 |
+
conv::Operator::kFprop
|
| 568 |
+
>;
|
| 569 |
+
};
|
| 570 |
+
|
| 571 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 572 |
+
|
| 573 |
+
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage
|
| 574 |
+
/// pipeline with interleaved layout.
|
| 575 |
+
template <
|
| 576 |
+
typename ElementA,
|
| 577 |
+
typename ElementB,
|
| 578 |
+
typename ElementC,
|
| 579 |
+
typename LayoutC,
|
| 580 |
+
typename ElementAccumulator,
|
| 581 |
+
typename ArchTag,
|
| 582 |
+
typename ThreadblockShape0,
|
| 583 |
+
typename ThreadblockShape1,
|
| 584 |
+
typename WarpShape0,
|
| 585 |
+
typename WarpShape1,
|
| 586 |
+
typename InstructionShape,
|
| 587 |
+
typename EpilogueOutputOp0,
|
| 588 |
+
typename EpilogueOutputOp1,
|
| 589 |
+
typename ThreadblockSwizzle,
|
| 590 |
+
typename MathOperatorTag,
|
| 591 |
+
int InterleavedK
|
| 592 |
+
>
|
| 593 |
+
struct DefaultB2bConv2dFprop <
|
| 594 |
+
ElementA,
|
| 595 |
+
layout::TensorNCxHWx<InterleavedK>,
|
| 596 |
+
ElementB,
|
| 597 |
+
layout::TensorCxRSKx<InterleavedK>,
|
| 598 |
+
ElementC,
|
| 599 |
+
LayoutC,
|
| 600 |
+
ElementAccumulator,
|
| 601 |
+
arch::OpClassTensorOp,
|
| 602 |
+
ArchTag,
|
| 603 |
+
ThreadblockShape0,
|
| 604 |
+
ThreadblockShape1,
|
| 605 |
+
WarpShape0,
|
| 606 |
+
WarpShape1,
|
| 607 |
+
InstructionShape,
|
| 608 |
+
EpilogueOutputOp0,
|
| 609 |
+
EpilogueOutputOp1,
|
| 610 |
+
ThreadblockSwizzle,
|
| 611 |
+
2,
|
| 612 |
+
MathOperatorTag,
|
| 613 |
+
IteratorAlgorithm::kOptimized
|
| 614 |
+
> {
|
| 615 |
+
|
| 616 |
+
// Define the core components from GEMM
|
| 617 |
+
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 618 |
+
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 619 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 620 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 621 |
+
2, MathOperatorTag, true>;
|
| 622 |
+
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 623 |
+
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 624 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 625 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 626 |
+
2, MathOperatorTag, true>;
|
| 627 |
+
|
| 628 |
+
// Define iterators over tiles from the A operand
|
| 629 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 630 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 631 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 632 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 633 |
+
// layout.
|
| 634 |
+
|
| 635 |
+
// Define iterators over tiles from the A operand
|
| 636 |
+
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
| 637 |
+
using IteratorA0 =
|
| 638 |
+
cutlass::conv::threadblock::TileIterator<
|
| 639 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
| 640 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
| 641 |
+
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
| 642 |
+
ThreadMapA0
|
| 643 |
+
>
|
| 644 |
+
>;
|
| 645 |
+
|
| 646 |
+
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
| 647 |
+
|
| 648 |
+
// Define iterators over tiles from the B operand
|
| 649 |
+
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
| 650 |
+
using IteratorB0 =
|
| 651 |
+
cutlass::conv::threadblock::TileIterator<
|
| 652 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 653 |
+
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
| 654 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 655 |
+
ThreadMapB0
|
| 656 |
+
>
|
| 657 |
+
>;
|
| 658 |
+
|
| 659 |
+
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
| 660 |
+
|
| 661 |
+
// Use fragment iterator for A operand
|
| 662 |
+
using AccumulatorLayout = cutlass::layout::RowMajor;
|
| 663 |
+
using FragmentIteratorA1 =
|
| 664 |
+
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
| 665 |
+
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
| 666 |
+
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
| 667 |
+
MmaCore1::Shape::kK, //kBlocksColumn
|
| 668 |
+
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
| 669 |
+
|
| 670 |
+
/// Define iterators over tiles from scale/bias vectors
|
| 671 |
+
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
| 672 |
+
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
| 673 |
+
static int const kElementsPerAccess = 4;
|
| 674 |
+
using IteratorAccumulatorScaleBias =
|
| 675 |
+
cutlass::transform::threadblock::VectorIterator<
|
| 676 |
+
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
| 677 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
| 678 |
+
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
| 679 |
+
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
| 680 |
+
>;
|
| 681 |
+
|
| 682 |
+
// Warp-level iterators to load scale and bias vectors
|
| 683 |
+
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
| 684 |
+
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
| 685 |
+
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
| 686 |
+
|
| 687 |
+
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
| 688 |
+
using IteratorB1 =
|
| 689 |
+
cutlass::conv::threadblock::TileIterator<
|
| 690 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 691 |
+
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
| 692 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 693 |
+
ThreadMapB1
|
| 694 |
+
>
|
| 695 |
+
>;
|
| 696 |
+
|
| 697 |
+
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
| 698 |
+
|
| 699 |
+
// Warp-level GEMM components
|
| 700 |
+
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
| 701 |
+
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
| 702 |
+
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
| 703 |
+
|
| 704 |
+
// Define the Mma
|
| 705 |
+
using B2bMma = threadblock::B2bImplicitGemmPipelined<
|
| 706 |
+
ThreadblockShape0,
|
| 707 |
+
IteratorA0,
|
| 708 |
+
SmemIteratorA0,
|
| 709 |
+
IteratorB0,
|
| 710 |
+
SmemIteratorB0,
|
| 711 |
+
ThreadblockShape1,
|
| 712 |
+
FragmentIteratorA1,
|
| 713 |
+
IteratorAccumulatorScaleBias,
|
| 714 |
+
FragmentIteratorA1ScaleBias,
|
| 715 |
+
IteratorB1,
|
| 716 |
+
SmemIteratorB1,
|
| 717 |
+
ElementC,
|
| 718 |
+
LayoutC,
|
| 719 |
+
EpilogueOutputOp0,
|
| 720 |
+
MmaPolicy0,
|
| 721 |
+
MmaPolicy1
|
| 722 |
+
>;
|
| 723 |
+
|
| 724 |
+
// Define the epilogue
|
| 725 |
+
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
| 726 |
+
ThreadblockShape1,
|
| 727 |
+
WarpMmaTensorOp1,
|
| 728 |
+
1,
|
| 729 |
+
EpilogueOutputOp1,
|
| 730 |
+
EpilogueOutputOp1::kCount,
|
| 731 |
+
InterleavedK
|
| 732 |
+
>::Epilogue;
|
| 733 |
+
|
| 734 |
+
// Define the kernel
|
| 735 |
+
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
| 736 |
+
B2bMma,
|
| 737 |
+
Epilogue,
|
| 738 |
+
ThreadblockSwizzle,
|
| 739 |
+
conv::Operator::kFprop
|
| 740 |
+
>;
|
| 741 |
+
};
|
| 742 |
+
|
| 743 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 744 |
+
|
| 745 |
+
} // namespace kernel
|
| 746 |
+
} // namespace conv
|
| 747 |
+
} // namespace cutlass
|
| 748 |
+
|
| 749 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h
ADDED
|
@@ -0,0 +1,740 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief
|
| 34 |
+
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
| 35 |
+
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
| 36 |
+
*/
|
| 37 |
+
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
#include "cutlass/cutlass.h"
|
| 41 |
+
#include "cutlass/conv/kernel/default_conv2d.h"
|
| 42 |
+
|
| 43 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
| 44 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
| 45 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
| 46 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
| 47 |
+
|
| 48 |
+
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
| 49 |
+
#include "cutlass/transform/threadblock/vector_iterator.h"
|
| 50 |
+
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
| 51 |
+
|
| 52 |
+
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
| 53 |
+
|
| 54 |
+
#include "kernel/default_b2b_conv2d_fprop.h"
|
| 55 |
+
#include "kernel/b2b_implicit_gemm_convolution.h"
|
| 56 |
+
#include "threadblock/b2b_implicit_gemm_multistage.h"
|
| 57 |
+
|
| 58 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
|
| 60 |
+
namespace cutlass {
|
| 61 |
+
namespace conv {
|
| 62 |
+
namespace kernel {
|
| 63 |
+
|
| 64 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 65 |
+
// OpClassTensorOp convolutions
|
| 66 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 67 |
+
|
| 68 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
| 69 |
+
/// pipeline.
|
| 70 |
+
template <
|
| 71 |
+
typename ElementA,
|
| 72 |
+
typename LayoutA,
|
| 73 |
+
typename ElementB,
|
| 74 |
+
typename LayoutB,
|
| 75 |
+
typename ElementC,
|
| 76 |
+
typename LayoutC,
|
| 77 |
+
typename ElementAccumulator,
|
| 78 |
+
typename ArchTag,
|
| 79 |
+
typename ThreadblockShape0,
|
| 80 |
+
typename ThreadblockShape1,
|
| 81 |
+
typename WarpShape0,
|
| 82 |
+
typename WarpShape1,
|
| 83 |
+
typename InstructionShape,
|
| 84 |
+
typename EpilogueOutputOp0,
|
| 85 |
+
typename EpilogueOutputOp1,
|
| 86 |
+
typename ThreadblockSwizzle,
|
| 87 |
+
int Stages,
|
| 88 |
+
typename MathOperatorTag
|
| 89 |
+
>
|
| 90 |
+
struct DefaultB2bConv2dFprop <
|
| 91 |
+
ElementA,
|
| 92 |
+
LayoutA,
|
| 93 |
+
ElementB,
|
| 94 |
+
LayoutB,
|
| 95 |
+
ElementC,
|
| 96 |
+
LayoutC,
|
| 97 |
+
ElementAccumulator,
|
| 98 |
+
arch::OpClassTensorOp,
|
| 99 |
+
ArchTag,
|
| 100 |
+
ThreadblockShape0,
|
| 101 |
+
ThreadblockShape1,
|
| 102 |
+
WarpShape0,
|
| 103 |
+
WarpShape1,
|
| 104 |
+
InstructionShape,
|
| 105 |
+
EpilogueOutputOp0,
|
| 106 |
+
EpilogueOutputOp1,
|
| 107 |
+
ThreadblockSwizzle,
|
| 108 |
+
Stages,
|
| 109 |
+
MathOperatorTag,
|
| 110 |
+
IteratorAlgorithm::kAnalytic
|
| 111 |
+
> {
|
| 112 |
+
|
| 113 |
+
// Define the core components from GEMM
|
| 114 |
+
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 115 |
+
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
| 116 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 117 |
+
Stages, MathOperatorTag>;
|
| 118 |
+
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 119 |
+
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
| 120 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 121 |
+
Stages, MathOperatorTag>;
|
| 122 |
+
|
| 123 |
+
// Define iterators over tiles from the A operand
|
| 124 |
+
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
| 125 |
+
using IteratorA0 =
|
| 126 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
| 127 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
| 128 |
+
ElementA, LayoutA,
|
| 129 |
+
ThreadMapA0
|
| 130 |
+
>;
|
| 131 |
+
|
| 132 |
+
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
| 133 |
+
|
| 134 |
+
// Define iterators over tiles from the B operand
|
| 135 |
+
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
| 136 |
+
using IteratorB0 =
|
| 137 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 138 |
+
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
| 139 |
+
ElementB, LayoutB,
|
| 140 |
+
ThreadMapB0
|
| 141 |
+
>;
|
| 142 |
+
|
| 143 |
+
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
| 144 |
+
|
| 145 |
+
// Use fragment iterator for A operand
|
| 146 |
+
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
| 147 |
+
using FragmentIteratorA1 =
|
| 148 |
+
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
| 149 |
+
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
| 150 |
+
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
| 151 |
+
MmaCore1::Shape::kK, //kBlocksColumn
|
| 152 |
+
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
| 153 |
+
|
| 154 |
+
/// Define iterators over tiles from scale/bias vectors
|
| 155 |
+
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
| 156 |
+
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
| 157 |
+
static int const kElementsPerAccess = 2;
|
| 158 |
+
using IteratorAccumulatorScaleBias =
|
| 159 |
+
cutlass::transform::threadblock::VectorIterator<
|
| 160 |
+
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
| 161 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
| 162 |
+
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
| 163 |
+
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
| 164 |
+
>;
|
| 165 |
+
|
| 166 |
+
// Warp-level iterators to load scale and bias vectors
|
| 167 |
+
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
| 168 |
+
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
| 169 |
+
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
| 170 |
+
|
| 171 |
+
// Define iterators over tiles from the B operand
|
| 172 |
+
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
| 173 |
+
using IteratorB1 =
|
| 174 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 175 |
+
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
| 176 |
+
ElementB, LayoutB,
|
| 177 |
+
ThreadMapB1
|
| 178 |
+
>;
|
| 179 |
+
|
| 180 |
+
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
| 181 |
+
|
| 182 |
+
// Warp-level GEMM components
|
| 183 |
+
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
| 184 |
+
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
| 185 |
+
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
| 186 |
+
|
| 187 |
+
// Define the Mma
|
| 188 |
+
using B2bMma = threadblock::B2bImplicitGemmMultistage<
|
| 189 |
+
ThreadblockShape0,
|
| 190 |
+
IteratorA0,
|
| 191 |
+
SmemIteratorA0,
|
| 192 |
+
arch::CacheOperation::Always,
|
| 193 |
+
IteratorB0,
|
| 194 |
+
SmemIteratorB0,
|
| 195 |
+
arch::CacheOperation::Global,
|
| 196 |
+
ThreadblockShape1,
|
| 197 |
+
FragmentIteratorA1,
|
| 198 |
+
IteratorAccumulatorScaleBias,
|
| 199 |
+
FragmentIteratorA1ScaleBias,
|
| 200 |
+
IteratorB1,
|
| 201 |
+
SmemIteratorB1,
|
| 202 |
+
arch::CacheOperation::Global,
|
| 203 |
+
EpilogueOutputOp0,
|
| 204 |
+
MmaPolicy0,
|
| 205 |
+
MmaPolicy1,
|
| 206 |
+
Stages
|
| 207 |
+
>;
|
| 208 |
+
|
| 209 |
+
// Define the epilogue
|
| 210 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 211 |
+
ThreadblockShape1,
|
| 212 |
+
WarpMmaTensorOp1,
|
| 213 |
+
1,
|
| 214 |
+
EpilogueOutputOp1,
|
| 215 |
+
EpilogueOutputOp1::kCount
|
| 216 |
+
>::Epilogue;
|
| 217 |
+
|
| 218 |
+
// Define the kernel
|
| 219 |
+
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
| 220 |
+
B2bMma,
|
| 221 |
+
Epilogue,
|
| 222 |
+
ThreadblockSwizzle,
|
| 223 |
+
conv::Operator::kFprop
|
| 224 |
+
>;
|
| 225 |
+
};
|
| 226 |
+
|
| 227 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 228 |
+
|
| 229 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
| 230 |
+
/// pipeline with interleaved layout.
|
| 231 |
+
template <
|
| 232 |
+
typename ElementA,
|
| 233 |
+
typename ElementB,
|
| 234 |
+
typename ElementC,
|
| 235 |
+
typename LayoutC,
|
| 236 |
+
typename ElementAccumulator,
|
| 237 |
+
typename ArchTag,
|
| 238 |
+
typename ThreadblockShape0,
|
| 239 |
+
typename ThreadblockShape1,
|
| 240 |
+
typename WarpShape0,
|
| 241 |
+
typename WarpShape1,
|
| 242 |
+
typename InstructionShape,
|
| 243 |
+
typename EpilogueOutputOp0,
|
| 244 |
+
typename EpilogueOutputOp1,
|
| 245 |
+
typename ThreadblockSwizzle,
|
| 246 |
+
int Stages,
|
| 247 |
+
typename MathOperatorTag,
|
| 248 |
+
int InterleavedK
|
| 249 |
+
>
|
| 250 |
+
struct DefaultB2bConv2dFprop <
|
| 251 |
+
ElementA,
|
| 252 |
+
layout::TensorNCxHWx<InterleavedK>,
|
| 253 |
+
ElementB,
|
| 254 |
+
layout::TensorCxRSKx<InterleavedK>,
|
| 255 |
+
ElementC,
|
| 256 |
+
LayoutC,
|
| 257 |
+
ElementAccumulator,
|
| 258 |
+
arch::OpClassTensorOp,
|
| 259 |
+
ArchTag,
|
| 260 |
+
ThreadblockShape0,
|
| 261 |
+
ThreadblockShape1,
|
| 262 |
+
WarpShape0,
|
| 263 |
+
WarpShape1,
|
| 264 |
+
InstructionShape,
|
| 265 |
+
EpilogueOutputOp0,
|
| 266 |
+
EpilogueOutputOp1,
|
| 267 |
+
ThreadblockSwizzle,
|
| 268 |
+
Stages,
|
| 269 |
+
MathOperatorTag,
|
| 270 |
+
IteratorAlgorithm::kAnalytic
|
| 271 |
+
> {
|
| 272 |
+
|
| 273 |
+
// Define the core components from GEMM
|
| 274 |
+
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 275 |
+
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 276 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 277 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 278 |
+
Stages, MathOperatorTag, true>;
|
| 279 |
+
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 280 |
+
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 281 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 282 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 283 |
+
Stages, MathOperatorTag, true>;
|
| 284 |
+
|
| 285 |
+
// Define iterators over tiles from the A operand
|
| 286 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 287 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 288 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 289 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 290 |
+
// layout.
|
| 291 |
+
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
| 292 |
+
using IteratorA0 =
|
| 293 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
| 294 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
| 295 |
+
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
| 296 |
+
ThreadMapA0
|
| 297 |
+
>;
|
| 298 |
+
|
| 299 |
+
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
| 300 |
+
|
| 301 |
+
// Define iterators over tiles from the B operand
|
| 302 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 303 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 304 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 305 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 306 |
+
// layout.
|
| 307 |
+
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
| 308 |
+
using IteratorB0 =
|
| 309 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 310 |
+
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
| 311 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 312 |
+
ThreadMapB0
|
| 313 |
+
>;
|
| 314 |
+
|
| 315 |
+
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
| 316 |
+
|
| 317 |
+
// Use fragment iterator for A operand
|
| 318 |
+
using AccumulatorLayout = cutlass::layout::RowMajor;
|
| 319 |
+
using FragmentIteratorA1 =
|
| 320 |
+
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
| 321 |
+
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
| 322 |
+
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
| 323 |
+
MmaCore1::Shape::kK, //kBlocksColumn
|
| 324 |
+
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
| 325 |
+
|
| 326 |
+
/// Define iterators over tiles from scale/bias vectors
|
| 327 |
+
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
| 328 |
+
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
| 329 |
+
static int const kElementsPerAccess = 4;
|
| 330 |
+
using IteratorAccumulatorScaleBias =
|
| 331 |
+
cutlass::transform::threadblock::VectorIterator<
|
| 332 |
+
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
| 333 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
| 334 |
+
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
| 335 |
+
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
| 336 |
+
>;
|
| 337 |
+
|
| 338 |
+
// Warp-level iterators to load scale and bias vectors
|
| 339 |
+
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
| 340 |
+
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
| 341 |
+
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
| 342 |
+
|
| 343 |
+
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
| 344 |
+
using IteratorB1 =
|
| 345 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 346 |
+
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
| 347 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 348 |
+
ThreadMapB1
|
| 349 |
+
>;
|
| 350 |
+
|
| 351 |
+
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
// Warp-level GEMM components
|
| 355 |
+
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
| 356 |
+
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
| 357 |
+
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
| 358 |
+
|
| 359 |
+
// Define the Mma
|
| 360 |
+
using B2bMma = threadblock::B2bImplicitGemmMultistage<
|
| 361 |
+
ThreadblockShape0,
|
| 362 |
+
IteratorA0,
|
| 363 |
+
SmemIteratorA0,
|
| 364 |
+
arch::CacheOperation::Always,
|
| 365 |
+
IteratorB0,
|
| 366 |
+
SmemIteratorB0,
|
| 367 |
+
arch::CacheOperation::Global,
|
| 368 |
+
ThreadblockShape1,
|
| 369 |
+
FragmentIteratorA1,
|
| 370 |
+
IteratorAccumulatorScaleBias,
|
| 371 |
+
FragmentIteratorA1ScaleBias,
|
| 372 |
+
IteratorB1,
|
| 373 |
+
SmemIteratorB1,
|
| 374 |
+
arch::CacheOperation::Global,
|
| 375 |
+
EpilogueOutputOp0,
|
| 376 |
+
MmaPolicy0,
|
| 377 |
+
MmaPolicy1,
|
| 378 |
+
Stages
|
| 379 |
+
>;
|
| 380 |
+
|
| 381 |
+
// Define the epilogue
|
| 382 |
+
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
| 383 |
+
ThreadblockShape1,
|
| 384 |
+
WarpMmaTensorOp1,
|
| 385 |
+
1,
|
| 386 |
+
EpilogueOutputOp1,
|
| 387 |
+
EpilogueOutputOp1::kCount,
|
| 388 |
+
InterleavedK
|
| 389 |
+
>::Epilogue;
|
| 390 |
+
|
| 391 |
+
// Define the kernel
|
| 392 |
+
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
| 393 |
+
B2bMma,
|
| 394 |
+
Epilogue,
|
| 395 |
+
ThreadblockSwizzle,
|
| 396 |
+
conv::Operator::kFprop
|
| 397 |
+
>;
|
| 398 |
+
};
|
| 399 |
+
|
| 400 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 401 |
+
|
| 402 |
+
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
| 403 |
+
/// multistage pipeline.
|
| 404 |
+
template <
|
| 405 |
+
typename ElementA,
|
| 406 |
+
typename LayoutA,
|
| 407 |
+
typename ElementB,
|
| 408 |
+
typename LayoutB,
|
| 409 |
+
typename ElementC,
|
| 410 |
+
typename LayoutC,
|
| 411 |
+
typename ElementAccumulator,
|
| 412 |
+
typename ArchTag,
|
| 413 |
+
typename ThreadblockShape0,
|
| 414 |
+
typename ThreadblockShape1,
|
| 415 |
+
typename WarpShape0,
|
| 416 |
+
typename WarpShape1,
|
| 417 |
+
typename InstructionShape,
|
| 418 |
+
typename EpilogueOutputOp0,
|
| 419 |
+
typename EpilogueOutputOp1,
|
| 420 |
+
typename ThreadblockSwizzle,
|
| 421 |
+
int Stages,
|
| 422 |
+
typename MathOperatorTag
|
| 423 |
+
>
|
| 424 |
+
struct DefaultB2bConv2dFprop <
|
| 425 |
+
ElementA,
|
| 426 |
+
LayoutA,
|
| 427 |
+
ElementB,
|
| 428 |
+
LayoutB,
|
| 429 |
+
ElementC,
|
| 430 |
+
LayoutC,
|
| 431 |
+
ElementAccumulator,
|
| 432 |
+
arch::OpClassTensorOp,
|
| 433 |
+
ArchTag,
|
| 434 |
+
ThreadblockShape0,
|
| 435 |
+
ThreadblockShape1,
|
| 436 |
+
WarpShape0,
|
| 437 |
+
WarpShape1,
|
| 438 |
+
InstructionShape,
|
| 439 |
+
EpilogueOutputOp0,
|
| 440 |
+
EpilogueOutputOp1,
|
| 441 |
+
ThreadblockSwizzle,
|
| 442 |
+
Stages,
|
| 443 |
+
MathOperatorTag,
|
| 444 |
+
IteratorAlgorithm::kOptimized
|
| 445 |
+
> {
|
| 446 |
+
|
| 447 |
+
// Define the core components from GEMM
|
| 448 |
+
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 449 |
+
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
| 450 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 451 |
+
Stages, MathOperatorTag>;
|
| 452 |
+
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 453 |
+
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
| 454 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 455 |
+
Stages, MathOperatorTag>;
|
| 456 |
+
|
| 457 |
+
// Define iterators over tiles from the A operand
|
| 458 |
+
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
| 459 |
+
using IteratorA0 =
|
| 460 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
| 461 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
| 462 |
+
ElementA, LayoutA,
|
| 463 |
+
ThreadMapA0
|
| 464 |
+
>;
|
| 465 |
+
|
| 466 |
+
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
| 467 |
+
|
| 468 |
+
// Define iterators over tiles from the B operand
|
| 469 |
+
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
| 470 |
+
using IteratorB0 =
|
| 471 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 472 |
+
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
| 473 |
+
ElementB, LayoutB,
|
| 474 |
+
ThreadMapB0
|
| 475 |
+
>;
|
| 476 |
+
|
| 477 |
+
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
| 478 |
+
|
| 479 |
+
// Use fragment iterator for A operand
|
| 480 |
+
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
| 481 |
+
using FragmentIteratorA1 =
|
| 482 |
+
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
| 483 |
+
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
| 484 |
+
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
| 485 |
+
MmaCore1::Shape::kK, //kBlocksColumn
|
| 486 |
+
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
| 487 |
+
|
| 488 |
+
/// Define iterators over tiles from scale/bias vectors
|
| 489 |
+
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
| 490 |
+
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
| 491 |
+
static int const kElementsPerAccess = 2;
|
| 492 |
+
using IteratorAccumulatorScaleBias =
|
| 493 |
+
cutlass::transform::threadblock::VectorIterator<
|
| 494 |
+
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
| 495 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
| 496 |
+
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
| 497 |
+
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
| 498 |
+
>;
|
| 499 |
+
|
| 500 |
+
// Warp-level iterators to load scale and bias vectors
|
| 501 |
+
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
| 502 |
+
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
| 503 |
+
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
| 504 |
+
|
| 505 |
+
// Define iterators over tiles from the B operand
|
| 506 |
+
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
| 507 |
+
using IteratorB1 =
|
| 508 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 509 |
+
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
| 510 |
+
ElementB, LayoutB,
|
| 511 |
+
ThreadMapB1
|
| 512 |
+
>;
|
| 513 |
+
|
| 514 |
+
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
| 515 |
+
|
| 516 |
+
// Warp-level GEMM components
|
| 517 |
+
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
| 518 |
+
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
| 519 |
+
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
| 520 |
+
|
| 521 |
+
// Define the Mma
|
| 522 |
+
using B2bMma = threadblock::B2bImplicitGemmMultistage<
|
| 523 |
+
ThreadblockShape0,
|
| 524 |
+
IteratorA0,
|
| 525 |
+
SmemIteratorA0,
|
| 526 |
+
arch::CacheOperation::Always,
|
| 527 |
+
IteratorB0,
|
| 528 |
+
SmemIteratorB0,
|
| 529 |
+
arch::CacheOperation::Global,
|
| 530 |
+
ThreadblockShape1,
|
| 531 |
+
FragmentIteratorA1,
|
| 532 |
+
IteratorAccumulatorScaleBias,
|
| 533 |
+
FragmentIteratorA1ScaleBias,
|
| 534 |
+
IteratorB1,
|
| 535 |
+
SmemIteratorB1,
|
| 536 |
+
arch::CacheOperation::Global,
|
| 537 |
+
EpilogueOutputOp0,
|
| 538 |
+
MmaPolicy0,
|
| 539 |
+
MmaPolicy1,
|
| 540 |
+
Stages
|
| 541 |
+
>;
|
| 542 |
+
|
| 543 |
+
// Define the epilogue
|
| 544 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 545 |
+
ThreadblockShape1,
|
| 546 |
+
WarpMmaTensorOp1,
|
| 547 |
+
1,
|
| 548 |
+
EpilogueOutputOp1,
|
| 549 |
+
EpilogueOutputOp1::kCount
|
| 550 |
+
>::Epilogue;
|
| 551 |
+
|
| 552 |
+
// Define the kernel
|
| 553 |
+
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
| 554 |
+
B2bMma,
|
| 555 |
+
Epilogue,
|
| 556 |
+
ThreadblockSwizzle,
|
| 557 |
+
conv::Operator::kFprop
|
| 558 |
+
>;
|
| 559 |
+
};
|
| 560 |
+
|
| 561 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 562 |
+
|
| 563 |
+
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
| 564 |
+
// multistage pipeline with interleaved layout.
|
| 565 |
+
template <
|
| 566 |
+
typename ElementA,
|
| 567 |
+
typename ElementB,
|
| 568 |
+
typename ElementC,
|
| 569 |
+
typename LayoutC,
|
| 570 |
+
typename ElementAccumulator,
|
| 571 |
+
typename ArchTag,
|
| 572 |
+
typename ThreadblockShape0,
|
| 573 |
+
typename ThreadblockShape1,
|
| 574 |
+
typename WarpShape0,
|
| 575 |
+
typename WarpShape1,
|
| 576 |
+
typename InstructionShape,
|
| 577 |
+
typename EpilogueOutputOp0,
|
| 578 |
+
typename EpilogueOutputOp1,
|
| 579 |
+
typename ThreadblockSwizzle,
|
| 580 |
+
int Stages,
|
| 581 |
+
typename MathOperatorTag,
|
| 582 |
+
int InterleavedK
|
| 583 |
+
>
|
| 584 |
+
struct DefaultB2bConv2dFprop <
|
| 585 |
+
ElementA,
|
| 586 |
+
layout::TensorNCxHWx<InterleavedK>,
|
| 587 |
+
ElementB,
|
| 588 |
+
layout::TensorCxRSKx<InterleavedK>,
|
| 589 |
+
ElementC,
|
| 590 |
+
LayoutC,
|
| 591 |
+
ElementAccumulator,
|
| 592 |
+
arch::OpClassTensorOp,
|
| 593 |
+
ArchTag,
|
| 594 |
+
ThreadblockShape0,
|
| 595 |
+
ThreadblockShape1,
|
| 596 |
+
WarpShape0,
|
| 597 |
+
WarpShape1,
|
| 598 |
+
InstructionShape,
|
| 599 |
+
EpilogueOutputOp0,
|
| 600 |
+
EpilogueOutputOp1,
|
| 601 |
+
ThreadblockSwizzle,
|
| 602 |
+
Stages,
|
| 603 |
+
MathOperatorTag,
|
| 604 |
+
IteratorAlgorithm::kOptimized
|
| 605 |
+
> {
|
| 606 |
+
|
| 607 |
+
// Define the core components from GEMM
|
| 608 |
+
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 609 |
+
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 610 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 611 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 612 |
+
Stages, MathOperatorTag, true>;
|
| 613 |
+
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 614 |
+
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 615 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 616 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 617 |
+
Stages, MathOperatorTag, true>;
|
| 618 |
+
|
| 619 |
+
// Define iterators over tiles from the A operand
|
| 620 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 621 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 622 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 623 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 624 |
+
// layout.
|
| 625 |
+
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
| 626 |
+
using IteratorA0 =
|
| 627 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
| 628 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
| 629 |
+
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
| 630 |
+
ThreadMapA0
|
| 631 |
+
>;
|
| 632 |
+
|
| 633 |
+
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
| 634 |
+
|
| 635 |
+
// Define iterators over tiles from the B operand
|
| 636 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 637 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 638 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 639 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 640 |
+
// layout.
|
| 641 |
+
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
| 642 |
+
using IteratorB0 =
|
| 643 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 644 |
+
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
| 645 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 646 |
+
ThreadMapB0
|
| 647 |
+
>;
|
| 648 |
+
|
| 649 |
+
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
| 650 |
+
|
| 651 |
+
// Use fragment iterator for A operand
|
| 652 |
+
using AccumulatorLayout = cutlass::layout::RowMajor;
|
| 653 |
+
using FragmentIteratorA1 =
|
| 654 |
+
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
| 655 |
+
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
| 656 |
+
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
| 657 |
+
MmaCore1::Shape::kK, //kBlocksColumn
|
| 658 |
+
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
| 659 |
+
|
| 660 |
+
/// Define iterators over tiles from scale/bias vectors
|
| 661 |
+
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
| 662 |
+
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
| 663 |
+
static int const kElementsPerAccess = 4;
|
| 664 |
+
using IteratorAccumulatorScaleBias =
|
| 665 |
+
cutlass::transform::threadblock::VectorIterator<
|
| 666 |
+
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
| 667 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
| 668 |
+
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
| 669 |
+
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
| 670 |
+
>;
|
| 671 |
+
|
| 672 |
+
// Warp-level iterators to load scale and bias vectors
|
| 673 |
+
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
| 674 |
+
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
| 675 |
+
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
| 676 |
+
|
| 677 |
+
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
| 678 |
+
using IteratorB1 =
|
| 679 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 680 |
+
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
| 681 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 682 |
+
ThreadMapB1
|
| 683 |
+
>;
|
| 684 |
+
|
| 685 |
+
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
// Warp-level GEMM components
|
| 689 |
+
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
| 690 |
+
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
| 691 |
+
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
| 692 |
+
|
| 693 |
+
// Define the Mma
|
| 694 |
+
using B2bMma = threadblock::B2bImplicitGemmMultistage<
|
| 695 |
+
ThreadblockShape0,
|
| 696 |
+
IteratorA0,
|
| 697 |
+
SmemIteratorA0,
|
| 698 |
+
arch::CacheOperation::Always,
|
| 699 |
+
IteratorB0,
|
| 700 |
+
SmemIteratorB0,
|
| 701 |
+
arch::CacheOperation::Global,
|
| 702 |
+
ThreadblockShape1,
|
| 703 |
+
FragmentIteratorA1,
|
| 704 |
+
IteratorAccumulatorScaleBias,
|
| 705 |
+
FragmentIteratorA1ScaleBias,
|
| 706 |
+
IteratorB1,
|
| 707 |
+
SmemIteratorB1,
|
| 708 |
+
arch::CacheOperation::Global,
|
| 709 |
+
EpilogueOutputOp0,
|
| 710 |
+
MmaPolicy0,
|
| 711 |
+
MmaPolicy1,
|
| 712 |
+
Stages
|
| 713 |
+
>;
|
| 714 |
+
|
| 715 |
+
// Define the epilogue
|
| 716 |
+
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
| 717 |
+
ThreadblockShape1,
|
| 718 |
+
WarpMmaTensorOp1,
|
| 719 |
+
1,
|
| 720 |
+
EpilogueOutputOp1,
|
| 721 |
+
EpilogueOutputOp1::kCount,
|
| 722 |
+
InterleavedK
|
| 723 |
+
>::Epilogue;
|
| 724 |
+
|
| 725 |
+
// Define the kernel
|
| 726 |
+
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
| 727 |
+
B2bMma,
|
| 728 |
+
Epilogue,
|
| 729 |
+
ThreadblockSwizzle,
|
| 730 |
+
conv::Operator::kFprop
|
| 731 |
+
>;
|
| 732 |
+
};
|
| 733 |
+
|
| 734 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 735 |
+
|
| 736 |
+
} // namespace kernel
|
| 737 |
+
} // namespace conv
|
| 738 |
+
} // namespace cutlass
|
| 739 |
+
|
| 740 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h
ADDED
|
@@ -0,0 +1,817 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief
|
| 34 |
+
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
| 35 |
+
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
| 36 |
+
*/
|
| 37 |
+
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
#include "cutlass/cutlass.h"
|
| 41 |
+
#include "cutlass/conv/kernel/default_conv2d.h"
|
| 42 |
+
|
| 43 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
| 44 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
| 45 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
| 46 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
| 47 |
+
|
| 48 |
+
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
| 49 |
+
#include "cutlass/transform/threadblock/vector_iterator.h"
|
| 50 |
+
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
| 51 |
+
|
| 52 |
+
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
| 53 |
+
|
| 54 |
+
#include "kernel/default_b2b_conv2d_fprop.h"
|
| 55 |
+
#include "kernel/b2b_implicit_gemm_convolution.h"
|
| 56 |
+
#include "threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h"
|
| 57 |
+
|
| 58 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
|
| 60 |
+
namespace cutlass {
|
| 61 |
+
namespace conv {
|
| 62 |
+
namespace kernel {
|
| 63 |
+
|
| 64 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 65 |
+
|
| 66 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm
|
| 67 |
+
/// and 2 stage pipeline.
|
| 68 |
+
/// Accumulator will be staged in shared memory.
|
| 69 |
+
template <
|
| 70 |
+
typename ElementA,
|
| 71 |
+
typename LayoutA,
|
| 72 |
+
typename ElementB,
|
| 73 |
+
typename LayoutB,
|
| 74 |
+
typename ElementC,
|
| 75 |
+
typename LayoutC,
|
| 76 |
+
typename ElementAccumulator,
|
| 77 |
+
typename ArchTag,
|
| 78 |
+
typename ThreadblockShape0,
|
| 79 |
+
typename ThreadblockShape1,
|
| 80 |
+
typename WarpShape0,
|
| 81 |
+
typename WarpShape1,
|
| 82 |
+
typename InstructionShape,
|
| 83 |
+
typename EpilogueOutputOp0,
|
| 84 |
+
typename EpilogueOutputOp1,
|
| 85 |
+
typename ThreadblockSwizzle,
|
| 86 |
+
typename MathOperatorTag
|
| 87 |
+
>
|
| 88 |
+
struct DefaultB2bConv2dFprop <
|
| 89 |
+
ElementA,
|
| 90 |
+
LayoutA,
|
| 91 |
+
ElementB,
|
| 92 |
+
LayoutB,
|
| 93 |
+
ElementC,
|
| 94 |
+
LayoutC,
|
| 95 |
+
ElementAccumulator,
|
| 96 |
+
arch::OpClassTensorOp,
|
| 97 |
+
ArchTag,
|
| 98 |
+
ThreadblockShape0,
|
| 99 |
+
ThreadblockShape1,
|
| 100 |
+
WarpShape0,
|
| 101 |
+
WarpShape1,
|
| 102 |
+
InstructionShape,
|
| 103 |
+
EpilogueOutputOp0,
|
| 104 |
+
EpilogueOutputOp1,
|
| 105 |
+
ThreadblockSwizzle,
|
| 106 |
+
2,
|
| 107 |
+
MathOperatorTag,
|
| 108 |
+
IteratorAlgorithm::kAnalytic,
|
| 109 |
+
true
|
| 110 |
+
> {
|
| 111 |
+
|
| 112 |
+
// Define the core components from GEMM
|
| 113 |
+
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 114 |
+
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
| 115 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 116 |
+
2, MathOperatorTag>;
|
| 117 |
+
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 118 |
+
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
| 119 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 120 |
+
2, MathOperatorTag>;
|
| 121 |
+
|
| 122 |
+
// Define iterators over tiles from the A operand
|
| 123 |
+
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
| 124 |
+
using IteratorA0 =
|
| 125 |
+
cutlass::conv::threadblock::TileIterator<
|
| 126 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
| 127 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
| 128 |
+
ElementA, LayoutA,
|
| 129 |
+
ThreadMapA0
|
| 130 |
+
>
|
| 131 |
+
>;
|
| 132 |
+
|
| 133 |
+
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
| 134 |
+
|
| 135 |
+
// Define iterators over tiles from the B operand
|
| 136 |
+
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
| 137 |
+
using IteratorB0 =
|
| 138 |
+
cutlass::conv::threadblock::TileIterator<
|
| 139 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 140 |
+
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
| 141 |
+
ElementB, LayoutB,
|
| 142 |
+
ThreadMapB0
|
| 143 |
+
>
|
| 144 |
+
>;
|
| 145 |
+
|
| 146 |
+
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
| 147 |
+
|
| 148 |
+
/// Define iterators over tiles from scale/bias vectors
|
| 149 |
+
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
| 150 |
+
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
| 151 |
+
static int const kElementsPerAccess = 2;
|
| 152 |
+
using IteratorAccumulatorScaleBias =
|
| 153 |
+
cutlass::transform::threadblock::VectorIterator<
|
| 154 |
+
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
| 155 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
| 156 |
+
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
| 157 |
+
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
| 158 |
+
>;
|
| 159 |
+
|
| 160 |
+
// Define iterators over tiles from the B operand
|
| 161 |
+
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
| 162 |
+
using IteratorB1 =
|
| 163 |
+
cutlass::conv::threadblock::TileIterator<
|
| 164 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 165 |
+
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
| 166 |
+
ElementB, LayoutB,
|
| 167 |
+
ThreadMapB1
|
| 168 |
+
>
|
| 169 |
+
>;
|
| 170 |
+
|
| 171 |
+
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
| 172 |
+
|
| 173 |
+
// Warp-level GEMM components
|
| 174 |
+
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
| 175 |
+
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
| 176 |
+
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
| 177 |
+
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
| 178 |
+
|
| 179 |
+
// Use fragment iterator for the accumulator
|
| 180 |
+
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
| 181 |
+
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
| 182 |
+
WarpShape0, InstructionShape,
|
| 183 |
+
ElementAccumulator,
|
| 184 |
+
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
| 185 |
+
SmemAccumulatorLayout
|
| 186 |
+
>;
|
| 187 |
+
|
| 188 |
+
// Store Accumulator tiles to Shared Memory
|
| 189 |
+
using SmemIteratorD0 =
|
| 190 |
+
cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 191 |
+
WarpShape0,
|
| 192 |
+
InstructionShape,
|
| 193 |
+
ElementC,
|
| 194 |
+
SmemAccumulatorLayout
|
| 195 |
+
>;
|
| 196 |
+
|
| 197 |
+
static int const kThreadCount = 32;
|
| 198 |
+
// load warp tile from Shared Memory accumulator
|
| 199 |
+
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
| 200 |
+
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
| 201 |
+
ElementA, SmemAccumulatorLayout,
|
| 202 |
+
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
| 203 |
+
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
| 204 |
+
|
| 205 |
+
// Define the Mma
|
| 206 |
+
using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
|
| 207 |
+
ThreadblockShape0,
|
| 208 |
+
IteratorA0,
|
| 209 |
+
SmemIteratorA0,
|
| 210 |
+
IteratorB0,
|
| 211 |
+
SmemIteratorB0,
|
| 212 |
+
IteratorAccumulatorScaleBias,
|
| 213 |
+
FragmentIteratorAccumulator,
|
| 214 |
+
SmemIteratorD0,
|
| 215 |
+
ThreadblockShape1,
|
| 216 |
+
WarpIteratorA1,
|
| 217 |
+
IteratorB1,
|
| 218 |
+
SmemIteratorB1,
|
| 219 |
+
ElementC,
|
| 220 |
+
LayoutC,
|
| 221 |
+
EpilogueOutputOp0,
|
| 222 |
+
MmaPolicy0,
|
| 223 |
+
MmaPolicy1
|
| 224 |
+
>;
|
| 225 |
+
|
| 226 |
+
// Define the epilogue
|
| 227 |
+
using Epilogue = typename detail::DefaultConvEpilogue<
|
| 228 |
+
ArchTag,
|
| 229 |
+
ThreadblockShape1,
|
| 230 |
+
WarpMmaTensorOp1,
|
| 231 |
+
1,
|
| 232 |
+
EpilogueOutputOp1
|
| 233 |
+
>::Epilogue;
|
| 234 |
+
|
| 235 |
+
// Define the kernel
|
| 236 |
+
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
| 237 |
+
B2bMma,
|
| 238 |
+
Epilogue,
|
| 239 |
+
ThreadblockSwizzle,
|
| 240 |
+
conv::Operator::kFprop
|
| 241 |
+
>;
|
| 242 |
+
};
|
| 243 |
+
|
| 244 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 245 |
+
|
| 246 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage
|
| 247 |
+
/// pipeline with interleaved layout.
|
| 248 |
+
/// Accumulator will be staged in shared memory.
|
| 249 |
+
template <
|
| 250 |
+
typename ElementA,
|
| 251 |
+
typename ElementB,
|
| 252 |
+
typename ElementC,
|
| 253 |
+
typename LayoutC,
|
| 254 |
+
typename ElementAccumulator,
|
| 255 |
+
typename ArchTag,
|
| 256 |
+
typename ThreadblockShape0,
|
| 257 |
+
typename ThreadblockShape1,
|
| 258 |
+
typename WarpShape0,
|
| 259 |
+
typename WarpShape1,
|
| 260 |
+
typename InstructionShape,
|
| 261 |
+
typename EpilogueOutputOp0,
|
| 262 |
+
typename EpilogueOutputOp1,
|
| 263 |
+
typename ThreadblockSwizzle,
|
| 264 |
+
typename MathOperatorTag,
|
| 265 |
+
int InterleavedK
|
| 266 |
+
>
|
| 267 |
+
struct DefaultB2bConv2dFprop <
|
| 268 |
+
ElementA,
|
| 269 |
+
layout::TensorNCxHWx<InterleavedK>,
|
| 270 |
+
ElementB,
|
| 271 |
+
layout::TensorCxRSKx<InterleavedK>,
|
| 272 |
+
ElementC,
|
| 273 |
+
LayoutC,
|
| 274 |
+
ElementAccumulator,
|
| 275 |
+
arch::OpClassTensorOp,
|
| 276 |
+
ArchTag,
|
| 277 |
+
ThreadblockShape0,
|
| 278 |
+
ThreadblockShape1,
|
| 279 |
+
WarpShape0,
|
| 280 |
+
WarpShape1,
|
| 281 |
+
InstructionShape,
|
| 282 |
+
EpilogueOutputOp0,
|
| 283 |
+
EpilogueOutputOp1,
|
| 284 |
+
ThreadblockSwizzle,
|
| 285 |
+
2,
|
| 286 |
+
MathOperatorTag,
|
| 287 |
+
IteratorAlgorithm::kAnalytic,
|
| 288 |
+
true
|
| 289 |
+
> {
|
| 290 |
+
|
| 291 |
+
// Define the core components from GEMM
|
| 292 |
+
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 293 |
+
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 294 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 295 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 296 |
+
2, MathOperatorTag, true>;
|
| 297 |
+
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 298 |
+
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 299 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 300 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 301 |
+
2, MathOperatorTag, true>;
|
| 302 |
+
|
| 303 |
+
// Define iterators over tiles from the A operand
|
| 304 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 305 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 306 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 307 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 308 |
+
// layout.
|
| 309 |
+
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
| 310 |
+
using IteratorA0 =
|
| 311 |
+
cutlass::conv::threadblock::TileIterator<
|
| 312 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
| 313 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
| 314 |
+
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
| 315 |
+
ThreadMapA0
|
| 316 |
+
>
|
| 317 |
+
>;
|
| 318 |
+
|
| 319 |
+
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
| 320 |
+
|
| 321 |
+
// Define iterators over tiles from the B operand
|
| 322 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 323 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 324 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 325 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 326 |
+
// layout.
|
| 327 |
+
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
| 328 |
+
using IteratorB0 =
|
| 329 |
+
cutlass::conv::threadblock::TileIterator<
|
| 330 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 331 |
+
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
| 332 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 333 |
+
ThreadMapB0
|
| 334 |
+
>
|
| 335 |
+
>;
|
| 336 |
+
|
| 337 |
+
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
| 338 |
+
|
| 339 |
+
/// Define iterators over tiles from scale/bias vectors
|
| 340 |
+
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
| 341 |
+
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
| 342 |
+
static int const kElementsPerAccess = 4; //For interleaved layout
|
| 343 |
+
using IteratorAccumulatorScaleBias =
|
| 344 |
+
cutlass::transform::threadblock::VectorIterator<
|
| 345 |
+
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
| 346 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
| 347 |
+
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
| 348 |
+
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
| 349 |
+
>;
|
| 350 |
+
|
| 351 |
+
// Define iterators over tiles from the B operand
|
| 352 |
+
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
| 353 |
+
using IteratorB1 =
|
| 354 |
+
cutlass::conv::threadblock::TileIterator<
|
| 355 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 356 |
+
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
| 357 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 358 |
+
ThreadMapB1
|
| 359 |
+
>
|
| 360 |
+
>;
|
| 361 |
+
|
| 362 |
+
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
| 363 |
+
|
| 364 |
+
// Warp-level GEMM components
|
| 365 |
+
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
| 366 |
+
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
| 367 |
+
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
| 368 |
+
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
| 369 |
+
|
| 370 |
+
// Use fragment iterator for the accumulator
|
| 371 |
+
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
| 372 |
+
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
| 373 |
+
WarpShape0, InstructionShape,
|
| 374 |
+
ElementAccumulator,
|
| 375 |
+
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
| 376 |
+
SmemAccumulatorLayout
|
| 377 |
+
>;
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
// Store Accumulator tiles to Shared Memory
|
| 381 |
+
using SmemIteratorD0 =
|
| 382 |
+
cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 383 |
+
WarpShape0,
|
| 384 |
+
InstructionShape,
|
| 385 |
+
ElementC,
|
| 386 |
+
SmemAccumulatorLayout
|
| 387 |
+
>;
|
| 388 |
+
|
| 389 |
+
static int const kThreadCount = 32;
|
| 390 |
+
// load warp tile from Shared Memory accumulator
|
| 391 |
+
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
| 392 |
+
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
| 393 |
+
ElementA, SmemAccumulatorLayout,
|
| 394 |
+
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
| 395 |
+
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
| 396 |
+
|
| 397 |
+
// Define the Mma
|
| 398 |
+
using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
|
| 399 |
+
ThreadblockShape0,
|
| 400 |
+
IteratorA0,
|
| 401 |
+
SmemIteratorA0,
|
| 402 |
+
IteratorB0,
|
| 403 |
+
SmemIteratorB0,
|
| 404 |
+
IteratorAccumulatorScaleBias,
|
| 405 |
+
FragmentIteratorAccumulator,
|
| 406 |
+
SmemIteratorD0,
|
| 407 |
+
ThreadblockShape1,
|
| 408 |
+
WarpIteratorA1,
|
| 409 |
+
IteratorB1,
|
| 410 |
+
SmemIteratorB1,
|
| 411 |
+
ElementC,
|
| 412 |
+
LayoutC,
|
| 413 |
+
EpilogueOutputOp0,
|
| 414 |
+
MmaPolicy0,
|
| 415 |
+
MmaPolicy1
|
| 416 |
+
>;
|
| 417 |
+
|
| 418 |
+
// Define the epilogue
|
| 419 |
+
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
| 420 |
+
ThreadblockShape1,
|
| 421 |
+
WarpMmaTensorOp1,
|
| 422 |
+
1,
|
| 423 |
+
EpilogueOutputOp1,
|
| 424 |
+
EpilogueOutputOp1::kCount,
|
| 425 |
+
InterleavedK
|
| 426 |
+
>::Epilogue;
|
| 427 |
+
|
| 428 |
+
// Define the kernel
|
| 429 |
+
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
| 430 |
+
B2bMma,
|
| 431 |
+
Epilogue,
|
| 432 |
+
ThreadblockSwizzle,
|
| 433 |
+
conv::Operator::kFprop
|
| 434 |
+
>;
|
| 435 |
+
};
|
| 436 |
+
|
| 437 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 438 |
+
|
| 439 |
+
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm
|
| 440 |
+
/// and 2 stage pipeline.
|
| 441 |
+
/// Accumulator will be staged in shared memory.
|
| 442 |
+
template <
|
| 443 |
+
typename ElementA,
|
| 444 |
+
typename LayoutA,
|
| 445 |
+
typename ElementB,
|
| 446 |
+
typename LayoutB,
|
| 447 |
+
typename ElementC,
|
| 448 |
+
typename LayoutC,
|
| 449 |
+
typename ElementAccumulator,
|
| 450 |
+
typename ArchTag,
|
| 451 |
+
typename ThreadblockShape0,
|
| 452 |
+
typename ThreadblockShape1,
|
| 453 |
+
typename WarpShape0,
|
| 454 |
+
typename WarpShape1,
|
| 455 |
+
typename InstructionShape,
|
| 456 |
+
typename EpilogueOutputOp0,
|
| 457 |
+
typename EpilogueOutputOp1,
|
| 458 |
+
typename ThreadblockSwizzle,
|
| 459 |
+
typename MathOperatorTag
|
| 460 |
+
>
|
| 461 |
+
struct DefaultB2bConv2dFprop <
|
| 462 |
+
ElementA,
|
| 463 |
+
LayoutA,
|
| 464 |
+
ElementB,
|
| 465 |
+
LayoutB,
|
| 466 |
+
ElementC,
|
| 467 |
+
LayoutC,
|
| 468 |
+
ElementAccumulator,
|
| 469 |
+
arch::OpClassTensorOp,
|
| 470 |
+
ArchTag,
|
| 471 |
+
ThreadblockShape0,
|
| 472 |
+
ThreadblockShape1,
|
| 473 |
+
WarpShape0,
|
| 474 |
+
WarpShape1,
|
| 475 |
+
InstructionShape,
|
| 476 |
+
EpilogueOutputOp0,
|
| 477 |
+
EpilogueOutputOp1,
|
| 478 |
+
ThreadblockSwizzle,
|
| 479 |
+
2,
|
| 480 |
+
MathOperatorTag,
|
| 481 |
+
IteratorAlgorithm::kOptimized,
|
| 482 |
+
true
|
| 483 |
+
> {
|
| 484 |
+
|
| 485 |
+
// Define the core components from GEMM
|
| 486 |
+
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 487 |
+
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
| 488 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 489 |
+
2, MathOperatorTag>;
|
| 490 |
+
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 491 |
+
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
| 492 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 493 |
+
2, MathOperatorTag>;
|
| 494 |
+
|
| 495 |
+
// Define iterators over tiles from the A operand
|
| 496 |
+
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
| 497 |
+
using IteratorA0 =
|
| 498 |
+
cutlass::conv::threadblock::TileIterator<
|
| 499 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
| 500 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
| 501 |
+
ElementA, LayoutA,
|
| 502 |
+
ThreadMapA0
|
| 503 |
+
>
|
| 504 |
+
>;
|
| 505 |
+
|
| 506 |
+
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
| 507 |
+
|
| 508 |
+
// Define iterators over tiles from the B operand
|
| 509 |
+
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
| 510 |
+
using IteratorB0 =
|
| 511 |
+
cutlass::conv::threadblock::TileIterator<
|
| 512 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 513 |
+
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
| 514 |
+
ElementB, LayoutB,
|
| 515 |
+
ThreadMapB0
|
| 516 |
+
>
|
| 517 |
+
>;
|
| 518 |
+
|
| 519 |
+
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
| 520 |
+
|
| 521 |
+
/// Define iterators over tiles from scale/bias vectors
|
| 522 |
+
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
| 523 |
+
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
| 524 |
+
static int const kElementsPerAccess = 2;
|
| 525 |
+
using IteratorAccumulatorScaleBias =
|
| 526 |
+
cutlass::transform::threadblock::VectorIterator<
|
| 527 |
+
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
| 528 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
| 529 |
+
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
| 530 |
+
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
| 531 |
+
>;
|
| 532 |
+
|
| 533 |
+
// Define iterators over tiles from the B operand
|
| 534 |
+
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
| 535 |
+
using IteratorB1 =
|
| 536 |
+
cutlass::conv::threadblock::TileIterator<
|
| 537 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 538 |
+
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
| 539 |
+
ElementB, LayoutB,
|
| 540 |
+
ThreadMapB1
|
| 541 |
+
>
|
| 542 |
+
>;
|
| 543 |
+
|
| 544 |
+
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
| 545 |
+
|
| 546 |
+
// Warp-level GEMM components
|
| 547 |
+
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
| 548 |
+
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
| 549 |
+
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
| 550 |
+
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
| 551 |
+
|
| 552 |
+
// Use fragment iterator for the accumulator
|
| 553 |
+
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
| 554 |
+
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
| 555 |
+
WarpShape0, InstructionShape,
|
| 556 |
+
ElementAccumulator,
|
| 557 |
+
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
| 558 |
+
SmemAccumulatorLayout
|
| 559 |
+
>;
|
| 560 |
+
|
| 561 |
+
// Store Accumulator tiles to Shared Memory
|
| 562 |
+
using SmemIteratorD0 =
|
| 563 |
+
cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 564 |
+
WarpShape0,
|
| 565 |
+
InstructionShape,
|
| 566 |
+
ElementC,
|
| 567 |
+
SmemAccumulatorLayout
|
| 568 |
+
>;
|
| 569 |
+
|
| 570 |
+
static int const kThreadCount = 32;
|
| 571 |
+
// load warp tile from Shared Memory accumulator
|
| 572 |
+
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
| 573 |
+
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
| 574 |
+
ElementA, SmemAccumulatorLayout,
|
| 575 |
+
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
| 576 |
+
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
| 577 |
+
|
| 578 |
+
// Define the Mma
|
| 579 |
+
using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
|
| 580 |
+
ThreadblockShape0,
|
| 581 |
+
IteratorA0,
|
| 582 |
+
SmemIteratorA0,
|
| 583 |
+
IteratorB0,
|
| 584 |
+
SmemIteratorB0,
|
| 585 |
+
IteratorAccumulatorScaleBias,
|
| 586 |
+
FragmentIteratorAccumulator,
|
| 587 |
+
SmemIteratorD0,
|
| 588 |
+
ThreadblockShape1,
|
| 589 |
+
WarpIteratorA1,
|
| 590 |
+
IteratorB1,
|
| 591 |
+
SmemIteratorB1,
|
| 592 |
+
ElementC,
|
| 593 |
+
LayoutC,
|
| 594 |
+
EpilogueOutputOp0,
|
| 595 |
+
MmaPolicy0,
|
| 596 |
+
MmaPolicy1
|
| 597 |
+
>;
|
| 598 |
+
|
| 599 |
+
// Define the epilogue
|
| 600 |
+
using Epilogue = typename detail::DefaultConvEpilogue<
|
| 601 |
+
ArchTag,
|
| 602 |
+
ThreadblockShape1,
|
| 603 |
+
WarpMmaTensorOp1,
|
| 604 |
+
1,
|
| 605 |
+
EpilogueOutputOp1
|
| 606 |
+
>::Epilogue;
|
| 607 |
+
|
| 608 |
+
// Define the kernel
|
| 609 |
+
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
| 610 |
+
B2bMma,
|
| 611 |
+
Epilogue,
|
| 612 |
+
ThreadblockSwizzle,
|
| 613 |
+
conv::Operator::kFprop
|
| 614 |
+
>;
|
| 615 |
+
};
|
| 616 |
+
|
| 617 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 618 |
+
|
| 619 |
+
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage
|
| 620 |
+
/// pipeline with interleaved layout.
|
| 621 |
+
/// Accumulator will be staged in shared memory.
|
| 622 |
+
template <
|
| 623 |
+
typename ElementA,
|
| 624 |
+
typename ElementB,
|
| 625 |
+
typename ElementC,
|
| 626 |
+
typename LayoutC,
|
| 627 |
+
typename ElementAccumulator,
|
| 628 |
+
typename ArchTag,
|
| 629 |
+
typename ThreadblockShape0,
|
| 630 |
+
typename ThreadblockShape1,
|
| 631 |
+
typename WarpShape0,
|
| 632 |
+
typename WarpShape1,
|
| 633 |
+
typename InstructionShape,
|
| 634 |
+
typename EpilogueOutputOp0,
|
| 635 |
+
typename EpilogueOutputOp1,
|
| 636 |
+
typename ThreadblockSwizzle,
|
| 637 |
+
typename MathOperatorTag,
|
| 638 |
+
int InterleavedK
|
| 639 |
+
>
|
| 640 |
+
struct DefaultB2bConv2dFprop <
|
| 641 |
+
ElementA,
|
| 642 |
+
layout::TensorNCxHWx<InterleavedK>,
|
| 643 |
+
ElementB,
|
| 644 |
+
layout::TensorCxRSKx<InterleavedK>,
|
| 645 |
+
ElementC,
|
| 646 |
+
LayoutC,
|
| 647 |
+
ElementAccumulator,
|
| 648 |
+
arch::OpClassTensorOp,
|
| 649 |
+
ArchTag,
|
| 650 |
+
ThreadblockShape0,
|
| 651 |
+
ThreadblockShape1,
|
| 652 |
+
WarpShape0,
|
| 653 |
+
WarpShape1,
|
| 654 |
+
InstructionShape,
|
| 655 |
+
EpilogueOutputOp0,
|
| 656 |
+
EpilogueOutputOp1,
|
| 657 |
+
ThreadblockSwizzle,
|
| 658 |
+
2,
|
| 659 |
+
MathOperatorTag,
|
| 660 |
+
IteratorAlgorithm::kOptimized,
|
| 661 |
+
true
|
| 662 |
+
> {
|
| 663 |
+
|
| 664 |
+
// Define the core components from GEMM
|
| 665 |
+
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 666 |
+
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 667 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 668 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 669 |
+
2, MathOperatorTag, true>;
|
| 670 |
+
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 671 |
+
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 672 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 673 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 674 |
+
2, MathOperatorTag, true>;
|
| 675 |
+
|
| 676 |
+
// Define iterators over tiles from the A operand
|
| 677 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 678 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 679 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 680 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 681 |
+
// layout.
|
| 682 |
+
|
| 683 |
+
// Define iterators over tiles from the A operand
|
| 684 |
+
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
| 685 |
+
using IteratorA0 =
|
| 686 |
+
cutlass::conv::threadblock::TileIterator<
|
| 687 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
| 688 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
| 689 |
+
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
| 690 |
+
ThreadMapA0
|
| 691 |
+
>
|
| 692 |
+
>;
|
| 693 |
+
|
| 694 |
+
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
| 695 |
+
|
| 696 |
+
// Define iterators over tiles from the B operand
|
| 697 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 698 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 699 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 700 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 701 |
+
// layout.
|
| 702 |
+
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
| 703 |
+
using IteratorB0 =
|
| 704 |
+
cutlass::conv::threadblock::TileIterator<
|
| 705 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 706 |
+
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
| 707 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 708 |
+
ThreadMapB0
|
| 709 |
+
>
|
| 710 |
+
>;
|
| 711 |
+
|
| 712 |
+
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
| 713 |
+
|
| 714 |
+
/// Define iterators over tiles from scale/bias vectors
|
| 715 |
+
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
| 716 |
+
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
| 717 |
+
static int const kElementsPerAccess = 4; //For interleaved layout
|
| 718 |
+
using IteratorAccumulatorScaleBias =
|
| 719 |
+
cutlass::transform::threadblock::VectorIterator<
|
| 720 |
+
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
| 721 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
| 722 |
+
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
| 723 |
+
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
| 724 |
+
>;
|
| 725 |
+
|
| 726 |
+
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
| 727 |
+
using IteratorB1 =
|
| 728 |
+
cutlass::conv::threadblock::TileIterator<
|
| 729 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 730 |
+
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
| 731 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 732 |
+
ThreadMapB1
|
| 733 |
+
>
|
| 734 |
+
>;
|
| 735 |
+
|
| 736 |
+
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
| 737 |
+
|
| 738 |
+
// Warp-level GEMM components
|
| 739 |
+
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
| 740 |
+
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
| 741 |
+
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
| 742 |
+
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
| 743 |
+
|
| 744 |
+
// Use fragment iterator for the accumulator
|
| 745 |
+
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
| 746 |
+
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
| 747 |
+
WarpShape0, InstructionShape,
|
| 748 |
+
ElementAccumulator,
|
| 749 |
+
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
| 750 |
+
SmemAccumulatorLayout
|
| 751 |
+
>;
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
// Store Accumulator tiles to Shared Memory
|
| 755 |
+
using SmemIteratorD0 =
|
| 756 |
+
cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 757 |
+
WarpShape0,
|
| 758 |
+
InstructionShape,
|
| 759 |
+
ElementC,
|
| 760 |
+
SmemAccumulatorLayout
|
| 761 |
+
>;
|
| 762 |
+
|
| 763 |
+
static int const kThreadCount = 32;
|
| 764 |
+
// load warp tile from Shared Memory accumulator
|
| 765 |
+
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
| 766 |
+
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
| 767 |
+
ElementA, SmemAccumulatorLayout,
|
| 768 |
+
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
| 769 |
+
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
| 770 |
+
|
| 771 |
+
// Define the Mma
|
| 772 |
+
using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
|
| 773 |
+
ThreadblockShape0,
|
| 774 |
+
IteratorA0,
|
| 775 |
+
SmemIteratorA0,
|
| 776 |
+
IteratorB0,
|
| 777 |
+
SmemIteratorB0,
|
| 778 |
+
IteratorAccumulatorScaleBias,
|
| 779 |
+
FragmentIteratorAccumulator,
|
| 780 |
+
SmemIteratorD0,
|
| 781 |
+
ThreadblockShape1,
|
| 782 |
+
WarpIteratorA1,
|
| 783 |
+
IteratorB1,
|
| 784 |
+
SmemIteratorB1,
|
| 785 |
+
ElementC,
|
| 786 |
+
LayoutC,
|
| 787 |
+
EpilogueOutputOp0,
|
| 788 |
+
MmaPolicy0,
|
| 789 |
+
MmaPolicy1
|
| 790 |
+
>;
|
| 791 |
+
|
| 792 |
+
// Define the epilogue
|
| 793 |
+
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
| 794 |
+
ThreadblockShape1,
|
| 795 |
+
WarpMmaTensorOp1,
|
| 796 |
+
1,
|
| 797 |
+
EpilogueOutputOp1,
|
| 798 |
+
EpilogueOutputOp1::kCount,
|
| 799 |
+
InterleavedK
|
| 800 |
+
>::Epilogue;
|
| 801 |
+
|
| 802 |
+
// Define the kernel
|
| 803 |
+
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
| 804 |
+
B2bMma,
|
| 805 |
+
Epilogue,
|
| 806 |
+
ThreadblockSwizzle,
|
| 807 |
+
conv::Operator::kFprop
|
| 808 |
+
>;
|
| 809 |
+
};
|
| 810 |
+
|
| 811 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 812 |
+
|
| 813 |
+
} // namespace kernel
|
| 814 |
+
} // namespace conv
|
| 815 |
+
} // namespace cutlass
|
| 816 |
+
|
| 817 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h
ADDED
|
@@ -0,0 +1,804 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief
|
| 34 |
+
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
| 35 |
+
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
| 36 |
+
*/
|
| 37 |
+
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
#include "cutlass/cutlass.h"
|
| 41 |
+
#include "cutlass/conv/kernel/default_conv2d.h"
|
| 42 |
+
|
| 43 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
| 44 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
| 45 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
| 46 |
+
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
| 47 |
+
|
| 48 |
+
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
| 49 |
+
#include "cutlass/transform/threadblock/vector_iterator.h"
|
| 50 |
+
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
| 51 |
+
|
| 52 |
+
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
| 53 |
+
|
| 54 |
+
#include "kernel/default_b2b_conv2d_fprop.h"
|
| 55 |
+
#include "kernel/b2b_implicit_gemm_convolution.h"
|
| 56 |
+
#include "threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h"
|
| 57 |
+
|
| 58 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
|
| 60 |
+
namespace cutlass {
|
| 61 |
+
namespace conv {
|
| 62 |
+
namespace kernel {
|
| 63 |
+
|
| 64 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 65 |
+
|
| 66 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
| 67 |
+
/// pipeline.
|
| 68 |
+
/// Accumulator will be staged in shared memory.
|
| 69 |
+
template <
|
| 70 |
+
typename ElementA,
|
| 71 |
+
typename LayoutA,
|
| 72 |
+
typename ElementB,
|
| 73 |
+
typename LayoutB,
|
| 74 |
+
typename ElementC,
|
| 75 |
+
typename LayoutC,
|
| 76 |
+
typename ElementAccumulator,
|
| 77 |
+
typename ArchTag,
|
| 78 |
+
typename ThreadblockShape0,
|
| 79 |
+
typename ThreadblockShape1,
|
| 80 |
+
typename WarpShape0,
|
| 81 |
+
typename WarpShape1,
|
| 82 |
+
typename InstructionShape,
|
| 83 |
+
typename EpilogueOutputOp0,
|
| 84 |
+
typename EpilogueOutputOp1,
|
| 85 |
+
typename ThreadblockSwizzle,
|
| 86 |
+
int Stages,
|
| 87 |
+
typename MathOperatorTag
|
| 88 |
+
>
|
| 89 |
+
struct DefaultB2bConv2dFprop <
|
| 90 |
+
ElementA,
|
| 91 |
+
LayoutA,
|
| 92 |
+
ElementB,
|
| 93 |
+
LayoutB,
|
| 94 |
+
ElementC,
|
| 95 |
+
LayoutC,
|
| 96 |
+
ElementAccumulator,
|
| 97 |
+
arch::OpClassTensorOp,
|
| 98 |
+
ArchTag,
|
| 99 |
+
ThreadblockShape0,
|
| 100 |
+
ThreadblockShape1,
|
| 101 |
+
WarpShape0,
|
| 102 |
+
WarpShape1,
|
| 103 |
+
InstructionShape,
|
| 104 |
+
EpilogueOutputOp0,
|
| 105 |
+
EpilogueOutputOp1,
|
| 106 |
+
ThreadblockSwizzle,
|
| 107 |
+
Stages,
|
| 108 |
+
MathOperatorTag,
|
| 109 |
+
IteratorAlgorithm::kAnalytic,
|
| 110 |
+
true
|
| 111 |
+
> {
|
| 112 |
+
|
| 113 |
+
// Define the core components from GEMM
|
| 114 |
+
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 115 |
+
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
| 116 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 117 |
+
Stages, MathOperatorTag>;
|
| 118 |
+
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 119 |
+
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
| 120 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 121 |
+
Stages, MathOperatorTag>;
|
| 122 |
+
|
| 123 |
+
// Define iterators over tiles from the A operand
|
| 124 |
+
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
| 125 |
+
using IteratorA0 =
|
| 126 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
| 127 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
| 128 |
+
ElementA, LayoutA,
|
| 129 |
+
ThreadMapA0
|
| 130 |
+
>;
|
| 131 |
+
|
| 132 |
+
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
| 133 |
+
|
| 134 |
+
// Define iterators over tiles from the B operand
|
| 135 |
+
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
| 136 |
+
using IteratorB0 =
|
| 137 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 138 |
+
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
| 139 |
+
ElementB, LayoutB,
|
| 140 |
+
ThreadMapB0
|
| 141 |
+
>;
|
| 142 |
+
|
| 143 |
+
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
| 144 |
+
|
| 145 |
+
/// Define iterators over tiles from scale/bias vectors
|
| 146 |
+
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
| 147 |
+
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
| 148 |
+
static int const kElementsPerAccess = 2;
|
| 149 |
+
using IteratorAccumulatorScaleBias =
|
| 150 |
+
cutlass::transform::threadblock::VectorIterator<
|
| 151 |
+
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
| 152 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
| 153 |
+
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
| 154 |
+
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
| 155 |
+
>;
|
| 156 |
+
|
| 157 |
+
// Define iterators over tiles from the B operand
|
| 158 |
+
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
| 159 |
+
using IteratorB1 =
|
| 160 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 161 |
+
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
| 162 |
+
ElementB, LayoutB,
|
| 163 |
+
ThreadMapB1
|
| 164 |
+
>;
|
| 165 |
+
|
| 166 |
+
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
| 167 |
+
|
| 168 |
+
// Warp-level GEMM components
|
| 169 |
+
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
| 170 |
+
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
| 171 |
+
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
| 172 |
+
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
| 173 |
+
|
| 174 |
+
// Use fragment iterator for the accumulator
|
| 175 |
+
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
| 176 |
+
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
| 177 |
+
WarpShape0, InstructionShape,
|
| 178 |
+
ElementAccumulator,
|
| 179 |
+
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
| 180 |
+
SmemAccumulatorLayout
|
| 181 |
+
>;
|
| 182 |
+
|
| 183 |
+
// Store Accumulator tiles to Shared Memory
|
| 184 |
+
using SmemIteratorD0 =
|
| 185 |
+
cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 186 |
+
WarpShape0,
|
| 187 |
+
InstructionShape,
|
| 188 |
+
ElementC,
|
| 189 |
+
SmemAccumulatorLayout
|
| 190 |
+
>;
|
| 191 |
+
|
| 192 |
+
static int const kThreadCount = 32;
|
| 193 |
+
// load warp tile from Shared Memory accumulator
|
| 194 |
+
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
| 195 |
+
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
| 196 |
+
ElementA, SmemAccumulatorLayout,
|
| 197 |
+
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
| 198 |
+
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
| 199 |
+
|
| 200 |
+
// Define the Mma
|
| 201 |
+
using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
|
| 202 |
+
ThreadblockShape0,
|
| 203 |
+
IteratorA0,
|
| 204 |
+
SmemIteratorA0,
|
| 205 |
+
arch::CacheOperation::Always,
|
| 206 |
+
IteratorB0,
|
| 207 |
+
SmemIteratorB0,
|
| 208 |
+
arch::CacheOperation::Global,
|
| 209 |
+
IteratorAccumulatorScaleBias,
|
| 210 |
+
FragmentIteratorAccumulator,
|
| 211 |
+
SmemIteratorD0,
|
| 212 |
+
ThreadblockShape1,
|
| 213 |
+
WarpIteratorA1,
|
| 214 |
+
IteratorB1,
|
| 215 |
+
SmemIteratorB1,
|
| 216 |
+
arch::CacheOperation::Global,
|
| 217 |
+
EpilogueOutputOp0,
|
| 218 |
+
MmaPolicy0,
|
| 219 |
+
MmaPolicy1,
|
| 220 |
+
Stages
|
| 221 |
+
>;
|
| 222 |
+
|
| 223 |
+
// Define the epilogue
|
| 224 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 225 |
+
ThreadblockShape1,
|
| 226 |
+
WarpMmaTensorOp1,
|
| 227 |
+
1,
|
| 228 |
+
EpilogueOutputOp1,
|
| 229 |
+
EpilogueOutputOp1::kCount
|
| 230 |
+
>::Epilogue;
|
| 231 |
+
|
| 232 |
+
// Define the kernel
|
| 233 |
+
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
| 234 |
+
B2bMma,
|
| 235 |
+
Epilogue,
|
| 236 |
+
ThreadblockSwizzle,
|
| 237 |
+
conv::Operator::kFprop
|
| 238 |
+
>;
|
| 239 |
+
};
|
| 240 |
+
|
| 241 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 242 |
+
|
| 243 |
+
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
| 244 |
+
/// pipeline with interleaved layout.
|
| 245 |
+
/// Accumulator will be staged in shared memory.
|
| 246 |
+
template <
|
| 247 |
+
typename ElementA,
|
| 248 |
+
typename ElementB,
|
| 249 |
+
typename ElementC,
|
| 250 |
+
typename LayoutC,
|
| 251 |
+
typename ElementAccumulator,
|
| 252 |
+
typename ArchTag,
|
| 253 |
+
typename ThreadblockShape0,
|
| 254 |
+
typename ThreadblockShape1,
|
| 255 |
+
typename WarpShape0,
|
| 256 |
+
typename WarpShape1,
|
| 257 |
+
typename InstructionShape,
|
| 258 |
+
typename EpilogueOutputOp0,
|
| 259 |
+
typename EpilogueOutputOp1,
|
| 260 |
+
typename ThreadblockSwizzle,
|
| 261 |
+
int Stages,
|
| 262 |
+
typename MathOperatorTag,
|
| 263 |
+
int InterleavedK
|
| 264 |
+
>
|
| 265 |
+
struct DefaultB2bConv2dFprop <
|
| 266 |
+
ElementA,
|
| 267 |
+
layout::TensorNCxHWx<InterleavedK>,
|
| 268 |
+
ElementB,
|
| 269 |
+
layout::TensorCxRSKx<InterleavedK>,
|
| 270 |
+
ElementC,
|
| 271 |
+
LayoutC,
|
| 272 |
+
ElementAccumulator,
|
| 273 |
+
arch::OpClassTensorOp,
|
| 274 |
+
ArchTag,
|
| 275 |
+
ThreadblockShape0,
|
| 276 |
+
ThreadblockShape1,
|
| 277 |
+
WarpShape0,
|
| 278 |
+
WarpShape1,
|
| 279 |
+
InstructionShape,
|
| 280 |
+
EpilogueOutputOp0,
|
| 281 |
+
EpilogueOutputOp1,
|
| 282 |
+
ThreadblockSwizzle,
|
| 283 |
+
Stages,
|
| 284 |
+
MathOperatorTag,
|
| 285 |
+
IteratorAlgorithm::kAnalytic,
|
| 286 |
+
true
|
| 287 |
+
> {
|
| 288 |
+
|
| 289 |
+
// Define the core components from GEMM
|
| 290 |
+
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 291 |
+
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 292 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 293 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 294 |
+
Stages, MathOperatorTag, true>;
|
| 295 |
+
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 296 |
+
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 297 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 298 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 299 |
+
Stages, MathOperatorTag, true>;
|
| 300 |
+
|
| 301 |
+
// Define iterators over tiles from the A operand
|
| 302 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 303 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 304 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 305 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 306 |
+
// layout.
|
| 307 |
+
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
| 308 |
+
using IteratorA0 =
|
| 309 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
| 310 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
| 311 |
+
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
| 312 |
+
ThreadMapA0
|
| 313 |
+
>;
|
| 314 |
+
|
| 315 |
+
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
| 316 |
+
|
| 317 |
+
// Define iterators over tiles from the B operand
|
| 318 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 319 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 320 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 321 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 322 |
+
// layout.
|
| 323 |
+
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
| 324 |
+
using IteratorB0 =
|
| 325 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 326 |
+
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
| 327 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 328 |
+
ThreadMapB0
|
| 329 |
+
>;
|
| 330 |
+
|
| 331 |
+
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
| 332 |
+
|
| 333 |
+
/// Define iterators over tiles from scale/bias vectors
|
| 334 |
+
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
| 335 |
+
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
| 336 |
+
static int const kElementsPerAccess = 4;
|
| 337 |
+
using IteratorAccumulatorScaleBias =
|
| 338 |
+
cutlass::transform::threadblock::VectorIterator<
|
| 339 |
+
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
| 340 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
| 341 |
+
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
| 342 |
+
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
| 343 |
+
>;
|
| 344 |
+
|
| 345 |
+
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
| 346 |
+
using IteratorB1 =
|
| 347 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
| 348 |
+
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
| 349 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 350 |
+
ThreadMapB1
|
| 351 |
+
>;
|
| 352 |
+
|
| 353 |
+
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
| 354 |
+
|
| 355 |
+
// Warp-level GEMM components
|
| 356 |
+
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
| 357 |
+
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
| 358 |
+
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
| 359 |
+
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
| 360 |
+
|
| 361 |
+
// Use fragment iterator for the accumulator
|
| 362 |
+
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
| 363 |
+
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
| 364 |
+
WarpShape0, InstructionShape,
|
| 365 |
+
ElementAccumulator,
|
| 366 |
+
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
| 367 |
+
SmemAccumulatorLayout
|
| 368 |
+
>;
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
// Store Accumulator tiles to Shared Memory
|
| 372 |
+
using SmemIteratorD0 =
|
| 373 |
+
cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 374 |
+
WarpShape0,
|
| 375 |
+
InstructionShape,
|
| 376 |
+
ElementC,
|
| 377 |
+
SmemAccumulatorLayout
|
| 378 |
+
>;
|
| 379 |
+
|
| 380 |
+
static int const kThreadCount = 32;
|
| 381 |
+
// load warp tile from Shared Memory accumulator
|
| 382 |
+
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
| 383 |
+
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
| 384 |
+
ElementA, SmemAccumulatorLayout,
|
| 385 |
+
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
| 386 |
+
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
| 387 |
+
|
| 388 |
+
// Define the Mma
|
| 389 |
+
using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
|
| 390 |
+
ThreadblockShape0,
|
| 391 |
+
IteratorA0,
|
| 392 |
+
SmemIteratorA0,
|
| 393 |
+
arch::CacheOperation::Always,
|
| 394 |
+
IteratorB0,
|
| 395 |
+
SmemIteratorB0,
|
| 396 |
+
arch::CacheOperation::Global,
|
| 397 |
+
IteratorAccumulatorScaleBias,
|
| 398 |
+
FragmentIteratorAccumulator,
|
| 399 |
+
SmemIteratorD0,
|
| 400 |
+
ThreadblockShape1,
|
| 401 |
+
WarpIteratorA1,
|
| 402 |
+
IteratorB1,
|
| 403 |
+
SmemIteratorB1,
|
| 404 |
+
arch::CacheOperation::Global,
|
| 405 |
+
EpilogueOutputOp0,
|
| 406 |
+
MmaPolicy0,
|
| 407 |
+
MmaPolicy1,
|
| 408 |
+
Stages
|
| 409 |
+
>;
|
| 410 |
+
|
| 411 |
+
// Define the epilogue
|
| 412 |
+
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
| 413 |
+
ThreadblockShape1,
|
| 414 |
+
WarpMmaTensorOp1,
|
| 415 |
+
1,
|
| 416 |
+
EpilogueOutputOp1,
|
| 417 |
+
EpilogueOutputOp1::kCount,
|
| 418 |
+
InterleavedK
|
| 419 |
+
>::Epilogue;
|
| 420 |
+
|
| 421 |
+
// Define the kernel
|
| 422 |
+
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
| 423 |
+
B2bMma,
|
| 424 |
+
Epilogue,
|
| 425 |
+
ThreadblockSwizzle,
|
| 426 |
+
conv::Operator::kFprop
|
| 427 |
+
>;
|
| 428 |
+
};
|
| 429 |
+
|
| 430 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 431 |
+
|
| 432 |
+
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
| 433 |
+
/// multistage pipeline.
|
| 434 |
+
/// Accumulator will be staged in shared memory.
|
| 435 |
+
template <
|
| 436 |
+
typename ElementA,
|
| 437 |
+
typename LayoutA,
|
| 438 |
+
typename ElementB,
|
| 439 |
+
typename LayoutB,
|
| 440 |
+
typename ElementC,
|
| 441 |
+
typename LayoutC,
|
| 442 |
+
typename ElementAccumulator,
|
| 443 |
+
typename ArchTag,
|
| 444 |
+
typename ThreadblockShape0,
|
| 445 |
+
typename ThreadblockShape1,
|
| 446 |
+
typename WarpShape0,
|
| 447 |
+
typename WarpShape1,
|
| 448 |
+
typename InstructionShape,
|
| 449 |
+
typename EpilogueOutputOp0,
|
| 450 |
+
typename EpilogueOutputOp1,
|
| 451 |
+
typename ThreadblockSwizzle,
|
| 452 |
+
int Stages,
|
| 453 |
+
typename MathOperatorTag
|
| 454 |
+
>
|
| 455 |
+
struct DefaultB2bConv2dFprop <
|
| 456 |
+
ElementA,
|
| 457 |
+
LayoutA,
|
| 458 |
+
ElementB,
|
| 459 |
+
LayoutB,
|
| 460 |
+
ElementC,
|
| 461 |
+
LayoutC,
|
| 462 |
+
ElementAccumulator,
|
| 463 |
+
arch::OpClassTensorOp,
|
| 464 |
+
ArchTag,
|
| 465 |
+
ThreadblockShape0,
|
| 466 |
+
ThreadblockShape1,
|
| 467 |
+
WarpShape0,
|
| 468 |
+
WarpShape1,
|
| 469 |
+
InstructionShape,
|
| 470 |
+
EpilogueOutputOp0,
|
| 471 |
+
EpilogueOutputOp1,
|
| 472 |
+
ThreadblockSwizzle,
|
| 473 |
+
Stages,
|
| 474 |
+
MathOperatorTag,
|
| 475 |
+
IteratorAlgorithm::kOptimized,
|
| 476 |
+
true
|
| 477 |
+
> {
|
| 478 |
+
|
| 479 |
+
// Define the core components from GEMM
|
| 480 |
+
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 481 |
+
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
| 482 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 483 |
+
Stages, MathOperatorTag>;
|
| 484 |
+
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 485 |
+
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
| 486 |
+
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
| 487 |
+
Stages, MathOperatorTag>;
|
| 488 |
+
|
| 489 |
+
// Define iterators over tiles from the A operand
|
| 490 |
+
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
| 491 |
+
using IteratorA0 =
|
| 492 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
| 493 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
| 494 |
+
ElementA, LayoutA,
|
| 495 |
+
ThreadMapA0
|
| 496 |
+
>;
|
| 497 |
+
|
| 498 |
+
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
| 499 |
+
|
| 500 |
+
// Define iterators over tiles from the B operand
|
| 501 |
+
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
| 502 |
+
using IteratorB0 =
|
| 503 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 504 |
+
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
| 505 |
+
ElementB, LayoutB,
|
| 506 |
+
ThreadMapB0
|
| 507 |
+
>;
|
| 508 |
+
|
| 509 |
+
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
| 510 |
+
|
| 511 |
+
/// Define iterators over tiles from scale/bias vectors
|
| 512 |
+
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
| 513 |
+
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
| 514 |
+
static int const kElementsPerAccess = 2;
|
| 515 |
+
using IteratorAccumulatorScaleBias =
|
| 516 |
+
cutlass::transform::threadblock::VectorIterator<
|
| 517 |
+
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
| 518 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
| 519 |
+
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
| 520 |
+
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
| 521 |
+
>;
|
| 522 |
+
|
| 523 |
+
// Define iterators over tiles from the B operand
|
| 524 |
+
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
| 525 |
+
using IteratorB1 =
|
| 526 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 527 |
+
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
| 528 |
+
ElementB, LayoutB,
|
| 529 |
+
ThreadMapB1
|
| 530 |
+
>;
|
| 531 |
+
|
| 532 |
+
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
| 533 |
+
|
| 534 |
+
// Warp-level GEMM components
|
| 535 |
+
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
| 536 |
+
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
| 537 |
+
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
| 538 |
+
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
| 539 |
+
|
| 540 |
+
// Use fragment iterator for the accumulator
|
| 541 |
+
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
| 542 |
+
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
| 543 |
+
WarpShape0, InstructionShape,
|
| 544 |
+
ElementAccumulator,
|
| 545 |
+
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
| 546 |
+
SmemAccumulatorLayout
|
| 547 |
+
>;
|
| 548 |
+
|
| 549 |
+
// Store Accumulator tiles to Shared Memory
|
| 550 |
+
using SmemIteratorD0 =
|
| 551 |
+
cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 552 |
+
WarpShape0,
|
| 553 |
+
InstructionShape,
|
| 554 |
+
ElementC,
|
| 555 |
+
SmemAccumulatorLayout
|
| 556 |
+
>;
|
| 557 |
+
|
| 558 |
+
static int const kThreadCount = 32;
|
| 559 |
+
// load warp tile from Shared Memory accumulator
|
| 560 |
+
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
| 561 |
+
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
| 562 |
+
ElementA, SmemAccumulatorLayout,
|
| 563 |
+
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
| 564 |
+
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
| 565 |
+
|
| 566 |
+
// Define the Mma
|
| 567 |
+
using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
|
| 568 |
+
ThreadblockShape0,
|
| 569 |
+
IteratorA0,
|
| 570 |
+
SmemIteratorA0,
|
| 571 |
+
arch::CacheOperation::Always,
|
| 572 |
+
IteratorB0,
|
| 573 |
+
SmemIteratorB0,
|
| 574 |
+
arch::CacheOperation::Global,
|
| 575 |
+
IteratorAccumulatorScaleBias,
|
| 576 |
+
FragmentIteratorAccumulator,
|
| 577 |
+
SmemIteratorD0,
|
| 578 |
+
ThreadblockShape1,
|
| 579 |
+
WarpIteratorA1,
|
| 580 |
+
IteratorB1,
|
| 581 |
+
SmemIteratorB1,
|
| 582 |
+
arch::CacheOperation::Global,
|
| 583 |
+
EpilogueOutputOp0,
|
| 584 |
+
MmaPolicy0,
|
| 585 |
+
MmaPolicy1,
|
| 586 |
+
Stages
|
| 587 |
+
>;
|
| 588 |
+
|
| 589 |
+
// Define the epilogue
|
| 590 |
+
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 591 |
+
ThreadblockShape1,
|
| 592 |
+
WarpMmaTensorOp1,
|
| 593 |
+
1,
|
| 594 |
+
EpilogueOutputOp1,
|
| 595 |
+
EpilogueOutputOp1::kCount
|
| 596 |
+
>::Epilogue;
|
| 597 |
+
|
| 598 |
+
// Define the kernel
|
| 599 |
+
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
| 600 |
+
B2bMma,
|
| 601 |
+
Epilogue,
|
| 602 |
+
ThreadblockSwizzle,
|
| 603 |
+
conv::Operator::kFprop
|
| 604 |
+
>;
|
| 605 |
+
};
|
| 606 |
+
|
| 607 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 608 |
+
|
| 609 |
+
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
| 610 |
+
// multistage pipeline with interleaved layout.
|
| 611 |
+
/// Accumulator will be staged in shared memory.
|
| 612 |
+
template <
|
| 613 |
+
typename ElementA,
|
| 614 |
+
typename ElementB,
|
| 615 |
+
typename ElementC,
|
| 616 |
+
typename LayoutC,
|
| 617 |
+
typename ElementAccumulator,
|
| 618 |
+
typename ArchTag,
|
| 619 |
+
typename ThreadblockShape0,
|
| 620 |
+
typename ThreadblockShape1,
|
| 621 |
+
typename WarpShape0,
|
| 622 |
+
typename WarpShape1,
|
| 623 |
+
typename InstructionShape,
|
| 624 |
+
typename EpilogueOutputOp0,
|
| 625 |
+
typename EpilogueOutputOp1,
|
| 626 |
+
typename ThreadblockSwizzle,
|
| 627 |
+
int Stages,
|
| 628 |
+
typename MathOperatorTag,
|
| 629 |
+
int InterleavedK
|
| 630 |
+
>
|
| 631 |
+
struct DefaultB2bConv2dFprop <
|
| 632 |
+
ElementA,
|
| 633 |
+
layout::TensorNCxHWx<InterleavedK>,
|
| 634 |
+
ElementB,
|
| 635 |
+
layout::TensorCxRSKx<InterleavedK>,
|
| 636 |
+
ElementC,
|
| 637 |
+
LayoutC,
|
| 638 |
+
ElementAccumulator,
|
| 639 |
+
arch::OpClassTensorOp,
|
| 640 |
+
ArchTag,
|
| 641 |
+
ThreadblockShape0,
|
| 642 |
+
ThreadblockShape1,
|
| 643 |
+
WarpShape0,
|
| 644 |
+
WarpShape1,
|
| 645 |
+
InstructionShape,
|
| 646 |
+
EpilogueOutputOp0,
|
| 647 |
+
EpilogueOutputOp1,
|
| 648 |
+
ThreadblockSwizzle,
|
| 649 |
+
Stages,
|
| 650 |
+
MathOperatorTag,
|
| 651 |
+
IteratorAlgorithm::kOptimized,
|
| 652 |
+
true
|
| 653 |
+
> {
|
| 654 |
+
|
| 655 |
+
// Define the core components from GEMM
|
| 656 |
+
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 657 |
+
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 658 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 659 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 660 |
+
Stages, MathOperatorTag, true>;
|
| 661 |
+
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
| 662 |
+
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 663 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
| 664 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
| 665 |
+
Stages, MathOperatorTag, true>;
|
| 666 |
+
|
| 667 |
+
// Define iterators over tiles from the A operand
|
| 668 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 669 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 670 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 671 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 672 |
+
// layout.
|
| 673 |
+
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
| 674 |
+
using IteratorA0 =
|
| 675 |
+
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
| 676 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
| 677 |
+
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
| 678 |
+
ThreadMapA0
|
| 679 |
+
>;
|
| 680 |
+
|
| 681 |
+
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
| 682 |
+
|
| 683 |
+
// Define iterators over tiles from the B operand
|
| 684 |
+
// Note GEMM shared memory threadmap is used here because conv global memory
|
| 685 |
+
// layout needs to be mapped to fprop which is similar to the crosswise
|
| 686 |
+
// layout which is used by the interleaved GEMM shared memory threadmap.
|
| 687 |
+
// The Interleaved GEMM global memory layout is similar to the congruous
|
| 688 |
+
// layout.
|
| 689 |
+
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
| 690 |
+
using IteratorB0 =
|
| 691 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 692 |
+
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
| 693 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 694 |
+
ThreadMapB0
|
| 695 |
+
>;
|
| 696 |
+
|
| 697 |
+
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
| 698 |
+
|
| 699 |
+
/// Define iterators over tiles from scale/bias vectors
|
| 700 |
+
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
| 701 |
+
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
| 702 |
+
static int const kElementsPerAccess = 4;
|
| 703 |
+
using IteratorAccumulatorScaleBias =
|
| 704 |
+
cutlass::transform::threadblock::VectorIterator<
|
| 705 |
+
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
| 706 |
+
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
| 707 |
+
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
| 708 |
+
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
| 709 |
+
>;
|
| 710 |
+
|
| 711 |
+
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
| 712 |
+
using IteratorB1 =
|
| 713 |
+
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
| 714 |
+
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
| 715 |
+
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
| 716 |
+
ThreadMapB1
|
| 717 |
+
>;
|
| 718 |
+
|
| 719 |
+
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
// Warp-level GEMM components
|
| 723 |
+
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
| 724 |
+
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
| 725 |
+
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
| 726 |
+
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
| 727 |
+
|
| 728 |
+
// Use fragment iterator for the accumulator
|
| 729 |
+
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
| 730 |
+
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
| 731 |
+
WarpShape0, InstructionShape,
|
| 732 |
+
ElementAccumulator,
|
| 733 |
+
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
| 734 |
+
SmemAccumulatorLayout
|
| 735 |
+
>;
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
// Store Accumulator tiles to Shared Memory
|
| 739 |
+
using SmemIteratorD0 =
|
| 740 |
+
cutlass::epilogue::warp::TileIteratorTensorOp<
|
| 741 |
+
WarpShape0,
|
| 742 |
+
InstructionShape,
|
| 743 |
+
ElementC,
|
| 744 |
+
SmemAccumulatorLayout
|
| 745 |
+
>;
|
| 746 |
+
|
| 747 |
+
static int const kThreadCount = 32;
|
| 748 |
+
// load warp tile from Shared Memory accumulator
|
| 749 |
+
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
| 750 |
+
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
| 751 |
+
ElementA, SmemAccumulatorLayout,
|
| 752 |
+
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
| 753 |
+
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
| 754 |
+
|
| 755 |
+
// Define the Mma
|
| 756 |
+
using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
|
| 757 |
+
ThreadblockShape0,
|
| 758 |
+
IteratorA0,
|
| 759 |
+
SmemIteratorA0,
|
| 760 |
+
arch::CacheOperation::Always,
|
| 761 |
+
IteratorB0,
|
| 762 |
+
SmemIteratorB0,
|
| 763 |
+
arch::CacheOperation::Global,
|
| 764 |
+
IteratorAccumulatorScaleBias,
|
| 765 |
+
FragmentIteratorAccumulator,
|
| 766 |
+
SmemIteratorD0,
|
| 767 |
+
ThreadblockShape1,
|
| 768 |
+
WarpIteratorA1,
|
| 769 |
+
IteratorB1,
|
| 770 |
+
SmemIteratorB1,
|
| 771 |
+
arch::CacheOperation::Global,
|
| 772 |
+
EpilogueOutputOp0,
|
| 773 |
+
MmaPolicy0,
|
| 774 |
+
MmaPolicy1,
|
| 775 |
+
Stages
|
| 776 |
+
>;
|
| 777 |
+
|
| 778 |
+
// Define the epilogue
|
| 779 |
+
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
| 780 |
+
ThreadblockShape1,
|
| 781 |
+
WarpMmaTensorOp1,
|
| 782 |
+
1,
|
| 783 |
+
EpilogueOutputOp1,
|
| 784 |
+
EpilogueOutputOp1::kCount,
|
| 785 |
+
InterleavedK
|
| 786 |
+
>::Epilogue;
|
| 787 |
+
|
| 788 |
+
// Define the kernel
|
| 789 |
+
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
| 790 |
+
B2bMma,
|
| 791 |
+
Epilogue,
|
| 792 |
+
ThreadblockSwizzle,
|
| 793 |
+
conv::Operator::kFprop
|
| 794 |
+
>;
|
| 795 |
+
};
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 799 |
+
|
| 800 |
+
} // namespace kernel
|
| 801 |
+
} // namespace conv
|
| 802 |
+
} // namespace cutlass
|
| 803 |
+
|
| 804 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h
ADDED
|
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief
|
| 34 |
+
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
| 35 |
+
the appropriate threadblock-scoped epilogue.
|
| 36 |
+
|
| 37 |
+
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
| 38 |
+
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
| 39 |
+
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
| 40 |
+
*/
|
| 41 |
+
|
| 42 |
+
#pragma once
|
| 43 |
+
|
| 44 |
+
#include "cutlass/cutlass.h"
|
| 45 |
+
|
| 46 |
+
#include "cutlass/layout/matrix.h"
|
| 47 |
+
#include "cutlass/numeric_types.h"
|
| 48 |
+
|
| 49 |
+
#include "cutlass/epilogue/threadblock/epilogue.h"
|
| 50 |
+
#include "cutlass/epilogue/thread/linear_combination.h"
|
| 51 |
+
|
| 52 |
+
#include "cutlass/gemm/gemm.h"
|
| 53 |
+
#include "cutlass/gemm/kernel/gemm_pipelined.h"
|
| 54 |
+
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
| 55 |
+
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
| 56 |
+
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
| 57 |
+
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
| 58 |
+
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
| 59 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
| 60 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
| 61 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
| 62 |
+
|
| 63 |
+
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
| 64 |
+
|
| 65 |
+
#include "kernel/b2b_gemm.h"
|
| 66 |
+
#include "kernel/grouped.h"
|
| 67 |
+
#include "threadblock/default_b2b_mma.h"
|
| 68 |
+
#include "threadblock/grouped_threadblock_swizzle.h"
|
| 69 |
+
|
| 70 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 71 |
+
|
| 72 |
+
namespace cutlass {
|
| 73 |
+
namespace gemm {
|
| 74 |
+
namespace kernel {
|
| 75 |
+
|
| 76 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 77 |
+
|
| 78 |
+
template <typename T>
|
| 79 |
+
using IsGroupedSwizzle = cutlass::gemm::threadblock::detail::IsGroupedSwizzle<T>;
|
| 80 |
+
|
| 81 |
+
template <
|
| 82 |
+
/// Element type for A matrix operand
|
| 83 |
+
typename ElementA_,
|
| 84 |
+
/// Layout type for A matrix operand
|
| 85 |
+
typename LayoutA_,
|
| 86 |
+
/// Access granularity of A matrix in units of elements
|
| 87 |
+
int kAlignmentA,
|
| 88 |
+
/// Element type for B matrix operand
|
| 89 |
+
typename ElementB_,
|
| 90 |
+
/// Layout type for B matrix operand
|
| 91 |
+
typename LayoutB_,
|
| 92 |
+
/// Access granularity of B matrix in units of elements
|
| 93 |
+
int kAlignmentB,
|
| 94 |
+
/// Element type for C and D matrix operands
|
| 95 |
+
typename ElementC_,
|
| 96 |
+
/// Layout type for C and D matrix operands
|
| 97 |
+
typename LayoutC_,
|
| 98 |
+
/// Element type for internal accumulation
|
| 99 |
+
typename ElementAccumulator,
|
| 100 |
+
/// Operator class tag
|
| 101 |
+
typename OperatorClass,
|
| 102 |
+
/// Tag indicating architecture to tune for
|
| 103 |
+
typename ArchTag,
|
| 104 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 105 |
+
typename ThreadblockShape0,
|
| 106 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 107 |
+
typename ThreadblockShape1,
|
| 108 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 109 |
+
typename WarpShape0,
|
| 110 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 111 |
+
typename WarpShape1,
|
| 112 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 113 |
+
typename InstructionShape,
|
| 114 |
+
/// Epilogue output operator
|
| 115 |
+
typename EpilogueOutputOp0,
|
| 116 |
+
/// Epilogue output operator
|
| 117 |
+
typename EpilogueOutputOp1,
|
| 118 |
+
/// Threadblock-level swizzling operator
|
| 119 |
+
typename ThreadblockSwizzle,
|
| 120 |
+
/// Number of stages used in the pipelined mainloop
|
| 121 |
+
int Stages,
|
| 122 |
+
/// Operation performed by GEMM
|
| 123 |
+
typename Operator,
|
| 124 |
+
/// Stage accumulator in shared memory
|
| 125 |
+
bool SmemAccumulator = false,
|
| 126 |
+
/// Whether or not the operation is grouped
|
| 127 |
+
typename Enable = void
|
| 128 |
+
>
|
| 129 |
+
struct DefaultB2bGemm;
|
| 130 |
+
|
| 131 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 132 |
+
|
| 133 |
+
/// Partial specialization for Ampere Architecture
|
| 134 |
+
template <
|
| 135 |
+
/// Element type for A matrix operand
|
| 136 |
+
typename ElementA,
|
| 137 |
+
/// Layout type for A matrix operand
|
| 138 |
+
typename LayoutA,
|
| 139 |
+
/// Access granularity of A matrix in units of elements
|
| 140 |
+
int kAlignmentA,
|
| 141 |
+
/// Element type for B matrix operand
|
| 142 |
+
typename ElementB,
|
| 143 |
+
/// Layout type for B matrix operand
|
| 144 |
+
typename LayoutB,
|
| 145 |
+
/// Access granularity of A matrix in units of elements
|
| 146 |
+
int kAlignmentB,
|
| 147 |
+
/// Element type for C and D matrix operands
|
| 148 |
+
typename ElementC,
|
| 149 |
+
/// Element type for internal accumulation
|
| 150 |
+
typename ElementAccumulator,
|
| 151 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 152 |
+
typename ThreadblockShape0,
|
| 153 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 154 |
+
typename ThreadblockShape1,
|
| 155 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 156 |
+
typename WarpShape0,
|
| 157 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 158 |
+
typename WarpShape1,
|
| 159 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 160 |
+
typename InstructionShape,
|
| 161 |
+
/// Epilogue output operator
|
| 162 |
+
typename EpilogueOutputOp0,
|
| 163 |
+
/// Epilogue output operator
|
| 164 |
+
typename EpilogueOutputOp1,
|
| 165 |
+
/// Threadblock-level swizzling operator
|
| 166 |
+
typename ThreadblockSwizzle,
|
| 167 |
+
/// Number of stages used in the pipelined mainloop
|
| 168 |
+
int Stages,
|
| 169 |
+
/// Operation performed by GEMM
|
| 170 |
+
typename Operator>
|
| 171 |
+
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
| 172 |
+
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
| 173 |
+
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
| 174 |
+
WarpShape0, WarpShape1, InstructionShape,
|
| 175 |
+
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
| 176 |
+
Operator, false, typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
| 177 |
+
/// Define the threadblock-scoped matrix multiply-accumulate
|
| 178 |
+
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
| 179 |
+
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
| 180 |
+
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
| 181 |
+
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
| 182 |
+
InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma;
|
| 183 |
+
|
| 184 |
+
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
| 185 |
+
|
| 186 |
+
/// Define the epilogue
|
| 187 |
+
using Epilogue =
|
| 188 |
+
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 189 |
+
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
| 190 |
+
EpilogueOutputOp1::kCount>::Epilogue;
|
| 191 |
+
|
| 192 |
+
/// Define the kernel-level GEMM operator.
|
| 193 |
+
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
| 194 |
+
};
|
| 195 |
+
|
| 196 |
+
/// Partial specialization for Ampere Architecture with grouped operation
|
| 197 |
+
template <
|
| 198 |
+
/// Element type for A matrix operand
|
| 199 |
+
typename ElementA,
|
| 200 |
+
/// Layout type for A matrix operand
|
| 201 |
+
typename LayoutA,
|
| 202 |
+
/// Access granularity of A matrix in units of elements
|
| 203 |
+
int kAlignmentA,
|
| 204 |
+
/// Element type for B matrix operand
|
| 205 |
+
typename ElementB,
|
| 206 |
+
/// Layout type for B matrix operand
|
| 207 |
+
typename LayoutB,
|
| 208 |
+
/// Access granularity of A matrix in units of elements
|
| 209 |
+
int kAlignmentB,
|
| 210 |
+
/// Element type for C and D matrix operands
|
| 211 |
+
typename ElementC,
|
| 212 |
+
/// Element type for internal accumulation
|
| 213 |
+
typename ElementAccumulator,
|
| 214 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 215 |
+
typename ThreadblockShape0,
|
| 216 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 217 |
+
typename ThreadblockShape1,
|
| 218 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 219 |
+
typename WarpShape0,
|
| 220 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 221 |
+
typename WarpShape1,
|
| 222 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 223 |
+
typename InstructionShape,
|
| 224 |
+
/// Epilogue output operator
|
| 225 |
+
typename EpilogueOutputOp0,
|
| 226 |
+
/// Epilogue output operator
|
| 227 |
+
typename EpilogueOutputOp1,
|
| 228 |
+
/// Threadblock-level swizzling operator
|
| 229 |
+
typename ThreadblockSwizzle,
|
| 230 |
+
/// Number of stages used in the pipelined mainloop
|
| 231 |
+
int Stages,
|
| 232 |
+
/// Operation performed by GEMM
|
| 233 |
+
typename Operator>
|
| 234 |
+
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
| 235 |
+
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
| 236 |
+
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
| 237 |
+
WarpShape0, WarpShape1, InstructionShape,
|
| 238 |
+
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
| 239 |
+
Operator, false, typename platform::enable_if<IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
| 240 |
+
/// Define the threadblock-scoped matrix multiply-accumulate
|
| 241 |
+
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
| 242 |
+
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
| 243 |
+
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
| 244 |
+
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
| 245 |
+
InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma;
|
| 246 |
+
|
| 247 |
+
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
| 248 |
+
|
| 249 |
+
/// Define the epilogue
|
| 250 |
+
using Epilogue =
|
| 251 |
+
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 252 |
+
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
| 253 |
+
EpilogueOutputOp1::kCount>::Epilogue;
|
| 254 |
+
|
| 255 |
+
/// Define the kernel-level GEMM operator.
|
| 256 |
+
using UnderlyingB2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
| 257 |
+
|
| 258 |
+
using B2bGemmKernel = kernel::GroupedKernel<UnderlyingB2bGemmKernel>;
|
| 259 |
+
};
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 263 |
+
|
| 264 |
+
/// Partial specialization for Turing Architecture
|
| 265 |
+
template <
|
| 266 |
+
/// Element type for A matrix operand
|
| 267 |
+
typename ElementA,
|
| 268 |
+
/// Layout type for A matrix operand
|
| 269 |
+
typename LayoutA,
|
| 270 |
+
/// Access granularity of A matrix in units of elements
|
| 271 |
+
int kAlignmentA,
|
| 272 |
+
/// Element type for B matrix operand
|
| 273 |
+
typename ElementB,
|
| 274 |
+
/// Layout type for B matrix operand
|
| 275 |
+
typename LayoutB,
|
| 276 |
+
/// Access granularity of B matrix in units of elements
|
| 277 |
+
int kAlignmentB,
|
| 278 |
+
/// Element type for C and D matrix operands
|
| 279 |
+
typename ElementC,
|
| 280 |
+
/// Element type for internal accumulation
|
| 281 |
+
typename ElementAccumulator,
|
| 282 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 283 |
+
typename ThreadblockShape0,
|
| 284 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 285 |
+
typename ThreadblockShape1,
|
| 286 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 287 |
+
typename WarpShape0,
|
| 288 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 289 |
+
typename WarpShape1,
|
| 290 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 291 |
+
typename InstructionShape,
|
| 292 |
+
/// Epilogue output operator
|
| 293 |
+
typename EpilogueOutputOp0,
|
| 294 |
+
/// Epilogue output operator
|
| 295 |
+
typename EpilogueOutputOp1,
|
| 296 |
+
/// Threadblock-level swizzling operator
|
| 297 |
+
typename ThreadblockSwizzle,
|
| 298 |
+
/// Operation performed by GEMM
|
| 299 |
+
typename Operator
|
| 300 |
+
>
|
| 301 |
+
struct DefaultB2bGemm<
|
| 302 |
+
ElementA, LayoutA, kAlignmentA,
|
| 303 |
+
ElementB, LayoutB, kAlignmentB,
|
| 304 |
+
ElementC, layout::RowMajor,
|
| 305 |
+
ElementAccumulator,
|
| 306 |
+
arch::OpClassTensorOp,
|
| 307 |
+
arch::Sm75,
|
| 308 |
+
ThreadblockShape0,
|
| 309 |
+
ThreadblockShape1,
|
| 310 |
+
WarpShape0,
|
| 311 |
+
WarpShape1,
|
| 312 |
+
InstructionShape,
|
| 313 |
+
EpilogueOutputOp0,
|
| 314 |
+
EpilogueOutputOp1,
|
| 315 |
+
ThreadblockSwizzle,
|
| 316 |
+
2,
|
| 317 |
+
Operator,
|
| 318 |
+
false,
|
| 319 |
+
typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type
|
| 320 |
+
> {
|
| 321 |
+
|
| 322 |
+
/// Define the threadblock-scoped matrix multiply-accumulate
|
| 323 |
+
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
| 324 |
+
ElementA,
|
| 325 |
+
LayoutA,
|
| 326 |
+
kAlignmentA,
|
| 327 |
+
ElementB,
|
| 328 |
+
LayoutB,
|
| 329 |
+
kAlignmentB,
|
| 330 |
+
ElementAccumulator,
|
| 331 |
+
layout::RowMajor,
|
| 332 |
+
arch::OpClassTensorOp,
|
| 333 |
+
arch::Sm75,
|
| 334 |
+
ThreadblockShape0,
|
| 335 |
+
ThreadblockShape1,
|
| 336 |
+
WarpShape0,
|
| 337 |
+
WarpShape1,
|
| 338 |
+
InstructionShape,
|
| 339 |
+
2,
|
| 340 |
+
Operator,
|
| 341 |
+
EpilogueOutputOp0
|
| 342 |
+
>::ThreadblockB2bMma;
|
| 343 |
+
|
| 344 |
+
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
| 345 |
+
|
| 346 |
+
/// Define the epilogue
|
| 347 |
+
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 348 |
+
ThreadblockShape1,
|
| 349 |
+
typename B2bMma::Operator1,
|
| 350 |
+
kPartitionsK1,
|
| 351 |
+
EpilogueOutputOp1,
|
| 352 |
+
EpilogueOutputOp1::kCount
|
| 353 |
+
>::Epilogue;
|
| 354 |
+
|
| 355 |
+
/// Define the kernel-level GEMM operator.
|
| 356 |
+
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
| 357 |
+
};
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout
|
| 361 |
+
template <
|
| 362 |
+
/// Element type for A matrix operand
|
| 363 |
+
typename ElementA,
|
| 364 |
+
/// Access granularity of A matrix in units of elements
|
| 365 |
+
int kAlignmentA,
|
| 366 |
+
/// Element type for B matrix operand
|
| 367 |
+
typename ElementB,
|
| 368 |
+
/// Access granularity of B matrix in units of elements
|
| 369 |
+
int kAlignmentB,
|
| 370 |
+
/// Element type for C and D matrix operands
|
| 371 |
+
typename ElementC,
|
| 372 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 373 |
+
typename ThreadblockShape0,
|
| 374 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 375 |
+
typename ThreadblockShape1,
|
| 376 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 377 |
+
typename WarpShape0,
|
| 378 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 379 |
+
typename WarpShape1,
|
| 380 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 381 |
+
typename InstructionShape,
|
| 382 |
+
/// Epilogue output operator
|
| 383 |
+
typename EpilogueOutputOp0,
|
| 384 |
+
/// Epilogue output operator
|
| 385 |
+
typename EpilogueOutputOp1,
|
| 386 |
+
/// Threadblock-level swizzling operator
|
| 387 |
+
typename ThreadblockSwizzle,
|
| 388 |
+
/// Number of stages used in the pipelined mainloop
|
| 389 |
+
int Stages,
|
| 390 |
+
/// Number of Interleaved k
|
| 391 |
+
int InterleavedK,
|
| 392 |
+
/// Operation performed by GEMM
|
| 393 |
+
typename Operator>
|
| 394 |
+
struct DefaultB2bGemm<
|
| 395 |
+
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
|
| 396 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
| 397 |
+
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
|
| 398 |
+
arch::OpClassTensorOp, arch::Sm80,
|
| 399 |
+
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
| 400 |
+
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
| 401 |
+
ThreadblockSwizzle, Stages,
|
| 402 |
+
Operator, false, typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
| 403 |
+
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
| 404 |
+
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
| 405 |
+
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
| 406 |
+
|
| 407 |
+
using ElementAccumulator = int32_t;
|
| 408 |
+
|
| 409 |
+
/// Define the threadblock-scoped matrix multiply-accumulate
|
| 410 |
+
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
| 411 |
+
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
| 412 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80,
|
| 413 |
+
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
| 414 |
+
InstructionShape, Stages, Operator, EpilogueOutputOp0,
|
| 415 |
+
true>::ThreadblockB2bMma;
|
| 416 |
+
|
| 417 |
+
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
| 418 |
+
|
| 419 |
+
/// Define the epilogue
|
| 420 |
+
using Epilogue = typename cutlass::epilogue::threadblock::
|
| 421 |
+
DefaultInterleavedEpilogueTensorOp<
|
| 422 |
+
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
| 423 |
+
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
| 424 |
+
|
| 425 |
+
/// Define the kernel-level GEMM operator.
|
| 426 |
+
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
| 427 |
+
};
|
| 428 |
+
|
| 429 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
/// Partial specialization for Turing Integer Tensor Core Interleaved layout
|
| 433 |
+
template <
|
| 434 |
+
/// Element type for A matrix operand
|
| 435 |
+
typename ElementA,
|
| 436 |
+
/// Access granularity of A matrix in units of elements
|
| 437 |
+
int kAlignmentA,
|
| 438 |
+
/// Element type for B matrix operand
|
| 439 |
+
typename ElementB,
|
| 440 |
+
/// Access granularity of B matrix in units of elements
|
| 441 |
+
int kAlignmentB,
|
| 442 |
+
/// Element type for C and D matrix operands
|
| 443 |
+
typename ElementC,
|
| 444 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 445 |
+
typename ThreadblockShape0,
|
| 446 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 447 |
+
typename ThreadblockShape1,
|
| 448 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 449 |
+
typename WarpShape0,
|
| 450 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 451 |
+
typename WarpShape1,
|
| 452 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 453 |
+
typename InstructionShape,
|
| 454 |
+
/// Epilogue output operator
|
| 455 |
+
typename EpilogueOutputOp0,
|
| 456 |
+
/// Epilogue output operator
|
| 457 |
+
typename EpilogueOutputOp1,
|
| 458 |
+
/// Threadblock-level swizzling operator
|
| 459 |
+
typename ThreadblockSwizzle,
|
| 460 |
+
/// Number of Interleaved k
|
| 461 |
+
int InterleavedK,
|
| 462 |
+
/// Operation performed by GEMM
|
| 463 |
+
typename Operator>
|
| 464 |
+
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 465 |
+
kAlignmentA, ElementB,
|
| 466 |
+
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
| 467 |
+
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 468 |
+
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
| 469 |
+
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
| 470 |
+
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
| 471 |
+
ThreadblockSwizzle, 2, Operator, false,
|
| 472 |
+
typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
| 473 |
+
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
| 474 |
+
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
| 475 |
+
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
| 476 |
+
|
| 477 |
+
using ElementAccumulator = int32_t;
|
| 478 |
+
|
| 479 |
+
/// Define the threadblock-scoped matrix multiply-accumulate
|
| 480 |
+
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
| 481 |
+
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,
|
| 482 |
+
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
|
| 483 |
+
WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma;
|
| 484 |
+
|
| 485 |
+
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
| 486 |
+
|
| 487 |
+
/// Define the epilogue for the 2nd Gemm
|
| 488 |
+
using Epilogue = typename cutlass::epilogue::threadblock::
|
| 489 |
+
DefaultInterleavedEpilogueTensorOp<
|
| 490 |
+
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
| 491 |
+
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
| 492 |
+
|
| 493 |
+
/// Define the kernel-level GEMM operator.
|
| 494 |
+
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
| 495 |
+
};
|
| 496 |
+
|
| 497 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 498 |
+
|
| 499 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 500 |
+
|
| 501 |
+
} // namespace kernel
|
| 502 |
+
} // namespace gemm
|
| 503 |
+
} // namespace cutlass
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief
|
| 34 |
+
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
| 35 |
+
the appropriate threadblock-scoped epilogue.
|
| 36 |
+
|
| 37 |
+
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
| 38 |
+
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
| 39 |
+
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
| 40 |
+
*/
|
| 41 |
+
|
| 42 |
+
#pragma once
|
| 43 |
+
|
| 44 |
+
#include "cutlass/cutlass.h"
|
| 45 |
+
|
| 46 |
+
#include "cutlass/layout/matrix.h"
|
| 47 |
+
#include "cutlass/numeric_types.h"
|
| 48 |
+
|
| 49 |
+
#include "cutlass/epilogue/threadblock/epilogue.h"
|
| 50 |
+
#include "cutlass/epilogue/thread/linear_combination.h"
|
| 51 |
+
|
| 52 |
+
#include "cutlass/gemm/gemm.h"
|
| 53 |
+
#include "cutlass/gemm/kernel/gemm_pipelined.h"
|
| 54 |
+
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
| 55 |
+
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
| 56 |
+
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
| 57 |
+
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
| 58 |
+
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
| 59 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
| 60 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
| 61 |
+
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
| 62 |
+
|
| 63 |
+
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
| 64 |
+
#include "cutlass/transform/threadblock/vector_iterator.h"
|
| 65 |
+
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
| 66 |
+
|
| 67 |
+
#include "kernel/b2b_gemm.h"
|
| 68 |
+
#include "threadblock/default_b2b_mma.h"
|
| 69 |
+
#include "threadblock/default_b2b_mma_smem_accumulator.h"
|
| 70 |
+
|
| 71 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 72 |
+
|
| 73 |
+
namespace cutlass {
|
| 74 |
+
namespace gemm {
|
| 75 |
+
namespace kernel {
|
| 76 |
+
|
| 77 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 78 |
+
|
| 79 |
+
/// Partial specialization for Ampere Architecture
|
| 80 |
+
template <
|
| 81 |
+
/// Element type for A matrix operand
|
| 82 |
+
typename ElementA,
|
| 83 |
+
/// Layout type for A matrix operand
|
| 84 |
+
typename LayoutA,
|
| 85 |
+
/// Access granularity of A matrix in units of elements
|
| 86 |
+
int kAlignmentA,
|
| 87 |
+
/// Element type for B matrix operand
|
| 88 |
+
typename ElementB,
|
| 89 |
+
/// Layout type for B matrix operand
|
| 90 |
+
typename LayoutB,
|
| 91 |
+
/// Access granularity of A matrix in units of elements
|
| 92 |
+
int kAlignmentB,
|
| 93 |
+
/// Element type for C and D matrix operands
|
| 94 |
+
typename ElementC,
|
| 95 |
+
/// Element type for internal accumulation
|
| 96 |
+
typename ElementAccumulator,
|
| 97 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 98 |
+
typename ThreadblockShape0,
|
| 99 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 100 |
+
typename ThreadblockShape1,
|
| 101 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 102 |
+
typename WarpShape0,
|
| 103 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 104 |
+
typename WarpShape1,
|
| 105 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 106 |
+
typename InstructionShape,
|
| 107 |
+
/// Epilogue output operator
|
| 108 |
+
typename EpilogueOutputOp0,
|
| 109 |
+
/// Epilogue output operator
|
| 110 |
+
typename EpilogueOutputOp1,
|
| 111 |
+
/// Threadblock-level swizzling operator
|
| 112 |
+
typename ThreadblockSwizzle,
|
| 113 |
+
/// Number of stages used in the pipelined mainloop
|
| 114 |
+
int Stages,
|
| 115 |
+
/// Operation performed by GEMM
|
| 116 |
+
typename Operator>
|
| 117 |
+
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
| 118 |
+
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
| 119 |
+
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
| 120 |
+
WarpShape0, WarpShape1, InstructionShape,
|
| 121 |
+
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
| 122 |
+
Operator, true> {
|
| 123 |
+
/// Define the threadblock-scoped matrix multiply-accumulate
|
| 124 |
+
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
| 125 |
+
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
| 126 |
+
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
| 127 |
+
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
| 128 |
+
InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma;
|
| 129 |
+
|
| 130 |
+
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
| 131 |
+
|
| 132 |
+
/// Define the epilogue
|
| 133 |
+
using Epilogue =
|
| 134 |
+
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 135 |
+
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
| 136 |
+
EpilogueOutputOp1::kCount>::Epilogue;
|
| 137 |
+
|
| 138 |
+
/// Define the kernel-level GEMM operator.
|
| 139 |
+
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
| 140 |
+
};
|
| 141 |
+
|
| 142 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 143 |
+
|
| 144 |
+
/// Partial specialization for Turing Architecture
|
| 145 |
+
template <
|
| 146 |
+
/// Element type for A matrix operand
|
| 147 |
+
typename ElementA,
|
| 148 |
+
/// Layout type for A matrix operand
|
| 149 |
+
typename LayoutA,
|
| 150 |
+
/// Access granularity of A matrix in units of elements
|
| 151 |
+
int kAlignmentA,
|
| 152 |
+
/// Element type for B matrix operand
|
| 153 |
+
typename ElementB,
|
| 154 |
+
/// Layout type for B matrix operand
|
| 155 |
+
typename LayoutB,
|
| 156 |
+
/// Access granularity of B matrix in units of elements
|
| 157 |
+
int kAlignmentB,
|
| 158 |
+
/// Element type for C and D matrix operands
|
| 159 |
+
typename ElementC,
|
| 160 |
+
/// Element type for internal accumulation
|
| 161 |
+
typename ElementAccumulator,
|
| 162 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 163 |
+
typename ThreadblockShape0,
|
| 164 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 165 |
+
typename ThreadblockShape1,
|
| 166 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 167 |
+
typename WarpShape0,
|
| 168 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 169 |
+
typename WarpShape1,
|
| 170 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 171 |
+
typename InstructionShape,
|
| 172 |
+
/// Epilogue output operator
|
| 173 |
+
typename EpilogueOutputOp0,
|
| 174 |
+
/// Epilogue output operator
|
| 175 |
+
typename EpilogueOutputOp1,
|
| 176 |
+
/// Threadblock-level swizzling operator
|
| 177 |
+
typename ThreadblockSwizzle,
|
| 178 |
+
/// Operation performed by GEMM
|
| 179 |
+
typename Operator
|
| 180 |
+
>
|
| 181 |
+
struct DefaultB2bGemm<
|
| 182 |
+
ElementA, LayoutA, kAlignmentA,
|
| 183 |
+
ElementB, LayoutB, kAlignmentB,
|
| 184 |
+
ElementC, layout::RowMajor,
|
| 185 |
+
ElementAccumulator,
|
| 186 |
+
arch::OpClassTensorOp,
|
| 187 |
+
arch::Sm75,
|
| 188 |
+
ThreadblockShape0,
|
| 189 |
+
ThreadblockShape1,
|
| 190 |
+
WarpShape0,
|
| 191 |
+
WarpShape1,
|
| 192 |
+
InstructionShape,
|
| 193 |
+
EpilogueOutputOp0,
|
| 194 |
+
EpilogueOutputOp1,
|
| 195 |
+
ThreadblockSwizzle,
|
| 196 |
+
2,
|
| 197 |
+
Operator,
|
| 198 |
+
true
|
| 199 |
+
> {
|
| 200 |
+
|
| 201 |
+
/// Define the threadblock-scoped matrix multiply-accumulate
|
| 202 |
+
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
| 203 |
+
ElementA,
|
| 204 |
+
LayoutA,
|
| 205 |
+
kAlignmentA,
|
| 206 |
+
ElementB,
|
| 207 |
+
LayoutB,
|
| 208 |
+
kAlignmentB,
|
| 209 |
+
ElementAccumulator,
|
| 210 |
+
layout::RowMajor,
|
| 211 |
+
arch::OpClassTensorOp,
|
| 212 |
+
arch::Sm75,
|
| 213 |
+
ThreadblockShape0,
|
| 214 |
+
ThreadblockShape1,
|
| 215 |
+
WarpShape0,
|
| 216 |
+
WarpShape1,
|
| 217 |
+
InstructionShape,
|
| 218 |
+
2,
|
| 219 |
+
Operator,
|
| 220 |
+
EpilogueOutputOp0,
|
| 221 |
+
false,
|
| 222 |
+
true
|
| 223 |
+
>::ThreadblockB2bMma;
|
| 224 |
+
|
| 225 |
+
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
| 226 |
+
|
| 227 |
+
/// Define the epilogue
|
| 228 |
+
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
| 229 |
+
ThreadblockShape1,
|
| 230 |
+
typename B2bMma::Operator1,
|
| 231 |
+
kPartitionsK1,
|
| 232 |
+
EpilogueOutputOp1,
|
| 233 |
+
EpilogueOutputOp1::kCount
|
| 234 |
+
>::Epilogue;
|
| 235 |
+
|
| 236 |
+
/// Define the kernel-level GEMM operator.
|
| 237 |
+
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
| 238 |
+
};
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout
|
| 242 |
+
template <
|
| 243 |
+
/// Element type for A matrix operand
|
| 244 |
+
typename ElementA,
|
| 245 |
+
/// Access granularity of A matrix in units of elements
|
| 246 |
+
int kAlignmentA,
|
| 247 |
+
/// Element type for B matrix operand
|
| 248 |
+
typename ElementB,
|
| 249 |
+
/// Access granularity of B matrix in units of elements
|
| 250 |
+
int kAlignmentB,
|
| 251 |
+
/// Element type for C and D matrix operands
|
| 252 |
+
typename ElementC,
|
| 253 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 254 |
+
typename ThreadblockShape0,
|
| 255 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 256 |
+
typename ThreadblockShape1,
|
| 257 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 258 |
+
typename WarpShape0,
|
| 259 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 260 |
+
typename WarpShape1,
|
| 261 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 262 |
+
typename InstructionShape,
|
| 263 |
+
/// Epilogue output operator
|
| 264 |
+
typename EpilogueOutputOp0,
|
| 265 |
+
/// Epilogue output operator
|
| 266 |
+
typename EpilogueOutputOp1,
|
| 267 |
+
/// Threadblock-level swizzling operator
|
| 268 |
+
typename ThreadblockSwizzle,
|
| 269 |
+
/// Number of stages used in the pipelined mainloop
|
| 270 |
+
int Stages,
|
| 271 |
+
/// Number of Interleaved k
|
| 272 |
+
int InterleavedK,
|
| 273 |
+
/// Operation performed by GEMM
|
| 274 |
+
typename Operator>
|
| 275 |
+
struct DefaultB2bGemm<
|
| 276 |
+
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
|
| 277 |
+
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
| 278 |
+
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
|
| 279 |
+
arch::OpClassTensorOp, arch::Sm80,
|
| 280 |
+
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
| 281 |
+
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
| 282 |
+
ThreadblockSwizzle, Stages,
|
| 283 |
+
Operator, true> {
|
| 284 |
+
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
| 285 |
+
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
| 286 |
+
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
| 287 |
+
|
| 288 |
+
using ElementAccumulator = int32_t;
|
| 289 |
+
|
| 290 |
+
/// Define the threadblock-scoped matrix multiply-accumulate
|
| 291 |
+
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
| 292 |
+
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
| 293 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80,
|
| 294 |
+
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
| 295 |
+
InstructionShape, Stages, Operator, EpilogueOutputOp0,
|
| 296 |
+
true, true>::ThreadblockB2bMma;
|
| 297 |
+
|
| 298 |
+
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
| 299 |
+
|
| 300 |
+
/// Define the epilogue
|
| 301 |
+
using Epilogue = typename cutlass::epilogue::threadblock::
|
| 302 |
+
DefaultInterleavedEpilogueTensorOp<
|
| 303 |
+
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
| 304 |
+
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
| 305 |
+
|
| 306 |
+
/// Define the kernel-level GEMM operator.
|
| 307 |
+
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
| 308 |
+
};
|
| 309 |
+
|
| 310 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
/// Partial specialization for Turing Integer Tensor Core Interleaved layout
|
| 314 |
+
template <
|
| 315 |
+
/// Element type for A matrix operand
|
| 316 |
+
typename ElementA,
|
| 317 |
+
/// Access granularity of A matrix in units of elements
|
| 318 |
+
int kAlignmentA,
|
| 319 |
+
/// Element type for B matrix operand
|
| 320 |
+
typename ElementB,
|
| 321 |
+
/// Access granularity of B matrix in units of elements
|
| 322 |
+
int kAlignmentB,
|
| 323 |
+
/// Element type for C and D matrix operands
|
| 324 |
+
typename ElementC,
|
| 325 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 326 |
+
typename ThreadblockShape0,
|
| 327 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 328 |
+
typename ThreadblockShape1,
|
| 329 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 330 |
+
typename WarpShape0,
|
| 331 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 332 |
+
typename WarpShape1,
|
| 333 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 334 |
+
typename InstructionShape,
|
| 335 |
+
/// Epilogue output operator
|
| 336 |
+
typename EpilogueOutputOp0,
|
| 337 |
+
/// Epilogue output operator
|
| 338 |
+
typename EpilogueOutputOp1,
|
| 339 |
+
/// Threadblock-level swizzling operator
|
| 340 |
+
typename ThreadblockSwizzle,
|
| 341 |
+
/// Number of Interleaved k
|
| 342 |
+
int InterleavedK,
|
| 343 |
+
/// Operation performed by GEMM
|
| 344 |
+
typename Operator>
|
| 345 |
+
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 346 |
+
kAlignmentA, ElementB,
|
| 347 |
+
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
| 348 |
+
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
|
| 349 |
+
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
| 350 |
+
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
| 351 |
+
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
| 352 |
+
ThreadblockSwizzle, 2, Operator, true> {
|
| 353 |
+
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
| 354 |
+
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
| 355 |
+
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
| 356 |
+
|
| 357 |
+
using ElementAccumulator = int32_t;
|
| 358 |
+
|
| 359 |
+
/// Define the threadblock-scoped matrix multiply-accumulate
|
| 360 |
+
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
| 361 |
+
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
| 362 |
+
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
|
| 363 |
+
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
| 364 |
+
InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma;
|
| 365 |
+
|
| 366 |
+
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
| 367 |
+
|
| 368 |
+
/// Define the epilogue for the 2nd Gemm
|
| 369 |
+
using Epilogue = typename cutlass::epilogue::threadblock::
|
| 370 |
+
DefaultInterleavedEpilogueTensorOp<
|
| 371 |
+
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
| 372 |
+
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
| 373 |
+
|
| 374 |
+
/// Define the kernel-level GEMM operator.
|
| 375 |
+
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
| 376 |
+
};
|
| 377 |
+
|
| 378 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 379 |
+
|
| 380 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 381 |
+
|
| 382 |
+
} // namespace kernel
|
| 383 |
+
} // namespace gemm
|
| 384 |
+
} // namespace cutlass
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/grouped.h
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief High-level interface for running a grouped version of a CUTLASS kernel
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#include "cutlass/fast_math.h"
|
| 40 |
+
#include "cutlass/gemm/gemm.h"
|
| 41 |
+
#include "cutlass/matrix_coord.h"
|
| 42 |
+
#include "cutlass/complex.h"
|
| 43 |
+
#include "cutlass/semaphore.h"
|
| 44 |
+
|
| 45 |
+
#include "cutlass/layout/matrix.h"
|
| 46 |
+
#include "cutlass/trace.h"
|
| 47 |
+
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
| 48 |
+
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
| 49 |
+
|
| 50 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 51 |
+
|
| 52 |
+
namespace cutlass {
|
| 53 |
+
namespace gemm {
|
| 54 |
+
namespace kernel {
|
| 55 |
+
|
| 56 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 57 |
+
|
| 58 |
+
/// High-level interface for running a grouped version of a CUTLASS kernel
|
| 59 |
+
template <
|
| 60 |
+
typename BaseKernel_ ///! Kernel-scoped matrix multiply-accumulate
|
| 61 |
+
>
|
| 62 |
+
struct GroupedKernel {
|
| 63 |
+
public:
|
| 64 |
+
|
| 65 |
+
using BaseKernel = BaseKernel_;
|
| 66 |
+
using Epilogue = typename BaseKernel::Epilogue;
|
| 67 |
+
|
| 68 |
+
/// Types that need to be exported to work properly with device::BaseGrouped
|
| 69 |
+
using ElementA = typename BaseKernel::ElementA;
|
| 70 |
+
using LayoutA = typename BaseKernel::LayoutA;
|
| 71 |
+
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
| 72 |
+
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
|
| 73 |
+
static int const kAlignmentA = BaseKernel::kAlignmentA;
|
| 74 |
+
|
| 75 |
+
using ElementB = typename BaseKernel::ElementB;
|
| 76 |
+
using LayoutB = typename BaseKernel::LayoutB;
|
| 77 |
+
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
| 78 |
+
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
|
| 79 |
+
static int const kAlignmentB = BaseKernel::kAlignmentB;
|
| 80 |
+
|
| 81 |
+
using ElementC = typename BaseKernel::ElementC;
|
| 82 |
+
using LayoutC = typename BaseKernel::LayoutC;
|
| 83 |
+
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
| 84 |
+
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
| 85 |
+
static int const kAlignmentC = BaseKernel::kAlignmentC;
|
| 86 |
+
|
| 87 |
+
using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
|
| 88 |
+
|
| 89 |
+
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
|
| 90 |
+
using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle;
|
| 91 |
+
|
| 92 |
+
using Operator = typename BaseKernel::Operator;
|
| 93 |
+
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
|
| 94 |
+
|
| 95 |
+
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
|
| 96 |
+
using MathOperator = typename WarpMmaOperator::MathOperator;
|
| 97 |
+
using OperatorClass = typename WarpMmaOperator::OperatorClass;
|
| 98 |
+
using ArchTag = typename WarpMmaOperator::ArchTag;
|
| 99 |
+
using ThreadblockShape = typename BaseKernel::Mma::Shape;
|
| 100 |
+
using WarpShape = typename BaseKernel::WarpShape;
|
| 101 |
+
using InstructionShape = typename BaseKernel::InstructionShape;
|
| 102 |
+
static int const kStages = BaseKernel::Mma::kStages;
|
| 103 |
+
|
| 104 |
+
using Mma = typename BaseKernel::Mma;
|
| 105 |
+
|
| 106 |
+
using Arguments = typename BaseKernel::GroupedArguments;
|
| 107 |
+
using Params = typename BaseKernel::GroupedParams;
|
| 108 |
+
using ProblemVisitor = typename ThreadblockSwizzle::ProblemVisitor;
|
| 109 |
+
|
| 110 |
+
static int const kThreadCount = BaseKernel::kThreadCount;
|
| 111 |
+
|
| 112 |
+
/// Shared memory storage structure
|
| 113 |
+
struct SharedStorage {
|
| 114 |
+
typename BaseKernel::SharedStorage kernel;
|
| 115 |
+
|
| 116 |
+
// ProblemVisitor shared storage can't be overlapped with others
|
| 117 |
+
typename ProblemVisitor::SharedStorage problem_visitor;
|
| 118 |
+
};
|
| 119 |
+
|
| 120 |
+
public:
|
| 121 |
+
|
| 122 |
+
//
|
| 123 |
+
// Methods
|
| 124 |
+
//
|
| 125 |
+
|
| 126 |
+
CUTLASS_DEVICE
|
| 127 |
+
GroupedKernel() { }
|
| 128 |
+
|
| 129 |
+
/// Determines whether kernel satisfies alignment
|
| 130 |
+
static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) {
|
| 131 |
+
return Status::kSuccess;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
static Status can_implement(Arguments const &args) {
|
| 135 |
+
return Status::kSuccess;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
/// Executes a kernel-level GEMM in a loop
|
| 139 |
+
CUTLASS_DEVICE
|
| 140 |
+
void operator()(Params ¶ms, SharedStorage &shared_storage) {
|
| 141 |
+
|
| 142 |
+
ThreadblockSwizzle swizzle(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
|
| 143 |
+
|
| 144 |
+
if (ProblemVisitor::kTransposed) {
|
| 145 |
+
params.transpose();
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
BaseKernel mma;
|
| 149 |
+
|
| 150 |
+
// Outer 'persistent' loop to iterate over tiles
|
| 151 |
+
while (swizzle.problem_visitor.next_tile()) {
|
| 152 |
+
|
| 153 |
+
typename BaseKernel::Params mma_params = params.to_single_params(swizzle.problem_visitor);
|
| 154 |
+
mma.run_with_swizzle(mma_params, shared_storage.kernel, swizzle);
|
| 155 |
+
|
| 156 |
+
// Next tile
|
| 157 |
+
swizzle.problem_visitor.advance(gridDim.x);
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
};
|
| 161 |
+
|
| 162 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 163 |
+
|
| 164 |
+
} // namespace kernel
|
| 165 |
+
} // namespace gemm
|
| 166 |
+
} // namespace cutlass
|
| 167 |
+
|
| 168 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|