deep-gemm / tests /test_bf16.py
medmekk's picture
Upload folder using huggingface_hub
c67ae40 verified
"""
Tests for BF16 GEMM kernels.
Tests correctness of:
- bf16_gemm_nt (and layout aliases nn, tn, tt)
- m_grouped_bf16_gemm_nt_contiguous (and nn alias)
- m_grouped_bf16_gemm_nt_masked
- k_grouped_bf16_gemm_tn_contiguous
- cublaslt_gemm_nt
"""
import copy
import random
import pytest
import torch
import deep_gemm
from deep_gemm.testing import calc_diff, get_arch_major
from generators import (
MajorTypeAB, KernelType, QuantConfig,
reset_seed, align,
generate_normal, generate_m_grouped_contiguous,
generate_m_grouped_masked, generate_k_grouped_contiguous,
layout_masked_to_psum, get_psum_layout_usage,
get_mk_alignment_for_contiguous_layout
)
cuda_available = torch.cuda.is_available()
requires_cuda = pytest.mark.skipif(not cuda_available, reason="CUDA is required")
requires_sm90 = pytest.mark.skipif(
not cuda_available or get_arch_major() < 9,
reason="Requires SM90+ (Hopper or newer)"
)
# ---------------------------------------------------------------------------
# BF16 GEMM (standard)
# ---------------------------------------------------------------------------
BF16_GEMM_SHAPES = [
# (m, n, k, accumulate, out_dtype)
(1, 2112, 7168, False, torch.bfloat16),
(128, 576, 7168, False, torch.bfloat16),
(4096, 7168, 2048, False, torch.bfloat16),
(4096, 4096, 7168, False, torch.bfloat16),
# FP32 output (only BF16 GEMMs)
(128, 256, 7168, False, torch.float),
# With accumulation
(128, 2112, 7168, True, torch.bfloat16),
]
@requires_sm90
@pytest.mark.parametrize("m,n,k,accumulate,out_dtype", BF16_GEMM_SHAPES)
def test_bf16_gemm_nt(m, n, k, accumulate, out_dtype):
"""Test BF16 GEMM with NT layout."""
reset_seed()
kernel_type = KernelType.KernelNoSF
major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor
a, b, c, d, ref_d = generate_normal(
m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True
)
deep_gemm.bf16_gemm_nt(a, b, d, c=c)
diff = calc_diff(d, ref_d)
assert diff < 1e-5, f"{m=}, {n=}, {k=}, {accumulate=}, {out_dtype=}, {diff:.5f}"
@requires_sm90
@pytest.mark.parametrize("layout_name,func_name", [
("nn", "bf16_gemm_nn"),
("tn", "bf16_gemm_tn"),
("tt", "bf16_gemm_tt"),
])
def test_bf16_gemm_aliases(layout_name, func_name):
"""Test BF16 GEMM layout aliases (nn, tn, tt) with contiguous inputs."""
reset_seed()
m, n, k = 128, 4096, 7168
kernel_type = KernelType.KernelNoSF
major_a = MajorTypeAB.MNMajor if layout_name[0] == 't' else MajorTypeAB.KMajor
major_b = MajorTypeAB.MNMajor if layout_name[1] == 'n' else MajorTypeAB.KMajor
a, b, c, d, ref_d = generate_normal(
m, n, k, major_a, major_b, False, torch.bfloat16, kernel_type, use_bf16=True
)
# Make contiguous for alias path
a = a if major_a.is_k_major() else a.T
b = b if major_b.is_k_major() else b.T
assert a.is_contiguous() and b.is_contiguous()
getattr(deep_gemm, func_name)(a, b, d)
diff = calc_diff(d, ref_d)
assert diff < 1e-5, f"{layout_name=}, {diff:.5f}"
# ---------------------------------------------------------------------------
# BF16 m-grouped contiguous GEMM
# ---------------------------------------------------------------------------
M_GROUPED_CONT_PARAMS = [
# (num_groups, expected_m_per_group, n, k)
(4, 8192, 6144, 7168),
(8, 4096, 7168, 3072),
(4, 8192, 4096, 4096),
]
@requires_sm90
@pytest.mark.parametrize("num_groups,expected_m,n,k", M_GROUPED_CONT_PARAMS)
def test_m_grouped_bf16_gemm_nt_contiguous(num_groups, expected_m, n, k):
"""Test m-grouped contiguous BF16 GEMM."""
reset_seed()
major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor
m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(
num_groups, expected_m, n, k, major_a, major_b, use_bf16=True
)
deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout)
diff = calc_diff(d, ref_d)
assert diff < 1e-5, f"{m=}, {n=}, {k=}, {diff:.5f}"
@requires_sm90
@pytest.mark.parametrize("num_groups,expected_m,n,k", M_GROUPED_CONT_PARAMS[:2])
def test_m_grouped_bf16_gemm_nn_contiguous_alias(num_groups, expected_m, n, k):
"""Test m-grouped contiguous BF16 GEMM with NN alias."""
reset_seed()
major_a = MajorTypeAB.KMajor
major_b = MajorTypeAB.MNMajor
m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(
num_groups, expected_m, n, k, major_a, major_b, use_bf16=True
)
b = b.mT
assert a[0:1].is_contiguous() and b[0:1].is_contiguous()
deep_gemm.m_grouped_bf16_gemm_nn_contiguous(a, b, d, grouped_layout)
diff = calc_diff(d, ref_d)
assert diff < 1e-5, f"{m=}, {n=}, {k=}, {diff:.5f}"
# ---------------------------------------------------------------------------
# BF16 m-grouped masked GEMM
# ---------------------------------------------------------------------------
M_GROUPED_MASKED_PARAMS = [
# (num_groups, max_m, expected_m_per_group, n, k)
(6, 4096, 1024, 6144, 7168),
(32, 4096, 192, 7168, 3072),
(32, 4096, 50, 4096, 4096),
]
@requires_sm90
@pytest.mark.parametrize("num_groups,max_m,expected_m,n,k", M_GROUPED_MASKED_PARAMS)
def test_m_grouped_bf16_gemm_nt_masked(num_groups, max_m, expected_m, n, k):
"""Test m-grouped masked BF16 GEMM."""
reset_seed()
a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(
num_groups, max_m, expected_m, n, k, use_bf16=True
)
deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m)
for j in range(num_groups):
mj = masked_m[j].item()
if mj == 0:
continue
diff = calc_diff(d[j, :mj], ref_d[j, :mj])
assert diff < 1e-5, f"{max_m=}, {n=}, {k=}, group={j}, masked_m={mj}, {diff:.5f}"
# ---------------------------------------------------------------------------
# BF16 k-grouped contiguous GEMM
# ---------------------------------------------------------------------------
K_GROUPED_PARAMS = [
# (num_groups, m, n, expected_k_per_group)
(4, 4096, 7168, 8192),
(8, 4096, 7168, 4096),
(16, 4096, 7168, 2048),
]
@requires_sm90
@pytest.mark.parametrize("num_groups,m,n,expected_k", K_GROUPED_PARAMS)
def test_k_grouped_bf16_gemm_tn_contiguous(num_groups, m, n, expected_k):
"""Test k-grouped contiguous BF16 GEMM."""
reset_seed()
major_a, major_b = MajorTypeAB.MNMajor, MajorTypeAB.MNMajor
ks = [align(int(expected_k * random.uniform(0.7, 1.3)),
get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)]
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(
num_groups, m, n, major_a, major_b, ks, use_bf16=True
)
ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda')
deep_gemm.k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c)
diff = calc_diff(d, ref_d)
assert diff < 1e-5, f"{m=}, {n=}, k_total={k}, {ks=}, {diff:.7f}"
@requires_sm90
@pytest.mark.parametrize("num_groups,m,n,expected_k", K_GROUPED_PARAMS[:1])
def test_k_grouped_bf16_gemm_tn_with_empty_groups(num_groups, m, n, expected_k):
"""Test k-grouped contiguous BF16 GEMM with an empty group."""
reset_seed()
major_a, major_b = MajorTypeAB.MNMajor, MajorTypeAB.MNMajor
ks = [align(int(expected_k * random.uniform(0.7, 1.3)),
get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)]
ks[random.randint(0, num_groups - 1)] = 0
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(
num_groups, m, n, major_a, major_b, ks, use_bf16=True
)
ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda')
deep_gemm.k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c)
diff = calc_diff(d, ref_d)
assert diff < 1e-5, f"{m=}, {n=}, {ks=}, {diff:.7f}"
# ---------------------------------------------------------------------------
# cuBLASLt GEMM
# ---------------------------------------------------------------------------
CUBLASLT_SHAPES = [
# (m, n, k, accumulate, out_dtype)
(1, 2112, 7168, False, torch.bfloat16),
(128, 576, 7168, False, torch.bfloat16),
(4096, 4096, 7168, False, torch.bfloat16),
(128, 2112, 7168, True, torch.bfloat16),
]
@requires_cuda
@pytest.mark.parametrize("m,n,k,accumulate,out_dtype", CUBLASLT_SHAPES)
def test_cublaslt_gemm_nt(m, n, k, accumulate, out_dtype):
"""Test cuBLASLt GEMM wrapper."""
reset_seed()
kernel_type = KernelType.KernelNoSF
major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor
a, b, c, d, ref_d = generate_normal(
m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True
)
deep_gemm.cublaslt_gemm_nt(a, b, d, c=c)
diff = calc_diff(d, ref_d)
assert diff < 6e-7, f"{m=}, {n=}, {k=}, {accumulate=}, {out_dtype=}, {diff=}"
if __name__ == '__main__':
pytest.main([__file__, '-v'])