| | import torch |
| |
|
| | from kernels.benchmark import Benchmark |
| |
|
| |
|
| | def ref_masked_attention( |
| | query: torch.Tensor, |
| | key: torch.Tensor, |
| | value: torch.Tensor, |
| | scale: float, |
| | ) -> torch.Tensor: |
| | |
| | |
| | q = query.transpose(0, 1) |
| | k = key.transpose(0, 1) |
| | v = value.transpose(0, 1) |
| |
|
| | |
| | attn_weights = (scale * torch.matmul(q, k.transpose(-1, -2))).float() |
| | attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) |
| |
|
| | |
| | out = torch.matmul(attn_weights, v) |
| |
|
| | |
| | return out.transpose(0, 1) |
| |
|
| |
|
| | def ref_paged_attention( |
| | query: torch.Tensor, |
| | key_cache: torch.Tensor, |
| | value_cache: torch.Tensor, |
| | block_tables: torch.Tensor, |
| | seq_lens: torch.Tensor, |
| | scale: float, |
| | ) -> torch.Tensor: |
| | num_seqs = query.shape[0] |
| | num_heads = query.shape[1] |
| | head_size = query.shape[2] |
| | block_size = value_cache.shape[3] |
| | max_seq_len = int(seq_lens.max().item()) |
| |
|
| | |
| | positions = torch.arange(max_seq_len, device=query.device) |
| | block_indices = positions // block_size |
| | block_offsets = positions % block_size |
| |
|
| | |
| | block_numbers = block_tables[:, block_indices.long()] |
| |
|
| | |
| | flat_block_numbers = block_numbers.reshape(-1) |
| | flat_offsets = block_offsets.repeat(num_seqs) |
| |
|
| | |
| | |
| | keys = key_cache[flat_block_numbers, :, :, flat_offsets, :] |
| | keys = keys.reshape(num_seqs, max_seq_len, num_heads, head_size) |
| | keys = keys.transpose(1, 2) |
| |
|
| | |
| | values = value_cache[flat_block_numbers, :, :, flat_offsets] |
| | values = values.reshape(num_seqs, max_seq_len, num_heads, head_size) |
| | values = values.transpose(1, 2) |
| |
|
| | |
| | q = query.unsqueeze(2) |
| |
|
| | |
| | attn_weights = (scale * torch.matmul(q, keys.transpose(-1, -2))).float() |
| |
|
| | |
| | |
| | seq_mask = positions.unsqueeze(0) >= seq_lens.unsqueeze( |
| | 1 |
| | ) |
| | seq_mask = seq_mask.unsqueeze(1).unsqueeze(2) |
| | attn_weights = attn_weights.masked_fill(seq_mask, float("-inf")) |
| |
|
| | attn_weights = torch.softmax(attn_weights, dim=-1).to(values.dtype) |
| |
|
| | |
| | out = torch.matmul(attn_weights, values) |
| |
|
| | return out.squeeze(2) |
| |
|
| |
|
| | class PagedAttentionBenchmark(Benchmark): |
| | seed: int = 42 |
| |
|
| | def setup(self): |
| | num_seqs = 4 |
| | num_heads = 8 |
| | head_size = 64 |
| | block_size = 16 |
| | max_seq_len = 128 |
| | num_blocks = 64 |
| | dtype = torch.float16 |
| |
|
| | self.num_heads = num_heads |
| | self.block_size = block_size |
| | self.max_seq_len = max_seq_len |
| | self.scale = 1.0 / (head_size**0.5) |
| |
|
| | |
| | self.query = torch.randn( |
| | num_seqs, num_heads, head_size, device=self.device, dtype=dtype |
| | ) |
| |
|
| | |
| | |
| | x = 16 // torch.tensor([], dtype=dtype).element_size() |
| | self.key_cache = torch.randn( |
| | num_blocks, |
| | num_heads, |
| | head_size // x, |
| | block_size, |
| | x, |
| | device=self.device, |
| | dtype=dtype, |
| | ) |
| | self.value_cache = torch.randn( |
| | num_blocks, |
| | num_heads, |
| | head_size, |
| | block_size, |
| | device=self.device, |
| | dtype=dtype, |
| | ) |
| |
|
| | |
| | max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size |
| | self.block_tables = torch.randint( |
| | 0, |
| | num_blocks, |
| | (num_seqs, max_num_blocks_per_seq), |
| | device=self.device, |
| | dtype=torch.int32, |
| | ) |
| |
|
| | |
| | self.seq_lens = torch.tensor( |
| | [64, 96, 48, 128], device=self.device, dtype=torch.int32 |
| | ) |
| |
|
| | |
| | self.k_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device) |
| | self.v_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device) |
| |
|
| | |
| | self.out = torch.empty_like(self.query) |
| |
|
| | def benchmark_base(self): |
| | self.kernel.paged_attention_v1( |
| | self.out, |
| | self.query, |
| | self.key_cache, |
| | self.value_cache, |
| | num_kv_heads=self.num_heads, |
| | scale=self.scale, |
| | block_tables=self.block_tables, |
| | seq_lens=self.seq_lens, |
| | block_size=self.block_size, |
| | max_seq_len=self.max_seq_len, |
| | alibi_slopes=None, |
| | kv_cache_dtype="auto", |
| | k_scale=self.k_scale, |
| | v_scale=self.v_scale, |
| | ) |
| |
|
| | def verify_base(self) -> torch.Tensor: |
| | return ref_paged_attention( |
| | self.query, |
| | self.key_cache, |
| | self.value_cache, |
| | self.block_tables, |
| | self.seq_lens, |
| | self.scale, |
| | ) |
| |
|
| | def setup_large(self): |
| | num_seqs = 16 |
| | num_heads = 32 |
| | head_size = 128 |
| | block_size = 16 |
| | max_seq_len = 512 |
| | num_blocks = 256 |
| | dtype = torch.float16 |
| |
|
| | self.num_heads = num_heads |
| | self.block_size = block_size |
| | self.max_seq_len = max_seq_len |
| | self.scale = 1.0 / (head_size**0.5) |
| |
|
| | self.query = torch.randn( |
| | num_seqs, num_heads, head_size, device=self.device, dtype=dtype |
| | ) |
| |
|
| | x = 16 // torch.tensor([], dtype=dtype).element_size() |
| | self.key_cache = torch.randn( |
| | num_blocks, |
| | num_heads, |
| | head_size // x, |
| | block_size, |
| | x, |
| | device=self.device, |
| | dtype=dtype, |
| | ) |
| | self.value_cache = torch.randn( |
| | num_blocks, |
| | num_heads, |
| | head_size, |
| | block_size, |
| | device=self.device, |
| | dtype=dtype, |
| | ) |
| |
|
| | max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size |
| | self.block_tables = torch.randint( |
| | 0, |
| | num_blocks, |
| | (num_seqs, max_num_blocks_per_seq), |
| | device=self.device, |
| | dtype=torch.int32, |
| | ) |
| |
|
| | |
| | self.seq_lens = torch.randint( |
| | 64, max_seq_len + 1, (num_seqs,), device=self.device, dtype=torch.int32 |
| | ) |
| |
|
| | self.k_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device) |
| | self.v_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device) |
| |
|
| | self.out = torch.empty_like(self.query) |
| |
|
| | def benchmark_large(self): |
| | self.kernel.paged_attention_v1( |
| | self.out, |
| | self.query, |
| | self.key_cache, |
| | self.value_cache, |
| | num_kv_heads=self.num_heads, |
| | scale=self.scale, |
| | block_tables=self.block_tables, |
| | seq_lens=self.seq_lens, |
| | block_size=self.block_size, |
| | max_seq_len=self.max_seq_len, |
| | alibi_slopes=None, |
| | kv_cache_dtype="auto", |
| | k_scale=self.k_scale, |
| | v_scale=self.v_scale, |
| | ) |
| |
|
| | def verify_large(self) -> torch.Tensor: |
| | return ref_paged_attention( |
| | self.query, |
| | self.key_cache, |
| | self.value_cache, |
| | self.block_tables, |
| | self.seq_lens, |
| | self.scale, |
| | ) |
| |
|