| import torch |
| from typing import Tuple |
|
|
| from .tuner import jit_tuner |
| from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout |
|
|
| |
| includes = ('"deep_gemm/fp8_gemm.cuh"', ) |
| template = """ |
| using namespace deep_gemm; |
| |
| // Templated args from Python JIT call |
| constexpr auto N = {N}, K = {K}; |
| constexpr auto BLOCK_M = {BLOCK_M}; |
| constexpr auto BLOCK_N = {BLOCK_N}; |
| constexpr auto kNumStages = {NUM_STAGES}; |
| constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; |
| |
| // Make a templated GEMM |
| using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, 1, kNumStages, kNumTMAMulticast, GemmType::Normal>; |
| |
| // Launch kernel |
| auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); |
| auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs); |
| auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m); |
| auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m); |
| GemmType::run(out, rhs_scales, nullptr, |
| m, |
| tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, |
| stream, num_sms, smem_size); |
| """ |
|
|
|
|
| def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms: int) -> bool: |
| if num_tma_multicast == 1: |
| return True |
| return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0 |
|
|
|
|
| def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int: |
| smem_d = block_m * block_n * 2 |
| smem_a_per_stage = block_m * block_k |
| smem_scales_a_per_stage = block_m * 4 |
| smem_b_per_stage = block_n * block_k |
| smem_scales_b = ceil_div(k, block_k) * 4 |
| smem_barrier = num_stages * 8 * 2 |
|
|
| smem_size = 0 |
| smem_size += smem_d |
| smem_size += num_stages * smem_a_per_stage |
| smem_size += num_stages * smem_scales_a_per_stage |
| smem_size += num_stages * smem_b_per_stage |
| smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 |
| smem_size += smem_barrier |
| return smem_size |
|
|
|
|
| def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, |
| is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int]: |
| if not is_grouped_contiguous: |
| |
| block_ms = (64 if m <= 64 else 128, ) |
| else: |
| block_ms = (get_m_alignment_for_contiguous_layout(), ) |
| block_ns = tuple(range(16, 129, 8)) |
|
|
| fix_wave_saturate = lambda x: num_sms if x == 0 else x |
| get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) |
| get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms) |
|
|
| |
| best_block_m, best_block_n = None, None |
| for block_m in block_ms: |
| for block_n in block_ns: |
| success = False |
| num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) |
| if best_block_m is None or best_block_n is None: |
| success = True |
| elif num_waves < best_num_waves: |
| success = True |
| elif num_waves == best_num_waves: |
| |
| util = get_last_wave_util(block_m, block_n) |
| best_util = get_last_wave_util(best_block_m, best_block_n) |
| success = util > best_util or (util == best_util and (block_m > best_block_m or (block_m == best_block_m and block_n < best_block_n))) |
| best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) |
| assert best_block_m is not None and best_block_n is not None |
|
|
| |
| |
| best_num_stages, best_smem_size, sm90_capacity = None, None, 232448 |
| for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4): |
| best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) |
| if best_smem_size <= sm90_capacity: |
| best_num_stages = num_stages |
| break |
| assert best_num_stages is not None |
|
|
| |
| best_num_tma_multicast = 1 |
| if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1: |
| best_num_tma_multicast = 2 |
|
|
| return best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size |
|
|
|
|
| def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], |
| rhs: Tuple[torch.Tensor, torch.Tensor], |
| out: torch.Tensor) -> None: |
| """ |
| Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. |
| LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. |
| RHS and RHS scaling factors are required to be transposed. |
| The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, |
| this function will do a transposing with a set of slow PyTorch operations. |
| |
| Arguments: |
| lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, |
| the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. |
| rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`. |
| the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`. |
| out: the BF16 output tensor of shape `[m, n]`, representing the result. |
| """ |
| lhs, lhs_scales = lhs |
| rhs, rhs_scales = rhs |
| m, k = lhs.shape |
| n, k_ = rhs.shape |
| m_, n_ = out.shape |
|
|
| assert n % 64 == 0 and k % 128 == 0 |
|
|
| |
| assert m == m_ and n == n_ and k == k_ |
| assert n > 0 and k > 0 |
| assert lhs_scales.shape == (m, (k + 127) // 128) |
| assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128) |
| assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 |
| assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 |
| assert out.dtype == torch.bfloat16 |
| assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous() |
|
|
| |
| |
| lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) |
| assert rhs_scales.is_contiguous() |
|
|
| |
| if m == 0: |
| return |
|
|
| |
| global includes, template |
| num_sms = get_num_sms() |
| block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms) |
| args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size) |
| runtime = jit_tuner.compile_and_tune( |
| name='gemm_fp8_fp8_bf16_nt', |
| keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, |
| 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast}, |
| space=(), |
| includes=includes, |
| arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), |
| ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), |
| ('out', torch.bfloat16), ('m', int), |
| ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), |
| template=template, |
| args=args |
| ) |
|
|
| |
| runtime(*args) |
|
|