| """ | |
| Tests for hyperconnection prenorm GEMM kernel. | |
| Tests correctness of: | |
| - tf32_hc_prenorm_gemm | |
| """ | |
| import random | |
| import pytest | |
| import torch | |
| import deep_gemm | |
| from deep_gemm.testing import calc_diff, get_arch_major | |
| cuda_available = torch.cuda.is_available() | |
| requires_sm90 = pytest.mark.skipif( | |
| not cuda_available or get_arch_major() < 9, | |
| reason="Requires SM90+ (Hopper or newer)" | |
| ) | |
| HC_PRENORM_PARAMS = [ | |
| # (m, n, k, num_splits) | |
| (13, 24, 28672, None), | |
| (137, 24, 28672, None), | |
| (4096, 24, 28672, None), | |
| (8192, 24, 28672, None), | |
| (13, 24, 7680, None), | |
| (4096, 24, 7680, None), | |
| (13, 24, 7168, None), | |
| (4096, 24, 7168, None), | |
| # With split-K | |
| (13, 24, 28672, 16), | |
| (4096, 24, 28672, 16), | |
| (8192, 24, 28672, 16), | |
| (13, 24, 7680, 16), | |
| (4096, 24, 7168, 16), | |
| ] | |
| def test_tf32_hc_prenorm_gemm(m, n, k, num_splits): | |
| """Test TF32 hyperconnection prenorm GEMM. | |
| This kernel computes both: | |
| d = a @ b.T (GEMM) | |
| s = sum(a^2, dim=-1) (squared norm) | |
| in a single fused kernel. | |
| """ | |
| torch.manual_seed(0) | |
| random.seed(0) | |
| # TF32 precision required for reference | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| a = torch.randn((m, k), dtype=torch.bfloat16, device='cuda') | |
| b = torch.randn((n, k), dtype=torch.float, device='cuda') | |
| if num_splits is None: | |
| d = torch.empty((m, n), dtype=torch.float, device='cuda') | |
| s = torch.empty((m,), dtype=torch.float, device='cuda') | |
| else: | |
| d = torch.empty((num_splits, m, n), dtype=torch.float, device='cuda') | |
| s = torch.empty((num_splits, m), dtype=torch.float, device='cuda') | |
| deep_gemm.tf32_hc_prenorm_gemm(a, b, d, s, num_splits=num_splits) | |
| final_d = d if num_splits is None else d.sum(0) | |
| final_s = s if num_splits is None else s.sum(0) | |
| ref_d = a.float() @ b.T | |
| ref_s = a.float().square().sum(-1) | |
| diff = max(calc_diff(final_d, ref_d), calc_diff(final_s, ref_s)) | |
| assert diff < 1e-8, f"{m=}, {n=}, {k=}, {num_splits=}, {diff:.10f}" | |
| if __name__ == '__main__': | |
| pytest.main([__file__, '-v']) | |