| | from typing import List, Optional, Tuple, Union |
| |
|
| | import paged_attention as ops |
| | import pytest |
| | import torch |
| |
|
| |
|
| | @pytest.fixture() |
| | def kv_cache_factory(): |
| | return create_kv_caches_with_random |
| |
|
| |
|
| | @pytest.fixture() |
| | def kv_cache_factory_flashinfer(): |
| | return create_kv_caches_with_random_flash |
| |
|
| |
|
| | STR_DTYPE_TO_TORCH_DTYPE = { |
| | "half": torch.half, |
| | "bfloat16": torch.bfloat16, |
| | "float": torch.float, |
| | "fp8": torch.uint8, |
| | "fp8_e4m3": torch.uint8, |
| | "fp8_e5m2": torch.uint8, |
| | } |
| |
|
| |
|
| | def create_kv_caches_with_random( |
| | num_blocks: int, |
| | block_size: int, |
| | num_layers: int, |
| | num_heads: int, |
| | head_size: int, |
| | cache_dtype: Optional[Union[str, torch.dtype]], |
| | model_dtype: Optional[Union[str, torch.dtype]] = None, |
| | seed: int = 0, |
| | device: Optional[str] = "cuda", |
| | ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: |
| |
|
| | if cache_dtype == "fp8" and head_size % 16: |
| | raise ValueError( |
| | f"Does not support key cache of type fp8 with head_size {head_size}" |
| | ) |
| | from paged_attention.platforms import current_platform |
| |
|
| | current_platform.seed_everything(seed) |
| |
|
| | torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) |
| |
|
| | scale = head_size**-0.5 |
| | x = 16 // torch.tensor([], dtype=torch_dtype).element_size() |
| | key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) |
| | key_caches: List[torch.Tensor] = [] |
| | for _ in range(num_layers): |
| | key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device) |
| | if cache_dtype in ["auto", "half", "bfloat16", "float"]: |
| | key_cache.uniform_(-scale, scale) |
| | elif cache_dtype == "fp8": |
| | _generate_random_fp8(key_cache, -scale, scale) |
| | else: |
| | raise ValueError(f"Does not support key cache of type {cache_dtype}") |
| | key_caches.append(key_cache) |
| |
|
| | value_cache_shape = (num_blocks, num_heads, head_size, block_size) |
| | value_caches: List[torch.Tensor] = [] |
| | for _ in range(num_layers): |
| | value_cache = torch.empty( |
| | size=value_cache_shape, dtype=torch_dtype, device=device |
| | ) |
| | if cache_dtype in ["auto", "half", "bfloat16", "float"]: |
| | value_cache.uniform_(-scale, scale) |
| | elif cache_dtype == "fp8": |
| | _generate_random_fp8(value_cache, -scale, scale) |
| | else: |
| | raise ValueError(f"Does not support value cache of type {cache_dtype}") |
| | value_caches.append(value_cache) |
| | return key_caches, value_caches |
| |
|
| |
|
| | def create_kv_caches_with_random_flash( |
| | num_blocks: int, |
| | block_size: int, |
| | num_layers: int, |
| | num_heads: int, |
| | head_size: int, |
| | cache_dtype: Optional[Union[str, torch.dtype]], |
| | model_dtype: Optional[Union[str, torch.dtype]] = None, |
| | seed: int = 0, |
| | device: Optional[str] = "cuda", |
| | ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: |
| | from paged_attention.platforms import current_platform |
| |
|
| | current_platform.seed_everything(seed) |
| |
|
| | torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) |
| | key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) |
| | scale = head_size**-0.5 |
| |
|
| | key_caches: List[torch.Tensor] = [] |
| | value_caches: List[torch.Tensor] = [] |
| |
|
| | for _ in range(num_layers): |
| | key_value_cache = torch.empty( |
| | size=key_value_cache_shape, dtype=torch_dtype, device=device |
| | ) |
| | if cache_dtype in ["auto", "half", "bfloat16", "float"]: |
| | key_value_cache.uniform_(-scale, scale) |
| | elif cache_dtype == "fp8": |
| | _generate_random_fp8(key_value_cache, -scale, scale) |
| | else: |
| | raise ValueError(f"Does not support key cache of type {cache_dtype}") |
| | key_caches.append(key_value_cache[:, 0]) |
| | value_caches.append(key_value_cache[:, 1]) |
| | return key_caches, value_caches |
| |
|
| |
|
| | def get_kv_cache_torch_dtype( |
| | cache_dtype: Optional[Union[str, torch.dtype]], |
| | model_dtype: Optional[Union[str, torch.dtype]] = None, |
| | ) -> torch.dtype: |
| | if isinstance(cache_dtype, str): |
| | if cache_dtype == "auto": |
| | if isinstance(model_dtype, str): |
| | torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] |
| | elif isinstance(model_dtype, torch.dtype): |
| | torch_dtype = model_dtype |
| | else: |
| | raise ValueError(f"Invalid model dtype: {model_dtype}") |
| | elif cache_dtype in ["half", "bfloat16", "float"]: |
| | torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] |
| | elif cache_dtype == "fp8": |
| | torch_dtype = torch.uint8 |
| | else: |
| | raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") |
| | elif isinstance(cache_dtype, torch.dtype): |
| | torch_dtype = cache_dtype |
| | else: |
| | raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") |
| | return torch_dtype |
| |
|
| |
|
| | def _generate_random_fp8( |
| | tensor: torch.Tensor, |
| | low: float, |
| | high: float, |
| | ) -> None: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) |
| | tensor_tmp.uniform_(low, high) |
| | ops.convert_fp8(tensor, tensor_tmp) |
| | del tensor_tmp |
| |
|