File size: 2,215 Bytes
c67ae40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
Tests for hyperconnection prenorm GEMM kernel.

Tests correctness of:
  - tf32_hc_prenorm_gemm
"""
import random
import pytest
import torch

import deep_gemm
from deep_gemm.testing import calc_diff, get_arch_major

cuda_available = torch.cuda.is_available()
requires_sm90 = pytest.mark.skipif(
    not cuda_available or get_arch_major() < 9,
    reason="Requires SM90+ (Hopper or newer)"
)


HC_PRENORM_PARAMS = [
    # (m, n, k, num_splits)
    (13, 24, 28672, None),
    (137, 24, 28672, None),
    (4096, 24, 28672, None),
    (8192, 24, 28672, None),
    (13, 24, 7680, None),
    (4096, 24, 7680, None),
    (13, 24, 7168, None),
    (4096, 24, 7168, None),
    # With split-K
    (13, 24, 28672, 16),
    (4096, 24, 28672, 16),
    (8192, 24, 28672, 16),
    (13, 24, 7680, 16),
    (4096, 24, 7168, 16),
]


@requires_sm90
@pytest.mark.parametrize("m,n,k,num_splits", HC_PRENORM_PARAMS)
def test_tf32_hc_prenorm_gemm(m, n, k, num_splits):
    """Test TF32 hyperconnection prenorm GEMM.

    This kernel computes both:
      d = a @ b.T  (GEMM)
      s = sum(a^2, dim=-1)  (squared norm)
    in a single fused kernel.
    """
    torch.manual_seed(0)
    random.seed(0)

    # TF32 precision required for reference
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    a = torch.randn((m, k), dtype=torch.bfloat16, device='cuda')
    b = torch.randn((n, k), dtype=torch.float, device='cuda')

    if num_splits is None:
        d = torch.empty((m, n), dtype=torch.float, device='cuda')
        s = torch.empty((m,), dtype=torch.float, device='cuda')
    else:
        d = torch.empty((num_splits, m, n), dtype=torch.float, device='cuda')
        s = torch.empty((num_splits, m), dtype=torch.float, device='cuda')

    deep_gemm.tf32_hc_prenorm_gemm(a, b, d, s, num_splits=num_splits)

    final_d = d if num_splits is None else d.sum(0)
    final_s = s if num_splits is None else s.sum(0)

    ref_d = a.float() @ b.T
    ref_s = a.float().square().sum(-1)

    diff = max(calc_diff(final_d, ref_d), calc_diff(final_s, ref_s))
    assert diff < 1e-8, f"{m=}, {n=}, {k=}, {num_splits=}, {diff:.10f}"


if __name__ == '__main__':
    pytest.main([__file__, '-v'])