""" Tests for hyperconnection prenorm GEMM kernel. Tests correctness of: - tf32_hc_prenorm_gemm """ import random import pytest import torch import deep_gemm from deep_gemm.testing import calc_diff, get_arch_major cuda_available = torch.cuda.is_available() requires_sm90 = pytest.mark.skipif( not cuda_available or get_arch_major() < 9, reason="Requires SM90+ (Hopper or newer)" ) HC_PRENORM_PARAMS = [ # (m, n, k, num_splits) (13, 24, 28672, None), (137, 24, 28672, None), (4096, 24, 28672, None), (8192, 24, 28672, None), (13, 24, 7680, None), (4096, 24, 7680, None), (13, 24, 7168, None), (4096, 24, 7168, None), # With split-K (13, 24, 28672, 16), (4096, 24, 28672, 16), (8192, 24, 28672, 16), (13, 24, 7680, 16), (4096, 24, 7168, 16), ] @requires_sm90 @pytest.mark.parametrize("m,n,k,num_splits", HC_PRENORM_PARAMS) def test_tf32_hc_prenorm_gemm(m, n, k, num_splits): """Test TF32 hyperconnection prenorm GEMM. This kernel computes both: d = a @ b.T (GEMM) s = sum(a^2, dim=-1) (squared norm) in a single fused kernel. """ torch.manual_seed(0) random.seed(0) # TF32 precision required for reference torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True a = torch.randn((m, k), dtype=torch.bfloat16, device='cuda') b = torch.randn((n, k), dtype=torch.float, device='cuda') if num_splits is None: d = torch.empty((m, n), dtype=torch.float, device='cuda') s = torch.empty((m,), dtype=torch.float, device='cuda') else: d = torch.empty((num_splits, m, n), dtype=torch.float, device='cuda') s = torch.empty((num_splits, m), dtype=torch.float, device='cuda') deep_gemm.tf32_hc_prenorm_gemm(a, b, d, s, num_splits=num_splits) final_d = d if num_splits is None else d.sum(0) final_s = s if num_splits is None else s.sum(0) ref_d = a.float() @ b.T ref_s = a.float().square().sum(-1) diff = max(calc_diff(final_d, ref_d), calc_diff(final_s, ref_s)) assert diff < 1e-8, f"{m=}, {n=}, {k=}, {num_splits=}, {diff:.10f}" if __name__ == '__main__': pytest.main([__file__, '-v'])