deep-gemm / tests /test_hyperconnection.py
medmekk's picture
Upload folder using huggingface_hub
c67ae40 verified
"""
Tests for hyperconnection prenorm GEMM kernel.
Tests correctness of:
- tf32_hc_prenorm_gemm
"""
import random
import pytest
import torch
import deep_gemm
from deep_gemm.testing import calc_diff, get_arch_major
cuda_available = torch.cuda.is_available()
requires_sm90 = pytest.mark.skipif(
not cuda_available or get_arch_major() < 9,
reason="Requires SM90+ (Hopper or newer)"
)
HC_PRENORM_PARAMS = [
# (m, n, k, num_splits)
(13, 24, 28672, None),
(137, 24, 28672, None),
(4096, 24, 28672, None),
(8192, 24, 28672, None),
(13, 24, 7680, None),
(4096, 24, 7680, None),
(13, 24, 7168, None),
(4096, 24, 7168, None),
# With split-K
(13, 24, 28672, 16),
(4096, 24, 28672, 16),
(8192, 24, 28672, 16),
(13, 24, 7680, 16),
(4096, 24, 7168, 16),
]
@requires_sm90
@pytest.mark.parametrize("m,n,k,num_splits", HC_PRENORM_PARAMS)
def test_tf32_hc_prenorm_gemm(m, n, k, num_splits):
"""Test TF32 hyperconnection prenorm GEMM.
This kernel computes both:
d = a @ b.T (GEMM)
s = sum(a^2, dim=-1) (squared norm)
in a single fused kernel.
"""
torch.manual_seed(0)
random.seed(0)
# TF32 precision required for reference
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
a = torch.randn((m, k), dtype=torch.bfloat16, device='cuda')
b = torch.randn((n, k), dtype=torch.float, device='cuda')
if num_splits is None:
d = torch.empty((m, n), dtype=torch.float, device='cuda')
s = torch.empty((m,), dtype=torch.float, device='cuda')
else:
d = torch.empty((num_splits, m, n), dtype=torch.float, device='cuda')
s = torch.empty((num_splits, m), dtype=torch.float, device='cuda')
deep_gemm.tf32_hc_prenorm_gemm(a, b, d, s, num_splits=num_splits)
final_d = d if num_splits is None else d.sum(0)
final_s = s if num_splits is None else s.sum(0)
ref_d = a.float() @ b.T
ref_s = a.float().square().sum(-1)
diff = max(calc_diff(final_d, ref_d), calc_diff(final_s, ref_s))
assert diff < 1e-8, f"{m=}, {n=}, {k=}, {num_splits=}, {diff:.10f}"
if __name__ == '__main__':
pytest.main([__file__, '-v'])