| """ |
| Tests for FP8/FP4 GEMM kernels. |
| |
| Tests correctness of: |
| - fp8_fp4_gemm_nt (and layout aliases) |
| - m_grouped_fp8_fp4_gemm_nt_contiguous |
| - m_grouped_fp8_fp4_gemm_nt_masked |
| - k_grouped_fp8_gemm_nt_contiguous / k_grouped_fp8_gemm_tn_contiguous |
| """ |
| import copy |
| import random |
| import pytest |
| import torch |
|
|
| import deep_gemm |
| from deep_gemm.testing import calc_diff, get_arch_major |
| from generators import ( |
| MajorTypeAB, KernelType, QuantConfig, |
| reset_seed, align, |
| get_kernel_types, get_ue8m0_usage, |
| generate_normal, generate_m_grouped_contiguous, |
| generate_m_grouped_masked, generate_k_grouped_contiguous, |
| layout_masked_to_psum, get_psum_layout_usage, |
| get_mk_alignment_for_contiguous_layout |
| ) |
|
|
| cuda_available = torch.cuda.is_available() |
| requires_sm90 = pytest.mark.skipif( |
| not cuda_available or get_arch_major() < 9, |
| reason="Requires SM90+ (Hopper or newer)" |
| ) |
|
|
|
|
| def _get_default_kernel_type(): |
| return get_kernel_types(torch.float8_e4m3fn)[0] |
|
|
|
|
| |
| |
| |
| FP8_GEMM_FWD_SHAPES = [ |
| |
| (1, 2112, 7168, False, torch.bfloat16), |
| (128, 576, 7168, False, torch.bfloat16), |
| (4096, 4096, 7168, False, torch.bfloat16), |
| (4096, 7168, 2048, False, torch.bfloat16), |
| (128, 7168, 16384, False, torch.bfloat16), |
| ] |
|
|
|
|
| @requires_sm90 |
| @pytest.mark.parametrize("m,n,k,accumulate,out_dtype", FP8_GEMM_FWD_SHAPES) |
| def test_fp8_gemm_nt(m, n, k, accumulate, out_dtype): |
| """Test standard FP8 GEMM with NT layout (forward pass).""" |
| reset_seed() |
| kernel_type = _get_default_kernel_type() |
| use_ue8m0 = get_ue8m0_usage(kernel_type) |
| quant_config = QuantConfig() |
| recipe, recipe_a, recipe_b = quant_config.get_recipes() |
|
|
| a, b, c, d, ref_d = generate_normal( |
| m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, |
| accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0, |
| quant_config=quant_config |
| ) |
| deep_gemm.fp8_fp4_gemm_nt( |
| a, b, d, c=c, disable_ue8m0_cast=not use_ue8m0, |
| recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b |
| ) |
| diff = calc_diff(d, ref_d) |
| assert diff < quant_config.max_diff(), f"{m=}, {n=}, {k=}, {diff:.5f}" |
|
|
|
|
| |
| FP8_GEMM_BWD_SHAPES = [ |
| |
| (4096, 7168, 2112, False, torch.bfloat16), |
| |
| (2112, 4096, 7168, True, torch.float), |
| (2112, 4096, 7168, False, torch.bfloat16), |
| ] |
|
|
|
|
| @requires_sm90 |
| @pytest.mark.parametrize("m,n,k,accumulate,out_dtype", FP8_GEMM_BWD_SHAPES) |
| def test_fp8_gemm_nt_backward_shapes(m, n, k, accumulate, out_dtype): |
| """Test FP8 GEMM with backward-pass-like shapes.""" |
| reset_seed() |
| kernel_type = _get_default_kernel_type() |
| use_ue8m0 = get_ue8m0_usage(kernel_type) |
| quant_config = QuantConfig() |
| recipe, recipe_a, recipe_b = quant_config.get_recipes(is_wgrad=accumulate) |
|
|
| |
| major_b = MajorTypeAB.MNMajor if get_arch_major() != 9 else MajorTypeAB.KMajor |
| override_kernel_type = kernel_type |
| if get_arch_major() == 9: |
| major_b = MajorTypeAB.KMajor |
| override_kernel_type = KernelType.Kernel1D1D if accumulate else kernel_type |
|
|
| a, b, c, d, ref_d = generate_normal( |
| m, n, k, MajorTypeAB.KMajor, major_b, |
| accumulate, out_dtype, override_kernel_type if accumulate else kernel_type, |
| use_ue8m0=use_ue8m0, quant_config=quant_config |
| ) |
| deep_gemm.fp8_fp4_gemm_nt( |
| a, b, d, c=c, disable_ue8m0_cast=not use_ue8m0, |
| recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b |
| ) |
| diff = calc_diff(d, ref_d) |
| assert diff < quant_config.max_diff(), f"{m=}, {n=}, {k=}, {accumulate=}, {diff:.5f}" |
|
|
|
|
| @requires_sm90 |
| @pytest.mark.parametrize("layout_name,func_name", [ |
| ("nn", "fp8_fp4_gemm_nn"), |
| ("tn", "fp8_fp4_gemm_tn"), |
| ("tt", "fp8_fp4_gemm_tt"), |
| ]) |
| def test_fp8_gemm_aliases(layout_name, func_name): |
| """Test FP8 GEMM layout aliases with contiguous inputs.""" |
| reset_seed() |
| m, n, k = 128, 4096, 7168 |
| kernel_type = _get_default_kernel_type() |
| use_ue8m0 = get_ue8m0_usage(kernel_type) |
| quant_config = QuantConfig() |
| recipe, recipe_a, recipe_b = quant_config.get_recipes() |
|
|
| major_a = MajorTypeAB.MNMajor if layout_name[0] == 't' else MajorTypeAB.KMajor |
| major_b = MajorTypeAB.MNMajor if layout_name[1] == 'n' else MajorTypeAB.KMajor |
|
|
| a, b, c, d, ref_d = generate_normal( |
| m, n, k, major_a, major_b, False, torch.bfloat16, kernel_type, |
| use_ue8m0=use_ue8m0, quant_config=quant_config |
| ) |
| |
| a = a if major_a.is_k_major() else (a[0].T, a[1].T) |
| b = b if major_b.is_k_major() else (b[0].T, b[1].T) |
| assert a[0].is_contiguous() and b[0].is_contiguous() |
|
|
| getattr(deep_gemm, func_name)( |
| a, b, d, disable_ue8m0_cast=not use_ue8m0, |
| recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b |
| ) |
| diff = calc_diff(d, ref_d) |
| assert diff < quant_config.max_diff(), f"{layout_name=}, {diff:.5f}" |
|
|
|
|
| |
| |
| |
| M_GROUPED_CONT_PARAMS = [ |
| |
| (4, 8192, 6144, 7168), |
| (8, 4096, 7168, 3072), |
| (4, 8192, 4096, 4096), |
| ] |
|
|
|
|
| @requires_sm90 |
| @pytest.mark.parametrize("num_groups,expected_m,n,k", M_GROUPED_CONT_PARAMS) |
| def test_m_grouped_fp8_gemm_nt_contiguous(num_groups, expected_m, n, k): |
| """Test m-grouped contiguous FP8 GEMM.""" |
| reset_seed() |
| kernel_type = _get_default_kernel_type() |
| use_ue8m0 = get_ue8m0_usage(kernel_type) |
| quant_config = QuantConfig() |
| recipe, recipe_a, recipe_b = quant_config.get_recipes() |
|
|
| major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor |
| m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous( |
| num_groups, expected_m, n, k, major_a, major_b, |
| use_ue8m0=use_ue8m0, quant_config=quant_config |
| ) |
| deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous( |
| a, b, d, grouped_layout, |
| disable_ue8m0_cast=not use_ue8m0, |
| recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b |
| ) |
| diff = calc_diff(d, ref_d) |
| assert diff < quant_config.max_diff(), f"{m=}, {n=}, {k=}, {diff:.5f}" |
|
|
|
|
| @requires_sm90 |
| @pytest.mark.parametrize("num_groups,expected_m,n,k", M_GROUPED_CONT_PARAMS[:2]) |
| def test_m_grouped_fp8_gemm_nn_contiguous_alias(num_groups, expected_m, n, k): |
| """Test m-grouped contiguous FP8 GEMM with NN alias.""" |
| reset_seed() |
| kernel_type = _get_default_kernel_type() |
| use_ue8m0 = get_ue8m0_usage(kernel_type) |
| quant_config = QuantConfig() |
| recipe, recipe_a, recipe_b = quant_config.get_recipes() |
|
|
| allow_b_mn = get_arch_major() != 9 |
| if not allow_b_mn: |
| pytest.skip("NN alias requires B MN-major support (SM100+)") |
|
|
| major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.MNMajor |
| m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous( |
| num_groups, expected_m, n, k, major_a, major_b, |
| use_ue8m0=use_ue8m0, quant_config=quant_config |
| ) |
| b = b if major_b.is_k_major() else (b[0].mT, b[1].mT) |
| assert a[0].is_contiguous() and b[0].is_contiguous() |
|
|
| deep_gemm.m_grouped_fp8_fp4_gemm_nn_contiguous( |
| a, b, d, grouped_layout, |
| disable_ue8m0_cast=not use_ue8m0, |
| recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b |
| ) |
| diff = calc_diff(d, ref_d) |
| assert diff < quant_config.max_diff(), f"{m=}, {n=}, {k=}, {diff:.5f}" |
|
|
|
|
| |
| |
| |
| M_GROUPED_MASKED_PARAMS = [ |
| |
| (6, 4096, 1024, 6144, 7168), |
| (32, 4096, 192, 7168, 3072), |
| (32, 4096, 50, 4096, 4096), |
| ] |
|
|
|
|
| @requires_sm90 |
| @pytest.mark.parametrize("num_groups,max_m,expected_m,n,k", M_GROUPED_MASKED_PARAMS) |
| def test_m_grouped_fp8_gemm_nt_masked(num_groups, max_m, expected_m, n, k): |
| """Test m-grouped masked FP8 GEMM.""" |
| reset_seed() |
| kernel_type = _get_default_kernel_type() |
| use_ue8m0 = get_ue8m0_usage(kernel_type) |
| quant_config = QuantConfig() |
| recipe, recipe_a, recipe_b = quant_config.get_recipes() |
|
|
| a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked( |
| num_groups, max_m, expected_m, n, k, |
| use_ue8m0=use_ue8m0, quant_config=quant_config |
| ) |
| deep_gemm.m_grouped_fp8_fp4_gemm_nt_masked( |
| a, b, d, masked_m, expected_m, |
| disable_ue8m0_cast=not use_ue8m0, |
| recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b |
| ) |
|
|
| for j in range(num_groups): |
| mj = masked_m[j].item() |
| if mj == 0: |
| continue |
| diff = calc_diff(d[j, :mj], ref_d[j, :mj]) |
| assert diff < quant_config.max_diff(), ( |
| f"{max_m=}, {n=}, {k=}, group={j}, masked_m={mj}, {diff:.5f}" |
| ) |
|
|
|
|
| |
| |
| |
| K_GROUPED_PARAMS = [ |
| |
| (4, 4096, 7168, 8192), |
| (8, 4096, 7168, 4096), |
| (16, 4096, 7168, 2048), |
| ] |
|
|
|
|
| @requires_sm90 |
| @pytest.mark.parametrize("num_groups,m,n,expected_k", K_GROUPED_PARAMS) |
| def test_k_grouped_fp8_gemm_contiguous(num_groups, m, n, expected_k): |
| """Test k-grouped contiguous FP8 GEMM.""" |
| reset_seed() |
| kernel_type = KernelType.Kernel1D1D |
| use_ue8m0 = get_ue8m0_usage(kernel_type) |
|
|
| |
| if get_arch_major() == 9: |
| major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor |
| else: |
| major_a, major_b = MajorTypeAB.MNMajor, MajorTypeAB.MNMajor |
|
|
| ks = [align(int(expected_k * random.uniform(0.7, 1.3)), |
| get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)] |
|
|
| k, a, b, c, d, ref_d = generate_k_grouped_contiguous( |
| num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0 |
| ) |
| ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') |
|
|
| |
| k_grouped_func = (deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9 |
| else deep_gemm.k_grouped_fp8_gemm_tn_contiguous) |
| k_grouped_func(a, b, d, ks, ks_tensor, c) |
|
|
| diff = calc_diff(d, ref_d) |
| assert diff < 0.001, f"{m=}, {n=}, k_total={k}, {ks=}, {diff:.5f}" |
|
|
|
|
| @requires_sm90 |
| @pytest.mark.parametrize("num_groups,m,n,expected_k", K_GROUPED_PARAMS[:1]) |
| def test_k_grouped_fp8_gemm_with_empty_groups(num_groups, m, n, expected_k): |
| """Test k-grouped contiguous FP8 GEMM with an empty group.""" |
| reset_seed() |
| kernel_type = KernelType.Kernel1D1D |
| use_ue8m0 = get_ue8m0_usage(kernel_type) |
|
|
| if get_arch_major() == 9: |
| major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor |
| else: |
| major_a, major_b = MajorTypeAB.MNMajor, MajorTypeAB.MNMajor |
|
|
| ks = [align(int(expected_k * random.uniform(0.7, 1.3)), |
| get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)] |
| ks[random.randint(0, num_groups - 1)] = 0 |
|
|
| k, a, b, c, d, ref_d = generate_k_grouped_contiguous( |
| num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0 |
| ) |
| ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') |
|
|
| k_grouped_func = (deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9 |
| else deep_gemm.k_grouped_fp8_gemm_tn_contiguous) |
| k_grouped_func(a, b, d, ks, ks_tensor, c) |
|
|
| diff = calc_diff(d, ref_d) |
| assert diff < 0.001, f"{m=}, {n=}, {ks=}, {diff:.5f}" |
|
|
|
|
| if __name__ == '__main__': |
| pytest.main([__file__, '-v']) |
|
|