flash-mla / scripts /readme_example.py
medmekk's picture
Upload folder using huggingface_hub
ccef021 verified
# /// script
# dependencies = [
# "numpy",
# "torch",
# "kernels"
# ]
# ///
"""
Flash-MLA (Multi-head Latent Attention) Example
This script demonstrates the usage of the Flash-MLA kernel for efficient
attention computation on Hopper (SM90) GPUs.
Flash-MLA is optimized for DeepSeek-style MLA attention patterns.
"""
import math
import torch
from kernels import get_kernel, get_local_kernel
from pathlib import Path
# Setup
torch.manual_seed(42)
flash_mla = get_kernel("drbh/tmp-kernel-123")
# flash_mla = get_local_kernel(Path("build"), "flash-mla")
device = torch.device("cuda")
# Check GPU architecture
cc_major, cc_minor = torch.cuda.get_device_capability()
print(f"GPU Compute Capability: {cc_major}.{cc_minor}")
if cc_major != 9:
print("Warning: Flash-MLA dense decoding is optimized for SM90 (Hopper) GPUs.")
print("Some features may not work on other architectures.")
def cdiv(a, b):
"""Ceiling division"""
return (a + b - 1) // b
# =============================================================================
# Test 1: Dense MLA Decoding (SM90)
# =============================================================================
print("\n" + "=" * 60)
print("Test 1: Dense MLA Decoding")
print("=" * 60)
# Configuration matching DeepSeek V3 architecture
batch_size = 2
seq_len_q = 1 # Typically 1 for decoding
num_heads_q = 64 # Number of query heads (must be 64 or 128)
num_heads_k = 1 # MLA uses single KV head
head_dim = 576 # Q/K head dimension (576 or 512)
head_dim_v = 512 # V head dimension (must be 512)
page_block_size = 64 # Page block size (must be 64)
seq_len_k = 256 # KV cache sequence length
# Calculate number of blocks needed
max_num_blocks = cdiv(seq_len_k, page_block_size)
# Create input tensors
q = torch.randn(batch_size, seq_len_q, num_heads_q, head_dim,
device=device, dtype=torch.bfloat16) / 10
q.clamp_(min=-1.0, max=1.0)
# KV cache in blocked format: [num_blocks, page_block_size, num_heads_k, head_dim]
total_blocks = batch_size * max_num_blocks
blocked_k = torch.randn(total_blocks, page_block_size, num_heads_k, head_dim,
device=device, dtype=torch.bfloat16) / 10
blocked_k.clamp_(min=-1.0, max=1.0)
# Block table maps batch elements to their cache blocks
block_table = torch.arange(total_blocks, device=device, dtype=torch.int32).view(batch_size, max_num_blocks)
# Sequence lengths for each batch element
cache_seqlens = torch.full((batch_size,), seq_len_k, device=device, dtype=torch.int32)
# Get scheduler metadata (required for flash_mla_with_kvcache)
tile_scheduler_metadata, _ = flash_mla.get_mla_metadata()
print(f"Query shape: {q.shape}")
print(f"KV cache shape: {blocked_k.shape}")
print(f"Block table shape: {block_table.shape}")
print(f"Cache seqlens: {cache_seqlens}")
# Run Flash-MLA dense decoding
with torch.inference_mode():
out, lse = flash_mla.flash_mla_with_kvcache(
q=q,
k_cache=blocked_k,
block_table=block_table,
cache_seqlens=cache_seqlens,
head_dim_v=head_dim_v,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=None,
causal=False, # Causal masking
)
print(f"Output shape: {out.shape}") # [batch_size, seq_len_q, num_heads_q, head_dim_v]
print(f"LSE shape: {lse.shape}") # [batch_size, num_heads_q, seq_len_q]
print("Dense MLA decoding: SUCCESS")
# =============================================================================
# Test 2: Reference comparison for correctness
# =============================================================================
print("\n" + "=" * 60)
print("Test 2: Correctness Check vs PyTorch Reference")
print("=" * 60)
def reference_attention(q, blocked_k, block_table, cache_seqlens, dv, is_causal=False):
"""
Reference implementation using PyTorch for verification
"""
b, s_q, h_q, d = q.size()
block_size = blocked_k.size(1)
h_kv = blocked_k.size(2)
out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32, device=q.device)
lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32, device=q.device)
cache_seqlens_cpu = cache_seqlens.cpu()
for i in range(b):
cur_len = int(cache_seqlens_cpu[i].item())
cur_num_blocks = cdiv(cur_len, block_size)
cur_block_indices = block_table[i][0:cur_num_blocks]
# Reconstruct KV from blocks
cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...]
# Compute attention
query = q[i].transpose(0, 1).float() # [h_q, s_q, d]
kv = cur_kv.transpose(0, 1).float() # [h_kv, s_k, d]
# Expand KV heads if needed
if h_kv != h_q:
kv = kv.repeat_interleave(h_q // h_kv, dim=0)
# Q @ K^T
attn_weight = query @ kv.transpose(-2, -1)
# Apply causal mask if needed
s_k = kv.size(1)
if is_causal and s_q > 1:
mask = torch.ones(s_q, s_k, dtype=torch.bool, device=q.device).tril(diagonal=s_k - s_q)
attn_weight.masked_fill_(~mask, float("-inf"))
# Scale and softmax
attn_weight = attn_weight / math.sqrt(d)
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1)
# Attention @ V
output = attn_weight @ kv[..., :dv]
out_ref[i] = output.transpose(0, 1)
lse_ref[i] = lse
return out_ref.to(q.dtype), lse_ref
# Compute reference
out_ref, lse_ref = reference_attention(q, blocked_k, block_table, cache_seqlens, head_dim_v, is_causal=False)
# Compare
out_close = torch.allclose(out.float(), out_ref.float(), atol=1e-3, rtol=1e-2)
lse_close = torch.allclose(lse.float(), lse_ref.float(), atol=1e-4, rtol=1e-3)
print(f"Output close to reference: {out_close}")
print(f"LSE close to reference: {lse_close}")
if out_close and lse_close:
print("Correctness check: PASSED")
else:
max_out_diff = (out.float() - out_ref.float()).abs().max().item()
max_lse_diff = (lse.float() - lse_ref.float()).abs().max().item()
print(f"Max output diff: {max_out_diff}")
print(f"Max LSE diff: {max_lse_diff}")
print("Correctness check: Check differences above")
# =============================================================================
# Test 3: Different configurations
# =============================================================================
print("\n" + "=" * 60)
print("Test 3: Testing different configurations")
print("=" * 60)
configs = [
{"batch": 1, "seq_q": 1, "heads_q": 64, "seq_k": 128},
{"batch": 4, "seq_q": 1, "heads_q": 128, "seq_k": 512},
{"batch": 8, "seq_q": 2, "heads_q": 64, "seq_k": 1024},
]
for cfg in configs:
b = cfg["batch"]
s_q = cfg["seq_q"]
h_q = cfg["heads_q"]
s_k = cfg["seq_k"]
max_blocks = cdiv(s_k, page_block_size)
total_blks = b * max_blocks
q_test = torch.randn(b, s_q, h_q, head_dim, device=device, dtype=torch.bfloat16) / 10
k_test = torch.randn(total_blks, page_block_size, num_heads_k, head_dim, device=device, dtype=torch.bfloat16) / 10
bt_test = torch.arange(total_blks, device=device, dtype=torch.int32).view(b, max_blocks)
sl_test = torch.full((b,), s_k, device=device, dtype=torch.int32)
sched_meta, _ = flash_mla.get_mla_metadata()
with torch.inference_mode():
out_test, lse_test = flash_mla.flash_mla_with_kvcache(
q=q_test,
k_cache=k_test,
block_table=bt_test,
cache_seqlens=sl_test,
head_dim_v=head_dim_v,
tile_scheduler_metadata=sched_meta,
)
print(f"Config: batch={b}, seq_q={s_q}, heads_q={h_q}, seq_k={s_k} -> Output: {out_test.shape} SUCCESS")
print("\n" + "=" * 60)
print("All tests completed successfully!")
print("=" * 60)