Kernels
danieldk HF Staff commited on
Commit
4913396
·
0 Parent(s):

Import DeepGEMM

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 DeepSeek
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - kernel
5
+ ---
6
+
7
+ ## DeepGEMM
8
+
9
+ [DeepGEMM](https://github.com/deepseek-ai/DeepGEMM/) by DeepSeek.
10
+
tests/test_core.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from typing import Tuple
4
+
5
+ import deep_gemm
6
+ from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor
7
+
8
+
9
+ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
10
+ assert x.dim() == 2 and x.size(1) % 128 == 0
11
+ m, n = x.shape
12
+ x_view = x.view(m, -1, 128)
13
+ x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
14
+ return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
15
+
16
+
17
+ def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
18
+ assert x.dim() == 2
19
+ m, n = x.shape
20
+ x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device)
21
+ x_padded[:m, :n] = x
22
+ x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
23
+ x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
24
+ x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
25
+ return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
26
+
27
+
28
+ def construct(m: int, k: int, n: int) -> \
29
+ Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
30
+ x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
31
+ y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
32
+ out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
33
+ ref_out = x @ y.t()
34
+
35
+ x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y)
36
+ # Transpose earlier so that the testing will not trigger transposing kernels
37
+ x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
38
+ return x_fp8, y_fp8, out, ref_out
39
+
40
+
41
+ def construct_grouped(num_groups: int, m: int, k: int, n: int, is_masked: bool) -> \
42
+ Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
43
+ x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16)
44
+ y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
45
+ out = torch.empty((num_groups, m, n), device='cuda', dtype=torch.bfloat16)
46
+ ref_out = torch.einsum('gmk,gnk->gmn', x, y)
47
+
48
+ assert m % 4 == 0, f'TMA alignment error: {m}'
49
+ x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float))
50
+ y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
51
+ for i in range(num_groups):
52
+ x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
53
+ y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
54
+
55
+ # For non-masked input, we must merge the group and M dims
56
+ if not is_masked:
57
+ x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1])
58
+ out, ref_out = out.view(-1, n), ref_out.view(-1, n)
59
+
60
+ # Transpose earlier so that the testing will not trigger transposing kernels
61
+ x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
62
+ return x_fp8, y_fp8, out, ref_out
63
+
64
+
65
+ def test_gemm() -> None:
66
+ print('Testing GEMM:')
67
+ for m in (64, 128, 4096):
68
+ for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]:
69
+ x_fp8, y_fp8, out, ref_out = construct(m, k, n)
70
+ deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
71
+ diff = calc_diff(out, ref_out)
72
+ assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
73
+
74
+ # noinspection PyShadowingNames
75
+ def test_func():
76
+ # Construct new tensors every time to avoid L2 cache acceleration
77
+ x_fp8, y_fp8, out, ref_out = construct(m, k, n)
78
+ deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
79
+
80
+ t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
81
+ print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
82
+ f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
83
+ f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
84
+ print()
85
+
86
+
87
+ def test_m_grouped_gemm_contiguous() -> None:
88
+ print('Testing grouped contiguous GEMM:')
89
+
90
+ for num_groups, m, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)):
91
+ # TODO: make a stronger test
92
+ x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False)
93
+ m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
94
+ m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
95
+ deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
96
+ diff = calc_diff(out, ref_out)
97
+ assert diff < 0.001, f'm={m * num_groups}, {k=}, {n=}, {diff:.5f}'
98
+
99
+ # noinspection PyShadowingNames
100
+ def test_func():
101
+ # Construct new tensors every time to avoid L2 cache acceleration
102
+ x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False)
103
+ m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
104
+ m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
105
+ deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
106
+
107
+ t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
108
+ print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
109
+ f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
110
+ f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
111
+ print()
112
+
113
+
114
+ def test_m_grouped_gemm_masked() -> None:
115
+ print('Testing grouped masked GEMM:')
116
+
117
+ for num_groups, m in ((1, 1024), (2, 512), (4, 256)):
118
+ for k, n in ((7168, 4096), (2048, 7168), ):
119
+ # Test correctness
120
+ masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384)))
121
+ for i in range(10):
122
+ x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
123
+ masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
124
+ for j in range(num_groups):
125
+ masked_m[j] = random.choice(masked_m_candidates)
126
+ expected_m = min(int(masked_m.float().mean()) + 1, m)
127
+ deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m)
128
+ for j in range(num_groups):
129
+ diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()])
130
+ assert diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
131
+
132
+ # noinspection PyShadowingNames
133
+ def test_func():
134
+ # Construct new tensors every time to avoid L2 cache acceleration
135
+ x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
136
+ masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m
137
+ deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, m)
138
+
139
+ # Test performance with fixed shapes
140
+ t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
141
+ print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
142
+ f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
143
+ f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
144
+ print()
145
+
146
+
147
+ if __name__ == '__main__':
148
+ torch.backends.cuda.matmul.allow_tf32 = True
149
+ torch.backends.cudnn.allow_tf32 = True
150
+ torch.manual_seed(0)
151
+ random.seed(0)
152
+
153
+ print('Library path:')
154
+ print(f' > {deep_gemm.__path__}\n')
155
+
156
+ test_gemm()
157
+ test_m_grouped_gemm_contiguous()
158
+ test_m_grouped_gemm_masked()
tests/test_jit.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from typing import Any
4
+
5
+ from deep_gemm import jit
6
+
7
+
8
+ class Capture:
9
+ def __init__(self) -> None:
10
+ self.read_fd = None
11
+ self.write_fd = None
12
+ self.saved_stdout = None
13
+ self.captured = None
14
+
15
+ def __enter__(self) -> Any:
16
+ self.read_fd, self.write_fd = os.pipe()
17
+ self.saved_stdout = os.dup(1)
18
+ os.dup2(self.write_fd, 1)
19
+ return self
20
+
21
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
22
+ os.dup2(self.saved_stdout, 1)
23
+ os.close(self.write_fd)
24
+ with os.fdopen(self.read_fd, 'r') as f:
25
+ self.captured = f.read()
26
+
27
+ def capture(self) -> str:
28
+ return self.captured
29
+
30
+
31
+ if __name__ == '__main__':
32
+ # Runtime
33
+ print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n')
34
+
35
+ # Templates
36
+ print('Generated code:')
37
+ args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16),
38
+ ('enable_double_streams', bool), ('stream', torch.cuda.Stream))
39
+ body = "\n"
40
+ body += 'std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;\n'
41
+ body += 'std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;\n'
42
+ body += 'std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;\n'
43
+ body += 'std::cout << reinterpret_cast<uint64_t>(out) << std::endl;\n'
44
+ body += 'std::cout << enable_double_streams << std::endl;\n'
45
+ body += 'std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;\n'
46
+ code = jit.generate((), args, body)
47
+ print(code)
48
+
49
+ # Build
50
+ print('Building ...')
51
+ func = jit.build('test_func', args, code)
52
+
53
+ # Test correctness
54
+ print('Running ...')
55
+ fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda')
56
+ fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda')
57
+ bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda')
58
+ with Capture() as capture:
59
+ assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0
60
+ output = capture.capture()
61
+ ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n'
62
+ assert output == ref_output, f'{output=}, {ref_output=}'
63
+
64
+ print('JIT test passed')
torch-ext/deep_gemm/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from . import jit
4
+ from .jit_kernels import (
5
+ gemm_fp8_fp8_bf16_nt,
6
+ m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
7
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked,
8
+ ceil_div,
9
+ set_num_sms, get_num_sms,
10
+ get_col_major_tma_aligned_tensor,
11
+ get_m_alignment_for_contiguous_layout
12
+ )
13
+ from .utils import bench, bench_kineto, calc_diff
torch-ext/deep_gemm/include/deep_gemm/fp8_gemm.cuh ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma clang diagnostic push
2
+ #pragma clang diagnostic ignored "-Wunknown-attributes"
3
+ #pragma once
4
+
5
+ #include <cutlass/arch/barrier.h>
6
+ #include <cutlass/arch/reg_reconfig.h>
7
+
8
+ #include <cute/arch/cluster_sm90.hpp>
9
+ #include <cute/arch/copy_sm90_desc.hpp>
10
+ #include <cute/arch/copy_sm90_tma.hpp>
11
+
12
+ #include "mma_utils.cuh"
13
+ #include "scheduler.cuh"
14
+ #include "tma_utils.cuh"
15
+ #include "utils.cuh"
16
+
17
+ namespace deep_gemm {
18
+
19
+ enum class Layout {
20
+ RowMajor,
21
+ ColMajor
22
+ };
23
+
24
+ template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
25
+ __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
26
+ DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
27
+ return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
28
+ }
29
+
30
+ template <uint32_t SHAPE_N, uint32_t SHAPE_K,
31
+ uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
32
+ uint32_t kNumGroups, uint32_t kNumStages,
33
+ uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
34
+ uint32_t kNumTMAMulticast,
35
+ GemmType kGemmType>
36
+ __global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
37
+ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
38
+ uint32_t shape_m,
39
+ const __grid_constant__ CUtensorMap tensor_map_a,
40
+ const __grid_constant__ CUtensorMap tensor_map_b,
41
+ const __grid_constant__ CUtensorMap tensor_map_scales_a,
42
+ const __grid_constant__ CUtensorMap tensor_map_d) {
43
+ #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
44
+ // Scaling checks
45
+ DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
46
+ DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
47
+
48
+ // Types
49
+ using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
50
+ using Barrier = cutlass::arch::ClusterTransactionBarrier;
51
+
52
+ // Shared memory
53
+ static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
54
+ static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
55
+ static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
56
+ static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
57
+ static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
58
+ static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K);
59
+ static constexpr uint32_t SMEM_SCALES_B_SIZE = ceil_div<uint32_t>(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) * sizeof(Barrier);
60
+
61
+ // Configs
62
+ constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
63
+ constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
64
+ constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
65
+ constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages);
66
+ const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
67
+ const uint32_t lane_idx = get_lane_id();
68
+
69
+ // Prefetch TMA descriptors at very beginning
70
+ if (threadIdx.x == kNumMathThreads) {
71
+ cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
72
+ cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
73
+ cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
74
+ cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
75
+ }
76
+ __syncwarp();
77
+
78
+ // Align to 1024 bytes for swizzle-128B
79
+ extern __shared__ __align__(1024) uint8_t smem_buffer[];
80
+ DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
81
+
82
+ // Data on shared memory
83
+ auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
84
+ __nv_fp8_e4m3* smem_a[kNumStages];
85
+ __nv_fp8_e4m3* smem_b[kNumStages];
86
+ float* smem_scales_a[kNumStages];
87
+ float* smem_scales_b;
88
+
89
+ // TMA Barrier for both divisible and non-divisible cases
90
+ Barrier* full_barriers[kNumStages];
91
+ Barrier* empty_barriers[kNumStages];
92
+
93
+ // Fill shared memory pointers
94
+ #pragma unroll
95
+ for (int i = 0; i < kNumStages; ++ i) {
96
+ smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
97
+ smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
98
+ smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
99
+ }
100
+ smem_scales_b = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE));
101
+
102
+ // Fill barriers
103
+ auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_scales_b) + SMEM_SCALES_B_SIZE);
104
+ #pragma unroll
105
+ for (int i = 0; i < kNumStages; ++ i) {
106
+ full_barriers[i] = barrier_start_ptr + i;
107
+ empty_barriers[i] = barrier_start_ptr + kNumStages + i;
108
+ }
109
+
110
+ // Initialize barriers
111
+ DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
112
+ if (threadIdx.x == kNumMathThreads) {
113
+ // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
114
+ // even with TMA multicast disabled, we want to make the behavior aligned
115
+ #pragma unroll
116
+ for (int i = 0; i < kNumStages; ++ i) {
117
+ full_barriers[i]->init(1);
118
+ empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
119
+ }
120
+
121
+ // Make initialized barrier visible in async proxy
122
+ cutlass::arch::fence_view_async_shared();
123
+ (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
124
+ }
125
+
126
+ // Synchronize all threads to make barrier visible in normal memory model
127
+ (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
128
+
129
+ // For pipeline unrolling
130
+ struct DivisibleK {};
131
+ struct NotDivisibleK {};
132
+ auto launch_k_iterations = [](const auto& func) {
133
+ if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
134
+ for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
135
+ func(k_iter, DivisibleK{});
136
+ } else {
137
+ for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
138
+ func(k_iter, DivisibleK{});
139
+ func(kNumIterations - 1, NotDivisibleK{});
140
+ }
141
+ };
142
+
143
+ // Register reconfigurations
144
+ constexpr int kNumTMARegisters = 40;
145
+ constexpr int kNumMathRegisters = 232;
146
+
147
+ // Block scheduler
148
+ uint32_t m_block_idx, n_block_idx;
149
+ auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast>(shape_m, grouped_layout);
150
+
151
+ if (threadIdx.x >= kNumMathThreads) {
152
+ // TMA warp-group for loading data
153
+ cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
154
+
155
+ // NOTES: only one thread (or warp) will be used
156
+ if (threadIdx.x == kNumMathThreads) {
157
+ // Persistently schedule over blocks
158
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
159
+ launch_k_iterations([&](int k_iter, auto type) {
160
+ constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
161
+ constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
162
+ DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
163
+
164
+ // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
165
+ // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
166
+ #pragma unroll
167
+ for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
168
+ // Wait consumer release
169
+ empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
170
+
171
+ // Issue TMA A with broadcasting
172
+ auto& full_barrier = *full_barriers[s];
173
+ int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
174
+ tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
175
+ smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
176
+ tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
177
+ smem_scales_a[s], m_block_idx * BLOCK_M,
178
+ scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
179
+
180
+ // Issue TMA B without broadcasting
181
+ tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
182
+ smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
183
+ full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
184
+ }
185
+
186
+ // Wait unaligned cases
187
+ #pragma unroll
188
+ for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
189
+ empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
190
+ full_barriers[s]->arrive();
191
+ }
192
+ });
193
+ }
194
+
195
+ // To safely deconstruct distributed shared barriers, we need another round of empty waits
196
+ if constexpr (kNumTMAMulticast > 1) {
197
+ #pragma unroll
198
+ for (uint32_t s = 0; s < kNumStages; ++ s)
199
+ empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
200
+ }
201
+ }
202
+ } else {
203
+ // Math warp-groups for WGMMA
204
+ cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
205
+
206
+ // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
207
+ const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
208
+ const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
209
+
210
+ // Persistently schedule over blocks
211
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
212
+ // Decide the number of scales B to load
213
+ DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
214
+ uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
215
+ if constexpr (not kMustUseUniformedScaleB) {
216
+ num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8;
217
+ num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8;
218
+ }
219
+ uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2);
220
+
221
+ // Load B scales with math warp-groups
222
+ // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
223
+ if (threadIdx.x >= 32) {
224
+ auto num_previous_lines = scheduler.get_global_idx<false>(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
225
+ auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
226
+ #pragma unroll
227
+ for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
228
+ st_shared(smem_scales_b + i, __ldg(local_scales_b + i));
229
+ }
230
+ cutlass::arch::NamedBarrier(kNumMathThreads).sync();
231
+
232
+ // Accumulation for WGMMA or CUDA promotion
233
+ float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
234
+
235
+ // Empty barrier arrival
236
+ auto empty_barrier_arrive = [&](int s) {
237
+ if constexpr (kNumTMAMulticast == 1) {
238
+ lane_idx == 0 ? empty_barriers[s]->arrive() : void();
239
+ } else {
240
+ lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
241
+ }
242
+ };
243
+
244
+ // Launch MMAs
245
+ launch_k_iterations([&](int k_iter, auto type) {
246
+ constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
247
+ constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
248
+ DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
249
+
250
+ #pragma unroll
251
+ for (int s = 0; s < kNumInnerStages; ++ s) {
252
+ // Read B scales
253
+ float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
254
+ // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
255
+ if constexpr (not kMustUseUniformedScaleB)
256
+ scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
257
+
258
+ // Wait TMA arrivals
259
+ full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
260
+
261
+ // Read A scales
262
+ // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
263
+ auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
264
+
265
+ // Commit WGMMA instructions
266
+ #pragma unroll
267
+ for (int i = 0; i < WGMMA::kNumAccum; ++ i)
268
+ warpgroup_fence_operand(accum[i]);
269
+ warpgroup_arrive();
270
+ #pragma unroll
271
+ for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
272
+ auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
273
+ auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
274
+ WGMMA::wgmma(desc_a, desc_b, accum, k);
275
+ }
276
+ warpgroup_commit_batch();
277
+ #pragma unroll
278
+ for (int i = 0; i < WGMMA::kNumAccum; ++ i)
279
+ warpgroup_fence_operand(accum[i]);
280
+ warpgroup_wait<0>();
281
+
282
+ // Notify barrier arrival
283
+ empty_barrier_arrive(s);
284
+
285
+ // Promote with scales
286
+ // NOTES: making it as predicates is very important for performance, comparing to two loops
287
+ float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
288
+ float scale_0_1, scale_1_1;
289
+ if constexpr (not kMustUseUniformedScaleB)
290
+ scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
291
+ #pragma unroll
292
+ for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
293
+ bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
294
+ final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
295
+ final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
296
+ final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
297
+ final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
298
+ }
299
+ }
300
+
301
+ // Wait unaligned cases
302
+ #pragma unroll
303
+ for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
304
+ full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
305
+ empty_barrier_arrive(s);
306
+ }
307
+ });
308
+
309
+ // Write back to shared memory using STSM
310
+ DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
311
+ #pragma unroll
312
+ for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
313
+ SM90_U32x4_STSM_N<nv_bfloat162>::copy(
314
+ __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
315
+ __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
316
+ __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
317
+ __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
318
+ smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
319
+ );
320
+ }
321
+ if constexpr (WGMMA::kNumAccum % 8 != 0) {
322
+ SM90_U32x2_STSM_N<nv_bfloat162>::copy(
323
+ __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
324
+ __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
325
+ smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
326
+ );
327
+ }
328
+ cute::tma_store_fence();
329
+ cutlass::arch::NamedBarrier(kNumMathThreads).sync();
330
+
331
+ // Use TMA store to write back to global memory
332
+ if (threadIdx.x == 0) {
333
+ cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
334
+ scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
335
+ cute::tma_store_arrive();
336
+ cute::tma_store_wait<0>();
337
+ }
338
+ __syncwarp();
339
+ }
340
+ }
341
+ #else
342
+ if (blockIdx.x == 0 and threadIdx.x == 0)
343
+ DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
344
+ #endif
345
+ }
346
+
347
+ template <uint32_t SHAPE_N, uint32_t SHAPE_K,
348
+ uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
349
+ uint32_t kNumGroups, uint32_t kNumStages,
350
+ uint32_t kNumTMAMulticast,
351
+ GemmType kGemmType>
352
+ class Gemm {
353
+ private:
354
+ using Barrier = cuda::barrier<cuda::thread_scope_block>;
355
+
356
+ public:
357
+ Gemm() = default;
358
+
359
+ static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
360
+ uint32_t shape_m,
361
+ const CUtensorMap& tma_a_desc,
362
+ const CUtensorMap& tma_b_desc,
363
+ const CUtensorMap& tma_scales_a_desc,
364
+ const CUtensorMap& tma_d_desc,
365
+ cudaStream_t stream,
366
+ int num_sms, uint32_t smem_size) {
367
+ // NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
368
+ constexpr uint32_t kNumTMAThreads = 128;
369
+ constexpr uint32_t kNumMathThreadsPerGroup = 128;
370
+ auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
371
+ kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
372
+ kNumTMAMulticast, kGemmType>;
373
+ DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
374
+
375
+ // Cluster launch
376
+ cudaLaunchConfig_t config;
377
+ config.gridDim = num_sms;
378
+ config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
379
+ config.dynamicSmemBytes = smem_size;
380
+ config.stream = stream;
381
+
382
+ // Clusters for TMA multicast
383
+ // NOTES: `>= 4` cluster size will cause performance degradation
384
+ cudaLaunchAttribute attr;
385
+ attr.id = cudaLaunchAttributeClusterDimension;
386
+ attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
387
+ config.attrs = &attr;
388
+ config.numAttrs = 1;
389
+
390
+ // Launch
391
+ auto status = cudaLaunchKernelEx(&config, kernel,
392
+ gmem_d, scales_b, grouped_layout,
393
+ shape_m,
394
+ tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
395
+ DG_HOST_ASSERT(status == cudaSuccess);
396
+ }
397
+
398
+ template <typename T>
399
+ static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) {
400
+ return make_2d_tma_desc(global_address, Layout::RowMajor,
401
+ shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K);
402
+ }
403
+
404
+ template <typename T>
405
+ static CUtensorMap make_2d_tma_b_desc(T* global_address) {
406
+ return make_2d_tma_desc(global_address, Layout::ColMajor,
407
+ SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
408
+ }
409
+
410
+ template <typename T>
411
+ static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
412
+ return make_2d_tma_desc(global_address, Layout::RowMajor,
413
+ shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
414
+ min(BLOCK_M, shape_m), BLOCK_N,
415
+ CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
416
+ }
417
+
418
+ template <typename T>
419
+ static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
420
+ // Make TMA aligned to 16 bytes
421
+ constexpr uint32_t kAlignment = 16 / sizeof(T);
422
+ shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
423
+
424
+ return make_2d_tma_desc(global_address, Layout::ColMajor,
425
+ shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
426
+ CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
427
+ }
428
+
429
+ template <typename T>
430
+ static CUtensorMap make_2d_tma_desc(
431
+ T* global_address, Layout layout,
432
+ uint32_t gmem_rows, uint32_t gmem_cols,
433
+ uint32_t smem_rows, uint32_t smem_cols,
434
+ CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
435
+ if (layout == Layout::RowMajor) {
436
+ uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
437
+ uint32_t smem_dim[2] = {smem_cols, smem_rows};
438
+ return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type);
439
+ } else {
440
+ uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
441
+ uint32_t smem_dim[2] = {smem_rows, smem_cols};
442
+ return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type);
443
+ }
444
+ }
445
+ };
446
+
447
+ }; // namespace deep_gemm
448
+
449
+ #pragma clang diagnostic pop
torch-ext/deep_gemm/include/deep_gemm/mma_utils.cuh ADDED
@@ -0,0 +1,885 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cuda.h>
4
+
5
+ #include "utils.cuh"
6
+
7
+ namespace deep_gemm {
8
+
9
+ struct SM90_64x16x32_F32E4M3E4M3_SS {
10
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
11
+ float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
12
+ bool scale_d) {
13
+ asm volatile("{\n"
14
+ ".reg .pred p;\n"
15
+ "setp.ne.b32 p, %10, 0;\n"
16
+ "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3"
17
+ "{%0, %1, %2, %3, %4, %5, %6, %7},"
18
+ " %8,"
19
+ " %9,"
20
+ " p , 1, 1;\n"
21
+ "}\n"
22
+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07)
23
+ : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
24
+ }
25
+
26
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
27
+ wgmma(desc_a, desc_b,
28
+ d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
29
+ scale_d);
30
+ }
31
+
32
+ static constexpr int M = 64;
33
+ static constexpr int N = 16;
34
+ static constexpr int K = 32;
35
+ static constexpr int kNumAccum = M * N / 128;
36
+ };
37
+
38
+ struct SM90_64x24x32_F32E4M3E4M3_SS {
39
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
40
+ float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
41
+ float& d08, float& d09, float& d10, float& d11,
42
+ bool scale_d) {
43
+ asm volatile("{\n"
44
+ ".reg .pred p;\n"
45
+ "setp.ne.b32 p, %14, 0;\n"
46
+ "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3"
47
+ "{%0, %1, %2, %3, %4, %5, %6, %7, "
48
+ " %8, %9, %10, %11},"
49
+ " %12,"
50
+ " %13,"
51
+ " p , 1, 1;\n"
52
+ "}\n"
53
+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
54
+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11)
55
+ : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
56
+ }
57
+
58
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
59
+ wgmma(desc_a, desc_b,
60
+ d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
61
+ d[8], d[9], d[10], d[11],
62
+ scale_d);
63
+ }
64
+
65
+ static constexpr int M = 64;
66
+ static constexpr int N = 24;
67
+ static constexpr int K = 32;
68
+ static constexpr int kNumAccum = M * N / 128;
69
+ };
70
+
71
+ struct SM90_64x32x32_F32E4M3E4M3_SS {
72
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
73
+ float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
74
+ float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
75
+ bool scale_d) {
76
+ asm volatile("{\n"
77
+ ".reg .pred p;\n"
78
+ "setp.ne.b32 p, %18, 0;\n"
79
+ "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3"
80
+ "{%0, %1, %2, %3, %4, %5, %6, %7, "
81
+ " %8, %9, %10, %11, %12, %13, %14, %15},"
82
+ " %16,"
83
+ " %17,"
84
+ " p , 1, 1;\n"
85
+ "}\n"
86
+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
87
+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
88
+ : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
89
+ }
90
+
91
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
92
+ wgmma(desc_a, desc_b,
93
+ d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
94
+ d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
95
+ scale_d);
96
+ }
97
+
98
+ static constexpr int M = 64;
99
+ static constexpr int N = 32;
100
+ static constexpr int K = 32;
101
+ static constexpr int kNumAccum = M * N / 128;
102
+ };
103
+
104
+ struct SM90_64x40x32_F32E4M3E4M3_SS {
105
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
106
+ float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
107
+ float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
108
+ float& d16, float& d17, float& d18, float& d19,
109
+ bool scale_d) {
110
+ asm volatile("{\n"
111
+ ".reg .pred p;\n"
112
+ "setp.ne.b32 p, %22, 0;\n"
113
+ "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3"
114
+ "{%0, %1, %2, %3, %4, %5, %6, %7, "
115
+ " %8, %9, %10, %11, %12, %13, %14, %15, "
116
+ " %16, %17, %18, %19},"
117
+ " %20,"
118
+ " %21,"
119
+ " p , 1, 1;\n"
120
+ "}\n"
121
+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
122
+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
123
+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19)
124
+ : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
125
+ }
126
+
127
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
128
+ wgmma(desc_a, desc_b,
129
+ d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
130
+ d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
131
+ d[16], d[17], d[18], d[19],
132
+ scale_d);
133
+ }
134
+
135
+ static constexpr int M = 64;
136
+ static constexpr int N = 40;
137
+ static constexpr int K = 32;
138
+ static constexpr int kNumAccum = M * N / 128;
139
+ };
140
+
141
+ struct SM90_64x48x32_F32E4M3E4M3_SS {
142
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
143
+ float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
144
+ float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
145
+ float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
146
+ bool scale_d) {
147
+ asm volatile("{\n"
148
+ ".reg .pred p;\n"
149
+ "setp.ne.b32 p, %26, 0;\n"
150
+ "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3"
151
+ "{%0, %1, %2, %3, %4, %5, %6, %7, "
152
+ " %8, %9, %10, %11, %12, %13, %14, %15, "
153
+ " %16, %17, %18, %19, %20, %21, %22, %23},"
154
+ " %24,"
155
+ " %25,"
156
+ " p , 1, 1;\n"
157
+ "}\n"
158
+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
159
+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
160
+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23)
161
+ : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
162
+ }
163
+
164
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
165
+ wgmma(desc_a, desc_b,
166
+ d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
167
+ d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
168
+ d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
169
+ scale_d);
170
+ }
171
+
172
+ static constexpr int M = 64;
173
+ static constexpr int N = 48;
174
+ static constexpr int K = 32;
175
+ static constexpr int kNumAccum = M * N / 128;
176
+ };
177
+
178
+ struct SM90_64x56x32_F32E4M3E4M3_SS {
179
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
180
+ float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
181
+ float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
182
+ float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
183
+ float& d24, float& d25, float& d26, float& d27,
184
+ bool scale_d) {
185
+ asm volatile("{\n"
186
+ ".reg .pred p;\n"
187
+ "setp.ne.b32 p, %30, 0;\n"
188
+ "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3"
189
+ "{%0, %1, %2, %3, %4, %5, %6, %7, "
190
+ " %8, %9, %10, %11, %12, %13, %14, %15, "
191
+ " %16, %17, %18, %19, %20, %21, %22, %23, "
192
+ " %24, %25, %26, %27}, "
193
+ " %28,"
194
+ " %29,"
195
+ " p , 1, 1;\n"
196
+ "}\n"
197
+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
198
+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
199
+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
200
+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27)
201
+ : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
202
+ }
203
+
204
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
205
+ wgmma(desc_a, desc_b,
206
+ d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
207
+ d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
208
+ d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
209
+ d[24], d[25], d[26], d[27],
210
+ scale_d);
211
+ }
212
+
213
+ static constexpr int M = 64;
214
+ static constexpr int N = 56;
215
+ static constexpr int K = 32;
216
+ static constexpr int kNumAccum = M * N / 128;
217
+ };
218
+
219
+ struct SM90_64x64x32_F32E4M3E4M3_SS {
220
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
221
+ float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
222
+ float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
223
+ float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
224
+ float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
225
+ bool scale_d) {
226
+ asm volatile("{\n"
227
+ ".reg .pred p;\n"
228
+ "setp.ne.b32 p, %34, 0;\n"
229
+ "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3"
230
+ "{%0, %1, %2, %3, %4, %5, %6, %7, "
231
+ " %8, %9, %10, %11, %12, %13, %14, %15, "
232
+ " %16, %17, %18, %19, %20, %21, %22, %23, "
233
+ " %24, %25, %26, %27, %28, %29, %30, %31}, "
234
+ " %32,"
235
+ " %33,"
236
+ " p , 1, 1;\n"
237
+ "}\n"
238
+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
239
+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
240
+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
241
+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
242
+ : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
243
+ }
244
+
245
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
246
+ wgmma(desc_a, desc_b,
247
+ d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
248
+ d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
249
+ d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
250
+ d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
251
+ scale_d);
252
+ }
253
+
254
+ static constexpr int M = 64;
255
+ static constexpr int N = 64;
256
+ static constexpr int K = 32;
257
+ static constexpr int kNumAccum = M * N / 128;
258
+ };
259
+
260
+ struct SM90_64x72x32_F32E4M3E4M3_SS {
261
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
262
+ float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
263
+ float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
264
+ float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
265
+ float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
266
+ float& d32, float& d33, float& d34, float& d35,
267
+ bool scale_d) {
268
+ asm volatile("{\n"
269
+ ".reg .pred p;\n"
270
+ "setp.ne.b32 p, %38, 0;\n"
271
+ "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3"
272
+ "{%0, %1, %2, %3, %4, %5, %6, %7, "
273
+ " %8, %9, %10, %11, %12, %13, %14, %15, "
274
+ " %16, %17, %18, %19, %20, %21, %22, %23, "
275
+ " %24, %25, %26, %27, %28, %29, %30, %31, "
276
+ " %32, %33, %34, %35}, "
277
+ " %36,"
278
+ " %37,"
279
+ " p , 1, 1;\n"
280
+ "}\n"
281
+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
282
+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
283
+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
284
+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
285
+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35)
286
+ : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
287
+ }
288
+
289
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
290
+ wgmma(desc_a, desc_b,
291
+ d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
292
+ d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
293
+ d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
294
+ d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
295
+ d[32], d[33], d[34], d[35],
296
+ scale_d);
297
+ }
298
+
299
+ static constexpr int M = 64;
300
+ static constexpr int N = 72;
301
+ static constexpr int K = 32;
302
+ static constexpr int kNumAccum = M * N / 128;
303
+ };
304
+
305
+ struct SM90_64x80x32_F32E4M3E4M3_SS {
306
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
307
+ float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
308
+ float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
309
+ float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
310
+ float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
311
+ float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
312
+ bool scale_d) {
313
+ asm volatile("{\n"
314
+ ".reg .pred p;\n"
315
+ "setp.ne.b32 p, %42, 0;\n"
316
+ "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3"
317
+ "{%0, %1, %2, %3, %4, %5, %6, %7, "
318
+ " %8, %9, %10, %11, %12, %13, %14, %15, "
319
+ " %16, %17, %18, %19, %20, %21, %22, %23, "
320
+ " %24, %25, %26, %27, %28, %29, %30, %31, "
321
+ " %32, %33, %34, %35, %36, %37, %38, %39}, "
322
+ " %40,"
323
+ " %41,"
324
+ " p , 1, 1;\n"
325
+ "}\n"
326
+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
327
+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
328
+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
329
+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
330
+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39)
331
+ : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
332
+ }
333
+
334
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
335
+ wgmma(desc_a, desc_b,
336
+ d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
337
+ d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
338
+ d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
339
+ d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
340
+ d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
341
+ scale_d);
342
+ }
343
+
344
+ static constexpr int M = 64;
345
+ static constexpr int N = 80;
346
+ static constexpr int K = 32;
347
+ static constexpr int kNumAccum = M * N / 128;
348
+ };
349
+
350
+ struct SM90_64x88x32_F32E4M3E4M3_SS {
351
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
352
+ float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
353
+ float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
354
+ float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
355
+ float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
356
+ float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
357
+ float& d40, float& d41, float& d42, float& d43,
358
+ bool scale_d) {
359
+ asm volatile("{\n"
360
+ ".reg .pred p;\n"
361
+ "setp.ne.b32 p, %46, 0;\n"
362
+ "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3"
363
+ "{%0, %1, %2, %3, %4, %5, %6, %7, "
364
+ " %8, %9, %10, %11, %12, %13, %14, %15, "
365
+ " %16, %17, %18, %19, %20, %21, %22, %23, "
366
+ " %24, %25, %26, %27, %28, %29, %30, %31, "
367
+ " %32, %33, %34, %35, %36, %37, %38, %39, "
368
+ " %40, %41, %42, %43}, "
369
+ " %44,"
370
+ " %45,"
371
+ " p , 1, 1;\n"
372
+ "}\n"
373
+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
374
+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
375
+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
376
+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
377
+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
378
+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43)
379
+ : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
380
+ }
381
+
382
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
383
+ wgmma(desc_a, desc_b,
384
+ d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
385
+ d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
386
+ d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
387
+ d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
388
+ d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
389
+ d[40], d[41], d[42], d[43],
390
+ scale_d);
391
+ }
392
+
393
+ static constexpr int M = 64;
394
+ static constexpr int N = 88;
395
+ static constexpr int K = 32;
396
+ static constexpr int kNumAccum = M * N / 128;
397
+ };
398
+
399
+ struct SM90_64x96x32_F32E4M3E4M3_SS {
400
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
401
+ float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
402
+ float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
403
+ float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
404
+ float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
405
+ float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
406
+ float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
407
+ bool scale_d) {
408
+ asm volatile("{\n"
409
+ ".reg .pred p;\n"
410
+ "setp.ne.b32 p, %50, 0;\n"
411
+ "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3"
412
+ "{%0, %1, %2, %3, %4, %5, %6, %7, "
413
+ " %8, %9, %10, %11, %12, %13, %14, %15, "
414
+ " %16, %17, %18, %19, %20, %21, %22, %23, "
415
+ " %24, %25, %26, %27, %28, %29, %30, %31, "
416
+ " %32, %33, %34, %35, %36, %37, %38, %39, "
417
+ " %40, %41, %42, %43, %44, %45, %46, %47}, "
418
+ " %48,"
419
+ " %49,"
420
+ " p , 1, 1;\n"
421
+ "}\n"
422
+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
423
+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
424
+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
425
+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
426
+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
427
+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
428
+ : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
429
+ }
430
+
431
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
432
+ wgmma(desc_a, desc_b,
433
+ d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
434
+ d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
435
+ d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
436
+ d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
437
+ d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
438
+ d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
439
+ scale_d);
440
+ }
441
+
442
+ static constexpr int M = 64;
443
+ static constexpr int N = 96;
444
+ static constexpr int K = 32;
445
+ static constexpr int kNumAccum = M * N / 128;
446
+ };
447
+
448
+ struct SM90_64x104x32_F32E4M3E4M3_SS {
449
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
450
+ float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
451
+ float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
452
+ float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
453
+ float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
454
+ float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
455
+ float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
456
+ float& d48, float& d49, float& d50, float& d51,
457
+ bool scale_d) {
458
+ asm volatile("{\n"
459
+ ".reg .pred p;\n"
460
+ "setp.ne.b32 p, %54, 0;\n"
461
+ "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3"
462
+ "{%0, %1, %2, %3, %4, %5, %6, %7, "
463
+ " %8, %9, %10, %11, %12, %13, %14, %15, "
464
+ " %16, %17, %18, %19, %20, %21, %22, %23, "
465
+ " %24, %25, %26, %27, %28, %29, %30, %31, "
466
+ " %32, %33, %34, %35, %36, %37, %38, %39, "
467
+ " %40, %41, %42, %43, %44, %45, %46, %47, "
468
+ " %48, %49, %50, %51}, "
469
+ " %52,"
470
+ " %53,"
471
+ " p , 1, 1;\n"
472
+ "}\n"
473
+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
474
+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
475
+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
476
+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
477
+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
478
+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
479
+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51)
480
+ : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
481
+ }
482
+
483
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
484
+ wgmma(desc_a, desc_b,
485
+ d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
486
+ d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
487
+ d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
488
+ d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
489
+ d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
490
+ d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
491
+ d[48], d[49], d[50], d[51],
492
+ scale_d);
493
+ }
494
+
495
+ static constexpr int M = 64;
496
+ static constexpr int N = 104;
497
+ static constexpr int K = 32;
498
+ static constexpr int kNumAccum = M * N / 128;
499
+ };
500
+
501
+ struct SM90_64x112x32_F32E4M3E4M3_SS {
502
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
503
+ float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
504
+ float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
505
+ float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
506
+ float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
507
+ float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
508
+ float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
509
+ float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
510
+ bool scale_d) {
511
+ asm volatile("{\n"
512
+ ".reg .pred p;\n"
513
+ "setp.ne.b32 p, %58, 0;\n"
514
+ "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3"
515
+ "{%0, %1, %2, %3, %4, %5, %6, %7, "
516
+ " %8, %9, %10, %11, %12, %13, %14, %15, "
517
+ " %16, %17, %18, %19, %20, %21, %22, %23, "
518
+ " %24, %25, %26, %27, %28, %29, %30, %31, "
519
+ " %32, %33, %34, %35, %36, %37, %38, %39, "
520
+ " %40, %41, %42, %43, %44, %45, %46, %47, "
521
+ " %48, %49, %50, %51, %52, %53, %54, %55}, "
522
+ " %56,"
523
+ " %57,"
524
+ " p , 1, 1;\n"
525
+ "}\n"
526
+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
527
+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
528
+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
529
+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
530
+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
531
+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
532
+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55)
533
+ : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
534
+ }
535
+
536
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
537
+ wgmma(desc_a, desc_b,
538
+ d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
539
+ d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
540
+ d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
541
+ d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
542
+ d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
543
+ d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
544
+ d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
545
+ scale_d);
546
+ }
547
+
548
+ static constexpr int M = 64;
549
+ static constexpr int N = 112;
550
+ static constexpr int K = 32;
551
+ static constexpr int kNumAccum = M * N / 128;
552
+ };
553
+
554
+ struct SM90_64x120x32_F32E4M3E4M3_SS {
555
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
556
+ float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
557
+ float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
558
+ float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
559
+ float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
560
+ float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
561
+ float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
562
+ float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
563
+ float& d56, float& d57, float& d58, float& d59,
564
+ bool scale_d) {
565
+ asm volatile("{\n"
566
+ ".reg .pred p;\n"
567
+ "setp.ne.b32 p, %62, 0;\n"
568
+ "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3"
569
+ "{%0, %1, %2, %3, %4, %5, %6, %7, "
570
+ " %8, %9, %10, %11, %12, %13, %14, %15, "
571
+ " %16, %17, %18, %19, %20, %21, %22, %23, "
572
+ " %24, %25, %26, %27, %28, %29, %30, %31, "
573
+ " %32, %33, %34, %35, %36, %37, %38, %39, "
574
+ " %40, %41, %42, %43, %44, %45, %46, %47, "
575
+ " %48, %49, %50, %51, %52, %53, %54, %55, "
576
+ " %56, %57, %58, %59}, "
577
+ " %60,"
578
+ " %61,"
579
+ " p , 1, 1;\n"
580
+ "}\n"
581
+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
582
+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
583
+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
584
+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
585
+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
586
+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
587
+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
588
+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59)
589
+ : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
590
+ }
591
+
592
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
593
+ wgmma(desc_a, desc_b,
594
+ d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
595
+ d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
596
+ d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
597
+ d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
598
+ d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
599
+ d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
600
+ d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
601
+ d[56], d[57], d[58], d[59],
602
+ scale_d);
603
+ }
604
+
605
+ static constexpr int M = 64;
606
+ static constexpr int N = 120;
607
+ static constexpr int K = 32;
608
+ static constexpr int kNumAccum = M * N / 128;
609
+ };
610
+
611
+ struct SM90_64x128x32_F32E4M3E4M3_SS {
612
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
613
+ float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
614
+ float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
615
+ float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
616
+ float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
617
+ float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
618
+ float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
619
+ float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
620
+ float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
621
+ bool scale_d) {
622
+ asm volatile("{\n"
623
+ ".reg .pred p;\n"
624
+ "setp.ne.b32 p, %66, 0;\n"
625
+ "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3"
626
+ "{%0, %1, %2, %3, %4, %5, %6, %7, "
627
+ " %8, %9, %10, %11, %12, %13, %14, %15, "
628
+ " %16, %17, %18, %19, %20, %21, %22, %23, "
629
+ " %24, %25, %26, %27, %28, %29, %30, %31, "
630
+ " %32, %33, %34, %35, %36, %37, %38, %39, "
631
+ " %40, %41, %42, %43, %44, %45, %46, %47, "
632
+ " %48, %49, %50, %51, %52, %53, %54, %55, "
633
+ " %56, %57, %58, %59, %60, %61, %62, %63}, "
634
+ " %64,"
635
+ " %65,"
636
+ " p , 1, 1;\n"
637
+ "}\n"
638
+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
639
+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
640
+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
641
+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
642
+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
643
+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
644
+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
645
+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
646
+ : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
647
+ }
648
+
649
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
650
+ wgmma(desc_a, desc_b,
651
+ d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
652
+ d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
653
+ d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
654
+ d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
655
+ d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
656
+ d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
657
+ d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
658
+ d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
659
+ scale_d);
660
+ }
661
+
662
+ static constexpr int M = 64;
663
+ static constexpr int N = 128;
664
+ static constexpr int K = 32;
665
+ static constexpr int kNumAccum = M * N / 128;
666
+ };
667
+
668
+ struct SM90_64x192x32_F32E4M3E4M3_SS {
669
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
670
+ float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
671
+ float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
672
+ float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
673
+ float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
674
+ float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
675
+ float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
676
+ float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
677
+ float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
678
+ float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71,
679
+ float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79,
680
+ float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87,
681
+ float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95,
682
+ bool scale_d) {
683
+ asm volatile("{\n"
684
+ ".reg .pred p;\n"
685
+ "setp.ne.b32 p, %98, 0;\n"
686
+ "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3"
687
+ "{%0, %1, %2, %3, %4, %5, %6, %7, "
688
+ " %8, %9, %10, %11, %12, %13, %14, %15, "
689
+ " %16, %17, %18, %19, %20, %21, %22, %23, "
690
+ " %24, %25, %26, %27, %28, %29, %30, %31, "
691
+ " %32, %33, %34, %35, %36, %37, %38, %39, "
692
+ " %40, %41, %42, %43, %44, %45, %46, %47, "
693
+ " %48, %49, %50, %51, %52, %53, %54, %55, "
694
+ " %56, %57, %58, %59, %60, %61, %62, %63, "
695
+ " %64, %65, %66, %67, %68, %69, %70, %71, "
696
+ " %72, %73, %74, %75, %76, %77, %78, %79, "
697
+ " %80, %81, %82, %83, %84, %85, %86, %87, "
698
+ " %88, %89, %90, %91, %92, %93, %94, %95}, "
699
+ " %96,"
700
+ " %97,"
701
+ " p , 1, 1;\n"
702
+ "}\n"
703
+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
704
+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
705
+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
706
+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
707
+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
708
+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
709
+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
710
+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
711
+ "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
712
+ "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
713
+ "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
714
+ "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
715
+ : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
716
+ }
717
+
718
+ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
719
+ wgmma(desc_a, desc_b,
720
+ d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
721
+ d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
722
+ d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
723
+ d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
724
+ d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
725
+ d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
726
+ d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
727
+ d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
728
+ d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71],
729
+ d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79],
730
+ d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87],
731
+ d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95],
732
+ scale_d);
733
+ }
734
+
735
+ static constexpr int M = 64;
736
+ static constexpr int N = 192;
737
+ static constexpr int K = 32;
738
+ static constexpr int kNumAccum = M * N / 128;
739
+ };
740
+
741
+ template <typename dtype_t>
742
+ struct SM90_U32x2_STSM_N {
743
+ __device__ __forceinline__ static void
744
+ copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
745
+ const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
746
+ asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
747
+ :: "l"(smem_dst), "r"(src[0]), "r"(src[1]));
748
+ }
749
+ };
750
+
751
+ template <typename dtype_t>
752
+ struct SM90_U32x4_STSM_N {
753
+ __device__ __forceinline__ static void
754
+ copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) {
755
+ const uint32_t src[4] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1),
756
+ *reinterpret_cast<uint32_t*>(&src_2), *reinterpret_cast<uint32_t*>(&src_3)};
757
+ asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
758
+ :: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3]));
759
+ }
760
+ };
761
+
762
+ __device__ void warpgroup_arrive() {
763
+ asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
764
+ }
765
+
766
+ __device__ void warpgroup_commit_batch() {
767
+ asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
768
+ }
769
+
770
+ __device__ void warpgroup_fence_operand(float& reg) {
771
+ asm volatile("" : "+f"(reg) :: "memory");
772
+ }
773
+
774
+ __forceinline__ __device__ uint32_t get_lane_id() {
775
+ uint32_t lane_id;
776
+ asm("mov.u32 %0, %laneid;" : "=r"(lane_id));
777
+ return lane_id;
778
+ }
779
+
780
+ __device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) {
781
+ uint32_t ret;
782
+ asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr));
783
+ return ret;
784
+ }
785
+
786
+ __device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) {
787
+ int4 ret;
788
+ asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
789
+ return ret;
790
+ }
791
+
792
+ __device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) {
793
+ float ret;
794
+ asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
795
+ return ret;
796
+ }
797
+
798
+ __device__ __forceinline__ void st_shared(const float* ptr, float val) {
799
+ asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
800
+ }
801
+
802
+ __device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
803
+ asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
804
+ }
805
+
806
+ template <int N>
807
+ __device__ void warpgroup_wait() {
808
+ DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
809
+ asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
810
+ }
811
+
812
+ union GmmaDescriptor {
813
+ __host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {}
814
+
815
+ __host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {}
816
+
817
+ __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {}
818
+
819
+ __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {}
820
+
821
+ __host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept {
822
+ desc_ = t.desc_;
823
+ return *this;
824
+ }
825
+
826
+ __host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept {
827
+ desc_ = t.desc_;
828
+ return *this;
829
+ }
830
+
831
+ uint64_t desc_;
832
+ uint32_t reg32_[2];
833
+ uint16_t reg16_[4];
834
+
835
+ struct {
836
+ uint16_t start_address_: 14, : 2;
837
+ uint16_t leading_byte_offset_: 14, : 2;
838
+ uint16_t stride_byte_offset_: 14, : 2;
839
+ uint8_t : 1, base_offset_: 3, : 4;
840
+ uint8_t : 6, layout_type_: 2;
841
+ } bitfield;
842
+
843
+ // Decay to an `uint64_t`
844
+ __host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; }
845
+ };
846
+
847
+ template <class PointerType>
848
+ __device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type,
849
+ int leading_byte_offset = 0,
850
+ int stride_byte_offset = 1024) {
851
+ GmmaDescriptor desc;
852
+ auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
853
+ desc.bitfield.start_address_ = uint_ptr >> 4;
854
+ desc.bitfield.layout_type_ = layout_type;
855
+ desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
856
+ desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
857
+ desc.bitfield.base_offset_ = 0;
858
+ return desc;
859
+ }
860
+
861
+ template <int N>
862
+ struct FP8MMASelector {
863
+ static constexpr auto select_type() {
864
+ if constexpr (N == 16) return SM90_64x16x32_F32E4M3E4M3_SS();
865
+ if constexpr (N == 24) return SM90_64x24x32_F32E4M3E4M3_SS();
866
+ if constexpr (N == 32) return SM90_64x32x32_F32E4M3E4M3_SS();
867
+ if constexpr (N == 40) return SM90_64x40x32_F32E4M3E4M3_SS();
868
+ if constexpr (N == 48) return SM90_64x48x32_F32E4M3E4M3_SS();
869
+ if constexpr (N == 56) return SM90_64x56x32_F32E4M3E4M3_SS();
870
+ if constexpr (N == 64) return SM90_64x64x32_F32E4M3E4M3_SS();
871
+ if constexpr (N == 72) return SM90_64x72x32_F32E4M3E4M3_SS();
872
+ if constexpr (N == 80) return SM90_64x80x32_F32E4M3E4M3_SS();
873
+ if constexpr (N == 88) return SM90_64x88x32_F32E4M3E4M3_SS();
874
+ if constexpr (N == 96) return SM90_64x96x32_F32E4M3E4M3_SS();
875
+ if constexpr (N == 104) return SM90_64x104x32_F32E4M3E4M3_SS();
876
+ if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS();
877
+ if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS();
878
+ if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS();
879
+ if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS();
880
+ }
881
+
882
+ using type = decltype(select_type());
883
+ };
884
+
885
+ } // namespace deep_gemm
torch-ext/deep_gemm/include/deep_gemm/scheduler.cuh ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "utils.cuh"
2
+
3
+ namespace deep_gemm {
4
+
5
+ enum class GemmType {
6
+ Normal,
7
+ GroupedContiguous,
8
+ GroupedMasked
9
+ };
10
+
11
+ #pragma clang diagnostic push
12
+ #pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
13
+ template <GemmType kGemmType,
14
+ uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
15
+ uint32_t kNumGroups, uint32_t kNumTMAMulticast,
16
+ uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N),
17
+ uint32_t kNumNBlocksPerGroup = 16>
18
+ struct Scheduler {
19
+ int current_iter = -1;
20
+ uint32_t num_aligned_m_blocks;
21
+
22
+ // For normal GEMM
23
+ // Maybe not used in the masked grouped GEMM
24
+ uint32_t num_blocks;
25
+
26
+ // For grouped GEMM
27
+ int* grouped_layout;
28
+ // Only used for masked layout
29
+ uint32_t curr_group_idx, curr_cumsum;
30
+
31
+ __device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
32
+ int* grouped_layout = nullptr) {
33
+ num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M);
34
+ if constexpr (kGemmType == GemmType::Normal) {
35
+ num_blocks = num_aligned_m_blocks * kNumNBlocks;
36
+ } else if (kGemmType == GemmType::GroupedContiguous) {
37
+ num_blocks = num_aligned_m_blocks * kNumNBlocks;
38
+ this->grouped_layout = grouped_layout;
39
+ } else if (kGemmType == GemmType::GroupedMasked) {
40
+ curr_group_idx = curr_cumsum = 0;
41
+ this->grouped_layout = grouped_layout;
42
+ }
43
+ }
44
+
45
+ __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
46
+ DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
47
+
48
+ // Swizzle for better L2 usages
49
+ auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup;
50
+ auto group_idx = block_idx / num_blocks_per_group;
51
+ auto first_n_block_idx = group_idx * kNumNBlocksPerGroup;
52
+ auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx);
53
+ auto in_group_idx = block_idx % num_blocks_per_group;
54
+ m_block_idx = in_group_idx / num_n_blocks_in_group;
55
+ n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
56
+ }
57
+
58
+ template <bool kIgnoreGroupedForGroupedContiguous=true>
59
+ __device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
60
+ const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
61
+ if constexpr (kGemmType == GemmType::Normal) {
62
+ return block_idx * block_size;
63
+ } else if (kGemmType == GemmType::GroupedContiguous) {
64
+ auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
65
+ return offset * shape_dim + block_idx * block_size;
66
+ } else if (kGemmType == GemmType::GroupedMasked) {
67
+ return curr_group_idx * shape_dim + block_idx * block_size;
68
+ }
69
+ }
70
+
71
+ __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
72
+ const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x;
73
+
74
+ if constexpr (kGemmType == GemmType::GroupedMasked) {
75
+ uint32_t num_m_blocks;
76
+ while (true) {
77
+ // End of the task
78
+ if (curr_group_idx == kNumGroups)
79
+ return false;
80
+
81
+ // Within current group
82
+ num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
83
+ auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
84
+ if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
85
+ break;
86
+
87
+ // Move to check the next group
88
+ curr_group_idx ++, curr_cumsum = current_m_block_cumsum;
89
+ }
90
+
91
+ get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx);
92
+ } else {
93
+ if (next_block_idx >= num_blocks)
94
+ return false;
95
+
96
+ get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
97
+ }
98
+ return true;
99
+ }
100
+ };
101
+ #pragma clang diagnostic pop
102
+
103
+ } // namespace deep_gemm
torch-ext/deep_gemm/include/deep_gemm/tma_utils.cuh ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cassert>
4
+ #include <cuda.h>
5
+ #include <cudaTypedefs.h>
6
+ #include <cuda_fp8.h>
7
+ #include <cuda_runtime.h>
8
+ #include <cuda/barrier>
9
+
10
+ #include "utils.cuh"
11
+
12
+ namespace deep_gemm {
13
+
14
+ template <class T>
15
+ constexpr CUtensorMapDataType get_CUtensorMapDataType() {
16
+ if constexpr (std::is_same<T, uint8_t>::value) {
17
+ return CU_TENSOR_MAP_DATA_TYPE_UINT8;
18
+ } else if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
19
+ return CU_TENSOR_MAP_DATA_TYPE_UINT8;
20
+ } else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
21
+ return CU_TENSOR_MAP_DATA_TYPE_UINT8;
22
+ } else if constexpr (std::is_same<T, uint16_t>::value) {
23
+ return CU_TENSOR_MAP_DATA_TYPE_UINT16;
24
+ } else if constexpr (std::is_same<T, uint32_t>::value) {
25
+ return CU_TENSOR_MAP_DATA_TYPE_UINT32;
26
+ } else if constexpr (std::is_same<T, uint64_t>::value) {
27
+ return CU_TENSOR_MAP_DATA_TYPE_UINT64;
28
+ } else if constexpr (std::is_same<T, int32_t>::value) {
29
+ return CU_TENSOR_MAP_DATA_TYPE_INT32;
30
+ } else if constexpr (std::is_same<T, int64_t>::value) {
31
+ return CU_TENSOR_MAP_DATA_TYPE_INT64;
32
+ } else if constexpr (std::is_same<T, __half>::value) {
33
+ return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
34
+ } else if constexpr (std::is_same<T, float>::value) {
35
+ return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
36
+ } else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
37
+ return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
38
+ } else if constexpr (std::is_same<T, double>::value) {
39
+ return CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
40
+ }
41
+ }
42
+
43
+ PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
44
+ // Get pointer to `cuTensorMapEncodeTiled`
45
+ cudaDriverEntryPointQueryResult driver_status;
46
+ void* cuTensorMapEncodeTiled_ptr = nullptr;
47
+
48
+ #if CUDA_VERSION >= 12050
49
+ cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000,
50
+ cudaEnableDefault, &driver_status);
51
+ #else
52
+ cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr,
53
+ cudaEnableDefault, &driver_status);
54
+ #endif
55
+
56
+ if (driver_status != cudaDriverEntryPointSuccess)
57
+ throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
58
+ return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
59
+ }
60
+
61
+ template <typename T>
62
+ CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
63
+ uint64_t stride_in_bytes, uint32_t smem_dim[2],
64
+ CUtensorMapSwizzle swizzle_type,
65
+ PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
66
+ CUtensorMap tensor_map{};
67
+ constexpr uint32_t rank = 2;
68
+ uint64_t global_stride[rank - 1] = {stride_in_bytes};
69
+ uint32_t elem_strides[rank] = {1, 1};
70
+
71
+ if (encode_func == nullptr)
72
+ encode_func = get_cuTensorMapEncodeTiled();
73
+
74
+ auto result = encode_func(
75
+ &tensor_map, get_CUtensorMapDataType<typename std::remove_cv<T>::type>(), rank,
76
+ global_address, gmem_dim, global_stride, smem_dim, elem_strides,
77
+ CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
78
+ CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
79
+ CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
80
+ DG_HOST_ASSERT(result == CUDA_SUCCESS);
81
+ return tensor_map;
82
+ }
83
+
84
+ template <uint32_t kNumTMAMulticast = 1>
85
+ __device__ __forceinline__ void
86
+ tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
87
+ int32_t const& crd_0, int32_t const& crd_1) {
88
+ constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
89
+ if constexpr (kNumTMAMulticast == 1) {
90
+ cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
91
+ } else if (cute::block_rank_in_cluster() == 0) {
92
+ cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << kNumTMAMulticast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
93
+ }
94
+ }
95
+
96
+ } // namespace deep_gemm
torch-ext/deep_gemm/include/deep_gemm/utils.cuh ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <exception>
4
+
5
+ #ifdef __CLION_IDE__
6
+ __host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
7
+ #define printf host_device_printf
8
+ #endif
9
+
10
+ class AssertionException : public std::exception {
11
+ private:
12
+ std::string message{};
13
+
14
+ public:
15
+ explicit AssertionException(const std::string& message) : message(message) {}
16
+
17
+ const char *what() const noexcept override { return message.c_str(); }
18
+ };
19
+
20
+ #ifndef DG_HOST_ASSERT
21
+ #define DG_HOST_ASSERT(cond) \
22
+ do { \
23
+ if (not (cond)) { \
24
+ printf("Assertion failed: %s:%d, condition: %s\n", \
25
+ __FILE__, __LINE__, #cond); \
26
+ throw AssertionException("Assertion failed: " #cond); \
27
+ } \
28
+ } while (0)
29
+ #endif
30
+
31
+ #ifndef DG_DEVICE_ASSERT
32
+ #define DG_DEVICE_ASSERT(cond) \
33
+ do { \
34
+ if (not (cond)) { \
35
+ printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
36
+ asm("trap;"); \
37
+ } \
38
+ } while (0)
39
+ #endif
40
+
41
+ #ifndef DG_STATIC_ASSERT
42
+ #define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
43
+ #endif
44
+
45
+ template <typename T>
46
+ __device__ __host__ constexpr T ceil_div(T a, T b) {
47
+ return (a + b - 1) / b;
48
+ }
torch-ext/deep_gemm/jit/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .compiler import get_nvcc_compiler, build
2
+ from .template import cpp_format, generate
3
+ from .runtime import Runtime
torch-ext/deep_gemm/jit/compiler.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import functools
3
+ import os
4
+ import re
5
+ import subprocess
6
+ import uuid
7
+ from torch.utils.cpp_extension import CUDA_HOME
8
+ from typing import Tuple
9
+
10
+ from . import interleave_ffma
11
+ from .runtime import Runtime, RuntimeCache
12
+ from .template import typename_map
13
+
14
+ runtime_cache = RuntimeCache()
15
+
16
+
17
+ def hash_to_hex(s: str) -> str:
18
+ md5 = hashlib.md5()
19
+ md5.update(s.encode('utf-8'))
20
+ return md5.hexdigest()[0:12]
21
+
22
+
23
+ @functools.lru_cache(maxsize=None)
24
+ def get_jit_include_dir() -> str:
25
+ return f'{os.path.dirname(os.path.abspath(__file__))}/../include'
26
+
27
+
28
+ @functools.lru_cache(maxsize=None)
29
+ def get_deep_gemm_version() -> str:
30
+ # Update include directories
31
+ include_dir = f'{get_jit_include_dir()}/deep_gemm'
32
+ assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
33
+ md5 = hashlib.md5()
34
+ for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
35
+ with open(f'{include_dir}/{filename}', 'rb') as f:
36
+ md5.update(f.read())
37
+
38
+ # Update `interleave_ffma.py`
39
+ with open(f'{os.path.dirname(os.path.realpath(__file__))}/interleave_ffma.py', 'rb') as f:
40
+ md5.update(f.read())
41
+ return md5.hexdigest()[0:12]
42
+
43
+
44
+ @functools.lru_cache(maxsize=None)
45
+ def get_nvcc_compiler() -> Tuple[str, str]:
46
+ paths = []
47
+ if os.getenv('DG_NVCC_COMPILER'):
48
+ paths.append(os.getenv('DG_NVCC_COMPILER'))
49
+ paths.append(f'{CUDA_HOME}/bin/nvcc')
50
+
51
+ # Try to find the first available NVCC compiler
52
+ least_version_required = '12.3'
53
+ version_pattern = re.compile(r'release (\d+\.\d+)')
54
+ for path in paths:
55
+ if os.path.exists(path):
56
+ match = version_pattern.search(os.popen(f'{path} --version').read())
57
+ version = match.group(1)
58
+ assert match, f'Cannot get the version of NVCC compiler {path}'
59
+ assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
60
+ return path, version
61
+ raise RuntimeError('Cannot find any available NVCC compiler')
62
+
63
+
64
+ @functools.lru_cache(maxsize=None)
65
+ def get_default_user_dir():
66
+ if 'DG_CACHE_DIR' in os.environ:
67
+ path = os.getenv('DG_CACHE_DIR')
68
+ os.makedirs(path, exist_ok=True)
69
+ return path
70
+ return os.path.expanduser('~') + '/.deep_gemm'
71
+
72
+
73
+ @functools.lru_cache(maxsize=None)
74
+ def get_tmp_dir():
75
+ return f'{get_default_user_dir()}/tmp'
76
+
77
+
78
+ @functools.lru_cache(maxsize=None)
79
+ def get_cache_dir():
80
+ return f'{get_default_user_dir()}/cache'
81
+
82
+
83
+ def make_tmp_dir():
84
+ tmp_dir = get_tmp_dir()
85
+ os.makedirs(tmp_dir, exist_ok=True)
86
+ return tmp_dir
87
+
88
+
89
+ def put(path, data, is_binary=False):
90
+ # Write and do POSIX atomic replace
91
+ tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}'
92
+ with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
93
+ f.write(data)
94
+ os.replace(tmp_file_path, path)
95
+
96
+
97
+ def build(name: str, arg_defs: tuple, code: str) -> Runtime:
98
+ # Compiler flags
99
+ nvcc_flags = ['-std=c++17', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
100
+ '-gencode=arch=compute_90a,code=sm_90a',
101
+ '--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
102
+ # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
103
+ '--diag-suppress=177,174,940']
104
+ cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi']
105
+ flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
106
+ include_dirs = [get_jit_include_dir()]
107
+
108
+ # Build signature
109
+ enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
110
+ signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
111
+ name = f'kernel.{name}.{hash_to_hex(signature)}'
112
+ path = f'{get_cache_dir()}/{name}'
113
+
114
+ # Check runtime cache or file system hit
115
+ global runtime_cache
116
+ if runtime_cache[path] is not None:
117
+ if os.getenv('DG_JIT_DEBUG', None):
118
+ print(f'Using cached JIT runtime {name} during build')
119
+ return runtime_cache[path]
120
+
121
+ # Write the code
122
+ os.makedirs(path, exist_ok=True)
123
+ args_path = f'{path}/kernel.args'
124
+ src_path = f'{path}/kernel.cu'
125
+ put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs]))
126
+ put(src_path, code)
127
+
128
+ # Compile into a temporary SO file
129
+ so_path = f'{path}/kernel.so'
130
+ tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so'
131
+
132
+ # Compile
133
+ command = [get_nvcc_compiler()[0],
134
+ src_path, '-o', tmp_so_path,
135
+ *flags,
136
+ *[f'-I{d}' for d in include_dirs]]
137
+ if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
138
+ print(f'Compiling JIT runtime {name} with command {command}')
139
+ return_code = subprocess.check_call(command)
140
+ assert return_code == 0, f'Failed to compile {src_path}'
141
+
142
+ # Interleave FFMA reuse
143
+ if enable_sass_opt:
144
+ interleave_ffma.process(tmp_so_path)
145
+
146
+ # Atomic replace SO file
147
+ os.replace(tmp_so_path, so_path)
148
+
149
+ # Put cache and return
150
+ runtime_cache[path] = Runtime(path)
151
+ return runtime_cache[path]
torch-ext/deep_gemm/jit/interleave_ffma.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import mmap
3
+ import os
4
+ import re
5
+ import subprocess
6
+ from torch.utils.cpp_extension import CUDA_HOME
7
+
8
+
9
+ def run_cuobjdump(file_path):
10
+ command = [f'{CUDA_HOME}/bin/cuobjdump', '-sass', file_path]
11
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
12
+ assert result.returncode == 0
13
+ return result.stdout
14
+
15
+
16
+ def extract_ffma(sass):
17
+ lines = sass.splitlines()
18
+ collected = []
19
+ current = []
20
+
21
+ arch_name, func_name = 'N/A', 'N/A'
22
+ skip_next_line = False
23
+ for line in lines:
24
+ if 'code for' in line:
25
+ arch_name = line.lstrip().lstrip('code for ').rstrip()
26
+ elif 'Function :' in line:
27
+ func_name = line.lstrip().lstrip('Function :').rstrip()
28
+ elif 'FFMA' in line:
29
+ current.append(line)
30
+ skip_next_line = True
31
+ elif skip_next_line:
32
+ current.append(line)
33
+ skip_next_line = False
34
+ else:
35
+ if len(current) >= 16:
36
+ assert len(current) % 2 == 0
37
+ collected.append((f'{arch_name}::{func_name}', current))
38
+ current = []
39
+
40
+ if os.getenv('DG_PRINT_REG_REUSE', None):
41
+ print(f'Found {len(collected)} FFMA segments')
42
+ return collected
43
+
44
+
45
+ def extract_hex_from_line(line):
46
+ match = re.search(r'/\*\s*(0x[0-9a-fA-F]+)\s*\*/', line)
47
+ assert match
48
+ return int(match.group(1), 16)
49
+
50
+
51
+ def validate(m, offset, le_bytes, num_lines):
52
+ assert len(le_bytes) == num_lines // 2
53
+ assert m[offset:offset + 16] == le_bytes[0]
54
+ for i in range(1, num_lines // 2):
55
+ if m[offset + i * 16:offset + i * 16 + 16] != le_bytes[i]:
56
+ return False
57
+ return True
58
+
59
+
60
+ def parse_registers(line):
61
+ line = re.sub(r'/\*.*?\*/', '', line)
62
+ line = line.replace(';', '')
63
+ tokens = line.strip().split(',')
64
+ registers = []
65
+ for token in tokens:
66
+ token = token.strip()
67
+ words = token.split()
68
+ for word in words:
69
+ if word.startswith('R'):
70
+ reg = word.split('.')[0]
71
+ registers.append(reg)
72
+ return registers
73
+
74
+
75
+ def modify_segment(m, name, ffma_lines):
76
+ num_lines = len(ffma_lines)
77
+ assert num_lines % 2 == 0
78
+
79
+ le_bytes, new_le_bytes = [], []
80
+ reused_list = []
81
+ dst_reg_set = set()
82
+ last_reused, last_dst_reg = False, ''
83
+ num_changed = 0
84
+ for i in range(num_lines // 2):
85
+ dst_reg = parse_registers(ffma_lines[i * 2])[-2]
86
+ low_line, high_line = ffma_lines[i * 2], ffma_lines[i * 2 + 1]
87
+ low_hex, high_hex = extract_hex_from_line(low_line), extract_hex_from_line(high_line)
88
+ le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
89
+ reused = (high_hex & 0x0800000000000000) != 0
90
+ if reused:
91
+ is_first_occurred = dst_reg not in dst_reg_set
92
+ if is_first_occurred or (last_reused and dst_reg == last_dst_reg):
93
+ # Modify the `reuse` and `yield` bits
94
+ assert high_hex & 0x0800200000000000, f'{hex(high_hex)}'
95
+ high_hex ^= 0x0800200000000000
96
+ reused = False
97
+ num_changed += 1
98
+ else:
99
+ reused_list.append(i)
100
+ dst_reg_set.add(dst_reg)
101
+ new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
102
+ last_reused, last_dst_reg = reused, dst_reg
103
+ if os.getenv('DG_PRINT_REG_REUSE', None):
104
+ print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}')
105
+
106
+ # Find the offset
107
+ offsets = []
108
+ offset = m.find(le_bytes[0])
109
+ while offset != -1:
110
+ offsets.append(offset)
111
+ offset = m.find(le_bytes[0], offset + 1)
112
+ offsets = list(filter(lambda x: validate(m, x, le_bytes, num_lines), offsets))
113
+
114
+ # Replace with `new_le_bytes`
115
+ for offset in offsets:
116
+ for i in range(num_lines // 2):
117
+ m[offset + i * 16:offset + i * 16 + 16] = new_le_bytes[i]
118
+
119
+
120
+ def process(path):
121
+ if os.getenv('DG_PRINT_REG_REUSE', None):
122
+ print(f'Processing {path}')
123
+ output = run_cuobjdump(path)
124
+ segments = extract_ffma(output)
125
+ with open(path, 'r+b') as f:
126
+ mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_WRITE)
127
+ for segment in segments:
128
+ modify_segment(mm, *segment)
129
+ mm.close()
130
+
131
+
132
+ if __name__ == '__main__':
133
+ parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse')
134
+ parser.add_argument('--so', help='Path to the SO file')
135
+ args = parser.parse_args()
136
+
137
+ process(args.so)
torch-ext/deep_gemm/jit/runtime.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import os
3
+ import torch
4
+ from typing import Optional
5
+
6
+ from .template import map_ctype
7
+
8
+
9
+ class Runtime:
10
+ def __init__(self, path: str) -> None:
11
+ self.path = path
12
+ self.lib = None
13
+ self.args = None
14
+
15
+ assert self.is_path_valid(self.path)
16
+
17
+ @staticmethod
18
+ def is_path_valid(path: str) -> bool:
19
+ # Exists and is a directory
20
+ if not os.path.exists(path) or not os.path.isdir(path):
21
+ return False
22
+
23
+ # Contains all necessary files
24
+ files = ['kernel.cu', 'kernel.args', 'kernel.so']
25
+ return all(os.path.exists(os.path.join(path, file)) for file in files)
26
+
27
+ def __call__(self, *args) -> int:
28
+ # Load SO file
29
+ if self.lib is None or self.args is None:
30
+ self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so'))
31
+ with open(os.path.join(self.path, 'kernel.args'), 'r') as f:
32
+ self.args = eval(f.read())
33
+
34
+ # Check args and launch
35
+ assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
36
+ cargs = []
37
+ for arg, (name, dtype) in zip(args, self.args):
38
+ if isinstance(arg, torch.Tensor):
39
+ assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
40
+ else:
41
+ assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
42
+ cargs.append(map_ctype(arg))
43
+
44
+ return_code = ctypes.c_int(0)
45
+ self.lib.launch(*cargs, ctypes.byref(return_code))
46
+ return return_code.value
47
+
48
+
49
+ class RuntimeCache:
50
+ def __init__(self) -> None:
51
+ self.cache = {}
52
+
53
+ def __getitem__(self, path: str) -> Optional[Runtime]:
54
+ # In Python runtime
55
+ if path in self.cache:
56
+ return self.cache[path]
57
+
58
+ # Already compiled
59
+ if os.path.exists(path) and Runtime.is_path_valid(path):
60
+ runtime = Runtime(path)
61
+ self.cache[path] = runtime
62
+ return runtime
63
+ return None
64
+
65
+ def __setitem__(self, path, runtime) -> None:
66
+ self.cache[path] = runtime
torch-ext/deep_gemm/jit/template.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import ctypes
3
+ import os
4
+ import torch
5
+
6
+ from typing import Any, Iterable, Dict, Tuple
7
+
8
+
9
+ # Name map for Python `eval`
10
+ typename_map: Dict[Any, str] = {
11
+ **{t: t.__name__ for t in (bool, int, float)},
12
+ torch.int: 'torch.int',
13
+ torch.float: 'torch.float',
14
+ torch.bfloat16: 'torch.bfloat16',
15
+ torch.float8_e4m3fn: 'torch.float8_e4m3fn',
16
+ torch.cuda.Stream: 'torch.cuda.Stream',
17
+ }
18
+
19
+ # `ctype` map for Python casting
20
+ ctype_map: Dict[Any, Any] = {
21
+ **{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
22
+ **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
23
+ }
24
+
25
+
26
+ # Type map for both Python API and source code usages
27
+ genc_map = {
28
+ bool: ('bool', 'bool'),
29
+ int: ('int', 'int'),
30
+ float: ('float', 'float'),
31
+ torch.int: ('void*', 'int*'),
32
+ torch.float: ('void*', 'float*'),
33
+ torch.bfloat16: ('void*', '__nv_bfloat16*'),
34
+ torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
35
+ torch.cuda.Stream: ('void*', 'cudaStream_t'),
36
+ }
37
+
38
+
39
+ def map_ctype(value: Any) -> Any:
40
+ ctype = ctype_map[value.dtype if isinstance(value, torch.Tensor) else type(value)]
41
+ if isinstance(value, torch.Tensor):
42
+ return ctype(value.data_ptr())
43
+ if isinstance(value, torch.cuda.Stream):
44
+ return ctype(value.cuda_stream)
45
+ return ctype(value)
46
+
47
+
48
+ def cpp_format(template: str, keys: Dict[str, Any]) -> str:
49
+ # We don't use `str.format` because it's not safe for C++ {} braces
50
+ new_template = copy.deepcopy(template)
51
+ for key, value in keys.items():
52
+ new_template = new_template.replace(f'{{{key}}}', f'{value}')
53
+ return new_template
54
+
55
+
56
+ def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str:
57
+ # Common prefix
58
+ code = '// DeepGEMM auto-generated JIT CUDA source file\n\n'
59
+
60
+ # Includes
61
+ preload_sys_includes = ['<cuda.h>', '<cuda_fp8.h>', '<cuda_runtime.h>', '<iostream>']
62
+ preload_package_includes = ['"cutlass/cutlass.h"']
63
+
64
+ assert isinstance(includes, list) or isinstance(includes, tuple)
65
+ sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')])))
66
+ package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')])))
67
+ code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n'
68
+ code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n'
69
+
70
+ # Function signature
71
+ raw = '__raw_'
72
+ get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n
73
+ code += f'extern "C" void launch('
74
+ code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ])
75
+ code += ') {\n'
76
+
77
+ # Cast raw types
78
+ code += ' // Cast raw types (if needed)\n'
79
+ for arg_name, arg_type in arg_defs:
80
+ if genc_map[arg_type][0] != genc_map[arg_type][1]:
81
+ code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n'
82
+
83
+ # Function body
84
+ code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')])
85
+
86
+ # End the function
87
+ code += '}\n\n'
88
+
89
+ # Debug print
90
+ if os.getenv('DG_JIT_DEBUG', None):
91
+ print(f'Generated code:\n{code}')
92
+
93
+ return code
torch-ext/deep_gemm/jit_kernels/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .gemm import gemm_fp8_fp8_bf16_nt
2
+ from .m_grouped_gemm import (
3
+ m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
4
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked
5
+ )
6
+ from .utils import (
7
+ ceil_div, set_num_sms, get_num_sms,
8
+ get_col_major_tma_aligned_tensor,
9
+ get_m_alignment_for_contiguous_layout
10
+ )
torch-ext/deep_gemm/jit_kernels/gemm.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple
3
+
4
+ from .tuner import jit_tuner
5
+ from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
6
+
7
+ # C++ code templates
8
+ includes = ('"deep_gemm/fp8_gemm.cuh"', )
9
+ template = """
10
+ using namespace deep_gemm;
11
+
12
+ // Templated args from Python JIT call
13
+ constexpr auto N = {N}, K = {K};
14
+ constexpr auto BLOCK_M = {BLOCK_M};
15
+ constexpr auto BLOCK_N = {BLOCK_N};
16
+ constexpr auto kNumStages = {NUM_STAGES};
17
+ constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
18
+
19
+ // Make a templated GEMM
20
+ using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
21
+
22
+ // Launch kernel
23
+ auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
24
+ auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
25
+ auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
26
+ auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
27
+ GemmType::run(out, rhs_scales, nullptr,
28
+ m,
29
+ tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
30
+ stream, num_sms, smem_size);
31
+ """
32
+
33
+
34
+ def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms: int) -> bool:
35
+ if num_tma_multicast == 1:
36
+ return True
37
+ return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
38
+
39
+
40
+ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int:
41
+ smem_d = block_m * block_n * 2
42
+ smem_a_per_stage = block_m * block_k
43
+ smem_scales_a_per_stage = block_m * 4
44
+ smem_b_per_stage = block_n * block_k
45
+ smem_scales_b = ceil_div(k, block_k) * 4
46
+ smem_barrier = num_stages * 8 * 2
47
+
48
+ smem_size = 0
49
+ smem_size += smem_d
50
+ smem_size += num_stages * smem_a_per_stage
51
+ smem_size += num_stages * smem_scales_a_per_stage
52
+ smem_size += num_stages * smem_b_per_stage
53
+ smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
54
+ smem_size += smem_barrier
55
+ return smem_size
56
+
57
+
58
+ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
59
+ is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int]:
60
+ if not is_grouped_contiguous:
61
+ # TODO: for some cases, smaller M block is better, add them into tuning space
62
+ block_ms = (64 if m <= 64 else 128, )
63
+ else:
64
+ block_ms = (get_m_alignment_for_contiguous_layout(), )
65
+ block_ns = tuple(range(16, 129, 8))
66
+
67
+ fix_wave_saturate = lambda x: num_sms if x == 0 else x
68
+ get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
69
+ get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
70
+
71
+ # Decide block sizes by waves
72
+ best_block_m, best_block_n = None, None
73
+ for block_m in block_ms:
74
+ for block_n in block_ns:
75
+ success = False
76
+ num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
77
+ if best_block_m is None or best_block_n is None:
78
+ success = True
79
+ elif num_waves < best_num_waves:
80
+ success = True
81
+ elif num_waves == best_num_waves:
82
+ # Check last wave utilization
83
+ util = get_last_wave_util(block_m, block_n)
84
+ best_util = get_last_wave_util(best_block_m, best_block_n)
85
+ success = util > best_util or (util == best_util and (block_m > best_block_m or (block_m == best_block_m and block_n < best_block_n)))
86
+ best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
87
+ assert best_block_m is not None and best_block_n is not None
88
+
89
+ # Always pick the longest one
90
+ # NOTES: for double B scales, the best number of stages may be reduced
91
+ best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
92
+ for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
93
+ best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
94
+ if best_smem_size <= sm90_capacity:
95
+ best_num_stages = num_stages
96
+ break
97
+ assert best_num_stages is not None
98
+
99
+ # Decide the number of TMA multicast
100
+ best_num_tma_multicast = 1
101
+ if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1:
102
+ best_num_tma_multicast = 2
103
+
104
+ return best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size
105
+
106
+
107
+ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
108
+ rhs: Tuple[torch.Tensor, torch.Tensor],
109
+ out: torch.Tensor) -> None:
110
+ """
111
+ Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
112
+ LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
113
+ RHS and RHS scaling factors are required to be transposed.
114
+ The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
115
+ this function will do a transposing with a set of slow PyTorch operations.
116
+
117
+ Arguments:
118
+ lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
119
+ the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
120
+ rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
121
+ the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
122
+ out: the BF16 output tensor of shape `[m, n]`, representing the result.
123
+ """
124
+ lhs, lhs_scales = lhs
125
+ rhs, rhs_scales = rhs
126
+ m, k = lhs.shape
127
+ n, k_ = rhs.shape
128
+ m_, n_ = out.shape
129
+
130
+ assert n % 64 == 0 and k % 128 == 0
131
+
132
+ # Type and shape checks
133
+ assert m == m_ and n == n_ and k == k_
134
+ assert n > 0 and k > 0
135
+ assert lhs_scales.shape == (m, (k + 127) // 128)
136
+ assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
137
+ assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
138
+ assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
139
+ assert out.dtype == torch.bfloat16
140
+ assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
141
+
142
+ # LHS scales must be transposed for TMA load, but not for RHS scales
143
+ # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
144
+ lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
145
+ assert rhs_scales.is_contiguous()
146
+
147
+ # Do nothing if `m` is zero
148
+ if m == 0:
149
+ return
150
+
151
+ # Auto-tuning with compilation
152
+ global includes, template
153
+ num_sms = get_num_sms()
154
+ block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms)
155
+ args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
156
+ runtime = jit_tuner.compile_and_tune(
157
+ name='gemm_fp8_fp8_bf16_nt',
158
+ keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
159
+ 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast},
160
+ space=(),
161
+ includes=includes,
162
+ arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
163
+ ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
164
+ ('out', torch.bfloat16), ('m', int),
165
+ ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
166
+ template=template,
167
+ args=args
168
+ )
169
+
170
+ # Run the kernel
171
+ runtime(*args)
torch-ext/deep_gemm/jit_kernels/m_grouped_gemm.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple
3
+
4
+ from .gemm import get_best_configs
5
+ from .tuner import jit_tuner
6
+ from .utils import get_col_major_tma_aligned_tensor, get_num_sms
7
+
8
+ # C++ code templates
9
+ includes = ('"deep_gemm/fp8_gemm.cuh"', )
10
+ template = """
11
+ using namespace deep_gemm;
12
+
13
+ // Templated args from Python JIT call
14
+ constexpr auto N = {N}, K = {K};
15
+ constexpr auto BLOCK_M = {BLOCK_M};
16
+ constexpr auto BLOCK_N = {BLOCK_N};
17
+ constexpr auto kNumStages = {NUM_STAGES};
18
+ constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
19
+
20
+ // Make a templated grouped GEMM
21
+ using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
22
+
23
+ // Launch kernel
24
+ auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
25
+ auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
26
+ auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
27
+ auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
28
+ GemmType::run(out, rhs_scales, grouped_layout,
29
+ m,
30
+ tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
31
+ stream, num_sms, smem_size);
32
+ """
33
+
34
+
35
+ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
36
+ rhs: Tuple[torch.Tensor, torch.Tensor],
37
+ out: torch.Tensor, m_indices: torch.Tensor) -> None:
38
+ """
39
+ Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
40
+ LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
41
+ RHS and RHS scaling factors are required to be transposed.
42
+ The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
43
+ this function will do a transposing with a set of slow PyTorch operations.
44
+ On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
45
+ `get_m_alignment_for_contiguous_layout()` (128).
46
+
47
+ Arguments:
48
+ lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
49
+ the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`.
50
+ rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
51
+ the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
52
+ out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
53
+ m_indices: a tensor of shape `[m_sum]` with type `torch.int`.
54
+ `m_indices[i]` records the group which the j-th row of the LHS belong to,
55
+ which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
56
+ Values of `m_indices` in every-m-alignment-block must also be the same.
57
+ """
58
+ lhs, lhs_scales = lhs
59
+ rhs, rhs_scales = rhs
60
+ m, k = lhs.shape
61
+ num_groups, n, k_ = rhs.shape
62
+ m_, n_ = out.shape
63
+ m__ = m_indices.numel()
64
+
65
+ # Type and shape checks
66
+ assert m == m_ == m__ and k == k_ and n == n_
67
+ assert lhs_scales.shape == (m, (k + 127) // 128)
68
+ assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
69
+ assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
70
+ assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
71
+ assert out.dtype == torch.bfloat16
72
+ assert m_indices.dtype == torch.int32
73
+ assert lhs.is_contiguous() and rhs.is_contiguous()
74
+ assert out.is_contiguous() and m_indices.is_contiguous()
75
+
76
+ # LHS scales must be transposed for TMA load, but not for RHS scales
77
+ lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
78
+ assert rhs_scales.is_contiguous()
79
+
80
+ # Do nothing if `m` is zero
81
+ if m == 0:
82
+ return
83
+
84
+ # Auto-tuning with compilation
85
+ global includes, template
86
+ num_sms = get_num_sms()
87
+ block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms,
88
+ is_grouped_contiguous=True)
89
+ args = (lhs, lhs_scales, rhs, rhs_scales, out,
90
+ m_indices, m, num_groups,
91
+ torch.cuda.current_stream(), num_sms, smem_size)
92
+ runtime = jit_tuner.compile_and_tune(
93
+ name='m_grouped_gemm_fp8_fp8_bf16_nt',
94
+ keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
95
+ 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'},
96
+ space=(),
97
+ includes=includes,
98
+ arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
99
+ ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
100
+ ('out', torch.bfloat16),
101
+ ('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
102
+ ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
103
+ template=template,
104
+ args=args
105
+ )
106
+
107
+ # Run the kernel
108
+ runtime(*args)
109
+
110
+
111
+ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
112
+ rhs: Tuple[torch.Tensor, torch.Tensor],
113
+ out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
114
+ """
115
+ Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
116
+ LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
117
+ RHS and RHS scaling factors are required to be transposed.
118
+ The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
119
+ this function will do a transposing with a set of slow PyTorch operations.
120
+ Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
121
+ should be separately transposed.
122
+
123
+ Arguments:
124
+ lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
125
+ the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`.
126
+ rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
127
+ the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
128
+ out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
129
+ masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
130
+ in the i-th group.
131
+ expected_m: a value hint (which is a value on CPU) for the M expectation of each batch,
132
+ correctly setting this value may lead to better performance.
133
+ """
134
+ lhs, lhs_scales = lhs
135
+ rhs, rhs_scales = rhs
136
+ num_groups, m, k = lhs.shape
137
+ num_groups_, n, k_ = rhs.shape
138
+ num_groups__, m_, n_ = out.shape
139
+ num_groups___ = masked_m.numel()
140
+
141
+ # Type and shape checks
142
+ assert num_groups == num_groups_ == num_groups__ == num_groups___
143
+ assert m == m_ and n == n_ and k == k_
144
+ assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
145
+ assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
146
+ assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
147
+ assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
148
+ assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
149
+ assert out.dtype == torch.bfloat16
150
+ assert masked_m.dtype == torch.int32
151
+ assert lhs.is_contiguous() and rhs.is_contiguous()
152
+ assert out.is_contiguous() and masked_m.is_contiguous()
153
+
154
+ # LHS scales must be transposed for TMA load, but not for RHS scales
155
+ lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
156
+ assert rhs_scales.is_contiguous()
157
+
158
+ # Auto-tuning with compilation
159
+ global includes, template
160
+ num_sms = get_num_sms()
161
+ block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
162
+
163
+ # Extra checks for TMA store
164
+ if num_groups > 1 and m > block_m:
165
+ assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
166
+
167
+ args = (lhs, lhs_scales, rhs, rhs_scales, out,
168
+ masked_m, m,
169
+ torch.cuda.current_stream(), num_sms, smem_size)
170
+ runtime = jit_tuner.compile_and_tune(
171
+ name='m_grouped_gemm_fp8_fp8_bf16_nt',
172
+ keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
173
+ 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'},
174
+ space=(),
175
+ includes=includes,
176
+ arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
177
+ ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
178
+ ('out', torch.bfloat16),
179
+ ('grouped_layout', torch.int32), ('m', int),
180
+ ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
181
+ template=template,
182
+ args=args
183
+ )
184
+
185
+ # Run the kernel
186
+ runtime(*args)
torch-ext/deep_gemm/jit_kernels/tuner.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import torch
4
+ from typing import Any, Dict
5
+
6
+ from ..jit import build, cpp_format, generate, Runtime
7
+
8
+
9
+ class JITTuner:
10
+ def __init__(self) -> None:
11
+ self.tuned = {}
12
+
13
+ def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
14
+ includes: tuple, arg_defs: tuple, template: str, args: tuple) -> Runtime:
15
+ # NOTES: we always assume the space and template will not change
16
+ # We also assume the GPU device will not be changed
17
+ # NOTES: the function must have no accumulated side effects
18
+ keys = {k: keys[k] for k in sorted(keys.keys())}
19
+ signature = (name, f'{keys}')
20
+ if signature in self.tuned:
21
+ if os.getenv('DG_JIT_DEBUG', None):
22
+ print(f'Using cached JIT kernel {name} with keys {keys}')
23
+ return self.tuned[signature]
24
+
25
+ if os.getenv('DG_JIT_DEBUG', None):
26
+ print(f'Auto-tuning JIT kernel {name} with keys {keys}')
27
+
28
+ assert signature not in self.tuned
29
+ assert args is not None
30
+ space = (dict(), ) if len(space) == 0 else space
31
+
32
+ kernels = []
33
+ for tuned_keys in space:
34
+ assert isinstance(tuned_keys, dict)
35
+ full_keys = copy.deepcopy(keys)
36
+ full_keys.update(tuned_keys)
37
+ code = generate(includes, arg_defs, cpp_format(template, full_keys))
38
+
39
+ # Illegal build must raise errors
40
+ kernels.append((build(name, arg_defs, code), tuned_keys))
41
+
42
+ best_runtime, best_time, best_keys = None, None, None
43
+ for runtime, tuned_keys in kernels:
44
+ if len(space) > 1:
45
+ # Check kernel validity
46
+ return_code = runtime(*args)
47
+ if return_code != 0:
48
+ # Pass illegal kernels, e.g. insufficient shared memory capacity
49
+ if os.getenv('DG_JIT_DEBUG', None):
50
+ print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
51
+ continue
52
+
53
+ # Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
54
+ start_event = torch.cuda.Event(enable_timing=True)
55
+ end_event = torch.cuda.Event(enable_timing=True)
56
+ torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
57
+ torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda')
58
+ start_event.record()
59
+ for i in range(20):
60
+ assert runtime(*args) == 0
61
+ end_event.record()
62
+ end_event.synchronize()
63
+ elapsed_time = start_event.elapsed_time(end_event)
64
+ else:
65
+ elapsed_time = 0
66
+
67
+ # Compare if better
68
+ if best_time is None or elapsed_time < best_time:
69
+ best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
70
+ if os.getenv('DG_JIT_DEBUG', None):
71
+ print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
72
+ assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
73
+
74
+ # Cache the best runtime and return
75
+ if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None):
76
+ print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
77
+ self.tuned[signature] = best_runtime
78
+ return best_runtime
79
+
80
+
81
+ jit_tuner = JITTuner()
torch-ext/deep_gemm/jit_kernels/utils.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ _num_sms = None
4
+
5
+
6
+ def set_num_sms(num_sms: int) -> None:
7
+ """
8
+ Set the maximum SM count for all GEMM kernels to use.
9
+
10
+ Arguments:
11
+ num_sms: the desired maximum SM count for all GEMM kernels to use.
12
+ """
13
+ global _num_sms
14
+ assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
15
+ _num_sms = num_sms
16
+
17
+
18
+ def get_num_sms() -> int:
19
+ """
20
+ Get the current maximum limit of SM count for all GEMM kernels to use.
21
+ If the count is never specified, the function will return the number of device SMs.
22
+
23
+ Returns:
24
+ Current maximum limit of SM count for all GEMM kernels to use.
25
+ """
26
+ global _num_sms
27
+ if _num_sms is None:
28
+ _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
29
+ return _num_sms
30
+
31
+
32
+ def ceil_div(x: int, y: int) -> int:
33
+ """
34
+ Perform ceiling division of two integers.
35
+
36
+ Args:
37
+ x: the dividend.
38
+ y: the divisor.
39
+
40
+ Returns:
41
+ The result of the ceiling division.
42
+ """
43
+ return (x + y - 1) // y
44
+
45
+
46
+ def get_m_alignment_for_contiguous_layout():
47
+ """
48
+ When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis.
49
+ Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well
50
+ with GEMM block shape.
51
+
52
+ Returns:
53
+ Group-level alignment requirement for grouped contiguous layout, which is always 128.
54
+ """
55
+ return 128
56
+
57
+
58
+ def get_tma_aligned_size(x: int, element_size: int) -> int:
59
+ """
60
+ Global memory address of TMA must be 16-byte aligned.
61
+ Since we use column-major layout for the LHS scaling tensor,
62
+ the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
63
+
64
+ Arguments:
65
+ x: original M-axis shape of the LHS scaling tensor.
66
+ element_size: element size of the LHS scaling tensor.
67
+
68
+ Returns:
69
+ M-axis shape of the LHS scaling tensor after padding.
70
+ """
71
+ tma_alignment_bytes = 16
72
+ assert tma_alignment_bytes % element_size == 0
73
+ alignment = tma_alignment_bytes // element_size
74
+ return ceil_div(x, alignment) * alignment
75
+
76
+
77
+ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
78
+ """
79
+ Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
80
+ If the input tensor is already column-major layout and 16-byte aligned along the M axis
81
+ (thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
82
+
83
+ Arguments:
84
+ x: usually the LHS scaling tensor in GEMM.
85
+
86
+ Returns:
87
+ The LHS scaling tensor of TMA-aligned transposed format.
88
+ """
89
+ # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
90
+ assert x.dim() in (2, 3)
91
+ remove_dim = False
92
+ if x.dim() == 2:
93
+ x, remove_dim = x.unsqueeze(0), True
94
+
95
+ b, m, n = x.shape
96
+ aligned_m = get_tma_aligned_size(m, x.element_size())
97
+
98
+ # The last kernel gives a column-major TMA aligned layout
99
+ if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
100
+ return x.squeeze(0) if remove_dim else x
101
+
102
+ # Normal layout requires transposing
103
+ aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
104
+ aligned_x[:, :m, :] = x
105
+ aligned_x = aligned_x[:, :m, :]
106
+ return aligned_x.squeeze(0) if remove_dim else aligned_x
torch-ext/deep_gemm/utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import torch.distributed as dist
5
+
6
+
7
+ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
8
+ high_precision: bool = False):
9
+ # Flush L2 cache with 256 MB data
10
+ torch.cuda.synchronize()
11
+ cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
12
+ cache.zero_()
13
+
14
+ # Warmup
15
+ for _ in range(num_warmups):
16
+ fn()
17
+
18
+ # Add a large kernel to eliminate the CPU launch overhead
19
+ if high_precision:
20
+ x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
21
+ y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
22
+ x @ y
23
+
24
+ # Testing
25
+ start_event = torch.cuda.Event(enable_timing=True)
26
+ end_event = torch.cuda.Event(enable_timing=True)
27
+ start_event.record()
28
+ for i in range(num_tests):
29
+ fn()
30
+ end_event.record()
31
+ torch.cuda.synchronize()
32
+
33
+ return start_event.elapsed_time(end_event) / num_tests
34
+
35
+
36
+ class empty_suppress:
37
+ def __enter__(self):
38
+ return self
39
+
40
+ def __exit__(self, *_):
41
+ pass
42
+
43
+
44
+ class suppress_stdout_stderr:
45
+ def __enter__(self):
46
+ self.outnull_file = open(os.devnull, 'w')
47
+ self.errnull_file = open(os.devnull, 'w')
48
+
49
+ self.old_stdout_fileno_undup = sys.stdout.fileno()
50
+ self.old_stderr_fileno_undup = sys.stderr.fileno()
51
+
52
+ self.old_stdout_fileno = os.dup(sys.stdout.fileno())
53
+ self.old_stderr_fileno = os.dup(sys.stderr.fileno())
54
+
55
+ self.old_stdout = sys.stdout
56
+ self.old_stderr = sys.stderr
57
+
58
+ os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
59
+ os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
60
+
61
+ sys.stdout = self.outnull_file
62
+ sys.stderr = self.errnull_file
63
+ return self
64
+
65
+ def __exit__(self, *_):
66
+ sys.stdout = self.old_stdout
67
+ sys.stderr = self.old_stderr
68
+
69
+ os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
70
+ os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
71
+
72
+ os.close(self.old_stdout_fileno)
73
+ os.close(self.old_stderr_fileno)
74
+
75
+ self.outnull_file.close()
76
+ self.errnull_file.close()
77
+
78
+
79
+ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
80
+ trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = False):
81
+ # Conflict with Nsight Systems
82
+ using_nsys = os.environ.get('DG_NSYS_PROFILING', False)
83
+
84
+ # For some auto-tuning kernels with prints
85
+ fn()
86
+
87
+ # Profile
88
+ suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
89
+ with suppress():
90
+ schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None
91
+ profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
92
+ with profiler:
93
+ for i in range(2):
94
+ # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
95
+ if barrier_comm_profiling:
96
+ lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
97
+ rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
98
+ lhs @ rhs
99
+ dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
100
+ for _ in range(num_tests):
101
+ if flush_l2:
102
+ torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
103
+ fn()
104
+
105
+ if not using_nsys:
106
+ profiler.step()
107
+
108
+ # Return 1 if using Nsight Systems
109
+ if using_nsys:
110
+ return 1
111
+
112
+ # Parse the profiling table
113
+ assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
114
+ is_tupled = isinstance(kernel_names, tuple)
115
+ prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
116
+ kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
117
+ assert all([isinstance(name, str) for name in kernel_names])
118
+ for name in kernel_names:
119
+ assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
120
+
121
+ # Save chrome traces
122
+ if trace_path is not None:
123
+ profiler.export_chrome_trace(trace_path)
124
+
125
+ # Return average kernel times
126
+ units = {'ms': 1e3, 'us': 1e6}
127
+ kernel_times = []
128
+ for name in kernel_names:
129
+ for line in prof_lines:
130
+ if name in line:
131
+ time_str = line.split()[-2]
132
+ for unit, scale in units.items():
133
+ if unit in time_str:
134
+ kernel_times.append(float(time_str.replace(unit, '')) / scale)
135
+ break
136
+ break
137
+ return tuple(kernel_times) if is_tupled else kernel_times[0]
138
+
139
+
140
+ def calc_diff(x, y):
141
+ x, y = x.double(), y.double()
142
+ denominator = (x * x + y * y).sum()
143
+ sim = 2 * (x * y).sum() / denominator
144
+ return 1 - sim
145
+
146
+
147
+ def count_bytes(tensors):
148
+ total = 0
149
+ for t in tensors:
150
+ if isinstance(t, tuple):
151
+ total += count_bytes(t)
152
+ else:
153
+ total += t.numel() * t.element_size()
154
+ return total