| | import random |
| | from typing import List, Optional, Tuple |
| |
|
| | import paged_attention as ops |
| | import pytest |
| | import torch |
| | from paged_attention.platforms import current_platform |
| |
|
| | from .allclose_default import get_default_atol, get_default_rtol |
| | from .utils import get_max_shared_memory_bytes, opcheck |
| |
|
| | FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 |
| | |
| | |
| | MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 |
| | |
| | |
| | NUM_BLOCKS = 4321 |
| | PARTITION_SIZE = 512 |
| | |
| | DTYPES = ( |
| | [torch.half, torch.bfloat16, torch.float] |
| | if not current_platform.is_rocm() |
| | else [torch.half, torch.bfloat16] |
| | ) |
| | NUM_GEN_SEQS = [7] |
| | NUM_PREFILL_SEQS = [3] |
| | NUM_HEADS = [(40, 40), (64, 8)] |
| |
|
| | |
| | |
| | HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256] |
| |
|
| | BLOCK_SIZES = [16, 32] |
| | USE_ALIBI = [False, True] |
| | if current_platform.is_mps(): |
| | KV_CACHE_DTYPE = ["auto", "fp8"] |
| | else: |
| | KV_CACHE_DTYPE = ["auto", "fp8"] |
| | SEEDS = [0] |
| | if current_platform.is_mps(): |
| | DEVICES = ["mps:0"] |
| | else: |
| | DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] |
| |
|
| | def ref_masked_attention( |
| | query: torch.Tensor, |
| | key: torch.Tensor, |
| | value: torch.Tensor, |
| | scale: float, |
| | attn_mask: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() |
| | if attn_mask is not None: |
| | attn_weights = attn_weights + attn_mask.float() |
| | attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) |
| | out = torch.einsum("hqk,khd->qhd", attn_weights, value) |
| | return out |
| |
|
| |
|
| | def ref_single_query_cached_kv_attention( |
| | output: torch.Tensor, |
| | query: torch.Tensor, |
| | num_queries_per_kv: int, |
| | key_cache: torch.Tensor, |
| | value_cache: torch.Tensor, |
| | block_tables: torch.Tensor, |
| | seq_lens: torch.Tensor, |
| | scale: float, |
| | alibi_slopes: Optional[torch.Tensor], |
| | ) -> None: |
| | num_query_heads = query.shape[1] |
| | num_kv_heads = value_cache.shape[1] |
| | head_size = value_cache.shape[2] |
| | block_size = value_cache.shape[3] |
| | num_seqs = query.shape[0] |
| |
|
| | block_tables_lst = block_tables.cpu().tolist() |
| | seq_lens_lst = seq_lens.cpu().tolist() |
| | for i in range(num_seqs): |
| | q = query[i].unsqueeze(0) |
| | block_table = block_tables_lst[i] |
| | seq_len = int(seq_lens_lst[i]) |
| |
|
| | keys_lst: List[torch.Tensor] = [] |
| | values_lst: List[torch.Tensor] = [] |
| | for j in range(seq_len): |
| | block_number = int(block_table[j // block_size]) |
| | block_offset = j % block_size |
| |
|
| | k = key_cache[block_number, :, :, block_offset, :] |
| | k = k.reshape(num_kv_heads, head_size) |
| | keys_lst.append(k) |
| |
|
| | v = value_cache[block_number, :, :, block_offset] |
| | values_lst.append(v) |
| | keys = torch.stack(keys_lst, dim=0) |
| | values = torch.stack(values_lst, dim=0) |
| | if num_queries_per_kv > 1: |
| | |
| | keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) |
| | values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) |
| |
|
| | alibi_bias = None |
| | if alibi_slopes is not None: |
| | |
| | position_ids = torch.arange(seq_len).int() |
| | alibi_bias = (position_ids - seq_len + 1).float() |
| | alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1) |
| |
|
| | out = ref_masked_attention(q, keys, values, scale, alibi_bias) |
| | out = out.view(num_query_heads, head_size) |
| | output[i].copy_(out, non_blocking=True) |
| |
|
| |
|
| | @pytest.mark.parametrize( |
| | "version", ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"] |
| | ) |
| | @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) |
| | @pytest.mark.parametrize("num_heads", NUM_HEADS) |
| | @pytest.mark.parametrize("head_size", HEAD_SIZES) |
| | @pytest.mark.parametrize("use_alibi", USE_ALIBI) |
| | @pytest.mark.parametrize("block_size", BLOCK_SIZES) |
| | @pytest.mark.parametrize("dtype", DTYPES) |
| | @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) |
| | @pytest.mark.parametrize("seed", SEEDS) |
| | @pytest.mark.parametrize("device", DEVICES) |
| | def test_paged_attention( |
| | kv_cache_factory, |
| | version: str, |
| | num_seqs: int, |
| | num_heads: Tuple[int, int], |
| | head_size: int, |
| | use_alibi: bool, |
| | block_size: int, |
| | dtype: torch.dtype, |
| | kv_cache_dtype: str, |
| | seed: int, |
| | device: str, |
| | ) -> None: |
| | if (kv_cache_dtype == "fp8" and head_size % 16) or ( |
| | version == "rocm" and head_size not in (64, 128) |
| | ): |
| | pytest.skip() |
| |
|
| | current_platform.seed_everything(seed) |
| | torch.set_default_device(device) |
| | scale = float(1.0 / (head_size**0.5)) |
| | num_query_heads, num_kv_heads = num_heads |
| | query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) |
| | query.uniform_(-scale, scale) |
| |
|
| | assert num_query_heads % num_kv_heads == 0 |
| | num_queries_per_kv = num_query_heads // num_kv_heads |
| | alibi_slopes = None |
| | if use_alibi: |
| | alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) |
| |
|
| | seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] |
| | seq_lens[-1] = MAX_SEQ_LEN |
| | max_seq_len = max(seq_lens) |
| | seq_lens = torch.tensor(seq_lens, dtype=torch.int) |
| |
|
| | |
| | max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size |
| | block_tables_lst: List[List[int]] = [] |
| | for _ in range(num_seqs): |
| | block_table = [ |
| | random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq) |
| | ] |
| | block_tables_lst.append(block_table) |
| |
|
| | block_tables = torch.tensor(block_tables_lst, dtype=torch.int) |
| |
|
| | |
| | key_caches, value_caches = kv_cache_factory( |
| | NUM_BLOCKS, |
| | block_size, |
| | 1, |
| | num_kv_heads, |
| | head_size, |
| | kv_cache_dtype, |
| | dtype, |
| | seed, |
| | device, |
| | ) |
| | key_cache, value_cache = key_caches[0], value_caches[0] |
| |
|
| | |
| | k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) |
| |
|
| | |
| | output = torch.empty_like(query) |
| | if version == "v1": |
| | ops.paged_attention_v1( |
| | output, |
| | query, |
| | key_cache, |
| | value_cache, |
| | num_kv_heads, |
| | scale, |
| | block_tables, |
| | seq_lens, |
| | block_size, |
| | max_seq_len, |
| | alibi_slopes, |
| | kv_cache_dtype, |
| | k_scale, |
| | v_scale, |
| | ) |
| |
|
| | opcheck( |
| | ops.ops.paged_attention_v1, |
| | ( |
| | output, |
| | query, |
| | key_cache, |
| | value_cache, |
| | num_kv_heads, |
| | scale, |
| | block_tables, |
| | seq_lens, |
| | block_size, |
| | max_seq_len, |
| | alibi_slopes, |
| | kv_cache_dtype, |
| | k_scale, |
| | v_scale, |
| | 0, |
| | 0, |
| | 0, |
| | 64, |
| | 0, |
| | ), |
| | cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0] and not device.startswith("mps")), |
| | ) |
| |
|
| | elif version in ("v2", "rocm"): |
| | num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE |
| | assert PARTITION_SIZE % block_size == 0 |
| | num_seqs, num_heads, head_size = output.shape |
| | tmp_output = torch.empty( |
| | size=(num_seqs, num_heads, num_partitions, head_size), |
| | dtype=output.dtype, |
| | ) |
| | exp_sums = torch.empty( |
| | size=(num_seqs, num_heads, num_partitions), |
| | dtype=torch.float32, |
| | ) |
| | max_logits = torch.empty_like(exp_sums) |
| | if version == "v2": |
| | ops.paged_attention_v2( |
| | output, |
| | exp_sums, |
| | max_logits, |
| | tmp_output, |
| | query, |
| | key_cache, |
| | value_cache, |
| | num_kv_heads, |
| | scale, |
| | block_tables, |
| | seq_lens, |
| | block_size, |
| | max_seq_len, |
| | alibi_slopes, |
| | kv_cache_dtype, |
| | k_scale, |
| | v_scale, |
| | ) |
| |
|
| | opcheck( |
| | ops.ops.paged_attention_v2, |
| | ( |
| | output, |
| | exp_sums, |
| | max_logits, |
| | tmp_output, |
| | query, |
| | key_cache, |
| | value_cache, |
| | num_kv_heads, |
| | scale, |
| | block_tables, |
| | seq_lens, |
| | block_size, |
| | max_seq_len, |
| | alibi_slopes, |
| | kv_cache_dtype, |
| | k_scale, |
| | v_scale, |
| | 0, |
| | 0, |
| | 0, |
| | 64, |
| | 0, |
| | ), |
| | cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0] and not device.startswith("mps")), |
| | ) |
| |
|
| | else: |
| | ops.paged_attention_rocm( |
| | output, |
| | exp_sums, |
| | max_logits, |
| | tmp_output, |
| | query, |
| | key_cache, |
| | value_cache, |
| | num_kv_heads, |
| | scale, |
| | block_tables, |
| | seq_lens, |
| | block_size, |
| | max_seq_len, |
| | alibi_slopes, |
| | kv_cache_dtype, |
| | k_scale, |
| | v_scale, |
| | ) |
| |
|
| | opcheck( |
| | torch.ops._rocm_C.paged_attention, |
| | ( |
| | output, |
| | exp_sums, |
| | max_logits, |
| | tmp_output, |
| | query, |
| | key_cache, |
| | value_cache, |
| | num_kv_heads, |
| | scale, |
| | block_tables, |
| | seq_lens, |
| | block_size, |
| | max_seq_len, |
| | alibi_slopes, |
| | kv_cache_dtype, |
| | k_scale, |
| | v_scale, |
| | ), |
| | cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0] and not device.startswith("mps")), |
| | ) |
| |
|
| | else: |
| | raise AssertionError(f"Unknown version: {version}") |
| |
|
| | |
| | if kv_cache_dtype == "fp8": |
| | |
| | x = 16 // torch.tensor([], dtype=dtype).element_size() |
| | key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) |
| | dequantized_key_cache = torch.empty( |
| | size=key_cache_shape, dtype=dtype, device=device |
| | ) |
| | ops.convert_fp8(dequantized_key_cache, key_cache) |
| | key_cache = dequantized_key_cache |
| |
|
| | value_cache_shape = value_cache.shape |
| | dequantized_value_cache = torch.empty( |
| | size=value_cache_shape, dtype=dtype, device=device |
| | ) |
| | ops.convert_fp8(dequantized_value_cache, value_cache) |
| | value_cache = dequantized_value_cache |
| |
|
| | ref_output = torch.empty_like(query) |
| | ref_single_query_cached_kv_attention( |
| | ref_output, |
| | query, |
| | num_queries_per_kv, |
| | key_cache, |
| | value_cache, |
| | block_tables, |
| | seq_lens, |
| | scale, |
| | alibi_slopes, |
| | ) |
| |
|
| | |
| | |
| | |
| | atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 |
| | rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 |
| |
|
| | |
| | |
| | atol, rtol = 1e-3, 1e-5 |
| | if kv_cache_dtype == "fp8": |
| | atol, rtol = 1e-2, 1e-5 |
| | |
| | elif dtype == torch.bfloat16 and use_alibi: |
| | atol, rtol = 2e-3, 1e-5 |
| | torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) |
| |
|
| |
|
| | def ref_multi_query_kv_attention( |
| | cu_seq_lens: List[int], |
| | query: torch.Tensor, |
| | key: torch.Tensor, |
| | value: torch.Tensor, |
| | scale: float, |
| | dtype: torch.dtype, |
| | ) -> torch.Tensor: |
| | num_seqs = len(cu_seq_lens) - 1 |
| | ref_outputs: List[torch.Tensor] = [] |
| | for i in range(num_seqs): |
| | start_idx = cu_seq_lens[i] |
| | end_idx = cu_seq_lens[i + 1] |
| | seq_len = end_idx - start_idx |
| |
|
| | |
| | attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1) |
| | attn_mask = attn_mask * torch.finfo(dtype).min |
| | attn_mask = attn_mask.to(dtype=dtype) |
| |
|
| | ref_output = ref_masked_attention( |
| | query[start_idx:end_idx], |
| | key[start_idx:end_idx], |
| | value[start_idx:end_idx], |
| | scale, |
| | attn_mask=attn_mask, |
| | ) |
| | ref_outputs.append(ref_output) |
| |
|
| | return torch.cat(ref_outputs, dim=0) |
| |
|