Instructions to use kernels-community/megablocks with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/megablocks with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/megablocks") - Notebooks
- Google Colab
- Kaggle
Uploaded using `kernel-builder`.
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- build/torch211-cxx11-cu126-x86_64-linux/__init__.py +5 -2
- build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/__init__.py +0 -0
- build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py +0 -0
- build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py +574 -0
- build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/adapter.py +53 -0
- build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/configs.py +5 -0
- build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/gmm.py +567 -0
- build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/__init__.py +0 -0
- build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py +0 -0
- build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py +46 -0
- build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py +100 -0
- build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py +752 -0
- build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/logger.py +47 -0
- build/torch211-cxx11-cu126-x86_64-linux/{_megablocks_cuda_ae601bb.abi3.so → _megablocks_cuda_f8f8b50.abi3.so} +2 -2
- build/torch211-cxx11-cu126-x86_64-linux/_ops.py +3 -3
- build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py +9 -9
- build/torch211-cxx11-cu126-x86_64-linux/metadata.json +3 -2
- build/torch211-cxx11-cu128-x86_64-linux/__init__.py +5 -2
- build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/__init__.py +0 -0
- build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py +0 -0
- build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py +574 -0
- build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/adapter.py +53 -0
- build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/configs.py +5 -0
- build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/gmm.py +567 -0
- build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/__init__.py +0 -0
- build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py +0 -0
- build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py +46 -0
- build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py +100 -0
- build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py +752 -0
- build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/logger.py +47 -0
- build/torch211-cxx11-cu128-x86_64-linux/{_megablocks_cuda_ae601bb.abi3.so → _megablocks_cuda_f8f8b50.abi3.so} +2 -2
- build/torch211-cxx11-cu128-x86_64-linux/_ops.py +3 -3
- build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py +9 -9
- build/torch211-cxx11-cu128-x86_64-linux/metadata.json +2 -1
- build/torch211-cxx11-cu130-x86_64-linux/__init__.py +5 -2
- build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/__init__.py +0 -0
- build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py +0 -0
- build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py +574 -0
- build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/adapter.py +53 -0
- build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/configs.py +5 -0
- build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/gmm.py +567 -0
- build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/__init__.py +0 -0
- build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py +0 -0
- build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py +46 -0
- build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py +100 -0
- build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py +752 -0
- build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/logger.py +47 -0
- build/torch211-cxx11-cu130-x86_64-linux/{_megablocks_cuda_ae601bb.abi3.so → _megablocks_cuda_f8f8b50.abi3.so} +2 -2
- build/torch211-cxx11-cu130-x86_64-linux/_ops.py +3 -3
- build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py +9 -9
build/torch211-cxx11-cu126-x86_64-linux/__init__.py
CHANGED
|
@@ -3,7 +3,9 @@
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
|
| 6 |
-
from .
|
|
|
|
|
|
|
| 7 |
|
| 8 |
from .grouped_gemm import backend as gg_backend
|
| 9 |
from .grouped_gemm import ops as gg_ops
|
|
@@ -136,7 +138,8 @@ def sort(
|
|
| 136 |
Returns:
|
| 137 |
The sorted values tensor
|
| 138 |
"""
|
| 139 |
-
|
|
|
|
| 140 |
|
| 141 |
|
| 142 |
# Convenience functions for common use cases
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
# Stable alias: bare `ops` is shadowed by `from . import layers` below.
|
| 7 |
+
from ._ops import ops as _compiled_ops
|
| 8 |
+
from . import ops
|
| 9 |
|
| 10 |
from .grouped_gemm import backend as gg_backend
|
| 11 |
from .grouped_gemm import ops as gg_ops
|
|
|
|
| 138 |
Returns:
|
| 139 |
The sorted values tensor
|
| 140 |
"""
|
| 141 |
+
_compiled_ops.sort(x, end_bit, x_out, iota_out)
|
| 142 |
+
return x_out
|
| 143 |
|
| 144 |
|
| 145 |
# Convenience functions for common use cases
|
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/__init__.py
ADDED
|
File without changes
|
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py
ADDED
|
File without changes
|
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py
ADDED
|
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
# Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Imports.
|
| 6 |
+
# ------------------------------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
# Python standard library
|
| 9 |
+
import functools
|
| 10 |
+
|
| 11 |
+
# Triton
|
| 12 |
+
import triton
|
| 13 |
+
import triton.language as tl
|
| 14 |
+
|
| 15 |
+
# AITER
|
| 16 |
+
from ..configs import CONFIGS as _CONFIGS
|
| 17 |
+
from ..utils._triton import arch_info
|
| 18 |
+
from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
|
| 19 |
+
|
| 20 |
+
# Kernel config.
|
| 21 |
+
# ------------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@functools.lru_cache()
|
| 25 |
+
def get_config(
|
| 26 |
+
gmm_type: str, M: int, K: int, N: int, G: int, accumulate: bool = False
|
| 27 |
+
) -> dict[str, int]:
|
| 28 |
+
assert gmm_type in {
|
| 29 |
+
"gmm",
|
| 30 |
+
"ptgmm",
|
| 31 |
+
"nptgmm",
|
| 32 |
+
}, f"'{gmm_type}' is an invalid GMM variant."
|
| 33 |
+
dev = arch_info.get_arch()
|
| 34 |
+
assert (
|
| 35 |
+
dev in _CONFIGS
|
| 36 |
+
), f"No GMM configuration tuned for arch '{dev}'. Supported: {sorted(_CONFIGS)}."
|
| 37 |
+
arch_configs = _CONFIGS[dev]
|
| 38 |
+
assert (
|
| 39 |
+
"default" in arch_configs[gmm_type]
|
| 40 |
+
), "Default configuration is absent."
|
| 41 |
+
key = "accumulate" if accumulate else "default"
|
| 42 |
+
return arch_configs[gmm_type][key]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Common code shared by GMM and TGMM kernels.
|
| 46 |
+
# ------------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# XCD remapping followed by 1D PID to 2D grid mapping.
|
| 50 |
+
@triton.jit
|
| 51 |
+
def _remap_xcd_tile_grid(
|
| 52 |
+
tile_in_mm,
|
| 53 |
+
num_row_tiles,
|
| 54 |
+
num_col_tiles,
|
| 55 |
+
GROUP_SIZE: tl.constexpr = 1,
|
| 56 |
+
NUM_XCDS: tl.constexpr = 8,
|
| 57 |
+
):
|
| 58 |
+
return pid_grid(
|
| 59 |
+
remap_xcd(tile_in_mm, num_row_tiles * num_col_tiles, NUM_XCDS=NUM_XCDS),
|
| 60 |
+
num_row_tiles,
|
| 61 |
+
num_col_tiles,
|
| 62 |
+
GROUP_SIZE_M=GROUP_SIZE,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# GMM kernel.
|
| 67 |
+
# ------------------------------------------------------------------------------
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@triton.heuristics(
|
| 71 |
+
{
|
| 72 |
+
"K_DIVISIBLE_BY_BLOCK_SIZE_K": lambda META: META["K"] % META["BLOCK_SIZE_K"]
|
| 73 |
+
== 0,
|
| 74 |
+
}
|
| 75 |
+
)
|
| 76 |
+
@triton.jit
|
| 77 |
+
def gmm_kernel(
|
| 78 |
+
# Tensor pointers:
|
| 79 |
+
lhs_ptr,
|
| 80 |
+
rhs_ptr,
|
| 81 |
+
group_sizes_ptr,
|
| 82 |
+
out_ptr,
|
| 83 |
+
bias_ptr,
|
| 84 |
+
# Tensor shapes:
|
| 85 |
+
M: int,
|
| 86 |
+
K: int,
|
| 87 |
+
N: int,
|
| 88 |
+
G: int,
|
| 89 |
+
# Meta-parameters:
|
| 90 |
+
TRANS_RHS: tl.constexpr,
|
| 91 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 92 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 93 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 94 |
+
K_DIVISIBLE_BY_BLOCK_SIZE_K: tl.constexpr,
|
| 95 |
+
GROUP_SIZE: tl.constexpr,
|
| 96 |
+
GRID_DIM: tl.constexpr,
|
| 97 |
+
USE_BIAS: tl.constexpr,
|
| 98 |
+
):
|
| 99 |
+
tl.assume(M > 0)
|
| 100 |
+
tl.assume(K > 0)
|
| 101 |
+
tl.assume(N > 0)
|
| 102 |
+
tl.assume(G > 0)
|
| 103 |
+
|
| 104 |
+
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
| 105 |
+
tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0")
|
| 106 |
+
|
| 107 |
+
# Current tile. Each program computes multiple tiles of each group.
|
| 108 |
+
tile = tl.program_id(0)
|
| 109 |
+
tl.device_assert(tile >= 0, "tile < 0 (at initialization)")
|
| 110 |
+
|
| 111 |
+
# Tile limit of last MM problem (inclusive).
|
| 112 |
+
last_mm_tile = 0
|
| 113 |
+
|
| 114 |
+
# Last input row of lhs and output row of out. Each group reads some rows of
|
| 115 |
+
# lhs and writes some rows to out.
|
| 116 |
+
last_m = 0
|
| 117 |
+
|
| 118 |
+
# Loop through all (m, K, N) MM problems:
|
| 119 |
+
# (m, K) x (K, N) = (m, N)
|
| 120 |
+
# sum(m) = M
|
| 121 |
+
for g in range(G):
|
| 122 |
+
# Get m dimension of current MM problem.
|
| 123 |
+
m = tl.load(group_sizes_ptr + g)
|
| 124 |
+
# m can be zero if group is empty
|
| 125 |
+
tl.device_assert(m >= 0, "m < 0")
|
| 126 |
+
|
| 127 |
+
num_m_tiles = tl.cdiv(m, BLOCK_SIZE_M)
|
| 128 |
+
# num_m_tiles can be zero if group is empty
|
| 129 |
+
tl.device_assert(num_m_tiles >= 0, "num_m_tiles < 0")
|
| 130 |
+
|
| 131 |
+
num_tiles = num_m_tiles * num_n_tiles
|
| 132 |
+
# num_tiles can be zero if group is empty
|
| 133 |
+
tl.device_assert(num_tiles >= 0, "num_tiles < 0")
|
| 134 |
+
|
| 135 |
+
# Loop through tiles of current MM problem.
|
| 136 |
+
while tile >= last_mm_tile and tile < last_mm_tile + num_tiles:
|
| 137 |
+
# Figure out tile coordinates in current MM problem.
|
| 138 |
+
tile_in_mm = tile - last_mm_tile
|
| 139 |
+
tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0")
|
| 140 |
+
|
| 141 |
+
tile_m, tile_n = _remap_xcd_tile_grid(
|
| 142 |
+
tile_in_mm, num_m_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Do regular MM:
|
| 146 |
+
|
| 147 |
+
tl.device_assert(tile_m * BLOCK_SIZE_M >= 0, "tile_m * BLOCK_SIZE_M < 0")
|
| 148 |
+
tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0")
|
| 149 |
+
|
| 150 |
+
offs_lhs_m = (
|
| 151 |
+
tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 152 |
+
) % m
|
| 153 |
+
offs_rhs_n = (
|
| 154 |
+
tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 155 |
+
) % N
|
| 156 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
|
| 157 |
+
|
| 158 |
+
lhs_ptrs = lhs_ptr + (last_m + offs_lhs_m[:, None]) * K + offs_k[None, :]
|
| 159 |
+
|
| 160 |
+
if TRANS_RHS:
|
| 161 |
+
rhs_ptrs = (
|
| 162 |
+
rhs_ptr
|
| 163 |
+
+ g.to(tl.int64) * K * N
|
| 164 |
+
+ offs_k[:, None]
|
| 165 |
+
+ offs_rhs_n[None, :] * K
|
| 166 |
+
)
|
| 167 |
+
else:
|
| 168 |
+
rhs_ptrs = (
|
| 169 |
+
rhs_ptr
|
| 170 |
+
+ g.to(tl.int64) * K * N
|
| 171 |
+
+ offs_k[:, None] * N
|
| 172 |
+
+ offs_rhs_n[None, :]
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 176 |
+
|
| 177 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 178 |
+
if K_DIVISIBLE_BY_BLOCK_SIZE_K:
|
| 179 |
+
lhs = tl.load(lhs_ptrs)
|
| 180 |
+
rhs = tl.load(rhs_ptrs)
|
| 181 |
+
else:
|
| 182 |
+
k_mask_limit = K - k * BLOCK_SIZE_K
|
| 183 |
+
lhs = tl.load(
|
| 184 |
+
lhs_ptrs, mask=offs_k[None, :] < k_mask_limit, other=0
|
| 185 |
+
)
|
| 186 |
+
rhs = tl.load(
|
| 187 |
+
rhs_ptrs, mask=offs_k[:, None] < k_mask_limit, other=0
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
acc = tl.dot(lhs, rhs, acc=acc)
|
| 191 |
+
|
| 192 |
+
lhs_ptrs += BLOCK_SIZE_K
|
| 193 |
+
|
| 194 |
+
if TRANS_RHS:
|
| 195 |
+
rhs_ptrs += BLOCK_SIZE_K
|
| 196 |
+
else:
|
| 197 |
+
rhs_ptrs += BLOCK_SIZE_K * N
|
| 198 |
+
|
| 199 |
+
# Add bias if enabled
|
| 200 |
+
if USE_BIAS:
|
| 201 |
+
offs_bias_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(
|
| 202 |
+
0, BLOCK_SIZE_N
|
| 203 |
+
)
|
| 204 |
+
bias_ptrs = bias_ptr + g.to(tl.int64) * N + offs_bias_n
|
| 205 |
+
bias = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0)
|
| 206 |
+
# Convert bias to float32 to match accumulator precision
|
| 207 |
+
bias = bias.to(tl.float32)
|
| 208 |
+
# Broadcast bias across M dimension and add in float32
|
| 209 |
+
acc += bias[None, :]
|
| 210 |
+
|
| 211 |
+
# Convert to output dtype after all computations
|
| 212 |
+
acc = acc.to(out_ptr.type.element_ty)
|
| 213 |
+
|
| 214 |
+
offs_out_m = tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 215 |
+
offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 216 |
+
|
| 217 |
+
out_ptrs = (
|
| 218 |
+
out_ptr + (last_m + offs_out_m[:, None]) * N + offs_out_n[None, :]
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
tl.store(
|
| 222 |
+
out_ptrs,
|
| 223 |
+
acc,
|
| 224 |
+
mask=(offs_out_m[:, None] < m) & (offs_out_n[None, :] < N),
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Go to the next tile by advancing number of programs.
|
| 228 |
+
tile += GRID_DIM
|
| 229 |
+
tl.device_assert(tile > 0, "tile <= 0 (at update)")
|
| 230 |
+
|
| 231 |
+
# Get ready to go to the next MM problem.
|
| 232 |
+
|
| 233 |
+
last_mm_tile += num_tiles
|
| 234 |
+
# last_mm_tile can be zero if group 0 is skipped
|
| 235 |
+
tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)")
|
| 236 |
+
|
| 237 |
+
last_m += m
|
| 238 |
+
# last_m can be zero if group 0 is skipped
|
| 239 |
+
tl.device_assert(last_m >= 0, "last_m < 0 (at update)")
|
| 240 |
+
tl.device_assert(last_m <= M, "last_m > M (at update)")
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
# Persistent TGMM kernel.
|
| 244 |
+
# ------------------------------------------------------------------------------
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
@triton.jit
|
| 248 |
+
def tgmm_persistent_kernel(
|
| 249 |
+
# Tensor pointers:
|
| 250 |
+
lhs_ptr,
|
| 251 |
+
rhs_ptr,
|
| 252 |
+
group_sizes_ptr,
|
| 253 |
+
out_ptr,
|
| 254 |
+
bias_grad_ptr,
|
| 255 |
+
# Tensor shapes:
|
| 256 |
+
M: int,
|
| 257 |
+
K: int,
|
| 258 |
+
N: int,
|
| 259 |
+
G: int,
|
| 260 |
+
# Meta-parameters:
|
| 261 |
+
TRANS_LHS: tl.constexpr,
|
| 262 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 263 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 264 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 265 |
+
GROUP_SIZE: tl.constexpr,
|
| 266 |
+
GRID_DIM: tl.constexpr,
|
| 267 |
+
COMPUTE_BIAS_GRAD: tl.constexpr,
|
| 268 |
+
ACCUMULATE: tl.constexpr,
|
| 269 |
+
):
|
| 270 |
+
tl.assume(M > 0)
|
| 271 |
+
tl.assume(K > 0)
|
| 272 |
+
tl.assume(N > 0)
|
| 273 |
+
tl.assume(G > 0)
|
| 274 |
+
|
| 275 |
+
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
| 276 |
+
tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0")
|
| 277 |
+
|
| 278 |
+
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
| 279 |
+
tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0")
|
| 280 |
+
|
| 281 |
+
num_tiles = num_k_tiles * num_n_tiles
|
| 282 |
+
tl.device_assert(num_tiles > 0, "num_tiles <= 0")
|
| 283 |
+
|
| 284 |
+
# Current tile. Each program computes multiple tiles of each group.
|
| 285 |
+
tile = tl.program_id(0)
|
| 286 |
+
tl.device_assert(tile >= 0, "tile < 0 (at initialization)")
|
| 287 |
+
|
| 288 |
+
# Tile limit of last MM problem (inclusive).
|
| 289 |
+
last_mm_tile = 0
|
| 290 |
+
|
| 291 |
+
# Last input column of lhs and input row of rhs. Each group reads some
|
| 292 |
+
# columns of lhs and some rows of rhs.
|
| 293 |
+
last_m = 0
|
| 294 |
+
|
| 295 |
+
# Loop through all (K, m, N) MM problems:
|
| 296 |
+
# (K, m) x (m, N) = (K, N)
|
| 297 |
+
# sum(m) = M
|
| 298 |
+
for g in range(G):
|
| 299 |
+
# Get m dimension of current MM problem.
|
| 300 |
+
m = tl.load(group_sizes_ptr + g)
|
| 301 |
+
# m can be zero if group is empty
|
| 302 |
+
tl.device_assert(m >= 0, "m < 0")
|
| 303 |
+
|
| 304 |
+
# Loop through tiles of current MM problem.
|
| 305 |
+
while tile >= last_mm_tile and tile < last_mm_tile + num_tiles:
|
| 306 |
+
# Figure out tile coordinates in current MM problem.
|
| 307 |
+
tile_in_mm = tile - last_mm_tile
|
| 308 |
+
tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0")
|
| 309 |
+
|
| 310 |
+
tile_k, tile_n = _remap_xcd_tile_grid(
|
| 311 |
+
tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Do regular MM:
|
| 315 |
+
|
| 316 |
+
tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0")
|
| 317 |
+
tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0")
|
| 318 |
+
|
| 319 |
+
offs_lhs_k = (
|
| 320 |
+
tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
| 321 |
+
) % K
|
| 322 |
+
offs_rhs_n = (
|
| 323 |
+
tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 324 |
+
) % N
|
| 325 |
+
offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
| 326 |
+
|
| 327 |
+
if TRANS_LHS:
|
| 328 |
+
lhs_ptrs = (
|
| 329 |
+
lhs_ptr + offs_lhs_k[:, None] + (last_m + offs_m[None, :]) * K
|
| 330 |
+
)
|
| 331 |
+
else:
|
| 332 |
+
lhs_ptrs = (
|
| 333 |
+
lhs_ptr + offs_lhs_k[:, None] * M + (last_m + offs_m[None, :])
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
rhs_ptrs = rhs_ptr + (last_m + offs_m[:, None]) * N + offs_rhs_n[None, :]
|
| 337 |
+
|
| 338 |
+
loop_m = tl.cdiv(m, BLOCK_SIZE_M)
|
| 339 |
+
m_divisible_by_block_m = m % BLOCK_SIZE_M == 0
|
| 340 |
+
if not m_divisible_by_block_m:
|
| 341 |
+
loop_m -= 1
|
| 342 |
+
|
| 343 |
+
acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32)
|
| 344 |
+
|
| 345 |
+
# Initialize bias accumulator
|
| 346 |
+
bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
|
| 347 |
+
|
| 348 |
+
for _ in range(0, loop_m):
|
| 349 |
+
lhs = tl.load(lhs_ptrs)
|
| 350 |
+
rhs = tl.load(rhs_ptrs)
|
| 351 |
+
|
| 352 |
+
acc = tl.dot(lhs, rhs, acc=acc)
|
| 353 |
+
|
| 354 |
+
# Accumulate for bias gradient: sum lhs across M dimension
|
| 355 |
+
if COMPUTE_BIAS_GRAD and tile_n == 0:
|
| 356 |
+
bias_acc += tl.sum(
|
| 357 |
+
lhs, axis=1
|
| 358 |
+
) # Sum across M dimension [K, M] -> [K]
|
| 359 |
+
|
| 360 |
+
if TRANS_LHS:
|
| 361 |
+
lhs_ptrs += BLOCK_SIZE_M * K
|
| 362 |
+
else:
|
| 363 |
+
lhs_ptrs += BLOCK_SIZE_M
|
| 364 |
+
|
| 365 |
+
rhs_ptrs += BLOCK_SIZE_M * N
|
| 366 |
+
|
| 367 |
+
if not m_divisible_by_block_m:
|
| 368 |
+
offs_lhs_k = (
|
| 369 |
+
tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
| 370 |
+
) % K
|
| 371 |
+
offs_rhs_n = (
|
| 372 |
+
tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 373 |
+
) % N
|
| 374 |
+
offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 375 |
+
lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0)
|
| 376 |
+
rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0)
|
| 377 |
+
acc = tl.dot(lhs, rhs, acc=acc)
|
| 378 |
+
|
| 379 |
+
# Accumulate last chunk for bias gradient
|
| 380 |
+
if COMPUTE_BIAS_GRAD and tile_n == 0:
|
| 381 |
+
bias_acc += tl.sum(lhs, axis=1)
|
| 382 |
+
|
| 383 |
+
acc = acc.to(out_ptr.type.element_ty)
|
| 384 |
+
|
| 385 |
+
offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
| 386 |
+
offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 387 |
+
|
| 388 |
+
out_ptrs = (
|
| 389 |
+
out_ptr
|
| 390 |
+
+ g.to(tl.int64) * K * N
|
| 391 |
+
+ offs_out_k[:, None] * N
|
| 392 |
+
+ offs_out_n[None, :]
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N)
|
| 396 |
+
if ACCUMULATE:
|
| 397 |
+
# Load existing values and add to them (like beta=1 in BLAS)
|
| 398 |
+
old_vals = tl.load(out_ptrs, mask=mask, other=0.0)
|
| 399 |
+
tl.store(out_ptrs, acc + old_vals, mask=mask)
|
| 400 |
+
else:
|
| 401 |
+
# Overwrite output (like beta=0 in BLAS)
|
| 402 |
+
tl.store(out_ptrs, acc, mask=mask)
|
| 403 |
+
|
| 404 |
+
# Store bias gradient (only for first N tile, sum across all M)
|
| 405 |
+
if COMPUTE_BIAS_GRAD and tile_n == 0:
|
| 406 |
+
# Keep as float32 for atomic_add (bf16 not supported for atomics)
|
| 407 |
+
bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k
|
| 408 |
+
# Use atomic add since multiple K-tiles may write to same expert's bias
|
| 409 |
+
tl.atomic_add(
|
| 410 |
+
bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed"
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
# Go to the next tile by advancing number of programs.
|
| 414 |
+
tile += GRID_DIM
|
| 415 |
+
tl.device_assert(tile > 0, "tile <= 0 (at update)")
|
| 416 |
+
|
| 417 |
+
# Get ready to go to the next MM problem.
|
| 418 |
+
|
| 419 |
+
last_mm_tile += num_tiles
|
| 420 |
+
# last_mm_tile can be zero if group 0 is skipped
|
| 421 |
+
tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)")
|
| 422 |
+
|
| 423 |
+
last_m += m
|
| 424 |
+
# last_m can be zero if group 0 is skipped
|
| 425 |
+
tl.device_assert(last_m >= 0, "last_m < 0 (at update)")
|
| 426 |
+
tl.device_assert(last_m <= M, "last_m > M (at update)")
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
# Regular non-persistent TGMM kernel.
|
| 430 |
+
# ------------------------------------------------------------------------------
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
@triton.heuristics({"BLOCK_SIZE_G": lambda META: triton.next_power_of_2(META["G"])})
|
| 434 |
+
@triton.jit
|
| 435 |
+
def tgmm_non_persistent_kernel(
|
| 436 |
+
# Tensor pointers:
|
| 437 |
+
lhs_ptr,
|
| 438 |
+
rhs_ptr,
|
| 439 |
+
group_sizes_ptr,
|
| 440 |
+
out_ptr,
|
| 441 |
+
bias_grad_ptr,
|
| 442 |
+
# Tensor shapes:
|
| 443 |
+
M: int,
|
| 444 |
+
K: int,
|
| 445 |
+
N: int,
|
| 446 |
+
G: int,
|
| 447 |
+
# Meta-parameters:
|
| 448 |
+
TRANS_LHS: tl.constexpr,
|
| 449 |
+
BLOCK_SIZE_G: tl.constexpr,
|
| 450 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 451 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 452 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 453 |
+
GROUP_SIZE: tl.constexpr,
|
| 454 |
+
COMPUTE_BIAS_GRAD: tl.constexpr,
|
| 455 |
+
ACCUMULATE: tl.constexpr,
|
| 456 |
+
):
|
| 457 |
+
tl.assume(M > 0)
|
| 458 |
+
tl.assume(K > 0)
|
| 459 |
+
tl.assume(N > 0)
|
| 460 |
+
tl.assume(G > 0)
|
| 461 |
+
|
| 462 |
+
# Get group ID from grid.
|
| 463 |
+
g = tl.program_id(0)
|
| 464 |
+
tl.device_assert(g >= 0, "g < 0")
|
| 465 |
+
tl.device_assert(g < G, "g >= G")
|
| 466 |
+
|
| 467 |
+
# Get m dimension of current MM group.
|
| 468 |
+
m = tl.load(group_sizes_ptr + g)
|
| 469 |
+
# m can be zero if group is empty.
|
| 470 |
+
tl.device_assert(m >= 0, "m < 0")
|
| 471 |
+
|
| 472 |
+
# Skip empty groups.
|
| 473 |
+
if m == 0:
|
| 474 |
+
return
|
| 475 |
+
|
| 476 |
+
# Compute sum(group_sizes) until current group g.
|
| 477 |
+
# It's the starting column of lhs and starting row of rhs.
|
| 478 |
+
offs_g = tl.arange(0, BLOCK_SIZE_G)
|
| 479 |
+
group_sizes = tl.load(group_sizes_ptr + offs_g, mask=offs_g < g, other=0)
|
| 480 |
+
start_m = tl.sum(group_sizes)
|
| 481 |
+
|
| 482 |
+
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
| 483 |
+
tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0")
|
| 484 |
+
|
| 485 |
+
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
| 486 |
+
tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0")
|
| 487 |
+
|
| 488 |
+
# Get MM tile from grid.
|
| 489 |
+
tile_in_mm = tl.program_id(1)
|
| 490 |
+
tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0")
|
| 491 |
+
|
| 492 |
+
tile_k, tile_n = _remap_xcd_tile_grid(
|
| 493 |
+
tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0")
|
| 497 |
+
tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0")
|
| 498 |
+
|
| 499 |
+
offs_lhs_k = (tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K
|
| 500 |
+
offs_rhs_n = (tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
| 501 |
+
offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
| 502 |
+
|
| 503 |
+
if TRANS_LHS:
|
| 504 |
+
lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] + (start_m + offs_m[None, :]) * K
|
| 505 |
+
else:
|
| 506 |
+
lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] * M + (start_m + offs_m[None, :])
|
| 507 |
+
|
| 508 |
+
rhs_ptrs = rhs_ptr + (start_m + offs_m[:, None]) * N + offs_rhs_n[None, :]
|
| 509 |
+
|
| 510 |
+
loop_m = tl.cdiv(m, BLOCK_SIZE_M)
|
| 511 |
+
m_divisible_by_block_m = m % BLOCK_SIZE_M == 0
|
| 512 |
+
if not m_divisible_by_block_m:
|
| 513 |
+
loop_m -= 1
|
| 514 |
+
|
| 515 |
+
acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32)
|
| 516 |
+
# Initialize bias accumulator
|
| 517 |
+
bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
|
| 518 |
+
|
| 519 |
+
for _ in range(0, loop_m):
|
| 520 |
+
lhs = tl.load(lhs_ptrs)
|
| 521 |
+
rhs = tl.load(rhs_ptrs)
|
| 522 |
+
|
| 523 |
+
acc = tl.dot(lhs, rhs, acc=acc)
|
| 524 |
+
|
| 525 |
+
# Accumulate for bias gradient: sum lhs across M dimension
|
| 526 |
+
if COMPUTE_BIAS_GRAD and tile_n == 0:
|
| 527 |
+
bias_acc += tl.sum(lhs, axis=1) # [K, M] -> [K]
|
| 528 |
+
|
| 529 |
+
if TRANS_LHS:
|
| 530 |
+
lhs_ptrs += BLOCK_SIZE_M * K
|
| 531 |
+
else:
|
| 532 |
+
lhs_ptrs += BLOCK_SIZE_M
|
| 533 |
+
|
| 534 |
+
rhs_ptrs += BLOCK_SIZE_M * N
|
| 535 |
+
|
| 536 |
+
if not m_divisible_by_block_m:
|
| 537 |
+
offs_lhs_k = (
|
| 538 |
+
tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
| 539 |
+
) % K
|
| 540 |
+
offs_rhs_n = (
|
| 541 |
+
tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 542 |
+
) % N
|
| 543 |
+
offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 544 |
+
lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0)
|
| 545 |
+
rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0)
|
| 546 |
+
acc = tl.dot(lhs, rhs, acc=acc)
|
| 547 |
+
# Accumulate last chunk for bias gradient
|
| 548 |
+
if COMPUTE_BIAS_GRAD and tile_n == 0:
|
| 549 |
+
bias_acc += tl.sum(lhs, axis=1)
|
| 550 |
+
|
| 551 |
+
acc = acc.to(out_ptr.type.element_ty)
|
| 552 |
+
|
| 553 |
+
offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
| 554 |
+
offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 555 |
+
|
| 556 |
+
out_ptrs = (
|
| 557 |
+
out_ptr + g.to(tl.int64) * K * N + offs_out_k[:, None] * N + offs_out_n[None, :]
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N)
|
| 561 |
+
if ACCUMULATE:
|
| 562 |
+
# Load existing values and add to them (like beta=1 in BLAS)
|
| 563 |
+
old_vals = tl.load(out_ptrs, mask=mask, other=0.0)
|
| 564 |
+
tl.store(out_ptrs, acc + old_vals, mask=mask)
|
| 565 |
+
else:
|
| 566 |
+
# Overwrite output (like beta=0 in BLAS)
|
| 567 |
+
tl.store(out_ptrs, acc, mask=mask)
|
| 568 |
+
|
| 569 |
+
# Store bias gradient (only for first N tile, sum across all M)
|
| 570 |
+
if COMPUTE_BIAS_GRAD and tile_n == 0:
|
| 571 |
+
# Keep as float32 for atomic_add (bf16/fp16 not supported for atomics)
|
| 572 |
+
bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k
|
| 573 |
+
# Use atomic add since multiple K-tiles may write to same expert's bias
|
| 574 |
+
tl.atomic_add(bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed")
|
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/adapter.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Adapt AITER's Triton grouped GEMM to MegaBlocks' ``gmm`` calling convention.
|
| 3 |
+
|
| 4 |
+
MegaBlocks (following tgale96/grouped_gemm) uses a single ``gmm`` entry point
|
| 5 |
+
with ``trans_a`` / ``trans_b`` flags:
|
| 6 |
+
|
| 7 |
+
* ``trans_a=False, trans_b=False``: a(M,K) @ b(G,K,N) -> c(M,N)
|
| 8 |
+
* ``trans_a=False, trans_b=True`` : a(M,K) @ b(G,N,K)^T -> c(M,N) (dgrad)
|
| 9 |
+
* ``trans_a=True`` : a(M,K)^T @ b(M,N) per group -> c(G,K,N) (wgrad)
|
| 10 |
+
|
| 11 |
+
AITER exposes these as two kernels: ``gmm`` ((M,K)@(G,K,N)->(M,N), transposition
|
| 12 |
+
of the 3D operand inferred from strides) and ``ptgmm`` ((K,M)@(M,N)->(G,K,N),
|
| 13 |
+
transposition of the 2D operand inferred from strides).
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
from .gmm import gmm as _aiter_gmm
|
| 19 |
+
from .gmm import ptgmm as _aiter_ptgmm
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def gmm(a, b, c, batch_sizes, trans_a=False, trans_b=False):
|
| 23 |
+
# AITER requires group sizes to be int32 and to live on the compute device.
|
| 24 |
+
group_sizes = batch_sizes.to(device=a.device, dtype=torch.int32)
|
| 25 |
+
|
| 26 |
+
# AITER asserts exact strides: gmm wants lhs/rhs row-major (a transposed
|
| 27 |
+
# 3D operand must be exactly column-major), tgmm wants rhs row-major and
|
| 28 |
+
# lhs row/column-major. Make operands contiguous first so the transposed
|
| 29 |
+
# views have the precise strides the kernels expect. `.contiguous()` is a
|
| 30 |
+
# no-op when the tensor is already contiguous.
|
| 31 |
+
if trans_a:
|
| 32 |
+
# Weight gradient: a(M,K), b(M,N) -> c(G,K,N).
|
| 33 |
+
# Pass a transposed so AITER sees lhs(K,M) column-major (TRANS_LHS).
|
| 34 |
+
_aiter_ptgmm(
|
| 35 |
+
a.contiguous().transpose(0, 1),
|
| 36 |
+
b.contiguous(),
|
| 37 |
+
group_sizes,
|
| 38 |
+
preferred_element_type=c.dtype,
|
| 39 |
+
existing_out=c,
|
| 40 |
+
)
|
| 41 |
+
else:
|
| 42 |
+
# trans_b contracts b's last dim: pass a column-major (G,K,N) view.
|
| 43 |
+
rhs = b.contiguous()
|
| 44 |
+
if trans_b:
|
| 45 |
+
rhs = rhs.transpose(1, 2)
|
| 46 |
+
_aiter_gmm(
|
| 47 |
+
a.contiguous(),
|
| 48 |
+
rhs,
|
| 49 |
+
group_sizes,
|
| 50 |
+
preferred_element_type=c.dtype,
|
| 51 |
+
existing_out=c,
|
| 52 |
+
)
|
| 53 |
+
return c
|
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/configs.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
# Tuned GMM configs vendored from ROCm/aiter (aiter/ops/triton/configs/).
|
| 3 |
+
# Inlined as a Python module so packaging always includes them.
|
| 4 |
+
|
| 5 |
+
CONFIGS = {'gfx1250': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx942': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx950': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}}
|
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/gmm.py
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Imports.
|
| 6 |
+
# ------------------------------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
# PyTorch
|
| 9 |
+
import torch
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
|
| 12 |
+
# Triton
|
| 13 |
+
import triton
|
| 14 |
+
|
| 15 |
+
# AITER: GMM utility functions
|
| 16 |
+
from .utils.gmm_common import (
|
| 17 |
+
DTYPE,
|
| 18 |
+
is_power_of_2,
|
| 19 |
+
check_input_device_dtype,
|
| 20 |
+
check_bias_shape_stride,
|
| 21 |
+
get_gmm_shape,
|
| 22 |
+
get_gmm_output,
|
| 23 |
+
get_gmm_transposition,
|
| 24 |
+
get_tgmm_shape,
|
| 25 |
+
get_tgmm_output,
|
| 26 |
+
get_tgmm_bias_grad,
|
| 27 |
+
get_tgmm_transposition,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# AITER: GMM Triton kernels
|
| 31 |
+
from ._triton_kernels.gmm import (
|
| 32 |
+
gmm_kernel,
|
| 33 |
+
tgmm_persistent_kernel,
|
| 34 |
+
tgmm_non_persistent_kernel,
|
| 35 |
+
get_config,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# GMM PyTorch wrapper.
|
| 39 |
+
# ------------------------------------------------------------------------------
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _gmm_grid(
|
| 43 |
+
N: int,
|
| 44 |
+
block_size_m: int,
|
| 45 |
+
block_size_n: int,
|
| 46 |
+
group_sizes: Tensor,
|
| 47 |
+
grid_dim: int,
|
| 48 |
+
) -> tuple[int]:
|
| 49 |
+
assert N > 0, f"N must be positive, it's {N}."
|
| 50 |
+
assert is_power_of_2(
|
| 51 |
+
block_size_m
|
| 52 |
+
), f"M-dimension tile size must be a power of 2 (it's {block_size_m})."
|
| 53 |
+
assert is_power_of_2(
|
| 54 |
+
block_size_n
|
| 55 |
+
), f"N-dimension tile size must be a power of 2 (it's {block_size_n})."
|
| 56 |
+
assert torch.all(group_sizes >= 0).item(), "All group_sizes must be non-negative."
|
| 57 |
+
assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})."
|
| 58 |
+
num_m_tiles = (group_sizes + block_size_m - 1) // block_size_m
|
| 59 |
+
assert torch.all(num_m_tiles >= 0).item(), "All num_m_tiles must be non-negative."
|
| 60 |
+
num_n_tiles = triton.cdiv(N, block_size_n)
|
| 61 |
+
assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}."
|
| 62 |
+
num_tiles = torch.sum(num_m_tiles * num_n_tiles).item()
|
| 63 |
+
assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}."
|
| 64 |
+
num_programs = int(min(grid_dim, num_tiles))
|
| 65 |
+
assert num_programs > 0, f"num_programs must be positive, it's {num_programs}."
|
| 66 |
+
return (num_programs,)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def gmm(
|
| 70 |
+
lhs: Tensor,
|
| 71 |
+
rhs: Tensor,
|
| 72 |
+
group_sizes: Tensor,
|
| 73 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 74 |
+
existing_out: Tensor | None = None,
|
| 75 |
+
config: dict[str, int] | None = None,
|
| 76 |
+
bias: Tensor | None = None,
|
| 77 |
+
) -> Tensor:
|
| 78 |
+
"""
|
| 79 |
+
Perform Group Matrix Multiplication (GMM): out = lhs @ rhs + bias
|
| 80 |
+
|
| 81 |
+
lhs rows are divided into G groups. Each group of lhs rows is matrix multiplied with a plane of
|
| 82 |
+
rhs 3D tensor and then stored in a slice of out. In PyTorch parlance, it can be implemented as
|
| 83 |
+
follows for a given group g:
|
| 84 |
+
out[group_start:group_end, :] = lhs[group_start:group_end, :] @ rhs[g] + bias[g]
|
| 85 |
+
|
| 86 |
+
The size of each group, and their respective start and end positions are specified by
|
| 87 |
+
group_sizes tensor. For instance, suppose that group_sizes = [3, 2, 4, 1]. In this particular
|
| 88 |
+
case we have 4 groups. The 1st group starts at 0 and ends at 2, the second group starts at 3 and
|
| 89 |
+
ends at 4, the third group starts at 5 and ends at 8, and the fourth and final group consists of
|
| 90 |
+
just the 10th (last) row of lhs.
|
| 91 |
+
|
| 92 |
+
Parameters
|
| 93 |
+
----------
|
| 94 |
+
lhs : torch.Tensor
|
| 95 |
+
Left-hand side 2D input tensor. Shape: (M, K).
|
| 96 |
+
lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type.
|
| 97 |
+
lhs must be on the same device of rhs and group_sizes.
|
| 98 |
+
rhs : torch.Tensor
|
| 99 |
+
Right-hand side 3D input tensor. Shape: (G, K, N).
|
| 100 |
+
rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type.
|
| 101 |
+
rhs must be on the same device of lhs and group_sizes.
|
| 102 |
+
group_sizes : torch.Tensor
|
| 103 |
+
1D input tensor describing group sizes. Shape: (G,).
|
| 104 |
+
group_sizes data type must be torch.int32 and all its elements must be non-negative.
|
| 105 |
+
group_sizes must be on the same device of lhs and rhs.
|
| 106 |
+
preferred_element_type : torch.dtype, optional
|
| 107 |
+
Desired data type for output tensor. Default is torch.bfloat16.
|
| 108 |
+
Supported output types are torch.float16 and torch.bfloat16.
|
| 109 |
+
existing_out : torch.Tensor or None, optional
|
| 110 |
+
Preallocated output tensor. Default is None.
|
| 111 |
+
If provided, results are written into this tensor. Otherwise, a new output tensor is
|
| 112 |
+
allocated.
|
| 113 |
+
If provided then it must have shape (M, N), its data type must match preferred_element_type
|
| 114 |
+
and it must be on the same device of other input tensors.
|
| 115 |
+
config : dict[str, int] or None, optional
|
| 116 |
+
Optional dictionary with kernel metaparameters. If absent, config will be queried from
|
| 117 |
+
internal tuning database.
|
| 118 |
+
bias : torch.Tensor or None, optional
|
| 119 |
+
Optional bias tensor. Shape: (G, N).
|
| 120 |
+
If provided, bias data type must match lhs and rhs data type, and bias must be on the same
|
| 121 |
+
device as other input tensors. Each group g adds bias[g] to the output.
|
| 122 |
+
|
| 123 |
+
Returns
|
| 124 |
+
-------
|
| 125 |
+
torch.Tensor
|
| 126 |
+
The computed output 2D tensor. Shape: (M, N).
|
| 127 |
+
Output tensor data type is given by preferred_element_type.
|
| 128 |
+
If existing_out is provided then existing_out is also returned.
|
| 129 |
+
|
| 130 |
+
Implementation Notes
|
| 131 |
+
--------------------
|
| 132 |
+
- GMM is implemented with a persistent Triton kernel.
|
| 133 |
+
- lhs must be row-major (lhs.stride() == (K, 1)).
|
| 134 |
+
- rhs can be row-major (rhs.stride() == (K * N, N, 1)) or column-major (rhs.stride() ==
|
| 135 |
+
(K * N, 1, K)). If rhs is row-major then kernel parameter TRANS_RHS == False, this is useful
|
| 136 |
+
for implementing forward pass. If rhs is column-major then kernel parameter TRANS_RHS == True,
|
| 137 |
+
this is useful for computing the lhs derivative in the backward pass, while fusing the
|
| 138 |
+
transposition.
|
| 139 |
+
- out must be row-major (out.stride() == (N, 1)).
|
| 140 |
+
- bias must be row-major (bias.stride() == (N, 1)) if provided.
|
| 141 |
+
"""
|
| 142 |
+
use_bias = bias is not None
|
| 143 |
+
check_input_device_dtype(lhs, rhs, group_sizes, bias)
|
| 144 |
+
|
| 145 |
+
M, K, N, G = get_gmm_shape(lhs, rhs, group_sizes)
|
| 146 |
+
|
| 147 |
+
if use_bias:
|
| 148 |
+
check_bias_shape_stride(bias, G, N)
|
| 149 |
+
|
| 150 |
+
out = get_gmm_output(
|
| 151 |
+
M,
|
| 152 |
+
N,
|
| 153 |
+
device=lhs.device,
|
| 154 |
+
preferred_element_type=preferred_element_type,
|
| 155 |
+
existing_out=existing_out,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
trans_rhs, _ = get_gmm_transposition(lhs, rhs, out)
|
| 159 |
+
|
| 160 |
+
if config is None:
|
| 161 |
+
config = get_config("gmm", M, K, N, G)
|
| 162 |
+
|
| 163 |
+
assert all(
|
| 164 |
+
key in config
|
| 165 |
+
and isinstance(config[key], int)
|
| 166 |
+
and (
|
| 167 |
+
is_power_of_2(config[key])
|
| 168 |
+
if key.startswith("BLOCK_SIZE_")
|
| 169 |
+
else config[key] > 0
|
| 170 |
+
)
|
| 171 |
+
for key in {
|
| 172 |
+
"BLOCK_SIZE_M",
|
| 173 |
+
"BLOCK_SIZE_K",
|
| 174 |
+
"BLOCK_SIZE_N",
|
| 175 |
+
"GROUP_SIZE",
|
| 176 |
+
"GRID_DIM",
|
| 177 |
+
}
|
| 178 |
+
), "Invalid GMM kernel config."
|
| 179 |
+
|
| 180 |
+
grid = _gmm_grid(
|
| 181 |
+
N,
|
| 182 |
+
config["BLOCK_SIZE_M"],
|
| 183 |
+
config["BLOCK_SIZE_N"],
|
| 184 |
+
group_sizes,
|
| 185 |
+
config["GRID_DIM"],
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# fmt: off
|
| 189 |
+
gmm_kernel[grid](
|
| 190 |
+
# Tensor pointers:
|
| 191 |
+
lhs, rhs, group_sizes, out, bias,
|
| 192 |
+
# Tensor shapes:
|
| 193 |
+
M, K, N, G,
|
| 194 |
+
# Meta-parameters:
|
| 195 |
+
TRANS_RHS=trans_rhs,
|
| 196 |
+
USE_BIAS=use_bias,
|
| 197 |
+
**config,
|
| 198 |
+
)
|
| 199 |
+
# fmt: on
|
| 200 |
+
|
| 201 |
+
return out
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# Persistent TGMM PyTorch wrapper.
|
| 205 |
+
# ------------------------------------------------------------------------------
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def _ptgmm_grid(
|
| 209 |
+
K: int,
|
| 210 |
+
N: int,
|
| 211 |
+
G: int,
|
| 212 |
+
block_size_k: int,
|
| 213 |
+
block_size_n: int,
|
| 214 |
+
grid_dim: int,
|
| 215 |
+
) -> tuple[int]:
|
| 216 |
+
assert K > 0, f"K must be positive, it's {K}."
|
| 217 |
+
assert N > 0, f"N must be positive, it's {N}."
|
| 218 |
+
assert G > 0, f"G must be positive, it's {G}."
|
| 219 |
+
assert is_power_of_2(
|
| 220 |
+
block_size_k
|
| 221 |
+
), f"K-dimension tile size must be a power of 2 (it's {block_size_k})."
|
| 222 |
+
assert is_power_of_2(
|
| 223 |
+
block_size_n
|
| 224 |
+
), f"N-dimension tile size must be a power of 2 (it's {block_size_n})."
|
| 225 |
+
assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})."
|
| 226 |
+
num_k_tiles = triton.cdiv(K, block_size_k)
|
| 227 |
+
assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}."
|
| 228 |
+
num_n_tiles = triton.cdiv(N, block_size_n)
|
| 229 |
+
assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}."
|
| 230 |
+
num_tiles = G * num_k_tiles * num_n_tiles
|
| 231 |
+
assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}."
|
| 232 |
+
num_programs = min(grid_dim, num_tiles)
|
| 233 |
+
assert num_programs > 0, f"num_programs must be positive, it's {num_programs}."
|
| 234 |
+
return (num_programs,)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def ptgmm(
|
| 238 |
+
lhs: Tensor,
|
| 239 |
+
rhs: Tensor,
|
| 240 |
+
group_sizes: Tensor,
|
| 241 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 242 |
+
existing_out: Tensor | None = None,
|
| 243 |
+
config: dict[str, int] | None = None,
|
| 244 |
+
bias_grad: Tensor | None = None,
|
| 245 |
+
accumulate: bool = False,
|
| 246 |
+
) -> Tensor:
|
| 247 |
+
"""
|
| 248 |
+
Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs
|
| 249 |
+
|
| 250 |
+
lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with
|
| 251 |
+
the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch
|
| 252 |
+
parlance, it can be implemented as follows for a given group g:
|
| 253 |
+
out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :]
|
| 254 |
+
|
| 255 |
+
The 't' in the operator name derives from MaxText implementation
|
| 256 |
+
(https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py),
|
| 257 |
+
which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor
|
| 258 |
+
shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N).
|
| 259 |
+
|
| 260 |
+
The 'p' in the operator name means that it is implemented with a persistent kernel. There is
|
| 261 |
+
also the non-persistent variation, which is implemented with a regular kernel. Please take a
|
| 262 |
+
look at nptgmm operator. Both ptgmm and nptgmm implement the same computation, choosing one or
|
| 263 |
+
the other is a matter of performance for the target workload.
|
| 264 |
+
|
| 265 |
+
Parameters
|
| 266 |
+
----------
|
| 267 |
+
lhs : torch.Tensor
|
| 268 |
+
Left-hand side 2D input tensor. Shape: (K, M).
|
| 269 |
+
lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type.
|
| 270 |
+
lhs must be on the same device of rhs and group_sizes.
|
| 271 |
+
rhs : torch.Tensor
|
| 272 |
+
Right-hand side 2D input tensor. Shape: (M, N).
|
| 273 |
+
rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type.
|
| 274 |
+
rhs must be on the same device of lhs and group_sizes.
|
| 275 |
+
group_sizes : torch.Tensor
|
| 276 |
+
1D input tensor describing group sizes. Shape: (G,).
|
| 277 |
+
group_sizes data type must be torch.int32 and all its elements must be non-negative.
|
| 278 |
+
group_sizes must be on the same device of lhs and rhs.
|
| 279 |
+
preferred_element_type : torch.dtype, optional
|
| 280 |
+
Desired data type for output tensor. Default is torch.bfloat16.
|
| 281 |
+
Supported output types are torch.float16 and torch.bfloat16.
|
| 282 |
+
existing_out : torch.Tensor or None, optional
|
| 283 |
+
Preallocated output tensor. Default is None.
|
| 284 |
+
If provided, results are written into this tensor. Otherwise, a new output tensor is
|
| 285 |
+
allocated.
|
| 286 |
+
If provided then it must have shape (G, K, N), its data type must match
|
| 287 |
+
preferred_element_type and it must be on the same device of other input tensors.
|
| 288 |
+
config : dict[str, int] or None, optional
|
| 289 |
+
Optional dictionary with kernel metaparameters. If absent, config will be queried from
|
| 290 |
+
internal tuning database.
|
| 291 |
+
bias_grad : torch.Tensor or None, optional
|
| 292 |
+
Optional bias gradient output tensor. Shape: (G, K).
|
| 293 |
+
If provided, the kernel will compute the bias gradient and write it to this tensor.
|
| 294 |
+
bias_grad must be torch.float32 (kernel uses atomic_add which requires float32),
|
| 295 |
+
accumulate : bool, optional
|
| 296 |
+
Whether to accumulate into existing output tensor values. Default is False.
|
| 297 |
+
If False, output will be overwritten with fresh computation.
|
| 298 |
+
If True, results will be added to existing output tensor values.
|
| 299 |
+
|
| 300 |
+
Returns
|
| 301 |
+
-------
|
| 302 |
+
torch.Tensor
|
| 303 |
+
The computed output 3D tensor. Shape: (G, K, N).
|
| 304 |
+
Output tensor data type is given by preferred_element_type.
|
| 305 |
+
If existing_out is provided then existing_out is also returned.
|
| 306 |
+
|
| 307 |
+
Implementation Notes
|
| 308 |
+
--------------------
|
| 309 |
+
- PTGMM is implemented with a persistent Triton kernel.
|
| 310 |
+
- lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs
|
| 311 |
+
is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel
|
| 312 |
+
parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward
|
| 313 |
+
pass, while fusing the transposition.
|
| 314 |
+
- rhs must be row-major (rhs.stride() == (N, 1)).
|
| 315 |
+
- out must be row-major (out.stride() == (K * N, N, 1)).
|
| 316 |
+
"""
|
| 317 |
+
check_input_device_dtype(lhs, rhs, group_sizes)
|
| 318 |
+
|
| 319 |
+
M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes)
|
| 320 |
+
|
| 321 |
+
out = get_tgmm_output(
|
| 322 |
+
K,
|
| 323 |
+
N,
|
| 324 |
+
G,
|
| 325 |
+
device=lhs.device,
|
| 326 |
+
preferred_element_type=preferred_element_type,
|
| 327 |
+
existing_out=existing_out,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out)
|
| 331 |
+
|
| 332 |
+
if config is None:
|
| 333 |
+
config = get_config("ptgmm", M, K, N, G, accumulate)
|
| 334 |
+
|
| 335 |
+
assert all(
|
| 336 |
+
key in config
|
| 337 |
+
and isinstance(config[key], int)
|
| 338 |
+
and (
|
| 339 |
+
is_power_of_2(config[key])
|
| 340 |
+
if key.startswith("BLOCK_SIZE_")
|
| 341 |
+
else config[key] > 0
|
| 342 |
+
)
|
| 343 |
+
for key in {
|
| 344 |
+
"BLOCK_SIZE_M",
|
| 345 |
+
"BLOCK_SIZE_K",
|
| 346 |
+
"BLOCK_SIZE_N",
|
| 347 |
+
"GROUP_SIZE",
|
| 348 |
+
"GRID_DIM",
|
| 349 |
+
}
|
| 350 |
+
), "Invalid PTGMM kernel config."
|
| 351 |
+
|
| 352 |
+
# Bias gradient handling.
|
| 353 |
+
# -----------------------
|
| 354 |
+
# Get or validate bias gradient tensor.
|
| 355 |
+
compute_bias_grad = bias_grad is not None
|
| 356 |
+
bias_grad_ptr = get_tgmm_bias_grad(
|
| 357 |
+
K,
|
| 358 |
+
G,
|
| 359 |
+
device=lhs.device,
|
| 360 |
+
existing_bias_grad=bias_grad,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
grid = _ptgmm_grid(
|
| 364 |
+
K,
|
| 365 |
+
N,
|
| 366 |
+
G,
|
| 367 |
+
config["BLOCK_SIZE_K"],
|
| 368 |
+
config["BLOCK_SIZE_N"],
|
| 369 |
+
config["GRID_DIM"],
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# fmt: off
|
| 373 |
+
tgmm_persistent_kernel[grid](
|
| 374 |
+
# Tensor pointers:
|
| 375 |
+
lhs, rhs, group_sizes, out, bias_grad_ptr,
|
| 376 |
+
# Tensor shapes:
|
| 377 |
+
M, K, N, G,
|
| 378 |
+
# Meta-parameters:
|
| 379 |
+
TRANS_LHS=trans_lhs,
|
| 380 |
+
COMPUTE_BIAS_GRAD=compute_bias_grad,
|
| 381 |
+
ACCUMULATE=accumulate,
|
| 382 |
+
**config,
|
| 383 |
+
)
|
| 384 |
+
# fmt: on
|
| 385 |
+
|
| 386 |
+
return out
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
# Regular non-persistent TGMM PyTorch wrapper.
|
| 390 |
+
# ------------------------------------------------------------------------------
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def _nptgmm_grid(
|
| 394 |
+
K: int,
|
| 395 |
+
N: int,
|
| 396 |
+
G: int,
|
| 397 |
+
block_size_k: int,
|
| 398 |
+
block_size_n: int,
|
| 399 |
+
) -> tuple[int, int]:
|
| 400 |
+
assert K > 0, f"K must be positive, it's {K}."
|
| 401 |
+
assert N > 0, f"N must be positive, it's {N}."
|
| 402 |
+
assert G > 0, f"G must be positive, it's {G}."
|
| 403 |
+
assert is_power_of_2(
|
| 404 |
+
block_size_k
|
| 405 |
+
), f"K-dimension tile size must be a power of 2 (it's {block_size_k})."
|
| 406 |
+
assert is_power_of_2(
|
| 407 |
+
block_size_n
|
| 408 |
+
), f"N-dimension tile size must be a power of 2 (it's {block_size_n})."
|
| 409 |
+
num_k_tiles = triton.cdiv(K, block_size_k)
|
| 410 |
+
assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}."
|
| 411 |
+
num_n_tiles = triton.cdiv(N, block_size_n)
|
| 412 |
+
assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}."
|
| 413 |
+
num_tiles_per_mm = num_k_tiles * num_n_tiles
|
| 414 |
+
assert (
|
| 415 |
+
num_tiles_per_mm > 0
|
| 416 |
+
), f"num_tiles_per_mm must be positive, it's {num_tiles_per_mm}."
|
| 417 |
+
return (G, num_tiles_per_mm)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def nptgmm(
|
| 421 |
+
lhs: Tensor,
|
| 422 |
+
rhs: Tensor,
|
| 423 |
+
group_sizes: Tensor,
|
| 424 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 425 |
+
existing_out: Tensor | None = None,
|
| 426 |
+
config: dict[str, int] | None = None,
|
| 427 |
+
bias_grad: Tensor | None = None,
|
| 428 |
+
accumulate: bool = False,
|
| 429 |
+
) -> Tensor:
|
| 430 |
+
"""
|
| 431 |
+
Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs
|
| 432 |
+
|
| 433 |
+
lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with
|
| 434 |
+
the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch
|
| 435 |
+
parlance, it can be implemented as follows for a given group g:
|
| 436 |
+
out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :]
|
| 437 |
+
|
| 438 |
+
The 't' in the operator name derives from MaxText implementation
|
| 439 |
+
(https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py),
|
| 440 |
+
which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor
|
| 441 |
+
shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N).
|
| 442 |
+
|
| 443 |
+
The 'np' in the operator name means that it is implemented with a non-persistent, i.e. regular
|
| 444 |
+
kernel. There is also the persistent variation, which is implemented with a persistent kernel.
|
| 445 |
+
Please take a look at ptgmm operator. Both nptgmm and ptgmm implement the same computation,
|
| 446 |
+
choosing one or the other is a matter of performance for the target workload.
|
| 447 |
+
|
| 448 |
+
Parameters
|
| 449 |
+
----------
|
| 450 |
+
lhs : torch.Tensor
|
| 451 |
+
Left-hand side 2D input tensor. Shape: (K, M).
|
| 452 |
+
lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type.
|
| 453 |
+
lhs must be on the same device of rhs and group_sizes.
|
| 454 |
+
rhs : torch.Tensor
|
| 455 |
+
Right-hand side 2D input tensor. Shape: (M, N).
|
| 456 |
+
rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type.
|
| 457 |
+
rhs must be on the same device of lhs and group_sizes.
|
| 458 |
+
group_sizes : torch.Tensor
|
| 459 |
+
1D input tensor describing group sizes. Shape: (G,).
|
| 460 |
+
group_sizes data type must be torch.int32 and all its elements must be non-negative.
|
| 461 |
+
group_sizes must be on the same device of lhs and rhs.
|
| 462 |
+
preferred_element_type : torch.dtype, optional
|
| 463 |
+
Desired data type for output tensor. Default is torch.bfloat16.
|
| 464 |
+
Supported output types are torch.float16 and torch.bfloat16.
|
| 465 |
+
existing_out : torch.Tensor or None, optional
|
| 466 |
+
Preallocated output tensor. Default is None.
|
| 467 |
+
If provided, results are written into this tensor. Otherwise, a new output tensor is
|
| 468 |
+
allocated.
|
| 469 |
+
If provided then it must have shape (G, K, N), its data type must match
|
| 470 |
+
preferred_element_type and it must be on the same device of other input tensors.
|
| 471 |
+
config : dict[str, int] or None, optional
|
| 472 |
+
Optional dictionary with kernel metaparameters. If absent, config will be queried from
|
| 473 |
+
internal tuning database.
|
| 474 |
+
bias_grad : torch.Tensor or None, optional
|
| 475 |
+
Optional bias gradient output tensor. Shape: (G, K).
|
| 476 |
+
If provided, the kernel will compute the bias gradient and write it to this tensor.
|
| 477 |
+
bias_grad must be torch.float32 (kernel uses atomic_add which requires float32),
|
| 478 |
+
accumulate : bool, optional
|
| 479 |
+
Whether to accumulate into existing output tensor values. Default is False.
|
| 480 |
+
If False, output will be overwritten with fresh computation.
|
| 481 |
+
If True, results will be added to existing output tensor values.
|
| 482 |
+
|
| 483 |
+
Returns
|
| 484 |
+
-------
|
| 485 |
+
torch.Tensor
|
| 486 |
+
The computed output 3D tensor. Shape: (G, K, N).
|
| 487 |
+
Output tensor data type is given by preferred_element_type.
|
| 488 |
+
If existing_out is provided then existing_out is also returned.
|
| 489 |
+
|
| 490 |
+
Implementation Notes
|
| 491 |
+
--------------------
|
| 492 |
+
- NPTGMM is implemented with a non-persistent regular Triton kernel.
|
| 493 |
+
- lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs
|
| 494 |
+
is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel
|
| 495 |
+
parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward
|
| 496 |
+
pass, while fusing the transposition.
|
| 497 |
+
- rhs must be row-major (rhs.stride() == (N, 1)).
|
| 498 |
+
- out must be row-major (out.stride() == (K * N, N, 1)).
|
| 499 |
+
"""
|
| 500 |
+
check_input_device_dtype(lhs, rhs, group_sizes)
|
| 501 |
+
|
| 502 |
+
M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes)
|
| 503 |
+
|
| 504 |
+
out = get_tgmm_output(
|
| 505 |
+
K,
|
| 506 |
+
N,
|
| 507 |
+
G,
|
| 508 |
+
device=lhs.device,
|
| 509 |
+
preferred_element_type=preferred_element_type,
|
| 510 |
+
existing_out=existing_out,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out)
|
| 514 |
+
|
| 515 |
+
# Bias gradient handling.
|
| 516 |
+
# -----------------------
|
| 517 |
+
# Get or validate bias gradient tensor.
|
| 518 |
+
compute_bias_grad = bias_grad is not None
|
| 519 |
+
bias_grad_ptr = get_tgmm_bias_grad(
|
| 520 |
+
K,
|
| 521 |
+
G,
|
| 522 |
+
device=lhs.device,
|
| 523 |
+
existing_bias_grad=bias_grad,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
if config is None:
|
| 527 |
+
config = get_config("nptgmm", M, K, N, G, accumulate)
|
| 528 |
+
|
| 529 |
+
assert all(
|
| 530 |
+
key in config
|
| 531 |
+
and isinstance(config[key], int)
|
| 532 |
+
and (
|
| 533 |
+
is_power_of_2(config[key])
|
| 534 |
+
if key.startswith("BLOCK_SIZE_")
|
| 535 |
+
else config[key] > 0
|
| 536 |
+
)
|
| 537 |
+
for key in {
|
| 538 |
+
"BLOCK_SIZE_M",
|
| 539 |
+
"BLOCK_SIZE_K",
|
| 540 |
+
"BLOCK_SIZE_N",
|
| 541 |
+
"GROUP_SIZE",
|
| 542 |
+
}
|
| 543 |
+
), "Invalid NPTGMM kernel config."
|
| 544 |
+
|
| 545 |
+
grid = _nptgmm_grid(
|
| 546 |
+
K,
|
| 547 |
+
N,
|
| 548 |
+
G,
|
| 549 |
+
config["BLOCK_SIZE_K"],
|
| 550 |
+
config["BLOCK_SIZE_N"],
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
# fmt: off
|
| 554 |
+
tgmm_non_persistent_kernel[grid](
|
| 555 |
+
# Tensor pointers:
|
| 556 |
+
lhs, rhs, group_sizes, out, bias_grad_ptr,
|
| 557 |
+
# Tensor shapes:
|
| 558 |
+
M, K, N, G,
|
| 559 |
+
# Meta-parameters:
|
| 560 |
+
TRANS_LHS=trans_lhs,
|
| 561 |
+
COMPUTE_BIAS_GRAD=compute_bias_grad,
|
| 562 |
+
ACCUMULATE=accumulate,
|
| 563 |
+
**config,
|
| 564 |
+
)
|
| 565 |
+
# fmt: on
|
| 566 |
+
|
| 567 |
+
return out
|
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/__init__.py
ADDED
|
File without changes
|
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py
ADDED
|
File without changes
|
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import triton
|
| 2 |
+
|
| 3 |
+
# Detect the GPU arch lazily: querying the triton driver at import time fails
|
| 4 |
+
# in headless environments (e.g. the kernel-builder ABI check sandbox has no
|
| 5 |
+
# GPU), and the original JAX fallback pulled in an unrelated runtime dep. The
|
| 6 |
+
# arch is only actually needed when a GMM kernel is dispatched, so resolve and
|
| 7 |
+
# cache on first call.
|
| 8 |
+
_CACHED_ARCH = None
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_arch():
|
| 12 |
+
global _CACHED_ARCH
|
| 13 |
+
if _CACHED_ARCH is not None:
|
| 14 |
+
return _CACHED_ARCH
|
| 15 |
+
try:
|
| 16 |
+
_CACHED_ARCH = triton.runtime.driver.active.get_current_target().arch
|
| 17 |
+
except RuntimeError:
|
| 18 |
+
try:
|
| 19 |
+
from jax._src.lib import gpu_triton as triton_kernel_call_lib
|
| 20 |
+
_CACHED_ARCH = triton_kernel_call_lib.get_arch_details("0").split(":")[0]
|
| 21 |
+
except ImportError as e:
|
| 22 |
+
raise RuntimeError(
|
| 23 |
+
"Cannot determine GPU arch: triton driver is inactive and "
|
| 24 |
+
"JAX is not available. A GPU is required for grouped GEMM."
|
| 25 |
+
) from e
|
| 26 |
+
return _CACHED_ARCH
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def is_gluon_avail():
|
| 30 |
+
return get_arch() in ("gfx950", "gfx1250")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def is_fp4_avail():
|
| 34 |
+
return get_arch() in ("gfx950", "gfx1250")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def is_fp8_avail():
|
| 38 |
+
return get_arch() in ("gfx942", "gfx950", "gfx1250", "gfx1200", "gfx1201")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def is_mx_scale_preshuffling_avail():
|
| 42 |
+
return get_arch() in ("gfx950", "gfx1250")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def is_tdm_avail():
|
| 46 |
+
return get_arch() in ("gfx1250",)
|
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
|
| 3 |
+
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
|
| 4 |
+
|
| 5 |
+
import triton
|
| 6 |
+
import triton.language as tl
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@triton.jit
|
| 10 |
+
def remap_xcd_chunked(
|
| 11 |
+
pid, GRID_MN, NUM_XCDS: tl.constexpr = 8, CHUNK_SIZE: tl.constexpr = 2
|
| 12 |
+
):
|
| 13 |
+
# Compute current XCD and local PID
|
| 14 |
+
xcd = pid % NUM_XCDS
|
| 15 |
+
# distribute the modulo pids in round robin
|
| 16 |
+
if pid > (GRID_MN // (NUM_XCDS * CHUNK_SIZE)) * (NUM_XCDS * CHUNK_SIZE):
|
| 17 |
+
return pid
|
| 18 |
+
local_pid = pid // NUM_XCDS
|
| 19 |
+
# Calculate chunk index and position within chunk
|
| 20 |
+
chunk_idx = local_pid // CHUNK_SIZE
|
| 21 |
+
pos_in_chunk = local_pid % CHUNK_SIZE
|
| 22 |
+
# Calculate new PID
|
| 23 |
+
new_pid = chunk_idx * NUM_XCDS * CHUNK_SIZE + xcd * CHUNK_SIZE + pos_in_chunk
|
| 24 |
+
return new_pid
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@triton.jit
|
| 28 |
+
def remap_xcd(pid, GRID_MN, NUM_XCDS: tl.constexpr = 8):
|
| 29 |
+
## pid remapping on xcds
|
| 30 |
+
# Number of pids per XCD in the new arrangement
|
| 31 |
+
pids_per_xcd = (GRID_MN + NUM_XCDS - 1) // NUM_XCDS
|
| 32 |
+
# When GRID_MN cannot divide NUM_XCDS, some xcds will have
|
| 33 |
+
# pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
|
| 34 |
+
# We calculate the number of xcds that have pids_per_xcd pids as
|
| 35 |
+
# tall_xcds
|
| 36 |
+
tall_xcds = GRID_MN % NUM_XCDS
|
| 37 |
+
tall_xcds = NUM_XCDS if tall_xcds == 0 else tall_xcds
|
| 38 |
+
# Compute current XCD and local pid within the XCD
|
| 39 |
+
xcd = pid % NUM_XCDS
|
| 40 |
+
local_pid = pid // NUM_XCDS
|
| 41 |
+
# Calculate new pid based on the new grouping
|
| 42 |
+
# Note that we need to consider the following two cases:
|
| 43 |
+
# 1. the current pid is on a tall xcd
|
| 44 |
+
# 2. the current pid is on a short xcd
|
| 45 |
+
if xcd < tall_xcds:
|
| 46 |
+
pid = xcd * pids_per_xcd + local_pid
|
| 47 |
+
else:
|
| 48 |
+
pid = (
|
| 49 |
+
tall_xcds * pids_per_xcd
|
| 50 |
+
+ (xcd - tall_xcds) * (pids_per_xcd - 1)
|
| 51 |
+
+ local_pid
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return pid
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@triton.jit
|
| 58 |
+
def pid_grid(pid: int, num_pid_m: int, num_pid_n: int, GROUP_SIZE_M: tl.constexpr = 1):
|
| 59 |
+
"""
|
| 60 |
+
Maps 1D pid to 2D grid coords (pid_m, pid_n).
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
- pid: 1D pid
|
| 64 |
+
- num_pid_m: grid m size
|
| 65 |
+
- num_pid_n: grid n size
|
| 66 |
+
- GROUP_SIZE_M: tl.constexpr: default is 1
|
| 67 |
+
"""
|
| 68 |
+
if GROUP_SIZE_M == 1:
|
| 69 |
+
pid_m = pid // num_pid_n
|
| 70 |
+
pid_n = pid % num_pid_n
|
| 71 |
+
else:
|
| 72 |
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
| 73 |
+
group_id = pid // num_pid_in_group
|
| 74 |
+
first_pid_m = group_id * GROUP_SIZE_M
|
| 75 |
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
| 76 |
+
tl.assume(group_size_m >= 0)
|
| 77 |
+
pid_m = first_pid_m + (pid % group_size_m)
|
| 78 |
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
| 79 |
+
|
| 80 |
+
return pid_m, pid_n
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@triton.jit
|
| 84 |
+
def pid_grid_3d(pid: int, num_pid_m: int, num_pid_n: int, num_pid_k):
|
| 85 |
+
"""
|
| 86 |
+
Maps 1D pid to 3D grid coords (pid_m, pid_n, pid_k).
|
| 87 |
+
Args:
|
| 88 |
+
- pid: 1D pid
|
| 89 |
+
- num_pid_m: grid m size
|
| 90 |
+
- num_pid_n: grid n size
|
| 91 |
+
- num_pid_k: grid k size
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
- pid_m, pid_n, pid_k: 3D grid coordinates
|
| 95 |
+
"""
|
| 96 |
+
pid_m = pid % num_pid_m
|
| 97 |
+
pid_n = (pid // num_pid_m) % num_pid_n
|
| 98 |
+
pid_k = pid // (num_pid_m * num_pid_n) % num_pid_k
|
| 99 |
+
|
| 100 |
+
return pid_m, pid_n, pid_k
|
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py
ADDED
|
@@ -0,0 +1,752 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
| 3 |
+
|
| 4 |
+
# Imports.
|
| 5 |
+
# ------------------------------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
# PyTorch
|
| 8 |
+
import torch
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
# AITER: logging
|
| 12 |
+
from .logger import AiterTritonLogger
|
| 13 |
+
|
| 14 |
+
_LOGGER: AiterTritonLogger = AiterTritonLogger()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Supported data types.
|
| 18 |
+
# ------------------------------------------------------------------------------
|
| 19 |
+
|
| 20 |
+
# Supported data types, as strings.
|
| 21 |
+
SUPPORTED_DTYPES_STR: set[str] = {"fp16", "bf16"}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Convert string data type to PyTorch data type.
|
| 25 |
+
def dtype_from_str(dtype_str: str) -> torch.dtype:
|
| 26 |
+
dtype_str = dtype_str.strip().lower()
|
| 27 |
+
dtype_str = dtype_str[1:] if dtype_str[0] in {"i", "o"} else dtype_str
|
| 28 |
+
assert (
|
| 29 |
+
dtype_str in SUPPORTED_DTYPES_STR
|
| 30 |
+
), "String data type isn't in set of supported string data types."
|
| 31 |
+
return {"fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Supported data types, as PyTorch types.
|
| 35 |
+
SUPPORTED_DTYPES: set[torch.dtype] = {
|
| 36 |
+
dtype_from_str(dtype_str) for dtype_str in SUPPORTED_DTYPES_STR
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Convert PyTorch data type to string data type.
|
| 41 |
+
def str_from_dtype(dtype: torch.dtype) -> str:
|
| 42 |
+
assert (
|
| 43 |
+
dtype in SUPPORTED_DTYPES
|
| 44 |
+
), "PyTorch data type isn't in set of supported PyTorch data types."
|
| 45 |
+
return {torch.float16: "fp16", torch.bfloat16: "bf16"}[dtype]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Default data type, as string.
|
| 49 |
+
DTYPE_STR: str = "bf16"
|
| 50 |
+
assert (
|
| 51 |
+
DTYPE_STR in SUPPORTED_DTYPES_STR
|
| 52 |
+
), "Default string data type isn't in set of supported string data types."
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Default data type, as PyTorch type.
|
| 56 |
+
DTYPE: torch.dtype = dtype_from_str(DTYPE_STR)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Other defaults.
|
| 60 |
+
# ------------------------------------------------------------------------------
|
| 61 |
+
|
| 62 |
+
# Default device.
|
| 63 |
+
DEVICE: torch.device | str = "cuda"
|
| 64 |
+
|
| 65 |
+
# Default RNG seed for input generation.
|
| 66 |
+
RNG_SEED: int = 0
|
| 67 |
+
|
| 68 |
+
# Default number of group sizes.
|
| 69 |
+
NUM_GROUP_SIZES: int = 1
|
| 70 |
+
|
| 71 |
+
# Default transposition (NN).
|
| 72 |
+
TRANS_LHS: bool = False
|
| 73 |
+
TRANS_RHS: bool = False
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Parameter checking functions.
|
| 77 |
+
# ------------------------------------------------------------------------------
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def is_power_of_2(x: int) -> bool:
|
| 81 |
+
return (x > 0) and (x & (x - 1) == 0)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def check_input_device_dtype(
|
| 85 |
+
lhs: Tensor, rhs: Tensor, group_sizes: Tensor, bias: Tensor | None = None
|
| 86 |
+
) -> None:
|
| 87 |
+
assert (
|
| 88 |
+
lhs.device == rhs.device == group_sizes.device
|
| 89 |
+
), f"All input tensors must be in the same device (lhs = {lhs.device}, rhs = {rhs.device}, group_sizes = {group_sizes.device})."
|
| 90 |
+
assert (
|
| 91 |
+
lhs.dtype == rhs.dtype
|
| 92 |
+
), f"lhs and rhs types must match (lhs = {lhs.dtype}, rhs = {rhs.dtype})."
|
| 93 |
+
assert group_sizes.dtype == torch.int32, "group_sizes type must be int32."
|
| 94 |
+
|
| 95 |
+
if bias is not None:
|
| 96 |
+
assert (
|
| 97 |
+
bias.device == lhs.device
|
| 98 |
+
), f"bias must be on the same device as lhs (bias = {bias.device}, lhs = {lhs.device})."
|
| 99 |
+
assert (
|
| 100 |
+
bias.dtype == lhs.dtype
|
| 101 |
+
), f"bias dtype must match lhs dtype (bias = {bias.dtype}, lhs = {lhs.dtype})."
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def check_bias_shape_stride(bias: Tensor, G: int, N: int) -> None:
|
| 105 |
+
assert bias.shape == (
|
| 106 |
+
G,
|
| 107 |
+
N,
|
| 108 |
+
), f"bias must have shape (G, N) = ({G}, {N}), got {bias.shape}."
|
| 109 |
+
assert bias.stride() == (N, 1), "bias must be row-major (bias.stride() == (N, 1))."
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Generation of group sizes.
|
| 113 |
+
# ------------------------------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Probabilities for generating random group sizes.
|
| 117 |
+
UNUSED_TOKENS_PROB: float = 0.0
|
| 118 |
+
UNUSED_EXPERTS_PROB: float = 0.1
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def gen_uniform_group_sizes(
|
| 122 |
+
M: int,
|
| 123 |
+
G: int,
|
| 124 |
+
device: torch.device | str = DEVICE,
|
| 125 |
+
) -> Tensor:
|
| 126 |
+
assert M >= 0, f"Number of tokens M must be non-negative (it's {M})."
|
| 127 |
+
assert G > 0, f"Number of experts G must be positive (it's {G})."
|
| 128 |
+
|
| 129 |
+
base = M // G
|
| 130 |
+
remainder = M % G
|
| 131 |
+
group_sizes = torch.full((G,), base, dtype=torch.int32, device=device)
|
| 132 |
+
if remainder > 0:
|
| 133 |
+
group_sizes[:remainder] += 1
|
| 134 |
+
|
| 135 |
+
assert (
|
| 136 |
+
len(group_sizes) == G
|
| 137 |
+
), f"Group sizes don't have {G} elements (it's {len(group_sizes)})."
|
| 138 |
+
assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative."
|
| 139 |
+
assert (
|
| 140 |
+
torch.sum(group_sizes).item() == M
|
| 141 |
+
), f"Group sizes don't add up to total tokens {M}."
|
| 142 |
+
assert group_sizes.dtype == torch.int32, "Group sizes must be int32."
|
| 143 |
+
|
| 144 |
+
return group_sizes
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def gen_group_sizes(
|
| 148 |
+
M: int,
|
| 149 |
+
G: int,
|
| 150 |
+
device: torch.device | str = DEVICE,
|
| 151 |
+
rng_seed: int | None = RNG_SEED,
|
| 152 |
+
unused_tokens_prob: float = UNUSED_TOKENS_PROB,
|
| 153 |
+
unused_experts_prob: float = UNUSED_EXPERTS_PROB,
|
| 154 |
+
) -> Tensor:
|
| 155 |
+
assert M >= 0, f"Number of tokens M must be non-negative (it's {M})."
|
| 156 |
+
assert G > 0, f"Number of experts G must be positive (it's {G})."
|
| 157 |
+
assert (
|
| 158 |
+
0 <= unused_tokens_prob <= 1
|
| 159 |
+
), f"Probability of unused tokens must be in [0, 1] interval (it's {unused_tokens_prob})."
|
| 160 |
+
assert (
|
| 161 |
+
0 <= unused_experts_prob <= 1
|
| 162 |
+
), f"Probability of unused experts must be in [0, 1] interval (it's {unused_experts_prob})."
|
| 163 |
+
|
| 164 |
+
if rng_seed is not None:
|
| 165 |
+
torch.manual_seed(rng_seed)
|
| 166 |
+
|
| 167 |
+
if unused_tokens_prob > 0:
|
| 168 |
+
# Optionally drop tokens to simulate routing sparsity, some tokens may not be routed.
|
| 169 |
+
num_unused_tokens = M
|
| 170 |
+
while num_unused_tokens == M:
|
| 171 |
+
num_unused_tokens = int(
|
| 172 |
+
torch.binomial(
|
| 173 |
+
torch.tensor(float(M), device=device),
|
| 174 |
+
torch.tensor(unused_tokens_prob, device=device),
|
| 175 |
+
).item()
|
| 176 |
+
)
|
| 177 |
+
else:
|
| 178 |
+
num_unused_tokens = 0
|
| 179 |
+
num_used_tokens = M - num_unused_tokens
|
| 180 |
+
assert (
|
| 181 |
+
num_unused_tokens >= 0
|
| 182 |
+
), f"Number of unused tokens must be non-negative (it's {num_unused_tokens})."
|
| 183 |
+
assert (
|
| 184 |
+
num_used_tokens > 0
|
| 185 |
+
), f"Number of used tokens must be positive (it's {num_used_tokens})."
|
| 186 |
+
assert (
|
| 187 |
+
num_used_tokens + num_unused_tokens == M
|
| 188 |
+
), f"Unused + used tokens don't add up total tokens ({num_used_tokens} + {num_unused_tokens} != {M})."
|
| 189 |
+
|
| 190 |
+
if num_unused_tokens > 0:
|
| 191 |
+
_LOGGER.debug(
|
| 192 |
+
f"Group sizes generation: dropped {num_unused_tokens} token{'s' if num_unused_tokens > 1 else ''}.",
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if unused_experts_prob > 0:
|
| 196 |
+
# Some experts may have zero tokens assigned to them.
|
| 197 |
+
num_used_experts = 0
|
| 198 |
+
while num_used_experts == 0:
|
| 199 |
+
used_experts = torch.nonzero(
|
| 200 |
+
torch.rand((G,), device=device) >= unused_experts_prob
|
| 201 |
+
).squeeze()
|
| 202 |
+
num_used_experts = used_experts.numel()
|
| 203 |
+
else:
|
| 204 |
+
used_experts = torch.arange(0, G, device=device)
|
| 205 |
+
num_used_experts = G
|
| 206 |
+
num_unused_experts = G - num_used_experts
|
| 207 |
+
assert (
|
| 208 |
+
num_unused_experts >= 0
|
| 209 |
+
), f"Number of unused experts must be non-negative (it's {num_unused_experts})."
|
| 210 |
+
assert (
|
| 211 |
+
num_used_experts >= 1
|
| 212 |
+
), f"At least one expert must be used (it's {num_used_experts})."
|
| 213 |
+
assert (
|
| 214 |
+
num_unused_experts + num_used_experts == G
|
| 215 |
+
), f"Unused + used experts don't add up total experts ({num_unused_experts} + {num_used_experts} != {G})."
|
| 216 |
+
|
| 217 |
+
if num_unused_experts > 0:
|
| 218 |
+
_LOGGER.debug(
|
| 219 |
+
f"Group sizes generation: dropped {num_unused_experts} expert{'s' if num_unused_experts > 1 else ''}.",
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
group_sizes = torch.bincount(
|
| 223 |
+
used_experts[
|
| 224 |
+
torch.randint(low=0, high=num_used_experts, size=(num_used_tokens,))
|
| 225 |
+
],
|
| 226 |
+
minlength=G,
|
| 227 |
+
).to(torch.int32)
|
| 228 |
+
|
| 229 |
+
assert (
|
| 230 |
+
len(group_sizes) == G
|
| 231 |
+
), f"Group sizes don't have {G} elements (it's {len(group_sizes)})."
|
| 232 |
+
assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative."
|
| 233 |
+
assert (
|
| 234 |
+
torch.sum(group_sizes).item() == num_used_tokens
|
| 235 |
+
), f"Group sizes don't add up to used tokens {num_used_tokens}."
|
| 236 |
+
assert group_sizes.dtype == torch.int32, "Group sizes must be int32."
|
| 237 |
+
|
| 238 |
+
return group_sizes
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def gen_multiple_group_sizes(
|
| 242 |
+
num_group_sizes: int,
|
| 243 |
+
M: int,
|
| 244 |
+
G: int,
|
| 245 |
+
device: torch.device | str = DEVICE,
|
| 246 |
+
rng_seed: int | None = RNG_SEED,
|
| 247 |
+
unused_tokens_prob: float = UNUSED_TOKENS_PROB,
|
| 248 |
+
unused_experts_prob: float = UNUSED_EXPERTS_PROB,
|
| 249 |
+
group_sizes_0: Tensor | None = None,
|
| 250 |
+
) -> list[Tensor]:
|
| 251 |
+
assert (
|
| 252 |
+
num_group_sizes > 0
|
| 253 |
+
), f"Number of group sizes to be generated must be positive, it's {num_group_sizes}."
|
| 254 |
+
multiple_group_sizes = [
|
| 255 |
+
gen_group_sizes(
|
| 256 |
+
M,
|
| 257 |
+
G,
|
| 258 |
+
device=device,
|
| 259 |
+
rng_seed=rng_seed if g == 0 else None,
|
| 260 |
+
unused_tokens_prob=unused_tokens_prob,
|
| 261 |
+
unused_experts_prob=unused_experts_prob,
|
| 262 |
+
)
|
| 263 |
+
for g in range(
|
| 264 |
+
num_group_sizes if group_sizes_0 is None else num_group_sizes - 1
|
| 265 |
+
)
|
| 266 |
+
]
|
| 267 |
+
if group_sizes_0 is not None:
|
| 268 |
+
multiple_group_sizes.insert(0, group_sizes_0)
|
| 269 |
+
assert (
|
| 270 |
+
len(multiple_group_sizes) == num_group_sizes
|
| 271 |
+
), f"Expecting {num_group_sizes} distinct group sizes (it's {len(multiple_group_sizes)})."
|
| 272 |
+
return multiple_group_sizes
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# GMM helpers: tensor generation.
|
| 276 |
+
# ------------------------------------------------------------------------------
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def gen_gmm_input(
|
| 280 |
+
M: int,
|
| 281 |
+
K: int,
|
| 282 |
+
N: int,
|
| 283 |
+
G: int,
|
| 284 |
+
device: torch.device | str = DEVICE,
|
| 285 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 286 |
+
trans_rhs: bool = TRANS_RHS,
|
| 287 |
+
rng_seed: int | None = RNG_SEED,
|
| 288 |
+
unif_group_sizes: bool = False,
|
| 289 |
+
) -> tuple[Tensor, Tensor, Tensor]:
|
| 290 |
+
assert M > 0, f"Number of lhs rows M must be positive (M = {M})."
|
| 291 |
+
assert K > 0, f"Number of lhs columns / rhs rows K must be positive (K = {K})."
|
| 292 |
+
assert N > 0, f"Number of rhs columns N must be positive (N = {N})."
|
| 293 |
+
assert G > 0, f"Number of groups G must be positive (G = {G})."
|
| 294 |
+
|
| 295 |
+
if rng_seed is not None:
|
| 296 |
+
torch.manual_seed(rng_seed)
|
| 297 |
+
|
| 298 |
+
lhs = torch.randn((M, K), dtype=torch.float32, device=device)
|
| 299 |
+
lhs = lhs.to(preferred_element_type)
|
| 300 |
+
|
| 301 |
+
if trans_rhs:
|
| 302 |
+
rhs = torch.randn((G, N, K), dtype=torch.float32, device=device).permute(
|
| 303 |
+
0, 2, 1
|
| 304 |
+
)
|
| 305 |
+
else:
|
| 306 |
+
rhs = torch.randn((G, K, N), dtype=torch.float32, device=device)
|
| 307 |
+
rhs = rhs.to(preferred_element_type)
|
| 308 |
+
|
| 309 |
+
group_sizes = (
|
| 310 |
+
gen_uniform_group_sizes(M, G, device=device)
|
| 311 |
+
if unif_group_sizes
|
| 312 |
+
else gen_group_sizes(M, G, device=device, rng_seed=None)
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
return lhs, rhs, group_sizes
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def gen_gmm_output(
|
| 319 |
+
M: int,
|
| 320 |
+
N: int,
|
| 321 |
+
device: torch.device | str = DEVICE,
|
| 322 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 323 |
+
) -> Tensor:
|
| 324 |
+
assert M > 0, f"Number of out rows M must be positive (M = {M})."
|
| 325 |
+
assert N > 0, f"Number of out columns N must be positive (N = {N})."
|
| 326 |
+
|
| 327 |
+
out = torch.empty((M, N), dtype=preferred_element_type, device=device)
|
| 328 |
+
|
| 329 |
+
return out
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def gen_gmm_tensors(
|
| 333 |
+
M: int,
|
| 334 |
+
K: int,
|
| 335 |
+
N: int,
|
| 336 |
+
G: int,
|
| 337 |
+
num_group_sizes: int,
|
| 338 |
+
device: torch.device | str = DEVICE,
|
| 339 |
+
input_type: torch.dtype = DTYPE,
|
| 340 |
+
output_type: torch.dtype = DTYPE,
|
| 341 |
+
trans_lhs: bool = False,
|
| 342 |
+
trans_rhs: bool = TRANS_RHS,
|
| 343 |
+
rng_seed: int | None = RNG_SEED,
|
| 344 |
+
unif_group_sizes: bool = False,
|
| 345 |
+
use_bias: bool = False,
|
| 346 |
+
) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]:
|
| 347 |
+
lhs, rhs, group_sizes_0 = gen_gmm_input(
|
| 348 |
+
M,
|
| 349 |
+
K,
|
| 350 |
+
N,
|
| 351 |
+
G,
|
| 352 |
+
device=device,
|
| 353 |
+
preferred_element_type=input_type,
|
| 354 |
+
trans_rhs=trans_rhs,
|
| 355 |
+
rng_seed=rng_seed,
|
| 356 |
+
unif_group_sizes=unif_group_sizes,
|
| 357 |
+
)
|
| 358 |
+
multiple_group_sizes = gen_multiple_group_sizes(
|
| 359 |
+
num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0
|
| 360 |
+
)
|
| 361 |
+
out = gen_gmm_output(M, N, device=device, preferred_element_type=output_type)
|
| 362 |
+
bias = None
|
| 363 |
+
if use_bias:
|
| 364 |
+
torch.manual_seed(rng_seed + 1000) # Different seed for bias
|
| 365 |
+
bias = torch.randn(G, N, dtype=input_type, device=device)
|
| 366 |
+
|
| 367 |
+
return lhs, rhs, multiple_group_sizes, out, bias
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
# GMM helpers: get information from tensors.
|
| 371 |
+
# ------------------------------------------------------------------------------
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def get_gmm_shape(
|
| 375 |
+
lhs: Tensor, rhs: Tensor, group_sizes: Tensor
|
| 376 |
+
) -> tuple[int, int, int, int]:
|
| 377 |
+
assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
|
| 378 |
+
assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})."
|
| 379 |
+
assert (
|
| 380 |
+
group_sizes.dim() == 1
|
| 381 |
+
), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})."
|
| 382 |
+
|
| 383 |
+
M, lhs_k = lhs.shape
|
| 384 |
+
rhs_g, rhs_k, N = rhs.shape
|
| 385 |
+
group_sizes_g = group_sizes.shape[0]
|
| 386 |
+
|
| 387 |
+
assert (
|
| 388 |
+
lhs_k == rhs_k
|
| 389 |
+
), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})."
|
| 390 |
+
K = lhs_k
|
| 391 |
+
assert (
|
| 392 |
+
rhs_g == group_sizes_g
|
| 393 |
+
), f"G dimension of rhs and group_sizes don't match (rhs = {rhs_g}, group_sizes = {group_sizes_g})."
|
| 394 |
+
G = rhs_g
|
| 395 |
+
|
| 396 |
+
assert M > 0, f"M must be positive, it's {M}."
|
| 397 |
+
assert K > 0, f"K must be positive, it's {K}."
|
| 398 |
+
assert N > 0, f"N must be positive, it's {N}"
|
| 399 |
+
assert G > 0, f"G must be positive, it's {G}"
|
| 400 |
+
|
| 401 |
+
return M, K, N, G
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def get_gmm_output(
|
| 405 |
+
M: int,
|
| 406 |
+
N: int,
|
| 407 |
+
device: torch.device | str = DEVICE,
|
| 408 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 409 |
+
existing_out: Tensor | None = None,
|
| 410 |
+
) -> Tensor:
|
| 411 |
+
assert M > 0, f"Number of out rows M must be positive (M = {M})."
|
| 412 |
+
assert N > 0, f"Number of out columns N must be positive (N = {N})."
|
| 413 |
+
|
| 414 |
+
if existing_out is not None:
|
| 415 |
+
assert (
|
| 416 |
+
existing_out.device == device
|
| 417 |
+
), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})."
|
| 418 |
+
assert (
|
| 419 |
+
existing_out.dtype == preferred_element_type
|
| 420 |
+
), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})."
|
| 421 |
+
assert existing_out.shape == (
|
| 422 |
+
M,
|
| 423 |
+
N,
|
| 424 |
+
), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(M, N)})."
|
| 425 |
+
return existing_out
|
| 426 |
+
|
| 427 |
+
return gen_gmm_output(
|
| 428 |
+
M,
|
| 429 |
+
N,
|
| 430 |
+
device=device,
|
| 431 |
+
preferred_element_type=preferred_element_type,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def get_gmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]:
|
| 436 |
+
assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
|
| 437 |
+
assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})."
|
| 438 |
+
assert out.dim() == 2, f"out must have 2 dimensions (it's {out.dim()})."
|
| 439 |
+
|
| 440 |
+
lhs_m, lhs_k = lhs.shape
|
| 441 |
+
G, rhs_k, rhs_n = rhs.shape
|
| 442 |
+
out_m, out_n = out.shape
|
| 443 |
+
|
| 444 |
+
assert (
|
| 445 |
+
lhs_m == out_m
|
| 446 |
+
), f"M dimension of lhs and out don't match (lhs = {lhs_m}, rhs = {out_m})."
|
| 447 |
+
M = lhs_m
|
| 448 |
+
assert (
|
| 449 |
+
lhs_k == rhs_k
|
| 450 |
+
), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})."
|
| 451 |
+
K = lhs_k
|
| 452 |
+
assert (
|
| 453 |
+
rhs_n == out_n
|
| 454 |
+
), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})."
|
| 455 |
+
N = rhs_n
|
| 456 |
+
|
| 457 |
+
assert M > 0, f"M must be positive, it's {M}."
|
| 458 |
+
assert K > 0, f"K must be positive, it's {K}."
|
| 459 |
+
assert N > 0, f"N must be positive, it's {N}"
|
| 460 |
+
assert G > 0, f"G must be positive, it's {G}"
|
| 461 |
+
|
| 462 |
+
is_lhs_row_major = lhs.stride() == (K, 1)
|
| 463 |
+
assert is_lhs_row_major, "lhs must be row-major."
|
| 464 |
+
is_rhs_row_major = rhs.stride() == (K * N, N, 1)
|
| 465 |
+
is_rhs_col_major = rhs.stride() == (K * N, 1, K)
|
| 466 |
+
assert (
|
| 467 |
+
is_rhs_row_major != is_rhs_col_major
|
| 468 |
+
), "rhs must be row-major or column-major."
|
| 469 |
+
is_out_row_major = out.stride() == (N, 1)
|
| 470 |
+
assert is_out_row_major, "out must be row-major."
|
| 471 |
+
|
| 472 |
+
# Get rhs leading dimension according to transposition configuration.
|
| 473 |
+
ld_rhs = N if is_rhs_row_major else K
|
| 474 |
+
|
| 475 |
+
return is_rhs_col_major, ld_rhs
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
# TGMM helpers: tensor generation.
|
| 479 |
+
# ------------------------------------------------------------------------------
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def gen_tgmm_input(
|
| 483 |
+
M: int,
|
| 484 |
+
K: int,
|
| 485 |
+
N: int,
|
| 486 |
+
G: int,
|
| 487 |
+
device: torch.device | str = DEVICE,
|
| 488 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 489 |
+
trans_lhs: bool = TRANS_LHS,
|
| 490 |
+
rng_seed: int | None = RNG_SEED,
|
| 491 |
+
unif_group_sizes: bool = False,
|
| 492 |
+
) -> tuple[Tensor, Tensor, Tensor]:
|
| 493 |
+
assert K > 0, f"Number of lhs rows K must be positive (M = {K})."
|
| 494 |
+
assert M > 0, f"Number of lhs columns / rhs rows M must be positive (K = {M})."
|
| 495 |
+
assert N > 0, f"Number of rhs columns N must be positive (N = {N})."
|
| 496 |
+
assert G > 0, f"Number of groups G must be positive (G = {G})."
|
| 497 |
+
|
| 498 |
+
if rng_seed is not None:
|
| 499 |
+
torch.manual_seed(rng_seed)
|
| 500 |
+
|
| 501 |
+
if trans_lhs:
|
| 502 |
+
lhs = torch.randn((M, K), dtype=torch.float32, device=device).T
|
| 503 |
+
else:
|
| 504 |
+
lhs = torch.randn((K, M), dtype=torch.float32, device=device)
|
| 505 |
+
lhs = lhs.to(preferred_element_type)
|
| 506 |
+
|
| 507 |
+
rhs = torch.randn((M, N), dtype=torch.float32, device=device)
|
| 508 |
+
rhs = rhs.to(preferred_element_type)
|
| 509 |
+
|
| 510 |
+
group_sizes = (
|
| 511 |
+
gen_uniform_group_sizes(M, G, device=device)
|
| 512 |
+
if unif_group_sizes
|
| 513 |
+
else gen_group_sizes(M, G, device=device, rng_seed=None)
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
return lhs, rhs, group_sizes
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def gen_tgmm_output(
|
| 520 |
+
K: int,
|
| 521 |
+
N: int,
|
| 522 |
+
G: int,
|
| 523 |
+
device: torch.device | str = DEVICE,
|
| 524 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 525 |
+
) -> Tensor:
|
| 526 |
+
assert K > 0, f"Number of out rows K must be positive (K = {K})."
|
| 527 |
+
assert N > 0, f"Number of out columns N must be positive (N = {N})."
|
| 528 |
+
assert G > 0, f"Number of groups G must be positive (G = {G})."
|
| 529 |
+
|
| 530 |
+
out = torch.empty((G, K, N), dtype=preferred_element_type, device=device)
|
| 531 |
+
|
| 532 |
+
return out
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def gen_tgmm_bias_grad(
|
| 536 |
+
K: int,
|
| 537 |
+
G: int,
|
| 538 |
+
device: torch.device | str = DEVICE,
|
| 539 |
+
with_bias_grad: bool = False,
|
| 540 |
+
) -> Tensor:
|
| 541 |
+
if with_bias_grad:
|
| 542 |
+
assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})."
|
| 543 |
+
assert G > 0, f"Number of groups G must be positive (G = {G})."
|
| 544 |
+
return torch.empty((G, K), device=device, dtype=torch.float32)
|
| 545 |
+
else:
|
| 546 |
+
# Return dummy pointer when bias_grad is not needed.
|
| 547 |
+
# Must be float32 because atomic_add does not support bf16/fp16,
|
| 548 |
+
# and Triton validates the pointer dtype even in dead branches.
|
| 549 |
+
return torch.tensor([], device=device, dtype=torch.float32)
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
def gen_tgmm_tensors(
|
| 553 |
+
M: int,
|
| 554 |
+
K: int,
|
| 555 |
+
N: int,
|
| 556 |
+
G: int,
|
| 557 |
+
num_group_sizes: int,
|
| 558 |
+
device: torch.device | str = DEVICE,
|
| 559 |
+
input_type: torch.dtype = DTYPE,
|
| 560 |
+
output_type: torch.dtype = DTYPE,
|
| 561 |
+
trans_lhs: bool = TRANS_LHS,
|
| 562 |
+
trans_rhs: bool = False,
|
| 563 |
+
rng_seed: int | None = RNG_SEED,
|
| 564 |
+
unif_group_sizes: bool = False,
|
| 565 |
+
use_bias: bool = False,
|
| 566 |
+
) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]:
|
| 567 |
+
lhs, rhs, group_sizes_0 = gen_tgmm_input(
|
| 568 |
+
M,
|
| 569 |
+
K,
|
| 570 |
+
N,
|
| 571 |
+
G,
|
| 572 |
+
device=device,
|
| 573 |
+
preferred_element_type=input_type,
|
| 574 |
+
trans_lhs=trans_lhs,
|
| 575 |
+
rng_seed=rng_seed,
|
| 576 |
+
unif_group_sizes=unif_group_sizes,
|
| 577 |
+
)
|
| 578 |
+
multiple_group_sizes = gen_multiple_group_sizes(
|
| 579 |
+
num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0
|
| 580 |
+
)
|
| 581 |
+
out = gen_tgmm_output(K, N, G, device=device, preferred_element_type=output_type)
|
| 582 |
+
if use_bias:
|
| 583 |
+
bias_grad = gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=True)
|
| 584 |
+
else:
|
| 585 |
+
bias_grad = None
|
| 586 |
+
return lhs, rhs, multiple_group_sizes, out, bias_grad
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
# TGMM helpers: get information from tensors.
|
| 590 |
+
# ------------------------------------------------------------------------------
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def get_tgmm_shape(
|
| 594 |
+
lhs: Tensor, rhs: Tensor, group_sizes: Tensor
|
| 595 |
+
) -> tuple[int, int, int, int]:
|
| 596 |
+
assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
|
| 597 |
+
assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})."
|
| 598 |
+
assert (
|
| 599 |
+
group_sizes.dim() == 1
|
| 600 |
+
), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})."
|
| 601 |
+
|
| 602 |
+
K, lhs_m = lhs.shape
|
| 603 |
+
rhs_m, N = rhs.shape
|
| 604 |
+
G = group_sizes.shape[0]
|
| 605 |
+
|
| 606 |
+
assert (
|
| 607 |
+
lhs_m == rhs_m
|
| 608 |
+
), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})."
|
| 609 |
+
M = lhs_m
|
| 610 |
+
|
| 611 |
+
assert M > 0, f"M must be positive, it's {M}."
|
| 612 |
+
assert K > 0, f"K must be positive, it's {K}."
|
| 613 |
+
assert N > 0, f"N must be positive, it's {N}"
|
| 614 |
+
assert G > 0, f"G must be positive, it's {G}"
|
| 615 |
+
|
| 616 |
+
return M, K, N, G
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def get_tgmm_output(
|
| 620 |
+
K: int,
|
| 621 |
+
N: int,
|
| 622 |
+
G: int,
|
| 623 |
+
device: torch.device | str = DEVICE,
|
| 624 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 625 |
+
existing_out: Tensor | None = None,
|
| 626 |
+
) -> Tensor:
|
| 627 |
+
assert K > 0, f"Number of out rows K must be positive (K = {K})."
|
| 628 |
+
assert N > 0, f"Number of out columns N must be positive (N = {N})."
|
| 629 |
+
assert G > 0, f"Number of groups G must be positive (G = {G})."
|
| 630 |
+
|
| 631 |
+
if existing_out is not None:
|
| 632 |
+
assert (
|
| 633 |
+
existing_out.device == device
|
| 634 |
+
), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})."
|
| 635 |
+
assert (
|
| 636 |
+
existing_out.dtype == preferred_element_type
|
| 637 |
+
), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})."
|
| 638 |
+
assert existing_out.shape == (
|
| 639 |
+
G,
|
| 640 |
+
K,
|
| 641 |
+
N,
|
| 642 |
+
), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(G, K, N)})."
|
| 643 |
+
return existing_out
|
| 644 |
+
|
| 645 |
+
return gen_tgmm_output(
|
| 646 |
+
K,
|
| 647 |
+
N,
|
| 648 |
+
G,
|
| 649 |
+
device=device,
|
| 650 |
+
preferred_element_type=preferred_element_type,
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def get_tgmm_bias_grad(
|
| 655 |
+
K: int,
|
| 656 |
+
G: int,
|
| 657 |
+
device: torch.device | str = DEVICE,
|
| 658 |
+
existing_bias_grad: Tensor | None = None,
|
| 659 |
+
) -> Tensor:
|
| 660 |
+
"""
|
| 661 |
+
Get or validate bias gradient tensor for TGMM.
|
| 662 |
+
|
| 663 |
+
If existing_bias_grad is provided, validates its shape, device, dtype, and stride,
|
| 664 |
+
and always zeros it before returning (since the kernel uses atomic_add).
|
| 665 |
+
If existing_bias_grad is None, returns a dummy tensor (for use when COMPUTE_BIAS_GRAD=False).
|
| 666 |
+
Parameters
|
| 667 |
+
----------
|
| 668 |
+
K : int
|
| 669 |
+
Number of rows in the bias gradient tensor.
|
| 670 |
+
G : int
|
| 671 |
+
Number of groups.
|
| 672 |
+
device : torch.device or str
|
| 673 |
+
Device for the tensor.
|
| 674 |
+
existing_bias_grad : torch.Tensor or None
|
| 675 |
+
Existing bias gradient tensor to validate and use.
|
| 676 |
+
Returns
|
| 677 |
+
-------
|
| 678 |
+
torch.Tensor
|
| 679 |
+
Valid bias gradient tensor or dummy tensor.
|
| 680 |
+
"""
|
| 681 |
+
assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})."
|
| 682 |
+
assert G > 0, f"Number of groups G must be positive (G = {G})."
|
| 683 |
+
|
| 684 |
+
if existing_bias_grad is not None:
|
| 685 |
+
# Validate existing bias_grad tensor.
|
| 686 |
+
expected_shape = (G, K)
|
| 687 |
+
assert (
|
| 688 |
+
tuple(existing_bias_grad.shape) == expected_shape
|
| 689 |
+
), f"bias_grad must have shape {expected_shape}, got {tuple(existing_bias_grad.shape)}."
|
| 690 |
+
assert (
|
| 691 |
+
existing_bias_grad.device == device
|
| 692 |
+
), f"bias_grad must be on the same device (bias_grad = {existing_bias_grad.device}, device = {device})."
|
| 693 |
+
assert (
|
| 694 |
+
existing_bias_grad.dtype == torch.float32
|
| 695 |
+
), f"bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), got {existing_bias_grad.dtype}."
|
| 696 |
+
assert existing_bias_grad.stride() == (
|
| 697 |
+
K,
|
| 698 |
+
1,
|
| 699 |
+
), f"bias_grad must be row-major with stride (K, 1) = ({K}, 1), got {existing_bias_grad.stride()}."
|
| 700 |
+
|
| 701 |
+
# Always zero the tensor since bias_grad represents gradients for the current
|
| 702 |
+
# computation and should start fresh. The kernel uses atomic_add which adds to
|
| 703 |
+
# existing values, so we must zero before the kernel runs.
|
| 704 |
+
existing_bias_grad.zero_()
|
| 705 |
+
|
| 706 |
+
return existing_bias_grad
|
| 707 |
+
|
| 708 |
+
else:
|
| 709 |
+
return gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=False)
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
def get_tgmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]:
|
| 713 |
+
assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
|
| 714 |
+
assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})."
|
| 715 |
+
assert out.dim() == 3, f"out must have 3 dimensions (it's {out.dim()})."
|
| 716 |
+
|
| 717 |
+
lhs_k, lhs_m = lhs.shape
|
| 718 |
+
rhs_m, rhs_n = rhs.shape
|
| 719 |
+
G, out_k, out_n = out.shape
|
| 720 |
+
|
| 721 |
+
assert (
|
| 722 |
+
lhs_m == rhs_m
|
| 723 |
+
), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})."
|
| 724 |
+
M = lhs_m
|
| 725 |
+
assert (
|
| 726 |
+
lhs_k == out_k
|
| 727 |
+
), f"K dimension of lhs and out don't match (lhs = {lhs_k}, rhs = {out_k})."
|
| 728 |
+
K = lhs_k
|
| 729 |
+
assert (
|
| 730 |
+
rhs_n == out_n
|
| 731 |
+
), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})."
|
| 732 |
+
N = rhs_n
|
| 733 |
+
|
| 734 |
+
assert M > 0, f"M must be positive, it's {M}."
|
| 735 |
+
assert K > 0, f"K must be positive, it's {K}."
|
| 736 |
+
assert N > 0, f"N must be positive, it's {N}"
|
| 737 |
+
assert G > 0, f"G must be positive, it's {G}"
|
| 738 |
+
|
| 739 |
+
is_lhs_row_major = lhs.stride() == (M, 1)
|
| 740 |
+
is_lhs_col_major = lhs.stride() == (1, K)
|
| 741 |
+
assert (
|
| 742 |
+
is_lhs_row_major != is_lhs_col_major
|
| 743 |
+
), "lhs must be row-major or column-major."
|
| 744 |
+
is_rhs_row_major = rhs.stride() == (N, 1)
|
| 745 |
+
assert is_rhs_row_major, "rhs must be row-major."
|
| 746 |
+
is_out_row_major = out.stride() == (K * N, N, 1)
|
| 747 |
+
assert is_out_row_major, "out must be row-major."
|
| 748 |
+
|
| 749 |
+
# Get lhs leading dimension according to transposition configuration.
|
| 750 |
+
ld_lhs = M if is_lhs_row_major else K
|
| 751 |
+
|
| 752 |
+
return is_lhs_col_major, ld_lhs
|
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/logger.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# AITER Triton Logger which is singleton object around python logging.
|
| 6 |
+
# Note: Python logging is also a singleton object, but we want to read the
|
| 7 |
+
# env var AITER_LOG_LEVEL once at the beginning. Another alternative is to do
|
| 8 |
+
# this in __init__.py. In fact, that's how CK logger is setup. We can look at
|
| 9 |
+
# switching to that at some point
|
| 10 |
+
#
|
| 11 |
+
# AITER_LOG_LEVEL follows python logging levels
|
| 12 |
+
# DEBUG
|
| 13 |
+
# INFO
|
| 14 |
+
# WARNING
|
| 15 |
+
# ERROR
|
| 16 |
+
# CRITICAL
|
| 17 |
+
#
|
| 18 |
+
class AiterTritonLogger(object):
|
| 19 |
+
_instance = None
|
| 20 |
+
|
| 21 |
+
def __new__(cls):
|
| 22 |
+
if cls._instance is None:
|
| 23 |
+
cls._instance = super(AiterTritonLogger, cls).__new__(cls)
|
| 24 |
+
log_level_str = os.getenv("AITER_TRITON_LOG_LEVEL", "WARNING").upper()
|
| 25 |
+
numeric_level = getattr(logging, log_level_str, logging.WARNING)
|
| 26 |
+
cls._instance._logger = logging.getLogger("AITER_TRITON")
|
| 27 |
+
cls._instance._logger.setLevel(numeric_level)
|
| 28 |
+
|
| 29 |
+
return cls._instance
|
| 30 |
+
|
| 31 |
+
def get_logger(self):
|
| 32 |
+
return self._logger
|
| 33 |
+
|
| 34 |
+
def debug(self, msg):
|
| 35 |
+
self._logger.debug(msg)
|
| 36 |
+
|
| 37 |
+
def info(self, msg):
|
| 38 |
+
self._logger.info(msg)
|
| 39 |
+
|
| 40 |
+
def warning(self, msg):
|
| 41 |
+
self._logger.warning(msg)
|
| 42 |
+
|
| 43 |
+
def error(self, msg):
|
| 44 |
+
self._logger.error(msg)
|
| 45 |
+
|
| 46 |
+
def critical(self, msg):
|
| 47 |
+
self._logger.critical(msg)
|
build/torch211-cxx11-cu126-x86_64-linux/{_megablocks_cuda_ae601bb.abi3.so → _megablocks_cuda_f8f8b50.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:78a283cd033d5770287d652455033307d26b1896681abbeb5ed4d1cba4dbc1fe
|
| 3 |
+
size 13822768
|
build/torch211-cxx11-cu126-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_cuda_f8f8b50
|
| 3 |
+
ops = torch.ops._megablocks_cuda_f8f8b50
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_cuda_f8f8b50::{op_name}"
|
build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py
CHANGED
|
@@ -2,16 +2,16 @@
|
|
| 2 |
# extensions. Otherwise libc10.so cannot be found.
|
| 3 |
import torch
|
| 4 |
|
| 5 |
-
#
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
# import grouped_gemm_backend as backend
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
#
|
| 14 |
-
|
|
|
|
| 15 |
|
| 16 |
def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
|
| 17 |
assert not (trans_a and trans_b)
|
|
|
|
| 2 |
# extensions. Otherwise libc10.so cannot be found.
|
| 3 |
import torch
|
| 4 |
|
| 5 |
+
# On ROCm there is no CUTLASS grouped GEMM; dispatch to the vendored AITER
|
| 6 |
+
# Triton kernels instead. On CUDA we use the compiled CUTLASS `gmm` op.
|
| 7 |
+
_IS_ROCM = torch.version.hip is not None
|
|
|
|
| 8 |
|
| 9 |
+
if _IS_ROCM:
|
| 10 |
+
from .._grouped_gemm_triton import adapter as backend
|
| 11 |
+
else:
|
| 12 |
+
# We import the backend operations from the megablocks package as
|
| 13 |
+
# grouped_gemm is vendored in megablocks in this repository.
|
| 14 |
+
from .._ops import ops as backend # type: ignore
|
| 15 |
|
| 16 |
def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
|
| 17 |
assert not (trans_a and trans_b)
|
build/torch211-cxx11-cu126-x86_64-linux/metadata.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
{
|
| 2 |
"name": "megablocks",
|
| 3 |
-
"id": "
|
| 4 |
"version": 1,
|
| 5 |
"license": "Apache-2.0",
|
| 6 |
"python-depends": [],
|
|
@@ -14,7 +14,8 @@
|
|
| 14 |
"8.6",
|
| 15 |
"8.7",
|
| 16 |
"8.9",
|
| 17 |
-
"9.0"
|
|
|
|
| 18 |
]
|
| 19 |
}
|
| 20 |
}
|
|
|
|
| 1 |
{
|
| 2 |
"name": "megablocks",
|
| 3 |
+
"id": "_megablocks_cuda_f8f8b50",
|
| 4 |
"version": 1,
|
| 5 |
"license": "Apache-2.0",
|
| 6 |
"python-depends": [],
|
|
|
|
| 14 |
"8.6",
|
| 15 |
"8.7",
|
| 16 |
"8.9",
|
| 17 |
+
"9.0",
|
| 18 |
+
"9.0+PTX"
|
| 19 |
]
|
| 20 |
}
|
| 21 |
}
|
build/torch211-cxx11-cu128-x86_64-linux/__init__.py
CHANGED
|
@@ -3,7 +3,9 @@
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
|
| 6 |
-
from .
|
|
|
|
|
|
|
| 7 |
|
| 8 |
from .grouped_gemm import backend as gg_backend
|
| 9 |
from .grouped_gemm import ops as gg_ops
|
|
@@ -136,7 +138,8 @@ def sort(
|
|
| 136 |
Returns:
|
| 137 |
The sorted values tensor
|
| 138 |
"""
|
| 139 |
-
|
|
|
|
| 140 |
|
| 141 |
|
| 142 |
# Convenience functions for common use cases
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
# Stable alias: bare `ops` is shadowed by `from . import layers` below.
|
| 7 |
+
from ._ops import ops as _compiled_ops
|
| 8 |
+
from . import ops
|
| 9 |
|
| 10 |
from .grouped_gemm import backend as gg_backend
|
| 11 |
from .grouped_gemm import ops as gg_ops
|
|
|
|
| 138 |
Returns:
|
| 139 |
The sorted values tensor
|
| 140 |
"""
|
| 141 |
+
_compiled_ops.sort(x, end_bit, x_out, iota_out)
|
| 142 |
+
return x_out
|
| 143 |
|
| 144 |
|
| 145 |
# Convenience functions for common use cases
|
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/__init__.py
ADDED
|
File without changes
|
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py
ADDED
|
File without changes
|
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py
ADDED
|
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
# Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Imports.
|
| 6 |
+
# ------------------------------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
# Python standard library
|
| 9 |
+
import functools
|
| 10 |
+
|
| 11 |
+
# Triton
|
| 12 |
+
import triton
|
| 13 |
+
import triton.language as tl
|
| 14 |
+
|
| 15 |
+
# AITER
|
| 16 |
+
from ..configs import CONFIGS as _CONFIGS
|
| 17 |
+
from ..utils._triton import arch_info
|
| 18 |
+
from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
|
| 19 |
+
|
| 20 |
+
# Kernel config.
|
| 21 |
+
# ------------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@functools.lru_cache()
|
| 25 |
+
def get_config(
|
| 26 |
+
gmm_type: str, M: int, K: int, N: int, G: int, accumulate: bool = False
|
| 27 |
+
) -> dict[str, int]:
|
| 28 |
+
assert gmm_type in {
|
| 29 |
+
"gmm",
|
| 30 |
+
"ptgmm",
|
| 31 |
+
"nptgmm",
|
| 32 |
+
}, f"'{gmm_type}' is an invalid GMM variant."
|
| 33 |
+
dev = arch_info.get_arch()
|
| 34 |
+
assert (
|
| 35 |
+
dev in _CONFIGS
|
| 36 |
+
), f"No GMM configuration tuned for arch '{dev}'. Supported: {sorted(_CONFIGS)}."
|
| 37 |
+
arch_configs = _CONFIGS[dev]
|
| 38 |
+
assert (
|
| 39 |
+
"default" in arch_configs[gmm_type]
|
| 40 |
+
), "Default configuration is absent."
|
| 41 |
+
key = "accumulate" if accumulate else "default"
|
| 42 |
+
return arch_configs[gmm_type][key]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Common code shared by GMM and TGMM kernels.
|
| 46 |
+
# ------------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# XCD remapping followed by 1D PID to 2D grid mapping.
|
| 50 |
+
@triton.jit
|
| 51 |
+
def _remap_xcd_tile_grid(
|
| 52 |
+
tile_in_mm,
|
| 53 |
+
num_row_tiles,
|
| 54 |
+
num_col_tiles,
|
| 55 |
+
GROUP_SIZE: tl.constexpr = 1,
|
| 56 |
+
NUM_XCDS: tl.constexpr = 8,
|
| 57 |
+
):
|
| 58 |
+
return pid_grid(
|
| 59 |
+
remap_xcd(tile_in_mm, num_row_tiles * num_col_tiles, NUM_XCDS=NUM_XCDS),
|
| 60 |
+
num_row_tiles,
|
| 61 |
+
num_col_tiles,
|
| 62 |
+
GROUP_SIZE_M=GROUP_SIZE,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# GMM kernel.
|
| 67 |
+
# ------------------------------------------------------------------------------
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@triton.heuristics(
|
| 71 |
+
{
|
| 72 |
+
"K_DIVISIBLE_BY_BLOCK_SIZE_K": lambda META: META["K"] % META["BLOCK_SIZE_K"]
|
| 73 |
+
== 0,
|
| 74 |
+
}
|
| 75 |
+
)
|
| 76 |
+
@triton.jit
|
| 77 |
+
def gmm_kernel(
|
| 78 |
+
# Tensor pointers:
|
| 79 |
+
lhs_ptr,
|
| 80 |
+
rhs_ptr,
|
| 81 |
+
group_sizes_ptr,
|
| 82 |
+
out_ptr,
|
| 83 |
+
bias_ptr,
|
| 84 |
+
# Tensor shapes:
|
| 85 |
+
M: int,
|
| 86 |
+
K: int,
|
| 87 |
+
N: int,
|
| 88 |
+
G: int,
|
| 89 |
+
# Meta-parameters:
|
| 90 |
+
TRANS_RHS: tl.constexpr,
|
| 91 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 92 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 93 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 94 |
+
K_DIVISIBLE_BY_BLOCK_SIZE_K: tl.constexpr,
|
| 95 |
+
GROUP_SIZE: tl.constexpr,
|
| 96 |
+
GRID_DIM: tl.constexpr,
|
| 97 |
+
USE_BIAS: tl.constexpr,
|
| 98 |
+
):
|
| 99 |
+
tl.assume(M > 0)
|
| 100 |
+
tl.assume(K > 0)
|
| 101 |
+
tl.assume(N > 0)
|
| 102 |
+
tl.assume(G > 0)
|
| 103 |
+
|
| 104 |
+
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
| 105 |
+
tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0")
|
| 106 |
+
|
| 107 |
+
# Current tile. Each program computes multiple tiles of each group.
|
| 108 |
+
tile = tl.program_id(0)
|
| 109 |
+
tl.device_assert(tile >= 0, "tile < 0 (at initialization)")
|
| 110 |
+
|
| 111 |
+
# Tile limit of last MM problem (inclusive).
|
| 112 |
+
last_mm_tile = 0
|
| 113 |
+
|
| 114 |
+
# Last input row of lhs and output row of out. Each group reads some rows of
|
| 115 |
+
# lhs and writes some rows to out.
|
| 116 |
+
last_m = 0
|
| 117 |
+
|
| 118 |
+
# Loop through all (m, K, N) MM problems:
|
| 119 |
+
# (m, K) x (K, N) = (m, N)
|
| 120 |
+
# sum(m) = M
|
| 121 |
+
for g in range(G):
|
| 122 |
+
# Get m dimension of current MM problem.
|
| 123 |
+
m = tl.load(group_sizes_ptr + g)
|
| 124 |
+
# m can be zero if group is empty
|
| 125 |
+
tl.device_assert(m >= 0, "m < 0")
|
| 126 |
+
|
| 127 |
+
num_m_tiles = tl.cdiv(m, BLOCK_SIZE_M)
|
| 128 |
+
# num_m_tiles can be zero if group is empty
|
| 129 |
+
tl.device_assert(num_m_tiles >= 0, "num_m_tiles < 0")
|
| 130 |
+
|
| 131 |
+
num_tiles = num_m_tiles * num_n_tiles
|
| 132 |
+
# num_tiles can be zero if group is empty
|
| 133 |
+
tl.device_assert(num_tiles >= 0, "num_tiles < 0")
|
| 134 |
+
|
| 135 |
+
# Loop through tiles of current MM problem.
|
| 136 |
+
while tile >= last_mm_tile and tile < last_mm_tile + num_tiles:
|
| 137 |
+
# Figure out tile coordinates in current MM problem.
|
| 138 |
+
tile_in_mm = tile - last_mm_tile
|
| 139 |
+
tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0")
|
| 140 |
+
|
| 141 |
+
tile_m, tile_n = _remap_xcd_tile_grid(
|
| 142 |
+
tile_in_mm, num_m_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Do regular MM:
|
| 146 |
+
|
| 147 |
+
tl.device_assert(tile_m * BLOCK_SIZE_M >= 0, "tile_m * BLOCK_SIZE_M < 0")
|
| 148 |
+
tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0")
|
| 149 |
+
|
| 150 |
+
offs_lhs_m = (
|
| 151 |
+
tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 152 |
+
) % m
|
| 153 |
+
offs_rhs_n = (
|
| 154 |
+
tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 155 |
+
) % N
|
| 156 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
|
| 157 |
+
|
| 158 |
+
lhs_ptrs = lhs_ptr + (last_m + offs_lhs_m[:, None]) * K + offs_k[None, :]
|
| 159 |
+
|
| 160 |
+
if TRANS_RHS:
|
| 161 |
+
rhs_ptrs = (
|
| 162 |
+
rhs_ptr
|
| 163 |
+
+ g.to(tl.int64) * K * N
|
| 164 |
+
+ offs_k[:, None]
|
| 165 |
+
+ offs_rhs_n[None, :] * K
|
| 166 |
+
)
|
| 167 |
+
else:
|
| 168 |
+
rhs_ptrs = (
|
| 169 |
+
rhs_ptr
|
| 170 |
+
+ g.to(tl.int64) * K * N
|
| 171 |
+
+ offs_k[:, None] * N
|
| 172 |
+
+ offs_rhs_n[None, :]
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 176 |
+
|
| 177 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 178 |
+
if K_DIVISIBLE_BY_BLOCK_SIZE_K:
|
| 179 |
+
lhs = tl.load(lhs_ptrs)
|
| 180 |
+
rhs = tl.load(rhs_ptrs)
|
| 181 |
+
else:
|
| 182 |
+
k_mask_limit = K - k * BLOCK_SIZE_K
|
| 183 |
+
lhs = tl.load(
|
| 184 |
+
lhs_ptrs, mask=offs_k[None, :] < k_mask_limit, other=0
|
| 185 |
+
)
|
| 186 |
+
rhs = tl.load(
|
| 187 |
+
rhs_ptrs, mask=offs_k[:, None] < k_mask_limit, other=0
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
acc = tl.dot(lhs, rhs, acc=acc)
|
| 191 |
+
|
| 192 |
+
lhs_ptrs += BLOCK_SIZE_K
|
| 193 |
+
|
| 194 |
+
if TRANS_RHS:
|
| 195 |
+
rhs_ptrs += BLOCK_SIZE_K
|
| 196 |
+
else:
|
| 197 |
+
rhs_ptrs += BLOCK_SIZE_K * N
|
| 198 |
+
|
| 199 |
+
# Add bias if enabled
|
| 200 |
+
if USE_BIAS:
|
| 201 |
+
offs_bias_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(
|
| 202 |
+
0, BLOCK_SIZE_N
|
| 203 |
+
)
|
| 204 |
+
bias_ptrs = bias_ptr + g.to(tl.int64) * N + offs_bias_n
|
| 205 |
+
bias = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0)
|
| 206 |
+
# Convert bias to float32 to match accumulator precision
|
| 207 |
+
bias = bias.to(tl.float32)
|
| 208 |
+
# Broadcast bias across M dimension and add in float32
|
| 209 |
+
acc += bias[None, :]
|
| 210 |
+
|
| 211 |
+
# Convert to output dtype after all computations
|
| 212 |
+
acc = acc.to(out_ptr.type.element_ty)
|
| 213 |
+
|
| 214 |
+
offs_out_m = tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 215 |
+
offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 216 |
+
|
| 217 |
+
out_ptrs = (
|
| 218 |
+
out_ptr + (last_m + offs_out_m[:, None]) * N + offs_out_n[None, :]
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
tl.store(
|
| 222 |
+
out_ptrs,
|
| 223 |
+
acc,
|
| 224 |
+
mask=(offs_out_m[:, None] < m) & (offs_out_n[None, :] < N),
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Go to the next tile by advancing number of programs.
|
| 228 |
+
tile += GRID_DIM
|
| 229 |
+
tl.device_assert(tile > 0, "tile <= 0 (at update)")
|
| 230 |
+
|
| 231 |
+
# Get ready to go to the next MM problem.
|
| 232 |
+
|
| 233 |
+
last_mm_tile += num_tiles
|
| 234 |
+
# last_mm_tile can be zero if group 0 is skipped
|
| 235 |
+
tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)")
|
| 236 |
+
|
| 237 |
+
last_m += m
|
| 238 |
+
# last_m can be zero if group 0 is skipped
|
| 239 |
+
tl.device_assert(last_m >= 0, "last_m < 0 (at update)")
|
| 240 |
+
tl.device_assert(last_m <= M, "last_m > M (at update)")
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
# Persistent TGMM kernel.
|
| 244 |
+
# ------------------------------------------------------------------------------
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
@triton.jit
|
| 248 |
+
def tgmm_persistent_kernel(
|
| 249 |
+
# Tensor pointers:
|
| 250 |
+
lhs_ptr,
|
| 251 |
+
rhs_ptr,
|
| 252 |
+
group_sizes_ptr,
|
| 253 |
+
out_ptr,
|
| 254 |
+
bias_grad_ptr,
|
| 255 |
+
# Tensor shapes:
|
| 256 |
+
M: int,
|
| 257 |
+
K: int,
|
| 258 |
+
N: int,
|
| 259 |
+
G: int,
|
| 260 |
+
# Meta-parameters:
|
| 261 |
+
TRANS_LHS: tl.constexpr,
|
| 262 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 263 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 264 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 265 |
+
GROUP_SIZE: tl.constexpr,
|
| 266 |
+
GRID_DIM: tl.constexpr,
|
| 267 |
+
COMPUTE_BIAS_GRAD: tl.constexpr,
|
| 268 |
+
ACCUMULATE: tl.constexpr,
|
| 269 |
+
):
|
| 270 |
+
tl.assume(M > 0)
|
| 271 |
+
tl.assume(K > 0)
|
| 272 |
+
tl.assume(N > 0)
|
| 273 |
+
tl.assume(G > 0)
|
| 274 |
+
|
| 275 |
+
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
| 276 |
+
tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0")
|
| 277 |
+
|
| 278 |
+
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
| 279 |
+
tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0")
|
| 280 |
+
|
| 281 |
+
num_tiles = num_k_tiles * num_n_tiles
|
| 282 |
+
tl.device_assert(num_tiles > 0, "num_tiles <= 0")
|
| 283 |
+
|
| 284 |
+
# Current tile. Each program computes multiple tiles of each group.
|
| 285 |
+
tile = tl.program_id(0)
|
| 286 |
+
tl.device_assert(tile >= 0, "tile < 0 (at initialization)")
|
| 287 |
+
|
| 288 |
+
# Tile limit of last MM problem (inclusive).
|
| 289 |
+
last_mm_tile = 0
|
| 290 |
+
|
| 291 |
+
# Last input column of lhs and input row of rhs. Each group reads some
|
| 292 |
+
# columns of lhs and some rows of rhs.
|
| 293 |
+
last_m = 0
|
| 294 |
+
|
| 295 |
+
# Loop through all (K, m, N) MM problems:
|
| 296 |
+
# (K, m) x (m, N) = (K, N)
|
| 297 |
+
# sum(m) = M
|
| 298 |
+
for g in range(G):
|
| 299 |
+
# Get m dimension of current MM problem.
|
| 300 |
+
m = tl.load(group_sizes_ptr + g)
|
| 301 |
+
# m can be zero if group is empty
|
| 302 |
+
tl.device_assert(m >= 0, "m < 0")
|
| 303 |
+
|
| 304 |
+
# Loop through tiles of current MM problem.
|
| 305 |
+
while tile >= last_mm_tile and tile < last_mm_tile + num_tiles:
|
| 306 |
+
# Figure out tile coordinates in current MM problem.
|
| 307 |
+
tile_in_mm = tile - last_mm_tile
|
| 308 |
+
tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0")
|
| 309 |
+
|
| 310 |
+
tile_k, tile_n = _remap_xcd_tile_grid(
|
| 311 |
+
tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Do regular MM:
|
| 315 |
+
|
| 316 |
+
tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0")
|
| 317 |
+
tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0")
|
| 318 |
+
|
| 319 |
+
offs_lhs_k = (
|
| 320 |
+
tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
| 321 |
+
) % K
|
| 322 |
+
offs_rhs_n = (
|
| 323 |
+
tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 324 |
+
) % N
|
| 325 |
+
offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
| 326 |
+
|
| 327 |
+
if TRANS_LHS:
|
| 328 |
+
lhs_ptrs = (
|
| 329 |
+
lhs_ptr + offs_lhs_k[:, None] + (last_m + offs_m[None, :]) * K
|
| 330 |
+
)
|
| 331 |
+
else:
|
| 332 |
+
lhs_ptrs = (
|
| 333 |
+
lhs_ptr + offs_lhs_k[:, None] * M + (last_m + offs_m[None, :])
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
rhs_ptrs = rhs_ptr + (last_m + offs_m[:, None]) * N + offs_rhs_n[None, :]
|
| 337 |
+
|
| 338 |
+
loop_m = tl.cdiv(m, BLOCK_SIZE_M)
|
| 339 |
+
m_divisible_by_block_m = m % BLOCK_SIZE_M == 0
|
| 340 |
+
if not m_divisible_by_block_m:
|
| 341 |
+
loop_m -= 1
|
| 342 |
+
|
| 343 |
+
acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32)
|
| 344 |
+
|
| 345 |
+
# Initialize bias accumulator
|
| 346 |
+
bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
|
| 347 |
+
|
| 348 |
+
for _ in range(0, loop_m):
|
| 349 |
+
lhs = tl.load(lhs_ptrs)
|
| 350 |
+
rhs = tl.load(rhs_ptrs)
|
| 351 |
+
|
| 352 |
+
acc = tl.dot(lhs, rhs, acc=acc)
|
| 353 |
+
|
| 354 |
+
# Accumulate for bias gradient: sum lhs across M dimension
|
| 355 |
+
if COMPUTE_BIAS_GRAD and tile_n == 0:
|
| 356 |
+
bias_acc += tl.sum(
|
| 357 |
+
lhs, axis=1
|
| 358 |
+
) # Sum across M dimension [K, M] -> [K]
|
| 359 |
+
|
| 360 |
+
if TRANS_LHS:
|
| 361 |
+
lhs_ptrs += BLOCK_SIZE_M * K
|
| 362 |
+
else:
|
| 363 |
+
lhs_ptrs += BLOCK_SIZE_M
|
| 364 |
+
|
| 365 |
+
rhs_ptrs += BLOCK_SIZE_M * N
|
| 366 |
+
|
| 367 |
+
if not m_divisible_by_block_m:
|
| 368 |
+
offs_lhs_k = (
|
| 369 |
+
tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
| 370 |
+
) % K
|
| 371 |
+
offs_rhs_n = (
|
| 372 |
+
tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 373 |
+
) % N
|
| 374 |
+
offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 375 |
+
lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0)
|
| 376 |
+
rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0)
|
| 377 |
+
acc = tl.dot(lhs, rhs, acc=acc)
|
| 378 |
+
|
| 379 |
+
# Accumulate last chunk for bias gradient
|
| 380 |
+
if COMPUTE_BIAS_GRAD and tile_n == 0:
|
| 381 |
+
bias_acc += tl.sum(lhs, axis=1)
|
| 382 |
+
|
| 383 |
+
acc = acc.to(out_ptr.type.element_ty)
|
| 384 |
+
|
| 385 |
+
offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
| 386 |
+
offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 387 |
+
|
| 388 |
+
out_ptrs = (
|
| 389 |
+
out_ptr
|
| 390 |
+
+ g.to(tl.int64) * K * N
|
| 391 |
+
+ offs_out_k[:, None] * N
|
| 392 |
+
+ offs_out_n[None, :]
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N)
|
| 396 |
+
if ACCUMULATE:
|
| 397 |
+
# Load existing values and add to them (like beta=1 in BLAS)
|
| 398 |
+
old_vals = tl.load(out_ptrs, mask=mask, other=0.0)
|
| 399 |
+
tl.store(out_ptrs, acc + old_vals, mask=mask)
|
| 400 |
+
else:
|
| 401 |
+
# Overwrite output (like beta=0 in BLAS)
|
| 402 |
+
tl.store(out_ptrs, acc, mask=mask)
|
| 403 |
+
|
| 404 |
+
# Store bias gradient (only for first N tile, sum across all M)
|
| 405 |
+
if COMPUTE_BIAS_GRAD and tile_n == 0:
|
| 406 |
+
# Keep as float32 for atomic_add (bf16 not supported for atomics)
|
| 407 |
+
bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k
|
| 408 |
+
# Use atomic add since multiple K-tiles may write to same expert's bias
|
| 409 |
+
tl.atomic_add(
|
| 410 |
+
bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed"
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
# Go to the next tile by advancing number of programs.
|
| 414 |
+
tile += GRID_DIM
|
| 415 |
+
tl.device_assert(tile > 0, "tile <= 0 (at update)")
|
| 416 |
+
|
| 417 |
+
# Get ready to go to the next MM problem.
|
| 418 |
+
|
| 419 |
+
last_mm_tile += num_tiles
|
| 420 |
+
# last_mm_tile can be zero if group 0 is skipped
|
| 421 |
+
tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)")
|
| 422 |
+
|
| 423 |
+
last_m += m
|
| 424 |
+
# last_m can be zero if group 0 is skipped
|
| 425 |
+
tl.device_assert(last_m >= 0, "last_m < 0 (at update)")
|
| 426 |
+
tl.device_assert(last_m <= M, "last_m > M (at update)")
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
# Regular non-persistent TGMM kernel.
|
| 430 |
+
# ------------------------------------------------------------------------------
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
@triton.heuristics({"BLOCK_SIZE_G": lambda META: triton.next_power_of_2(META["G"])})
|
| 434 |
+
@triton.jit
|
| 435 |
+
def tgmm_non_persistent_kernel(
|
| 436 |
+
# Tensor pointers:
|
| 437 |
+
lhs_ptr,
|
| 438 |
+
rhs_ptr,
|
| 439 |
+
group_sizes_ptr,
|
| 440 |
+
out_ptr,
|
| 441 |
+
bias_grad_ptr,
|
| 442 |
+
# Tensor shapes:
|
| 443 |
+
M: int,
|
| 444 |
+
K: int,
|
| 445 |
+
N: int,
|
| 446 |
+
G: int,
|
| 447 |
+
# Meta-parameters:
|
| 448 |
+
TRANS_LHS: tl.constexpr,
|
| 449 |
+
BLOCK_SIZE_G: tl.constexpr,
|
| 450 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 451 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 452 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 453 |
+
GROUP_SIZE: tl.constexpr,
|
| 454 |
+
COMPUTE_BIAS_GRAD: tl.constexpr,
|
| 455 |
+
ACCUMULATE: tl.constexpr,
|
| 456 |
+
):
|
| 457 |
+
tl.assume(M > 0)
|
| 458 |
+
tl.assume(K > 0)
|
| 459 |
+
tl.assume(N > 0)
|
| 460 |
+
tl.assume(G > 0)
|
| 461 |
+
|
| 462 |
+
# Get group ID from grid.
|
| 463 |
+
g = tl.program_id(0)
|
| 464 |
+
tl.device_assert(g >= 0, "g < 0")
|
| 465 |
+
tl.device_assert(g < G, "g >= G")
|
| 466 |
+
|
| 467 |
+
# Get m dimension of current MM group.
|
| 468 |
+
m = tl.load(group_sizes_ptr + g)
|
| 469 |
+
# m can be zero if group is empty.
|
| 470 |
+
tl.device_assert(m >= 0, "m < 0")
|
| 471 |
+
|
| 472 |
+
# Skip empty groups.
|
| 473 |
+
if m == 0:
|
| 474 |
+
return
|
| 475 |
+
|
| 476 |
+
# Compute sum(group_sizes) until current group g.
|
| 477 |
+
# It's the starting column of lhs and starting row of rhs.
|
| 478 |
+
offs_g = tl.arange(0, BLOCK_SIZE_G)
|
| 479 |
+
group_sizes = tl.load(group_sizes_ptr + offs_g, mask=offs_g < g, other=0)
|
| 480 |
+
start_m = tl.sum(group_sizes)
|
| 481 |
+
|
| 482 |
+
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
| 483 |
+
tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0")
|
| 484 |
+
|
| 485 |
+
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
| 486 |
+
tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0")
|
| 487 |
+
|
| 488 |
+
# Get MM tile from grid.
|
| 489 |
+
tile_in_mm = tl.program_id(1)
|
| 490 |
+
tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0")
|
| 491 |
+
|
| 492 |
+
tile_k, tile_n = _remap_xcd_tile_grid(
|
| 493 |
+
tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0")
|
| 497 |
+
tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0")
|
| 498 |
+
|
| 499 |
+
offs_lhs_k = (tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K
|
| 500 |
+
offs_rhs_n = (tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
| 501 |
+
offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
| 502 |
+
|
| 503 |
+
if TRANS_LHS:
|
| 504 |
+
lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] + (start_m + offs_m[None, :]) * K
|
| 505 |
+
else:
|
| 506 |
+
lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] * M + (start_m + offs_m[None, :])
|
| 507 |
+
|
| 508 |
+
rhs_ptrs = rhs_ptr + (start_m + offs_m[:, None]) * N + offs_rhs_n[None, :]
|
| 509 |
+
|
| 510 |
+
loop_m = tl.cdiv(m, BLOCK_SIZE_M)
|
| 511 |
+
m_divisible_by_block_m = m % BLOCK_SIZE_M == 0
|
| 512 |
+
if not m_divisible_by_block_m:
|
| 513 |
+
loop_m -= 1
|
| 514 |
+
|
| 515 |
+
acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32)
|
| 516 |
+
# Initialize bias accumulator
|
| 517 |
+
bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
|
| 518 |
+
|
| 519 |
+
for _ in range(0, loop_m):
|
| 520 |
+
lhs = tl.load(lhs_ptrs)
|
| 521 |
+
rhs = tl.load(rhs_ptrs)
|
| 522 |
+
|
| 523 |
+
acc = tl.dot(lhs, rhs, acc=acc)
|
| 524 |
+
|
| 525 |
+
# Accumulate for bias gradient: sum lhs across M dimension
|
| 526 |
+
if COMPUTE_BIAS_GRAD and tile_n == 0:
|
| 527 |
+
bias_acc += tl.sum(lhs, axis=1) # [K, M] -> [K]
|
| 528 |
+
|
| 529 |
+
if TRANS_LHS:
|
| 530 |
+
lhs_ptrs += BLOCK_SIZE_M * K
|
| 531 |
+
else:
|
| 532 |
+
lhs_ptrs += BLOCK_SIZE_M
|
| 533 |
+
|
| 534 |
+
rhs_ptrs += BLOCK_SIZE_M * N
|
| 535 |
+
|
| 536 |
+
if not m_divisible_by_block_m:
|
| 537 |
+
offs_lhs_k = (
|
| 538 |
+
tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
| 539 |
+
) % K
|
| 540 |
+
offs_rhs_n = (
|
| 541 |
+
tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 542 |
+
) % N
|
| 543 |
+
offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 544 |
+
lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0)
|
| 545 |
+
rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0)
|
| 546 |
+
acc = tl.dot(lhs, rhs, acc=acc)
|
| 547 |
+
# Accumulate last chunk for bias gradient
|
| 548 |
+
if COMPUTE_BIAS_GRAD and tile_n == 0:
|
| 549 |
+
bias_acc += tl.sum(lhs, axis=1)
|
| 550 |
+
|
| 551 |
+
acc = acc.to(out_ptr.type.element_ty)
|
| 552 |
+
|
| 553 |
+
offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
| 554 |
+
offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 555 |
+
|
| 556 |
+
out_ptrs = (
|
| 557 |
+
out_ptr + g.to(tl.int64) * K * N + offs_out_k[:, None] * N + offs_out_n[None, :]
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N)
|
| 561 |
+
if ACCUMULATE:
|
| 562 |
+
# Load existing values and add to them (like beta=1 in BLAS)
|
| 563 |
+
old_vals = tl.load(out_ptrs, mask=mask, other=0.0)
|
| 564 |
+
tl.store(out_ptrs, acc + old_vals, mask=mask)
|
| 565 |
+
else:
|
| 566 |
+
# Overwrite output (like beta=0 in BLAS)
|
| 567 |
+
tl.store(out_ptrs, acc, mask=mask)
|
| 568 |
+
|
| 569 |
+
# Store bias gradient (only for first N tile, sum across all M)
|
| 570 |
+
if COMPUTE_BIAS_GRAD and tile_n == 0:
|
| 571 |
+
# Keep as float32 for atomic_add (bf16/fp16 not supported for atomics)
|
| 572 |
+
bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k
|
| 573 |
+
# Use atomic add since multiple K-tiles may write to same expert's bias
|
| 574 |
+
tl.atomic_add(bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed")
|
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/adapter.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Adapt AITER's Triton grouped GEMM to MegaBlocks' ``gmm`` calling convention.
|
| 3 |
+
|
| 4 |
+
MegaBlocks (following tgale96/grouped_gemm) uses a single ``gmm`` entry point
|
| 5 |
+
with ``trans_a`` / ``trans_b`` flags:
|
| 6 |
+
|
| 7 |
+
* ``trans_a=False, trans_b=False``: a(M,K) @ b(G,K,N) -> c(M,N)
|
| 8 |
+
* ``trans_a=False, trans_b=True`` : a(M,K) @ b(G,N,K)^T -> c(M,N) (dgrad)
|
| 9 |
+
* ``trans_a=True`` : a(M,K)^T @ b(M,N) per group -> c(G,K,N) (wgrad)
|
| 10 |
+
|
| 11 |
+
AITER exposes these as two kernels: ``gmm`` ((M,K)@(G,K,N)->(M,N), transposition
|
| 12 |
+
of the 3D operand inferred from strides) and ``ptgmm`` ((K,M)@(M,N)->(G,K,N),
|
| 13 |
+
transposition of the 2D operand inferred from strides).
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
from .gmm import gmm as _aiter_gmm
|
| 19 |
+
from .gmm import ptgmm as _aiter_ptgmm
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def gmm(a, b, c, batch_sizes, trans_a=False, trans_b=False):
|
| 23 |
+
# AITER requires group sizes to be int32 and to live on the compute device.
|
| 24 |
+
group_sizes = batch_sizes.to(device=a.device, dtype=torch.int32)
|
| 25 |
+
|
| 26 |
+
# AITER asserts exact strides: gmm wants lhs/rhs row-major (a transposed
|
| 27 |
+
# 3D operand must be exactly column-major), tgmm wants rhs row-major and
|
| 28 |
+
# lhs row/column-major. Make operands contiguous first so the transposed
|
| 29 |
+
# views have the precise strides the kernels expect. `.contiguous()` is a
|
| 30 |
+
# no-op when the tensor is already contiguous.
|
| 31 |
+
if trans_a:
|
| 32 |
+
# Weight gradient: a(M,K), b(M,N) -> c(G,K,N).
|
| 33 |
+
# Pass a transposed so AITER sees lhs(K,M) column-major (TRANS_LHS).
|
| 34 |
+
_aiter_ptgmm(
|
| 35 |
+
a.contiguous().transpose(0, 1),
|
| 36 |
+
b.contiguous(),
|
| 37 |
+
group_sizes,
|
| 38 |
+
preferred_element_type=c.dtype,
|
| 39 |
+
existing_out=c,
|
| 40 |
+
)
|
| 41 |
+
else:
|
| 42 |
+
# trans_b contracts b's last dim: pass a column-major (G,K,N) view.
|
| 43 |
+
rhs = b.contiguous()
|
| 44 |
+
if trans_b:
|
| 45 |
+
rhs = rhs.transpose(1, 2)
|
| 46 |
+
_aiter_gmm(
|
| 47 |
+
a.contiguous(),
|
| 48 |
+
rhs,
|
| 49 |
+
group_sizes,
|
| 50 |
+
preferred_element_type=c.dtype,
|
| 51 |
+
existing_out=c,
|
| 52 |
+
)
|
| 53 |
+
return c
|
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/configs.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
# Tuned GMM configs vendored from ROCm/aiter (aiter/ops/triton/configs/).
|
| 3 |
+
# Inlined as a Python module so packaging always includes them.
|
| 4 |
+
|
| 5 |
+
CONFIGS = {'gfx1250': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx942': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx950': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}}
|
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/gmm.py
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Imports.
|
| 6 |
+
# ------------------------------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
# PyTorch
|
| 9 |
+
import torch
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
|
| 12 |
+
# Triton
|
| 13 |
+
import triton
|
| 14 |
+
|
| 15 |
+
# AITER: GMM utility functions
|
| 16 |
+
from .utils.gmm_common import (
|
| 17 |
+
DTYPE,
|
| 18 |
+
is_power_of_2,
|
| 19 |
+
check_input_device_dtype,
|
| 20 |
+
check_bias_shape_stride,
|
| 21 |
+
get_gmm_shape,
|
| 22 |
+
get_gmm_output,
|
| 23 |
+
get_gmm_transposition,
|
| 24 |
+
get_tgmm_shape,
|
| 25 |
+
get_tgmm_output,
|
| 26 |
+
get_tgmm_bias_grad,
|
| 27 |
+
get_tgmm_transposition,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# AITER: GMM Triton kernels
|
| 31 |
+
from ._triton_kernels.gmm import (
|
| 32 |
+
gmm_kernel,
|
| 33 |
+
tgmm_persistent_kernel,
|
| 34 |
+
tgmm_non_persistent_kernel,
|
| 35 |
+
get_config,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# GMM PyTorch wrapper.
|
| 39 |
+
# ------------------------------------------------------------------------------
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _gmm_grid(
|
| 43 |
+
N: int,
|
| 44 |
+
block_size_m: int,
|
| 45 |
+
block_size_n: int,
|
| 46 |
+
group_sizes: Tensor,
|
| 47 |
+
grid_dim: int,
|
| 48 |
+
) -> tuple[int]:
|
| 49 |
+
assert N > 0, f"N must be positive, it's {N}."
|
| 50 |
+
assert is_power_of_2(
|
| 51 |
+
block_size_m
|
| 52 |
+
), f"M-dimension tile size must be a power of 2 (it's {block_size_m})."
|
| 53 |
+
assert is_power_of_2(
|
| 54 |
+
block_size_n
|
| 55 |
+
), f"N-dimension tile size must be a power of 2 (it's {block_size_n})."
|
| 56 |
+
assert torch.all(group_sizes >= 0).item(), "All group_sizes must be non-negative."
|
| 57 |
+
assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})."
|
| 58 |
+
num_m_tiles = (group_sizes + block_size_m - 1) // block_size_m
|
| 59 |
+
assert torch.all(num_m_tiles >= 0).item(), "All num_m_tiles must be non-negative."
|
| 60 |
+
num_n_tiles = triton.cdiv(N, block_size_n)
|
| 61 |
+
assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}."
|
| 62 |
+
num_tiles = torch.sum(num_m_tiles * num_n_tiles).item()
|
| 63 |
+
assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}."
|
| 64 |
+
num_programs = int(min(grid_dim, num_tiles))
|
| 65 |
+
assert num_programs > 0, f"num_programs must be positive, it's {num_programs}."
|
| 66 |
+
return (num_programs,)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def gmm(
|
| 70 |
+
lhs: Tensor,
|
| 71 |
+
rhs: Tensor,
|
| 72 |
+
group_sizes: Tensor,
|
| 73 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 74 |
+
existing_out: Tensor | None = None,
|
| 75 |
+
config: dict[str, int] | None = None,
|
| 76 |
+
bias: Tensor | None = None,
|
| 77 |
+
) -> Tensor:
|
| 78 |
+
"""
|
| 79 |
+
Perform Group Matrix Multiplication (GMM): out = lhs @ rhs + bias
|
| 80 |
+
|
| 81 |
+
lhs rows are divided into G groups. Each group of lhs rows is matrix multiplied with a plane of
|
| 82 |
+
rhs 3D tensor and then stored in a slice of out. In PyTorch parlance, it can be implemented as
|
| 83 |
+
follows for a given group g:
|
| 84 |
+
out[group_start:group_end, :] = lhs[group_start:group_end, :] @ rhs[g] + bias[g]
|
| 85 |
+
|
| 86 |
+
The size of each group, and their respective start and end positions are specified by
|
| 87 |
+
group_sizes tensor. For instance, suppose that group_sizes = [3, 2, 4, 1]. In this particular
|
| 88 |
+
case we have 4 groups. The 1st group starts at 0 and ends at 2, the second group starts at 3 and
|
| 89 |
+
ends at 4, the third group starts at 5 and ends at 8, and the fourth and final group consists of
|
| 90 |
+
just the 10th (last) row of lhs.
|
| 91 |
+
|
| 92 |
+
Parameters
|
| 93 |
+
----------
|
| 94 |
+
lhs : torch.Tensor
|
| 95 |
+
Left-hand side 2D input tensor. Shape: (M, K).
|
| 96 |
+
lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type.
|
| 97 |
+
lhs must be on the same device of rhs and group_sizes.
|
| 98 |
+
rhs : torch.Tensor
|
| 99 |
+
Right-hand side 3D input tensor. Shape: (G, K, N).
|
| 100 |
+
rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type.
|
| 101 |
+
rhs must be on the same device of lhs and group_sizes.
|
| 102 |
+
group_sizes : torch.Tensor
|
| 103 |
+
1D input tensor describing group sizes. Shape: (G,).
|
| 104 |
+
group_sizes data type must be torch.int32 and all its elements must be non-negative.
|
| 105 |
+
group_sizes must be on the same device of lhs and rhs.
|
| 106 |
+
preferred_element_type : torch.dtype, optional
|
| 107 |
+
Desired data type for output tensor. Default is torch.bfloat16.
|
| 108 |
+
Supported output types are torch.float16 and torch.bfloat16.
|
| 109 |
+
existing_out : torch.Tensor or None, optional
|
| 110 |
+
Preallocated output tensor. Default is None.
|
| 111 |
+
If provided, results are written into this tensor. Otherwise, a new output tensor is
|
| 112 |
+
allocated.
|
| 113 |
+
If provided then it must have shape (M, N), its data type must match preferred_element_type
|
| 114 |
+
and it must be on the same device of other input tensors.
|
| 115 |
+
config : dict[str, int] or None, optional
|
| 116 |
+
Optional dictionary with kernel metaparameters. If absent, config will be queried from
|
| 117 |
+
internal tuning database.
|
| 118 |
+
bias : torch.Tensor or None, optional
|
| 119 |
+
Optional bias tensor. Shape: (G, N).
|
| 120 |
+
If provided, bias data type must match lhs and rhs data type, and bias must be on the same
|
| 121 |
+
device as other input tensors. Each group g adds bias[g] to the output.
|
| 122 |
+
|
| 123 |
+
Returns
|
| 124 |
+
-------
|
| 125 |
+
torch.Tensor
|
| 126 |
+
The computed output 2D tensor. Shape: (M, N).
|
| 127 |
+
Output tensor data type is given by preferred_element_type.
|
| 128 |
+
If existing_out is provided then existing_out is also returned.
|
| 129 |
+
|
| 130 |
+
Implementation Notes
|
| 131 |
+
--------------------
|
| 132 |
+
- GMM is implemented with a persistent Triton kernel.
|
| 133 |
+
- lhs must be row-major (lhs.stride() == (K, 1)).
|
| 134 |
+
- rhs can be row-major (rhs.stride() == (K * N, N, 1)) or column-major (rhs.stride() ==
|
| 135 |
+
(K * N, 1, K)). If rhs is row-major then kernel parameter TRANS_RHS == False, this is useful
|
| 136 |
+
for implementing forward pass. If rhs is column-major then kernel parameter TRANS_RHS == True,
|
| 137 |
+
this is useful for computing the lhs derivative in the backward pass, while fusing the
|
| 138 |
+
transposition.
|
| 139 |
+
- out must be row-major (out.stride() == (N, 1)).
|
| 140 |
+
- bias must be row-major (bias.stride() == (N, 1)) if provided.
|
| 141 |
+
"""
|
| 142 |
+
use_bias = bias is not None
|
| 143 |
+
check_input_device_dtype(lhs, rhs, group_sizes, bias)
|
| 144 |
+
|
| 145 |
+
M, K, N, G = get_gmm_shape(lhs, rhs, group_sizes)
|
| 146 |
+
|
| 147 |
+
if use_bias:
|
| 148 |
+
check_bias_shape_stride(bias, G, N)
|
| 149 |
+
|
| 150 |
+
out = get_gmm_output(
|
| 151 |
+
M,
|
| 152 |
+
N,
|
| 153 |
+
device=lhs.device,
|
| 154 |
+
preferred_element_type=preferred_element_type,
|
| 155 |
+
existing_out=existing_out,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
trans_rhs, _ = get_gmm_transposition(lhs, rhs, out)
|
| 159 |
+
|
| 160 |
+
if config is None:
|
| 161 |
+
config = get_config("gmm", M, K, N, G)
|
| 162 |
+
|
| 163 |
+
assert all(
|
| 164 |
+
key in config
|
| 165 |
+
and isinstance(config[key], int)
|
| 166 |
+
and (
|
| 167 |
+
is_power_of_2(config[key])
|
| 168 |
+
if key.startswith("BLOCK_SIZE_")
|
| 169 |
+
else config[key] > 0
|
| 170 |
+
)
|
| 171 |
+
for key in {
|
| 172 |
+
"BLOCK_SIZE_M",
|
| 173 |
+
"BLOCK_SIZE_K",
|
| 174 |
+
"BLOCK_SIZE_N",
|
| 175 |
+
"GROUP_SIZE",
|
| 176 |
+
"GRID_DIM",
|
| 177 |
+
}
|
| 178 |
+
), "Invalid GMM kernel config."
|
| 179 |
+
|
| 180 |
+
grid = _gmm_grid(
|
| 181 |
+
N,
|
| 182 |
+
config["BLOCK_SIZE_M"],
|
| 183 |
+
config["BLOCK_SIZE_N"],
|
| 184 |
+
group_sizes,
|
| 185 |
+
config["GRID_DIM"],
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# fmt: off
|
| 189 |
+
gmm_kernel[grid](
|
| 190 |
+
# Tensor pointers:
|
| 191 |
+
lhs, rhs, group_sizes, out, bias,
|
| 192 |
+
# Tensor shapes:
|
| 193 |
+
M, K, N, G,
|
| 194 |
+
# Meta-parameters:
|
| 195 |
+
TRANS_RHS=trans_rhs,
|
| 196 |
+
USE_BIAS=use_bias,
|
| 197 |
+
**config,
|
| 198 |
+
)
|
| 199 |
+
# fmt: on
|
| 200 |
+
|
| 201 |
+
return out
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# Persistent TGMM PyTorch wrapper.
|
| 205 |
+
# ------------------------------------------------------------------------------
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def _ptgmm_grid(
|
| 209 |
+
K: int,
|
| 210 |
+
N: int,
|
| 211 |
+
G: int,
|
| 212 |
+
block_size_k: int,
|
| 213 |
+
block_size_n: int,
|
| 214 |
+
grid_dim: int,
|
| 215 |
+
) -> tuple[int]:
|
| 216 |
+
assert K > 0, f"K must be positive, it's {K}."
|
| 217 |
+
assert N > 0, f"N must be positive, it's {N}."
|
| 218 |
+
assert G > 0, f"G must be positive, it's {G}."
|
| 219 |
+
assert is_power_of_2(
|
| 220 |
+
block_size_k
|
| 221 |
+
), f"K-dimension tile size must be a power of 2 (it's {block_size_k})."
|
| 222 |
+
assert is_power_of_2(
|
| 223 |
+
block_size_n
|
| 224 |
+
), f"N-dimension tile size must be a power of 2 (it's {block_size_n})."
|
| 225 |
+
assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})."
|
| 226 |
+
num_k_tiles = triton.cdiv(K, block_size_k)
|
| 227 |
+
assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}."
|
| 228 |
+
num_n_tiles = triton.cdiv(N, block_size_n)
|
| 229 |
+
assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}."
|
| 230 |
+
num_tiles = G * num_k_tiles * num_n_tiles
|
| 231 |
+
assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}."
|
| 232 |
+
num_programs = min(grid_dim, num_tiles)
|
| 233 |
+
assert num_programs > 0, f"num_programs must be positive, it's {num_programs}."
|
| 234 |
+
return (num_programs,)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def ptgmm(
|
| 238 |
+
lhs: Tensor,
|
| 239 |
+
rhs: Tensor,
|
| 240 |
+
group_sizes: Tensor,
|
| 241 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 242 |
+
existing_out: Tensor | None = None,
|
| 243 |
+
config: dict[str, int] | None = None,
|
| 244 |
+
bias_grad: Tensor | None = None,
|
| 245 |
+
accumulate: bool = False,
|
| 246 |
+
) -> Tensor:
|
| 247 |
+
"""
|
| 248 |
+
Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs
|
| 249 |
+
|
| 250 |
+
lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with
|
| 251 |
+
the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch
|
| 252 |
+
parlance, it can be implemented as follows for a given group g:
|
| 253 |
+
out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :]
|
| 254 |
+
|
| 255 |
+
The 't' in the operator name derives from MaxText implementation
|
| 256 |
+
(https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py),
|
| 257 |
+
which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor
|
| 258 |
+
shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N).
|
| 259 |
+
|
| 260 |
+
The 'p' in the operator name means that it is implemented with a persistent kernel. There is
|
| 261 |
+
also the non-persistent variation, which is implemented with a regular kernel. Please take a
|
| 262 |
+
look at nptgmm operator. Both ptgmm and nptgmm implement the same computation, choosing one or
|
| 263 |
+
the other is a matter of performance for the target workload.
|
| 264 |
+
|
| 265 |
+
Parameters
|
| 266 |
+
----------
|
| 267 |
+
lhs : torch.Tensor
|
| 268 |
+
Left-hand side 2D input tensor. Shape: (K, M).
|
| 269 |
+
lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type.
|
| 270 |
+
lhs must be on the same device of rhs and group_sizes.
|
| 271 |
+
rhs : torch.Tensor
|
| 272 |
+
Right-hand side 2D input tensor. Shape: (M, N).
|
| 273 |
+
rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type.
|
| 274 |
+
rhs must be on the same device of lhs and group_sizes.
|
| 275 |
+
group_sizes : torch.Tensor
|
| 276 |
+
1D input tensor describing group sizes. Shape: (G,).
|
| 277 |
+
group_sizes data type must be torch.int32 and all its elements must be non-negative.
|
| 278 |
+
group_sizes must be on the same device of lhs and rhs.
|
| 279 |
+
preferred_element_type : torch.dtype, optional
|
| 280 |
+
Desired data type for output tensor. Default is torch.bfloat16.
|
| 281 |
+
Supported output types are torch.float16 and torch.bfloat16.
|
| 282 |
+
existing_out : torch.Tensor or None, optional
|
| 283 |
+
Preallocated output tensor. Default is None.
|
| 284 |
+
If provided, results are written into this tensor. Otherwise, a new output tensor is
|
| 285 |
+
allocated.
|
| 286 |
+
If provided then it must have shape (G, K, N), its data type must match
|
| 287 |
+
preferred_element_type and it must be on the same device of other input tensors.
|
| 288 |
+
config : dict[str, int] or None, optional
|
| 289 |
+
Optional dictionary with kernel metaparameters. If absent, config will be queried from
|
| 290 |
+
internal tuning database.
|
| 291 |
+
bias_grad : torch.Tensor or None, optional
|
| 292 |
+
Optional bias gradient output tensor. Shape: (G, K).
|
| 293 |
+
If provided, the kernel will compute the bias gradient and write it to this tensor.
|
| 294 |
+
bias_grad must be torch.float32 (kernel uses atomic_add which requires float32),
|
| 295 |
+
accumulate : bool, optional
|
| 296 |
+
Whether to accumulate into existing output tensor values. Default is False.
|
| 297 |
+
If False, output will be overwritten with fresh computation.
|
| 298 |
+
If True, results will be added to existing output tensor values.
|
| 299 |
+
|
| 300 |
+
Returns
|
| 301 |
+
-------
|
| 302 |
+
torch.Tensor
|
| 303 |
+
The computed output 3D tensor. Shape: (G, K, N).
|
| 304 |
+
Output tensor data type is given by preferred_element_type.
|
| 305 |
+
If existing_out is provided then existing_out is also returned.
|
| 306 |
+
|
| 307 |
+
Implementation Notes
|
| 308 |
+
--------------------
|
| 309 |
+
- PTGMM is implemented with a persistent Triton kernel.
|
| 310 |
+
- lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs
|
| 311 |
+
is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel
|
| 312 |
+
parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward
|
| 313 |
+
pass, while fusing the transposition.
|
| 314 |
+
- rhs must be row-major (rhs.stride() == (N, 1)).
|
| 315 |
+
- out must be row-major (out.stride() == (K * N, N, 1)).
|
| 316 |
+
"""
|
| 317 |
+
check_input_device_dtype(lhs, rhs, group_sizes)
|
| 318 |
+
|
| 319 |
+
M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes)
|
| 320 |
+
|
| 321 |
+
out = get_tgmm_output(
|
| 322 |
+
K,
|
| 323 |
+
N,
|
| 324 |
+
G,
|
| 325 |
+
device=lhs.device,
|
| 326 |
+
preferred_element_type=preferred_element_type,
|
| 327 |
+
existing_out=existing_out,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out)
|
| 331 |
+
|
| 332 |
+
if config is None:
|
| 333 |
+
config = get_config("ptgmm", M, K, N, G, accumulate)
|
| 334 |
+
|
| 335 |
+
assert all(
|
| 336 |
+
key in config
|
| 337 |
+
and isinstance(config[key], int)
|
| 338 |
+
and (
|
| 339 |
+
is_power_of_2(config[key])
|
| 340 |
+
if key.startswith("BLOCK_SIZE_")
|
| 341 |
+
else config[key] > 0
|
| 342 |
+
)
|
| 343 |
+
for key in {
|
| 344 |
+
"BLOCK_SIZE_M",
|
| 345 |
+
"BLOCK_SIZE_K",
|
| 346 |
+
"BLOCK_SIZE_N",
|
| 347 |
+
"GROUP_SIZE",
|
| 348 |
+
"GRID_DIM",
|
| 349 |
+
}
|
| 350 |
+
), "Invalid PTGMM kernel config."
|
| 351 |
+
|
| 352 |
+
# Bias gradient handling.
|
| 353 |
+
# -----------------------
|
| 354 |
+
# Get or validate bias gradient tensor.
|
| 355 |
+
compute_bias_grad = bias_grad is not None
|
| 356 |
+
bias_grad_ptr = get_tgmm_bias_grad(
|
| 357 |
+
K,
|
| 358 |
+
G,
|
| 359 |
+
device=lhs.device,
|
| 360 |
+
existing_bias_grad=bias_grad,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
grid = _ptgmm_grid(
|
| 364 |
+
K,
|
| 365 |
+
N,
|
| 366 |
+
G,
|
| 367 |
+
config["BLOCK_SIZE_K"],
|
| 368 |
+
config["BLOCK_SIZE_N"],
|
| 369 |
+
config["GRID_DIM"],
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# fmt: off
|
| 373 |
+
tgmm_persistent_kernel[grid](
|
| 374 |
+
# Tensor pointers:
|
| 375 |
+
lhs, rhs, group_sizes, out, bias_grad_ptr,
|
| 376 |
+
# Tensor shapes:
|
| 377 |
+
M, K, N, G,
|
| 378 |
+
# Meta-parameters:
|
| 379 |
+
TRANS_LHS=trans_lhs,
|
| 380 |
+
COMPUTE_BIAS_GRAD=compute_bias_grad,
|
| 381 |
+
ACCUMULATE=accumulate,
|
| 382 |
+
**config,
|
| 383 |
+
)
|
| 384 |
+
# fmt: on
|
| 385 |
+
|
| 386 |
+
return out
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
# Regular non-persistent TGMM PyTorch wrapper.
|
| 390 |
+
# ------------------------------------------------------------------------------
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def _nptgmm_grid(
|
| 394 |
+
K: int,
|
| 395 |
+
N: int,
|
| 396 |
+
G: int,
|
| 397 |
+
block_size_k: int,
|
| 398 |
+
block_size_n: int,
|
| 399 |
+
) -> tuple[int, int]:
|
| 400 |
+
assert K > 0, f"K must be positive, it's {K}."
|
| 401 |
+
assert N > 0, f"N must be positive, it's {N}."
|
| 402 |
+
assert G > 0, f"G must be positive, it's {G}."
|
| 403 |
+
assert is_power_of_2(
|
| 404 |
+
block_size_k
|
| 405 |
+
), f"K-dimension tile size must be a power of 2 (it's {block_size_k})."
|
| 406 |
+
assert is_power_of_2(
|
| 407 |
+
block_size_n
|
| 408 |
+
), f"N-dimension tile size must be a power of 2 (it's {block_size_n})."
|
| 409 |
+
num_k_tiles = triton.cdiv(K, block_size_k)
|
| 410 |
+
assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}."
|
| 411 |
+
num_n_tiles = triton.cdiv(N, block_size_n)
|
| 412 |
+
assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}."
|
| 413 |
+
num_tiles_per_mm = num_k_tiles * num_n_tiles
|
| 414 |
+
assert (
|
| 415 |
+
num_tiles_per_mm > 0
|
| 416 |
+
), f"num_tiles_per_mm must be positive, it's {num_tiles_per_mm}."
|
| 417 |
+
return (G, num_tiles_per_mm)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def nptgmm(
|
| 421 |
+
lhs: Tensor,
|
| 422 |
+
rhs: Tensor,
|
| 423 |
+
group_sizes: Tensor,
|
| 424 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 425 |
+
existing_out: Tensor | None = None,
|
| 426 |
+
config: dict[str, int] | None = None,
|
| 427 |
+
bias_grad: Tensor | None = None,
|
| 428 |
+
accumulate: bool = False,
|
| 429 |
+
) -> Tensor:
|
| 430 |
+
"""
|
| 431 |
+
Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs
|
| 432 |
+
|
| 433 |
+
lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with
|
| 434 |
+
the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch
|
| 435 |
+
parlance, it can be implemented as follows for a given group g:
|
| 436 |
+
out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :]
|
| 437 |
+
|
| 438 |
+
The 't' in the operator name derives from MaxText implementation
|
| 439 |
+
(https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py),
|
| 440 |
+
which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor
|
| 441 |
+
shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N).
|
| 442 |
+
|
| 443 |
+
The 'np' in the operator name means that it is implemented with a non-persistent, i.e. regular
|
| 444 |
+
kernel. There is also the persistent variation, which is implemented with a persistent kernel.
|
| 445 |
+
Please take a look at ptgmm operator. Both nptgmm and ptgmm implement the same computation,
|
| 446 |
+
choosing one or the other is a matter of performance for the target workload.
|
| 447 |
+
|
| 448 |
+
Parameters
|
| 449 |
+
----------
|
| 450 |
+
lhs : torch.Tensor
|
| 451 |
+
Left-hand side 2D input tensor. Shape: (K, M).
|
| 452 |
+
lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type.
|
| 453 |
+
lhs must be on the same device of rhs and group_sizes.
|
| 454 |
+
rhs : torch.Tensor
|
| 455 |
+
Right-hand side 2D input tensor. Shape: (M, N).
|
| 456 |
+
rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type.
|
| 457 |
+
rhs must be on the same device of lhs and group_sizes.
|
| 458 |
+
group_sizes : torch.Tensor
|
| 459 |
+
1D input tensor describing group sizes. Shape: (G,).
|
| 460 |
+
group_sizes data type must be torch.int32 and all its elements must be non-negative.
|
| 461 |
+
group_sizes must be on the same device of lhs and rhs.
|
| 462 |
+
preferred_element_type : torch.dtype, optional
|
| 463 |
+
Desired data type for output tensor. Default is torch.bfloat16.
|
| 464 |
+
Supported output types are torch.float16 and torch.bfloat16.
|
| 465 |
+
existing_out : torch.Tensor or None, optional
|
| 466 |
+
Preallocated output tensor. Default is None.
|
| 467 |
+
If provided, results are written into this tensor. Otherwise, a new output tensor is
|
| 468 |
+
allocated.
|
| 469 |
+
If provided then it must have shape (G, K, N), its data type must match
|
| 470 |
+
preferred_element_type and it must be on the same device of other input tensors.
|
| 471 |
+
config : dict[str, int] or None, optional
|
| 472 |
+
Optional dictionary with kernel metaparameters. If absent, config will be queried from
|
| 473 |
+
internal tuning database.
|
| 474 |
+
bias_grad : torch.Tensor or None, optional
|
| 475 |
+
Optional bias gradient output tensor. Shape: (G, K).
|
| 476 |
+
If provided, the kernel will compute the bias gradient and write it to this tensor.
|
| 477 |
+
bias_grad must be torch.float32 (kernel uses atomic_add which requires float32),
|
| 478 |
+
accumulate : bool, optional
|
| 479 |
+
Whether to accumulate into existing output tensor values. Default is False.
|
| 480 |
+
If False, output will be overwritten with fresh computation.
|
| 481 |
+
If True, results will be added to existing output tensor values.
|
| 482 |
+
|
| 483 |
+
Returns
|
| 484 |
+
-------
|
| 485 |
+
torch.Tensor
|
| 486 |
+
The computed output 3D tensor. Shape: (G, K, N).
|
| 487 |
+
Output tensor data type is given by preferred_element_type.
|
| 488 |
+
If existing_out is provided then existing_out is also returned.
|
| 489 |
+
|
| 490 |
+
Implementation Notes
|
| 491 |
+
--------------------
|
| 492 |
+
- NPTGMM is implemented with a non-persistent regular Triton kernel.
|
| 493 |
+
- lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs
|
| 494 |
+
is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel
|
| 495 |
+
parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward
|
| 496 |
+
pass, while fusing the transposition.
|
| 497 |
+
- rhs must be row-major (rhs.stride() == (N, 1)).
|
| 498 |
+
- out must be row-major (out.stride() == (K * N, N, 1)).
|
| 499 |
+
"""
|
| 500 |
+
check_input_device_dtype(lhs, rhs, group_sizes)
|
| 501 |
+
|
| 502 |
+
M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes)
|
| 503 |
+
|
| 504 |
+
out = get_tgmm_output(
|
| 505 |
+
K,
|
| 506 |
+
N,
|
| 507 |
+
G,
|
| 508 |
+
device=lhs.device,
|
| 509 |
+
preferred_element_type=preferred_element_type,
|
| 510 |
+
existing_out=existing_out,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out)
|
| 514 |
+
|
| 515 |
+
# Bias gradient handling.
|
| 516 |
+
# -----------------------
|
| 517 |
+
# Get or validate bias gradient tensor.
|
| 518 |
+
compute_bias_grad = bias_grad is not None
|
| 519 |
+
bias_grad_ptr = get_tgmm_bias_grad(
|
| 520 |
+
K,
|
| 521 |
+
G,
|
| 522 |
+
device=lhs.device,
|
| 523 |
+
existing_bias_grad=bias_grad,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
if config is None:
|
| 527 |
+
config = get_config("nptgmm", M, K, N, G, accumulate)
|
| 528 |
+
|
| 529 |
+
assert all(
|
| 530 |
+
key in config
|
| 531 |
+
and isinstance(config[key], int)
|
| 532 |
+
and (
|
| 533 |
+
is_power_of_2(config[key])
|
| 534 |
+
if key.startswith("BLOCK_SIZE_")
|
| 535 |
+
else config[key] > 0
|
| 536 |
+
)
|
| 537 |
+
for key in {
|
| 538 |
+
"BLOCK_SIZE_M",
|
| 539 |
+
"BLOCK_SIZE_K",
|
| 540 |
+
"BLOCK_SIZE_N",
|
| 541 |
+
"GROUP_SIZE",
|
| 542 |
+
}
|
| 543 |
+
), "Invalid NPTGMM kernel config."
|
| 544 |
+
|
| 545 |
+
grid = _nptgmm_grid(
|
| 546 |
+
K,
|
| 547 |
+
N,
|
| 548 |
+
G,
|
| 549 |
+
config["BLOCK_SIZE_K"],
|
| 550 |
+
config["BLOCK_SIZE_N"],
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
# fmt: off
|
| 554 |
+
tgmm_non_persistent_kernel[grid](
|
| 555 |
+
# Tensor pointers:
|
| 556 |
+
lhs, rhs, group_sizes, out, bias_grad_ptr,
|
| 557 |
+
# Tensor shapes:
|
| 558 |
+
M, K, N, G,
|
| 559 |
+
# Meta-parameters:
|
| 560 |
+
TRANS_LHS=trans_lhs,
|
| 561 |
+
COMPUTE_BIAS_GRAD=compute_bias_grad,
|
| 562 |
+
ACCUMULATE=accumulate,
|
| 563 |
+
**config,
|
| 564 |
+
)
|
| 565 |
+
# fmt: on
|
| 566 |
+
|
| 567 |
+
return out
|
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/__init__.py
ADDED
|
File without changes
|
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py
ADDED
|
File without changes
|
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import triton
|
| 2 |
+
|
| 3 |
+
# Detect the GPU arch lazily: querying the triton driver at import time fails
|
| 4 |
+
# in headless environments (e.g. the kernel-builder ABI check sandbox has no
|
| 5 |
+
# GPU), and the original JAX fallback pulled in an unrelated runtime dep. The
|
| 6 |
+
# arch is only actually needed when a GMM kernel is dispatched, so resolve and
|
| 7 |
+
# cache on first call.
|
| 8 |
+
_CACHED_ARCH = None
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_arch():
|
| 12 |
+
global _CACHED_ARCH
|
| 13 |
+
if _CACHED_ARCH is not None:
|
| 14 |
+
return _CACHED_ARCH
|
| 15 |
+
try:
|
| 16 |
+
_CACHED_ARCH = triton.runtime.driver.active.get_current_target().arch
|
| 17 |
+
except RuntimeError:
|
| 18 |
+
try:
|
| 19 |
+
from jax._src.lib import gpu_triton as triton_kernel_call_lib
|
| 20 |
+
_CACHED_ARCH = triton_kernel_call_lib.get_arch_details("0").split(":")[0]
|
| 21 |
+
except ImportError as e:
|
| 22 |
+
raise RuntimeError(
|
| 23 |
+
"Cannot determine GPU arch: triton driver is inactive and "
|
| 24 |
+
"JAX is not available. A GPU is required for grouped GEMM."
|
| 25 |
+
) from e
|
| 26 |
+
return _CACHED_ARCH
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def is_gluon_avail():
|
| 30 |
+
return get_arch() in ("gfx950", "gfx1250")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def is_fp4_avail():
|
| 34 |
+
return get_arch() in ("gfx950", "gfx1250")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def is_fp8_avail():
|
| 38 |
+
return get_arch() in ("gfx942", "gfx950", "gfx1250", "gfx1200", "gfx1201")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def is_mx_scale_preshuffling_avail():
|
| 42 |
+
return get_arch() in ("gfx950", "gfx1250")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def is_tdm_avail():
|
| 46 |
+
return get_arch() in ("gfx1250",)
|
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
|
| 3 |
+
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
|
| 4 |
+
|
| 5 |
+
import triton
|
| 6 |
+
import triton.language as tl
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@triton.jit
|
| 10 |
+
def remap_xcd_chunked(
|
| 11 |
+
pid, GRID_MN, NUM_XCDS: tl.constexpr = 8, CHUNK_SIZE: tl.constexpr = 2
|
| 12 |
+
):
|
| 13 |
+
# Compute current XCD and local PID
|
| 14 |
+
xcd = pid % NUM_XCDS
|
| 15 |
+
# distribute the modulo pids in round robin
|
| 16 |
+
if pid > (GRID_MN // (NUM_XCDS * CHUNK_SIZE)) * (NUM_XCDS * CHUNK_SIZE):
|
| 17 |
+
return pid
|
| 18 |
+
local_pid = pid // NUM_XCDS
|
| 19 |
+
# Calculate chunk index and position within chunk
|
| 20 |
+
chunk_idx = local_pid // CHUNK_SIZE
|
| 21 |
+
pos_in_chunk = local_pid % CHUNK_SIZE
|
| 22 |
+
# Calculate new PID
|
| 23 |
+
new_pid = chunk_idx * NUM_XCDS * CHUNK_SIZE + xcd * CHUNK_SIZE + pos_in_chunk
|
| 24 |
+
return new_pid
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@triton.jit
|
| 28 |
+
def remap_xcd(pid, GRID_MN, NUM_XCDS: tl.constexpr = 8):
|
| 29 |
+
## pid remapping on xcds
|
| 30 |
+
# Number of pids per XCD in the new arrangement
|
| 31 |
+
pids_per_xcd = (GRID_MN + NUM_XCDS - 1) // NUM_XCDS
|
| 32 |
+
# When GRID_MN cannot divide NUM_XCDS, some xcds will have
|
| 33 |
+
# pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
|
| 34 |
+
# We calculate the number of xcds that have pids_per_xcd pids as
|
| 35 |
+
# tall_xcds
|
| 36 |
+
tall_xcds = GRID_MN % NUM_XCDS
|
| 37 |
+
tall_xcds = NUM_XCDS if tall_xcds == 0 else tall_xcds
|
| 38 |
+
# Compute current XCD and local pid within the XCD
|
| 39 |
+
xcd = pid % NUM_XCDS
|
| 40 |
+
local_pid = pid // NUM_XCDS
|
| 41 |
+
# Calculate new pid based on the new grouping
|
| 42 |
+
# Note that we need to consider the following two cases:
|
| 43 |
+
# 1. the current pid is on a tall xcd
|
| 44 |
+
# 2. the current pid is on a short xcd
|
| 45 |
+
if xcd < tall_xcds:
|
| 46 |
+
pid = xcd * pids_per_xcd + local_pid
|
| 47 |
+
else:
|
| 48 |
+
pid = (
|
| 49 |
+
tall_xcds * pids_per_xcd
|
| 50 |
+
+ (xcd - tall_xcds) * (pids_per_xcd - 1)
|
| 51 |
+
+ local_pid
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return pid
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@triton.jit
|
| 58 |
+
def pid_grid(pid: int, num_pid_m: int, num_pid_n: int, GROUP_SIZE_M: tl.constexpr = 1):
|
| 59 |
+
"""
|
| 60 |
+
Maps 1D pid to 2D grid coords (pid_m, pid_n).
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
- pid: 1D pid
|
| 64 |
+
- num_pid_m: grid m size
|
| 65 |
+
- num_pid_n: grid n size
|
| 66 |
+
- GROUP_SIZE_M: tl.constexpr: default is 1
|
| 67 |
+
"""
|
| 68 |
+
if GROUP_SIZE_M == 1:
|
| 69 |
+
pid_m = pid // num_pid_n
|
| 70 |
+
pid_n = pid % num_pid_n
|
| 71 |
+
else:
|
| 72 |
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
| 73 |
+
group_id = pid // num_pid_in_group
|
| 74 |
+
first_pid_m = group_id * GROUP_SIZE_M
|
| 75 |
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
| 76 |
+
tl.assume(group_size_m >= 0)
|
| 77 |
+
pid_m = first_pid_m + (pid % group_size_m)
|
| 78 |
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
| 79 |
+
|
| 80 |
+
return pid_m, pid_n
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@triton.jit
|
| 84 |
+
def pid_grid_3d(pid: int, num_pid_m: int, num_pid_n: int, num_pid_k):
|
| 85 |
+
"""
|
| 86 |
+
Maps 1D pid to 3D grid coords (pid_m, pid_n, pid_k).
|
| 87 |
+
Args:
|
| 88 |
+
- pid: 1D pid
|
| 89 |
+
- num_pid_m: grid m size
|
| 90 |
+
- num_pid_n: grid n size
|
| 91 |
+
- num_pid_k: grid k size
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
- pid_m, pid_n, pid_k: 3D grid coordinates
|
| 95 |
+
"""
|
| 96 |
+
pid_m = pid % num_pid_m
|
| 97 |
+
pid_n = (pid // num_pid_m) % num_pid_n
|
| 98 |
+
pid_k = pid // (num_pid_m * num_pid_n) % num_pid_k
|
| 99 |
+
|
| 100 |
+
return pid_m, pid_n, pid_k
|
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py
ADDED
|
@@ -0,0 +1,752 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
| 3 |
+
|
| 4 |
+
# Imports.
|
| 5 |
+
# ------------------------------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
# PyTorch
|
| 8 |
+
import torch
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
# AITER: logging
|
| 12 |
+
from .logger import AiterTritonLogger
|
| 13 |
+
|
| 14 |
+
_LOGGER: AiterTritonLogger = AiterTritonLogger()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Supported data types.
|
| 18 |
+
# ------------------------------------------------------------------------------
|
| 19 |
+
|
| 20 |
+
# Supported data types, as strings.
|
| 21 |
+
SUPPORTED_DTYPES_STR: set[str] = {"fp16", "bf16"}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Convert string data type to PyTorch data type.
|
| 25 |
+
def dtype_from_str(dtype_str: str) -> torch.dtype:
|
| 26 |
+
dtype_str = dtype_str.strip().lower()
|
| 27 |
+
dtype_str = dtype_str[1:] if dtype_str[0] in {"i", "o"} else dtype_str
|
| 28 |
+
assert (
|
| 29 |
+
dtype_str in SUPPORTED_DTYPES_STR
|
| 30 |
+
), "String data type isn't in set of supported string data types."
|
| 31 |
+
return {"fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Supported data types, as PyTorch types.
|
| 35 |
+
SUPPORTED_DTYPES: set[torch.dtype] = {
|
| 36 |
+
dtype_from_str(dtype_str) for dtype_str in SUPPORTED_DTYPES_STR
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Convert PyTorch data type to string data type.
|
| 41 |
+
def str_from_dtype(dtype: torch.dtype) -> str:
|
| 42 |
+
assert (
|
| 43 |
+
dtype in SUPPORTED_DTYPES
|
| 44 |
+
), "PyTorch data type isn't in set of supported PyTorch data types."
|
| 45 |
+
return {torch.float16: "fp16", torch.bfloat16: "bf16"}[dtype]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Default data type, as string.
|
| 49 |
+
DTYPE_STR: str = "bf16"
|
| 50 |
+
assert (
|
| 51 |
+
DTYPE_STR in SUPPORTED_DTYPES_STR
|
| 52 |
+
), "Default string data type isn't in set of supported string data types."
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Default data type, as PyTorch type.
|
| 56 |
+
DTYPE: torch.dtype = dtype_from_str(DTYPE_STR)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Other defaults.
|
| 60 |
+
# ------------------------------------------------------------------------------
|
| 61 |
+
|
| 62 |
+
# Default device.
|
| 63 |
+
DEVICE: torch.device | str = "cuda"
|
| 64 |
+
|
| 65 |
+
# Default RNG seed for input generation.
|
| 66 |
+
RNG_SEED: int = 0
|
| 67 |
+
|
| 68 |
+
# Default number of group sizes.
|
| 69 |
+
NUM_GROUP_SIZES: int = 1
|
| 70 |
+
|
| 71 |
+
# Default transposition (NN).
|
| 72 |
+
TRANS_LHS: bool = False
|
| 73 |
+
TRANS_RHS: bool = False
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Parameter checking functions.
|
| 77 |
+
# ------------------------------------------------------------------------------
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def is_power_of_2(x: int) -> bool:
|
| 81 |
+
return (x > 0) and (x & (x - 1) == 0)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def check_input_device_dtype(
|
| 85 |
+
lhs: Tensor, rhs: Tensor, group_sizes: Tensor, bias: Tensor | None = None
|
| 86 |
+
) -> None:
|
| 87 |
+
assert (
|
| 88 |
+
lhs.device == rhs.device == group_sizes.device
|
| 89 |
+
), f"All input tensors must be in the same device (lhs = {lhs.device}, rhs = {rhs.device}, group_sizes = {group_sizes.device})."
|
| 90 |
+
assert (
|
| 91 |
+
lhs.dtype == rhs.dtype
|
| 92 |
+
), f"lhs and rhs types must match (lhs = {lhs.dtype}, rhs = {rhs.dtype})."
|
| 93 |
+
assert group_sizes.dtype == torch.int32, "group_sizes type must be int32."
|
| 94 |
+
|
| 95 |
+
if bias is not None:
|
| 96 |
+
assert (
|
| 97 |
+
bias.device == lhs.device
|
| 98 |
+
), f"bias must be on the same device as lhs (bias = {bias.device}, lhs = {lhs.device})."
|
| 99 |
+
assert (
|
| 100 |
+
bias.dtype == lhs.dtype
|
| 101 |
+
), f"bias dtype must match lhs dtype (bias = {bias.dtype}, lhs = {lhs.dtype})."
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def check_bias_shape_stride(bias: Tensor, G: int, N: int) -> None:
|
| 105 |
+
assert bias.shape == (
|
| 106 |
+
G,
|
| 107 |
+
N,
|
| 108 |
+
), f"bias must have shape (G, N) = ({G}, {N}), got {bias.shape}."
|
| 109 |
+
assert bias.stride() == (N, 1), "bias must be row-major (bias.stride() == (N, 1))."
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Generation of group sizes.
|
| 113 |
+
# ------------------------------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Probabilities for generating random group sizes.
|
| 117 |
+
UNUSED_TOKENS_PROB: float = 0.0
|
| 118 |
+
UNUSED_EXPERTS_PROB: float = 0.1
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def gen_uniform_group_sizes(
|
| 122 |
+
M: int,
|
| 123 |
+
G: int,
|
| 124 |
+
device: torch.device | str = DEVICE,
|
| 125 |
+
) -> Tensor:
|
| 126 |
+
assert M >= 0, f"Number of tokens M must be non-negative (it's {M})."
|
| 127 |
+
assert G > 0, f"Number of experts G must be positive (it's {G})."
|
| 128 |
+
|
| 129 |
+
base = M // G
|
| 130 |
+
remainder = M % G
|
| 131 |
+
group_sizes = torch.full((G,), base, dtype=torch.int32, device=device)
|
| 132 |
+
if remainder > 0:
|
| 133 |
+
group_sizes[:remainder] += 1
|
| 134 |
+
|
| 135 |
+
assert (
|
| 136 |
+
len(group_sizes) == G
|
| 137 |
+
), f"Group sizes don't have {G} elements (it's {len(group_sizes)})."
|
| 138 |
+
assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative."
|
| 139 |
+
assert (
|
| 140 |
+
torch.sum(group_sizes).item() == M
|
| 141 |
+
), f"Group sizes don't add up to total tokens {M}."
|
| 142 |
+
assert group_sizes.dtype == torch.int32, "Group sizes must be int32."
|
| 143 |
+
|
| 144 |
+
return group_sizes
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def gen_group_sizes(
|
| 148 |
+
M: int,
|
| 149 |
+
G: int,
|
| 150 |
+
device: torch.device | str = DEVICE,
|
| 151 |
+
rng_seed: int | None = RNG_SEED,
|
| 152 |
+
unused_tokens_prob: float = UNUSED_TOKENS_PROB,
|
| 153 |
+
unused_experts_prob: float = UNUSED_EXPERTS_PROB,
|
| 154 |
+
) -> Tensor:
|
| 155 |
+
assert M >= 0, f"Number of tokens M must be non-negative (it's {M})."
|
| 156 |
+
assert G > 0, f"Number of experts G must be positive (it's {G})."
|
| 157 |
+
assert (
|
| 158 |
+
0 <= unused_tokens_prob <= 1
|
| 159 |
+
), f"Probability of unused tokens must be in [0, 1] interval (it's {unused_tokens_prob})."
|
| 160 |
+
assert (
|
| 161 |
+
0 <= unused_experts_prob <= 1
|
| 162 |
+
), f"Probability of unused experts must be in [0, 1] interval (it's {unused_experts_prob})."
|
| 163 |
+
|
| 164 |
+
if rng_seed is not None:
|
| 165 |
+
torch.manual_seed(rng_seed)
|
| 166 |
+
|
| 167 |
+
if unused_tokens_prob > 0:
|
| 168 |
+
# Optionally drop tokens to simulate routing sparsity, some tokens may not be routed.
|
| 169 |
+
num_unused_tokens = M
|
| 170 |
+
while num_unused_tokens == M:
|
| 171 |
+
num_unused_tokens = int(
|
| 172 |
+
torch.binomial(
|
| 173 |
+
torch.tensor(float(M), device=device),
|
| 174 |
+
torch.tensor(unused_tokens_prob, device=device),
|
| 175 |
+
).item()
|
| 176 |
+
)
|
| 177 |
+
else:
|
| 178 |
+
num_unused_tokens = 0
|
| 179 |
+
num_used_tokens = M - num_unused_tokens
|
| 180 |
+
assert (
|
| 181 |
+
num_unused_tokens >= 0
|
| 182 |
+
), f"Number of unused tokens must be non-negative (it's {num_unused_tokens})."
|
| 183 |
+
assert (
|
| 184 |
+
num_used_tokens > 0
|
| 185 |
+
), f"Number of used tokens must be positive (it's {num_used_tokens})."
|
| 186 |
+
assert (
|
| 187 |
+
num_used_tokens + num_unused_tokens == M
|
| 188 |
+
), f"Unused + used tokens don't add up total tokens ({num_used_tokens} + {num_unused_tokens} != {M})."
|
| 189 |
+
|
| 190 |
+
if num_unused_tokens > 0:
|
| 191 |
+
_LOGGER.debug(
|
| 192 |
+
f"Group sizes generation: dropped {num_unused_tokens} token{'s' if num_unused_tokens > 1 else ''}.",
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if unused_experts_prob > 0:
|
| 196 |
+
# Some experts may have zero tokens assigned to them.
|
| 197 |
+
num_used_experts = 0
|
| 198 |
+
while num_used_experts == 0:
|
| 199 |
+
used_experts = torch.nonzero(
|
| 200 |
+
torch.rand((G,), device=device) >= unused_experts_prob
|
| 201 |
+
).squeeze()
|
| 202 |
+
num_used_experts = used_experts.numel()
|
| 203 |
+
else:
|
| 204 |
+
used_experts = torch.arange(0, G, device=device)
|
| 205 |
+
num_used_experts = G
|
| 206 |
+
num_unused_experts = G - num_used_experts
|
| 207 |
+
assert (
|
| 208 |
+
num_unused_experts >= 0
|
| 209 |
+
), f"Number of unused experts must be non-negative (it's {num_unused_experts})."
|
| 210 |
+
assert (
|
| 211 |
+
num_used_experts >= 1
|
| 212 |
+
), f"At least one expert must be used (it's {num_used_experts})."
|
| 213 |
+
assert (
|
| 214 |
+
num_unused_experts + num_used_experts == G
|
| 215 |
+
), f"Unused + used experts don't add up total experts ({num_unused_experts} + {num_used_experts} != {G})."
|
| 216 |
+
|
| 217 |
+
if num_unused_experts > 0:
|
| 218 |
+
_LOGGER.debug(
|
| 219 |
+
f"Group sizes generation: dropped {num_unused_experts} expert{'s' if num_unused_experts > 1 else ''}.",
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
group_sizes = torch.bincount(
|
| 223 |
+
used_experts[
|
| 224 |
+
torch.randint(low=0, high=num_used_experts, size=(num_used_tokens,))
|
| 225 |
+
],
|
| 226 |
+
minlength=G,
|
| 227 |
+
).to(torch.int32)
|
| 228 |
+
|
| 229 |
+
assert (
|
| 230 |
+
len(group_sizes) == G
|
| 231 |
+
), f"Group sizes don't have {G} elements (it's {len(group_sizes)})."
|
| 232 |
+
assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative."
|
| 233 |
+
assert (
|
| 234 |
+
torch.sum(group_sizes).item() == num_used_tokens
|
| 235 |
+
), f"Group sizes don't add up to used tokens {num_used_tokens}."
|
| 236 |
+
assert group_sizes.dtype == torch.int32, "Group sizes must be int32."
|
| 237 |
+
|
| 238 |
+
return group_sizes
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def gen_multiple_group_sizes(
|
| 242 |
+
num_group_sizes: int,
|
| 243 |
+
M: int,
|
| 244 |
+
G: int,
|
| 245 |
+
device: torch.device | str = DEVICE,
|
| 246 |
+
rng_seed: int | None = RNG_SEED,
|
| 247 |
+
unused_tokens_prob: float = UNUSED_TOKENS_PROB,
|
| 248 |
+
unused_experts_prob: float = UNUSED_EXPERTS_PROB,
|
| 249 |
+
group_sizes_0: Tensor | None = None,
|
| 250 |
+
) -> list[Tensor]:
|
| 251 |
+
assert (
|
| 252 |
+
num_group_sizes > 0
|
| 253 |
+
), f"Number of group sizes to be generated must be positive, it's {num_group_sizes}."
|
| 254 |
+
multiple_group_sizes = [
|
| 255 |
+
gen_group_sizes(
|
| 256 |
+
M,
|
| 257 |
+
G,
|
| 258 |
+
device=device,
|
| 259 |
+
rng_seed=rng_seed if g == 0 else None,
|
| 260 |
+
unused_tokens_prob=unused_tokens_prob,
|
| 261 |
+
unused_experts_prob=unused_experts_prob,
|
| 262 |
+
)
|
| 263 |
+
for g in range(
|
| 264 |
+
num_group_sizes if group_sizes_0 is None else num_group_sizes - 1
|
| 265 |
+
)
|
| 266 |
+
]
|
| 267 |
+
if group_sizes_0 is not None:
|
| 268 |
+
multiple_group_sizes.insert(0, group_sizes_0)
|
| 269 |
+
assert (
|
| 270 |
+
len(multiple_group_sizes) == num_group_sizes
|
| 271 |
+
), f"Expecting {num_group_sizes} distinct group sizes (it's {len(multiple_group_sizes)})."
|
| 272 |
+
return multiple_group_sizes
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# GMM helpers: tensor generation.
|
| 276 |
+
# ------------------------------------------------------------------------------
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def gen_gmm_input(
|
| 280 |
+
M: int,
|
| 281 |
+
K: int,
|
| 282 |
+
N: int,
|
| 283 |
+
G: int,
|
| 284 |
+
device: torch.device | str = DEVICE,
|
| 285 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 286 |
+
trans_rhs: bool = TRANS_RHS,
|
| 287 |
+
rng_seed: int | None = RNG_SEED,
|
| 288 |
+
unif_group_sizes: bool = False,
|
| 289 |
+
) -> tuple[Tensor, Tensor, Tensor]:
|
| 290 |
+
assert M > 0, f"Number of lhs rows M must be positive (M = {M})."
|
| 291 |
+
assert K > 0, f"Number of lhs columns / rhs rows K must be positive (K = {K})."
|
| 292 |
+
assert N > 0, f"Number of rhs columns N must be positive (N = {N})."
|
| 293 |
+
assert G > 0, f"Number of groups G must be positive (G = {G})."
|
| 294 |
+
|
| 295 |
+
if rng_seed is not None:
|
| 296 |
+
torch.manual_seed(rng_seed)
|
| 297 |
+
|
| 298 |
+
lhs = torch.randn((M, K), dtype=torch.float32, device=device)
|
| 299 |
+
lhs = lhs.to(preferred_element_type)
|
| 300 |
+
|
| 301 |
+
if trans_rhs:
|
| 302 |
+
rhs = torch.randn((G, N, K), dtype=torch.float32, device=device).permute(
|
| 303 |
+
0, 2, 1
|
| 304 |
+
)
|
| 305 |
+
else:
|
| 306 |
+
rhs = torch.randn((G, K, N), dtype=torch.float32, device=device)
|
| 307 |
+
rhs = rhs.to(preferred_element_type)
|
| 308 |
+
|
| 309 |
+
group_sizes = (
|
| 310 |
+
gen_uniform_group_sizes(M, G, device=device)
|
| 311 |
+
if unif_group_sizes
|
| 312 |
+
else gen_group_sizes(M, G, device=device, rng_seed=None)
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
return lhs, rhs, group_sizes
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def gen_gmm_output(
|
| 319 |
+
M: int,
|
| 320 |
+
N: int,
|
| 321 |
+
device: torch.device | str = DEVICE,
|
| 322 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 323 |
+
) -> Tensor:
|
| 324 |
+
assert M > 0, f"Number of out rows M must be positive (M = {M})."
|
| 325 |
+
assert N > 0, f"Number of out columns N must be positive (N = {N})."
|
| 326 |
+
|
| 327 |
+
out = torch.empty((M, N), dtype=preferred_element_type, device=device)
|
| 328 |
+
|
| 329 |
+
return out
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def gen_gmm_tensors(
|
| 333 |
+
M: int,
|
| 334 |
+
K: int,
|
| 335 |
+
N: int,
|
| 336 |
+
G: int,
|
| 337 |
+
num_group_sizes: int,
|
| 338 |
+
device: torch.device | str = DEVICE,
|
| 339 |
+
input_type: torch.dtype = DTYPE,
|
| 340 |
+
output_type: torch.dtype = DTYPE,
|
| 341 |
+
trans_lhs: bool = False,
|
| 342 |
+
trans_rhs: bool = TRANS_RHS,
|
| 343 |
+
rng_seed: int | None = RNG_SEED,
|
| 344 |
+
unif_group_sizes: bool = False,
|
| 345 |
+
use_bias: bool = False,
|
| 346 |
+
) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]:
|
| 347 |
+
lhs, rhs, group_sizes_0 = gen_gmm_input(
|
| 348 |
+
M,
|
| 349 |
+
K,
|
| 350 |
+
N,
|
| 351 |
+
G,
|
| 352 |
+
device=device,
|
| 353 |
+
preferred_element_type=input_type,
|
| 354 |
+
trans_rhs=trans_rhs,
|
| 355 |
+
rng_seed=rng_seed,
|
| 356 |
+
unif_group_sizes=unif_group_sizes,
|
| 357 |
+
)
|
| 358 |
+
multiple_group_sizes = gen_multiple_group_sizes(
|
| 359 |
+
num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0
|
| 360 |
+
)
|
| 361 |
+
out = gen_gmm_output(M, N, device=device, preferred_element_type=output_type)
|
| 362 |
+
bias = None
|
| 363 |
+
if use_bias:
|
| 364 |
+
torch.manual_seed(rng_seed + 1000) # Different seed for bias
|
| 365 |
+
bias = torch.randn(G, N, dtype=input_type, device=device)
|
| 366 |
+
|
| 367 |
+
return lhs, rhs, multiple_group_sizes, out, bias
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
# GMM helpers: get information from tensors.
|
| 371 |
+
# ------------------------------------------------------------------------------
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def get_gmm_shape(
|
| 375 |
+
lhs: Tensor, rhs: Tensor, group_sizes: Tensor
|
| 376 |
+
) -> tuple[int, int, int, int]:
|
| 377 |
+
assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
|
| 378 |
+
assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})."
|
| 379 |
+
assert (
|
| 380 |
+
group_sizes.dim() == 1
|
| 381 |
+
), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})."
|
| 382 |
+
|
| 383 |
+
M, lhs_k = lhs.shape
|
| 384 |
+
rhs_g, rhs_k, N = rhs.shape
|
| 385 |
+
group_sizes_g = group_sizes.shape[0]
|
| 386 |
+
|
| 387 |
+
assert (
|
| 388 |
+
lhs_k == rhs_k
|
| 389 |
+
), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})."
|
| 390 |
+
K = lhs_k
|
| 391 |
+
assert (
|
| 392 |
+
rhs_g == group_sizes_g
|
| 393 |
+
), f"G dimension of rhs and group_sizes don't match (rhs = {rhs_g}, group_sizes = {group_sizes_g})."
|
| 394 |
+
G = rhs_g
|
| 395 |
+
|
| 396 |
+
assert M > 0, f"M must be positive, it's {M}."
|
| 397 |
+
assert K > 0, f"K must be positive, it's {K}."
|
| 398 |
+
assert N > 0, f"N must be positive, it's {N}"
|
| 399 |
+
assert G > 0, f"G must be positive, it's {G}"
|
| 400 |
+
|
| 401 |
+
return M, K, N, G
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def get_gmm_output(
|
| 405 |
+
M: int,
|
| 406 |
+
N: int,
|
| 407 |
+
device: torch.device | str = DEVICE,
|
| 408 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 409 |
+
existing_out: Tensor | None = None,
|
| 410 |
+
) -> Tensor:
|
| 411 |
+
assert M > 0, f"Number of out rows M must be positive (M = {M})."
|
| 412 |
+
assert N > 0, f"Number of out columns N must be positive (N = {N})."
|
| 413 |
+
|
| 414 |
+
if existing_out is not None:
|
| 415 |
+
assert (
|
| 416 |
+
existing_out.device == device
|
| 417 |
+
), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})."
|
| 418 |
+
assert (
|
| 419 |
+
existing_out.dtype == preferred_element_type
|
| 420 |
+
), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})."
|
| 421 |
+
assert existing_out.shape == (
|
| 422 |
+
M,
|
| 423 |
+
N,
|
| 424 |
+
), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(M, N)})."
|
| 425 |
+
return existing_out
|
| 426 |
+
|
| 427 |
+
return gen_gmm_output(
|
| 428 |
+
M,
|
| 429 |
+
N,
|
| 430 |
+
device=device,
|
| 431 |
+
preferred_element_type=preferred_element_type,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def get_gmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]:
|
| 436 |
+
assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
|
| 437 |
+
assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})."
|
| 438 |
+
assert out.dim() == 2, f"out must have 2 dimensions (it's {out.dim()})."
|
| 439 |
+
|
| 440 |
+
lhs_m, lhs_k = lhs.shape
|
| 441 |
+
G, rhs_k, rhs_n = rhs.shape
|
| 442 |
+
out_m, out_n = out.shape
|
| 443 |
+
|
| 444 |
+
assert (
|
| 445 |
+
lhs_m == out_m
|
| 446 |
+
), f"M dimension of lhs and out don't match (lhs = {lhs_m}, rhs = {out_m})."
|
| 447 |
+
M = lhs_m
|
| 448 |
+
assert (
|
| 449 |
+
lhs_k == rhs_k
|
| 450 |
+
), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})."
|
| 451 |
+
K = lhs_k
|
| 452 |
+
assert (
|
| 453 |
+
rhs_n == out_n
|
| 454 |
+
), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})."
|
| 455 |
+
N = rhs_n
|
| 456 |
+
|
| 457 |
+
assert M > 0, f"M must be positive, it's {M}."
|
| 458 |
+
assert K > 0, f"K must be positive, it's {K}."
|
| 459 |
+
assert N > 0, f"N must be positive, it's {N}"
|
| 460 |
+
assert G > 0, f"G must be positive, it's {G}"
|
| 461 |
+
|
| 462 |
+
is_lhs_row_major = lhs.stride() == (K, 1)
|
| 463 |
+
assert is_lhs_row_major, "lhs must be row-major."
|
| 464 |
+
is_rhs_row_major = rhs.stride() == (K * N, N, 1)
|
| 465 |
+
is_rhs_col_major = rhs.stride() == (K * N, 1, K)
|
| 466 |
+
assert (
|
| 467 |
+
is_rhs_row_major != is_rhs_col_major
|
| 468 |
+
), "rhs must be row-major or column-major."
|
| 469 |
+
is_out_row_major = out.stride() == (N, 1)
|
| 470 |
+
assert is_out_row_major, "out must be row-major."
|
| 471 |
+
|
| 472 |
+
# Get rhs leading dimension according to transposition configuration.
|
| 473 |
+
ld_rhs = N if is_rhs_row_major else K
|
| 474 |
+
|
| 475 |
+
return is_rhs_col_major, ld_rhs
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
# TGMM helpers: tensor generation.
|
| 479 |
+
# ------------------------------------------------------------------------------
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def gen_tgmm_input(
|
| 483 |
+
M: int,
|
| 484 |
+
K: int,
|
| 485 |
+
N: int,
|
| 486 |
+
G: int,
|
| 487 |
+
device: torch.device | str = DEVICE,
|
| 488 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 489 |
+
trans_lhs: bool = TRANS_LHS,
|
| 490 |
+
rng_seed: int | None = RNG_SEED,
|
| 491 |
+
unif_group_sizes: bool = False,
|
| 492 |
+
) -> tuple[Tensor, Tensor, Tensor]:
|
| 493 |
+
assert K > 0, f"Number of lhs rows K must be positive (M = {K})."
|
| 494 |
+
assert M > 0, f"Number of lhs columns / rhs rows M must be positive (K = {M})."
|
| 495 |
+
assert N > 0, f"Number of rhs columns N must be positive (N = {N})."
|
| 496 |
+
assert G > 0, f"Number of groups G must be positive (G = {G})."
|
| 497 |
+
|
| 498 |
+
if rng_seed is not None:
|
| 499 |
+
torch.manual_seed(rng_seed)
|
| 500 |
+
|
| 501 |
+
if trans_lhs:
|
| 502 |
+
lhs = torch.randn((M, K), dtype=torch.float32, device=device).T
|
| 503 |
+
else:
|
| 504 |
+
lhs = torch.randn((K, M), dtype=torch.float32, device=device)
|
| 505 |
+
lhs = lhs.to(preferred_element_type)
|
| 506 |
+
|
| 507 |
+
rhs = torch.randn((M, N), dtype=torch.float32, device=device)
|
| 508 |
+
rhs = rhs.to(preferred_element_type)
|
| 509 |
+
|
| 510 |
+
group_sizes = (
|
| 511 |
+
gen_uniform_group_sizes(M, G, device=device)
|
| 512 |
+
if unif_group_sizes
|
| 513 |
+
else gen_group_sizes(M, G, device=device, rng_seed=None)
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
return lhs, rhs, group_sizes
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def gen_tgmm_output(
|
| 520 |
+
K: int,
|
| 521 |
+
N: int,
|
| 522 |
+
G: int,
|
| 523 |
+
device: torch.device | str = DEVICE,
|
| 524 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 525 |
+
) -> Tensor:
|
| 526 |
+
assert K > 0, f"Number of out rows K must be positive (K = {K})."
|
| 527 |
+
assert N > 0, f"Number of out columns N must be positive (N = {N})."
|
| 528 |
+
assert G > 0, f"Number of groups G must be positive (G = {G})."
|
| 529 |
+
|
| 530 |
+
out = torch.empty((G, K, N), dtype=preferred_element_type, device=device)
|
| 531 |
+
|
| 532 |
+
return out
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def gen_tgmm_bias_grad(
|
| 536 |
+
K: int,
|
| 537 |
+
G: int,
|
| 538 |
+
device: torch.device | str = DEVICE,
|
| 539 |
+
with_bias_grad: bool = False,
|
| 540 |
+
) -> Tensor:
|
| 541 |
+
if with_bias_grad:
|
| 542 |
+
assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})."
|
| 543 |
+
assert G > 0, f"Number of groups G must be positive (G = {G})."
|
| 544 |
+
return torch.empty((G, K), device=device, dtype=torch.float32)
|
| 545 |
+
else:
|
| 546 |
+
# Return dummy pointer when bias_grad is not needed.
|
| 547 |
+
# Must be float32 because atomic_add does not support bf16/fp16,
|
| 548 |
+
# and Triton validates the pointer dtype even in dead branches.
|
| 549 |
+
return torch.tensor([], device=device, dtype=torch.float32)
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
def gen_tgmm_tensors(
|
| 553 |
+
M: int,
|
| 554 |
+
K: int,
|
| 555 |
+
N: int,
|
| 556 |
+
G: int,
|
| 557 |
+
num_group_sizes: int,
|
| 558 |
+
device: torch.device | str = DEVICE,
|
| 559 |
+
input_type: torch.dtype = DTYPE,
|
| 560 |
+
output_type: torch.dtype = DTYPE,
|
| 561 |
+
trans_lhs: bool = TRANS_LHS,
|
| 562 |
+
trans_rhs: bool = False,
|
| 563 |
+
rng_seed: int | None = RNG_SEED,
|
| 564 |
+
unif_group_sizes: bool = False,
|
| 565 |
+
use_bias: bool = False,
|
| 566 |
+
) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]:
|
| 567 |
+
lhs, rhs, group_sizes_0 = gen_tgmm_input(
|
| 568 |
+
M,
|
| 569 |
+
K,
|
| 570 |
+
N,
|
| 571 |
+
G,
|
| 572 |
+
device=device,
|
| 573 |
+
preferred_element_type=input_type,
|
| 574 |
+
trans_lhs=trans_lhs,
|
| 575 |
+
rng_seed=rng_seed,
|
| 576 |
+
unif_group_sizes=unif_group_sizes,
|
| 577 |
+
)
|
| 578 |
+
multiple_group_sizes = gen_multiple_group_sizes(
|
| 579 |
+
num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0
|
| 580 |
+
)
|
| 581 |
+
out = gen_tgmm_output(K, N, G, device=device, preferred_element_type=output_type)
|
| 582 |
+
if use_bias:
|
| 583 |
+
bias_grad = gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=True)
|
| 584 |
+
else:
|
| 585 |
+
bias_grad = None
|
| 586 |
+
return lhs, rhs, multiple_group_sizes, out, bias_grad
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
# TGMM helpers: get information from tensors.
|
| 590 |
+
# ------------------------------------------------------------------------------
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def get_tgmm_shape(
|
| 594 |
+
lhs: Tensor, rhs: Tensor, group_sizes: Tensor
|
| 595 |
+
) -> tuple[int, int, int, int]:
|
| 596 |
+
assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
|
| 597 |
+
assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})."
|
| 598 |
+
assert (
|
| 599 |
+
group_sizes.dim() == 1
|
| 600 |
+
), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})."
|
| 601 |
+
|
| 602 |
+
K, lhs_m = lhs.shape
|
| 603 |
+
rhs_m, N = rhs.shape
|
| 604 |
+
G = group_sizes.shape[0]
|
| 605 |
+
|
| 606 |
+
assert (
|
| 607 |
+
lhs_m == rhs_m
|
| 608 |
+
), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})."
|
| 609 |
+
M = lhs_m
|
| 610 |
+
|
| 611 |
+
assert M > 0, f"M must be positive, it's {M}."
|
| 612 |
+
assert K > 0, f"K must be positive, it's {K}."
|
| 613 |
+
assert N > 0, f"N must be positive, it's {N}"
|
| 614 |
+
assert G > 0, f"G must be positive, it's {G}"
|
| 615 |
+
|
| 616 |
+
return M, K, N, G
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def get_tgmm_output(
|
| 620 |
+
K: int,
|
| 621 |
+
N: int,
|
| 622 |
+
G: int,
|
| 623 |
+
device: torch.device | str = DEVICE,
|
| 624 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 625 |
+
existing_out: Tensor | None = None,
|
| 626 |
+
) -> Tensor:
|
| 627 |
+
assert K > 0, f"Number of out rows K must be positive (K = {K})."
|
| 628 |
+
assert N > 0, f"Number of out columns N must be positive (N = {N})."
|
| 629 |
+
assert G > 0, f"Number of groups G must be positive (G = {G})."
|
| 630 |
+
|
| 631 |
+
if existing_out is not None:
|
| 632 |
+
assert (
|
| 633 |
+
existing_out.device == device
|
| 634 |
+
), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})."
|
| 635 |
+
assert (
|
| 636 |
+
existing_out.dtype == preferred_element_type
|
| 637 |
+
), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})."
|
| 638 |
+
assert existing_out.shape == (
|
| 639 |
+
G,
|
| 640 |
+
K,
|
| 641 |
+
N,
|
| 642 |
+
), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(G, K, N)})."
|
| 643 |
+
return existing_out
|
| 644 |
+
|
| 645 |
+
return gen_tgmm_output(
|
| 646 |
+
K,
|
| 647 |
+
N,
|
| 648 |
+
G,
|
| 649 |
+
device=device,
|
| 650 |
+
preferred_element_type=preferred_element_type,
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def get_tgmm_bias_grad(
|
| 655 |
+
K: int,
|
| 656 |
+
G: int,
|
| 657 |
+
device: torch.device | str = DEVICE,
|
| 658 |
+
existing_bias_grad: Tensor | None = None,
|
| 659 |
+
) -> Tensor:
|
| 660 |
+
"""
|
| 661 |
+
Get or validate bias gradient tensor for TGMM.
|
| 662 |
+
|
| 663 |
+
If existing_bias_grad is provided, validates its shape, device, dtype, and stride,
|
| 664 |
+
and always zeros it before returning (since the kernel uses atomic_add).
|
| 665 |
+
If existing_bias_grad is None, returns a dummy tensor (for use when COMPUTE_BIAS_GRAD=False).
|
| 666 |
+
Parameters
|
| 667 |
+
----------
|
| 668 |
+
K : int
|
| 669 |
+
Number of rows in the bias gradient tensor.
|
| 670 |
+
G : int
|
| 671 |
+
Number of groups.
|
| 672 |
+
device : torch.device or str
|
| 673 |
+
Device for the tensor.
|
| 674 |
+
existing_bias_grad : torch.Tensor or None
|
| 675 |
+
Existing bias gradient tensor to validate and use.
|
| 676 |
+
Returns
|
| 677 |
+
-------
|
| 678 |
+
torch.Tensor
|
| 679 |
+
Valid bias gradient tensor or dummy tensor.
|
| 680 |
+
"""
|
| 681 |
+
assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})."
|
| 682 |
+
assert G > 0, f"Number of groups G must be positive (G = {G})."
|
| 683 |
+
|
| 684 |
+
if existing_bias_grad is not None:
|
| 685 |
+
# Validate existing bias_grad tensor.
|
| 686 |
+
expected_shape = (G, K)
|
| 687 |
+
assert (
|
| 688 |
+
tuple(existing_bias_grad.shape) == expected_shape
|
| 689 |
+
), f"bias_grad must have shape {expected_shape}, got {tuple(existing_bias_grad.shape)}."
|
| 690 |
+
assert (
|
| 691 |
+
existing_bias_grad.device == device
|
| 692 |
+
), f"bias_grad must be on the same device (bias_grad = {existing_bias_grad.device}, device = {device})."
|
| 693 |
+
assert (
|
| 694 |
+
existing_bias_grad.dtype == torch.float32
|
| 695 |
+
), f"bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), got {existing_bias_grad.dtype}."
|
| 696 |
+
assert existing_bias_grad.stride() == (
|
| 697 |
+
K,
|
| 698 |
+
1,
|
| 699 |
+
), f"bias_grad must be row-major with stride (K, 1) = ({K}, 1), got {existing_bias_grad.stride()}."
|
| 700 |
+
|
| 701 |
+
# Always zero the tensor since bias_grad represents gradients for the current
|
| 702 |
+
# computation and should start fresh. The kernel uses atomic_add which adds to
|
| 703 |
+
# existing values, so we must zero before the kernel runs.
|
| 704 |
+
existing_bias_grad.zero_()
|
| 705 |
+
|
| 706 |
+
return existing_bias_grad
|
| 707 |
+
|
| 708 |
+
else:
|
| 709 |
+
return gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=False)
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
def get_tgmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]:
|
| 713 |
+
assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
|
| 714 |
+
assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})."
|
| 715 |
+
assert out.dim() == 3, f"out must have 3 dimensions (it's {out.dim()})."
|
| 716 |
+
|
| 717 |
+
lhs_k, lhs_m = lhs.shape
|
| 718 |
+
rhs_m, rhs_n = rhs.shape
|
| 719 |
+
G, out_k, out_n = out.shape
|
| 720 |
+
|
| 721 |
+
assert (
|
| 722 |
+
lhs_m == rhs_m
|
| 723 |
+
), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})."
|
| 724 |
+
M = lhs_m
|
| 725 |
+
assert (
|
| 726 |
+
lhs_k == out_k
|
| 727 |
+
), f"K dimension of lhs and out don't match (lhs = {lhs_k}, rhs = {out_k})."
|
| 728 |
+
K = lhs_k
|
| 729 |
+
assert (
|
| 730 |
+
rhs_n == out_n
|
| 731 |
+
), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})."
|
| 732 |
+
N = rhs_n
|
| 733 |
+
|
| 734 |
+
assert M > 0, f"M must be positive, it's {M}."
|
| 735 |
+
assert K > 0, f"K must be positive, it's {K}."
|
| 736 |
+
assert N > 0, f"N must be positive, it's {N}"
|
| 737 |
+
assert G > 0, f"G must be positive, it's {G}"
|
| 738 |
+
|
| 739 |
+
is_lhs_row_major = lhs.stride() == (M, 1)
|
| 740 |
+
is_lhs_col_major = lhs.stride() == (1, K)
|
| 741 |
+
assert (
|
| 742 |
+
is_lhs_row_major != is_lhs_col_major
|
| 743 |
+
), "lhs must be row-major or column-major."
|
| 744 |
+
is_rhs_row_major = rhs.stride() == (N, 1)
|
| 745 |
+
assert is_rhs_row_major, "rhs must be row-major."
|
| 746 |
+
is_out_row_major = out.stride() == (K * N, N, 1)
|
| 747 |
+
assert is_out_row_major, "out must be row-major."
|
| 748 |
+
|
| 749 |
+
# Get lhs leading dimension according to transposition configuration.
|
| 750 |
+
ld_lhs = M if is_lhs_row_major else K
|
| 751 |
+
|
| 752 |
+
return is_lhs_col_major, ld_lhs
|
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/logger.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# AITER Triton Logger which is singleton object around python logging.
|
| 6 |
+
# Note: Python logging is also a singleton object, but we want to read the
|
| 7 |
+
# env var AITER_LOG_LEVEL once at the beginning. Another alternative is to do
|
| 8 |
+
# this in __init__.py. In fact, that's how CK logger is setup. We can look at
|
| 9 |
+
# switching to that at some point
|
| 10 |
+
#
|
| 11 |
+
# AITER_LOG_LEVEL follows python logging levels
|
| 12 |
+
# DEBUG
|
| 13 |
+
# INFO
|
| 14 |
+
# WARNING
|
| 15 |
+
# ERROR
|
| 16 |
+
# CRITICAL
|
| 17 |
+
#
|
| 18 |
+
class AiterTritonLogger(object):
|
| 19 |
+
_instance = None
|
| 20 |
+
|
| 21 |
+
def __new__(cls):
|
| 22 |
+
if cls._instance is None:
|
| 23 |
+
cls._instance = super(AiterTritonLogger, cls).__new__(cls)
|
| 24 |
+
log_level_str = os.getenv("AITER_TRITON_LOG_LEVEL", "WARNING").upper()
|
| 25 |
+
numeric_level = getattr(logging, log_level_str, logging.WARNING)
|
| 26 |
+
cls._instance._logger = logging.getLogger("AITER_TRITON")
|
| 27 |
+
cls._instance._logger.setLevel(numeric_level)
|
| 28 |
+
|
| 29 |
+
return cls._instance
|
| 30 |
+
|
| 31 |
+
def get_logger(self):
|
| 32 |
+
return self._logger
|
| 33 |
+
|
| 34 |
+
def debug(self, msg):
|
| 35 |
+
self._logger.debug(msg)
|
| 36 |
+
|
| 37 |
+
def info(self, msg):
|
| 38 |
+
self._logger.info(msg)
|
| 39 |
+
|
| 40 |
+
def warning(self, msg):
|
| 41 |
+
self._logger.warning(msg)
|
| 42 |
+
|
| 43 |
+
def error(self, msg):
|
| 44 |
+
self._logger.error(msg)
|
| 45 |
+
|
| 46 |
+
def critical(self, msg):
|
| 47 |
+
self._logger.critical(msg)
|
build/torch211-cxx11-cu128-x86_64-linux/{_megablocks_cuda_ae601bb.abi3.so → _megablocks_cuda_f8f8b50.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4ea3f6a68cbc730572a4a4c8d3814a2075cc775bffcf3082c9dbd6291e888555
|
| 3 |
+
size 19750504
|
build/torch211-cxx11-cu128-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_cuda_f8f8b50
|
| 3 |
+
ops = torch.ops._megablocks_cuda_f8f8b50
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_cuda_f8f8b50::{op_name}"
|
build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py
CHANGED
|
@@ -2,16 +2,16 @@
|
|
| 2 |
# extensions. Otherwise libc10.so cannot be found.
|
| 3 |
import torch
|
| 4 |
|
| 5 |
-
#
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
# import grouped_gemm_backend as backend
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
#
|
| 14 |
-
|
|
|
|
| 15 |
|
| 16 |
def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
|
| 17 |
assert not (trans_a and trans_b)
|
|
|
|
| 2 |
# extensions. Otherwise libc10.so cannot be found.
|
| 3 |
import torch
|
| 4 |
|
| 5 |
+
# On ROCm there is no CUTLASS grouped GEMM; dispatch to the vendored AITER
|
| 6 |
+
# Triton kernels instead. On CUDA we use the compiled CUTLASS `gmm` op.
|
| 7 |
+
_IS_ROCM = torch.version.hip is not None
|
|
|
|
| 8 |
|
| 9 |
+
if _IS_ROCM:
|
| 10 |
+
from .._grouped_gemm_triton import adapter as backend
|
| 11 |
+
else:
|
| 12 |
+
# We import the backend operations from the megablocks package as
|
| 13 |
+
# grouped_gemm is vendored in megablocks in this repository.
|
| 14 |
+
from .._ops import ops as backend # type: ignore
|
| 15 |
|
| 16 |
def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
|
| 17 |
assert not (trans_a and trans_b)
|
build/torch211-cxx11-cu128-x86_64-linux/metadata.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
{
|
| 2 |
"name": "megablocks",
|
| 3 |
-
"id": "
|
| 4 |
"version": 1,
|
| 5 |
"license": "Apache-2.0",
|
| 6 |
"python-depends": [],
|
|
@@ -10,6 +10,7 @@
|
|
| 10 |
"10.0",
|
| 11 |
"10.1",
|
| 12 |
"12.0",
|
|
|
|
| 13 |
"7.0",
|
| 14 |
"7.2",
|
| 15 |
"7.5",
|
|
|
|
| 1 |
{
|
| 2 |
"name": "megablocks",
|
| 3 |
+
"id": "_megablocks_cuda_f8f8b50",
|
| 4 |
"version": 1,
|
| 5 |
"license": "Apache-2.0",
|
| 6 |
"python-depends": [],
|
|
|
|
| 10 |
"10.0",
|
| 11 |
"10.1",
|
| 12 |
"12.0",
|
| 13 |
+
"12.0+PTX",
|
| 14 |
"7.0",
|
| 15 |
"7.2",
|
| 16 |
"7.5",
|
build/torch211-cxx11-cu130-x86_64-linux/__init__.py
CHANGED
|
@@ -3,7 +3,9 @@
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
|
| 6 |
-
from .
|
|
|
|
|
|
|
| 7 |
|
| 8 |
from .grouped_gemm import backend as gg_backend
|
| 9 |
from .grouped_gemm import ops as gg_ops
|
|
@@ -136,7 +138,8 @@ def sort(
|
|
| 136 |
Returns:
|
| 137 |
The sorted values tensor
|
| 138 |
"""
|
| 139 |
-
|
|
|
|
| 140 |
|
| 141 |
|
| 142 |
# Convenience functions for common use cases
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
# Stable alias: bare `ops` is shadowed by `from . import layers` below.
|
| 7 |
+
from ._ops import ops as _compiled_ops
|
| 8 |
+
from . import ops
|
| 9 |
|
| 10 |
from .grouped_gemm import backend as gg_backend
|
| 11 |
from .grouped_gemm import ops as gg_ops
|
|
|
|
| 138 |
Returns:
|
| 139 |
The sorted values tensor
|
| 140 |
"""
|
| 141 |
+
_compiled_ops.sort(x, end_bit, x_out, iota_out)
|
| 142 |
+
return x_out
|
| 143 |
|
| 144 |
|
| 145 |
# Convenience functions for common use cases
|
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/__init__.py
ADDED
|
File without changes
|
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py
ADDED
|
File without changes
|
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py
ADDED
|
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
# Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Imports.
|
| 6 |
+
# ------------------------------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
# Python standard library
|
| 9 |
+
import functools
|
| 10 |
+
|
| 11 |
+
# Triton
|
| 12 |
+
import triton
|
| 13 |
+
import triton.language as tl
|
| 14 |
+
|
| 15 |
+
# AITER
|
| 16 |
+
from ..configs import CONFIGS as _CONFIGS
|
| 17 |
+
from ..utils._triton import arch_info
|
| 18 |
+
from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
|
| 19 |
+
|
| 20 |
+
# Kernel config.
|
| 21 |
+
# ------------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@functools.lru_cache()
|
| 25 |
+
def get_config(
|
| 26 |
+
gmm_type: str, M: int, K: int, N: int, G: int, accumulate: bool = False
|
| 27 |
+
) -> dict[str, int]:
|
| 28 |
+
assert gmm_type in {
|
| 29 |
+
"gmm",
|
| 30 |
+
"ptgmm",
|
| 31 |
+
"nptgmm",
|
| 32 |
+
}, f"'{gmm_type}' is an invalid GMM variant."
|
| 33 |
+
dev = arch_info.get_arch()
|
| 34 |
+
assert (
|
| 35 |
+
dev in _CONFIGS
|
| 36 |
+
), f"No GMM configuration tuned for arch '{dev}'. Supported: {sorted(_CONFIGS)}."
|
| 37 |
+
arch_configs = _CONFIGS[dev]
|
| 38 |
+
assert (
|
| 39 |
+
"default" in arch_configs[gmm_type]
|
| 40 |
+
), "Default configuration is absent."
|
| 41 |
+
key = "accumulate" if accumulate else "default"
|
| 42 |
+
return arch_configs[gmm_type][key]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Common code shared by GMM and TGMM kernels.
|
| 46 |
+
# ------------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# XCD remapping followed by 1D PID to 2D grid mapping.
|
| 50 |
+
@triton.jit
|
| 51 |
+
def _remap_xcd_tile_grid(
|
| 52 |
+
tile_in_mm,
|
| 53 |
+
num_row_tiles,
|
| 54 |
+
num_col_tiles,
|
| 55 |
+
GROUP_SIZE: tl.constexpr = 1,
|
| 56 |
+
NUM_XCDS: tl.constexpr = 8,
|
| 57 |
+
):
|
| 58 |
+
return pid_grid(
|
| 59 |
+
remap_xcd(tile_in_mm, num_row_tiles * num_col_tiles, NUM_XCDS=NUM_XCDS),
|
| 60 |
+
num_row_tiles,
|
| 61 |
+
num_col_tiles,
|
| 62 |
+
GROUP_SIZE_M=GROUP_SIZE,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# GMM kernel.
|
| 67 |
+
# ------------------------------------------------------------------------------
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@triton.heuristics(
|
| 71 |
+
{
|
| 72 |
+
"K_DIVISIBLE_BY_BLOCK_SIZE_K": lambda META: META["K"] % META["BLOCK_SIZE_K"]
|
| 73 |
+
== 0,
|
| 74 |
+
}
|
| 75 |
+
)
|
| 76 |
+
@triton.jit
|
| 77 |
+
def gmm_kernel(
|
| 78 |
+
# Tensor pointers:
|
| 79 |
+
lhs_ptr,
|
| 80 |
+
rhs_ptr,
|
| 81 |
+
group_sizes_ptr,
|
| 82 |
+
out_ptr,
|
| 83 |
+
bias_ptr,
|
| 84 |
+
# Tensor shapes:
|
| 85 |
+
M: int,
|
| 86 |
+
K: int,
|
| 87 |
+
N: int,
|
| 88 |
+
G: int,
|
| 89 |
+
# Meta-parameters:
|
| 90 |
+
TRANS_RHS: tl.constexpr,
|
| 91 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 92 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 93 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 94 |
+
K_DIVISIBLE_BY_BLOCK_SIZE_K: tl.constexpr,
|
| 95 |
+
GROUP_SIZE: tl.constexpr,
|
| 96 |
+
GRID_DIM: tl.constexpr,
|
| 97 |
+
USE_BIAS: tl.constexpr,
|
| 98 |
+
):
|
| 99 |
+
tl.assume(M > 0)
|
| 100 |
+
tl.assume(K > 0)
|
| 101 |
+
tl.assume(N > 0)
|
| 102 |
+
tl.assume(G > 0)
|
| 103 |
+
|
| 104 |
+
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
| 105 |
+
tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0")
|
| 106 |
+
|
| 107 |
+
# Current tile. Each program computes multiple tiles of each group.
|
| 108 |
+
tile = tl.program_id(0)
|
| 109 |
+
tl.device_assert(tile >= 0, "tile < 0 (at initialization)")
|
| 110 |
+
|
| 111 |
+
# Tile limit of last MM problem (inclusive).
|
| 112 |
+
last_mm_tile = 0
|
| 113 |
+
|
| 114 |
+
# Last input row of lhs and output row of out. Each group reads some rows of
|
| 115 |
+
# lhs and writes some rows to out.
|
| 116 |
+
last_m = 0
|
| 117 |
+
|
| 118 |
+
# Loop through all (m, K, N) MM problems:
|
| 119 |
+
# (m, K) x (K, N) = (m, N)
|
| 120 |
+
# sum(m) = M
|
| 121 |
+
for g in range(G):
|
| 122 |
+
# Get m dimension of current MM problem.
|
| 123 |
+
m = tl.load(group_sizes_ptr + g)
|
| 124 |
+
# m can be zero if group is empty
|
| 125 |
+
tl.device_assert(m >= 0, "m < 0")
|
| 126 |
+
|
| 127 |
+
num_m_tiles = tl.cdiv(m, BLOCK_SIZE_M)
|
| 128 |
+
# num_m_tiles can be zero if group is empty
|
| 129 |
+
tl.device_assert(num_m_tiles >= 0, "num_m_tiles < 0")
|
| 130 |
+
|
| 131 |
+
num_tiles = num_m_tiles * num_n_tiles
|
| 132 |
+
# num_tiles can be zero if group is empty
|
| 133 |
+
tl.device_assert(num_tiles >= 0, "num_tiles < 0")
|
| 134 |
+
|
| 135 |
+
# Loop through tiles of current MM problem.
|
| 136 |
+
while tile >= last_mm_tile and tile < last_mm_tile + num_tiles:
|
| 137 |
+
# Figure out tile coordinates in current MM problem.
|
| 138 |
+
tile_in_mm = tile - last_mm_tile
|
| 139 |
+
tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0")
|
| 140 |
+
|
| 141 |
+
tile_m, tile_n = _remap_xcd_tile_grid(
|
| 142 |
+
tile_in_mm, num_m_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Do regular MM:
|
| 146 |
+
|
| 147 |
+
tl.device_assert(tile_m * BLOCK_SIZE_M >= 0, "tile_m * BLOCK_SIZE_M < 0")
|
| 148 |
+
tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0")
|
| 149 |
+
|
| 150 |
+
offs_lhs_m = (
|
| 151 |
+
tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 152 |
+
) % m
|
| 153 |
+
offs_rhs_n = (
|
| 154 |
+
tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 155 |
+
) % N
|
| 156 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
|
| 157 |
+
|
| 158 |
+
lhs_ptrs = lhs_ptr + (last_m + offs_lhs_m[:, None]) * K + offs_k[None, :]
|
| 159 |
+
|
| 160 |
+
if TRANS_RHS:
|
| 161 |
+
rhs_ptrs = (
|
| 162 |
+
rhs_ptr
|
| 163 |
+
+ g.to(tl.int64) * K * N
|
| 164 |
+
+ offs_k[:, None]
|
| 165 |
+
+ offs_rhs_n[None, :] * K
|
| 166 |
+
)
|
| 167 |
+
else:
|
| 168 |
+
rhs_ptrs = (
|
| 169 |
+
rhs_ptr
|
| 170 |
+
+ g.to(tl.int64) * K * N
|
| 171 |
+
+ offs_k[:, None] * N
|
| 172 |
+
+ offs_rhs_n[None, :]
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 176 |
+
|
| 177 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 178 |
+
if K_DIVISIBLE_BY_BLOCK_SIZE_K:
|
| 179 |
+
lhs = tl.load(lhs_ptrs)
|
| 180 |
+
rhs = tl.load(rhs_ptrs)
|
| 181 |
+
else:
|
| 182 |
+
k_mask_limit = K - k * BLOCK_SIZE_K
|
| 183 |
+
lhs = tl.load(
|
| 184 |
+
lhs_ptrs, mask=offs_k[None, :] < k_mask_limit, other=0
|
| 185 |
+
)
|
| 186 |
+
rhs = tl.load(
|
| 187 |
+
rhs_ptrs, mask=offs_k[:, None] < k_mask_limit, other=0
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
acc = tl.dot(lhs, rhs, acc=acc)
|
| 191 |
+
|
| 192 |
+
lhs_ptrs += BLOCK_SIZE_K
|
| 193 |
+
|
| 194 |
+
if TRANS_RHS:
|
| 195 |
+
rhs_ptrs += BLOCK_SIZE_K
|
| 196 |
+
else:
|
| 197 |
+
rhs_ptrs += BLOCK_SIZE_K * N
|
| 198 |
+
|
| 199 |
+
# Add bias if enabled
|
| 200 |
+
if USE_BIAS:
|
| 201 |
+
offs_bias_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(
|
| 202 |
+
0, BLOCK_SIZE_N
|
| 203 |
+
)
|
| 204 |
+
bias_ptrs = bias_ptr + g.to(tl.int64) * N + offs_bias_n
|
| 205 |
+
bias = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0)
|
| 206 |
+
# Convert bias to float32 to match accumulator precision
|
| 207 |
+
bias = bias.to(tl.float32)
|
| 208 |
+
# Broadcast bias across M dimension and add in float32
|
| 209 |
+
acc += bias[None, :]
|
| 210 |
+
|
| 211 |
+
# Convert to output dtype after all computations
|
| 212 |
+
acc = acc.to(out_ptr.type.element_ty)
|
| 213 |
+
|
| 214 |
+
offs_out_m = tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 215 |
+
offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 216 |
+
|
| 217 |
+
out_ptrs = (
|
| 218 |
+
out_ptr + (last_m + offs_out_m[:, None]) * N + offs_out_n[None, :]
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
tl.store(
|
| 222 |
+
out_ptrs,
|
| 223 |
+
acc,
|
| 224 |
+
mask=(offs_out_m[:, None] < m) & (offs_out_n[None, :] < N),
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Go to the next tile by advancing number of programs.
|
| 228 |
+
tile += GRID_DIM
|
| 229 |
+
tl.device_assert(tile > 0, "tile <= 0 (at update)")
|
| 230 |
+
|
| 231 |
+
# Get ready to go to the next MM problem.
|
| 232 |
+
|
| 233 |
+
last_mm_tile += num_tiles
|
| 234 |
+
# last_mm_tile can be zero if group 0 is skipped
|
| 235 |
+
tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)")
|
| 236 |
+
|
| 237 |
+
last_m += m
|
| 238 |
+
# last_m can be zero if group 0 is skipped
|
| 239 |
+
tl.device_assert(last_m >= 0, "last_m < 0 (at update)")
|
| 240 |
+
tl.device_assert(last_m <= M, "last_m > M (at update)")
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
# Persistent TGMM kernel.
|
| 244 |
+
# ------------------------------------------------------------------------------
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
@triton.jit
|
| 248 |
+
def tgmm_persistent_kernel(
|
| 249 |
+
# Tensor pointers:
|
| 250 |
+
lhs_ptr,
|
| 251 |
+
rhs_ptr,
|
| 252 |
+
group_sizes_ptr,
|
| 253 |
+
out_ptr,
|
| 254 |
+
bias_grad_ptr,
|
| 255 |
+
# Tensor shapes:
|
| 256 |
+
M: int,
|
| 257 |
+
K: int,
|
| 258 |
+
N: int,
|
| 259 |
+
G: int,
|
| 260 |
+
# Meta-parameters:
|
| 261 |
+
TRANS_LHS: tl.constexpr,
|
| 262 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 263 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 264 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 265 |
+
GROUP_SIZE: tl.constexpr,
|
| 266 |
+
GRID_DIM: tl.constexpr,
|
| 267 |
+
COMPUTE_BIAS_GRAD: tl.constexpr,
|
| 268 |
+
ACCUMULATE: tl.constexpr,
|
| 269 |
+
):
|
| 270 |
+
tl.assume(M > 0)
|
| 271 |
+
tl.assume(K > 0)
|
| 272 |
+
tl.assume(N > 0)
|
| 273 |
+
tl.assume(G > 0)
|
| 274 |
+
|
| 275 |
+
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
| 276 |
+
tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0")
|
| 277 |
+
|
| 278 |
+
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
| 279 |
+
tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0")
|
| 280 |
+
|
| 281 |
+
num_tiles = num_k_tiles * num_n_tiles
|
| 282 |
+
tl.device_assert(num_tiles > 0, "num_tiles <= 0")
|
| 283 |
+
|
| 284 |
+
# Current tile. Each program computes multiple tiles of each group.
|
| 285 |
+
tile = tl.program_id(0)
|
| 286 |
+
tl.device_assert(tile >= 0, "tile < 0 (at initialization)")
|
| 287 |
+
|
| 288 |
+
# Tile limit of last MM problem (inclusive).
|
| 289 |
+
last_mm_tile = 0
|
| 290 |
+
|
| 291 |
+
# Last input column of lhs and input row of rhs. Each group reads some
|
| 292 |
+
# columns of lhs and some rows of rhs.
|
| 293 |
+
last_m = 0
|
| 294 |
+
|
| 295 |
+
# Loop through all (K, m, N) MM problems:
|
| 296 |
+
# (K, m) x (m, N) = (K, N)
|
| 297 |
+
# sum(m) = M
|
| 298 |
+
for g in range(G):
|
| 299 |
+
# Get m dimension of current MM problem.
|
| 300 |
+
m = tl.load(group_sizes_ptr + g)
|
| 301 |
+
# m can be zero if group is empty
|
| 302 |
+
tl.device_assert(m >= 0, "m < 0")
|
| 303 |
+
|
| 304 |
+
# Loop through tiles of current MM problem.
|
| 305 |
+
while tile >= last_mm_tile and tile < last_mm_tile + num_tiles:
|
| 306 |
+
# Figure out tile coordinates in current MM problem.
|
| 307 |
+
tile_in_mm = tile - last_mm_tile
|
| 308 |
+
tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0")
|
| 309 |
+
|
| 310 |
+
tile_k, tile_n = _remap_xcd_tile_grid(
|
| 311 |
+
tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Do regular MM:
|
| 315 |
+
|
| 316 |
+
tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0")
|
| 317 |
+
tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0")
|
| 318 |
+
|
| 319 |
+
offs_lhs_k = (
|
| 320 |
+
tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
| 321 |
+
) % K
|
| 322 |
+
offs_rhs_n = (
|
| 323 |
+
tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 324 |
+
) % N
|
| 325 |
+
offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
| 326 |
+
|
| 327 |
+
if TRANS_LHS:
|
| 328 |
+
lhs_ptrs = (
|
| 329 |
+
lhs_ptr + offs_lhs_k[:, None] + (last_m + offs_m[None, :]) * K
|
| 330 |
+
)
|
| 331 |
+
else:
|
| 332 |
+
lhs_ptrs = (
|
| 333 |
+
lhs_ptr + offs_lhs_k[:, None] * M + (last_m + offs_m[None, :])
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
rhs_ptrs = rhs_ptr + (last_m + offs_m[:, None]) * N + offs_rhs_n[None, :]
|
| 337 |
+
|
| 338 |
+
loop_m = tl.cdiv(m, BLOCK_SIZE_M)
|
| 339 |
+
m_divisible_by_block_m = m % BLOCK_SIZE_M == 0
|
| 340 |
+
if not m_divisible_by_block_m:
|
| 341 |
+
loop_m -= 1
|
| 342 |
+
|
| 343 |
+
acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32)
|
| 344 |
+
|
| 345 |
+
# Initialize bias accumulator
|
| 346 |
+
bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
|
| 347 |
+
|
| 348 |
+
for _ in range(0, loop_m):
|
| 349 |
+
lhs = tl.load(lhs_ptrs)
|
| 350 |
+
rhs = tl.load(rhs_ptrs)
|
| 351 |
+
|
| 352 |
+
acc = tl.dot(lhs, rhs, acc=acc)
|
| 353 |
+
|
| 354 |
+
# Accumulate for bias gradient: sum lhs across M dimension
|
| 355 |
+
if COMPUTE_BIAS_GRAD and tile_n == 0:
|
| 356 |
+
bias_acc += tl.sum(
|
| 357 |
+
lhs, axis=1
|
| 358 |
+
) # Sum across M dimension [K, M] -> [K]
|
| 359 |
+
|
| 360 |
+
if TRANS_LHS:
|
| 361 |
+
lhs_ptrs += BLOCK_SIZE_M * K
|
| 362 |
+
else:
|
| 363 |
+
lhs_ptrs += BLOCK_SIZE_M
|
| 364 |
+
|
| 365 |
+
rhs_ptrs += BLOCK_SIZE_M * N
|
| 366 |
+
|
| 367 |
+
if not m_divisible_by_block_m:
|
| 368 |
+
offs_lhs_k = (
|
| 369 |
+
tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
| 370 |
+
) % K
|
| 371 |
+
offs_rhs_n = (
|
| 372 |
+
tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 373 |
+
) % N
|
| 374 |
+
offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 375 |
+
lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0)
|
| 376 |
+
rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0)
|
| 377 |
+
acc = tl.dot(lhs, rhs, acc=acc)
|
| 378 |
+
|
| 379 |
+
# Accumulate last chunk for bias gradient
|
| 380 |
+
if COMPUTE_BIAS_GRAD and tile_n == 0:
|
| 381 |
+
bias_acc += tl.sum(lhs, axis=1)
|
| 382 |
+
|
| 383 |
+
acc = acc.to(out_ptr.type.element_ty)
|
| 384 |
+
|
| 385 |
+
offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
| 386 |
+
offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 387 |
+
|
| 388 |
+
out_ptrs = (
|
| 389 |
+
out_ptr
|
| 390 |
+
+ g.to(tl.int64) * K * N
|
| 391 |
+
+ offs_out_k[:, None] * N
|
| 392 |
+
+ offs_out_n[None, :]
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N)
|
| 396 |
+
if ACCUMULATE:
|
| 397 |
+
# Load existing values and add to them (like beta=1 in BLAS)
|
| 398 |
+
old_vals = tl.load(out_ptrs, mask=mask, other=0.0)
|
| 399 |
+
tl.store(out_ptrs, acc + old_vals, mask=mask)
|
| 400 |
+
else:
|
| 401 |
+
# Overwrite output (like beta=0 in BLAS)
|
| 402 |
+
tl.store(out_ptrs, acc, mask=mask)
|
| 403 |
+
|
| 404 |
+
# Store bias gradient (only for first N tile, sum across all M)
|
| 405 |
+
if COMPUTE_BIAS_GRAD and tile_n == 0:
|
| 406 |
+
# Keep as float32 for atomic_add (bf16 not supported for atomics)
|
| 407 |
+
bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k
|
| 408 |
+
# Use atomic add since multiple K-tiles may write to same expert's bias
|
| 409 |
+
tl.atomic_add(
|
| 410 |
+
bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed"
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
# Go to the next tile by advancing number of programs.
|
| 414 |
+
tile += GRID_DIM
|
| 415 |
+
tl.device_assert(tile > 0, "tile <= 0 (at update)")
|
| 416 |
+
|
| 417 |
+
# Get ready to go to the next MM problem.
|
| 418 |
+
|
| 419 |
+
last_mm_tile += num_tiles
|
| 420 |
+
# last_mm_tile can be zero if group 0 is skipped
|
| 421 |
+
tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)")
|
| 422 |
+
|
| 423 |
+
last_m += m
|
| 424 |
+
# last_m can be zero if group 0 is skipped
|
| 425 |
+
tl.device_assert(last_m >= 0, "last_m < 0 (at update)")
|
| 426 |
+
tl.device_assert(last_m <= M, "last_m > M (at update)")
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
# Regular non-persistent TGMM kernel.
|
| 430 |
+
# ------------------------------------------------------------------------------
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
@triton.heuristics({"BLOCK_SIZE_G": lambda META: triton.next_power_of_2(META["G"])})
|
| 434 |
+
@triton.jit
|
| 435 |
+
def tgmm_non_persistent_kernel(
|
| 436 |
+
# Tensor pointers:
|
| 437 |
+
lhs_ptr,
|
| 438 |
+
rhs_ptr,
|
| 439 |
+
group_sizes_ptr,
|
| 440 |
+
out_ptr,
|
| 441 |
+
bias_grad_ptr,
|
| 442 |
+
# Tensor shapes:
|
| 443 |
+
M: int,
|
| 444 |
+
K: int,
|
| 445 |
+
N: int,
|
| 446 |
+
G: int,
|
| 447 |
+
# Meta-parameters:
|
| 448 |
+
TRANS_LHS: tl.constexpr,
|
| 449 |
+
BLOCK_SIZE_G: tl.constexpr,
|
| 450 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 451 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 452 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 453 |
+
GROUP_SIZE: tl.constexpr,
|
| 454 |
+
COMPUTE_BIAS_GRAD: tl.constexpr,
|
| 455 |
+
ACCUMULATE: tl.constexpr,
|
| 456 |
+
):
|
| 457 |
+
tl.assume(M > 0)
|
| 458 |
+
tl.assume(K > 0)
|
| 459 |
+
tl.assume(N > 0)
|
| 460 |
+
tl.assume(G > 0)
|
| 461 |
+
|
| 462 |
+
# Get group ID from grid.
|
| 463 |
+
g = tl.program_id(0)
|
| 464 |
+
tl.device_assert(g >= 0, "g < 0")
|
| 465 |
+
tl.device_assert(g < G, "g >= G")
|
| 466 |
+
|
| 467 |
+
# Get m dimension of current MM group.
|
| 468 |
+
m = tl.load(group_sizes_ptr + g)
|
| 469 |
+
# m can be zero if group is empty.
|
| 470 |
+
tl.device_assert(m >= 0, "m < 0")
|
| 471 |
+
|
| 472 |
+
# Skip empty groups.
|
| 473 |
+
if m == 0:
|
| 474 |
+
return
|
| 475 |
+
|
| 476 |
+
# Compute sum(group_sizes) until current group g.
|
| 477 |
+
# It's the starting column of lhs and starting row of rhs.
|
| 478 |
+
offs_g = tl.arange(0, BLOCK_SIZE_G)
|
| 479 |
+
group_sizes = tl.load(group_sizes_ptr + offs_g, mask=offs_g < g, other=0)
|
| 480 |
+
start_m = tl.sum(group_sizes)
|
| 481 |
+
|
| 482 |
+
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
| 483 |
+
tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0")
|
| 484 |
+
|
| 485 |
+
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
| 486 |
+
tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0")
|
| 487 |
+
|
| 488 |
+
# Get MM tile from grid.
|
| 489 |
+
tile_in_mm = tl.program_id(1)
|
| 490 |
+
tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0")
|
| 491 |
+
|
| 492 |
+
tile_k, tile_n = _remap_xcd_tile_grid(
|
| 493 |
+
tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0")
|
| 497 |
+
tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0")
|
| 498 |
+
|
| 499 |
+
offs_lhs_k = (tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K
|
| 500 |
+
offs_rhs_n = (tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
| 501 |
+
offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
| 502 |
+
|
| 503 |
+
if TRANS_LHS:
|
| 504 |
+
lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] + (start_m + offs_m[None, :]) * K
|
| 505 |
+
else:
|
| 506 |
+
lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] * M + (start_m + offs_m[None, :])
|
| 507 |
+
|
| 508 |
+
rhs_ptrs = rhs_ptr + (start_m + offs_m[:, None]) * N + offs_rhs_n[None, :]
|
| 509 |
+
|
| 510 |
+
loop_m = tl.cdiv(m, BLOCK_SIZE_M)
|
| 511 |
+
m_divisible_by_block_m = m % BLOCK_SIZE_M == 0
|
| 512 |
+
if not m_divisible_by_block_m:
|
| 513 |
+
loop_m -= 1
|
| 514 |
+
|
| 515 |
+
acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32)
|
| 516 |
+
# Initialize bias accumulator
|
| 517 |
+
bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
|
| 518 |
+
|
| 519 |
+
for _ in range(0, loop_m):
|
| 520 |
+
lhs = tl.load(lhs_ptrs)
|
| 521 |
+
rhs = tl.load(rhs_ptrs)
|
| 522 |
+
|
| 523 |
+
acc = tl.dot(lhs, rhs, acc=acc)
|
| 524 |
+
|
| 525 |
+
# Accumulate for bias gradient: sum lhs across M dimension
|
| 526 |
+
if COMPUTE_BIAS_GRAD and tile_n == 0:
|
| 527 |
+
bias_acc += tl.sum(lhs, axis=1) # [K, M] -> [K]
|
| 528 |
+
|
| 529 |
+
if TRANS_LHS:
|
| 530 |
+
lhs_ptrs += BLOCK_SIZE_M * K
|
| 531 |
+
else:
|
| 532 |
+
lhs_ptrs += BLOCK_SIZE_M
|
| 533 |
+
|
| 534 |
+
rhs_ptrs += BLOCK_SIZE_M * N
|
| 535 |
+
|
| 536 |
+
if not m_divisible_by_block_m:
|
| 537 |
+
offs_lhs_k = (
|
| 538 |
+
tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
| 539 |
+
) % K
|
| 540 |
+
offs_rhs_n = (
|
| 541 |
+
tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 542 |
+
) % N
|
| 543 |
+
offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 544 |
+
lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0)
|
| 545 |
+
rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0)
|
| 546 |
+
acc = tl.dot(lhs, rhs, acc=acc)
|
| 547 |
+
# Accumulate last chunk for bias gradient
|
| 548 |
+
if COMPUTE_BIAS_GRAD and tile_n == 0:
|
| 549 |
+
bias_acc += tl.sum(lhs, axis=1)
|
| 550 |
+
|
| 551 |
+
acc = acc.to(out_ptr.type.element_ty)
|
| 552 |
+
|
| 553 |
+
offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
| 554 |
+
offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 555 |
+
|
| 556 |
+
out_ptrs = (
|
| 557 |
+
out_ptr + g.to(tl.int64) * K * N + offs_out_k[:, None] * N + offs_out_n[None, :]
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N)
|
| 561 |
+
if ACCUMULATE:
|
| 562 |
+
# Load existing values and add to them (like beta=1 in BLAS)
|
| 563 |
+
old_vals = tl.load(out_ptrs, mask=mask, other=0.0)
|
| 564 |
+
tl.store(out_ptrs, acc + old_vals, mask=mask)
|
| 565 |
+
else:
|
| 566 |
+
# Overwrite output (like beta=0 in BLAS)
|
| 567 |
+
tl.store(out_ptrs, acc, mask=mask)
|
| 568 |
+
|
| 569 |
+
# Store bias gradient (only for first N tile, sum across all M)
|
| 570 |
+
if COMPUTE_BIAS_GRAD and tile_n == 0:
|
| 571 |
+
# Keep as float32 for atomic_add (bf16/fp16 not supported for atomics)
|
| 572 |
+
bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k
|
| 573 |
+
# Use atomic add since multiple K-tiles may write to same expert's bias
|
| 574 |
+
tl.atomic_add(bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed")
|
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/adapter.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Adapt AITER's Triton grouped GEMM to MegaBlocks' ``gmm`` calling convention.
|
| 3 |
+
|
| 4 |
+
MegaBlocks (following tgale96/grouped_gemm) uses a single ``gmm`` entry point
|
| 5 |
+
with ``trans_a`` / ``trans_b`` flags:
|
| 6 |
+
|
| 7 |
+
* ``trans_a=False, trans_b=False``: a(M,K) @ b(G,K,N) -> c(M,N)
|
| 8 |
+
* ``trans_a=False, trans_b=True`` : a(M,K) @ b(G,N,K)^T -> c(M,N) (dgrad)
|
| 9 |
+
* ``trans_a=True`` : a(M,K)^T @ b(M,N) per group -> c(G,K,N) (wgrad)
|
| 10 |
+
|
| 11 |
+
AITER exposes these as two kernels: ``gmm`` ((M,K)@(G,K,N)->(M,N), transposition
|
| 12 |
+
of the 3D operand inferred from strides) and ``ptgmm`` ((K,M)@(M,N)->(G,K,N),
|
| 13 |
+
transposition of the 2D operand inferred from strides).
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
from .gmm import gmm as _aiter_gmm
|
| 19 |
+
from .gmm import ptgmm as _aiter_ptgmm
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def gmm(a, b, c, batch_sizes, trans_a=False, trans_b=False):
|
| 23 |
+
# AITER requires group sizes to be int32 and to live on the compute device.
|
| 24 |
+
group_sizes = batch_sizes.to(device=a.device, dtype=torch.int32)
|
| 25 |
+
|
| 26 |
+
# AITER asserts exact strides: gmm wants lhs/rhs row-major (a transposed
|
| 27 |
+
# 3D operand must be exactly column-major), tgmm wants rhs row-major and
|
| 28 |
+
# lhs row/column-major. Make operands contiguous first so the transposed
|
| 29 |
+
# views have the precise strides the kernels expect. `.contiguous()` is a
|
| 30 |
+
# no-op when the tensor is already contiguous.
|
| 31 |
+
if trans_a:
|
| 32 |
+
# Weight gradient: a(M,K), b(M,N) -> c(G,K,N).
|
| 33 |
+
# Pass a transposed so AITER sees lhs(K,M) column-major (TRANS_LHS).
|
| 34 |
+
_aiter_ptgmm(
|
| 35 |
+
a.contiguous().transpose(0, 1),
|
| 36 |
+
b.contiguous(),
|
| 37 |
+
group_sizes,
|
| 38 |
+
preferred_element_type=c.dtype,
|
| 39 |
+
existing_out=c,
|
| 40 |
+
)
|
| 41 |
+
else:
|
| 42 |
+
# trans_b contracts b's last dim: pass a column-major (G,K,N) view.
|
| 43 |
+
rhs = b.contiguous()
|
| 44 |
+
if trans_b:
|
| 45 |
+
rhs = rhs.transpose(1, 2)
|
| 46 |
+
_aiter_gmm(
|
| 47 |
+
a.contiguous(),
|
| 48 |
+
rhs,
|
| 49 |
+
group_sizes,
|
| 50 |
+
preferred_element_type=c.dtype,
|
| 51 |
+
existing_out=c,
|
| 52 |
+
)
|
| 53 |
+
return c
|
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/configs.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
# Tuned GMM configs vendored from ROCm/aiter (aiter/ops/triton/configs/).
|
| 3 |
+
# Inlined as a Python module so packaging always includes them.
|
| 4 |
+
|
| 5 |
+
CONFIGS = {'gfx1250': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx942': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx950': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}}
|
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/gmm.py
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Imports.
|
| 6 |
+
# ------------------------------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
# PyTorch
|
| 9 |
+
import torch
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
|
| 12 |
+
# Triton
|
| 13 |
+
import triton
|
| 14 |
+
|
| 15 |
+
# AITER: GMM utility functions
|
| 16 |
+
from .utils.gmm_common import (
|
| 17 |
+
DTYPE,
|
| 18 |
+
is_power_of_2,
|
| 19 |
+
check_input_device_dtype,
|
| 20 |
+
check_bias_shape_stride,
|
| 21 |
+
get_gmm_shape,
|
| 22 |
+
get_gmm_output,
|
| 23 |
+
get_gmm_transposition,
|
| 24 |
+
get_tgmm_shape,
|
| 25 |
+
get_tgmm_output,
|
| 26 |
+
get_tgmm_bias_grad,
|
| 27 |
+
get_tgmm_transposition,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# AITER: GMM Triton kernels
|
| 31 |
+
from ._triton_kernels.gmm import (
|
| 32 |
+
gmm_kernel,
|
| 33 |
+
tgmm_persistent_kernel,
|
| 34 |
+
tgmm_non_persistent_kernel,
|
| 35 |
+
get_config,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# GMM PyTorch wrapper.
|
| 39 |
+
# ------------------------------------------------------------------------------
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _gmm_grid(
|
| 43 |
+
N: int,
|
| 44 |
+
block_size_m: int,
|
| 45 |
+
block_size_n: int,
|
| 46 |
+
group_sizes: Tensor,
|
| 47 |
+
grid_dim: int,
|
| 48 |
+
) -> tuple[int]:
|
| 49 |
+
assert N > 0, f"N must be positive, it's {N}."
|
| 50 |
+
assert is_power_of_2(
|
| 51 |
+
block_size_m
|
| 52 |
+
), f"M-dimension tile size must be a power of 2 (it's {block_size_m})."
|
| 53 |
+
assert is_power_of_2(
|
| 54 |
+
block_size_n
|
| 55 |
+
), f"N-dimension tile size must be a power of 2 (it's {block_size_n})."
|
| 56 |
+
assert torch.all(group_sizes >= 0).item(), "All group_sizes must be non-negative."
|
| 57 |
+
assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})."
|
| 58 |
+
num_m_tiles = (group_sizes + block_size_m - 1) // block_size_m
|
| 59 |
+
assert torch.all(num_m_tiles >= 0).item(), "All num_m_tiles must be non-negative."
|
| 60 |
+
num_n_tiles = triton.cdiv(N, block_size_n)
|
| 61 |
+
assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}."
|
| 62 |
+
num_tiles = torch.sum(num_m_tiles * num_n_tiles).item()
|
| 63 |
+
assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}."
|
| 64 |
+
num_programs = int(min(grid_dim, num_tiles))
|
| 65 |
+
assert num_programs > 0, f"num_programs must be positive, it's {num_programs}."
|
| 66 |
+
return (num_programs,)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def gmm(
|
| 70 |
+
lhs: Tensor,
|
| 71 |
+
rhs: Tensor,
|
| 72 |
+
group_sizes: Tensor,
|
| 73 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 74 |
+
existing_out: Tensor | None = None,
|
| 75 |
+
config: dict[str, int] | None = None,
|
| 76 |
+
bias: Tensor | None = None,
|
| 77 |
+
) -> Tensor:
|
| 78 |
+
"""
|
| 79 |
+
Perform Group Matrix Multiplication (GMM): out = lhs @ rhs + bias
|
| 80 |
+
|
| 81 |
+
lhs rows are divided into G groups. Each group of lhs rows is matrix multiplied with a plane of
|
| 82 |
+
rhs 3D tensor and then stored in a slice of out. In PyTorch parlance, it can be implemented as
|
| 83 |
+
follows for a given group g:
|
| 84 |
+
out[group_start:group_end, :] = lhs[group_start:group_end, :] @ rhs[g] + bias[g]
|
| 85 |
+
|
| 86 |
+
The size of each group, and their respective start and end positions are specified by
|
| 87 |
+
group_sizes tensor. For instance, suppose that group_sizes = [3, 2, 4, 1]. In this particular
|
| 88 |
+
case we have 4 groups. The 1st group starts at 0 and ends at 2, the second group starts at 3 and
|
| 89 |
+
ends at 4, the third group starts at 5 and ends at 8, and the fourth and final group consists of
|
| 90 |
+
just the 10th (last) row of lhs.
|
| 91 |
+
|
| 92 |
+
Parameters
|
| 93 |
+
----------
|
| 94 |
+
lhs : torch.Tensor
|
| 95 |
+
Left-hand side 2D input tensor. Shape: (M, K).
|
| 96 |
+
lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type.
|
| 97 |
+
lhs must be on the same device of rhs and group_sizes.
|
| 98 |
+
rhs : torch.Tensor
|
| 99 |
+
Right-hand side 3D input tensor. Shape: (G, K, N).
|
| 100 |
+
rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type.
|
| 101 |
+
rhs must be on the same device of lhs and group_sizes.
|
| 102 |
+
group_sizes : torch.Tensor
|
| 103 |
+
1D input tensor describing group sizes. Shape: (G,).
|
| 104 |
+
group_sizes data type must be torch.int32 and all its elements must be non-negative.
|
| 105 |
+
group_sizes must be on the same device of lhs and rhs.
|
| 106 |
+
preferred_element_type : torch.dtype, optional
|
| 107 |
+
Desired data type for output tensor. Default is torch.bfloat16.
|
| 108 |
+
Supported output types are torch.float16 and torch.bfloat16.
|
| 109 |
+
existing_out : torch.Tensor or None, optional
|
| 110 |
+
Preallocated output tensor. Default is None.
|
| 111 |
+
If provided, results are written into this tensor. Otherwise, a new output tensor is
|
| 112 |
+
allocated.
|
| 113 |
+
If provided then it must have shape (M, N), its data type must match preferred_element_type
|
| 114 |
+
and it must be on the same device of other input tensors.
|
| 115 |
+
config : dict[str, int] or None, optional
|
| 116 |
+
Optional dictionary with kernel metaparameters. If absent, config will be queried from
|
| 117 |
+
internal tuning database.
|
| 118 |
+
bias : torch.Tensor or None, optional
|
| 119 |
+
Optional bias tensor. Shape: (G, N).
|
| 120 |
+
If provided, bias data type must match lhs and rhs data type, and bias must be on the same
|
| 121 |
+
device as other input tensors. Each group g adds bias[g] to the output.
|
| 122 |
+
|
| 123 |
+
Returns
|
| 124 |
+
-------
|
| 125 |
+
torch.Tensor
|
| 126 |
+
The computed output 2D tensor. Shape: (M, N).
|
| 127 |
+
Output tensor data type is given by preferred_element_type.
|
| 128 |
+
If existing_out is provided then existing_out is also returned.
|
| 129 |
+
|
| 130 |
+
Implementation Notes
|
| 131 |
+
--------------------
|
| 132 |
+
- GMM is implemented with a persistent Triton kernel.
|
| 133 |
+
- lhs must be row-major (lhs.stride() == (K, 1)).
|
| 134 |
+
- rhs can be row-major (rhs.stride() == (K * N, N, 1)) or column-major (rhs.stride() ==
|
| 135 |
+
(K * N, 1, K)). If rhs is row-major then kernel parameter TRANS_RHS == False, this is useful
|
| 136 |
+
for implementing forward pass. If rhs is column-major then kernel parameter TRANS_RHS == True,
|
| 137 |
+
this is useful for computing the lhs derivative in the backward pass, while fusing the
|
| 138 |
+
transposition.
|
| 139 |
+
- out must be row-major (out.stride() == (N, 1)).
|
| 140 |
+
- bias must be row-major (bias.stride() == (N, 1)) if provided.
|
| 141 |
+
"""
|
| 142 |
+
use_bias = bias is not None
|
| 143 |
+
check_input_device_dtype(lhs, rhs, group_sizes, bias)
|
| 144 |
+
|
| 145 |
+
M, K, N, G = get_gmm_shape(lhs, rhs, group_sizes)
|
| 146 |
+
|
| 147 |
+
if use_bias:
|
| 148 |
+
check_bias_shape_stride(bias, G, N)
|
| 149 |
+
|
| 150 |
+
out = get_gmm_output(
|
| 151 |
+
M,
|
| 152 |
+
N,
|
| 153 |
+
device=lhs.device,
|
| 154 |
+
preferred_element_type=preferred_element_type,
|
| 155 |
+
existing_out=existing_out,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
trans_rhs, _ = get_gmm_transposition(lhs, rhs, out)
|
| 159 |
+
|
| 160 |
+
if config is None:
|
| 161 |
+
config = get_config("gmm", M, K, N, G)
|
| 162 |
+
|
| 163 |
+
assert all(
|
| 164 |
+
key in config
|
| 165 |
+
and isinstance(config[key], int)
|
| 166 |
+
and (
|
| 167 |
+
is_power_of_2(config[key])
|
| 168 |
+
if key.startswith("BLOCK_SIZE_")
|
| 169 |
+
else config[key] > 0
|
| 170 |
+
)
|
| 171 |
+
for key in {
|
| 172 |
+
"BLOCK_SIZE_M",
|
| 173 |
+
"BLOCK_SIZE_K",
|
| 174 |
+
"BLOCK_SIZE_N",
|
| 175 |
+
"GROUP_SIZE",
|
| 176 |
+
"GRID_DIM",
|
| 177 |
+
}
|
| 178 |
+
), "Invalid GMM kernel config."
|
| 179 |
+
|
| 180 |
+
grid = _gmm_grid(
|
| 181 |
+
N,
|
| 182 |
+
config["BLOCK_SIZE_M"],
|
| 183 |
+
config["BLOCK_SIZE_N"],
|
| 184 |
+
group_sizes,
|
| 185 |
+
config["GRID_DIM"],
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# fmt: off
|
| 189 |
+
gmm_kernel[grid](
|
| 190 |
+
# Tensor pointers:
|
| 191 |
+
lhs, rhs, group_sizes, out, bias,
|
| 192 |
+
# Tensor shapes:
|
| 193 |
+
M, K, N, G,
|
| 194 |
+
# Meta-parameters:
|
| 195 |
+
TRANS_RHS=trans_rhs,
|
| 196 |
+
USE_BIAS=use_bias,
|
| 197 |
+
**config,
|
| 198 |
+
)
|
| 199 |
+
# fmt: on
|
| 200 |
+
|
| 201 |
+
return out
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# Persistent TGMM PyTorch wrapper.
|
| 205 |
+
# ------------------------------------------------------------------------------
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def _ptgmm_grid(
|
| 209 |
+
K: int,
|
| 210 |
+
N: int,
|
| 211 |
+
G: int,
|
| 212 |
+
block_size_k: int,
|
| 213 |
+
block_size_n: int,
|
| 214 |
+
grid_dim: int,
|
| 215 |
+
) -> tuple[int]:
|
| 216 |
+
assert K > 0, f"K must be positive, it's {K}."
|
| 217 |
+
assert N > 0, f"N must be positive, it's {N}."
|
| 218 |
+
assert G > 0, f"G must be positive, it's {G}."
|
| 219 |
+
assert is_power_of_2(
|
| 220 |
+
block_size_k
|
| 221 |
+
), f"K-dimension tile size must be a power of 2 (it's {block_size_k})."
|
| 222 |
+
assert is_power_of_2(
|
| 223 |
+
block_size_n
|
| 224 |
+
), f"N-dimension tile size must be a power of 2 (it's {block_size_n})."
|
| 225 |
+
assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})."
|
| 226 |
+
num_k_tiles = triton.cdiv(K, block_size_k)
|
| 227 |
+
assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}."
|
| 228 |
+
num_n_tiles = triton.cdiv(N, block_size_n)
|
| 229 |
+
assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}."
|
| 230 |
+
num_tiles = G * num_k_tiles * num_n_tiles
|
| 231 |
+
assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}."
|
| 232 |
+
num_programs = min(grid_dim, num_tiles)
|
| 233 |
+
assert num_programs > 0, f"num_programs must be positive, it's {num_programs}."
|
| 234 |
+
return (num_programs,)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def ptgmm(
|
| 238 |
+
lhs: Tensor,
|
| 239 |
+
rhs: Tensor,
|
| 240 |
+
group_sizes: Tensor,
|
| 241 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 242 |
+
existing_out: Tensor | None = None,
|
| 243 |
+
config: dict[str, int] | None = None,
|
| 244 |
+
bias_grad: Tensor | None = None,
|
| 245 |
+
accumulate: bool = False,
|
| 246 |
+
) -> Tensor:
|
| 247 |
+
"""
|
| 248 |
+
Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs
|
| 249 |
+
|
| 250 |
+
lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with
|
| 251 |
+
the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch
|
| 252 |
+
parlance, it can be implemented as follows for a given group g:
|
| 253 |
+
out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :]
|
| 254 |
+
|
| 255 |
+
The 't' in the operator name derives from MaxText implementation
|
| 256 |
+
(https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py),
|
| 257 |
+
which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor
|
| 258 |
+
shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N).
|
| 259 |
+
|
| 260 |
+
The 'p' in the operator name means that it is implemented with a persistent kernel. There is
|
| 261 |
+
also the non-persistent variation, which is implemented with a regular kernel. Please take a
|
| 262 |
+
look at nptgmm operator. Both ptgmm and nptgmm implement the same computation, choosing one or
|
| 263 |
+
the other is a matter of performance for the target workload.
|
| 264 |
+
|
| 265 |
+
Parameters
|
| 266 |
+
----------
|
| 267 |
+
lhs : torch.Tensor
|
| 268 |
+
Left-hand side 2D input tensor. Shape: (K, M).
|
| 269 |
+
lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type.
|
| 270 |
+
lhs must be on the same device of rhs and group_sizes.
|
| 271 |
+
rhs : torch.Tensor
|
| 272 |
+
Right-hand side 2D input tensor. Shape: (M, N).
|
| 273 |
+
rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type.
|
| 274 |
+
rhs must be on the same device of lhs and group_sizes.
|
| 275 |
+
group_sizes : torch.Tensor
|
| 276 |
+
1D input tensor describing group sizes. Shape: (G,).
|
| 277 |
+
group_sizes data type must be torch.int32 and all its elements must be non-negative.
|
| 278 |
+
group_sizes must be on the same device of lhs and rhs.
|
| 279 |
+
preferred_element_type : torch.dtype, optional
|
| 280 |
+
Desired data type for output tensor. Default is torch.bfloat16.
|
| 281 |
+
Supported output types are torch.float16 and torch.bfloat16.
|
| 282 |
+
existing_out : torch.Tensor or None, optional
|
| 283 |
+
Preallocated output tensor. Default is None.
|
| 284 |
+
If provided, results are written into this tensor. Otherwise, a new output tensor is
|
| 285 |
+
allocated.
|
| 286 |
+
If provided then it must have shape (G, K, N), its data type must match
|
| 287 |
+
preferred_element_type and it must be on the same device of other input tensors.
|
| 288 |
+
config : dict[str, int] or None, optional
|
| 289 |
+
Optional dictionary with kernel metaparameters. If absent, config will be queried from
|
| 290 |
+
internal tuning database.
|
| 291 |
+
bias_grad : torch.Tensor or None, optional
|
| 292 |
+
Optional bias gradient output tensor. Shape: (G, K).
|
| 293 |
+
If provided, the kernel will compute the bias gradient and write it to this tensor.
|
| 294 |
+
bias_grad must be torch.float32 (kernel uses atomic_add which requires float32),
|
| 295 |
+
accumulate : bool, optional
|
| 296 |
+
Whether to accumulate into existing output tensor values. Default is False.
|
| 297 |
+
If False, output will be overwritten with fresh computation.
|
| 298 |
+
If True, results will be added to existing output tensor values.
|
| 299 |
+
|
| 300 |
+
Returns
|
| 301 |
+
-------
|
| 302 |
+
torch.Tensor
|
| 303 |
+
The computed output 3D tensor. Shape: (G, K, N).
|
| 304 |
+
Output tensor data type is given by preferred_element_type.
|
| 305 |
+
If existing_out is provided then existing_out is also returned.
|
| 306 |
+
|
| 307 |
+
Implementation Notes
|
| 308 |
+
--------------------
|
| 309 |
+
- PTGMM is implemented with a persistent Triton kernel.
|
| 310 |
+
- lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs
|
| 311 |
+
is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel
|
| 312 |
+
parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward
|
| 313 |
+
pass, while fusing the transposition.
|
| 314 |
+
- rhs must be row-major (rhs.stride() == (N, 1)).
|
| 315 |
+
- out must be row-major (out.stride() == (K * N, N, 1)).
|
| 316 |
+
"""
|
| 317 |
+
check_input_device_dtype(lhs, rhs, group_sizes)
|
| 318 |
+
|
| 319 |
+
M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes)
|
| 320 |
+
|
| 321 |
+
out = get_tgmm_output(
|
| 322 |
+
K,
|
| 323 |
+
N,
|
| 324 |
+
G,
|
| 325 |
+
device=lhs.device,
|
| 326 |
+
preferred_element_type=preferred_element_type,
|
| 327 |
+
existing_out=existing_out,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out)
|
| 331 |
+
|
| 332 |
+
if config is None:
|
| 333 |
+
config = get_config("ptgmm", M, K, N, G, accumulate)
|
| 334 |
+
|
| 335 |
+
assert all(
|
| 336 |
+
key in config
|
| 337 |
+
and isinstance(config[key], int)
|
| 338 |
+
and (
|
| 339 |
+
is_power_of_2(config[key])
|
| 340 |
+
if key.startswith("BLOCK_SIZE_")
|
| 341 |
+
else config[key] > 0
|
| 342 |
+
)
|
| 343 |
+
for key in {
|
| 344 |
+
"BLOCK_SIZE_M",
|
| 345 |
+
"BLOCK_SIZE_K",
|
| 346 |
+
"BLOCK_SIZE_N",
|
| 347 |
+
"GROUP_SIZE",
|
| 348 |
+
"GRID_DIM",
|
| 349 |
+
}
|
| 350 |
+
), "Invalid PTGMM kernel config."
|
| 351 |
+
|
| 352 |
+
# Bias gradient handling.
|
| 353 |
+
# -----------------------
|
| 354 |
+
# Get or validate bias gradient tensor.
|
| 355 |
+
compute_bias_grad = bias_grad is not None
|
| 356 |
+
bias_grad_ptr = get_tgmm_bias_grad(
|
| 357 |
+
K,
|
| 358 |
+
G,
|
| 359 |
+
device=lhs.device,
|
| 360 |
+
existing_bias_grad=bias_grad,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
grid = _ptgmm_grid(
|
| 364 |
+
K,
|
| 365 |
+
N,
|
| 366 |
+
G,
|
| 367 |
+
config["BLOCK_SIZE_K"],
|
| 368 |
+
config["BLOCK_SIZE_N"],
|
| 369 |
+
config["GRID_DIM"],
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# fmt: off
|
| 373 |
+
tgmm_persistent_kernel[grid](
|
| 374 |
+
# Tensor pointers:
|
| 375 |
+
lhs, rhs, group_sizes, out, bias_grad_ptr,
|
| 376 |
+
# Tensor shapes:
|
| 377 |
+
M, K, N, G,
|
| 378 |
+
# Meta-parameters:
|
| 379 |
+
TRANS_LHS=trans_lhs,
|
| 380 |
+
COMPUTE_BIAS_GRAD=compute_bias_grad,
|
| 381 |
+
ACCUMULATE=accumulate,
|
| 382 |
+
**config,
|
| 383 |
+
)
|
| 384 |
+
# fmt: on
|
| 385 |
+
|
| 386 |
+
return out
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
# Regular non-persistent TGMM PyTorch wrapper.
|
| 390 |
+
# ------------------------------------------------------------------------------
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def _nptgmm_grid(
|
| 394 |
+
K: int,
|
| 395 |
+
N: int,
|
| 396 |
+
G: int,
|
| 397 |
+
block_size_k: int,
|
| 398 |
+
block_size_n: int,
|
| 399 |
+
) -> tuple[int, int]:
|
| 400 |
+
assert K > 0, f"K must be positive, it's {K}."
|
| 401 |
+
assert N > 0, f"N must be positive, it's {N}."
|
| 402 |
+
assert G > 0, f"G must be positive, it's {G}."
|
| 403 |
+
assert is_power_of_2(
|
| 404 |
+
block_size_k
|
| 405 |
+
), f"K-dimension tile size must be a power of 2 (it's {block_size_k})."
|
| 406 |
+
assert is_power_of_2(
|
| 407 |
+
block_size_n
|
| 408 |
+
), f"N-dimension tile size must be a power of 2 (it's {block_size_n})."
|
| 409 |
+
num_k_tiles = triton.cdiv(K, block_size_k)
|
| 410 |
+
assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}."
|
| 411 |
+
num_n_tiles = triton.cdiv(N, block_size_n)
|
| 412 |
+
assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}."
|
| 413 |
+
num_tiles_per_mm = num_k_tiles * num_n_tiles
|
| 414 |
+
assert (
|
| 415 |
+
num_tiles_per_mm > 0
|
| 416 |
+
), f"num_tiles_per_mm must be positive, it's {num_tiles_per_mm}."
|
| 417 |
+
return (G, num_tiles_per_mm)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def nptgmm(
|
| 421 |
+
lhs: Tensor,
|
| 422 |
+
rhs: Tensor,
|
| 423 |
+
group_sizes: Tensor,
|
| 424 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 425 |
+
existing_out: Tensor | None = None,
|
| 426 |
+
config: dict[str, int] | None = None,
|
| 427 |
+
bias_grad: Tensor | None = None,
|
| 428 |
+
accumulate: bool = False,
|
| 429 |
+
) -> Tensor:
|
| 430 |
+
"""
|
| 431 |
+
Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs
|
| 432 |
+
|
| 433 |
+
lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with
|
| 434 |
+
the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch
|
| 435 |
+
parlance, it can be implemented as follows for a given group g:
|
| 436 |
+
out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :]
|
| 437 |
+
|
| 438 |
+
The 't' in the operator name derives from MaxText implementation
|
| 439 |
+
(https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py),
|
| 440 |
+
which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor
|
| 441 |
+
shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N).
|
| 442 |
+
|
| 443 |
+
The 'np' in the operator name means that it is implemented with a non-persistent, i.e. regular
|
| 444 |
+
kernel. There is also the persistent variation, which is implemented with a persistent kernel.
|
| 445 |
+
Please take a look at ptgmm operator. Both nptgmm and ptgmm implement the same computation,
|
| 446 |
+
choosing one or the other is a matter of performance for the target workload.
|
| 447 |
+
|
| 448 |
+
Parameters
|
| 449 |
+
----------
|
| 450 |
+
lhs : torch.Tensor
|
| 451 |
+
Left-hand side 2D input tensor. Shape: (K, M).
|
| 452 |
+
lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type.
|
| 453 |
+
lhs must be on the same device of rhs and group_sizes.
|
| 454 |
+
rhs : torch.Tensor
|
| 455 |
+
Right-hand side 2D input tensor. Shape: (M, N).
|
| 456 |
+
rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type.
|
| 457 |
+
rhs must be on the same device of lhs and group_sizes.
|
| 458 |
+
group_sizes : torch.Tensor
|
| 459 |
+
1D input tensor describing group sizes. Shape: (G,).
|
| 460 |
+
group_sizes data type must be torch.int32 and all its elements must be non-negative.
|
| 461 |
+
group_sizes must be on the same device of lhs and rhs.
|
| 462 |
+
preferred_element_type : torch.dtype, optional
|
| 463 |
+
Desired data type for output tensor. Default is torch.bfloat16.
|
| 464 |
+
Supported output types are torch.float16 and torch.bfloat16.
|
| 465 |
+
existing_out : torch.Tensor or None, optional
|
| 466 |
+
Preallocated output tensor. Default is None.
|
| 467 |
+
If provided, results are written into this tensor. Otherwise, a new output tensor is
|
| 468 |
+
allocated.
|
| 469 |
+
If provided then it must have shape (G, K, N), its data type must match
|
| 470 |
+
preferred_element_type and it must be on the same device of other input tensors.
|
| 471 |
+
config : dict[str, int] or None, optional
|
| 472 |
+
Optional dictionary with kernel metaparameters. If absent, config will be queried from
|
| 473 |
+
internal tuning database.
|
| 474 |
+
bias_grad : torch.Tensor or None, optional
|
| 475 |
+
Optional bias gradient output tensor. Shape: (G, K).
|
| 476 |
+
If provided, the kernel will compute the bias gradient and write it to this tensor.
|
| 477 |
+
bias_grad must be torch.float32 (kernel uses atomic_add which requires float32),
|
| 478 |
+
accumulate : bool, optional
|
| 479 |
+
Whether to accumulate into existing output tensor values. Default is False.
|
| 480 |
+
If False, output will be overwritten with fresh computation.
|
| 481 |
+
If True, results will be added to existing output tensor values.
|
| 482 |
+
|
| 483 |
+
Returns
|
| 484 |
+
-------
|
| 485 |
+
torch.Tensor
|
| 486 |
+
The computed output 3D tensor. Shape: (G, K, N).
|
| 487 |
+
Output tensor data type is given by preferred_element_type.
|
| 488 |
+
If existing_out is provided then existing_out is also returned.
|
| 489 |
+
|
| 490 |
+
Implementation Notes
|
| 491 |
+
--------------------
|
| 492 |
+
- NPTGMM is implemented with a non-persistent regular Triton kernel.
|
| 493 |
+
- lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs
|
| 494 |
+
is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel
|
| 495 |
+
parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward
|
| 496 |
+
pass, while fusing the transposition.
|
| 497 |
+
- rhs must be row-major (rhs.stride() == (N, 1)).
|
| 498 |
+
- out must be row-major (out.stride() == (K * N, N, 1)).
|
| 499 |
+
"""
|
| 500 |
+
check_input_device_dtype(lhs, rhs, group_sizes)
|
| 501 |
+
|
| 502 |
+
M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes)
|
| 503 |
+
|
| 504 |
+
out = get_tgmm_output(
|
| 505 |
+
K,
|
| 506 |
+
N,
|
| 507 |
+
G,
|
| 508 |
+
device=lhs.device,
|
| 509 |
+
preferred_element_type=preferred_element_type,
|
| 510 |
+
existing_out=existing_out,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out)
|
| 514 |
+
|
| 515 |
+
# Bias gradient handling.
|
| 516 |
+
# -----------------------
|
| 517 |
+
# Get or validate bias gradient tensor.
|
| 518 |
+
compute_bias_grad = bias_grad is not None
|
| 519 |
+
bias_grad_ptr = get_tgmm_bias_grad(
|
| 520 |
+
K,
|
| 521 |
+
G,
|
| 522 |
+
device=lhs.device,
|
| 523 |
+
existing_bias_grad=bias_grad,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
if config is None:
|
| 527 |
+
config = get_config("nptgmm", M, K, N, G, accumulate)
|
| 528 |
+
|
| 529 |
+
assert all(
|
| 530 |
+
key in config
|
| 531 |
+
and isinstance(config[key], int)
|
| 532 |
+
and (
|
| 533 |
+
is_power_of_2(config[key])
|
| 534 |
+
if key.startswith("BLOCK_SIZE_")
|
| 535 |
+
else config[key] > 0
|
| 536 |
+
)
|
| 537 |
+
for key in {
|
| 538 |
+
"BLOCK_SIZE_M",
|
| 539 |
+
"BLOCK_SIZE_K",
|
| 540 |
+
"BLOCK_SIZE_N",
|
| 541 |
+
"GROUP_SIZE",
|
| 542 |
+
}
|
| 543 |
+
), "Invalid NPTGMM kernel config."
|
| 544 |
+
|
| 545 |
+
grid = _nptgmm_grid(
|
| 546 |
+
K,
|
| 547 |
+
N,
|
| 548 |
+
G,
|
| 549 |
+
config["BLOCK_SIZE_K"],
|
| 550 |
+
config["BLOCK_SIZE_N"],
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
# fmt: off
|
| 554 |
+
tgmm_non_persistent_kernel[grid](
|
| 555 |
+
# Tensor pointers:
|
| 556 |
+
lhs, rhs, group_sizes, out, bias_grad_ptr,
|
| 557 |
+
# Tensor shapes:
|
| 558 |
+
M, K, N, G,
|
| 559 |
+
# Meta-parameters:
|
| 560 |
+
TRANS_LHS=trans_lhs,
|
| 561 |
+
COMPUTE_BIAS_GRAD=compute_bias_grad,
|
| 562 |
+
ACCUMULATE=accumulate,
|
| 563 |
+
**config,
|
| 564 |
+
)
|
| 565 |
+
# fmt: on
|
| 566 |
+
|
| 567 |
+
return out
|
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/__init__.py
ADDED
|
File without changes
|
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py
ADDED
|
File without changes
|
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import triton
|
| 2 |
+
|
| 3 |
+
# Detect the GPU arch lazily: querying the triton driver at import time fails
|
| 4 |
+
# in headless environments (e.g. the kernel-builder ABI check sandbox has no
|
| 5 |
+
# GPU), and the original JAX fallback pulled in an unrelated runtime dep. The
|
| 6 |
+
# arch is only actually needed when a GMM kernel is dispatched, so resolve and
|
| 7 |
+
# cache on first call.
|
| 8 |
+
_CACHED_ARCH = None
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_arch():
|
| 12 |
+
global _CACHED_ARCH
|
| 13 |
+
if _CACHED_ARCH is not None:
|
| 14 |
+
return _CACHED_ARCH
|
| 15 |
+
try:
|
| 16 |
+
_CACHED_ARCH = triton.runtime.driver.active.get_current_target().arch
|
| 17 |
+
except RuntimeError:
|
| 18 |
+
try:
|
| 19 |
+
from jax._src.lib import gpu_triton as triton_kernel_call_lib
|
| 20 |
+
_CACHED_ARCH = triton_kernel_call_lib.get_arch_details("0").split(":")[0]
|
| 21 |
+
except ImportError as e:
|
| 22 |
+
raise RuntimeError(
|
| 23 |
+
"Cannot determine GPU arch: triton driver is inactive and "
|
| 24 |
+
"JAX is not available. A GPU is required for grouped GEMM."
|
| 25 |
+
) from e
|
| 26 |
+
return _CACHED_ARCH
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def is_gluon_avail():
|
| 30 |
+
return get_arch() in ("gfx950", "gfx1250")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def is_fp4_avail():
|
| 34 |
+
return get_arch() in ("gfx950", "gfx1250")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def is_fp8_avail():
|
| 38 |
+
return get_arch() in ("gfx942", "gfx950", "gfx1250", "gfx1200", "gfx1201")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def is_mx_scale_preshuffling_avail():
|
| 42 |
+
return get_arch() in ("gfx950", "gfx1250")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def is_tdm_avail():
|
| 46 |
+
return get_arch() in ("gfx1250",)
|
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
|
| 3 |
+
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
|
| 4 |
+
|
| 5 |
+
import triton
|
| 6 |
+
import triton.language as tl
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@triton.jit
|
| 10 |
+
def remap_xcd_chunked(
|
| 11 |
+
pid, GRID_MN, NUM_XCDS: tl.constexpr = 8, CHUNK_SIZE: tl.constexpr = 2
|
| 12 |
+
):
|
| 13 |
+
# Compute current XCD and local PID
|
| 14 |
+
xcd = pid % NUM_XCDS
|
| 15 |
+
# distribute the modulo pids in round robin
|
| 16 |
+
if pid > (GRID_MN // (NUM_XCDS * CHUNK_SIZE)) * (NUM_XCDS * CHUNK_SIZE):
|
| 17 |
+
return pid
|
| 18 |
+
local_pid = pid // NUM_XCDS
|
| 19 |
+
# Calculate chunk index and position within chunk
|
| 20 |
+
chunk_idx = local_pid // CHUNK_SIZE
|
| 21 |
+
pos_in_chunk = local_pid % CHUNK_SIZE
|
| 22 |
+
# Calculate new PID
|
| 23 |
+
new_pid = chunk_idx * NUM_XCDS * CHUNK_SIZE + xcd * CHUNK_SIZE + pos_in_chunk
|
| 24 |
+
return new_pid
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@triton.jit
|
| 28 |
+
def remap_xcd(pid, GRID_MN, NUM_XCDS: tl.constexpr = 8):
|
| 29 |
+
## pid remapping on xcds
|
| 30 |
+
# Number of pids per XCD in the new arrangement
|
| 31 |
+
pids_per_xcd = (GRID_MN + NUM_XCDS - 1) // NUM_XCDS
|
| 32 |
+
# When GRID_MN cannot divide NUM_XCDS, some xcds will have
|
| 33 |
+
# pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
|
| 34 |
+
# We calculate the number of xcds that have pids_per_xcd pids as
|
| 35 |
+
# tall_xcds
|
| 36 |
+
tall_xcds = GRID_MN % NUM_XCDS
|
| 37 |
+
tall_xcds = NUM_XCDS if tall_xcds == 0 else tall_xcds
|
| 38 |
+
# Compute current XCD and local pid within the XCD
|
| 39 |
+
xcd = pid % NUM_XCDS
|
| 40 |
+
local_pid = pid // NUM_XCDS
|
| 41 |
+
# Calculate new pid based on the new grouping
|
| 42 |
+
# Note that we need to consider the following two cases:
|
| 43 |
+
# 1. the current pid is on a tall xcd
|
| 44 |
+
# 2. the current pid is on a short xcd
|
| 45 |
+
if xcd < tall_xcds:
|
| 46 |
+
pid = xcd * pids_per_xcd + local_pid
|
| 47 |
+
else:
|
| 48 |
+
pid = (
|
| 49 |
+
tall_xcds * pids_per_xcd
|
| 50 |
+
+ (xcd - tall_xcds) * (pids_per_xcd - 1)
|
| 51 |
+
+ local_pid
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return pid
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@triton.jit
|
| 58 |
+
def pid_grid(pid: int, num_pid_m: int, num_pid_n: int, GROUP_SIZE_M: tl.constexpr = 1):
|
| 59 |
+
"""
|
| 60 |
+
Maps 1D pid to 2D grid coords (pid_m, pid_n).
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
- pid: 1D pid
|
| 64 |
+
- num_pid_m: grid m size
|
| 65 |
+
- num_pid_n: grid n size
|
| 66 |
+
- GROUP_SIZE_M: tl.constexpr: default is 1
|
| 67 |
+
"""
|
| 68 |
+
if GROUP_SIZE_M == 1:
|
| 69 |
+
pid_m = pid // num_pid_n
|
| 70 |
+
pid_n = pid % num_pid_n
|
| 71 |
+
else:
|
| 72 |
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
| 73 |
+
group_id = pid // num_pid_in_group
|
| 74 |
+
first_pid_m = group_id * GROUP_SIZE_M
|
| 75 |
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
| 76 |
+
tl.assume(group_size_m >= 0)
|
| 77 |
+
pid_m = first_pid_m + (pid % group_size_m)
|
| 78 |
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
| 79 |
+
|
| 80 |
+
return pid_m, pid_n
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@triton.jit
|
| 84 |
+
def pid_grid_3d(pid: int, num_pid_m: int, num_pid_n: int, num_pid_k):
|
| 85 |
+
"""
|
| 86 |
+
Maps 1D pid to 3D grid coords (pid_m, pid_n, pid_k).
|
| 87 |
+
Args:
|
| 88 |
+
- pid: 1D pid
|
| 89 |
+
- num_pid_m: grid m size
|
| 90 |
+
- num_pid_n: grid n size
|
| 91 |
+
- num_pid_k: grid k size
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
- pid_m, pid_n, pid_k: 3D grid coordinates
|
| 95 |
+
"""
|
| 96 |
+
pid_m = pid % num_pid_m
|
| 97 |
+
pid_n = (pid // num_pid_m) % num_pid_n
|
| 98 |
+
pid_k = pid // (num_pid_m * num_pid_n) % num_pid_k
|
| 99 |
+
|
| 100 |
+
return pid_m, pid_n, pid_k
|
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py
ADDED
|
@@ -0,0 +1,752 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
| 3 |
+
|
| 4 |
+
# Imports.
|
| 5 |
+
# ------------------------------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
# PyTorch
|
| 8 |
+
import torch
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
# AITER: logging
|
| 12 |
+
from .logger import AiterTritonLogger
|
| 13 |
+
|
| 14 |
+
_LOGGER: AiterTritonLogger = AiterTritonLogger()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Supported data types.
|
| 18 |
+
# ------------------------------------------------------------------------------
|
| 19 |
+
|
| 20 |
+
# Supported data types, as strings.
|
| 21 |
+
SUPPORTED_DTYPES_STR: set[str] = {"fp16", "bf16"}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Convert string data type to PyTorch data type.
|
| 25 |
+
def dtype_from_str(dtype_str: str) -> torch.dtype:
|
| 26 |
+
dtype_str = dtype_str.strip().lower()
|
| 27 |
+
dtype_str = dtype_str[1:] if dtype_str[0] in {"i", "o"} else dtype_str
|
| 28 |
+
assert (
|
| 29 |
+
dtype_str in SUPPORTED_DTYPES_STR
|
| 30 |
+
), "String data type isn't in set of supported string data types."
|
| 31 |
+
return {"fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Supported data types, as PyTorch types.
|
| 35 |
+
SUPPORTED_DTYPES: set[torch.dtype] = {
|
| 36 |
+
dtype_from_str(dtype_str) for dtype_str in SUPPORTED_DTYPES_STR
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Convert PyTorch data type to string data type.
|
| 41 |
+
def str_from_dtype(dtype: torch.dtype) -> str:
|
| 42 |
+
assert (
|
| 43 |
+
dtype in SUPPORTED_DTYPES
|
| 44 |
+
), "PyTorch data type isn't in set of supported PyTorch data types."
|
| 45 |
+
return {torch.float16: "fp16", torch.bfloat16: "bf16"}[dtype]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Default data type, as string.
|
| 49 |
+
DTYPE_STR: str = "bf16"
|
| 50 |
+
assert (
|
| 51 |
+
DTYPE_STR in SUPPORTED_DTYPES_STR
|
| 52 |
+
), "Default string data type isn't in set of supported string data types."
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Default data type, as PyTorch type.
|
| 56 |
+
DTYPE: torch.dtype = dtype_from_str(DTYPE_STR)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Other defaults.
|
| 60 |
+
# ------------------------------------------------------------------------------
|
| 61 |
+
|
| 62 |
+
# Default device.
|
| 63 |
+
DEVICE: torch.device | str = "cuda"
|
| 64 |
+
|
| 65 |
+
# Default RNG seed for input generation.
|
| 66 |
+
RNG_SEED: int = 0
|
| 67 |
+
|
| 68 |
+
# Default number of group sizes.
|
| 69 |
+
NUM_GROUP_SIZES: int = 1
|
| 70 |
+
|
| 71 |
+
# Default transposition (NN).
|
| 72 |
+
TRANS_LHS: bool = False
|
| 73 |
+
TRANS_RHS: bool = False
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Parameter checking functions.
|
| 77 |
+
# ------------------------------------------------------------------------------
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def is_power_of_2(x: int) -> bool:
|
| 81 |
+
return (x > 0) and (x & (x - 1) == 0)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def check_input_device_dtype(
|
| 85 |
+
lhs: Tensor, rhs: Tensor, group_sizes: Tensor, bias: Tensor | None = None
|
| 86 |
+
) -> None:
|
| 87 |
+
assert (
|
| 88 |
+
lhs.device == rhs.device == group_sizes.device
|
| 89 |
+
), f"All input tensors must be in the same device (lhs = {lhs.device}, rhs = {rhs.device}, group_sizes = {group_sizes.device})."
|
| 90 |
+
assert (
|
| 91 |
+
lhs.dtype == rhs.dtype
|
| 92 |
+
), f"lhs and rhs types must match (lhs = {lhs.dtype}, rhs = {rhs.dtype})."
|
| 93 |
+
assert group_sizes.dtype == torch.int32, "group_sizes type must be int32."
|
| 94 |
+
|
| 95 |
+
if bias is not None:
|
| 96 |
+
assert (
|
| 97 |
+
bias.device == lhs.device
|
| 98 |
+
), f"bias must be on the same device as lhs (bias = {bias.device}, lhs = {lhs.device})."
|
| 99 |
+
assert (
|
| 100 |
+
bias.dtype == lhs.dtype
|
| 101 |
+
), f"bias dtype must match lhs dtype (bias = {bias.dtype}, lhs = {lhs.dtype})."
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def check_bias_shape_stride(bias: Tensor, G: int, N: int) -> None:
|
| 105 |
+
assert bias.shape == (
|
| 106 |
+
G,
|
| 107 |
+
N,
|
| 108 |
+
), f"bias must have shape (G, N) = ({G}, {N}), got {bias.shape}."
|
| 109 |
+
assert bias.stride() == (N, 1), "bias must be row-major (bias.stride() == (N, 1))."
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Generation of group sizes.
|
| 113 |
+
# ------------------------------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Probabilities for generating random group sizes.
|
| 117 |
+
UNUSED_TOKENS_PROB: float = 0.0
|
| 118 |
+
UNUSED_EXPERTS_PROB: float = 0.1
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def gen_uniform_group_sizes(
|
| 122 |
+
M: int,
|
| 123 |
+
G: int,
|
| 124 |
+
device: torch.device | str = DEVICE,
|
| 125 |
+
) -> Tensor:
|
| 126 |
+
assert M >= 0, f"Number of tokens M must be non-negative (it's {M})."
|
| 127 |
+
assert G > 0, f"Number of experts G must be positive (it's {G})."
|
| 128 |
+
|
| 129 |
+
base = M // G
|
| 130 |
+
remainder = M % G
|
| 131 |
+
group_sizes = torch.full((G,), base, dtype=torch.int32, device=device)
|
| 132 |
+
if remainder > 0:
|
| 133 |
+
group_sizes[:remainder] += 1
|
| 134 |
+
|
| 135 |
+
assert (
|
| 136 |
+
len(group_sizes) == G
|
| 137 |
+
), f"Group sizes don't have {G} elements (it's {len(group_sizes)})."
|
| 138 |
+
assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative."
|
| 139 |
+
assert (
|
| 140 |
+
torch.sum(group_sizes).item() == M
|
| 141 |
+
), f"Group sizes don't add up to total tokens {M}."
|
| 142 |
+
assert group_sizes.dtype == torch.int32, "Group sizes must be int32."
|
| 143 |
+
|
| 144 |
+
return group_sizes
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def gen_group_sizes(
|
| 148 |
+
M: int,
|
| 149 |
+
G: int,
|
| 150 |
+
device: torch.device | str = DEVICE,
|
| 151 |
+
rng_seed: int | None = RNG_SEED,
|
| 152 |
+
unused_tokens_prob: float = UNUSED_TOKENS_PROB,
|
| 153 |
+
unused_experts_prob: float = UNUSED_EXPERTS_PROB,
|
| 154 |
+
) -> Tensor:
|
| 155 |
+
assert M >= 0, f"Number of tokens M must be non-negative (it's {M})."
|
| 156 |
+
assert G > 0, f"Number of experts G must be positive (it's {G})."
|
| 157 |
+
assert (
|
| 158 |
+
0 <= unused_tokens_prob <= 1
|
| 159 |
+
), f"Probability of unused tokens must be in [0, 1] interval (it's {unused_tokens_prob})."
|
| 160 |
+
assert (
|
| 161 |
+
0 <= unused_experts_prob <= 1
|
| 162 |
+
), f"Probability of unused experts must be in [0, 1] interval (it's {unused_experts_prob})."
|
| 163 |
+
|
| 164 |
+
if rng_seed is not None:
|
| 165 |
+
torch.manual_seed(rng_seed)
|
| 166 |
+
|
| 167 |
+
if unused_tokens_prob > 0:
|
| 168 |
+
# Optionally drop tokens to simulate routing sparsity, some tokens may not be routed.
|
| 169 |
+
num_unused_tokens = M
|
| 170 |
+
while num_unused_tokens == M:
|
| 171 |
+
num_unused_tokens = int(
|
| 172 |
+
torch.binomial(
|
| 173 |
+
torch.tensor(float(M), device=device),
|
| 174 |
+
torch.tensor(unused_tokens_prob, device=device),
|
| 175 |
+
).item()
|
| 176 |
+
)
|
| 177 |
+
else:
|
| 178 |
+
num_unused_tokens = 0
|
| 179 |
+
num_used_tokens = M - num_unused_tokens
|
| 180 |
+
assert (
|
| 181 |
+
num_unused_tokens >= 0
|
| 182 |
+
), f"Number of unused tokens must be non-negative (it's {num_unused_tokens})."
|
| 183 |
+
assert (
|
| 184 |
+
num_used_tokens > 0
|
| 185 |
+
), f"Number of used tokens must be positive (it's {num_used_tokens})."
|
| 186 |
+
assert (
|
| 187 |
+
num_used_tokens + num_unused_tokens == M
|
| 188 |
+
), f"Unused + used tokens don't add up total tokens ({num_used_tokens} + {num_unused_tokens} != {M})."
|
| 189 |
+
|
| 190 |
+
if num_unused_tokens > 0:
|
| 191 |
+
_LOGGER.debug(
|
| 192 |
+
f"Group sizes generation: dropped {num_unused_tokens} token{'s' if num_unused_tokens > 1 else ''}.",
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if unused_experts_prob > 0:
|
| 196 |
+
# Some experts may have zero tokens assigned to them.
|
| 197 |
+
num_used_experts = 0
|
| 198 |
+
while num_used_experts == 0:
|
| 199 |
+
used_experts = torch.nonzero(
|
| 200 |
+
torch.rand((G,), device=device) >= unused_experts_prob
|
| 201 |
+
).squeeze()
|
| 202 |
+
num_used_experts = used_experts.numel()
|
| 203 |
+
else:
|
| 204 |
+
used_experts = torch.arange(0, G, device=device)
|
| 205 |
+
num_used_experts = G
|
| 206 |
+
num_unused_experts = G - num_used_experts
|
| 207 |
+
assert (
|
| 208 |
+
num_unused_experts >= 0
|
| 209 |
+
), f"Number of unused experts must be non-negative (it's {num_unused_experts})."
|
| 210 |
+
assert (
|
| 211 |
+
num_used_experts >= 1
|
| 212 |
+
), f"At least one expert must be used (it's {num_used_experts})."
|
| 213 |
+
assert (
|
| 214 |
+
num_unused_experts + num_used_experts == G
|
| 215 |
+
), f"Unused + used experts don't add up total experts ({num_unused_experts} + {num_used_experts} != {G})."
|
| 216 |
+
|
| 217 |
+
if num_unused_experts > 0:
|
| 218 |
+
_LOGGER.debug(
|
| 219 |
+
f"Group sizes generation: dropped {num_unused_experts} expert{'s' if num_unused_experts > 1 else ''}.",
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
group_sizes = torch.bincount(
|
| 223 |
+
used_experts[
|
| 224 |
+
torch.randint(low=0, high=num_used_experts, size=(num_used_tokens,))
|
| 225 |
+
],
|
| 226 |
+
minlength=G,
|
| 227 |
+
).to(torch.int32)
|
| 228 |
+
|
| 229 |
+
assert (
|
| 230 |
+
len(group_sizes) == G
|
| 231 |
+
), f"Group sizes don't have {G} elements (it's {len(group_sizes)})."
|
| 232 |
+
assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative."
|
| 233 |
+
assert (
|
| 234 |
+
torch.sum(group_sizes).item() == num_used_tokens
|
| 235 |
+
), f"Group sizes don't add up to used tokens {num_used_tokens}."
|
| 236 |
+
assert group_sizes.dtype == torch.int32, "Group sizes must be int32."
|
| 237 |
+
|
| 238 |
+
return group_sizes
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def gen_multiple_group_sizes(
|
| 242 |
+
num_group_sizes: int,
|
| 243 |
+
M: int,
|
| 244 |
+
G: int,
|
| 245 |
+
device: torch.device | str = DEVICE,
|
| 246 |
+
rng_seed: int | None = RNG_SEED,
|
| 247 |
+
unused_tokens_prob: float = UNUSED_TOKENS_PROB,
|
| 248 |
+
unused_experts_prob: float = UNUSED_EXPERTS_PROB,
|
| 249 |
+
group_sizes_0: Tensor | None = None,
|
| 250 |
+
) -> list[Tensor]:
|
| 251 |
+
assert (
|
| 252 |
+
num_group_sizes > 0
|
| 253 |
+
), f"Number of group sizes to be generated must be positive, it's {num_group_sizes}."
|
| 254 |
+
multiple_group_sizes = [
|
| 255 |
+
gen_group_sizes(
|
| 256 |
+
M,
|
| 257 |
+
G,
|
| 258 |
+
device=device,
|
| 259 |
+
rng_seed=rng_seed if g == 0 else None,
|
| 260 |
+
unused_tokens_prob=unused_tokens_prob,
|
| 261 |
+
unused_experts_prob=unused_experts_prob,
|
| 262 |
+
)
|
| 263 |
+
for g in range(
|
| 264 |
+
num_group_sizes if group_sizes_0 is None else num_group_sizes - 1
|
| 265 |
+
)
|
| 266 |
+
]
|
| 267 |
+
if group_sizes_0 is not None:
|
| 268 |
+
multiple_group_sizes.insert(0, group_sizes_0)
|
| 269 |
+
assert (
|
| 270 |
+
len(multiple_group_sizes) == num_group_sizes
|
| 271 |
+
), f"Expecting {num_group_sizes} distinct group sizes (it's {len(multiple_group_sizes)})."
|
| 272 |
+
return multiple_group_sizes
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# GMM helpers: tensor generation.
|
| 276 |
+
# ------------------------------------------------------------------------------
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def gen_gmm_input(
|
| 280 |
+
M: int,
|
| 281 |
+
K: int,
|
| 282 |
+
N: int,
|
| 283 |
+
G: int,
|
| 284 |
+
device: torch.device | str = DEVICE,
|
| 285 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 286 |
+
trans_rhs: bool = TRANS_RHS,
|
| 287 |
+
rng_seed: int | None = RNG_SEED,
|
| 288 |
+
unif_group_sizes: bool = False,
|
| 289 |
+
) -> tuple[Tensor, Tensor, Tensor]:
|
| 290 |
+
assert M > 0, f"Number of lhs rows M must be positive (M = {M})."
|
| 291 |
+
assert K > 0, f"Number of lhs columns / rhs rows K must be positive (K = {K})."
|
| 292 |
+
assert N > 0, f"Number of rhs columns N must be positive (N = {N})."
|
| 293 |
+
assert G > 0, f"Number of groups G must be positive (G = {G})."
|
| 294 |
+
|
| 295 |
+
if rng_seed is not None:
|
| 296 |
+
torch.manual_seed(rng_seed)
|
| 297 |
+
|
| 298 |
+
lhs = torch.randn((M, K), dtype=torch.float32, device=device)
|
| 299 |
+
lhs = lhs.to(preferred_element_type)
|
| 300 |
+
|
| 301 |
+
if trans_rhs:
|
| 302 |
+
rhs = torch.randn((G, N, K), dtype=torch.float32, device=device).permute(
|
| 303 |
+
0, 2, 1
|
| 304 |
+
)
|
| 305 |
+
else:
|
| 306 |
+
rhs = torch.randn((G, K, N), dtype=torch.float32, device=device)
|
| 307 |
+
rhs = rhs.to(preferred_element_type)
|
| 308 |
+
|
| 309 |
+
group_sizes = (
|
| 310 |
+
gen_uniform_group_sizes(M, G, device=device)
|
| 311 |
+
if unif_group_sizes
|
| 312 |
+
else gen_group_sizes(M, G, device=device, rng_seed=None)
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
return lhs, rhs, group_sizes
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def gen_gmm_output(
|
| 319 |
+
M: int,
|
| 320 |
+
N: int,
|
| 321 |
+
device: torch.device | str = DEVICE,
|
| 322 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 323 |
+
) -> Tensor:
|
| 324 |
+
assert M > 0, f"Number of out rows M must be positive (M = {M})."
|
| 325 |
+
assert N > 0, f"Number of out columns N must be positive (N = {N})."
|
| 326 |
+
|
| 327 |
+
out = torch.empty((M, N), dtype=preferred_element_type, device=device)
|
| 328 |
+
|
| 329 |
+
return out
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def gen_gmm_tensors(
|
| 333 |
+
M: int,
|
| 334 |
+
K: int,
|
| 335 |
+
N: int,
|
| 336 |
+
G: int,
|
| 337 |
+
num_group_sizes: int,
|
| 338 |
+
device: torch.device | str = DEVICE,
|
| 339 |
+
input_type: torch.dtype = DTYPE,
|
| 340 |
+
output_type: torch.dtype = DTYPE,
|
| 341 |
+
trans_lhs: bool = False,
|
| 342 |
+
trans_rhs: bool = TRANS_RHS,
|
| 343 |
+
rng_seed: int | None = RNG_SEED,
|
| 344 |
+
unif_group_sizes: bool = False,
|
| 345 |
+
use_bias: bool = False,
|
| 346 |
+
) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]:
|
| 347 |
+
lhs, rhs, group_sizes_0 = gen_gmm_input(
|
| 348 |
+
M,
|
| 349 |
+
K,
|
| 350 |
+
N,
|
| 351 |
+
G,
|
| 352 |
+
device=device,
|
| 353 |
+
preferred_element_type=input_type,
|
| 354 |
+
trans_rhs=trans_rhs,
|
| 355 |
+
rng_seed=rng_seed,
|
| 356 |
+
unif_group_sizes=unif_group_sizes,
|
| 357 |
+
)
|
| 358 |
+
multiple_group_sizes = gen_multiple_group_sizes(
|
| 359 |
+
num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0
|
| 360 |
+
)
|
| 361 |
+
out = gen_gmm_output(M, N, device=device, preferred_element_type=output_type)
|
| 362 |
+
bias = None
|
| 363 |
+
if use_bias:
|
| 364 |
+
torch.manual_seed(rng_seed + 1000) # Different seed for bias
|
| 365 |
+
bias = torch.randn(G, N, dtype=input_type, device=device)
|
| 366 |
+
|
| 367 |
+
return lhs, rhs, multiple_group_sizes, out, bias
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
# GMM helpers: get information from tensors.
|
| 371 |
+
# ------------------------------------------------------------------------------
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def get_gmm_shape(
|
| 375 |
+
lhs: Tensor, rhs: Tensor, group_sizes: Tensor
|
| 376 |
+
) -> tuple[int, int, int, int]:
|
| 377 |
+
assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
|
| 378 |
+
assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})."
|
| 379 |
+
assert (
|
| 380 |
+
group_sizes.dim() == 1
|
| 381 |
+
), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})."
|
| 382 |
+
|
| 383 |
+
M, lhs_k = lhs.shape
|
| 384 |
+
rhs_g, rhs_k, N = rhs.shape
|
| 385 |
+
group_sizes_g = group_sizes.shape[0]
|
| 386 |
+
|
| 387 |
+
assert (
|
| 388 |
+
lhs_k == rhs_k
|
| 389 |
+
), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})."
|
| 390 |
+
K = lhs_k
|
| 391 |
+
assert (
|
| 392 |
+
rhs_g == group_sizes_g
|
| 393 |
+
), f"G dimension of rhs and group_sizes don't match (rhs = {rhs_g}, group_sizes = {group_sizes_g})."
|
| 394 |
+
G = rhs_g
|
| 395 |
+
|
| 396 |
+
assert M > 0, f"M must be positive, it's {M}."
|
| 397 |
+
assert K > 0, f"K must be positive, it's {K}."
|
| 398 |
+
assert N > 0, f"N must be positive, it's {N}"
|
| 399 |
+
assert G > 0, f"G must be positive, it's {G}"
|
| 400 |
+
|
| 401 |
+
return M, K, N, G
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def get_gmm_output(
|
| 405 |
+
M: int,
|
| 406 |
+
N: int,
|
| 407 |
+
device: torch.device | str = DEVICE,
|
| 408 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 409 |
+
existing_out: Tensor | None = None,
|
| 410 |
+
) -> Tensor:
|
| 411 |
+
assert M > 0, f"Number of out rows M must be positive (M = {M})."
|
| 412 |
+
assert N > 0, f"Number of out columns N must be positive (N = {N})."
|
| 413 |
+
|
| 414 |
+
if existing_out is not None:
|
| 415 |
+
assert (
|
| 416 |
+
existing_out.device == device
|
| 417 |
+
), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})."
|
| 418 |
+
assert (
|
| 419 |
+
existing_out.dtype == preferred_element_type
|
| 420 |
+
), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})."
|
| 421 |
+
assert existing_out.shape == (
|
| 422 |
+
M,
|
| 423 |
+
N,
|
| 424 |
+
), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(M, N)})."
|
| 425 |
+
return existing_out
|
| 426 |
+
|
| 427 |
+
return gen_gmm_output(
|
| 428 |
+
M,
|
| 429 |
+
N,
|
| 430 |
+
device=device,
|
| 431 |
+
preferred_element_type=preferred_element_type,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def get_gmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]:
|
| 436 |
+
assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
|
| 437 |
+
assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})."
|
| 438 |
+
assert out.dim() == 2, f"out must have 2 dimensions (it's {out.dim()})."
|
| 439 |
+
|
| 440 |
+
lhs_m, lhs_k = lhs.shape
|
| 441 |
+
G, rhs_k, rhs_n = rhs.shape
|
| 442 |
+
out_m, out_n = out.shape
|
| 443 |
+
|
| 444 |
+
assert (
|
| 445 |
+
lhs_m == out_m
|
| 446 |
+
), f"M dimension of lhs and out don't match (lhs = {lhs_m}, rhs = {out_m})."
|
| 447 |
+
M = lhs_m
|
| 448 |
+
assert (
|
| 449 |
+
lhs_k == rhs_k
|
| 450 |
+
), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})."
|
| 451 |
+
K = lhs_k
|
| 452 |
+
assert (
|
| 453 |
+
rhs_n == out_n
|
| 454 |
+
), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})."
|
| 455 |
+
N = rhs_n
|
| 456 |
+
|
| 457 |
+
assert M > 0, f"M must be positive, it's {M}."
|
| 458 |
+
assert K > 0, f"K must be positive, it's {K}."
|
| 459 |
+
assert N > 0, f"N must be positive, it's {N}"
|
| 460 |
+
assert G > 0, f"G must be positive, it's {G}"
|
| 461 |
+
|
| 462 |
+
is_lhs_row_major = lhs.stride() == (K, 1)
|
| 463 |
+
assert is_lhs_row_major, "lhs must be row-major."
|
| 464 |
+
is_rhs_row_major = rhs.stride() == (K * N, N, 1)
|
| 465 |
+
is_rhs_col_major = rhs.stride() == (K * N, 1, K)
|
| 466 |
+
assert (
|
| 467 |
+
is_rhs_row_major != is_rhs_col_major
|
| 468 |
+
), "rhs must be row-major or column-major."
|
| 469 |
+
is_out_row_major = out.stride() == (N, 1)
|
| 470 |
+
assert is_out_row_major, "out must be row-major."
|
| 471 |
+
|
| 472 |
+
# Get rhs leading dimension according to transposition configuration.
|
| 473 |
+
ld_rhs = N if is_rhs_row_major else K
|
| 474 |
+
|
| 475 |
+
return is_rhs_col_major, ld_rhs
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
# TGMM helpers: tensor generation.
|
| 479 |
+
# ------------------------------------------------------------------------------
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def gen_tgmm_input(
|
| 483 |
+
M: int,
|
| 484 |
+
K: int,
|
| 485 |
+
N: int,
|
| 486 |
+
G: int,
|
| 487 |
+
device: torch.device | str = DEVICE,
|
| 488 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 489 |
+
trans_lhs: bool = TRANS_LHS,
|
| 490 |
+
rng_seed: int | None = RNG_SEED,
|
| 491 |
+
unif_group_sizes: bool = False,
|
| 492 |
+
) -> tuple[Tensor, Tensor, Tensor]:
|
| 493 |
+
assert K > 0, f"Number of lhs rows K must be positive (M = {K})."
|
| 494 |
+
assert M > 0, f"Number of lhs columns / rhs rows M must be positive (K = {M})."
|
| 495 |
+
assert N > 0, f"Number of rhs columns N must be positive (N = {N})."
|
| 496 |
+
assert G > 0, f"Number of groups G must be positive (G = {G})."
|
| 497 |
+
|
| 498 |
+
if rng_seed is not None:
|
| 499 |
+
torch.manual_seed(rng_seed)
|
| 500 |
+
|
| 501 |
+
if trans_lhs:
|
| 502 |
+
lhs = torch.randn((M, K), dtype=torch.float32, device=device).T
|
| 503 |
+
else:
|
| 504 |
+
lhs = torch.randn((K, M), dtype=torch.float32, device=device)
|
| 505 |
+
lhs = lhs.to(preferred_element_type)
|
| 506 |
+
|
| 507 |
+
rhs = torch.randn((M, N), dtype=torch.float32, device=device)
|
| 508 |
+
rhs = rhs.to(preferred_element_type)
|
| 509 |
+
|
| 510 |
+
group_sizes = (
|
| 511 |
+
gen_uniform_group_sizes(M, G, device=device)
|
| 512 |
+
if unif_group_sizes
|
| 513 |
+
else gen_group_sizes(M, G, device=device, rng_seed=None)
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
return lhs, rhs, group_sizes
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def gen_tgmm_output(
|
| 520 |
+
K: int,
|
| 521 |
+
N: int,
|
| 522 |
+
G: int,
|
| 523 |
+
device: torch.device | str = DEVICE,
|
| 524 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 525 |
+
) -> Tensor:
|
| 526 |
+
assert K > 0, f"Number of out rows K must be positive (K = {K})."
|
| 527 |
+
assert N > 0, f"Number of out columns N must be positive (N = {N})."
|
| 528 |
+
assert G > 0, f"Number of groups G must be positive (G = {G})."
|
| 529 |
+
|
| 530 |
+
out = torch.empty((G, K, N), dtype=preferred_element_type, device=device)
|
| 531 |
+
|
| 532 |
+
return out
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def gen_tgmm_bias_grad(
|
| 536 |
+
K: int,
|
| 537 |
+
G: int,
|
| 538 |
+
device: torch.device | str = DEVICE,
|
| 539 |
+
with_bias_grad: bool = False,
|
| 540 |
+
) -> Tensor:
|
| 541 |
+
if with_bias_grad:
|
| 542 |
+
assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})."
|
| 543 |
+
assert G > 0, f"Number of groups G must be positive (G = {G})."
|
| 544 |
+
return torch.empty((G, K), device=device, dtype=torch.float32)
|
| 545 |
+
else:
|
| 546 |
+
# Return dummy pointer when bias_grad is not needed.
|
| 547 |
+
# Must be float32 because atomic_add does not support bf16/fp16,
|
| 548 |
+
# and Triton validates the pointer dtype even in dead branches.
|
| 549 |
+
return torch.tensor([], device=device, dtype=torch.float32)
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
def gen_tgmm_tensors(
|
| 553 |
+
M: int,
|
| 554 |
+
K: int,
|
| 555 |
+
N: int,
|
| 556 |
+
G: int,
|
| 557 |
+
num_group_sizes: int,
|
| 558 |
+
device: torch.device | str = DEVICE,
|
| 559 |
+
input_type: torch.dtype = DTYPE,
|
| 560 |
+
output_type: torch.dtype = DTYPE,
|
| 561 |
+
trans_lhs: bool = TRANS_LHS,
|
| 562 |
+
trans_rhs: bool = False,
|
| 563 |
+
rng_seed: int | None = RNG_SEED,
|
| 564 |
+
unif_group_sizes: bool = False,
|
| 565 |
+
use_bias: bool = False,
|
| 566 |
+
) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]:
|
| 567 |
+
lhs, rhs, group_sizes_0 = gen_tgmm_input(
|
| 568 |
+
M,
|
| 569 |
+
K,
|
| 570 |
+
N,
|
| 571 |
+
G,
|
| 572 |
+
device=device,
|
| 573 |
+
preferred_element_type=input_type,
|
| 574 |
+
trans_lhs=trans_lhs,
|
| 575 |
+
rng_seed=rng_seed,
|
| 576 |
+
unif_group_sizes=unif_group_sizes,
|
| 577 |
+
)
|
| 578 |
+
multiple_group_sizes = gen_multiple_group_sizes(
|
| 579 |
+
num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0
|
| 580 |
+
)
|
| 581 |
+
out = gen_tgmm_output(K, N, G, device=device, preferred_element_type=output_type)
|
| 582 |
+
if use_bias:
|
| 583 |
+
bias_grad = gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=True)
|
| 584 |
+
else:
|
| 585 |
+
bias_grad = None
|
| 586 |
+
return lhs, rhs, multiple_group_sizes, out, bias_grad
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
# TGMM helpers: get information from tensors.
|
| 590 |
+
# ------------------------------------------------------------------------------
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def get_tgmm_shape(
|
| 594 |
+
lhs: Tensor, rhs: Tensor, group_sizes: Tensor
|
| 595 |
+
) -> tuple[int, int, int, int]:
|
| 596 |
+
assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
|
| 597 |
+
assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})."
|
| 598 |
+
assert (
|
| 599 |
+
group_sizes.dim() == 1
|
| 600 |
+
), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})."
|
| 601 |
+
|
| 602 |
+
K, lhs_m = lhs.shape
|
| 603 |
+
rhs_m, N = rhs.shape
|
| 604 |
+
G = group_sizes.shape[0]
|
| 605 |
+
|
| 606 |
+
assert (
|
| 607 |
+
lhs_m == rhs_m
|
| 608 |
+
), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})."
|
| 609 |
+
M = lhs_m
|
| 610 |
+
|
| 611 |
+
assert M > 0, f"M must be positive, it's {M}."
|
| 612 |
+
assert K > 0, f"K must be positive, it's {K}."
|
| 613 |
+
assert N > 0, f"N must be positive, it's {N}"
|
| 614 |
+
assert G > 0, f"G must be positive, it's {G}"
|
| 615 |
+
|
| 616 |
+
return M, K, N, G
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def get_tgmm_output(
|
| 620 |
+
K: int,
|
| 621 |
+
N: int,
|
| 622 |
+
G: int,
|
| 623 |
+
device: torch.device | str = DEVICE,
|
| 624 |
+
preferred_element_type: torch.dtype = DTYPE,
|
| 625 |
+
existing_out: Tensor | None = None,
|
| 626 |
+
) -> Tensor:
|
| 627 |
+
assert K > 0, f"Number of out rows K must be positive (K = {K})."
|
| 628 |
+
assert N > 0, f"Number of out columns N must be positive (N = {N})."
|
| 629 |
+
assert G > 0, f"Number of groups G must be positive (G = {G})."
|
| 630 |
+
|
| 631 |
+
if existing_out is not None:
|
| 632 |
+
assert (
|
| 633 |
+
existing_out.device == device
|
| 634 |
+
), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})."
|
| 635 |
+
assert (
|
| 636 |
+
existing_out.dtype == preferred_element_type
|
| 637 |
+
), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})."
|
| 638 |
+
assert existing_out.shape == (
|
| 639 |
+
G,
|
| 640 |
+
K,
|
| 641 |
+
N,
|
| 642 |
+
), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(G, K, N)})."
|
| 643 |
+
return existing_out
|
| 644 |
+
|
| 645 |
+
return gen_tgmm_output(
|
| 646 |
+
K,
|
| 647 |
+
N,
|
| 648 |
+
G,
|
| 649 |
+
device=device,
|
| 650 |
+
preferred_element_type=preferred_element_type,
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def get_tgmm_bias_grad(
|
| 655 |
+
K: int,
|
| 656 |
+
G: int,
|
| 657 |
+
device: torch.device | str = DEVICE,
|
| 658 |
+
existing_bias_grad: Tensor | None = None,
|
| 659 |
+
) -> Tensor:
|
| 660 |
+
"""
|
| 661 |
+
Get or validate bias gradient tensor for TGMM.
|
| 662 |
+
|
| 663 |
+
If existing_bias_grad is provided, validates its shape, device, dtype, and stride,
|
| 664 |
+
and always zeros it before returning (since the kernel uses atomic_add).
|
| 665 |
+
If existing_bias_grad is None, returns a dummy tensor (for use when COMPUTE_BIAS_GRAD=False).
|
| 666 |
+
Parameters
|
| 667 |
+
----------
|
| 668 |
+
K : int
|
| 669 |
+
Number of rows in the bias gradient tensor.
|
| 670 |
+
G : int
|
| 671 |
+
Number of groups.
|
| 672 |
+
device : torch.device or str
|
| 673 |
+
Device for the tensor.
|
| 674 |
+
existing_bias_grad : torch.Tensor or None
|
| 675 |
+
Existing bias gradient tensor to validate and use.
|
| 676 |
+
Returns
|
| 677 |
+
-------
|
| 678 |
+
torch.Tensor
|
| 679 |
+
Valid bias gradient tensor or dummy tensor.
|
| 680 |
+
"""
|
| 681 |
+
assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})."
|
| 682 |
+
assert G > 0, f"Number of groups G must be positive (G = {G})."
|
| 683 |
+
|
| 684 |
+
if existing_bias_grad is not None:
|
| 685 |
+
# Validate existing bias_grad tensor.
|
| 686 |
+
expected_shape = (G, K)
|
| 687 |
+
assert (
|
| 688 |
+
tuple(existing_bias_grad.shape) == expected_shape
|
| 689 |
+
), f"bias_grad must have shape {expected_shape}, got {tuple(existing_bias_grad.shape)}."
|
| 690 |
+
assert (
|
| 691 |
+
existing_bias_grad.device == device
|
| 692 |
+
), f"bias_grad must be on the same device (bias_grad = {existing_bias_grad.device}, device = {device})."
|
| 693 |
+
assert (
|
| 694 |
+
existing_bias_grad.dtype == torch.float32
|
| 695 |
+
), f"bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), got {existing_bias_grad.dtype}."
|
| 696 |
+
assert existing_bias_grad.stride() == (
|
| 697 |
+
K,
|
| 698 |
+
1,
|
| 699 |
+
), f"bias_grad must be row-major with stride (K, 1) = ({K}, 1), got {existing_bias_grad.stride()}."
|
| 700 |
+
|
| 701 |
+
# Always zero the tensor since bias_grad represents gradients for the current
|
| 702 |
+
# computation and should start fresh. The kernel uses atomic_add which adds to
|
| 703 |
+
# existing values, so we must zero before the kernel runs.
|
| 704 |
+
existing_bias_grad.zero_()
|
| 705 |
+
|
| 706 |
+
return existing_bias_grad
|
| 707 |
+
|
| 708 |
+
else:
|
| 709 |
+
return gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=False)
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
def get_tgmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]:
|
| 713 |
+
assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
|
| 714 |
+
assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})."
|
| 715 |
+
assert out.dim() == 3, f"out must have 3 dimensions (it's {out.dim()})."
|
| 716 |
+
|
| 717 |
+
lhs_k, lhs_m = lhs.shape
|
| 718 |
+
rhs_m, rhs_n = rhs.shape
|
| 719 |
+
G, out_k, out_n = out.shape
|
| 720 |
+
|
| 721 |
+
assert (
|
| 722 |
+
lhs_m == rhs_m
|
| 723 |
+
), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})."
|
| 724 |
+
M = lhs_m
|
| 725 |
+
assert (
|
| 726 |
+
lhs_k == out_k
|
| 727 |
+
), f"K dimension of lhs and out don't match (lhs = {lhs_k}, rhs = {out_k})."
|
| 728 |
+
K = lhs_k
|
| 729 |
+
assert (
|
| 730 |
+
rhs_n == out_n
|
| 731 |
+
), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})."
|
| 732 |
+
N = rhs_n
|
| 733 |
+
|
| 734 |
+
assert M > 0, f"M must be positive, it's {M}."
|
| 735 |
+
assert K > 0, f"K must be positive, it's {K}."
|
| 736 |
+
assert N > 0, f"N must be positive, it's {N}"
|
| 737 |
+
assert G > 0, f"G must be positive, it's {G}"
|
| 738 |
+
|
| 739 |
+
is_lhs_row_major = lhs.stride() == (M, 1)
|
| 740 |
+
is_lhs_col_major = lhs.stride() == (1, K)
|
| 741 |
+
assert (
|
| 742 |
+
is_lhs_row_major != is_lhs_col_major
|
| 743 |
+
), "lhs must be row-major or column-major."
|
| 744 |
+
is_rhs_row_major = rhs.stride() == (N, 1)
|
| 745 |
+
assert is_rhs_row_major, "rhs must be row-major."
|
| 746 |
+
is_out_row_major = out.stride() == (K * N, N, 1)
|
| 747 |
+
assert is_out_row_major, "out must be row-major."
|
| 748 |
+
|
| 749 |
+
# Get lhs leading dimension according to transposition configuration.
|
| 750 |
+
ld_lhs = M if is_lhs_row_major else K
|
| 751 |
+
|
| 752 |
+
return is_lhs_col_major, ld_lhs
|
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/logger.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# AITER Triton Logger which is singleton object around python logging.
|
| 6 |
+
# Note: Python logging is also a singleton object, but we want to read the
|
| 7 |
+
# env var AITER_LOG_LEVEL once at the beginning. Another alternative is to do
|
| 8 |
+
# this in __init__.py. In fact, that's how CK logger is setup. We can look at
|
| 9 |
+
# switching to that at some point
|
| 10 |
+
#
|
| 11 |
+
# AITER_LOG_LEVEL follows python logging levels
|
| 12 |
+
# DEBUG
|
| 13 |
+
# INFO
|
| 14 |
+
# WARNING
|
| 15 |
+
# ERROR
|
| 16 |
+
# CRITICAL
|
| 17 |
+
#
|
| 18 |
+
class AiterTritonLogger(object):
|
| 19 |
+
_instance = None
|
| 20 |
+
|
| 21 |
+
def __new__(cls):
|
| 22 |
+
if cls._instance is None:
|
| 23 |
+
cls._instance = super(AiterTritonLogger, cls).__new__(cls)
|
| 24 |
+
log_level_str = os.getenv("AITER_TRITON_LOG_LEVEL", "WARNING").upper()
|
| 25 |
+
numeric_level = getattr(logging, log_level_str, logging.WARNING)
|
| 26 |
+
cls._instance._logger = logging.getLogger("AITER_TRITON")
|
| 27 |
+
cls._instance._logger.setLevel(numeric_level)
|
| 28 |
+
|
| 29 |
+
return cls._instance
|
| 30 |
+
|
| 31 |
+
def get_logger(self):
|
| 32 |
+
return self._logger
|
| 33 |
+
|
| 34 |
+
def debug(self, msg):
|
| 35 |
+
self._logger.debug(msg)
|
| 36 |
+
|
| 37 |
+
def info(self, msg):
|
| 38 |
+
self._logger.info(msg)
|
| 39 |
+
|
| 40 |
+
def warning(self, msg):
|
| 41 |
+
self._logger.warning(msg)
|
| 42 |
+
|
| 43 |
+
def error(self, msg):
|
| 44 |
+
self._logger.error(msg)
|
| 45 |
+
|
| 46 |
+
def critical(self, msg):
|
| 47 |
+
self._logger.critical(msg)
|
build/torch211-cxx11-cu130-x86_64-linux/{_megablocks_cuda_ae601bb.abi3.so → _megablocks_cuda_f8f8b50.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5ef673d78d220cea71eace3a5bdb4b952444ab7b95ed15774258ad108ad40d51
|
| 3 |
+
size 11769248
|
build/torch211-cxx11-cu130-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_cuda_f8f8b50
|
| 3 |
+
ops = torch.ops._megablocks_cuda_f8f8b50
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_cuda_f8f8b50::{op_name}"
|
build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py
CHANGED
|
@@ -2,16 +2,16 @@
|
|
| 2 |
# extensions. Otherwise libc10.so cannot be found.
|
| 3 |
import torch
|
| 4 |
|
| 5 |
-
#
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
# import grouped_gemm_backend as backend
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
#
|
| 14 |
-
|
|
|
|
| 15 |
|
| 16 |
def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
|
| 17 |
assert not (trans_a and trans_b)
|
|
|
|
| 2 |
# extensions. Otherwise libc10.so cannot be found.
|
| 3 |
import torch
|
| 4 |
|
| 5 |
+
# On ROCm there is no CUTLASS grouped GEMM; dispatch to the vendored AITER
|
| 6 |
+
# Triton kernels instead. On CUDA we use the compiled CUTLASS `gmm` op.
|
| 7 |
+
_IS_ROCM = torch.version.hip is not None
|
|
|
|
| 8 |
|
| 9 |
+
if _IS_ROCM:
|
| 10 |
+
from .._grouped_gemm_triton import adapter as backend
|
| 11 |
+
else:
|
| 12 |
+
# We import the backend operations from the megablocks package as
|
| 13 |
+
# grouped_gemm is vendored in megablocks in this repository.
|
| 14 |
+
from .._ops import ops as backend # type: ignore
|
| 15 |
|
| 16 |
def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
|
| 17 |
assert not (trans_a and trans_b)
|