File size: 3,051 Bytes
4754707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""Isolate which operations in score_mod break Inductor lowering on this PT/triton."""
import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask

flex_c = torch.compile(flex_attention, dynamic=False)


def try_score_mod(label, score_mod_fn, captures=None):
    print(f'\n=== {label} ===')
    B, H, T, Dh = 4, 16, 2048, 48
    device = 'cuda'
    Q = torch.randn(B, H, T, Dh, device=device, dtype=torch.bfloat16, requires_grad=True)
    K = torch.randn(B, H, T, Dh, device=device, dtype=torch.bfloat16, requires_grad=True)
    V = torch.randn(B, H, T, Dh, device=device, dtype=torch.bfloat16, requires_grad=True)
    def causal(b, h, q, kv): return q >= kv
    bm = create_block_mask(causal, B=None, H=None, Q_LEN=T, KV_LEN=T, device=device)
    try:
        with torch.autocast('cuda', dtype=torch.bfloat16):
            O = flex_c(Q, K, V, score_mod=score_mod_fn, block_mask=bm, scale=1.0)
        O.sum().backward()
        print(f'  PASS  shape={O.shape}')
    except Exception as e:
        msg = str(e).split('\n')[0]
        # Find the actual assertion line if present
        s = str(e)
        for kw in ['AssertionError', 'wrong ndim', 'FlexibleLayout', 'NotImplementedError']:
            i = s.find(kw)
            if i >= 0:
                msg = s[i:i+100]
                break
        print(f'  FAIL: {msg}')


def main():
    H = 16
    slopes_f = torch.tensor([1 << i for i in range(H)], dtype=torch.float32, device='cuda')
    tau_t = torch.tensor(0.1, device='cuda')
    tau_1 = torch.tensor([0.1], device='cuda')

    # Test 1: identity
    try_score_mod('1. identity', lambda s, b, h, q, kv: s)

    # Test 2: scalar divide (Python float)
    try_score_mod('2. s / 0.1 (python float)', lambda s, b, h, q, kv: s / 0.1)

    # Test 3: 0-dim tensor divide
    try_score_mod('3. s / tau_t (0-dim tensor)', lambda s, b, h, q, kv: s / tau_t)

    # Test 4: (1,) tensor squeezed
    tau_sq = tau_1.squeeze()
    try_score_mod('4. s / tau_1.squeeze()', lambda s, b, h, q, kv: s / tau_sq)

    # Test 5: subtract slopes-indexed bias
    try_score_mod('5. s - slopes_f[h]', lambda s, b, h, q, kv: s - slopes_f[h])

    # Test 6: q - kv
    try_score_mod('6. s + (q - kv)', lambda s, b, h, q, kv: s + (q - kv))

    # Test 7: abs(q - kv)
    try_score_mod('7. s + abs(q - kv)', lambda s, b, h, q, kv: s + (q - kv).abs())

    # Test 8: full ALiBi (slopes * |q-kv|, no .float())
    try_score_mod('8. s - slopes_f[h] * (q - kv).abs()',
                  lambda s, b, h, q, kv: s - slopes_f[h] * (q - kv).abs())

    # Test 9: alibi + tau divide
    try_score_mod('9. (s - slopes_f[h] * (q - kv).abs()) / tau_t',
                  lambda s, b, h, q, kv: (s - slopes_f[h] * (q - kv).abs()) / tau_t)

    # Test 10: with captured g
    g = torch.zeros(4, H, 2048, 2048, device='cuda', dtype=torch.float32)
    try_score_mod('10. + g[b,h,q,kv]',
                  lambda s, b, h, q, kv:
                      (s + g[b, h, q, kv] - slopes_f[h] * (q - kv).abs()) / tau_t)


if __name__ == '__main__':
    main()