| | """ |
| | Test data generators for DeepGEMM kernel tests. |
| | |
| | Adapted from the original DeepGEMM test suite to work with the ported |
| | kernels-community package. |
| | """ |
| | import enum |
| | import random |
| | import torch |
| | from typing import Generator, List, Optional, Tuple |
| |
|
| | from deep_gemm.testing import get_arch_major |
| | from deep_gemm.utils import ( |
| | align, ceil_div, |
| | per_token_cast_to_fp8, per_channel_cast_to_fp8, per_block_cast_to_fp8, |
| | per_token_cast_to_fp4, transpose_packed_fp4, |
| | get_mk_alignment_for_contiguous_layout |
| | ) |
| |
|
| |
|
| | class KernelType(enum.Enum): |
| | Kernel1D1D = 0 |
| | Kernel1D2D = 1 |
| | KernelNoSF = 2 |
| |
|
| | def is_1d1d(self): |
| | return self.value == 0 |
| |
|
| | def is_1d2d(self): |
| | return self.value == 1 |
| |
|
| | def is_nosf(self): |
| | return self.value == 2 |
| |
|
| |
|
| | class MajorTypeAB(enum.Enum): |
| | KMajor = 0 |
| | MNMajor = 1 |
| |
|
| | def is_k_major(self): |
| | return self.value == 0 |
| |
|
| | def is_mn_major(self): |
| | return self.value == 1 |
| |
|
| |
|
| | class QuantConfig: |
| | _legacy_quant_config = (128, 128, False, False) |
| |
|
| | def __init__(self, value: Tuple[int, int, bool, bool] = _legacy_quant_config): |
| | self.gran_k_a, self.gran_k_b, self.is_fp4_a, self.is_fp4_b = value |
| |
|
| | def is_legacy(self) -> bool: |
| | return (self.gran_k_a, self.gran_k_b, self.is_fp4_a, self.is_fp4_b) == self._legacy_quant_config |
| |
|
| | def get_recipes(self, is_wgrad: bool = False) -> Tuple[Tuple, Tuple, Tuple]: |
| | recipe, recipe_a, recipe_b = None, None, None |
| | if self.is_legacy(): |
| | recipe = (1, 1, 128) if is_wgrad else None |
| | else: |
| | recipe_a = (1, self.gran_k_a) |
| | recipe_b = (1, self.gran_k_b) if self.is_fp4_b or is_wgrad else (self.gran_k_b, self.gran_k_b) |
| | return recipe, recipe_a, recipe_b |
| |
|
| | def max_diff(self) -> float: |
| | if self.is_fp4_a and self.is_fp4_b: |
| | return 0.02 |
| | if self.is_fp4_a or self.is_fp4_b: |
| | return 0.01 |
| | return 0.001 |
| |
|
| | @staticmethod |
| | def get_list_from_dtype(dtype: torch.dtype) -> List: |
| | if dtype == torch.bfloat16: |
| | return [None] |
| | quant_config_list = [QuantConfig()] |
| | if get_arch_major() == 10: |
| | quant_config_list.append(QuantConfig((128, 32, False, True))) |
| | return quant_config_list |
| |
|
| |
|
| | def reset_seed(seed: int = 0): |
| | random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| |
|
| |
|
| | def get_ue8m0_usage(kernel_type: KernelType) -> bool: |
| | if get_arch_major() == 9: |
| | return False |
| | return kernel_type.is_1d1d() |
| |
|
| |
|
| | def get_kernel_types(dtype: torch.dtype) -> tuple: |
| | if dtype == torch.bfloat16: |
| | return (KernelType.KernelNoSF, ) |
| | return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, ) |
| |
|
| |
|
| | def get_major_ab(allow_a_mn_major: bool, allow_b_mn_major: bool) -> Generator: |
| | for major_a in (MajorTypeAB.KMajor, MajorTypeAB.MNMajor): |
| | for major_b in (MajorTypeAB.KMajor, MajorTypeAB.MNMajor): |
| | if major_a.is_mn_major() and not allow_a_mn_major: |
| | continue |
| | if major_b.is_mn_major() and not allow_b_mn_major: |
| | continue |
| | yield major_a, major_b |
| |
|
| |
|
| | def get_psum_layout_usage() -> tuple: |
| | return (False, True) if get_arch_major() == 10 else (False, ) |
| |
|
| |
|
| | def cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is_fp4: bool, |
| | use_ue8m0: bool, use_block_cast_for_fp8: bool = False): |
| | if is_fp4: |
| | x_fp4 = per_token_cast_to_fp4(x, use_ue8m0=use_ue8m0, gran_k=gran_k) |
| | x = x_fp4 if major.is_k_major() else (transpose_packed_fp4(x_fp4[0]).T, x_fp4[1]) |
| | else: |
| | x_fp8 = per_block_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \ |
| | else per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k) |
| | x = x_fp8 if major.is_k_major() else (x_fp8[0].T.contiguous().T, x_fp8[1]) |
| | return x |
| |
|
| |
|
| | def grouped_cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is_fp4: bool, |
| | use_ue8m0: bool, use_block_cast_for_fp8: bool = False): |
| | num_groups, mn, k = x.size() |
| | if is_fp4: |
| | x_fp4 = (torch.empty((num_groups, mn, k // 2), device='cuda', dtype=torch.uint8) if major.is_k_major() else |
| | torch.empty((num_groups, k, mn // 2), device='cuda', dtype=torch.uint8), |
| | torch.empty((num_groups, mn, ceil_div(k, gran_k)), device='cuda', dtype=torch.float)) |
| | for i in range(num_groups): |
| | x_i_fp4 = per_token_cast_to_fp4(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) |
| | x_fp4[0][i], x_fp4[1][i] = x_i_fp4 if major.is_k_major() else (transpose_packed_fp4(x_i_fp4[0]), x_i_fp4[1]) |
| | x = x_fp4 if major.is_k_major() else (x_fp4[0].mT, x_fp4[1]) |
| | else: |
| | x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), |
| | torch.empty((num_groups, ceil_div(mn, gran_k), ceil_div(k, gran_k)), device='cuda', dtype=torch.float) if use_block_cast_for_fp8 |
| | else torch.empty((num_groups, mn, ceil_div(k, gran_k)), device='cuda', dtype=torch.float)) |
| | for i in range(num_groups): |
| | x_fp8[0][i], x_fp8[1][i] = per_block_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \ |
| | else per_token_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) |
| | x = x_fp8 if major.is_k_major() else (x_fp8[0].mT.contiguous().mT, x_fp8[1]) |
| | return x |
| |
|
| |
|
| | def generate_normal(m: int, n: int, k: int, |
| | major_a: MajorTypeAB, major_b: MajorTypeAB, |
| | accumulate: bool, out_dtype: torch.dtype, |
| | kernel_type: KernelType, |
| | use_ue8m0: bool = False, use_bf16: bool = False, |
| | quant_config: Optional[QuantConfig] = None): |
| | a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) |
| | b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) |
| | d = torch.randn((m, n), device='cuda', dtype=out_dtype) * 32 if accumulate else \ |
| | torch.empty((m, n), device='cuda', dtype=out_dtype) |
| | c = d if accumulate else None |
| | ref_d = (a.float() @ b.float().t() + (c if accumulate else 0)).to(out_dtype) |
| |
|
| | if use_bf16: |
| | a = a if major_a.is_k_major() else a.T.contiguous().T |
| | b = b if major_b.is_k_major() else b.T.contiguous().T |
| | return a, b, c, d, ref_d |
| |
|
| | quant_config = QuantConfig() if quant_config is None else quant_config |
| | a = cast_fp8_fp4_with_major(a, major_a, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0) |
| | b = cast_fp8_fp4_with_major(b, major_b, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, |
| | use_block_cast_for_fp8=not (kernel_type.is_1d1d() and accumulate)) |
| |
|
| | return a, b, c, d, ref_d |
| |
|
| |
|
| | def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: int, k: int, |
| | major_a: MajorTypeAB, major_b: MajorTypeAB, |
| | use_ue8m0: bool = False, use_bf16: bool = False, |
| | use_psum_layout: bool = False, |
| | quant_config: Optional[QuantConfig] = None): |
| | actual_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] |
| | aligned_ms = [align(actual_m, get_mk_alignment_for_contiguous_layout()) for actual_m in actual_ms] |
| | m = sum(aligned_ms) |
| |
|
| | a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) |
| | b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) |
| | grouped_layout = torch.empty(num_groups, device='cuda', dtype=torch.int32) if use_psum_layout \ |
| | else torch.empty(m, device='cuda', dtype=torch.int32) |
| | d = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) |
| | ref_d = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) |
| |
|
| | start = 0 |
| | for i, (actual_m, aligned_m) in enumerate(zip(actual_ms, aligned_ms)): |
| | actual_end = start + actual_m |
| | aligned_end = start + aligned_m |
| | if use_psum_layout: |
| | grouped_layout[i] = actual_end |
| | else: |
| | grouped_layout[start: actual_end] = i |
| | grouped_layout[actual_end: aligned_end] = -1 |
| | a[actual_end: aligned_end] = 0 |
| | ref_d[start: aligned_end] = a[start: aligned_end] @ b[i].t() |
| | start = aligned_end |
| |
|
| | if use_bf16: |
| | b = b if major_b.is_k_major() else b.mT.contiguous().mT |
| | return m, a, b, grouped_layout, d, ref_d |
| |
|
| | assert major_a.is_k_major() |
| | quant_config = QuantConfig() if quant_config is None else quant_config |
| | a = cast_fp8_fp4_with_major(a, major_a, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0) |
| | b = grouped_cast_fp8_fp4_with_major(b, major_b, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, |
| | use_block_cast_for_fp8=True) |
| |
|
| | return m, a, b, grouped_layout, d, ref_d |
| |
|
| |
|
| | def layout_masked_to_psum(x: torch.Tensor, psum_m: torch.Tensor): |
| | num_groups, max_m, _ = x.size() |
| | x_psum = torch.empty_like(x).view(num_groups * max_m, -1) |
| | last_psum_m = 0 |
| | for i in range(num_groups): |
| | x_psum[last_psum_m: psum_m[i]] = x[i, :psum_m[i] - last_psum_m] |
| | last_psum_m = align(psum_m[i], 128) |
| | return x_psum |
| |
|
| |
|
| | def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, |
| | use_ue8m0: bool = False, use_bf16: bool = False, |
| | use_psum_layout: bool = False, |
| | quant_config: Optional[QuantConfig] = None): |
| | a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16) |
| | b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) |
| | d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16) |
| | ref_d = torch.einsum('gmk,gnk->gmn', a, b) |
| |
|
| | masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) |
| | psum_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) |
| | for j in range(num_groups): |
| | masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) |
| | psum_m[j] = (0 if j == 0 else align(psum_m[j - 1], 128)) + masked_m[j] |
| | assert masked_m.amax().item() <= max_m |
| |
|
| | if use_bf16: |
| | return a, b, masked_m, psum_m, d, ref_d |
| |
|
| | quant_config = QuantConfig() if quant_config is None else quant_config |
| | a = grouped_cast_fp8_fp4_with_major(a, MajorTypeAB.KMajor, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0) |
| | b = grouped_cast_fp8_fp4_with_major(b, MajorTypeAB.KMajor, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, |
| | use_block_cast_for_fp8=True) |
| |
|
| | return a, b, masked_m, psum_m, d, ref_d |
| |
|
| |
|
| | def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, |
| | major_a: MajorTypeAB, major_b: MajorTypeAB, |
| | ks: List[int], |
| | use_ue8m0: bool = False, use_bf16: bool = False): |
| | assert get_mk_alignment_for_contiguous_layout() % 128 == 0 |
| | k = sum(ks) |
| |
|
| | a = torch.randn((k, m), device='cuda', dtype=torch.bfloat16) |
| | b = torch.randn((k, n), device='cuda', dtype=torch.bfloat16) |
| | c = torch.randn((num_groups, m, n), device='cuda', dtype=torch.float) * 32 |
| | d = c |
| | ref_d = torch.empty_like(c) |
| |
|
| | start = 0 |
| | for i, group_k in enumerate(ks): |
| | end = start + group_k |
| | ref_d[i] = c[i] + (a[start:end].T @ b[start:end]) |
| | start = end |
| |
|
| | if use_bf16: |
| | assert (major_a, major_b) == (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor) |
| | return k, a, b, c, d, ref_d |
| |
|
| | a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0) |
| | b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0) |
| |
|
| | if (major_a, major_b) == (MajorTypeAB.KMajor, MajorTypeAB.KMajor): |
| | a, sfa = a_fp8 |
| | b, sfb = b_fp8 |
| | new_a = torch.empty((sum(ks) * m, ), dtype=a.dtype, device=a.device) |
| | new_b = torch.empty((sum(ks) * n, ), dtype=b.dtype, device=b.device) |
| | prefix = 0 |
| | for K in ks: |
| | new_a[prefix * m : (prefix + K) * m] = a[prefix : prefix + K, ].T.flatten() |
| | new_b[prefix * n : (prefix + K) * n] = b[prefix : prefix + K, ].T.flatten() |
| | prefix += K |
| | a_fp8, b_fp8 = (new_a, sfa.T), (new_b, sfb.T) |
| | else: |
| | assert (major_a, major_b) == (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor) |
| |
|
| | return k, a_fp8, b_fp8, c, d, ref_d |
| |
|