deep-gemm / tests /test_fp8.py
medmekk's picture
Upload folder using huggingface_hub
c67ae40 verified
"""
Tests for FP8/FP4 GEMM kernels.
Tests correctness of:
- fp8_fp4_gemm_nt (and layout aliases)
- m_grouped_fp8_fp4_gemm_nt_contiguous
- m_grouped_fp8_fp4_gemm_nt_masked
- k_grouped_fp8_gemm_nt_contiguous / k_grouped_fp8_gemm_tn_contiguous
"""
import copy
import random
import pytest
import torch
import deep_gemm
from deep_gemm.testing import calc_diff, get_arch_major
from generators import (
MajorTypeAB, KernelType, QuantConfig,
reset_seed, align,
get_kernel_types, get_ue8m0_usage,
generate_normal, generate_m_grouped_contiguous,
generate_m_grouped_masked, generate_k_grouped_contiguous,
layout_masked_to_psum, get_psum_layout_usage,
get_mk_alignment_for_contiguous_layout
)
cuda_available = torch.cuda.is_available()
requires_sm90 = pytest.mark.skipif(
not cuda_available or get_arch_major() < 9,
reason="Requires SM90+ (Hopper or newer)"
)
def _get_default_kernel_type():
return get_kernel_types(torch.float8_e4m3fn)[0]
# ---------------------------------------------------------------------------
# FP8 GEMM (standard, forward-like shapes)
# ---------------------------------------------------------------------------
FP8_GEMM_FWD_SHAPES = [
# (m, n, k, accumulate, out_dtype)
(1, 2112, 7168, False, torch.bfloat16),
(128, 576, 7168, False, torch.bfloat16),
(4096, 4096, 7168, False, torch.bfloat16),
(4096, 7168, 2048, False, torch.bfloat16),
(128, 7168, 16384, False, torch.bfloat16),
]
@requires_sm90
@pytest.mark.parametrize("m,n,k,accumulate,out_dtype", FP8_GEMM_FWD_SHAPES)
def test_fp8_gemm_nt(m, n, k, accumulate, out_dtype):
"""Test standard FP8 GEMM with NT layout (forward pass)."""
reset_seed()
kernel_type = _get_default_kernel_type()
use_ue8m0 = get_ue8m0_usage(kernel_type)
quant_config = QuantConfig()
recipe, recipe_a, recipe_b = quant_config.get_recipes()
a, b, c, d, ref_d = generate_normal(
m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor,
accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0,
quant_config=quant_config
)
deep_gemm.fp8_fp4_gemm_nt(
a, b, d, c=c, disable_ue8m0_cast=not use_ue8m0,
recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b
)
diff = calc_diff(d, ref_d)
assert diff < quant_config.max_diff(), f"{m=}, {n=}, {k=}, {diff:.5f}"
# Backward-like shapes: dgrad (b major may be MN-major) and wgrad (accumulate=True)
FP8_GEMM_BWD_SHAPES = [
# dgrad-like
(4096, 7168, 2112, False, torch.bfloat16),
# wgrad-like (accumulate + FP32 output)
(2112, 4096, 7168, True, torch.float),
(2112, 4096, 7168, False, torch.bfloat16),
]
@requires_sm90
@pytest.mark.parametrize("m,n,k,accumulate,out_dtype", FP8_GEMM_BWD_SHAPES)
def test_fp8_gemm_nt_backward_shapes(m, n, k, accumulate, out_dtype):
"""Test FP8 GEMM with backward-pass-like shapes."""
reset_seed()
kernel_type = _get_default_kernel_type()
use_ue8m0 = get_ue8m0_usage(kernel_type)
quant_config = QuantConfig()
recipe, recipe_a, recipe_b = quant_config.get_recipes(is_wgrad=accumulate)
# For backward, B may be MN-major on SM100
major_b = MajorTypeAB.MNMajor if get_arch_major() != 9 else MajorTypeAB.KMajor
override_kernel_type = kernel_type
if get_arch_major() == 9:
major_b = MajorTypeAB.KMajor
override_kernel_type = KernelType.Kernel1D1D if accumulate else kernel_type
a, b, c, d, ref_d = generate_normal(
m, n, k, MajorTypeAB.KMajor, major_b,
accumulate, out_dtype, override_kernel_type if accumulate else kernel_type,
use_ue8m0=use_ue8m0, quant_config=quant_config
)
deep_gemm.fp8_fp4_gemm_nt(
a, b, d, c=c, disable_ue8m0_cast=not use_ue8m0,
recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b
)
diff = calc_diff(d, ref_d)
assert diff < quant_config.max_diff(), f"{m=}, {n=}, {k=}, {accumulate=}, {diff:.5f}"
@requires_sm90
@pytest.mark.parametrize("layout_name,func_name", [
("nn", "fp8_fp4_gemm_nn"),
("tn", "fp8_fp4_gemm_tn"),
("tt", "fp8_fp4_gemm_tt"),
])
def test_fp8_gemm_aliases(layout_name, func_name):
"""Test FP8 GEMM layout aliases with contiguous inputs."""
reset_seed()
m, n, k = 128, 4096, 7168
kernel_type = _get_default_kernel_type()
use_ue8m0 = get_ue8m0_usage(kernel_type)
quant_config = QuantConfig()
recipe, recipe_a, recipe_b = quant_config.get_recipes()
major_a = MajorTypeAB.MNMajor if layout_name[0] == 't' else MajorTypeAB.KMajor
major_b = MajorTypeAB.MNMajor if layout_name[1] == 'n' else MajorTypeAB.KMajor
a, b, c, d, ref_d = generate_normal(
m, n, k, major_a, major_b, False, torch.bfloat16, kernel_type,
use_ue8m0=use_ue8m0, quant_config=quant_config
)
# Make contiguous for alias path
a = a if major_a.is_k_major() else (a[0].T, a[1].T)
b = b if major_b.is_k_major() else (b[0].T, b[1].T)
assert a[0].is_contiguous() and b[0].is_contiguous()
getattr(deep_gemm, func_name)(
a, b, d, disable_ue8m0_cast=not use_ue8m0,
recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b
)
diff = calc_diff(d, ref_d)
assert diff < quant_config.max_diff(), f"{layout_name=}, {diff:.5f}"
# ---------------------------------------------------------------------------
# FP8 m-grouped contiguous GEMM
# ---------------------------------------------------------------------------
M_GROUPED_CONT_PARAMS = [
# (num_groups, expected_m_per_group, n, k)
(4, 8192, 6144, 7168),
(8, 4096, 7168, 3072),
(4, 8192, 4096, 4096),
]
@requires_sm90
@pytest.mark.parametrize("num_groups,expected_m,n,k", M_GROUPED_CONT_PARAMS)
def test_m_grouped_fp8_gemm_nt_contiguous(num_groups, expected_m, n, k):
"""Test m-grouped contiguous FP8 GEMM."""
reset_seed()
kernel_type = _get_default_kernel_type()
use_ue8m0 = get_ue8m0_usage(kernel_type)
quant_config = QuantConfig()
recipe, recipe_a, recipe_b = quant_config.get_recipes()
major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor
m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(
num_groups, expected_m, n, k, major_a, major_b,
use_ue8m0=use_ue8m0, quant_config=quant_config
)
deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(
a, b, d, grouped_layout,
disable_ue8m0_cast=not use_ue8m0,
recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b
)
diff = calc_diff(d, ref_d)
assert diff < quant_config.max_diff(), f"{m=}, {n=}, {k=}, {diff:.5f}"
@requires_sm90
@pytest.mark.parametrize("num_groups,expected_m,n,k", M_GROUPED_CONT_PARAMS[:2])
def test_m_grouped_fp8_gemm_nn_contiguous_alias(num_groups, expected_m, n, k):
"""Test m-grouped contiguous FP8 GEMM with NN alias."""
reset_seed()
kernel_type = _get_default_kernel_type()
use_ue8m0 = get_ue8m0_usage(kernel_type)
quant_config = QuantConfig()
recipe, recipe_a, recipe_b = quant_config.get_recipes()
allow_b_mn = get_arch_major() != 9
if not allow_b_mn:
pytest.skip("NN alias requires B MN-major support (SM100+)")
major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.MNMajor
m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(
num_groups, expected_m, n, k, major_a, major_b,
use_ue8m0=use_ue8m0, quant_config=quant_config
)
b = b if major_b.is_k_major() else (b[0].mT, b[1].mT)
assert a[0].is_contiguous() and b[0].is_contiguous()
deep_gemm.m_grouped_fp8_fp4_gemm_nn_contiguous(
a, b, d, grouped_layout,
disable_ue8m0_cast=not use_ue8m0,
recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b
)
diff = calc_diff(d, ref_d)
assert diff < quant_config.max_diff(), f"{m=}, {n=}, {k=}, {diff:.5f}"
# ---------------------------------------------------------------------------
# FP8 m-grouped masked GEMM
# ---------------------------------------------------------------------------
M_GROUPED_MASKED_PARAMS = [
# (num_groups, max_m, expected_m_per_group, n, k)
(6, 4096, 1024, 6144, 7168),
(32, 4096, 192, 7168, 3072),
(32, 4096, 50, 4096, 4096),
]
@requires_sm90
@pytest.mark.parametrize("num_groups,max_m,expected_m,n,k", M_GROUPED_MASKED_PARAMS)
def test_m_grouped_fp8_gemm_nt_masked(num_groups, max_m, expected_m, n, k):
"""Test m-grouped masked FP8 GEMM."""
reset_seed()
kernel_type = _get_default_kernel_type()
use_ue8m0 = get_ue8m0_usage(kernel_type)
quant_config = QuantConfig()
recipe, recipe_a, recipe_b = quant_config.get_recipes()
a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(
num_groups, max_m, expected_m, n, k,
use_ue8m0=use_ue8m0, quant_config=quant_config
)
deep_gemm.m_grouped_fp8_fp4_gemm_nt_masked(
a, b, d, masked_m, expected_m,
disable_ue8m0_cast=not use_ue8m0,
recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b
)
for j in range(num_groups):
mj = masked_m[j].item()
if mj == 0:
continue
diff = calc_diff(d[j, :mj], ref_d[j, :mj])
assert diff < quant_config.max_diff(), (
f"{max_m=}, {n=}, {k=}, group={j}, masked_m={mj}, {diff:.5f}"
)
# ---------------------------------------------------------------------------
# FP8 k-grouped contiguous GEMM
# ---------------------------------------------------------------------------
K_GROUPED_PARAMS = [
# (num_groups, m, n, expected_k_per_group)
(4, 4096, 7168, 8192),
(8, 4096, 7168, 4096),
(16, 4096, 7168, 2048),
]
@requires_sm90
@pytest.mark.parametrize("num_groups,m,n,expected_k", K_GROUPED_PARAMS)
def test_k_grouped_fp8_gemm_contiguous(num_groups, m, n, expected_k):
"""Test k-grouped contiguous FP8 GEMM."""
reset_seed()
kernel_type = KernelType.Kernel1D1D
use_ue8m0 = get_ue8m0_usage(kernel_type)
# Layout depends on architecture
if get_arch_major() == 9:
major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor
else:
major_a, major_b = MajorTypeAB.MNMajor, MajorTypeAB.MNMajor
ks = [align(int(expected_k * random.uniform(0.7, 1.3)),
get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)]
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(
num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0
)
ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda')
# Use the appropriate function for the architecture
k_grouped_func = (deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9
else deep_gemm.k_grouped_fp8_gemm_tn_contiguous)
k_grouped_func(a, b, d, ks, ks_tensor, c)
diff = calc_diff(d, ref_d)
assert diff < 0.001, f"{m=}, {n=}, k_total={k}, {ks=}, {diff:.5f}"
@requires_sm90
@pytest.mark.parametrize("num_groups,m,n,expected_k", K_GROUPED_PARAMS[:1])
def test_k_grouped_fp8_gemm_with_empty_groups(num_groups, m, n, expected_k):
"""Test k-grouped contiguous FP8 GEMM with an empty group."""
reset_seed()
kernel_type = KernelType.Kernel1D1D
use_ue8m0 = get_ue8m0_usage(kernel_type)
if get_arch_major() == 9:
major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor
else:
major_a, major_b = MajorTypeAB.MNMajor, MajorTypeAB.MNMajor
ks = [align(int(expected_k * random.uniform(0.7, 1.3)),
get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)]
ks[random.randint(0, num_groups - 1)] = 0
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(
num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0
)
ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda')
k_grouped_func = (deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9
else deep_gemm.k_grouped_fp8_gemm_tn_contiguous)
k_grouped_func(a, b, d, ks, ks_tensor, c)
diff = calc_diff(d, ref_d)
assert diff < 0.001, f"{m=}, {n=}, {ks=}, {diff:.5f}"
if __name__ == '__main__':
pytest.main([__file__, '-v'])