File size: 7,765 Bytes
ccef021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
# /// script
# dependencies = [
#   "numpy",
#   "torch",
#   "kernels"
# ]
# ///
"""
Flash-MLA (Multi-head Latent Attention) Example

This script demonstrates the usage of the Flash-MLA kernel for efficient
attention computation on Hopper (SM90) GPUs.

Flash-MLA is optimized for DeepSeek-style MLA attention patterns.
"""
import math
import torch
from kernels import get_kernel, get_local_kernel
from pathlib import Path

# Setup
torch.manual_seed(42)
flash_mla = get_kernel("drbh/tmp-kernel-123")
# flash_mla = get_local_kernel(Path("build"), "flash-mla")
device = torch.device("cuda")

# Check GPU architecture
cc_major, cc_minor = torch.cuda.get_device_capability()
print(f"GPU Compute Capability: {cc_major}.{cc_minor}")
if cc_major != 9:
    print("Warning: Flash-MLA dense decoding is optimized for SM90 (Hopper) GPUs.")
    print("Some features may not work on other architectures.")

def cdiv(a, b):
    """Ceiling division"""
    return (a + b - 1) // b


# =============================================================================
# Test 1: Dense MLA Decoding (SM90)
# =============================================================================
print("\n" + "=" * 60)
print("Test 1: Dense MLA Decoding")
print("=" * 60)

# Configuration matching DeepSeek V3 architecture
batch_size = 2
seq_len_q = 1  # Typically 1 for decoding
num_heads_q = 64  # Number of query heads (must be 64 or 128)
num_heads_k = 1   # MLA uses single KV head
head_dim = 576    # Q/K head dimension (576 or 512)
head_dim_v = 512  # V head dimension (must be 512)
page_block_size = 64  # Page block size (must be 64)
seq_len_k = 256   # KV cache sequence length

# Calculate number of blocks needed
max_num_blocks = cdiv(seq_len_k, page_block_size)

# Create input tensors
q = torch.randn(batch_size, seq_len_q, num_heads_q, head_dim,
                device=device, dtype=torch.bfloat16) / 10
q.clamp_(min=-1.0, max=1.0)

# KV cache in blocked format: [num_blocks, page_block_size, num_heads_k, head_dim]
total_blocks = batch_size * max_num_blocks
blocked_k = torch.randn(total_blocks, page_block_size, num_heads_k, head_dim,
                        device=device, dtype=torch.bfloat16) / 10
blocked_k.clamp_(min=-1.0, max=1.0)

# Block table maps batch elements to their cache blocks
block_table = torch.arange(total_blocks, device=device, dtype=torch.int32).view(batch_size, max_num_blocks)

# Sequence lengths for each batch element
cache_seqlens = torch.full((batch_size,), seq_len_k, device=device, dtype=torch.int32)

# Get scheduler metadata (required for flash_mla_with_kvcache)
tile_scheduler_metadata, _ = flash_mla.get_mla_metadata()

print(f"Query shape: {q.shape}")
print(f"KV cache shape: {blocked_k.shape}")
print(f"Block table shape: {block_table.shape}")
print(f"Cache seqlens: {cache_seqlens}")

# Run Flash-MLA dense decoding
with torch.inference_mode():
    out, lse = flash_mla.flash_mla_with_kvcache(
        q=q,
        k_cache=blocked_k,
        block_table=block_table,
        cache_seqlens=cache_seqlens,
        head_dim_v=head_dim_v,
        tile_scheduler_metadata=tile_scheduler_metadata,
        num_splits=None,
        causal=False,  # Causal masking
    )

print(f"Output shape: {out.shape}")  # [batch_size, seq_len_q, num_heads_q, head_dim_v]
print(f"LSE shape: {lse.shape}")     # [batch_size, num_heads_q, seq_len_q]
print("Dense MLA decoding: SUCCESS")


# =============================================================================
# Test 2: Reference comparison for correctness
# =============================================================================
print("\n" + "=" * 60)
print("Test 2: Correctness Check vs PyTorch Reference")
print("=" * 60)

def reference_attention(q, blocked_k, block_table, cache_seqlens, dv, is_causal=False):
    """
    Reference implementation using PyTorch for verification
    """
    b, s_q, h_q, d = q.size()
    block_size = blocked_k.size(1)
    h_kv = blocked_k.size(2)

    out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32, device=q.device)
    lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32, device=q.device)

    cache_seqlens_cpu = cache_seqlens.cpu()

    for i in range(b):
        cur_len = int(cache_seqlens_cpu[i].item())
        cur_num_blocks = cdiv(cur_len, block_size)
        cur_block_indices = block_table[i][0:cur_num_blocks]

        # Reconstruct KV from blocks
        cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...]

        # Compute attention
        query = q[i].transpose(0, 1).float()  # [h_q, s_q, d]
        kv = cur_kv.transpose(0, 1).float()   # [h_kv, s_k, d]

        # Expand KV heads if needed
        if h_kv != h_q:
            kv = kv.repeat_interleave(h_q // h_kv, dim=0)

        # Q @ K^T
        attn_weight = query @ kv.transpose(-2, -1)

        # Apply causal mask if needed
        s_k = kv.size(1)
        if is_causal and s_q > 1:
            mask = torch.ones(s_q, s_k, dtype=torch.bool, device=q.device).tril(diagonal=s_k - s_q)
            attn_weight.masked_fill_(~mask, float("-inf"))

        # Scale and softmax
        attn_weight = attn_weight / math.sqrt(d)
        lse = attn_weight.logsumexp(dim=-1)
        attn_weight = torch.softmax(attn_weight, dim=-1)

        # Attention @ V
        output = attn_weight @ kv[..., :dv]

        out_ref[i] = output.transpose(0, 1)
        lse_ref[i] = lse

    return out_ref.to(q.dtype), lse_ref

# Compute reference
out_ref, lse_ref = reference_attention(q, blocked_k, block_table, cache_seqlens, head_dim_v, is_causal=False)

# Compare
out_close = torch.allclose(out.float(), out_ref.float(), atol=1e-3, rtol=1e-2)
lse_close = torch.allclose(lse.float(), lse_ref.float(), atol=1e-4, rtol=1e-3)

print(f"Output close to reference: {out_close}")
print(f"LSE close to reference: {lse_close}")

if out_close and lse_close:
    print("Correctness check: PASSED")
else:
    max_out_diff = (out.float() - out_ref.float()).abs().max().item()
    max_lse_diff = (lse.float() - lse_ref.float()).abs().max().item()
    print(f"Max output diff: {max_out_diff}")
    print(f"Max LSE diff: {max_lse_diff}")
    print("Correctness check: Check differences above")


# =============================================================================
# Test 3: Different configurations
# =============================================================================
print("\n" + "=" * 60)
print("Test 3: Testing different configurations")
print("=" * 60)

configs = [
    {"batch": 1, "seq_q": 1, "heads_q": 64, "seq_k": 128},
    {"batch": 4, "seq_q": 1, "heads_q": 128, "seq_k": 512},
    {"batch": 8, "seq_q": 2, "heads_q": 64, "seq_k": 1024},
]

for cfg in configs:
    b = cfg["batch"]
    s_q = cfg["seq_q"]
    h_q = cfg["heads_q"]
    s_k = cfg["seq_k"]

    max_blocks = cdiv(s_k, page_block_size)
    total_blks = b * max_blocks

    q_test = torch.randn(b, s_q, h_q, head_dim, device=device, dtype=torch.bfloat16) / 10
    k_test = torch.randn(total_blks, page_block_size, num_heads_k, head_dim, device=device, dtype=torch.bfloat16) / 10
    bt_test = torch.arange(total_blks, device=device, dtype=torch.int32).view(b, max_blocks)
    sl_test = torch.full((b,), s_k, device=device, dtype=torch.int32)

    sched_meta, _ = flash_mla.get_mla_metadata()

    with torch.inference_mode():
        out_test, lse_test = flash_mla.flash_mla_with_kvcache(
            q=q_test,
            k_cache=k_test,
            block_table=bt_test,
            cache_seqlens=sl_test,
            head_dim_v=head_dim_v,
            tile_scheduler_metadata=sched_meta,
        )

    print(f"Config: batch={b}, seq_q={s_q}, heads_q={h_q}, seq_k={s_k} -> Output: {out_test.shape} SUCCESS")


print("\n" + "=" * 60)
print("All tests completed successfully!")
print("=" * 60)