""" Tests for BF16 GEMM kernels. Tests correctness of: - bf16_gemm_nt (and layout aliases nn, tn, tt) - m_grouped_bf16_gemm_nt_contiguous (and nn alias) - m_grouped_bf16_gemm_nt_masked - k_grouped_bf16_gemm_tn_contiguous - cublaslt_gemm_nt """ import copy import random import pytest import torch import deep_gemm from deep_gemm.testing import calc_diff, get_arch_major from generators import ( MajorTypeAB, KernelType, QuantConfig, reset_seed, align, generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous, layout_masked_to_psum, get_psum_layout_usage, get_mk_alignment_for_contiguous_layout ) cuda_available = torch.cuda.is_available() requires_cuda = pytest.mark.skipif(not cuda_available, reason="CUDA is required") requires_sm90 = pytest.mark.skipif( not cuda_available or get_arch_major() < 9, reason="Requires SM90+ (Hopper or newer)" ) # --------------------------------------------------------------------------- # BF16 GEMM (standard) # --------------------------------------------------------------------------- BF16_GEMM_SHAPES = [ # (m, n, k, accumulate, out_dtype) (1, 2112, 7168, False, torch.bfloat16), (128, 576, 7168, False, torch.bfloat16), (4096, 7168, 2048, False, torch.bfloat16), (4096, 4096, 7168, False, torch.bfloat16), # FP32 output (only BF16 GEMMs) (128, 256, 7168, False, torch.float), # With accumulation (128, 2112, 7168, True, torch.bfloat16), ] @requires_sm90 @pytest.mark.parametrize("m,n,k,accumulate,out_dtype", BF16_GEMM_SHAPES) def test_bf16_gemm_nt(m, n, k, accumulate, out_dtype): """Test BF16 GEMM with NT layout.""" reset_seed() kernel_type = KernelType.KernelNoSF major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor a, b, c, d, ref_d = generate_normal( m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True ) deep_gemm.bf16_gemm_nt(a, b, d, c=c) diff = calc_diff(d, ref_d) assert diff < 1e-5, f"{m=}, {n=}, {k=}, {accumulate=}, {out_dtype=}, {diff:.5f}" @requires_sm90 @pytest.mark.parametrize("layout_name,func_name", [ ("nn", "bf16_gemm_nn"), ("tn", "bf16_gemm_tn"), ("tt", "bf16_gemm_tt"), ]) def test_bf16_gemm_aliases(layout_name, func_name): """Test BF16 GEMM layout aliases (nn, tn, tt) with contiguous inputs.""" reset_seed() m, n, k = 128, 4096, 7168 kernel_type = KernelType.KernelNoSF major_a = MajorTypeAB.MNMajor if layout_name[0] == 't' else MajorTypeAB.KMajor major_b = MajorTypeAB.MNMajor if layout_name[1] == 'n' else MajorTypeAB.KMajor a, b, c, d, ref_d = generate_normal( m, n, k, major_a, major_b, False, torch.bfloat16, kernel_type, use_bf16=True ) # Make contiguous for alias path a = a if major_a.is_k_major() else a.T b = b if major_b.is_k_major() else b.T assert a.is_contiguous() and b.is_contiguous() getattr(deep_gemm, func_name)(a, b, d) diff = calc_diff(d, ref_d) assert diff < 1e-5, f"{layout_name=}, {diff:.5f}" # --------------------------------------------------------------------------- # BF16 m-grouped contiguous GEMM # --------------------------------------------------------------------------- M_GROUPED_CONT_PARAMS = [ # (num_groups, expected_m_per_group, n, k) (4, 8192, 6144, 7168), (8, 4096, 7168, 3072), (4, 8192, 4096, 4096), ] @requires_sm90 @pytest.mark.parametrize("num_groups,expected_m,n,k", M_GROUPED_CONT_PARAMS) def test_m_grouped_bf16_gemm_nt_contiguous(num_groups, expected_m, n, k): """Test m-grouped contiguous BF16 GEMM.""" reset_seed() major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous( num_groups, expected_m, n, k, major_a, major_b, use_bf16=True ) deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout) diff = calc_diff(d, ref_d) assert diff < 1e-5, f"{m=}, {n=}, {k=}, {diff:.5f}" @requires_sm90 @pytest.mark.parametrize("num_groups,expected_m,n,k", M_GROUPED_CONT_PARAMS[:2]) def test_m_grouped_bf16_gemm_nn_contiguous_alias(num_groups, expected_m, n, k): """Test m-grouped contiguous BF16 GEMM with NN alias.""" reset_seed() major_a = MajorTypeAB.KMajor major_b = MajorTypeAB.MNMajor m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous( num_groups, expected_m, n, k, major_a, major_b, use_bf16=True ) b = b.mT assert a[0:1].is_contiguous() and b[0:1].is_contiguous() deep_gemm.m_grouped_bf16_gemm_nn_contiguous(a, b, d, grouped_layout) diff = calc_diff(d, ref_d) assert diff < 1e-5, f"{m=}, {n=}, {k=}, {diff:.5f}" # --------------------------------------------------------------------------- # BF16 m-grouped masked GEMM # --------------------------------------------------------------------------- M_GROUPED_MASKED_PARAMS = [ # (num_groups, max_m, expected_m_per_group, n, k) (6, 4096, 1024, 6144, 7168), (32, 4096, 192, 7168, 3072), (32, 4096, 50, 4096, 4096), ] @requires_sm90 @pytest.mark.parametrize("num_groups,max_m,expected_m,n,k", M_GROUPED_MASKED_PARAMS) def test_m_grouped_bf16_gemm_nt_masked(num_groups, max_m, expected_m, n, k): """Test m-grouped masked BF16 GEMM.""" reset_seed() a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked( num_groups, max_m, expected_m, n, k, use_bf16=True ) deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m) for j in range(num_groups): mj = masked_m[j].item() if mj == 0: continue diff = calc_diff(d[j, :mj], ref_d[j, :mj]) assert diff < 1e-5, f"{max_m=}, {n=}, {k=}, group={j}, masked_m={mj}, {diff:.5f}" # --------------------------------------------------------------------------- # BF16 k-grouped contiguous GEMM # --------------------------------------------------------------------------- K_GROUPED_PARAMS = [ # (num_groups, m, n, expected_k_per_group) (4, 4096, 7168, 8192), (8, 4096, 7168, 4096), (16, 4096, 7168, 2048), ] @requires_sm90 @pytest.mark.parametrize("num_groups,m,n,expected_k", K_GROUPED_PARAMS) def test_k_grouped_bf16_gemm_tn_contiguous(num_groups, m, n, expected_k): """Test k-grouped contiguous BF16 GEMM.""" reset_seed() major_a, major_b = MajorTypeAB.MNMajor, MajorTypeAB.MNMajor ks = [align(int(expected_k * random.uniform(0.7, 1.3)), get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)] k, a, b, c, d, ref_d = generate_k_grouped_contiguous( num_groups, m, n, major_a, major_b, ks, use_bf16=True ) ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') deep_gemm.k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c) diff = calc_diff(d, ref_d) assert diff < 1e-5, f"{m=}, {n=}, k_total={k}, {ks=}, {diff:.7f}" @requires_sm90 @pytest.mark.parametrize("num_groups,m,n,expected_k", K_GROUPED_PARAMS[:1]) def test_k_grouped_bf16_gemm_tn_with_empty_groups(num_groups, m, n, expected_k): """Test k-grouped contiguous BF16 GEMM with an empty group.""" reset_seed() major_a, major_b = MajorTypeAB.MNMajor, MajorTypeAB.MNMajor ks = [align(int(expected_k * random.uniform(0.7, 1.3)), get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)] ks[random.randint(0, num_groups - 1)] = 0 k, a, b, c, d, ref_d = generate_k_grouped_contiguous( num_groups, m, n, major_a, major_b, ks, use_bf16=True ) ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') deep_gemm.k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c) diff = calc_diff(d, ref_d) assert diff < 1e-5, f"{m=}, {n=}, {ks=}, {diff:.7f}" # --------------------------------------------------------------------------- # cuBLASLt GEMM # --------------------------------------------------------------------------- CUBLASLT_SHAPES = [ # (m, n, k, accumulate, out_dtype) (1, 2112, 7168, False, torch.bfloat16), (128, 576, 7168, False, torch.bfloat16), (4096, 4096, 7168, False, torch.bfloat16), (128, 2112, 7168, True, torch.bfloat16), ] @requires_cuda @pytest.mark.parametrize("m,n,k,accumulate,out_dtype", CUBLASLT_SHAPES) def test_cublaslt_gemm_nt(m, n, k, accumulate, out_dtype): """Test cuBLASLt GEMM wrapper.""" reset_seed() kernel_type = KernelType.KernelNoSF major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor a, b, c, d, ref_d = generate_normal( m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True ) deep_gemm.cublaslt_gemm_nt(a, b, d, c=c) diff = calc_diff(d, ref_d) assert diff < 6e-7, f"{m=}, {n=}, {k=}, {accumulate=}, {out_dtype=}, {diff=}" if __name__ == '__main__': pytest.main([__file__, '-v'])