Commit ·
4913396
0
Parent(s):
Import DeepGEMM
Browse files- LICENSE +21 -0
- README.md +10 -0
- tests/test_core.py +158 -0
- tests/test_jit.py +64 -0
- torch-ext/deep_gemm/__init__.py +13 -0
- torch-ext/deep_gemm/include/deep_gemm/fp8_gemm.cuh +449 -0
- torch-ext/deep_gemm/include/deep_gemm/mma_utils.cuh +885 -0
- torch-ext/deep_gemm/include/deep_gemm/scheduler.cuh +103 -0
- torch-ext/deep_gemm/include/deep_gemm/tma_utils.cuh +96 -0
- torch-ext/deep_gemm/include/deep_gemm/utils.cuh +48 -0
- torch-ext/deep_gemm/jit/__init__.py +3 -0
- torch-ext/deep_gemm/jit/compiler.py +151 -0
- torch-ext/deep_gemm/jit/interleave_ffma.py +137 -0
- torch-ext/deep_gemm/jit/runtime.py +66 -0
- torch-ext/deep_gemm/jit/template.py +93 -0
- torch-ext/deep_gemm/jit_kernels/__init__.py +10 -0
- torch-ext/deep_gemm/jit_kernels/gemm.py +171 -0
- torch-ext/deep_gemm/jit_kernels/m_grouped_gemm.py +186 -0
- torch-ext/deep_gemm/jit_kernels/tuner.py +81 -0
- torch-ext/deep_gemm/jit_kernels/utils.py +106 -0
- torch-ext/deep_gemm/utils.py +154 -0
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 DeepSeek
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- kernel
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## DeepGEMM
|
| 8 |
+
|
| 9 |
+
[DeepGEMM](https://github.com/deepseek-ai/DeepGEMM/) by DeepSeek.
|
| 10 |
+
|
tests/test_core.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import deep_gemm
|
| 6 |
+
from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 10 |
+
assert x.dim() == 2 and x.size(1) % 128 == 0
|
| 11 |
+
m, n = x.shape
|
| 12 |
+
x_view = x.view(m, -1, 128)
|
| 13 |
+
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
| 14 |
+
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 18 |
+
assert x.dim() == 2
|
| 19 |
+
m, n = x.shape
|
| 20 |
+
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device)
|
| 21 |
+
x_padded[:m, :n] = x
|
| 22 |
+
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
| 23 |
+
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
| 24 |
+
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
| 25 |
+
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def construct(m: int, k: int, n: int) -> \
|
| 29 |
+
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
| 30 |
+
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
| 31 |
+
y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
| 32 |
+
out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
| 33 |
+
ref_out = x @ y.t()
|
| 34 |
+
|
| 35 |
+
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y)
|
| 36 |
+
# Transpose earlier so that the testing will not trigger transposing kernels
|
| 37 |
+
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
| 38 |
+
return x_fp8, y_fp8, out, ref_out
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def construct_grouped(num_groups: int, m: int, k: int, n: int, is_masked: bool) -> \
|
| 42 |
+
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
| 43 |
+
x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16)
|
| 44 |
+
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
| 45 |
+
out = torch.empty((num_groups, m, n), device='cuda', dtype=torch.bfloat16)
|
| 46 |
+
ref_out = torch.einsum('gmk,gnk->gmn', x, y)
|
| 47 |
+
|
| 48 |
+
assert m % 4 == 0, f'TMA alignment error: {m}'
|
| 49 |
+
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float))
|
| 50 |
+
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
|
| 51 |
+
for i in range(num_groups):
|
| 52 |
+
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
|
| 53 |
+
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
| 54 |
+
|
| 55 |
+
# For non-masked input, we must merge the group and M dims
|
| 56 |
+
if not is_masked:
|
| 57 |
+
x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1])
|
| 58 |
+
out, ref_out = out.view(-1, n), ref_out.view(-1, n)
|
| 59 |
+
|
| 60 |
+
# Transpose earlier so that the testing will not trigger transposing kernels
|
| 61 |
+
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
|
| 62 |
+
return x_fp8, y_fp8, out, ref_out
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def test_gemm() -> None:
|
| 66 |
+
print('Testing GEMM:')
|
| 67 |
+
for m in (64, 128, 4096):
|
| 68 |
+
for k, n in [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]:
|
| 69 |
+
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
|
| 70 |
+
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
|
| 71 |
+
diff = calc_diff(out, ref_out)
|
| 72 |
+
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
|
| 73 |
+
|
| 74 |
+
# noinspection PyShadowingNames
|
| 75 |
+
def test_func():
|
| 76 |
+
# Construct new tensors every time to avoid L2 cache acceleration
|
| 77 |
+
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
|
| 78 |
+
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
|
| 79 |
+
|
| 80 |
+
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
| 81 |
+
print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
|
| 82 |
+
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
| 83 |
+
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
|
| 84 |
+
print()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def test_m_grouped_gemm_contiguous() -> None:
|
| 88 |
+
print('Testing grouped contiguous GEMM:')
|
| 89 |
+
|
| 90 |
+
for num_groups, m, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)):
|
| 91 |
+
# TODO: make a stronger test
|
| 92 |
+
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False)
|
| 93 |
+
m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
|
| 94 |
+
m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
|
| 95 |
+
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
| 96 |
+
diff = calc_diff(out, ref_out)
|
| 97 |
+
assert diff < 0.001, f'm={m * num_groups}, {k=}, {n=}, {diff:.5f}'
|
| 98 |
+
|
| 99 |
+
# noinspection PyShadowingNames
|
| 100 |
+
def test_func():
|
| 101 |
+
# Construct new tensors every time to avoid L2 cache acceleration
|
| 102 |
+
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False)
|
| 103 |
+
m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
|
| 104 |
+
m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
|
| 105 |
+
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
| 106 |
+
|
| 107 |
+
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
| 108 |
+
print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
| 109 |
+
f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
| 110 |
+
f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
|
| 111 |
+
print()
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def test_m_grouped_gemm_masked() -> None:
|
| 115 |
+
print('Testing grouped masked GEMM:')
|
| 116 |
+
|
| 117 |
+
for num_groups, m in ((1, 1024), (2, 512), (4, 256)):
|
| 118 |
+
for k, n in ((7168, 4096), (2048, 7168), ):
|
| 119 |
+
# Test correctness
|
| 120 |
+
masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384)))
|
| 121 |
+
for i in range(10):
|
| 122 |
+
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
|
| 123 |
+
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
| 124 |
+
for j in range(num_groups):
|
| 125 |
+
masked_m[j] = random.choice(masked_m_candidates)
|
| 126 |
+
expected_m = min(int(masked_m.float().mean()) + 1, m)
|
| 127 |
+
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m)
|
| 128 |
+
for j in range(num_groups):
|
| 129 |
+
diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()])
|
| 130 |
+
assert diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
|
| 131 |
+
|
| 132 |
+
# noinspection PyShadowingNames
|
| 133 |
+
def test_func():
|
| 134 |
+
# Construct new tensors every time to avoid L2 cache acceleration
|
| 135 |
+
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
|
| 136 |
+
masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m
|
| 137 |
+
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, m)
|
| 138 |
+
|
| 139 |
+
# Test performance with fixed shapes
|
| 140 |
+
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
| 141 |
+
print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
|
| 142 |
+
f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
|
| 143 |
+
f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
|
| 144 |
+
print()
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if __name__ == '__main__':
|
| 148 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 149 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 150 |
+
torch.manual_seed(0)
|
| 151 |
+
random.seed(0)
|
| 152 |
+
|
| 153 |
+
print('Library path:')
|
| 154 |
+
print(f' > {deep_gemm.__path__}\n')
|
| 155 |
+
|
| 156 |
+
test_gemm()
|
| 157 |
+
test_m_grouped_gemm_contiguous()
|
| 158 |
+
test_m_grouped_gemm_masked()
|
tests/test_jit.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from deep_gemm import jit
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Capture:
|
| 9 |
+
def __init__(self) -> None:
|
| 10 |
+
self.read_fd = None
|
| 11 |
+
self.write_fd = None
|
| 12 |
+
self.saved_stdout = None
|
| 13 |
+
self.captured = None
|
| 14 |
+
|
| 15 |
+
def __enter__(self) -> Any:
|
| 16 |
+
self.read_fd, self.write_fd = os.pipe()
|
| 17 |
+
self.saved_stdout = os.dup(1)
|
| 18 |
+
os.dup2(self.write_fd, 1)
|
| 19 |
+
return self
|
| 20 |
+
|
| 21 |
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
| 22 |
+
os.dup2(self.saved_stdout, 1)
|
| 23 |
+
os.close(self.write_fd)
|
| 24 |
+
with os.fdopen(self.read_fd, 'r') as f:
|
| 25 |
+
self.captured = f.read()
|
| 26 |
+
|
| 27 |
+
def capture(self) -> str:
|
| 28 |
+
return self.captured
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == '__main__':
|
| 32 |
+
# Runtime
|
| 33 |
+
print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n')
|
| 34 |
+
|
| 35 |
+
# Templates
|
| 36 |
+
print('Generated code:')
|
| 37 |
+
args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16),
|
| 38 |
+
('enable_double_streams', bool), ('stream', torch.cuda.Stream))
|
| 39 |
+
body = "\n"
|
| 40 |
+
body += 'std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;\n'
|
| 41 |
+
body += 'std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;\n'
|
| 42 |
+
body += 'std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;\n'
|
| 43 |
+
body += 'std::cout << reinterpret_cast<uint64_t>(out) << std::endl;\n'
|
| 44 |
+
body += 'std::cout << enable_double_streams << std::endl;\n'
|
| 45 |
+
body += 'std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;\n'
|
| 46 |
+
code = jit.generate((), args, body)
|
| 47 |
+
print(code)
|
| 48 |
+
|
| 49 |
+
# Build
|
| 50 |
+
print('Building ...')
|
| 51 |
+
func = jit.build('test_func', args, code)
|
| 52 |
+
|
| 53 |
+
# Test correctness
|
| 54 |
+
print('Running ...')
|
| 55 |
+
fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda')
|
| 56 |
+
fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda')
|
| 57 |
+
bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda')
|
| 58 |
+
with Capture() as capture:
|
| 59 |
+
assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0
|
| 60 |
+
output = capture.capture()
|
| 61 |
+
ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n'
|
| 62 |
+
assert output == ref_output, f'{output=}, {ref_output=}'
|
| 63 |
+
|
| 64 |
+
print('JIT test passed')
|
torch-ext/deep_gemm/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from . import jit
|
| 4 |
+
from .jit_kernels import (
|
| 5 |
+
gemm_fp8_fp8_bf16_nt,
|
| 6 |
+
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
| 7 |
+
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
| 8 |
+
ceil_div,
|
| 9 |
+
set_num_sms, get_num_sms,
|
| 10 |
+
get_col_major_tma_aligned_tensor,
|
| 11 |
+
get_m_alignment_for_contiguous_layout
|
| 12 |
+
)
|
| 13 |
+
from .utils import bench, bench_kineto, calc_diff
|
torch-ext/deep_gemm/include/deep_gemm/fp8_gemm.cuh
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma clang diagnostic push
|
| 2 |
+
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <cutlass/arch/barrier.h>
|
| 6 |
+
#include <cutlass/arch/reg_reconfig.h>
|
| 7 |
+
|
| 8 |
+
#include <cute/arch/cluster_sm90.hpp>
|
| 9 |
+
#include <cute/arch/copy_sm90_desc.hpp>
|
| 10 |
+
#include <cute/arch/copy_sm90_tma.hpp>
|
| 11 |
+
|
| 12 |
+
#include "mma_utils.cuh"
|
| 13 |
+
#include "scheduler.cuh"
|
| 14 |
+
#include "tma_utils.cuh"
|
| 15 |
+
#include "utils.cuh"
|
| 16 |
+
|
| 17 |
+
namespace deep_gemm {
|
| 18 |
+
|
| 19 |
+
enum class Layout {
|
| 20 |
+
RowMajor,
|
| 21 |
+
ColMajor
|
| 22 |
+
};
|
| 23 |
+
|
| 24 |
+
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
|
| 25 |
+
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
|
| 26 |
+
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
|
| 27 |
+
return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 31 |
+
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
| 32 |
+
uint32_t kNumGroups, uint32_t kNumStages,
|
| 33 |
+
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
| 34 |
+
uint32_t kNumTMAMulticast,
|
| 35 |
+
GemmType kGemmType>
|
| 36 |
+
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
| 37 |
+
fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
| 38 |
+
uint32_t shape_m,
|
| 39 |
+
const __grid_constant__ CUtensorMap tensor_map_a,
|
| 40 |
+
const __grid_constant__ CUtensorMap tensor_map_b,
|
| 41 |
+
const __grid_constant__ CUtensorMap tensor_map_scales_a,
|
| 42 |
+
const __grid_constant__ CUtensorMap tensor_map_d) {
|
| 43 |
+
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
| 44 |
+
// Scaling checks
|
| 45 |
+
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
| 46 |
+
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block");
|
| 47 |
+
|
| 48 |
+
// Types
|
| 49 |
+
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
| 50 |
+
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
| 51 |
+
|
| 52 |
+
// Shared memory
|
| 53 |
+
static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
|
| 54 |
+
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
|
| 55 |
+
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
| 56 |
+
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
| 57 |
+
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
| 58 |
+
static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K);
|
| 59 |
+
static constexpr uint32_t SMEM_SCALES_B_SIZE = ceil_div<uint32_t>(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) * sizeof(Barrier);
|
| 60 |
+
|
| 61 |
+
// Configs
|
| 62 |
+
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
|
| 63 |
+
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
| 64 |
+
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
|
| 65 |
+
constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages);
|
| 66 |
+
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 67 |
+
const uint32_t lane_idx = get_lane_id();
|
| 68 |
+
|
| 69 |
+
// Prefetch TMA descriptors at very beginning
|
| 70 |
+
if (threadIdx.x == kNumMathThreads) {
|
| 71 |
+
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
|
| 72 |
+
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
|
| 73 |
+
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
|
| 74 |
+
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
|
| 75 |
+
}
|
| 76 |
+
__syncwarp();
|
| 77 |
+
|
| 78 |
+
// Align to 1024 bytes for swizzle-128B
|
| 79 |
+
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
| 80 |
+
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
| 81 |
+
|
| 82 |
+
// Data on shared memory
|
| 83 |
+
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
|
| 84 |
+
__nv_fp8_e4m3* smem_a[kNumStages];
|
| 85 |
+
__nv_fp8_e4m3* smem_b[kNumStages];
|
| 86 |
+
float* smem_scales_a[kNumStages];
|
| 87 |
+
float* smem_scales_b;
|
| 88 |
+
|
| 89 |
+
// TMA Barrier for both divisible and non-divisible cases
|
| 90 |
+
Barrier* full_barriers[kNumStages];
|
| 91 |
+
Barrier* empty_barriers[kNumStages];
|
| 92 |
+
|
| 93 |
+
// Fill shared memory pointers
|
| 94 |
+
#pragma unroll
|
| 95 |
+
for (int i = 0; i < kNumStages; ++ i) {
|
| 96 |
+
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
| 97 |
+
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
| 98 |
+
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
|
| 99 |
+
}
|
| 100 |
+
smem_scales_b = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE));
|
| 101 |
+
|
| 102 |
+
// Fill barriers
|
| 103 |
+
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_scales_b) + SMEM_SCALES_B_SIZE);
|
| 104 |
+
#pragma unroll
|
| 105 |
+
for (int i = 0; i < kNumStages; ++ i) {
|
| 106 |
+
full_barriers[i] = barrier_start_ptr + i;
|
| 107 |
+
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
// Initialize barriers
|
| 111 |
+
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
|
| 112 |
+
if (threadIdx.x == kNumMathThreads) {
|
| 113 |
+
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
|
| 114 |
+
// even with TMA multicast disabled, we want to make the behavior aligned
|
| 115 |
+
#pragma unroll
|
| 116 |
+
for (int i = 0; i < kNumStages; ++ i) {
|
| 117 |
+
full_barriers[i]->init(1);
|
| 118 |
+
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
// Make initialized barrier visible in async proxy
|
| 122 |
+
cutlass::arch::fence_view_async_shared();
|
| 123 |
+
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
// Synchronize all threads to make barrier visible in normal memory model
|
| 127 |
+
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
|
| 128 |
+
|
| 129 |
+
// For pipeline unrolling
|
| 130 |
+
struct DivisibleK {};
|
| 131 |
+
struct NotDivisibleK {};
|
| 132 |
+
auto launch_k_iterations = [](const auto& func) {
|
| 133 |
+
if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
|
| 134 |
+
for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
| 135 |
+
func(k_iter, DivisibleK{});
|
| 136 |
+
} else {
|
| 137 |
+
for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
|
| 138 |
+
func(k_iter, DivisibleK{});
|
| 139 |
+
func(kNumIterations - 1, NotDivisibleK{});
|
| 140 |
+
}
|
| 141 |
+
};
|
| 142 |
+
|
| 143 |
+
// Register reconfigurations
|
| 144 |
+
constexpr int kNumTMARegisters = 40;
|
| 145 |
+
constexpr int kNumMathRegisters = 232;
|
| 146 |
+
|
| 147 |
+
// Block scheduler
|
| 148 |
+
uint32_t m_block_idx, n_block_idx;
|
| 149 |
+
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast>(shape_m, grouped_layout);
|
| 150 |
+
|
| 151 |
+
if (threadIdx.x >= kNumMathThreads) {
|
| 152 |
+
// TMA warp-group for loading data
|
| 153 |
+
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
| 154 |
+
|
| 155 |
+
// NOTES: only one thread (or warp) will be used
|
| 156 |
+
if (threadIdx.x == kNumMathThreads) {
|
| 157 |
+
// Persistently schedule over blocks
|
| 158 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 159 |
+
launch_k_iterations([&](int k_iter, auto type) {
|
| 160 |
+
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
| 161 |
+
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
| 162 |
+
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
| 163 |
+
|
| 164 |
+
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
|
| 165 |
+
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
|
| 166 |
+
#pragma unroll
|
| 167 |
+
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
| 168 |
+
// Wait consumer release
|
| 169 |
+
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
| 170 |
+
|
| 171 |
+
// Issue TMA A with broadcasting
|
| 172 |
+
auto& full_barrier = *full_barriers[s];
|
| 173 |
+
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
| 174 |
+
tma_copy<kNumTMAMulticast>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
| 175 |
+
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
| 176 |
+
tma_copy<kNumTMAMulticast>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
| 177 |
+
smem_scales_a[s], m_block_idx * BLOCK_M,
|
| 178 |
+
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
| 179 |
+
|
| 180 |
+
// Issue TMA B without broadcasting
|
| 181 |
+
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
| 182 |
+
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
| 183 |
+
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
// Wait unaligned cases
|
| 187 |
+
#pragma unroll
|
| 188 |
+
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
| 189 |
+
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
| 190 |
+
full_barriers[s]->arrive();
|
| 191 |
+
}
|
| 192 |
+
});
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
| 196 |
+
if constexpr (kNumTMAMulticast > 1) {
|
| 197 |
+
#pragma unroll
|
| 198 |
+
for (uint32_t s = 0; s < kNumStages; ++ s)
|
| 199 |
+
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
} else {
|
| 203 |
+
// Math warp-groups for WGMMA
|
| 204 |
+
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
| 205 |
+
|
| 206 |
+
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
| 207 |
+
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
|
| 208 |
+
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
|
| 209 |
+
|
| 210 |
+
// Persistently schedule over blocks
|
| 211 |
+
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
| 212 |
+
// Decide the number of scales B to load
|
| 213 |
+
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
|
| 214 |
+
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
|
| 215 |
+
if constexpr (not kMustUseUniformedScaleB) {
|
| 216 |
+
num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8;
|
| 217 |
+
num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8;
|
| 218 |
+
}
|
| 219 |
+
uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2);
|
| 220 |
+
|
| 221 |
+
// Load B scales with math warp-groups
|
| 222 |
+
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
|
| 223 |
+
if (threadIdx.x >= 32) {
|
| 224 |
+
auto num_previous_lines = scheduler.get_global_idx<false>(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
|
| 225 |
+
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
|
| 226 |
+
#pragma unroll
|
| 227 |
+
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
|
| 228 |
+
st_shared(smem_scales_b + i, __ldg(local_scales_b + i));
|
| 229 |
+
}
|
| 230 |
+
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
| 231 |
+
|
| 232 |
+
// Accumulation for WGMMA or CUDA promotion
|
| 233 |
+
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
|
| 234 |
+
|
| 235 |
+
// Empty barrier arrival
|
| 236 |
+
auto empty_barrier_arrive = [&](int s) {
|
| 237 |
+
if constexpr (kNumTMAMulticast == 1) {
|
| 238 |
+
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
| 239 |
+
} else {
|
| 240 |
+
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
|
| 241 |
+
}
|
| 242 |
+
};
|
| 243 |
+
|
| 244 |
+
// Launch MMAs
|
| 245 |
+
launch_k_iterations([&](int k_iter, auto type) {
|
| 246 |
+
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
| 247 |
+
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
| 248 |
+
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
| 249 |
+
|
| 250 |
+
#pragma unroll
|
| 251 |
+
for (int s = 0; s < kNumInnerStages; ++ s) {
|
| 252 |
+
// Read B scales
|
| 253 |
+
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
|
| 254 |
+
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
|
| 255 |
+
if constexpr (not kMustUseUniformedScaleB)
|
| 256 |
+
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
|
| 257 |
+
|
| 258 |
+
// Wait TMA arrivals
|
| 259 |
+
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
| 260 |
+
|
| 261 |
+
// Read A scales
|
| 262 |
+
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
| 263 |
+
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1);
|
| 264 |
+
|
| 265 |
+
// Commit WGMMA instructions
|
| 266 |
+
#pragma unroll
|
| 267 |
+
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 268 |
+
warpgroup_fence_operand(accum[i]);
|
| 269 |
+
warpgroup_arrive();
|
| 270 |
+
#pragma unroll
|
| 271 |
+
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
| 272 |
+
auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
|
| 273 |
+
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
|
| 274 |
+
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
| 275 |
+
}
|
| 276 |
+
warpgroup_commit_batch();
|
| 277 |
+
#pragma unroll
|
| 278 |
+
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
| 279 |
+
warpgroup_fence_operand(accum[i]);
|
| 280 |
+
warpgroup_wait<0>();
|
| 281 |
+
|
| 282 |
+
// Notify barrier arrival
|
| 283 |
+
empty_barrier_arrive(s);
|
| 284 |
+
|
| 285 |
+
// Promote with scales
|
| 286 |
+
// NOTES: making it as predicates is very important for performance, comparing to two loops
|
| 287 |
+
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
|
| 288 |
+
float scale_0_1, scale_1_1;
|
| 289 |
+
if constexpr (not kMustUseUniformedScaleB)
|
| 290 |
+
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
| 291 |
+
#pragma unroll
|
| 292 |
+
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
| 293 |
+
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
|
| 294 |
+
final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
| 295 |
+
final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
| 296 |
+
final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
| 297 |
+
final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
// Wait unaligned cases
|
| 302 |
+
#pragma unroll
|
| 303 |
+
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
| 304 |
+
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
| 305 |
+
empty_barrier_arrive(s);
|
| 306 |
+
}
|
| 307 |
+
});
|
| 308 |
+
|
| 309 |
+
// Write back to shared memory using STSM
|
| 310 |
+
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
| 311 |
+
#pragma unroll
|
| 312 |
+
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
|
| 313 |
+
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
|
| 314 |
+
__float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
|
| 315 |
+
__float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
|
| 316 |
+
__float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
|
| 317 |
+
__float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
|
| 318 |
+
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
|
| 319 |
+
);
|
| 320 |
+
}
|
| 321 |
+
if constexpr (WGMMA::kNumAccum % 8 != 0) {
|
| 322 |
+
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
| 323 |
+
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
|
| 324 |
+
__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
|
| 325 |
+
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16
|
| 326 |
+
);
|
| 327 |
+
}
|
| 328 |
+
cute::tma_store_fence();
|
| 329 |
+
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
| 330 |
+
|
| 331 |
+
// Use TMA store to write back to global memory
|
| 332 |
+
if (threadIdx.x == 0) {
|
| 333 |
+
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N,
|
| 334 |
+
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
| 335 |
+
cute::tma_store_arrive();
|
| 336 |
+
cute::tma_store_wait<0>();
|
| 337 |
+
}
|
| 338 |
+
__syncwarp();
|
| 339 |
+
}
|
| 340 |
+
}
|
| 341 |
+
#else
|
| 342 |
+
if (blockIdx.x == 0 and threadIdx.x == 0)
|
| 343 |
+
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
| 344 |
+
#endif
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
| 348 |
+
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
| 349 |
+
uint32_t kNumGroups, uint32_t kNumStages,
|
| 350 |
+
uint32_t kNumTMAMulticast,
|
| 351 |
+
GemmType kGemmType>
|
| 352 |
+
class Gemm {
|
| 353 |
+
private:
|
| 354 |
+
using Barrier = cuda::barrier<cuda::thread_scope_block>;
|
| 355 |
+
|
| 356 |
+
public:
|
| 357 |
+
Gemm() = default;
|
| 358 |
+
|
| 359 |
+
static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
| 360 |
+
uint32_t shape_m,
|
| 361 |
+
const CUtensorMap& tma_a_desc,
|
| 362 |
+
const CUtensorMap& tma_b_desc,
|
| 363 |
+
const CUtensorMap& tma_scales_a_desc,
|
| 364 |
+
const CUtensorMap& tma_d_desc,
|
| 365 |
+
cudaStream_t stream,
|
| 366 |
+
int num_sms, uint32_t smem_size) {
|
| 367 |
+
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
|
| 368 |
+
constexpr uint32_t kNumTMAThreads = 128;
|
| 369 |
+
constexpr uint32_t kNumMathThreadsPerGroup = 128;
|
| 370 |
+
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
|
| 371 |
+
kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
|
| 372 |
+
kNumTMAMulticast, kGemmType>;
|
| 373 |
+
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
|
| 374 |
+
|
| 375 |
+
// Cluster launch
|
| 376 |
+
cudaLaunchConfig_t config;
|
| 377 |
+
config.gridDim = num_sms;
|
| 378 |
+
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
| 379 |
+
config.dynamicSmemBytes = smem_size;
|
| 380 |
+
config.stream = stream;
|
| 381 |
+
|
| 382 |
+
// Clusters for TMA multicast
|
| 383 |
+
// NOTES: `>= 4` cluster size will cause performance degradation
|
| 384 |
+
cudaLaunchAttribute attr;
|
| 385 |
+
attr.id = cudaLaunchAttributeClusterDimension;
|
| 386 |
+
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
|
| 387 |
+
config.attrs = &attr;
|
| 388 |
+
config.numAttrs = 1;
|
| 389 |
+
|
| 390 |
+
// Launch
|
| 391 |
+
auto status = cudaLaunchKernelEx(&config, kernel,
|
| 392 |
+
gmem_d, scales_b, grouped_layout,
|
| 393 |
+
shape_m,
|
| 394 |
+
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
|
| 395 |
+
DG_HOST_ASSERT(status == cudaSuccess);
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
template <typename T>
|
| 399 |
+
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) {
|
| 400 |
+
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
| 401 |
+
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K);
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
template <typename T>
|
| 405 |
+
static CUtensorMap make_2d_tma_b_desc(T* global_address) {
|
| 406 |
+
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
| 407 |
+
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
template <typename T>
|
| 411 |
+
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
|
| 412 |
+
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
| 413 |
+
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
|
| 414 |
+
min(BLOCK_M, shape_m), BLOCK_N,
|
| 415 |
+
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
template <typename T>
|
| 419 |
+
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
|
| 420 |
+
// Make TMA aligned to 16 bytes
|
| 421 |
+
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
| 422 |
+
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
|
| 423 |
+
|
| 424 |
+
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
| 425 |
+
shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
|
| 426 |
+
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
template <typename T>
|
| 430 |
+
static CUtensorMap make_2d_tma_desc(
|
| 431 |
+
T* global_address, Layout layout,
|
| 432 |
+
uint32_t gmem_rows, uint32_t gmem_cols,
|
| 433 |
+
uint32_t smem_rows, uint32_t smem_cols,
|
| 434 |
+
CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
|
| 435 |
+
if (layout == Layout::RowMajor) {
|
| 436 |
+
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
|
| 437 |
+
uint32_t smem_dim[2] = {smem_cols, smem_rows};
|
| 438 |
+
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type);
|
| 439 |
+
} else {
|
| 440 |
+
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
|
| 441 |
+
uint32_t smem_dim[2] = {smem_rows, smem_cols};
|
| 442 |
+
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type);
|
| 443 |
+
}
|
| 444 |
+
}
|
| 445 |
+
};
|
| 446 |
+
|
| 447 |
+
}; // namespace deep_gemm
|
| 448 |
+
|
| 449 |
+
#pragma clang diagnostic pop
|
torch-ext/deep_gemm/include/deep_gemm/mma_utils.cuh
ADDED
|
@@ -0,0 +1,885 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cuda.h>
|
| 4 |
+
|
| 5 |
+
#include "utils.cuh"
|
| 6 |
+
|
| 7 |
+
namespace deep_gemm {
|
| 8 |
+
|
| 9 |
+
struct SM90_64x16x32_F32E4M3E4M3_SS {
|
| 10 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
| 11 |
+
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
| 12 |
+
bool scale_d) {
|
| 13 |
+
asm volatile("{\n"
|
| 14 |
+
".reg .pred p;\n"
|
| 15 |
+
"setp.ne.b32 p, %10, 0;\n"
|
| 16 |
+
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3"
|
| 17 |
+
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
| 18 |
+
" %8,"
|
| 19 |
+
" %9,"
|
| 20 |
+
" p , 1, 1;\n"
|
| 21 |
+
"}\n"
|
| 22 |
+
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07)
|
| 23 |
+
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 27 |
+
wgmma(desc_a, desc_b,
|
| 28 |
+
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
| 29 |
+
scale_d);
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
static constexpr int M = 64;
|
| 33 |
+
static constexpr int N = 16;
|
| 34 |
+
static constexpr int K = 32;
|
| 35 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
struct SM90_64x24x32_F32E4M3E4M3_SS {
|
| 39 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
| 40 |
+
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
| 41 |
+
float& d08, float& d09, float& d10, float& d11,
|
| 42 |
+
bool scale_d) {
|
| 43 |
+
asm volatile("{\n"
|
| 44 |
+
".reg .pred p;\n"
|
| 45 |
+
"setp.ne.b32 p, %14, 0;\n"
|
| 46 |
+
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3"
|
| 47 |
+
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
| 48 |
+
" %8, %9, %10, %11},"
|
| 49 |
+
" %12,"
|
| 50 |
+
" %13,"
|
| 51 |
+
" p , 1, 1;\n"
|
| 52 |
+
"}\n"
|
| 53 |
+
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
| 54 |
+
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11)
|
| 55 |
+
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 59 |
+
wgmma(desc_a, desc_b,
|
| 60 |
+
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
| 61 |
+
d[8], d[9], d[10], d[11],
|
| 62 |
+
scale_d);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
static constexpr int M = 64;
|
| 66 |
+
static constexpr int N = 24;
|
| 67 |
+
static constexpr int K = 32;
|
| 68 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
struct SM90_64x32x32_F32E4M3E4M3_SS {
|
| 72 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
| 73 |
+
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
| 74 |
+
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
| 75 |
+
bool scale_d) {
|
| 76 |
+
asm volatile("{\n"
|
| 77 |
+
".reg .pred p;\n"
|
| 78 |
+
"setp.ne.b32 p, %18, 0;\n"
|
| 79 |
+
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3"
|
| 80 |
+
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
| 81 |
+
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
| 82 |
+
" %16,"
|
| 83 |
+
" %17,"
|
| 84 |
+
" p , 1, 1;\n"
|
| 85 |
+
"}\n"
|
| 86 |
+
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
| 87 |
+
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
| 88 |
+
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 92 |
+
wgmma(desc_a, desc_b,
|
| 93 |
+
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
| 94 |
+
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
| 95 |
+
scale_d);
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
static constexpr int M = 64;
|
| 99 |
+
static constexpr int N = 32;
|
| 100 |
+
static constexpr int K = 32;
|
| 101 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 102 |
+
};
|
| 103 |
+
|
| 104 |
+
struct SM90_64x40x32_F32E4M3E4M3_SS {
|
| 105 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
| 106 |
+
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
| 107 |
+
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
| 108 |
+
float& d16, float& d17, float& d18, float& d19,
|
| 109 |
+
bool scale_d) {
|
| 110 |
+
asm volatile("{\n"
|
| 111 |
+
".reg .pred p;\n"
|
| 112 |
+
"setp.ne.b32 p, %22, 0;\n"
|
| 113 |
+
"wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3"
|
| 114 |
+
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
| 115 |
+
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
| 116 |
+
" %16, %17, %18, %19},"
|
| 117 |
+
" %20,"
|
| 118 |
+
" %21,"
|
| 119 |
+
" p , 1, 1;\n"
|
| 120 |
+
"}\n"
|
| 121 |
+
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
| 122 |
+
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
| 123 |
+
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19)
|
| 124 |
+
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 128 |
+
wgmma(desc_a, desc_b,
|
| 129 |
+
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
| 130 |
+
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
| 131 |
+
d[16], d[17], d[18], d[19],
|
| 132 |
+
scale_d);
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
static constexpr int M = 64;
|
| 136 |
+
static constexpr int N = 40;
|
| 137 |
+
static constexpr int K = 32;
|
| 138 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 139 |
+
};
|
| 140 |
+
|
| 141 |
+
struct SM90_64x48x32_F32E4M3E4M3_SS {
|
| 142 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
| 143 |
+
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
| 144 |
+
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
| 145 |
+
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
| 146 |
+
bool scale_d) {
|
| 147 |
+
asm volatile("{\n"
|
| 148 |
+
".reg .pred p;\n"
|
| 149 |
+
"setp.ne.b32 p, %26, 0;\n"
|
| 150 |
+
"wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3"
|
| 151 |
+
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
| 152 |
+
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
| 153 |
+
" %16, %17, %18, %19, %20, %21, %22, %23},"
|
| 154 |
+
" %24,"
|
| 155 |
+
" %25,"
|
| 156 |
+
" p , 1, 1;\n"
|
| 157 |
+
"}\n"
|
| 158 |
+
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
| 159 |
+
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
| 160 |
+
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23)
|
| 161 |
+
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 165 |
+
wgmma(desc_a, desc_b,
|
| 166 |
+
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
| 167 |
+
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
| 168 |
+
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
| 169 |
+
scale_d);
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
static constexpr int M = 64;
|
| 173 |
+
static constexpr int N = 48;
|
| 174 |
+
static constexpr int K = 32;
|
| 175 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 176 |
+
};
|
| 177 |
+
|
| 178 |
+
struct SM90_64x56x32_F32E4M3E4M3_SS {
|
| 179 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
| 180 |
+
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
| 181 |
+
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
| 182 |
+
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
| 183 |
+
float& d24, float& d25, float& d26, float& d27,
|
| 184 |
+
bool scale_d) {
|
| 185 |
+
asm volatile("{\n"
|
| 186 |
+
".reg .pred p;\n"
|
| 187 |
+
"setp.ne.b32 p, %30, 0;\n"
|
| 188 |
+
"wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3"
|
| 189 |
+
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
| 190 |
+
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
| 191 |
+
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
| 192 |
+
" %24, %25, %26, %27}, "
|
| 193 |
+
" %28,"
|
| 194 |
+
" %29,"
|
| 195 |
+
" p , 1, 1;\n"
|
| 196 |
+
"}\n"
|
| 197 |
+
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
| 198 |
+
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
| 199 |
+
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
| 200 |
+
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27)
|
| 201 |
+
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 205 |
+
wgmma(desc_a, desc_b,
|
| 206 |
+
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
| 207 |
+
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
| 208 |
+
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
| 209 |
+
d[24], d[25], d[26], d[27],
|
| 210 |
+
scale_d);
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
static constexpr int M = 64;
|
| 214 |
+
static constexpr int N = 56;
|
| 215 |
+
static constexpr int K = 32;
|
| 216 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 217 |
+
};
|
| 218 |
+
|
| 219 |
+
struct SM90_64x64x32_F32E4M3E4M3_SS {
|
| 220 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
| 221 |
+
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
| 222 |
+
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
| 223 |
+
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
| 224 |
+
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
| 225 |
+
bool scale_d) {
|
| 226 |
+
asm volatile("{\n"
|
| 227 |
+
".reg .pred p;\n"
|
| 228 |
+
"setp.ne.b32 p, %34, 0;\n"
|
| 229 |
+
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3"
|
| 230 |
+
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
| 231 |
+
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
| 232 |
+
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
| 233 |
+
" %24, %25, %26, %27, %28, %29, %30, %31}, "
|
| 234 |
+
" %32,"
|
| 235 |
+
" %33,"
|
| 236 |
+
" p , 1, 1;\n"
|
| 237 |
+
"}\n"
|
| 238 |
+
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
| 239 |
+
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
| 240 |
+
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
| 241 |
+
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
| 242 |
+
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 246 |
+
wgmma(desc_a, desc_b,
|
| 247 |
+
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
| 248 |
+
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
| 249 |
+
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
| 250 |
+
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
| 251 |
+
scale_d);
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
static constexpr int M = 64;
|
| 255 |
+
static constexpr int N = 64;
|
| 256 |
+
static constexpr int K = 32;
|
| 257 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 258 |
+
};
|
| 259 |
+
|
| 260 |
+
struct SM90_64x72x32_F32E4M3E4M3_SS {
|
| 261 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
| 262 |
+
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
| 263 |
+
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
| 264 |
+
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
| 265 |
+
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
| 266 |
+
float& d32, float& d33, float& d34, float& d35,
|
| 267 |
+
bool scale_d) {
|
| 268 |
+
asm volatile("{\n"
|
| 269 |
+
".reg .pred p;\n"
|
| 270 |
+
"setp.ne.b32 p, %38, 0;\n"
|
| 271 |
+
"wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3"
|
| 272 |
+
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
| 273 |
+
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
| 274 |
+
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
| 275 |
+
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
| 276 |
+
" %32, %33, %34, %35}, "
|
| 277 |
+
" %36,"
|
| 278 |
+
" %37,"
|
| 279 |
+
" p , 1, 1;\n"
|
| 280 |
+
"}\n"
|
| 281 |
+
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
| 282 |
+
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
| 283 |
+
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
| 284 |
+
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
| 285 |
+
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35)
|
| 286 |
+
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 290 |
+
wgmma(desc_a, desc_b,
|
| 291 |
+
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
| 292 |
+
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
| 293 |
+
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
| 294 |
+
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
| 295 |
+
d[32], d[33], d[34], d[35],
|
| 296 |
+
scale_d);
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
static constexpr int M = 64;
|
| 300 |
+
static constexpr int N = 72;
|
| 301 |
+
static constexpr int K = 32;
|
| 302 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 303 |
+
};
|
| 304 |
+
|
| 305 |
+
struct SM90_64x80x32_F32E4M3E4M3_SS {
|
| 306 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
| 307 |
+
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
| 308 |
+
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
| 309 |
+
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
| 310 |
+
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
| 311 |
+
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
| 312 |
+
bool scale_d) {
|
| 313 |
+
asm volatile("{\n"
|
| 314 |
+
".reg .pred p;\n"
|
| 315 |
+
"setp.ne.b32 p, %42, 0;\n"
|
| 316 |
+
"wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3"
|
| 317 |
+
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
| 318 |
+
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
| 319 |
+
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
| 320 |
+
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
| 321 |
+
" %32, %33, %34, %35, %36, %37, %38, %39}, "
|
| 322 |
+
" %40,"
|
| 323 |
+
" %41,"
|
| 324 |
+
" p , 1, 1;\n"
|
| 325 |
+
"}\n"
|
| 326 |
+
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
| 327 |
+
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
| 328 |
+
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
| 329 |
+
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
| 330 |
+
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39)
|
| 331 |
+
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 335 |
+
wgmma(desc_a, desc_b,
|
| 336 |
+
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
| 337 |
+
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
| 338 |
+
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
| 339 |
+
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
| 340 |
+
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
| 341 |
+
scale_d);
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
static constexpr int M = 64;
|
| 345 |
+
static constexpr int N = 80;
|
| 346 |
+
static constexpr int K = 32;
|
| 347 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 348 |
+
};
|
| 349 |
+
|
| 350 |
+
struct SM90_64x88x32_F32E4M3E4M3_SS {
|
| 351 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
| 352 |
+
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
| 353 |
+
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
| 354 |
+
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
| 355 |
+
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
| 356 |
+
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
| 357 |
+
float& d40, float& d41, float& d42, float& d43,
|
| 358 |
+
bool scale_d) {
|
| 359 |
+
asm volatile("{\n"
|
| 360 |
+
".reg .pred p;\n"
|
| 361 |
+
"setp.ne.b32 p, %46, 0;\n"
|
| 362 |
+
"wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3"
|
| 363 |
+
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
| 364 |
+
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
| 365 |
+
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
| 366 |
+
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
| 367 |
+
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
| 368 |
+
" %40, %41, %42, %43}, "
|
| 369 |
+
" %44,"
|
| 370 |
+
" %45,"
|
| 371 |
+
" p , 1, 1;\n"
|
| 372 |
+
"}\n"
|
| 373 |
+
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
| 374 |
+
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
| 375 |
+
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
| 376 |
+
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
| 377 |
+
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
| 378 |
+
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43)
|
| 379 |
+
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 383 |
+
wgmma(desc_a, desc_b,
|
| 384 |
+
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
| 385 |
+
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
| 386 |
+
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
| 387 |
+
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
| 388 |
+
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
| 389 |
+
d[40], d[41], d[42], d[43],
|
| 390 |
+
scale_d);
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
static constexpr int M = 64;
|
| 394 |
+
static constexpr int N = 88;
|
| 395 |
+
static constexpr int K = 32;
|
| 396 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 397 |
+
};
|
| 398 |
+
|
| 399 |
+
struct SM90_64x96x32_F32E4M3E4M3_SS {
|
| 400 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
| 401 |
+
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
| 402 |
+
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
| 403 |
+
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
| 404 |
+
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
| 405 |
+
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
| 406 |
+
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
| 407 |
+
bool scale_d) {
|
| 408 |
+
asm volatile("{\n"
|
| 409 |
+
".reg .pred p;\n"
|
| 410 |
+
"setp.ne.b32 p, %50, 0;\n"
|
| 411 |
+
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3"
|
| 412 |
+
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
| 413 |
+
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
| 414 |
+
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
| 415 |
+
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
| 416 |
+
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
| 417 |
+
" %40, %41, %42, %43, %44, %45, %46, %47}, "
|
| 418 |
+
" %48,"
|
| 419 |
+
" %49,"
|
| 420 |
+
" p , 1, 1;\n"
|
| 421 |
+
"}\n"
|
| 422 |
+
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
| 423 |
+
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
| 424 |
+
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
| 425 |
+
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
| 426 |
+
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
| 427 |
+
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
| 428 |
+
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 432 |
+
wgmma(desc_a, desc_b,
|
| 433 |
+
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
| 434 |
+
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
| 435 |
+
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
| 436 |
+
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
| 437 |
+
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
| 438 |
+
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
| 439 |
+
scale_d);
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
static constexpr int M = 64;
|
| 443 |
+
static constexpr int N = 96;
|
| 444 |
+
static constexpr int K = 32;
|
| 445 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 446 |
+
};
|
| 447 |
+
|
| 448 |
+
struct SM90_64x104x32_F32E4M3E4M3_SS {
|
| 449 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
| 450 |
+
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
| 451 |
+
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
| 452 |
+
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
| 453 |
+
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
| 454 |
+
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
| 455 |
+
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
| 456 |
+
float& d48, float& d49, float& d50, float& d51,
|
| 457 |
+
bool scale_d) {
|
| 458 |
+
asm volatile("{\n"
|
| 459 |
+
".reg .pred p;\n"
|
| 460 |
+
"setp.ne.b32 p, %54, 0;\n"
|
| 461 |
+
"wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3"
|
| 462 |
+
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
| 463 |
+
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
| 464 |
+
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
| 465 |
+
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
| 466 |
+
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
| 467 |
+
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
| 468 |
+
" %48, %49, %50, %51}, "
|
| 469 |
+
" %52,"
|
| 470 |
+
" %53,"
|
| 471 |
+
" p , 1, 1;\n"
|
| 472 |
+
"}\n"
|
| 473 |
+
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
| 474 |
+
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
| 475 |
+
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
| 476 |
+
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
| 477 |
+
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
| 478 |
+
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
| 479 |
+
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51)
|
| 480 |
+
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 484 |
+
wgmma(desc_a, desc_b,
|
| 485 |
+
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
| 486 |
+
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
| 487 |
+
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
| 488 |
+
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
| 489 |
+
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
| 490 |
+
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
| 491 |
+
d[48], d[49], d[50], d[51],
|
| 492 |
+
scale_d);
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
static constexpr int M = 64;
|
| 496 |
+
static constexpr int N = 104;
|
| 497 |
+
static constexpr int K = 32;
|
| 498 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 499 |
+
};
|
| 500 |
+
|
| 501 |
+
struct SM90_64x112x32_F32E4M3E4M3_SS {
|
| 502 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
| 503 |
+
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
| 504 |
+
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
| 505 |
+
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
| 506 |
+
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
| 507 |
+
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
| 508 |
+
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
| 509 |
+
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
| 510 |
+
bool scale_d) {
|
| 511 |
+
asm volatile("{\n"
|
| 512 |
+
".reg .pred p;\n"
|
| 513 |
+
"setp.ne.b32 p, %58, 0;\n"
|
| 514 |
+
"wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3"
|
| 515 |
+
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
| 516 |
+
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
| 517 |
+
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
| 518 |
+
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
| 519 |
+
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
| 520 |
+
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
| 521 |
+
" %48, %49, %50, %51, %52, %53, %54, %55}, "
|
| 522 |
+
" %56,"
|
| 523 |
+
" %57,"
|
| 524 |
+
" p , 1, 1;\n"
|
| 525 |
+
"}\n"
|
| 526 |
+
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
| 527 |
+
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
| 528 |
+
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
| 529 |
+
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
| 530 |
+
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
| 531 |
+
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
| 532 |
+
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55)
|
| 533 |
+
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 537 |
+
wgmma(desc_a, desc_b,
|
| 538 |
+
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
| 539 |
+
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
| 540 |
+
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
| 541 |
+
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
| 542 |
+
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
| 543 |
+
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
| 544 |
+
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
| 545 |
+
scale_d);
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
static constexpr int M = 64;
|
| 549 |
+
static constexpr int N = 112;
|
| 550 |
+
static constexpr int K = 32;
|
| 551 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 552 |
+
};
|
| 553 |
+
|
| 554 |
+
struct SM90_64x120x32_F32E4M3E4M3_SS {
|
| 555 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
| 556 |
+
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
| 557 |
+
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
| 558 |
+
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
| 559 |
+
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
| 560 |
+
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
| 561 |
+
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
| 562 |
+
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
| 563 |
+
float& d56, float& d57, float& d58, float& d59,
|
| 564 |
+
bool scale_d) {
|
| 565 |
+
asm volatile("{\n"
|
| 566 |
+
".reg .pred p;\n"
|
| 567 |
+
"setp.ne.b32 p, %62, 0;\n"
|
| 568 |
+
"wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3"
|
| 569 |
+
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
| 570 |
+
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
| 571 |
+
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
| 572 |
+
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
| 573 |
+
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
| 574 |
+
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
| 575 |
+
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
| 576 |
+
" %56, %57, %58, %59}, "
|
| 577 |
+
" %60,"
|
| 578 |
+
" %61,"
|
| 579 |
+
" p , 1, 1;\n"
|
| 580 |
+
"}\n"
|
| 581 |
+
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
| 582 |
+
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
| 583 |
+
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
| 584 |
+
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
| 585 |
+
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
| 586 |
+
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
| 587 |
+
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
| 588 |
+
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59)
|
| 589 |
+
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
| 590 |
+
}
|
| 591 |
+
|
| 592 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 593 |
+
wgmma(desc_a, desc_b,
|
| 594 |
+
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
| 595 |
+
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
| 596 |
+
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
| 597 |
+
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
| 598 |
+
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
| 599 |
+
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
| 600 |
+
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
| 601 |
+
d[56], d[57], d[58], d[59],
|
| 602 |
+
scale_d);
|
| 603 |
+
}
|
| 604 |
+
|
| 605 |
+
static constexpr int M = 64;
|
| 606 |
+
static constexpr int N = 120;
|
| 607 |
+
static constexpr int K = 32;
|
| 608 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 609 |
+
};
|
| 610 |
+
|
| 611 |
+
struct SM90_64x128x32_F32E4M3E4M3_SS {
|
| 612 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
| 613 |
+
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
| 614 |
+
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
| 615 |
+
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
| 616 |
+
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
| 617 |
+
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
| 618 |
+
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
| 619 |
+
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
| 620 |
+
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
|
| 621 |
+
bool scale_d) {
|
| 622 |
+
asm volatile("{\n"
|
| 623 |
+
".reg .pred p;\n"
|
| 624 |
+
"setp.ne.b32 p, %66, 0;\n"
|
| 625 |
+
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3"
|
| 626 |
+
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
| 627 |
+
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
| 628 |
+
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
| 629 |
+
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
| 630 |
+
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
| 631 |
+
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
| 632 |
+
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
| 633 |
+
" %56, %57, %58, %59, %60, %61, %62, %63}, "
|
| 634 |
+
" %64,"
|
| 635 |
+
" %65,"
|
| 636 |
+
" p , 1, 1;\n"
|
| 637 |
+
"}\n"
|
| 638 |
+
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
| 639 |
+
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
| 640 |
+
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
| 641 |
+
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
| 642 |
+
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
| 643 |
+
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
| 644 |
+
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
| 645 |
+
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
| 646 |
+
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
| 647 |
+
}
|
| 648 |
+
|
| 649 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 650 |
+
wgmma(desc_a, desc_b,
|
| 651 |
+
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
| 652 |
+
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
| 653 |
+
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
| 654 |
+
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
| 655 |
+
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
| 656 |
+
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
| 657 |
+
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
| 658 |
+
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
|
| 659 |
+
scale_d);
|
| 660 |
+
}
|
| 661 |
+
|
| 662 |
+
static constexpr int M = 64;
|
| 663 |
+
static constexpr int N = 128;
|
| 664 |
+
static constexpr int K = 32;
|
| 665 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 666 |
+
};
|
| 667 |
+
|
| 668 |
+
struct SM90_64x192x32_F32E4M3E4M3_SS {
|
| 669 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
|
| 670 |
+
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
|
| 671 |
+
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
|
| 672 |
+
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
|
| 673 |
+
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
|
| 674 |
+
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
|
| 675 |
+
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
|
| 676 |
+
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
|
| 677 |
+
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
|
| 678 |
+
float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71,
|
| 679 |
+
float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79,
|
| 680 |
+
float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87,
|
| 681 |
+
float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95,
|
| 682 |
+
bool scale_d) {
|
| 683 |
+
asm volatile("{\n"
|
| 684 |
+
".reg .pred p;\n"
|
| 685 |
+
"setp.ne.b32 p, %98, 0;\n"
|
| 686 |
+
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3"
|
| 687 |
+
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
| 688 |
+
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
| 689 |
+
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
| 690 |
+
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
| 691 |
+
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
| 692 |
+
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
| 693 |
+
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
| 694 |
+
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
| 695 |
+
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
| 696 |
+
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
| 697 |
+
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
| 698 |
+
" %88, %89, %90, %91, %92, %93, %94, %95}, "
|
| 699 |
+
" %96,"
|
| 700 |
+
" %97,"
|
| 701 |
+
" p , 1, 1;\n"
|
| 702 |
+
"}\n"
|
| 703 |
+
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
| 704 |
+
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
| 705 |
+
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
| 706 |
+
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
| 707 |
+
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
| 708 |
+
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
| 709 |
+
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
| 710 |
+
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
| 711 |
+
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
| 712 |
+
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
| 713 |
+
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
| 714 |
+
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
| 715 |
+
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
|
| 716 |
+
}
|
| 717 |
+
|
| 718 |
+
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
| 719 |
+
wgmma(desc_a, desc_b,
|
| 720 |
+
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
|
| 721 |
+
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
|
| 722 |
+
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
|
| 723 |
+
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
|
| 724 |
+
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
|
| 725 |
+
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
|
| 726 |
+
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
|
| 727 |
+
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
|
| 728 |
+
d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71],
|
| 729 |
+
d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79],
|
| 730 |
+
d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87],
|
| 731 |
+
d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95],
|
| 732 |
+
scale_d);
|
| 733 |
+
}
|
| 734 |
+
|
| 735 |
+
static constexpr int M = 64;
|
| 736 |
+
static constexpr int N = 192;
|
| 737 |
+
static constexpr int K = 32;
|
| 738 |
+
static constexpr int kNumAccum = M * N / 128;
|
| 739 |
+
};
|
| 740 |
+
|
| 741 |
+
template <typename dtype_t>
|
| 742 |
+
struct SM90_U32x2_STSM_N {
|
| 743 |
+
__device__ __forceinline__ static void
|
| 744 |
+
copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
|
| 745 |
+
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
|
| 746 |
+
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
|
| 747 |
+
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]));
|
| 748 |
+
}
|
| 749 |
+
};
|
| 750 |
+
|
| 751 |
+
template <typename dtype_t>
|
| 752 |
+
struct SM90_U32x4_STSM_N {
|
| 753 |
+
__device__ __forceinline__ static void
|
| 754 |
+
copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) {
|
| 755 |
+
const uint32_t src[4] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1),
|
| 756 |
+
*reinterpret_cast<uint32_t*>(&src_2), *reinterpret_cast<uint32_t*>(&src_3)};
|
| 757 |
+
asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
|
| 758 |
+
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3]));
|
| 759 |
+
}
|
| 760 |
+
};
|
| 761 |
+
|
| 762 |
+
__device__ void warpgroup_arrive() {
|
| 763 |
+
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
|
| 764 |
+
}
|
| 765 |
+
|
| 766 |
+
__device__ void warpgroup_commit_batch() {
|
| 767 |
+
asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
|
| 768 |
+
}
|
| 769 |
+
|
| 770 |
+
__device__ void warpgroup_fence_operand(float& reg) {
|
| 771 |
+
asm volatile("" : "+f"(reg) :: "memory");
|
| 772 |
+
}
|
| 773 |
+
|
| 774 |
+
__forceinline__ __device__ uint32_t get_lane_id() {
|
| 775 |
+
uint32_t lane_id;
|
| 776 |
+
asm("mov.u32 %0, %laneid;" : "=r"(lane_id));
|
| 777 |
+
return lane_id;
|
| 778 |
+
}
|
| 779 |
+
|
| 780 |
+
__device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) {
|
| 781 |
+
uint32_t ret;
|
| 782 |
+
asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
| 783 |
+
return ret;
|
| 784 |
+
}
|
| 785 |
+
|
| 786 |
+
__device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) {
|
| 787 |
+
int4 ret;
|
| 788 |
+
asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
|
| 789 |
+
return ret;
|
| 790 |
+
}
|
| 791 |
+
|
| 792 |
+
__device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) {
|
| 793 |
+
float ret;
|
| 794 |
+
asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
|
| 795 |
+
return ret;
|
| 796 |
+
}
|
| 797 |
+
|
| 798 |
+
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
|
| 799 |
+
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
|
| 800 |
+
}
|
| 801 |
+
|
| 802 |
+
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
|
| 803 |
+
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
|
| 804 |
+
}
|
| 805 |
+
|
| 806 |
+
template <int N>
|
| 807 |
+
__device__ void warpgroup_wait() {
|
| 808 |
+
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
|
| 809 |
+
asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
|
| 810 |
+
}
|
| 811 |
+
|
| 812 |
+
union GmmaDescriptor {
|
| 813 |
+
__host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {}
|
| 814 |
+
|
| 815 |
+
__host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {}
|
| 816 |
+
|
| 817 |
+
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {}
|
| 818 |
+
|
| 819 |
+
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {}
|
| 820 |
+
|
| 821 |
+
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept {
|
| 822 |
+
desc_ = t.desc_;
|
| 823 |
+
return *this;
|
| 824 |
+
}
|
| 825 |
+
|
| 826 |
+
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept {
|
| 827 |
+
desc_ = t.desc_;
|
| 828 |
+
return *this;
|
| 829 |
+
}
|
| 830 |
+
|
| 831 |
+
uint64_t desc_;
|
| 832 |
+
uint32_t reg32_[2];
|
| 833 |
+
uint16_t reg16_[4];
|
| 834 |
+
|
| 835 |
+
struct {
|
| 836 |
+
uint16_t start_address_: 14, : 2;
|
| 837 |
+
uint16_t leading_byte_offset_: 14, : 2;
|
| 838 |
+
uint16_t stride_byte_offset_: 14, : 2;
|
| 839 |
+
uint8_t : 1, base_offset_: 3, : 4;
|
| 840 |
+
uint8_t : 6, layout_type_: 2;
|
| 841 |
+
} bitfield;
|
| 842 |
+
|
| 843 |
+
// Decay to an `uint64_t`
|
| 844 |
+
__host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; }
|
| 845 |
+
};
|
| 846 |
+
|
| 847 |
+
template <class PointerType>
|
| 848 |
+
__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type,
|
| 849 |
+
int leading_byte_offset = 0,
|
| 850 |
+
int stride_byte_offset = 1024) {
|
| 851 |
+
GmmaDescriptor desc;
|
| 852 |
+
auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
| 853 |
+
desc.bitfield.start_address_ = uint_ptr >> 4;
|
| 854 |
+
desc.bitfield.layout_type_ = layout_type;
|
| 855 |
+
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
|
| 856 |
+
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
|
| 857 |
+
desc.bitfield.base_offset_ = 0;
|
| 858 |
+
return desc;
|
| 859 |
+
}
|
| 860 |
+
|
| 861 |
+
template <int N>
|
| 862 |
+
struct FP8MMASelector {
|
| 863 |
+
static constexpr auto select_type() {
|
| 864 |
+
if constexpr (N == 16) return SM90_64x16x32_F32E4M3E4M3_SS();
|
| 865 |
+
if constexpr (N == 24) return SM90_64x24x32_F32E4M3E4M3_SS();
|
| 866 |
+
if constexpr (N == 32) return SM90_64x32x32_F32E4M3E4M3_SS();
|
| 867 |
+
if constexpr (N == 40) return SM90_64x40x32_F32E4M3E4M3_SS();
|
| 868 |
+
if constexpr (N == 48) return SM90_64x48x32_F32E4M3E4M3_SS();
|
| 869 |
+
if constexpr (N == 56) return SM90_64x56x32_F32E4M3E4M3_SS();
|
| 870 |
+
if constexpr (N == 64) return SM90_64x64x32_F32E4M3E4M3_SS();
|
| 871 |
+
if constexpr (N == 72) return SM90_64x72x32_F32E4M3E4M3_SS();
|
| 872 |
+
if constexpr (N == 80) return SM90_64x80x32_F32E4M3E4M3_SS();
|
| 873 |
+
if constexpr (N == 88) return SM90_64x88x32_F32E4M3E4M3_SS();
|
| 874 |
+
if constexpr (N == 96) return SM90_64x96x32_F32E4M3E4M3_SS();
|
| 875 |
+
if constexpr (N == 104) return SM90_64x104x32_F32E4M3E4M3_SS();
|
| 876 |
+
if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS();
|
| 877 |
+
if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS();
|
| 878 |
+
if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS();
|
| 879 |
+
if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS();
|
| 880 |
+
}
|
| 881 |
+
|
| 882 |
+
using type = decltype(select_type());
|
| 883 |
+
};
|
| 884 |
+
|
| 885 |
+
} // namespace deep_gemm
|
torch-ext/deep_gemm/include/deep_gemm/scheduler.cuh
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "utils.cuh"
|
| 2 |
+
|
| 3 |
+
namespace deep_gemm {
|
| 4 |
+
|
| 5 |
+
enum class GemmType {
|
| 6 |
+
Normal,
|
| 7 |
+
GroupedContiguous,
|
| 8 |
+
GroupedMasked
|
| 9 |
+
};
|
| 10 |
+
|
| 11 |
+
#pragma clang diagnostic push
|
| 12 |
+
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
|
| 13 |
+
template <GemmType kGemmType,
|
| 14 |
+
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
|
| 15 |
+
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
|
| 16 |
+
uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N),
|
| 17 |
+
uint32_t kNumNBlocksPerGroup = 16>
|
| 18 |
+
struct Scheduler {
|
| 19 |
+
int current_iter = -1;
|
| 20 |
+
uint32_t num_aligned_m_blocks;
|
| 21 |
+
|
| 22 |
+
// For normal GEMM
|
| 23 |
+
// Maybe not used in the masked grouped GEMM
|
| 24 |
+
uint32_t num_blocks;
|
| 25 |
+
|
| 26 |
+
// For grouped GEMM
|
| 27 |
+
int* grouped_layout;
|
| 28 |
+
// Only used for masked layout
|
| 29 |
+
uint32_t curr_group_idx, curr_cumsum;
|
| 30 |
+
|
| 31 |
+
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
|
| 32 |
+
int* grouped_layout = nullptr) {
|
| 33 |
+
num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M);
|
| 34 |
+
if constexpr (kGemmType == GemmType::Normal) {
|
| 35 |
+
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
| 36 |
+
} else if (kGemmType == GemmType::GroupedContiguous) {
|
| 37 |
+
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
| 38 |
+
this->grouped_layout = grouped_layout;
|
| 39 |
+
} else if (kGemmType == GemmType::GroupedMasked) {
|
| 40 |
+
curr_group_idx = curr_cumsum = 0;
|
| 41 |
+
this->grouped_layout = grouped_layout;
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
| 46 |
+
DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
| 47 |
+
|
| 48 |
+
// Swizzle for better L2 usages
|
| 49 |
+
auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup;
|
| 50 |
+
auto group_idx = block_idx / num_blocks_per_group;
|
| 51 |
+
auto first_n_block_idx = group_idx * kNumNBlocksPerGroup;
|
| 52 |
+
auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx);
|
| 53 |
+
auto in_group_idx = block_idx % num_blocks_per_group;
|
| 54 |
+
m_block_idx = in_group_idx / num_n_blocks_in_group;
|
| 55 |
+
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
template <bool kIgnoreGroupedForGroupedContiguous=true>
|
| 59 |
+
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
|
| 60 |
+
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
|
| 61 |
+
if constexpr (kGemmType == GemmType::Normal) {
|
| 62 |
+
return block_idx * block_size;
|
| 63 |
+
} else if (kGemmType == GemmType::GroupedContiguous) {
|
| 64 |
+
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
| 65 |
+
return offset * shape_dim + block_idx * block_size;
|
| 66 |
+
} else if (kGemmType == GemmType::GroupedMasked) {
|
| 67 |
+
return curr_group_idx * shape_dim + block_idx * block_size;
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
| 72 |
+
const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x;
|
| 73 |
+
|
| 74 |
+
if constexpr (kGemmType == GemmType::GroupedMasked) {
|
| 75 |
+
uint32_t num_m_blocks;
|
| 76 |
+
while (true) {
|
| 77 |
+
// End of the task
|
| 78 |
+
if (curr_group_idx == kNumGroups)
|
| 79 |
+
return false;
|
| 80 |
+
|
| 81 |
+
// Within current group
|
| 82 |
+
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
|
| 83 |
+
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
|
| 84 |
+
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
|
| 85 |
+
break;
|
| 86 |
+
|
| 87 |
+
// Move to check the next group
|
| 88 |
+
curr_group_idx ++, curr_cumsum = current_m_block_cumsum;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx);
|
| 92 |
+
} else {
|
| 93 |
+
if (next_block_idx >= num_blocks)
|
| 94 |
+
return false;
|
| 95 |
+
|
| 96 |
+
get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
|
| 97 |
+
}
|
| 98 |
+
return true;
|
| 99 |
+
}
|
| 100 |
+
};
|
| 101 |
+
#pragma clang diagnostic pop
|
| 102 |
+
|
| 103 |
+
} // namespace deep_gemm
|
torch-ext/deep_gemm/include/deep_gemm/tma_utils.cuh
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cassert>
|
| 4 |
+
#include <cuda.h>
|
| 5 |
+
#include <cudaTypedefs.h>
|
| 6 |
+
#include <cuda_fp8.h>
|
| 7 |
+
#include <cuda_runtime.h>
|
| 8 |
+
#include <cuda/barrier>
|
| 9 |
+
|
| 10 |
+
#include "utils.cuh"
|
| 11 |
+
|
| 12 |
+
namespace deep_gemm {
|
| 13 |
+
|
| 14 |
+
template <class T>
|
| 15 |
+
constexpr CUtensorMapDataType get_CUtensorMapDataType() {
|
| 16 |
+
if constexpr (std::is_same<T, uint8_t>::value) {
|
| 17 |
+
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
| 18 |
+
} else if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
|
| 19 |
+
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
| 20 |
+
} else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
|
| 21 |
+
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
| 22 |
+
} else if constexpr (std::is_same<T, uint16_t>::value) {
|
| 23 |
+
return CU_TENSOR_MAP_DATA_TYPE_UINT16;
|
| 24 |
+
} else if constexpr (std::is_same<T, uint32_t>::value) {
|
| 25 |
+
return CU_TENSOR_MAP_DATA_TYPE_UINT32;
|
| 26 |
+
} else if constexpr (std::is_same<T, uint64_t>::value) {
|
| 27 |
+
return CU_TENSOR_MAP_DATA_TYPE_UINT64;
|
| 28 |
+
} else if constexpr (std::is_same<T, int32_t>::value) {
|
| 29 |
+
return CU_TENSOR_MAP_DATA_TYPE_INT32;
|
| 30 |
+
} else if constexpr (std::is_same<T, int64_t>::value) {
|
| 31 |
+
return CU_TENSOR_MAP_DATA_TYPE_INT64;
|
| 32 |
+
} else if constexpr (std::is_same<T, __half>::value) {
|
| 33 |
+
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
| 34 |
+
} else if constexpr (std::is_same<T, float>::value) {
|
| 35 |
+
return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
| 36 |
+
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
| 37 |
+
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
| 38 |
+
} else if constexpr (std::is_same<T, double>::value) {
|
| 39 |
+
return CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
|
| 44 |
+
// Get pointer to `cuTensorMapEncodeTiled`
|
| 45 |
+
cudaDriverEntryPointQueryResult driver_status;
|
| 46 |
+
void* cuTensorMapEncodeTiled_ptr = nullptr;
|
| 47 |
+
|
| 48 |
+
#if CUDA_VERSION >= 12050
|
| 49 |
+
cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000,
|
| 50 |
+
cudaEnableDefault, &driver_status);
|
| 51 |
+
#else
|
| 52 |
+
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr,
|
| 53 |
+
cudaEnableDefault, &driver_status);
|
| 54 |
+
#endif
|
| 55 |
+
|
| 56 |
+
if (driver_status != cudaDriverEntryPointSuccess)
|
| 57 |
+
throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
|
| 58 |
+
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
template <typename T>
|
| 62 |
+
CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
|
| 63 |
+
uint64_t stride_in_bytes, uint32_t smem_dim[2],
|
| 64 |
+
CUtensorMapSwizzle swizzle_type,
|
| 65 |
+
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
|
| 66 |
+
CUtensorMap tensor_map{};
|
| 67 |
+
constexpr uint32_t rank = 2;
|
| 68 |
+
uint64_t global_stride[rank - 1] = {stride_in_bytes};
|
| 69 |
+
uint32_t elem_strides[rank] = {1, 1};
|
| 70 |
+
|
| 71 |
+
if (encode_func == nullptr)
|
| 72 |
+
encode_func = get_cuTensorMapEncodeTiled();
|
| 73 |
+
|
| 74 |
+
auto result = encode_func(
|
| 75 |
+
&tensor_map, get_CUtensorMapDataType<typename std::remove_cv<T>::type>(), rank,
|
| 76 |
+
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
|
| 77 |
+
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
|
| 78 |
+
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
| 79 |
+
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
| 80 |
+
DG_HOST_ASSERT(result == CUDA_SUCCESS);
|
| 81 |
+
return tensor_map;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
template <uint32_t kNumTMAMulticast = 1>
|
| 85 |
+
__device__ __forceinline__ void
|
| 86 |
+
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
| 87 |
+
int32_t const& crd_0, int32_t const& crd_1) {
|
| 88 |
+
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
|
| 89 |
+
if constexpr (kNumTMAMulticast == 1) {
|
| 90 |
+
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
|
| 91 |
+
} else if (cute::block_rank_in_cluster() == 0) {
|
| 92 |
+
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << kNumTMAMulticast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
} // namespace deep_gemm
|
torch-ext/deep_gemm/include/deep_gemm/utils.cuh
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <exception>
|
| 4 |
+
|
| 5 |
+
#ifdef __CLION_IDE__
|
| 6 |
+
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
|
| 7 |
+
#define printf host_device_printf
|
| 8 |
+
#endif
|
| 9 |
+
|
| 10 |
+
class AssertionException : public std::exception {
|
| 11 |
+
private:
|
| 12 |
+
std::string message{};
|
| 13 |
+
|
| 14 |
+
public:
|
| 15 |
+
explicit AssertionException(const std::string& message) : message(message) {}
|
| 16 |
+
|
| 17 |
+
const char *what() const noexcept override { return message.c_str(); }
|
| 18 |
+
};
|
| 19 |
+
|
| 20 |
+
#ifndef DG_HOST_ASSERT
|
| 21 |
+
#define DG_HOST_ASSERT(cond) \
|
| 22 |
+
do { \
|
| 23 |
+
if (not (cond)) { \
|
| 24 |
+
printf("Assertion failed: %s:%d, condition: %s\n", \
|
| 25 |
+
__FILE__, __LINE__, #cond); \
|
| 26 |
+
throw AssertionException("Assertion failed: " #cond); \
|
| 27 |
+
} \
|
| 28 |
+
} while (0)
|
| 29 |
+
#endif
|
| 30 |
+
|
| 31 |
+
#ifndef DG_DEVICE_ASSERT
|
| 32 |
+
#define DG_DEVICE_ASSERT(cond) \
|
| 33 |
+
do { \
|
| 34 |
+
if (not (cond)) { \
|
| 35 |
+
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
|
| 36 |
+
asm("trap;"); \
|
| 37 |
+
} \
|
| 38 |
+
} while (0)
|
| 39 |
+
#endif
|
| 40 |
+
|
| 41 |
+
#ifndef DG_STATIC_ASSERT
|
| 42 |
+
#define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
|
| 43 |
+
#endif
|
| 44 |
+
|
| 45 |
+
template <typename T>
|
| 46 |
+
__device__ __host__ constexpr T ceil_div(T a, T b) {
|
| 47 |
+
return (a + b - 1) / b;
|
| 48 |
+
}
|
torch-ext/deep_gemm/jit/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .compiler import get_nvcc_compiler, build
|
| 2 |
+
from .template import cpp_format, generate
|
| 3 |
+
from .runtime import Runtime
|
torch-ext/deep_gemm/jit/compiler.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import functools
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import subprocess
|
| 6 |
+
import uuid
|
| 7 |
+
from torch.utils.cpp_extension import CUDA_HOME
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
from . import interleave_ffma
|
| 11 |
+
from .runtime import Runtime, RuntimeCache
|
| 12 |
+
from .template import typename_map
|
| 13 |
+
|
| 14 |
+
runtime_cache = RuntimeCache()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def hash_to_hex(s: str) -> str:
|
| 18 |
+
md5 = hashlib.md5()
|
| 19 |
+
md5.update(s.encode('utf-8'))
|
| 20 |
+
return md5.hexdigest()[0:12]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@functools.lru_cache(maxsize=None)
|
| 24 |
+
def get_jit_include_dir() -> str:
|
| 25 |
+
return f'{os.path.dirname(os.path.abspath(__file__))}/../include'
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@functools.lru_cache(maxsize=None)
|
| 29 |
+
def get_deep_gemm_version() -> str:
|
| 30 |
+
# Update include directories
|
| 31 |
+
include_dir = f'{get_jit_include_dir()}/deep_gemm'
|
| 32 |
+
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
|
| 33 |
+
md5 = hashlib.md5()
|
| 34 |
+
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
|
| 35 |
+
with open(f'{include_dir}/{filename}', 'rb') as f:
|
| 36 |
+
md5.update(f.read())
|
| 37 |
+
|
| 38 |
+
# Update `interleave_ffma.py`
|
| 39 |
+
with open(f'{os.path.dirname(os.path.realpath(__file__))}/interleave_ffma.py', 'rb') as f:
|
| 40 |
+
md5.update(f.read())
|
| 41 |
+
return md5.hexdigest()[0:12]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@functools.lru_cache(maxsize=None)
|
| 45 |
+
def get_nvcc_compiler() -> Tuple[str, str]:
|
| 46 |
+
paths = []
|
| 47 |
+
if os.getenv('DG_NVCC_COMPILER'):
|
| 48 |
+
paths.append(os.getenv('DG_NVCC_COMPILER'))
|
| 49 |
+
paths.append(f'{CUDA_HOME}/bin/nvcc')
|
| 50 |
+
|
| 51 |
+
# Try to find the first available NVCC compiler
|
| 52 |
+
least_version_required = '12.3'
|
| 53 |
+
version_pattern = re.compile(r'release (\d+\.\d+)')
|
| 54 |
+
for path in paths:
|
| 55 |
+
if os.path.exists(path):
|
| 56 |
+
match = version_pattern.search(os.popen(f'{path} --version').read())
|
| 57 |
+
version = match.group(1)
|
| 58 |
+
assert match, f'Cannot get the version of NVCC compiler {path}'
|
| 59 |
+
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
|
| 60 |
+
return path, version
|
| 61 |
+
raise RuntimeError('Cannot find any available NVCC compiler')
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@functools.lru_cache(maxsize=None)
|
| 65 |
+
def get_default_user_dir():
|
| 66 |
+
if 'DG_CACHE_DIR' in os.environ:
|
| 67 |
+
path = os.getenv('DG_CACHE_DIR')
|
| 68 |
+
os.makedirs(path, exist_ok=True)
|
| 69 |
+
return path
|
| 70 |
+
return os.path.expanduser('~') + '/.deep_gemm'
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@functools.lru_cache(maxsize=None)
|
| 74 |
+
def get_tmp_dir():
|
| 75 |
+
return f'{get_default_user_dir()}/tmp'
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@functools.lru_cache(maxsize=None)
|
| 79 |
+
def get_cache_dir():
|
| 80 |
+
return f'{get_default_user_dir()}/cache'
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def make_tmp_dir():
|
| 84 |
+
tmp_dir = get_tmp_dir()
|
| 85 |
+
os.makedirs(tmp_dir, exist_ok=True)
|
| 86 |
+
return tmp_dir
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def put(path, data, is_binary=False):
|
| 90 |
+
# Write and do POSIX atomic replace
|
| 91 |
+
tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}'
|
| 92 |
+
with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
|
| 93 |
+
f.write(data)
|
| 94 |
+
os.replace(tmp_file_path, path)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def build(name: str, arg_defs: tuple, code: str) -> Runtime:
|
| 98 |
+
# Compiler flags
|
| 99 |
+
nvcc_flags = ['-std=c++17', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
| 100 |
+
'-gencode=arch=compute_90a,code=sm_90a',
|
| 101 |
+
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
| 102 |
+
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
| 103 |
+
'--diag-suppress=177,174,940']
|
| 104 |
+
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi']
|
| 105 |
+
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
|
| 106 |
+
include_dirs = [get_jit_include_dir()]
|
| 107 |
+
|
| 108 |
+
# Build signature
|
| 109 |
+
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
|
| 110 |
+
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
|
| 111 |
+
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
| 112 |
+
path = f'{get_cache_dir()}/{name}'
|
| 113 |
+
|
| 114 |
+
# Check runtime cache or file system hit
|
| 115 |
+
global runtime_cache
|
| 116 |
+
if runtime_cache[path] is not None:
|
| 117 |
+
if os.getenv('DG_JIT_DEBUG', None):
|
| 118 |
+
print(f'Using cached JIT runtime {name} during build')
|
| 119 |
+
return runtime_cache[path]
|
| 120 |
+
|
| 121 |
+
# Write the code
|
| 122 |
+
os.makedirs(path, exist_ok=True)
|
| 123 |
+
args_path = f'{path}/kernel.args'
|
| 124 |
+
src_path = f'{path}/kernel.cu'
|
| 125 |
+
put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs]))
|
| 126 |
+
put(src_path, code)
|
| 127 |
+
|
| 128 |
+
# Compile into a temporary SO file
|
| 129 |
+
so_path = f'{path}/kernel.so'
|
| 130 |
+
tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so'
|
| 131 |
+
|
| 132 |
+
# Compile
|
| 133 |
+
command = [get_nvcc_compiler()[0],
|
| 134 |
+
src_path, '-o', tmp_so_path,
|
| 135 |
+
*flags,
|
| 136 |
+
*[f'-I{d}' for d in include_dirs]]
|
| 137 |
+
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
|
| 138 |
+
print(f'Compiling JIT runtime {name} with command {command}')
|
| 139 |
+
return_code = subprocess.check_call(command)
|
| 140 |
+
assert return_code == 0, f'Failed to compile {src_path}'
|
| 141 |
+
|
| 142 |
+
# Interleave FFMA reuse
|
| 143 |
+
if enable_sass_opt:
|
| 144 |
+
interleave_ffma.process(tmp_so_path)
|
| 145 |
+
|
| 146 |
+
# Atomic replace SO file
|
| 147 |
+
os.replace(tmp_so_path, so_path)
|
| 148 |
+
|
| 149 |
+
# Put cache and return
|
| 150 |
+
runtime_cache[path] = Runtime(path)
|
| 151 |
+
return runtime_cache[path]
|
torch-ext/deep_gemm/jit/interleave_ffma.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import mmap
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import subprocess
|
| 6 |
+
from torch.utils.cpp_extension import CUDA_HOME
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def run_cuobjdump(file_path):
|
| 10 |
+
command = [f'{CUDA_HOME}/bin/cuobjdump', '-sass', file_path]
|
| 11 |
+
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
| 12 |
+
assert result.returncode == 0
|
| 13 |
+
return result.stdout
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def extract_ffma(sass):
|
| 17 |
+
lines = sass.splitlines()
|
| 18 |
+
collected = []
|
| 19 |
+
current = []
|
| 20 |
+
|
| 21 |
+
arch_name, func_name = 'N/A', 'N/A'
|
| 22 |
+
skip_next_line = False
|
| 23 |
+
for line in lines:
|
| 24 |
+
if 'code for' in line:
|
| 25 |
+
arch_name = line.lstrip().lstrip('code for ').rstrip()
|
| 26 |
+
elif 'Function :' in line:
|
| 27 |
+
func_name = line.lstrip().lstrip('Function :').rstrip()
|
| 28 |
+
elif 'FFMA' in line:
|
| 29 |
+
current.append(line)
|
| 30 |
+
skip_next_line = True
|
| 31 |
+
elif skip_next_line:
|
| 32 |
+
current.append(line)
|
| 33 |
+
skip_next_line = False
|
| 34 |
+
else:
|
| 35 |
+
if len(current) >= 16:
|
| 36 |
+
assert len(current) % 2 == 0
|
| 37 |
+
collected.append((f'{arch_name}::{func_name}', current))
|
| 38 |
+
current = []
|
| 39 |
+
|
| 40 |
+
if os.getenv('DG_PRINT_REG_REUSE', None):
|
| 41 |
+
print(f'Found {len(collected)} FFMA segments')
|
| 42 |
+
return collected
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def extract_hex_from_line(line):
|
| 46 |
+
match = re.search(r'/\*\s*(0x[0-9a-fA-F]+)\s*\*/', line)
|
| 47 |
+
assert match
|
| 48 |
+
return int(match.group(1), 16)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def validate(m, offset, le_bytes, num_lines):
|
| 52 |
+
assert len(le_bytes) == num_lines // 2
|
| 53 |
+
assert m[offset:offset + 16] == le_bytes[0]
|
| 54 |
+
for i in range(1, num_lines // 2):
|
| 55 |
+
if m[offset + i * 16:offset + i * 16 + 16] != le_bytes[i]:
|
| 56 |
+
return False
|
| 57 |
+
return True
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def parse_registers(line):
|
| 61 |
+
line = re.sub(r'/\*.*?\*/', '', line)
|
| 62 |
+
line = line.replace(';', '')
|
| 63 |
+
tokens = line.strip().split(',')
|
| 64 |
+
registers = []
|
| 65 |
+
for token in tokens:
|
| 66 |
+
token = token.strip()
|
| 67 |
+
words = token.split()
|
| 68 |
+
for word in words:
|
| 69 |
+
if word.startswith('R'):
|
| 70 |
+
reg = word.split('.')[0]
|
| 71 |
+
registers.append(reg)
|
| 72 |
+
return registers
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def modify_segment(m, name, ffma_lines):
|
| 76 |
+
num_lines = len(ffma_lines)
|
| 77 |
+
assert num_lines % 2 == 0
|
| 78 |
+
|
| 79 |
+
le_bytes, new_le_bytes = [], []
|
| 80 |
+
reused_list = []
|
| 81 |
+
dst_reg_set = set()
|
| 82 |
+
last_reused, last_dst_reg = False, ''
|
| 83 |
+
num_changed = 0
|
| 84 |
+
for i in range(num_lines // 2):
|
| 85 |
+
dst_reg = parse_registers(ffma_lines[i * 2])[-2]
|
| 86 |
+
low_line, high_line = ffma_lines[i * 2], ffma_lines[i * 2 + 1]
|
| 87 |
+
low_hex, high_hex = extract_hex_from_line(low_line), extract_hex_from_line(high_line)
|
| 88 |
+
le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
|
| 89 |
+
reused = (high_hex & 0x0800000000000000) != 0
|
| 90 |
+
if reused:
|
| 91 |
+
is_first_occurred = dst_reg not in dst_reg_set
|
| 92 |
+
if is_first_occurred or (last_reused and dst_reg == last_dst_reg):
|
| 93 |
+
# Modify the `reuse` and `yield` bits
|
| 94 |
+
assert high_hex & 0x0800200000000000, f'{hex(high_hex)}'
|
| 95 |
+
high_hex ^= 0x0800200000000000
|
| 96 |
+
reused = False
|
| 97 |
+
num_changed += 1
|
| 98 |
+
else:
|
| 99 |
+
reused_list.append(i)
|
| 100 |
+
dst_reg_set.add(dst_reg)
|
| 101 |
+
new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
|
| 102 |
+
last_reused, last_dst_reg = reused, dst_reg
|
| 103 |
+
if os.getenv('DG_PRINT_REG_REUSE', None):
|
| 104 |
+
print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}')
|
| 105 |
+
|
| 106 |
+
# Find the offset
|
| 107 |
+
offsets = []
|
| 108 |
+
offset = m.find(le_bytes[0])
|
| 109 |
+
while offset != -1:
|
| 110 |
+
offsets.append(offset)
|
| 111 |
+
offset = m.find(le_bytes[0], offset + 1)
|
| 112 |
+
offsets = list(filter(lambda x: validate(m, x, le_bytes, num_lines), offsets))
|
| 113 |
+
|
| 114 |
+
# Replace with `new_le_bytes`
|
| 115 |
+
for offset in offsets:
|
| 116 |
+
for i in range(num_lines // 2):
|
| 117 |
+
m[offset + i * 16:offset + i * 16 + 16] = new_le_bytes[i]
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def process(path):
|
| 121 |
+
if os.getenv('DG_PRINT_REG_REUSE', None):
|
| 122 |
+
print(f'Processing {path}')
|
| 123 |
+
output = run_cuobjdump(path)
|
| 124 |
+
segments = extract_ffma(output)
|
| 125 |
+
with open(path, 'r+b') as f:
|
| 126 |
+
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_WRITE)
|
| 127 |
+
for segment in segments:
|
| 128 |
+
modify_segment(mm, *segment)
|
| 129 |
+
mm.close()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
if __name__ == '__main__':
|
| 133 |
+
parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse')
|
| 134 |
+
parser.add_argument('--so', help='Path to the SO file')
|
| 135 |
+
args = parser.parse_args()
|
| 136 |
+
|
| 137 |
+
process(args.so)
|
torch-ext/deep_gemm/jit/runtime.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
from .template import map_ctype
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Runtime:
|
| 10 |
+
def __init__(self, path: str) -> None:
|
| 11 |
+
self.path = path
|
| 12 |
+
self.lib = None
|
| 13 |
+
self.args = None
|
| 14 |
+
|
| 15 |
+
assert self.is_path_valid(self.path)
|
| 16 |
+
|
| 17 |
+
@staticmethod
|
| 18 |
+
def is_path_valid(path: str) -> bool:
|
| 19 |
+
# Exists and is a directory
|
| 20 |
+
if not os.path.exists(path) or not os.path.isdir(path):
|
| 21 |
+
return False
|
| 22 |
+
|
| 23 |
+
# Contains all necessary files
|
| 24 |
+
files = ['kernel.cu', 'kernel.args', 'kernel.so']
|
| 25 |
+
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
| 26 |
+
|
| 27 |
+
def __call__(self, *args) -> int:
|
| 28 |
+
# Load SO file
|
| 29 |
+
if self.lib is None or self.args is None:
|
| 30 |
+
self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so'))
|
| 31 |
+
with open(os.path.join(self.path, 'kernel.args'), 'r') as f:
|
| 32 |
+
self.args = eval(f.read())
|
| 33 |
+
|
| 34 |
+
# Check args and launch
|
| 35 |
+
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
|
| 36 |
+
cargs = []
|
| 37 |
+
for arg, (name, dtype) in zip(args, self.args):
|
| 38 |
+
if isinstance(arg, torch.Tensor):
|
| 39 |
+
assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
|
| 40 |
+
else:
|
| 41 |
+
assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
|
| 42 |
+
cargs.append(map_ctype(arg))
|
| 43 |
+
|
| 44 |
+
return_code = ctypes.c_int(0)
|
| 45 |
+
self.lib.launch(*cargs, ctypes.byref(return_code))
|
| 46 |
+
return return_code.value
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class RuntimeCache:
|
| 50 |
+
def __init__(self) -> None:
|
| 51 |
+
self.cache = {}
|
| 52 |
+
|
| 53 |
+
def __getitem__(self, path: str) -> Optional[Runtime]:
|
| 54 |
+
# In Python runtime
|
| 55 |
+
if path in self.cache:
|
| 56 |
+
return self.cache[path]
|
| 57 |
+
|
| 58 |
+
# Already compiled
|
| 59 |
+
if os.path.exists(path) and Runtime.is_path_valid(path):
|
| 60 |
+
runtime = Runtime(path)
|
| 61 |
+
self.cache[path] = runtime
|
| 62 |
+
return runtime
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
def __setitem__(self, path, runtime) -> None:
|
| 66 |
+
self.cache[path] = runtime
|
torch-ext/deep_gemm/jit/template.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import ctypes
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from typing import Any, Iterable, Dict, Tuple
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Name map for Python `eval`
|
| 10 |
+
typename_map: Dict[Any, str] = {
|
| 11 |
+
**{t: t.__name__ for t in (bool, int, float)},
|
| 12 |
+
torch.int: 'torch.int',
|
| 13 |
+
torch.float: 'torch.float',
|
| 14 |
+
torch.bfloat16: 'torch.bfloat16',
|
| 15 |
+
torch.float8_e4m3fn: 'torch.float8_e4m3fn',
|
| 16 |
+
torch.cuda.Stream: 'torch.cuda.Stream',
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
# `ctype` map for Python casting
|
| 20 |
+
ctype_map: Dict[Any, Any] = {
|
| 21 |
+
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
|
| 22 |
+
**{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Type map for both Python API and source code usages
|
| 27 |
+
genc_map = {
|
| 28 |
+
bool: ('bool', 'bool'),
|
| 29 |
+
int: ('int', 'int'),
|
| 30 |
+
float: ('float', 'float'),
|
| 31 |
+
torch.int: ('void*', 'int*'),
|
| 32 |
+
torch.float: ('void*', 'float*'),
|
| 33 |
+
torch.bfloat16: ('void*', '__nv_bfloat16*'),
|
| 34 |
+
torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
|
| 35 |
+
torch.cuda.Stream: ('void*', 'cudaStream_t'),
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def map_ctype(value: Any) -> Any:
|
| 40 |
+
ctype = ctype_map[value.dtype if isinstance(value, torch.Tensor) else type(value)]
|
| 41 |
+
if isinstance(value, torch.Tensor):
|
| 42 |
+
return ctype(value.data_ptr())
|
| 43 |
+
if isinstance(value, torch.cuda.Stream):
|
| 44 |
+
return ctype(value.cuda_stream)
|
| 45 |
+
return ctype(value)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def cpp_format(template: str, keys: Dict[str, Any]) -> str:
|
| 49 |
+
# We don't use `str.format` because it's not safe for C++ {} braces
|
| 50 |
+
new_template = copy.deepcopy(template)
|
| 51 |
+
for key, value in keys.items():
|
| 52 |
+
new_template = new_template.replace(f'{{{key}}}', f'{value}')
|
| 53 |
+
return new_template
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str:
|
| 57 |
+
# Common prefix
|
| 58 |
+
code = '// DeepGEMM auto-generated JIT CUDA source file\n\n'
|
| 59 |
+
|
| 60 |
+
# Includes
|
| 61 |
+
preload_sys_includes = ['<cuda.h>', '<cuda_fp8.h>', '<cuda_runtime.h>', '<iostream>']
|
| 62 |
+
preload_package_includes = ['"cutlass/cutlass.h"']
|
| 63 |
+
|
| 64 |
+
assert isinstance(includes, list) or isinstance(includes, tuple)
|
| 65 |
+
sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')])))
|
| 66 |
+
package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')])))
|
| 67 |
+
code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n'
|
| 68 |
+
code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n'
|
| 69 |
+
|
| 70 |
+
# Function signature
|
| 71 |
+
raw = '__raw_'
|
| 72 |
+
get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n
|
| 73 |
+
code += f'extern "C" void launch('
|
| 74 |
+
code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ])
|
| 75 |
+
code += ') {\n'
|
| 76 |
+
|
| 77 |
+
# Cast raw types
|
| 78 |
+
code += ' // Cast raw types (if needed)\n'
|
| 79 |
+
for arg_name, arg_type in arg_defs:
|
| 80 |
+
if genc_map[arg_type][0] != genc_map[arg_type][1]:
|
| 81 |
+
code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n'
|
| 82 |
+
|
| 83 |
+
# Function body
|
| 84 |
+
code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')])
|
| 85 |
+
|
| 86 |
+
# End the function
|
| 87 |
+
code += '}\n\n'
|
| 88 |
+
|
| 89 |
+
# Debug print
|
| 90 |
+
if os.getenv('DG_JIT_DEBUG', None):
|
| 91 |
+
print(f'Generated code:\n{code}')
|
| 92 |
+
|
| 93 |
+
return code
|
torch-ext/deep_gemm/jit_kernels/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .gemm import gemm_fp8_fp8_bf16_nt
|
| 2 |
+
from .m_grouped_gemm import (
|
| 3 |
+
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
| 4 |
+
m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
| 5 |
+
)
|
| 6 |
+
from .utils import (
|
| 7 |
+
ceil_div, set_num_sms, get_num_sms,
|
| 8 |
+
get_col_major_tma_aligned_tensor,
|
| 9 |
+
get_m_alignment_for_contiguous_layout
|
| 10 |
+
)
|
torch-ext/deep_gemm/jit_kernels/gemm.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
from .tuner import jit_tuner
|
| 5 |
+
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
| 6 |
+
|
| 7 |
+
# C++ code templates
|
| 8 |
+
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
| 9 |
+
template = """
|
| 10 |
+
using namespace deep_gemm;
|
| 11 |
+
|
| 12 |
+
// Templated args from Python JIT call
|
| 13 |
+
constexpr auto N = {N}, K = {K};
|
| 14 |
+
constexpr auto BLOCK_M = {BLOCK_M};
|
| 15 |
+
constexpr auto BLOCK_N = {BLOCK_N};
|
| 16 |
+
constexpr auto kNumStages = {NUM_STAGES};
|
| 17 |
+
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
| 18 |
+
|
| 19 |
+
// Make a templated GEMM
|
| 20 |
+
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>;
|
| 21 |
+
|
| 22 |
+
// Launch kernel
|
| 23 |
+
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
| 24 |
+
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
|
| 25 |
+
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
|
| 26 |
+
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
|
| 27 |
+
GemmType::run(out, rhs_scales, nullptr,
|
| 28 |
+
m,
|
| 29 |
+
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
| 30 |
+
stream, num_sms, smem_size);
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms: int) -> bool:
|
| 35 |
+
if num_tma_multicast == 1:
|
| 36 |
+
return True
|
| 37 |
+
return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int:
|
| 41 |
+
smem_d = block_m * block_n * 2
|
| 42 |
+
smem_a_per_stage = block_m * block_k
|
| 43 |
+
smem_scales_a_per_stage = block_m * 4
|
| 44 |
+
smem_b_per_stage = block_n * block_k
|
| 45 |
+
smem_scales_b = ceil_div(k, block_k) * 4
|
| 46 |
+
smem_barrier = num_stages * 8 * 2
|
| 47 |
+
|
| 48 |
+
smem_size = 0
|
| 49 |
+
smem_size += smem_d
|
| 50 |
+
smem_size += num_stages * smem_a_per_stage
|
| 51 |
+
smem_size += num_stages * smem_scales_a_per_stage
|
| 52 |
+
smem_size += num_stages * smem_b_per_stage
|
| 53 |
+
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
|
| 54 |
+
smem_size += smem_barrier
|
| 55 |
+
return smem_size
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
| 59 |
+
is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int]:
|
| 60 |
+
if not is_grouped_contiguous:
|
| 61 |
+
# TODO: for some cases, smaller M block is better, add them into tuning space
|
| 62 |
+
block_ms = (64 if m <= 64 else 128, )
|
| 63 |
+
else:
|
| 64 |
+
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
| 65 |
+
block_ns = tuple(range(16, 129, 8))
|
| 66 |
+
|
| 67 |
+
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
| 68 |
+
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
| 69 |
+
get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
|
| 70 |
+
|
| 71 |
+
# Decide block sizes by waves
|
| 72 |
+
best_block_m, best_block_n = None, None
|
| 73 |
+
for block_m in block_ms:
|
| 74 |
+
for block_n in block_ns:
|
| 75 |
+
success = False
|
| 76 |
+
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
|
| 77 |
+
if best_block_m is None or best_block_n is None:
|
| 78 |
+
success = True
|
| 79 |
+
elif num_waves < best_num_waves:
|
| 80 |
+
success = True
|
| 81 |
+
elif num_waves == best_num_waves:
|
| 82 |
+
# Check last wave utilization
|
| 83 |
+
util = get_last_wave_util(block_m, block_n)
|
| 84 |
+
best_util = get_last_wave_util(best_block_m, best_block_n)
|
| 85 |
+
success = util > best_util or (util == best_util and (block_m > best_block_m or (block_m == best_block_m and block_n < best_block_n)))
|
| 86 |
+
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
|
| 87 |
+
assert best_block_m is not None and best_block_n is not None
|
| 88 |
+
|
| 89 |
+
# Always pick the longest one
|
| 90 |
+
# NOTES: for double B scales, the best number of stages may be reduced
|
| 91 |
+
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
|
| 92 |
+
for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
|
| 93 |
+
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
|
| 94 |
+
if best_smem_size <= sm90_capacity:
|
| 95 |
+
best_num_stages = num_stages
|
| 96 |
+
break
|
| 97 |
+
assert best_num_stages is not None
|
| 98 |
+
|
| 99 |
+
# Decide the number of TMA multicast
|
| 100 |
+
best_num_tma_multicast = 1
|
| 101 |
+
if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1:
|
| 102 |
+
best_num_tma_multicast = 2
|
| 103 |
+
|
| 104 |
+
return best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
| 108 |
+
rhs: Tuple[torch.Tensor, torch.Tensor],
|
| 109 |
+
out: torch.Tensor) -> None:
|
| 110 |
+
"""
|
| 111 |
+
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
| 112 |
+
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
| 113 |
+
RHS and RHS scaling factors are required to be transposed.
|
| 114 |
+
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
| 115 |
+
this function will do a transposing with a set of slow PyTorch operations.
|
| 116 |
+
|
| 117 |
+
Arguments:
|
| 118 |
+
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
| 119 |
+
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
|
| 120 |
+
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
|
| 121 |
+
the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
|
| 122 |
+
out: the BF16 output tensor of shape `[m, n]`, representing the result.
|
| 123 |
+
"""
|
| 124 |
+
lhs, lhs_scales = lhs
|
| 125 |
+
rhs, rhs_scales = rhs
|
| 126 |
+
m, k = lhs.shape
|
| 127 |
+
n, k_ = rhs.shape
|
| 128 |
+
m_, n_ = out.shape
|
| 129 |
+
|
| 130 |
+
assert n % 64 == 0 and k % 128 == 0
|
| 131 |
+
|
| 132 |
+
# Type and shape checks
|
| 133 |
+
assert m == m_ and n == n_ and k == k_
|
| 134 |
+
assert n > 0 and k > 0
|
| 135 |
+
assert lhs_scales.shape == (m, (k + 127) // 128)
|
| 136 |
+
assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
|
| 137 |
+
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
| 138 |
+
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
| 139 |
+
assert out.dtype == torch.bfloat16
|
| 140 |
+
assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
| 141 |
+
|
| 142 |
+
# LHS scales must be transposed for TMA load, but not for RHS scales
|
| 143 |
+
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
| 144 |
+
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
| 145 |
+
assert rhs_scales.is_contiguous()
|
| 146 |
+
|
| 147 |
+
# Do nothing if `m` is zero
|
| 148 |
+
if m == 0:
|
| 149 |
+
return
|
| 150 |
+
|
| 151 |
+
# Auto-tuning with compilation
|
| 152 |
+
global includes, template
|
| 153 |
+
num_sms = get_num_sms()
|
| 154 |
+
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms)
|
| 155 |
+
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
|
| 156 |
+
runtime = jit_tuner.compile_and_tune(
|
| 157 |
+
name='gemm_fp8_fp8_bf16_nt',
|
| 158 |
+
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
| 159 |
+
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast},
|
| 160 |
+
space=(),
|
| 161 |
+
includes=includes,
|
| 162 |
+
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
| 163 |
+
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
| 164 |
+
('out', torch.bfloat16), ('m', int),
|
| 165 |
+
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
| 166 |
+
template=template,
|
| 167 |
+
args=args
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Run the kernel
|
| 171 |
+
runtime(*args)
|
torch-ext/deep_gemm/jit_kernels/m_grouped_gemm.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
from .gemm import get_best_configs
|
| 5 |
+
from .tuner import jit_tuner
|
| 6 |
+
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
|
| 7 |
+
|
| 8 |
+
# C++ code templates
|
| 9 |
+
includes = ('"deep_gemm/fp8_gemm.cuh"', )
|
| 10 |
+
template = """
|
| 11 |
+
using namespace deep_gemm;
|
| 12 |
+
|
| 13 |
+
// Templated args from Python JIT call
|
| 14 |
+
constexpr auto N = {N}, K = {K};
|
| 15 |
+
constexpr auto BLOCK_M = {BLOCK_M};
|
| 16 |
+
constexpr auto BLOCK_N = {BLOCK_N};
|
| 17 |
+
constexpr auto kNumStages = {NUM_STAGES};
|
| 18 |
+
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
| 19 |
+
|
| 20 |
+
// Make a templated grouped GEMM
|
| 21 |
+
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>;
|
| 22 |
+
|
| 23 |
+
// Launch kernel
|
| 24 |
+
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
|
| 25 |
+
auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs);
|
| 26 |
+
auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m);
|
| 27 |
+
auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m);
|
| 28 |
+
GemmType::run(out, rhs_scales, grouped_layout,
|
| 29 |
+
m,
|
| 30 |
+
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
| 31 |
+
stream, num_sms, smem_size);
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
| 36 |
+
rhs: Tuple[torch.Tensor, torch.Tensor],
|
| 37 |
+
out: torch.Tensor, m_indices: torch.Tensor) -> None:
|
| 38 |
+
"""
|
| 39 |
+
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
| 40 |
+
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
| 41 |
+
RHS and RHS scaling factors are required to be transposed.
|
| 42 |
+
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
| 43 |
+
this function will do a transposing with a set of slow PyTorch operations.
|
| 44 |
+
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
|
| 45 |
+
`get_m_alignment_for_contiguous_layout()` (128).
|
| 46 |
+
|
| 47 |
+
Arguments:
|
| 48 |
+
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
|
| 49 |
+
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`.
|
| 50 |
+
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
| 51 |
+
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
| 52 |
+
out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
|
| 53 |
+
m_indices: a tensor of shape `[m_sum]` with type `torch.int`.
|
| 54 |
+
`m_indices[i]` records the group which the j-th row of the LHS belong to,
|
| 55 |
+
which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
|
| 56 |
+
Values of `m_indices` in every-m-alignment-block must also be the same.
|
| 57 |
+
"""
|
| 58 |
+
lhs, lhs_scales = lhs
|
| 59 |
+
rhs, rhs_scales = rhs
|
| 60 |
+
m, k = lhs.shape
|
| 61 |
+
num_groups, n, k_ = rhs.shape
|
| 62 |
+
m_, n_ = out.shape
|
| 63 |
+
m__ = m_indices.numel()
|
| 64 |
+
|
| 65 |
+
# Type and shape checks
|
| 66 |
+
assert m == m_ == m__ and k == k_ and n == n_
|
| 67 |
+
assert lhs_scales.shape == (m, (k + 127) // 128)
|
| 68 |
+
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
| 69 |
+
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
| 70 |
+
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
| 71 |
+
assert out.dtype == torch.bfloat16
|
| 72 |
+
assert m_indices.dtype == torch.int32
|
| 73 |
+
assert lhs.is_contiguous() and rhs.is_contiguous()
|
| 74 |
+
assert out.is_contiguous() and m_indices.is_contiguous()
|
| 75 |
+
|
| 76 |
+
# LHS scales must be transposed for TMA load, but not for RHS scales
|
| 77 |
+
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
| 78 |
+
assert rhs_scales.is_contiguous()
|
| 79 |
+
|
| 80 |
+
# Do nothing if `m` is zero
|
| 81 |
+
if m == 0:
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
# Auto-tuning with compilation
|
| 85 |
+
global includes, template
|
| 86 |
+
num_sms = get_num_sms()
|
| 87 |
+
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms,
|
| 88 |
+
is_grouped_contiguous=True)
|
| 89 |
+
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
| 90 |
+
m_indices, m, num_groups,
|
| 91 |
+
torch.cuda.current_stream(), num_sms, smem_size)
|
| 92 |
+
runtime = jit_tuner.compile_and_tune(
|
| 93 |
+
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
| 94 |
+
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
| 95 |
+
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'},
|
| 96 |
+
space=(),
|
| 97 |
+
includes=includes,
|
| 98 |
+
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
| 99 |
+
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
| 100 |
+
('out', torch.bfloat16),
|
| 101 |
+
('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
|
| 102 |
+
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
| 103 |
+
template=template,
|
| 104 |
+
args=args
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Run the kernel
|
| 108 |
+
runtime(*args)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
| 112 |
+
rhs: Tuple[torch.Tensor, torch.Tensor],
|
| 113 |
+
out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
|
| 114 |
+
"""
|
| 115 |
+
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
| 116 |
+
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
| 117 |
+
RHS and RHS scaling factors are required to be transposed.
|
| 118 |
+
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
| 119 |
+
this function will do a transposing with a set of slow PyTorch operations.
|
| 120 |
+
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
|
| 121 |
+
should be separately transposed.
|
| 122 |
+
|
| 123 |
+
Arguments:
|
| 124 |
+
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
| 125 |
+
the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`.
|
| 126 |
+
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
| 127 |
+
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
| 128 |
+
out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
|
| 129 |
+
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
|
| 130 |
+
in the i-th group.
|
| 131 |
+
expected_m: a value hint (which is a value on CPU) for the M expectation of each batch,
|
| 132 |
+
correctly setting this value may lead to better performance.
|
| 133 |
+
"""
|
| 134 |
+
lhs, lhs_scales = lhs
|
| 135 |
+
rhs, rhs_scales = rhs
|
| 136 |
+
num_groups, m, k = lhs.shape
|
| 137 |
+
num_groups_, n, k_ = rhs.shape
|
| 138 |
+
num_groups__, m_, n_ = out.shape
|
| 139 |
+
num_groups___ = masked_m.numel()
|
| 140 |
+
|
| 141 |
+
# Type and shape checks
|
| 142 |
+
assert num_groups == num_groups_ == num_groups__ == num_groups___
|
| 143 |
+
assert m == m_ and n == n_ and k == k_
|
| 144 |
+
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
|
| 145 |
+
assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
|
| 146 |
+
assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
| 147 |
+
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
| 148 |
+
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
| 149 |
+
assert out.dtype == torch.bfloat16
|
| 150 |
+
assert masked_m.dtype == torch.int32
|
| 151 |
+
assert lhs.is_contiguous() and rhs.is_contiguous()
|
| 152 |
+
assert out.is_contiguous() and masked_m.is_contiguous()
|
| 153 |
+
|
| 154 |
+
# LHS scales must be transposed for TMA load, but not for RHS scales
|
| 155 |
+
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
| 156 |
+
assert rhs_scales.is_contiguous()
|
| 157 |
+
|
| 158 |
+
# Auto-tuning with compilation
|
| 159 |
+
global includes, template
|
| 160 |
+
num_sms = get_num_sms()
|
| 161 |
+
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
|
| 162 |
+
|
| 163 |
+
# Extra checks for TMA store
|
| 164 |
+
if num_groups > 1 and m > block_m:
|
| 165 |
+
assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
|
| 166 |
+
|
| 167 |
+
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
| 168 |
+
masked_m, m,
|
| 169 |
+
torch.cuda.current_stream(), num_sms, smem_size)
|
| 170 |
+
runtime = jit_tuner.compile_and_tune(
|
| 171 |
+
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
| 172 |
+
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
|
| 173 |
+
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'},
|
| 174 |
+
space=(),
|
| 175 |
+
includes=includes,
|
| 176 |
+
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
| 177 |
+
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
| 178 |
+
('out', torch.bfloat16),
|
| 179 |
+
('grouped_layout', torch.int32), ('m', int),
|
| 180 |
+
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
| 181 |
+
template=template,
|
| 182 |
+
args=args
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Run the kernel
|
| 186 |
+
runtime(*args)
|
torch-ext/deep_gemm/jit_kernels/tuner.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Any, Dict
|
| 5 |
+
|
| 6 |
+
from ..jit import build, cpp_format, generate, Runtime
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class JITTuner:
|
| 10 |
+
def __init__(self) -> None:
|
| 11 |
+
self.tuned = {}
|
| 12 |
+
|
| 13 |
+
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
|
| 14 |
+
includes: tuple, arg_defs: tuple, template: str, args: tuple) -> Runtime:
|
| 15 |
+
# NOTES: we always assume the space and template will not change
|
| 16 |
+
# We also assume the GPU device will not be changed
|
| 17 |
+
# NOTES: the function must have no accumulated side effects
|
| 18 |
+
keys = {k: keys[k] for k in sorted(keys.keys())}
|
| 19 |
+
signature = (name, f'{keys}')
|
| 20 |
+
if signature in self.tuned:
|
| 21 |
+
if os.getenv('DG_JIT_DEBUG', None):
|
| 22 |
+
print(f'Using cached JIT kernel {name} with keys {keys}')
|
| 23 |
+
return self.tuned[signature]
|
| 24 |
+
|
| 25 |
+
if os.getenv('DG_JIT_DEBUG', None):
|
| 26 |
+
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
|
| 27 |
+
|
| 28 |
+
assert signature not in self.tuned
|
| 29 |
+
assert args is not None
|
| 30 |
+
space = (dict(), ) if len(space) == 0 else space
|
| 31 |
+
|
| 32 |
+
kernels = []
|
| 33 |
+
for tuned_keys in space:
|
| 34 |
+
assert isinstance(tuned_keys, dict)
|
| 35 |
+
full_keys = copy.deepcopy(keys)
|
| 36 |
+
full_keys.update(tuned_keys)
|
| 37 |
+
code = generate(includes, arg_defs, cpp_format(template, full_keys))
|
| 38 |
+
|
| 39 |
+
# Illegal build must raise errors
|
| 40 |
+
kernels.append((build(name, arg_defs, code), tuned_keys))
|
| 41 |
+
|
| 42 |
+
best_runtime, best_time, best_keys = None, None, None
|
| 43 |
+
for runtime, tuned_keys in kernels:
|
| 44 |
+
if len(space) > 1:
|
| 45 |
+
# Check kernel validity
|
| 46 |
+
return_code = runtime(*args)
|
| 47 |
+
if return_code != 0:
|
| 48 |
+
# Pass illegal kernels, e.g. insufficient shared memory capacity
|
| 49 |
+
if os.getenv('DG_JIT_DEBUG', None):
|
| 50 |
+
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
|
| 51 |
+
continue
|
| 52 |
+
|
| 53 |
+
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
| 54 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
| 55 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
| 56 |
+
torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
|
| 57 |
+
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
| 58 |
+
start_event.record()
|
| 59 |
+
for i in range(20):
|
| 60 |
+
assert runtime(*args) == 0
|
| 61 |
+
end_event.record()
|
| 62 |
+
end_event.synchronize()
|
| 63 |
+
elapsed_time = start_event.elapsed_time(end_event)
|
| 64 |
+
else:
|
| 65 |
+
elapsed_time = 0
|
| 66 |
+
|
| 67 |
+
# Compare if better
|
| 68 |
+
if best_time is None or elapsed_time < best_time:
|
| 69 |
+
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
|
| 70 |
+
if os.getenv('DG_JIT_DEBUG', None):
|
| 71 |
+
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
|
| 72 |
+
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
|
| 73 |
+
|
| 74 |
+
# Cache the best runtime and return
|
| 75 |
+
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None):
|
| 76 |
+
print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
|
| 77 |
+
self.tuned[signature] = best_runtime
|
| 78 |
+
return best_runtime
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
jit_tuner = JITTuner()
|
torch-ext/deep_gemm/jit_kernels/utils.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
_num_sms = None
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def set_num_sms(num_sms: int) -> None:
|
| 7 |
+
"""
|
| 8 |
+
Set the maximum SM count for all GEMM kernels to use.
|
| 9 |
+
|
| 10 |
+
Arguments:
|
| 11 |
+
num_sms: the desired maximum SM count for all GEMM kernels to use.
|
| 12 |
+
"""
|
| 13 |
+
global _num_sms
|
| 14 |
+
assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
| 15 |
+
_num_sms = num_sms
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_num_sms() -> int:
|
| 19 |
+
"""
|
| 20 |
+
Get the current maximum limit of SM count for all GEMM kernels to use.
|
| 21 |
+
If the count is never specified, the function will return the number of device SMs.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Current maximum limit of SM count for all GEMM kernels to use.
|
| 25 |
+
"""
|
| 26 |
+
global _num_sms
|
| 27 |
+
if _num_sms is None:
|
| 28 |
+
_num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
| 29 |
+
return _num_sms
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def ceil_div(x: int, y: int) -> int:
|
| 33 |
+
"""
|
| 34 |
+
Perform ceiling division of two integers.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
x: the dividend.
|
| 38 |
+
y: the divisor.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
The result of the ceiling division.
|
| 42 |
+
"""
|
| 43 |
+
return (x + y - 1) // y
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_m_alignment_for_contiguous_layout():
|
| 47 |
+
"""
|
| 48 |
+
When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis.
|
| 49 |
+
Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well
|
| 50 |
+
with GEMM block shape.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Group-level alignment requirement for grouped contiguous layout, which is always 128.
|
| 54 |
+
"""
|
| 55 |
+
return 128
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_tma_aligned_size(x: int, element_size: int) -> int:
|
| 59 |
+
"""
|
| 60 |
+
Global memory address of TMA must be 16-byte aligned.
|
| 61 |
+
Since we use column-major layout for the LHS scaling tensor,
|
| 62 |
+
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
|
| 63 |
+
|
| 64 |
+
Arguments:
|
| 65 |
+
x: original M-axis shape of the LHS scaling tensor.
|
| 66 |
+
element_size: element size of the LHS scaling tensor.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
M-axis shape of the LHS scaling tensor after padding.
|
| 70 |
+
"""
|
| 71 |
+
tma_alignment_bytes = 16
|
| 72 |
+
assert tma_alignment_bytes % element_size == 0
|
| 73 |
+
alignment = tma_alignment_bytes // element_size
|
| 74 |
+
return ceil_div(x, alignment) * alignment
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
| 78 |
+
"""
|
| 79 |
+
Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
|
| 80 |
+
If the input tensor is already column-major layout and 16-byte aligned along the M axis
|
| 81 |
+
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
|
| 82 |
+
|
| 83 |
+
Arguments:
|
| 84 |
+
x: usually the LHS scaling tensor in GEMM.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
The LHS scaling tensor of TMA-aligned transposed format.
|
| 88 |
+
"""
|
| 89 |
+
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
|
| 90 |
+
assert x.dim() in (2, 3)
|
| 91 |
+
remove_dim = False
|
| 92 |
+
if x.dim() == 2:
|
| 93 |
+
x, remove_dim = x.unsqueeze(0), True
|
| 94 |
+
|
| 95 |
+
b, m, n = x.shape
|
| 96 |
+
aligned_m = get_tma_aligned_size(m, x.element_size())
|
| 97 |
+
|
| 98 |
+
# The last kernel gives a column-major TMA aligned layout
|
| 99 |
+
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
|
| 100 |
+
return x.squeeze(0) if remove_dim else x
|
| 101 |
+
|
| 102 |
+
# Normal layout requires transposing
|
| 103 |
+
aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
| 104 |
+
aligned_x[:, :m, :] = x
|
| 105 |
+
aligned_x = aligned_x[:, :m, :]
|
| 106 |
+
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
torch-ext/deep_gemm/utils.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
| 8 |
+
high_precision: bool = False):
|
| 9 |
+
# Flush L2 cache with 256 MB data
|
| 10 |
+
torch.cuda.synchronize()
|
| 11 |
+
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
| 12 |
+
cache.zero_()
|
| 13 |
+
|
| 14 |
+
# Warmup
|
| 15 |
+
for _ in range(num_warmups):
|
| 16 |
+
fn()
|
| 17 |
+
|
| 18 |
+
# Add a large kernel to eliminate the CPU launch overhead
|
| 19 |
+
if high_precision:
|
| 20 |
+
x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
| 21 |
+
y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
| 22 |
+
x @ y
|
| 23 |
+
|
| 24 |
+
# Testing
|
| 25 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
| 26 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
| 27 |
+
start_event.record()
|
| 28 |
+
for i in range(num_tests):
|
| 29 |
+
fn()
|
| 30 |
+
end_event.record()
|
| 31 |
+
torch.cuda.synchronize()
|
| 32 |
+
|
| 33 |
+
return start_event.elapsed_time(end_event) / num_tests
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class empty_suppress:
|
| 37 |
+
def __enter__(self):
|
| 38 |
+
return self
|
| 39 |
+
|
| 40 |
+
def __exit__(self, *_):
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class suppress_stdout_stderr:
|
| 45 |
+
def __enter__(self):
|
| 46 |
+
self.outnull_file = open(os.devnull, 'w')
|
| 47 |
+
self.errnull_file = open(os.devnull, 'w')
|
| 48 |
+
|
| 49 |
+
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
| 50 |
+
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
| 51 |
+
|
| 52 |
+
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
| 53 |
+
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
| 54 |
+
|
| 55 |
+
self.old_stdout = sys.stdout
|
| 56 |
+
self.old_stderr = sys.stderr
|
| 57 |
+
|
| 58 |
+
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
| 59 |
+
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
| 60 |
+
|
| 61 |
+
sys.stdout = self.outnull_file
|
| 62 |
+
sys.stderr = self.errnull_file
|
| 63 |
+
return self
|
| 64 |
+
|
| 65 |
+
def __exit__(self, *_):
|
| 66 |
+
sys.stdout = self.old_stdout
|
| 67 |
+
sys.stderr = self.old_stderr
|
| 68 |
+
|
| 69 |
+
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
| 70 |
+
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
| 71 |
+
|
| 72 |
+
os.close(self.old_stdout_fileno)
|
| 73 |
+
os.close(self.old_stderr_fileno)
|
| 74 |
+
|
| 75 |
+
self.outnull_file.close()
|
| 76 |
+
self.errnull_file.close()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
|
| 80 |
+
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = False):
|
| 81 |
+
# Conflict with Nsight Systems
|
| 82 |
+
using_nsys = os.environ.get('DG_NSYS_PROFILING', False)
|
| 83 |
+
|
| 84 |
+
# For some auto-tuning kernels with prints
|
| 85 |
+
fn()
|
| 86 |
+
|
| 87 |
+
# Profile
|
| 88 |
+
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
|
| 89 |
+
with suppress():
|
| 90 |
+
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None
|
| 91 |
+
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
|
| 92 |
+
with profiler:
|
| 93 |
+
for i in range(2):
|
| 94 |
+
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
|
| 95 |
+
if barrier_comm_profiling:
|
| 96 |
+
lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
| 97 |
+
rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
| 98 |
+
lhs @ rhs
|
| 99 |
+
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
|
| 100 |
+
for _ in range(num_tests):
|
| 101 |
+
if flush_l2:
|
| 102 |
+
torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
|
| 103 |
+
fn()
|
| 104 |
+
|
| 105 |
+
if not using_nsys:
|
| 106 |
+
profiler.step()
|
| 107 |
+
|
| 108 |
+
# Return 1 if using Nsight Systems
|
| 109 |
+
if using_nsys:
|
| 110 |
+
return 1
|
| 111 |
+
|
| 112 |
+
# Parse the profiling table
|
| 113 |
+
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
| 114 |
+
is_tupled = isinstance(kernel_names, tuple)
|
| 115 |
+
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
|
| 116 |
+
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
|
| 117 |
+
assert all([isinstance(name, str) for name in kernel_names])
|
| 118 |
+
for name in kernel_names:
|
| 119 |
+
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
|
| 120 |
+
|
| 121 |
+
# Save chrome traces
|
| 122 |
+
if trace_path is not None:
|
| 123 |
+
profiler.export_chrome_trace(trace_path)
|
| 124 |
+
|
| 125 |
+
# Return average kernel times
|
| 126 |
+
units = {'ms': 1e3, 'us': 1e6}
|
| 127 |
+
kernel_times = []
|
| 128 |
+
for name in kernel_names:
|
| 129 |
+
for line in prof_lines:
|
| 130 |
+
if name in line:
|
| 131 |
+
time_str = line.split()[-2]
|
| 132 |
+
for unit, scale in units.items():
|
| 133 |
+
if unit in time_str:
|
| 134 |
+
kernel_times.append(float(time_str.replace(unit, '')) / scale)
|
| 135 |
+
break
|
| 136 |
+
break
|
| 137 |
+
return tuple(kernel_times) if is_tupled else kernel_times[0]
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def calc_diff(x, y):
|
| 141 |
+
x, y = x.double(), y.double()
|
| 142 |
+
denominator = (x * x + y * y).sum()
|
| 143 |
+
sim = 2 * (x * y).sum() / denominator
|
| 144 |
+
return 1 - sim
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def count_bytes(tensors):
|
| 148 |
+
total = 0
|
| 149 |
+
for t in tensors:
|
| 150 |
+
if isinstance(t, tuple):
|
| 151 |
+
total += count_bytes(t)
|
| 152 |
+
else:
|
| 153 |
+
total += t.numel() * t.element_size()
|
| 154 |
+
return total
|