Kernels
danieldk's picture
danieldk HF Staff
Benchmarks uploaded using `kernels`.
a608493 verified
import torch
from kernels.benchmark import Benchmark
def ref_masked_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
) -> torch.Tensor:
# query: (q, h, d), key: (k, h, d), value: (k, h, d)
# Transpose to (h, q, d) and (h, k, d) for batched matmul
q = query.transpose(0, 1) # (h, q, d)
k = key.transpose(0, 1) # (h, k, d)
v = value.transpose(0, 1) # (h, k, d)
# Compute attention scores: (h, q, d) @ (h, d, k) -> (h, q, k)
attn_weights = (scale * torch.matmul(q, k.transpose(-1, -2))).float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
# Compute output: (h, q, k) @ (h, k, d) -> (h, q, d)
out = torch.matmul(attn_weights, v)
# Transpose back to (q, h, d)
return out.transpose(0, 1)
def ref_paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
scale: float,
) -> torch.Tensor:
num_seqs = query.shape[0]
num_heads = query.shape[1]
head_size = query.shape[2]
block_size = value_cache.shape[3]
max_seq_len = int(seq_lens.max().item())
# Create position indices for all sequences up to max_seq_len
positions = torch.arange(max_seq_len, device=query.device)
block_indices = positions // block_size # (max_seq_len,)
block_offsets = positions % block_size # (max_seq_len,)
# Gather block numbers for all sequences: (num_seqs, max_seq_len)
block_numbers = block_tables[:, block_indices.long()]
# Flatten for gathering: (num_seqs * max_seq_len,)
flat_block_numbers = block_numbers.reshape(-1)
flat_offsets = block_offsets.repeat(num_seqs)
# Gather keys: key_cache is (num_blocks, num_heads, head_size // x, block_size, x)
# Index into [block_number, :, :, offset, :] and reshape
keys = key_cache[flat_block_numbers, :, :, flat_offsets, :]
keys = keys.reshape(num_seqs, max_seq_len, num_heads, head_size)
keys = keys.transpose(1, 2) # (num_seqs, num_heads, max_seq_len, head_size)
# Gather values: value_cache is (num_blocks, num_heads, head_size, block_size)
values = value_cache[flat_block_numbers, :, :, flat_offsets]
values = values.reshape(num_seqs, max_seq_len, num_heads, head_size)
values = values.transpose(1, 2) # (num_seqs, num_heads, max_seq_len, head_size)
# Query: (num_seqs, num_heads, head_size) -> (num_seqs, num_heads, 1, head_size)
q = query.unsqueeze(2)
# Compute attention scores: (num_seqs, num_heads, 1, head_size) @ (num_seqs, num_heads, head_size, max_seq_len)
attn_weights = (scale * torch.matmul(q, keys.transpose(-1, -2))).float()
# Create causal mask for variable sequence lengths
# Mask out positions beyond seq_len for each sequence
seq_mask = positions.unsqueeze(0) >= seq_lens.unsqueeze(
1
) # (num_seqs, max_seq_len)
seq_mask = seq_mask.unsqueeze(1).unsqueeze(2) # (num_seqs, 1, 1, max_seq_len)
attn_weights = attn_weights.masked_fill(seq_mask, float("-inf"))
attn_weights = torch.softmax(attn_weights, dim=-1).to(values.dtype)
# Compute output: (num_seqs, num_heads, 1, max_seq_len) @ (num_seqs, num_heads, max_seq_len, head_size)
out = torch.matmul(attn_weights, values)
return out.squeeze(2) # (num_seqs, num_heads, head_size)
class PagedAttentionBenchmark(Benchmark):
seed: int = 42
def setup(self):
num_seqs = 4
num_heads = 8
head_size = 64
block_size = 16
max_seq_len = 128
num_blocks = 64
dtype = torch.float16
self.num_heads = num_heads
self.block_size = block_size
self.max_seq_len = max_seq_len
self.scale = 1.0 / (head_size**0.5)
# Query tensor (current token)
self.query = torch.randn(
num_seqs, num_heads, head_size, device=self.device, dtype=dtype
)
# KV cache with proper layout for the kernel
# x = 16 // element_size, for float16 x = 8
x = 16 // torch.tensor([], dtype=dtype).element_size()
self.key_cache = torch.randn(
num_blocks,
num_heads,
head_size // x,
block_size,
x,
device=self.device,
dtype=dtype,
)
self.value_cache = torch.randn(
num_blocks,
num_heads,
head_size,
block_size,
device=self.device,
dtype=dtype,
)
# Block tables: mapping from sequences to memory blocks
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
self.block_tables = torch.randint(
0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
device=self.device,
dtype=torch.int32,
)
# Sequence lengths
self.seq_lens = torch.tensor(
[64, 96, 48, 128], device=self.device, dtype=torch.int32
)
# KV scales
self.k_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device)
self.v_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device)
# Output tensor
self.out = torch.empty_like(self.query)
def benchmark_base(self):
self.kernel.paged_attention_v1(
self.out,
self.query,
self.key_cache,
self.value_cache,
num_kv_heads=self.num_heads,
scale=self.scale,
block_tables=self.block_tables,
seq_lens=self.seq_lens,
block_size=self.block_size,
max_seq_len=self.max_seq_len,
alibi_slopes=None,
kv_cache_dtype="auto",
k_scale=self.k_scale,
v_scale=self.v_scale,
)
def verify_base(self) -> torch.Tensor:
return ref_paged_attention(
self.query,
self.key_cache,
self.value_cache,
self.block_tables,
self.seq_lens,
self.scale,
)
def setup_large(self):
num_seqs = 16
num_heads = 32
head_size = 128
block_size = 16
max_seq_len = 512
num_blocks = 256
dtype = torch.float16
self.num_heads = num_heads
self.block_size = block_size
self.max_seq_len = max_seq_len
self.scale = 1.0 / (head_size**0.5)
self.query = torch.randn(
num_seqs, num_heads, head_size, device=self.device, dtype=dtype
)
x = 16 // torch.tensor([], dtype=dtype).element_size()
self.key_cache = torch.randn(
num_blocks,
num_heads,
head_size // x,
block_size,
x,
device=self.device,
dtype=dtype,
)
self.value_cache = torch.randn(
num_blocks,
num_heads,
head_size,
block_size,
device=self.device,
dtype=dtype,
)
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
self.block_tables = torch.randint(
0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
device=self.device,
dtype=torch.int32,
)
# Variable sequence lengths
self.seq_lens = torch.randint(
64, max_seq_len + 1, (num_seqs,), device=self.device, dtype=torch.int32
)
self.k_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device)
self.v_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device)
self.out = torch.empty_like(self.query)
def benchmark_large(self):
self.kernel.paged_attention_v1(
self.out,
self.query,
self.key_cache,
self.value_cache,
num_kv_heads=self.num_heads,
scale=self.scale,
block_tables=self.block_tables,
seq_lens=self.seq_lens,
block_size=self.block_size,
max_seq_len=self.max_seq_len,
alibi_slopes=None,
kv_cache_dtype="auto",
k_scale=self.k_scale,
v_scale=self.v_scale,
)
def verify_large(self) -> torch.Tensor:
return ref_paged_attention(
self.query,
self.key_cache,
self.value_cache,
self.block_tables,
self.seq_lens,
self.scale,
)