| """ |
| Tests for BF16 GEMM kernels. |
| |
| Tests correctness of: |
| - bf16_gemm_nt (and layout aliases nn, tn, tt) |
| - m_grouped_bf16_gemm_nt_contiguous (and nn alias) |
| - m_grouped_bf16_gemm_nt_masked |
| - k_grouped_bf16_gemm_tn_contiguous |
| - cublaslt_gemm_nt |
| """ |
| import copy |
| import random |
| import pytest |
| import torch |
|
|
| import deep_gemm |
| from deep_gemm.testing import calc_diff, get_arch_major |
| from generators import ( |
| MajorTypeAB, KernelType, QuantConfig, |
| reset_seed, align, |
| generate_normal, generate_m_grouped_contiguous, |
| generate_m_grouped_masked, generate_k_grouped_contiguous, |
| layout_masked_to_psum, get_psum_layout_usage, |
| get_mk_alignment_for_contiguous_layout |
| ) |
|
|
| cuda_available = torch.cuda.is_available() |
| requires_cuda = pytest.mark.skipif(not cuda_available, reason="CUDA is required") |
| requires_sm90 = pytest.mark.skipif( |
| not cuda_available or get_arch_major() < 9, |
| reason="Requires SM90+ (Hopper or newer)" |
| ) |
|
|
|
|
| |
| |
| |
| BF16_GEMM_SHAPES = [ |
| |
| (1, 2112, 7168, False, torch.bfloat16), |
| (128, 576, 7168, False, torch.bfloat16), |
| (4096, 7168, 2048, False, torch.bfloat16), |
| (4096, 4096, 7168, False, torch.bfloat16), |
| |
| (128, 256, 7168, False, torch.float), |
| |
| (128, 2112, 7168, True, torch.bfloat16), |
| ] |
|
|
|
|
| @requires_sm90 |
| @pytest.mark.parametrize("m,n,k,accumulate,out_dtype", BF16_GEMM_SHAPES) |
| def test_bf16_gemm_nt(m, n, k, accumulate, out_dtype): |
| """Test BF16 GEMM with NT layout.""" |
| reset_seed() |
| kernel_type = KernelType.KernelNoSF |
| major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor |
|
|
| a, b, c, d, ref_d = generate_normal( |
| m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True |
| ) |
| deep_gemm.bf16_gemm_nt(a, b, d, c=c) |
| diff = calc_diff(d, ref_d) |
| assert diff < 1e-5, f"{m=}, {n=}, {k=}, {accumulate=}, {out_dtype=}, {diff:.5f}" |
|
|
|
|
| @requires_sm90 |
| @pytest.mark.parametrize("layout_name,func_name", [ |
| ("nn", "bf16_gemm_nn"), |
| ("tn", "bf16_gemm_tn"), |
| ("tt", "bf16_gemm_tt"), |
| ]) |
| def test_bf16_gemm_aliases(layout_name, func_name): |
| """Test BF16 GEMM layout aliases (nn, tn, tt) with contiguous inputs.""" |
| reset_seed() |
| m, n, k = 128, 4096, 7168 |
| kernel_type = KernelType.KernelNoSF |
| major_a = MajorTypeAB.MNMajor if layout_name[0] == 't' else MajorTypeAB.KMajor |
| major_b = MajorTypeAB.MNMajor if layout_name[1] == 'n' else MajorTypeAB.KMajor |
|
|
| a, b, c, d, ref_d = generate_normal( |
| m, n, k, major_a, major_b, False, torch.bfloat16, kernel_type, use_bf16=True |
| ) |
| |
| a = a if major_a.is_k_major() else a.T |
| b = b if major_b.is_k_major() else b.T |
| assert a.is_contiguous() and b.is_contiguous() |
|
|
| getattr(deep_gemm, func_name)(a, b, d) |
| diff = calc_diff(d, ref_d) |
| assert diff < 1e-5, f"{layout_name=}, {diff:.5f}" |
|
|
|
|
| |
| |
| |
| M_GROUPED_CONT_PARAMS = [ |
| |
| (4, 8192, 6144, 7168), |
| (8, 4096, 7168, 3072), |
| (4, 8192, 4096, 4096), |
| ] |
|
|
|
|
| @requires_sm90 |
| @pytest.mark.parametrize("num_groups,expected_m,n,k", M_GROUPED_CONT_PARAMS) |
| def test_m_grouped_bf16_gemm_nt_contiguous(num_groups, expected_m, n, k): |
| """Test m-grouped contiguous BF16 GEMM.""" |
| reset_seed() |
| major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor |
| m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous( |
| num_groups, expected_m, n, k, major_a, major_b, use_bf16=True |
| ) |
| deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout) |
| diff = calc_diff(d, ref_d) |
| assert diff < 1e-5, f"{m=}, {n=}, {k=}, {diff:.5f}" |
|
|
|
|
| @requires_sm90 |
| @pytest.mark.parametrize("num_groups,expected_m,n,k", M_GROUPED_CONT_PARAMS[:2]) |
| def test_m_grouped_bf16_gemm_nn_contiguous_alias(num_groups, expected_m, n, k): |
| """Test m-grouped contiguous BF16 GEMM with NN alias.""" |
| reset_seed() |
| major_a = MajorTypeAB.KMajor |
| major_b = MajorTypeAB.MNMajor |
| m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous( |
| num_groups, expected_m, n, k, major_a, major_b, use_bf16=True |
| ) |
| b = b.mT |
| assert a[0:1].is_contiguous() and b[0:1].is_contiguous() |
| deep_gemm.m_grouped_bf16_gemm_nn_contiguous(a, b, d, grouped_layout) |
| diff = calc_diff(d, ref_d) |
| assert diff < 1e-5, f"{m=}, {n=}, {k=}, {diff:.5f}" |
|
|
|
|
| |
| |
| |
| M_GROUPED_MASKED_PARAMS = [ |
| |
| (6, 4096, 1024, 6144, 7168), |
| (32, 4096, 192, 7168, 3072), |
| (32, 4096, 50, 4096, 4096), |
| ] |
|
|
|
|
| @requires_sm90 |
| @pytest.mark.parametrize("num_groups,max_m,expected_m,n,k", M_GROUPED_MASKED_PARAMS) |
| def test_m_grouped_bf16_gemm_nt_masked(num_groups, max_m, expected_m, n, k): |
| """Test m-grouped masked BF16 GEMM.""" |
| reset_seed() |
| a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked( |
| num_groups, max_m, expected_m, n, k, use_bf16=True |
| ) |
| deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m) |
|
|
| for j in range(num_groups): |
| mj = masked_m[j].item() |
| if mj == 0: |
| continue |
| diff = calc_diff(d[j, :mj], ref_d[j, :mj]) |
| assert diff < 1e-5, f"{max_m=}, {n=}, {k=}, group={j}, masked_m={mj}, {diff:.5f}" |
|
|
|
|
| |
| |
| |
| K_GROUPED_PARAMS = [ |
| |
| (4, 4096, 7168, 8192), |
| (8, 4096, 7168, 4096), |
| (16, 4096, 7168, 2048), |
| ] |
|
|
|
|
| @requires_sm90 |
| @pytest.mark.parametrize("num_groups,m,n,expected_k", K_GROUPED_PARAMS) |
| def test_k_grouped_bf16_gemm_tn_contiguous(num_groups, m, n, expected_k): |
| """Test k-grouped contiguous BF16 GEMM.""" |
| reset_seed() |
| major_a, major_b = MajorTypeAB.MNMajor, MajorTypeAB.MNMajor |
| ks = [align(int(expected_k * random.uniform(0.7, 1.3)), |
| get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)] |
|
|
| k, a, b, c, d, ref_d = generate_k_grouped_contiguous( |
| num_groups, m, n, major_a, major_b, ks, use_bf16=True |
| ) |
| ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') |
| deep_gemm.k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c) |
|
|
| diff = calc_diff(d, ref_d) |
| assert diff < 1e-5, f"{m=}, {n=}, k_total={k}, {ks=}, {diff:.7f}" |
|
|
|
|
| @requires_sm90 |
| @pytest.mark.parametrize("num_groups,m,n,expected_k", K_GROUPED_PARAMS[:1]) |
| def test_k_grouped_bf16_gemm_tn_with_empty_groups(num_groups, m, n, expected_k): |
| """Test k-grouped contiguous BF16 GEMM with an empty group.""" |
| reset_seed() |
| major_a, major_b = MajorTypeAB.MNMajor, MajorTypeAB.MNMajor |
| ks = [align(int(expected_k * random.uniform(0.7, 1.3)), |
| get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)] |
| ks[random.randint(0, num_groups - 1)] = 0 |
|
|
| k, a, b, c, d, ref_d = generate_k_grouped_contiguous( |
| num_groups, m, n, major_a, major_b, ks, use_bf16=True |
| ) |
| ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') |
| deep_gemm.k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c) |
|
|
| diff = calc_diff(d, ref_d) |
| assert diff < 1e-5, f"{m=}, {n=}, {ks=}, {diff:.7f}" |
|
|
|
|
| |
| |
| |
| CUBLASLT_SHAPES = [ |
| |
| (1, 2112, 7168, False, torch.bfloat16), |
| (128, 576, 7168, False, torch.bfloat16), |
| (4096, 4096, 7168, False, torch.bfloat16), |
| (128, 2112, 7168, True, torch.bfloat16), |
| ] |
|
|
|
|
| @requires_cuda |
| @pytest.mark.parametrize("m,n,k,accumulate,out_dtype", CUBLASLT_SHAPES) |
| def test_cublaslt_gemm_nt(m, n, k, accumulate, out_dtype): |
| """Test cuBLASLt GEMM wrapper.""" |
| reset_seed() |
| kernel_type = KernelType.KernelNoSF |
| major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor |
|
|
| a, b, c, d, ref_d = generate_normal( |
| m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True |
| ) |
| deep_gemm.cublaslt_gemm_nt(a, b, d, c=c) |
| diff = calc_diff(d, ref_d) |
| assert diff < 6e-7, f"{m=}, {n=}, {k=}, {accumulate=}, {out_dtype=}, {diff=}" |
|
|
|
|
| if __name__ == '__main__': |
| pytest.main([__file__, '-v']) |
|
|