diff --git a/build/torch211-cxx11-cu126-x86_64-linux/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/__init__.py index 38075732c6d8fa0e1e6ef493145e1aca3851ae6b..0766d7b8da4f97baca212177b4bb989bc6374bf8 100644 --- a/build/torch211-cxx11-cu126-x86_64-linux/__init__.py +++ b/build/torch211-cxx11-cu126-x86_64-linux/__init__.py @@ -3,7 +3,9 @@ import torch -from ._ops import ops +# Stable alias: bare `ops` is shadowed by `from . import layers` below. +from ._ops import ops as _compiled_ops +from . import ops from .grouped_gemm import backend as gg_backend from .grouped_gemm import ops as gg_ops @@ -136,7 +138,8 @@ def sort( Returns: The sorted values tensor """ - return ops.sort(x, end_bit, x_out, iota_out) + _compiled_ops.sort(x, end_bit, x_out, iota_out) + return x_out # Convenience functions for common use cases diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py new file mode 100644 index 0000000000000000000000000000000000000000..8c101d07cea416f9390b708e5a35fdc466e48aed --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py @@ -0,0 +1,574 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + + +# Imports. +# ------------------------------------------------------------------------------ + +# Python standard library +import functools + +# Triton +import triton +import triton.language as tl + +# AITER +from ..configs import CONFIGS as _CONFIGS +from ..utils._triton import arch_info +from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd + +# Kernel config. +# ------------------------------------------------------------------------------ + + +@functools.lru_cache() +def get_config( + gmm_type: str, M: int, K: int, N: int, G: int, accumulate: bool = False +) -> dict[str, int]: + assert gmm_type in { + "gmm", + "ptgmm", + "nptgmm", + }, f"'{gmm_type}' is an invalid GMM variant." + dev = arch_info.get_arch() + assert ( + dev in _CONFIGS + ), f"No GMM configuration tuned for arch '{dev}'. Supported: {sorted(_CONFIGS)}." + arch_configs = _CONFIGS[dev] + assert ( + "default" in arch_configs[gmm_type] + ), "Default configuration is absent." + key = "accumulate" if accumulate else "default" + return arch_configs[gmm_type][key] + + +# Common code shared by GMM and TGMM kernels. +# ------------------------------------------------------------------------------ + + +# XCD remapping followed by 1D PID to 2D grid mapping. +@triton.jit +def _remap_xcd_tile_grid( + tile_in_mm, + num_row_tiles, + num_col_tiles, + GROUP_SIZE: tl.constexpr = 1, + NUM_XCDS: tl.constexpr = 8, +): + return pid_grid( + remap_xcd(tile_in_mm, num_row_tiles * num_col_tiles, NUM_XCDS=NUM_XCDS), + num_row_tiles, + num_col_tiles, + GROUP_SIZE_M=GROUP_SIZE, + ) + + +# GMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.heuristics( + { + "K_DIVISIBLE_BY_BLOCK_SIZE_K": lambda META: META["K"] % META["BLOCK_SIZE_K"] + == 0, + } +) +@triton.jit +def gmm_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_RHS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + K_DIVISIBLE_BY_BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE: tl.constexpr, + GRID_DIM: tl.constexpr, + USE_BIAS: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + # Current tile. Each program computes multiple tiles of each group. + tile = tl.program_id(0) + tl.device_assert(tile >= 0, "tile < 0 (at initialization)") + + # Tile limit of last MM problem (inclusive). + last_mm_tile = 0 + + # Last input row of lhs and output row of out. Each group reads some rows of + # lhs and writes some rows to out. + last_m = 0 + + # Loop through all (m, K, N) MM problems: + # (m, K) x (K, N) = (m, N) + # sum(m) = M + for g in range(G): + # Get m dimension of current MM problem. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty + tl.device_assert(m >= 0, "m < 0") + + num_m_tiles = tl.cdiv(m, BLOCK_SIZE_M) + # num_m_tiles can be zero if group is empty + tl.device_assert(num_m_tiles >= 0, "num_m_tiles < 0") + + num_tiles = num_m_tiles * num_n_tiles + # num_tiles can be zero if group is empty + tl.device_assert(num_tiles >= 0, "num_tiles < 0") + + # Loop through tiles of current MM problem. + while tile >= last_mm_tile and tile < last_mm_tile + num_tiles: + # Figure out tile coordinates in current MM problem. + tile_in_mm = tile - last_mm_tile + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_m, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_m_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + # Do regular MM: + + tl.device_assert(tile_m * BLOCK_SIZE_M >= 0, "tile_m * BLOCK_SIZE_M < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_m = ( + tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + ) % m + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + + lhs_ptrs = lhs_ptr + (last_m + offs_lhs_m[:, None]) * K + offs_k[None, :] + + if TRANS_RHS: + rhs_ptrs = ( + rhs_ptr + + g.to(tl.int64) * K * N + + offs_k[:, None] + + offs_rhs_n[None, :] * K + ) + else: + rhs_ptrs = ( + rhs_ptr + + g.to(tl.int64) * K * N + + offs_k[:, None] * N + + offs_rhs_n[None, :] + ) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if K_DIVISIBLE_BY_BLOCK_SIZE_K: + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + else: + k_mask_limit = K - k * BLOCK_SIZE_K + lhs = tl.load( + lhs_ptrs, mask=offs_k[None, :] < k_mask_limit, other=0 + ) + rhs = tl.load( + rhs_ptrs, mask=offs_k[:, None] < k_mask_limit, other=0 + ) + + acc = tl.dot(lhs, rhs, acc=acc) + + lhs_ptrs += BLOCK_SIZE_K + + if TRANS_RHS: + rhs_ptrs += BLOCK_SIZE_K + else: + rhs_ptrs += BLOCK_SIZE_K * N + + # Add bias if enabled + if USE_BIAS: + offs_bias_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange( + 0, BLOCK_SIZE_N + ) + bias_ptrs = bias_ptr + g.to(tl.int64) * N + offs_bias_n + bias = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0) + # Convert bias to float32 to match accumulator precision + bias = bias.to(tl.float32) + # Broadcast bias across M dimension and add in float32 + acc += bias[None, :] + + # Convert to output dtype after all computations + acc = acc.to(out_ptr.type.element_ty) + + offs_out_m = tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + (last_m + offs_out_m[:, None]) * N + offs_out_n[None, :] + ) + + tl.store( + out_ptrs, + acc, + mask=(offs_out_m[:, None] < m) & (offs_out_n[None, :] < N), + ) + + # Go to the next tile by advancing number of programs. + tile += GRID_DIM + tl.device_assert(tile > 0, "tile <= 0 (at update)") + + # Get ready to go to the next MM problem. + + last_mm_tile += num_tiles + # last_mm_tile can be zero if group 0 is skipped + tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)") + + last_m += m + # last_m can be zero if group 0 is skipped + tl.device_assert(last_m >= 0, "last_m < 0 (at update)") + tl.device_assert(last_m <= M, "last_m > M (at update)") + + +# Persistent TGMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.jit +def tgmm_persistent_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_grad_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_LHS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE: tl.constexpr, + GRID_DIM: tl.constexpr, + COMPUTE_BIAS_GRAD: tl.constexpr, + ACCUMULATE: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0") + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + num_tiles = num_k_tiles * num_n_tiles + tl.device_assert(num_tiles > 0, "num_tiles <= 0") + + # Current tile. Each program computes multiple tiles of each group. + tile = tl.program_id(0) + tl.device_assert(tile >= 0, "tile < 0 (at initialization)") + + # Tile limit of last MM problem (inclusive). + last_mm_tile = 0 + + # Last input column of lhs and input row of rhs. Each group reads some + # columns of lhs and some rows of rhs. + last_m = 0 + + # Loop through all (K, m, N) MM problems: + # (K, m) x (m, N) = (K, N) + # sum(m) = M + for g in range(G): + # Get m dimension of current MM problem. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty + tl.device_assert(m >= 0, "m < 0") + + # Loop through tiles of current MM problem. + while tile >= last_mm_tile and tile < last_mm_tile + num_tiles: + # Figure out tile coordinates in current MM problem. + tile_in_mm = tile - last_mm_tile + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_k, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + # Do regular MM: + + tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + + if TRANS_LHS: + lhs_ptrs = ( + lhs_ptr + offs_lhs_k[:, None] + (last_m + offs_m[None, :]) * K + ) + else: + lhs_ptrs = ( + lhs_ptr + offs_lhs_k[:, None] * M + (last_m + offs_m[None, :]) + ) + + rhs_ptrs = rhs_ptr + (last_m + offs_m[:, None]) * N + offs_rhs_n[None, :] + + loop_m = tl.cdiv(m, BLOCK_SIZE_M) + m_divisible_by_block_m = m % BLOCK_SIZE_M == 0 + if not m_divisible_by_block_m: + loop_m -= 1 + + acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32) + + # Initialize bias accumulator + bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32) + + for _ in range(0, loop_m): + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + + acc = tl.dot(lhs, rhs, acc=acc) + + # Accumulate for bias gradient: sum lhs across M dimension + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum( + lhs, axis=1 + ) # Sum across M dimension [K, M] -> [K] + + if TRANS_LHS: + lhs_ptrs += BLOCK_SIZE_M * K + else: + lhs_ptrs += BLOCK_SIZE_M + + rhs_ptrs += BLOCK_SIZE_M * N + + if not m_divisible_by_block_m: + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0) + rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0) + acc = tl.dot(lhs, rhs, acc=acc) + + # Accumulate last chunk for bias gradient + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) + + acc = acc.to(out_ptr.type.element_ty) + + offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + + g.to(tl.int64) * K * N + + offs_out_k[:, None] * N + + offs_out_n[None, :] + ) + + mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N) + if ACCUMULATE: + # Load existing values and add to them (like beta=1 in BLAS) + old_vals = tl.load(out_ptrs, mask=mask, other=0.0) + tl.store(out_ptrs, acc + old_vals, mask=mask) + else: + # Overwrite output (like beta=0 in BLAS) + tl.store(out_ptrs, acc, mask=mask) + + # Store bias gradient (only for first N tile, sum across all M) + if COMPUTE_BIAS_GRAD and tile_n == 0: + # Keep as float32 for atomic_add (bf16 not supported for atomics) + bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k + # Use atomic add since multiple K-tiles may write to same expert's bias + tl.atomic_add( + bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed" + ) + + # Go to the next tile by advancing number of programs. + tile += GRID_DIM + tl.device_assert(tile > 0, "tile <= 0 (at update)") + + # Get ready to go to the next MM problem. + + last_mm_tile += num_tiles + # last_mm_tile can be zero if group 0 is skipped + tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)") + + last_m += m + # last_m can be zero if group 0 is skipped + tl.device_assert(last_m >= 0, "last_m < 0 (at update)") + tl.device_assert(last_m <= M, "last_m > M (at update)") + + +# Regular non-persistent TGMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.heuristics({"BLOCK_SIZE_G": lambda META: triton.next_power_of_2(META["G"])}) +@triton.jit +def tgmm_non_persistent_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_grad_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_LHS: tl.constexpr, + BLOCK_SIZE_G: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE: tl.constexpr, + COMPUTE_BIAS_GRAD: tl.constexpr, + ACCUMULATE: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + # Get group ID from grid. + g = tl.program_id(0) + tl.device_assert(g >= 0, "g < 0") + tl.device_assert(g < G, "g >= G") + + # Get m dimension of current MM group. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty. + tl.device_assert(m >= 0, "m < 0") + + # Skip empty groups. + if m == 0: + return + + # Compute sum(group_sizes) until current group g. + # It's the starting column of lhs and starting row of rhs. + offs_g = tl.arange(0, BLOCK_SIZE_G) + group_sizes = tl.load(group_sizes_ptr + offs_g, mask=offs_g < g, other=0) + start_m = tl.sum(group_sizes) + + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0") + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + # Get MM tile from grid. + tile_in_mm = tl.program_id(1) + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_k, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_k = (tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K + offs_rhs_n = (tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + + if TRANS_LHS: + lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] + (start_m + offs_m[None, :]) * K + else: + lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] * M + (start_m + offs_m[None, :]) + + rhs_ptrs = rhs_ptr + (start_m + offs_m[:, None]) * N + offs_rhs_n[None, :] + + loop_m = tl.cdiv(m, BLOCK_SIZE_M) + m_divisible_by_block_m = m % BLOCK_SIZE_M == 0 + if not m_divisible_by_block_m: + loop_m -= 1 + + acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32) + # Initialize bias accumulator + bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32) + + for _ in range(0, loop_m): + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + + acc = tl.dot(lhs, rhs, acc=acc) + + # Accumulate for bias gradient: sum lhs across M dimension + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) # [K, M] -> [K] + + if TRANS_LHS: + lhs_ptrs += BLOCK_SIZE_M * K + else: + lhs_ptrs += BLOCK_SIZE_M + + rhs_ptrs += BLOCK_SIZE_M * N + + if not m_divisible_by_block_m: + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0) + rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0) + acc = tl.dot(lhs, rhs, acc=acc) + # Accumulate last chunk for bias gradient + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) + + acc = acc.to(out_ptr.type.element_ty) + + offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + g.to(tl.int64) * K * N + offs_out_k[:, None] * N + offs_out_n[None, :] + ) + + mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N) + if ACCUMULATE: + # Load existing values and add to them (like beta=1 in BLAS) + old_vals = tl.load(out_ptrs, mask=mask, other=0.0) + tl.store(out_ptrs, acc + old_vals, mask=mask) + else: + # Overwrite output (like beta=0 in BLAS) + tl.store(out_ptrs, acc, mask=mask) + + # Store bias gradient (only for first N tile, sum across all M) + if COMPUTE_BIAS_GRAD and tile_n == 0: + # Keep as float32 for atomic_add (bf16/fp16 not supported for atomics) + bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k + # Use atomic add since multiple K-tiles may write to same expert's bias + tl.atomic_add(bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed") diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/adapter.py b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..98c224244f27445384e0c2377d73516406927536 --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/adapter.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Adapt AITER's Triton grouped GEMM to MegaBlocks' ``gmm`` calling convention. + +MegaBlocks (following tgale96/grouped_gemm) uses a single ``gmm`` entry point +with ``trans_a`` / ``trans_b`` flags: + +* ``trans_a=False, trans_b=False``: a(M,K) @ b(G,K,N) -> c(M,N) +* ``trans_a=False, trans_b=True`` : a(M,K) @ b(G,N,K)^T -> c(M,N) (dgrad) +* ``trans_a=True`` : a(M,K)^T @ b(M,N) per group -> c(G,K,N) (wgrad) + +AITER exposes these as two kernels: ``gmm`` ((M,K)@(G,K,N)->(M,N), transposition +of the 3D operand inferred from strides) and ``ptgmm`` ((K,M)@(M,N)->(G,K,N), +transposition of the 2D operand inferred from strides). +""" + +import torch + +from .gmm import gmm as _aiter_gmm +from .gmm import ptgmm as _aiter_ptgmm + + +def gmm(a, b, c, batch_sizes, trans_a=False, trans_b=False): + # AITER requires group sizes to be int32 and to live on the compute device. + group_sizes = batch_sizes.to(device=a.device, dtype=torch.int32) + + # AITER asserts exact strides: gmm wants lhs/rhs row-major (a transposed + # 3D operand must be exactly column-major), tgmm wants rhs row-major and + # lhs row/column-major. Make operands contiguous first so the transposed + # views have the precise strides the kernels expect. `.contiguous()` is a + # no-op when the tensor is already contiguous. + if trans_a: + # Weight gradient: a(M,K), b(M,N) -> c(G,K,N). + # Pass a transposed so AITER sees lhs(K,M) column-major (TRANS_LHS). + _aiter_ptgmm( + a.contiguous().transpose(0, 1), + b.contiguous(), + group_sizes, + preferred_element_type=c.dtype, + existing_out=c, + ) + else: + # trans_b contracts b's last dim: pass a column-major (G,K,N) view. + rhs = b.contiguous() + if trans_b: + rhs = rhs.transpose(1, 2) + _aiter_gmm( + a.contiguous(), + rhs, + group_sizes, + preferred_element_type=c.dtype, + existing_out=c, + ) + return c diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/configs.py b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..9a4fe5617d8100869aa76dba9b7d22c7bcab814f --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/configs.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: MIT +# Tuned GMM configs vendored from ROCm/aiter (aiter/ops/triton/configs/). +# Inlined as a Python module so packaging always includes them. + +CONFIGS = {'gfx1250': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx942': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx950': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}} diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/gmm.py b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/gmm.py new file mode 100644 index 0000000000000000000000000000000000000000..e30c9326c6d4e4836d1303e2761ea2440a7f4750 --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/gmm.py @@ -0,0 +1,567 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + + +# Imports. +# ------------------------------------------------------------------------------ + +# PyTorch +import torch +from torch import Tensor + +# Triton +import triton + +# AITER: GMM utility functions +from .utils.gmm_common import ( + DTYPE, + is_power_of_2, + check_input_device_dtype, + check_bias_shape_stride, + get_gmm_shape, + get_gmm_output, + get_gmm_transposition, + get_tgmm_shape, + get_tgmm_output, + get_tgmm_bias_grad, + get_tgmm_transposition, +) + +# AITER: GMM Triton kernels +from ._triton_kernels.gmm import ( + gmm_kernel, + tgmm_persistent_kernel, + tgmm_non_persistent_kernel, + get_config, +) + +# GMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _gmm_grid( + N: int, + block_size_m: int, + block_size_n: int, + group_sizes: Tensor, + grid_dim: int, +) -> tuple[int]: + assert N > 0, f"N must be positive, it's {N}." + assert is_power_of_2( + block_size_m + ), f"M-dimension tile size must be a power of 2 (it's {block_size_m})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + assert torch.all(group_sizes >= 0).item(), "All group_sizes must be non-negative." + assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})." + num_m_tiles = (group_sizes + block_size_m - 1) // block_size_m + assert torch.all(num_m_tiles >= 0).item(), "All num_m_tiles must be non-negative." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles = torch.sum(num_m_tiles * num_n_tiles).item() + assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}." + num_programs = int(min(grid_dim, num_tiles)) + assert num_programs > 0, f"num_programs must be positive, it's {num_programs}." + return (num_programs,) + + +def gmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias: Tensor | None = None, +) -> Tensor: + """ + Perform Group Matrix Multiplication (GMM): out = lhs @ rhs + bias + + lhs rows are divided into G groups. Each group of lhs rows is matrix multiplied with a plane of + rhs 3D tensor and then stored in a slice of out. In PyTorch parlance, it can be implemented as + follows for a given group g: + out[group_start:group_end, :] = lhs[group_start:group_end, :] @ rhs[g] + bias[g] + + The size of each group, and their respective start and end positions are specified by + group_sizes tensor. For instance, suppose that group_sizes = [3, 2, 4, 1]. In this particular + case we have 4 groups. The 1st group starts at 0 and ends at 2, the second group starts at 3 and + ends at 4, the third group starts at 5 and ends at 8, and the fourth and final group consists of + just the 10th (last) row of lhs. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (M, K). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 3D input tensor. Shape: (G, K, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (M, N), its data type must match preferred_element_type + and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias : torch.Tensor or None, optional + Optional bias tensor. Shape: (G, N). + If provided, bias data type must match lhs and rhs data type, and bias must be on the same + device as other input tensors. Each group g adds bias[g] to the output. + + Returns + ------- + torch.Tensor + The computed output 2D tensor. Shape: (M, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - GMM is implemented with a persistent Triton kernel. + - lhs must be row-major (lhs.stride() == (K, 1)). + - rhs can be row-major (rhs.stride() == (K * N, N, 1)) or column-major (rhs.stride() == + (K * N, 1, K)). If rhs is row-major then kernel parameter TRANS_RHS == False, this is useful + for implementing forward pass. If rhs is column-major then kernel parameter TRANS_RHS == True, + this is useful for computing the lhs derivative in the backward pass, while fusing the + transposition. + - out must be row-major (out.stride() == (N, 1)). + - bias must be row-major (bias.stride() == (N, 1)) if provided. + """ + use_bias = bias is not None + check_input_device_dtype(lhs, rhs, group_sizes, bias) + + M, K, N, G = get_gmm_shape(lhs, rhs, group_sizes) + + if use_bias: + check_bias_shape_stride(bias, G, N) + + out = get_gmm_output( + M, + N, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_rhs, _ = get_gmm_transposition(lhs, rhs, out) + + if config is None: + config = get_config("gmm", M, K, N, G) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + "GRID_DIM", + } + ), "Invalid GMM kernel config." + + grid = _gmm_grid( + N, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + group_sizes, + config["GRID_DIM"], + ) + + # fmt: off + gmm_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_RHS=trans_rhs, + USE_BIAS=use_bias, + **config, + ) + # fmt: on + + return out + + +# Persistent TGMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _ptgmm_grid( + K: int, + N: int, + G: int, + block_size_k: int, + block_size_n: int, + grid_dim: int, +) -> tuple[int]: + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}." + assert G > 0, f"G must be positive, it's {G}." + assert is_power_of_2( + block_size_k + ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})." + num_k_tiles = triton.cdiv(K, block_size_k) + assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles = G * num_k_tiles * num_n_tiles + assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}." + num_programs = min(grid_dim, num_tiles) + assert num_programs > 0, f"num_programs must be positive, it's {num_programs}." + return (num_programs,) + + +def ptgmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias_grad: Tensor | None = None, + accumulate: bool = False, +) -> Tensor: + """ + Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs + + lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with + the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch + parlance, it can be implemented as follows for a given group g: + out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :] + + The 't' in the operator name derives from MaxText implementation + (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py), + which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor + shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N). + + The 'p' in the operator name means that it is implemented with a persistent kernel. There is + also the non-persistent variation, which is implemented with a regular kernel. Please take a + look at nptgmm operator. Both ptgmm and nptgmm implement the same computation, choosing one or + the other is a matter of performance for the target workload. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (K, M). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 2D input tensor. Shape: (M, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (G, K, N), its data type must match + preferred_element_type and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias_grad : torch.Tensor or None, optional + Optional bias gradient output tensor. Shape: (G, K). + If provided, the kernel will compute the bias gradient and write it to this tensor. + bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), + accumulate : bool, optional + Whether to accumulate into existing output tensor values. Default is False. + If False, output will be overwritten with fresh computation. + If True, results will be added to existing output tensor values. + + Returns + ------- + torch.Tensor + The computed output 3D tensor. Shape: (G, K, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - PTGMM is implemented with a persistent Triton kernel. + - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs + is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel + parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward + pass, while fusing the transposition. + - rhs must be row-major (rhs.stride() == (N, 1)). + - out must be row-major (out.stride() == (K * N, N, 1)). + """ + check_input_device_dtype(lhs, rhs, group_sizes) + + M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes) + + out = get_tgmm_output( + K, + N, + G, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out) + + if config is None: + config = get_config("ptgmm", M, K, N, G, accumulate) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + "GRID_DIM", + } + ), "Invalid PTGMM kernel config." + + # Bias gradient handling. + # ----------------------- + # Get or validate bias gradient tensor. + compute_bias_grad = bias_grad is not None + bias_grad_ptr = get_tgmm_bias_grad( + K, + G, + device=lhs.device, + existing_bias_grad=bias_grad, + ) + + grid = _ptgmm_grid( + K, + N, + G, + config["BLOCK_SIZE_K"], + config["BLOCK_SIZE_N"], + config["GRID_DIM"], + ) + + # fmt: off + tgmm_persistent_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias_grad_ptr, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_LHS=trans_lhs, + COMPUTE_BIAS_GRAD=compute_bias_grad, + ACCUMULATE=accumulate, + **config, + ) + # fmt: on + + return out + + +# Regular non-persistent TGMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _nptgmm_grid( + K: int, + N: int, + G: int, + block_size_k: int, + block_size_n: int, +) -> tuple[int, int]: + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}." + assert G > 0, f"G must be positive, it's {G}." + assert is_power_of_2( + block_size_k + ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + num_k_tiles = triton.cdiv(K, block_size_k) + assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles_per_mm = num_k_tiles * num_n_tiles + assert ( + num_tiles_per_mm > 0 + ), f"num_tiles_per_mm must be positive, it's {num_tiles_per_mm}." + return (G, num_tiles_per_mm) + + +def nptgmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias_grad: Tensor | None = None, + accumulate: bool = False, +) -> Tensor: + """ + Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs + + lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with + the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch + parlance, it can be implemented as follows for a given group g: + out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :] + + The 't' in the operator name derives from MaxText implementation + (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py), + which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor + shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N). + + The 'np' in the operator name means that it is implemented with a non-persistent, i.e. regular + kernel. There is also the persistent variation, which is implemented with a persistent kernel. + Please take a look at ptgmm operator. Both nptgmm and ptgmm implement the same computation, + choosing one or the other is a matter of performance for the target workload. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (K, M). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 2D input tensor. Shape: (M, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (G, K, N), its data type must match + preferred_element_type and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias_grad : torch.Tensor or None, optional + Optional bias gradient output tensor. Shape: (G, K). + If provided, the kernel will compute the bias gradient and write it to this tensor. + bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), + accumulate : bool, optional + Whether to accumulate into existing output tensor values. Default is False. + If False, output will be overwritten with fresh computation. + If True, results will be added to existing output tensor values. + + Returns + ------- + torch.Tensor + The computed output 3D tensor. Shape: (G, K, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - NPTGMM is implemented with a non-persistent regular Triton kernel. + - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs + is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel + parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward + pass, while fusing the transposition. + - rhs must be row-major (rhs.stride() == (N, 1)). + - out must be row-major (out.stride() == (K * N, N, 1)). + """ + check_input_device_dtype(lhs, rhs, group_sizes) + + M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes) + + out = get_tgmm_output( + K, + N, + G, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out) + + # Bias gradient handling. + # ----------------------- + # Get or validate bias gradient tensor. + compute_bias_grad = bias_grad is not None + bias_grad_ptr = get_tgmm_bias_grad( + K, + G, + device=lhs.device, + existing_bias_grad=bias_grad, + ) + + if config is None: + config = get_config("nptgmm", M, K, N, G, accumulate) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + } + ), "Invalid NPTGMM kernel config." + + grid = _nptgmm_grid( + K, + N, + G, + config["BLOCK_SIZE_K"], + config["BLOCK_SIZE_N"], + ) + + # fmt: off + tgmm_non_persistent_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias_grad_ptr, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_LHS=trans_lhs, + COMPUTE_BIAS_GRAD=compute_bias_grad, + ACCUMULATE=accumulate, + **config, + ) + # fmt: on + + return out diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py new file mode 100644 index 0000000000000000000000000000000000000000..3f6c88581a64044518125623f116082c53bd5474 --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py @@ -0,0 +1,46 @@ +import triton + +# Detect the GPU arch lazily: querying the triton driver at import time fails +# in headless environments (e.g. the kernel-builder ABI check sandbox has no +# GPU), and the original JAX fallback pulled in an unrelated runtime dep. The +# arch is only actually needed when a GMM kernel is dispatched, so resolve and +# cache on first call. +_CACHED_ARCH = None + + +def get_arch(): + global _CACHED_ARCH + if _CACHED_ARCH is not None: + return _CACHED_ARCH + try: + _CACHED_ARCH = triton.runtime.driver.active.get_current_target().arch + except RuntimeError: + try: + from jax._src.lib import gpu_triton as triton_kernel_call_lib + _CACHED_ARCH = triton_kernel_call_lib.get_arch_details("0").split(":")[0] + except ImportError as e: + raise RuntimeError( + "Cannot determine GPU arch: triton driver is inactive and " + "JAX is not available. A GPU is required for grouped GEMM." + ) from e + return _CACHED_ARCH + + +def is_gluon_avail(): + return get_arch() in ("gfx950", "gfx1250") + + +def is_fp4_avail(): + return get_arch() in ("gfx950", "gfx1250") + + +def is_fp8_avail(): + return get_arch() in ("gfx942", "gfx950", "gfx1250", "gfx1200", "gfx1201") + + +def is_mx_scale_preshuffling_avail(): + return get_arch() in ("gfx950", "gfx1250") + + +def is_tdm_avail(): + return get_arch() in ("gfx1250",) diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..99792bb3ba2fab8fff223bba733ced1eb6e6df53 --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: MIT + +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl + + +@triton.jit +def remap_xcd_chunked( + pid, GRID_MN, NUM_XCDS: tl.constexpr = 8, CHUNK_SIZE: tl.constexpr = 2 +): + # Compute current XCD and local PID + xcd = pid % NUM_XCDS + # distribute the modulo pids in round robin + if pid > (GRID_MN // (NUM_XCDS * CHUNK_SIZE)) * (NUM_XCDS * CHUNK_SIZE): + return pid + local_pid = pid // NUM_XCDS + # Calculate chunk index and position within chunk + chunk_idx = local_pid // CHUNK_SIZE + pos_in_chunk = local_pid % CHUNK_SIZE + # Calculate new PID + new_pid = chunk_idx * NUM_XCDS * CHUNK_SIZE + xcd * CHUNK_SIZE + pos_in_chunk + return new_pid + + +@triton.jit +def remap_xcd(pid, GRID_MN, NUM_XCDS: tl.constexpr = 8): + ## pid remapping on xcds + # Number of pids per XCD in the new arrangement + pids_per_xcd = (GRID_MN + NUM_XCDS - 1) // NUM_XCDS + # When GRID_MN cannot divide NUM_XCDS, some xcds will have + # pids_per_xcd pids, the other will have pids_per_xcd - 1 pids. + # We calculate the number of xcds that have pids_per_xcd pids as + # tall_xcds + tall_xcds = GRID_MN % NUM_XCDS + tall_xcds = NUM_XCDS if tall_xcds == 0 else tall_xcds + # Compute current XCD and local pid within the XCD + xcd = pid % NUM_XCDS + local_pid = pid // NUM_XCDS + # Calculate new pid based on the new grouping + # Note that we need to consider the following two cases: + # 1. the current pid is on a tall xcd + # 2. the current pid is on a short xcd + if xcd < tall_xcds: + pid = xcd * pids_per_xcd + local_pid + else: + pid = ( + tall_xcds * pids_per_xcd + + (xcd - tall_xcds) * (pids_per_xcd - 1) + + local_pid + ) + + return pid + + +@triton.jit +def pid_grid(pid: int, num_pid_m: int, num_pid_n: int, GROUP_SIZE_M: tl.constexpr = 1): + """ + Maps 1D pid to 2D grid coords (pid_m, pid_n). + + Args: + - pid: 1D pid + - num_pid_m: grid m size + - num_pid_n: grid n size + - GROUP_SIZE_M: tl.constexpr: default is 1 + """ + if GROUP_SIZE_M == 1: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + tl.assume(group_size_m >= 0) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + return pid_m, pid_n + + +@triton.jit +def pid_grid_3d(pid: int, num_pid_m: int, num_pid_n: int, num_pid_k): + """ + Maps 1D pid to 3D grid coords (pid_m, pid_n, pid_k). + Args: + - pid: 1D pid + - num_pid_m: grid m size + - num_pid_n: grid n size + - num_pid_k: grid k size + + Returns: + - pid_m, pid_n, pid_k: 3D grid coordinates + """ + pid_m = pid % num_pid_m + pid_n = (pid // num_pid_m) % num_pid_n + pid_k = pid // (num_pid_m * num_pid_n) % num_pid_k + + return pid_m, pid_n, pid_k diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py new file mode 100644 index 0000000000000000000000000000000000000000..153dee65b50ab5f833262481889d2184d1ca639f --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py @@ -0,0 +1,752 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + +# Imports. +# ------------------------------------------------------------------------------ + +# PyTorch +import torch +from torch import Tensor + +# AITER: logging +from .logger import AiterTritonLogger + +_LOGGER: AiterTritonLogger = AiterTritonLogger() + + +# Supported data types. +# ------------------------------------------------------------------------------ + +# Supported data types, as strings. +SUPPORTED_DTYPES_STR: set[str] = {"fp16", "bf16"} + + +# Convert string data type to PyTorch data type. +def dtype_from_str(dtype_str: str) -> torch.dtype: + dtype_str = dtype_str.strip().lower() + dtype_str = dtype_str[1:] if dtype_str[0] in {"i", "o"} else dtype_str + assert ( + dtype_str in SUPPORTED_DTYPES_STR + ), "String data type isn't in set of supported string data types." + return {"fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str] + + +# Supported data types, as PyTorch types. +SUPPORTED_DTYPES: set[torch.dtype] = { + dtype_from_str(dtype_str) for dtype_str in SUPPORTED_DTYPES_STR +} + + +# Convert PyTorch data type to string data type. +def str_from_dtype(dtype: torch.dtype) -> str: + assert ( + dtype in SUPPORTED_DTYPES + ), "PyTorch data type isn't in set of supported PyTorch data types." + return {torch.float16: "fp16", torch.bfloat16: "bf16"}[dtype] + + +# Default data type, as string. +DTYPE_STR: str = "bf16" +assert ( + DTYPE_STR in SUPPORTED_DTYPES_STR +), "Default string data type isn't in set of supported string data types." + + +# Default data type, as PyTorch type. +DTYPE: torch.dtype = dtype_from_str(DTYPE_STR) + + +# Other defaults. +# ------------------------------------------------------------------------------ + +# Default device. +DEVICE: torch.device | str = "cuda" + +# Default RNG seed for input generation. +RNG_SEED: int = 0 + +# Default number of group sizes. +NUM_GROUP_SIZES: int = 1 + +# Default transposition (NN). +TRANS_LHS: bool = False +TRANS_RHS: bool = False + + +# Parameter checking functions. +# ------------------------------------------------------------------------------ + + +def is_power_of_2(x: int) -> bool: + return (x > 0) and (x & (x - 1) == 0) + + +def check_input_device_dtype( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor, bias: Tensor | None = None +) -> None: + assert ( + lhs.device == rhs.device == group_sizes.device + ), f"All input tensors must be in the same device (lhs = {lhs.device}, rhs = {rhs.device}, group_sizes = {group_sizes.device})." + assert ( + lhs.dtype == rhs.dtype + ), f"lhs and rhs types must match (lhs = {lhs.dtype}, rhs = {rhs.dtype})." + assert group_sizes.dtype == torch.int32, "group_sizes type must be int32." + + if bias is not None: + assert ( + bias.device == lhs.device + ), f"bias must be on the same device as lhs (bias = {bias.device}, lhs = {lhs.device})." + assert ( + bias.dtype == lhs.dtype + ), f"bias dtype must match lhs dtype (bias = {bias.dtype}, lhs = {lhs.dtype})." + + +def check_bias_shape_stride(bias: Tensor, G: int, N: int) -> None: + assert bias.shape == ( + G, + N, + ), f"bias must have shape (G, N) = ({G}, {N}), got {bias.shape}." + assert bias.stride() == (N, 1), "bias must be row-major (bias.stride() == (N, 1))." + + +# Generation of group sizes. +# ------------------------------------------------------------------------------ + + +# Probabilities for generating random group sizes. +UNUSED_TOKENS_PROB: float = 0.0 +UNUSED_EXPERTS_PROB: float = 0.1 + + +def gen_uniform_group_sizes( + M: int, + G: int, + device: torch.device | str = DEVICE, +) -> Tensor: + assert M >= 0, f"Number of tokens M must be non-negative (it's {M})." + assert G > 0, f"Number of experts G must be positive (it's {G})." + + base = M // G + remainder = M % G + group_sizes = torch.full((G,), base, dtype=torch.int32, device=device) + if remainder > 0: + group_sizes[:remainder] += 1 + + assert ( + len(group_sizes) == G + ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})." + assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative." + assert ( + torch.sum(group_sizes).item() == M + ), f"Group sizes don't add up to total tokens {M}." + assert group_sizes.dtype == torch.int32, "Group sizes must be int32." + + return group_sizes + + +def gen_group_sizes( + M: int, + G: int, + device: torch.device | str = DEVICE, + rng_seed: int | None = RNG_SEED, + unused_tokens_prob: float = UNUSED_TOKENS_PROB, + unused_experts_prob: float = UNUSED_EXPERTS_PROB, +) -> Tensor: + assert M >= 0, f"Number of tokens M must be non-negative (it's {M})." + assert G > 0, f"Number of experts G must be positive (it's {G})." + assert ( + 0 <= unused_tokens_prob <= 1 + ), f"Probability of unused tokens must be in [0, 1] interval (it's {unused_tokens_prob})." + assert ( + 0 <= unused_experts_prob <= 1 + ), f"Probability of unused experts must be in [0, 1] interval (it's {unused_experts_prob})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + if unused_tokens_prob > 0: + # Optionally drop tokens to simulate routing sparsity, some tokens may not be routed. + num_unused_tokens = M + while num_unused_tokens == M: + num_unused_tokens = int( + torch.binomial( + torch.tensor(float(M), device=device), + torch.tensor(unused_tokens_prob, device=device), + ).item() + ) + else: + num_unused_tokens = 0 + num_used_tokens = M - num_unused_tokens + assert ( + num_unused_tokens >= 0 + ), f"Number of unused tokens must be non-negative (it's {num_unused_tokens})." + assert ( + num_used_tokens > 0 + ), f"Number of used tokens must be positive (it's {num_used_tokens})." + assert ( + num_used_tokens + num_unused_tokens == M + ), f"Unused + used tokens don't add up total tokens ({num_used_tokens} + {num_unused_tokens} != {M})." + + if num_unused_tokens > 0: + _LOGGER.debug( + f"Group sizes generation: dropped {num_unused_tokens} token{'s' if num_unused_tokens > 1 else ''}.", + ) + + if unused_experts_prob > 0: + # Some experts may have zero tokens assigned to them. + num_used_experts = 0 + while num_used_experts == 0: + used_experts = torch.nonzero( + torch.rand((G,), device=device) >= unused_experts_prob + ).squeeze() + num_used_experts = used_experts.numel() + else: + used_experts = torch.arange(0, G, device=device) + num_used_experts = G + num_unused_experts = G - num_used_experts + assert ( + num_unused_experts >= 0 + ), f"Number of unused experts must be non-negative (it's {num_unused_experts})." + assert ( + num_used_experts >= 1 + ), f"At least one expert must be used (it's {num_used_experts})." + assert ( + num_unused_experts + num_used_experts == G + ), f"Unused + used experts don't add up total experts ({num_unused_experts} + {num_used_experts} != {G})." + + if num_unused_experts > 0: + _LOGGER.debug( + f"Group sizes generation: dropped {num_unused_experts} expert{'s' if num_unused_experts > 1 else ''}.", + ) + + group_sizes = torch.bincount( + used_experts[ + torch.randint(low=0, high=num_used_experts, size=(num_used_tokens,)) + ], + minlength=G, + ).to(torch.int32) + + assert ( + len(group_sizes) == G + ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})." + assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative." + assert ( + torch.sum(group_sizes).item() == num_used_tokens + ), f"Group sizes don't add up to used tokens {num_used_tokens}." + assert group_sizes.dtype == torch.int32, "Group sizes must be int32." + + return group_sizes + + +def gen_multiple_group_sizes( + num_group_sizes: int, + M: int, + G: int, + device: torch.device | str = DEVICE, + rng_seed: int | None = RNG_SEED, + unused_tokens_prob: float = UNUSED_TOKENS_PROB, + unused_experts_prob: float = UNUSED_EXPERTS_PROB, + group_sizes_0: Tensor | None = None, +) -> list[Tensor]: + assert ( + num_group_sizes > 0 + ), f"Number of group sizes to be generated must be positive, it's {num_group_sizes}." + multiple_group_sizes = [ + gen_group_sizes( + M, + G, + device=device, + rng_seed=rng_seed if g == 0 else None, + unused_tokens_prob=unused_tokens_prob, + unused_experts_prob=unused_experts_prob, + ) + for g in range( + num_group_sizes if group_sizes_0 is None else num_group_sizes - 1 + ) + ] + if group_sizes_0 is not None: + multiple_group_sizes.insert(0, group_sizes_0) + assert ( + len(multiple_group_sizes) == num_group_sizes + ), f"Expecting {num_group_sizes} distinct group sizes (it's {len(multiple_group_sizes)})." + return multiple_group_sizes + + +# GMM helpers: tensor generation. +# ------------------------------------------------------------------------------ + + +def gen_gmm_input( + M: int, + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + trans_rhs: bool = TRANS_RHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, +) -> tuple[Tensor, Tensor, Tensor]: + assert M > 0, f"Number of lhs rows M must be positive (M = {M})." + assert K > 0, f"Number of lhs columns / rhs rows K must be positive (K = {K})." + assert N > 0, f"Number of rhs columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + lhs = torch.randn((M, K), dtype=torch.float32, device=device) + lhs = lhs.to(preferred_element_type) + + if trans_rhs: + rhs = torch.randn((G, N, K), dtype=torch.float32, device=device).permute( + 0, 2, 1 + ) + else: + rhs = torch.randn((G, K, N), dtype=torch.float32, device=device) + rhs = rhs.to(preferred_element_type) + + group_sizes = ( + gen_uniform_group_sizes(M, G, device=device) + if unif_group_sizes + else gen_group_sizes(M, G, device=device, rng_seed=None) + ) + + return lhs, rhs, group_sizes + + +def gen_gmm_output( + M: int, + N: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, +) -> Tensor: + assert M > 0, f"Number of out rows M must be positive (M = {M})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + + out = torch.empty((M, N), dtype=preferred_element_type, device=device) + + return out + + +def gen_gmm_tensors( + M: int, + K: int, + N: int, + G: int, + num_group_sizes: int, + device: torch.device | str = DEVICE, + input_type: torch.dtype = DTYPE, + output_type: torch.dtype = DTYPE, + trans_lhs: bool = False, + trans_rhs: bool = TRANS_RHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, + use_bias: bool = False, +) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]: + lhs, rhs, group_sizes_0 = gen_gmm_input( + M, + K, + N, + G, + device=device, + preferred_element_type=input_type, + trans_rhs=trans_rhs, + rng_seed=rng_seed, + unif_group_sizes=unif_group_sizes, + ) + multiple_group_sizes = gen_multiple_group_sizes( + num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0 + ) + out = gen_gmm_output(M, N, device=device, preferred_element_type=output_type) + bias = None + if use_bias: + torch.manual_seed(rng_seed + 1000) # Different seed for bias + bias = torch.randn(G, N, dtype=input_type, device=device) + + return lhs, rhs, multiple_group_sizes, out, bias + + +# GMM helpers: get information from tensors. +# ------------------------------------------------------------------------------ + + +def get_gmm_shape( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor +) -> tuple[int, int, int, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})." + assert ( + group_sizes.dim() == 1 + ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})." + + M, lhs_k = lhs.shape + rhs_g, rhs_k, N = rhs.shape + group_sizes_g = group_sizes.shape[0] + + assert ( + lhs_k == rhs_k + ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})." + K = lhs_k + assert ( + rhs_g == group_sizes_g + ), f"G dimension of rhs and group_sizes don't match (rhs = {rhs_g}, group_sizes = {group_sizes_g})." + G = rhs_g + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + return M, K, N, G + + +def get_gmm_output( + M: int, + N: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, +) -> Tensor: + assert M > 0, f"Number of out rows M must be positive (M = {M})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + + if existing_out is not None: + assert ( + existing_out.device == device + ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})." + assert ( + existing_out.dtype == preferred_element_type + ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})." + assert existing_out.shape == ( + M, + N, + ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(M, N)})." + return existing_out + + return gen_gmm_output( + M, + N, + device=device, + preferred_element_type=preferred_element_type, + ) + + +def get_gmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})." + assert out.dim() == 2, f"out must have 2 dimensions (it's {out.dim()})." + + lhs_m, lhs_k = lhs.shape + G, rhs_k, rhs_n = rhs.shape + out_m, out_n = out.shape + + assert ( + lhs_m == out_m + ), f"M dimension of lhs and out don't match (lhs = {lhs_m}, rhs = {out_m})." + M = lhs_m + assert ( + lhs_k == rhs_k + ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})." + K = lhs_k + assert ( + rhs_n == out_n + ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})." + N = rhs_n + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + is_lhs_row_major = lhs.stride() == (K, 1) + assert is_lhs_row_major, "lhs must be row-major." + is_rhs_row_major = rhs.stride() == (K * N, N, 1) + is_rhs_col_major = rhs.stride() == (K * N, 1, K) + assert ( + is_rhs_row_major != is_rhs_col_major + ), "rhs must be row-major or column-major." + is_out_row_major = out.stride() == (N, 1) + assert is_out_row_major, "out must be row-major." + + # Get rhs leading dimension according to transposition configuration. + ld_rhs = N if is_rhs_row_major else K + + return is_rhs_col_major, ld_rhs + + +# TGMM helpers: tensor generation. +# ------------------------------------------------------------------------------ + + +def gen_tgmm_input( + M: int, + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + trans_lhs: bool = TRANS_LHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, +) -> tuple[Tensor, Tensor, Tensor]: + assert K > 0, f"Number of lhs rows K must be positive (M = {K})." + assert M > 0, f"Number of lhs columns / rhs rows M must be positive (K = {M})." + assert N > 0, f"Number of rhs columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + if trans_lhs: + lhs = torch.randn((M, K), dtype=torch.float32, device=device).T + else: + lhs = torch.randn((K, M), dtype=torch.float32, device=device) + lhs = lhs.to(preferred_element_type) + + rhs = torch.randn((M, N), dtype=torch.float32, device=device) + rhs = rhs.to(preferred_element_type) + + group_sizes = ( + gen_uniform_group_sizes(M, G, device=device) + if unif_group_sizes + else gen_group_sizes(M, G, device=device, rng_seed=None) + ) + + return lhs, rhs, group_sizes + + +def gen_tgmm_output( + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, +) -> Tensor: + assert K > 0, f"Number of out rows K must be positive (K = {K})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + out = torch.empty((G, K, N), dtype=preferred_element_type, device=device) + + return out + + +def gen_tgmm_bias_grad( + K: int, + G: int, + device: torch.device | str = DEVICE, + with_bias_grad: bool = False, +) -> Tensor: + if with_bias_grad: + assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + return torch.empty((G, K), device=device, dtype=torch.float32) + else: + # Return dummy pointer when bias_grad is not needed. + # Must be float32 because atomic_add does not support bf16/fp16, + # and Triton validates the pointer dtype even in dead branches. + return torch.tensor([], device=device, dtype=torch.float32) + + +def gen_tgmm_tensors( + M: int, + K: int, + N: int, + G: int, + num_group_sizes: int, + device: torch.device | str = DEVICE, + input_type: torch.dtype = DTYPE, + output_type: torch.dtype = DTYPE, + trans_lhs: bool = TRANS_LHS, + trans_rhs: bool = False, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, + use_bias: bool = False, +) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]: + lhs, rhs, group_sizes_0 = gen_tgmm_input( + M, + K, + N, + G, + device=device, + preferred_element_type=input_type, + trans_lhs=trans_lhs, + rng_seed=rng_seed, + unif_group_sizes=unif_group_sizes, + ) + multiple_group_sizes = gen_multiple_group_sizes( + num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0 + ) + out = gen_tgmm_output(K, N, G, device=device, preferred_element_type=output_type) + if use_bias: + bias_grad = gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=True) + else: + bias_grad = None + return lhs, rhs, multiple_group_sizes, out, bias_grad + + +# TGMM helpers: get information from tensors. +# ------------------------------------------------------------------------------ + + +def get_tgmm_shape( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor +) -> tuple[int, int, int, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})." + assert ( + group_sizes.dim() == 1 + ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})." + + K, lhs_m = lhs.shape + rhs_m, N = rhs.shape + G = group_sizes.shape[0] + + assert ( + lhs_m == rhs_m + ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})." + M = lhs_m + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + return M, K, N, G + + +def get_tgmm_output( + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, +) -> Tensor: + assert K > 0, f"Number of out rows K must be positive (K = {K})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if existing_out is not None: + assert ( + existing_out.device == device + ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})." + assert ( + existing_out.dtype == preferred_element_type + ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})." + assert existing_out.shape == ( + G, + K, + N, + ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(G, K, N)})." + return existing_out + + return gen_tgmm_output( + K, + N, + G, + device=device, + preferred_element_type=preferred_element_type, + ) + + +def get_tgmm_bias_grad( + K: int, + G: int, + device: torch.device | str = DEVICE, + existing_bias_grad: Tensor | None = None, +) -> Tensor: + """ + Get or validate bias gradient tensor for TGMM. + + If existing_bias_grad is provided, validates its shape, device, dtype, and stride, + and always zeros it before returning (since the kernel uses atomic_add). + If existing_bias_grad is None, returns a dummy tensor (for use when COMPUTE_BIAS_GRAD=False). + Parameters + ---------- + K : int + Number of rows in the bias gradient tensor. + G : int + Number of groups. + device : torch.device or str + Device for the tensor. + existing_bias_grad : torch.Tensor or None + Existing bias gradient tensor to validate and use. + Returns + ------- + torch.Tensor + Valid bias gradient tensor or dummy tensor. + """ + assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if existing_bias_grad is not None: + # Validate existing bias_grad tensor. + expected_shape = (G, K) + assert ( + tuple(existing_bias_grad.shape) == expected_shape + ), f"bias_grad must have shape {expected_shape}, got {tuple(existing_bias_grad.shape)}." + assert ( + existing_bias_grad.device == device + ), f"bias_grad must be on the same device (bias_grad = {existing_bias_grad.device}, device = {device})." + assert ( + existing_bias_grad.dtype == torch.float32 + ), f"bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), got {existing_bias_grad.dtype}." + assert existing_bias_grad.stride() == ( + K, + 1, + ), f"bias_grad must be row-major with stride (K, 1) = ({K}, 1), got {existing_bias_grad.stride()}." + + # Always zero the tensor since bias_grad represents gradients for the current + # computation and should start fresh. The kernel uses atomic_add which adds to + # existing values, so we must zero before the kernel runs. + existing_bias_grad.zero_() + + return existing_bias_grad + + else: + return gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=False) + + +def get_tgmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})." + assert out.dim() == 3, f"out must have 3 dimensions (it's {out.dim()})." + + lhs_k, lhs_m = lhs.shape + rhs_m, rhs_n = rhs.shape + G, out_k, out_n = out.shape + + assert ( + lhs_m == rhs_m + ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})." + M = lhs_m + assert ( + lhs_k == out_k + ), f"K dimension of lhs and out don't match (lhs = {lhs_k}, rhs = {out_k})." + K = lhs_k + assert ( + rhs_n == out_n + ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})." + N = rhs_n + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + is_lhs_row_major = lhs.stride() == (M, 1) + is_lhs_col_major = lhs.stride() == (1, K) + assert ( + is_lhs_row_major != is_lhs_col_major + ), "lhs must be row-major or column-major." + is_rhs_row_major = rhs.stride() == (N, 1) + assert is_rhs_row_major, "rhs must be row-major." + is_out_row_major = out.stride() == (K * N, N, 1) + assert is_out_row_major, "out must be row-major." + + # Get lhs leading dimension according to transposition configuration. + ld_lhs = M if is_lhs_row_major else K + + return is_lhs_col_major, ld_lhs diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/logger.py b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..391ddf9b6543f5244e7f4932c8568d60748e15cd --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/logger.py @@ -0,0 +1,47 @@ +import os +import logging + + +# AITER Triton Logger which is singleton object around python logging. +# Note: Python logging is also a singleton object, but we want to read the +# env var AITER_LOG_LEVEL once at the beginning. Another alternative is to do +# this in __init__.py. In fact, that's how CK logger is setup. We can look at +# switching to that at some point +# +# AITER_LOG_LEVEL follows python logging levels +# DEBUG +# INFO +# WARNING +# ERROR +# CRITICAL +# +class AiterTritonLogger(object): + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(AiterTritonLogger, cls).__new__(cls) + log_level_str = os.getenv("AITER_TRITON_LOG_LEVEL", "WARNING").upper() + numeric_level = getattr(logging, log_level_str, logging.WARNING) + cls._instance._logger = logging.getLogger("AITER_TRITON") + cls._instance._logger.setLevel(numeric_level) + + return cls._instance + + def get_logger(self): + return self._logger + + def debug(self, msg): + self._logger.debug(msg) + + def info(self, msg): + self._logger.info(msg) + + def warning(self, msg): + self._logger.warning(msg) + + def error(self, msg): + self._logger.error(msg) + + def critical(self, msg): + self._logger.critical(msg) diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_megablocks_cuda_ae601bb.abi3.so b/build/torch211-cxx11-cu126-x86_64-linux/_megablocks_cuda_ae601bb.abi3.so deleted file mode 100644 index 7a59a9adada00fca27df22c3be2744a619dc2d9b..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu126-x86_64-linux/_megablocks_cuda_ae601bb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:04357ebe4748e32fc898f2b6b3c4310beda29692b0ac34b78bd1c031efdee1bb -size 13179696 diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_megablocks_cuda_f8f8b50.abi3.so b/build/torch211-cxx11-cu126-x86_64-linux/_megablocks_cuda_f8f8b50.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..9391d85aea8683ff62a9b2bde8a79d3d0b1ea4dc --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/_megablocks_cuda_f8f8b50.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:78a283cd033d5770287d652455033307d26b1896681abbeb5ed4d1cba4dbc1fe +size 13822768 diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_ops.py b/build/torch211-cxx11-cu126-x86_64-linux/_ops.py index 8dd1b7bcf680d2d32dd4ac912487118eafcee4ea..69afb8c26a3fa2691be277b0270d600d29a5865e 100644 --- a/build/torch211-cxx11-cu126-x86_64-linux/_ops.py +++ b/build/torch211-cxx11-cu126-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_cuda_ae601bb -ops = torch.ops._megablocks_cuda_ae601bb +from . import _megablocks_cuda_f8f8b50 +ops = torch.ops._megablocks_cuda_f8f8b50 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_cuda_ae601bb::{op_name}" + return f"_megablocks_cuda_f8f8b50::{op_name}" diff --git a/build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py b/build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py index 76037d8039cbfc2f0577275c78e4bc0be762592a..c7ef28ced79c830dae934177f059c1f4ddc24aad 100644 --- a/build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py +++ b/build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py @@ -2,16 +2,16 @@ # extensions. Otherwise libc10.so cannot be found. import torch -# # TODO(tgale): Wrap this in a try-block with better -# # error message and instructions for building the -# # c++ operations. -# import grouped_gemm_backend as backend +# On ROCm there is no CUTLASS grouped GEMM; dispatch to the vendored AITER +# Triton kernels instead. On CUDA we use the compiled CUTLASS `gmm` op. +_IS_ROCM = torch.version.hip is not None -# We import the backend operations from the megablocks package as -# grouped_gemm is vendored in megablocks in this repository. -# from ... import _ops as backend -# from megablocks._ops import ops as backend # type: ignore -from .._ops import ops as backend # type: ignore +if _IS_ROCM: + from .._grouped_gemm_triton import adapter as backend +else: + # We import the backend operations from the megablocks package as + # grouped_gemm is vendored in megablocks in this repository. + from .._ops import ops as backend # type: ignore def _allocate_output(a, b, batch_sizes, trans_a, trans_b): assert not (trans_a and trans_b) diff --git a/build/torch211-cxx11-cu126-x86_64-linux/metadata.json b/build/torch211-cxx11-cu126-x86_64-linux/metadata.json index bc7202ab8d715cad1aee4e42b2a479b869543603..0843e9de0dd35d448c75ae7fc5a2eff09da63271 100644 --- a/build/torch211-cxx11-cu126-x86_64-linux/metadata.json +++ b/build/torch211-cxx11-cu126-x86_64-linux/metadata.json @@ -1,6 +1,6 @@ { "name": "megablocks", - "id": "_megablocks_cuda_ae601bb", + "id": "_megablocks_cuda_f8f8b50", "version": 1, "license": "Apache-2.0", "python-depends": [], @@ -14,7 +14,8 @@ "8.6", "8.7", "8.9", - "9.0" + "9.0", + "9.0+PTX" ] } } diff --git a/build/torch211-cxx11-cu128-x86_64-linux/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py index 38075732c6d8fa0e1e6ef493145e1aca3851ae6b..0766d7b8da4f97baca212177b4bb989bc6374bf8 100644 --- a/build/torch211-cxx11-cu128-x86_64-linux/__init__.py +++ b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py @@ -3,7 +3,9 @@ import torch -from ._ops import ops +# Stable alias: bare `ops` is shadowed by `from . import layers` below. +from ._ops import ops as _compiled_ops +from . import ops from .grouped_gemm import backend as gg_backend from .grouped_gemm import ops as gg_ops @@ -136,7 +138,8 @@ def sort( Returns: The sorted values tensor """ - return ops.sort(x, end_bit, x_out, iota_out) + _compiled_ops.sort(x, end_bit, x_out, iota_out) + return x_out # Convenience functions for common use cases diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py new file mode 100644 index 0000000000000000000000000000000000000000..8c101d07cea416f9390b708e5a35fdc466e48aed --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py @@ -0,0 +1,574 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + + +# Imports. +# ------------------------------------------------------------------------------ + +# Python standard library +import functools + +# Triton +import triton +import triton.language as tl + +# AITER +from ..configs import CONFIGS as _CONFIGS +from ..utils._triton import arch_info +from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd + +# Kernel config. +# ------------------------------------------------------------------------------ + + +@functools.lru_cache() +def get_config( + gmm_type: str, M: int, K: int, N: int, G: int, accumulate: bool = False +) -> dict[str, int]: + assert gmm_type in { + "gmm", + "ptgmm", + "nptgmm", + }, f"'{gmm_type}' is an invalid GMM variant." + dev = arch_info.get_arch() + assert ( + dev in _CONFIGS + ), f"No GMM configuration tuned for arch '{dev}'. Supported: {sorted(_CONFIGS)}." + arch_configs = _CONFIGS[dev] + assert ( + "default" in arch_configs[gmm_type] + ), "Default configuration is absent." + key = "accumulate" if accumulate else "default" + return arch_configs[gmm_type][key] + + +# Common code shared by GMM and TGMM kernels. +# ------------------------------------------------------------------------------ + + +# XCD remapping followed by 1D PID to 2D grid mapping. +@triton.jit +def _remap_xcd_tile_grid( + tile_in_mm, + num_row_tiles, + num_col_tiles, + GROUP_SIZE: tl.constexpr = 1, + NUM_XCDS: tl.constexpr = 8, +): + return pid_grid( + remap_xcd(tile_in_mm, num_row_tiles * num_col_tiles, NUM_XCDS=NUM_XCDS), + num_row_tiles, + num_col_tiles, + GROUP_SIZE_M=GROUP_SIZE, + ) + + +# GMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.heuristics( + { + "K_DIVISIBLE_BY_BLOCK_SIZE_K": lambda META: META["K"] % META["BLOCK_SIZE_K"] + == 0, + } +) +@triton.jit +def gmm_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_RHS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + K_DIVISIBLE_BY_BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE: tl.constexpr, + GRID_DIM: tl.constexpr, + USE_BIAS: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + # Current tile. Each program computes multiple tiles of each group. + tile = tl.program_id(0) + tl.device_assert(tile >= 0, "tile < 0 (at initialization)") + + # Tile limit of last MM problem (inclusive). + last_mm_tile = 0 + + # Last input row of lhs and output row of out. Each group reads some rows of + # lhs and writes some rows to out. + last_m = 0 + + # Loop through all (m, K, N) MM problems: + # (m, K) x (K, N) = (m, N) + # sum(m) = M + for g in range(G): + # Get m dimension of current MM problem. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty + tl.device_assert(m >= 0, "m < 0") + + num_m_tiles = tl.cdiv(m, BLOCK_SIZE_M) + # num_m_tiles can be zero if group is empty + tl.device_assert(num_m_tiles >= 0, "num_m_tiles < 0") + + num_tiles = num_m_tiles * num_n_tiles + # num_tiles can be zero if group is empty + tl.device_assert(num_tiles >= 0, "num_tiles < 0") + + # Loop through tiles of current MM problem. + while tile >= last_mm_tile and tile < last_mm_tile + num_tiles: + # Figure out tile coordinates in current MM problem. + tile_in_mm = tile - last_mm_tile + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_m, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_m_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + # Do regular MM: + + tl.device_assert(tile_m * BLOCK_SIZE_M >= 0, "tile_m * BLOCK_SIZE_M < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_m = ( + tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + ) % m + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + + lhs_ptrs = lhs_ptr + (last_m + offs_lhs_m[:, None]) * K + offs_k[None, :] + + if TRANS_RHS: + rhs_ptrs = ( + rhs_ptr + + g.to(tl.int64) * K * N + + offs_k[:, None] + + offs_rhs_n[None, :] * K + ) + else: + rhs_ptrs = ( + rhs_ptr + + g.to(tl.int64) * K * N + + offs_k[:, None] * N + + offs_rhs_n[None, :] + ) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if K_DIVISIBLE_BY_BLOCK_SIZE_K: + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + else: + k_mask_limit = K - k * BLOCK_SIZE_K + lhs = tl.load( + lhs_ptrs, mask=offs_k[None, :] < k_mask_limit, other=0 + ) + rhs = tl.load( + rhs_ptrs, mask=offs_k[:, None] < k_mask_limit, other=0 + ) + + acc = tl.dot(lhs, rhs, acc=acc) + + lhs_ptrs += BLOCK_SIZE_K + + if TRANS_RHS: + rhs_ptrs += BLOCK_SIZE_K + else: + rhs_ptrs += BLOCK_SIZE_K * N + + # Add bias if enabled + if USE_BIAS: + offs_bias_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange( + 0, BLOCK_SIZE_N + ) + bias_ptrs = bias_ptr + g.to(tl.int64) * N + offs_bias_n + bias = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0) + # Convert bias to float32 to match accumulator precision + bias = bias.to(tl.float32) + # Broadcast bias across M dimension and add in float32 + acc += bias[None, :] + + # Convert to output dtype after all computations + acc = acc.to(out_ptr.type.element_ty) + + offs_out_m = tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + (last_m + offs_out_m[:, None]) * N + offs_out_n[None, :] + ) + + tl.store( + out_ptrs, + acc, + mask=(offs_out_m[:, None] < m) & (offs_out_n[None, :] < N), + ) + + # Go to the next tile by advancing number of programs. + tile += GRID_DIM + tl.device_assert(tile > 0, "tile <= 0 (at update)") + + # Get ready to go to the next MM problem. + + last_mm_tile += num_tiles + # last_mm_tile can be zero if group 0 is skipped + tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)") + + last_m += m + # last_m can be zero if group 0 is skipped + tl.device_assert(last_m >= 0, "last_m < 0 (at update)") + tl.device_assert(last_m <= M, "last_m > M (at update)") + + +# Persistent TGMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.jit +def tgmm_persistent_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_grad_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_LHS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE: tl.constexpr, + GRID_DIM: tl.constexpr, + COMPUTE_BIAS_GRAD: tl.constexpr, + ACCUMULATE: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0") + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + num_tiles = num_k_tiles * num_n_tiles + tl.device_assert(num_tiles > 0, "num_tiles <= 0") + + # Current tile. Each program computes multiple tiles of each group. + tile = tl.program_id(0) + tl.device_assert(tile >= 0, "tile < 0 (at initialization)") + + # Tile limit of last MM problem (inclusive). + last_mm_tile = 0 + + # Last input column of lhs and input row of rhs. Each group reads some + # columns of lhs and some rows of rhs. + last_m = 0 + + # Loop through all (K, m, N) MM problems: + # (K, m) x (m, N) = (K, N) + # sum(m) = M + for g in range(G): + # Get m dimension of current MM problem. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty + tl.device_assert(m >= 0, "m < 0") + + # Loop through tiles of current MM problem. + while tile >= last_mm_tile and tile < last_mm_tile + num_tiles: + # Figure out tile coordinates in current MM problem. + tile_in_mm = tile - last_mm_tile + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_k, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + # Do regular MM: + + tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + + if TRANS_LHS: + lhs_ptrs = ( + lhs_ptr + offs_lhs_k[:, None] + (last_m + offs_m[None, :]) * K + ) + else: + lhs_ptrs = ( + lhs_ptr + offs_lhs_k[:, None] * M + (last_m + offs_m[None, :]) + ) + + rhs_ptrs = rhs_ptr + (last_m + offs_m[:, None]) * N + offs_rhs_n[None, :] + + loop_m = tl.cdiv(m, BLOCK_SIZE_M) + m_divisible_by_block_m = m % BLOCK_SIZE_M == 0 + if not m_divisible_by_block_m: + loop_m -= 1 + + acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32) + + # Initialize bias accumulator + bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32) + + for _ in range(0, loop_m): + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + + acc = tl.dot(lhs, rhs, acc=acc) + + # Accumulate for bias gradient: sum lhs across M dimension + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum( + lhs, axis=1 + ) # Sum across M dimension [K, M] -> [K] + + if TRANS_LHS: + lhs_ptrs += BLOCK_SIZE_M * K + else: + lhs_ptrs += BLOCK_SIZE_M + + rhs_ptrs += BLOCK_SIZE_M * N + + if not m_divisible_by_block_m: + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0) + rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0) + acc = tl.dot(lhs, rhs, acc=acc) + + # Accumulate last chunk for bias gradient + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) + + acc = acc.to(out_ptr.type.element_ty) + + offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + + g.to(tl.int64) * K * N + + offs_out_k[:, None] * N + + offs_out_n[None, :] + ) + + mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N) + if ACCUMULATE: + # Load existing values and add to them (like beta=1 in BLAS) + old_vals = tl.load(out_ptrs, mask=mask, other=0.0) + tl.store(out_ptrs, acc + old_vals, mask=mask) + else: + # Overwrite output (like beta=0 in BLAS) + tl.store(out_ptrs, acc, mask=mask) + + # Store bias gradient (only for first N tile, sum across all M) + if COMPUTE_BIAS_GRAD and tile_n == 0: + # Keep as float32 for atomic_add (bf16 not supported for atomics) + bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k + # Use atomic add since multiple K-tiles may write to same expert's bias + tl.atomic_add( + bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed" + ) + + # Go to the next tile by advancing number of programs. + tile += GRID_DIM + tl.device_assert(tile > 0, "tile <= 0 (at update)") + + # Get ready to go to the next MM problem. + + last_mm_tile += num_tiles + # last_mm_tile can be zero if group 0 is skipped + tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)") + + last_m += m + # last_m can be zero if group 0 is skipped + tl.device_assert(last_m >= 0, "last_m < 0 (at update)") + tl.device_assert(last_m <= M, "last_m > M (at update)") + + +# Regular non-persistent TGMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.heuristics({"BLOCK_SIZE_G": lambda META: triton.next_power_of_2(META["G"])}) +@triton.jit +def tgmm_non_persistent_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_grad_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_LHS: tl.constexpr, + BLOCK_SIZE_G: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE: tl.constexpr, + COMPUTE_BIAS_GRAD: tl.constexpr, + ACCUMULATE: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + # Get group ID from grid. + g = tl.program_id(0) + tl.device_assert(g >= 0, "g < 0") + tl.device_assert(g < G, "g >= G") + + # Get m dimension of current MM group. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty. + tl.device_assert(m >= 0, "m < 0") + + # Skip empty groups. + if m == 0: + return + + # Compute sum(group_sizes) until current group g. + # It's the starting column of lhs and starting row of rhs. + offs_g = tl.arange(0, BLOCK_SIZE_G) + group_sizes = tl.load(group_sizes_ptr + offs_g, mask=offs_g < g, other=0) + start_m = tl.sum(group_sizes) + + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0") + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + # Get MM tile from grid. + tile_in_mm = tl.program_id(1) + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_k, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_k = (tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K + offs_rhs_n = (tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + + if TRANS_LHS: + lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] + (start_m + offs_m[None, :]) * K + else: + lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] * M + (start_m + offs_m[None, :]) + + rhs_ptrs = rhs_ptr + (start_m + offs_m[:, None]) * N + offs_rhs_n[None, :] + + loop_m = tl.cdiv(m, BLOCK_SIZE_M) + m_divisible_by_block_m = m % BLOCK_SIZE_M == 0 + if not m_divisible_by_block_m: + loop_m -= 1 + + acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32) + # Initialize bias accumulator + bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32) + + for _ in range(0, loop_m): + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + + acc = tl.dot(lhs, rhs, acc=acc) + + # Accumulate for bias gradient: sum lhs across M dimension + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) # [K, M] -> [K] + + if TRANS_LHS: + lhs_ptrs += BLOCK_SIZE_M * K + else: + lhs_ptrs += BLOCK_SIZE_M + + rhs_ptrs += BLOCK_SIZE_M * N + + if not m_divisible_by_block_m: + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0) + rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0) + acc = tl.dot(lhs, rhs, acc=acc) + # Accumulate last chunk for bias gradient + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) + + acc = acc.to(out_ptr.type.element_ty) + + offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + g.to(tl.int64) * K * N + offs_out_k[:, None] * N + offs_out_n[None, :] + ) + + mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N) + if ACCUMULATE: + # Load existing values and add to them (like beta=1 in BLAS) + old_vals = tl.load(out_ptrs, mask=mask, other=0.0) + tl.store(out_ptrs, acc + old_vals, mask=mask) + else: + # Overwrite output (like beta=0 in BLAS) + tl.store(out_ptrs, acc, mask=mask) + + # Store bias gradient (only for first N tile, sum across all M) + if COMPUTE_BIAS_GRAD and tile_n == 0: + # Keep as float32 for atomic_add (bf16/fp16 not supported for atomics) + bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k + # Use atomic add since multiple K-tiles may write to same expert's bias + tl.atomic_add(bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed") diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/adapter.py b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..98c224244f27445384e0c2377d73516406927536 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/adapter.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Adapt AITER's Triton grouped GEMM to MegaBlocks' ``gmm`` calling convention. + +MegaBlocks (following tgale96/grouped_gemm) uses a single ``gmm`` entry point +with ``trans_a`` / ``trans_b`` flags: + +* ``trans_a=False, trans_b=False``: a(M,K) @ b(G,K,N) -> c(M,N) +* ``trans_a=False, trans_b=True`` : a(M,K) @ b(G,N,K)^T -> c(M,N) (dgrad) +* ``trans_a=True`` : a(M,K)^T @ b(M,N) per group -> c(G,K,N) (wgrad) + +AITER exposes these as two kernels: ``gmm`` ((M,K)@(G,K,N)->(M,N), transposition +of the 3D operand inferred from strides) and ``ptgmm`` ((K,M)@(M,N)->(G,K,N), +transposition of the 2D operand inferred from strides). +""" + +import torch + +from .gmm import gmm as _aiter_gmm +from .gmm import ptgmm as _aiter_ptgmm + + +def gmm(a, b, c, batch_sizes, trans_a=False, trans_b=False): + # AITER requires group sizes to be int32 and to live on the compute device. + group_sizes = batch_sizes.to(device=a.device, dtype=torch.int32) + + # AITER asserts exact strides: gmm wants lhs/rhs row-major (a transposed + # 3D operand must be exactly column-major), tgmm wants rhs row-major and + # lhs row/column-major. Make operands contiguous first so the transposed + # views have the precise strides the kernels expect. `.contiguous()` is a + # no-op when the tensor is already contiguous. + if trans_a: + # Weight gradient: a(M,K), b(M,N) -> c(G,K,N). + # Pass a transposed so AITER sees lhs(K,M) column-major (TRANS_LHS). + _aiter_ptgmm( + a.contiguous().transpose(0, 1), + b.contiguous(), + group_sizes, + preferred_element_type=c.dtype, + existing_out=c, + ) + else: + # trans_b contracts b's last dim: pass a column-major (G,K,N) view. + rhs = b.contiguous() + if trans_b: + rhs = rhs.transpose(1, 2) + _aiter_gmm( + a.contiguous(), + rhs, + group_sizes, + preferred_element_type=c.dtype, + existing_out=c, + ) + return c diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/configs.py b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..9a4fe5617d8100869aa76dba9b7d22c7bcab814f --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/configs.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: MIT +# Tuned GMM configs vendored from ROCm/aiter (aiter/ops/triton/configs/). +# Inlined as a Python module so packaging always includes them. + +CONFIGS = {'gfx1250': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx942': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx950': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}} diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/gmm.py b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/gmm.py new file mode 100644 index 0000000000000000000000000000000000000000..e30c9326c6d4e4836d1303e2761ea2440a7f4750 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/gmm.py @@ -0,0 +1,567 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + + +# Imports. +# ------------------------------------------------------------------------------ + +# PyTorch +import torch +from torch import Tensor + +# Triton +import triton + +# AITER: GMM utility functions +from .utils.gmm_common import ( + DTYPE, + is_power_of_2, + check_input_device_dtype, + check_bias_shape_stride, + get_gmm_shape, + get_gmm_output, + get_gmm_transposition, + get_tgmm_shape, + get_tgmm_output, + get_tgmm_bias_grad, + get_tgmm_transposition, +) + +# AITER: GMM Triton kernels +from ._triton_kernels.gmm import ( + gmm_kernel, + tgmm_persistent_kernel, + tgmm_non_persistent_kernel, + get_config, +) + +# GMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _gmm_grid( + N: int, + block_size_m: int, + block_size_n: int, + group_sizes: Tensor, + grid_dim: int, +) -> tuple[int]: + assert N > 0, f"N must be positive, it's {N}." + assert is_power_of_2( + block_size_m + ), f"M-dimension tile size must be a power of 2 (it's {block_size_m})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + assert torch.all(group_sizes >= 0).item(), "All group_sizes must be non-negative." + assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})." + num_m_tiles = (group_sizes + block_size_m - 1) // block_size_m + assert torch.all(num_m_tiles >= 0).item(), "All num_m_tiles must be non-negative." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles = torch.sum(num_m_tiles * num_n_tiles).item() + assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}." + num_programs = int(min(grid_dim, num_tiles)) + assert num_programs > 0, f"num_programs must be positive, it's {num_programs}." + return (num_programs,) + + +def gmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias: Tensor | None = None, +) -> Tensor: + """ + Perform Group Matrix Multiplication (GMM): out = lhs @ rhs + bias + + lhs rows are divided into G groups. Each group of lhs rows is matrix multiplied with a plane of + rhs 3D tensor and then stored in a slice of out. In PyTorch parlance, it can be implemented as + follows for a given group g: + out[group_start:group_end, :] = lhs[group_start:group_end, :] @ rhs[g] + bias[g] + + The size of each group, and their respective start and end positions are specified by + group_sizes tensor. For instance, suppose that group_sizes = [3, 2, 4, 1]. In this particular + case we have 4 groups. The 1st group starts at 0 and ends at 2, the second group starts at 3 and + ends at 4, the third group starts at 5 and ends at 8, and the fourth and final group consists of + just the 10th (last) row of lhs. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (M, K). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 3D input tensor. Shape: (G, K, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (M, N), its data type must match preferred_element_type + and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias : torch.Tensor or None, optional + Optional bias tensor. Shape: (G, N). + If provided, bias data type must match lhs and rhs data type, and bias must be on the same + device as other input tensors. Each group g adds bias[g] to the output. + + Returns + ------- + torch.Tensor + The computed output 2D tensor. Shape: (M, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - GMM is implemented with a persistent Triton kernel. + - lhs must be row-major (lhs.stride() == (K, 1)). + - rhs can be row-major (rhs.stride() == (K * N, N, 1)) or column-major (rhs.stride() == + (K * N, 1, K)). If rhs is row-major then kernel parameter TRANS_RHS == False, this is useful + for implementing forward pass. If rhs is column-major then kernel parameter TRANS_RHS == True, + this is useful for computing the lhs derivative in the backward pass, while fusing the + transposition. + - out must be row-major (out.stride() == (N, 1)). + - bias must be row-major (bias.stride() == (N, 1)) if provided. + """ + use_bias = bias is not None + check_input_device_dtype(lhs, rhs, group_sizes, bias) + + M, K, N, G = get_gmm_shape(lhs, rhs, group_sizes) + + if use_bias: + check_bias_shape_stride(bias, G, N) + + out = get_gmm_output( + M, + N, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_rhs, _ = get_gmm_transposition(lhs, rhs, out) + + if config is None: + config = get_config("gmm", M, K, N, G) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + "GRID_DIM", + } + ), "Invalid GMM kernel config." + + grid = _gmm_grid( + N, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + group_sizes, + config["GRID_DIM"], + ) + + # fmt: off + gmm_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_RHS=trans_rhs, + USE_BIAS=use_bias, + **config, + ) + # fmt: on + + return out + + +# Persistent TGMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _ptgmm_grid( + K: int, + N: int, + G: int, + block_size_k: int, + block_size_n: int, + grid_dim: int, +) -> tuple[int]: + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}." + assert G > 0, f"G must be positive, it's {G}." + assert is_power_of_2( + block_size_k + ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})." + num_k_tiles = triton.cdiv(K, block_size_k) + assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles = G * num_k_tiles * num_n_tiles + assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}." + num_programs = min(grid_dim, num_tiles) + assert num_programs > 0, f"num_programs must be positive, it's {num_programs}." + return (num_programs,) + + +def ptgmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias_grad: Tensor | None = None, + accumulate: bool = False, +) -> Tensor: + """ + Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs + + lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with + the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch + parlance, it can be implemented as follows for a given group g: + out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :] + + The 't' in the operator name derives from MaxText implementation + (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py), + which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor + shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N). + + The 'p' in the operator name means that it is implemented with a persistent kernel. There is + also the non-persistent variation, which is implemented with a regular kernel. Please take a + look at nptgmm operator. Both ptgmm and nptgmm implement the same computation, choosing one or + the other is a matter of performance for the target workload. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (K, M). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 2D input tensor. Shape: (M, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (G, K, N), its data type must match + preferred_element_type and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias_grad : torch.Tensor or None, optional + Optional bias gradient output tensor. Shape: (G, K). + If provided, the kernel will compute the bias gradient and write it to this tensor. + bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), + accumulate : bool, optional + Whether to accumulate into existing output tensor values. Default is False. + If False, output will be overwritten with fresh computation. + If True, results will be added to existing output tensor values. + + Returns + ------- + torch.Tensor + The computed output 3D tensor. Shape: (G, K, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - PTGMM is implemented with a persistent Triton kernel. + - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs + is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel + parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward + pass, while fusing the transposition. + - rhs must be row-major (rhs.stride() == (N, 1)). + - out must be row-major (out.stride() == (K * N, N, 1)). + """ + check_input_device_dtype(lhs, rhs, group_sizes) + + M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes) + + out = get_tgmm_output( + K, + N, + G, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out) + + if config is None: + config = get_config("ptgmm", M, K, N, G, accumulate) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + "GRID_DIM", + } + ), "Invalid PTGMM kernel config." + + # Bias gradient handling. + # ----------------------- + # Get or validate bias gradient tensor. + compute_bias_grad = bias_grad is not None + bias_grad_ptr = get_tgmm_bias_grad( + K, + G, + device=lhs.device, + existing_bias_grad=bias_grad, + ) + + grid = _ptgmm_grid( + K, + N, + G, + config["BLOCK_SIZE_K"], + config["BLOCK_SIZE_N"], + config["GRID_DIM"], + ) + + # fmt: off + tgmm_persistent_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias_grad_ptr, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_LHS=trans_lhs, + COMPUTE_BIAS_GRAD=compute_bias_grad, + ACCUMULATE=accumulate, + **config, + ) + # fmt: on + + return out + + +# Regular non-persistent TGMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _nptgmm_grid( + K: int, + N: int, + G: int, + block_size_k: int, + block_size_n: int, +) -> tuple[int, int]: + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}." + assert G > 0, f"G must be positive, it's {G}." + assert is_power_of_2( + block_size_k + ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + num_k_tiles = triton.cdiv(K, block_size_k) + assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles_per_mm = num_k_tiles * num_n_tiles + assert ( + num_tiles_per_mm > 0 + ), f"num_tiles_per_mm must be positive, it's {num_tiles_per_mm}." + return (G, num_tiles_per_mm) + + +def nptgmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias_grad: Tensor | None = None, + accumulate: bool = False, +) -> Tensor: + """ + Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs + + lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with + the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch + parlance, it can be implemented as follows for a given group g: + out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :] + + The 't' in the operator name derives from MaxText implementation + (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py), + which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor + shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N). + + The 'np' in the operator name means that it is implemented with a non-persistent, i.e. regular + kernel. There is also the persistent variation, which is implemented with a persistent kernel. + Please take a look at ptgmm operator. Both nptgmm and ptgmm implement the same computation, + choosing one or the other is a matter of performance for the target workload. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (K, M). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 2D input tensor. Shape: (M, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (G, K, N), its data type must match + preferred_element_type and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias_grad : torch.Tensor or None, optional + Optional bias gradient output tensor. Shape: (G, K). + If provided, the kernel will compute the bias gradient and write it to this tensor. + bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), + accumulate : bool, optional + Whether to accumulate into existing output tensor values. Default is False. + If False, output will be overwritten with fresh computation. + If True, results will be added to existing output tensor values. + + Returns + ------- + torch.Tensor + The computed output 3D tensor. Shape: (G, K, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - NPTGMM is implemented with a non-persistent regular Triton kernel. + - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs + is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel + parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward + pass, while fusing the transposition. + - rhs must be row-major (rhs.stride() == (N, 1)). + - out must be row-major (out.stride() == (K * N, N, 1)). + """ + check_input_device_dtype(lhs, rhs, group_sizes) + + M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes) + + out = get_tgmm_output( + K, + N, + G, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out) + + # Bias gradient handling. + # ----------------------- + # Get or validate bias gradient tensor. + compute_bias_grad = bias_grad is not None + bias_grad_ptr = get_tgmm_bias_grad( + K, + G, + device=lhs.device, + existing_bias_grad=bias_grad, + ) + + if config is None: + config = get_config("nptgmm", M, K, N, G, accumulate) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + } + ), "Invalid NPTGMM kernel config." + + grid = _nptgmm_grid( + K, + N, + G, + config["BLOCK_SIZE_K"], + config["BLOCK_SIZE_N"], + ) + + # fmt: off + tgmm_non_persistent_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias_grad_ptr, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_LHS=trans_lhs, + COMPUTE_BIAS_GRAD=compute_bias_grad, + ACCUMULATE=accumulate, + **config, + ) + # fmt: on + + return out diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py new file mode 100644 index 0000000000000000000000000000000000000000..3f6c88581a64044518125623f116082c53bd5474 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py @@ -0,0 +1,46 @@ +import triton + +# Detect the GPU arch lazily: querying the triton driver at import time fails +# in headless environments (e.g. the kernel-builder ABI check sandbox has no +# GPU), and the original JAX fallback pulled in an unrelated runtime dep. The +# arch is only actually needed when a GMM kernel is dispatched, so resolve and +# cache on first call. +_CACHED_ARCH = None + + +def get_arch(): + global _CACHED_ARCH + if _CACHED_ARCH is not None: + return _CACHED_ARCH + try: + _CACHED_ARCH = triton.runtime.driver.active.get_current_target().arch + except RuntimeError: + try: + from jax._src.lib import gpu_triton as triton_kernel_call_lib + _CACHED_ARCH = triton_kernel_call_lib.get_arch_details("0").split(":")[0] + except ImportError as e: + raise RuntimeError( + "Cannot determine GPU arch: triton driver is inactive and " + "JAX is not available. A GPU is required for grouped GEMM." + ) from e + return _CACHED_ARCH + + +def is_gluon_avail(): + return get_arch() in ("gfx950", "gfx1250") + + +def is_fp4_avail(): + return get_arch() in ("gfx950", "gfx1250") + + +def is_fp8_avail(): + return get_arch() in ("gfx942", "gfx950", "gfx1250", "gfx1200", "gfx1201") + + +def is_mx_scale_preshuffling_avail(): + return get_arch() in ("gfx950", "gfx1250") + + +def is_tdm_avail(): + return get_arch() in ("gfx1250",) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..99792bb3ba2fab8fff223bba733ced1eb6e6df53 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: MIT + +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl + + +@triton.jit +def remap_xcd_chunked( + pid, GRID_MN, NUM_XCDS: tl.constexpr = 8, CHUNK_SIZE: tl.constexpr = 2 +): + # Compute current XCD and local PID + xcd = pid % NUM_XCDS + # distribute the modulo pids in round robin + if pid > (GRID_MN // (NUM_XCDS * CHUNK_SIZE)) * (NUM_XCDS * CHUNK_SIZE): + return pid + local_pid = pid // NUM_XCDS + # Calculate chunk index and position within chunk + chunk_idx = local_pid // CHUNK_SIZE + pos_in_chunk = local_pid % CHUNK_SIZE + # Calculate new PID + new_pid = chunk_idx * NUM_XCDS * CHUNK_SIZE + xcd * CHUNK_SIZE + pos_in_chunk + return new_pid + + +@triton.jit +def remap_xcd(pid, GRID_MN, NUM_XCDS: tl.constexpr = 8): + ## pid remapping on xcds + # Number of pids per XCD in the new arrangement + pids_per_xcd = (GRID_MN + NUM_XCDS - 1) // NUM_XCDS + # When GRID_MN cannot divide NUM_XCDS, some xcds will have + # pids_per_xcd pids, the other will have pids_per_xcd - 1 pids. + # We calculate the number of xcds that have pids_per_xcd pids as + # tall_xcds + tall_xcds = GRID_MN % NUM_XCDS + tall_xcds = NUM_XCDS if tall_xcds == 0 else tall_xcds + # Compute current XCD and local pid within the XCD + xcd = pid % NUM_XCDS + local_pid = pid // NUM_XCDS + # Calculate new pid based on the new grouping + # Note that we need to consider the following two cases: + # 1. the current pid is on a tall xcd + # 2. the current pid is on a short xcd + if xcd < tall_xcds: + pid = xcd * pids_per_xcd + local_pid + else: + pid = ( + tall_xcds * pids_per_xcd + + (xcd - tall_xcds) * (pids_per_xcd - 1) + + local_pid + ) + + return pid + + +@triton.jit +def pid_grid(pid: int, num_pid_m: int, num_pid_n: int, GROUP_SIZE_M: tl.constexpr = 1): + """ + Maps 1D pid to 2D grid coords (pid_m, pid_n). + + Args: + - pid: 1D pid + - num_pid_m: grid m size + - num_pid_n: grid n size + - GROUP_SIZE_M: tl.constexpr: default is 1 + """ + if GROUP_SIZE_M == 1: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + tl.assume(group_size_m >= 0) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + return pid_m, pid_n + + +@triton.jit +def pid_grid_3d(pid: int, num_pid_m: int, num_pid_n: int, num_pid_k): + """ + Maps 1D pid to 3D grid coords (pid_m, pid_n, pid_k). + Args: + - pid: 1D pid + - num_pid_m: grid m size + - num_pid_n: grid n size + - num_pid_k: grid k size + + Returns: + - pid_m, pid_n, pid_k: 3D grid coordinates + """ + pid_m = pid % num_pid_m + pid_n = (pid // num_pid_m) % num_pid_n + pid_k = pid // (num_pid_m * num_pid_n) % num_pid_k + + return pid_m, pid_n, pid_k diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py new file mode 100644 index 0000000000000000000000000000000000000000..153dee65b50ab5f833262481889d2184d1ca639f --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py @@ -0,0 +1,752 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + +# Imports. +# ------------------------------------------------------------------------------ + +# PyTorch +import torch +from torch import Tensor + +# AITER: logging +from .logger import AiterTritonLogger + +_LOGGER: AiterTritonLogger = AiterTritonLogger() + + +# Supported data types. +# ------------------------------------------------------------------------------ + +# Supported data types, as strings. +SUPPORTED_DTYPES_STR: set[str] = {"fp16", "bf16"} + + +# Convert string data type to PyTorch data type. +def dtype_from_str(dtype_str: str) -> torch.dtype: + dtype_str = dtype_str.strip().lower() + dtype_str = dtype_str[1:] if dtype_str[0] in {"i", "o"} else dtype_str + assert ( + dtype_str in SUPPORTED_DTYPES_STR + ), "String data type isn't in set of supported string data types." + return {"fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str] + + +# Supported data types, as PyTorch types. +SUPPORTED_DTYPES: set[torch.dtype] = { + dtype_from_str(dtype_str) for dtype_str in SUPPORTED_DTYPES_STR +} + + +# Convert PyTorch data type to string data type. +def str_from_dtype(dtype: torch.dtype) -> str: + assert ( + dtype in SUPPORTED_DTYPES + ), "PyTorch data type isn't in set of supported PyTorch data types." + return {torch.float16: "fp16", torch.bfloat16: "bf16"}[dtype] + + +# Default data type, as string. +DTYPE_STR: str = "bf16" +assert ( + DTYPE_STR in SUPPORTED_DTYPES_STR +), "Default string data type isn't in set of supported string data types." + + +# Default data type, as PyTorch type. +DTYPE: torch.dtype = dtype_from_str(DTYPE_STR) + + +# Other defaults. +# ------------------------------------------------------------------------------ + +# Default device. +DEVICE: torch.device | str = "cuda" + +# Default RNG seed for input generation. +RNG_SEED: int = 0 + +# Default number of group sizes. +NUM_GROUP_SIZES: int = 1 + +# Default transposition (NN). +TRANS_LHS: bool = False +TRANS_RHS: bool = False + + +# Parameter checking functions. +# ------------------------------------------------------------------------------ + + +def is_power_of_2(x: int) -> bool: + return (x > 0) and (x & (x - 1) == 0) + + +def check_input_device_dtype( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor, bias: Tensor | None = None +) -> None: + assert ( + lhs.device == rhs.device == group_sizes.device + ), f"All input tensors must be in the same device (lhs = {lhs.device}, rhs = {rhs.device}, group_sizes = {group_sizes.device})." + assert ( + lhs.dtype == rhs.dtype + ), f"lhs and rhs types must match (lhs = {lhs.dtype}, rhs = {rhs.dtype})." + assert group_sizes.dtype == torch.int32, "group_sizes type must be int32." + + if bias is not None: + assert ( + bias.device == lhs.device + ), f"bias must be on the same device as lhs (bias = {bias.device}, lhs = {lhs.device})." + assert ( + bias.dtype == lhs.dtype + ), f"bias dtype must match lhs dtype (bias = {bias.dtype}, lhs = {lhs.dtype})." + + +def check_bias_shape_stride(bias: Tensor, G: int, N: int) -> None: + assert bias.shape == ( + G, + N, + ), f"bias must have shape (G, N) = ({G}, {N}), got {bias.shape}." + assert bias.stride() == (N, 1), "bias must be row-major (bias.stride() == (N, 1))." + + +# Generation of group sizes. +# ------------------------------------------------------------------------------ + + +# Probabilities for generating random group sizes. +UNUSED_TOKENS_PROB: float = 0.0 +UNUSED_EXPERTS_PROB: float = 0.1 + + +def gen_uniform_group_sizes( + M: int, + G: int, + device: torch.device | str = DEVICE, +) -> Tensor: + assert M >= 0, f"Number of tokens M must be non-negative (it's {M})." + assert G > 0, f"Number of experts G must be positive (it's {G})." + + base = M // G + remainder = M % G + group_sizes = torch.full((G,), base, dtype=torch.int32, device=device) + if remainder > 0: + group_sizes[:remainder] += 1 + + assert ( + len(group_sizes) == G + ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})." + assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative." + assert ( + torch.sum(group_sizes).item() == M + ), f"Group sizes don't add up to total tokens {M}." + assert group_sizes.dtype == torch.int32, "Group sizes must be int32." + + return group_sizes + + +def gen_group_sizes( + M: int, + G: int, + device: torch.device | str = DEVICE, + rng_seed: int | None = RNG_SEED, + unused_tokens_prob: float = UNUSED_TOKENS_PROB, + unused_experts_prob: float = UNUSED_EXPERTS_PROB, +) -> Tensor: + assert M >= 0, f"Number of tokens M must be non-negative (it's {M})." + assert G > 0, f"Number of experts G must be positive (it's {G})." + assert ( + 0 <= unused_tokens_prob <= 1 + ), f"Probability of unused tokens must be in [0, 1] interval (it's {unused_tokens_prob})." + assert ( + 0 <= unused_experts_prob <= 1 + ), f"Probability of unused experts must be in [0, 1] interval (it's {unused_experts_prob})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + if unused_tokens_prob > 0: + # Optionally drop tokens to simulate routing sparsity, some tokens may not be routed. + num_unused_tokens = M + while num_unused_tokens == M: + num_unused_tokens = int( + torch.binomial( + torch.tensor(float(M), device=device), + torch.tensor(unused_tokens_prob, device=device), + ).item() + ) + else: + num_unused_tokens = 0 + num_used_tokens = M - num_unused_tokens + assert ( + num_unused_tokens >= 0 + ), f"Number of unused tokens must be non-negative (it's {num_unused_tokens})." + assert ( + num_used_tokens > 0 + ), f"Number of used tokens must be positive (it's {num_used_tokens})." + assert ( + num_used_tokens + num_unused_tokens == M + ), f"Unused + used tokens don't add up total tokens ({num_used_tokens} + {num_unused_tokens} != {M})." + + if num_unused_tokens > 0: + _LOGGER.debug( + f"Group sizes generation: dropped {num_unused_tokens} token{'s' if num_unused_tokens > 1 else ''}.", + ) + + if unused_experts_prob > 0: + # Some experts may have zero tokens assigned to them. + num_used_experts = 0 + while num_used_experts == 0: + used_experts = torch.nonzero( + torch.rand((G,), device=device) >= unused_experts_prob + ).squeeze() + num_used_experts = used_experts.numel() + else: + used_experts = torch.arange(0, G, device=device) + num_used_experts = G + num_unused_experts = G - num_used_experts + assert ( + num_unused_experts >= 0 + ), f"Number of unused experts must be non-negative (it's {num_unused_experts})." + assert ( + num_used_experts >= 1 + ), f"At least one expert must be used (it's {num_used_experts})." + assert ( + num_unused_experts + num_used_experts == G + ), f"Unused + used experts don't add up total experts ({num_unused_experts} + {num_used_experts} != {G})." + + if num_unused_experts > 0: + _LOGGER.debug( + f"Group sizes generation: dropped {num_unused_experts} expert{'s' if num_unused_experts > 1 else ''}.", + ) + + group_sizes = torch.bincount( + used_experts[ + torch.randint(low=0, high=num_used_experts, size=(num_used_tokens,)) + ], + minlength=G, + ).to(torch.int32) + + assert ( + len(group_sizes) == G + ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})." + assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative." + assert ( + torch.sum(group_sizes).item() == num_used_tokens + ), f"Group sizes don't add up to used tokens {num_used_tokens}." + assert group_sizes.dtype == torch.int32, "Group sizes must be int32." + + return group_sizes + + +def gen_multiple_group_sizes( + num_group_sizes: int, + M: int, + G: int, + device: torch.device | str = DEVICE, + rng_seed: int | None = RNG_SEED, + unused_tokens_prob: float = UNUSED_TOKENS_PROB, + unused_experts_prob: float = UNUSED_EXPERTS_PROB, + group_sizes_0: Tensor | None = None, +) -> list[Tensor]: + assert ( + num_group_sizes > 0 + ), f"Number of group sizes to be generated must be positive, it's {num_group_sizes}." + multiple_group_sizes = [ + gen_group_sizes( + M, + G, + device=device, + rng_seed=rng_seed if g == 0 else None, + unused_tokens_prob=unused_tokens_prob, + unused_experts_prob=unused_experts_prob, + ) + for g in range( + num_group_sizes if group_sizes_0 is None else num_group_sizes - 1 + ) + ] + if group_sizes_0 is not None: + multiple_group_sizes.insert(0, group_sizes_0) + assert ( + len(multiple_group_sizes) == num_group_sizes + ), f"Expecting {num_group_sizes} distinct group sizes (it's {len(multiple_group_sizes)})." + return multiple_group_sizes + + +# GMM helpers: tensor generation. +# ------------------------------------------------------------------------------ + + +def gen_gmm_input( + M: int, + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + trans_rhs: bool = TRANS_RHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, +) -> tuple[Tensor, Tensor, Tensor]: + assert M > 0, f"Number of lhs rows M must be positive (M = {M})." + assert K > 0, f"Number of lhs columns / rhs rows K must be positive (K = {K})." + assert N > 0, f"Number of rhs columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + lhs = torch.randn((M, K), dtype=torch.float32, device=device) + lhs = lhs.to(preferred_element_type) + + if trans_rhs: + rhs = torch.randn((G, N, K), dtype=torch.float32, device=device).permute( + 0, 2, 1 + ) + else: + rhs = torch.randn((G, K, N), dtype=torch.float32, device=device) + rhs = rhs.to(preferred_element_type) + + group_sizes = ( + gen_uniform_group_sizes(M, G, device=device) + if unif_group_sizes + else gen_group_sizes(M, G, device=device, rng_seed=None) + ) + + return lhs, rhs, group_sizes + + +def gen_gmm_output( + M: int, + N: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, +) -> Tensor: + assert M > 0, f"Number of out rows M must be positive (M = {M})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + + out = torch.empty((M, N), dtype=preferred_element_type, device=device) + + return out + + +def gen_gmm_tensors( + M: int, + K: int, + N: int, + G: int, + num_group_sizes: int, + device: torch.device | str = DEVICE, + input_type: torch.dtype = DTYPE, + output_type: torch.dtype = DTYPE, + trans_lhs: bool = False, + trans_rhs: bool = TRANS_RHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, + use_bias: bool = False, +) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]: + lhs, rhs, group_sizes_0 = gen_gmm_input( + M, + K, + N, + G, + device=device, + preferred_element_type=input_type, + trans_rhs=trans_rhs, + rng_seed=rng_seed, + unif_group_sizes=unif_group_sizes, + ) + multiple_group_sizes = gen_multiple_group_sizes( + num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0 + ) + out = gen_gmm_output(M, N, device=device, preferred_element_type=output_type) + bias = None + if use_bias: + torch.manual_seed(rng_seed + 1000) # Different seed for bias + bias = torch.randn(G, N, dtype=input_type, device=device) + + return lhs, rhs, multiple_group_sizes, out, bias + + +# GMM helpers: get information from tensors. +# ------------------------------------------------------------------------------ + + +def get_gmm_shape( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor +) -> tuple[int, int, int, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})." + assert ( + group_sizes.dim() == 1 + ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})." + + M, lhs_k = lhs.shape + rhs_g, rhs_k, N = rhs.shape + group_sizes_g = group_sizes.shape[0] + + assert ( + lhs_k == rhs_k + ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})." + K = lhs_k + assert ( + rhs_g == group_sizes_g + ), f"G dimension of rhs and group_sizes don't match (rhs = {rhs_g}, group_sizes = {group_sizes_g})." + G = rhs_g + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + return M, K, N, G + + +def get_gmm_output( + M: int, + N: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, +) -> Tensor: + assert M > 0, f"Number of out rows M must be positive (M = {M})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + + if existing_out is not None: + assert ( + existing_out.device == device + ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})." + assert ( + existing_out.dtype == preferred_element_type + ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})." + assert existing_out.shape == ( + M, + N, + ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(M, N)})." + return existing_out + + return gen_gmm_output( + M, + N, + device=device, + preferred_element_type=preferred_element_type, + ) + + +def get_gmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})." + assert out.dim() == 2, f"out must have 2 dimensions (it's {out.dim()})." + + lhs_m, lhs_k = lhs.shape + G, rhs_k, rhs_n = rhs.shape + out_m, out_n = out.shape + + assert ( + lhs_m == out_m + ), f"M dimension of lhs and out don't match (lhs = {lhs_m}, rhs = {out_m})." + M = lhs_m + assert ( + lhs_k == rhs_k + ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})." + K = lhs_k + assert ( + rhs_n == out_n + ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})." + N = rhs_n + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + is_lhs_row_major = lhs.stride() == (K, 1) + assert is_lhs_row_major, "lhs must be row-major." + is_rhs_row_major = rhs.stride() == (K * N, N, 1) + is_rhs_col_major = rhs.stride() == (K * N, 1, K) + assert ( + is_rhs_row_major != is_rhs_col_major + ), "rhs must be row-major or column-major." + is_out_row_major = out.stride() == (N, 1) + assert is_out_row_major, "out must be row-major." + + # Get rhs leading dimension according to transposition configuration. + ld_rhs = N if is_rhs_row_major else K + + return is_rhs_col_major, ld_rhs + + +# TGMM helpers: tensor generation. +# ------------------------------------------------------------------------------ + + +def gen_tgmm_input( + M: int, + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + trans_lhs: bool = TRANS_LHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, +) -> tuple[Tensor, Tensor, Tensor]: + assert K > 0, f"Number of lhs rows K must be positive (M = {K})." + assert M > 0, f"Number of lhs columns / rhs rows M must be positive (K = {M})." + assert N > 0, f"Number of rhs columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + if trans_lhs: + lhs = torch.randn((M, K), dtype=torch.float32, device=device).T + else: + lhs = torch.randn((K, M), dtype=torch.float32, device=device) + lhs = lhs.to(preferred_element_type) + + rhs = torch.randn((M, N), dtype=torch.float32, device=device) + rhs = rhs.to(preferred_element_type) + + group_sizes = ( + gen_uniform_group_sizes(M, G, device=device) + if unif_group_sizes + else gen_group_sizes(M, G, device=device, rng_seed=None) + ) + + return lhs, rhs, group_sizes + + +def gen_tgmm_output( + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, +) -> Tensor: + assert K > 0, f"Number of out rows K must be positive (K = {K})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + out = torch.empty((G, K, N), dtype=preferred_element_type, device=device) + + return out + + +def gen_tgmm_bias_grad( + K: int, + G: int, + device: torch.device | str = DEVICE, + with_bias_grad: bool = False, +) -> Tensor: + if with_bias_grad: + assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + return torch.empty((G, K), device=device, dtype=torch.float32) + else: + # Return dummy pointer when bias_grad is not needed. + # Must be float32 because atomic_add does not support bf16/fp16, + # and Triton validates the pointer dtype even in dead branches. + return torch.tensor([], device=device, dtype=torch.float32) + + +def gen_tgmm_tensors( + M: int, + K: int, + N: int, + G: int, + num_group_sizes: int, + device: torch.device | str = DEVICE, + input_type: torch.dtype = DTYPE, + output_type: torch.dtype = DTYPE, + trans_lhs: bool = TRANS_LHS, + trans_rhs: bool = False, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, + use_bias: bool = False, +) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]: + lhs, rhs, group_sizes_0 = gen_tgmm_input( + M, + K, + N, + G, + device=device, + preferred_element_type=input_type, + trans_lhs=trans_lhs, + rng_seed=rng_seed, + unif_group_sizes=unif_group_sizes, + ) + multiple_group_sizes = gen_multiple_group_sizes( + num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0 + ) + out = gen_tgmm_output(K, N, G, device=device, preferred_element_type=output_type) + if use_bias: + bias_grad = gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=True) + else: + bias_grad = None + return lhs, rhs, multiple_group_sizes, out, bias_grad + + +# TGMM helpers: get information from tensors. +# ------------------------------------------------------------------------------ + + +def get_tgmm_shape( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor +) -> tuple[int, int, int, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})." + assert ( + group_sizes.dim() == 1 + ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})." + + K, lhs_m = lhs.shape + rhs_m, N = rhs.shape + G = group_sizes.shape[0] + + assert ( + lhs_m == rhs_m + ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})." + M = lhs_m + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + return M, K, N, G + + +def get_tgmm_output( + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, +) -> Tensor: + assert K > 0, f"Number of out rows K must be positive (K = {K})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if existing_out is not None: + assert ( + existing_out.device == device + ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})." + assert ( + existing_out.dtype == preferred_element_type + ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})." + assert existing_out.shape == ( + G, + K, + N, + ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(G, K, N)})." + return existing_out + + return gen_tgmm_output( + K, + N, + G, + device=device, + preferred_element_type=preferred_element_type, + ) + + +def get_tgmm_bias_grad( + K: int, + G: int, + device: torch.device | str = DEVICE, + existing_bias_grad: Tensor | None = None, +) -> Tensor: + """ + Get or validate bias gradient tensor for TGMM. + + If existing_bias_grad is provided, validates its shape, device, dtype, and stride, + and always zeros it before returning (since the kernel uses atomic_add). + If existing_bias_grad is None, returns a dummy tensor (for use when COMPUTE_BIAS_GRAD=False). + Parameters + ---------- + K : int + Number of rows in the bias gradient tensor. + G : int + Number of groups. + device : torch.device or str + Device for the tensor. + existing_bias_grad : torch.Tensor or None + Existing bias gradient tensor to validate and use. + Returns + ------- + torch.Tensor + Valid bias gradient tensor or dummy tensor. + """ + assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if existing_bias_grad is not None: + # Validate existing bias_grad tensor. + expected_shape = (G, K) + assert ( + tuple(existing_bias_grad.shape) == expected_shape + ), f"bias_grad must have shape {expected_shape}, got {tuple(existing_bias_grad.shape)}." + assert ( + existing_bias_grad.device == device + ), f"bias_grad must be on the same device (bias_grad = {existing_bias_grad.device}, device = {device})." + assert ( + existing_bias_grad.dtype == torch.float32 + ), f"bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), got {existing_bias_grad.dtype}." + assert existing_bias_grad.stride() == ( + K, + 1, + ), f"bias_grad must be row-major with stride (K, 1) = ({K}, 1), got {existing_bias_grad.stride()}." + + # Always zero the tensor since bias_grad represents gradients for the current + # computation and should start fresh. The kernel uses atomic_add which adds to + # existing values, so we must zero before the kernel runs. + existing_bias_grad.zero_() + + return existing_bias_grad + + else: + return gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=False) + + +def get_tgmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})." + assert out.dim() == 3, f"out must have 3 dimensions (it's {out.dim()})." + + lhs_k, lhs_m = lhs.shape + rhs_m, rhs_n = rhs.shape + G, out_k, out_n = out.shape + + assert ( + lhs_m == rhs_m + ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})." + M = lhs_m + assert ( + lhs_k == out_k + ), f"K dimension of lhs and out don't match (lhs = {lhs_k}, rhs = {out_k})." + K = lhs_k + assert ( + rhs_n == out_n + ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})." + N = rhs_n + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + is_lhs_row_major = lhs.stride() == (M, 1) + is_lhs_col_major = lhs.stride() == (1, K) + assert ( + is_lhs_row_major != is_lhs_col_major + ), "lhs must be row-major or column-major." + is_rhs_row_major = rhs.stride() == (N, 1) + assert is_rhs_row_major, "rhs must be row-major." + is_out_row_major = out.stride() == (K * N, N, 1) + assert is_out_row_major, "out must be row-major." + + # Get lhs leading dimension according to transposition configuration. + ld_lhs = M if is_lhs_row_major else K + + return is_lhs_col_major, ld_lhs diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/logger.py b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..391ddf9b6543f5244e7f4932c8568d60748e15cd --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/logger.py @@ -0,0 +1,47 @@ +import os +import logging + + +# AITER Triton Logger which is singleton object around python logging. +# Note: Python logging is also a singleton object, but we want to read the +# env var AITER_LOG_LEVEL once at the beginning. Another alternative is to do +# this in __init__.py. In fact, that's how CK logger is setup. We can look at +# switching to that at some point +# +# AITER_LOG_LEVEL follows python logging levels +# DEBUG +# INFO +# WARNING +# ERROR +# CRITICAL +# +class AiterTritonLogger(object): + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(AiterTritonLogger, cls).__new__(cls) + log_level_str = os.getenv("AITER_TRITON_LOG_LEVEL", "WARNING").upper() + numeric_level = getattr(logging, log_level_str, logging.WARNING) + cls._instance._logger = logging.getLogger("AITER_TRITON") + cls._instance._logger.setLevel(numeric_level) + + return cls._instance + + def get_logger(self): + return self._logger + + def debug(self, msg): + self._logger.debug(msg) + + def info(self, msg): + self._logger.info(msg) + + def warning(self, msg): + self._logger.warning(msg) + + def error(self, msg): + self._logger.error(msg) + + def critical(self, msg): + self._logger.critical(msg) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_megablocks_cuda_ae601bb.abi3.so b/build/torch211-cxx11-cu128-x86_64-linux/_megablocks_cuda_ae601bb.abi3.so deleted file mode 100644 index de3850850f32146603ef725e8e2aaa3f24f83835..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu128-x86_64-linux/_megablocks_cuda_ae601bb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ce5790e025e92878a33c9a766bca1cda450c920f68f49549525413c7e754c100 -size 19082856 diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_megablocks_cuda_f8f8b50.abi3.so b/build/torch211-cxx11-cu128-x86_64-linux/_megablocks_cuda_f8f8b50.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..da523ae82c2189b2afaca74670239c3e6d951788 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_megablocks_cuda_f8f8b50.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4ea3f6a68cbc730572a4a4c8d3814a2075cc775bffcf3082c9dbd6291e888555 +size 19750504 diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_ops.py b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py index 8dd1b7bcf680d2d32dd4ac912487118eafcee4ea..69afb8c26a3fa2691be277b0270d600d29a5865e 100644 --- a/build/torch211-cxx11-cu128-x86_64-linux/_ops.py +++ b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_cuda_ae601bb -ops = torch.ops._megablocks_cuda_ae601bb +from . import _megablocks_cuda_f8f8b50 +ops = torch.ops._megablocks_cuda_f8f8b50 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_cuda_ae601bb::{op_name}" + return f"_megablocks_cuda_f8f8b50::{op_name}" diff --git a/build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py b/build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py index 76037d8039cbfc2f0577275c78e4bc0be762592a..c7ef28ced79c830dae934177f059c1f4ddc24aad 100644 --- a/build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py +++ b/build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py @@ -2,16 +2,16 @@ # extensions. Otherwise libc10.so cannot be found. import torch -# # TODO(tgale): Wrap this in a try-block with better -# # error message and instructions for building the -# # c++ operations. -# import grouped_gemm_backend as backend +# On ROCm there is no CUTLASS grouped GEMM; dispatch to the vendored AITER +# Triton kernels instead. On CUDA we use the compiled CUTLASS `gmm` op. +_IS_ROCM = torch.version.hip is not None -# We import the backend operations from the megablocks package as -# grouped_gemm is vendored in megablocks in this repository. -# from ... import _ops as backend -# from megablocks._ops import ops as backend # type: ignore -from .._ops import ops as backend # type: ignore +if _IS_ROCM: + from .._grouped_gemm_triton import adapter as backend +else: + # We import the backend operations from the megablocks package as + # grouped_gemm is vendored in megablocks in this repository. + from .._ops import ops as backend # type: ignore def _allocate_output(a, b, batch_sizes, trans_a, trans_b): assert not (trans_a and trans_b) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/metadata.json b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json index 2ac083d9c8aa13338502cb46e51c3c623dec8951..5851ca557336aa902e5c27d1b5e8cab5b60e971a 100644 --- a/build/torch211-cxx11-cu128-x86_64-linux/metadata.json +++ b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json @@ -1,6 +1,6 @@ { "name": "megablocks", - "id": "_megablocks_cuda_ae601bb", + "id": "_megablocks_cuda_f8f8b50", "version": 1, "license": "Apache-2.0", "python-depends": [], @@ -10,6 +10,7 @@ "10.0", "10.1", "12.0", + "12.0+PTX", "7.0", "7.2", "7.5", diff --git a/build/torch211-cxx11-cu130-x86_64-linux/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py index 38075732c6d8fa0e1e6ef493145e1aca3851ae6b..0766d7b8da4f97baca212177b4bb989bc6374bf8 100644 --- a/build/torch211-cxx11-cu130-x86_64-linux/__init__.py +++ b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py @@ -3,7 +3,9 @@ import torch -from ._ops import ops +# Stable alias: bare `ops` is shadowed by `from . import layers` below. +from ._ops import ops as _compiled_ops +from . import ops from .grouped_gemm import backend as gg_backend from .grouped_gemm import ops as gg_ops @@ -136,7 +138,8 @@ def sort( Returns: The sorted values tensor """ - return ops.sort(x, end_bit, x_out, iota_out) + _compiled_ops.sort(x, end_bit, x_out, iota_out) + return x_out # Convenience functions for common use cases diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py new file mode 100644 index 0000000000000000000000000000000000000000..8c101d07cea416f9390b708e5a35fdc466e48aed --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py @@ -0,0 +1,574 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + + +# Imports. +# ------------------------------------------------------------------------------ + +# Python standard library +import functools + +# Triton +import triton +import triton.language as tl + +# AITER +from ..configs import CONFIGS as _CONFIGS +from ..utils._triton import arch_info +from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd + +# Kernel config. +# ------------------------------------------------------------------------------ + + +@functools.lru_cache() +def get_config( + gmm_type: str, M: int, K: int, N: int, G: int, accumulate: bool = False +) -> dict[str, int]: + assert gmm_type in { + "gmm", + "ptgmm", + "nptgmm", + }, f"'{gmm_type}' is an invalid GMM variant." + dev = arch_info.get_arch() + assert ( + dev in _CONFIGS + ), f"No GMM configuration tuned for arch '{dev}'. Supported: {sorted(_CONFIGS)}." + arch_configs = _CONFIGS[dev] + assert ( + "default" in arch_configs[gmm_type] + ), "Default configuration is absent." + key = "accumulate" if accumulate else "default" + return arch_configs[gmm_type][key] + + +# Common code shared by GMM and TGMM kernels. +# ------------------------------------------------------------------------------ + + +# XCD remapping followed by 1D PID to 2D grid mapping. +@triton.jit +def _remap_xcd_tile_grid( + tile_in_mm, + num_row_tiles, + num_col_tiles, + GROUP_SIZE: tl.constexpr = 1, + NUM_XCDS: tl.constexpr = 8, +): + return pid_grid( + remap_xcd(tile_in_mm, num_row_tiles * num_col_tiles, NUM_XCDS=NUM_XCDS), + num_row_tiles, + num_col_tiles, + GROUP_SIZE_M=GROUP_SIZE, + ) + + +# GMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.heuristics( + { + "K_DIVISIBLE_BY_BLOCK_SIZE_K": lambda META: META["K"] % META["BLOCK_SIZE_K"] + == 0, + } +) +@triton.jit +def gmm_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_RHS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + K_DIVISIBLE_BY_BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE: tl.constexpr, + GRID_DIM: tl.constexpr, + USE_BIAS: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + # Current tile. Each program computes multiple tiles of each group. + tile = tl.program_id(0) + tl.device_assert(tile >= 0, "tile < 0 (at initialization)") + + # Tile limit of last MM problem (inclusive). + last_mm_tile = 0 + + # Last input row of lhs and output row of out. Each group reads some rows of + # lhs and writes some rows to out. + last_m = 0 + + # Loop through all (m, K, N) MM problems: + # (m, K) x (K, N) = (m, N) + # sum(m) = M + for g in range(G): + # Get m dimension of current MM problem. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty + tl.device_assert(m >= 0, "m < 0") + + num_m_tiles = tl.cdiv(m, BLOCK_SIZE_M) + # num_m_tiles can be zero if group is empty + tl.device_assert(num_m_tiles >= 0, "num_m_tiles < 0") + + num_tiles = num_m_tiles * num_n_tiles + # num_tiles can be zero if group is empty + tl.device_assert(num_tiles >= 0, "num_tiles < 0") + + # Loop through tiles of current MM problem. + while tile >= last_mm_tile and tile < last_mm_tile + num_tiles: + # Figure out tile coordinates in current MM problem. + tile_in_mm = tile - last_mm_tile + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_m, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_m_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + # Do regular MM: + + tl.device_assert(tile_m * BLOCK_SIZE_M >= 0, "tile_m * BLOCK_SIZE_M < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_m = ( + tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + ) % m + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + + lhs_ptrs = lhs_ptr + (last_m + offs_lhs_m[:, None]) * K + offs_k[None, :] + + if TRANS_RHS: + rhs_ptrs = ( + rhs_ptr + + g.to(tl.int64) * K * N + + offs_k[:, None] + + offs_rhs_n[None, :] * K + ) + else: + rhs_ptrs = ( + rhs_ptr + + g.to(tl.int64) * K * N + + offs_k[:, None] * N + + offs_rhs_n[None, :] + ) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if K_DIVISIBLE_BY_BLOCK_SIZE_K: + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + else: + k_mask_limit = K - k * BLOCK_SIZE_K + lhs = tl.load( + lhs_ptrs, mask=offs_k[None, :] < k_mask_limit, other=0 + ) + rhs = tl.load( + rhs_ptrs, mask=offs_k[:, None] < k_mask_limit, other=0 + ) + + acc = tl.dot(lhs, rhs, acc=acc) + + lhs_ptrs += BLOCK_SIZE_K + + if TRANS_RHS: + rhs_ptrs += BLOCK_SIZE_K + else: + rhs_ptrs += BLOCK_SIZE_K * N + + # Add bias if enabled + if USE_BIAS: + offs_bias_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange( + 0, BLOCK_SIZE_N + ) + bias_ptrs = bias_ptr + g.to(tl.int64) * N + offs_bias_n + bias = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0) + # Convert bias to float32 to match accumulator precision + bias = bias.to(tl.float32) + # Broadcast bias across M dimension and add in float32 + acc += bias[None, :] + + # Convert to output dtype after all computations + acc = acc.to(out_ptr.type.element_ty) + + offs_out_m = tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + (last_m + offs_out_m[:, None]) * N + offs_out_n[None, :] + ) + + tl.store( + out_ptrs, + acc, + mask=(offs_out_m[:, None] < m) & (offs_out_n[None, :] < N), + ) + + # Go to the next tile by advancing number of programs. + tile += GRID_DIM + tl.device_assert(tile > 0, "tile <= 0 (at update)") + + # Get ready to go to the next MM problem. + + last_mm_tile += num_tiles + # last_mm_tile can be zero if group 0 is skipped + tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)") + + last_m += m + # last_m can be zero if group 0 is skipped + tl.device_assert(last_m >= 0, "last_m < 0 (at update)") + tl.device_assert(last_m <= M, "last_m > M (at update)") + + +# Persistent TGMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.jit +def tgmm_persistent_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_grad_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_LHS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE: tl.constexpr, + GRID_DIM: tl.constexpr, + COMPUTE_BIAS_GRAD: tl.constexpr, + ACCUMULATE: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0") + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + num_tiles = num_k_tiles * num_n_tiles + tl.device_assert(num_tiles > 0, "num_tiles <= 0") + + # Current tile. Each program computes multiple tiles of each group. + tile = tl.program_id(0) + tl.device_assert(tile >= 0, "tile < 0 (at initialization)") + + # Tile limit of last MM problem (inclusive). + last_mm_tile = 0 + + # Last input column of lhs and input row of rhs. Each group reads some + # columns of lhs and some rows of rhs. + last_m = 0 + + # Loop through all (K, m, N) MM problems: + # (K, m) x (m, N) = (K, N) + # sum(m) = M + for g in range(G): + # Get m dimension of current MM problem. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty + tl.device_assert(m >= 0, "m < 0") + + # Loop through tiles of current MM problem. + while tile >= last_mm_tile and tile < last_mm_tile + num_tiles: + # Figure out tile coordinates in current MM problem. + tile_in_mm = tile - last_mm_tile + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_k, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + # Do regular MM: + + tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + + if TRANS_LHS: + lhs_ptrs = ( + lhs_ptr + offs_lhs_k[:, None] + (last_m + offs_m[None, :]) * K + ) + else: + lhs_ptrs = ( + lhs_ptr + offs_lhs_k[:, None] * M + (last_m + offs_m[None, :]) + ) + + rhs_ptrs = rhs_ptr + (last_m + offs_m[:, None]) * N + offs_rhs_n[None, :] + + loop_m = tl.cdiv(m, BLOCK_SIZE_M) + m_divisible_by_block_m = m % BLOCK_SIZE_M == 0 + if not m_divisible_by_block_m: + loop_m -= 1 + + acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32) + + # Initialize bias accumulator + bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32) + + for _ in range(0, loop_m): + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + + acc = tl.dot(lhs, rhs, acc=acc) + + # Accumulate for bias gradient: sum lhs across M dimension + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum( + lhs, axis=1 + ) # Sum across M dimension [K, M] -> [K] + + if TRANS_LHS: + lhs_ptrs += BLOCK_SIZE_M * K + else: + lhs_ptrs += BLOCK_SIZE_M + + rhs_ptrs += BLOCK_SIZE_M * N + + if not m_divisible_by_block_m: + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0) + rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0) + acc = tl.dot(lhs, rhs, acc=acc) + + # Accumulate last chunk for bias gradient + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) + + acc = acc.to(out_ptr.type.element_ty) + + offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + + g.to(tl.int64) * K * N + + offs_out_k[:, None] * N + + offs_out_n[None, :] + ) + + mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N) + if ACCUMULATE: + # Load existing values and add to them (like beta=1 in BLAS) + old_vals = tl.load(out_ptrs, mask=mask, other=0.0) + tl.store(out_ptrs, acc + old_vals, mask=mask) + else: + # Overwrite output (like beta=0 in BLAS) + tl.store(out_ptrs, acc, mask=mask) + + # Store bias gradient (only for first N tile, sum across all M) + if COMPUTE_BIAS_GRAD and tile_n == 0: + # Keep as float32 for atomic_add (bf16 not supported for atomics) + bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k + # Use atomic add since multiple K-tiles may write to same expert's bias + tl.atomic_add( + bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed" + ) + + # Go to the next tile by advancing number of programs. + tile += GRID_DIM + tl.device_assert(tile > 0, "tile <= 0 (at update)") + + # Get ready to go to the next MM problem. + + last_mm_tile += num_tiles + # last_mm_tile can be zero if group 0 is skipped + tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)") + + last_m += m + # last_m can be zero if group 0 is skipped + tl.device_assert(last_m >= 0, "last_m < 0 (at update)") + tl.device_assert(last_m <= M, "last_m > M (at update)") + + +# Regular non-persistent TGMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.heuristics({"BLOCK_SIZE_G": lambda META: triton.next_power_of_2(META["G"])}) +@triton.jit +def tgmm_non_persistent_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_grad_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_LHS: tl.constexpr, + BLOCK_SIZE_G: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE: tl.constexpr, + COMPUTE_BIAS_GRAD: tl.constexpr, + ACCUMULATE: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + # Get group ID from grid. + g = tl.program_id(0) + tl.device_assert(g >= 0, "g < 0") + tl.device_assert(g < G, "g >= G") + + # Get m dimension of current MM group. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty. + tl.device_assert(m >= 0, "m < 0") + + # Skip empty groups. + if m == 0: + return + + # Compute sum(group_sizes) until current group g. + # It's the starting column of lhs and starting row of rhs. + offs_g = tl.arange(0, BLOCK_SIZE_G) + group_sizes = tl.load(group_sizes_ptr + offs_g, mask=offs_g < g, other=0) + start_m = tl.sum(group_sizes) + + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0") + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + # Get MM tile from grid. + tile_in_mm = tl.program_id(1) + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_k, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_k = (tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K + offs_rhs_n = (tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + + if TRANS_LHS: + lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] + (start_m + offs_m[None, :]) * K + else: + lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] * M + (start_m + offs_m[None, :]) + + rhs_ptrs = rhs_ptr + (start_m + offs_m[:, None]) * N + offs_rhs_n[None, :] + + loop_m = tl.cdiv(m, BLOCK_SIZE_M) + m_divisible_by_block_m = m % BLOCK_SIZE_M == 0 + if not m_divisible_by_block_m: + loop_m -= 1 + + acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32) + # Initialize bias accumulator + bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32) + + for _ in range(0, loop_m): + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + + acc = tl.dot(lhs, rhs, acc=acc) + + # Accumulate for bias gradient: sum lhs across M dimension + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) # [K, M] -> [K] + + if TRANS_LHS: + lhs_ptrs += BLOCK_SIZE_M * K + else: + lhs_ptrs += BLOCK_SIZE_M + + rhs_ptrs += BLOCK_SIZE_M * N + + if not m_divisible_by_block_m: + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0) + rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0) + acc = tl.dot(lhs, rhs, acc=acc) + # Accumulate last chunk for bias gradient + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) + + acc = acc.to(out_ptr.type.element_ty) + + offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + g.to(tl.int64) * K * N + offs_out_k[:, None] * N + offs_out_n[None, :] + ) + + mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N) + if ACCUMULATE: + # Load existing values and add to them (like beta=1 in BLAS) + old_vals = tl.load(out_ptrs, mask=mask, other=0.0) + tl.store(out_ptrs, acc + old_vals, mask=mask) + else: + # Overwrite output (like beta=0 in BLAS) + tl.store(out_ptrs, acc, mask=mask) + + # Store bias gradient (only for first N tile, sum across all M) + if COMPUTE_BIAS_GRAD and tile_n == 0: + # Keep as float32 for atomic_add (bf16/fp16 not supported for atomics) + bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k + # Use atomic add since multiple K-tiles may write to same expert's bias + tl.atomic_add(bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed") diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/adapter.py b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..98c224244f27445384e0c2377d73516406927536 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/adapter.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Adapt AITER's Triton grouped GEMM to MegaBlocks' ``gmm`` calling convention. + +MegaBlocks (following tgale96/grouped_gemm) uses a single ``gmm`` entry point +with ``trans_a`` / ``trans_b`` flags: + +* ``trans_a=False, trans_b=False``: a(M,K) @ b(G,K,N) -> c(M,N) +* ``trans_a=False, trans_b=True`` : a(M,K) @ b(G,N,K)^T -> c(M,N) (dgrad) +* ``trans_a=True`` : a(M,K)^T @ b(M,N) per group -> c(G,K,N) (wgrad) + +AITER exposes these as two kernels: ``gmm`` ((M,K)@(G,K,N)->(M,N), transposition +of the 3D operand inferred from strides) and ``ptgmm`` ((K,M)@(M,N)->(G,K,N), +transposition of the 2D operand inferred from strides). +""" + +import torch + +from .gmm import gmm as _aiter_gmm +from .gmm import ptgmm as _aiter_ptgmm + + +def gmm(a, b, c, batch_sizes, trans_a=False, trans_b=False): + # AITER requires group sizes to be int32 and to live on the compute device. + group_sizes = batch_sizes.to(device=a.device, dtype=torch.int32) + + # AITER asserts exact strides: gmm wants lhs/rhs row-major (a transposed + # 3D operand must be exactly column-major), tgmm wants rhs row-major and + # lhs row/column-major. Make operands contiguous first so the transposed + # views have the precise strides the kernels expect. `.contiguous()` is a + # no-op when the tensor is already contiguous. + if trans_a: + # Weight gradient: a(M,K), b(M,N) -> c(G,K,N). + # Pass a transposed so AITER sees lhs(K,M) column-major (TRANS_LHS). + _aiter_ptgmm( + a.contiguous().transpose(0, 1), + b.contiguous(), + group_sizes, + preferred_element_type=c.dtype, + existing_out=c, + ) + else: + # trans_b contracts b's last dim: pass a column-major (G,K,N) view. + rhs = b.contiguous() + if trans_b: + rhs = rhs.transpose(1, 2) + _aiter_gmm( + a.contiguous(), + rhs, + group_sizes, + preferred_element_type=c.dtype, + existing_out=c, + ) + return c diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/configs.py b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..9a4fe5617d8100869aa76dba9b7d22c7bcab814f --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/configs.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: MIT +# Tuned GMM configs vendored from ROCm/aiter (aiter/ops/triton/configs/). +# Inlined as a Python module so packaging always includes them. + +CONFIGS = {'gfx1250': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx942': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx950': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}} diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/gmm.py b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/gmm.py new file mode 100644 index 0000000000000000000000000000000000000000..e30c9326c6d4e4836d1303e2761ea2440a7f4750 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/gmm.py @@ -0,0 +1,567 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + + +# Imports. +# ------------------------------------------------------------------------------ + +# PyTorch +import torch +from torch import Tensor + +# Triton +import triton + +# AITER: GMM utility functions +from .utils.gmm_common import ( + DTYPE, + is_power_of_2, + check_input_device_dtype, + check_bias_shape_stride, + get_gmm_shape, + get_gmm_output, + get_gmm_transposition, + get_tgmm_shape, + get_tgmm_output, + get_tgmm_bias_grad, + get_tgmm_transposition, +) + +# AITER: GMM Triton kernels +from ._triton_kernels.gmm import ( + gmm_kernel, + tgmm_persistent_kernel, + tgmm_non_persistent_kernel, + get_config, +) + +# GMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _gmm_grid( + N: int, + block_size_m: int, + block_size_n: int, + group_sizes: Tensor, + grid_dim: int, +) -> tuple[int]: + assert N > 0, f"N must be positive, it's {N}." + assert is_power_of_2( + block_size_m + ), f"M-dimension tile size must be a power of 2 (it's {block_size_m})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + assert torch.all(group_sizes >= 0).item(), "All group_sizes must be non-negative." + assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})." + num_m_tiles = (group_sizes + block_size_m - 1) // block_size_m + assert torch.all(num_m_tiles >= 0).item(), "All num_m_tiles must be non-negative." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles = torch.sum(num_m_tiles * num_n_tiles).item() + assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}." + num_programs = int(min(grid_dim, num_tiles)) + assert num_programs > 0, f"num_programs must be positive, it's {num_programs}." + return (num_programs,) + + +def gmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias: Tensor | None = None, +) -> Tensor: + """ + Perform Group Matrix Multiplication (GMM): out = lhs @ rhs + bias + + lhs rows are divided into G groups. Each group of lhs rows is matrix multiplied with a plane of + rhs 3D tensor and then stored in a slice of out. In PyTorch parlance, it can be implemented as + follows for a given group g: + out[group_start:group_end, :] = lhs[group_start:group_end, :] @ rhs[g] + bias[g] + + The size of each group, and their respective start and end positions are specified by + group_sizes tensor. For instance, suppose that group_sizes = [3, 2, 4, 1]. In this particular + case we have 4 groups. The 1st group starts at 0 and ends at 2, the second group starts at 3 and + ends at 4, the third group starts at 5 and ends at 8, and the fourth and final group consists of + just the 10th (last) row of lhs. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (M, K). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 3D input tensor. Shape: (G, K, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (M, N), its data type must match preferred_element_type + and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias : torch.Tensor or None, optional + Optional bias tensor. Shape: (G, N). + If provided, bias data type must match lhs and rhs data type, and bias must be on the same + device as other input tensors. Each group g adds bias[g] to the output. + + Returns + ------- + torch.Tensor + The computed output 2D tensor. Shape: (M, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - GMM is implemented with a persistent Triton kernel. + - lhs must be row-major (lhs.stride() == (K, 1)). + - rhs can be row-major (rhs.stride() == (K * N, N, 1)) or column-major (rhs.stride() == + (K * N, 1, K)). If rhs is row-major then kernel parameter TRANS_RHS == False, this is useful + for implementing forward pass. If rhs is column-major then kernel parameter TRANS_RHS == True, + this is useful for computing the lhs derivative in the backward pass, while fusing the + transposition. + - out must be row-major (out.stride() == (N, 1)). + - bias must be row-major (bias.stride() == (N, 1)) if provided. + """ + use_bias = bias is not None + check_input_device_dtype(lhs, rhs, group_sizes, bias) + + M, K, N, G = get_gmm_shape(lhs, rhs, group_sizes) + + if use_bias: + check_bias_shape_stride(bias, G, N) + + out = get_gmm_output( + M, + N, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_rhs, _ = get_gmm_transposition(lhs, rhs, out) + + if config is None: + config = get_config("gmm", M, K, N, G) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + "GRID_DIM", + } + ), "Invalid GMM kernel config." + + grid = _gmm_grid( + N, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + group_sizes, + config["GRID_DIM"], + ) + + # fmt: off + gmm_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_RHS=trans_rhs, + USE_BIAS=use_bias, + **config, + ) + # fmt: on + + return out + + +# Persistent TGMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _ptgmm_grid( + K: int, + N: int, + G: int, + block_size_k: int, + block_size_n: int, + grid_dim: int, +) -> tuple[int]: + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}." + assert G > 0, f"G must be positive, it's {G}." + assert is_power_of_2( + block_size_k + ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})." + num_k_tiles = triton.cdiv(K, block_size_k) + assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles = G * num_k_tiles * num_n_tiles + assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}." + num_programs = min(grid_dim, num_tiles) + assert num_programs > 0, f"num_programs must be positive, it's {num_programs}." + return (num_programs,) + + +def ptgmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias_grad: Tensor | None = None, + accumulate: bool = False, +) -> Tensor: + """ + Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs + + lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with + the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch + parlance, it can be implemented as follows for a given group g: + out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :] + + The 't' in the operator name derives from MaxText implementation + (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py), + which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor + shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N). + + The 'p' in the operator name means that it is implemented with a persistent kernel. There is + also the non-persistent variation, which is implemented with a regular kernel. Please take a + look at nptgmm operator. Both ptgmm and nptgmm implement the same computation, choosing one or + the other is a matter of performance for the target workload. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (K, M). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 2D input tensor. Shape: (M, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (G, K, N), its data type must match + preferred_element_type and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias_grad : torch.Tensor or None, optional + Optional bias gradient output tensor. Shape: (G, K). + If provided, the kernel will compute the bias gradient and write it to this tensor. + bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), + accumulate : bool, optional + Whether to accumulate into existing output tensor values. Default is False. + If False, output will be overwritten with fresh computation. + If True, results will be added to existing output tensor values. + + Returns + ------- + torch.Tensor + The computed output 3D tensor. Shape: (G, K, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - PTGMM is implemented with a persistent Triton kernel. + - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs + is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel + parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward + pass, while fusing the transposition. + - rhs must be row-major (rhs.stride() == (N, 1)). + - out must be row-major (out.stride() == (K * N, N, 1)). + """ + check_input_device_dtype(lhs, rhs, group_sizes) + + M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes) + + out = get_tgmm_output( + K, + N, + G, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out) + + if config is None: + config = get_config("ptgmm", M, K, N, G, accumulate) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + "GRID_DIM", + } + ), "Invalid PTGMM kernel config." + + # Bias gradient handling. + # ----------------------- + # Get or validate bias gradient tensor. + compute_bias_grad = bias_grad is not None + bias_grad_ptr = get_tgmm_bias_grad( + K, + G, + device=lhs.device, + existing_bias_grad=bias_grad, + ) + + grid = _ptgmm_grid( + K, + N, + G, + config["BLOCK_SIZE_K"], + config["BLOCK_SIZE_N"], + config["GRID_DIM"], + ) + + # fmt: off + tgmm_persistent_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias_grad_ptr, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_LHS=trans_lhs, + COMPUTE_BIAS_GRAD=compute_bias_grad, + ACCUMULATE=accumulate, + **config, + ) + # fmt: on + + return out + + +# Regular non-persistent TGMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _nptgmm_grid( + K: int, + N: int, + G: int, + block_size_k: int, + block_size_n: int, +) -> tuple[int, int]: + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}." + assert G > 0, f"G must be positive, it's {G}." + assert is_power_of_2( + block_size_k + ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + num_k_tiles = triton.cdiv(K, block_size_k) + assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles_per_mm = num_k_tiles * num_n_tiles + assert ( + num_tiles_per_mm > 0 + ), f"num_tiles_per_mm must be positive, it's {num_tiles_per_mm}." + return (G, num_tiles_per_mm) + + +def nptgmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias_grad: Tensor | None = None, + accumulate: bool = False, +) -> Tensor: + """ + Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs + + lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with + the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch + parlance, it can be implemented as follows for a given group g: + out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :] + + The 't' in the operator name derives from MaxText implementation + (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py), + which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor + shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N). + + The 'np' in the operator name means that it is implemented with a non-persistent, i.e. regular + kernel. There is also the persistent variation, which is implemented with a persistent kernel. + Please take a look at ptgmm operator. Both nptgmm and ptgmm implement the same computation, + choosing one or the other is a matter of performance for the target workload. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (K, M). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 2D input tensor. Shape: (M, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (G, K, N), its data type must match + preferred_element_type and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias_grad : torch.Tensor or None, optional + Optional bias gradient output tensor. Shape: (G, K). + If provided, the kernel will compute the bias gradient and write it to this tensor. + bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), + accumulate : bool, optional + Whether to accumulate into existing output tensor values. Default is False. + If False, output will be overwritten with fresh computation. + If True, results will be added to existing output tensor values. + + Returns + ------- + torch.Tensor + The computed output 3D tensor. Shape: (G, K, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - NPTGMM is implemented with a non-persistent regular Triton kernel. + - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs + is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel + parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward + pass, while fusing the transposition. + - rhs must be row-major (rhs.stride() == (N, 1)). + - out must be row-major (out.stride() == (K * N, N, 1)). + """ + check_input_device_dtype(lhs, rhs, group_sizes) + + M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes) + + out = get_tgmm_output( + K, + N, + G, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out) + + # Bias gradient handling. + # ----------------------- + # Get or validate bias gradient tensor. + compute_bias_grad = bias_grad is not None + bias_grad_ptr = get_tgmm_bias_grad( + K, + G, + device=lhs.device, + existing_bias_grad=bias_grad, + ) + + if config is None: + config = get_config("nptgmm", M, K, N, G, accumulate) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + } + ), "Invalid NPTGMM kernel config." + + grid = _nptgmm_grid( + K, + N, + G, + config["BLOCK_SIZE_K"], + config["BLOCK_SIZE_N"], + ) + + # fmt: off + tgmm_non_persistent_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias_grad_ptr, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_LHS=trans_lhs, + COMPUTE_BIAS_GRAD=compute_bias_grad, + ACCUMULATE=accumulate, + **config, + ) + # fmt: on + + return out diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py new file mode 100644 index 0000000000000000000000000000000000000000..3f6c88581a64044518125623f116082c53bd5474 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py @@ -0,0 +1,46 @@ +import triton + +# Detect the GPU arch lazily: querying the triton driver at import time fails +# in headless environments (e.g. the kernel-builder ABI check sandbox has no +# GPU), and the original JAX fallback pulled in an unrelated runtime dep. The +# arch is only actually needed when a GMM kernel is dispatched, so resolve and +# cache on first call. +_CACHED_ARCH = None + + +def get_arch(): + global _CACHED_ARCH + if _CACHED_ARCH is not None: + return _CACHED_ARCH + try: + _CACHED_ARCH = triton.runtime.driver.active.get_current_target().arch + except RuntimeError: + try: + from jax._src.lib import gpu_triton as triton_kernel_call_lib + _CACHED_ARCH = triton_kernel_call_lib.get_arch_details("0").split(":")[0] + except ImportError as e: + raise RuntimeError( + "Cannot determine GPU arch: triton driver is inactive and " + "JAX is not available. A GPU is required for grouped GEMM." + ) from e + return _CACHED_ARCH + + +def is_gluon_avail(): + return get_arch() in ("gfx950", "gfx1250") + + +def is_fp4_avail(): + return get_arch() in ("gfx950", "gfx1250") + + +def is_fp8_avail(): + return get_arch() in ("gfx942", "gfx950", "gfx1250", "gfx1200", "gfx1201") + + +def is_mx_scale_preshuffling_avail(): + return get_arch() in ("gfx950", "gfx1250") + + +def is_tdm_avail(): + return get_arch() in ("gfx1250",) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..99792bb3ba2fab8fff223bba733ced1eb6e6df53 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: MIT + +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl + + +@triton.jit +def remap_xcd_chunked( + pid, GRID_MN, NUM_XCDS: tl.constexpr = 8, CHUNK_SIZE: tl.constexpr = 2 +): + # Compute current XCD and local PID + xcd = pid % NUM_XCDS + # distribute the modulo pids in round robin + if pid > (GRID_MN // (NUM_XCDS * CHUNK_SIZE)) * (NUM_XCDS * CHUNK_SIZE): + return pid + local_pid = pid // NUM_XCDS + # Calculate chunk index and position within chunk + chunk_idx = local_pid // CHUNK_SIZE + pos_in_chunk = local_pid % CHUNK_SIZE + # Calculate new PID + new_pid = chunk_idx * NUM_XCDS * CHUNK_SIZE + xcd * CHUNK_SIZE + pos_in_chunk + return new_pid + + +@triton.jit +def remap_xcd(pid, GRID_MN, NUM_XCDS: tl.constexpr = 8): + ## pid remapping on xcds + # Number of pids per XCD in the new arrangement + pids_per_xcd = (GRID_MN + NUM_XCDS - 1) // NUM_XCDS + # When GRID_MN cannot divide NUM_XCDS, some xcds will have + # pids_per_xcd pids, the other will have pids_per_xcd - 1 pids. + # We calculate the number of xcds that have pids_per_xcd pids as + # tall_xcds + tall_xcds = GRID_MN % NUM_XCDS + tall_xcds = NUM_XCDS if tall_xcds == 0 else tall_xcds + # Compute current XCD and local pid within the XCD + xcd = pid % NUM_XCDS + local_pid = pid // NUM_XCDS + # Calculate new pid based on the new grouping + # Note that we need to consider the following two cases: + # 1. the current pid is on a tall xcd + # 2. the current pid is on a short xcd + if xcd < tall_xcds: + pid = xcd * pids_per_xcd + local_pid + else: + pid = ( + tall_xcds * pids_per_xcd + + (xcd - tall_xcds) * (pids_per_xcd - 1) + + local_pid + ) + + return pid + + +@triton.jit +def pid_grid(pid: int, num_pid_m: int, num_pid_n: int, GROUP_SIZE_M: tl.constexpr = 1): + """ + Maps 1D pid to 2D grid coords (pid_m, pid_n). + + Args: + - pid: 1D pid + - num_pid_m: grid m size + - num_pid_n: grid n size + - GROUP_SIZE_M: tl.constexpr: default is 1 + """ + if GROUP_SIZE_M == 1: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + tl.assume(group_size_m >= 0) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + return pid_m, pid_n + + +@triton.jit +def pid_grid_3d(pid: int, num_pid_m: int, num_pid_n: int, num_pid_k): + """ + Maps 1D pid to 3D grid coords (pid_m, pid_n, pid_k). + Args: + - pid: 1D pid + - num_pid_m: grid m size + - num_pid_n: grid n size + - num_pid_k: grid k size + + Returns: + - pid_m, pid_n, pid_k: 3D grid coordinates + """ + pid_m = pid % num_pid_m + pid_n = (pid // num_pid_m) % num_pid_n + pid_k = pid // (num_pid_m * num_pid_n) % num_pid_k + + return pid_m, pid_n, pid_k diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py new file mode 100644 index 0000000000000000000000000000000000000000..153dee65b50ab5f833262481889d2184d1ca639f --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py @@ -0,0 +1,752 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + +# Imports. +# ------------------------------------------------------------------------------ + +# PyTorch +import torch +from torch import Tensor + +# AITER: logging +from .logger import AiterTritonLogger + +_LOGGER: AiterTritonLogger = AiterTritonLogger() + + +# Supported data types. +# ------------------------------------------------------------------------------ + +# Supported data types, as strings. +SUPPORTED_DTYPES_STR: set[str] = {"fp16", "bf16"} + + +# Convert string data type to PyTorch data type. +def dtype_from_str(dtype_str: str) -> torch.dtype: + dtype_str = dtype_str.strip().lower() + dtype_str = dtype_str[1:] if dtype_str[0] in {"i", "o"} else dtype_str + assert ( + dtype_str in SUPPORTED_DTYPES_STR + ), "String data type isn't in set of supported string data types." + return {"fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str] + + +# Supported data types, as PyTorch types. +SUPPORTED_DTYPES: set[torch.dtype] = { + dtype_from_str(dtype_str) for dtype_str in SUPPORTED_DTYPES_STR +} + + +# Convert PyTorch data type to string data type. +def str_from_dtype(dtype: torch.dtype) -> str: + assert ( + dtype in SUPPORTED_DTYPES + ), "PyTorch data type isn't in set of supported PyTorch data types." + return {torch.float16: "fp16", torch.bfloat16: "bf16"}[dtype] + + +# Default data type, as string. +DTYPE_STR: str = "bf16" +assert ( + DTYPE_STR in SUPPORTED_DTYPES_STR +), "Default string data type isn't in set of supported string data types." + + +# Default data type, as PyTorch type. +DTYPE: torch.dtype = dtype_from_str(DTYPE_STR) + + +# Other defaults. +# ------------------------------------------------------------------------------ + +# Default device. +DEVICE: torch.device | str = "cuda" + +# Default RNG seed for input generation. +RNG_SEED: int = 0 + +# Default number of group sizes. +NUM_GROUP_SIZES: int = 1 + +# Default transposition (NN). +TRANS_LHS: bool = False +TRANS_RHS: bool = False + + +# Parameter checking functions. +# ------------------------------------------------------------------------------ + + +def is_power_of_2(x: int) -> bool: + return (x > 0) and (x & (x - 1) == 0) + + +def check_input_device_dtype( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor, bias: Tensor | None = None +) -> None: + assert ( + lhs.device == rhs.device == group_sizes.device + ), f"All input tensors must be in the same device (lhs = {lhs.device}, rhs = {rhs.device}, group_sizes = {group_sizes.device})." + assert ( + lhs.dtype == rhs.dtype + ), f"lhs and rhs types must match (lhs = {lhs.dtype}, rhs = {rhs.dtype})." + assert group_sizes.dtype == torch.int32, "group_sizes type must be int32." + + if bias is not None: + assert ( + bias.device == lhs.device + ), f"bias must be on the same device as lhs (bias = {bias.device}, lhs = {lhs.device})." + assert ( + bias.dtype == lhs.dtype + ), f"bias dtype must match lhs dtype (bias = {bias.dtype}, lhs = {lhs.dtype})." + + +def check_bias_shape_stride(bias: Tensor, G: int, N: int) -> None: + assert bias.shape == ( + G, + N, + ), f"bias must have shape (G, N) = ({G}, {N}), got {bias.shape}." + assert bias.stride() == (N, 1), "bias must be row-major (bias.stride() == (N, 1))." + + +# Generation of group sizes. +# ------------------------------------------------------------------------------ + + +# Probabilities for generating random group sizes. +UNUSED_TOKENS_PROB: float = 0.0 +UNUSED_EXPERTS_PROB: float = 0.1 + + +def gen_uniform_group_sizes( + M: int, + G: int, + device: torch.device | str = DEVICE, +) -> Tensor: + assert M >= 0, f"Number of tokens M must be non-negative (it's {M})." + assert G > 0, f"Number of experts G must be positive (it's {G})." + + base = M // G + remainder = M % G + group_sizes = torch.full((G,), base, dtype=torch.int32, device=device) + if remainder > 0: + group_sizes[:remainder] += 1 + + assert ( + len(group_sizes) == G + ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})." + assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative." + assert ( + torch.sum(group_sizes).item() == M + ), f"Group sizes don't add up to total tokens {M}." + assert group_sizes.dtype == torch.int32, "Group sizes must be int32." + + return group_sizes + + +def gen_group_sizes( + M: int, + G: int, + device: torch.device | str = DEVICE, + rng_seed: int | None = RNG_SEED, + unused_tokens_prob: float = UNUSED_TOKENS_PROB, + unused_experts_prob: float = UNUSED_EXPERTS_PROB, +) -> Tensor: + assert M >= 0, f"Number of tokens M must be non-negative (it's {M})." + assert G > 0, f"Number of experts G must be positive (it's {G})." + assert ( + 0 <= unused_tokens_prob <= 1 + ), f"Probability of unused tokens must be in [0, 1] interval (it's {unused_tokens_prob})." + assert ( + 0 <= unused_experts_prob <= 1 + ), f"Probability of unused experts must be in [0, 1] interval (it's {unused_experts_prob})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + if unused_tokens_prob > 0: + # Optionally drop tokens to simulate routing sparsity, some tokens may not be routed. + num_unused_tokens = M + while num_unused_tokens == M: + num_unused_tokens = int( + torch.binomial( + torch.tensor(float(M), device=device), + torch.tensor(unused_tokens_prob, device=device), + ).item() + ) + else: + num_unused_tokens = 0 + num_used_tokens = M - num_unused_tokens + assert ( + num_unused_tokens >= 0 + ), f"Number of unused tokens must be non-negative (it's {num_unused_tokens})." + assert ( + num_used_tokens > 0 + ), f"Number of used tokens must be positive (it's {num_used_tokens})." + assert ( + num_used_tokens + num_unused_tokens == M + ), f"Unused + used tokens don't add up total tokens ({num_used_tokens} + {num_unused_tokens} != {M})." + + if num_unused_tokens > 0: + _LOGGER.debug( + f"Group sizes generation: dropped {num_unused_tokens} token{'s' if num_unused_tokens > 1 else ''}.", + ) + + if unused_experts_prob > 0: + # Some experts may have zero tokens assigned to them. + num_used_experts = 0 + while num_used_experts == 0: + used_experts = torch.nonzero( + torch.rand((G,), device=device) >= unused_experts_prob + ).squeeze() + num_used_experts = used_experts.numel() + else: + used_experts = torch.arange(0, G, device=device) + num_used_experts = G + num_unused_experts = G - num_used_experts + assert ( + num_unused_experts >= 0 + ), f"Number of unused experts must be non-negative (it's {num_unused_experts})." + assert ( + num_used_experts >= 1 + ), f"At least one expert must be used (it's {num_used_experts})." + assert ( + num_unused_experts + num_used_experts == G + ), f"Unused + used experts don't add up total experts ({num_unused_experts} + {num_used_experts} != {G})." + + if num_unused_experts > 0: + _LOGGER.debug( + f"Group sizes generation: dropped {num_unused_experts} expert{'s' if num_unused_experts > 1 else ''}.", + ) + + group_sizes = torch.bincount( + used_experts[ + torch.randint(low=0, high=num_used_experts, size=(num_used_tokens,)) + ], + minlength=G, + ).to(torch.int32) + + assert ( + len(group_sizes) == G + ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})." + assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative." + assert ( + torch.sum(group_sizes).item() == num_used_tokens + ), f"Group sizes don't add up to used tokens {num_used_tokens}." + assert group_sizes.dtype == torch.int32, "Group sizes must be int32." + + return group_sizes + + +def gen_multiple_group_sizes( + num_group_sizes: int, + M: int, + G: int, + device: torch.device | str = DEVICE, + rng_seed: int | None = RNG_SEED, + unused_tokens_prob: float = UNUSED_TOKENS_PROB, + unused_experts_prob: float = UNUSED_EXPERTS_PROB, + group_sizes_0: Tensor | None = None, +) -> list[Tensor]: + assert ( + num_group_sizes > 0 + ), f"Number of group sizes to be generated must be positive, it's {num_group_sizes}." + multiple_group_sizes = [ + gen_group_sizes( + M, + G, + device=device, + rng_seed=rng_seed if g == 0 else None, + unused_tokens_prob=unused_tokens_prob, + unused_experts_prob=unused_experts_prob, + ) + for g in range( + num_group_sizes if group_sizes_0 is None else num_group_sizes - 1 + ) + ] + if group_sizes_0 is not None: + multiple_group_sizes.insert(0, group_sizes_0) + assert ( + len(multiple_group_sizes) == num_group_sizes + ), f"Expecting {num_group_sizes} distinct group sizes (it's {len(multiple_group_sizes)})." + return multiple_group_sizes + + +# GMM helpers: tensor generation. +# ------------------------------------------------------------------------------ + + +def gen_gmm_input( + M: int, + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + trans_rhs: bool = TRANS_RHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, +) -> tuple[Tensor, Tensor, Tensor]: + assert M > 0, f"Number of lhs rows M must be positive (M = {M})." + assert K > 0, f"Number of lhs columns / rhs rows K must be positive (K = {K})." + assert N > 0, f"Number of rhs columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + lhs = torch.randn((M, K), dtype=torch.float32, device=device) + lhs = lhs.to(preferred_element_type) + + if trans_rhs: + rhs = torch.randn((G, N, K), dtype=torch.float32, device=device).permute( + 0, 2, 1 + ) + else: + rhs = torch.randn((G, K, N), dtype=torch.float32, device=device) + rhs = rhs.to(preferred_element_type) + + group_sizes = ( + gen_uniform_group_sizes(M, G, device=device) + if unif_group_sizes + else gen_group_sizes(M, G, device=device, rng_seed=None) + ) + + return lhs, rhs, group_sizes + + +def gen_gmm_output( + M: int, + N: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, +) -> Tensor: + assert M > 0, f"Number of out rows M must be positive (M = {M})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + + out = torch.empty((M, N), dtype=preferred_element_type, device=device) + + return out + + +def gen_gmm_tensors( + M: int, + K: int, + N: int, + G: int, + num_group_sizes: int, + device: torch.device | str = DEVICE, + input_type: torch.dtype = DTYPE, + output_type: torch.dtype = DTYPE, + trans_lhs: bool = False, + trans_rhs: bool = TRANS_RHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, + use_bias: bool = False, +) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]: + lhs, rhs, group_sizes_0 = gen_gmm_input( + M, + K, + N, + G, + device=device, + preferred_element_type=input_type, + trans_rhs=trans_rhs, + rng_seed=rng_seed, + unif_group_sizes=unif_group_sizes, + ) + multiple_group_sizes = gen_multiple_group_sizes( + num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0 + ) + out = gen_gmm_output(M, N, device=device, preferred_element_type=output_type) + bias = None + if use_bias: + torch.manual_seed(rng_seed + 1000) # Different seed for bias + bias = torch.randn(G, N, dtype=input_type, device=device) + + return lhs, rhs, multiple_group_sizes, out, bias + + +# GMM helpers: get information from tensors. +# ------------------------------------------------------------------------------ + + +def get_gmm_shape( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor +) -> tuple[int, int, int, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})." + assert ( + group_sizes.dim() == 1 + ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})." + + M, lhs_k = lhs.shape + rhs_g, rhs_k, N = rhs.shape + group_sizes_g = group_sizes.shape[0] + + assert ( + lhs_k == rhs_k + ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})." + K = lhs_k + assert ( + rhs_g == group_sizes_g + ), f"G dimension of rhs and group_sizes don't match (rhs = {rhs_g}, group_sizes = {group_sizes_g})." + G = rhs_g + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + return M, K, N, G + + +def get_gmm_output( + M: int, + N: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, +) -> Tensor: + assert M > 0, f"Number of out rows M must be positive (M = {M})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + + if existing_out is not None: + assert ( + existing_out.device == device + ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})." + assert ( + existing_out.dtype == preferred_element_type + ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})." + assert existing_out.shape == ( + M, + N, + ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(M, N)})." + return existing_out + + return gen_gmm_output( + M, + N, + device=device, + preferred_element_type=preferred_element_type, + ) + + +def get_gmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})." + assert out.dim() == 2, f"out must have 2 dimensions (it's {out.dim()})." + + lhs_m, lhs_k = lhs.shape + G, rhs_k, rhs_n = rhs.shape + out_m, out_n = out.shape + + assert ( + lhs_m == out_m + ), f"M dimension of lhs and out don't match (lhs = {lhs_m}, rhs = {out_m})." + M = lhs_m + assert ( + lhs_k == rhs_k + ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})." + K = lhs_k + assert ( + rhs_n == out_n + ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})." + N = rhs_n + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + is_lhs_row_major = lhs.stride() == (K, 1) + assert is_lhs_row_major, "lhs must be row-major." + is_rhs_row_major = rhs.stride() == (K * N, N, 1) + is_rhs_col_major = rhs.stride() == (K * N, 1, K) + assert ( + is_rhs_row_major != is_rhs_col_major + ), "rhs must be row-major or column-major." + is_out_row_major = out.stride() == (N, 1) + assert is_out_row_major, "out must be row-major." + + # Get rhs leading dimension according to transposition configuration. + ld_rhs = N if is_rhs_row_major else K + + return is_rhs_col_major, ld_rhs + + +# TGMM helpers: tensor generation. +# ------------------------------------------------------------------------------ + + +def gen_tgmm_input( + M: int, + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + trans_lhs: bool = TRANS_LHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, +) -> tuple[Tensor, Tensor, Tensor]: + assert K > 0, f"Number of lhs rows K must be positive (M = {K})." + assert M > 0, f"Number of lhs columns / rhs rows M must be positive (K = {M})." + assert N > 0, f"Number of rhs columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + if trans_lhs: + lhs = torch.randn((M, K), dtype=torch.float32, device=device).T + else: + lhs = torch.randn((K, M), dtype=torch.float32, device=device) + lhs = lhs.to(preferred_element_type) + + rhs = torch.randn((M, N), dtype=torch.float32, device=device) + rhs = rhs.to(preferred_element_type) + + group_sizes = ( + gen_uniform_group_sizes(M, G, device=device) + if unif_group_sizes + else gen_group_sizes(M, G, device=device, rng_seed=None) + ) + + return lhs, rhs, group_sizes + + +def gen_tgmm_output( + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, +) -> Tensor: + assert K > 0, f"Number of out rows K must be positive (K = {K})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + out = torch.empty((G, K, N), dtype=preferred_element_type, device=device) + + return out + + +def gen_tgmm_bias_grad( + K: int, + G: int, + device: torch.device | str = DEVICE, + with_bias_grad: bool = False, +) -> Tensor: + if with_bias_grad: + assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + return torch.empty((G, K), device=device, dtype=torch.float32) + else: + # Return dummy pointer when bias_grad is not needed. + # Must be float32 because atomic_add does not support bf16/fp16, + # and Triton validates the pointer dtype even in dead branches. + return torch.tensor([], device=device, dtype=torch.float32) + + +def gen_tgmm_tensors( + M: int, + K: int, + N: int, + G: int, + num_group_sizes: int, + device: torch.device | str = DEVICE, + input_type: torch.dtype = DTYPE, + output_type: torch.dtype = DTYPE, + trans_lhs: bool = TRANS_LHS, + trans_rhs: bool = False, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, + use_bias: bool = False, +) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]: + lhs, rhs, group_sizes_0 = gen_tgmm_input( + M, + K, + N, + G, + device=device, + preferred_element_type=input_type, + trans_lhs=trans_lhs, + rng_seed=rng_seed, + unif_group_sizes=unif_group_sizes, + ) + multiple_group_sizes = gen_multiple_group_sizes( + num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0 + ) + out = gen_tgmm_output(K, N, G, device=device, preferred_element_type=output_type) + if use_bias: + bias_grad = gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=True) + else: + bias_grad = None + return lhs, rhs, multiple_group_sizes, out, bias_grad + + +# TGMM helpers: get information from tensors. +# ------------------------------------------------------------------------------ + + +def get_tgmm_shape( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor +) -> tuple[int, int, int, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})." + assert ( + group_sizes.dim() == 1 + ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})." + + K, lhs_m = lhs.shape + rhs_m, N = rhs.shape + G = group_sizes.shape[0] + + assert ( + lhs_m == rhs_m + ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})." + M = lhs_m + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + return M, K, N, G + + +def get_tgmm_output( + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, +) -> Tensor: + assert K > 0, f"Number of out rows K must be positive (K = {K})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if existing_out is not None: + assert ( + existing_out.device == device + ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})." + assert ( + existing_out.dtype == preferred_element_type + ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})." + assert existing_out.shape == ( + G, + K, + N, + ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(G, K, N)})." + return existing_out + + return gen_tgmm_output( + K, + N, + G, + device=device, + preferred_element_type=preferred_element_type, + ) + + +def get_tgmm_bias_grad( + K: int, + G: int, + device: torch.device | str = DEVICE, + existing_bias_grad: Tensor | None = None, +) -> Tensor: + """ + Get or validate bias gradient tensor for TGMM. + + If existing_bias_grad is provided, validates its shape, device, dtype, and stride, + and always zeros it before returning (since the kernel uses atomic_add). + If existing_bias_grad is None, returns a dummy tensor (for use when COMPUTE_BIAS_GRAD=False). + Parameters + ---------- + K : int + Number of rows in the bias gradient tensor. + G : int + Number of groups. + device : torch.device or str + Device for the tensor. + existing_bias_grad : torch.Tensor or None + Existing bias gradient tensor to validate and use. + Returns + ------- + torch.Tensor + Valid bias gradient tensor or dummy tensor. + """ + assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if existing_bias_grad is not None: + # Validate existing bias_grad tensor. + expected_shape = (G, K) + assert ( + tuple(existing_bias_grad.shape) == expected_shape + ), f"bias_grad must have shape {expected_shape}, got {tuple(existing_bias_grad.shape)}." + assert ( + existing_bias_grad.device == device + ), f"bias_grad must be on the same device (bias_grad = {existing_bias_grad.device}, device = {device})." + assert ( + existing_bias_grad.dtype == torch.float32 + ), f"bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), got {existing_bias_grad.dtype}." + assert existing_bias_grad.stride() == ( + K, + 1, + ), f"bias_grad must be row-major with stride (K, 1) = ({K}, 1), got {existing_bias_grad.stride()}." + + # Always zero the tensor since bias_grad represents gradients for the current + # computation and should start fresh. The kernel uses atomic_add which adds to + # existing values, so we must zero before the kernel runs. + existing_bias_grad.zero_() + + return existing_bias_grad + + else: + return gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=False) + + +def get_tgmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})." + assert out.dim() == 3, f"out must have 3 dimensions (it's {out.dim()})." + + lhs_k, lhs_m = lhs.shape + rhs_m, rhs_n = rhs.shape + G, out_k, out_n = out.shape + + assert ( + lhs_m == rhs_m + ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})." + M = lhs_m + assert ( + lhs_k == out_k + ), f"K dimension of lhs and out don't match (lhs = {lhs_k}, rhs = {out_k})." + K = lhs_k + assert ( + rhs_n == out_n + ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})." + N = rhs_n + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + is_lhs_row_major = lhs.stride() == (M, 1) + is_lhs_col_major = lhs.stride() == (1, K) + assert ( + is_lhs_row_major != is_lhs_col_major + ), "lhs must be row-major or column-major." + is_rhs_row_major = rhs.stride() == (N, 1) + assert is_rhs_row_major, "rhs must be row-major." + is_out_row_major = out.stride() == (K * N, N, 1) + assert is_out_row_major, "out must be row-major." + + # Get lhs leading dimension according to transposition configuration. + ld_lhs = M if is_lhs_row_major else K + + return is_lhs_col_major, ld_lhs diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/logger.py b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..391ddf9b6543f5244e7f4932c8568d60748e15cd --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/logger.py @@ -0,0 +1,47 @@ +import os +import logging + + +# AITER Triton Logger which is singleton object around python logging. +# Note: Python logging is also a singleton object, but we want to read the +# env var AITER_LOG_LEVEL once at the beginning. Another alternative is to do +# this in __init__.py. In fact, that's how CK logger is setup. We can look at +# switching to that at some point +# +# AITER_LOG_LEVEL follows python logging levels +# DEBUG +# INFO +# WARNING +# ERROR +# CRITICAL +# +class AiterTritonLogger(object): + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(AiterTritonLogger, cls).__new__(cls) + log_level_str = os.getenv("AITER_TRITON_LOG_LEVEL", "WARNING").upper() + numeric_level = getattr(logging, log_level_str, logging.WARNING) + cls._instance._logger = logging.getLogger("AITER_TRITON") + cls._instance._logger.setLevel(numeric_level) + + return cls._instance + + def get_logger(self): + return self._logger + + def debug(self, msg): + self._logger.debug(msg) + + def info(self, msg): + self._logger.info(msg) + + def warning(self, msg): + self._logger.warning(msg) + + def error(self, msg): + self._logger.error(msg) + + def critical(self, msg): + self._logger.critical(msg) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_megablocks_cuda_ae601bb.abi3.so b/build/torch211-cxx11-cu130-x86_64-linux/_megablocks_cuda_ae601bb.abi3.so deleted file mode 100644 index 1e3363894846896e60d405193be55a068d1922ce..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu130-x86_64-linux/_megablocks_cuda_ae601bb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8f05428251fcba79071d881be47c1d2778f2fb3a068d029c7f6c4f546efa5b64 -size 10113080 diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_megablocks_cuda_f8f8b50.abi3.so b/build/torch211-cxx11-cu130-x86_64-linux/_megablocks_cuda_f8f8b50.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..547f2d66baa2523fe1fb0a08898f70a2684e651b --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_megablocks_cuda_f8f8b50.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5ef673d78d220cea71eace3a5bdb4b952444ab7b95ed15774258ad108ad40d51 +size 11769248 diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_ops.py b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py index 8dd1b7bcf680d2d32dd4ac912487118eafcee4ea..69afb8c26a3fa2691be277b0270d600d29a5865e 100644 --- a/build/torch211-cxx11-cu130-x86_64-linux/_ops.py +++ b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_cuda_ae601bb -ops = torch.ops._megablocks_cuda_ae601bb +from . import _megablocks_cuda_f8f8b50 +ops = torch.ops._megablocks_cuda_f8f8b50 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_cuda_ae601bb::{op_name}" + return f"_megablocks_cuda_f8f8b50::{op_name}" diff --git a/build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py b/build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py index 76037d8039cbfc2f0577275c78e4bc0be762592a..c7ef28ced79c830dae934177f059c1f4ddc24aad 100644 --- a/build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py +++ b/build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py @@ -2,16 +2,16 @@ # extensions. Otherwise libc10.so cannot be found. import torch -# # TODO(tgale): Wrap this in a try-block with better -# # error message and instructions for building the -# # c++ operations. -# import grouped_gemm_backend as backend +# On ROCm there is no CUTLASS grouped GEMM; dispatch to the vendored AITER +# Triton kernels instead. On CUDA we use the compiled CUTLASS `gmm` op. +_IS_ROCM = torch.version.hip is not None -# We import the backend operations from the megablocks package as -# grouped_gemm is vendored in megablocks in this repository. -# from ... import _ops as backend -# from megablocks._ops import ops as backend # type: ignore -from .._ops import ops as backend # type: ignore +if _IS_ROCM: + from .._grouped_gemm_triton import adapter as backend +else: + # We import the backend operations from the megablocks package as + # grouped_gemm is vendored in megablocks in this repository. + from .._ops import ops as backend # type: ignore def _allocate_output(a, b, batch_sizes, trans_a, trans_b): assert not (trans_a and trans_b) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/metadata.json b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json index dae1319c841f27d4cd7a5a4b31fbde6ae4d4cacd..436ad3fc85ff69b069290830671db574d1045671 100644 --- a/build/torch211-cxx11-cu130-x86_64-linux/metadata.json +++ b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json @@ -1,6 +1,6 @@ { "name": "megablocks", - "id": "_megablocks_cuda_ae601bb", + "id": "_megablocks_cuda_f8f8b50", "version": 1, "license": "Apache-2.0", "python-depends": [], @@ -8,7 +8,9 @@ "type": "cuda", "archs": [ "10.0", + "11.0", "12.0", + "12.0+PTX", "7.5", "8.0", "8.6", diff --git a/build/torch212-cxx11-cu126-x86_64-linux/__init__.py b/build/torch212-cxx11-cu126-x86_64-linux/__init__.py index 38075732c6d8fa0e1e6ef493145e1aca3851ae6b..0766d7b8da4f97baca212177b4bb989bc6374bf8 100644 --- a/build/torch212-cxx11-cu126-x86_64-linux/__init__.py +++ b/build/torch212-cxx11-cu126-x86_64-linux/__init__.py @@ -3,7 +3,9 @@ import torch -from ._ops import ops +# Stable alias: bare `ops` is shadowed by `from . import layers` below. +from ._ops import ops as _compiled_ops +from . import ops from .grouped_gemm import backend as gg_backend from .grouped_gemm import ops as gg_ops @@ -136,7 +138,8 @@ def sort( Returns: The sorted values tensor """ - return ops.sort(x, end_bit, x_out, iota_out) + _compiled_ops.sort(x, end_bit, x_out, iota_out) + return x_out # Convenience functions for common use cases diff --git a/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/__init__.py b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py new file mode 100644 index 0000000000000000000000000000000000000000..8c101d07cea416f9390b708e5a35fdc466e48aed --- /dev/null +++ b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py @@ -0,0 +1,574 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + + +# Imports. +# ------------------------------------------------------------------------------ + +# Python standard library +import functools + +# Triton +import triton +import triton.language as tl + +# AITER +from ..configs import CONFIGS as _CONFIGS +from ..utils._triton import arch_info +from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd + +# Kernel config. +# ------------------------------------------------------------------------------ + + +@functools.lru_cache() +def get_config( + gmm_type: str, M: int, K: int, N: int, G: int, accumulate: bool = False +) -> dict[str, int]: + assert gmm_type in { + "gmm", + "ptgmm", + "nptgmm", + }, f"'{gmm_type}' is an invalid GMM variant." + dev = arch_info.get_arch() + assert ( + dev in _CONFIGS + ), f"No GMM configuration tuned for arch '{dev}'. Supported: {sorted(_CONFIGS)}." + arch_configs = _CONFIGS[dev] + assert ( + "default" in arch_configs[gmm_type] + ), "Default configuration is absent." + key = "accumulate" if accumulate else "default" + return arch_configs[gmm_type][key] + + +# Common code shared by GMM and TGMM kernels. +# ------------------------------------------------------------------------------ + + +# XCD remapping followed by 1D PID to 2D grid mapping. +@triton.jit +def _remap_xcd_tile_grid( + tile_in_mm, + num_row_tiles, + num_col_tiles, + GROUP_SIZE: tl.constexpr = 1, + NUM_XCDS: tl.constexpr = 8, +): + return pid_grid( + remap_xcd(tile_in_mm, num_row_tiles * num_col_tiles, NUM_XCDS=NUM_XCDS), + num_row_tiles, + num_col_tiles, + GROUP_SIZE_M=GROUP_SIZE, + ) + + +# GMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.heuristics( + { + "K_DIVISIBLE_BY_BLOCK_SIZE_K": lambda META: META["K"] % META["BLOCK_SIZE_K"] + == 0, + } +) +@triton.jit +def gmm_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_RHS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + K_DIVISIBLE_BY_BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE: tl.constexpr, + GRID_DIM: tl.constexpr, + USE_BIAS: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + # Current tile. Each program computes multiple tiles of each group. + tile = tl.program_id(0) + tl.device_assert(tile >= 0, "tile < 0 (at initialization)") + + # Tile limit of last MM problem (inclusive). + last_mm_tile = 0 + + # Last input row of lhs and output row of out. Each group reads some rows of + # lhs and writes some rows to out. + last_m = 0 + + # Loop through all (m, K, N) MM problems: + # (m, K) x (K, N) = (m, N) + # sum(m) = M + for g in range(G): + # Get m dimension of current MM problem. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty + tl.device_assert(m >= 0, "m < 0") + + num_m_tiles = tl.cdiv(m, BLOCK_SIZE_M) + # num_m_tiles can be zero if group is empty + tl.device_assert(num_m_tiles >= 0, "num_m_tiles < 0") + + num_tiles = num_m_tiles * num_n_tiles + # num_tiles can be zero if group is empty + tl.device_assert(num_tiles >= 0, "num_tiles < 0") + + # Loop through tiles of current MM problem. + while tile >= last_mm_tile and tile < last_mm_tile + num_tiles: + # Figure out tile coordinates in current MM problem. + tile_in_mm = tile - last_mm_tile + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_m, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_m_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + # Do regular MM: + + tl.device_assert(tile_m * BLOCK_SIZE_M >= 0, "tile_m * BLOCK_SIZE_M < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_m = ( + tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + ) % m + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + + lhs_ptrs = lhs_ptr + (last_m + offs_lhs_m[:, None]) * K + offs_k[None, :] + + if TRANS_RHS: + rhs_ptrs = ( + rhs_ptr + + g.to(tl.int64) * K * N + + offs_k[:, None] + + offs_rhs_n[None, :] * K + ) + else: + rhs_ptrs = ( + rhs_ptr + + g.to(tl.int64) * K * N + + offs_k[:, None] * N + + offs_rhs_n[None, :] + ) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if K_DIVISIBLE_BY_BLOCK_SIZE_K: + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + else: + k_mask_limit = K - k * BLOCK_SIZE_K + lhs = tl.load( + lhs_ptrs, mask=offs_k[None, :] < k_mask_limit, other=0 + ) + rhs = tl.load( + rhs_ptrs, mask=offs_k[:, None] < k_mask_limit, other=0 + ) + + acc = tl.dot(lhs, rhs, acc=acc) + + lhs_ptrs += BLOCK_SIZE_K + + if TRANS_RHS: + rhs_ptrs += BLOCK_SIZE_K + else: + rhs_ptrs += BLOCK_SIZE_K * N + + # Add bias if enabled + if USE_BIAS: + offs_bias_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange( + 0, BLOCK_SIZE_N + ) + bias_ptrs = bias_ptr + g.to(tl.int64) * N + offs_bias_n + bias = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0) + # Convert bias to float32 to match accumulator precision + bias = bias.to(tl.float32) + # Broadcast bias across M dimension and add in float32 + acc += bias[None, :] + + # Convert to output dtype after all computations + acc = acc.to(out_ptr.type.element_ty) + + offs_out_m = tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + (last_m + offs_out_m[:, None]) * N + offs_out_n[None, :] + ) + + tl.store( + out_ptrs, + acc, + mask=(offs_out_m[:, None] < m) & (offs_out_n[None, :] < N), + ) + + # Go to the next tile by advancing number of programs. + tile += GRID_DIM + tl.device_assert(tile > 0, "tile <= 0 (at update)") + + # Get ready to go to the next MM problem. + + last_mm_tile += num_tiles + # last_mm_tile can be zero if group 0 is skipped + tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)") + + last_m += m + # last_m can be zero if group 0 is skipped + tl.device_assert(last_m >= 0, "last_m < 0 (at update)") + tl.device_assert(last_m <= M, "last_m > M (at update)") + + +# Persistent TGMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.jit +def tgmm_persistent_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_grad_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_LHS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE: tl.constexpr, + GRID_DIM: tl.constexpr, + COMPUTE_BIAS_GRAD: tl.constexpr, + ACCUMULATE: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0") + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + num_tiles = num_k_tiles * num_n_tiles + tl.device_assert(num_tiles > 0, "num_tiles <= 0") + + # Current tile. Each program computes multiple tiles of each group. + tile = tl.program_id(0) + tl.device_assert(tile >= 0, "tile < 0 (at initialization)") + + # Tile limit of last MM problem (inclusive). + last_mm_tile = 0 + + # Last input column of lhs and input row of rhs. Each group reads some + # columns of lhs and some rows of rhs. + last_m = 0 + + # Loop through all (K, m, N) MM problems: + # (K, m) x (m, N) = (K, N) + # sum(m) = M + for g in range(G): + # Get m dimension of current MM problem. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty + tl.device_assert(m >= 0, "m < 0") + + # Loop through tiles of current MM problem. + while tile >= last_mm_tile and tile < last_mm_tile + num_tiles: + # Figure out tile coordinates in current MM problem. + tile_in_mm = tile - last_mm_tile + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_k, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + # Do regular MM: + + tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + + if TRANS_LHS: + lhs_ptrs = ( + lhs_ptr + offs_lhs_k[:, None] + (last_m + offs_m[None, :]) * K + ) + else: + lhs_ptrs = ( + lhs_ptr + offs_lhs_k[:, None] * M + (last_m + offs_m[None, :]) + ) + + rhs_ptrs = rhs_ptr + (last_m + offs_m[:, None]) * N + offs_rhs_n[None, :] + + loop_m = tl.cdiv(m, BLOCK_SIZE_M) + m_divisible_by_block_m = m % BLOCK_SIZE_M == 0 + if not m_divisible_by_block_m: + loop_m -= 1 + + acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32) + + # Initialize bias accumulator + bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32) + + for _ in range(0, loop_m): + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + + acc = tl.dot(lhs, rhs, acc=acc) + + # Accumulate for bias gradient: sum lhs across M dimension + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum( + lhs, axis=1 + ) # Sum across M dimension [K, M] -> [K] + + if TRANS_LHS: + lhs_ptrs += BLOCK_SIZE_M * K + else: + lhs_ptrs += BLOCK_SIZE_M + + rhs_ptrs += BLOCK_SIZE_M * N + + if not m_divisible_by_block_m: + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0) + rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0) + acc = tl.dot(lhs, rhs, acc=acc) + + # Accumulate last chunk for bias gradient + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) + + acc = acc.to(out_ptr.type.element_ty) + + offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + + g.to(tl.int64) * K * N + + offs_out_k[:, None] * N + + offs_out_n[None, :] + ) + + mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N) + if ACCUMULATE: + # Load existing values and add to them (like beta=1 in BLAS) + old_vals = tl.load(out_ptrs, mask=mask, other=0.0) + tl.store(out_ptrs, acc + old_vals, mask=mask) + else: + # Overwrite output (like beta=0 in BLAS) + tl.store(out_ptrs, acc, mask=mask) + + # Store bias gradient (only for first N tile, sum across all M) + if COMPUTE_BIAS_GRAD and tile_n == 0: + # Keep as float32 for atomic_add (bf16 not supported for atomics) + bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k + # Use atomic add since multiple K-tiles may write to same expert's bias + tl.atomic_add( + bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed" + ) + + # Go to the next tile by advancing number of programs. + tile += GRID_DIM + tl.device_assert(tile > 0, "tile <= 0 (at update)") + + # Get ready to go to the next MM problem. + + last_mm_tile += num_tiles + # last_mm_tile can be zero if group 0 is skipped + tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)") + + last_m += m + # last_m can be zero if group 0 is skipped + tl.device_assert(last_m >= 0, "last_m < 0 (at update)") + tl.device_assert(last_m <= M, "last_m > M (at update)") + + +# Regular non-persistent TGMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.heuristics({"BLOCK_SIZE_G": lambda META: triton.next_power_of_2(META["G"])}) +@triton.jit +def tgmm_non_persistent_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_grad_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_LHS: tl.constexpr, + BLOCK_SIZE_G: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE: tl.constexpr, + COMPUTE_BIAS_GRAD: tl.constexpr, + ACCUMULATE: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + # Get group ID from grid. + g = tl.program_id(0) + tl.device_assert(g >= 0, "g < 0") + tl.device_assert(g < G, "g >= G") + + # Get m dimension of current MM group. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty. + tl.device_assert(m >= 0, "m < 0") + + # Skip empty groups. + if m == 0: + return + + # Compute sum(group_sizes) until current group g. + # It's the starting column of lhs and starting row of rhs. + offs_g = tl.arange(0, BLOCK_SIZE_G) + group_sizes = tl.load(group_sizes_ptr + offs_g, mask=offs_g < g, other=0) + start_m = tl.sum(group_sizes) + + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0") + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + # Get MM tile from grid. + tile_in_mm = tl.program_id(1) + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_k, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_k = (tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K + offs_rhs_n = (tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + + if TRANS_LHS: + lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] + (start_m + offs_m[None, :]) * K + else: + lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] * M + (start_m + offs_m[None, :]) + + rhs_ptrs = rhs_ptr + (start_m + offs_m[:, None]) * N + offs_rhs_n[None, :] + + loop_m = tl.cdiv(m, BLOCK_SIZE_M) + m_divisible_by_block_m = m % BLOCK_SIZE_M == 0 + if not m_divisible_by_block_m: + loop_m -= 1 + + acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32) + # Initialize bias accumulator + bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32) + + for _ in range(0, loop_m): + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + + acc = tl.dot(lhs, rhs, acc=acc) + + # Accumulate for bias gradient: sum lhs across M dimension + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) # [K, M] -> [K] + + if TRANS_LHS: + lhs_ptrs += BLOCK_SIZE_M * K + else: + lhs_ptrs += BLOCK_SIZE_M + + rhs_ptrs += BLOCK_SIZE_M * N + + if not m_divisible_by_block_m: + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0) + rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0) + acc = tl.dot(lhs, rhs, acc=acc) + # Accumulate last chunk for bias gradient + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) + + acc = acc.to(out_ptr.type.element_ty) + + offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + g.to(tl.int64) * K * N + offs_out_k[:, None] * N + offs_out_n[None, :] + ) + + mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N) + if ACCUMULATE: + # Load existing values and add to them (like beta=1 in BLAS) + old_vals = tl.load(out_ptrs, mask=mask, other=0.0) + tl.store(out_ptrs, acc + old_vals, mask=mask) + else: + # Overwrite output (like beta=0 in BLAS) + tl.store(out_ptrs, acc, mask=mask) + + # Store bias gradient (only for first N tile, sum across all M) + if COMPUTE_BIAS_GRAD and tile_n == 0: + # Keep as float32 for atomic_add (bf16/fp16 not supported for atomics) + bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k + # Use atomic add since multiple K-tiles may write to same expert's bias + tl.atomic_add(bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed") diff --git a/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/adapter.py b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..98c224244f27445384e0c2377d73516406927536 --- /dev/null +++ b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/adapter.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Adapt AITER's Triton grouped GEMM to MegaBlocks' ``gmm`` calling convention. + +MegaBlocks (following tgale96/grouped_gemm) uses a single ``gmm`` entry point +with ``trans_a`` / ``trans_b`` flags: + +* ``trans_a=False, trans_b=False``: a(M,K) @ b(G,K,N) -> c(M,N) +* ``trans_a=False, trans_b=True`` : a(M,K) @ b(G,N,K)^T -> c(M,N) (dgrad) +* ``trans_a=True`` : a(M,K)^T @ b(M,N) per group -> c(G,K,N) (wgrad) + +AITER exposes these as two kernels: ``gmm`` ((M,K)@(G,K,N)->(M,N), transposition +of the 3D operand inferred from strides) and ``ptgmm`` ((K,M)@(M,N)->(G,K,N), +transposition of the 2D operand inferred from strides). +""" + +import torch + +from .gmm import gmm as _aiter_gmm +from .gmm import ptgmm as _aiter_ptgmm + + +def gmm(a, b, c, batch_sizes, trans_a=False, trans_b=False): + # AITER requires group sizes to be int32 and to live on the compute device. + group_sizes = batch_sizes.to(device=a.device, dtype=torch.int32) + + # AITER asserts exact strides: gmm wants lhs/rhs row-major (a transposed + # 3D operand must be exactly column-major), tgmm wants rhs row-major and + # lhs row/column-major. Make operands contiguous first so the transposed + # views have the precise strides the kernels expect. `.contiguous()` is a + # no-op when the tensor is already contiguous. + if trans_a: + # Weight gradient: a(M,K), b(M,N) -> c(G,K,N). + # Pass a transposed so AITER sees lhs(K,M) column-major (TRANS_LHS). + _aiter_ptgmm( + a.contiguous().transpose(0, 1), + b.contiguous(), + group_sizes, + preferred_element_type=c.dtype, + existing_out=c, + ) + else: + # trans_b contracts b's last dim: pass a column-major (G,K,N) view. + rhs = b.contiguous() + if trans_b: + rhs = rhs.transpose(1, 2) + _aiter_gmm( + a.contiguous(), + rhs, + group_sizes, + preferred_element_type=c.dtype, + existing_out=c, + ) + return c diff --git a/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/configs.py b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..9a4fe5617d8100869aa76dba9b7d22c7bcab814f --- /dev/null +++ b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/configs.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: MIT +# Tuned GMM configs vendored from ROCm/aiter (aiter/ops/triton/configs/). +# Inlined as a Python module so packaging always includes them. + +CONFIGS = {'gfx1250': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx942': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx950': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}} diff --git a/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/gmm.py b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/gmm.py new file mode 100644 index 0000000000000000000000000000000000000000..e30c9326c6d4e4836d1303e2761ea2440a7f4750 --- /dev/null +++ b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/gmm.py @@ -0,0 +1,567 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + + +# Imports. +# ------------------------------------------------------------------------------ + +# PyTorch +import torch +from torch import Tensor + +# Triton +import triton + +# AITER: GMM utility functions +from .utils.gmm_common import ( + DTYPE, + is_power_of_2, + check_input_device_dtype, + check_bias_shape_stride, + get_gmm_shape, + get_gmm_output, + get_gmm_transposition, + get_tgmm_shape, + get_tgmm_output, + get_tgmm_bias_grad, + get_tgmm_transposition, +) + +# AITER: GMM Triton kernels +from ._triton_kernels.gmm import ( + gmm_kernel, + tgmm_persistent_kernel, + tgmm_non_persistent_kernel, + get_config, +) + +# GMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _gmm_grid( + N: int, + block_size_m: int, + block_size_n: int, + group_sizes: Tensor, + grid_dim: int, +) -> tuple[int]: + assert N > 0, f"N must be positive, it's {N}." + assert is_power_of_2( + block_size_m + ), f"M-dimension tile size must be a power of 2 (it's {block_size_m})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + assert torch.all(group_sizes >= 0).item(), "All group_sizes must be non-negative." + assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})." + num_m_tiles = (group_sizes + block_size_m - 1) // block_size_m + assert torch.all(num_m_tiles >= 0).item(), "All num_m_tiles must be non-negative." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles = torch.sum(num_m_tiles * num_n_tiles).item() + assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}." + num_programs = int(min(grid_dim, num_tiles)) + assert num_programs > 0, f"num_programs must be positive, it's {num_programs}." + return (num_programs,) + + +def gmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias: Tensor | None = None, +) -> Tensor: + """ + Perform Group Matrix Multiplication (GMM): out = lhs @ rhs + bias + + lhs rows are divided into G groups. Each group of lhs rows is matrix multiplied with a plane of + rhs 3D tensor and then stored in a slice of out. In PyTorch parlance, it can be implemented as + follows for a given group g: + out[group_start:group_end, :] = lhs[group_start:group_end, :] @ rhs[g] + bias[g] + + The size of each group, and their respective start and end positions are specified by + group_sizes tensor. For instance, suppose that group_sizes = [3, 2, 4, 1]. In this particular + case we have 4 groups. The 1st group starts at 0 and ends at 2, the second group starts at 3 and + ends at 4, the third group starts at 5 and ends at 8, and the fourth and final group consists of + just the 10th (last) row of lhs. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (M, K). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 3D input tensor. Shape: (G, K, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (M, N), its data type must match preferred_element_type + and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias : torch.Tensor or None, optional + Optional bias tensor. Shape: (G, N). + If provided, bias data type must match lhs and rhs data type, and bias must be on the same + device as other input tensors. Each group g adds bias[g] to the output. + + Returns + ------- + torch.Tensor + The computed output 2D tensor. Shape: (M, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - GMM is implemented with a persistent Triton kernel. + - lhs must be row-major (lhs.stride() == (K, 1)). + - rhs can be row-major (rhs.stride() == (K * N, N, 1)) or column-major (rhs.stride() == + (K * N, 1, K)). If rhs is row-major then kernel parameter TRANS_RHS == False, this is useful + for implementing forward pass. If rhs is column-major then kernel parameter TRANS_RHS == True, + this is useful for computing the lhs derivative in the backward pass, while fusing the + transposition. + - out must be row-major (out.stride() == (N, 1)). + - bias must be row-major (bias.stride() == (N, 1)) if provided. + """ + use_bias = bias is not None + check_input_device_dtype(lhs, rhs, group_sizes, bias) + + M, K, N, G = get_gmm_shape(lhs, rhs, group_sizes) + + if use_bias: + check_bias_shape_stride(bias, G, N) + + out = get_gmm_output( + M, + N, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_rhs, _ = get_gmm_transposition(lhs, rhs, out) + + if config is None: + config = get_config("gmm", M, K, N, G) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + "GRID_DIM", + } + ), "Invalid GMM kernel config." + + grid = _gmm_grid( + N, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + group_sizes, + config["GRID_DIM"], + ) + + # fmt: off + gmm_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_RHS=trans_rhs, + USE_BIAS=use_bias, + **config, + ) + # fmt: on + + return out + + +# Persistent TGMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _ptgmm_grid( + K: int, + N: int, + G: int, + block_size_k: int, + block_size_n: int, + grid_dim: int, +) -> tuple[int]: + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}." + assert G > 0, f"G must be positive, it's {G}." + assert is_power_of_2( + block_size_k + ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})." + num_k_tiles = triton.cdiv(K, block_size_k) + assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles = G * num_k_tiles * num_n_tiles + assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}." + num_programs = min(grid_dim, num_tiles) + assert num_programs > 0, f"num_programs must be positive, it's {num_programs}." + return (num_programs,) + + +def ptgmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias_grad: Tensor | None = None, + accumulate: bool = False, +) -> Tensor: + """ + Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs + + lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with + the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch + parlance, it can be implemented as follows for a given group g: + out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :] + + The 't' in the operator name derives from MaxText implementation + (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py), + which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor + shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N). + + The 'p' in the operator name means that it is implemented with a persistent kernel. There is + also the non-persistent variation, which is implemented with a regular kernel. Please take a + look at nptgmm operator. Both ptgmm and nptgmm implement the same computation, choosing one or + the other is a matter of performance for the target workload. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (K, M). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 2D input tensor. Shape: (M, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (G, K, N), its data type must match + preferred_element_type and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias_grad : torch.Tensor or None, optional + Optional bias gradient output tensor. Shape: (G, K). + If provided, the kernel will compute the bias gradient and write it to this tensor. + bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), + accumulate : bool, optional + Whether to accumulate into existing output tensor values. Default is False. + If False, output will be overwritten with fresh computation. + If True, results will be added to existing output tensor values. + + Returns + ------- + torch.Tensor + The computed output 3D tensor. Shape: (G, K, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - PTGMM is implemented with a persistent Triton kernel. + - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs + is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel + parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward + pass, while fusing the transposition. + - rhs must be row-major (rhs.stride() == (N, 1)). + - out must be row-major (out.stride() == (K * N, N, 1)). + """ + check_input_device_dtype(lhs, rhs, group_sizes) + + M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes) + + out = get_tgmm_output( + K, + N, + G, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out) + + if config is None: + config = get_config("ptgmm", M, K, N, G, accumulate) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + "GRID_DIM", + } + ), "Invalid PTGMM kernel config." + + # Bias gradient handling. + # ----------------------- + # Get or validate bias gradient tensor. + compute_bias_grad = bias_grad is not None + bias_grad_ptr = get_tgmm_bias_grad( + K, + G, + device=lhs.device, + existing_bias_grad=bias_grad, + ) + + grid = _ptgmm_grid( + K, + N, + G, + config["BLOCK_SIZE_K"], + config["BLOCK_SIZE_N"], + config["GRID_DIM"], + ) + + # fmt: off + tgmm_persistent_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias_grad_ptr, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_LHS=trans_lhs, + COMPUTE_BIAS_GRAD=compute_bias_grad, + ACCUMULATE=accumulate, + **config, + ) + # fmt: on + + return out + + +# Regular non-persistent TGMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _nptgmm_grid( + K: int, + N: int, + G: int, + block_size_k: int, + block_size_n: int, +) -> tuple[int, int]: + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}." + assert G > 0, f"G must be positive, it's {G}." + assert is_power_of_2( + block_size_k + ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + num_k_tiles = triton.cdiv(K, block_size_k) + assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles_per_mm = num_k_tiles * num_n_tiles + assert ( + num_tiles_per_mm > 0 + ), f"num_tiles_per_mm must be positive, it's {num_tiles_per_mm}." + return (G, num_tiles_per_mm) + + +def nptgmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias_grad: Tensor | None = None, + accumulate: bool = False, +) -> Tensor: + """ + Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs + + lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with + the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch + parlance, it can be implemented as follows for a given group g: + out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :] + + The 't' in the operator name derives from MaxText implementation + (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py), + which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor + shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N). + + The 'np' in the operator name means that it is implemented with a non-persistent, i.e. regular + kernel. There is also the persistent variation, which is implemented with a persistent kernel. + Please take a look at ptgmm operator. Both nptgmm and ptgmm implement the same computation, + choosing one or the other is a matter of performance for the target workload. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (K, M). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 2D input tensor. Shape: (M, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (G, K, N), its data type must match + preferred_element_type and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias_grad : torch.Tensor or None, optional + Optional bias gradient output tensor. Shape: (G, K). + If provided, the kernel will compute the bias gradient and write it to this tensor. + bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), + accumulate : bool, optional + Whether to accumulate into existing output tensor values. Default is False. + If False, output will be overwritten with fresh computation. + If True, results will be added to existing output tensor values. + + Returns + ------- + torch.Tensor + The computed output 3D tensor. Shape: (G, K, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - NPTGMM is implemented with a non-persistent regular Triton kernel. + - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs + is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel + parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward + pass, while fusing the transposition. + - rhs must be row-major (rhs.stride() == (N, 1)). + - out must be row-major (out.stride() == (K * N, N, 1)). + """ + check_input_device_dtype(lhs, rhs, group_sizes) + + M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes) + + out = get_tgmm_output( + K, + N, + G, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out) + + # Bias gradient handling. + # ----------------------- + # Get or validate bias gradient tensor. + compute_bias_grad = bias_grad is not None + bias_grad_ptr = get_tgmm_bias_grad( + K, + G, + device=lhs.device, + existing_bias_grad=bias_grad, + ) + + if config is None: + config = get_config("nptgmm", M, K, N, G, accumulate) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + } + ), "Invalid NPTGMM kernel config." + + grid = _nptgmm_grid( + K, + N, + G, + config["BLOCK_SIZE_K"], + config["BLOCK_SIZE_N"], + ) + + # fmt: off + tgmm_non_persistent_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias_grad_ptr, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_LHS=trans_lhs, + COMPUTE_BIAS_GRAD=compute_bias_grad, + ACCUMULATE=accumulate, + **config, + ) + # fmt: on + + return out diff --git a/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/__init__.py b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py new file mode 100644 index 0000000000000000000000000000000000000000..3f6c88581a64044518125623f116082c53bd5474 --- /dev/null +++ b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py @@ -0,0 +1,46 @@ +import triton + +# Detect the GPU arch lazily: querying the triton driver at import time fails +# in headless environments (e.g. the kernel-builder ABI check sandbox has no +# GPU), and the original JAX fallback pulled in an unrelated runtime dep. The +# arch is only actually needed when a GMM kernel is dispatched, so resolve and +# cache on first call. +_CACHED_ARCH = None + + +def get_arch(): + global _CACHED_ARCH + if _CACHED_ARCH is not None: + return _CACHED_ARCH + try: + _CACHED_ARCH = triton.runtime.driver.active.get_current_target().arch + except RuntimeError: + try: + from jax._src.lib import gpu_triton as triton_kernel_call_lib + _CACHED_ARCH = triton_kernel_call_lib.get_arch_details("0").split(":")[0] + except ImportError as e: + raise RuntimeError( + "Cannot determine GPU arch: triton driver is inactive and " + "JAX is not available. A GPU is required for grouped GEMM." + ) from e + return _CACHED_ARCH + + +def is_gluon_avail(): + return get_arch() in ("gfx950", "gfx1250") + + +def is_fp4_avail(): + return get_arch() in ("gfx950", "gfx1250") + + +def is_fp8_avail(): + return get_arch() in ("gfx942", "gfx950", "gfx1250", "gfx1200", "gfx1201") + + +def is_mx_scale_preshuffling_avail(): + return get_arch() in ("gfx950", "gfx1250") + + +def is_tdm_avail(): + return get_arch() in ("gfx1250",) diff --git a/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..99792bb3ba2fab8fff223bba733ced1eb6e6df53 --- /dev/null +++ b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: MIT + +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl + + +@triton.jit +def remap_xcd_chunked( + pid, GRID_MN, NUM_XCDS: tl.constexpr = 8, CHUNK_SIZE: tl.constexpr = 2 +): + # Compute current XCD and local PID + xcd = pid % NUM_XCDS + # distribute the modulo pids in round robin + if pid > (GRID_MN // (NUM_XCDS * CHUNK_SIZE)) * (NUM_XCDS * CHUNK_SIZE): + return pid + local_pid = pid // NUM_XCDS + # Calculate chunk index and position within chunk + chunk_idx = local_pid // CHUNK_SIZE + pos_in_chunk = local_pid % CHUNK_SIZE + # Calculate new PID + new_pid = chunk_idx * NUM_XCDS * CHUNK_SIZE + xcd * CHUNK_SIZE + pos_in_chunk + return new_pid + + +@triton.jit +def remap_xcd(pid, GRID_MN, NUM_XCDS: tl.constexpr = 8): + ## pid remapping on xcds + # Number of pids per XCD in the new arrangement + pids_per_xcd = (GRID_MN + NUM_XCDS - 1) // NUM_XCDS + # When GRID_MN cannot divide NUM_XCDS, some xcds will have + # pids_per_xcd pids, the other will have pids_per_xcd - 1 pids. + # We calculate the number of xcds that have pids_per_xcd pids as + # tall_xcds + tall_xcds = GRID_MN % NUM_XCDS + tall_xcds = NUM_XCDS if tall_xcds == 0 else tall_xcds + # Compute current XCD and local pid within the XCD + xcd = pid % NUM_XCDS + local_pid = pid // NUM_XCDS + # Calculate new pid based on the new grouping + # Note that we need to consider the following two cases: + # 1. the current pid is on a tall xcd + # 2. the current pid is on a short xcd + if xcd < tall_xcds: + pid = xcd * pids_per_xcd + local_pid + else: + pid = ( + tall_xcds * pids_per_xcd + + (xcd - tall_xcds) * (pids_per_xcd - 1) + + local_pid + ) + + return pid + + +@triton.jit +def pid_grid(pid: int, num_pid_m: int, num_pid_n: int, GROUP_SIZE_M: tl.constexpr = 1): + """ + Maps 1D pid to 2D grid coords (pid_m, pid_n). + + Args: + - pid: 1D pid + - num_pid_m: grid m size + - num_pid_n: grid n size + - GROUP_SIZE_M: tl.constexpr: default is 1 + """ + if GROUP_SIZE_M == 1: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + tl.assume(group_size_m >= 0) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + return pid_m, pid_n + + +@triton.jit +def pid_grid_3d(pid: int, num_pid_m: int, num_pid_n: int, num_pid_k): + """ + Maps 1D pid to 3D grid coords (pid_m, pid_n, pid_k). + Args: + - pid: 1D pid + - num_pid_m: grid m size + - num_pid_n: grid n size + - num_pid_k: grid k size + + Returns: + - pid_m, pid_n, pid_k: 3D grid coordinates + """ + pid_m = pid % num_pid_m + pid_n = (pid // num_pid_m) % num_pid_n + pid_k = pid // (num_pid_m * num_pid_n) % num_pid_k + + return pid_m, pid_n, pid_k diff --git a/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py new file mode 100644 index 0000000000000000000000000000000000000000..153dee65b50ab5f833262481889d2184d1ca639f --- /dev/null +++ b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py @@ -0,0 +1,752 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + +# Imports. +# ------------------------------------------------------------------------------ + +# PyTorch +import torch +from torch import Tensor + +# AITER: logging +from .logger import AiterTritonLogger + +_LOGGER: AiterTritonLogger = AiterTritonLogger() + + +# Supported data types. +# ------------------------------------------------------------------------------ + +# Supported data types, as strings. +SUPPORTED_DTYPES_STR: set[str] = {"fp16", "bf16"} + + +# Convert string data type to PyTorch data type. +def dtype_from_str(dtype_str: str) -> torch.dtype: + dtype_str = dtype_str.strip().lower() + dtype_str = dtype_str[1:] if dtype_str[0] in {"i", "o"} else dtype_str + assert ( + dtype_str in SUPPORTED_DTYPES_STR + ), "String data type isn't in set of supported string data types." + return {"fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str] + + +# Supported data types, as PyTorch types. +SUPPORTED_DTYPES: set[torch.dtype] = { + dtype_from_str(dtype_str) for dtype_str in SUPPORTED_DTYPES_STR +} + + +# Convert PyTorch data type to string data type. +def str_from_dtype(dtype: torch.dtype) -> str: + assert ( + dtype in SUPPORTED_DTYPES + ), "PyTorch data type isn't in set of supported PyTorch data types." + return {torch.float16: "fp16", torch.bfloat16: "bf16"}[dtype] + + +# Default data type, as string. +DTYPE_STR: str = "bf16" +assert ( + DTYPE_STR in SUPPORTED_DTYPES_STR +), "Default string data type isn't in set of supported string data types." + + +# Default data type, as PyTorch type. +DTYPE: torch.dtype = dtype_from_str(DTYPE_STR) + + +# Other defaults. +# ------------------------------------------------------------------------------ + +# Default device. +DEVICE: torch.device | str = "cuda" + +# Default RNG seed for input generation. +RNG_SEED: int = 0 + +# Default number of group sizes. +NUM_GROUP_SIZES: int = 1 + +# Default transposition (NN). +TRANS_LHS: bool = False +TRANS_RHS: bool = False + + +# Parameter checking functions. +# ------------------------------------------------------------------------------ + + +def is_power_of_2(x: int) -> bool: + return (x > 0) and (x & (x - 1) == 0) + + +def check_input_device_dtype( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor, bias: Tensor | None = None +) -> None: + assert ( + lhs.device == rhs.device == group_sizes.device + ), f"All input tensors must be in the same device (lhs = {lhs.device}, rhs = {rhs.device}, group_sizes = {group_sizes.device})." + assert ( + lhs.dtype == rhs.dtype + ), f"lhs and rhs types must match (lhs = {lhs.dtype}, rhs = {rhs.dtype})." + assert group_sizes.dtype == torch.int32, "group_sizes type must be int32." + + if bias is not None: + assert ( + bias.device == lhs.device + ), f"bias must be on the same device as lhs (bias = {bias.device}, lhs = {lhs.device})." + assert ( + bias.dtype == lhs.dtype + ), f"bias dtype must match lhs dtype (bias = {bias.dtype}, lhs = {lhs.dtype})." + + +def check_bias_shape_stride(bias: Tensor, G: int, N: int) -> None: + assert bias.shape == ( + G, + N, + ), f"bias must have shape (G, N) = ({G}, {N}), got {bias.shape}." + assert bias.stride() == (N, 1), "bias must be row-major (bias.stride() == (N, 1))." + + +# Generation of group sizes. +# ------------------------------------------------------------------------------ + + +# Probabilities for generating random group sizes. +UNUSED_TOKENS_PROB: float = 0.0 +UNUSED_EXPERTS_PROB: float = 0.1 + + +def gen_uniform_group_sizes( + M: int, + G: int, + device: torch.device | str = DEVICE, +) -> Tensor: + assert M >= 0, f"Number of tokens M must be non-negative (it's {M})." + assert G > 0, f"Number of experts G must be positive (it's {G})." + + base = M // G + remainder = M % G + group_sizes = torch.full((G,), base, dtype=torch.int32, device=device) + if remainder > 0: + group_sizes[:remainder] += 1 + + assert ( + len(group_sizes) == G + ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})." + assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative." + assert ( + torch.sum(group_sizes).item() == M + ), f"Group sizes don't add up to total tokens {M}." + assert group_sizes.dtype == torch.int32, "Group sizes must be int32." + + return group_sizes + + +def gen_group_sizes( + M: int, + G: int, + device: torch.device | str = DEVICE, + rng_seed: int | None = RNG_SEED, + unused_tokens_prob: float = UNUSED_TOKENS_PROB, + unused_experts_prob: float = UNUSED_EXPERTS_PROB, +) -> Tensor: + assert M >= 0, f"Number of tokens M must be non-negative (it's {M})." + assert G > 0, f"Number of experts G must be positive (it's {G})." + assert ( + 0 <= unused_tokens_prob <= 1 + ), f"Probability of unused tokens must be in [0, 1] interval (it's {unused_tokens_prob})." + assert ( + 0 <= unused_experts_prob <= 1 + ), f"Probability of unused experts must be in [0, 1] interval (it's {unused_experts_prob})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + if unused_tokens_prob > 0: + # Optionally drop tokens to simulate routing sparsity, some tokens may not be routed. + num_unused_tokens = M + while num_unused_tokens == M: + num_unused_tokens = int( + torch.binomial( + torch.tensor(float(M), device=device), + torch.tensor(unused_tokens_prob, device=device), + ).item() + ) + else: + num_unused_tokens = 0 + num_used_tokens = M - num_unused_tokens + assert ( + num_unused_tokens >= 0 + ), f"Number of unused tokens must be non-negative (it's {num_unused_tokens})." + assert ( + num_used_tokens > 0 + ), f"Number of used tokens must be positive (it's {num_used_tokens})." + assert ( + num_used_tokens + num_unused_tokens == M + ), f"Unused + used tokens don't add up total tokens ({num_used_tokens} + {num_unused_tokens} != {M})." + + if num_unused_tokens > 0: + _LOGGER.debug( + f"Group sizes generation: dropped {num_unused_tokens} token{'s' if num_unused_tokens > 1 else ''}.", + ) + + if unused_experts_prob > 0: + # Some experts may have zero tokens assigned to them. + num_used_experts = 0 + while num_used_experts == 0: + used_experts = torch.nonzero( + torch.rand((G,), device=device) >= unused_experts_prob + ).squeeze() + num_used_experts = used_experts.numel() + else: + used_experts = torch.arange(0, G, device=device) + num_used_experts = G + num_unused_experts = G - num_used_experts + assert ( + num_unused_experts >= 0 + ), f"Number of unused experts must be non-negative (it's {num_unused_experts})." + assert ( + num_used_experts >= 1 + ), f"At least one expert must be used (it's {num_used_experts})." + assert ( + num_unused_experts + num_used_experts == G + ), f"Unused + used experts don't add up total experts ({num_unused_experts} + {num_used_experts} != {G})." + + if num_unused_experts > 0: + _LOGGER.debug( + f"Group sizes generation: dropped {num_unused_experts} expert{'s' if num_unused_experts > 1 else ''}.", + ) + + group_sizes = torch.bincount( + used_experts[ + torch.randint(low=0, high=num_used_experts, size=(num_used_tokens,)) + ], + minlength=G, + ).to(torch.int32) + + assert ( + len(group_sizes) == G + ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})." + assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative." + assert ( + torch.sum(group_sizes).item() == num_used_tokens + ), f"Group sizes don't add up to used tokens {num_used_tokens}." + assert group_sizes.dtype == torch.int32, "Group sizes must be int32." + + return group_sizes + + +def gen_multiple_group_sizes( + num_group_sizes: int, + M: int, + G: int, + device: torch.device | str = DEVICE, + rng_seed: int | None = RNG_SEED, + unused_tokens_prob: float = UNUSED_TOKENS_PROB, + unused_experts_prob: float = UNUSED_EXPERTS_PROB, + group_sizes_0: Tensor | None = None, +) -> list[Tensor]: + assert ( + num_group_sizes > 0 + ), f"Number of group sizes to be generated must be positive, it's {num_group_sizes}." + multiple_group_sizes = [ + gen_group_sizes( + M, + G, + device=device, + rng_seed=rng_seed if g == 0 else None, + unused_tokens_prob=unused_tokens_prob, + unused_experts_prob=unused_experts_prob, + ) + for g in range( + num_group_sizes if group_sizes_0 is None else num_group_sizes - 1 + ) + ] + if group_sizes_0 is not None: + multiple_group_sizes.insert(0, group_sizes_0) + assert ( + len(multiple_group_sizes) == num_group_sizes + ), f"Expecting {num_group_sizes} distinct group sizes (it's {len(multiple_group_sizes)})." + return multiple_group_sizes + + +# GMM helpers: tensor generation. +# ------------------------------------------------------------------------------ + + +def gen_gmm_input( + M: int, + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + trans_rhs: bool = TRANS_RHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, +) -> tuple[Tensor, Tensor, Tensor]: + assert M > 0, f"Number of lhs rows M must be positive (M = {M})." + assert K > 0, f"Number of lhs columns / rhs rows K must be positive (K = {K})." + assert N > 0, f"Number of rhs columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + lhs = torch.randn((M, K), dtype=torch.float32, device=device) + lhs = lhs.to(preferred_element_type) + + if trans_rhs: + rhs = torch.randn((G, N, K), dtype=torch.float32, device=device).permute( + 0, 2, 1 + ) + else: + rhs = torch.randn((G, K, N), dtype=torch.float32, device=device) + rhs = rhs.to(preferred_element_type) + + group_sizes = ( + gen_uniform_group_sizes(M, G, device=device) + if unif_group_sizes + else gen_group_sizes(M, G, device=device, rng_seed=None) + ) + + return lhs, rhs, group_sizes + + +def gen_gmm_output( + M: int, + N: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, +) -> Tensor: + assert M > 0, f"Number of out rows M must be positive (M = {M})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + + out = torch.empty((M, N), dtype=preferred_element_type, device=device) + + return out + + +def gen_gmm_tensors( + M: int, + K: int, + N: int, + G: int, + num_group_sizes: int, + device: torch.device | str = DEVICE, + input_type: torch.dtype = DTYPE, + output_type: torch.dtype = DTYPE, + trans_lhs: bool = False, + trans_rhs: bool = TRANS_RHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, + use_bias: bool = False, +) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]: + lhs, rhs, group_sizes_0 = gen_gmm_input( + M, + K, + N, + G, + device=device, + preferred_element_type=input_type, + trans_rhs=trans_rhs, + rng_seed=rng_seed, + unif_group_sizes=unif_group_sizes, + ) + multiple_group_sizes = gen_multiple_group_sizes( + num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0 + ) + out = gen_gmm_output(M, N, device=device, preferred_element_type=output_type) + bias = None + if use_bias: + torch.manual_seed(rng_seed + 1000) # Different seed for bias + bias = torch.randn(G, N, dtype=input_type, device=device) + + return lhs, rhs, multiple_group_sizes, out, bias + + +# GMM helpers: get information from tensors. +# ------------------------------------------------------------------------------ + + +def get_gmm_shape( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor +) -> tuple[int, int, int, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})." + assert ( + group_sizes.dim() == 1 + ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})." + + M, lhs_k = lhs.shape + rhs_g, rhs_k, N = rhs.shape + group_sizes_g = group_sizes.shape[0] + + assert ( + lhs_k == rhs_k + ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})." + K = lhs_k + assert ( + rhs_g == group_sizes_g + ), f"G dimension of rhs and group_sizes don't match (rhs = {rhs_g}, group_sizes = {group_sizes_g})." + G = rhs_g + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + return M, K, N, G + + +def get_gmm_output( + M: int, + N: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, +) -> Tensor: + assert M > 0, f"Number of out rows M must be positive (M = {M})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + + if existing_out is not None: + assert ( + existing_out.device == device + ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})." + assert ( + existing_out.dtype == preferred_element_type + ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})." + assert existing_out.shape == ( + M, + N, + ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(M, N)})." + return existing_out + + return gen_gmm_output( + M, + N, + device=device, + preferred_element_type=preferred_element_type, + ) + + +def get_gmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})." + assert out.dim() == 2, f"out must have 2 dimensions (it's {out.dim()})." + + lhs_m, lhs_k = lhs.shape + G, rhs_k, rhs_n = rhs.shape + out_m, out_n = out.shape + + assert ( + lhs_m == out_m + ), f"M dimension of lhs and out don't match (lhs = {lhs_m}, rhs = {out_m})." + M = lhs_m + assert ( + lhs_k == rhs_k + ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})." + K = lhs_k + assert ( + rhs_n == out_n + ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})." + N = rhs_n + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + is_lhs_row_major = lhs.stride() == (K, 1) + assert is_lhs_row_major, "lhs must be row-major." + is_rhs_row_major = rhs.stride() == (K * N, N, 1) + is_rhs_col_major = rhs.stride() == (K * N, 1, K) + assert ( + is_rhs_row_major != is_rhs_col_major + ), "rhs must be row-major or column-major." + is_out_row_major = out.stride() == (N, 1) + assert is_out_row_major, "out must be row-major." + + # Get rhs leading dimension according to transposition configuration. + ld_rhs = N if is_rhs_row_major else K + + return is_rhs_col_major, ld_rhs + + +# TGMM helpers: tensor generation. +# ------------------------------------------------------------------------------ + + +def gen_tgmm_input( + M: int, + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + trans_lhs: bool = TRANS_LHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, +) -> tuple[Tensor, Tensor, Tensor]: + assert K > 0, f"Number of lhs rows K must be positive (M = {K})." + assert M > 0, f"Number of lhs columns / rhs rows M must be positive (K = {M})." + assert N > 0, f"Number of rhs columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + if trans_lhs: + lhs = torch.randn((M, K), dtype=torch.float32, device=device).T + else: + lhs = torch.randn((K, M), dtype=torch.float32, device=device) + lhs = lhs.to(preferred_element_type) + + rhs = torch.randn((M, N), dtype=torch.float32, device=device) + rhs = rhs.to(preferred_element_type) + + group_sizes = ( + gen_uniform_group_sizes(M, G, device=device) + if unif_group_sizes + else gen_group_sizes(M, G, device=device, rng_seed=None) + ) + + return lhs, rhs, group_sizes + + +def gen_tgmm_output( + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, +) -> Tensor: + assert K > 0, f"Number of out rows K must be positive (K = {K})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + out = torch.empty((G, K, N), dtype=preferred_element_type, device=device) + + return out + + +def gen_tgmm_bias_grad( + K: int, + G: int, + device: torch.device | str = DEVICE, + with_bias_grad: bool = False, +) -> Tensor: + if with_bias_grad: + assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + return torch.empty((G, K), device=device, dtype=torch.float32) + else: + # Return dummy pointer when bias_grad is not needed. + # Must be float32 because atomic_add does not support bf16/fp16, + # and Triton validates the pointer dtype even in dead branches. + return torch.tensor([], device=device, dtype=torch.float32) + + +def gen_tgmm_tensors( + M: int, + K: int, + N: int, + G: int, + num_group_sizes: int, + device: torch.device | str = DEVICE, + input_type: torch.dtype = DTYPE, + output_type: torch.dtype = DTYPE, + trans_lhs: bool = TRANS_LHS, + trans_rhs: bool = False, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, + use_bias: bool = False, +) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]: + lhs, rhs, group_sizes_0 = gen_tgmm_input( + M, + K, + N, + G, + device=device, + preferred_element_type=input_type, + trans_lhs=trans_lhs, + rng_seed=rng_seed, + unif_group_sizes=unif_group_sizes, + ) + multiple_group_sizes = gen_multiple_group_sizes( + num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0 + ) + out = gen_tgmm_output(K, N, G, device=device, preferred_element_type=output_type) + if use_bias: + bias_grad = gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=True) + else: + bias_grad = None + return lhs, rhs, multiple_group_sizes, out, bias_grad + + +# TGMM helpers: get information from tensors. +# ------------------------------------------------------------------------------ + + +def get_tgmm_shape( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor +) -> tuple[int, int, int, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})." + assert ( + group_sizes.dim() == 1 + ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})." + + K, lhs_m = lhs.shape + rhs_m, N = rhs.shape + G = group_sizes.shape[0] + + assert ( + lhs_m == rhs_m + ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})." + M = lhs_m + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + return M, K, N, G + + +def get_tgmm_output( + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, +) -> Tensor: + assert K > 0, f"Number of out rows K must be positive (K = {K})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if existing_out is not None: + assert ( + existing_out.device == device + ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})." + assert ( + existing_out.dtype == preferred_element_type + ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})." + assert existing_out.shape == ( + G, + K, + N, + ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(G, K, N)})." + return existing_out + + return gen_tgmm_output( + K, + N, + G, + device=device, + preferred_element_type=preferred_element_type, + ) + + +def get_tgmm_bias_grad( + K: int, + G: int, + device: torch.device | str = DEVICE, + existing_bias_grad: Tensor | None = None, +) -> Tensor: + """ + Get or validate bias gradient tensor for TGMM. + + If existing_bias_grad is provided, validates its shape, device, dtype, and stride, + and always zeros it before returning (since the kernel uses atomic_add). + If existing_bias_grad is None, returns a dummy tensor (for use when COMPUTE_BIAS_GRAD=False). + Parameters + ---------- + K : int + Number of rows in the bias gradient tensor. + G : int + Number of groups. + device : torch.device or str + Device for the tensor. + existing_bias_grad : torch.Tensor or None + Existing bias gradient tensor to validate and use. + Returns + ------- + torch.Tensor + Valid bias gradient tensor or dummy tensor. + """ + assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if existing_bias_grad is not None: + # Validate existing bias_grad tensor. + expected_shape = (G, K) + assert ( + tuple(existing_bias_grad.shape) == expected_shape + ), f"bias_grad must have shape {expected_shape}, got {tuple(existing_bias_grad.shape)}." + assert ( + existing_bias_grad.device == device + ), f"bias_grad must be on the same device (bias_grad = {existing_bias_grad.device}, device = {device})." + assert ( + existing_bias_grad.dtype == torch.float32 + ), f"bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), got {existing_bias_grad.dtype}." + assert existing_bias_grad.stride() == ( + K, + 1, + ), f"bias_grad must be row-major with stride (K, 1) = ({K}, 1), got {existing_bias_grad.stride()}." + + # Always zero the tensor since bias_grad represents gradients for the current + # computation and should start fresh. The kernel uses atomic_add which adds to + # existing values, so we must zero before the kernel runs. + existing_bias_grad.zero_() + + return existing_bias_grad + + else: + return gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=False) + + +def get_tgmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})." + assert out.dim() == 3, f"out must have 3 dimensions (it's {out.dim()})." + + lhs_k, lhs_m = lhs.shape + rhs_m, rhs_n = rhs.shape + G, out_k, out_n = out.shape + + assert ( + lhs_m == rhs_m + ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})." + M = lhs_m + assert ( + lhs_k == out_k + ), f"K dimension of lhs and out don't match (lhs = {lhs_k}, rhs = {out_k})." + K = lhs_k + assert ( + rhs_n == out_n + ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})." + N = rhs_n + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + is_lhs_row_major = lhs.stride() == (M, 1) + is_lhs_col_major = lhs.stride() == (1, K) + assert ( + is_lhs_row_major != is_lhs_col_major + ), "lhs must be row-major or column-major." + is_rhs_row_major = rhs.stride() == (N, 1) + assert is_rhs_row_major, "rhs must be row-major." + is_out_row_major = out.stride() == (K * N, N, 1) + assert is_out_row_major, "out must be row-major." + + # Get lhs leading dimension according to transposition configuration. + ld_lhs = M if is_lhs_row_major else K + + return is_lhs_col_major, ld_lhs diff --git a/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/logger.py b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..391ddf9b6543f5244e7f4932c8568d60748e15cd --- /dev/null +++ b/build/torch212-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/logger.py @@ -0,0 +1,47 @@ +import os +import logging + + +# AITER Triton Logger which is singleton object around python logging. +# Note: Python logging is also a singleton object, but we want to read the +# env var AITER_LOG_LEVEL once at the beginning. Another alternative is to do +# this in __init__.py. In fact, that's how CK logger is setup. We can look at +# switching to that at some point +# +# AITER_LOG_LEVEL follows python logging levels +# DEBUG +# INFO +# WARNING +# ERROR +# CRITICAL +# +class AiterTritonLogger(object): + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(AiterTritonLogger, cls).__new__(cls) + log_level_str = os.getenv("AITER_TRITON_LOG_LEVEL", "WARNING").upper() + numeric_level = getattr(logging, log_level_str, logging.WARNING) + cls._instance._logger = logging.getLogger("AITER_TRITON") + cls._instance._logger.setLevel(numeric_level) + + return cls._instance + + def get_logger(self): + return self._logger + + def debug(self, msg): + self._logger.debug(msg) + + def info(self, msg): + self._logger.info(msg) + + def warning(self, msg): + self._logger.warning(msg) + + def error(self, msg): + self._logger.error(msg) + + def critical(self, msg): + self._logger.critical(msg) diff --git a/build/torch212-cxx11-cu126-x86_64-linux/_megablocks_cuda_ae601bb.abi3.so b/build/torch212-cxx11-cu126-x86_64-linux/_megablocks_cuda_ae601bb.abi3.so deleted file mode 100644 index 31f3206b0d25a64c7aed091cf43f5b54e0f642ac..0000000000000000000000000000000000000000 --- a/build/torch212-cxx11-cu126-x86_64-linux/_megablocks_cuda_ae601bb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b90e38bb520661f87dee787a1dbe6350c84d0268372114ef0b68008dbf3f7ebc -size 13175016 diff --git a/build/torch212-cxx11-cu126-x86_64-linux/_megablocks_cuda_f8f8b50.abi3.so b/build/torch212-cxx11-cu126-x86_64-linux/_megablocks_cuda_f8f8b50.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..127003414e1596f6bd3f2731eeaa5dd1c89e284f --- /dev/null +++ b/build/torch212-cxx11-cu126-x86_64-linux/_megablocks_cuda_f8f8b50.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7091c728a7e3f83ed19fae6ce415a089693ac8ef64121abd99538a1777e17631 +size 13818088 diff --git a/build/torch212-cxx11-cu126-x86_64-linux/_ops.py b/build/torch212-cxx11-cu126-x86_64-linux/_ops.py index 8dd1b7bcf680d2d32dd4ac912487118eafcee4ea..69afb8c26a3fa2691be277b0270d600d29a5865e 100644 --- a/build/torch212-cxx11-cu126-x86_64-linux/_ops.py +++ b/build/torch212-cxx11-cu126-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_cuda_ae601bb -ops = torch.ops._megablocks_cuda_ae601bb +from . import _megablocks_cuda_f8f8b50 +ops = torch.ops._megablocks_cuda_f8f8b50 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_cuda_ae601bb::{op_name}" + return f"_megablocks_cuda_f8f8b50::{op_name}" diff --git a/build/torch212-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py b/build/torch212-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py index 76037d8039cbfc2f0577275c78e4bc0be762592a..c7ef28ced79c830dae934177f059c1f4ddc24aad 100644 --- a/build/torch212-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py +++ b/build/torch212-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py @@ -2,16 +2,16 @@ # extensions. Otherwise libc10.so cannot be found. import torch -# # TODO(tgale): Wrap this in a try-block with better -# # error message and instructions for building the -# # c++ operations. -# import grouped_gemm_backend as backend +# On ROCm there is no CUTLASS grouped GEMM; dispatch to the vendored AITER +# Triton kernels instead. On CUDA we use the compiled CUTLASS `gmm` op. +_IS_ROCM = torch.version.hip is not None -# We import the backend operations from the megablocks package as -# grouped_gemm is vendored in megablocks in this repository. -# from ... import _ops as backend -# from megablocks._ops import ops as backend # type: ignore -from .._ops import ops as backend # type: ignore +if _IS_ROCM: + from .._grouped_gemm_triton import adapter as backend +else: + # We import the backend operations from the megablocks package as + # grouped_gemm is vendored in megablocks in this repository. + from .._ops import ops as backend # type: ignore def _allocate_output(a, b, batch_sizes, trans_a, trans_b): assert not (trans_a and trans_b) diff --git a/build/torch212-cxx11-cu126-x86_64-linux/metadata.json b/build/torch212-cxx11-cu126-x86_64-linux/metadata.json index bc7202ab8d715cad1aee4e42b2a479b869543603..0843e9de0dd35d448c75ae7fc5a2eff09da63271 100644 --- a/build/torch212-cxx11-cu126-x86_64-linux/metadata.json +++ b/build/torch212-cxx11-cu126-x86_64-linux/metadata.json @@ -1,6 +1,6 @@ { "name": "megablocks", - "id": "_megablocks_cuda_ae601bb", + "id": "_megablocks_cuda_f8f8b50", "version": 1, "license": "Apache-2.0", "python-depends": [], @@ -14,7 +14,8 @@ "8.6", "8.7", "8.9", - "9.0" + "9.0", + "9.0+PTX" ] } } diff --git a/build/torch212-cxx11-cu130-x86_64-linux/__init__.py b/build/torch212-cxx11-cu130-x86_64-linux/__init__.py index 38075732c6d8fa0e1e6ef493145e1aca3851ae6b..0766d7b8da4f97baca212177b4bb989bc6374bf8 100644 --- a/build/torch212-cxx11-cu130-x86_64-linux/__init__.py +++ b/build/torch212-cxx11-cu130-x86_64-linux/__init__.py @@ -3,7 +3,9 @@ import torch -from ._ops import ops +# Stable alias: bare `ops` is shadowed by `from . import layers` below. +from ._ops import ops as _compiled_ops +from . import ops from .grouped_gemm import backend as gg_backend from .grouped_gemm import ops as gg_ops @@ -136,7 +138,8 @@ def sort( Returns: The sorted values tensor """ - return ops.sort(x, end_bit, x_out, iota_out) + _compiled_ops.sort(x, end_bit, x_out, iota_out) + return x_out # Convenience functions for common use cases diff --git a/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/__init__.py b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py new file mode 100644 index 0000000000000000000000000000000000000000..8c101d07cea416f9390b708e5a35fdc466e48aed --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py @@ -0,0 +1,574 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + + +# Imports. +# ------------------------------------------------------------------------------ + +# Python standard library +import functools + +# Triton +import triton +import triton.language as tl + +# AITER +from ..configs import CONFIGS as _CONFIGS +from ..utils._triton import arch_info +from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd + +# Kernel config. +# ------------------------------------------------------------------------------ + + +@functools.lru_cache() +def get_config( + gmm_type: str, M: int, K: int, N: int, G: int, accumulate: bool = False +) -> dict[str, int]: + assert gmm_type in { + "gmm", + "ptgmm", + "nptgmm", + }, f"'{gmm_type}' is an invalid GMM variant." + dev = arch_info.get_arch() + assert ( + dev in _CONFIGS + ), f"No GMM configuration tuned for arch '{dev}'. Supported: {sorted(_CONFIGS)}." + arch_configs = _CONFIGS[dev] + assert ( + "default" in arch_configs[gmm_type] + ), "Default configuration is absent." + key = "accumulate" if accumulate else "default" + return arch_configs[gmm_type][key] + + +# Common code shared by GMM and TGMM kernels. +# ------------------------------------------------------------------------------ + + +# XCD remapping followed by 1D PID to 2D grid mapping. +@triton.jit +def _remap_xcd_tile_grid( + tile_in_mm, + num_row_tiles, + num_col_tiles, + GROUP_SIZE: tl.constexpr = 1, + NUM_XCDS: tl.constexpr = 8, +): + return pid_grid( + remap_xcd(tile_in_mm, num_row_tiles * num_col_tiles, NUM_XCDS=NUM_XCDS), + num_row_tiles, + num_col_tiles, + GROUP_SIZE_M=GROUP_SIZE, + ) + + +# GMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.heuristics( + { + "K_DIVISIBLE_BY_BLOCK_SIZE_K": lambda META: META["K"] % META["BLOCK_SIZE_K"] + == 0, + } +) +@triton.jit +def gmm_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_RHS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + K_DIVISIBLE_BY_BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE: tl.constexpr, + GRID_DIM: tl.constexpr, + USE_BIAS: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + # Current tile. Each program computes multiple tiles of each group. + tile = tl.program_id(0) + tl.device_assert(tile >= 0, "tile < 0 (at initialization)") + + # Tile limit of last MM problem (inclusive). + last_mm_tile = 0 + + # Last input row of lhs and output row of out. Each group reads some rows of + # lhs and writes some rows to out. + last_m = 0 + + # Loop through all (m, K, N) MM problems: + # (m, K) x (K, N) = (m, N) + # sum(m) = M + for g in range(G): + # Get m dimension of current MM problem. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty + tl.device_assert(m >= 0, "m < 0") + + num_m_tiles = tl.cdiv(m, BLOCK_SIZE_M) + # num_m_tiles can be zero if group is empty + tl.device_assert(num_m_tiles >= 0, "num_m_tiles < 0") + + num_tiles = num_m_tiles * num_n_tiles + # num_tiles can be zero if group is empty + tl.device_assert(num_tiles >= 0, "num_tiles < 0") + + # Loop through tiles of current MM problem. + while tile >= last_mm_tile and tile < last_mm_tile + num_tiles: + # Figure out tile coordinates in current MM problem. + tile_in_mm = tile - last_mm_tile + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_m, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_m_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + # Do regular MM: + + tl.device_assert(tile_m * BLOCK_SIZE_M >= 0, "tile_m * BLOCK_SIZE_M < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_m = ( + tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + ) % m + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + + lhs_ptrs = lhs_ptr + (last_m + offs_lhs_m[:, None]) * K + offs_k[None, :] + + if TRANS_RHS: + rhs_ptrs = ( + rhs_ptr + + g.to(tl.int64) * K * N + + offs_k[:, None] + + offs_rhs_n[None, :] * K + ) + else: + rhs_ptrs = ( + rhs_ptr + + g.to(tl.int64) * K * N + + offs_k[:, None] * N + + offs_rhs_n[None, :] + ) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if K_DIVISIBLE_BY_BLOCK_SIZE_K: + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + else: + k_mask_limit = K - k * BLOCK_SIZE_K + lhs = tl.load( + lhs_ptrs, mask=offs_k[None, :] < k_mask_limit, other=0 + ) + rhs = tl.load( + rhs_ptrs, mask=offs_k[:, None] < k_mask_limit, other=0 + ) + + acc = tl.dot(lhs, rhs, acc=acc) + + lhs_ptrs += BLOCK_SIZE_K + + if TRANS_RHS: + rhs_ptrs += BLOCK_SIZE_K + else: + rhs_ptrs += BLOCK_SIZE_K * N + + # Add bias if enabled + if USE_BIAS: + offs_bias_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange( + 0, BLOCK_SIZE_N + ) + bias_ptrs = bias_ptr + g.to(tl.int64) * N + offs_bias_n + bias = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0) + # Convert bias to float32 to match accumulator precision + bias = bias.to(tl.float32) + # Broadcast bias across M dimension and add in float32 + acc += bias[None, :] + + # Convert to output dtype after all computations + acc = acc.to(out_ptr.type.element_ty) + + offs_out_m = tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + (last_m + offs_out_m[:, None]) * N + offs_out_n[None, :] + ) + + tl.store( + out_ptrs, + acc, + mask=(offs_out_m[:, None] < m) & (offs_out_n[None, :] < N), + ) + + # Go to the next tile by advancing number of programs. + tile += GRID_DIM + tl.device_assert(tile > 0, "tile <= 0 (at update)") + + # Get ready to go to the next MM problem. + + last_mm_tile += num_tiles + # last_mm_tile can be zero if group 0 is skipped + tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)") + + last_m += m + # last_m can be zero if group 0 is skipped + tl.device_assert(last_m >= 0, "last_m < 0 (at update)") + tl.device_assert(last_m <= M, "last_m > M (at update)") + + +# Persistent TGMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.jit +def tgmm_persistent_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_grad_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_LHS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE: tl.constexpr, + GRID_DIM: tl.constexpr, + COMPUTE_BIAS_GRAD: tl.constexpr, + ACCUMULATE: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0") + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + num_tiles = num_k_tiles * num_n_tiles + tl.device_assert(num_tiles > 0, "num_tiles <= 0") + + # Current tile. Each program computes multiple tiles of each group. + tile = tl.program_id(0) + tl.device_assert(tile >= 0, "tile < 0 (at initialization)") + + # Tile limit of last MM problem (inclusive). + last_mm_tile = 0 + + # Last input column of lhs and input row of rhs. Each group reads some + # columns of lhs and some rows of rhs. + last_m = 0 + + # Loop through all (K, m, N) MM problems: + # (K, m) x (m, N) = (K, N) + # sum(m) = M + for g in range(G): + # Get m dimension of current MM problem. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty + tl.device_assert(m >= 0, "m < 0") + + # Loop through tiles of current MM problem. + while tile >= last_mm_tile and tile < last_mm_tile + num_tiles: + # Figure out tile coordinates in current MM problem. + tile_in_mm = tile - last_mm_tile + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_k, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + # Do regular MM: + + tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + + if TRANS_LHS: + lhs_ptrs = ( + lhs_ptr + offs_lhs_k[:, None] + (last_m + offs_m[None, :]) * K + ) + else: + lhs_ptrs = ( + lhs_ptr + offs_lhs_k[:, None] * M + (last_m + offs_m[None, :]) + ) + + rhs_ptrs = rhs_ptr + (last_m + offs_m[:, None]) * N + offs_rhs_n[None, :] + + loop_m = tl.cdiv(m, BLOCK_SIZE_M) + m_divisible_by_block_m = m % BLOCK_SIZE_M == 0 + if not m_divisible_by_block_m: + loop_m -= 1 + + acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32) + + # Initialize bias accumulator + bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32) + + for _ in range(0, loop_m): + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + + acc = tl.dot(lhs, rhs, acc=acc) + + # Accumulate for bias gradient: sum lhs across M dimension + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum( + lhs, axis=1 + ) # Sum across M dimension [K, M] -> [K] + + if TRANS_LHS: + lhs_ptrs += BLOCK_SIZE_M * K + else: + lhs_ptrs += BLOCK_SIZE_M + + rhs_ptrs += BLOCK_SIZE_M * N + + if not m_divisible_by_block_m: + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0) + rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0) + acc = tl.dot(lhs, rhs, acc=acc) + + # Accumulate last chunk for bias gradient + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) + + acc = acc.to(out_ptr.type.element_ty) + + offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + + g.to(tl.int64) * K * N + + offs_out_k[:, None] * N + + offs_out_n[None, :] + ) + + mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N) + if ACCUMULATE: + # Load existing values and add to them (like beta=1 in BLAS) + old_vals = tl.load(out_ptrs, mask=mask, other=0.0) + tl.store(out_ptrs, acc + old_vals, mask=mask) + else: + # Overwrite output (like beta=0 in BLAS) + tl.store(out_ptrs, acc, mask=mask) + + # Store bias gradient (only for first N tile, sum across all M) + if COMPUTE_BIAS_GRAD and tile_n == 0: + # Keep as float32 for atomic_add (bf16 not supported for atomics) + bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k + # Use atomic add since multiple K-tiles may write to same expert's bias + tl.atomic_add( + bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed" + ) + + # Go to the next tile by advancing number of programs. + tile += GRID_DIM + tl.device_assert(tile > 0, "tile <= 0 (at update)") + + # Get ready to go to the next MM problem. + + last_mm_tile += num_tiles + # last_mm_tile can be zero if group 0 is skipped + tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)") + + last_m += m + # last_m can be zero if group 0 is skipped + tl.device_assert(last_m >= 0, "last_m < 0 (at update)") + tl.device_assert(last_m <= M, "last_m > M (at update)") + + +# Regular non-persistent TGMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.heuristics({"BLOCK_SIZE_G": lambda META: triton.next_power_of_2(META["G"])}) +@triton.jit +def tgmm_non_persistent_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_grad_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_LHS: tl.constexpr, + BLOCK_SIZE_G: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE: tl.constexpr, + COMPUTE_BIAS_GRAD: tl.constexpr, + ACCUMULATE: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + # Get group ID from grid. + g = tl.program_id(0) + tl.device_assert(g >= 0, "g < 0") + tl.device_assert(g < G, "g >= G") + + # Get m dimension of current MM group. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty. + tl.device_assert(m >= 0, "m < 0") + + # Skip empty groups. + if m == 0: + return + + # Compute sum(group_sizes) until current group g. + # It's the starting column of lhs and starting row of rhs. + offs_g = tl.arange(0, BLOCK_SIZE_G) + group_sizes = tl.load(group_sizes_ptr + offs_g, mask=offs_g < g, other=0) + start_m = tl.sum(group_sizes) + + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0") + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + # Get MM tile from grid. + tile_in_mm = tl.program_id(1) + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_k, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_k = (tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K + offs_rhs_n = (tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + + if TRANS_LHS: + lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] + (start_m + offs_m[None, :]) * K + else: + lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] * M + (start_m + offs_m[None, :]) + + rhs_ptrs = rhs_ptr + (start_m + offs_m[:, None]) * N + offs_rhs_n[None, :] + + loop_m = tl.cdiv(m, BLOCK_SIZE_M) + m_divisible_by_block_m = m % BLOCK_SIZE_M == 0 + if not m_divisible_by_block_m: + loop_m -= 1 + + acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32) + # Initialize bias accumulator + bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32) + + for _ in range(0, loop_m): + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + + acc = tl.dot(lhs, rhs, acc=acc) + + # Accumulate for bias gradient: sum lhs across M dimension + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) # [K, M] -> [K] + + if TRANS_LHS: + lhs_ptrs += BLOCK_SIZE_M * K + else: + lhs_ptrs += BLOCK_SIZE_M + + rhs_ptrs += BLOCK_SIZE_M * N + + if not m_divisible_by_block_m: + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0) + rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0) + acc = tl.dot(lhs, rhs, acc=acc) + # Accumulate last chunk for bias gradient + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) + + acc = acc.to(out_ptr.type.element_ty) + + offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + g.to(tl.int64) * K * N + offs_out_k[:, None] * N + offs_out_n[None, :] + ) + + mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N) + if ACCUMULATE: + # Load existing values and add to them (like beta=1 in BLAS) + old_vals = tl.load(out_ptrs, mask=mask, other=0.0) + tl.store(out_ptrs, acc + old_vals, mask=mask) + else: + # Overwrite output (like beta=0 in BLAS) + tl.store(out_ptrs, acc, mask=mask) + + # Store bias gradient (only for first N tile, sum across all M) + if COMPUTE_BIAS_GRAD and tile_n == 0: + # Keep as float32 for atomic_add (bf16/fp16 not supported for atomics) + bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k + # Use atomic add since multiple K-tiles may write to same expert's bias + tl.atomic_add(bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed") diff --git a/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/adapter.py b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..98c224244f27445384e0c2377d73516406927536 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/adapter.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Adapt AITER's Triton grouped GEMM to MegaBlocks' ``gmm`` calling convention. + +MegaBlocks (following tgale96/grouped_gemm) uses a single ``gmm`` entry point +with ``trans_a`` / ``trans_b`` flags: + +* ``trans_a=False, trans_b=False``: a(M,K) @ b(G,K,N) -> c(M,N) +* ``trans_a=False, trans_b=True`` : a(M,K) @ b(G,N,K)^T -> c(M,N) (dgrad) +* ``trans_a=True`` : a(M,K)^T @ b(M,N) per group -> c(G,K,N) (wgrad) + +AITER exposes these as two kernels: ``gmm`` ((M,K)@(G,K,N)->(M,N), transposition +of the 3D operand inferred from strides) and ``ptgmm`` ((K,M)@(M,N)->(G,K,N), +transposition of the 2D operand inferred from strides). +""" + +import torch + +from .gmm import gmm as _aiter_gmm +from .gmm import ptgmm as _aiter_ptgmm + + +def gmm(a, b, c, batch_sizes, trans_a=False, trans_b=False): + # AITER requires group sizes to be int32 and to live on the compute device. + group_sizes = batch_sizes.to(device=a.device, dtype=torch.int32) + + # AITER asserts exact strides: gmm wants lhs/rhs row-major (a transposed + # 3D operand must be exactly column-major), tgmm wants rhs row-major and + # lhs row/column-major. Make operands contiguous first so the transposed + # views have the precise strides the kernels expect. `.contiguous()` is a + # no-op when the tensor is already contiguous. + if trans_a: + # Weight gradient: a(M,K), b(M,N) -> c(G,K,N). + # Pass a transposed so AITER sees lhs(K,M) column-major (TRANS_LHS). + _aiter_ptgmm( + a.contiguous().transpose(0, 1), + b.contiguous(), + group_sizes, + preferred_element_type=c.dtype, + existing_out=c, + ) + else: + # trans_b contracts b's last dim: pass a column-major (G,K,N) view. + rhs = b.contiguous() + if trans_b: + rhs = rhs.transpose(1, 2) + _aiter_gmm( + a.contiguous(), + rhs, + group_sizes, + preferred_element_type=c.dtype, + existing_out=c, + ) + return c diff --git a/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/configs.py b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..9a4fe5617d8100869aa76dba9b7d22c7bcab814f --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/configs.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: MIT +# Tuned GMM configs vendored from ROCm/aiter (aiter/ops/triton/configs/). +# Inlined as a Python module so packaging always includes them. + +CONFIGS = {'gfx1250': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx942': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx950': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}} diff --git a/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/gmm.py b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/gmm.py new file mode 100644 index 0000000000000000000000000000000000000000..e30c9326c6d4e4836d1303e2761ea2440a7f4750 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/gmm.py @@ -0,0 +1,567 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + + +# Imports. +# ------------------------------------------------------------------------------ + +# PyTorch +import torch +from torch import Tensor + +# Triton +import triton + +# AITER: GMM utility functions +from .utils.gmm_common import ( + DTYPE, + is_power_of_2, + check_input_device_dtype, + check_bias_shape_stride, + get_gmm_shape, + get_gmm_output, + get_gmm_transposition, + get_tgmm_shape, + get_tgmm_output, + get_tgmm_bias_grad, + get_tgmm_transposition, +) + +# AITER: GMM Triton kernels +from ._triton_kernels.gmm import ( + gmm_kernel, + tgmm_persistent_kernel, + tgmm_non_persistent_kernel, + get_config, +) + +# GMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _gmm_grid( + N: int, + block_size_m: int, + block_size_n: int, + group_sizes: Tensor, + grid_dim: int, +) -> tuple[int]: + assert N > 0, f"N must be positive, it's {N}." + assert is_power_of_2( + block_size_m + ), f"M-dimension tile size must be a power of 2 (it's {block_size_m})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + assert torch.all(group_sizes >= 0).item(), "All group_sizes must be non-negative." + assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})." + num_m_tiles = (group_sizes + block_size_m - 1) // block_size_m + assert torch.all(num_m_tiles >= 0).item(), "All num_m_tiles must be non-negative." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles = torch.sum(num_m_tiles * num_n_tiles).item() + assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}." + num_programs = int(min(grid_dim, num_tiles)) + assert num_programs > 0, f"num_programs must be positive, it's {num_programs}." + return (num_programs,) + + +def gmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias: Tensor | None = None, +) -> Tensor: + """ + Perform Group Matrix Multiplication (GMM): out = lhs @ rhs + bias + + lhs rows are divided into G groups. Each group of lhs rows is matrix multiplied with a plane of + rhs 3D tensor and then stored in a slice of out. In PyTorch parlance, it can be implemented as + follows for a given group g: + out[group_start:group_end, :] = lhs[group_start:group_end, :] @ rhs[g] + bias[g] + + The size of each group, and their respective start and end positions are specified by + group_sizes tensor. For instance, suppose that group_sizes = [3, 2, 4, 1]. In this particular + case we have 4 groups. The 1st group starts at 0 and ends at 2, the second group starts at 3 and + ends at 4, the third group starts at 5 and ends at 8, and the fourth and final group consists of + just the 10th (last) row of lhs. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (M, K). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 3D input tensor. Shape: (G, K, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (M, N), its data type must match preferred_element_type + and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias : torch.Tensor or None, optional + Optional bias tensor. Shape: (G, N). + If provided, bias data type must match lhs and rhs data type, and bias must be on the same + device as other input tensors. Each group g adds bias[g] to the output. + + Returns + ------- + torch.Tensor + The computed output 2D tensor. Shape: (M, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - GMM is implemented with a persistent Triton kernel. + - lhs must be row-major (lhs.stride() == (K, 1)). + - rhs can be row-major (rhs.stride() == (K * N, N, 1)) or column-major (rhs.stride() == + (K * N, 1, K)). If rhs is row-major then kernel parameter TRANS_RHS == False, this is useful + for implementing forward pass. If rhs is column-major then kernel parameter TRANS_RHS == True, + this is useful for computing the lhs derivative in the backward pass, while fusing the + transposition. + - out must be row-major (out.stride() == (N, 1)). + - bias must be row-major (bias.stride() == (N, 1)) if provided. + """ + use_bias = bias is not None + check_input_device_dtype(lhs, rhs, group_sizes, bias) + + M, K, N, G = get_gmm_shape(lhs, rhs, group_sizes) + + if use_bias: + check_bias_shape_stride(bias, G, N) + + out = get_gmm_output( + M, + N, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_rhs, _ = get_gmm_transposition(lhs, rhs, out) + + if config is None: + config = get_config("gmm", M, K, N, G) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + "GRID_DIM", + } + ), "Invalid GMM kernel config." + + grid = _gmm_grid( + N, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + group_sizes, + config["GRID_DIM"], + ) + + # fmt: off + gmm_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_RHS=trans_rhs, + USE_BIAS=use_bias, + **config, + ) + # fmt: on + + return out + + +# Persistent TGMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _ptgmm_grid( + K: int, + N: int, + G: int, + block_size_k: int, + block_size_n: int, + grid_dim: int, +) -> tuple[int]: + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}." + assert G > 0, f"G must be positive, it's {G}." + assert is_power_of_2( + block_size_k + ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})." + num_k_tiles = triton.cdiv(K, block_size_k) + assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles = G * num_k_tiles * num_n_tiles + assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}." + num_programs = min(grid_dim, num_tiles) + assert num_programs > 0, f"num_programs must be positive, it's {num_programs}." + return (num_programs,) + + +def ptgmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias_grad: Tensor | None = None, + accumulate: bool = False, +) -> Tensor: + """ + Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs + + lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with + the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch + parlance, it can be implemented as follows for a given group g: + out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :] + + The 't' in the operator name derives from MaxText implementation + (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py), + which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor + shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N). + + The 'p' in the operator name means that it is implemented with a persistent kernel. There is + also the non-persistent variation, which is implemented with a regular kernel. Please take a + look at nptgmm operator. Both ptgmm and nptgmm implement the same computation, choosing one or + the other is a matter of performance for the target workload. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (K, M). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 2D input tensor. Shape: (M, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (G, K, N), its data type must match + preferred_element_type and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias_grad : torch.Tensor or None, optional + Optional bias gradient output tensor. Shape: (G, K). + If provided, the kernel will compute the bias gradient and write it to this tensor. + bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), + accumulate : bool, optional + Whether to accumulate into existing output tensor values. Default is False. + If False, output will be overwritten with fresh computation. + If True, results will be added to existing output tensor values. + + Returns + ------- + torch.Tensor + The computed output 3D tensor. Shape: (G, K, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - PTGMM is implemented with a persistent Triton kernel. + - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs + is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel + parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward + pass, while fusing the transposition. + - rhs must be row-major (rhs.stride() == (N, 1)). + - out must be row-major (out.stride() == (K * N, N, 1)). + """ + check_input_device_dtype(lhs, rhs, group_sizes) + + M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes) + + out = get_tgmm_output( + K, + N, + G, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out) + + if config is None: + config = get_config("ptgmm", M, K, N, G, accumulate) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + "GRID_DIM", + } + ), "Invalid PTGMM kernel config." + + # Bias gradient handling. + # ----------------------- + # Get or validate bias gradient tensor. + compute_bias_grad = bias_grad is not None + bias_grad_ptr = get_tgmm_bias_grad( + K, + G, + device=lhs.device, + existing_bias_grad=bias_grad, + ) + + grid = _ptgmm_grid( + K, + N, + G, + config["BLOCK_SIZE_K"], + config["BLOCK_SIZE_N"], + config["GRID_DIM"], + ) + + # fmt: off + tgmm_persistent_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias_grad_ptr, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_LHS=trans_lhs, + COMPUTE_BIAS_GRAD=compute_bias_grad, + ACCUMULATE=accumulate, + **config, + ) + # fmt: on + + return out + + +# Regular non-persistent TGMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _nptgmm_grid( + K: int, + N: int, + G: int, + block_size_k: int, + block_size_n: int, +) -> tuple[int, int]: + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}." + assert G > 0, f"G must be positive, it's {G}." + assert is_power_of_2( + block_size_k + ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + num_k_tiles = triton.cdiv(K, block_size_k) + assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles_per_mm = num_k_tiles * num_n_tiles + assert ( + num_tiles_per_mm > 0 + ), f"num_tiles_per_mm must be positive, it's {num_tiles_per_mm}." + return (G, num_tiles_per_mm) + + +def nptgmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias_grad: Tensor | None = None, + accumulate: bool = False, +) -> Tensor: + """ + Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs + + lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with + the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch + parlance, it can be implemented as follows for a given group g: + out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :] + + The 't' in the operator name derives from MaxText implementation + (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py), + which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor + shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N). + + The 'np' in the operator name means that it is implemented with a non-persistent, i.e. regular + kernel. There is also the persistent variation, which is implemented with a persistent kernel. + Please take a look at ptgmm operator. Both nptgmm and ptgmm implement the same computation, + choosing one or the other is a matter of performance for the target workload. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (K, M). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 2D input tensor. Shape: (M, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (G, K, N), its data type must match + preferred_element_type and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias_grad : torch.Tensor or None, optional + Optional bias gradient output tensor. Shape: (G, K). + If provided, the kernel will compute the bias gradient and write it to this tensor. + bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), + accumulate : bool, optional + Whether to accumulate into existing output tensor values. Default is False. + If False, output will be overwritten with fresh computation. + If True, results will be added to existing output tensor values. + + Returns + ------- + torch.Tensor + The computed output 3D tensor. Shape: (G, K, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - NPTGMM is implemented with a non-persistent regular Triton kernel. + - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs + is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel + parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward + pass, while fusing the transposition. + - rhs must be row-major (rhs.stride() == (N, 1)). + - out must be row-major (out.stride() == (K * N, N, 1)). + """ + check_input_device_dtype(lhs, rhs, group_sizes) + + M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes) + + out = get_tgmm_output( + K, + N, + G, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out) + + # Bias gradient handling. + # ----------------------- + # Get or validate bias gradient tensor. + compute_bias_grad = bias_grad is not None + bias_grad_ptr = get_tgmm_bias_grad( + K, + G, + device=lhs.device, + existing_bias_grad=bias_grad, + ) + + if config is None: + config = get_config("nptgmm", M, K, N, G, accumulate) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + } + ), "Invalid NPTGMM kernel config." + + grid = _nptgmm_grid( + K, + N, + G, + config["BLOCK_SIZE_K"], + config["BLOCK_SIZE_N"], + ) + + # fmt: off + tgmm_non_persistent_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias_grad_ptr, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_LHS=trans_lhs, + COMPUTE_BIAS_GRAD=compute_bias_grad, + ACCUMULATE=accumulate, + **config, + ) + # fmt: on + + return out diff --git a/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/__init__.py b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py new file mode 100644 index 0000000000000000000000000000000000000000..3f6c88581a64044518125623f116082c53bd5474 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py @@ -0,0 +1,46 @@ +import triton + +# Detect the GPU arch lazily: querying the triton driver at import time fails +# in headless environments (e.g. the kernel-builder ABI check sandbox has no +# GPU), and the original JAX fallback pulled in an unrelated runtime dep. The +# arch is only actually needed when a GMM kernel is dispatched, so resolve and +# cache on first call. +_CACHED_ARCH = None + + +def get_arch(): + global _CACHED_ARCH + if _CACHED_ARCH is not None: + return _CACHED_ARCH + try: + _CACHED_ARCH = triton.runtime.driver.active.get_current_target().arch + except RuntimeError: + try: + from jax._src.lib import gpu_triton as triton_kernel_call_lib + _CACHED_ARCH = triton_kernel_call_lib.get_arch_details("0").split(":")[0] + except ImportError as e: + raise RuntimeError( + "Cannot determine GPU arch: triton driver is inactive and " + "JAX is not available. A GPU is required for grouped GEMM." + ) from e + return _CACHED_ARCH + + +def is_gluon_avail(): + return get_arch() in ("gfx950", "gfx1250") + + +def is_fp4_avail(): + return get_arch() in ("gfx950", "gfx1250") + + +def is_fp8_avail(): + return get_arch() in ("gfx942", "gfx950", "gfx1250", "gfx1200", "gfx1201") + + +def is_mx_scale_preshuffling_avail(): + return get_arch() in ("gfx950", "gfx1250") + + +def is_tdm_avail(): + return get_arch() in ("gfx1250",) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..99792bb3ba2fab8fff223bba733ced1eb6e6df53 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: MIT + +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl + + +@triton.jit +def remap_xcd_chunked( + pid, GRID_MN, NUM_XCDS: tl.constexpr = 8, CHUNK_SIZE: tl.constexpr = 2 +): + # Compute current XCD and local PID + xcd = pid % NUM_XCDS + # distribute the modulo pids in round robin + if pid > (GRID_MN // (NUM_XCDS * CHUNK_SIZE)) * (NUM_XCDS * CHUNK_SIZE): + return pid + local_pid = pid // NUM_XCDS + # Calculate chunk index and position within chunk + chunk_idx = local_pid // CHUNK_SIZE + pos_in_chunk = local_pid % CHUNK_SIZE + # Calculate new PID + new_pid = chunk_idx * NUM_XCDS * CHUNK_SIZE + xcd * CHUNK_SIZE + pos_in_chunk + return new_pid + + +@triton.jit +def remap_xcd(pid, GRID_MN, NUM_XCDS: tl.constexpr = 8): + ## pid remapping on xcds + # Number of pids per XCD in the new arrangement + pids_per_xcd = (GRID_MN + NUM_XCDS - 1) // NUM_XCDS + # When GRID_MN cannot divide NUM_XCDS, some xcds will have + # pids_per_xcd pids, the other will have pids_per_xcd - 1 pids. + # We calculate the number of xcds that have pids_per_xcd pids as + # tall_xcds + tall_xcds = GRID_MN % NUM_XCDS + tall_xcds = NUM_XCDS if tall_xcds == 0 else tall_xcds + # Compute current XCD and local pid within the XCD + xcd = pid % NUM_XCDS + local_pid = pid // NUM_XCDS + # Calculate new pid based on the new grouping + # Note that we need to consider the following two cases: + # 1. the current pid is on a tall xcd + # 2. the current pid is on a short xcd + if xcd < tall_xcds: + pid = xcd * pids_per_xcd + local_pid + else: + pid = ( + tall_xcds * pids_per_xcd + + (xcd - tall_xcds) * (pids_per_xcd - 1) + + local_pid + ) + + return pid + + +@triton.jit +def pid_grid(pid: int, num_pid_m: int, num_pid_n: int, GROUP_SIZE_M: tl.constexpr = 1): + """ + Maps 1D pid to 2D grid coords (pid_m, pid_n). + + Args: + - pid: 1D pid + - num_pid_m: grid m size + - num_pid_n: grid n size + - GROUP_SIZE_M: tl.constexpr: default is 1 + """ + if GROUP_SIZE_M == 1: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + tl.assume(group_size_m >= 0) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + return pid_m, pid_n + + +@triton.jit +def pid_grid_3d(pid: int, num_pid_m: int, num_pid_n: int, num_pid_k): + """ + Maps 1D pid to 3D grid coords (pid_m, pid_n, pid_k). + Args: + - pid: 1D pid + - num_pid_m: grid m size + - num_pid_n: grid n size + - num_pid_k: grid k size + + Returns: + - pid_m, pid_n, pid_k: 3D grid coordinates + """ + pid_m = pid % num_pid_m + pid_n = (pid // num_pid_m) % num_pid_n + pid_k = pid // (num_pid_m * num_pid_n) % num_pid_k + + return pid_m, pid_n, pid_k diff --git a/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py new file mode 100644 index 0000000000000000000000000000000000000000..153dee65b50ab5f833262481889d2184d1ca639f --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py @@ -0,0 +1,752 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + +# Imports. +# ------------------------------------------------------------------------------ + +# PyTorch +import torch +from torch import Tensor + +# AITER: logging +from .logger import AiterTritonLogger + +_LOGGER: AiterTritonLogger = AiterTritonLogger() + + +# Supported data types. +# ------------------------------------------------------------------------------ + +# Supported data types, as strings. +SUPPORTED_DTYPES_STR: set[str] = {"fp16", "bf16"} + + +# Convert string data type to PyTorch data type. +def dtype_from_str(dtype_str: str) -> torch.dtype: + dtype_str = dtype_str.strip().lower() + dtype_str = dtype_str[1:] if dtype_str[0] in {"i", "o"} else dtype_str + assert ( + dtype_str in SUPPORTED_DTYPES_STR + ), "String data type isn't in set of supported string data types." + return {"fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str] + + +# Supported data types, as PyTorch types. +SUPPORTED_DTYPES: set[torch.dtype] = { + dtype_from_str(dtype_str) for dtype_str in SUPPORTED_DTYPES_STR +} + + +# Convert PyTorch data type to string data type. +def str_from_dtype(dtype: torch.dtype) -> str: + assert ( + dtype in SUPPORTED_DTYPES + ), "PyTorch data type isn't in set of supported PyTorch data types." + return {torch.float16: "fp16", torch.bfloat16: "bf16"}[dtype] + + +# Default data type, as string. +DTYPE_STR: str = "bf16" +assert ( + DTYPE_STR in SUPPORTED_DTYPES_STR +), "Default string data type isn't in set of supported string data types." + + +# Default data type, as PyTorch type. +DTYPE: torch.dtype = dtype_from_str(DTYPE_STR) + + +# Other defaults. +# ------------------------------------------------------------------------------ + +# Default device. +DEVICE: torch.device | str = "cuda" + +# Default RNG seed for input generation. +RNG_SEED: int = 0 + +# Default number of group sizes. +NUM_GROUP_SIZES: int = 1 + +# Default transposition (NN). +TRANS_LHS: bool = False +TRANS_RHS: bool = False + + +# Parameter checking functions. +# ------------------------------------------------------------------------------ + + +def is_power_of_2(x: int) -> bool: + return (x > 0) and (x & (x - 1) == 0) + + +def check_input_device_dtype( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor, bias: Tensor | None = None +) -> None: + assert ( + lhs.device == rhs.device == group_sizes.device + ), f"All input tensors must be in the same device (lhs = {lhs.device}, rhs = {rhs.device}, group_sizes = {group_sizes.device})." + assert ( + lhs.dtype == rhs.dtype + ), f"lhs and rhs types must match (lhs = {lhs.dtype}, rhs = {rhs.dtype})." + assert group_sizes.dtype == torch.int32, "group_sizes type must be int32." + + if bias is not None: + assert ( + bias.device == lhs.device + ), f"bias must be on the same device as lhs (bias = {bias.device}, lhs = {lhs.device})." + assert ( + bias.dtype == lhs.dtype + ), f"bias dtype must match lhs dtype (bias = {bias.dtype}, lhs = {lhs.dtype})." + + +def check_bias_shape_stride(bias: Tensor, G: int, N: int) -> None: + assert bias.shape == ( + G, + N, + ), f"bias must have shape (G, N) = ({G}, {N}), got {bias.shape}." + assert bias.stride() == (N, 1), "bias must be row-major (bias.stride() == (N, 1))." + + +# Generation of group sizes. +# ------------------------------------------------------------------------------ + + +# Probabilities for generating random group sizes. +UNUSED_TOKENS_PROB: float = 0.0 +UNUSED_EXPERTS_PROB: float = 0.1 + + +def gen_uniform_group_sizes( + M: int, + G: int, + device: torch.device | str = DEVICE, +) -> Tensor: + assert M >= 0, f"Number of tokens M must be non-negative (it's {M})." + assert G > 0, f"Number of experts G must be positive (it's {G})." + + base = M // G + remainder = M % G + group_sizes = torch.full((G,), base, dtype=torch.int32, device=device) + if remainder > 0: + group_sizes[:remainder] += 1 + + assert ( + len(group_sizes) == G + ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})." + assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative." + assert ( + torch.sum(group_sizes).item() == M + ), f"Group sizes don't add up to total tokens {M}." + assert group_sizes.dtype == torch.int32, "Group sizes must be int32." + + return group_sizes + + +def gen_group_sizes( + M: int, + G: int, + device: torch.device | str = DEVICE, + rng_seed: int | None = RNG_SEED, + unused_tokens_prob: float = UNUSED_TOKENS_PROB, + unused_experts_prob: float = UNUSED_EXPERTS_PROB, +) -> Tensor: + assert M >= 0, f"Number of tokens M must be non-negative (it's {M})." + assert G > 0, f"Number of experts G must be positive (it's {G})." + assert ( + 0 <= unused_tokens_prob <= 1 + ), f"Probability of unused tokens must be in [0, 1] interval (it's {unused_tokens_prob})." + assert ( + 0 <= unused_experts_prob <= 1 + ), f"Probability of unused experts must be in [0, 1] interval (it's {unused_experts_prob})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + if unused_tokens_prob > 0: + # Optionally drop tokens to simulate routing sparsity, some tokens may not be routed. + num_unused_tokens = M + while num_unused_tokens == M: + num_unused_tokens = int( + torch.binomial( + torch.tensor(float(M), device=device), + torch.tensor(unused_tokens_prob, device=device), + ).item() + ) + else: + num_unused_tokens = 0 + num_used_tokens = M - num_unused_tokens + assert ( + num_unused_tokens >= 0 + ), f"Number of unused tokens must be non-negative (it's {num_unused_tokens})." + assert ( + num_used_tokens > 0 + ), f"Number of used tokens must be positive (it's {num_used_tokens})." + assert ( + num_used_tokens + num_unused_tokens == M + ), f"Unused + used tokens don't add up total tokens ({num_used_tokens} + {num_unused_tokens} != {M})." + + if num_unused_tokens > 0: + _LOGGER.debug( + f"Group sizes generation: dropped {num_unused_tokens} token{'s' if num_unused_tokens > 1 else ''}.", + ) + + if unused_experts_prob > 0: + # Some experts may have zero tokens assigned to them. + num_used_experts = 0 + while num_used_experts == 0: + used_experts = torch.nonzero( + torch.rand((G,), device=device) >= unused_experts_prob + ).squeeze() + num_used_experts = used_experts.numel() + else: + used_experts = torch.arange(0, G, device=device) + num_used_experts = G + num_unused_experts = G - num_used_experts + assert ( + num_unused_experts >= 0 + ), f"Number of unused experts must be non-negative (it's {num_unused_experts})." + assert ( + num_used_experts >= 1 + ), f"At least one expert must be used (it's {num_used_experts})." + assert ( + num_unused_experts + num_used_experts == G + ), f"Unused + used experts don't add up total experts ({num_unused_experts} + {num_used_experts} != {G})." + + if num_unused_experts > 0: + _LOGGER.debug( + f"Group sizes generation: dropped {num_unused_experts} expert{'s' if num_unused_experts > 1 else ''}.", + ) + + group_sizes = torch.bincount( + used_experts[ + torch.randint(low=0, high=num_used_experts, size=(num_used_tokens,)) + ], + minlength=G, + ).to(torch.int32) + + assert ( + len(group_sizes) == G + ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})." + assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative." + assert ( + torch.sum(group_sizes).item() == num_used_tokens + ), f"Group sizes don't add up to used tokens {num_used_tokens}." + assert group_sizes.dtype == torch.int32, "Group sizes must be int32." + + return group_sizes + + +def gen_multiple_group_sizes( + num_group_sizes: int, + M: int, + G: int, + device: torch.device | str = DEVICE, + rng_seed: int | None = RNG_SEED, + unused_tokens_prob: float = UNUSED_TOKENS_PROB, + unused_experts_prob: float = UNUSED_EXPERTS_PROB, + group_sizes_0: Tensor | None = None, +) -> list[Tensor]: + assert ( + num_group_sizes > 0 + ), f"Number of group sizes to be generated must be positive, it's {num_group_sizes}." + multiple_group_sizes = [ + gen_group_sizes( + M, + G, + device=device, + rng_seed=rng_seed if g == 0 else None, + unused_tokens_prob=unused_tokens_prob, + unused_experts_prob=unused_experts_prob, + ) + for g in range( + num_group_sizes if group_sizes_0 is None else num_group_sizes - 1 + ) + ] + if group_sizes_0 is not None: + multiple_group_sizes.insert(0, group_sizes_0) + assert ( + len(multiple_group_sizes) == num_group_sizes + ), f"Expecting {num_group_sizes} distinct group sizes (it's {len(multiple_group_sizes)})." + return multiple_group_sizes + + +# GMM helpers: tensor generation. +# ------------------------------------------------------------------------------ + + +def gen_gmm_input( + M: int, + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + trans_rhs: bool = TRANS_RHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, +) -> tuple[Tensor, Tensor, Tensor]: + assert M > 0, f"Number of lhs rows M must be positive (M = {M})." + assert K > 0, f"Number of lhs columns / rhs rows K must be positive (K = {K})." + assert N > 0, f"Number of rhs columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + lhs = torch.randn((M, K), dtype=torch.float32, device=device) + lhs = lhs.to(preferred_element_type) + + if trans_rhs: + rhs = torch.randn((G, N, K), dtype=torch.float32, device=device).permute( + 0, 2, 1 + ) + else: + rhs = torch.randn((G, K, N), dtype=torch.float32, device=device) + rhs = rhs.to(preferred_element_type) + + group_sizes = ( + gen_uniform_group_sizes(M, G, device=device) + if unif_group_sizes + else gen_group_sizes(M, G, device=device, rng_seed=None) + ) + + return lhs, rhs, group_sizes + + +def gen_gmm_output( + M: int, + N: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, +) -> Tensor: + assert M > 0, f"Number of out rows M must be positive (M = {M})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + + out = torch.empty((M, N), dtype=preferred_element_type, device=device) + + return out + + +def gen_gmm_tensors( + M: int, + K: int, + N: int, + G: int, + num_group_sizes: int, + device: torch.device | str = DEVICE, + input_type: torch.dtype = DTYPE, + output_type: torch.dtype = DTYPE, + trans_lhs: bool = False, + trans_rhs: bool = TRANS_RHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, + use_bias: bool = False, +) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]: + lhs, rhs, group_sizes_0 = gen_gmm_input( + M, + K, + N, + G, + device=device, + preferred_element_type=input_type, + trans_rhs=trans_rhs, + rng_seed=rng_seed, + unif_group_sizes=unif_group_sizes, + ) + multiple_group_sizes = gen_multiple_group_sizes( + num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0 + ) + out = gen_gmm_output(M, N, device=device, preferred_element_type=output_type) + bias = None + if use_bias: + torch.manual_seed(rng_seed + 1000) # Different seed for bias + bias = torch.randn(G, N, dtype=input_type, device=device) + + return lhs, rhs, multiple_group_sizes, out, bias + + +# GMM helpers: get information from tensors. +# ------------------------------------------------------------------------------ + + +def get_gmm_shape( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor +) -> tuple[int, int, int, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})." + assert ( + group_sizes.dim() == 1 + ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})." + + M, lhs_k = lhs.shape + rhs_g, rhs_k, N = rhs.shape + group_sizes_g = group_sizes.shape[0] + + assert ( + lhs_k == rhs_k + ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})." + K = lhs_k + assert ( + rhs_g == group_sizes_g + ), f"G dimension of rhs and group_sizes don't match (rhs = {rhs_g}, group_sizes = {group_sizes_g})." + G = rhs_g + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + return M, K, N, G + + +def get_gmm_output( + M: int, + N: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, +) -> Tensor: + assert M > 0, f"Number of out rows M must be positive (M = {M})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + + if existing_out is not None: + assert ( + existing_out.device == device + ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})." + assert ( + existing_out.dtype == preferred_element_type + ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})." + assert existing_out.shape == ( + M, + N, + ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(M, N)})." + return existing_out + + return gen_gmm_output( + M, + N, + device=device, + preferred_element_type=preferred_element_type, + ) + + +def get_gmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})." + assert out.dim() == 2, f"out must have 2 dimensions (it's {out.dim()})." + + lhs_m, lhs_k = lhs.shape + G, rhs_k, rhs_n = rhs.shape + out_m, out_n = out.shape + + assert ( + lhs_m == out_m + ), f"M dimension of lhs and out don't match (lhs = {lhs_m}, rhs = {out_m})." + M = lhs_m + assert ( + lhs_k == rhs_k + ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})." + K = lhs_k + assert ( + rhs_n == out_n + ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})." + N = rhs_n + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + is_lhs_row_major = lhs.stride() == (K, 1) + assert is_lhs_row_major, "lhs must be row-major." + is_rhs_row_major = rhs.stride() == (K * N, N, 1) + is_rhs_col_major = rhs.stride() == (K * N, 1, K) + assert ( + is_rhs_row_major != is_rhs_col_major + ), "rhs must be row-major or column-major." + is_out_row_major = out.stride() == (N, 1) + assert is_out_row_major, "out must be row-major." + + # Get rhs leading dimension according to transposition configuration. + ld_rhs = N if is_rhs_row_major else K + + return is_rhs_col_major, ld_rhs + + +# TGMM helpers: tensor generation. +# ------------------------------------------------------------------------------ + + +def gen_tgmm_input( + M: int, + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + trans_lhs: bool = TRANS_LHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, +) -> tuple[Tensor, Tensor, Tensor]: + assert K > 0, f"Number of lhs rows K must be positive (M = {K})." + assert M > 0, f"Number of lhs columns / rhs rows M must be positive (K = {M})." + assert N > 0, f"Number of rhs columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + if trans_lhs: + lhs = torch.randn((M, K), dtype=torch.float32, device=device).T + else: + lhs = torch.randn((K, M), dtype=torch.float32, device=device) + lhs = lhs.to(preferred_element_type) + + rhs = torch.randn((M, N), dtype=torch.float32, device=device) + rhs = rhs.to(preferred_element_type) + + group_sizes = ( + gen_uniform_group_sizes(M, G, device=device) + if unif_group_sizes + else gen_group_sizes(M, G, device=device, rng_seed=None) + ) + + return lhs, rhs, group_sizes + + +def gen_tgmm_output( + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, +) -> Tensor: + assert K > 0, f"Number of out rows K must be positive (K = {K})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + out = torch.empty((G, K, N), dtype=preferred_element_type, device=device) + + return out + + +def gen_tgmm_bias_grad( + K: int, + G: int, + device: torch.device | str = DEVICE, + with_bias_grad: bool = False, +) -> Tensor: + if with_bias_grad: + assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + return torch.empty((G, K), device=device, dtype=torch.float32) + else: + # Return dummy pointer when bias_grad is not needed. + # Must be float32 because atomic_add does not support bf16/fp16, + # and Triton validates the pointer dtype even in dead branches. + return torch.tensor([], device=device, dtype=torch.float32) + + +def gen_tgmm_tensors( + M: int, + K: int, + N: int, + G: int, + num_group_sizes: int, + device: torch.device | str = DEVICE, + input_type: torch.dtype = DTYPE, + output_type: torch.dtype = DTYPE, + trans_lhs: bool = TRANS_LHS, + trans_rhs: bool = False, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, + use_bias: bool = False, +) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]: + lhs, rhs, group_sizes_0 = gen_tgmm_input( + M, + K, + N, + G, + device=device, + preferred_element_type=input_type, + trans_lhs=trans_lhs, + rng_seed=rng_seed, + unif_group_sizes=unif_group_sizes, + ) + multiple_group_sizes = gen_multiple_group_sizes( + num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0 + ) + out = gen_tgmm_output(K, N, G, device=device, preferred_element_type=output_type) + if use_bias: + bias_grad = gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=True) + else: + bias_grad = None + return lhs, rhs, multiple_group_sizes, out, bias_grad + + +# TGMM helpers: get information from tensors. +# ------------------------------------------------------------------------------ + + +def get_tgmm_shape( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor +) -> tuple[int, int, int, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})." + assert ( + group_sizes.dim() == 1 + ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})." + + K, lhs_m = lhs.shape + rhs_m, N = rhs.shape + G = group_sizes.shape[0] + + assert ( + lhs_m == rhs_m + ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})." + M = lhs_m + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + return M, K, N, G + + +def get_tgmm_output( + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, +) -> Tensor: + assert K > 0, f"Number of out rows K must be positive (K = {K})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if existing_out is not None: + assert ( + existing_out.device == device + ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})." + assert ( + existing_out.dtype == preferred_element_type + ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})." + assert existing_out.shape == ( + G, + K, + N, + ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(G, K, N)})." + return existing_out + + return gen_tgmm_output( + K, + N, + G, + device=device, + preferred_element_type=preferred_element_type, + ) + + +def get_tgmm_bias_grad( + K: int, + G: int, + device: torch.device | str = DEVICE, + existing_bias_grad: Tensor | None = None, +) -> Tensor: + """ + Get or validate bias gradient tensor for TGMM. + + If existing_bias_grad is provided, validates its shape, device, dtype, and stride, + and always zeros it before returning (since the kernel uses atomic_add). + If existing_bias_grad is None, returns a dummy tensor (for use when COMPUTE_BIAS_GRAD=False). + Parameters + ---------- + K : int + Number of rows in the bias gradient tensor. + G : int + Number of groups. + device : torch.device or str + Device for the tensor. + existing_bias_grad : torch.Tensor or None + Existing bias gradient tensor to validate and use. + Returns + ------- + torch.Tensor + Valid bias gradient tensor or dummy tensor. + """ + assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if existing_bias_grad is not None: + # Validate existing bias_grad tensor. + expected_shape = (G, K) + assert ( + tuple(existing_bias_grad.shape) == expected_shape + ), f"bias_grad must have shape {expected_shape}, got {tuple(existing_bias_grad.shape)}." + assert ( + existing_bias_grad.device == device + ), f"bias_grad must be on the same device (bias_grad = {existing_bias_grad.device}, device = {device})." + assert ( + existing_bias_grad.dtype == torch.float32 + ), f"bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), got {existing_bias_grad.dtype}." + assert existing_bias_grad.stride() == ( + K, + 1, + ), f"bias_grad must be row-major with stride (K, 1) = ({K}, 1), got {existing_bias_grad.stride()}." + + # Always zero the tensor since bias_grad represents gradients for the current + # computation and should start fresh. The kernel uses atomic_add which adds to + # existing values, so we must zero before the kernel runs. + existing_bias_grad.zero_() + + return existing_bias_grad + + else: + return gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=False) + + +def get_tgmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})." + assert out.dim() == 3, f"out must have 3 dimensions (it's {out.dim()})." + + lhs_k, lhs_m = lhs.shape + rhs_m, rhs_n = rhs.shape + G, out_k, out_n = out.shape + + assert ( + lhs_m == rhs_m + ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})." + M = lhs_m + assert ( + lhs_k == out_k + ), f"K dimension of lhs and out don't match (lhs = {lhs_k}, rhs = {out_k})." + K = lhs_k + assert ( + rhs_n == out_n + ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})." + N = rhs_n + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + is_lhs_row_major = lhs.stride() == (M, 1) + is_lhs_col_major = lhs.stride() == (1, K) + assert ( + is_lhs_row_major != is_lhs_col_major + ), "lhs must be row-major or column-major." + is_rhs_row_major = rhs.stride() == (N, 1) + assert is_rhs_row_major, "rhs must be row-major." + is_out_row_major = out.stride() == (K * N, N, 1) + assert is_out_row_major, "out must be row-major." + + # Get lhs leading dimension according to transposition configuration. + ld_lhs = M if is_lhs_row_major else K + + return is_lhs_col_major, ld_lhs diff --git a/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/logger.py b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..391ddf9b6543f5244e7f4932c8568d60748e15cd --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/logger.py @@ -0,0 +1,47 @@ +import os +import logging + + +# AITER Triton Logger which is singleton object around python logging. +# Note: Python logging is also a singleton object, but we want to read the +# env var AITER_LOG_LEVEL once at the beginning. Another alternative is to do +# this in __init__.py. In fact, that's how CK logger is setup. We can look at +# switching to that at some point +# +# AITER_LOG_LEVEL follows python logging levels +# DEBUG +# INFO +# WARNING +# ERROR +# CRITICAL +# +class AiterTritonLogger(object): + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(AiterTritonLogger, cls).__new__(cls) + log_level_str = os.getenv("AITER_TRITON_LOG_LEVEL", "WARNING").upper() + numeric_level = getattr(logging, log_level_str, logging.WARNING) + cls._instance._logger = logging.getLogger("AITER_TRITON") + cls._instance._logger.setLevel(numeric_level) + + return cls._instance + + def get_logger(self): + return self._logger + + def debug(self, msg): + self._logger.debug(msg) + + def info(self, msg): + self._logger.info(msg) + + def warning(self, msg): + self._logger.warning(msg) + + def error(self, msg): + self._logger.error(msg) + + def critical(self, msg): + self._logger.critical(msg) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/_megablocks_cuda_ae601bb.abi3.so b/build/torch212-cxx11-cu130-x86_64-linux/_megablocks_cuda_ae601bb.abi3.so deleted file mode 100644 index 2523284b87ac13bce49b6c77a57b5719c6c6bc85..0000000000000000000000000000000000000000 --- a/build/torch212-cxx11-cu130-x86_64-linux/_megablocks_cuda_ae601bb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8875ffe3dc9a444ce2110476d0bb3b1a2825db65eedaf8983109cdb8553e8bd5 -size 10113176 diff --git a/build/torch212-cxx11-cu130-x86_64-linux/_megablocks_cuda_f8f8b50.abi3.so b/build/torch212-cxx11-cu130-x86_64-linux/_megablocks_cuda_f8f8b50.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..c4a781e7ab05b54e39e6ccad8500e63e8d9d9c87 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/_megablocks_cuda_f8f8b50.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:53280c5e4952e8dbdcad5f9cd7006d073c2ec1f218b9edb62e4b8fae5cb04b2a +size 11769344 diff --git a/build/torch212-cxx11-cu130-x86_64-linux/_ops.py b/build/torch212-cxx11-cu130-x86_64-linux/_ops.py index 8dd1b7bcf680d2d32dd4ac912487118eafcee4ea..69afb8c26a3fa2691be277b0270d600d29a5865e 100644 --- a/build/torch212-cxx11-cu130-x86_64-linux/_ops.py +++ b/build/torch212-cxx11-cu130-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_cuda_ae601bb -ops = torch.ops._megablocks_cuda_ae601bb +from . import _megablocks_cuda_f8f8b50 +ops = torch.ops._megablocks_cuda_f8f8b50 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_cuda_ae601bb::{op_name}" + return f"_megablocks_cuda_f8f8b50::{op_name}" diff --git a/build/torch212-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py b/build/torch212-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py index 76037d8039cbfc2f0577275c78e4bc0be762592a..c7ef28ced79c830dae934177f059c1f4ddc24aad 100644 --- a/build/torch212-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py +++ b/build/torch212-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py @@ -2,16 +2,16 @@ # extensions. Otherwise libc10.so cannot be found. import torch -# # TODO(tgale): Wrap this in a try-block with better -# # error message and instructions for building the -# # c++ operations. -# import grouped_gemm_backend as backend +# On ROCm there is no CUTLASS grouped GEMM; dispatch to the vendored AITER +# Triton kernels instead. On CUDA we use the compiled CUTLASS `gmm` op. +_IS_ROCM = torch.version.hip is not None -# We import the backend operations from the megablocks package as -# grouped_gemm is vendored in megablocks in this repository. -# from ... import _ops as backend -# from megablocks._ops import ops as backend # type: ignore -from .._ops import ops as backend # type: ignore +if _IS_ROCM: + from .._grouped_gemm_triton import adapter as backend +else: + # We import the backend operations from the megablocks package as + # grouped_gemm is vendored in megablocks in this repository. + from .._ops import ops as backend # type: ignore def _allocate_output(a, b, batch_sizes, trans_a, trans_b): assert not (trans_a and trans_b) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/metadata.json b/build/torch212-cxx11-cu130-x86_64-linux/metadata.json index dae1319c841f27d4cd7a5a4b31fbde6ae4d4cacd..436ad3fc85ff69b069290830671db574d1045671 100644 --- a/build/torch212-cxx11-cu130-x86_64-linux/metadata.json +++ b/build/torch212-cxx11-cu130-x86_64-linux/metadata.json @@ -1,6 +1,6 @@ { "name": "megablocks", - "id": "_megablocks_cuda_ae601bb", + "id": "_megablocks_cuda_f8f8b50", "version": 1, "license": "Apache-2.0", "python-depends": [], @@ -8,7 +8,9 @@ "type": "cuda", "archs": [ "10.0", + "11.0", "12.0", + "12.0+PTX", "7.5", "8.0", "8.6", diff --git a/build/torch212-cxx11-cu132-x86_64-linux/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/__init__.py index 38075732c6d8fa0e1e6ef493145e1aca3851ae6b..0766d7b8da4f97baca212177b4bb989bc6374bf8 100644 --- a/build/torch212-cxx11-cu132-x86_64-linux/__init__.py +++ b/build/torch212-cxx11-cu132-x86_64-linux/__init__.py @@ -3,7 +3,9 @@ import torch -from ._ops import ops +# Stable alias: bare `ops` is shadowed by `from . import layers` below. +from ._ops import ops as _compiled_ops +from . import ops from .grouped_gemm import backend as gg_backend from .grouped_gemm import ops as gg_ops @@ -136,7 +138,8 @@ def sort( Returns: The sorted values tensor """ - return ops.sort(x, end_bit, x_out, iota_out) + _compiled_ops.sort(x, end_bit, x_out, iota_out) + return x_out # Convenience functions for common use cases diff --git a/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py new file mode 100644 index 0000000000000000000000000000000000000000..8c101d07cea416f9390b708e5a35fdc466e48aed --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py @@ -0,0 +1,574 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + + +# Imports. +# ------------------------------------------------------------------------------ + +# Python standard library +import functools + +# Triton +import triton +import triton.language as tl + +# AITER +from ..configs import CONFIGS as _CONFIGS +from ..utils._triton import arch_info +from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd + +# Kernel config. +# ------------------------------------------------------------------------------ + + +@functools.lru_cache() +def get_config( + gmm_type: str, M: int, K: int, N: int, G: int, accumulate: bool = False +) -> dict[str, int]: + assert gmm_type in { + "gmm", + "ptgmm", + "nptgmm", + }, f"'{gmm_type}' is an invalid GMM variant." + dev = arch_info.get_arch() + assert ( + dev in _CONFIGS + ), f"No GMM configuration tuned for arch '{dev}'. Supported: {sorted(_CONFIGS)}." + arch_configs = _CONFIGS[dev] + assert ( + "default" in arch_configs[gmm_type] + ), "Default configuration is absent." + key = "accumulate" if accumulate else "default" + return arch_configs[gmm_type][key] + + +# Common code shared by GMM and TGMM kernels. +# ------------------------------------------------------------------------------ + + +# XCD remapping followed by 1D PID to 2D grid mapping. +@triton.jit +def _remap_xcd_tile_grid( + tile_in_mm, + num_row_tiles, + num_col_tiles, + GROUP_SIZE: tl.constexpr = 1, + NUM_XCDS: tl.constexpr = 8, +): + return pid_grid( + remap_xcd(tile_in_mm, num_row_tiles * num_col_tiles, NUM_XCDS=NUM_XCDS), + num_row_tiles, + num_col_tiles, + GROUP_SIZE_M=GROUP_SIZE, + ) + + +# GMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.heuristics( + { + "K_DIVISIBLE_BY_BLOCK_SIZE_K": lambda META: META["K"] % META["BLOCK_SIZE_K"] + == 0, + } +) +@triton.jit +def gmm_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_RHS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + K_DIVISIBLE_BY_BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE: tl.constexpr, + GRID_DIM: tl.constexpr, + USE_BIAS: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + # Current tile. Each program computes multiple tiles of each group. + tile = tl.program_id(0) + tl.device_assert(tile >= 0, "tile < 0 (at initialization)") + + # Tile limit of last MM problem (inclusive). + last_mm_tile = 0 + + # Last input row of lhs and output row of out. Each group reads some rows of + # lhs and writes some rows to out. + last_m = 0 + + # Loop through all (m, K, N) MM problems: + # (m, K) x (K, N) = (m, N) + # sum(m) = M + for g in range(G): + # Get m dimension of current MM problem. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty + tl.device_assert(m >= 0, "m < 0") + + num_m_tiles = tl.cdiv(m, BLOCK_SIZE_M) + # num_m_tiles can be zero if group is empty + tl.device_assert(num_m_tiles >= 0, "num_m_tiles < 0") + + num_tiles = num_m_tiles * num_n_tiles + # num_tiles can be zero if group is empty + tl.device_assert(num_tiles >= 0, "num_tiles < 0") + + # Loop through tiles of current MM problem. + while tile >= last_mm_tile and tile < last_mm_tile + num_tiles: + # Figure out tile coordinates in current MM problem. + tile_in_mm = tile - last_mm_tile + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_m, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_m_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + # Do regular MM: + + tl.device_assert(tile_m * BLOCK_SIZE_M >= 0, "tile_m * BLOCK_SIZE_M < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_m = ( + tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + ) % m + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + + lhs_ptrs = lhs_ptr + (last_m + offs_lhs_m[:, None]) * K + offs_k[None, :] + + if TRANS_RHS: + rhs_ptrs = ( + rhs_ptr + + g.to(tl.int64) * K * N + + offs_k[:, None] + + offs_rhs_n[None, :] * K + ) + else: + rhs_ptrs = ( + rhs_ptr + + g.to(tl.int64) * K * N + + offs_k[:, None] * N + + offs_rhs_n[None, :] + ) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if K_DIVISIBLE_BY_BLOCK_SIZE_K: + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + else: + k_mask_limit = K - k * BLOCK_SIZE_K + lhs = tl.load( + lhs_ptrs, mask=offs_k[None, :] < k_mask_limit, other=0 + ) + rhs = tl.load( + rhs_ptrs, mask=offs_k[:, None] < k_mask_limit, other=0 + ) + + acc = tl.dot(lhs, rhs, acc=acc) + + lhs_ptrs += BLOCK_SIZE_K + + if TRANS_RHS: + rhs_ptrs += BLOCK_SIZE_K + else: + rhs_ptrs += BLOCK_SIZE_K * N + + # Add bias if enabled + if USE_BIAS: + offs_bias_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange( + 0, BLOCK_SIZE_N + ) + bias_ptrs = bias_ptr + g.to(tl.int64) * N + offs_bias_n + bias = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0) + # Convert bias to float32 to match accumulator precision + bias = bias.to(tl.float32) + # Broadcast bias across M dimension and add in float32 + acc += bias[None, :] + + # Convert to output dtype after all computations + acc = acc.to(out_ptr.type.element_ty) + + offs_out_m = tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + (last_m + offs_out_m[:, None]) * N + offs_out_n[None, :] + ) + + tl.store( + out_ptrs, + acc, + mask=(offs_out_m[:, None] < m) & (offs_out_n[None, :] < N), + ) + + # Go to the next tile by advancing number of programs. + tile += GRID_DIM + tl.device_assert(tile > 0, "tile <= 0 (at update)") + + # Get ready to go to the next MM problem. + + last_mm_tile += num_tiles + # last_mm_tile can be zero if group 0 is skipped + tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)") + + last_m += m + # last_m can be zero if group 0 is skipped + tl.device_assert(last_m >= 0, "last_m < 0 (at update)") + tl.device_assert(last_m <= M, "last_m > M (at update)") + + +# Persistent TGMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.jit +def tgmm_persistent_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_grad_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_LHS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE: tl.constexpr, + GRID_DIM: tl.constexpr, + COMPUTE_BIAS_GRAD: tl.constexpr, + ACCUMULATE: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0") + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + num_tiles = num_k_tiles * num_n_tiles + tl.device_assert(num_tiles > 0, "num_tiles <= 0") + + # Current tile. Each program computes multiple tiles of each group. + tile = tl.program_id(0) + tl.device_assert(tile >= 0, "tile < 0 (at initialization)") + + # Tile limit of last MM problem (inclusive). + last_mm_tile = 0 + + # Last input column of lhs and input row of rhs. Each group reads some + # columns of lhs and some rows of rhs. + last_m = 0 + + # Loop through all (K, m, N) MM problems: + # (K, m) x (m, N) = (K, N) + # sum(m) = M + for g in range(G): + # Get m dimension of current MM problem. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty + tl.device_assert(m >= 0, "m < 0") + + # Loop through tiles of current MM problem. + while tile >= last_mm_tile and tile < last_mm_tile + num_tiles: + # Figure out tile coordinates in current MM problem. + tile_in_mm = tile - last_mm_tile + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_k, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + # Do regular MM: + + tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + + if TRANS_LHS: + lhs_ptrs = ( + lhs_ptr + offs_lhs_k[:, None] + (last_m + offs_m[None, :]) * K + ) + else: + lhs_ptrs = ( + lhs_ptr + offs_lhs_k[:, None] * M + (last_m + offs_m[None, :]) + ) + + rhs_ptrs = rhs_ptr + (last_m + offs_m[:, None]) * N + offs_rhs_n[None, :] + + loop_m = tl.cdiv(m, BLOCK_SIZE_M) + m_divisible_by_block_m = m % BLOCK_SIZE_M == 0 + if not m_divisible_by_block_m: + loop_m -= 1 + + acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32) + + # Initialize bias accumulator + bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32) + + for _ in range(0, loop_m): + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + + acc = tl.dot(lhs, rhs, acc=acc) + + # Accumulate for bias gradient: sum lhs across M dimension + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum( + lhs, axis=1 + ) # Sum across M dimension [K, M] -> [K] + + if TRANS_LHS: + lhs_ptrs += BLOCK_SIZE_M * K + else: + lhs_ptrs += BLOCK_SIZE_M + + rhs_ptrs += BLOCK_SIZE_M * N + + if not m_divisible_by_block_m: + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0) + rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0) + acc = tl.dot(lhs, rhs, acc=acc) + + # Accumulate last chunk for bias gradient + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) + + acc = acc.to(out_ptr.type.element_ty) + + offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + + g.to(tl.int64) * K * N + + offs_out_k[:, None] * N + + offs_out_n[None, :] + ) + + mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N) + if ACCUMULATE: + # Load existing values and add to them (like beta=1 in BLAS) + old_vals = tl.load(out_ptrs, mask=mask, other=0.0) + tl.store(out_ptrs, acc + old_vals, mask=mask) + else: + # Overwrite output (like beta=0 in BLAS) + tl.store(out_ptrs, acc, mask=mask) + + # Store bias gradient (only for first N tile, sum across all M) + if COMPUTE_BIAS_GRAD and tile_n == 0: + # Keep as float32 for atomic_add (bf16 not supported for atomics) + bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k + # Use atomic add since multiple K-tiles may write to same expert's bias + tl.atomic_add( + bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed" + ) + + # Go to the next tile by advancing number of programs. + tile += GRID_DIM + tl.device_assert(tile > 0, "tile <= 0 (at update)") + + # Get ready to go to the next MM problem. + + last_mm_tile += num_tiles + # last_mm_tile can be zero if group 0 is skipped + tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)") + + last_m += m + # last_m can be zero if group 0 is skipped + tl.device_assert(last_m >= 0, "last_m < 0 (at update)") + tl.device_assert(last_m <= M, "last_m > M (at update)") + + +# Regular non-persistent TGMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.heuristics({"BLOCK_SIZE_G": lambda META: triton.next_power_of_2(META["G"])}) +@triton.jit +def tgmm_non_persistent_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_grad_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_LHS: tl.constexpr, + BLOCK_SIZE_G: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE: tl.constexpr, + COMPUTE_BIAS_GRAD: tl.constexpr, + ACCUMULATE: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + # Get group ID from grid. + g = tl.program_id(0) + tl.device_assert(g >= 0, "g < 0") + tl.device_assert(g < G, "g >= G") + + # Get m dimension of current MM group. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty. + tl.device_assert(m >= 0, "m < 0") + + # Skip empty groups. + if m == 0: + return + + # Compute sum(group_sizes) until current group g. + # It's the starting column of lhs and starting row of rhs. + offs_g = tl.arange(0, BLOCK_SIZE_G) + group_sizes = tl.load(group_sizes_ptr + offs_g, mask=offs_g < g, other=0) + start_m = tl.sum(group_sizes) + + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0") + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + # Get MM tile from grid. + tile_in_mm = tl.program_id(1) + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_k, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_k = (tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K + offs_rhs_n = (tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + + if TRANS_LHS: + lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] + (start_m + offs_m[None, :]) * K + else: + lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] * M + (start_m + offs_m[None, :]) + + rhs_ptrs = rhs_ptr + (start_m + offs_m[:, None]) * N + offs_rhs_n[None, :] + + loop_m = tl.cdiv(m, BLOCK_SIZE_M) + m_divisible_by_block_m = m % BLOCK_SIZE_M == 0 + if not m_divisible_by_block_m: + loop_m -= 1 + + acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32) + # Initialize bias accumulator + bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32) + + for _ in range(0, loop_m): + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + + acc = tl.dot(lhs, rhs, acc=acc) + + # Accumulate for bias gradient: sum lhs across M dimension + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) # [K, M] -> [K] + + if TRANS_LHS: + lhs_ptrs += BLOCK_SIZE_M * K + else: + lhs_ptrs += BLOCK_SIZE_M + + rhs_ptrs += BLOCK_SIZE_M * N + + if not m_divisible_by_block_m: + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0) + rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0) + acc = tl.dot(lhs, rhs, acc=acc) + # Accumulate last chunk for bias gradient + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) + + acc = acc.to(out_ptr.type.element_ty) + + offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + g.to(tl.int64) * K * N + offs_out_k[:, None] * N + offs_out_n[None, :] + ) + + mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N) + if ACCUMULATE: + # Load existing values and add to them (like beta=1 in BLAS) + old_vals = tl.load(out_ptrs, mask=mask, other=0.0) + tl.store(out_ptrs, acc + old_vals, mask=mask) + else: + # Overwrite output (like beta=0 in BLAS) + tl.store(out_ptrs, acc, mask=mask) + + # Store bias gradient (only for first N tile, sum across all M) + if COMPUTE_BIAS_GRAD and tile_n == 0: + # Keep as float32 for atomic_add (bf16/fp16 not supported for atomics) + bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k + # Use atomic add since multiple K-tiles may write to same expert's bias + tl.atomic_add(bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed") diff --git a/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/adapter.py b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..98c224244f27445384e0c2377d73516406927536 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/adapter.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Adapt AITER's Triton grouped GEMM to MegaBlocks' ``gmm`` calling convention. + +MegaBlocks (following tgale96/grouped_gemm) uses a single ``gmm`` entry point +with ``trans_a`` / ``trans_b`` flags: + +* ``trans_a=False, trans_b=False``: a(M,K) @ b(G,K,N) -> c(M,N) +* ``trans_a=False, trans_b=True`` : a(M,K) @ b(G,N,K)^T -> c(M,N) (dgrad) +* ``trans_a=True`` : a(M,K)^T @ b(M,N) per group -> c(G,K,N) (wgrad) + +AITER exposes these as two kernels: ``gmm`` ((M,K)@(G,K,N)->(M,N), transposition +of the 3D operand inferred from strides) and ``ptgmm`` ((K,M)@(M,N)->(G,K,N), +transposition of the 2D operand inferred from strides). +""" + +import torch + +from .gmm import gmm as _aiter_gmm +from .gmm import ptgmm as _aiter_ptgmm + + +def gmm(a, b, c, batch_sizes, trans_a=False, trans_b=False): + # AITER requires group sizes to be int32 and to live on the compute device. + group_sizes = batch_sizes.to(device=a.device, dtype=torch.int32) + + # AITER asserts exact strides: gmm wants lhs/rhs row-major (a transposed + # 3D operand must be exactly column-major), tgmm wants rhs row-major and + # lhs row/column-major. Make operands contiguous first so the transposed + # views have the precise strides the kernels expect. `.contiguous()` is a + # no-op when the tensor is already contiguous. + if trans_a: + # Weight gradient: a(M,K), b(M,N) -> c(G,K,N). + # Pass a transposed so AITER sees lhs(K,M) column-major (TRANS_LHS). + _aiter_ptgmm( + a.contiguous().transpose(0, 1), + b.contiguous(), + group_sizes, + preferred_element_type=c.dtype, + existing_out=c, + ) + else: + # trans_b contracts b's last dim: pass a column-major (G,K,N) view. + rhs = b.contiguous() + if trans_b: + rhs = rhs.transpose(1, 2) + _aiter_gmm( + a.contiguous(), + rhs, + group_sizes, + preferred_element_type=c.dtype, + existing_out=c, + ) + return c diff --git a/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/configs.py b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..9a4fe5617d8100869aa76dba9b7d22c7bcab814f --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/configs.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: MIT +# Tuned GMM configs vendored from ROCm/aiter (aiter/ops/triton/configs/). +# Inlined as a Python module so packaging always includes them. + +CONFIGS = {'gfx1250': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx942': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx950': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}} diff --git a/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/gmm.py b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/gmm.py new file mode 100644 index 0000000000000000000000000000000000000000..e30c9326c6d4e4836d1303e2761ea2440a7f4750 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/gmm.py @@ -0,0 +1,567 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + + +# Imports. +# ------------------------------------------------------------------------------ + +# PyTorch +import torch +from torch import Tensor + +# Triton +import triton + +# AITER: GMM utility functions +from .utils.gmm_common import ( + DTYPE, + is_power_of_2, + check_input_device_dtype, + check_bias_shape_stride, + get_gmm_shape, + get_gmm_output, + get_gmm_transposition, + get_tgmm_shape, + get_tgmm_output, + get_tgmm_bias_grad, + get_tgmm_transposition, +) + +# AITER: GMM Triton kernels +from ._triton_kernels.gmm import ( + gmm_kernel, + tgmm_persistent_kernel, + tgmm_non_persistent_kernel, + get_config, +) + +# GMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _gmm_grid( + N: int, + block_size_m: int, + block_size_n: int, + group_sizes: Tensor, + grid_dim: int, +) -> tuple[int]: + assert N > 0, f"N must be positive, it's {N}." + assert is_power_of_2( + block_size_m + ), f"M-dimension tile size must be a power of 2 (it's {block_size_m})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + assert torch.all(group_sizes >= 0).item(), "All group_sizes must be non-negative." + assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})." + num_m_tiles = (group_sizes + block_size_m - 1) // block_size_m + assert torch.all(num_m_tiles >= 0).item(), "All num_m_tiles must be non-negative." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles = torch.sum(num_m_tiles * num_n_tiles).item() + assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}." + num_programs = int(min(grid_dim, num_tiles)) + assert num_programs > 0, f"num_programs must be positive, it's {num_programs}." + return (num_programs,) + + +def gmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias: Tensor | None = None, +) -> Tensor: + """ + Perform Group Matrix Multiplication (GMM): out = lhs @ rhs + bias + + lhs rows are divided into G groups. Each group of lhs rows is matrix multiplied with a plane of + rhs 3D tensor and then stored in a slice of out. In PyTorch parlance, it can be implemented as + follows for a given group g: + out[group_start:group_end, :] = lhs[group_start:group_end, :] @ rhs[g] + bias[g] + + The size of each group, and their respective start and end positions are specified by + group_sizes tensor. For instance, suppose that group_sizes = [3, 2, 4, 1]. In this particular + case we have 4 groups. The 1st group starts at 0 and ends at 2, the second group starts at 3 and + ends at 4, the third group starts at 5 and ends at 8, and the fourth and final group consists of + just the 10th (last) row of lhs. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (M, K). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 3D input tensor. Shape: (G, K, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (M, N), its data type must match preferred_element_type + and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias : torch.Tensor or None, optional + Optional bias tensor. Shape: (G, N). + If provided, bias data type must match lhs and rhs data type, and bias must be on the same + device as other input tensors. Each group g adds bias[g] to the output. + + Returns + ------- + torch.Tensor + The computed output 2D tensor. Shape: (M, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - GMM is implemented with a persistent Triton kernel. + - lhs must be row-major (lhs.stride() == (K, 1)). + - rhs can be row-major (rhs.stride() == (K * N, N, 1)) or column-major (rhs.stride() == + (K * N, 1, K)). If rhs is row-major then kernel parameter TRANS_RHS == False, this is useful + for implementing forward pass. If rhs is column-major then kernel parameter TRANS_RHS == True, + this is useful for computing the lhs derivative in the backward pass, while fusing the + transposition. + - out must be row-major (out.stride() == (N, 1)). + - bias must be row-major (bias.stride() == (N, 1)) if provided. + """ + use_bias = bias is not None + check_input_device_dtype(lhs, rhs, group_sizes, bias) + + M, K, N, G = get_gmm_shape(lhs, rhs, group_sizes) + + if use_bias: + check_bias_shape_stride(bias, G, N) + + out = get_gmm_output( + M, + N, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_rhs, _ = get_gmm_transposition(lhs, rhs, out) + + if config is None: + config = get_config("gmm", M, K, N, G) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + "GRID_DIM", + } + ), "Invalid GMM kernel config." + + grid = _gmm_grid( + N, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + group_sizes, + config["GRID_DIM"], + ) + + # fmt: off + gmm_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_RHS=trans_rhs, + USE_BIAS=use_bias, + **config, + ) + # fmt: on + + return out + + +# Persistent TGMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _ptgmm_grid( + K: int, + N: int, + G: int, + block_size_k: int, + block_size_n: int, + grid_dim: int, +) -> tuple[int]: + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}." + assert G > 0, f"G must be positive, it's {G}." + assert is_power_of_2( + block_size_k + ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})." + num_k_tiles = triton.cdiv(K, block_size_k) + assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles = G * num_k_tiles * num_n_tiles + assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}." + num_programs = min(grid_dim, num_tiles) + assert num_programs > 0, f"num_programs must be positive, it's {num_programs}." + return (num_programs,) + + +def ptgmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias_grad: Tensor | None = None, + accumulate: bool = False, +) -> Tensor: + """ + Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs + + lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with + the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch + parlance, it can be implemented as follows for a given group g: + out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :] + + The 't' in the operator name derives from MaxText implementation + (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py), + which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor + shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N). + + The 'p' in the operator name means that it is implemented with a persistent kernel. There is + also the non-persistent variation, which is implemented with a regular kernel. Please take a + look at nptgmm operator. Both ptgmm and nptgmm implement the same computation, choosing one or + the other is a matter of performance for the target workload. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (K, M). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 2D input tensor. Shape: (M, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (G, K, N), its data type must match + preferred_element_type and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias_grad : torch.Tensor or None, optional + Optional bias gradient output tensor. Shape: (G, K). + If provided, the kernel will compute the bias gradient and write it to this tensor. + bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), + accumulate : bool, optional + Whether to accumulate into existing output tensor values. Default is False. + If False, output will be overwritten with fresh computation. + If True, results will be added to existing output tensor values. + + Returns + ------- + torch.Tensor + The computed output 3D tensor. Shape: (G, K, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - PTGMM is implemented with a persistent Triton kernel. + - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs + is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel + parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward + pass, while fusing the transposition. + - rhs must be row-major (rhs.stride() == (N, 1)). + - out must be row-major (out.stride() == (K * N, N, 1)). + """ + check_input_device_dtype(lhs, rhs, group_sizes) + + M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes) + + out = get_tgmm_output( + K, + N, + G, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out) + + if config is None: + config = get_config("ptgmm", M, K, N, G, accumulate) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + "GRID_DIM", + } + ), "Invalid PTGMM kernel config." + + # Bias gradient handling. + # ----------------------- + # Get or validate bias gradient tensor. + compute_bias_grad = bias_grad is not None + bias_grad_ptr = get_tgmm_bias_grad( + K, + G, + device=lhs.device, + existing_bias_grad=bias_grad, + ) + + grid = _ptgmm_grid( + K, + N, + G, + config["BLOCK_SIZE_K"], + config["BLOCK_SIZE_N"], + config["GRID_DIM"], + ) + + # fmt: off + tgmm_persistent_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias_grad_ptr, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_LHS=trans_lhs, + COMPUTE_BIAS_GRAD=compute_bias_grad, + ACCUMULATE=accumulate, + **config, + ) + # fmt: on + + return out + + +# Regular non-persistent TGMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _nptgmm_grid( + K: int, + N: int, + G: int, + block_size_k: int, + block_size_n: int, +) -> tuple[int, int]: + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}." + assert G > 0, f"G must be positive, it's {G}." + assert is_power_of_2( + block_size_k + ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + num_k_tiles = triton.cdiv(K, block_size_k) + assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles_per_mm = num_k_tiles * num_n_tiles + assert ( + num_tiles_per_mm > 0 + ), f"num_tiles_per_mm must be positive, it's {num_tiles_per_mm}." + return (G, num_tiles_per_mm) + + +def nptgmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias_grad: Tensor | None = None, + accumulate: bool = False, +) -> Tensor: + """ + Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs + + lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with + the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch + parlance, it can be implemented as follows for a given group g: + out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :] + + The 't' in the operator name derives from MaxText implementation + (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py), + which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor + shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N). + + The 'np' in the operator name means that it is implemented with a non-persistent, i.e. regular + kernel. There is also the persistent variation, which is implemented with a persistent kernel. + Please take a look at ptgmm operator. Both nptgmm and ptgmm implement the same computation, + choosing one or the other is a matter of performance for the target workload. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (K, M). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 2D input tensor. Shape: (M, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (G, K, N), its data type must match + preferred_element_type and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias_grad : torch.Tensor or None, optional + Optional bias gradient output tensor. Shape: (G, K). + If provided, the kernel will compute the bias gradient and write it to this tensor. + bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), + accumulate : bool, optional + Whether to accumulate into existing output tensor values. Default is False. + If False, output will be overwritten with fresh computation. + If True, results will be added to existing output tensor values. + + Returns + ------- + torch.Tensor + The computed output 3D tensor. Shape: (G, K, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - NPTGMM is implemented with a non-persistent regular Triton kernel. + - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs + is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel + parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward + pass, while fusing the transposition. + - rhs must be row-major (rhs.stride() == (N, 1)). + - out must be row-major (out.stride() == (K * N, N, 1)). + """ + check_input_device_dtype(lhs, rhs, group_sizes) + + M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes) + + out = get_tgmm_output( + K, + N, + G, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out) + + # Bias gradient handling. + # ----------------------- + # Get or validate bias gradient tensor. + compute_bias_grad = bias_grad is not None + bias_grad_ptr = get_tgmm_bias_grad( + K, + G, + device=lhs.device, + existing_bias_grad=bias_grad, + ) + + if config is None: + config = get_config("nptgmm", M, K, N, G, accumulate) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + } + ), "Invalid NPTGMM kernel config." + + grid = _nptgmm_grid( + K, + N, + G, + config["BLOCK_SIZE_K"], + config["BLOCK_SIZE_N"], + ) + + # fmt: off + tgmm_non_persistent_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias_grad_ptr, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_LHS=trans_lhs, + COMPUTE_BIAS_GRAD=compute_bias_grad, + ACCUMULATE=accumulate, + **config, + ) + # fmt: on + + return out diff --git a/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/utils/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py new file mode 100644 index 0000000000000000000000000000000000000000..3f6c88581a64044518125623f116082c53bd5474 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py @@ -0,0 +1,46 @@ +import triton + +# Detect the GPU arch lazily: querying the triton driver at import time fails +# in headless environments (e.g. the kernel-builder ABI check sandbox has no +# GPU), and the original JAX fallback pulled in an unrelated runtime dep. The +# arch is only actually needed when a GMM kernel is dispatched, so resolve and +# cache on first call. +_CACHED_ARCH = None + + +def get_arch(): + global _CACHED_ARCH + if _CACHED_ARCH is not None: + return _CACHED_ARCH + try: + _CACHED_ARCH = triton.runtime.driver.active.get_current_target().arch + except RuntimeError: + try: + from jax._src.lib import gpu_triton as triton_kernel_call_lib + _CACHED_ARCH = triton_kernel_call_lib.get_arch_details("0").split(":")[0] + except ImportError as e: + raise RuntimeError( + "Cannot determine GPU arch: triton driver is inactive and " + "JAX is not available. A GPU is required for grouped GEMM." + ) from e + return _CACHED_ARCH + + +def is_gluon_avail(): + return get_arch() in ("gfx950", "gfx1250") + + +def is_fp4_avail(): + return get_arch() in ("gfx950", "gfx1250") + + +def is_fp8_avail(): + return get_arch() in ("gfx942", "gfx950", "gfx1250", "gfx1200", "gfx1201") + + +def is_mx_scale_preshuffling_avail(): + return get_arch() in ("gfx950", "gfx1250") + + +def is_tdm_avail(): + return get_arch() in ("gfx1250",) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..99792bb3ba2fab8fff223bba733ced1eb6e6df53 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: MIT + +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl + + +@triton.jit +def remap_xcd_chunked( + pid, GRID_MN, NUM_XCDS: tl.constexpr = 8, CHUNK_SIZE: tl.constexpr = 2 +): + # Compute current XCD and local PID + xcd = pid % NUM_XCDS + # distribute the modulo pids in round robin + if pid > (GRID_MN // (NUM_XCDS * CHUNK_SIZE)) * (NUM_XCDS * CHUNK_SIZE): + return pid + local_pid = pid // NUM_XCDS + # Calculate chunk index and position within chunk + chunk_idx = local_pid // CHUNK_SIZE + pos_in_chunk = local_pid % CHUNK_SIZE + # Calculate new PID + new_pid = chunk_idx * NUM_XCDS * CHUNK_SIZE + xcd * CHUNK_SIZE + pos_in_chunk + return new_pid + + +@triton.jit +def remap_xcd(pid, GRID_MN, NUM_XCDS: tl.constexpr = 8): + ## pid remapping on xcds + # Number of pids per XCD in the new arrangement + pids_per_xcd = (GRID_MN + NUM_XCDS - 1) // NUM_XCDS + # When GRID_MN cannot divide NUM_XCDS, some xcds will have + # pids_per_xcd pids, the other will have pids_per_xcd - 1 pids. + # We calculate the number of xcds that have pids_per_xcd pids as + # tall_xcds + tall_xcds = GRID_MN % NUM_XCDS + tall_xcds = NUM_XCDS if tall_xcds == 0 else tall_xcds + # Compute current XCD and local pid within the XCD + xcd = pid % NUM_XCDS + local_pid = pid // NUM_XCDS + # Calculate new pid based on the new grouping + # Note that we need to consider the following two cases: + # 1. the current pid is on a tall xcd + # 2. the current pid is on a short xcd + if xcd < tall_xcds: + pid = xcd * pids_per_xcd + local_pid + else: + pid = ( + tall_xcds * pids_per_xcd + + (xcd - tall_xcds) * (pids_per_xcd - 1) + + local_pid + ) + + return pid + + +@triton.jit +def pid_grid(pid: int, num_pid_m: int, num_pid_n: int, GROUP_SIZE_M: tl.constexpr = 1): + """ + Maps 1D pid to 2D grid coords (pid_m, pid_n). + + Args: + - pid: 1D pid + - num_pid_m: grid m size + - num_pid_n: grid n size + - GROUP_SIZE_M: tl.constexpr: default is 1 + """ + if GROUP_SIZE_M == 1: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + tl.assume(group_size_m >= 0) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + return pid_m, pid_n + + +@triton.jit +def pid_grid_3d(pid: int, num_pid_m: int, num_pid_n: int, num_pid_k): + """ + Maps 1D pid to 3D grid coords (pid_m, pid_n, pid_k). + Args: + - pid: 1D pid + - num_pid_m: grid m size + - num_pid_n: grid n size + - num_pid_k: grid k size + + Returns: + - pid_m, pid_n, pid_k: 3D grid coordinates + """ + pid_m = pid % num_pid_m + pid_n = (pid // num_pid_m) % num_pid_n + pid_k = pid // (num_pid_m * num_pid_n) % num_pid_k + + return pid_m, pid_n, pid_k diff --git a/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py new file mode 100644 index 0000000000000000000000000000000000000000..153dee65b50ab5f833262481889d2184d1ca639f --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py @@ -0,0 +1,752 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + +# Imports. +# ------------------------------------------------------------------------------ + +# PyTorch +import torch +from torch import Tensor + +# AITER: logging +from .logger import AiterTritonLogger + +_LOGGER: AiterTritonLogger = AiterTritonLogger() + + +# Supported data types. +# ------------------------------------------------------------------------------ + +# Supported data types, as strings. +SUPPORTED_DTYPES_STR: set[str] = {"fp16", "bf16"} + + +# Convert string data type to PyTorch data type. +def dtype_from_str(dtype_str: str) -> torch.dtype: + dtype_str = dtype_str.strip().lower() + dtype_str = dtype_str[1:] if dtype_str[0] in {"i", "o"} else dtype_str + assert ( + dtype_str in SUPPORTED_DTYPES_STR + ), "String data type isn't in set of supported string data types." + return {"fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str] + + +# Supported data types, as PyTorch types. +SUPPORTED_DTYPES: set[torch.dtype] = { + dtype_from_str(dtype_str) for dtype_str in SUPPORTED_DTYPES_STR +} + + +# Convert PyTorch data type to string data type. +def str_from_dtype(dtype: torch.dtype) -> str: + assert ( + dtype in SUPPORTED_DTYPES + ), "PyTorch data type isn't in set of supported PyTorch data types." + return {torch.float16: "fp16", torch.bfloat16: "bf16"}[dtype] + + +# Default data type, as string. +DTYPE_STR: str = "bf16" +assert ( + DTYPE_STR in SUPPORTED_DTYPES_STR +), "Default string data type isn't in set of supported string data types." + + +# Default data type, as PyTorch type. +DTYPE: torch.dtype = dtype_from_str(DTYPE_STR) + + +# Other defaults. +# ------------------------------------------------------------------------------ + +# Default device. +DEVICE: torch.device | str = "cuda" + +# Default RNG seed for input generation. +RNG_SEED: int = 0 + +# Default number of group sizes. +NUM_GROUP_SIZES: int = 1 + +# Default transposition (NN). +TRANS_LHS: bool = False +TRANS_RHS: bool = False + + +# Parameter checking functions. +# ------------------------------------------------------------------------------ + + +def is_power_of_2(x: int) -> bool: + return (x > 0) and (x & (x - 1) == 0) + + +def check_input_device_dtype( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor, bias: Tensor | None = None +) -> None: + assert ( + lhs.device == rhs.device == group_sizes.device + ), f"All input tensors must be in the same device (lhs = {lhs.device}, rhs = {rhs.device}, group_sizes = {group_sizes.device})." + assert ( + lhs.dtype == rhs.dtype + ), f"lhs and rhs types must match (lhs = {lhs.dtype}, rhs = {rhs.dtype})." + assert group_sizes.dtype == torch.int32, "group_sizes type must be int32." + + if bias is not None: + assert ( + bias.device == lhs.device + ), f"bias must be on the same device as lhs (bias = {bias.device}, lhs = {lhs.device})." + assert ( + bias.dtype == lhs.dtype + ), f"bias dtype must match lhs dtype (bias = {bias.dtype}, lhs = {lhs.dtype})." + + +def check_bias_shape_stride(bias: Tensor, G: int, N: int) -> None: + assert bias.shape == ( + G, + N, + ), f"bias must have shape (G, N) = ({G}, {N}), got {bias.shape}." + assert bias.stride() == (N, 1), "bias must be row-major (bias.stride() == (N, 1))." + + +# Generation of group sizes. +# ------------------------------------------------------------------------------ + + +# Probabilities for generating random group sizes. +UNUSED_TOKENS_PROB: float = 0.0 +UNUSED_EXPERTS_PROB: float = 0.1 + + +def gen_uniform_group_sizes( + M: int, + G: int, + device: torch.device | str = DEVICE, +) -> Tensor: + assert M >= 0, f"Number of tokens M must be non-negative (it's {M})." + assert G > 0, f"Number of experts G must be positive (it's {G})." + + base = M // G + remainder = M % G + group_sizes = torch.full((G,), base, dtype=torch.int32, device=device) + if remainder > 0: + group_sizes[:remainder] += 1 + + assert ( + len(group_sizes) == G + ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})." + assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative." + assert ( + torch.sum(group_sizes).item() == M + ), f"Group sizes don't add up to total tokens {M}." + assert group_sizes.dtype == torch.int32, "Group sizes must be int32." + + return group_sizes + + +def gen_group_sizes( + M: int, + G: int, + device: torch.device | str = DEVICE, + rng_seed: int | None = RNG_SEED, + unused_tokens_prob: float = UNUSED_TOKENS_PROB, + unused_experts_prob: float = UNUSED_EXPERTS_PROB, +) -> Tensor: + assert M >= 0, f"Number of tokens M must be non-negative (it's {M})." + assert G > 0, f"Number of experts G must be positive (it's {G})." + assert ( + 0 <= unused_tokens_prob <= 1 + ), f"Probability of unused tokens must be in [0, 1] interval (it's {unused_tokens_prob})." + assert ( + 0 <= unused_experts_prob <= 1 + ), f"Probability of unused experts must be in [0, 1] interval (it's {unused_experts_prob})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + if unused_tokens_prob > 0: + # Optionally drop tokens to simulate routing sparsity, some tokens may not be routed. + num_unused_tokens = M + while num_unused_tokens == M: + num_unused_tokens = int( + torch.binomial( + torch.tensor(float(M), device=device), + torch.tensor(unused_tokens_prob, device=device), + ).item() + ) + else: + num_unused_tokens = 0 + num_used_tokens = M - num_unused_tokens + assert ( + num_unused_tokens >= 0 + ), f"Number of unused tokens must be non-negative (it's {num_unused_tokens})." + assert ( + num_used_tokens > 0 + ), f"Number of used tokens must be positive (it's {num_used_tokens})." + assert ( + num_used_tokens + num_unused_tokens == M + ), f"Unused + used tokens don't add up total tokens ({num_used_tokens} + {num_unused_tokens} != {M})." + + if num_unused_tokens > 0: + _LOGGER.debug( + f"Group sizes generation: dropped {num_unused_tokens} token{'s' if num_unused_tokens > 1 else ''}.", + ) + + if unused_experts_prob > 0: + # Some experts may have zero tokens assigned to them. + num_used_experts = 0 + while num_used_experts == 0: + used_experts = torch.nonzero( + torch.rand((G,), device=device) >= unused_experts_prob + ).squeeze() + num_used_experts = used_experts.numel() + else: + used_experts = torch.arange(0, G, device=device) + num_used_experts = G + num_unused_experts = G - num_used_experts + assert ( + num_unused_experts >= 0 + ), f"Number of unused experts must be non-negative (it's {num_unused_experts})." + assert ( + num_used_experts >= 1 + ), f"At least one expert must be used (it's {num_used_experts})." + assert ( + num_unused_experts + num_used_experts == G + ), f"Unused + used experts don't add up total experts ({num_unused_experts} + {num_used_experts} != {G})." + + if num_unused_experts > 0: + _LOGGER.debug( + f"Group sizes generation: dropped {num_unused_experts} expert{'s' if num_unused_experts > 1 else ''}.", + ) + + group_sizes = torch.bincount( + used_experts[ + torch.randint(low=0, high=num_used_experts, size=(num_used_tokens,)) + ], + minlength=G, + ).to(torch.int32) + + assert ( + len(group_sizes) == G + ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})." + assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative." + assert ( + torch.sum(group_sizes).item() == num_used_tokens + ), f"Group sizes don't add up to used tokens {num_used_tokens}." + assert group_sizes.dtype == torch.int32, "Group sizes must be int32." + + return group_sizes + + +def gen_multiple_group_sizes( + num_group_sizes: int, + M: int, + G: int, + device: torch.device | str = DEVICE, + rng_seed: int | None = RNG_SEED, + unused_tokens_prob: float = UNUSED_TOKENS_PROB, + unused_experts_prob: float = UNUSED_EXPERTS_PROB, + group_sizes_0: Tensor | None = None, +) -> list[Tensor]: + assert ( + num_group_sizes > 0 + ), f"Number of group sizes to be generated must be positive, it's {num_group_sizes}." + multiple_group_sizes = [ + gen_group_sizes( + M, + G, + device=device, + rng_seed=rng_seed if g == 0 else None, + unused_tokens_prob=unused_tokens_prob, + unused_experts_prob=unused_experts_prob, + ) + for g in range( + num_group_sizes if group_sizes_0 is None else num_group_sizes - 1 + ) + ] + if group_sizes_0 is not None: + multiple_group_sizes.insert(0, group_sizes_0) + assert ( + len(multiple_group_sizes) == num_group_sizes + ), f"Expecting {num_group_sizes} distinct group sizes (it's {len(multiple_group_sizes)})." + return multiple_group_sizes + + +# GMM helpers: tensor generation. +# ------------------------------------------------------------------------------ + + +def gen_gmm_input( + M: int, + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + trans_rhs: bool = TRANS_RHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, +) -> tuple[Tensor, Tensor, Tensor]: + assert M > 0, f"Number of lhs rows M must be positive (M = {M})." + assert K > 0, f"Number of lhs columns / rhs rows K must be positive (K = {K})." + assert N > 0, f"Number of rhs columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + lhs = torch.randn((M, K), dtype=torch.float32, device=device) + lhs = lhs.to(preferred_element_type) + + if trans_rhs: + rhs = torch.randn((G, N, K), dtype=torch.float32, device=device).permute( + 0, 2, 1 + ) + else: + rhs = torch.randn((G, K, N), dtype=torch.float32, device=device) + rhs = rhs.to(preferred_element_type) + + group_sizes = ( + gen_uniform_group_sizes(M, G, device=device) + if unif_group_sizes + else gen_group_sizes(M, G, device=device, rng_seed=None) + ) + + return lhs, rhs, group_sizes + + +def gen_gmm_output( + M: int, + N: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, +) -> Tensor: + assert M > 0, f"Number of out rows M must be positive (M = {M})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + + out = torch.empty((M, N), dtype=preferred_element_type, device=device) + + return out + + +def gen_gmm_tensors( + M: int, + K: int, + N: int, + G: int, + num_group_sizes: int, + device: torch.device | str = DEVICE, + input_type: torch.dtype = DTYPE, + output_type: torch.dtype = DTYPE, + trans_lhs: bool = False, + trans_rhs: bool = TRANS_RHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, + use_bias: bool = False, +) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]: + lhs, rhs, group_sizes_0 = gen_gmm_input( + M, + K, + N, + G, + device=device, + preferred_element_type=input_type, + trans_rhs=trans_rhs, + rng_seed=rng_seed, + unif_group_sizes=unif_group_sizes, + ) + multiple_group_sizes = gen_multiple_group_sizes( + num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0 + ) + out = gen_gmm_output(M, N, device=device, preferred_element_type=output_type) + bias = None + if use_bias: + torch.manual_seed(rng_seed + 1000) # Different seed for bias + bias = torch.randn(G, N, dtype=input_type, device=device) + + return lhs, rhs, multiple_group_sizes, out, bias + + +# GMM helpers: get information from tensors. +# ------------------------------------------------------------------------------ + + +def get_gmm_shape( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor +) -> tuple[int, int, int, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})." + assert ( + group_sizes.dim() == 1 + ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})." + + M, lhs_k = lhs.shape + rhs_g, rhs_k, N = rhs.shape + group_sizes_g = group_sizes.shape[0] + + assert ( + lhs_k == rhs_k + ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})." + K = lhs_k + assert ( + rhs_g == group_sizes_g + ), f"G dimension of rhs and group_sizes don't match (rhs = {rhs_g}, group_sizes = {group_sizes_g})." + G = rhs_g + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + return M, K, N, G + + +def get_gmm_output( + M: int, + N: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, +) -> Tensor: + assert M > 0, f"Number of out rows M must be positive (M = {M})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + + if existing_out is not None: + assert ( + existing_out.device == device + ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})." + assert ( + existing_out.dtype == preferred_element_type + ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})." + assert existing_out.shape == ( + M, + N, + ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(M, N)})." + return existing_out + + return gen_gmm_output( + M, + N, + device=device, + preferred_element_type=preferred_element_type, + ) + + +def get_gmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})." + assert out.dim() == 2, f"out must have 2 dimensions (it's {out.dim()})." + + lhs_m, lhs_k = lhs.shape + G, rhs_k, rhs_n = rhs.shape + out_m, out_n = out.shape + + assert ( + lhs_m == out_m + ), f"M dimension of lhs and out don't match (lhs = {lhs_m}, rhs = {out_m})." + M = lhs_m + assert ( + lhs_k == rhs_k + ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})." + K = lhs_k + assert ( + rhs_n == out_n + ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})." + N = rhs_n + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + is_lhs_row_major = lhs.stride() == (K, 1) + assert is_lhs_row_major, "lhs must be row-major." + is_rhs_row_major = rhs.stride() == (K * N, N, 1) + is_rhs_col_major = rhs.stride() == (K * N, 1, K) + assert ( + is_rhs_row_major != is_rhs_col_major + ), "rhs must be row-major or column-major." + is_out_row_major = out.stride() == (N, 1) + assert is_out_row_major, "out must be row-major." + + # Get rhs leading dimension according to transposition configuration. + ld_rhs = N if is_rhs_row_major else K + + return is_rhs_col_major, ld_rhs + + +# TGMM helpers: tensor generation. +# ------------------------------------------------------------------------------ + + +def gen_tgmm_input( + M: int, + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + trans_lhs: bool = TRANS_LHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, +) -> tuple[Tensor, Tensor, Tensor]: + assert K > 0, f"Number of lhs rows K must be positive (M = {K})." + assert M > 0, f"Number of lhs columns / rhs rows M must be positive (K = {M})." + assert N > 0, f"Number of rhs columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + if trans_lhs: + lhs = torch.randn((M, K), dtype=torch.float32, device=device).T + else: + lhs = torch.randn((K, M), dtype=torch.float32, device=device) + lhs = lhs.to(preferred_element_type) + + rhs = torch.randn((M, N), dtype=torch.float32, device=device) + rhs = rhs.to(preferred_element_type) + + group_sizes = ( + gen_uniform_group_sizes(M, G, device=device) + if unif_group_sizes + else gen_group_sizes(M, G, device=device, rng_seed=None) + ) + + return lhs, rhs, group_sizes + + +def gen_tgmm_output( + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, +) -> Tensor: + assert K > 0, f"Number of out rows K must be positive (K = {K})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + out = torch.empty((G, K, N), dtype=preferred_element_type, device=device) + + return out + + +def gen_tgmm_bias_grad( + K: int, + G: int, + device: torch.device | str = DEVICE, + with_bias_grad: bool = False, +) -> Tensor: + if with_bias_grad: + assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + return torch.empty((G, K), device=device, dtype=torch.float32) + else: + # Return dummy pointer when bias_grad is not needed. + # Must be float32 because atomic_add does not support bf16/fp16, + # and Triton validates the pointer dtype even in dead branches. + return torch.tensor([], device=device, dtype=torch.float32) + + +def gen_tgmm_tensors( + M: int, + K: int, + N: int, + G: int, + num_group_sizes: int, + device: torch.device | str = DEVICE, + input_type: torch.dtype = DTYPE, + output_type: torch.dtype = DTYPE, + trans_lhs: bool = TRANS_LHS, + trans_rhs: bool = False, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, + use_bias: bool = False, +) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]: + lhs, rhs, group_sizes_0 = gen_tgmm_input( + M, + K, + N, + G, + device=device, + preferred_element_type=input_type, + trans_lhs=trans_lhs, + rng_seed=rng_seed, + unif_group_sizes=unif_group_sizes, + ) + multiple_group_sizes = gen_multiple_group_sizes( + num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0 + ) + out = gen_tgmm_output(K, N, G, device=device, preferred_element_type=output_type) + if use_bias: + bias_grad = gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=True) + else: + bias_grad = None + return lhs, rhs, multiple_group_sizes, out, bias_grad + + +# TGMM helpers: get information from tensors. +# ------------------------------------------------------------------------------ + + +def get_tgmm_shape( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor +) -> tuple[int, int, int, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})." + assert ( + group_sizes.dim() == 1 + ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})." + + K, lhs_m = lhs.shape + rhs_m, N = rhs.shape + G = group_sizes.shape[0] + + assert ( + lhs_m == rhs_m + ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})." + M = lhs_m + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + return M, K, N, G + + +def get_tgmm_output( + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, +) -> Tensor: + assert K > 0, f"Number of out rows K must be positive (K = {K})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if existing_out is not None: + assert ( + existing_out.device == device + ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})." + assert ( + existing_out.dtype == preferred_element_type + ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})." + assert existing_out.shape == ( + G, + K, + N, + ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(G, K, N)})." + return existing_out + + return gen_tgmm_output( + K, + N, + G, + device=device, + preferred_element_type=preferred_element_type, + ) + + +def get_tgmm_bias_grad( + K: int, + G: int, + device: torch.device | str = DEVICE, + existing_bias_grad: Tensor | None = None, +) -> Tensor: + """ + Get or validate bias gradient tensor for TGMM. + + If existing_bias_grad is provided, validates its shape, device, dtype, and stride, + and always zeros it before returning (since the kernel uses atomic_add). + If existing_bias_grad is None, returns a dummy tensor (for use when COMPUTE_BIAS_GRAD=False). + Parameters + ---------- + K : int + Number of rows in the bias gradient tensor. + G : int + Number of groups. + device : torch.device or str + Device for the tensor. + existing_bias_grad : torch.Tensor or None + Existing bias gradient tensor to validate and use. + Returns + ------- + torch.Tensor + Valid bias gradient tensor or dummy tensor. + """ + assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if existing_bias_grad is not None: + # Validate existing bias_grad tensor. + expected_shape = (G, K) + assert ( + tuple(existing_bias_grad.shape) == expected_shape + ), f"bias_grad must have shape {expected_shape}, got {tuple(existing_bias_grad.shape)}." + assert ( + existing_bias_grad.device == device + ), f"bias_grad must be on the same device (bias_grad = {existing_bias_grad.device}, device = {device})." + assert ( + existing_bias_grad.dtype == torch.float32 + ), f"bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), got {existing_bias_grad.dtype}." + assert existing_bias_grad.stride() == ( + K, + 1, + ), f"bias_grad must be row-major with stride (K, 1) = ({K}, 1), got {existing_bias_grad.stride()}." + + # Always zero the tensor since bias_grad represents gradients for the current + # computation and should start fresh. The kernel uses atomic_add which adds to + # existing values, so we must zero before the kernel runs. + existing_bias_grad.zero_() + + return existing_bias_grad + + else: + return gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=False) + + +def get_tgmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})." + assert out.dim() == 3, f"out must have 3 dimensions (it's {out.dim()})." + + lhs_k, lhs_m = lhs.shape + rhs_m, rhs_n = rhs.shape + G, out_k, out_n = out.shape + + assert ( + lhs_m == rhs_m + ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})." + M = lhs_m + assert ( + lhs_k == out_k + ), f"K dimension of lhs and out don't match (lhs = {lhs_k}, rhs = {out_k})." + K = lhs_k + assert ( + rhs_n == out_n + ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})." + N = rhs_n + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + is_lhs_row_major = lhs.stride() == (M, 1) + is_lhs_col_major = lhs.stride() == (1, K) + assert ( + is_lhs_row_major != is_lhs_col_major + ), "lhs must be row-major or column-major." + is_rhs_row_major = rhs.stride() == (N, 1) + assert is_rhs_row_major, "rhs must be row-major." + is_out_row_major = out.stride() == (K * N, N, 1) + assert is_out_row_major, "out must be row-major." + + # Get lhs leading dimension according to transposition configuration. + ld_lhs = M if is_lhs_row_major else K + + return is_lhs_col_major, ld_lhs diff --git a/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/utils/logger.py b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..391ddf9b6543f5244e7f4932c8568d60748e15cd --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/_grouped_gemm_triton/utils/logger.py @@ -0,0 +1,47 @@ +import os +import logging + + +# AITER Triton Logger which is singleton object around python logging. +# Note: Python logging is also a singleton object, but we want to read the +# env var AITER_LOG_LEVEL once at the beginning. Another alternative is to do +# this in __init__.py. In fact, that's how CK logger is setup. We can look at +# switching to that at some point +# +# AITER_LOG_LEVEL follows python logging levels +# DEBUG +# INFO +# WARNING +# ERROR +# CRITICAL +# +class AiterTritonLogger(object): + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(AiterTritonLogger, cls).__new__(cls) + log_level_str = os.getenv("AITER_TRITON_LOG_LEVEL", "WARNING").upper() + numeric_level = getattr(logging, log_level_str, logging.WARNING) + cls._instance._logger = logging.getLogger("AITER_TRITON") + cls._instance._logger.setLevel(numeric_level) + + return cls._instance + + def get_logger(self): + return self._logger + + def debug(self, msg): + self._logger.debug(msg) + + def info(self, msg): + self._logger.info(msg) + + def warning(self, msg): + self._logger.warning(msg) + + def error(self, msg): + self._logger.error(msg) + + def critical(self, msg): + self._logger.critical(msg) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/_megablocks_cuda_ae601bb.abi3.so b/build/torch212-cxx11-cu132-x86_64-linux/_megablocks_cuda_ae601bb.abi3.so deleted file mode 100644 index 83cd0cb8c2b0a02d06f3ca50ec6cc7d858fb64b3..0000000000000000000000000000000000000000 --- a/build/torch212-cxx11-cu132-x86_64-linux/_megablocks_cuda_ae601bb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:807a5230a22b932e3a6c4ef230bfad68b096aa7c532de2b5752b34e9769523d9 -size 10402512 diff --git a/build/torch212-cxx11-cu132-x86_64-linux/_megablocks_cuda_f8f8b50.abi3.so b/build/torch212-cxx11-cu132-x86_64-linux/_megablocks_cuda_f8f8b50.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..f0a1c79c6a0637f39294468c4569ccbbc2d2f73d --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/_megablocks_cuda_f8f8b50.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec4902f27d9e29c4debdcab985fa0ef63ff71e2c762767bf2789ddd780720f71 +size 12079912 diff --git a/build/torch212-cxx11-cu132-x86_64-linux/_ops.py b/build/torch212-cxx11-cu132-x86_64-linux/_ops.py index 8dd1b7bcf680d2d32dd4ac912487118eafcee4ea..69afb8c26a3fa2691be277b0270d600d29a5865e 100644 --- a/build/torch212-cxx11-cu132-x86_64-linux/_ops.py +++ b/build/torch212-cxx11-cu132-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_cuda_ae601bb -ops = torch.ops._megablocks_cuda_ae601bb +from . import _megablocks_cuda_f8f8b50 +ops = torch.ops._megablocks_cuda_f8f8b50 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_cuda_ae601bb::{op_name}" + return f"_megablocks_cuda_f8f8b50::{op_name}" diff --git a/build/torch212-cxx11-cu132-x86_64-linux/grouped_gemm/backend.py b/build/torch212-cxx11-cu132-x86_64-linux/grouped_gemm/backend.py index 76037d8039cbfc2f0577275c78e4bc0be762592a..c7ef28ced79c830dae934177f059c1f4ddc24aad 100644 --- a/build/torch212-cxx11-cu132-x86_64-linux/grouped_gemm/backend.py +++ b/build/torch212-cxx11-cu132-x86_64-linux/grouped_gemm/backend.py @@ -2,16 +2,16 @@ # extensions. Otherwise libc10.so cannot be found. import torch -# # TODO(tgale): Wrap this in a try-block with better -# # error message and instructions for building the -# # c++ operations. -# import grouped_gemm_backend as backend +# On ROCm there is no CUTLASS grouped GEMM; dispatch to the vendored AITER +# Triton kernels instead. On CUDA we use the compiled CUTLASS `gmm` op. +_IS_ROCM = torch.version.hip is not None -# We import the backend operations from the megablocks package as -# grouped_gemm is vendored in megablocks in this repository. -# from ... import _ops as backend -# from megablocks._ops import ops as backend # type: ignore -from .._ops import ops as backend # type: ignore +if _IS_ROCM: + from .._grouped_gemm_triton import adapter as backend +else: + # We import the backend operations from the megablocks package as + # grouped_gemm is vendored in megablocks in this repository. + from .._ops import ops as backend # type: ignore def _allocate_output(a, b, batch_sizes, trans_a, trans_b): assert not (trans_a and trans_b) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/metadata.json b/build/torch212-cxx11-cu132-x86_64-linux/metadata.json index dae1319c841f27d4cd7a5a4b31fbde6ae4d4cacd..436ad3fc85ff69b069290830671db574d1045671 100644 --- a/build/torch212-cxx11-cu132-x86_64-linux/metadata.json +++ b/build/torch212-cxx11-cu132-x86_64-linux/metadata.json @@ -1,6 +1,6 @@ { "name": "megablocks", - "id": "_megablocks_cuda_ae601bb", + "id": "_megablocks_cuda_f8f8b50", "version": 1, "license": "Apache-2.0", "python-depends": [], @@ -8,7 +8,9 @@ "type": "cuda", "archs": [ "10.0", + "11.0", "12.0", + "12.0+PTX", "7.5", "8.0", "8.6",