Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/vllm/attention/__init__.py +19 -0
- .venv/lib/python3.11/site-packages/vllm/attention/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/__pycache__/layer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/__pycache__/selector.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/abstract.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/blocksparse_attn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/flash_attn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/flashinfer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/hpu_attn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/ipex_attn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/openvino.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/pallas.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/placeholder_attn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/rocm_flash_attn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/torch_sdpa.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/triton_mla.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/xformers.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/abstract.py +296 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/blocksparse_attn.py +457 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/flash_attn.py +942 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/flashinfer.py +1066 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/hpu_attn.py +293 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/ipex_attn.py +387 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/mla/utils.py +541 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/openvino.py +146 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/pallas.py +337 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/placeholder_attn.py +410 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/rocm_flash_attn.py +891 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/torch_sdpa.py +681 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/triton_mla.py +746 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/utils.py +582 -0
- .venv/lib/python3.11/site-packages/vllm/attention/backends/xformers.py +794 -0
- .venv/lib/python3.11/site-packages/vllm/attention/ops/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/hpu_paged_attn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/ipex_attn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/nki_flash_attn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/paged_attn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/prefix_prefill.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/triton_decode_attention.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/triton_flash_attention.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/blocksparse_attention_kernel.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/vllm/attention/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from vllm.attention.backends.abstract import (AttentionBackend,
|
| 4 |
+
AttentionMetadata,
|
| 5 |
+
AttentionMetadataBuilder,
|
| 6 |
+
AttentionState, AttentionType)
|
| 7 |
+
from vllm.attention.layer import Attention
|
| 8 |
+
from vllm.attention.selector import get_attn_backend
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"Attention",
|
| 12 |
+
"AttentionBackend",
|
| 13 |
+
"AttentionMetadata",
|
| 14 |
+
"AttentionType",
|
| 15 |
+
"AttentionMetadataBuilder",
|
| 16 |
+
"Attention",
|
| 17 |
+
"AttentionState",
|
| 18 |
+
"get_attn_backend",
|
| 19 |
+
]
|
.venv/lib/python3.11/site-packages/vllm/attention/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (678 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/__pycache__/layer.cpython-311.pyc
ADDED
|
Binary file (16.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/__pycache__/selector.cpython-311.pyc
ADDED
|
Binary file (5.85 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (196 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/abstract.cpython-311.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/blocksparse_attn.cpython-311.pyc
ADDED
|
Binary file (17.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/flash_attn.cpython-311.pyc
ADDED
|
Binary file (36.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/flashinfer.cpython-311.pyc
ADDED
|
Binary file (44 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/hpu_attn.cpython-311.pyc
ADDED
|
Binary file (13.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/ipex_attn.cpython-311.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/openvino.cpython-311.pyc
ADDED
|
Binary file (6.34 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/pallas.cpython-311.pyc
ADDED
|
Binary file (14.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/placeholder_attn.cpython-311.pyc
ADDED
|
Binary file (18 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/rocm_flash_attn.cpython-311.pyc
ADDED
|
Binary file (35.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/torch_sdpa.cpython-311.pyc
ADDED
|
Binary file (28.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/triton_mla.cpython-311.pyc
ADDED
|
Binary file (31.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (25.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/xformers.cpython-311.pyc
ADDED
|
Binary file (29 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/abstract.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from contextlib import contextmanager
|
| 5 |
+
from dataclasses import dataclass, fields
|
| 6 |
+
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
|
| 7 |
+
Protocol, Set, Tuple, Type, TypeVar)
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from vllm.multimodal import MultiModalPlaceholderMap
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from vllm.worker.model_runner_base import (ModelRunnerBase,
|
| 15 |
+
ModelRunnerInputBase,
|
| 16 |
+
ModelRunnerInputBuilderBase)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class AttentionType:
|
| 20 |
+
"""
|
| 21 |
+
Attention type.
|
| 22 |
+
Use string to be compatible with `torch.compile`.
|
| 23 |
+
"""
|
| 24 |
+
# Decoder attention between previous layer Q/K/V
|
| 25 |
+
DECODER = "decoder"
|
| 26 |
+
# Encoder attention between previous layer Q/K/V for encoder-decoder
|
| 27 |
+
ENCODER = "encoder"
|
| 28 |
+
# Encoder attention between previous layer Q/K/V
|
| 29 |
+
ENCODER_ONLY = "encoder_only"
|
| 30 |
+
# Attention between dec. Q and enc. K/V for encoder-decoder
|
| 31 |
+
ENCODER_DECODER = "encoder_decoder"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class AttentionBackend(ABC):
|
| 35 |
+
"""Abstract class for attention backends."""
|
| 36 |
+
# For some attention backends, we allocate an output tensor before
|
| 37 |
+
# calling the custom op. When piecewise cudagraph is enabled, this
|
| 38 |
+
# makes sure the output tensor is allocated inside the cudagraph.
|
| 39 |
+
accept_output_buffer: bool = False
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
@abstractmethod
|
| 43 |
+
def get_name() -> str:
|
| 44 |
+
raise NotImplementedError
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
@abstractmethod
|
| 48 |
+
def get_impl_cls() -> Type["AttentionImpl"]:
|
| 49 |
+
raise NotImplementedError
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
@abstractmethod
|
| 53 |
+
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
| 54 |
+
raise NotImplementedError
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
@abstractmethod
|
| 58 |
+
def get_state_cls() -> Type["AttentionState"]:
|
| 59 |
+
raise NotImplementedError
|
| 60 |
+
|
| 61 |
+
@classmethod
|
| 62 |
+
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
|
| 63 |
+
return cls.get_metadata_cls()(*args, **kwargs)
|
| 64 |
+
|
| 65 |
+
@staticmethod
|
| 66 |
+
@abstractmethod
|
| 67 |
+
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
|
| 68 |
+
raise NotImplementedError
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
@abstractmethod
|
| 72 |
+
def get_kv_cache_shape(
|
| 73 |
+
num_blocks: int,
|
| 74 |
+
block_size: int,
|
| 75 |
+
num_kv_heads: int,
|
| 76 |
+
head_size: int,
|
| 77 |
+
) -> Tuple[int, ...]:
|
| 78 |
+
raise NotImplementedError
|
| 79 |
+
|
| 80 |
+
@staticmethod
|
| 81 |
+
@abstractmethod
|
| 82 |
+
def swap_blocks(
|
| 83 |
+
src_kv_cache: torch.Tensor,
|
| 84 |
+
dst_kv_cache: torch.Tensor,
|
| 85 |
+
src_to_dst: torch.Tensor,
|
| 86 |
+
) -> None:
|
| 87 |
+
raise NotImplementedError
|
| 88 |
+
|
| 89 |
+
@staticmethod
|
| 90 |
+
@abstractmethod
|
| 91 |
+
def copy_blocks(
|
| 92 |
+
kv_caches: List[torch.Tensor],
|
| 93 |
+
src_to_dists: torch.Tensor,
|
| 94 |
+
) -> None:
|
| 95 |
+
raise NotImplementedError
|
| 96 |
+
|
| 97 |
+
def advance_step(self, model_input: "ModelRunnerInputBase",
|
| 98 |
+
sampled_token_ids: Optional[torch.Tensor],
|
| 99 |
+
block_size: int, num_seqs: int, num_queries: int) -> None:
|
| 100 |
+
raise NotImplementedError
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@dataclass
|
| 104 |
+
class AttentionMetadata:
|
| 105 |
+
"""Attention metadata for prefill and decode batched together."""
|
| 106 |
+
# Total number of prefill requests.
|
| 107 |
+
num_prefills: int
|
| 108 |
+
# Number of prefill tokens.
|
| 109 |
+
num_prefill_tokens: int
|
| 110 |
+
# Number of decode tokens. Note that it is equivalent to the number of
|
| 111 |
+
# decode requests.
|
| 112 |
+
num_decode_tokens: int
|
| 113 |
+
# (num_tokens,). The indices of the token slots that input tokens will be
|
| 114 |
+
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
| 115 |
+
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
| 116 |
+
# in block 0, and 1st slot in block 1, respectively.
|
| 117 |
+
slot_mapping: torch.Tensor
|
| 118 |
+
|
| 119 |
+
# The index maps that relate multi-modal embeddings to the corresponding
|
| 120 |
+
# placeholders.
|
| 121 |
+
#
|
| 122 |
+
# N.B. These aren't really related to attention and don't belong on this
|
| 123 |
+
# type -- this is just a temporary solution to make them available to
|
| 124 |
+
# `model_executable`.
|
| 125 |
+
multi_modal_placeholder_index_maps: Optional[Dict[
|
| 126 |
+
str, MultiModalPlaceholderMap.IndexMap]]
|
| 127 |
+
|
| 128 |
+
# Enable/disable KV scales calculation. This is so that we can disable the
|
| 129 |
+
# calculation until after prefill and cuda graph capture.
|
| 130 |
+
enable_kv_scales_calculation: bool
|
| 131 |
+
|
| 132 |
+
@property
|
| 133 |
+
@abstractmethod
|
| 134 |
+
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
|
| 135 |
+
"""Return the attention metadata that's required to run prefill
|
| 136 |
+
attention."""
|
| 137 |
+
pass
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
@abstractmethod
|
| 141 |
+
def decode_metadata(self) -> Optional["AttentionMetadata"]:
|
| 142 |
+
"""Return the attention metadata that's required to run decode
|
| 143 |
+
attention."""
|
| 144 |
+
pass
|
| 145 |
+
|
| 146 |
+
def asdict_zerocopy(self,
|
| 147 |
+
skip_fields: Optional[Set[str]] = None
|
| 148 |
+
) -> Dict[str, Any]:
|
| 149 |
+
"""Similar to dataclasses.asdict, but avoids deepcopying."""
|
| 150 |
+
if skip_fields is None:
|
| 151 |
+
skip_fields = set()
|
| 152 |
+
# Note that if we add dataclasses as fields, they will need
|
| 153 |
+
# similar handling.
|
| 154 |
+
return {
|
| 155 |
+
field.name: getattr(self, field.name)
|
| 156 |
+
for field in fields(self) if field.name not in skip_fields
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
T = TypeVar("T", bound=AttentionMetadata)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class AttentionState(ABC, Generic[T]):
|
| 164 |
+
"""Holds attention backend-specific objects reused during the
|
| 165 |
+
lifetime of the model runner."""
|
| 166 |
+
|
| 167 |
+
@abstractmethod
|
| 168 |
+
def __init__(self, runner: "ModelRunnerBase"):
|
| 169 |
+
...
|
| 170 |
+
|
| 171 |
+
@abstractmethod
|
| 172 |
+
@contextmanager
|
| 173 |
+
def graph_capture(self, max_batch_size: int):
|
| 174 |
+
"""Context manager used when capturing CUDA graphs."""
|
| 175 |
+
yield
|
| 176 |
+
|
| 177 |
+
@abstractmethod
|
| 178 |
+
def graph_clone(self, batch_size: int) -> "AttentionState[T]":
|
| 179 |
+
"""Clone attention state to save in CUDA graph metadata."""
|
| 180 |
+
...
|
| 181 |
+
|
| 182 |
+
@abstractmethod
|
| 183 |
+
def graph_capture_get_metadata_for_batch(
|
| 184 |
+
self,
|
| 185 |
+
batch_size: int,
|
| 186 |
+
is_encoder_decoder_model: bool = False) -> T:
|
| 187 |
+
"""Get attention metadata for CUDA graph capture of batch_size."""
|
| 188 |
+
...
|
| 189 |
+
|
| 190 |
+
@abstractmethod
|
| 191 |
+
def get_graph_input_buffers(
|
| 192 |
+
self,
|
| 193 |
+
attn_metadata: T,
|
| 194 |
+
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
|
| 195 |
+
"""Get attention-specific input buffers for CUDA graph capture."""
|
| 196 |
+
...
|
| 197 |
+
|
| 198 |
+
@abstractmethod
|
| 199 |
+
def prepare_graph_input_buffers(
|
| 200 |
+
self,
|
| 201 |
+
input_buffers: Dict[str, Any],
|
| 202 |
+
attn_metadata: T,
|
| 203 |
+
is_encoder_decoder_model: bool = False) -> None:
|
| 204 |
+
"""In-place modify input buffers dict for CUDA graph replay."""
|
| 205 |
+
...
|
| 206 |
+
|
| 207 |
+
@abstractmethod
|
| 208 |
+
def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
|
| 209 |
+
"""Prepare state for forward pass."""
|
| 210 |
+
...
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class AttentionMetadataBuilder(ABC, Generic[T]):
|
| 214 |
+
"""Abstract class for attention metadata builders."""
|
| 215 |
+
|
| 216 |
+
@abstractmethod
|
| 217 |
+
def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
|
| 218 |
+
"""Create the builder, remember some configuration and parameters."""
|
| 219 |
+
raise NotImplementedError
|
| 220 |
+
|
| 221 |
+
@abstractmethod
|
| 222 |
+
def prepare(self) -> None:
|
| 223 |
+
"""Prepare for one batch."""
|
| 224 |
+
raise NotImplementedError
|
| 225 |
+
|
| 226 |
+
@abstractmethod
|
| 227 |
+
def build(self, seq_lens: List[int], query_lens: List[int],
|
| 228 |
+
cuda_graph_pad_size: int, batch_size: int) -> T:
|
| 229 |
+
"""Build attention metadata with on-device tensors."""
|
| 230 |
+
raise NotImplementedError
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class AttentionLayer(Protocol):
|
| 234 |
+
|
| 235 |
+
_k_scale: torch.Tensor
|
| 236 |
+
_v_scale: torch.Tensor
|
| 237 |
+
_k_scale_float: float
|
| 238 |
+
_v_scale_float: float
|
| 239 |
+
|
| 240 |
+
def forward(
|
| 241 |
+
self,
|
| 242 |
+
query: torch.Tensor,
|
| 243 |
+
key: torch.Tensor,
|
| 244 |
+
value: torch.Tensor,
|
| 245 |
+
kv_cache: torch.Tensor,
|
| 246 |
+
attn_metadata: AttentionMetadata,
|
| 247 |
+
) -> torch.Tensor:
|
| 248 |
+
...
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class AttentionImpl(ABC, Generic[T]):
|
| 252 |
+
|
| 253 |
+
@abstractmethod
|
| 254 |
+
def __init__(
|
| 255 |
+
self,
|
| 256 |
+
num_heads: int,
|
| 257 |
+
head_size: int,
|
| 258 |
+
scale: float,
|
| 259 |
+
num_kv_heads: Optional[int] = None,
|
| 260 |
+
alibi_slopes: Optional[List[float]] = None,
|
| 261 |
+
sliding_window: Optional[int] = None,
|
| 262 |
+
kv_cache_dtype: str = "auto",
|
| 263 |
+
blocksparse_params: Optional[Dict[str, Any]] = None,
|
| 264 |
+
logits_soft_cap: Optional[float] = None,
|
| 265 |
+
attn_type: str = AttentionType.DECODER,
|
| 266 |
+
) -> None:
|
| 267 |
+
raise NotImplementedError
|
| 268 |
+
|
| 269 |
+
@abstractmethod
|
| 270 |
+
def forward(
|
| 271 |
+
self,
|
| 272 |
+
layer: AttentionLayer,
|
| 273 |
+
query: torch.Tensor,
|
| 274 |
+
key: torch.Tensor,
|
| 275 |
+
value: torch.Tensor,
|
| 276 |
+
kv_cache: torch.Tensor,
|
| 277 |
+
attn_metadata: T,
|
| 278 |
+
output: Optional[torch.Tensor] = None,
|
| 279 |
+
) -> torch.Tensor:
|
| 280 |
+
raise NotImplementedError
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
| 284 |
+
|
| 285 |
+
@abstractmethod
|
| 286 |
+
def forward(
|
| 287 |
+
self,
|
| 288 |
+
layer: AttentionLayer,
|
| 289 |
+
hidden_states_or_cq: torch.Tensor,
|
| 290 |
+
kv_c_normed: torch.Tensor,
|
| 291 |
+
k_pe: torch.Tensor,
|
| 292 |
+
kv_cache: torch.Tensor,
|
| 293 |
+
attn_metadata: T,
|
| 294 |
+
output: Optional[torch.Tensor] = None,
|
| 295 |
+
) -> torch.Tensor:
|
| 296 |
+
raise NotImplementedError
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/blocksparse_attn.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
| 9 |
+
AttentionLayer,
|
| 10 |
+
AttentionMetadata, AttentionType)
|
| 11 |
+
from vllm.attention.backends.utils import (CommonAttentionState,
|
| 12 |
+
CommonMetadataBuilder)
|
| 13 |
+
from vllm.attention.ops.blocksparse_attention.interface import (
|
| 14 |
+
LocalStridedBlockSparseAttn, get_head_sliding_step)
|
| 15 |
+
from vllm.attention.ops.paged_attn import PagedAttention
|
| 16 |
+
from vllm.distributed import (get_tensor_model_parallel_rank,
|
| 17 |
+
get_tensor_model_parallel_world_size)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class BlocksparseParams:
|
| 22 |
+
max_seqlen: int
|
| 23 |
+
|
| 24 |
+
# Num q heads per tensor-parallel rank/partition
|
| 25 |
+
num_heads: int # per TP partition
|
| 26 |
+
# Num kv heads per tensor-parallel rank/partition
|
| 27 |
+
num_kv_heads: int
|
| 28 |
+
|
| 29 |
+
# block size used for blocksparse attention.
|
| 30 |
+
# This is the block_size used in `local_blocks`, `vert_stride`.
|
| 31 |
+
block_size: int
|
| 32 |
+
|
| 33 |
+
# Number of blocks for local attention, i.e., number of
|
| 34 |
+
# local attended tokens / `sparse_block_size`
|
| 35 |
+
local_blocks: int
|
| 36 |
+
|
| 37 |
+
# Attend to one block per every `vert_stride` blocks.
|
| 38 |
+
# Controlling the sparsity
|
| 39 |
+
vert_stride: int
|
| 40 |
+
"""
|
| 41 |
+
If to use the same vertical stride offset for all heads,
|
| 42 |
+
i.e., attend to the same block of tokens on all heads.
|
| 43 |
+
By default, it is False, i.e., attention on the non-local
|
| 44 |
+
blocks depends on the `head_idx`, that is on
|
| 45 |
+
blocks satisfying
|
| 46 |
+
`(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0`
|
| 47 |
+
where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`,
|
| 48 |
+
`block_idx = position_id // sparse_block_size`.
|
| 49 |
+
See `..ops.blocksparse_attention.utils:get_sparse_attn_mask`
|
| 50 |
+
for more detail.
|
| 51 |
+
"""
|
| 52 |
+
homo_head: bool = False
|
| 53 |
+
|
| 54 |
+
# If within a group, the kv offsets that each q attends is the same or no.
|
| 55 |
+
homo_head_group: bool = False
|
| 56 |
+
|
| 57 |
+
# Decided by homo_head and homo_head group
|
| 58 |
+
head_sliding_step: int = field(init=False)
|
| 59 |
+
|
| 60 |
+
# range of q heads to for a TP rank
|
| 61 |
+
active_head_range: Tuple = field(init=False)
|
| 62 |
+
|
| 63 |
+
def __post_init__(self):
|
| 64 |
+
assert self.block_size > 0
|
| 65 |
+
assert self.local_blocks >= 0
|
| 66 |
+
assert self.vert_stride >= 1
|
| 67 |
+
assert self.num_heads % self.num_kv_heads == 0
|
| 68 |
+
|
| 69 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 70 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 71 |
+
total_heads = tp_size * self.num_heads
|
| 72 |
+
total_kv_heads = tp_size * self.num_kv_heads
|
| 73 |
+
|
| 74 |
+
if self.homo_head:
|
| 75 |
+
self.head_sliding_step = 0
|
| 76 |
+
elif self.homo_head_group:
|
| 77 |
+
head_sliding_step = get_head_sliding_step(total_kv_heads,
|
| 78 |
+
self.vert_stride)
|
| 79 |
+
# negative indicates sliding along kv heads, i.e., homo q group
|
| 80 |
+
self.head_sliding_step = -head_sliding_step
|
| 81 |
+
else:
|
| 82 |
+
self.head_sliding_step = get_head_sliding_step(
|
| 83 |
+
total_heads, self.vert_stride)
|
| 84 |
+
|
| 85 |
+
self.active_head_range = (
|
| 86 |
+
tp_rank * self.num_heads,
|
| 87 |
+
(tp_rank + 1) * self.num_heads,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class BlocksparseFlashAttentionBackend(AttentionBackend):
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def get_name() -> str:
|
| 95 |
+
return "BLOCK_SPARSE_FLASH_ATTN"
|
| 96 |
+
|
| 97 |
+
@staticmethod
|
| 98 |
+
def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
|
| 99 |
+
return BlocksparseFlashAttentionImpl
|
| 100 |
+
|
| 101 |
+
@staticmethod
|
| 102 |
+
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
| 103 |
+
return BlocksparseFlashAttentionMetadata
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
|
| 107 |
+
return BlocksparseFlashAttentionMetadataBuilder
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def get_state_cls() -> Type["CommonAttentionState"]:
|
| 111 |
+
return CommonAttentionState
|
| 112 |
+
|
| 113 |
+
@staticmethod
|
| 114 |
+
def get_kv_cache_shape(
|
| 115 |
+
num_blocks: int,
|
| 116 |
+
block_size: int,
|
| 117 |
+
num_kv_heads: int,
|
| 118 |
+
head_size: int,
|
| 119 |
+
) -> Tuple[int, ...]:
|
| 120 |
+
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
| 121 |
+
num_kv_heads, head_size)
|
| 122 |
+
|
| 123 |
+
@staticmethod
|
| 124 |
+
def swap_blocks(
|
| 125 |
+
src_kv_cache: torch.Tensor,
|
| 126 |
+
dst_kv_cache: torch.Tensor,
|
| 127 |
+
src_to_dst: Dict[int, int],
|
| 128 |
+
) -> None:
|
| 129 |
+
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
| 130 |
+
|
| 131 |
+
@staticmethod
|
| 132 |
+
def copy_blocks(
|
| 133 |
+
kv_caches: List[torch.Tensor],
|
| 134 |
+
src_to_dists: Dict[int, List[int]],
|
| 135 |
+
) -> None:
|
| 136 |
+
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@dataclass
|
| 140 |
+
class BlocksparseFlashAttentionMetadata(AttentionMetadata):
|
| 141 |
+
"""A copy of Metadata for FlashAttentionBackend,
|
| 142 |
+
to avoid having to install flash_attn.
|
| 143 |
+
|
| 144 |
+
NOTE: Any python object stored here is not updated when it is
|
| 145 |
+
cuda-graph replayed. If you have values that need to be changed
|
| 146 |
+
dynamically, it should be stored in tensor. The tensor has to be
|
| 147 |
+
updated from `CUDAGraphRunner.forward` API.
|
| 148 |
+
"""
|
| 149 |
+
# (batch_size,). The sequence length per sequence. Sequence length means
|
| 150 |
+
# the computed tokens + new tokens None if it is a decoding.
|
| 151 |
+
seq_lens: Optional[List[int]]
|
| 152 |
+
# seq_lens stored as a tensor.
|
| 153 |
+
seq_lens_tensor: Optional[torch.Tensor]
|
| 154 |
+
|
| 155 |
+
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
| 156 |
+
# |---------- N-1 iteration --------|
|
| 157 |
+
# |---------------- N iteration ---------------------|
|
| 158 |
+
# |- tokenA -|......................|-- newTokens ---|
|
| 159 |
+
# |---------- context_len ----------|
|
| 160 |
+
# |-------------------- seq_len ----------------------|
|
| 161 |
+
# |-- query_len ---|
|
| 162 |
+
|
| 163 |
+
# Maximum query length in the batch. None for decoding.
|
| 164 |
+
max_query_len: Optional[int]
|
| 165 |
+
# Maximum sequence length among prefill batch. 0 if there are decoding
|
| 166 |
+
# requests only.
|
| 167 |
+
max_prefill_seq_len: int
|
| 168 |
+
# Maximum sequence length among decode batch. 0 if there are prefill
|
| 169 |
+
# requests only.
|
| 170 |
+
max_decode_seq_len: int
|
| 171 |
+
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
| 172 |
+
# the batch, used to index into subquery. E.g., if the subquery length
|
| 173 |
+
# is [4, 6], it is [0, 4, 10].
|
| 174 |
+
query_start_loc: Optional[torch.Tensor]
|
| 175 |
+
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
| 176 |
+
# the batch, used to index into sequence. E.g., if the sequence length is
|
| 177 |
+
# [4, 6], it is [0, 4, 10].
|
| 178 |
+
seq_start_loc: Optional[torch.Tensor]
|
| 179 |
+
# (batch_size,) A tensor of context lengths (tokens that are computed
|
| 180 |
+
# so far).
|
| 181 |
+
context_lens_tensor: Optional[torch.Tensor]
|
| 182 |
+
|
| 183 |
+
# (batch_size, max_blocks_per_seq).
|
| 184 |
+
# Block addresses per sequence. (Seq id -> list of physical block)
|
| 185 |
+
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
| 186 |
+
# in the kv cache. Each block can contain up to block_size tokens.
|
| 187 |
+
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
| 188 |
+
# captured.
|
| 189 |
+
block_tables: Optional[torch.Tensor]
|
| 190 |
+
|
| 191 |
+
# Whether or not if cuda graph is enabled.
|
| 192 |
+
# Cuda-graph is currently enabled for decoding only.
|
| 193 |
+
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
| 194 |
+
use_cuda_graph: bool
|
| 195 |
+
|
| 196 |
+
# Max number of query tokens for among request in the batch.
|
| 197 |
+
max_decode_query_len: Optional[int] = None
|
| 198 |
+
|
| 199 |
+
_cached_prefill_metadata: Optional[
|
| 200 |
+
"BlocksparseFlashAttentionMetadata"] = None
|
| 201 |
+
_cached_decode_metadata: Optional[
|
| 202 |
+
"BlocksparseFlashAttentionMetadata"] = None
|
| 203 |
+
|
| 204 |
+
@property
|
| 205 |
+
def prefill_metadata(
|
| 206 |
+
self) -> Optional["BlocksparseFlashAttentionMetadata"]:
|
| 207 |
+
if self.num_prefills == 0:
|
| 208 |
+
return None
|
| 209 |
+
|
| 210 |
+
if self._cached_prefill_metadata is not None:
|
| 211 |
+
return self._cached_prefill_metadata
|
| 212 |
+
|
| 213 |
+
assert self.seq_lens is not None
|
| 214 |
+
assert self.seq_lens_tensor is not None
|
| 215 |
+
assert self.query_start_loc is not None
|
| 216 |
+
assert self.context_lens_tensor is not None
|
| 217 |
+
assert self.block_tables is not None
|
| 218 |
+
assert self.seq_start_loc is not None
|
| 219 |
+
|
| 220 |
+
self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata(
|
| 221 |
+
num_prefills=self.num_prefills,
|
| 222 |
+
num_prefill_tokens=self.num_prefill_tokens,
|
| 223 |
+
num_decode_tokens=0,
|
| 224 |
+
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
| 225 |
+
multi_modal_placeholder_index_maps=self.
|
| 226 |
+
multi_modal_placeholder_index_maps,
|
| 227 |
+
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
| 228 |
+
seq_lens=self.seq_lens[:self.num_prefills],
|
| 229 |
+
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
| 230 |
+
max_query_len=self.max_query_len,
|
| 231 |
+
max_prefill_seq_len=self.max_prefill_seq_len,
|
| 232 |
+
max_decode_seq_len=0,
|
| 233 |
+
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
| 234 |
+
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
| 235 |
+
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
| 236 |
+
block_tables=self.block_tables[:self.num_prefills],
|
| 237 |
+
use_cuda_graph=False,
|
| 238 |
+
)
|
| 239 |
+
return self._cached_prefill_metadata
|
| 240 |
+
|
| 241 |
+
@property
|
| 242 |
+
def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
|
| 243 |
+
if self.num_decode_tokens == 0:
|
| 244 |
+
return None
|
| 245 |
+
|
| 246 |
+
if self._cached_decode_metadata is not None:
|
| 247 |
+
return self._cached_decode_metadata
|
| 248 |
+
assert self.block_tables is not None
|
| 249 |
+
assert self.seq_lens_tensor is not None
|
| 250 |
+
|
| 251 |
+
self._cached_decode_metadata = BlocksparseFlashAttentionMetadata(
|
| 252 |
+
num_prefills=0,
|
| 253 |
+
num_prefill_tokens=0,
|
| 254 |
+
num_decode_tokens=self.num_decode_tokens,
|
| 255 |
+
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
| 256 |
+
multi_modal_placeholder_index_maps=None,
|
| 257 |
+
enable_kv_scales_calculation=False,
|
| 258 |
+
seq_lens=None,
|
| 259 |
+
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
| 260 |
+
max_query_len=None,
|
| 261 |
+
max_prefill_seq_len=0,
|
| 262 |
+
max_decode_seq_len=self.max_decode_seq_len,
|
| 263 |
+
query_start_loc=None,
|
| 264 |
+
seq_start_loc=None,
|
| 265 |
+
context_lens_tensor=None,
|
| 266 |
+
block_tables=self.block_tables[self.num_prefills:],
|
| 267 |
+
use_cuda_graph=self.use_cuda_graph,
|
| 268 |
+
)
|
| 269 |
+
return self._cached_decode_metadata
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class BlocksparseFlashAttentionMetadataBuilder(
|
| 273 |
+
CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]):
|
| 274 |
+
|
| 275 |
+
_metadata_cls = BlocksparseFlashAttentionMetadata
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class BlocksparseFlashAttentionImpl(AttentionImpl):
|
| 279 |
+
"""
|
| 280 |
+
If the input tensors contain prompt tokens, the layout is as follows:
|
| 281 |
+
|<--------------- num_prompt_tokens -------------->|
|
| 282 |
+
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
|
| 283 |
+
|
| 284 |
+
Otherwise, the layout is as follows:
|
| 285 |
+
|<------------------ num_generation_tokens (M) ----------------->|
|
| 286 |
+
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|
| 287 |
+
|
| 288 |
+
Generation tokens can contain padding when cuda-graph is used.
|
| 289 |
+
Currently, prompt tokens don't contain any padding.
|
| 290 |
+
|
| 291 |
+
The prompts might have different lengths, while the generation tokens
|
| 292 |
+
always have length 1.
|
| 293 |
+
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
def __init__(
|
| 297 |
+
self,
|
| 298 |
+
num_heads: int,
|
| 299 |
+
head_size: int,
|
| 300 |
+
scale: float,
|
| 301 |
+
num_kv_heads: int,
|
| 302 |
+
alibi_slopes: Optional[List[float]],
|
| 303 |
+
sliding_window: Optional[int],
|
| 304 |
+
kv_cache_dtype: str,
|
| 305 |
+
blocksparse_params: Optional[Dict[str, Any]] = None,
|
| 306 |
+
logits_soft_cap: Optional[float] = None,
|
| 307 |
+
attn_type: str = AttentionType.DECODER,
|
| 308 |
+
) -> None:
|
| 309 |
+
assert blocksparse_params is not None
|
| 310 |
+
assert alibi_slopes is None, ValueError(
|
| 311 |
+
"Alibi not support for blocksparse flash attention.")
|
| 312 |
+
assert sliding_window is None, ValueError(
|
| 313 |
+
"sliding_window is invalid for blocksparse attention.")
|
| 314 |
+
assert logits_soft_cap is None, ValueError(
|
| 315 |
+
"logits_soft_cap is invalid for blocksparse attention.")
|
| 316 |
+
|
| 317 |
+
if "num_heads" not in blocksparse_params:
|
| 318 |
+
blocksparse_params["num_heads"] = num_heads
|
| 319 |
+
if "num_kv_heads" not in blocksparse_params:
|
| 320 |
+
blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads
|
| 321 |
+
self.blocksparse_params = BlocksparseParams(**blocksparse_params)
|
| 322 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 323 |
+
|
| 324 |
+
self.num_heads = num_heads
|
| 325 |
+
self.head_size = head_size
|
| 326 |
+
self.scale = float(scale)
|
| 327 |
+
self.alibi_slopes = alibi_slopes
|
| 328 |
+
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
| 329 |
+
|
| 330 |
+
assert self.num_heads % self.num_kv_heads == 0
|
| 331 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 332 |
+
|
| 333 |
+
self.local_blocks = self.blocksparse_params.local_blocks
|
| 334 |
+
self.vert_stride = self.blocksparse_params.vert_stride
|
| 335 |
+
self.sparse_block_size = self.blocksparse_params.block_size
|
| 336 |
+
self.head_sliding_step = self.blocksparse_params.head_sliding_step
|
| 337 |
+
|
| 338 |
+
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
|
| 339 |
+
if head_size not in suppored_head_sizes:
|
| 340 |
+
raise ValueError(
|
| 341 |
+
f"Head size {head_size} is not supported by PagedAttention. "
|
| 342 |
+
f"Supported head sizes are: {suppored_head_sizes}.")
|
| 343 |
+
|
| 344 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 345 |
+
self.tp_rank = get_tensor_model_parallel_rank()
|
| 346 |
+
|
| 347 |
+
total_num_heads = num_heads * self.tp_size
|
| 348 |
+
self.bs_attn = LocalStridedBlockSparseAttn(
|
| 349 |
+
total_num_heads,
|
| 350 |
+
self.blocksparse_params.max_seqlen,
|
| 351 |
+
self.blocksparse_params.local_blocks,
|
| 352 |
+
self.blocksparse_params.vert_stride,
|
| 353 |
+
self.blocksparse_params.block_size,
|
| 354 |
+
homo_head=self.blocksparse_params.homo_head,
|
| 355 |
+
active_head_range=self.blocksparse_params.active_head_range,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
if attn_type != AttentionType.DECODER:
|
| 359 |
+
raise NotImplementedError("Encoder self-attention and "
|
| 360 |
+
"encoder/decoder cross-attention "
|
| 361 |
+
"are not implemented for "
|
| 362 |
+
"BlocksparseFlashAttentionImpl")
|
| 363 |
+
|
| 364 |
+
def forward(
|
| 365 |
+
self,
|
| 366 |
+
layer: AttentionLayer,
|
| 367 |
+
query: torch.Tensor,
|
| 368 |
+
key: torch.Tensor,
|
| 369 |
+
value: torch.Tensor,
|
| 370 |
+
kv_cache: torch.Tensor,
|
| 371 |
+
attn_metadata: BlocksparseFlashAttentionMetadata,
|
| 372 |
+
output: Optional[torch.Tensor] = None,
|
| 373 |
+
) -> torch.Tensor:
|
| 374 |
+
"""Forward pass with FlashAttention and PagedAttention.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
query: shape = [num_tokens, num_heads * head_size]
|
| 378 |
+
key: shape = [num_tokens, num_kv_heads * head_size]
|
| 379 |
+
value: shape = [num_tokens, num_kv_heads * head_size]
|
| 380 |
+
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
| 381 |
+
NOTE: kv_cache will be an empty tensor with shape [0]
|
| 382 |
+
for profiling run.
|
| 383 |
+
attn_metadata: Metadata for attention.
|
| 384 |
+
Returns:
|
| 385 |
+
shape = [num_tokens, num_heads * head_size]
|
| 386 |
+
"""
|
| 387 |
+
num_tokens, hidden_size = query.shape
|
| 388 |
+
# Reshape the query, key, and value tensors.
|
| 389 |
+
query = query.view(-1, self.num_heads, self.head_size)
|
| 390 |
+
key = key.view(-1, self.num_kv_heads, self.head_size)
|
| 391 |
+
value = value.view(-1, self.num_kv_heads, self.head_size)
|
| 392 |
+
|
| 393 |
+
if kv_cache.numel() > 0:
|
| 394 |
+
key_cache, value_cache = PagedAttention.split_kv_cache(
|
| 395 |
+
kv_cache, self.num_kv_heads, self.head_size)
|
| 396 |
+
|
| 397 |
+
# Reshape the input keys and values and store them in the cache.
|
| 398 |
+
# If kv_cache is not provided, the new key and value tensors are
|
| 399 |
+
# not cached. This happens during the initial memory profiling run.
|
| 400 |
+
|
| 401 |
+
PagedAttention.write_to_paged_cache(
|
| 402 |
+
key,
|
| 403 |
+
value,
|
| 404 |
+
key_cache,
|
| 405 |
+
value_cache,
|
| 406 |
+
attn_metadata.slot_mapping,
|
| 407 |
+
self.kv_cache_dtype,
|
| 408 |
+
layer._k_scale,
|
| 409 |
+
layer._v_scale,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
if prefill_meta := attn_metadata.prefill_metadata:
|
| 413 |
+
|
| 414 |
+
# Prompt run.
|
| 415 |
+
# normal attention
|
| 416 |
+
# When block_tables are not filled, it means q and k are the
|
| 417 |
+
# prompt, and they have the same length.
|
| 418 |
+
|
| 419 |
+
assert kv_cache.numel() == 0 \
|
| 420 |
+
or prefill_meta.block_tables is None \
|
| 421 |
+
or prefill_meta.block_tables.numel() == 0, \
|
| 422 |
+
"Does not support prefix-enabled attention."
|
| 423 |
+
|
| 424 |
+
output = self.bs_attn(
|
| 425 |
+
q=query,
|
| 426 |
+
k=key,
|
| 427 |
+
v=value,
|
| 428 |
+
cu_seqlens_q=prefill_meta.seq_start_loc,
|
| 429 |
+
cu_seqlens_k=prefill_meta.seq_start_loc,
|
| 430 |
+
sm_scale=self.scale,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
if decode_meta := attn_metadata.decode_metadata:
|
| 434 |
+
# Decoding run.
|
| 435 |
+
output = PagedAttention.forward_decode(
|
| 436 |
+
query,
|
| 437 |
+
key_cache,
|
| 438 |
+
value_cache,
|
| 439 |
+
decode_meta.block_tables,
|
| 440 |
+
decode_meta.seq_lens_tensor,
|
| 441 |
+
self.blocksparse_params.max_seqlen,
|
| 442 |
+
self.kv_cache_dtype,
|
| 443 |
+
self.num_kv_heads,
|
| 444 |
+
self.scale,
|
| 445 |
+
self.alibi_slopes,
|
| 446 |
+
layer._k_scale,
|
| 447 |
+
layer._v_scale,
|
| 448 |
+
tp_rank=self.tp_rank,
|
| 449 |
+
blocksparse_local_blocks=self.local_blocks,
|
| 450 |
+
blocksparse_vert_stride=self.vert_stride,
|
| 451 |
+
blocksparse_block_size=self.sparse_block_size,
|
| 452 |
+
blocksparse_head_sliding_step=self.head_sliding_step,
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
assert output is not None
|
| 456 |
+
# Reshape the output tensor.
|
| 457 |
+
return output.view(num_tokens, hidden_size)
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/flash_attn.py
ADDED
|
@@ -0,0 +1,942 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Attention layer with FlashAttention."""
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from itertools import accumulate
|
| 6 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from vllm import _custom_ops as ops
|
| 11 |
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
| 12 |
+
AttentionLayer,
|
| 13 |
+
AttentionMetadata,
|
| 14 |
+
AttentionMetadataBuilder,
|
| 15 |
+
AttentionType)
|
| 16 |
+
from vllm.attention.backends.utils import (
|
| 17 |
+
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
|
| 18 |
+
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
|
| 19 |
+
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
| 20 |
+
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
| 21 |
+
from vllm.envs import VLLM_FLASH_ATTN_VERSION
|
| 22 |
+
from vllm.logger import init_logger
|
| 23 |
+
from vllm.multimodal import MultiModalPlaceholderMap
|
| 24 |
+
from vllm.platforms import current_platform
|
| 25 |
+
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
| 26 |
+
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
|
| 27 |
+
flash_attn_varlen_func,
|
| 28 |
+
flash_attn_with_kvcache,
|
| 29 |
+
is_fa_version_supported)
|
| 30 |
+
|
| 31 |
+
if TYPE_CHECKING:
|
| 32 |
+
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
| 33 |
+
ModelInputForGPUWithSamplingMetadata)
|
| 34 |
+
|
| 35 |
+
logger = init_logger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class FlashAttentionBackend(AttentionBackend):
|
| 39 |
+
|
| 40 |
+
accept_output_buffer: bool = True
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def get_supported_head_sizes() -> List[int]:
|
| 44 |
+
return [32, 64, 96, 128, 160, 192, 224, 256]
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def get_name() -> str:
|
| 48 |
+
return "FLASH_ATTN"
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def get_impl_cls() -> Type["FlashAttentionImpl"]:
|
| 52 |
+
return FlashAttentionImpl
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
| 56 |
+
return FlashAttentionMetadata
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
|
| 60 |
+
return FlashAttentionMetadataBuilder
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
def get_state_cls() -> Type["CommonAttentionState"]:
|
| 64 |
+
return CommonAttentionState
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def get_kv_cache_shape(
|
| 68 |
+
num_blocks: int,
|
| 69 |
+
block_size: int,
|
| 70 |
+
num_kv_heads: int,
|
| 71 |
+
head_size: int,
|
| 72 |
+
) -> Tuple[int, ...]:
|
| 73 |
+
if block_size % 16 != 0:
|
| 74 |
+
raise ValueError("Block size must be a multiple of 16.")
|
| 75 |
+
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
| 76 |
+
|
| 77 |
+
@staticmethod
|
| 78 |
+
def swap_blocks(
|
| 79 |
+
src_kv_cache: torch.Tensor,
|
| 80 |
+
dst_kv_cache: torch.Tensor,
|
| 81 |
+
src_to_dst: torch.Tensor,
|
| 82 |
+
) -> None:
|
| 83 |
+
src_key_cache = src_kv_cache[0]
|
| 84 |
+
dst_key_cache = dst_kv_cache[0]
|
| 85 |
+
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
| 86 |
+
src_value_cache = src_kv_cache[1]
|
| 87 |
+
dst_value_cache = dst_kv_cache[1]
|
| 88 |
+
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
| 89 |
+
|
| 90 |
+
@staticmethod
|
| 91 |
+
def copy_blocks(
|
| 92 |
+
kv_caches: List[torch.Tensor],
|
| 93 |
+
src_to_dists: torch.Tensor,
|
| 94 |
+
) -> None:
|
| 95 |
+
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
| 96 |
+
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
| 97 |
+
|
| 98 |
+
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@dataclass
|
| 102 |
+
class FlashAttentionMetadata(AttentionMetadata):
|
| 103 |
+
"""Metadata for FlashAttentionBackend.
|
| 104 |
+
|
| 105 |
+
NOTE: Any python object stored here is not updated when it is
|
| 106 |
+
cuda-graph replayed. If you have values that need to be changed
|
| 107 |
+
dynamically, it should be stored in tensor. The tensor has to be
|
| 108 |
+
updated from `CUDAGraphRunner.forward` API.
|
| 109 |
+
"""
|
| 110 |
+
# (batch_size,). The sequence length per sequence. Sequence length means
|
| 111 |
+
# the computed tokens + new tokens None if it is a decoding.
|
| 112 |
+
seq_lens: Optional[List[int]]
|
| 113 |
+
# seq_lens stored as a tensor.
|
| 114 |
+
seq_lens_tensor: Optional[torch.Tensor]
|
| 115 |
+
|
| 116 |
+
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
| 117 |
+
# |---------- N-1 iteration --------|
|
| 118 |
+
# |---------------- N iteration ---------------------|
|
| 119 |
+
# |- tokenA -|......................|-- newTokens ---|
|
| 120 |
+
# |---------- context_len ----------|
|
| 121 |
+
# |-------------------- seq_len ---------------------|
|
| 122 |
+
# |-- query_len ---|
|
| 123 |
+
|
| 124 |
+
# Maximum sequence length among prefill batch. 0 if there are decoding
|
| 125 |
+
# requests only.
|
| 126 |
+
max_prefill_seq_len: int
|
| 127 |
+
# Maximum sequence length among decode batch. 0 if there are prefill
|
| 128 |
+
# requests only.
|
| 129 |
+
max_decode_seq_len: int
|
| 130 |
+
# (batch_size,) A tensor of context lengths (tokens that are computed
|
| 131 |
+
# so far).
|
| 132 |
+
context_lens_tensor: Optional[torch.Tensor]
|
| 133 |
+
|
| 134 |
+
# (batch_size, max_blocks_per_seq).
|
| 135 |
+
# Block addresses per sequence. (Seq id -> list of physical block)
|
| 136 |
+
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
| 137 |
+
# in the kv cache. Each block can contain up to block_size tokens.
|
| 138 |
+
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
| 139 |
+
# captured.
|
| 140 |
+
block_tables: Optional[torch.Tensor]
|
| 141 |
+
|
| 142 |
+
# Whether or not if cuda graph is enabled.
|
| 143 |
+
# Cuda-graph is currently enabled for decoding only.
|
| 144 |
+
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
| 145 |
+
|
| 146 |
+
use_cuda_graph: bool
|
| 147 |
+
|
| 148 |
+
# Maximum query length in the batch.
|
| 149 |
+
max_query_len: Optional[int] = None
|
| 150 |
+
|
| 151 |
+
# Max number of query tokens among request in the batch.
|
| 152 |
+
max_decode_query_len: Optional[int] = None
|
| 153 |
+
|
| 154 |
+
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
| 155 |
+
# the batch, used to index into subquery. E.g., if the subquery length
|
| 156 |
+
# is [4, 6], it is [0, 4, 10].
|
| 157 |
+
query_start_loc: Optional[torch.Tensor] = None
|
| 158 |
+
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
| 159 |
+
# the batch, used to index into sequence. E.g., if the sequence length is
|
| 160 |
+
# [4, 6], it is [0, 4, 10].
|
| 161 |
+
seq_start_loc: Optional[torch.Tensor] = None
|
| 162 |
+
|
| 163 |
+
_cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
|
| 164 |
+
_cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
|
| 165 |
+
|
| 166 |
+
# Begin encoder attn & enc/dec cross-attn fields...
|
| 167 |
+
|
| 168 |
+
# Encoder sequence lengths representation
|
| 169 |
+
encoder_seq_lens: Optional[List[int]] = None
|
| 170 |
+
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
| 171 |
+
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
| 172 |
+
# the batch, used to index into sequence. E.g., if the sequence length is
|
| 173 |
+
# [4, 6], it is [0, 4, 10].
|
| 174 |
+
encoder_seq_start_loc: Optional[torch.Tensor] = None
|
| 175 |
+
# Maximum sequence length among encoder sequences
|
| 176 |
+
max_encoder_seq_len: Optional[int] = None
|
| 177 |
+
# Number of tokens input to encoder
|
| 178 |
+
num_encoder_tokens: Optional[int] = None
|
| 179 |
+
|
| 180 |
+
# Cross-attention memory-mapping data structures: slot mapping
|
| 181 |
+
# and block tables
|
| 182 |
+
cross_slot_mapping: Optional[torch.Tensor] = None
|
| 183 |
+
cross_block_tables: Optional[torch.Tensor] = None
|
| 184 |
+
|
| 185 |
+
@property
|
| 186 |
+
def is_all_encoder_attn_metadata_set(self):
|
| 187 |
+
'''
|
| 188 |
+
All attention metadata required for encoder attention is set.
|
| 189 |
+
'''
|
| 190 |
+
return is_all_encoder_attn_metadata_set(self)
|
| 191 |
+
|
| 192 |
+
@property
|
| 193 |
+
def is_all_cross_attn_metadata_set(self):
|
| 194 |
+
'''
|
| 195 |
+
All attention metadata required for enc/dec cross-attention is set.
|
| 196 |
+
|
| 197 |
+
Superset of encoder attention required metadata.
|
| 198 |
+
'''
|
| 199 |
+
return is_all_cross_attn_metadata_set(self)
|
| 200 |
+
|
| 201 |
+
@property
|
| 202 |
+
def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
|
| 203 |
+
if self.num_prefills == 0:
|
| 204 |
+
return None
|
| 205 |
+
|
| 206 |
+
if self._cached_prefill_metadata is not None:
|
| 207 |
+
return self._cached_prefill_metadata
|
| 208 |
+
|
| 209 |
+
assert ((self.seq_lens is not None)
|
| 210 |
+
or (self.encoder_seq_lens is not None))
|
| 211 |
+
assert ((self.seq_lens_tensor is not None)
|
| 212 |
+
or (self.encoder_seq_lens_tensor is not None))
|
| 213 |
+
|
| 214 |
+
# Compute some attn_metadata fields which default to None
|
| 215 |
+
query_start_loc = (None if self.query_start_loc is None else
|
| 216 |
+
self.query_start_loc[:self.num_prefills + 1])
|
| 217 |
+
slot_mapping = (None if self.slot_mapping is None else
|
| 218 |
+
self.slot_mapping[:self.num_prefill_tokens])
|
| 219 |
+
seq_lens = (None if self.seq_lens is None else
|
| 220 |
+
self.seq_lens[:self.num_prefills])
|
| 221 |
+
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
| 222 |
+
self.seq_lens_tensor[:self.num_prefills])
|
| 223 |
+
seq_start_loc = (None if self.seq_start_loc is None else
|
| 224 |
+
self.seq_start_loc[:self.num_prefills + 1])
|
| 225 |
+
context_lens_tensor = (None if self.context_lens_tensor is None else
|
| 226 |
+
self.context_lens_tensor[:self.num_prefills])
|
| 227 |
+
block_tables = (None if self.block_tables is None else
|
| 228 |
+
self.block_tables[:self.num_prefills])
|
| 229 |
+
|
| 230 |
+
self._cached_prefill_metadata = FlashAttentionMetadata(
|
| 231 |
+
num_prefills=self.num_prefills,
|
| 232 |
+
num_prefill_tokens=self.num_prefill_tokens,
|
| 233 |
+
num_decode_tokens=0,
|
| 234 |
+
slot_mapping=slot_mapping,
|
| 235 |
+
multi_modal_placeholder_index_maps=self.
|
| 236 |
+
multi_modal_placeholder_index_maps,
|
| 237 |
+
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
| 238 |
+
seq_lens=seq_lens,
|
| 239 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 240 |
+
max_query_len=self.max_query_len,
|
| 241 |
+
max_prefill_seq_len=self.max_prefill_seq_len,
|
| 242 |
+
max_decode_query_len=0,
|
| 243 |
+
max_decode_seq_len=0,
|
| 244 |
+
query_start_loc=query_start_loc,
|
| 245 |
+
seq_start_loc=seq_start_loc,
|
| 246 |
+
context_lens_tensor=context_lens_tensor,
|
| 247 |
+
block_tables=block_tables,
|
| 248 |
+
use_cuda_graph=False,
|
| 249 |
+
# Begin encoder & cross attn fields below...
|
| 250 |
+
encoder_seq_lens=self.encoder_seq_lens,
|
| 251 |
+
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
| 252 |
+
encoder_seq_start_loc=self.encoder_seq_start_loc,
|
| 253 |
+
max_encoder_seq_len=self.max_encoder_seq_len,
|
| 254 |
+
cross_slot_mapping=self.cross_slot_mapping,
|
| 255 |
+
cross_block_tables=self.cross_block_tables)
|
| 256 |
+
return self._cached_prefill_metadata
|
| 257 |
+
|
| 258 |
+
@property
|
| 259 |
+
def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
|
| 260 |
+
if self.num_decode_tokens == 0:
|
| 261 |
+
return None
|
| 262 |
+
|
| 263 |
+
if self._cached_decode_metadata is not None:
|
| 264 |
+
return self._cached_decode_metadata
|
| 265 |
+
assert ((self.seq_lens_tensor is not None)
|
| 266 |
+
or (self.encoder_seq_lens_tensor is not None))
|
| 267 |
+
|
| 268 |
+
# Compute some attn_metadata fields which default to None
|
| 269 |
+
slot_mapping = (None if self.slot_mapping is None else
|
| 270 |
+
self.slot_mapping[self.num_prefill_tokens:])
|
| 271 |
+
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
| 272 |
+
self.seq_lens_tensor[self.num_prefills:])
|
| 273 |
+
block_tables = (None if self.block_tables is None else
|
| 274 |
+
self.block_tables[self.num_prefills:])
|
| 275 |
+
|
| 276 |
+
self._cached_decode_metadata = FlashAttentionMetadata(
|
| 277 |
+
num_prefills=0,
|
| 278 |
+
num_prefill_tokens=0,
|
| 279 |
+
num_decode_tokens=self.num_decode_tokens,
|
| 280 |
+
slot_mapping=slot_mapping,
|
| 281 |
+
multi_modal_placeholder_index_maps=None,
|
| 282 |
+
enable_kv_scales_calculation=True,
|
| 283 |
+
seq_lens=None,
|
| 284 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 285 |
+
max_decode_query_len=self.max_decode_query_len,
|
| 286 |
+
max_query_len=self.max_query_len,
|
| 287 |
+
max_prefill_seq_len=0,
|
| 288 |
+
max_decode_seq_len=self.max_decode_seq_len,
|
| 289 |
+
# Batch may be composed of prefill|decodes, adjust query start
|
| 290 |
+
# indices to refer to the start of decodes. E.g.
|
| 291 |
+
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
| 292 |
+
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
| 293 |
+
self.query_start_loc[self.num_prefills])
|
| 294 |
+
if self.query_start_loc is not None else None,
|
| 295 |
+
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
| 296 |
+
if self.seq_start_loc is not None else None,
|
| 297 |
+
context_lens_tensor=None,
|
| 298 |
+
block_tables=block_tables,
|
| 299 |
+
use_cuda_graph=self.use_cuda_graph,
|
| 300 |
+
# Begin encoder & cross attn fields below...
|
| 301 |
+
encoder_seq_lens=self.encoder_seq_lens,
|
| 302 |
+
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
| 303 |
+
encoder_seq_start_loc=self.encoder_seq_start_loc,
|
| 304 |
+
max_encoder_seq_len=self.max_encoder_seq_len,
|
| 305 |
+
cross_slot_mapping=self.cross_slot_mapping,
|
| 306 |
+
cross_block_tables=self.cross_block_tables)
|
| 307 |
+
return self._cached_decode_metadata
|
| 308 |
+
|
| 309 |
+
def advance_step(self,
|
| 310 |
+
model_input: "ModelInputForGPUWithSamplingMetadata",
|
| 311 |
+
sampled_token_ids: Optional[torch.Tensor],
|
| 312 |
+
block_size: int,
|
| 313 |
+
num_seqs: int,
|
| 314 |
+
num_queries: int,
|
| 315 |
+
turn_prefills_into_decodes: bool = False):
|
| 316 |
+
"""
|
| 317 |
+
Update metadata in-place to advance one decode step.
|
| 318 |
+
"""
|
| 319 |
+
# When using cudagraph, the num_seqs is padded to the next captured
|
| 320 |
+
# batch sized, but num_queries tracks the actual number of requests in
|
| 321 |
+
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
| 322 |
+
if num_seqs != num_queries:
|
| 323 |
+
assert num_seqs > num_queries
|
| 324 |
+
assert self.use_cuda_graph
|
| 325 |
+
|
| 326 |
+
if turn_prefills_into_decodes:
|
| 327 |
+
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
| 328 |
+
# decodes are scheduled together. In the first step, all the
|
| 329 |
+
# prefills turn into decodes. This update reflects that
|
| 330 |
+
# conversion.
|
| 331 |
+
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
| 332 |
+
self.num_decode_tokens += self.num_prefills
|
| 333 |
+
self.num_prefills = 0
|
| 334 |
+
self.num_prefill_tokens = 0
|
| 335 |
+
self.max_prefill_seq_len = 0
|
| 336 |
+
self.max_query_len = 1
|
| 337 |
+
|
| 338 |
+
self.slot_mapping = self.slot_mapping[:num_seqs]
|
| 339 |
+
else:
|
| 340 |
+
assert self.seq_lens is not None
|
| 341 |
+
assert self.max_decode_seq_len == max(self.seq_lens)
|
| 342 |
+
|
| 343 |
+
assert self.num_prefills == 0
|
| 344 |
+
assert self.num_prefill_tokens == 0
|
| 345 |
+
assert self.num_decode_tokens == num_seqs
|
| 346 |
+
assert self.slot_mapping.shape == (num_seqs, )
|
| 347 |
+
|
| 348 |
+
assert self.seq_lens is not None
|
| 349 |
+
assert len(self.seq_lens) == num_seqs
|
| 350 |
+
assert self.seq_lens_tensor is not None
|
| 351 |
+
assert self.seq_lens_tensor.shape == (num_seqs, )
|
| 352 |
+
assert self.max_query_len == 1
|
| 353 |
+
assert self.max_prefill_seq_len == 0
|
| 354 |
+
|
| 355 |
+
assert self.query_start_loc is not None
|
| 356 |
+
assert self.query_start_loc.shape == (num_queries + 1, )
|
| 357 |
+
assert self.seq_start_loc is not None
|
| 358 |
+
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
| 359 |
+
|
| 360 |
+
assert self.context_lens_tensor is not None
|
| 361 |
+
assert self.context_lens_tensor.shape == (num_queries, )
|
| 362 |
+
|
| 363 |
+
assert self.block_tables is not None
|
| 364 |
+
assert self.block_tables.shape[0] == num_seqs
|
| 365 |
+
|
| 366 |
+
# Update query lengths. Note that we update only queries and not seqs,
|
| 367 |
+
# since tensors may be padded due to captured cuda graph batch size
|
| 368 |
+
for i in range(num_queries):
|
| 369 |
+
self.seq_lens[i] += 1
|
| 370 |
+
self.max_decode_seq_len = max(self.seq_lens)
|
| 371 |
+
|
| 372 |
+
ops.advance_step_flashattn(num_seqs=num_seqs,
|
| 373 |
+
num_queries=num_queries,
|
| 374 |
+
block_size=block_size,
|
| 375 |
+
input_tokens=model_input.input_tokens,
|
| 376 |
+
sampled_token_ids=sampled_token_ids,
|
| 377 |
+
input_positions=model_input.input_positions,
|
| 378 |
+
seq_lens=self.seq_lens_tensor,
|
| 379 |
+
slot_mapping=self.slot_mapping,
|
| 380 |
+
block_tables=self.block_tables)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
class FlashAttentionMetadataBuilder(
|
| 384 |
+
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
| 385 |
+
|
| 386 |
+
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
| 387 |
+
self.input_builder = input_builder
|
| 388 |
+
self.runner = input_builder.runner
|
| 389 |
+
self.sliding_window = input_builder.sliding_window
|
| 390 |
+
self.block_size = input_builder.block_size
|
| 391 |
+
|
| 392 |
+
def prepare(self):
|
| 393 |
+
self.slot_mapping: List[int] = []
|
| 394 |
+
self.prefill_seq_lens: List[int] = []
|
| 395 |
+
self.context_lens: List[int] = []
|
| 396 |
+
self.block_tables: List[List[int]] = []
|
| 397 |
+
self.curr_seq_lens: List[int] = []
|
| 398 |
+
self.multimodal_placeholder_maps: Dict[
|
| 399 |
+
str,
|
| 400 |
+
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
| 401 |
+
self.num_prefills = 0
|
| 402 |
+
self.num_prefill_tokens = 0
|
| 403 |
+
self.num_decode_tokens = 0
|
| 404 |
+
self.has_prefix_cache_hit = False
|
| 405 |
+
|
| 406 |
+
def _add_seq_group(
|
| 407 |
+
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
| 408 |
+
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
|
| 409 |
+
"""Add a sequence group to the metadata. Specifically update/append
|
| 410 |
+
1. context length.
|
| 411 |
+
2. block table.
|
| 412 |
+
3. slot mapping.
|
| 413 |
+
"""
|
| 414 |
+
is_prompt = inter_data.is_prompt
|
| 415 |
+
block_tables = inter_data.block_tables
|
| 416 |
+
|
| 417 |
+
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
| 418 |
+
curr_sliding_window_block) in zip(
|
| 419 |
+
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
| 420 |
+
inter_data.orig_seq_lens, inter_data.seq_lens,
|
| 421 |
+
inter_data.query_lens, inter_data.context_lens,
|
| 422 |
+
inter_data.curr_sliding_window_blocks):
|
| 423 |
+
self.context_lens.append(context_len)
|
| 424 |
+
|
| 425 |
+
if is_prompt:
|
| 426 |
+
mm_maps = inter_data.multi_modal_placeholder_maps
|
| 427 |
+
if mm_maps:
|
| 428 |
+
for modality, placeholders in mm_maps.items():
|
| 429 |
+
self.multimodal_placeholder_maps[modality].extend(
|
| 430 |
+
placeholders)
|
| 431 |
+
|
| 432 |
+
self.num_prefills += 1
|
| 433 |
+
self.num_prefill_tokens += token_len
|
| 434 |
+
self.prefill_seq_lens.append(seq_len)
|
| 435 |
+
else:
|
| 436 |
+
self.num_decode_tokens += query_len
|
| 437 |
+
self.curr_seq_lens.append(curr_seq_len)
|
| 438 |
+
|
| 439 |
+
# Compute block table.
|
| 440 |
+
# TODO(sang): Combine chunked prefill and prefix caching by
|
| 441 |
+
# only allowing multiple of block_size chunk size.
|
| 442 |
+
# NOTE: This only works for oooooooxxx style attention.
|
| 443 |
+
block_table = []
|
| 444 |
+
if prefix_cache_hit:
|
| 445 |
+
# NOTE(woosuk): For flash-attn, the block table should
|
| 446 |
+
# include the entries for the incoming prefill tokens.
|
| 447 |
+
block_table = block_tables[seq_id]
|
| 448 |
+
elif ((chunked_prefill_enabled or not is_prompt)
|
| 449 |
+
and block_tables is not None):
|
| 450 |
+
if curr_sliding_window_block == 0:
|
| 451 |
+
block_table = block_tables[seq_id]
|
| 452 |
+
else:
|
| 453 |
+
block_table = block_tables[seq_id][
|
| 454 |
+
-curr_sliding_window_block:]
|
| 455 |
+
self.block_tables.append(block_table)
|
| 456 |
+
|
| 457 |
+
# Compute slot mapping.
|
| 458 |
+
is_profile_run = is_block_tables_empty(block_tables)
|
| 459 |
+
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
| 460 |
+
context_len,
|
| 461 |
+
self.sliding_window)
|
| 462 |
+
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
| 463 |
+
seq_len, context_len, start_idx,
|
| 464 |
+
self.block_size, inter_data.block_tables)
|
| 465 |
+
|
| 466 |
+
def _get_graph_runner_block_tables(
|
| 467 |
+
self, num_seqs: int,
|
| 468 |
+
block_tables: List[List[int]]) -> torch.Tensor:
|
| 469 |
+
# The shape of graph_block_tables is
|
| 470 |
+
# [max batch size, max context len // block size].
|
| 471 |
+
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
| 472 |
+
assert max_batch_size >= num_seqs
|
| 473 |
+
|
| 474 |
+
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
|
| 475 |
+
for i, block_table in enumerate(block_tables):
|
| 476 |
+
if block_table:
|
| 477 |
+
num_blocks = len(block_table)
|
| 478 |
+
if num_blocks <= max_blocks:
|
| 479 |
+
graph_block_tables[i, :num_blocks] = block_table
|
| 480 |
+
else:
|
| 481 |
+
# It may be possible to have more blocks allocated due
|
| 482 |
+
# to lookahead slots of multi-step, however, they are
|
| 483 |
+
# not used anyway, so can be safely ignored.
|
| 484 |
+
graph_block_tables[
|
| 485 |
+
i, :max_blocks] = block_table[:max_blocks]
|
| 486 |
+
|
| 487 |
+
return torch.from_numpy(graph_block_tables).to(
|
| 488 |
+
device=self.runner.device, non_blocking=True)
|
| 489 |
+
|
| 490 |
+
def build(self, seq_lens: List[int], query_lens: List[int],
|
| 491 |
+
cuda_graph_pad_size: int, batch_size: int):
|
| 492 |
+
"""Build attention metadata with on-device tensors.
|
| 493 |
+
|
| 494 |
+
Args:
|
| 495 |
+
seq_lens: The maybe padded sequence lengths of the input sequences.
|
| 496 |
+
query_lens: The query lengths of the input sequences.
|
| 497 |
+
cuda_graph_pad_size: The padding size for cuda graph.
|
| 498 |
+
-1 if cuda graph is not used.
|
| 499 |
+
batch_size: The maybe padded batch size.
|
| 500 |
+
"""
|
| 501 |
+
prefix_cache_hit = any([
|
| 502 |
+
inter_data.prefix_cache_hit
|
| 503 |
+
for inter_data in self.input_builder.inter_data_list
|
| 504 |
+
])
|
| 505 |
+
for inter_data in self.input_builder.inter_data_list:
|
| 506 |
+
self._add_seq_group(inter_data,
|
| 507 |
+
self.input_builder.chunked_prefill_enabled,
|
| 508 |
+
prefix_cache_hit)
|
| 509 |
+
|
| 510 |
+
device = self.runner.device
|
| 511 |
+
use_captured_graph = cuda_graph_pad_size != -1
|
| 512 |
+
|
| 513 |
+
max_query_len = max(query_lens)
|
| 514 |
+
decode_query_lens = query_lens[self.num_prefills:]
|
| 515 |
+
if len(decode_query_lens) > 0:
|
| 516 |
+
max_decode_query_len = max(decode_query_lens)
|
| 517 |
+
else:
|
| 518 |
+
max_decode_query_len = 1
|
| 519 |
+
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
| 520 |
+
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
| 521 |
+
num_decode_tokens = self.num_decode_tokens
|
| 522 |
+
query_start_loc = list(accumulate(query_lens, initial=0))
|
| 523 |
+
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
| 524 |
+
|
| 525 |
+
num_seqs = len(seq_lens)
|
| 526 |
+
if use_captured_graph:
|
| 527 |
+
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
| 528 |
+
self.block_tables.extend([] * cuda_graph_pad_size)
|
| 529 |
+
num_decode_tokens = batch_size - self.num_prefill_tokens
|
| 530 |
+
block_tables = self._get_graph_runner_block_tables(
|
| 531 |
+
num_seqs, self.block_tables)
|
| 532 |
+
else:
|
| 533 |
+
block_tables = make_tensor_with_pad(
|
| 534 |
+
self.block_tables,
|
| 535 |
+
pad=0,
|
| 536 |
+
dtype=torch.int,
|
| 537 |
+
device=device,
|
| 538 |
+
)
|
| 539 |
+
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
| 540 |
+
|
| 541 |
+
assert device is not None
|
| 542 |
+
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
| 543 |
+
device, self.runner.pin_memory)
|
| 544 |
+
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
| 545 |
+
self.runner.pin_memory)
|
| 546 |
+
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
| 547 |
+
device, self.runner.pin_memory)
|
| 548 |
+
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
| 549 |
+
device,
|
| 550 |
+
self.runner.pin_memory)
|
| 551 |
+
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
| 552 |
+
device, self.runner.pin_memory)
|
| 553 |
+
placeholder_index_maps = {
|
| 554 |
+
modality: placeholder_map.index_map()
|
| 555 |
+
for modality, placeholder_map in
|
| 556 |
+
self.multimodal_placeholder_maps.items()
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
return FlashAttentionMetadata(
|
| 560 |
+
num_prefills=self.num_prefills,
|
| 561 |
+
slot_mapping=slot_mapping_tensor,
|
| 562 |
+
num_prefill_tokens=self.num_prefill_tokens,
|
| 563 |
+
num_decode_tokens=num_decode_tokens,
|
| 564 |
+
seq_lens=seq_lens,
|
| 565 |
+
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
| 566 |
+
enable_kv_scales_calculation=True,
|
| 567 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 568 |
+
max_query_len=max_query_len,
|
| 569 |
+
max_decode_query_len=max_decode_query_len,
|
| 570 |
+
max_prefill_seq_len=max_prefill_seq_len,
|
| 571 |
+
max_decode_seq_len=max_decode_seq_len,
|
| 572 |
+
query_start_loc=query_start_loc_tensor,
|
| 573 |
+
seq_start_loc=seq_start_loc_tensor,
|
| 574 |
+
context_lens_tensor=context_lens_tensor,
|
| 575 |
+
block_tables=block_tables,
|
| 576 |
+
use_cuda_graph=use_captured_graph,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
class FlashAttentionImpl(AttentionImpl):
|
| 581 |
+
"""
|
| 582 |
+
If the input tensors contain prompt tokens, the layout is as follows:
|
| 583 |
+
|<--------------- num_prefill_tokens ----------------->|
|
| 584 |
+
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
|
| 585 |
+
|
| 586 |
+
Otherwise, the layout is as follows:
|
| 587 |
+
|<----------------- num_decode_tokens ------------------>|
|
| 588 |
+
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
|
| 589 |
+
|
| 590 |
+
Generation tokens can contain padding when cuda-graph is used.
|
| 591 |
+
Currently, prompt tokens don't contain any padding.
|
| 592 |
+
|
| 593 |
+
The prompts might have different lengths, while the generation tokens
|
| 594 |
+
always have length 1.
|
| 595 |
+
|
| 596 |
+
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
| 597 |
+
batched together in a flattened 1D query.
|
| 598 |
+
|
| 599 |
+
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|
| 600 |
+
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
|
| 601 |
+
|
| 602 |
+
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
| 603 |
+
padding between prefill and decode tokens.
|
| 604 |
+
"""
|
| 605 |
+
|
| 606 |
+
def __init__(
|
| 607 |
+
self,
|
| 608 |
+
num_heads: int,
|
| 609 |
+
head_size: int,
|
| 610 |
+
scale: float,
|
| 611 |
+
num_kv_heads: int,
|
| 612 |
+
alibi_slopes: Optional[List[float]],
|
| 613 |
+
sliding_window: Optional[int],
|
| 614 |
+
kv_cache_dtype: str,
|
| 615 |
+
blocksparse_params: Optional[Dict[str, Any]] = None,
|
| 616 |
+
logits_soft_cap: Optional[float] = None,
|
| 617 |
+
attn_type: str = AttentionType.DECODER,
|
| 618 |
+
) -> None:
|
| 619 |
+
if blocksparse_params is not None:
|
| 620 |
+
raise ValueError(
|
| 621 |
+
"FlashAttention does not support block-sparse attention.")
|
| 622 |
+
self.num_heads = num_heads
|
| 623 |
+
self.head_size = head_size
|
| 624 |
+
self.scale = float(scale)
|
| 625 |
+
self.num_kv_heads = num_kv_heads
|
| 626 |
+
if alibi_slopes is not None:
|
| 627 |
+
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
| 628 |
+
self.alibi_slopes = alibi_slopes
|
| 629 |
+
self.sliding_window = ((sliding_window - 1,
|
| 630 |
+
0) if sliding_window is not None else (-1, -1))
|
| 631 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 632 |
+
if logits_soft_cap is None:
|
| 633 |
+
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
| 634 |
+
logits_soft_cap = 0
|
| 635 |
+
self.logits_soft_cap = logits_soft_cap
|
| 636 |
+
|
| 637 |
+
assert self.num_heads % self.num_kv_heads == 0
|
| 638 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 639 |
+
|
| 640 |
+
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
| 641 |
+
if head_size not in support_head_sizes:
|
| 642 |
+
raise ValueError(
|
| 643 |
+
f"Head size {head_size} is not supported by FlashAttention. "
|
| 644 |
+
f"Supported head sizes are: {support_head_sizes}.")
|
| 645 |
+
self.attn_type = attn_type
|
| 646 |
+
|
| 647 |
+
# if hopper default to FA3, otherwise stick to FA2 for now
|
| 648 |
+
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
|
| 649 |
+
# use FA3 as default for both
|
| 650 |
+
if current_platform.get_device_capability()[0] >= 9:
|
| 651 |
+
self.fa_version = 3 if is_fa_version_supported(3) else 2
|
| 652 |
+
else:
|
| 653 |
+
self.fa_version = 2
|
| 654 |
+
|
| 655 |
+
if VLLM_FLASH_ATTN_VERSION is not None:
|
| 656 |
+
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
|
| 657 |
+
self.fa_version = VLLM_FLASH_ATTN_VERSION
|
| 658 |
+
|
| 659 |
+
if not is_fa_version_supported(self.fa_version):
|
| 660 |
+
logger.error("Cannot use FA version %d is not supported due to %s",
|
| 661 |
+
self.fa_version,
|
| 662 |
+
fa_version_unsupported_reason(self.fa_version))
|
| 663 |
+
|
| 664 |
+
assert is_fa_version_supported(self.fa_version)
|
| 665 |
+
|
| 666 |
+
def forward(
|
| 667 |
+
self,
|
| 668 |
+
layer: AttentionLayer,
|
| 669 |
+
query: torch.Tensor,
|
| 670 |
+
key: torch.Tensor,
|
| 671 |
+
value: torch.Tensor,
|
| 672 |
+
kv_cache: torch.Tensor,
|
| 673 |
+
attn_metadata: FlashAttentionMetadata,
|
| 674 |
+
output: Optional[torch.Tensor] = None,
|
| 675 |
+
) -> torch.Tensor:
|
| 676 |
+
"""Forward pass with FlashAttention.
|
| 677 |
+
|
| 678 |
+
Args:
|
| 679 |
+
query: shape = [num_tokens, num_heads, head_size]
|
| 680 |
+
key: shape = [num_tokens, num_kv_heads, head_size]
|
| 681 |
+
value: shape = [num_tokens, num_kv_heads, head_size]
|
| 682 |
+
output: shape = [num_tokens, num_heads, head_size]
|
| 683 |
+
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
| 684 |
+
NOTE: kv_cache will be an empty tensor with shape [0]
|
| 685 |
+
for profiling run.
|
| 686 |
+
attn_metadata: Metadata for attention.
|
| 687 |
+
NOTE: It in-place updates the output tensor.
|
| 688 |
+
"""
|
| 689 |
+
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
| 690 |
+
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0, (
|
| 691 |
+
"key/v_scale is not supported in FlashAttention.")
|
| 692 |
+
|
| 693 |
+
assert output is not None, "Output tensor must be provided."
|
| 694 |
+
|
| 695 |
+
attn_type = self.attn_type
|
| 696 |
+
if (attn_type == AttentionType.ENCODER
|
| 697 |
+
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
| 698 |
+
raise AttributeError("Encoder attention requires setting "
|
| 699 |
+
"encoder metadata attributes.")
|
| 700 |
+
elif (attn_type == AttentionType.ENCODER_DECODER
|
| 701 |
+
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
| 702 |
+
raise AttributeError("Encoder/decoder cross-attention "
|
| 703 |
+
"requires setting cross-attention "
|
| 704 |
+
"metadata attributes.")
|
| 705 |
+
|
| 706 |
+
kv_cache_dtype: str = self.kv_cache_dtype
|
| 707 |
+
softmax_scale: float = self.scale
|
| 708 |
+
window_size = self.sliding_window
|
| 709 |
+
alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
|
| 710 |
+
logits_soft_cap: Optional[float] = self.logits_soft_cap
|
| 711 |
+
|
| 712 |
+
if kv_cache.numel() > 0:
|
| 713 |
+
key_cache = kv_cache[0]
|
| 714 |
+
value_cache = kv_cache[1]
|
| 715 |
+
# We skip updating the KV cache under two conditions:
|
| 716 |
+
# a. When the Attention Type is ENCODER. In this phase, we compute
|
| 717 |
+
# only the encoder attention without updating the cache.
|
| 718 |
+
# b. When both Key and Value are None. This occurs during
|
| 719 |
+
# cross-attention computation in the decoding phase, where the
|
| 720 |
+
# KV cache is already populated with the cross-attention
|
| 721 |
+
# tensor. Thus, we skip cache updates during this time.
|
| 722 |
+
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
|
| 723 |
+
value is not None):
|
| 724 |
+
if attn_type == AttentionType.ENCODER_DECODER:
|
| 725 |
+
# Update cross-attention KV cache (prefill-only)
|
| 726 |
+
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
| 727 |
+
else:
|
| 728 |
+
# Update self-attention KV cache (prefill/decode)
|
| 729 |
+
updated_slot_mapping = attn_metadata.slot_mapping
|
| 730 |
+
|
| 731 |
+
# Reshape the input keys and values and store them in the cache.
|
| 732 |
+
# If kv_cache is not provided, the new key and value tensors are
|
| 733 |
+
# not cached. This happens during the initial memory
|
| 734 |
+
# profiling run.
|
| 735 |
+
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
| 736 |
+
key,
|
| 737 |
+
value,
|
| 738 |
+
kv_cache[0],
|
| 739 |
+
kv_cache[1],
|
| 740 |
+
updated_slot_mapping.flatten(), # type: ignore[union-attr]
|
| 741 |
+
kv_cache_dtype,
|
| 742 |
+
layer._k_scale,
|
| 743 |
+
layer._v_scale,
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
| 747 |
+
num_decode_query_tokens) = \
|
| 748 |
+
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
| 749 |
+
decode_query = query[num_prefill_query_tokens:]
|
| 750 |
+
decode_output = output[num_prefill_query_tokens:]
|
| 751 |
+
# QKV for prefill.
|
| 752 |
+
query = query[:num_prefill_query_tokens]
|
| 753 |
+
prefill_output = output[:num_prefill_query_tokens]
|
| 754 |
+
assert query.shape[0] == num_prefill_query_tokens
|
| 755 |
+
assert decode_query.shape[0] == num_decode_query_tokens
|
| 756 |
+
|
| 757 |
+
if prefill_meta := attn_metadata.prefill_metadata:
|
| 758 |
+
# Prompt run.
|
| 759 |
+
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
| 760 |
+
or prefill_meta.block_tables.numel() == 0):
|
| 761 |
+
# normal attention
|
| 762 |
+
# When block_tables are not filled, it means q and k are the
|
| 763 |
+
# prompt, and they have the same length.
|
| 764 |
+
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
|
| 765 |
+
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
|
| 766 |
+
|
| 767 |
+
key = key[:num_prefill_kv_tokens]
|
| 768 |
+
value = value[:num_prefill_kv_tokens]
|
| 769 |
+
|
| 770 |
+
flash_attn_varlen_func(
|
| 771 |
+
q=query,
|
| 772 |
+
k=key,
|
| 773 |
+
v=value,
|
| 774 |
+
cu_seqlens_q=q_seq_start_loc,
|
| 775 |
+
cu_seqlens_k=k_seq_start_loc,
|
| 776 |
+
max_seqlen_q=q_seq_len,
|
| 777 |
+
max_seqlen_k=k_seq_len,
|
| 778 |
+
softmax_scale=softmax_scale,
|
| 779 |
+
causal=_get_causal_option(attn_type),
|
| 780 |
+
window_size=window_size,
|
| 781 |
+
alibi_slopes=alibi_slopes,
|
| 782 |
+
softcap=logits_soft_cap,
|
| 783 |
+
out=prefill_output,
|
| 784 |
+
fa_version=self.fa_version,
|
| 785 |
+
)
|
| 786 |
+
else:
|
| 787 |
+
# prefix-enabled attention
|
| 788 |
+
assert attn_type == AttentionType.DECODER, (
|
| 789 |
+
"Only decoder-only models support prefix caching")
|
| 790 |
+
assert prefill_meta.seq_lens is not None
|
| 791 |
+
max_seq_len = max(prefill_meta.seq_lens)
|
| 792 |
+
flash_attn_varlen_func( # noqa
|
| 793 |
+
q=query,
|
| 794 |
+
k=key_cache,
|
| 795 |
+
v=value_cache,
|
| 796 |
+
cu_seqlens_q=prefill_meta.query_start_loc,
|
| 797 |
+
max_seqlen_q=prefill_meta.max_query_len,
|
| 798 |
+
seqused_k=prefill_meta.seq_lens_tensor,
|
| 799 |
+
max_seqlen_k=max_seq_len,
|
| 800 |
+
softmax_scale=softmax_scale,
|
| 801 |
+
causal=True,
|
| 802 |
+
window_size=window_size,
|
| 803 |
+
alibi_slopes=alibi_slopes,
|
| 804 |
+
block_table=prefill_meta.block_tables,
|
| 805 |
+
softcap=logits_soft_cap,
|
| 806 |
+
out=prefill_output,
|
| 807 |
+
fa_version=self.fa_version,
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
if decode_meta := attn_metadata.decode_metadata:
|
| 811 |
+
# Decoding run.
|
| 812 |
+
# Use flash_attn_varlen_func kernel for speculative decoding
|
| 813 |
+
# because different queries might have different lengths.
|
| 814 |
+
|
| 815 |
+
assert decode_meta.max_decode_query_len is not None
|
| 816 |
+
# use only for actual varlen decoding
|
| 817 |
+
if decode_meta.max_decode_query_len > 1:
|
| 818 |
+
assert attn_type == AttentionType.DECODER, (
|
| 819 |
+
"Only decoder-only models support max_decode_query_len > 1"
|
| 820 |
+
)
|
| 821 |
+
flash_attn_varlen_func(
|
| 822 |
+
q=decode_query,
|
| 823 |
+
k=key_cache,
|
| 824 |
+
v=value_cache,
|
| 825 |
+
cu_seqlens_q=decode_meta.query_start_loc,
|
| 826 |
+
max_seqlen_q=decode_meta.max_decode_query_len,
|
| 827 |
+
seqused_k=decode_meta.seq_lens_tensor,
|
| 828 |
+
max_seqlen_k=decode_meta.max_decode_seq_len,
|
| 829 |
+
softmax_scale=softmax_scale,
|
| 830 |
+
causal=True,
|
| 831 |
+
window_size=window_size,
|
| 832 |
+
alibi_slopes=alibi_slopes,
|
| 833 |
+
softcap=logits_soft_cap,
|
| 834 |
+
block_table=decode_meta.block_tables,
|
| 835 |
+
out=decode_output,
|
| 836 |
+
fa_version=self.fa_version,
|
| 837 |
+
)
|
| 838 |
+
else:
|
| 839 |
+
# Use flash_attn_with_kvcache for normal decoding.
|
| 840 |
+
(
|
| 841 |
+
seq_lens_arg,
|
| 842 |
+
_,
|
| 843 |
+
block_tables_arg,
|
| 844 |
+
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
| 845 |
+
flash_attn_with_kvcache(
|
| 846 |
+
q=decode_query.unsqueeze(1),
|
| 847 |
+
k_cache=key_cache,
|
| 848 |
+
v_cache=value_cache,
|
| 849 |
+
block_table=block_tables_arg,
|
| 850 |
+
cache_seqlens=seq_lens_arg,
|
| 851 |
+
softmax_scale=softmax_scale,
|
| 852 |
+
causal=True,
|
| 853 |
+
window_size=window_size,
|
| 854 |
+
alibi_slopes=alibi_slopes,
|
| 855 |
+
softcap=logits_soft_cap,
|
| 856 |
+
out=decode_output.unsqueeze(1),
|
| 857 |
+
fa_version=self.fa_version,
|
| 858 |
+
)
|
| 859 |
+
return output
|
| 860 |
+
|
| 861 |
+
|
| 862 |
+
def _get_query_key_seq_metadata(
|
| 863 |
+
attn_metadata,
|
| 864 |
+
is_prompt: bool,
|
| 865 |
+
attn_type: str,
|
| 866 |
+
) -> tuple:
|
| 867 |
+
"""
|
| 868 |
+
Returns sequence metadata for key and query based on the specified
|
| 869 |
+
attention type and whether input is a prompt.
|
| 870 |
+
|
| 871 |
+
This function computes the starting locations and maximum sequence lengths
|
| 872 |
+
for key and query sequences for different attention types.
|
| 873 |
+
|
| 874 |
+
Args:
|
| 875 |
+
attn_metadata: The attention metadata object
|
| 876 |
+
is_prompt (bool): A flag indicating if the input is a prompt
|
| 877 |
+
attn_type (AttentionType): The type of attention being used.
|
| 878 |
+
|
| 879 |
+
Returns:
|
| 880 |
+
tuple: A tuple containing four integers:
|
| 881 |
+
- Starting location for the query sequence.
|
| 882 |
+
- Maximum sequence length for the query sequence.
|
| 883 |
+
- Starting location for the key sequence.
|
| 884 |
+
- Maximum sequence length for the key sequence.
|
| 885 |
+
|
| 886 |
+
Raises:
|
| 887 |
+
AttributeError: If an invalid attention type is provided.
|
| 888 |
+
"""
|
| 889 |
+
if attn_type == AttentionType.DECODER:
|
| 890 |
+
# Decoder self-attention
|
| 891 |
+
# Choose max_seq_len based on whether we are in prompt_run
|
| 892 |
+
if is_prompt:
|
| 893 |
+
max_seq_len = attn_metadata.max_prefill_seq_len
|
| 894 |
+
else:
|
| 895 |
+
max_seq_len = attn_metadata.max_decode_seq_len
|
| 896 |
+
return (attn_metadata.seq_start_loc, max_seq_len,
|
| 897 |
+
attn_metadata.seq_start_loc, max_seq_len)
|
| 898 |
+
|
| 899 |
+
elif attn_type == AttentionType.ENCODER_DECODER:
|
| 900 |
+
# This is cross attention between the where the key
|
| 901 |
+
# is the precomputed encoder attention and query
|
| 902 |
+
# is the input sequence.
|
| 903 |
+
# Choose query max length based on whether it is prompt
|
| 904 |
+
# or not.
|
| 905 |
+
if is_prompt:
|
| 906 |
+
max_seq_len = attn_metadata.max_prefill_seq_len
|
| 907 |
+
else:
|
| 908 |
+
max_seq_len = attn_metadata.max_decode_seq_len
|
| 909 |
+
return (attn_metadata.seq_start_loc, max_seq_len,
|
| 910 |
+
attn_metadata.encoder_seq_start_loc,
|
| 911 |
+
attn_metadata.max_encoder_seq_len)
|
| 912 |
+
elif attn_type == AttentionType.ENCODER:
|
| 913 |
+
# For encoder attention both the query and the key are same i.e the
|
| 914 |
+
# encoder sequence.
|
| 915 |
+
return (attn_metadata.encoder_seq_start_loc,
|
| 916 |
+
attn_metadata.max_encoder_seq_len,
|
| 917 |
+
attn_metadata.encoder_seq_start_loc,
|
| 918 |
+
attn_metadata.max_encoder_seq_len)
|
| 919 |
+
elif attn_type == AttentionType.ENCODER_ONLY:
|
| 920 |
+
assert is_prompt, "Should not have decode for encoder only model."
|
| 921 |
+
return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len,
|
| 922 |
+
attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len)
|
| 923 |
+
else:
|
| 924 |
+
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
def _get_causal_option(attn_type: str) -> bool:
|
| 928 |
+
"""
|
| 929 |
+
Determine whether the given attention type is suitable for causal
|
| 930 |
+
attention mechanisms.
|
| 931 |
+
|
| 932 |
+
Args:
|
| 933 |
+
attn_type (AttentionType): The type of attention being evaluated
|
| 934 |
+
|
| 935 |
+
Returns:
|
| 936 |
+
bool: Returns `True` if the attention type is suitable for causal
|
| 937 |
+
attention (i.e., not encoder, encoder-only, or encoder-decoder),
|
| 938 |
+
otherwise returns `False`.
|
| 939 |
+
"""
|
| 940 |
+
return not (attn_type == AttentionType.ENCODER
|
| 941 |
+
or attn_type == AttentionType.ENCODER_ONLY
|
| 942 |
+
or attn_type == AttentionType.ENCODER_DECODER)
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/flashinfer.py
ADDED
|
@@ -0,0 +1,1066 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import dataclasses
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
|
| 8 |
+
|
| 9 |
+
from vllm.multimodal import MultiModalPlaceholderMap
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
| 13 |
+
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
|
| 14 |
+
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
| 15 |
+
|
| 16 |
+
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
| 17 |
+
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
| 18 |
+
except ImportError:
|
| 19 |
+
# Avoid turning these types into variables during type checking
|
| 20 |
+
if not TYPE_CHECKING:
|
| 21 |
+
BatchDecodeWithPagedKVCacheWrapper = None
|
| 22 |
+
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
| 23 |
+
BatchPrefillWithPagedKVCacheWrapper = None
|
| 24 |
+
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
|
| 28 |
+
import vllm.envs as envs
|
| 29 |
+
from vllm import _custom_ops as ops
|
| 30 |
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
| 31 |
+
AttentionLayer,
|
| 32 |
+
AttentionMetadata,
|
| 33 |
+
AttentionMetadataBuilder,
|
| 34 |
+
AttentionState, AttentionType)
|
| 35 |
+
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
| 36 |
+
compute_slot_mapping_start_idx,
|
| 37 |
+
is_block_tables_empty)
|
| 38 |
+
from vllm.attention.layer import Attention
|
| 39 |
+
from vllm.attention.ops.paged_attn import PagedAttention
|
| 40 |
+
from vllm.config import VllmConfig, get_current_vllm_config
|
| 41 |
+
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
|
| 42 |
+
make_tensor_with_pad)
|
| 43 |
+
|
| 44 |
+
if TYPE_CHECKING:
|
| 45 |
+
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
| 46 |
+
ModelInputForGPUWithSamplingMetadata)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class FlashInferBackend(AttentionBackend):
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
def get_name() -> str:
|
| 53 |
+
return "FLASHINFER"
|
| 54 |
+
|
| 55 |
+
@staticmethod
|
| 56 |
+
def get_impl_cls() -> Type["FlashInferImpl"]:
|
| 57 |
+
return FlashInferImpl
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
| 61 |
+
return FlashInferMetadata
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
|
| 65 |
+
return FlashInferMetadataBuilder
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def get_state_cls() -> Type["FlashInferState"]:
|
| 69 |
+
return FlashInferState
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def get_kv_cache_shape(
|
| 73 |
+
num_blocks: int,
|
| 74 |
+
block_size: int,
|
| 75 |
+
num_kv_heads: int,
|
| 76 |
+
head_size: int,
|
| 77 |
+
) -> Tuple[int, ...]:
|
| 78 |
+
return (num_blocks, 2, block_size, num_kv_heads, head_size)
|
| 79 |
+
|
| 80 |
+
@staticmethod
|
| 81 |
+
def swap_blocks(
|
| 82 |
+
src_kv_cache: torch.Tensor,
|
| 83 |
+
dst_kv_cache: torch.Tensor,
|
| 84 |
+
src_to_dst: torch.Tensor,
|
| 85 |
+
) -> None:
|
| 86 |
+
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def copy_blocks(
|
| 90 |
+
kv_caches: List[torch.Tensor],
|
| 91 |
+
src_to_dists: torch.Tensor,
|
| 92 |
+
) -> None:
|
| 93 |
+
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
| 94 |
+
|
| 95 |
+
@staticmethod
|
| 96 |
+
def get_supported_head_sizes() -> List[int]:
|
| 97 |
+
return [64, 128, 256]
|
| 98 |
+
|
| 99 |
+
@staticmethod
|
| 100 |
+
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
|
| 101 |
+
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
| 102 |
+
return torch.float8_e4m3fn
|
| 103 |
+
elif kv_cache_dtype == "fp8_e5m2":
|
| 104 |
+
return torch.float8_e5m2
|
| 105 |
+
else:
|
| 106 |
+
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@dataclass
|
| 110 |
+
class PerLayerParameters:
|
| 111 |
+
"""
|
| 112 |
+
Currently, FlashInfer backend only support models in which all layers share
|
| 113 |
+
the same values for the following hyperparameters.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
window_left: int
|
| 117 |
+
logits_soft_cap: Optional[float]
|
| 118 |
+
sm_scale: float
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_per_layer_parameters(
|
| 122 |
+
vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]:
|
| 123 |
+
"""
|
| 124 |
+
Scan all attention layers and determine some hyperparameters
|
| 125 |
+
to use during `plan`.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
layers = vllm_config.compilation_config.static_forward_context
|
| 129 |
+
per_layer_params: Dict[str, PerLayerParameters] = {}
|
| 130 |
+
|
| 131 |
+
for key, layer in layers.items():
|
| 132 |
+
assert isinstance(layer, Attention)
|
| 133 |
+
|
| 134 |
+
impl = layer.impl
|
| 135 |
+
assert isinstance(impl, FlashInferImpl)
|
| 136 |
+
|
| 137 |
+
# Infer hyperparameters from the attention layer
|
| 138 |
+
window_size = impl.sliding_window
|
| 139 |
+
window_left = window_size[0] if window_size is not None else -1
|
| 140 |
+
logits_soft_cap = impl.logits_soft_cap
|
| 141 |
+
sm_scale = impl.scale
|
| 142 |
+
|
| 143 |
+
per_layer_params[key] = PerLayerParameters(window_left,
|
| 144 |
+
logits_soft_cap, sm_scale)
|
| 145 |
+
|
| 146 |
+
return per_layer_params
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def infer_global_hyperparameters(
|
| 150 |
+
per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters:
|
| 151 |
+
"""
|
| 152 |
+
Currently, FlashInfer backend only support models in which all layers share
|
| 153 |
+
the same values for the following hyperparameters:
|
| 154 |
+
- `window_left`
|
| 155 |
+
- `logits_soft_cap`
|
| 156 |
+
- `sm_scale`
|
| 157 |
+
|
| 158 |
+
So this function asserts that all layers share the same values for these
|
| 159 |
+
hyperparameters and returns the global values.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
assert len(per_layer_params) > 0, "No attention layers found in the model."
|
| 163 |
+
|
| 164 |
+
param_sets = list(per_layer_params.values())
|
| 165 |
+
global_params = param_sets[0]
|
| 166 |
+
for params in param_sets:
|
| 167 |
+
assert params == global_params, (
|
| 168 |
+
"FlashInfer backend currently only supports models in which all "
|
| 169 |
+
"layers share the same values for the following hyperparameters: "
|
| 170 |
+
"`window_left`, `logits_soft_cap`, `sm_scale`.")
|
| 171 |
+
|
| 172 |
+
return global_params
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class FlashInferState(AttentionState):
|
| 176 |
+
|
| 177 |
+
def __init__(self, runner):
|
| 178 |
+
self.runner = runner
|
| 179 |
+
self._is_graph_capturing = False
|
| 180 |
+
self._workspace_buffer = None
|
| 181 |
+
self._decode_wrapper = None
|
| 182 |
+
self._prefill_wrapper = None
|
| 183 |
+
|
| 184 |
+
# Global hyperparameters shared by all attention layers
|
| 185 |
+
self.global_hyperparameters: Optional[PerLayerParameters] = None
|
| 186 |
+
|
| 187 |
+
self.vllm_config = get_current_vllm_config()
|
| 188 |
+
|
| 189 |
+
def _get_workspace_buffer(self):
|
| 190 |
+
if self._workspace_buffer is None:
|
| 191 |
+
self._workspace_buffer = torch.empty(
|
| 192 |
+
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
| 193 |
+
dtype=torch.uint8,
|
| 194 |
+
device=self.runner.device)
|
| 195 |
+
return self._workspace_buffer
|
| 196 |
+
|
| 197 |
+
def _get_prefill_wrapper(self):
|
| 198 |
+
if self._prefill_wrapper is None:
|
| 199 |
+
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
| 200 |
+
self._get_workspace_buffer(), "NHD")
|
| 201 |
+
return self._prefill_wrapper
|
| 202 |
+
|
| 203 |
+
def _get_decode_wrapper(self):
|
| 204 |
+
if self._decode_wrapper is None:
|
| 205 |
+
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
|
| 206 |
+
self.runner.parallel_config))
|
| 207 |
+
num_kv_heads = self.runner.model_config.get_num_kv_heads(
|
| 208 |
+
self.runner.parallel_config)
|
| 209 |
+
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
|
| 210 |
+
num_qo_heads // num_kv_heads > 4)
|
| 211 |
+
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
| 212 |
+
self._get_workspace_buffer(),
|
| 213 |
+
"NHD",
|
| 214 |
+
use_tensor_cores=use_tensor_cores)
|
| 215 |
+
return self._decode_wrapper
|
| 216 |
+
|
| 217 |
+
@contextmanager
|
| 218 |
+
def graph_capture(self, max_batch_size: int):
|
| 219 |
+
self._is_graph_capturing = True
|
| 220 |
+
self._graph_decode_wrapper = None
|
| 221 |
+
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
| 222 |
+
PAD_SLOT_ID,
|
| 223 |
+
dtype=torch.long,
|
| 224 |
+
device=self.runner.device)
|
| 225 |
+
self._graph_seq_lens = torch.ones(max_batch_size,
|
| 226 |
+
dtype=torch.int32,
|
| 227 |
+
device=self.runner.device)
|
| 228 |
+
self._graph_block_tables = torch.from_numpy(
|
| 229 |
+
self.runner.graph_block_tables).to(device=self.runner.device)
|
| 230 |
+
self._graph_decode_workspace_buffer = self._get_workspace_buffer()
|
| 231 |
+
self._graph_indices_buffer = torch.empty(
|
| 232 |
+
max_batch_size * self.runner.cache_config.num_gpu_blocks,
|
| 233 |
+
dtype=torch.int32,
|
| 234 |
+
device=self.runner.device)
|
| 235 |
+
self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
|
| 236 |
+
dtype=torch.int32,
|
| 237 |
+
device=self.runner.device)
|
| 238 |
+
self._graph_last_page_len_buffer = torch.empty(
|
| 239 |
+
max_batch_size, dtype=torch.int32, device=self.runner.device)
|
| 240 |
+
yield
|
| 241 |
+
self._is_graph_capturing = False
|
| 242 |
+
del self._graph_slot_mapping
|
| 243 |
+
del self._graph_seq_lens
|
| 244 |
+
del self._graph_block_tables
|
| 245 |
+
del self._graph_decode_workspace_buffer
|
| 246 |
+
del self._graph_indices_buffer
|
| 247 |
+
del self._graph_indptr_buffer
|
| 248 |
+
del self._graph_last_page_len_buffer
|
| 249 |
+
del self._graph_decode_wrapper
|
| 250 |
+
|
| 251 |
+
def graph_clone(self, batch_size: int):
|
| 252 |
+
assert self._is_graph_capturing
|
| 253 |
+
state = self.__class__(self.runner)
|
| 254 |
+
state._workspace_buffer = self._graph_decode_workspace_buffer
|
| 255 |
+
state._decode_wrapper = self._graph_decode_wrapper
|
| 256 |
+
state._prefill_wrapper = self._get_prefill_wrapper()
|
| 257 |
+
return state
|
| 258 |
+
|
| 259 |
+
def graph_capture_get_metadata_for_batch(
|
| 260 |
+
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
| 261 |
+
assert self._is_graph_capturing
|
| 262 |
+
_indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
|
| 263 |
+
_last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]
|
| 264 |
+
|
| 265 |
+
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
|
| 266 |
+
self.runner.parallel_config))
|
| 267 |
+
num_kv_heads = self.runner.model_config.get_num_kv_heads(
|
| 268 |
+
self.runner.parallel_config)
|
| 269 |
+
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
|
| 270 |
+
num_qo_heads // num_kv_heads > 4)
|
| 271 |
+
self._graph_decode_wrapper = \
|
| 272 |
+
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
|
| 273 |
+
self._graph_decode_workspace_buffer, _indptr_buffer,
|
| 274 |
+
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
|
| 275 |
+
use_tensor_cores)
|
| 276 |
+
if self.runner.kv_cache_dtype.startswith("fp8"):
|
| 277 |
+
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
| 278 |
+
self.runner.kv_cache_dtype)
|
| 279 |
+
else:
|
| 280 |
+
kv_cache_dtype = get_kv_cache_torch_dtype(
|
| 281 |
+
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
|
| 282 |
+
|
| 283 |
+
paged_kv_indptr_tensor_host = torch.arange(0,
|
| 284 |
+
batch_size + 1,
|
| 285 |
+
dtype=torch.int32)
|
| 286 |
+
paged_kv_indices_tensor_host = torch.arange(0,
|
| 287 |
+
batch_size,
|
| 288 |
+
dtype=torch.int32)
|
| 289 |
+
paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
|
| 290 |
+
self.runner.block_size,
|
| 291 |
+
dtype=torch.int32)
|
| 292 |
+
query_start_loc_host = torch.arange(0,
|
| 293 |
+
batch_size + 1,
|
| 294 |
+
dtype=torch.int32)
|
| 295 |
+
|
| 296 |
+
global_params = infer_global_hyperparameters(
|
| 297 |
+
get_per_layer_parameters(self.vllm_config))
|
| 298 |
+
|
| 299 |
+
attn_metadata = self.runner.attn_backend.make_metadata(
|
| 300 |
+
num_prefills=0,
|
| 301 |
+
slot_mapping=self._graph_slot_mapping[:batch_size],
|
| 302 |
+
multi_modal_placeholder_index_maps=None,
|
| 303 |
+
enable_kv_scales_calculation=False,
|
| 304 |
+
num_prefill_tokens=0,
|
| 305 |
+
num_decode_tokens=batch_size,
|
| 306 |
+
max_prefill_seq_len=0,
|
| 307 |
+
block_tables=self._graph_block_tables,
|
| 308 |
+
paged_kv_indptr=paged_kv_indptr_tensor_host,
|
| 309 |
+
paged_kv_indices=paged_kv_indices_tensor_host,
|
| 310 |
+
paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
|
| 311 |
+
num_qo_heads=num_qo_heads,
|
| 312 |
+
num_kv_heads=num_kv_heads,
|
| 313 |
+
head_dim=self.runner.model_config.get_head_size(),
|
| 314 |
+
page_size=self.runner.block_size,
|
| 315 |
+
seq_start_loc=None,
|
| 316 |
+
query_start_loc=query_start_loc_host,
|
| 317 |
+
device=self.runner.device,
|
| 318 |
+
data_type=kv_cache_dtype,
|
| 319 |
+
q_data_type=self.runner.model_config.dtype,
|
| 320 |
+
use_cuda_graph=True,
|
| 321 |
+
decode_wrapper=self._graph_decode_wrapper,
|
| 322 |
+
prefill_wrapper=None,
|
| 323 |
+
**dataclasses.asdict(global_params),
|
| 324 |
+
)
|
| 325 |
+
attn_metadata.begin_forward()
|
| 326 |
+
return attn_metadata
|
| 327 |
+
|
| 328 |
+
def get_graph_input_buffers(self,
|
| 329 |
+
attn_metadata,
|
| 330 |
+
is_encoder_decoder_model: bool = False):
|
| 331 |
+
return {
|
| 332 |
+
"slot_mapping": attn_metadata.slot_mapping,
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
def prepare_graph_input_buffers(self,
|
| 336 |
+
input_buffers,
|
| 337 |
+
attn_metadata,
|
| 338 |
+
is_encoder_decoder_model: bool = False):
|
| 339 |
+
return
|
| 340 |
+
|
| 341 |
+
def begin_forward(self, model_input):
|
| 342 |
+
assert not self._is_graph_capturing
|
| 343 |
+
state = self
|
| 344 |
+
use_cuda_graph = model_input.attn_metadata.use_cuda_graph
|
| 345 |
+
is_decode = model_input.attn_metadata.num_prefills == 0
|
| 346 |
+
# In case of multistep chunked-prefill, there might be prefill requests
|
| 347 |
+
# scheduled while CUDA graph mode is enabled. We don't run graph in that
|
| 348 |
+
# case.
|
| 349 |
+
if use_cuda_graph and is_decode:
|
| 350 |
+
batch_size = model_input.input_tokens.shape[0]
|
| 351 |
+
state = (self.runner.graph_runners[model_input.virtual_engine]
|
| 352 |
+
[batch_size].attn_state)
|
| 353 |
+
model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
|
| 354 |
+
)
|
| 355 |
+
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
|
| 356 |
+
model_input.attn_metadata.begin_forward()
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
@dataclass
|
| 360 |
+
class FlashInferMetadata(AttentionMetadata):
|
| 361 |
+
# Maximum sequence length among prefill batch. 0 if there are decoding
|
| 362 |
+
# requests only.
|
| 363 |
+
max_prefill_seq_len: int
|
| 364 |
+
# Number of query tokens for each request in the batch.
|
| 365 |
+
# Currently, we require that all requests have the same number of query
|
| 366 |
+
# tokens during the decoding phase. When speculavie decoding is enabled,
|
| 367 |
+
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
| 368 |
+
decode_query_len: Optional[int] = 1
|
| 369 |
+
|
| 370 |
+
use_cuda_graph: bool = True
|
| 371 |
+
|
| 372 |
+
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
|
| 373 |
+
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
| 374 |
+
|
| 375 |
+
# Metadata for the prefill stage
|
| 376 |
+
seq_start_loc: Optional[torch.Tensor] = None
|
| 377 |
+
query_start_loc: Optional[torch.Tensor] = None
|
| 378 |
+
block_tables: Optional[torch.Tensor] = None
|
| 379 |
+
|
| 380 |
+
# used for GPU in-place advance_step
|
| 381 |
+
seq_lens_tensor: Optional[torch.Tensor] = None
|
| 382 |
+
block_table_bound: Optional[torch.Tensor] = None
|
| 383 |
+
|
| 384 |
+
# An example for paged_kv_indices, paged_kv_indptr:
|
| 385 |
+
# request 1, page indices [0, 5, 8]
|
| 386 |
+
# request 2, page indices [1, 6, 7]
|
| 387 |
+
# request 3, page indices [3, 4]
|
| 388 |
+
# paged_kv_indices is a concatenation of page indices of all requests:
|
| 389 |
+
# [0, 5, 8, 1, 6, 7, 3, 4]
|
| 390 |
+
# paged_kv_indptr is used to index into paged_kv_indices:
|
| 391 |
+
# [0, 3, 6, 8]
|
| 392 |
+
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
| 393 |
+
paged_kv_indptr: Optional[torch.Tensor] = None
|
| 394 |
+
# The page indices of the paged kv cache
|
| 395 |
+
paged_kv_indices: Optional[torch.Tensor] = None
|
| 396 |
+
# The number of entries in the last page of each request in
|
| 397 |
+
# the paged kv cache, shape: [batch_size]
|
| 398 |
+
paged_kv_last_page_len: Optional[torch.Tensor] = None
|
| 399 |
+
# The number of query/output heads
|
| 400 |
+
num_qo_heads: Optional[int] = None
|
| 401 |
+
# The number of key/value heads
|
| 402 |
+
num_kv_heads: Optional[int] = None
|
| 403 |
+
# The dimension of the attention heads
|
| 404 |
+
head_dim: Optional[int] = None
|
| 405 |
+
# Block size of vllm
|
| 406 |
+
page_size: Optional[int] = None
|
| 407 |
+
# The data type of the paged kv cache
|
| 408 |
+
data_type: torch.dtype = None
|
| 409 |
+
# The data type of the query
|
| 410 |
+
q_data_type: torch.dtype = None
|
| 411 |
+
# FlashInfer 0.2 encourages passing host tensors
|
| 412 |
+
device: torch.device = torch.device("cpu")
|
| 413 |
+
is_profile_run: bool = False
|
| 414 |
+
|
| 415 |
+
# The FlashInfer backend currently supports only models in which all layers
|
| 416 |
+
# share the same following hyperparameters:
|
| 417 |
+
|
| 418 |
+
# The left (inclusive) window size for the attention window, when
|
| 419 |
+
# set to `-1`, the window size will be set to the full length of
|
| 420 |
+
# the sequence. Defaults to `-1`.
|
| 421 |
+
window_left: int = -1
|
| 422 |
+
# The attention logits soft capping value (used in Gemini, Grok and
|
| 423 |
+
# Gemma-2, etc.), if not provided, will be set to `0`. If greater
|
| 424 |
+
# than 0, the logits will be capped according to formula:
|
| 425 |
+
# $$\texttt{logits\_soft\_cap} \times
|
| 426 |
+
# \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$,
|
| 427 |
+
# where $x$ is the input logits.
|
| 428 |
+
logits_soft_cap: Optional[float] = None
|
| 429 |
+
# The scale used in softmax, if not provided, will be set to
|
| 430 |
+
# `1.0 / sqrt(head_dim)`.
|
| 431 |
+
sm_scale: Optional[float] = None
|
| 432 |
+
|
| 433 |
+
def __post_init__(self):
|
| 434 |
+
# Refer to
|
| 435 |
+
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
|
| 436 |
+
supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
|
| 437 |
+
if self.head_dim is not None and self.head_dim \
|
| 438 |
+
not in supported_head_sizes:
|
| 439 |
+
raise ValueError(
|
| 440 |
+
f"Only {supported_head_sizes} are supported for head_dim,",
|
| 441 |
+
f"received {self.head_dim}.")
|
| 442 |
+
|
| 443 |
+
def begin_forward(self):
|
| 444 |
+
if self.num_prefill_tokens > 0:
|
| 445 |
+
if self.paged_kv_indices is None:
|
| 446 |
+
return
|
| 447 |
+
|
| 448 |
+
assert self.prefill_wrapper is not None
|
| 449 |
+
assert self.query_start_loc is not None
|
| 450 |
+
assert self.paged_kv_indices is not None
|
| 451 |
+
assert self.paged_kv_indptr is not None
|
| 452 |
+
assert self.paged_kv_last_page_len is not None
|
| 453 |
+
assert self.block_table_bound is not None
|
| 454 |
+
assert self.seq_lens_tensor is not None
|
| 455 |
+
self.query_start_loc = self.query_start_loc[:self.num_prefills + 1]
|
| 456 |
+
batch_size = self.query_start_loc.shape[0] - 1
|
| 457 |
+
assert batch_size >= 0
|
| 458 |
+
# We will use flash attention for profiling to
|
| 459 |
+
# determine the number of blocks. Therefore,
|
| 460 |
+
# we don't need to prepare the input for flashinfer for profile run.
|
| 461 |
+
if not self.is_profile_run:
|
| 462 |
+
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
| 463 |
+
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
| 464 |
+
self.device)
|
| 465 |
+
self.block_table_bound = self.block_table_bound.to(self.device)
|
| 466 |
+
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
|
| 467 |
+
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
| 468 |
+
self.prefill_wrapper.plan(
|
| 469 |
+
self.query_start_loc,
|
| 470 |
+
self.paged_kv_indptr[:self.num_prefills + 1],
|
| 471 |
+
self.paged_kv_indices,
|
| 472 |
+
self.paged_kv_last_page_len[:self.num_prefills],
|
| 473 |
+
self.num_qo_heads,
|
| 474 |
+
self.num_kv_heads,
|
| 475 |
+
self.head_dim,
|
| 476 |
+
self.page_size,
|
| 477 |
+
causal=True,
|
| 478 |
+
sm_scale=self.sm_scale,
|
| 479 |
+
window_left=self.window_left,
|
| 480 |
+
logits_soft_cap=self.logits_soft_cap,
|
| 481 |
+
q_data_type=self.q_data_type,
|
| 482 |
+
kv_data_type=self.data_type)
|
| 483 |
+
if self.num_decode_tokens > 0:
|
| 484 |
+
assert self.paged_kv_indices is not None
|
| 485 |
+
assert self.paged_kv_indptr is not None
|
| 486 |
+
assert self.paged_kv_last_page_len is not None
|
| 487 |
+
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
| 488 |
+
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
| 489 |
+
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
| 490 |
+
self.device)
|
| 491 |
+
# handle model warmup path
|
| 492 |
+
if self.block_table_bound is not None:
|
| 493 |
+
self.block_table_bound = self.block_table_bound.to(self.device)
|
| 494 |
+
if self.seq_lens_tensor is not None:
|
| 495 |
+
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
|
| 496 |
+
|
| 497 |
+
assert self.decode_wrapper is not None
|
| 498 |
+
self.decode_wrapper.plan(
|
| 499 |
+
self.paged_kv_indptr[self.num_prefills:],
|
| 500 |
+
self.paged_kv_indices,
|
| 501 |
+
self.paged_kv_last_page_len[self.num_prefills:],
|
| 502 |
+
self.num_qo_heads,
|
| 503 |
+
self.num_kv_heads,
|
| 504 |
+
self.head_dim,
|
| 505 |
+
self.page_size,
|
| 506 |
+
# Disable flashinfer's pos encoding and use vllm's rope.
|
| 507 |
+
pos_encoding_mode="NONE",
|
| 508 |
+
window_left=self.window_left,
|
| 509 |
+
logits_soft_cap=self.logits_soft_cap,
|
| 510 |
+
sm_scale=self.sm_scale,
|
| 511 |
+
# kv-cache data type.
|
| 512 |
+
kv_data_type=self.data_type,
|
| 513 |
+
# query data type.
|
| 514 |
+
q_data_type=self.q_data_type)
|
| 515 |
+
|
| 516 |
+
def asdict_zerocopy(self,
|
| 517 |
+
skip_fields: Optional[Set[str]] = None
|
| 518 |
+
) -> Dict[str, Any]:
|
| 519 |
+
if skip_fields is None:
|
| 520 |
+
skip_fields = set()
|
| 521 |
+
# We need to skip the prefill/decode_wrapper field since it cannot be
|
| 522 |
+
# broadcasted with nccl when TP is enabled.
|
| 523 |
+
skip_fields.add('prefill_wrapper')
|
| 524 |
+
skip_fields.add('decode_wrapper')
|
| 525 |
+
return super().asdict_zerocopy(skip_fields)
|
| 526 |
+
|
| 527 |
+
@property
|
| 528 |
+
def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
|
| 529 |
+
if self.num_prefills == 0:
|
| 530 |
+
return None
|
| 531 |
+
return self
|
| 532 |
+
|
| 533 |
+
@property
|
| 534 |
+
def decode_metadata(self) -> Optional["FlashInferMetadata"]:
|
| 535 |
+
if self.num_decode_tokens == 0:
|
| 536 |
+
return None
|
| 537 |
+
return self
|
| 538 |
+
|
| 539 |
+
def advance_step(self,
|
| 540 |
+
model_input: "ModelInputForGPUWithSamplingMetadata",
|
| 541 |
+
sampled_token_ids: Optional[torch.Tensor],
|
| 542 |
+
block_size: int,
|
| 543 |
+
num_seqs: int,
|
| 544 |
+
num_queries: int,
|
| 545 |
+
turn_prefills_into_decodes: bool = False):
|
| 546 |
+
"""
|
| 547 |
+
Update metadata in-place to advance one decode step.
|
| 548 |
+
"""
|
| 549 |
+
|
| 550 |
+
if turn_prefills_into_decodes:
|
| 551 |
+
# When Multi-Step is enabled with Chunked-Prefill, prefills and
|
| 552 |
+
# decodes are scheduled together. In the first step, all the
|
| 553 |
+
# prefills turn into decodes. This update reflects that
|
| 554 |
+
# conversion.
|
| 555 |
+
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
| 556 |
+
# Flashinfer doesn't support speculative decoding + chunked-prefill
|
| 557 |
+
# + multi-step scheduling yet.
|
| 558 |
+
assert self.decode_query_len == 1
|
| 559 |
+
self.num_decode_tokens += self.num_prefills
|
| 560 |
+
self.num_prefills = 0
|
| 561 |
+
self.num_prefill_tokens = 0
|
| 562 |
+
self.max_prefill_seq_len = 0
|
| 563 |
+
self.max_query_len = 1
|
| 564 |
+
|
| 565 |
+
self.slot_mapping = self.slot_mapping[:num_seqs]
|
| 566 |
+
else:
|
| 567 |
+
assert self.seq_lens_tensor is not None
|
| 568 |
+
|
| 569 |
+
assert num_seqs > 0
|
| 570 |
+
assert num_queries > 0
|
| 571 |
+
assert model_input.attn_metadata is not None
|
| 572 |
+
assert sampled_token_ids is not None
|
| 573 |
+
|
| 574 |
+
# When using cudagraph, the num_seqs is padded to the next captured
|
| 575 |
+
# batch sized, but num_queries tracks the actual number of requests in
|
| 576 |
+
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
| 577 |
+
if num_seqs != num_queries:
|
| 578 |
+
assert num_seqs > num_queries
|
| 579 |
+
assert self.use_cuda_graph
|
| 580 |
+
|
| 581 |
+
model_input.input_tokens[:num_queries] = sampled_token_ids.flatten()
|
| 582 |
+
|
| 583 |
+
# Update GPU tensors
|
| 584 |
+
ops.advance_step_flashinfer(
|
| 585 |
+
num_seqs=num_seqs,
|
| 586 |
+
num_queries=num_queries,
|
| 587 |
+
block_size=block_size,
|
| 588 |
+
input_tokens=model_input.input_tokens,
|
| 589 |
+
sampled_token_ids=model_input.input_tokens,
|
| 590 |
+
input_positions=model_input.input_positions,
|
| 591 |
+
seq_lens=self.seq_lens_tensor,
|
| 592 |
+
slot_mapping=self.slot_mapping,
|
| 593 |
+
block_tables=self.block_tables,
|
| 594 |
+
paged_kv_indices=self.paged_kv_indices,
|
| 595 |
+
paged_kv_indptr=self.paged_kv_indptr,
|
| 596 |
+
paged_kv_last_page_len=self.paged_kv_last_page_len,
|
| 597 |
+
block_table_bound=self.block_table_bound)
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
| 601 |
+
|
| 602 |
+
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
| 603 |
+
|
| 604 |
+
self.input_builder = input_builder
|
| 605 |
+
self.runner = input_builder.runner
|
| 606 |
+
|
| 607 |
+
self.sliding_window = input_builder.sliding_window
|
| 608 |
+
self.block_size = input_builder.block_size
|
| 609 |
+
|
| 610 |
+
# Global hyperparameters shared by all attention layers
|
| 611 |
+
self.global_hyperparameters: Optional[PerLayerParameters] = None
|
| 612 |
+
|
| 613 |
+
self.vllm_config = get_current_vllm_config()
|
| 614 |
+
|
| 615 |
+
def prepare(self):
|
| 616 |
+
self.slot_mapping: List[int] = []
|
| 617 |
+
self.prefill_seq_lens: List[int] = []
|
| 618 |
+
self.context_lens: List[int] = []
|
| 619 |
+
self.block_tables: List[List[int]] = []
|
| 620 |
+
self.curr_seq_lens: List[int] = []
|
| 621 |
+
self.multimodal_placeholder_maps: Dict[
|
| 622 |
+
str,
|
| 623 |
+
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
| 624 |
+
self.num_prefills = 0
|
| 625 |
+
self.num_prefill_tokens = 0
|
| 626 |
+
self.num_decode_tokens = 0
|
| 627 |
+
|
| 628 |
+
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
|
| 629 |
+
# for the precise definition of the following fields.
|
| 630 |
+
# An example:
|
| 631 |
+
# request 1, page indices [0, 5, 8]
|
| 632 |
+
# request 2, page indices [1, 6, 7]
|
| 633 |
+
# request 3, page indices [3, 4]
|
| 634 |
+
# paged_kv_indices is a concatenation of page indices of all requests:
|
| 635 |
+
# [0, 5, 8, 1, 6, 7, 3, 4]
|
| 636 |
+
# paged_kv_indptr is used to index into paged_kv_indices:
|
| 637 |
+
# [0, 3, 6, 8]
|
| 638 |
+
self.paged_kv_indices: List[int] = []
|
| 639 |
+
# 0 at the beginning of paged_kv_indptr indicates the start of the
|
| 640 |
+
# first request’s page indices in the paged_kv_indices list.
|
| 641 |
+
self.paged_kv_indptr: List[int] = [0]
|
| 642 |
+
# paged_kv_last_page_len is the length of the last page of each request
|
| 643 |
+
self.paged_kv_last_page_len: List[int] = []
|
| 644 |
+
self.total_blocks = 0
|
| 645 |
+
self.is_profile_run: bool = False
|
| 646 |
+
|
| 647 |
+
if self.global_hyperparameters is None:
|
| 648 |
+
# Infer global hyperparameters, since currently we only support
|
| 649 |
+
# models in which all layers share the same values for the
|
| 650 |
+
# following hyperparameters:
|
| 651 |
+
# - `window_left`
|
| 652 |
+
# - `logits_soft_cap`
|
| 653 |
+
# - `sm_scale`
|
| 654 |
+
inferred_params = infer_global_hyperparameters(
|
| 655 |
+
get_per_layer_parameters(self.vllm_config))
|
| 656 |
+
self.global_hyperparameters = inferred_params
|
| 657 |
+
self.window_left = inferred_params.window_left
|
| 658 |
+
self.logits_soft_cap = inferred_params.logits_soft_cap
|
| 659 |
+
self.sm_scale = inferred_params.sm_scale
|
| 660 |
+
|
| 661 |
+
def _add_seq_group(
|
| 662 |
+
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
| 663 |
+
chunked_prefill_enabled: bool):
|
| 664 |
+
"""Add a sequence group to the metadata. Specifically update/append
|
| 665 |
+
1. context length.
|
| 666 |
+
2. block table.
|
| 667 |
+
3. slot mapping.
|
| 668 |
+
"""
|
| 669 |
+
is_prompt = inter_data.is_prompt
|
| 670 |
+
block_tables = inter_data.block_tables
|
| 671 |
+
computed_block_nums = inter_data.computed_block_nums
|
| 672 |
+
|
| 673 |
+
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
| 674 |
+
curr_sliding_window_block) in zip(
|
| 675 |
+
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
| 676 |
+
inter_data.orig_seq_lens, inter_data.seq_lens,
|
| 677 |
+
inter_data.query_lens, inter_data.context_lens,
|
| 678 |
+
inter_data.curr_sliding_window_blocks):
|
| 679 |
+
self.context_lens.append(context_len)
|
| 680 |
+
if is_prompt:
|
| 681 |
+
mm_maps = inter_data.multi_modal_placeholder_maps
|
| 682 |
+
if mm_maps:
|
| 683 |
+
for modality, placeholders in mm_maps.items():
|
| 684 |
+
self.multimodal_placeholder_maps[modality].extend(
|
| 685 |
+
placeholders)
|
| 686 |
+
self.num_prefills += 1
|
| 687 |
+
self.num_prefill_tokens += token_len
|
| 688 |
+
self.prefill_seq_lens.append(seq_len)
|
| 689 |
+
else:
|
| 690 |
+
assert query_len == 1, (
|
| 691 |
+
"seq_len: {}, context_len: {}, query_len: {}".format(
|
| 692 |
+
seq_len, context_len, query_len))
|
| 693 |
+
self.num_decode_tokens += query_len
|
| 694 |
+
self.curr_seq_lens.append(curr_seq_len)
|
| 695 |
+
|
| 696 |
+
# Compute block table.
|
| 697 |
+
# TODO(sang): Combine chunked prefill and prefix caching by
|
| 698 |
+
# only allowing multiple of block_size chunk size.
|
| 699 |
+
# NOTE: This only works for oooooooxxx style attention.
|
| 700 |
+
block_table = []
|
| 701 |
+
if inter_data.prefix_cache_hit:
|
| 702 |
+
block_table = computed_block_nums
|
| 703 |
+
elif ((chunked_prefill_enabled or not is_prompt)
|
| 704 |
+
and block_tables is not None):
|
| 705 |
+
block_table = block_tables[seq_id][-curr_sliding_window_block:]
|
| 706 |
+
self.block_tables.append(block_table)
|
| 707 |
+
|
| 708 |
+
is_profile_run = is_block_tables_empty(block_tables)
|
| 709 |
+
|
| 710 |
+
# Compute slot mapping.
|
| 711 |
+
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
| 712 |
+
context_len,
|
| 713 |
+
self.sliding_window)
|
| 714 |
+
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
| 715 |
+
seq_len, context_len, start_idx,
|
| 716 |
+
self.block_size, inter_data.block_tables)
|
| 717 |
+
|
| 718 |
+
# It is not necessary to add paged_kv_indices, paged_kv_indptr,
|
| 719 |
+
# and paged_kv_last_page_len for profile run because we will
|
| 720 |
+
# create dummy inputs.
|
| 721 |
+
if is_profile_run:
|
| 722 |
+
self.is_profile_run = is_profile_run
|
| 723 |
+
return
|
| 724 |
+
|
| 725 |
+
block_table = block_tables[seq_id]
|
| 726 |
+
self._update_paged_kv_tensors(block_table, seq_len)
|
| 727 |
+
|
| 728 |
+
def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
|
| 729 |
+
# Get the number of valid blocks based on sequence length.
|
| 730 |
+
# If seq_len = 16, block_size = 16,
|
| 731 |
+
# block_table_bound is 1 with 1 valid block.
|
| 732 |
+
# If seq_len = 15, block_size = 16,
|
| 733 |
+
# block_table_bound is 0 + 1 with 1 valid block.
|
| 734 |
+
self.total_blocks += len(block_table)
|
| 735 |
+
block_table_bound = seq_len // self.block_size + 1 \
|
| 736 |
+
if seq_len % self.block_size != 0 \
|
| 737 |
+
else seq_len // self.block_size
|
| 738 |
+
self.paged_kv_indices.extend(block_table[:block_table_bound])
|
| 739 |
+
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
|
| 740 |
+
block_table_bound)
|
| 741 |
+
|
| 742 |
+
last_page_len = seq_len % self.block_size
|
| 743 |
+
if last_page_len == 0:
|
| 744 |
+
last_page_len = self.block_size
|
| 745 |
+
self.paged_kv_last_page_len.append(last_page_len)
|
| 746 |
+
|
| 747 |
+
def build(self, seq_lens: List[int], query_lens: List[int],
|
| 748 |
+
cuda_graph_pad_size: int, batch_size: int):
|
| 749 |
+
"""Build attention metadata with on-device tensors.
|
| 750 |
+
|
| 751 |
+
Args:
|
| 752 |
+
seq_lens: The maybe padded sequence lengths of the input sequences.
|
| 753 |
+
query_lens: The query lengths of the input sequences.
|
| 754 |
+
cuda_graph_pad_size: The padding size for cuda graph.
|
| 755 |
+
-1 if cuda graph is not used.
|
| 756 |
+
batch_size: The maybe padded batch size.
|
| 757 |
+
"""
|
| 758 |
+
for inter_data in self.input_builder.inter_data_list:
|
| 759 |
+
self._add_seq_group(inter_data,
|
| 760 |
+
self.input_builder.chunked_prefill_enabled)
|
| 761 |
+
|
| 762 |
+
device = self.runner.device
|
| 763 |
+
use_captured_graph = cuda_graph_pad_size != -1
|
| 764 |
+
|
| 765 |
+
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
| 766 |
+
num_decode_tokens = self.num_decode_tokens
|
| 767 |
+
decode_query_len = max(query_lens[self.num_prefills:], default=1)
|
| 768 |
+
|
| 769 |
+
if use_captured_graph:
|
| 770 |
+
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
| 771 |
+
self.block_tables.extend([] * cuda_graph_pad_size)
|
| 772 |
+
num_decode_tokens = batch_size - self.num_prefill_tokens
|
| 773 |
+
|
| 774 |
+
# The shape of graph_block_tables is
|
| 775 |
+
# [max batch size, max context len // block size].
|
| 776 |
+
input_block_tables = self.runner.graph_block_tables[:batch_size]
|
| 777 |
+
max_blocks = input_block_tables.shape[1]
|
| 778 |
+
for i, block_table in enumerate(self.block_tables):
|
| 779 |
+
if block_table:
|
| 780 |
+
num_blocks = len(block_table)
|
| 781 |
+
if num_blocks <= max_blocks:
|
| 782 |
+
input_block_tables[i, :num_blocks] = block_table
|
| 783 |
+
else:
|
| 784 |
+
# It may be possible to have more blocks allocated due
|
| 785 |
+
# to lookahead slots of multi-step, however, they are
|
| 786 |
+
# not used anyway, so can be safely ignored.
|
| 787 |
+
input_block_tables[
|
| 788 |
+
i, :max_blocks] = block_table[:max_blocks]
|
| 789 |
+
|
| 790 |
+
block_tables = torch.from_numpy(input_block_tables).to(
|
| 791 |
+
device, non_blocking=True)
|
| 792 |
+
|
| 793 |
+
last_paged_kv_indptr = self.paged_kv_indptr[-1]
|
| 794 |
+
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
|
| 795 |
+
cuda_graph_pad_size)
|
| 796 |
+
self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
|
| 797 |
+
else:
|
| 798 |
+
block_tables = make_tensor_with_pad(
|
| 799 |
+
self.block_tables,
|
| 800 |
+
pad=0,
|
| 801 |
+
dtype=torch.int,
|
| 802 |
+
device=device,
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
assert device is not None
|
| 806 |
+
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
| 807 |
+
self.runner.pin_memory)
|
| 808 |
+
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
|
| 809 |
+
self.runner.pin_memory)
|
| 810 |
+
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
| 811 |
+
device, self.runner.pin_memory)
|
| 812 |
+
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
| 813 |
+
dtype=torch.int32,
|
| 814 |
+
device=device)
|
| 815 |
+
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
| 816 |
+
dtype=torch.int32,
|
| 817 |
+
device=device)
|
| 818 |
+
placeholder_index_maps = {
|
| 819 |
+
modality: placeholder_map.index_map()
|
| 820 |
+
for modality, placeholder_map in
|
| 821 |
+
self.multimodal_placeholder_maps.items()
|
| 822 |
+
}
|
| 823 |
+
torch.cumsum(seq_lens_tensor,
|
| 824 |
+
dim=0,
|
| 825 |
+
dtype=seq_start_loc.dtype,
|
| 826 |
+
out=seq_start_loc[1:])
|
| 827 |
+
torch.cumsum(query_lens_tensor,
|
| 828 |
+
dim=0,
|
| 829 |
+
dtype=query_start_loc.dtype,
|
| 830 |
+
out=query_start_loc[1:])
|
| 831 |
+
|
| 832 |
+
if len(self.paged_kv_indptr) > 0:
|
| 833 |
+
# extend to the maximum number of blocks as returned by the
|
| 834 |
+
# scheduler
|
| 835 |
+
self.paged_kv_indices.extend(
|
| 836 |
+
[0] * (self.total_blocks - len(self.paged_kv_indices)))
|
| 837 |
+
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
|
| 838 |
+
device="cpu",
|
| 839 |
+
dtype=torch.int)
|
| 840 |
+
paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
|
| 841 |
+
device="cpu",
|
| 842 |
+
dtype=torch.int)
|
| 843 |
+
paged_kv_last_page_len_tensor = torch.tensor(
|
| 844 |
+
self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
|
| 845 |
+
block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
|
| 846 |
+
1,
|
| 847 |
+
device="cpu",
|
| 848 |
+
dtype=torch.int)
|
| 849 |
+
else:
|
| 850 |
+
paged_kv_indices_tensor = None
|
| 851 |
+
paged_kv_indptr_tensor = None
|
| 852 |
+
paged_kv_last_page_len_tensor = None
|
| 853 |
+
block_table_bound_tensor = None
|
| 854 |
+
|
| 855 |
+
if self.runner.kv_cache_dtype.startswith("fp8"):
|
| 856 |
+
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
| 857 |
+
self.runner.kv_cache_dtype)
|
| 858 |
+
else:
|
| 859 |
+
kv_cache_dtype = get_kv_cache_torch_dtype(
|
| 860 |
+
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
|
| 861 |
+
|
| 862 |
+
return FlashInferMetadata(
|
| 863 |
+
decode_query_len=decode_query_len,
|
| 864 |
+
num_prefills=self.num_prefills,
|
| 865 |
+
slot_mapping=slot_mapping_tensor,
|
| 866 |
+
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
| 867 |
+
enable_kv_scales_calculation=False,
|
| 868 |
+
num_prefill_tokens=self.num_prefill_tokens,
|
| 869 |
+
num_decode_tokens=num_decode_tokens,
|
| 870 |
+
max_prefill_seq_len=max_prefill_seq_len,
|
| 871 |
+
block_tables=block_tables,
|
| 872 |
+
paged_kv_indptr=paged_kv_indptr_tensor,
|
| 873 |
+
paged_kv_indices=paged_kv_indices_tensor,
|
| 874 |
+
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
|
| 875 |
+
block_table_bound=block_table_bound_tensor,
|
| 876 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 877 |
+
num_qo_heads=self.runner.model_config.get_num_attention_heads(
|
| 878 |
+
self.runner.parallel_config),
|
| 879 |
+
num_kv_heads=self.runner.model_config.get_num_kv_heads(
|
| 880 |
+
self.runner.parallel_config),
|
| 881 |
+
head_dim=self.runner.model_config.get_head_size(),
|
| 882 |
+
page_size=self.block_size,
|
| 883 |
+
seq_start_loc=seq_start_loc,
|
| 884 |
+
query_start_loc=query_start_loc,
|
| 885 |
+
device=device,
|
| 886 |
+
data_type=kv_cache_dtype,
|
| 887 |
+
q_data_type=self.runner.model_config.dtype,
|
| 888 |
+
use_cuda_graph=use_captured_graph,
|
| 889 |
+
is_profile_run=self.is_profile_run,
|
| 890 |
+
window_left=self.window_left,
|
| 891 |
+
logits_soft_cap=self.logits_soft_cap,
|
| 892 |
+
sm_scale=self.sm_scale,
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
class FlashInferImpl(AttentionImpl):
|
| 897 |
+
|
| 898 |
+
def __init__(
|
| 899 |
+
self,
|
| 900 |
+
num_heads: int,
|
| 901 |
+
head_size: int,
|
| 902 |
+
scale: float,
|
| 903 |
+
num_kv_heads: int,
|
| 904 |
+
alibi_slopes: Optional[List[float]],
|
| 905 |
+
sliding_window: Optional[int],
|
| 906 |
+
kv_cache_dtype: str,
|
| 907 |
+
blocksparse_params: Optional[Dict[str, Any]] = None,
|
| 908 |
+
logits_soft_cap: Optional[float] = None,
|
| 909 |
+
attn_type: str = AttentionType.DECODER,
|
| 910 |
+
) -> None:
|
| 911 |
+
self.num_heads = num_heads
|
| 912 |
+
self.head_size = head_size
|
| 913 |
+
self.scale = float(scale)
|
| 914 |
+
self.num_kv_heads = num_kv_heads
|
| 915 |
+
if alibi_slopes is not None:
|
| 916 |
+
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
| 917 |
+
self.alibi_slopes = alibi_slopes
|
| 918 |
+
self.sliding_window = ((sliding_window - 1,
|
| 919 |
+
0) if sliding_window is not None else (-1, -1))
|
| 920 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 921 |
+
self.logits_soft_cap = logits_soft_cap
|
| 922 |
+
|
| 923 |
+
assert self.num_heads % self.num_kv_heads == 0
|
| 924 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 925 |
+
|
| 926 |
+
if attn_type != AttentionType.DECODER:
|
| 927 |
+
raise NotImplementedError("Encoder self-attention and "
|
| 928 |
+
"encoder/decoder cross-attention "
|
| 929 |
+
"are not implemented for "
|
| 930 |
+
"FlashInferImpl")
|
| 931 |
+
|
| 932 |
+
def forward(
|
| 933 |
+
self,
|
| 934 |
+
layer: AttentionLayer,
|
| 935 |
+
query: torch.Tensor,
|
| 936 |
+
key: torch.Tensor,
|
| 937 |
+
value: torch.Tensor,
|
| 938 |
+
kv_cache: torch.Tensor,
|
| 939 |
+
attn_metadata: FlashInferMetadata,
|
| 940 |
+
output: Optional[torch.Tensor] = None,
|
| 941 |
+
) -> torch.Tensor:
|
| 942 |
+
|
| 943 |
+
# TODO: directly write to output tensor
|
| 944 |
+
num_heads: int = self.num_heads
|
| 945 |
+
head_size: int = self.head_size
|
| 946 |
+
num_kv_heads: int = self.num_kv_heads
|
| 947 |
+
kv_cache_dtype: str = self.kv_cache_dtype
|
| 948 |
+
softmax_scale: float = self.scale
|
| 949 |
+
window_size = self.sliding_window
|
| 950 |
+
alibi_slopes = self.alibi_slopes
|
| 951 |
+
logits_soft_cap = self.logits_soft_cap
|
| 952 |
+
|
| 953 |
+
num_tokens, hidden_size = query.shape
|
| 954 |
+
query = query.view(-1, num_heads, head_size)
|
| 955 |
+
key = key.view(-1, num_kv_heads, head_size)
|
| 956 |
+
value = value.view(-1, num_kv_heads, head_size)
|
| 957 |
+
|
| 958 |
+
if kv_cache.numel() > 0:
|
| 959 |
+
# Use the same reshape and cache kernel as flash attention.
|
| 960 |
+
ops.reshape_and_cache_flash(
|
| 961 |
+
key,
|
| 962 |
+
value,
|
| 963 |
+
kv_cache[:, 0],
|
| 964 |
+
kv_cache[:, 1],
|
| 965 |
+
attn_metadata.slot_mapping.flatten(),
|
| 966 |
+
kv_cache_dtype,
|
| 967 |
+
layer._k_scale,
|
| 968 |
+
layer._v_scale,
|
| 969 |
+
)
|
| 970 |
+
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
| 971 |
+
# to process the cache when the kv_cache_dtype is fp8
|
| 972 |
+
if kv_cache_dtype.startswith("fp8"):
|
| 973 |
+
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
| 974 |
+
kv_cache_dtype)
|
| 975 |
+
kv_cache = kv_cache.view(torch_dtype)
|
| 976 |
+
|
| 977 |
+
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
| 978 |
+
num_decode_tokens = attn_metadata.num_decode_tokens
|
| 979 |
+
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
| 980 |
+
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
|
| 981 |
+
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
| 982 |
+
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
|
| 983 |
+
query = query.contiguous(
|
| 984 |
+
) # Flashinfer requires query to be contiguous
|
| 985 |
+
# Query for decode. KV is not needed because it is already cached.
|
| 986 |
+
# QKV for prefill.
|
| 987 |
+
decode_query = query[num_prefill_tokens:]
|
| 988 |
+
query = query[:num_prefill_tokens]
|
| 989 |
+
|
| 990 |
+
key = key[:num_prefill_tokens]
|
| 991 |
+
value = value[:num_prefill_tokens]
|
| 992 |
+
|
| 993 |
+
assert query.shape[0] == num_prefill_tokens
|
| 994 |
+
assert decode_query.shape[0] == num_decode_tokens
|
| 995 |
+
|
| 996 |
+
window_left = window_size[0] if window_size is not None else -1
|
| 997 |
+
|
| 998 |
+
prefill_output: Optional[torch.Tensor] = None
|
| 999 |
+
decode_output: Optional[torch.Tensor] = None
|
| 1000 |
+
if prefill_meta := attn_metadata.prefill_metadata:
|
| 1001 |
+
# We will use flash attention for prefill
|
| 1002 |
+
# when kv_cache is not provided.
|
| 1003 |
+
# This happens when vllm runs the profiling to
|
| 1004 |
+
# determine the number of blocks.
|
| 1005 |
+
if kv_cache.numel() == 0:
|
| 1006 |
+
prefill_output = flash_attn_varlen_func(
|
| 1007 |
+
q=query,
|
| 1008 |
+
k=key,
|
| 1009 |
+
v=value,
|
| 1010 |
+
cu_seqlens_q=prefill_meta.seq_start_loc,
|
| 1011 |
+
cu_seqlens_k=prefill_meta.seq_start_loc,
|
| 1012 |
+
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
| 1013 |
+
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
| 1014 |
+
softmax_scale=softmax_scale,
|
| 1015 |
+
causal=True,
|
| 1016 |
+
window_size=window_size,
|
| 1017 |
+
alibi_slopes=alibi_slopes,
|
| 1018 |
+
)
|
| 1019 |
+
else:
|
| 1020 |
+
assert prefill_meta is not None
|
| 1021 |
+
assert prefill_meta.prefill_wrapper is not None
|
| 1022 |
+
|
| 1023 |
+
assert prefill_meta.prefill_wrapper._causal
|
| 1024 |
+
assert prefill_meta.prefill_wrapper._window_left == window_left
|
| 1025 |
+
assert prefill_meta.prefill_wrapper._logits_soft_cap == (
|
| 1026 |
+
logits_soft_cap or 0.0)
|
| 1027 |
+
assert prefill_meta.prefill_wrapper._sm_scale == softmax_scale
|
| 1028 |
+
|
| 1029 |
+
prefill_output = prefill_meta.prefill_wrapper.run(
|
| 1030 |
+
query,
|
| 1031 |
+
kv_cache,
|
| 1032 |
+
k_scale=layer._k_scale_float,
|
| 1033 |
+
v_scale=layer._v_scale_float,
|
| 1034 |
+
)
|
| 1035 |
+
if decode_meta := attn_metadata.decode_metadata:
|
| 1036 |
+
assert decode_meta is not None
|
| 1037 |
+
assert decode_meta.decode_wrapper is not None
|
| 1038 |
+
|
| 1039 |
+
assert decode_meta.decode_wrapper._window_left == window_left
|
| 1040 |
+
assert decode_meta.decode_wrapper._logits_soft_cap == (
|
| 1041 |
+
logits_soft_cap or 0.0)
|
| 1042 |
+
assert decode_meta.decode_wrapper._sm_scale == softmax_scale
|
| 1043 |
+
|
| 1044 |
+
decode_output = decode_meta.decode_wrapper.run(
|
| 1045 |
+
decode_query,
|
| 1046 |
+
kv_cache,
|
| 1047 |
+
k_scale=layer._k_scale_float,
|
| 1048 |
+
v_scale=layer._v_scale_float,
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
if prefill_output is None and decode_output is not None:
|
| 1052 |
+
# Decode only batch.
|
| 1053 |
+
output, num_tokens = decode_output, num_decode_tokens
|
| 1054 |
+
elif decode_output is None and prefill_output is not None:
|
| 1055 |
+
# Prefill only batch.
|
| 1056 |
+
output, num_tokens = prefill_output, num_prefill_tokens
|
| 1057 |
+
else:
|
| 1058 |
+
# Chunked prefill batch does not work with speculative decoding in
|
| 1059 |
+
# FlashInfer backend, so the query length for decode should be 1.
|
| 1060 |
+
assert prefill_output is not None
|
| 1061 |
+
assert decode_output is not None
|
| 1062 |
+
assert decode_meta is not None
|
| 1063 |
+
assert decode_meta.decode_query_len == 1
|
| 1064 |
+
decode_output = decode_output.squeeze(1)
|
| 1065 |
+
output = torch.cat([prefill_output, decode_output], dim=0)
|
| 1066 |
+
return output.view(num_tokens, hidden_size)
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/hpu_attn.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
###############################################################################
|
| 4 |
+
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
|
| 5 |
+
###############################################################################
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import vllm_hpu_extension.ops as ops
|
| 13 |
+
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
|
| 14 |
+
VLLMKVCache)
|
| 15 |
+
|
| 16 |
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
| 17 |
+
AttentionLayer,
|
| 18 |
+
AttentionMetadata, AttentionType)
|
| 19 |
+
from vllm.attention.backends.utils import CommonAttentionState
|
| 20 |
+
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
|
| 21 |
+
HPUPagedAttentionMetadata)
|
| 22 |
+
from vllm.logger import init_logger
|
| 23 |
+
|
| 24 |
+
logger = init_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class HPUAttentionBackend(AttentionBackend):
|
| 28 |
+
|
| 29 |
+
@staticmethod
|
| 30 |
+
def get_name() -> str:
|
| 31 |
+
return "HPU_ATTN"
|
| 32 |
+
|
| 33 |
+
@staticmethod
|
| 34 |
+
def get_impl_cls() -> Type["HPUAttentionImpl"]:
|
| 35 |
+
return HPUAttentionImpl
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
| 39 |
+
return HPUAttentionMetadata
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def get_state_cls() -> Type["CommonAttentionState"]:
|
| 43 |
+
return CommonAttentionState
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def get_kv_cache_shape(
|
| 47 |
+
num_blocks: int,
|
| 48 |
+
block_size: int,
|
| 49 |
+
num_kv_heads: int,
|
| 50 |
+
head_size: int,
|
| 51 |
+
) -> Tuple[int, ...]:
|
| 52 |
+
return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
| 53 |
+
num_kv_heads, head_size)
|
| 54 |
+
|
| 55 |
+
@staticmethod
|
| 56 |
+
def swap_blocks(
|
| 57 |
+
src_kv_cache: torch.Tensor,
|
| 58 |
+
dst_kv_cache: torch.Tensor,
|
| 59 |
+
src_to_dst: Dict[int, int],
|
| 60 |
+
) -> None:
|
| 61 |
+
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def copy_blocks(
|
| 65 |
+
kv_caches: List[torch.Tensor],
|
| 66 |
+
src_to_dists: Dict[int, List[int]],
|
| 67 |
+
) -> None:
|
| 68 |
+
HPUPagedAttention.copy_blocks(kv_caches, src_to_dists)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
|
| 73 |
+
"""Metadata for HPUAttentionbackend."""
|
| 74 |
+
# Currently, input sequences can only contain all prompts
|
| 75 |
+
# or all decoding. True if all sequences are prompts.
|
| 76 |
+
is_prompt: bool
|
| 77 |
+
attn_bias: Optional[torch.Tensor]
|
| 78 |
+
seq_lens_tensor: Optional[torch.Tensor]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
| 82 |
+
"""
|
| 83 |
+
If the input tensors contain prompt tokens, the layout is as follows:
|
| 84 |
+
|<--------------- num_prefill_tokens ----------------->|
|
| 85 |
+
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
|
| 86 |
+
|
| 87 |
+
Otherwise, the layout is as follows:
|
| 88 |
+
|<----------------- num_decode_tokens ------------------>|
|
| 89 |
+
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
|
| 90 |
+
|
| 91 |
+
Generation tokens can contain padding when cuda-graph is used.
|
| 92 |
+
Currently, prompt tokens don't contain any padding.
|
| 93 |
+
|
| 94 |
+
The prompts might have different lengths, while the generation tokens
|
| 95 |
+
always have length 1.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
num_heads: int,
|
| 101 |
+
head_size: int,
|
| 102 |
+
scale: float,
|
| 103 |
+
num_kv_heads: int,
|
| 104 |
+
alibi_slopes: Optional[List[float]],
|
| 105 |
+
sliding_window: Optional[int],
|
| 106 |
+
kv_cache_dtype: str,
|
| 107 |
+
blocksparse_params: Optional[Dict[str, Any]] = None,
|
| 108 |
+
max_seq_len: int = 4096,
|
| 109 |
+
attn_type: str = AttentionType.DECODER,
|
| 110 |
+
) -> None:
|
| 111 |
+
super(AttentionImpl, self).__init__()
|
| 112 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 113 |
+
self.num_heads = num_heads
|
| 114 |
+
self.head_size = head_size
|
| 115 |
+
self.scale = float(scale)
|
| 116 |
+
self.matmul_qk = Matmul()
|
| 117 |
+
self.softmax = Softmax()
|
| 118 |
+
self.matmul_av = Matmul()
|
| 119 |
+
self.batch2block_matmul = Matmul()
|
| 120 |
+
self.block2batch_matmul = Matmul()
|
| 121 |
+
# NOTE(kzawora): Contiguous PA is off until model runner supports it
|
| 122 |
+
self.k_cache = VLLMKVCache()
|
| 123 |
+
self.k_cache.use_contiguous_pa = False
|
| 124 |
+
self.v_cache = VLLMKVCache()
|
| 125 |
+
self.v_cache.use_contiguous_pa = False
|
| 126 |
+
# NOTE(kzawora): Pipelined PA is off until model runner supports it
|
| 127 |
+
ops.pa_impl = ops.pa
|
| 128 |
+
|
| 129 |
+
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
| 130 |
+
self.sliding_window = sliding_window
|
| 131 |
+
self.alibi_slopes = alibi_slopes
|
| 132 |
+
if alibi_slopes is not None:
|
| 133 |
+
alibi_slopes_tensor = torch.tensor(alibi_slopes,
|
| 134 |
+
dtype=torch.bfloat16)
|
| 135 |
+
self.alibi_slopes = alibi_slopes_tensor
|
| 136 |
+
assert self.num_heads % self.num_kv_heads == 0
|
| 137 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 138 |
+
|
| 139 |
+
self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
|
| 140 |
+
'0').lower() in ['1', 'true']
|
| 141 |
+
self.fused_scaled_dot_product_attention = None
|
| 142 |
+
if self.prefill_usefusedsdpa:
|
| 143 |
+
assert alibi_slopes is None, \
|
| 144 |
+
'Prefill with FusedSDPA not supported with alibi slopes!'
|
| 145 |
+
try:
|
| 146 |
+
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
| 147 |
+
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(
|
| 148 |
+
FusedSDPA)
|
| 149 |
+
except ImportError:
|
| 150 |
+
logger().warning("Could not import HPU FusedSDPA kernel. "
|
| 151 |
+
"vLLM will use native implementation.")
|
| 152 |
+
|
| 153 |
+
suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
|
| 154 |
+
if head_size not in suppored_head_sizes:
|
| 155 |
+
raise ValueError(
|
| 156 |
+
f"Head size {head_size} is not supported by PagedAttention. "
|
| 157 |
+
f"Supported head sizes are: {suppored_head_sizes}.")
|
| 158 |
+
|
| 159 |
+
if attn_type != AttentionType.DECODER:
|
| 160 |
+
raise NotImplementedError("Encoder self-attention and "
|
| 161 |
+
"encoder/decoder cross-attention "
|
| 162 |
+
"are not implemented for "
|
| 163 |
+
"HPUAttentionImpl")
|
| 164 |
+
|
| 165 |
+
def forward(
|
| 166 |
+
self,
|
| 167 |
+
layer: AttentionLayer,
|
| 168 |
+
query: torch.Tensor,
|
| 169 |
+
key: torch.Tensor,
|
| 170 |
+
value: torch.Tensor,
|
| 171 |
+
kv_cache: torch.Tensor,
|
| 172 |
+
attn_metadata: HPUAttentionMetadata,
|
| 173 |
+
output: Optional[torch.Tensor] = None,
|
| 174 |
+
) -> torch.Tensor:
|
| 175 |
+
"""Forward pass with xFormers and PagedAttention.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
query: shape = [num_tokens, num_heads * head_size]
|
| 179 |
+
key: shape = [num_tokens, num_kv_heads * head_size]
|
| 180 |
+
value: shape = [num_tokens, num_kv_heads * head_size]
|
| 181 |
+
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
| 182 |
+
attn_metadata: Metadata for attention.
|
| 183 |
+
Returns:
|
| 184 |
+
shape = [num_tokens, num_heads * head_size]
|
| 185 |
+
"""
|
| 186 |
+
batch_size, seq_len, hidden_size = query.shape
|
| 187 |
+
_, seq_len_kv, _ = key.shape
|
| 188 |
+
|
| 189 |
+
query = query.view(-1, self.num_heads, self.head_size)
|
| 190 |
+
key = key.view(-1, self.num_kv_heads, self.head_size)
|
| 191 |
+
value = value.view(-1, self.num_kv_heads, self.head_size)
|
| 192 |
+
block_indices = attn_metadata.block_indices
|
| 193 |
+
block_offsets = attn_metadata.block_offsets
|
| 194 |
+
if attn_metadata.is_prompt:
|
| 195 |
+
key = key.unflatten(0, (block_indices.size(0), -1))
|
| 196 |
+
value = value.unflatten(0, (block_indices.size(0), -1))
|
| 197 |
+
if kv_cache is not None:
|
| 198 |
+
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
|
| 199 |
+
kv_cache, self.num_kv_heads, self.head_size)
|
| 200 |
+
|
| 201 |
+
# Reshape the input keys and values and store them in the cache.
|
| 202 |
+
# If kv_cache is not provided, the new key and value tensors are
|
| 203 |
+
# not cached. This happens during the initial memory profiling run.
|
| 204 |
+
key_cache = self.k_cache(key, key_cache, block_indices,
|
| 205 |
+
block_offsets)
|
| 206 |
+
value_cache = self.v_cache(value, value_cache, block_indices,
|
| 207 |
+
block_offsets)
|
| 208 |
+
|
| 209 |
+
if attn_metadata.is_prompt:
|
| 210 |
+
# Prompt run.
|
| 211 |
+
if not self.prefill_usefusedsdpa:
|
| 212 |
+
# TODO: move this outside of model
|
| 213 |
+
assert attn_metadata.attn_bias is not None, \
|
| 214 |
+
'attn_bias must be set before calling model.forward!'
|
| 215 |
+
attn_bias = attn_metadata.attn_bias
|
| 216 |
+
if self.alibi_slopes is not None:
|
| 217 |
+
position_bias = _make_alibi_bias(self.alibi_slopes,
|
| 218 |
+
self.num_kv_heads,
|
| 219 |
+
attn_bias.dtype,
|
| 220 |
+
attn_bias.shape[-1])
|
| 221 |
+
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
|
| 222 |
+
attn_bias.add_(position_bias)
|
| 223 |
+
else:
|
| 224 |
+
attn_bias = None
|
| 225 |
+
|
| 226 |
+
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
|
| 227 |
+
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
|
| 228 |
+
self.head_size)
|
| 229 |
+
out = ops.prompt_attention(
|
| 230 |
+
query.view(query_shape),
|
| 231 |
+
key.view(kv_shape),
|
| 232 |
+
value.view(kv_shape),
|
| 233 |
+
attn_bias=attn_bias,
|
| 234 |
+
p=0.0,
|
| 235 |
+
scale=self.scale,
|
| 236 |
+
matmul_qk_op=self.matmul_qk,
|
| 237 |
+
softmax_op=self.softmax,
|
| 238 |
+
matmul_av_op=self.matmul_av,
|
| 239 |
+
fsdpa_op=self.fused_scaled_dot_product_attention,
|
| 240 |
+
)
|
| 241 |
+
output = out.reshape(batch_size, seq_len, hidden_size)
|
| 242 |
+
else:
|
| 243 |
+
# Decoding run.
|
| 244 |
+
output = HPUPagedAttention.forward_decode(
|
| 245 |
+
query=query,
|
| 246 |
+
key_cache=key_cache,
|
| 247 |
+
value_cache=value_cache,
|
| 248 |
+
block_list=attn_metadata.block_list,
|
| 249 |
+
block_mapping=attn_metadata.block_mapping,
|
| 250 |
+
block_bias=attn_metadata.attn_bias,
|
| 251 |
+
block_scales=attn_metadata.block_scales,
|
| 252 |
+
block_groups=None,
|
| 253 |
+
scale=self.scale,
|
| 254 |
+
matmul_qk_op=self.matmul_qk,
|
| 255 |
+
matmul_av_op=self.matmul_av,
|
| 256 |
+
batch2block_matmul_op=self.batch2block_matmul,
|
| 257 |
+
block2batch_matmul_op=self.block2batch_matmul,
|
| 258 |
+
keys_fetch_func=self.k_cache.fetch_from_cache,
|
| 259 |
+
values_fetch_func=self.v_cache.fetch_from_cache)
|
| 260 |
+
# Reshape the output tensor.
|
| 261 |
+
return output.view(batch_size, seq_len, hidden_size)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def _make_alibi_bias(
|
| 265 |
+
alibi_slopes: torch.Tensor,
|
| 266 |
+
num_kv_heads: int,
|
| 267 |
+
dtype: torch.dtype,
|
| 268 |
+
seq_len: int,
|
| 269 |
+
) -> torch.Tensor:
|
| 270 |
+
bias = torch.arange(seq_len, dtype=dtype)
|
| 271 |
+
# NOTE(zhuohan): HF uses
|
| 272 |
+
# `bias = bias[None, :].repeat(seq_len, 1)`
|
| 273 |
+
# here. We find that both biases give the same results, but
|
| 274 |
+
# the bias below more accurately follows the original ALiBi
|
| 275 |
+
# paper.
|
| 276 |
+
# Calculate a matrix where each element represents ith element- jth
|
| 277 |
+
# element.
|
| 278 |
+
bias = bias[None, :] - bias[:, None]
|
| 279 |
+
|
| 280 |
+
padded_len = (seq_len + 7) // 8 * 8
|
| 281 |
+
num_heads = alibi_slopes.shape[0]
|
| 282 |
+
bias = torch.empty(
|
| 283 |
+
1, # batch size
|
| 284 |
+
num_heads,
|
| 285 |
+
seq_len,
|
| 286 |
+
padded_len,
|
| 287 |
+
device=alibi_slopes.device,
|
| 288 |
+
dtype=dtype,
|
| 289 |
+
)[:, :, :, :seq_len].copy_(bias)
|
| 290 |
+
bias.mul_(alibi_slopes[:, None, None])
|
| 291 |
+
if num_heads != num_kv_heads:
|
| 292 |
+
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
|
| 293 |
+
return bias
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/ipex_attn.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
""" Attention layer with torch scaled_dot_product_attention
|
| 3 |
+
and PagedAttention."""
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from vllm._ipex_ops import ipex_ops
|
| 10 |
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
| 11 |
+
AttentionLayer,
|
| 12 |
+
AttentionMetadata, AttentionType)
|
| 13 |
+
from vllm.attention.backends.utils import CommonAttentionState
|
| 14 |
+
from vllm.attention.ops.paged_attn import (PagedAttention,
|
| 15 |
+
PagedAttentionMetadata)
|
| 16 |
+
|
| 17 |
+
_PARTITION_SIZE = 512
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class IpexAttnBackend(AttentionBackend):
|
| 21 |
+
|
| 22 |
+
@staticmethod
|
| 23 |
+
def get_name() -> str:
|
| 24 |
+
return "IPEX"
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
|
| 28 |
+
return IpexAttnBackendImpl
|
| 29 |
+
|
| 30 |
+
@staticmethod
|
| 31 |
+
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
|
| 32 |
+
return IpexAttnMetadata
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def get_state_cls() -> Type["CommonAttentionState"]:
|
| 36 |
+
return CommonAttentionState
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def get_kv_cache_shape(
|
| 40 |
+
num_blocks: int,
|
| 41 |
+
block_size: int,
|
| 42 |
+
num_kv_heads: int,
|
| 43 |
+
head_size: int,
|
| 44 |
+
) -> Tuple[int, ...]:
|
| 45 |
+
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
| 46 |
+
num_kv_heads, head_size)
|
| 47 |
+
|
| 48 |
+
@staticmethod
|
| 49 |
+
def swap_blocks(
|
| 50 |
+
src_kv_cache: torch.Tensor,
|
| 51 |
+
dst_kv_cache: torch.Tensor,
|
| 52 |
+
src_to_dst: torch.Tensor,
|
| 53 |
+
) -> None:
|
| 54 |
+
from vllm._ipex_ops import ipex_ops as ops
|
| 55 |
+
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
| 56 |
+
|
| 57 |
+
@staticmethod
|
| 58 |
+
def copy_blocks(
|
| 59 |
+
kv_caches: List[torch.Tensor],
|
| 60 |
+
src_to_dists: torch.Tensor,
|
| 61 |
+
) -> None:
|
| 62 |
+
from vllm._ipex_ops import ipex_ops as ops
|
| 63 |
+
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
| 64 |
+
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
| 65 |
+
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@dataclass
|
| 69 |
+
class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata):
|
| 70 |
+
"""Metadata for IpexAttnBackend.
|
| 71 |
+
"""
|
| 72 |
+
# Currently, input sequences can only contain all prompts
|
| 73 |
+
# or all decoding. True if all sequences are prompts.
|
| 74 |
+
is_prompt: bool
|
| 75 |
+
slot_mapping: torch.Tensor
|
| 76 |
+
seq_lens: Optional[List[int]]
|
| 77 |
+
seqlen_q: Optional[torch.Tensor]
|
| 78 |
+
max_seqlen: Optional[int]
|
| 79 |
+
|
| 80 |
+
def __post_init__(self):
|
| 81 |
+
# Set during the execution of the first attention op.
|
| 82 |
+
# It is a list because it is needed to set per prompt
|
| 83 |
+
# when alibi slopes is used. It is because of the limitation
|
| 84 |
+
# from xformer API.
|
| 85 |
+
# will not appear in the __repr__ and __init__
|
| 86 |
+
self.attn_bias: Optional[List[torch.Tensor]] = None
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def prefill_metadata(self) -> Optional["IpexAttnMetadata"]:
|
| 90 |
+
# Currently chunked prefill is not supported
|
| 91 |
+
if self.num_decode_tokens == 0:
|
| 92 |
+
assert self.num_prefills > 0
|
| 93 |
+
return self
|
| 94 |
+
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
@property
|
| 98 |
+
def decode_metadata(self) -> Optional["IpexAttnMetadata"]:
|
| 99 |
+
# Currently chunked prefill is not supported
|
| 100 |
+
if self.num_prefills > 0:
|
| 101 |
+
assert self.num_decode_tokens == 0
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
return self
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
| 108 |
+
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
num_heads: int,
|
| 112 |
+
head_size: int,
|
| 113 |
+
scale: float,
|
| 114 |
+
num_kv_heads: int,
|
| 115 |
+
alibi_slopes: Optional[List[float]],
|
| 116 |
+
sliding_window: Optional[int],
|
| 117 |
+
kv_cache_dtype: str,
|
| 118 |
+
blocksparse_params: Optional[Dict[str, Any]] = None,
|
| 119 |
+
logits_soft_cap: Optional[float] = None,
|
| 120 |
+
attn_type: str = AttentionType.DECODER,
|
| 121 |
+
) -> None:
|
| 122 |
+
if blocksparse_params is not None:
|
| 123 |
+
raise ValueError(
|
| 124 |
+
"IPEX backend does not support block-sparse attention.")
|
| 125 |
+
self.num_heads = num_heads
|
| 126 |
+
self.head_size = head_size
|
| 127 |
+
self.scale = float(scale)
|
| 128 |
+
self.num_kv_heads = num_kv_heads
|
| 129 |
+
if alibi_slopes is not None:
|
| 130 |
+
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
| 131 |
+
self.alibi_slopes = alibi_slopes
|
| 132 |
+
self.sliding_window = sliding_window
|
| 133 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 134 |
+
|
| 135 |
+
assert self.num_heads % self.num_kv_heads == 0
|
| 136 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 137 |
+
self.need_mask = (self.alibi_slopes is not None
|
| 138 |
+
or self.sliding_window is not None)
|
| 139 |
+
if logits_soft_cap is None:
|
| 140 |
+
logits_soft_cap = 0
|
| 141 |
+
self.logits_soft_cap = logits_soft_cap
|
| 142 |
+
|
| 143 |
+
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
| 144 |
+
if head_size not in supported_head_sizes:
|
| 145 |
+
raise ValueError(
|
| 146 |
+
f"Head size {head_size} is not supported by PagedAttention. "
|
| 147 |
+
f"Supported head sizes are: {supported_head_sizes}.")
|
| 148 |
+
if kv_cache_dtype != "auto":
|
| 149 |
+
raise NotImplementedError(
|
| 150 |
+
"IPEX backend does not support FP8 KV cache. "
|
| 151 |
+
"Please use xFormers backend instead.")
|
| 152 |
+
if attn_type != AttentionType.DECODER:
|
| 153 |
+
raise NotImplementedError("Encoder self-attention and "
|
| 154 |
+
"encoder/decoder cross-attention "
|
| 155 |
+
"are not implemented for "
|
| 156 |
+
"IpexAttnBackendImpl")
|
| 157 |
+
|
| 158 |
+
def split_kv_cache(
|
| 159 |
+
self,
|
| 160 |
+
kv_cache: torch.Tensor,
|
| 161 |
+
num_kv_heads: int,
|
| 162 |
+
head_size: int,
|
| 163 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 164 |
+
x = 1
|
| 165 |
+
num_blocks = kv_cache.shape[1]
|
| 166 |
+
|
| 167 |
+
key_cache = kv_cache[0]
|
| 168 |
+
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
| 169 |
+
-1, x)
|
| 170 |
+
value_cache = kv_cache[1]
|
| 171 |
+
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
| 172 |
+
return key_cache, value_cache
|
| 173 |
+
|
| 174 |
+
def forward(
|
| 175 |
+
self,
|
| 176 |
+
layer: AttentionLayer,
|
| 177 |
+
query: torch.Tensor,
|
| 178 |
+
key: torch.Tensor,
|
| 179 |
+
value: torch.Tensor,
|
| 180 |
+
kv_cache: torch.Tensor,
|
| 181 |
+
attn_metadata: IpexAttnMetadata, # type: ignore
|
| 182 |
+
output: Optional[torch.Tensor] = None,
|
| 183 |
+
) -> torch.Tensor:
|
| 184 |
+
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
query: shape = [num_tokens, num_heads * head_size]
|
| 188 |
+
key: shape = [num_tokens, num_kv_heads * head_size]
|
| 189 |
+
value: shape = [num_tokens, num_kv_heads * head_size]
|
| 190 |
+
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
| 191 |
+
NOTE: kv_cache will be an empty tensor with shape [0]
|
| 192 |
+
for profiling run.
|
| 193 |
+
attn_metadata: Metadata for attention.
|
| 194 |
+
Returns:
|
| 195 |
+
shape = [num_tokens, num_heads * head_size]
|
| 196 |
+
"""
|
| 197 |
+
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
| 198 |
+
num_tokens, hidden_size = query.shape
|
| 199 |
+
# Reshape the query, key, and value tensors.
|
| 200 |
+
query = query.view(-1, self.num_heads, self.head_size)
|
| 201 |
+
key = key.view(-1, self.num_kv_heads, self.head_size)
|
| 202 |
+
value = value.view(-1, self.num_kv_heads, self.head_size)
|
| 203 |
+
|
| 204 |
+
if kv_cache.numel() > 0:
|
| 205 |
+
key_cache, value_cache = self.split_kv_cache(
|
| 206 |
+
kv_cache, self.num_kv_heads, self.head_size)
|
| 207 |
+
ipex_ops.reshape_and_cache(
|
| 208 |
+
key,
|
| 209 |
+
value,
|
| 210 |
+
key_cache,
|
| 211 |
+
value_cache,
|
| 212 |
+
attn_metadata.slot_mapping.flatten(),
|
| 213 |
+
self.kv_cache_dtype,
|
| 214 |
+
layer._k_scale,
|
| 215 |
+
layer._v_scale,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
if attn_metadata.is_prompt:
|
| 219 |
+
assert attn_metadata.seq_lens is not None
|
| 220 |
+
if (kv_cache.numel() == 0
|
| 221 |
+
or attn_metadata.block_tables.numel() == 0):
|
| 222 |
+
if self.num_kv_heads != self.num_heads:
|
| 223 |
+
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
| 224 |
+
value = value.repeat_interleave(self.num_queries_per_kv,
|
| 225 |
+
dim=1)
|
| 226 |
+
|
| 227 |
+
if attn_metadata.attn_bias is None:
|
| 228 |
+
if self.alibi_slopes is not None:
|
| 229 |
+
att_masks = _make_alibi_bias(
|
| 230 |
+
self.alibi_slopes, query.dtype,
|
| 231 |
+
attn_metadata.seq_lens) # type: ignore
|
| 232 |
+
elif self.sliding_window is not None:
|
| 233 |
+
att_masks = _make_sliding_window_bias(
|
| 234 |
+
attn_metadata.seq_lens, self.sliding_window,
|
| 235 |
+
query.dtype) # type: ignore
|
| 236 |
+
else:
|
| 237 |
+
att_masks = _make_sliding_window_bias(
|
| 238 |
+
attn_metadata.seq_lens, None, dtype=query.dtype)
|
| 239 |
+
attn_metadata.attn_bias = att_masks
|
| 240 |
+
|
| 241 |
+
output = torch.empty(
|
| 242 |
+
(num_tokens, self.num_heads, self.head_size),
|
| 243 |
+
dtype=query.dtype,
|
| 244 |
+
device=query.device)
|
| 245 |
+
ipex_ops.varlen_attention(
|
| 246 |
+
query,
|
| 247 |
+
key,
|
| 248 |
+
value,
|
| 249 |
+
output,
|
| 250 |
+
attn_metadata.seqlen_q,
|
| 251 |
+
attn_metadata.seqlen_q,
|
| 252 |
+
attn_metadata.max_seqlen,
|
| 253 |
+
attn_metadata.max_seqlen,
|
| 254 |
+
pdropout=0.0,
|
| 255 |
+
softmax_scale=self.scale,
|
| 256 |
+
zero_tensors=False,
|
| 257 |
+
is_causal=True,
|
| 258 |
+
return_softmax=False,
|
| 259 |
+
gen_=None,
|
| 260 |
+
logits_soft_cap=self.logits_soft_cap,
|
| 261 |
+
)
|
| 262 |
+
else:
|
| 263 |
+
# prefix-enabled attention
|
| 264 |
+
raise RuntimeError(
|
| 265 |
+
"IPEX backend doesn't support prefix decoding.")
|
| 266 |
+
|
| 267 |
+
else:
|
| 268 |
+
# Decoding run.
|
| 269 |
+
max_seq_len = attn_metadata.max_decode_seq_len
|
| 270 |
+
output = torch.empty_like(query)
|
| 271 |
+
block_size = value_cache.shape[3]
|
| 272 |
+
num_seqs, num_heads, head_size = query.shape
|
| 273 |
+
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
| 274 |
+
_PARTITION_SIZE)
|
| 275 |
+
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
| 276 |
+
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
| 277 |
+
# V1 to avoid the overhead of reduction. Also, if the number of
|
| 278 |
+
# sequences or heads is large, we use V1 since there is enough work
|
| 279 |
+
# to parallelize.
|
| 280 |
+
# TODO(woosuk): Tune this heuristic.
|
| 281 |
+
# For context len > 8192, use V2 kernel to avoid shared memory
|
| 282 |
+
# shortage.
|
| 283 |
+
use_v1 = (max_seq_len <= 8192 and
|
| 284 |
+
(max_num_partitions == 1 or num_seqs * num_heads > 512))
|
| 285 |
+
if use_v1:
|
| 286 |
+
# Run PagedAttention V1.
|
| 287 |
+
ipex_ops.paged_attention_v1(
|
| 288 |
+
output,
|
| 289 |
+
query,
|
| 290 |
+
key_cache,
|
| 291 |
+
value_cache,
|
| 292 |
+
self.num_kv_heads,
|
| 293 |
+
self.scale,
|
| 294 |
+
attn_metadata.block_tables,
|
| 295 |
+
attn_metadata.seq_lens_tensor,
|
| 296 |
+
block_size,
|
| 297 |
+
max_seq_len,
|
| 298 |
+
self.alibi_slopes,
|
| 299 |
+
self.kv_cache_dtype,
|
| 300 |
+
layer._k_scale,
|
| 301 |
+
layer._v_scale,
|
| 302 |
+
)
|
| 303 |
+
else:
|
| 304 |
+
# Run PagedAttention V2.
|
| 305 |
+
assert _PARTITION_SIZE % block_size == 0
|
| 306 |
+
tmp_output = torch.empty(
|
| 307 |
+
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
| 308 |
+
dtype=output.dtype,
|
| 309 |
+
device=output.device,
|
| 310 |
+
)
|
| 311 |
+
exp_sums = torch.empty(
|
| 312 |
+
size=(num_seqs, num_heads, max_num_partitions),
|
| 313 |
+
dtype=torch.float32,
|
| 314 |
+
device=output.device,
|
| 315 |
+
)
|
| 316 |
+
max_logits = torch.empty_like(exp_sums)
|
| 317 |
+
ipex_ops.paged_attention_v2(
|
| 318 |
+
output,
|
| 319 |
+
exp_sums,
|
| 320 |
+
max_logits,
|
| 321 |
+
tmp_output,
|
| 322 |
+
query,
|
| 323 |
+
key_cache,
|
| 324 |
+
value_cache,
|
| 325 |
+
self.num_kv_heads,
|
| 326 |
+
self.scale,
|
| 327 |
+
attn_metadata.block_tables,
|
| 328 |
+
attn_metadata.seq_lens_tensor,
|
| 329 |
+
block_size,
|
| 330 |
+
max_seq_len,
|
| 331 |
+
self.alibi_slopes,
|
| 332 |
+
self.kv_cache_dtype,
|
| 333 |
+
layer._k_scale,
|
| 334 |
+
layer._v_scale,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# Reshape the output tensor.
|
| 338 |
+
return output.view(-1, self.num_heads * self.head_size)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def _make_alibi_bias(
|
| 342 |
+
alibi_slopes: torch.Tensor,
|
| 343 |
+
dtype: torch.dtype,
|
| 344 |
+
seq_lens: List[int],
|
| 345 |
+
) -> List[torch.Tensor]:
|
| 346 |
+
attn_biases = []
|
| 347 |
+
for seq_len in seq_lens:
|
| 348 |
+
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
|
| 349 |
+
# NOTE(zhuohan): HF uses
|
| 350 |
+
# `bias = bias[None, :].repeat(seq_len, 1)`
|
| 351 |
+
# here. We find that both biases give the same results, but
|
| 352 |
+
# the bias below more accurately follows the original ALiBi
|
| 353 |
+
# paper.
|
| 354 |
+
bias = bias[None, :] - bias[:, None]
|
| 355 |
+
|
| 356 |
+
num_heads = alibi_slopes.shape[0]
|
| 357 |
+
bias = bias[None, :].repeat((num_heads, 1, 1))
|
| 358 |
+
bias.mul_(alibi_slopes[:, None, None])
|
| 359 |
+
inf_mask = torch.empty(
|
| 360 |
+
(1, seq_len, seq_len),
|
| 361 |
+
dtype=bias.dtype,
|
| 362 |
+
device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1)
|
| 363 |
+
attn_biases.append((bias + inf_mask).to(dtype))
|
| 364 |
+
|
| 365 |
+
return attn_biases
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def _make_sliding_window_bias(
|
| 369 |
+
seq_lens: List[int],
|
| 370 |
+
window_size: Optional[int],
|
| 371 |
+
dtype: torch.dtype,
|
| 372 |
+
) -> List[torch.Tensor]:
|
| 373 |
+
attn_biases = []
|
| 374 |
+
for seq_len in seq_lens:
|
| 375 |
+
tensor = torch.full(
|
| 376 |
+
(1, seq_len, seq_len),
|
| 377 |
+
dtype=dtype,
|
| 378 |
+
fill_value=1,
|
| 379 |
+
)
|
| 380 |
+
shift = 0
|
| 381 |
+
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
|
| 382 |
+
if window_size is not None:
|
| 383 |
+
mask = torch.triu(mask, diagonal=shift - window_size + 1)
|
| 384 |
+
mask = torch.log(mask)
|
| 385 |
+
attn_biases.append(mask.to(dtype))
|
| 386 |
+
|
| 387 |
+
return attn_biases
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (200 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (25.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/utils.py
ADDED
|
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from abc import abstractmethod
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, Dict, Generic, List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from compressed_tensors.quantization import QuantizationStrategy
|
| 9 |
+
|
| 10 |
+
from vllm import _custom_ops as ops
|
| 11 |
+
from vllm import envs
|
| 12 |
+
from vllm.attention.backends.abstract import (AttentionLayer,
|
| 13 |
+
AttentionMetadata,
|
| 14 |
+
MLAAttentionImpl, T)
|
| 15 |
+
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
| 16 |
+
tensor_model_parallel_all_reduce)
|
| 17 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 18 |
+
LinearBase, RowParallelLinear,
|
| 19 |
+
UnquantizedLinearMethod)
|
| 20 |
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
| 21 |
+
CompressedTensorsLinearMethod)
|
| 22 |
+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
| 23 |
+
CompressedTensorsW8A8Fp8)
|
| 24 |
+
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
| 25 |
+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
| 26 |
+
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
|
| 27 |
+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
| 28 |
+
scaled_dequantize, scaled_quantize)
|
| 29 |
+
from vllm.model_executor.layers.rotary_embedding import (
|
| 30 |
+
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
| 34 |
+
except ImportError:
|
| 35 |
+
from flash_attn import flash_attn_varlen_func
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class MLACommonMetadata(AttentionMetadata):
|
| 40 |
+
# Input positions for rotrary embeddings since for MLA the rotary
|
| 41 |
+
# position embeddings are applied inside the attention backend
|
| 42 |
+
input_positions: torch.Tensor
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
| 46 |
+
"""
|
| 47 |
+
Common class for implementing repeated parts
|
| 48 |
+
|
| 49 |
+
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
| 50 |
+
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
|
| 51 |
+
|
| 52 |
+
Deepseek's MLA attention works the following way:
|
| 53 |
+
* Use a single latent vector to represent the entire KV cache.
|
| 54 |
+
* The attention "simulates" a multi-head attention, while the compute is
|
| 55 |
+
similar to multi-query attention.
|
| 56 |
+
* The dataflow is as follows,
|
| 57 |
+
|
| 58 |
+
* B: batch/sequence length
|
| 59 |
+
* H: hidden size
|
| 60 |
+
* N: number of attention heads
|
| 61 |
+
* Lq: latent dimension for Q
|
| 62 |
+
* Lkv: latent dimension for K/V
|
| 63 |
+
* P: nope dimension, P+R is the actual head_dim in common attention.
|
| 64 |
+
* R: rope dimension, this slide of the head_dim goes through rope.
|
| 65 |
+
* V: V head dim.
|
| 66 |
+
* kv_c: latent/compressed KV
|
| 67 |
+
* q_c: latent/compressed Q
|
| 68 |
+
|
| 69 |
+
#
|
| 70 |
+
# Outside the MLA attention backend
|
| 71 |
+
#
|
| 72 |
+
|
| 73 |
+
1. The hidden states (B, H) are projected down into cq (B, Lq) and
|
| 74 |
+
kv_c_k_pe (B, Lkv+R).
|
| 75 |
+
2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq
|
| 76 |
+
and kv_c are normalized.
|
| 77 |
+
|
| 78 |
+
#
|
| 79 |
+
# Inside the MLA attention backend
|
| 80 |
+
#
|
| 81 |
+
|
| 82 |
+
* if prefill:
|
| 83 |
+
|
| 84 |
+
3. The q_c is then projected up into the multi-head version.
|
| 85 |
+
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
|
| 86 |
+
(B, N, P) and q_pe (B, N, R).
|
| 87 |
+
4. q_pe, k_pe are then passed through rotary embeddings.
|
| 88 |
+
5. kv_c and k_pe are concatenated and inserted into the cache
|
| 89 |
+
6. The kv_c is then projected up into the multi-head version.
|
| 90 |
+
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
|
| 91 |
+
dimensions for K and V, which is split into k_nope (B, N, P)
|
| 92 |
+
and v (B, N, V).
|
| 93 |
+
7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from
|
| 94 |
+
q_nope, q_pe, k_nope, k_pe.
|
| 95 |
+
8. Attention is computued with q, k, v.
|
| 96 |
+
9. The attention computation returns (B, N, V), which is projected back
|
| 97 |
+
to (B, H) using out projection.
|
| 98 |
+
|
| 99 |
+
* if decode:
|
| 100 |
+
|
| 101 |
+
3. Here's the change, we do not perform up the full up projection for
|
| 102 |
+
q_c, and there is no up projection at all for kv_c. This is
|
| 103 |
+
achieved by the technique of "weight absorption". The paper says
|
| 104 |
+
"Fortunately, due to the associative law of matrix multiplication,
|
| 105 |
+
we can absorb WUK into WUQ, and WUV into WO"
|
| 106 |
+
* The q up projection turns (B, Lq) into (B, N, (P+R)), we split it
|
| 107 |
+
into W_UQ (Lq, N, P) and W_QR (Lq, N, R).
|
| 108 |
+
* The kv_c up projection turns (B, Lkv) into (B, N, (P+V)), we split
|
| 109 |
+
it into W_UK (Lkv, N, P) and W_UV (Lkv, N, V).
|
| 110 |
+
* The out projection shape W_O (N*V, H) turns (B, N, V) into (B, H).
|
| 111 |
+
* We can precompute the product of W_UQ and W_UK into
|
| 112 |
+
W_UQ_UK (Lq, N, Lkv), which is possible due to QK^T operation in
|
| 113 |
+
attention.
|
| 114 |
+
* We can precompute the product of W_UV and W_O into
|
| 115 |
+
W_UV_O (N, Lkv, H), which is possible due to V@O as the
|
| 116 |
+
"epilogue" of attention
|
| 117 |
+
4. We still need to compute q_pe (B, N, R) by applying W_QR to q_latent.
|
| 118 |
+
5. q_pe, k_pe are then passed through rotary embeddings.
|
| 119 |
+
6. kv_c and k_pe are concatenated and inserted into the cache
|
| 120 |
+
7. By applying W_UQ_UK to q_latent, we have the new q_nope of shape
|
| 121 |
+
(B, N, Lkv).
|
| 122 |
+
8. q (B, N, (Lkv+R)), k (B, (Lkv+R)) are assembled from q_nope, q_pe,
|
| 123 |
+
kv_a, k_pe. v (B, Lkv) is exactly the same vector as kv_a.
|
| 124 |
+
9. The attention is computed with q, k, v. Note that we just performed
|
| 125 |
+
a MQA attention with (LKv+R) as our head dim.
|
| 126 |
+
10. The KV cache is updated using the new entries k (B, N, (Lkv+R)),
|
| 127 |
+
which included the v and rope values.
|
| 128 |
+
11. The attention computation returns (B, N, Lkv), which is projected
|
| 129 |
+
back to (B, H) using W_UV_O.
|
| 130 |
+
|
| 131 |
+
From @tsu-bin's calculation, we only want to use the absorption technique
|
| 132 |
+
for decode. The prefill algorithm should still use the up-projected MHA
|
| 133 |
+
for less flops and memory usage.
|
| 134 |
+
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
num_heads: int,
|
| 140 |
+
head_size: int,
|
| 141 |
+
scale: float,
|
| 142 |
+
num_kv_heads: int,
|
| 143 |
+
alibi_slopes: Optional[List[float]],
|
| 144 |
+
sliding_window: Optional[int],
|
| 145 |
+
kv_cache_dtype: str,
|
| 146 |
+
blocksparse_params: Optional[Dict[str, Any]],
|
| 147 |
+
logits_soft_cap: Optional[float],
|
| 148 |
+
attn_type: str,
|
| 149 |
+
# MLA Specific Arguments
|
| 150 |
+
q_lora_rank: Optional[int],
|
| 151 |
+
kv_lora_rank: int,
|
| 152 |
+
qk_nope_head_dim: int,
|
| 153 |
+
qk_rope_head_dim: int,
|
| 154 |
+
qk_head_dim: int,
|
| 155 |
+
v_head_dim: int,
|
| 156 |
+
rotary_emb: RotaryEmbedding,
|
| 157 |
+
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
| 158 |
+
# attention backend perspective we rely on the layer to pass in the
|
| 159 |
+
# correct matrix
|
| 160 |
+
q_proj: ColumnParallelLinear,
|
| 161 |
+
kv_b_proj: ColumnParallelLinear,
|
| 162 |
+
o_proj: RowParallelLinear,
|
| 163 |
+
) -> None:
|
| 164 |
+
self.num_heads = num_heads
|
| 165 |
+
self.head_size = head_size
|
| 166 |
+
self.scale = float(scale)
|
| 167 |
+
self.num_kv_heads = num_kv_heads
|
| 168 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 169 |
+
|
| 170 |
+
self.q_lora_rank = q_lora_rank
|
| 171 |
+
self.kv_lora_rank = kv_lora_rank
|
| 172 |
+
self.qk_nope_head_dim = qk_nope_head_dim
|
| 173 |
+
self.qk_rope_head_dim = qk_rope_head_dim
|
| 174 |
+
self.qk_head_dim = qk_head_dim
|
| 175 |
+
self.v_head_dim = v_head_dim
|
| 176 |
+
|
| 177 |
+
self.rotary_emb = rotary_emb
|
| 178 |
+
self.use_yarn_rope = isinstance(rotary_emb,
|
| 179 |
+
DeepseekScalingRotaryEmbedding)
|
| 180 |
+
self.q_proj = q_proj
|
| 181 |
+
self.kv_b_proj = kv_b_proj
|
| 182 |
+
self.o_proj = o_proj
|
| 183 |
+
|
| 184 |
+
def _v_up_proj_and_o_proj(self, x):
|
| 185 |
+
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
| 186 |
+
if is_fp8(self.W_UV_O):
|
| 187 |
+
output_parallel = apply_fp8_linear_generic(
|
| 188 |
+
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
|
| 189 |
+
self.reqaunt_input_group_shape,
|
| 190 |
+
self.reqaunt_weight_group_shape)
|
| 191 |
+
else:
|
| 192 |
+
output_parallel = torch.matmul(x.flatten(start_dim=1),
|
| 193 |
+
self.W_UV_O)
|
| 194 |
+
if self.tp_size > 1:
|
| 195 |
+
output = tensor_model_parallel_all_reduce(output_parallel)
|
| 196 |
+
else:
|
| 197 |
+
output = output_parallel
|
| 198 |
+
return output
|
| 199 |
+
else:
|
| 200 |
+
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
|
| 201 |
+
return self.o_proj(x.reshape(-1,
|
| 202 |
+
self.num_heads * self.v_head_dim))[0]
|
| 203 |
+
|
| 204 |
+
def _q_proj_and_k_up_proj(self, x):
|
| 205 |
+
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
| 206 |
+
if is_fp8(self.W_Q_UK):
|
| 207 |
+
return apply_fp8_linear_generic(
|
| 208 |
+
x, self.W_Q_UK, self.W_Q_UK_scales,
|
| 209 |
+
self.reqaunt_input_group_shape,
|
| 210 |
+
self.reqaunt_weight_group_shape).view(
|
| 211 |
+
-1, self.num_heads, self.kv_lora_rank)
|
| 212 |
+
return torch.matmul(x, self.W_Q_UK)\
|
| 213 |
+
.view(-1, self.num_heads, self.kv_lora_rank)
|
| 214 |
+
else:
|
| 215 |
+
x = torch.matmul(x, self.W_Q)\
|
| 216 |
+
.view(-1, self.num_heads, self.qk_nope_head_dim)
|
| 217 |
+
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
|
| 218 |
+
.view(-1, self.num_heads, self.kv_lora_rank)
|
| 219 |
+
|
| 220 |
+
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
| 221 |
+
|
| 222 |
+
def is_layer_fp8(layer: LinearBase) -> bool:
|
| 223 |
+
return isinstance(layer.quant_method, Fp8LinearMethod) or\
|
| 224 |
+
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
| 225 |
+
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
|
| 226 |
+
|
| 227 |
+
def quantization_scheme_supported(layer: LinearBase) -> bool:
|
| 228 |
+
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
|
| 229 |
+
is_layer_fp8(layer)
|
| 230 |
+
|
| 231 |
+
# TODO(lucas) This is very gross, we need a more wide scale refactor of
|
| 232 |
+
# all the FP8 code with a more standard way of
|
| 233 |
+
# defining schemes/group-shapes, we should also potentially force
|
| 234 |
+
# quant_methods to support a decompress function
|
| 235 |
+
#
|
| 236 |
+
# returns input_group_shape, weight_group_shape
|
| 237 |
+
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
|
| 238 |
+
Tuple[Tuple[int, int], Tuple[int, int]]:
|
| 239 |
+
if isinstance(layer.quant_method, Fp8LinearMethod):
|
| 240 |
+
if layer.quant_method.block_quant is not None:
|
| 241 |
+
weight_block_size = \
|
| 242 |
+
layer.quant_method.quant_config.weight_block_size
|
| 243 |
+
# per-token-group (1, X), block-quantized (X, Y)
|
| 244 |
+
return (1, weight_block_size[-1]), weight_block_size
|
| 245 |
+
else:
|
| 246 |
+
return (-1, -1), (-1, -1) # per-tensor, per-tensor
|
| 247 |
+
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
| 248 |
+
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
| 249 |
+
# this is hacky but we always assume the for
|
| 250 |
+
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
|
| 251 |
+
# we ignore if it is static-per-tensor since we are going to
|
| 252 |
+
# requantize after later anyways
|
| 253 |
+
strategy = layer.scheme.strategy
|
| 254 |
+
if strategy == QuantizationStrategy.TENSOR:
|
| 255 |
+
return (1, -1), (-1, -1) # per-token, per-tensor
|
| 256 |
+
elif strategy == QuantizationStrategy.CHANNEL:
|
| 257 |
+
return (1, -1), (-1, 1) # per-token, per-channel
|
| 258 |
+
else:
|
| 259 |
+
raise NotImplementedError(
|
| 260 |
+
f"QuantizationStrategy.{strategy} is not supported for "
|
| 261 |
+
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
|
| 262 |
+
else:
|
| 263 |
+
raise NotImplementedError(
|
| 264 |
+
"Can't determine scale group shapes for "
|
| 265 |
+
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
def get_scales(layer: LinearBase) -> torch.Tensor:
|
| 269 |
+
if hasattr(layer, "weight_scale_inv"):
|
| 270 |
+
return layer.weight_scale_inv
|
| 271 |
+
return layer.weight_scale
|
| 272 |
+
|
| 273 |
+
def get_and_maybe_dequant_weights(layer: LinearBase):
|
| 274 |
+
if is_layer_fp8(layer):
|
| 275 |
+
if isinstance(layer.quant_method, \
|
| 276 |
+
CompressedTensorsLinearMethod) and \
|
| 277 |
+
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
| 278 |
+
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
|
| 279 |
+
# seems to store weights as (input, output) instead of
|
| 280 |
+
# (output, input) so we need to transpose
|
| 281 |
+
weight = layer.weight.T # standardize to (output, input)
|
| 282 |
+
else:
|
| 283 |
+
weight = layer.weight
|
| 284 |
+
_, weight_scale_group_shape = \
|
| 285 |
+
get_scale_group_shapes_for_fp8(layer)
|
| 286 |
+
scales = get_scales(layer)
|
| 287 |
+
|
| 288 |
+
return scaled_dequantize(weight, scales,
|
| 289 |
+
weight_scale_group_shape)
|
| 290 |
+
else:
|
| 291 |
+
return layer.weight
|
| 292 |
+
|
| 293 |
+
if not (quantization_scheme_supported(self.kv_b_proj) and\
|
| 294 |
+
quantization_scheme_supported(self.q_proj) and\
|
| 295 |
+
quantization_scheme_supported(self.o_proj)):
|
| 296 |
+
raise NotImplementedError(
|
| 297 |
+
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
|
| 298 |
+
", please run with VLLM_MLA_DISABLE=1")
|
| 299 |
+
|
| 300 |
+
weight_dtype = self.kv_b_proj.weight.dtype
|
| 301 |
+
assert self.o_proj.weight.dtype == weight_dtype
|
| 302 |
+
assert self.q_proj.weight.dtype == weight_dtype
|
| 303 |
+
|
| 304 |
+
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
| 305 |
+
assert kv_b_proj_weight.shape == (
|
| 306 |
+
self.kv_lora_rank,
|
| 307 |
+
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
| 308 |
+
f"{kv_b_proj_weight.shape=}, "
|
| 309 |
+
f"{self.kv_lora_rank=}, "
|
| 310 |
+
f"{self.num_heads=}, "
|
| 311 |
+
f"{self.qk_nope_head_dim=}, "
|
| 312 |
+
f"{self.v_head_dim=}")
|
| 313 |
+
kv_b_proj_weight = kv_b_proj_weight.view(
|
| 314 |
+
self.kv_lora_rank,
|
| 315 |
+
self.num_heads,
|
| 316 |
+
self.qk_nope_head_dim + self.v_head_dim,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
W_UK, W_UV = kv_b_proj_weight.split(
|
| 320 |
+
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 321 |
+
|
| 322 |
+
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
|
| 323 |
+
.view(-1, self.num_heads, self.qk_head_dim)
|
| 324 |
+
|
| 325 |
+
# can be W_Q or W_UQ depending q_lora_rank, the former if
|
| 326 |
+
# q_lora_rank is None, the latter otherwise. From the Attention backend
|
| 327 |
+
# perspective though we call these both W_Q and rely on the layer
|
| 328 |
+
# to pass in the correct matrix
|
| 329 |
+
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
|
| 330 |
+
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
|
| 331 |
+
.flatten(start_dim=1).contiguous()
|
| 332 |
+
|
| 333 |
+
# W_QR is small so for simplicity we dont bother requantizing it
|
| 334 |
+
self.W_QR = self.W_QR.to(act_dtype)
|
| 335 |
+
|
| 336 |
+
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
| 337 |
+
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
|
| 338 |
+
if is_fp8(weight_dtype) and requantization_enabled:
|
| 339 |
+
# This assumes it wise to requantize using the same group shapes
|
| 340 |
+
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
|
| 341 |
+
# weights were originally quantized
|
| 342 |
+
requant_input_group_shape, requant_weight_group_shape = \
|
| 343 |
+
get_scale_group_shapes_for_fp8(self.q_proj)
|
| 344 |
+
assert (requant_input_group_shape, requant_weight_group_shape)\
|
| 345 |
+
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
|
| 346 |
+
assert (requant_input_group_shape, requant_weight_group_shape)\
|
| 347 |
+
== get_scale_group_shapes_for_fp8(self.o_proj)
|
| 348 |
+
self.reqaunt_input_group_shape = requant_input_group_shape
|
| 349 |
+
self.reqaunt_weight_group_shape = requant_weight_group_shape
|
| 350 |
+
|
| 351 |
+
#
|
| 352 |
+
# Perform matrix-absorption following
|
| 353 |
+
# https://github.com/flashinfer-ai/flashinfer/pull/551
|
| 354 |
+
# for decode, as a result we end up with absorbed weights for decode
|
| 355 |
+
# and another copy of raw weights for prefill.
|
| 356 |
+
#
|
| 357 |
+
self.W_UK, self.W_UV = kv_b_proj_weight.split(
|
| 358 |
+
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 359 |
+
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
|
| 360 |
+
# depending q_lora_rank, the former if q_lora_rank is None, the
|
| 361 |
+
# latter otherwise
|
| 362 |
+
# basically if q_lora_rank is none we are absorbing into q_proj
|
| 363 |
+
# instead of UQ
|
| 364 |
+
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
|
| 365 |
+
.flatten(start_dim=1).contiguous()
|
| 366 |
+
|
| 367 |
+
if is_fp8(weight_dtype) and requantization_enabled:
|
| 368 |
+
W_Q_UK, W_Q_UK_scales = scaled_quantize(
|
| 369 |
+
W_Q_UK,
|
| 370 |
+
self.reqaunt_weight_group_shape,
|
| 371 |
+
quant_dtype=current_platform_fp8_dtype)
|
| 372 |
+
# For FP8 save the transpose so we can use
|
| 373 |
+
# `apply_w8a8_block_fp8_linear` directly
|
| 374 |
+
self.W_Q_UK = W_Q_UK.T.contiguous()
|
| 375 |
+
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
|
| 376 |
+
else:
|
| 377 |
+
self.W_Q_UK = W_Q_UK.to(act_dtype)
|
| 378 |
+
|
| 379 |
+
W_O = get_and_maybe_dequant_weights(self.o_proj)\
|
| 380 |
+
.view(-1, self.num_heads, self.v_head_dim)
|
| 381 |
+
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
|
| 382 |
+
.flatten(start_dim=0, end_dim=1).contiguous()
|
| 383 |
+
|
| 384 |
+
if is_fp8(weight_dtype) and requantization_enabled:
|
| 385 |
+
W_UV_O, W_UV_O_scales = scaled_quantize(
|
| 386 |
+
W_UV_O,
|
| 387 |
+
self.reqaunt_weight_group_shape,
|
| 388 |
+
quant_dtype=current_platform_fp8_dtype)
|
| 389 |
+
# For FP8 save the transpose so we can use
|
| 390 |
+
# `apply_w8a8_block_fp8_linear` directly
|
| 391 |
+
self.W_UV_O = W_UV_O.T.contiguous()
|
| 392 |
+
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
|
| 393 |
+
else:
|
| 394 |
+
self.W_UV_O = W_UV_O.to(act_dtype)
|
| 395 |
+
|
| 396 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 397 |
+
else:
|
| 398 |
+
if is_fp8(weight_dtype):
|
| 399 |
+
raise NotImplementedError(
|
| 400 |
+
"Currently fp8 requires matrix absorption")
|
| 401 |
+
|
| 402 |
+
self.W_UV = W_UV
|
| 403 |
+
self.W_UK = W_UK
|
| 404 |
+
self.W_Q = W_Q.flatten(start_dim=1)
|
| 405 |
+
|
| 406 |
+
@abstractmethod
|
| 407 |
+
def _forward_prefill(
|
| 408 |
+
self,
|
| 409 |
+
q: torch.Tensor,
|
| 410 |
+
kv_c_normed: torch.Tensor,
|
| 411 |
+
k_pe: torch.Tensor,
|
| 412 |
+
attn_metadata: T,
|
| 413 |
+
) -> torch.Tensor:
|
| 414 |
+
raise NotImplementedError
|
| 415 |
+
|
| 416 |
+
@abstractmethod
|
| 417 |
+
def _forward_decode(
|
| 418 |
+
self,
|
| 419 |
+
q_nope: torch.Tensor,
|
| 420 |
+
q_pe: torch.Tensor,
|
| 421 |
+
kv_cache: torch.Tensor,
|
| 422 |
+
attn_metadata: T,
|
| 423 |
+
) -> torch.Tensor:
|
| 424 |
+
raise NotImplementedError
|
| 425 |
+
|
| 426 |
+
def apply_pure_rope(
|
| 427 |
+
self,
|
| 428 |
+
input_positions: torch.Tensor,
|
| 429 |
+
q_pe: torch.Tensor,
|
| 430 |
+
k_pe: torch.Tensor,
|
| 431 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 432 |
+
seq_len = input_positions.size(0)
|
| 433 |
+
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
| 434 |
+
|
| 435 |
+
q_pe, k_pe = self.rotary_emb(
|
| 436 |
+
input_positions,
|
| 437 |
+
q_pe.reshape(seq_len, -1),
|
| 438 |
+
k_pe.reshape(seq_len, -1),
|
| 439 |
+
)
|
| 440 |
+
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
|
| 441 |
+
|
| 442 |
+
return q_pe, k_pe
|
| 443 |
+
|
| 444 |
+
def forward(
|
| 445 |
+
self,
|
| 446 |
+
layer: AttentionLayer,
|
| 447 |
+
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
| 448 |
+
k_c_normed: torch.Tensor, # key in unified attn
|
| 449 |
+
k_pe: torch.Tensor, # value in unified attn
|
| 450 |
+
kv_cache: torch.Tensor,
|
| 451 |
+
attn_metadata: T,
|
| 452 |
+
output: Optional[torch.Tensor] = None,
|
| 453 |
+
) -> torch.Tensor:
|
| 454 |
+
if output is not None:
|
| 455 |
+
raise NotImplementedError(
|
| 456 |
+
"output is not yet supported for MLAImplBase")
|
| 457 |
+
|
| 458 |
+
is_decode = attn_metadata.decode_metadata is not None
|
| 459 |
+
is_prefill = attn_metadata.prefill_metadata is not None
|
| 460 |
+
|
| 461 |
+
if (is_decode and is_prefill):
|
| 462 |
+
raise NotImplementedError(
|
| 463 |
+
"chunked prefill is not supported for MLAImplBase")
|
| 464 |
+
|
| 465 |
+
# Restore head dim (for rotary embedding)
|
| 466 |
+
k_pe = k_pe.unsqueeze(1)
|
| 467 |
+
assert hasattr(attn_metadata, "input_positions")
|
| 468 |
+
rope_fn = (self.rotary_emb
|
| 469 |
+
if self.use_yarn_rope else self.apply_pure_rope)
|
| 470 |
+
|
| 471 |
+
if is_decode:
|
| 472 |
+
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
|
| 473 |
+
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
|
| 474 |
+
.view(-1, self.num_heads, self.qk_rope_head_dim)
|
| 475 |
+
q_pe, k_pe = rope_fn(attn_metadata.input_positions, q_pe, k_pe)
|
| 476 |
+
else:
|
| 477 |
+
assert is_prefill
|
| 478 |
+
q = self.q_proj(hidden_states_or_q_c)[0]\
|
| 479 |
+
.view(-1, self.num_heads, self.qk_head_dim)
|
| 480 |
+
|
| 481 |
+
# TODO(lucas): there must be a nicer way to write this line
|
| 482 |
+
q[..., self.qk_nope_head_dim:], k_pe = \
|
| 483 |
+
rope_fn(
|
| 484 |
+
attn_metadata.input_positions,
|
| 485 |
+
q[..., self.qk_nope_head_dim:], k_pe)
|
| 486 |
+
|
| 487 |
+
# write the latent and rope to kv cache
|
| 488 |
+
if kv_cache.numel() > 0:
|
| 489 |
+
ops.concat_and_cache_mla(
|
| 490 |
+
k_c_normed,
|
| 491 |
+
k_pe.squeeze(1),
|
| 492 |
+
kv_cache,
|
| 493 |
+
attn_metadata.slot_mapping.flatten(),
|
| 494 |
+
kv_cache_dtype=self.kv_cache_dtype,
|
| 495 |
+
scale=layer._k_scale,
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
if attn_metadata.prefill_metadata is not None:
|
| 499 |
+
return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata)
|
| 500 |
+
|
| 501 |
+
if attn_metadata.decode_metadata is not None:
|
| 502 |
+
return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata)
|
| 503 |
+
|
| 504 |
+
# Optional common flash-attn based prefill
|
| 505 |
+
def _forward_prefill_flash(
|
| 506 |
+
self,
|
| 507 |
+
q: torch.Tensor,
|
| 508 |
+
k_c_normed: torch.Tensor,
|
| 509 |
+
k_pe: torch.Tensor,
|
| 510 |
+
seq_start_loc: torch.Tensor,
|
| 511 |
+
max_prefill_seq_len: int,
|
| 512 |
+
) -> torch.Tensor:
|
| 513 |
+
|
| 514 |
+
kv_nope = self.kv_b_proj(k_c_normed)[0]\
|
| 515 |
+
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
| 516 |
+
k_nope, v = kv_nope\
|
| 517 |
+
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 518 |
+
|
| 519 |
+
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
| 520 |
+
|
| 521 |
+
# For MLA the v head dim is smaller than qk head dim so we pad out
|
| 522 |
+
# v with 0s to match the qk head dim
|
| 523 |
+
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
| 524 |
+
value=0)
|
| 525 |
+
|
| 526 |
+
attn_output = flash_attn_varlen_func(
|
| 527 |
+
q=q,
|
| 528 |
+
k=k,
|
| 529 |
+
v=v_padded,
|
| 530 |
+
cu_seqlens_q=seq_start_loc,
|
| 531 |
+
cu_seqlens_k=seq_start_loc,
|
| 532 |
+
max_seqlen_q=max_prefill_seq_len,
|
| 533 |
+
max_seqlen_k=max_prefill_seq_len,
|
| 534 |
+
softmax_scale=self.scale,
|
| 535 |
+
causal=True,
|
| 536 |
+
)
|
| 537 |
+
attn_output = attn_output\
|
| 538 |
+
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
|
| 539 |
+
.reshape(-1, self.num_heads * v.shape[-1])
|
| 540 |
+
|
| 541 |
+
return self.o_proj(attn_output)[0]
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/openvino.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Dict, List, Optional, Tuple, Type
|
| 5 |
+
|
| 6 |
+
import openvino as ov
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from vllm.attention.backends.abstract import (AttentionBackend,
|
| 10 |
+
AttentionMetadata)
|
| 11 |
+
from vllm.attention.backends.utils import CommonAttentionState
|
| 12 |
+
from vllm.multimodal import MultiModalPlaceholderMap
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def copy_cache_block(src_tensor: ov.Tensor, dst_tensor: ov.Tensor,
|
| 16 |
+
src_offset: int, dst_offset: int) -> None:
|
| 17 |
+
|
| 18 |
+
def create_roi_tensor(
|
| 19 |
+
tensor: ov.Tensor,
|
| 20 |
+
block_number: int,
|
| 21 |
+
) -> ov.Tensor:
|
| 22 |
+
roi_begin = ov.runtime.Coordinate([0, 0, 0, 0])
|
| 23 |
+
roi_end = ov.runtime.Coordinate(tensor.get_shape())
|
| 24 |
+
|
| 25 |
+
roi_begin[0] = block_number
|
| 26 |
+
roi_end[0] = block_number + 1
|
| 27 |
+
|
| 28 |
+
if isinstance(tensor, ov.Tensor):
|
| 29 |
+
return ov.Tensor(tensor, roi_begin, roi_end)
|
| 30 |
+
else:
|
| 31 |
+
return ov.RemoteTensor(tensor, roi_begin, roi_end)
|
| 32 |
+
|
| 33 |
+
src_roi_tensor = \
|
| 34 |
+
create_roi_tensor(src_tensor, src_offset)
|
| 35 |
+
dst_roi_tensor = \
|
| 36 |
+
create_roi_tensor(dst_tensor, dst_offset)
|
| 37 |
+
src_roi_tensor.copy_to(dst_roi_tensor)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class OpenVINOAttentionBackend(AttentionBackend):
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def get_name() -> str:
|
| 44 |
+
return "OPENVINO"
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def get_impl_cls():
|
| 48 |
+
# OpenVINO implements PagedAttention as part of the Optimum
|
| 49 |
+
# exported model
|
| 50 |
+
raise NotImplementedError
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
|
| 54 |
+
raise NotImplementedError
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def get_state_cls() -> Type["CommonAttentionState"]:
|
| 58 |
+
return CommonAttentionState
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata":
|
| 62 |
+
return OpenVINOAttentionMetadata(*args, **kwargs)
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def get_kv_cache_shape(
|
| 66 |
+
num_blocks: int,
|
| 67 |
+
block_size: int,
|
| 68 |
+
num_kv_heads: int,
|
| 69 |
+
head_size: int,
|
| 70 |
+
) -> Tuple[int, ...]:
|
| 71 |
+
return (2, num_blocks, num_kv_heads, block_size, head_size)
|
| 72 |
+
|
| 73 |
+
@staticmethod
|
| 74 |
+
def swap_blocks(
|
| 75 |
+
src_tensor: ov.Tensor,
|
| 76 |
+
dst_tensor: ov.Tensor,
|
| 77 |
+
src_to_dists: List[Tuple[int, int]],
|
| 78 |
+
) -> None:
|
| 79 |
+
for src, dst in src_to_dists:
|
| 80 |
+
copy_cache_block(src_tensor, dst_tensor, src, dst)
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
def copy_blocks(
|
| 84 |
+
kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
|
| 85 |
+
src_to_dists: List[Tuple[int, int]],
|
| 86 |
+
) -> None:
|
| 87 |
+
for src, dst in src_to_dists:
|
| 88 |
+
for key_cache, value_cache in kv_caches:
|
| 89 |
+
copy_cache_block(key_cache, key_cache, src, dst)
|
| 90 |
+
copy_cache_block(value_cache, value_cache, src, dst)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@dataclass
|
| 94 |
+
class OpenVINOAttentionMetadata:
|
| 95 |
+
"""Metadata for OpenVINOAttentionBackend.
|
| 96 |
+
|
| 97 |
+
Basic terms used below:
|
| 98 |
+
- batch_size_in_sequences - total number of sequences to execute
|
| 99 |
+
- prompt_lens – per sequence size number of scheduled tokens
|
| 100 |
+
- batch_size_in_tokens = sum(prompt_lens)
|
| 101 |
+
- max_context_len = max(context_lens)
|
| 102 |
+
- max_num_blocks = div_up(max_context_len / BLOCK_SIZE)
|
| 103 |
+
- num_blocks – total number of blocks in block_indices
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
# Describes past KV cache size for each sequence within a batch
|
| 107 |
+
# Shape: [batch_size_in_sequences]
|
| 108 |
+
# Type: i32
|
| 109 |
+
past_lens: torch.Tensor
|
| 110 |
+
|
| 111 |
+
# Describes start indices of input / speculative tokens from
|
| 112 |
+
# current sequences within a batch sequence
|
| 113 |
+
# Shape: [batch_size_in_sequences + 1]
|
| 114 |
+
# Type: i32
|
| 115 |
+
subsequence_begins: torch.Tensor
|
| 116 |
+
|
| 117 |
+
# Describes block tables for each sequence within a batch -
|
| 118 |
+
# indices along 0th dimension in key_cache and value_cache inputs
|
| 119 |
+
# Shape: [num_blocks]
|
| 120 |
+
# Type: i32
|
| 121 |
+
block_indices: torch.Tensor
|
| 122 |
+
|
| 123 |
+
# Describes block tables for each sequence within a batch -
|
| 124 |
+
# for i-th element, it is an index in block_indices with the
|
| 125 |
+
# first block belonging to i-th sequence
|
| 126 |
+
# Shape: [batch_size_in_sequences + 1]
|
| 127 |
+
# Type: i32
|
| 128 |
+
block_indices_begins: torch.Tensor
|
| 129 |
+
|
| 130 |
+
# Describes max context length
|
| 131 |
+
# Shape: scalar
|
| 132 |
+
# Type: i32
|
| 133 |
+
max_context_len: torch.Tensor
|
| 134 |
+
|
| 135 |
+
# The index maps that relate multi-modal embeddings to the corresponding
|
| 136 |
+
# placeholders.
|
| 137 |
+
#
|
| 138 |
+
# N.B. These aren't really related to attention and don't belong on this
|
| 139 |
+
# type -- this is just a temporary solution to make them available to
|
| 140 |
+
# `model_executable`.
|
| 141 |
+
multi_modal_placeholder_index_maps: Optional[Dict[
|
| 142 |
+
str, MultiModalPlaceholderMap.IndexMap]]
|
| 143 |
+
|
| 144 |
+
# Enable/disable KV scales calculation. This is so that we can disable the
|
| 145 |
+
# calculation until after prefill and cuda graph capture.
|
| 146 |
+
enable_kv_scales_calculation: bool
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/pallas.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
| 8 |
+
|
| 9 |
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
| 10 |
+
AttentionLayer,
|
| 11 |
+
AttentionMetadata, AttentionType)
|
| 12 |
+
from vllm.attention.backends.utils import CommonAttentionState
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PallasAttentionBackend(AttentionBackend):
|
| 16 |
+
|
| 17 |
+
@staticmethod
|
| 18 |
+
def get_name() -> str:
|
| 19 |
+
return "PALLAS"
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
|
| 23 |
+
return PallasAttentionBackendImpl
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def get_metadata_cls() -> Type["PallasMetadata"]:
|
| 27 |
+
return PallasMetadata
|
| 28 |
+
|
| 29 |
+
@staticmethod
|
| 30 |
+
def get_state_cls() -> Type["CommonAttentionState"]:
|
| 31 |
+
return CommonAttentionState
|
| 32 |
+
|
| 33 |
+
@staticmethod
|
| 34 |
+
def get_kv_cache_shape(
|
| 35 |
+
num_blocks: int,
|
| 36 |
+
block_size: int,
|
| 37 |
+
num_kv_heads: int,
|
| 38 |
+
head_size: int,
|
| 39 |
+
) -> Tuple[int, ...]:
|
| 40 |
+
return (num_kv_heads, num_blocks, block_size, head_size)
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def swap_blocks(
|
| 44 |
+
src_kv_cache: torch.Tensor,
|
| 45 |
+
dst_kv_cache: torch.Tensor,
|
| 46 |
+
src_to_dst: torch.Tensor,
|
| 47 |
+
) -> None:
|
| 48 |
+
raise RuntimeError("swap_blocks is not used for the TPU backend.")
|
| 49 |
+
|
| 50 |
+
@torch.compile(backend="openxla")
|
| 51 |
+
@staticmethod
|
| 52 |
+
def copy_blocks(
|
| 53 |
+
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
| 54 |
+
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
|
| 55 |
+
) -> None:
|
| 56 |
+
src_indices, dst_indices = src_to_dists
|
| 57 |
+
for k_cache, v_cache in kv_caches:
|
| 58 |
+
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
|
| 59 |
+
k_cache[:, dst_indices] = k_cache[:, src_indices]
|
| 60 |
+
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
|
| 61 |
+
v_cache[:, dst_indices] = v_cache[:, src_indices]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@dataclass
|
| 65 |
+
class PallasMetadata(AttentionMetadata):
|
| 66 |
+
|
| 67 |
+
# Currently, input sequences can only contain all prefills
|
| 68 |
+
# or all decoding.
|
| 69 |
+
block_tables: Optional[torch.Tensor] = None
|
| 70 |
+
context_lens: Optional[torch.Tensor] = None
|
| 71 |
+
effective_query_lens: Optional[torch.Tensor] = None
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def prefill_metadata(self) -> Optional["PallasMetadata"]:
|
| 75 |
+
if self.num_prefills == 0:
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
assert self.num_decode_tokens == 0
|
| 79 |
+
return self
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def decode_metadata(self) -> Optional["PallasMetadata"]:
|
| 83 |
+
if self.num_decode_tokens == 0:
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
assert self.num_prefills == 0
|
| 87 |
+
assert self.num_prefill_tokens == 0
|
| 88 |
+
assert self.block_tables is not None
|
| 89 |
+
assert self.context_lens is not None
|
| 90 |
+
return self
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class PallasAttentionBackendImpl(AttentionImpl):
|
| 94 |
+
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
num_heads: int,
|
| 98 |
+
head_size: int,
|
| 99 |
+
scale: float,
|
| 100 |
+
num_kv_heads: int,
|
| 101 |
+
alibi_slopes: Optional[List[float]],
|
| 102 |
+
sliding_window: Optional[int],
|
| 103 |
+
kv_cache_dtype: str,
|
| 104 |
+
blocksparse_params: Optional[Dict[str, Any]] = None,
|
| 105 |
+
logits_soft_cap: Optional[float] = None,
|
| 106 |
+
attn_type: str = AttentionType.DECODER,
|
| 107 |
+
) -> None:
|
| 108 |
+
self.num_heads = num_heads
|
| 109 |
+
self.head_size = head_size
|
| 110 |
+
self.scale = float(scale)
|
| 111 |
+
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
| 112 |
+
|
| 113 |
+
assert self.num_heads % self.num_kv_heads == 0
|
| 114 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 115 |
+
self.logits_soft_cap = logits_soft_cap
|
| 116 |
+
if head_size % 128 != 0:
|
| 117 |
+
raise NotImplementedError("Head size must be a multiple of 128.")
|
| 118 |
+
if alibi_slopes is not None:
|
| 119 |
+
raise NotImplementedError("Alibi slopes is not supported.")
|
| 120 |
+
if sliding_window is not None:
|
| 121 |
+
raise NotImplementedError("Sliding window is not supported.")
|
| 122 |
+
if kv_cache_dtype != "auto":
|
| 123 |
+
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
| 124 |
+
if blocksparse_params is not None:
|
| 125 |
+
raise NotImplementedError("Blocksparse is not supported.")
|
| 126 |
+
|
| 127 |
+
if torch_xla.tpu.version() < 4:
|
| 128 |
+
raise NotImplementedError("TPU version must be 4 or higher.")
|
| 129 |
+
|
| 130 |
+
self.megacore_mode = None
|
| 131 |
+
tpu_env = torch_xla.tpu.get_tpu_env()
|
| 132 |
+
tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
|
| 133 |
+
or tpu_env.get("TYPE", None)
|
| 134 |
+
or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
|
| 135 |
+
assert tpu_type is not None
|
| 136 |
+
tpu_type = tpu_type.lower()
|
| 137 |
+
|
| 138 |
+
if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
|
| 139 |
+
if self.num_kv_heads % 2 == 0:
|
| 140 |
+
self.megacore_mode = "kv_head"
|
| 141 |
+
else:
|
| 142 |
+
# NOTE(woosuk): If the batch size is not a multiple of 2, the
|
| 143 |
+
# megacore mode will be None.
|
| 144 |
+
self.megacore_mode = "batch"
|
| 145 |
+
|
| 146 |
+
if attn_type != AttentionType.DECODER:
|
| 147 |
+
raise NotImplementedError("Encoder self-attention and "
|
| 148 |
+
"encoder/decoder cross-attention "
|
| 149 |
+
"are not implemented for "
|
| 150 |
+
"PallasAttentionBackendImpl")
|
| 151 |
+
|
| 152 |
+
def forward(
|
| 153 |
+
self,
|
| 154 |
+
layer: AttentionLayer,
|
| 155 |
+
query: torch.Tensor,
|
| 156 |
+
key: torch.Tensor,
|
| 157 |
+
value: torch.Tensor,
|
| 158 |
+
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
| 159 |
+
attn_metadata: PallasMetadata,
|
| 160 |
+
output: Optional[torch.Tensor] = None,
|
| 161 |
+
) -> torch.Tensor:
|
| 162 |
+
"""Forward pass with Pallas attention.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
query: shape = [batch_size, seq_len, num_heads * head_size]
|
| 166 |
+
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
| 167 |
+
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
| 168 |
+
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
|
| 169 |
+
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
|
| 170 |
+
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
|
| 171 |
+
with shape [0] for profiling run.
|
| 172 |
+
attn_metadata: Metadata for attention.
|
| 173 |
+
Returns:
|
| 174 |
+
shape = [batch_size, seq_len, num_heads * head_size]
|
| 175 |
+
"""
|
| 176 |
+
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
| 177 |
+
batch_size, seq_len, hidden_size = query.shape
|
| 178 |
+
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
| 179 |
+
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
| 180 |
+
value = value.view(batch_size, seq_len, self.num_kv_heads,
|
| 181 |
+
self.head_size)
|
| 182 |
+
|
| 183 |
+
if kv_cache[0].numel() > 0:
|
| 184 |
+
slot_mapping = attn_metadata.slot_mapping
|
| 185 |
+
key_cache, value_cache = kv_cache
|
| 186 |
+
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
|
| 187 |
+
|
| 188 |
+
query = query * self.scale
|
| 189 |
+
if attn_metadata.num_prefills > 0:
|
| 190 |
+
if attn_metadata.block_tables is None:
|
| 191 |
+
# Prefill without paged KV cache.
|
| 192 |
+
assert seq_len % 16 == 0, (
|
| 193 |
+
"Pallas FlashAttention kernel requires seq_len to be a "
|
| 194 |
+
f"multiple of 16 but got {seq_len}")
|
| 195 |
+
|
| 196 |
+
# Handle GQA/MQA.
|
| 197 |
+
if self.num_kv_heads != self.num_heads:
|
| 198 |
+
key = key.repeat_interleave(self.num_queries_per_kv,
|
| 199 |
+
dim=-2)
|
| 200 |
+
key = key.view(batch_size, seq_len, self.num_heads,
|
| 201 |
+
self.head_size)
|
| 202 |
+
value = value.repeat_interleave(self.num_queries_per_kv,
|
| 203 |
+
dim=-2)
|
| 204 |
+
value = value.view(batch_size, seq_len, self.num_heads,
|
| 205 |
+
self.head_size)
|
| 206 |
+
# FlashAttention kernel requires the input shape to be
|
| 207 |
+
# [batch_size, num_heads, seq_len, d_model]
|
| 208 |
+
# while the input is [batch_size, seq_len, num_heads, d_model].
|
| 209 |
+
# Permute the input to match the required format.
|
| 210 |
+
output = torch.ops.xla.flash_attention(
|
| 211 |
+
query.permute(0, 2, 1, 3),
|
| 212 |
+
key.permute(0, 2, 1, 3),
|
| 213 |
+
value.permute(0, 2, 1, 3),
|
| 214 |
+
True,
|
| 215 |
+
)
|
| 216 |
+
output = output.permute(0, 2, 1, 3)
|
| 217 |
+
else:
|
| 218 |
+
# Prefill with paged KV cache.
|
| 219 |
+
# TODO(woosuk): Tune the below knobs.
|
| 220 |
+
num_kv_pages_per_compute_block = 16
|
| 221 |
+
num_queries_per_compute_block = 16
|
| 222 |
+
assert seq_len % num_queries_per_compute_block == 0
|
| 223 |
+
output = torch.ops.xla.multi_queries_paged_attention(
|
| 224 |
+
query,
|
| 225 |
+
key_cache,
|
| 226 |
+
value_cache,
|
| 227 |
+
attn_metadata.context_lens,
|
| 228 |
+
attn_metadata.block_tables,
|
| 229 |
+
attn_metadata.effective_query_lens,
|
| 230 |
+
num_kv_pages_per_compute_block,
|
| 231 |
+
num_queries_per_compute_block,
|
| 232 |
+
use_kernel=True,
|
| 233 |
+
attn_logits_soft_cap=self.logits_soft_cap,
|
| 234 |
+
)
|
| 235 |
+
else:
|
| 236 |
+
# Decoding run.
|
| 237 |
+
assert kv_cache[0].numel() > 0
|
| 238 |
+
query = query.squeeze(dim=1)
|
| 239 |
+
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
|
| 240 |
+
|
| 241 |
+
assert attn_metadata.block_tables is not None
|
| 242 |
+
assert attn_metadata.context_lens is not None
|
| 243 |
+
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
|
| 244 |
+
# block table in SMEM. Therefore, if the block table is too large,
|
| 245 |
+
# the kernel compilation will fail. To avoid this, we split the
|
| 246 |
+
# batch dimension into smaller chunks and run the kernel multiple
|
| 247 |
+
# times.
|
| 248 |
+
MAX_SMEM_USAGE = 512 * 1024
|
| 249 |
+
size_per_seq = 4 * attn_metadata.block_tables.shape[1]
|
| 250 |
+
max_num_seq = MAX_SMEM_USAGE // size_per_seq
|
| 251 |
+
|
| 252 |
+
if batch_size <= max_num_seq:
|
| 253 |
+
output = paged_attention(
|
| 254 |
+
query,
|
| 255 |
+
key_cache,
|
| 256 |
+
value_cache,
|
| 257 |
+
attn_metadata.context_lens,
|
| 258 |
+
attn_metadata.block_tables,
|
| 259 |
+
pages_per_compute_block,
|
| 260 |
+
self.megacore_mode,
|
| 261 |
+
attn_logits_soft_cap=self.logits_soft_cap,
|
| 262 |
+
)
|
| 263 |
+
else:
|
| 264 |
+
chunk_size = max_num_seq
|
| 265 |
+
# Make sure the chunk size is a multiple of 2.
|
| 266 |
+
chunk_size = chunk_size // 2 * 2
|
| 267 |
+
num_chunks = (batch_size + chunk_size - 1) // chunk_size
|
| 268 |
+
|
| 269 |
+
output = torch.empty_like(query)
|
| 270 |
+
for chunk_idx in range(num_chunks):
|
| 271 |
+
chunk_start = chunk_idx * chunk_size
|
| 272 |
+
chunk_end = chunk_start + chunk_size
|
| 273 |
+
# NOTE(woosuk): We skip this line because it causes Dynamo
|
| 274 |
+
# compilation error. Instead, we rely on the slice operation
|
| 275 |
+
# to handle the out-of-bound case.
|
| 276 |
+
# chunk_end = min(chunk_end, batch_size)
|
| 277 |
+
chunk_output = paged_attention(
|
| 278 |
+
query[chunk_start:chunk_end],
|
| 279 |
+
key_cache,
|
| 280 |
+
value_cache,
|
| 281 |
+
attn_metadata.context_lens[chunk_start:chunk_end],
|
| 282 |
+
attn_metadata.block_tables[chunk_start:chunk_end],
|
| 283 |
+
pages_per_compute_block,
|
| 284 |
+
self.megacore_mode,
|
| 285 |
+
attn_logits_soft_cap=self.logits_soft_cap,
|
| 286 |
+
)
|
| 287 |
+
output[chunk_start:chunk_end] = chunk_output
|
| 288 |
+
|
| 289 |
+
# Reshape the output tensor.
|
| 290 |
+
return output.reshape(batch_size, seq_len, hidden_size)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def write_to_kv_cache(
|
| 294 |
+
key: torch.Tensor,
|
| 295 |
+
value: torch.Tensor,
|
| 296 |
+
key_cache: torch.Tensor,
|
| 297 |
+
value_cache: torch.Tensor,
|
| 298 |
+
slot_mapping: torch.Tensor,
|
| 299 |
+
) -> None:
|
| 300 |
+
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
|
| 301 |
+
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
|
| 302 |
+
|
| 303 |
+
key = key.flatten(0, 2)
|
| 304 |
+
value = value.flatten(0, 2)
|
| 305 |
+
key_cache = key_cache.flatten(0, 2)
|
| 306 |
+
value_cache = value_cache.flatten(0, 2)
|
| 307 |
+
key_cache.index_copy_(0, slot_mapping, key)
|
| 308 |
+
value_cache.index_copy_(0, slot_mapping, value)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def paged_attention(
|
| 312 |
+
query: torch.Tensor,
|
| 313 |
+
key_cache: torch.Tensor,
|
| 314 |
+
value_cache: torch.Tensor,
|
| 315 |
+
context_lens: torch.Tensor,
|
| 316 |
+
block_tables: torch.Tensor,
|
| 317 |
+
pages_per_compute_block: int,
|
| 318 |
+
megacore_mode: Optional[str],
|
| 319 |
+
*,
|
| 320 |
+
attn_logits_soft_cap: Optional[float],
|
| 321 |
+
) -> torch.Tensor:
|
| 322 |
+
batch_size = query.shape[0]
|
| 323 |
+
if megacore_mode == "batch" and batch_size % 2 != 0:
|
| 324 |
+
megacore_mode = None
|
| 325 |
+
else:
|
| 326 |
+
megacore_mode = megacore_mode
|
| 327 |
+
|
| 328 |
+
return torch.ops.xla.paged_attention(
|
| 329 |
+
query,
|
| 330 |
+
key_cache,
|
| 331 |
+
value_cache,
|
| 332 |
+
context_lens,
|
| 333 |
+
block_tables,
|
| 334 |
+
pages_per_compute_block,
|
| 335 |
+
megacore_mode=megacore_mode,
|
| 336 |
+
attn_logits_soft_cap=attn_logits_soft_cap,
|
| 337 |
+
)
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/placeholder_attn.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
| 10 |
+
AttentionMetadata,
|
| 11 |
+
AttentionMetadataBuilder)
|
| 12 |
+
from vllm.attention.backends.utils import CommonAttentionState
|
| 13 |
+
from vllm.multimodal import MultiModalPlaceholderMap
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
| 17 |
+
ModelInputForGPUWithSamplingMetadata)
|
| 18 |
+
|
| 19 |
+
# Placeholder attention backend for models like Mamba and pooling models that
|
| 20 |
+
# lack attention.
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class PlaceholderAttentionBackend(AttentionBackend):
|
| 24 |
+
"""Placeholder backend for when no attention is needed."""
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def get_name() -> str:
|
| 28 |
+
return "NO_ATTENTION"
|
| 29 |
+
|
| 30 |
+
@staticmethod
|
| 31 |
+
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
|
| 32 |
+
return PlaceholderAttentionImpl
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
|
| 36 |
+
return PlaceholderAttentionMetadataBuilder
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
|
| 40 |
+
return PlaceholderAttentionMetadata
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def get_state_cls() -> Type["CommonAttentionState"]:
|
| 44 |
+
return CommonAttentionState
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def get_kv_cache_shape(
|
| 48 |
+
num_blocks: int,
|
| 49 |
+
block_size: int,
|
| 50 |
+
num_kv_heads: int,
|
| 51 |
+
head_size: int,
|
| 52 |
+
) -> Tuple[int, ...]:
|
| 53 |
+
return (1, 1, 1, 1, 1)
|
| 54 |
+
|
| 55 |
+
@staticmethod
|
| 56 |
+
def swap_blocks(
|
| 57 |
+
src_kv_cache: torch.Tensor,
|
| 58 |
+
dst_kv_cache: torch.Tensor,
|
| 59 |
+
src_to_dst: torch.Tensor,
|
| 60 |
+
) -> None:
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def copy_blocks(
|
| 65 |
+
kv_caches: List[torch.Tensor],
|
| 66 |
+
src_to_dists: torch.Tensor,
|
| 67 |
+
) -> None:
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class PlaceholderAttentionMetadata(AttentionMetadata):
|
| 73 |
+
"""Attention metadata for prefill and decode batched together."""
|
| 74 |
+
# (batch_size,). The sequence length per sequence. Sequence length means
|
| 75 |
+
# the computed tokens + new tokens None if it is a decoding.
|
| 76 |
+
seq_lens: Optional[List[int]]
|
| 77 |
+
# seq_lens stored as a tensor.
|
| 78 |
+
seq_lens_tensor: Optional[torch.Tensor]
|
| 79 |
+
|
| 80 |
+
# Maximum query length in the batch.
|
| 81 |
+
max_query_len: Optional[int]
|
| 82 |
+
|
| 83 |
+
# Max number of query tokens among request in the batch.
|
| 84 |
+
max_decode_query_len: Optional[int]
|
| 85 |
+
|
| 86 |
+
# Maximum sequence length among prefill batch. 0 if there are decoding
|
| 87 |
+
# requests only.
|
| 88 |
+
max_prefill_seq_len: int
|
| 89 |
+
# Maximum sequence length among decode batch. 0 if there are prefill
|
| 90 |
+
# requests only.
|
| 91 |
+
max_decode_seq_len: int
|
| 92 |
+
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
| 93 |
+
# the batch, used to index into subquery. E.g., if the subquery length
|
| 94 |
+
# is [4, 6], it is [0, 4, 10].
|
| 95 |
+
query_start_loc: Optional[torch.Tensor]
|
| 96 |
+
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
| 97 |
+
# the batch, used to index into sequence. E.g., if the sequence length is
|
| 98 |
+
# [4, 6], it is [0, 4, 10].
|
| 99 |
+
seq_start_loc: Optional[torch.Tensor]
|
| 100 |
+
# (batch_size,) A tensor of context lengths (tokens that are computed
|
| 101 |
+
# so far).
|
| 102 |
+
context_lens_tensor: Optional[torch.Tensor]
|
| 103 |
+
|
| 104 |
+
# (batch_size, max_blocks_per_seq).
|
| 105 |
+
# Block addresses per sequence. (Seq id -> list of physical block)
|
| 106 |
+
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
| 107 |
+
# in the kv cache. Each block can contain up to block_size tokens.
|
| 108 |
+
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
| 109 |
+
# captured.
|
| 110 |
+
block_tables: Optional[torch.Tensor]
|
| 111 |
+
|
| 112 |
+
# Whether or not if cuda graph is enabled.
|
| 113 |
+
# Cuda-graph is currently enabled for decoding only.
|
| 114 |
+
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
| 115 |
+
use_cuda_graph: bool
|
| 116 |
+
|
| 117 |
+
_cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
|
| 118 |
+
_cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
|
| 122 |
+
if self.num_prefills == 0:
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
if self._cached_prefill_metadata is not None:
|
| 126 |
+
return self._cached_prefill_metadata
|
| 127 |
+
|
| 128 |
+
assert self.seq_lens is not None
|
| 129 |
+
assert self.seq_lens_tensor is not None
|
| 130 |
+
assert self.query_start_loc is not None
|
| 131 |
+
assert self.context_lens_tensor is not None
|
| 132 |
+
assert self.seq_start_loc is not None
|
| 133 |
+
|
| 134 |
+
# Placeholders
|
| 135 |
+
slot_mapping = torch.empty(0)
|
| 136 |
+
block_tables = torch.empty(0)
|
| 137 |
+
|
| 138 |
+
self._cached_prefill_metadata = PlaceholderAttentionMetadata(
|
| 139 |
+
num_prefills=self.num_prefills,
|
| 140 |
+
num_prefill_tokens=self.num_prefill_tokens,
|
| 141 |
+
num_decode_tokens=0,
|
| 142 |
+
slot_mapping=slot_mapping,
|
| 143 |
+
multi_modal_placeholder_index_maps=self.
|
| 144 |
+
multi_modal_placeholder_index_maps,
|
| 145 |
+
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
| 146 |
+
seq_lens=self.seq_lens[:self.num_prefills],
|
| 147 |
+
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
| 148 |
+
max_decode_query_len=0,
|
| 149 |
+
max_query_len=self.max_query_len,
|
| 150 |
+
max_prefill_seq_len=self.max_prefill_seq_len,
|
| 151 |
+
max_decode_seq_len=0,
|
| 152 |
+
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
| 153 |
+
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
| 154 |
+
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
| 155 |
+
block_tables=block_tables,
|
| 156 |
+
use_cuda_graph=False,
|
| 157 |
+
)
|
| 158 |
+
return self._cached_prefill_metadata
|
| 159 |
+
|
| 160 |
+
@property
|
| 161 |
+
def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
|
| 162 |
+
if self.num_decode_tokens == 0:
|
| 163 |
+
return None
|
| 164 |
+
|
| 165 |
+
if self._cached_decode_metadata is not None:
|
| 166 |
+
return self._cached_decode_metadata
|
| 167 |
+
assert self.seq_lens_tensor is not None
|
| 168 |
+
|
| 169 |
+
# Placeholders
|
| 170 |
+
slot_mapping = torch.empty(0)
|
| 171 |
+
block_tables = torch.empty(0)
|
| 172 |
+
|
| 173 |
+
self._cached_decode_metadata = PlaceholderAttentionMetadata(
|
| 174 |
+
num_prefills=0,
|
| 175 |
+
num_prefill_tokens=0,
|
| 176 |
+
num_decode_tokens=self.num_decode_tokens,
|
| 177 |
+
slot_mapping=slot_mapping,
|
| 178 |
+
multi_modal_placeholder_index_maps=None,
|
| 179 |
+
enable_kv_scales_calculation=True,
|
| 180 |
+
seq_lens=None,
|
| 181 |
+
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
| 182 |
+
max_decode_query_len=self.max_decode_query_len,
|
| 183 |
+
max_query_len=None,
|
| 184 |
+
max_prefill_seq_len=0,
|
| 185 |
+
max_decode_seq_len=self.max_decode_seq_len,
|
| 186 |
+
query_start_loc=None,
|
| 187 |
+
seq_start_loc=None,
|
| 188 |
+
context_lens_tensor=None,
|
| 189 |
+
block_tables=block_tables,
|
| 190 |
+
use_cuda_graph=self.use_cuda_graph,
|
| 191 |
+
)
|
| 192 |
+
return self._cached_decode_metadata
|
| 193 |
+
|
| 194 |
+
def advance_step(self,
|
| 195 |
+
model_input: "ModelInputForGPUWithSamplingMetadata",
|
| 196 |
+
sampled_token_ids: Optional[torch.Tensor],
|
| 197 |
+
block_size: int,
|
| 198 |
+
num_seqs: int,
|
| 199 |
+
num_queries: int,
|
| 200 |
+
turn_prefills_into_decodes: bool = False):
|
| 201 |
+
"""
|
| 202 |
+
Update metadata in-place to advance one decode step.
|
| 203 |
+
"""
|
| 204 |
+
# When using cudagraph, the num_seqs is padded to the next captured
|
| 205 |
+
# batch sized, but num_queries tracks the actual number of requests in
|
| 206 |
+
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
| 207 |
+
if num_seqs != num_queries:
|
| 208 |
+
assert num_seqs > num_queries
|
| 209 |
+
assert self.use_cuda_graph
|
| 210 |
+
|
| 211 |
+
assert not turn_prefills_into_decodes, \
|
| 212 |
+
("Multi-Step + Chunked-Prefill is not supported for attention-free"
|
| 213 |
+
"models. turn_prefills_into_decodes is a "
|
| 214 |
+
"Multi-Step + Chunked-Prefill specific parameter.")
|
| 215 |
+
|
| 216 |
+
assert self.seq_lens is not None
|
| 217 |
+
assert self.max_decode_seq_len == max(self.seq_lens)
|
| 218 |
+
|
| 219 |
+
assert self.num_prefills == 0
|
| 220 |
+
assert self.num_prefill_tokens == 0
|
| 221 |
+
assert self.num_decode_tokens == num_seqs
|
| 222 |
+
|
| 223 |
+
assert self.seq_lens is not None
|
| 224 |
+
assert len(self.seq_lens) == num_seqs
|
| 225 |
+
assert self.seq_lens_tensor is not None
|
| 226 |
+
assert self.seq_lens_tensor.shape == (num_seqs, )
|
| 227 |
+
assert self.max_query_len == 1
|
| 228 |
+
assert self.max_prefill_seq_len == 0
|
| 229 |
+
|
| 230 |
+
assert self.query_start_loc is not None
|
| 231 |
+
assert self.query_start_loc.shape == (num_queries + 1, )
|
| 232 |
+
assert self.seq_start_loc is not None
|
| 233 |
+
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
| 234 |
+
|
| 235 |
+
assert self.context_lens_tensor is not None
|
| 236 |
+
assert self.context_lens_tensor.shape == (num_queries, )
|
| 237 |
+
|
| 238 |
+
assert self.block_tables is not None
|
| 239 |
+
|
| 240 |
+
# Update query lengths. Note that we update only queries and not seqs,
|
| 241 |
+
# since tensors may be padded due to captured cuda graph batch size
|
| 242 |
+
for i in range(num_queries):
|
| 243 |
+
self.seq_lens[i] += 1
|
| 244 |
+
self.max_decode_seq_len = max(self.seq_lens)
|
| 245 |
+
|
| 246 |
+
# Update sequences, masking off entries greater than num_queries
|
| 247 |
+
device = self.seq_lens_tensor.device
|
| 248 |
+
mask = torch.arange(self.seq_lens_tensor.size(0),
|
| 249 |
+
device=device) < num_queries
|
| 250 |
+
self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype)
|
| 251 |
+
if sampled_token_ids is not None:
|
| 252 |
+
model_input.input_tokens.masked_scatter_(
|
| 253 |
+
mask, sampled_token_ids[:num_queries])
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class PlaceholderAttentionMetadataBuilder(
|
| 257 |
+
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
|
| 258 |
+
|
| 259 |
+
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
| 260 |
+
|
| 261 |
+
self.input_builder = input_builder
|
| 262 |
+
self.runner = input_builder.runner
|
| 263 |
+
|
| 264 |
+
def prepare(self):
|
| 265 |
+
self.prefill_seq_lens: List[int] = []
|
| 266 |
+
self.context_lens: List[int] = []
|
| 267 |
+
self.curr_seq_lens: List[int] = []
|
| 268 |
+
self.multimodal_placeholder_maps: Dict[
|
| 269 |
+
str,
|
| 270 |
+
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
| 271 |
+
self.num_prefills = 0
|
| 272 |
+
self.num_prefill_tokens = 0
|
| 273 |
+
self.num_decode_tokens = 0
|
| 274 |
+
|
| 275 |
+
def _add_seq_group(
|
| 276 |
+
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
| 277 |
+
chunked_prefill_enabled: bool):
|
| 278 |
+
"""Add a sequence group to the metadata. Specifically update/append
|
| 279 |
+
1. context length.
|
| 280 |
+
"""
|
| 281 |
+
is_prompt = inter_data.is_prompt
|
| 282 |
+
|
| 283 |
+
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
| 284 |
+
curr_sliding_window_block) in zip(
|
| 285 |
+
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
| 286 |
+
inter_data.orig_seq_lens, inter_data.seq_lens,
|
| 287 |
+
inter_data.query_lens, inter_data.context_lens,
|
| 288 |
+
inter_data.curr_sliding_window_blocks):
|
| 289 |
+
self.context_lens.append(context_len)
|
| 290 |
+
|
| 291 |
+
if is_prompt:
|
| 292 |
+
mm_maps = inter_data.multi_modal_placeholder_maps
|
| 293 |
+
if mm_maps:
|
| 294 |
+
for modality, placeholders in mm_maps.items():
|
| 295 |
+
self.multimodal_placeholder_maps[modality].extend(
|
| 296 |
+
placeholders)
|
| 297 |
+
|
| 298 |
+
self.num_prefills += 1
|
| 299 |
+
self.num_prefill_tokens += token_len
|
| 300 |
+
self.prefill_seq_lens.append(seq_len)
|
| 301 |
+
else:
|
| 302 |
+
assert query_len == 1, (
|
| 303 |
+
"seq_len: {}, context_len: {}, query_len: {}".format(
|
| 304 |
+
seq_len, context_len, query_len))
|
| 305 |
+
self.num_decode_tokens += query_len
|
| 306 |
+
self.curr_seq_lens.append(curr_seq_len)
|
| 307 |
+
|
| 308 |
+
def build(self, seq_lens: List[int], query_lens: List[int],
|
| 309 |
+
cuda_graph_pad_size: int, batch_size: int):
|
| 310 |
+
"""Build attention metadata with on-device tensors.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
seq_lens: The maybe padded sequence lengths of the input sequences.
|
| 314 |
+
query_lens: The query lengths of the input sequences.
|
| 315 |
+
cuda_graph_pad_size: The padding size for cuda graph.
|
| 316 |
+
-1 if cuda graph is not used.
|
| 317 |
+
batch_size: The maybe padded batch size.
|
| 318 |
+
"""
|
| 319 |
+
for inter_data in self.input_builder.inter_data_list:
|
| 320 |
+
self._add_seq_group(inter_data,
|
| 321 |
+
self.input_builder.chunked_prefill_enabled)
|
| 322 |
+
|
| 323 |
+
device = self.runner.device
|
| 324 |
+
use_captured_graph = cuda_graph_pad_size != -1
|
| 325 |
+
|
| 326 |
+
logits_soft_cap = getattr(self.runner.model_config.hf_config,
|
| 327 |
+
"attn_logit_softcapping", None)
|
| 328 |
+
if logits_soft_cap is not None:
|
| 329 |
+
raise ValueError(
|
| 330 |
+
"Please use Flashinfer backend for models with logits_soft_cap"
|
| 331 |
+
" (i.e., Gemma-2). Otherwise, the output might be wrong."
|
| 332 |
+
" Set Flashinfer backend by "
|
| 333 |
+
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
|
| 334 |
+
|
| 335 |
+
max_query_len = max(query_lens)
|
| 336 |
+
decode_query_lens = query_lens[self.num_prefills:]
|
| 337 |
+
if len(decode_query_lens) > 0:
|
| 338 |
+
max_decode_query_len = max(decode_query_lens)
|
| 339 |
+
else:
|
| 340 |
+
max_decode_query_len = 1
|
| 341 |
+
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
| 342 |
+
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
| 343 |
+
num_decode_tokens = self.num_decode_tokens
|
| 344 |
+
|
| 345 |
+
if use_captured_graph:
|
| 346 |
+
num_decode_tokens = batch_size
|
| 347 |
+
|
| 348 |
+
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
| 349 |
+
|
| 350 |
+
context_lens_tensor = torch.tensor(self.context_lens,
|
| 351 |
+
dtype=torch.int,
|
| 352 |
+
device=device)
|
| 353 |
+
seq_lens_tensor = torch.tensor(seq_lens,
|
| 354 |
+
dtype=torch.int,
|
| 355 |
+
device=device)
|
| 356 |
+
query_lens_tensor = torch.tensor(query_lens,
|
| 357 |
+
dtype=torch.long,
|
| 358 |
+
device=device)
|
| 359 |
+
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
| 360 |
+
dtype=torch.int32,
|
| 361 |
+
device=device)
|
| 362 |
+
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
| 363 |
+
dtype=torch.int32,
|
| 364 |
+
device=device)
|
| 365 |
+
placeholder_index_maps = {
|
| 366 |
+
modality: placeholder_map.index_map()
|
| 367 |
+
for modality, placeholder_map in
|
| 368 |
+
self.multimodal_placeholder_maps.items()
|
| 369 |
+
}
|
| 370 |
+
torch.cumsum(seq_lens_tensor,
|
| 371 |
+
dim=0,
|
| 372 |
+
dtype=seq_start_loc.dtype,
|
| 373 |
+
out=seq_start_loc[1:])
|
| 374 |
+
torch.cumsum(query_lens_tensor,
|
| 375 |
+
dim=0,
|
| 376 |
+
dtype=query_start_loc.dtype,
|
| 377 |
+
out=query_start_loc[1:])
|
| 378 |
+
|
| 379 |
+
# Placeholders
|
| 380 |
+
slot_mapping = torch.empty(0)
|
| 381 |
+
block_tables = torch.empty(0)
|
| 382 |
+
|
| 383 |
+
return PlaceholderAttentionMetadata(
|
| 384 |
+
num_prefills=self.num_prefills,
|
| 385 |
+
slot_mapping=slot_mapping,
|
| 386 |
+
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
| 387 |
+
enable_kv_scales_calculation=True,
|
| 388 |
+
num_prefill_tokens=self.num_prefill_tokens,
|
| 389 |
+
num_decode_tokens=num_decode_tokens,
|
| 390 |
+
seq_lens=seq_lens,
|
| 391 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 392 |
+
max_query_len=max_query_len,
|
| 393 |
+
max_decode_query_len=max_decode_query_len,
|
| 394 |
+
max_prefill_seq_len=max_prefill_seq_len,
|
| 395 |
+
max_decode_seq_len=max_decode_seq_len,
|
| 396 |
+
query_start_loc=query_start_loc,
|
| 397 |
+
seq_start_loc=seq_start_loc,
|
| 398 |
+
context_lens_tensor=context_lens_tensor,
|
| 399 |
+
block_tables=block_tables,
|
| 400 |
+
use_cuda_graph=use_captured_graph,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
class PlaceholderAttentionImpl(AttentionImpl):
|
| 405 |
+
|
| 406 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 407 |
+
return
|
| 408 |
+
|
| 409 |
+
def forward(self, *args, **kwargs) -> torch.Tensor:
|
| 410 |
+
raise NotImplementedError
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/rocm_flash_attn.py
ADDED
|
@@ -0,0 +1,891 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Attention layer ROCm GPUs."""
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
import vllm.envs as envs
|
| 9 |
+
from vllm import _custom_ops as ops
|
| 10 |
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
| 11 |
+
AttentionLayer,
|
| 12 |
+
AttentionMetadata, AttentionType)
|
| 13 |
+
from vllm.attention.backends.utils import (CommonAttentionState,
|
| 14 |
+
CommonMetadataBuilder)
|
| 15 |
+
from vllm.attention.ops.paged_attn import (PagedAttention,
|
| 16 |
+
PagedAttentionMetadata)
|
| 17 |
+
from vllm.logger import init_logger
|
| 18 |
+
from vllm.platforms import current_platform
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
| 22 |
+
|
| 23 |
+
logger = init_logger(__name__)
|
| 24 |
+
|
| 25 |
+
_PARTITION_SIZE_ROCM = 512
|
| 26 |
+
_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
| 27 |
+
_ON_NAVI = "gfx1" in _GPU_ARCH
|
| 28 |
+
_ON_MI250_MI300 = any(arch in _GPU_ARCH
|
| 29 |
+
for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"])
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ROCmFlashAttentionBackend(AttentionBackend):
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def get_name() -> str:
|
| 36 |
+
return "ROCM_FLASH"
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
|
| 40 |
+
return ROCmFlashAttentionImpl
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
| 44 |
+
return ROCmFlashAttentionMetadata
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
|
| 48 |
+
return ROCmFlashAttentionMetadataBuilder
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def get_state_cls() -> Type["CommonAttentionState"]:
|
| 52 |
+
return CommonAttentionState
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def get_kv_cache_shape(
|
| 56 |
+
num_blocks: int,
|
| 57 |
+
block_size: int,
|
| 58 |
+
num_kv_heads: int,
|
| 59 |
+
head_size: int,
|
| 60 |
+
) -> Tuple[int, ...]:
|
| 61 |
+
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
| 62 |
+
num_kv_heads, head_size)
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def swap_blocks(
|
| 66 |
+
src_kv_cache: torch.Tensor,
|
| 67 |
+
dst_kv_cache: torch.Tensor,
|
| 68 |
+
src_to_dst: torch.Tensor,
|
| 69 |
+
) -> None:
|
| 70 |
+
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def copy_blocks(
|
| 74 |
+
kv_caches: List[torch.Tensor],
|
| 75 |
+
src_to_dists: torch.Tensor,
|
| 76 |
+
) -> None:
|
| 77 |
+
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@dataclass
|
| 81 |
+
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
| 82 |
+
"""Metadata for FlashAttentionBackend.
|
| 83 |
+
|
| 84 |
+
NOTE: Any python object stored here is not updated when it is
|
| 85 |
+
cuda-graph replayed. If you have values that need to be changed
|
| 86 |
+
dynamically, it should be stored in tensor. The tensor has to be
|
| 87 |
+
updated from `CUDAGraphRunner.forward` API.
|
| 88 |
+
"""
|
| 89 |
+
# (batch_size,). The sequence length per sequence. Sequence length means
|
| 90 |
+
# the computed tokens + new tokens None if it is a decoding.
|
| 91 |
+
seq_lens: Optional[List[int]]
|
| 92 |
+
# seq_lens stored as a tensor.
|
| 93 |
+
seq_lens_tensor: Optional[torch.Tensor]
|
| 94 |
+
# Maximum sequence length among prefill batch. 0 if there are decoding
|
| 95 |
+
# requests only.
|
| 96 |
+
max_prefill_seq_len: int
|
| 97 |
+
# Maximum sequence length among decode batch. 0 if there are prefill
|
| 98 |
+
# requests only.
|
| 99 |
+
max_decode_seq_len: int
|
| 100 |
+
|
| 101 |
+
# Whether or not if cuda graph is enabled.
|
| 102 |
+
# Cuda-graph is currently enabled for decoding only.
|
| 103 |
+
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
| 104 |
+
use_cuda_graph: bool
|
| 105 |
+
|
| 106 |
+
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
| 107 |
+
# |---------- N-1 iteration --------|
|
| 108 |
+
# |---------------- N iteration ---------------------|
|
| 109 |
+
# |- tokenA -|......................|-- newTokens ---|
|
| 110 |
+
# |---------- context_len ----------|
|
| 111 |
+
# |-------------------- seq_len ----------------------|
|
| 112 |
+
# |-- query_len ---|
|
| 113 |
+
|
| 114 |
+
# Maximum query length in the batch. None for decoding.
|
| 115 |
+
max_query_len: Optional[int] = None
|
| 116 |
+
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
| 117 |
+
# the batch, used to index into subquery. E.g., if the subquery length
|
| 118 |
+
# is [4, 6], it is [0, 4, 10].
|
| 119 |
+
query_start_loc: Optional[torch.Tensor] = None
|
| 120 |
+
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
| 121 |
+
# the batch, used to index into sequence. E.g., if the sequence length is
|
| 122 |
+
# [4, 6], it is [0, 4, 10].
|
| 123 |
+
seq_start_loc: Optional[torch.Tensor] = None
|
| 124 |
+
# (batch_size,) A tensor of context lengths (tokens that are computed
|
| 125 |
+
# so far).
|
| 126 |
+
context_lens_tensor: Optional[torch.Tensor] = None
|
| 127 |
+
|
| 128 |
+
# Max number of query tokens among request in the batch.
|
| 129 |
+
max_decode_query_len: Optional[int] = None
|
| 130 |
+
|
| 131 |
+
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
| 132 |
+
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
| 133 |
+
|
| 134 |
+
# Begin encoder attn & enc/dec cross-attn fields...
|
| 135 |
+
|
| 136 |
+
# Encoder sequence lengths representation
|
| 137 |
+
encoder_seq_lens: Optional[List[int]] = None
|
| 138 |
+
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
| 139 |
+
|
| 140 |
+
# Maximum sequence length among encoder sequences
|
| 141 |
+
max_encoder_seq_len: Optional[int] = None
|
| 142 |
+
|
| 143 |
+
# Number of tokens input to encoder
|
| 144 |
+
num_encoder_tokens: Optional[int] = None
|
| 145 |
+
|
| 146 |
+
# Cross-attention memory-mapping data structures: slot mapping
|
| 147 |
+
# and block tables
|
| 148 |
+
cross_slot_mapping: Optional[torch.Tensor] = None
|
| 149 |
+
cross_block_tables: Optional[torch.Tensor] = None
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
|
| 153 |
+
if self.num_prefills == 0:
|
| 154 |
+
return None
|
| 155 |
+
|
| 156 |
+
if self._cached_prefill_metadata is not None:
|
| 157 |
+
return self._cached_prefill_metadata
|
| 158 |
+
|
| 159 |
+
assert self.seq_lens is not None
|
| 160 |
+
assert self.seq_lens_tensor is not None
|
| 161 |
+
assert self.block_tables is not None
|
| 162 |
+
|
| 163 |
+
self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
|
| 164 |
+
num_prefills=self.num_prefills,
|
| 165 |
+
num_prefill_tokens=self.num_prefill_tokens,
|
| 166 |
+
num_decode_tokens=0,
|
| 167 |
+
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
| 168 |
+
multi_modal_placeholder_index_maps=self.
|
| 169 |
+
multi_modal_placeholder_index_maps,
|
| 170 |
+
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
| 171 |
+
seq_lens=self.seq_lens[:self.num_prefills],
|
| 172 |
+
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
| 173 |
+
max_query_len=self.max_query_len,
|
| 174 |
+
max_prefill_seq_len=self.max_prefill_seq_len,
|
| 175 |
+
max_decode_seq_len=0,
|
| 176 |
+
query_start_loc=None if self.query_start_loc is None else
|
| 177 |
+
self.query_start_loc[:self.num_prefills + 1],
|
| 178 |
+
seq_start_loc=None if self.seq_start_loc is None else
|
| 179 |
+
self.seq_start_loc[:self.num_prefills + 1],
|
| 180 |
+
context_lens_tensor=None if self.context_lens_tensor is None else
|
| 181 |
+
self.context_lens_tensor[:self.num_prefills],
|
| 182 |
+
block_tables=self.block_tables[:self.num_prefills],
|
| 183 |
+
use_cuda_graph=False,
|
| 184 |
+
# Begin encoder & cross attn fields below...
|
| 185 |
+
encoder_seq_lens=self.encoder_seq_lens,
|
| 186 |
+
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
| 187 |
+
max_encoder_seq_len=self.max_encoder_seq_len,
|
| 188 |
+
cross_slot_mapping=self.cross_slot_mapping,
|
| 189 |
+
cross_block_tables=self.cross_block_tables)
|
| 190 |
+
return self._cached_prefill_metadata
|
| 191 |
+
|
| 192 |
+
@property
|
| 193 |
+
def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
|
| 194 |
+
if self.num_decode_tokens == 0:
|
| 195 |
+
return None
|
| 196 |
+
|
| 197 |
+
if self._cached_decode_metadata is not None:
|
| 198 |
+
return self._cached_decode_metadata
|
| 199 |
+
assert self.block_tables is not None
|
| 200 |
+
assert self.seq_lens_tensor is not None
|
| 201 |
+
|
| 202 |
+
self._cached_decode_metadata = ROCmFlashAttentionMetadata(
|
| 203 |
+
num_prefills=0,
|
| 204 |
+
num_prefill_tokens=0,
|
| 205 |
+
num_decode_tokens=self.num_decode_tokens,
|
| 206 |
+
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
| 207 |
+
multi_modal_placeholder_index_maps=None,
|
| 208 |
+
enable_kv_scales_calculation=True,
|
| 209 |
+
seq_lens=None,
|
| 210 |
+
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
| 211 |
+
max_query_len=None,
|
| 212 |
+
max_prefill_seq_len=0,
|
| 213 |
+
max_decode_seq_len=self.max_decode_seq_len,
|
| 214 |
+
query_start_loc=None,
|
| 215 |
+
seq_start_loc=None,
|
| 216 |
+
context_lens_tensor=None,
|
| 217 |
+
block_tables=self.block_tables[self.num_prefills:],
|
| 218 |
+
use_cuda_graph=self.use_cuda_graph,
|
| 219 |
+
# Begin encoder & cross attn fields below...
|
| 220 |
+
encoder_seq_lens=self.encoder_seq_lens,
|
| 221 |
+
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
| 222 |
+
max_encoder_seq_len=self.max_encoder_seq_len,
|
| 223 |
+
cross_slot_mapping=self.cross_slot_mapping,
|
| 224 |
+
cross_block_tables=self.cross_block_tables)
|
| 225 |
+
# Batch may be composed of prefill|decodes, adjust query start indices
|
| 226 |
+
# to refer to the start of decodes when the two are split apart.
|
| 227 |
+
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
| 228 |
+
if self._cached_decode_metadata.query_start_loc is not None:
|
| 229 |
+
qs = self._cached_decode_metadata.query_start_loc
|
| 230 |
+
self._cached_decode_metadata.query_start_loc = qs - qs[0]
|
| 231 |
+
return self._cached_decode_metadata
|
| 232 |
+
|
| 233 |
+
def advance_step(self,
|
| 234 |
+
model_input: "ModelInputForGPUWithSamplingMetadata",
|
| 235 |
+
sampled_token_ids: Optional[torch.Tensor],
|
| 236 |
+
block_size: int,
|
| 237 |
+
num_seqs: int,
|
| 238 |
+
num_queries: int,
|
| 239 |
+
turn_prefills_into_decodes: bool = False):
|
| 240 |
+
"""
|
| 241 |
+
Update metadata in-place to advance one decode step.
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
assert not turn_prefills_into_decodes, \
|
| 245 |
+
("Chunked prefill is not supported with rocm_flash_attn yet."
|
| 246 |
+
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
|
| 247 |
+
"specific parameter.")
|
| 248 |
+
|
| 249 |
+
# When using cudagraph, the num_seqs is padded to the next captured
|
| 250 |
+
# batch sized, but num_queries tracks the actual number of requests in
|
| 251 |
+
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
| 252 |
+
if num_seqs != num_queries:
|
| 253 |
+
assert num_seqs > num_queries
|
| 254 |
+
assert self.use_cuda_graph
|
| 255 |
+
|
| 256 |
+
assert self.num_prefills == 0
|
| 257 |
+
assert self.num_prefill_tokens == 0
|
| 258 |
+
assert self.num_decode_tokens == num_seqs
|
| 259 |
+
assert self.slot_mapping.shape == (num_seqs, )
|
| 260 |
+
|
| 261 |
+
assert self.seq_lens is not None
|
| 262 |
+
assert len(self.seq_lens) == num_seqs
|
| 263 |
+
assert self.seq_lens_tensor is not None
|
| 264 |
+
assert self.seq_lens_tensor.shape == (num_seqs, )
|
| 265 |
+
assert self.max_query_len == 1
|
| 266 |
+
assert self.max_prefill_seq_len == 0
|
| 267 |
+
assert self.max_decode_seq_len == max(self.seq_lens)
|
| 268 |
+
|
| 269 |
+
assert self.query_start_loc is not None
|
| 270 |
+
assert self.query_start_loc.shape == (num_queries + 1, )
|
| 271 |
+
assert self.seq_start_loc is not None
|
| 272 |
+
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
| 273 |
+
|
| 274 |
+
assert self.context_lens_tensor is not None
|
| 275 |
+
assert self.context_lens_tensor.shape == (num_queries, )
|
| 276 |
+
|
| 277 |
+
assert self.block_tables is not None
|
| 278 |
+
assert self.block_tables.shape[0] == num_seqs
|
| 279 |
+
|
| 280 |
+
# Update query lengths. Note that we update only queries and not seqs,
|
| 281 |
+
# since tensors may be padded due to captured cuda graph batch size
|
| 282 |
+
for i in range(num_queries):
|
| 283 |
+
self.seq_lens[i] += 1
|
| 284 |
+
self.max_decode_seq_len = max(self.seq_lens)
|
| 285 |
+
|
| 286 |
+
ops.advance_step_flashattn(num_seqs=num_seqs,
|
| 287 |
+
num_queries=num_queries,
|
| 288 |
+
block_size=block_size,
|
| 289 |
+
input_tokens=model_input.input_tokens,
|
| 290 |
+
sampled_token_ids=sampled_token_ids,
|
| 291 |
+
input_positions=model_input.input_positions,
|
| 292 |
+
seq_lens=self.seq_lens_tensor,
|
| 293 |
+
slot_mapping=self.slot_mapping,
|
| 294 |
+
block_tables=self.block_tables)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class ROCmFlashAttentionMetadataBuilder(
|
| 298 |
+
CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
|
| 299 |
+
|
| 300 |
+
_metadata_cls = ROCmFlashAttentionMetadata
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def _make_alibi_bias(alibi_slopes: torch.Tensor,
|
| 304 |
+
dtype: torch.dtype,
|
| 305 |
+
seq_lens: Optional[List[int]],
|
| 306 |
+
make_attn_mask: bool = True) -> List[torch.Tensor]:
|
| 307 |
+
attn_biases = []
|
| 308 |
+
if seq_lens:
|
| 309 |
+
for seq_len in seq_lens:
|
| 310 |
+
bias = torch.arange(seq_len, dtype=dtype)
|
| 311 |
+
# NOTE(zhuohan): HF uses
|
| 312 |
+
# `bias = bias[None, :].repeat(seq_len, 1)`
|
| 313 |
+
# here. We find that both biases give the same results, but
|
| 314 |
+
# the bias below more accurately follows the original ALiBi
|
| 315 |
+
# paper.
|
| 316 |
+
bias = bias[None, :] - bias[:, None]
|
| 317 |
+
|
| 318 |
+
num_heads = alibi_slopes.shape[0]
|
| 319 |
+
bias = bias[None, :].repeat(
|
| 320 |
+
(num_heads, 1, 1)).to(alibi_slopes.device)
|
| 321 |
+
bias.mul_(alibi_slopes[:, None, None])
|
| 322 |
+
if make_attn_mask:
|
| 323 |
+
inf_mask = torch.empty(
|
| 324 |
+
(1, seq_len, seq_len),
|
| 325 |
+
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
|
| 326 |
+
alibi_slopes.device)
|
| 327 |
+
attn_biases.append((bias + inf_mask).to(dtype))
|
| 328 |
+
else:
|
| 329 |
+
attn_biases.append(bias.to(dtype))
|
| 330 |
+
|
| 331 |
+
return attn_biases
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def _get_seq_len_block_table_args(
|
| 335 |
+
attn_metadata: ROCmFlashAttentionMetadata,
|
| 336 |
+
attn_type: str,
|
| 337 |
+
) -> tuple:
|
| 338 |
+
'''
|
| 339 |
+
The particular choice of sequence-length
|
| 340 |
+
attributes which should be extracted from attn_metadata is dependent
|
| 341 |
+
on the type of attention operation.
|
| 342 |
+
|
| 343 |
+
Decoder attn -> select entirely decoder self-attention-related fields
|
| 344 |
+
Encoder/decoder cross-attn -> select encoder sequence lengths
|
| 345 |
+
Encoder attn -> select encoder sequence lengths fields
|
| 346 |
+
|
| 347 |
+
Arguments:
|
| 348 |
+
|
| 349 |
+
* attn_metadata: Attention metadata structure associated with attention op
|
| 350 |
+
* attn_type: encoder attention, decoder self-attention,
|
| 351 |
+
encoder/decoder cross-attention
|
| 352 |
+
|
| 353 |
+
Returns:
|
| 354 |
+
|
| 355 |
+
* Appropriate sequence-lengths tensors for query and key
|
| 356 |
+
* Appropriate max sequence-length scalar
|
| 357 |
+
'''
|
| 358 |
+
|
| 359 |
+
partial_prefix_sum = 0
|
| 360 |
+
if attn_type == AttentionType.ENCODER:
|
| 361 |
+
assert attn_metadata.encoder_seq_lens is not None
|
| 362 |
+
assert attn_metadata.encoder_seq_lens_tensor is not None
|
| 363 |
+
query_seq_start_loc = torch.tensor(
|
| 364 |
+
[0] + [
|
| 365 |
+
partial_prefix_sum := partial_prefix_sum + i
|
| 366 |
+
for i in attn_metadata.encoder_seq_lens
|
| 367 |
+
],
|
| 368 |
+
device=attn_metadata.encoder_seq_lens_tensor.device,
|
| 369 |
+
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
|
| 370 |
+
causal_mask = False
|
| 371 |
+
|
| 372 |
+
# No block tables associated with encoder attention
|
| 373 |
+
return (query_seq_start_loc, attn_metadata.max_encoder_seq_len,
|
| 374 |
+
query_seq_start_loc, attn_metadata.max_encoder_seq_len,
|
| 375 |
+
attn_metadata.encoder_seq_lens, causal_mask)
|
| 376 |
+
elif attn_type == AttentionType.DECODER:
|
| 377 |
+
# Decoder self-attention
|
| 378 |
+
# Choose max_seq_len based on whether we are in prompt_run
|
| 379 |
+
assert attn_metadata.seq_lens is not None
|
| 380 |
+
assert attn_metadata.seq_lens_tensor is not None
|
| 381 |
+
query_seq_start_loc = torch.tensor(
|
| 382 |
+
[0] + [
|
| 383 |
+
partial_prefix_sum := partial_prefix_sum + i
|
| 384 |
+
for i in attn_metadata.seq_lens
|
| 385 |
+
],
|
| 386 |
+
device=attn_metadata.seq_lens_tensor.device,
|
| 387 |
+
dtype=attn_metadata.seq_lens_tensor.dtype)
|
| 388 |
+
max_seq_len = attn_metadata.max_prefill_seq_len
|
| 389 |
+
causal_mask = True
|
| 390 |
+
|
| 391 |
+
return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
|
| 392 |
+
max_seq_len, attn_metadata.seq_lens, causal_mask)
|
| 393 |
+
elif attn_type == AttentionType.ENCODER_DECODER:
|
| 394 |
+
assert attn_metadata.seq_lens is not None
|
| 395 |
+
assert attn_metadata.encoder_seq_lens_tensor is not None
|
| 396 |
+
query_start_loc = torch.tensor(
|
| 397 |
+
[0] + [
|
| 398 |
+
partial_prefix_sum := partial_prefix_sum + i
|
| 399 |
+
for i in attn_metadata.seq_lens
|
| 400 |
+
],
|
| 401 |
+
device=attn_metadata.encoder_seq_lens_tensor.device,
|
| 402 |
+
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
|
| 403 |
+
|
| 404 |
+
partial_prefix_sum = 0
|
| 405 |
+
assert attn_metadata.encoder_seq_lens is not None
|
| 406 |
+
assert attn_metadata.seq_lens_tensor is not None
|
| 407 |
+
key_seq_start_loc = torch.tensor(
|
| 408 |
+
[0] + [
|
| 409 |
+
partial_prefix_sum := partial_prefix_sum + i
|
| 410 |
+
for i in attn_metadata.encoder_seq_lens
|
| 411 |
+
],
|
| 412 |
+
device=attn_metadata.seq_lens_tensor.device,
|
| 413 |
+
dtype=attn_metadata.seq_lens_tensor.dtype)
|
| 414 |
+
causal_mask = False
|
| 415 |
+
|
| 416 |
+
# Enc/dec cross-attention KVs match encoder sequence length;
|
| 417 |
+
# cross-attention utilizes special "cross" block tables
|
| 418 |
+
return (query_start_loc, attn_metadata.max_prefill_seq_len,
|
| 419 |
+
key_seq_start_loc, attn_metadata.max_encoder_seq_len,
|
| 420 |
+
attn_metadata.seq_lens, causal_mask)
|
| 421 |
+
else:
|
| 422 |
+
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
class ROCmFlashAttentionImpl(AttentionImpl):
|
| 426 |
+
"""
|
| 427 |
+
If the input tensors contain prompt tokens, the layout is as follows:
|
| 428 |
+
|<--------------- num_prompt_tokens -------------->|
|
| 429 |
+
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
|
| 430 |
+
|
| 431 |
+
Otherwise, the layout is as follows:
|
| 432 |
+
|<------------------ num_generation_tokens (M) ----------------->|
|
| 433 |
+
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|
| 434 |
+
|
| 435 |
+
Generation tokens can contain padding when cuda-graph is used.
|
| 436 |
+
Currently, prompt tokens don't contain any padding.
|
| 437 |
+
|
| 438 |
+
The prompts might have different lengths, while the generation tokens
|
| 439 |
+
always have length 1.
|
| 440 |
+
|
| 441 |
+
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
| 442 |
+
batched together in a flattened 1D query.
|
| 443 |
+
|
| 444 |
+
|<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
|
| 445 |
+
|<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
|
| 446 |
+
|
| 447 |
+
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
| 448 |
+
padding between prefill and decode tokens.
|
| 449 |
+
"""
|
| 450 |
+
|
| 451 |
+
def __init__(
|
| 452 |
+
self,
|
| 453 |
+
num_heads: int,
|
| 454 |
+
head_size: int,
|
| 455 |
+
scale: float,
|
| 456 |
+
num_kv_heads: int,
|
| 457 |
+
alibi_slopes: Optional[List[float]],
|
| 458 |
+
sliding_window: Optional[int],
|
| 459 |
+
kv_cache_dtype: str,
|
| 460 |
+
blocksparse_params: Optional[Dict[str, Any]] = None,
|
| 461 |
+
logits_soft_cap: Optional[float] = None,
|
| 462 |
+
attn_type: str = AttentionType.DECODER,
|
| 463 |
+
) -> None:
|
| 464 |
+
if blocksparse_params is not None:
|
| 465 |
+
raise ValueError(
|
| 466 |
+
"ROCmFlashAttention does not support blocksparse attention.")
|
| 467 |
+
|
| 468 |
+
if logits_soft_cap is None:
|
| 469 |
+
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
| 470 |
+
self.logits_soft_cap = 0.0
|
| 471 |
+
else:
|
| 472 |
+
self.logits_soft_cap = logits_soft_cap
|
| 473 |
+
self.attn_type = attn_type
|
| 474 |
+
self.num_heads = num_heads
|
| 475 |
+
self.head_size = head_size
|
| 476 |
+
self.scale = float(scale)
|
| 477 |
+
self.num_kv_heads = num_kv_heads
|
| 478 |
+
if alibi_slopes is not None:
|
| 479 |
+
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
| 480 |
+
self.alibi_slopes = alibi_slopes
|
| 481 |
+
self.sliding_window = ((sliding_window, sliding_window)
|
| 482 |
+
if sliding_window is not None else (-1, -1))
|
| 483 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 484 |
+
|
| 485 |
+
assert self.num_heads % self.num_kv_heads == 0
|
| 486 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 487 |
+
|
| 488 |
+
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
| 489 |
+
if head_size not in supported_head_sizes:
|
| 490 |
+
raise ValueError(
|
| 491 |
+
f"Head size {head_size} is not supported by PagedAttention. "
|
| 492 |
+
f"Supported head sizes are: {supported_head_sizes}.")
|
| 493 |
+
|
| 494 |
+
self.use_naive_attn = False
|
| 495 |
+
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
|
| 496 |
+
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
|
| 497 |
+
if self.use_triton_flash_attn:
|
| 498 |
+
if logits_soft_cap is not None:
|
| 499 |
+
raise ValueError(
|
| 500 |
+
"ROCm Triton FlashAttention does not support attention"
|
| 501 |
+
"logits soft capping."
|
| 502 |
+
" please try using the ROCm CK "
|
| 503 |
+
"FA backend instead by setting the env var "
|
| 504 |
+
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
|
| 505 |
+
|
| 506 |
+
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
|
| 507 |
+
triton_attention)
|
| 508 |
+
self.attn_func = triton_attention
|
| 509 |
+
logger.debug("Using Triton FA in ROCmBackend")
|
| 510 |
+
if self.sliding_window != (-1, -1):
|
| 511 |
+
logger.warning("ROCm Triton FA does not currently support "
|
| 512 |
+
"sliding window attention. If using half "
|
| 513 |
+
"precision, please try using the ROCm CK "
|
| 514 |
+
"FA backend instead by setting the env var "
|
| 515 |
+
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
|
| 516 |
+
else:
|
| 517 |
+
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
|
| 518 |
+
# either
|
| 519 |
+
if not current_platform.has_device_capability(90):
|
| 520 |
+
self.use_naive_attn = True
|
| 521 |
+
else:
|
| 522 |
+
try:
|
| 523 |
+
from flash_attn import flash_attn_varlen_func # noqa: F401
|
| 524 |
+
self.attn_func = flash_attn_varlen_func
|
| 525 |
+
logger.debug("Using CK FA in ROCmBackend")
|
| 526 |
+
except ModuleNotFoundError:
|
| 527 |
+
self.use_naive_attn = True
|
| 528 |
+
|
| 529 |
+
if self.use_naive_attn:
|
| 530 |
+
if logits_soft_cap is not None:
|
| 531 |
+
raise ValueError(
|
| 532 |
+
"ROCm Naive FlashAttention does not support"
|
| 533 |
+
"attention logits soft capping.")
|
| 534 |
+
|
| 535 |
+
self.attn_func = _sdpa_attention
|
| 536 |
+
logger.debug("Using naive (SDPA) attention in ROCmBackend")
|
| 537 |
+
|
| 538 |
+
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 539 |
+
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
| 540 |
+
tokens, n_kv_heads, head_dim = x.shape
|
| 541 |
+
return (x[:, :,
|
| 542 |
+
None, :].expand(tokens, n_kv_heads, n_rep,
|
| 543 |
+
head_dim).reshape(tokens, n_kv_heads * n_rep,
|
| 544 |
+
head_dim))
|
| 545 |
+
|
| 546 |
+
def forward(
|
| 547 |
+
self,
|
| 548 |
+
layer: AttentionLayer,
|
| 549 |
+
query: torch.Tensor,
|
| 550 |
+
key: torch.Tensor,
|
| 551 |
+
value: torch.Tensor,
|
| 552 |
+
kv_cache: torch.Tensor,
|
| 553 |
+
attn_metadata: ROCmFlashAttentionMetadata,
|
| 554 |
+
output: Optional[torch.Tensor] = None,
|
| 555 |
+
) -> torch.Tensor:
|
| 556 |
+
"""Forward pass with FlashAttention and PagedAttention.
|
| 557 |
+
|
| 558 |
+
For decoder-only models: query, key and value must be non-None.
|
| 559 |
+
|
| 560 |
+
For encoder/decoder models:
|
| 561 |
+
* ROCmFlashAttentionImpl.forward() may be invoked for both self- and
|
| 562 |
+
cross-attention layers.
|
| 563 |
+
* For self-attention: query, key and value must be non-None.
|
| 564 |
+
* For cross-attention:
|
| 565 |
+
* Query must be non-None
|
| 566 |
+
* During prefill, key and value must be non-None; key and value
|
| 567 |
+
get cached for use during decode.
|
| 568 |
+
* During decode, key and value may be None, since:
|
| 569 |
+
(1) key and value tensors were cached during prefill, and
|
| 570 |
+
(2) cross-attention key and value tensors do not grow during
|
| 571 |
+
decode
|
| 572 |
+
|
| 573 |
+
A note on how the attn_type (attention type enum) argument impacts
|
| 574 |
+
attention forward() behavior:
|
| 575 |
+
|
| 576 |
+
* DECODER: normal decoder-only behavior;
|
| 577 |
+
use decoder self-attention block table
|
| 578 |
+
* ENCODER: no KV caching; pass encoder sequence
|
| 579 |
+
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
|
| 580 |
+
max_encoder_seq_len) to kernel, in lieu of decoder
|
| 581 |
+
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
|
| 582 |
+
* ENCODER_DECODER: cross-attention behavior;
|
| 583 |
+
use cross-attention block table for caching KVs derived
|
| 584 |
+
from encoder hidden states; since KV sequence lengths
|
| 585 |
+
will match encoder sequence lengths, pass encoder sequence
|
| 586 |
+
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
|
| 587 |
+
max_encoder_seq_len)
|
| 588 |
+
|
| 589 |
+
Args:
|
| 590 |
+
query: shape = [num_tokens, num_heads * head_size]
|
| 591 |
+
key: shape = [num_tokens, num_kv_heads * head_size]
|
| 592 |
+
value: shape = [num_tokens, num_kv_heads * head_size]
|
| 593 |
+
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
| 594 |
+
NOTE: kv_cache will be an empty tensor with shape [0]
|
| 595 |
+
for profiling run.
|
| 596 |
+
attn_metadata: Metadata for attention.
|
| 597 |
+
attn_type: Select attention type, between encoder attention,
|
| 598 |
+
decoder self-attention, or encoder/decoder cross-
|
| 599 |
+
attention. Defaults to decoder self-attention,
|
| 600 |
+
which is the vLLM default generally
|
| 601 |
+
Returns:
|
| 602 |
+
shape = [num_tokens, num_heads * head_size]
|
| 603 |
+
"""
|
| 604 |
+
query = query.view(-1, self.num_heads, self.head_size)
|
| 605 |
+
if key is not None:
|
| 606 |
+
assert value is not None
|
| 607 |
+
key = key.view(-1, self.num_kv_heads, self.head_size)
|
| 608 |
+
value = value.view(-1, self.num_kv_heads, self.head_size)
|
| 609 |
+
else:
|
| 610 |
+
assert value is None
|
| 611 |
+
|
| 612 |
+
if self.attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
|
| 613 |
+
key_cache, value_cache = PagedAttention.split_kv_cache(
|
| 614 |
+
kv_cache, self.num_kv_heads, self.head_size)
|
| 615 |
+
|
| 616 |
+
if key is not None and value is not None:
|
| 617 |
+
# Reshape the input keys and values and store them in the
|
| 618 |
+
# cache. If kv_cache is not provided, the new key and value
|
| 619 |
+
# tensors are not cached. This happens during the initial
|
| 620 |
+
# memory profiling run.
|
| 621 |
+
PagedAttention.write_to_paged_cache(
|
| 622 |
+
key,
|
| 623 |
+
value,
|
| 624 |
+
key_cache,
|
| 625 |
+
value_cache,
|
| 626 |
+
attn_metadata.slot_mapping
|
| 627 |
+
if self.attn_type != AttentionType.ENCODER_DECODER else
|
| 628 |
+
attn_metadata.cross_slot_mapping,
|
| 629 |
+
self.kv_cache_dtype,
|
| 630 |
+
layer._k_scale,
|
| 631 |
+
layer._v_scale,
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
if self.attn_type != AttentionType.ENCODER:
|
| 635 |
+
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
| 636 |
+
else:
|
| 637 |
+
assert attn_metadata.num_encoder_tokens is not None
|
| 638 |
+
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
| 639 |
+
|
| 640 |
+
output = torch.empty_like(query)
|
| 641 |
+
# Query for decode. KV is not needed because it is already cached.
|
| 642 |
+
decode_query = query[num_prefill_tokens:]
|
| 643 |
+
# QKV for prefill.
|
| 644 |
+
query = query[:num_prefill_tokens]
|
| 645 |
+
|
| 646 |
+
if key is not None and value is not None \
|
| 647 |
+
and self.attn_type != AttentionType.ENCODER_DECODER:
|
| 648 |
+
key = key[:num_prefill_tokens]
|
| 649 |
+
value = value[:num_prefill_tokens]
|
| 650 |
+
|
| 651 |
+
if prefill_meta := attn_metadata.prefill_metadata:
|
| 652 |
+
# Prompt run.
|
| 653 |
+
# normal attention and DECODER
|
| 654 |
+
if self.attn_type == AttentionType.DECODER and (
|
| 655 |
+
kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
| 656 |
+
or prefill_meta.block_tables.numel() == 0):
|
| 657 |
+
(query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
|
| 658 |
+
key_max_seq_len, seq_lens,
|
| 659 |
+
causal_mask) = (prefill_meta.seq_start_loc,
|
| 660 |
+
prefill_meta.max_prefill_seq_len,
|
| 661 |
+
prefill_meta.seq_start_loc,
|
| 662 |
+
prefill_meta.max_prefill_seq_len,
|
| 663 |
+
attn_metadata.seq_lens, True)
|
| 664 |
+
# prefix-enabled attention and ENCODER/ENCODER_DECODER
|
| 665 |
+
else:
|
| 666 |
+
(query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
|
| 667 |
+
key_max_seq_len, seq_lens,
|
| 668 |
+
causal_mask) = _get_seq_len_block_table_args(
|
| 669 |
+
prefill_meta, self.attn_type)
|
| 670 |
+
# Prompt run.
|
| 671 |
+
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
|
| 672 |
+
# triton attention
|
| 673 |
+
# When block_tables are not filled, it means q and k are the
|
| 674 |
+
# prompt, and they have the same length.
|
| 675 |
+
attn_masks = None
|
| 676 |
+
if self.use_triton_flash_attn:
|
| 677 |
+
if self.alibi_slopes is not None:
|
| 678 |
+
attn_masks = _make_alibi_bias(
|
| 679 |
+
self.alibi_slopes,
|
| 680 |
+
query.dtype,
|
| 681 |
+
seq_lens,
|
| 682 |
+
make_attn_mask=False) # type: ignore
|
| 683 |
+
out, _ = self.attn_func(
|
| 684 |
+
query,
|
| 685 |
+
key,
|
| 686 |
+
value,
|
| 687 |
+
None,
|
| 688 |
+
query_seq_start_loc,
|
| 689 |
+
key_seq_start_loc,
|
| 690 |
+
query_max_seq_len,
|
| 691 |
+
key_max_seq_len,
|
| 692 |
+
causal_mask,
|
| 693 |
+
self.scale,
|
| 694 |
+
attn_masks[0][None]
|
| 695 |
+
if attn_masks is not None else None,
|
| 696 |
+
)
|
| 697 |
+
elif self.use_naive_attn:
|
| 698 |
+
if self.num_kv_heads != self.num_heads:
|
| 699 |
+
# Interleave for MQA workaround.
|
| 700 |
+
key = self.repeat_kv(key, self.num_queries_per_kv)
|
| 701 |
+
value = self.repeat_kv(value, self.num_queries_per_kv)
|
| 702 |
+
if self.alibi_slopes is not None:
|
| 703 |
+
attn_masks = _make_alibi_bias(
|
| 704 |
+
self.alibi_slopes,
|
| 705 |
+
query.dtype,
|
| 706 |
+
attn_metadata.seq_lens,
|
| 707 |
+
make_attn_mask=True) # type: ignore
|
| 708 |
+
query = query.movedim(0, query.dim() - 2)
|
| 709 |
+
key = key.movedim(0, key.dim() - 2)
|
| 710 |
+
value = value.movedim(0, value.dim() - 2)
|
| 711 |
+
# sdpa math backend attention
|
| 712 |
+
out = self.attn_func(
|
| 713 |
+
query,
|
| 714 |
+
key,
|
| 715 |
+
value,
|
| 716 |
+
query_seq_start_loc,
|
| 717 |
+
num_prefill_tokens,
|
| 718 |
+
self.num_heads,
|
| 719 |
+
self.head_size,
|
| 720 |
+
self.scale,
|
| 721 |
+
causal_mask,
|
| 722 |
+
attn_masks,
|
| 723 |
+
)
|
| 724 |
+
else:
|
| 725 |
+
out = self.attn_func(
|
| 726 |
+
q=query,
|
| 727 |
+
k=key,
|
| 728 |
+
v=value,
|
| 729 |
+
cu_seqlens_q=query_seq_start_loc,
|
| 730 |
+
cu_seqlens_k=key_seq_start_loc,
|
| 731 |
+
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
| 732 |
+
max_seqlen_k=key_max_seq_len,
|
| 733 |
+
softmax_scale=self.scale,
|
| 734 |
+
causal=True,
|
| 735 |
+
window_size=self.sliding_window,
|
| 736 |
+
alibi_slopes=self.alibi_slopes,
|
| 737 |
+
softcap=self.logits_soft_cap,
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
# common code for prefill
|
| 741 |
+
assert output[:num_prefill_tokens].shape == out.shape
|
| 742 |
+
if output.shape[0] > num_prefill_tokens:
|
| 743 |
+
output[:num_prefill_tokens] = out
|
| 744 |
+
else:
|
| 745 |
+
output = out
|
| 746 |
+
else:
|
| 747 |
+
# prefix-enabled attention
|
| 748 |
+
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
|
| 749 |
+
query,
|
| 750 |
+
key,
|
| 751 |
+
value,
|
| 752 |
+
self.kv_cache_dtype,
|
| 753 |
+
key_cache,
|
| 754 |
+
value_cache,
|
| 755 |
+
prefill_meta.block_tables,
|
| 756 |
+
prefill_meta.query_start_loc,
|
| 757 |
+
prefill_meta.seq_lens_tensor,
|
| 758 |
+
prefill_meta.context_lens_tensor,
|
| 759 |
+
prefill_meta.max_query_len,
|
| 760 |
+
self.alibi_slopes,
|
| 761 |
+
self.sliding_window[0],
|
| 762 |
+
layer._k_scale,
|
| 763 |
+
layer._v_scale,
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
if decode_meta := attn_metadata.decode_metadata:
|
| 767 |
+
# Decoding run.
|
| 768 |
+
# Whether to use rocm custom paged attention or not
|
| 769 |
+
num_seqs, num_heads, head_size = decode_query.shape
|
| 770 |
+
block_size = value_cache.shape[3]
|
| 771 |
+
gqa_ratio = num_heads // self.num_kv_heads
|
| 772 |
+
use_custom = _use_rocm_custom_paged_attention(
|
| 773 |
+
decode_query.dtype, head_size, block_size, gqa_ratio,
|
| 774 |
+
decode_meta.max_decode_seq_len)
|
| 775 |
+
if use_custom:
|
| 776 |
+
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
|
| 777 |
+
!= AttentionType.ENCODER_DECODER else
|
| 778 |
+
decode_meta.max_encoder_seq_len)
|
| 779 |
+
assert max_seq_len is not None
|
| 780 |
+
max_num_partitions = (
|
| 781 |
+
(max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
| 782 |
+
_PARTITION_SIZE_ROCM)
|
| 783 |
+
assert _PARTITION_SIZE_ROCM % block_size == 0
|
| 784 |
+
tmp_output = torch.empty(
|
| 785 |
+
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
| 786 |
+
dtype=output.dtype,
|
| 787 |
+
device=output.device,
|
| 788 |
+
)
|
| 789 |
+
exp_sums = torch.empty(
|
| 790 |
+
size=(num_seqs, num_heads, max_num_partitions),
|
| 791 |
+
dtype=torch.float32,
|
| 792 |
+
device=output.device,
|
| 793 |
+
)
|
| 794 |
+
max_logits = torch.empty_like(exp_sums)
|
| 795 |
+
if num_prefill_tokens > 0:
|
| 796 |
+
out = output[num_prefill_tokens:]
|
| 797 |
+
else:
|
| 798 |
+
out = output
|
| 799 |
+
ops.paged_attention_rocm(
|
| 800 |
+
out,
|
| 801 |
+
exp_sums,
|
| 802 |
+
max_logits,
|
| 803 |
+
tmp_output,
|
| 804 |
+
decode_query,
|
| 805 |
+
key_cache,
|
| 806 |
+
value_cache,
|
| 807 |
+
self.num_kv_heads,
|
| 808 |
+
self.scale,
|
| 809 |
+
decode_meta.block_tables
|
| 810 |
+
if self.attn_type != AttentionType.ENCODER_DECODER else
|
| 811 |
+
decode_meta.cross_block_tables,
|
| 812 |
+
decode_meta.seq_lens_tensor
|
| 813 |
+
if self.attn_type != AttentionType.ENCODER_DECODER else
|
| 814 |
+
decode_meta.encoder_seq_lens_tensor,
|
| 815 |
+
block_size,
|
| 816 |
+
max_seq_len,
|
| 817 |
+
self.alibi_slopes,
|
| 818 |
+
self.kv_cache_dtype,
|
| 819 |
+
layer._k_scale,
|
| 820 |
+
layer._v_scale,
|
| 821 |
+
)
|
| 822 |
+
else:
|
| 823 |
+
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
| 824 |
+
decode_query,
|
| 825 |
+
key_cache,
|
| 826 |
+
value_cache,
|
| 827 |
+
decode_meta.block_tables
|
| 828 |
+
if self.attn_type != AttentionType.ENCODER_DECODER else
|
| 829 |
+
decode_meta.cross_block_tables,
|
| 830 |
+
decode_meta.seq_lens_tensor
|
| 831 |
+
if self.attn_type != AttentionType.ENCODER_DECODER else
|
| 832 |
+
decode_meta.encoder_seq_lens_tensor,
|
| 833 |
+
decode_meta.max_decode_seq_len
|
| 834 |
+
if self.attn_type != AttentionType.ENCODER_DECODER else
|
| 835 |
+
decode_meta.max_encoder_seq_len,
|
| 836 |
+
self.kv_cache_dtype,
|
| 837 |
+
self.num_kv_heads,
|
| 838 |
+
self.scale,
|
| 839 |
+
self.alibi_slopes,
|
| 840 |
+
layer._k_scale,
|
| 841 |
+
layer._v_scale,
|
| 842 |
+
)
|
| 843 |
+
|
| 844 |
+
# Reshape the output tensor.
|
| 845 |
+
return output.view(-1, self.num_heads * self.head_size)
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
def _sdpa_attention(
|
| 849 |
+
query: torch.Tensor,
|
| 850 |
+
key: torch.Tensor,
|
| 851 |
+
value: torch.Tensor,
|
| 852 |
+
seq_lens: List[int],
|
| 853 |
+
num_tokens: int,
|
| 854 |
+
num_heads: int,
|
| 855 |
+
head_size: int,
|
| 856 |
+
scale: float,
|
| 857 |
+
attn_masks: Optional[List[torch.Tensor]] = None,
|
| 858 |
+
) -> torch.Tensor:
|
| 859 |
+
start = 0
|
| 860 |
+
output = torch.empty((num_tokens, num_heads, head_size),
|
| 861 |
+
dtype=query.dtype,
|
| 862 |
+
device=query.device)
|
| 863 |
+
|
| 864 |
+
for i, seq_len in enumerate(seq_lens):
|
| 865 |
+
end = start + seq_len
|
| 866 |
+
with torch.backends.cuda.sdp_kernel(enable_math=True,
|
| 867 |
+
enable_flash=False,
|
| 868 |
+
enable_mem_efficient=False):
|
| 869 |
+
sub_out = torch.nn.functional.scaled_dot_product_attention(
|
| 870 |
+
query[:, start:end, :],
|
| 871 |
+
key[:, start:end, :],
|
| 872 |
+
value[:, start:end, :],
|
| 873 |
+
dropout_p=0.0,
|
| 874 |
+
is_causal=attn_masks is None,
|
| 875 |
+
attn_mask=attn_masks[i] if attn_masks else None,
|
| 876 |
+
scale=scale).movedim(query.dim() - 2, 0)
|
| 877 |
+
output[start:end, :, :] = sub_out
|
| 878 |
+
start = end
|
| 879 |
+
|
| 880 |
+
return output
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
| 884 |
+
block_size: int, gqa_ratio: int,
|
| 885 |
+
max_seq_len: int) -> bool:
|
| 886 |
+
# rocm custom page attention not support on navi (gfx1*)
|
| 887 |
+
return (_ON_MI250_MI300 and not _ON_NAVI
|
| 888 |
+
and (qtype == torch.half or qtype == torch.bfloat16)
|
| 889 |
+
and (head_size == 64 or head_size == 128)
|
| 890 |
+
and (block_size == 16 or block_size == 32)
|
| 891 |
+
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/torch_sdpa.py
ADDED
|
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
""" Attention layer with torch scaled_dot_product_attention
|
| 3 |
+
and PagedAttention."""
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.nn.functional import scaled_dot_product_attention
|
| 9 |
+
|
| 10 |
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
| 11 |
+
AttentionLayer,
|
| 12 |
+
AttentionMetadata,
|
| 13 |
+
AttentionMetadataBuilder,
|
| 14 |
+
AttentionType)
|
| 15 |
+
from vllm.attention.backends.utils import CommonAttentionState
|
| 16 |
+
from vllm.attention.ops.ipex_attn import PagedAttention
|
| 17 |
+
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
| 18 |
+
from vllm.logger import init_logger
|
| 19 |
+
from vllm.utils import make_tensor_with_pad
|
| 20 |
+
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
|
| 21 |
+
|
| 22 |
+
logger = init_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TorchSDPABackend(AttentionBackend):
|
| 26 |
+
|
| 27 |
+
@staticmethod
|
| 28 |
+
def get_name() -> str:
|
| 29 |
+
return "TORCH_SDPA"
|
| 30 |
+
|
| 31 |
+
@staticmethod
|
| 32 |
+
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
|
| 33 |
+
return TorchSDPABackendImpl
|
| 34 |
+
|
| 35 |
+
@staticmethod
|
| 36 |
+
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
| 37 |
+
return TorchSDPAMetadata
|
| 38 |
+
|
| 39 |
+
@staticmethod
|
| 40 |
+
def get_state_cls() -> Type["CommonAttentionState"]:
|
| 41 |
+
return CommonAttentionState
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]:
|
| 45 |
+
return TorchSDPAMetadataBuilder
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def get_kv_cache_shape(
|
| 49 |
+
num_blocks: int,
|
| 50 |
+
block_size: int,
|
| 51 |
+
num_kv_heads: int,
|
| 52 |
+
head_size: int,
|
| 53 |
+
) -> Tuple[int, ...]:
|
| 54 |
+
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
| 55 |
+
num_kv_heads, head_size)
|
| 56 |
+
|
| 57 |
+
@staticmethod
|
| 58 |
+
def swap_blocks(
|
| 59 |
+
src_kv_cache: torch.Tensor,
|
| 60 |
+
dst_kv_cache: torch.Tensor,
|
| 61 |
+
src_to_dst: torch.Tensor,
|
| 62 |
+
) -> None:
|
| 63 |
+
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
| 64 |
+
|
| 65 |
+
@staticmethod
|
| 66 |
+
def copy_blocks(
|
| 67 |
+
kv_caches: List[torch.Tensor],
|
| 68 |
+
src_to_dists: torch.Tensor,
|
| 69 |
+
) -> None:
|
| 70 |
+
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@dataclass
|
| 74 |
+
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
| 75 |
+
"""Metadata for TorchSDPABackend.
|
| 76 |
+
"""
|
| 77 |
+
# Currently, input sequences can only contain all prompts
|
| 78 |
+
# or all decoding. True if all sequences are prompts.
|
| 79 |
+
chunked_prefill: bool
|
| 80 |
+
seq_lens: Optional[List[int]] = None # For non-chunked prefill
|
| 81 |
+
|
| 82 |
+
# For chunked prefill only
|
| 83 |
+
max_query_len: Optional[int] = None
|
| 84 |
+
max_kv_len: Optional[int] = None
|
| 85 |
+
query_start_loc: Optional[torch.Tensor] = None
|
| 86 |
+
kv_start_loc: Optional[torch.Tensor] = None
|
| 87 |
+
prefill_block_tables: Optional[torch.Tensor] = None
|
| 88 |
+
|
| 89 |
+
# Begin encoder attn & enc/dec cross-attn fields...
|
| 90 |
+
# Encoder sequence lengths representation
|
| 91 |
+
encoder_seq_lens: Optional[List[int]] = None
|
| 92 |
+
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
| 93 |
+
|
| 94 |
+
# Maximum sequence length among encoder sequences
|
| 95 |
+
max_encoder_seq_len: Optional[int] = None
|
| 96 |
+
|
| 97 |
+
# Number of tokens input to encoder
|
| 98 |
+
num_encoder_tokens: Optional[int] = None
|
| 99 |
+
|
| 100 |
+
# Cross-attention memory-mapping data structures: slot mapping
|
| 101 |
+
# and block tables
|
| 102 |
+
cross_slot_mapping: Optional[torch.Tensor] = None
|
| 103 |
+
cross_block_tables: Optional[torch.Tensor] = None
|
| 104 |
+
|
| 105 |
+
def __post_init__(self):
|
| 106 |
+
# Set during the execution of the first attention op.
|
| 107 |
+
# It is a list because it is needed to set per prompt
|
| 108 |
+
# when alibi slopes is used. It is because of the limitation
|
| 109 |
+
# from xformer API.
|
| 110 |
+
# will not appear in the __repr__ and __init__
|
| 111 |
+
self.attn_bias: Optional[List[torch.Tensor]] = None
|
| 112 |
+
self.encoder_attn_bias: Optional[List[torch.Tensor]] = None
|
| 113 |
+
self.cross_attn_bias: Optional[List[torch.Tensor]] = None
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def is_all_encoder_attn_metadata_set(self):
|
| 117 |
+
'''
|
| 118 |
+
All attention metadata required for encoder attention is set.
|
| 119 |
+
'''
|
| 120 |
+
return ((self.encoder_seq_lens is not None)
|
| 121 |
+
and (self.encoder_seq_lens_tensor is not None)
|
| 122 |
+
and (self.max_encoder_seq_len is not None))
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def is_all_cross_attn_metadata_set(self):
|
| 126 |
+
'''
|
| 127 |
+
All attention metadata required for enc/dec cross-attention is set.
|
| 128 |
+
|
| 129 |
+
Superset of encoder attention required metadata.
|
| 130 |
+
'''
|
| 131 |
+
return (self.is_all_encoder_attn_metadata_set
|
| 132 |
+
and (self.cross_slot_mapping is not None)
|
| 133 |
+
and (self.cross_block_tables is not None))
|
| 134 |
+
|
| 135 |
+
@property
|
| 136 |
+
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
|
| 137 |
+
if self.num_prefill_tokens == 0:
|
| 138 |
+
return None
|
| 139 |
+
return self
|
| 140 |
+
|
| 141 |
+
@property
|
| 142 |
+
def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
|
| 143 |
+
if self.num_decode_tokens == 0:
|
| 144 |
+
return None
|
| 145 |
+
return self
|
| 146 |
+
|
| 147 |
+
def get_seq_lens(
|
| 148 |
+
self,
|
| 149 |
+
attn_type: str,
|
| 150 |
+
):
|
| 151 |
+
'''
|
| 152 |
+
Extract appropriate sequence lengths from attention metadata
|
| 153 |
+
according to attention type.
|
| 154 |
+
|
| 155 |
+
Arguments:
|
| 156 |
+
|
| 157 |
+
* attn_metadata: Attention metadata structure associated with attention
|
| 158 |
+
* attn_type: encoder attention, decoder self-attention,
|
| 159 |
+
encoder/decoder cross-attention
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
* Appropriate sequence lengths tensor for query
|
| 163 |
+
* Appropriate sequence lengths tensor for key & value
|
| 164 |
+
'''
|
| 165 |
+
|
| 166 |
+
if (attn_type == AttentionType.DECODER
|
| 167 |
+
or attn_type == AttentionType.ENCODER_ONLY):
|
| 168 |
+
seq_lens_q = self.seq_lens
|
| 169 |
+
seq_lens_kv = self.seq_lens
|
| 170 |
+
elif attn_type == AttentionType.ENCODER:
|
| 171 |
+
seq_lens_q = self.encoder_seq_lens
|
| 172 |
+
seq_lens_kv = self.encoder_seq_lens
|
| 173 |
+
elif attn_type == AttentionType.ENCODER_DECODER:
|
| 174 |
+
seq_lens_q = self.seq_lens
|
| 175 |
+
seq_lens_kv = self.encoder_seq_lens
|
| 176 |
+
else:
|
| 177 |
+
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
| 178 |
+
return seq_lens_q, seq_lens_kv
|
| 179 |
+
|
| 180 |
+
def get_attn_bias(
|
| 181 |
+
self,
|
| 182 |
+
attn_type: str,
|
| 183 |
+
) -> Optional[List[torch.Tensor]]:
|
| 184 |
+
'''
|
| 185 |
+
Extract appropriate attention bias from attention metadata
|
| 186 |
+
according to attention type.
|
| 187 |
+
|
| 188 |
+
Arguments:
|
| 189 |
+
|
| 190 |
+
* attn_metadata: Attention metadata structure associated with attention
|
| 191 |
+
* attn_type: encoder attention, decoder self-attention,
|
| 192 |
+
encoder/decoder cross-attention
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
* Appropriate attention bias value given the attention type
|
| 196 |
+
'''
|
| 197 |
+
|
| 198 |
+
if (attn_type == AttentionType.DECODER
|
| 199 |
+
or attn_type == AttentionType.ENCODER_ONLY):
|
| 200 |
+
return self.attn_bias
|
| 201 |
+
elif attn_type == AttentionType.ENCODER:
|
| 202 |
+
return self.encoder_attn_bias
|
| 203 |
+
elif attn_type == AttentionType.ENCODER_DECODER:
|
| 204 |
+
return self.cross_attn_bias
|
| 205 |
+
else:
|
| 206 |
+
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
| 207 |
+
|
| 208 |
+
def set_attn_bias(
|
| 209 |
+
self,
|
| 210 |
+
attn_bias: List[torch.Tensor],
|
| 211 |
+
attn_type: str,
|
| 212 |
+
) -> None:
|
| 213 |
+
'''
|
| 214 |
+
Update appropriate attention bias field of attention metadata,
|
| 215 |
+
according to attention type.
|
| 216 |
+
|
| 217 |
+
Arguments:
|
| 218 |
+
|
| 219 |
+
* attn_metadata: Attention metadata structure associated with attention
|
| 220 |
+
* attn_bias: The desired attention bias value
|
| 221 |
+
* attn_type: encoder attention, decoder self-attention,
|
| 222 |
+
encoder/decoder cross-attention
|
| 223 |
+
'''
|
| 224 |
+
|
| 225 |
+
if (attn_type == AttentionType.DECODER
|
| 226 |
+
or attn_type == AttentionType.ENCODER_ONLY):
|
| 227 |
+
self.attn_bias = attn_bias
|
| 228 |
+
elif attn_type == AttentionType.ENCODER:
|
| 229 |
+
self.encoder_attn_bias = attn_bias
|
| 230 |
+
elif attn_type == AttentionType.ENCODER_DECODER:
|
| 231 |
+
self.cross_attn_bias = attn_bias
|
| 232 |
+
else:
|
| 233 |
+
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
| 234 |
+
|
| 235 |
+
def get_seq_len_block_table_args(
|
| 236 |
+
self,
|
| 237 |
+
attn_type: str,
|
| 238 |
+
) -> tuple:
|
| 239 |
+
'''
|
| 240 |
+
The particular choice of sequence-length- and block-table-related
|
| 241 |
+
attributes which should be extracted from attn_metadata is dependent
|
| 242 |
+
on the type of attention operation.
|
| 243 |
+
|
| 244 |
+
Decoder attn -> select entirely decoder self-attention-related fields
|
| 245 |
+
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
| 246 |
+
cross-attn block-tables fields
|
| 247 |
+
Encoder attn -> select encoder sequence lengths fields & no block tables
|
| 248 |
+
|
| 249 |
+
Arguments:
|
| 250 |
+
|
| 251 |
+
* attn_metadata: Attention metadata structure associated with attention
|
| 252 |
+
* is_prompt: True if prefill, False otherwise
|
| 253 |
+
* attn_type: encoder attention, decoder self-attention,
|
| 254 |
+
encoder/decoder cross-attention
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
|
| 258 |
+
* Appropriate sequence-lengths tensor
|
| 259 |
+
* Appropriate max sequence-length scalar
|
| 260 |
+
* Appropriate block tables (or None)
|
| 261 |
+
'''
|
| 262 |
+
|
| 263 |
+
if (attn_type == AttentionType.DECODER
|
| 264 |
+
or attn_type == AttentionType.ENCODER_ONLY):
|
| 265 |
+
# Decoder self-attention
|
| 266 |
+
# Choose max_seq_len based on whether we are in prompt_run
|
| 267 |
+
return (self.seq_lens_tensor, self.max_decode_seq_len,
|
| 268 |
+
self.block_tables)
|
| 269 |
+
elif attn_type == AttentionType.ENCODER_DECODER:
|
| 270 |
+
# Enc/dec cross-attention KVs match encoder sequence length;
|
| 271 |
+
# cross-attention utilizes special "cross" block tables
|
| 272 |
+
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
|
| 273 |
+
self.cross_block_tables)
|
| 274 |
+
elif attn_type == AttentionType.ENCODER:
|
| 275 |
+
# No block tables associated with encoder attention
|
| 276 |
+
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
|
| 277 |
+
None)
|
| 278 |
+
else:
|
| 279 |
+
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
| 283 |
+
|
| 284 |
+
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
|
| 285 |
+
self.chunked_prefill = input_builder.chunked_prefill
|
| 286 |
+
self.input_builder = input_builder
|
| 287 |
+
|
| 288 |
+
def prepare(self):
|
| 289 |
+
self.input_data = self.input_builder.input_data
|
| 290 |
+
|
| 291 |
+
def build(self, seq_lens: List[int], query_lens: List[int],
|
| 292 |
+
cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:
|
| 293 |
+
input_data = self.input_data
|
| 294 |
+
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
|
| 295 |
+
prefill_query_lens = query_lens[0:input_data.num_prefills]
|
| 296 |
+
slot_mapping = torch.tensor(input_data.slot_mapping,
|
| 297 |
+
dtype=torch.long,
|
| 298 |
+
device="cpu")
|
| 299 |
+
|
| 300 |
+
# For chunked-prefill
|
| 301 |
+
if self.chunked_prefill and input_data.num_prefill_tokens != 0:
|
| 302 |
+
prefill_block_tables = make_tensor_with_pad(
|
| 303 |
+
self.input_data.prefill_block_tables,
|
| 304 |
+
pad=0,
|
| 305 |
+
dtype=torch.int32,
|
| 306 |
+
device="cpu",
|
| 307 |
+
)
|
| 308 |
+
query_lens_tensor = torch.tensor(prefill_query_lens,
|
| 309 |
+
dtype=torch.int32,
|
| 310 |
+
device="cpu")
|
| 311 |
+
kv_lens_tensor = torch.tensor(prefill_seq_lens,
|
| 312 |
+
dtype=torch.int32,
|
| 313 |
+
device="cpu")
|
| 314 |
+
query_start_loc = torch.zeros(input_data.num_prefills + 1,
|
| 315 |
+
dtype=torch.int32,
|
| 316 |
+
device="cpu")
|
| 317 |
+
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
|
| 318 |
+
dtype=torch.int32,
|
| 319 |
+
device="cpu")
|
| 320 |
+
torch.cumsum(query_lens_tensor,
|
| 321 |
+
dim=0,
|
| 322 |
+
dtype=torch.int32,
|
| 323 |
+
out=query_start_loc[1:])
|
| 324 |
+
torch.cumsum(kv_lens_tensor,
|
| 325 |
+
dim=0,
|
| 326 |
+
dtype=torch.int32,
|
| 327 |
+
out=kv_start_loc[1:])
|
| 328 |
+
max_query_len = max(prefill_query_lens)
|
| 329 |
+
max_kv_len = max(prefill_seq_lens)
|
| 330 |
+
else:
|
| 331 |
+
prefill_block_tables = None
|
| 332 |
+
query_start_loc = None
|
| 333 |
+
kv_start_loc = None
|
| 334 |
+
max_query_len = None
|
| 335 |
+
max_kv_len = None
|
| 336 |
+
|
| 337 |
+
# For paged attention
|
| 338 |
+
if input_data.num_decode_tokens != 0:
|
| 339 |
+
seq_lens_tensor = torch.tensor(
|
| 340 |
+
input_data.seq_lens[input_data.num_prefills:],
|
| 341 |
+
dtype=torch.int32,
|
| 342 |
+
device="cpu",
|
| 343 |
+
)
|
| 344 |
+
block_tables = make_tensor_with_pad(
|
| 345 |
+
self.input_data.decode_block_tables,
|
| 346 |
+
pad=0,
|
| 347 |
+
dtype=torch.int32,
|
| 348 |
+
device="cpu",
|
| 349 |
+
)
|
| 350 |
+
else:
|
| 351 |
+
block_tables = torch.tensor([])
|
| 352 |
+
seq_lens_tensor = torch.tensor(
|
| 353 |
+
input_data.seq_lens[:input_data.num_prefills],
|
| 354 |
+
dtype=torch.int32,
|
| 355 |
+
device="cpu",
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# For multi-modal models
|
| 359 |
+
placeholder_index_maps = None
|
| 360 |
+
if len(input_data.multi_modal_inputs_list) != 0:
|
| 361 |
+
placeholder_index_maps = {
|
| 362 |
+
modality: placeholder_map.index_map()
|
| 363 |
+
for modality, placeholder_map in
|
| 364 |
+
input_data.multi_modal_placeholder_maps.items()
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
attn_metadata = TorchSDPAMetadata(
|
| 368 |
+
chunked_prefill=self.chunked_prefill,
|
| 369 |
+
seq_lens=prefill_seq_lens,
|
| 370 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 371 |
+
max_query_len=max_query_len,
|
| 372 |
+
max_kv_len=max_kv_len,
|
| 373 |
+
query_start_loc=query_start_loc,
|
| 374 |
+
kv_start_loc=kv_start_loc,
|
| 375 |
+
max_decode_seq_len=input_data.max_decode_seq_len,
|
| 376 |
+
num_prefills=input_data.num_prefills,
|
| 377 |
+
num_prefill_tokens=input_data.num_prefill_tokens,
|
| 378 |
+
num_decode_tokens=input_data.num_decode_tokens,
|
| 379 |
+
block_tables=block_tables,
|
| 380 |
+
prefill_block_tables=prefill_block_tables,
|
| 381 |
+
slot_mapping=slot_mapping,
|
| 382 |
+
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
| 383 |
+
enable_kv_scales_calculation=False,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
return attn_metadata
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
| 390 |
+
|
| 391 |
+
def __init__(
|
| 392 |
+
self,
|
| 393 |
+
num_heads: int,
|
| 394 |
+
head_size: int,
|
| 395 |
+
scale: float,
|
| 396 |
+
num_kv_heads: int,
|
| 397 |
+
alibi_slopes: Optional[List[float]],
|
| 398 |
+
sliding_window: Optional[int],
|
| 399 |
+
kv_cache_dtype: str,
|
| 400 |
+
blocksparse_params: Optional[Dict[str, Any]] = None,
|
| 401 |
+
logits_soft_cap: Optional[float] = None,
|
| 402 |
+
attn_type: str = AttentionType.DECODER,
|
| 403 |
+
) -> None:
|
| 404 |
+
if blocksparse_params is not None:
|
| 405 |
+
raise ValueError(
|
| 406 |
+
"Torch SPDA does not support block-sparse attention.")
|
| 407 |
+
if logits_soft_cap is not None:
|
| 408 |
+
logger.warning_once("Torch SPDA does not support logits soft cap. "
|
| 409 |
+
"Outputs may be slightly off.")
|
| 410 |
+
self.num_heads = num_heads
|
| 411 |
+
self.head_size = head_size
|
| 412 |
+
self.scale = float(scale)
|
| 413 |
+
self.num_kv_heads = num_kv_heads
|
| 414 |
+
if alibi_slopes is not None:
|
| 415 |
+
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
| 416 |
+
self.alibi_slopes = alibi_slopes
|
| 417 |
+
self.sliding_window = sliding_window
|
| 418 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 419 |
+
|
| 420 |
+
assert self.num_heads % self.num_kv_heads == 0
|
| 421 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 422 |
+
self.need_mask = (self.alibi_slopes is not None
|
| 423 |
+
or self.sliding_window is not None)
|
| 424 |
+
|
| 425 |
+
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
| 426 |
+
if head_size not in supported_head_sizes:
|
| 427 |
+
raise ValueError(
|
| 428 |
+
f"Head size {head_size} is not supported by PagedAttention. "
|
| 429 |
+
f"Supported head sizes are: {supported_head_sizes}.")
|
| 430 |
+
if kv_cache_dtype != "auto":
|
| 431 |
+
raise NotImplementedError(
|
| 432 |
+
"Torch SDPA backend does not support FP8 KV cache. "
|
| 433 |
+
"Please use xFormers backend instead.")
|
| 434 |
+
self.attn_type = attn_type
|
| 435 |
+
|
| 436 |
+
def forward(
|
| 437 |
+
self,
|
| 438 |
+
layer: AttentionLayer,
|
| 439 |
+
query: torch.Tensor,
|
| 440 |
+
key: torch.Tensor,
|
| 441 |
+
value: torch.Tensor,
|
| 442 |
+
kv_cache: torch.Tensor,
|
| 443 |
+
attn_metadata: TorchSDPAMetadata, # type: ignore
|
| 444 |
+
output: Optional[torch.Tensor] = None,
|
| 445 |
+
) -> torch.Tensor:
|
| 446 |
+
"""Forward pass with torch SDPA and PagedAttention.
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
query: shape = [num_tokens, num_heads * head_size]
|
| 450 |
+
key: shape = [num_tokens, num_kv_heads * head_size]
|
| 451 |
+
value: shape = [num_tokens, num_kv_heads * head_size]
|
| 452 |
+
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
| 453 |
+
NOTE: kv_cache will be an empty tensor with shape [0]
|
| 454 |
+
for profiling run.
|
| 455 |
+
attn_metadata: Metadata for attention.
|
| 456 |
+
Returns:
|
| 457 |
+
shape = [num_tokens, num_heads * head_size]
|
| 458 |
+
"""
|
| 459 |
+
attn_type = self.attn_type
|
| 460 |
+
if (attn_type == AttentionType.ENCODER
|
| 461 |
+
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
| 462 |
+
raise AttributeError("Encoder attention requires setting "
|
| 463 |
+
"encoder metadata attributes.")
|
| 464 |
+
elif (attn_type == AttentionType.ENCODER_DECODER
|
| 465 |
+
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
| 466 |
+
raise AttributeError("Encoder/decoder cross-attention "
|
| 467 |
+
"requires setting cross-attention "
|
| 468 |
+
"metadata attributes.")
|
| 469 |
+
|
| 470 |
+
# Reshape the query, key, and value tensors.
|
| 471 |
+
query = query.view(-1, self.num_heads, self.head_size)
|
| 472 |
+
if key is not None:
|
| 473 |
+
assert value is not None
|
| 474 |
+
key = key.view(-1, self.num_kv_heads, self.head_size)
|
| 475 |
+
value = value.view(-1, self.num_kv_heads, self.head_size)
|
| 476 |
+
else:
|
| 477 |
+
assert value is None
|
| 478 |
+
|
| 479 |
+
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
| 480 |
+
# KV-cache during decoder-self- or
|
| 481 |
+
# encoder-decoder-cross-attention, but not
|
| 482 |
+
# during encoder attention.
|
| 483 |
+
#
|
| 484 |
+
# Even if there are no new key/value pairs to cache,
|
| 485 |
+
# we still need to break out key_cache and value_cache
|
| 486 |
+
# i.e. for later use by paged attention
|
| 487 |
+
key_cache, value_cache = PagedAttention.split_kv_cache(
|
| 488 |
+
kv_cache, self.num_kv_heads, self.head_size)
|
| 489 |
+
|
| 490 |
+
if (key is not None) and (value is not None):
|
| 491 |
+
if attn_type == AttentionType.ENCODER_DECODER:
|
| 492 |
+
# Update cross-attention KV cache (prefill-only)
|
| 493 |
+
# During cross-attention decode, key & value will be None,
|
| 494 |
+
# preventing this IF-statement branch from running
|
| 495 |
+
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
| 496 |
+
else:
|
| 497 |
+
# Update self-attention KV cache (prefill/decode)
|
| 498 |
+
updated_slot_mapping = attn_metadata.slot_mapping
|
| 499 |
+
|
| 500 |
+
PagedAttention.write_to_paged_cache(
|
| 501 |
+
key, value, key_cache, value_cache, updated_slot_mapping,
|
| 502 |
+
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
|
| 503 |
+
|
| 504 |
+
if attn_type != AttentionType.ENCODER:
|
| 505 |
+
# Decoder self-attention supports chunked prefill.
|
| 506 |
+
# Encoder/decoder cross-attention requires no chunked
|
| 507 |
+
# prefill (100% prefill or 100% decode tokens, no mix)
|
| 508 |
+
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
| 509 |
+
num_decode_tokens = attn_metadata.num_decode_tokens
|
| 510 |
+
else:
|
| 511 |
+
# Encoder attention - chunked prefill is not applicable;
|
| 512 |
+
# derive token-count from query shape & and treat them
|
| 513 |
+
# as 100% prefill tokens
|
| 514 |
+
assert attn_metadata.num_encoder_tokens is not None
|
| 515 |
+
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
| 516 |
+
num_decode_tokens = 0
|
| 517 |
+
|
| 518 |
+
if attn_type == AttentionType.DECODER:
|
| 519 |
+
# Only enforce this shape-constraint for decoder
|
| 520 |
+
# self-attention
|
| 521 |
+
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
| 522 |
+
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
| 523 |
+
|
| 524 |
+
output = torch.empty_like(query)
|
| 525 |
+
if prefill_meta := attn_metadata.prefill_metadata:
|
| 526 |
+
assert attn_metadata.seq_lens is not None
|
| 527 |
+
if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore
|
| 528 |
+
self._run_sdpa_forward(output,
|
| 529 |
+
query,
|
| 530 |
+
key,
|
| 531 |
+
value,
|
| 532 |
+
prefill_meta,
|
| 533 |
+
attn_type=attn_type)
|
| 534 |
+
else:
|
| 535 |
+
# prefix-enabled attention
|
| 536 |
+
assert not self.need_mask
|
| 537 |
+
import intel_extension_for_pytorch.llm.modules as ipex_modules
|
| 538 |
+
output = torch.empty_like(query)
|
| 539 |
+
ipex_modules.PagedAttention.flash_attn_varlen_func(
|
| 540 |
+
output[:prefill_meta.num_prefill_tokens, :, :],
|
| 541 |
+
query[:prefill_meta.num_prefill_tokens, :, :],
|
| 542 |
+
key_cache,
|
| 543 |
+
value_cache,
|
| 544 |
+
prefill_meta.query_start_loc,
|
| 545 |
+
prefill_meta.kv_start_loc,
|
| 546 |
+
prefill_meta.max_query_len,
|
| 547 |
+
prefill_meta.max_kv_len,
|
| 548 |
+
self.scale,
|
| 549 |
+
True,
|
| 550 |
+
prefill_meta.prefill_block_tables,
|
| 551 |
+
self.alibi_slopes,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
if decode_meta := attn_metadata.decode_metadata:
|
| 555 |
+
assert attn_type != AttentionType.ENCODER_ONLY, (
|
| 556 |
+
"Encoder-only models should not have decode metadata.")
|
| 557 |
+
# Decoding run.
|
| 558 |
+
(
|
| 559 |
+
seq_lens_arg,
|
| 560 |
+
max_seq_len_arg,
|
| 561 |
+
block_tables_arg,
|
| 562 |
+
) = decode_meta.get_seq_len_block_table_args(attn_type)
|
| 563 |
+
|
| 564 |
+
PagedAttention.forward_decode(
|
| 565 |
+
output[attn_metadata.num_prefill_tokens:, :, :],
|
| 566 |
+
query[attn_metadata.num_prefill_tokens:, :, :],
|
| 567 |
+
key_cache,
|
| 568 |
+
value_cache,
|
| 569 |
+
block_tables_arg,
|
| 570 |
+
seq_lens_arg,
|
| 571 |
+
max_seq_len_arg,
|
| 572 |
+
self.kv_cache_dtype,
|
| 573 |
+
self.num_kv_heads,
|
| 574 |
+
self.scale,
|
| 575 |
+
self.alibi_slopes,
|
| 576 |
+
layer._k_scale,
|
| 577 |
+
layer._v_scale,
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
# Reshape the output tensor.
|
| 581 |
+
return output.view(-1, self.num_heads * self.head_size)
|
| 582 |
+
|
| 583 |
+
def _run_sdpa_forward(
|
| 584 |
+
self,
|
| 585 |
+
output: torch.Tensor,
|
| 586 |
+
query: torch.Tensor,
|
| 587 |
+
key: torch.Tensor,
|
| 588 |
+
value: torch.Tensor,
|
| 589 |
+
attn_metadata: TorchSDPAMetadata,
|
| 590 |
+
attn_type: str = AttentionType.DECODER,
|
| 591 |
+
) -> None:
|
| 592 |
+
if self.num_kv_heads != self.num_heads:
|
| 593 |
+
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
| 594 |
+
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
|
| 595 |
+
|
| 596 |
+
attn_masks = attn_metadata.get_attn_bias(attn_type)
|
| 597 |
+
if attn_masks is None:
|
| 598 |
+
if self.alibi_slopes is not None:
|
| 599 |
+
attn_masks = _make_alibi_bias(
|
| 600 |
+
self.alibi_slopes, query.dtype,
|
| 601 |
+
attn_metadata.seq_lens) # type: ignore
|
| 602 |
+
elif self.sliding_window is not None:
|
| 603 |
+
assert attn_metadata.seq_lens is not None
|
| 604 |
+
attn_masks = _make_sliding_window_bias(
|
| 605 |
+
attn_metadata.seq_lens, self.sliding_window,
|
| 606 |
+
query.dtype) # type: ignore
|
| 607 |
+
else:
|
| 608 |
+
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
|
| 609 |
+
attn_masks = [None] * len(seq_lens)
|
| 610 |
+
attn_metadata.set_attn_bias(attn_masks, attn_type)
|
| 611 |
+
|
| 612 |
+
query = query.movedim(0, query.dim() - 2)
|
| 613 |
+
key = key.movedim(0, key.dim() - 2)
|
| 614 |
+
value = value.movedim(0, value.dim() - 2)
|
| 615 |
+
|
| 616 |
+
causal_attn = (attn_type == AttentionType.DECODER)
|
| 617 |
+
|
| 618 |
+
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
|
| 619 |
+
start_q, start_kv = 0, 0
|
| 620 |
+
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
|
| 621 |
+
attn_masks):
|
| 622 |
+
end_q = start_q + seq_len_q
|
| 623 |
+
end_kv = start_kv + seq_len_kv
|
| 624 |
+
sub_out = scaled_dot_product_attention(
|
| 625 |
+
query[None, :, start_q:end_q, :],
|
| 626 |
+
key[None, :, start_kv:end_kv, :],
|
| 627 |
+
value[None, :, start_kv:end_kv, :],
|
| 628 |
+
attn_mask=mask,
|
| 629 |
+
dropout_p=0.0,
|
| 630 |
+
is_causal=causal_attn and mask is None,
|
| 631 |
+
scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
|
| 632 |
+
output[start_q:end_q, :, :] = sub_out
|
| 633 |
+
start_q, start_kv = end_q, end_kv
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
def _make_alibi_bias(
|
| 637 |
+
alibi_slopes: torch.Tensor,
|
| 638 |
+
dtype: torch.dtype,
|
| 639 |
+
seq_lens: List[int],
|
| 640 |
+
) -> List[torch.Tensor]:
|
| 641 |
+
attn_biases: List[torch.Tensor] = []
|
| 642 |
+
for seq_len in seq_lens:
|
| 643 |
+
bias = torch.arange(seq_len, dtype=dtype)
|
| 644 |
+
# NOTE(zhuohan): HF uses
|
| 645 |
+
# `bias = bias[None, :].repeat(seq_len, 1)`
|
| 646 |
+
# here. We find that both biases give the same results, but
|
| 647 |
+
# the bias below more accurately follows the original ALiBi
|
| 648 |
+
# paper.
|
| 649 |
+
bias = bias[None, :] - bias[:, None]
|
| 650 |
+
|
| 651 |
+
num_heads = alibi_slopes.shape[0]
|
| 652 |
+
bias = bias[None, :].repeat((num_heads, 1, 1))
|
| 653 |
+
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
|
| 654 |
+
inf_mask = torch.empty(
|
| 655 |
+
(1, seq_len, seq_len),
|
| 656 |
+
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
|
| 657 |
+
attn_biases.append((bias + inf_mask).to(dtype))
|
| 658 |
+
|
| 659 |
+
return attn_biases
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
def _make_sliding_window_bias(
|
| 663 |
+
seq_lens: List[int],
|
| 664 |
+
window_size: Optional[int],
|
| 665 |
+
dtype: torch.dtype,
|
| 666 |
+
) -> List[torch.Tensor]:
|
| 667 |
+
attn_biases: List[torch.Tensor] = []
|
| 668 |
+
for seq_len in seq_lens:
|
| 669 |
+
tensor = torch.full(
|
| 670 |
+
(1, seq_len, seq_len),
|
| 671 |
+
dtype=dtype,
|
| 672 |
+
fill_value=1,
|
| 673 |
+
)
|
| 674 |
+
shift = 0
|
| 675 |
+
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
|
| 676 |
+
if window_size is not None:
|
| 677 |
+
mask = torch.triu(mask, diagonal=shift - window_size + 1)
|
| 678 |
+
mask = torch.log(mask)
|
| 679 |
+
attn_biases.append(mask.to(dtype))
|
| 680 |
+
|
| 681 |
+
return attn_biases
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/triton_mla.py
ADDED
|
@@ -0,0 +1,746 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from contextlib import contextmanager
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from itertools import accumulate
|
| 7 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
| 8 |
+
|
| 9 |
+
from vllm.multimodal import MultiModalPlaceholderMap
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper
|
| 13 |
+
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
| 14 |
+
except ImportError:
|
| 15 |
+
BatchDecodeMlaWithPagedKVCacheWrapper = None
|
| 16 |
+
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from vllm import _custom_ops as ops
|
| 21 |
+
from vllm.attention.backends.abstract import (AttentionBackend,
|
| 22 |
+
AttentionMetadata,
|
| 23 |
+
AttentionMetadataBuilder,
|
| 24 |
+
AttentionState, AttentionType)
|
| 25 |
+
from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata
|
| 26 |
+
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
| 27 |
+
compute_slot_mapping_start_idx,
|
| 28 |
+
is_block_tables_empty)
|
| 29 |
+
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
| 30 |
+
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
| 31 |
+
|
| 32 |
+
if TYPE_CHECKING:
|
| 33 |
+
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
| 34 |
+
ModelInputForGPUWithSamplingMetadata)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TritonMLABackend(AttentionBackend):
|
| 38 |
+
|
| 39 |
+
@staticmethod
|
| 40 |
+
def get_name() -> str:
|
| 41 |
+
return "TRITON_MLA"
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def get_impl_cls() -> Type["TritonMLAImpl"]:
|
| 45 |
+
return TritonMLAImpl
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
| 49 |
+
return TritonMLAMetadata
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
def get_builder_cls() -> Type["TritonMLAMetadataBuilder"]:
|
| 53 |
+
return TritonMLAMetadataBuilder
|
| 54 |
+
|
| 55 |
+
@staticmethod
|
| 56 |
+
def get_state_cls() -> Type["TritonMLAState"]:
|
| 57 |
+
return TritonMLAState
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def get_kv_cache_shape(
|
| 61 |
+
num_blocks: int,
|
| 62 |
+
block_size: int,
|
| 63 |
+
num_kv_heads: int, # assumed to be 1 for MLA
|
| 64 |
+
head_size: int,
|
| 65 |
+
) -> Tuple[int, ...]:
|
| 66 |
+
return (num_blocks, block_size, head_size)
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def swap_blocks(
|
| 70 |
+
src_kv_cache: torch.Tensor,
|
| 71 |
+
dst_kv_cache: torch.Tensor,
|
| 72 |
+
src_to_dst: torch.Tensor,
|
| 73 |
+
) -> None:
|
| 74 |
+
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def copy_blocks(
|
| 78 |
+
kv_caches: List[torch.Tensor],
|
| 79 |
+
src_to_dists: torch.Tensor,
|
| 80 |
+
) -> None:
|
| 81 |
+
ops.copy_blocks_mla(kv_caches, src_to_dists)
|
| 82 |
+
|
| 83 |
+
@staticmethod
|
| 84 |
+
def get_supported_head_sizes() -> List[int]:
|
| 85 |
+
return [576]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class TritonMLAState(AttentionState):
|
| 89 |
+
|
| 90 |
+
def __init__(self, runner):
|
| 91 |
+
self.runner = runner
|
| 92 |
+
self._is_graph_capturing = False
|
| 93 |
+
|
| 94 |
+
@contextmanager
|
| 95 |
+
def graph_capture(self, max_batch_size: int):
|
| 96 |
+
self._is_graph_capturing = True
|
| 97 |
+
|
| 98 |
+
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
| 99 |
+
PAD_SLOT_ID,
|
| 100 |
+
dtype=torch.long,
|
| 101 |
+
device=self.runner.device)
|
| 102 |
+
self._graph_seq_lens = torch.ones(max_batch_size,
|
| 103 |
+
dtype=torch.int32,
|
| 104 |
+
device=self.runner.device)
|
| 105 |
+
self._graph_block_tables = torch.from_numpy(
|
| 106 |
+
self.runner.graph_block_tables).to(device=self.runner.device)
|
| 107 |
+
|
| 108 |
+
self._positions = torch.zeros((max_batch_size, ),
|
| 109 |
+
dtype=torch.long,
|
| 110 |
+
device=self.runner.device)
|
| 111 |
+
|
| 112 |
+
yield
|
| 113 |
+
|
| 114 |
+
self._is_graph_capturing = False
|
| 115 |
+
del self._graph_slot_mapping
|
| 116 |
+
del self._graph_seq_lens
|
| 117 |
+
del self._graph_block_tables
|
| 118 |
+
del self._positions
|
| 119 |
+
|
| 120 |
+
def graph_clone(self, batch_size: int):
|
| 121 |
+
assert self._is_graph_capturing
|
| 122 |
+
return self.__class__(self.runner)
|
| 123 |
+
|
| 124 |
+
def graph_capture_get_metadata_for_batch(
|
| 125 |
+
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
| 126 |
+
assert self._is_graph_capturing
|
| 127 |
+
|
| 128 |
+
attn_metadata = self.runner.attn_backend.make_metadata(
|
| 129 |
+
num_prefills=0,
|
| 130 |
+
num_prefill_tokens=0,
|
| 131 |
+
num_decode_tokens=batch_size,
|
| 132 |
+
slot_mapping=self._graph_slot_mapping[:batch_size],
|
| 133 |
+
multi_modal_placeholder_index_maps=None,
|
| 134 |
+
enable_kv_scales_calculation=True,
|
| 135 |
+
seq_lens=None,
|
| 136 |
+
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
| 137 |
+
max_query_len=1,
|
| 138 |
+
max_decode_query_len=1,
|
| 139 |
+
max_prefill_seq_len=0,
|
| 140 |
+
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
| 141 |
+
query_start_loc=None,
|
| 142 |
+
seq_start_loc=None,
|
| 143 |
+
context_lens_tensor=None,
|
| 144 |
+
block_tables=self._graph_block_tables[:batch_size],
|
| 145 |
+
use_cuda_graph=True,
|
| 146 |
+
input_positions=self._positions[:batch_size],
|
| 147 |
+
head_dim=self.runner.model_config.get_head_size())
|
| 148 |
+
|
| 149 |
+
if is_encoder_decoder_model:
|
| 150 |
+
raise NotImplementedError(
|
| 151 |
+
"TritonMLAState does not support encoder/decoder yet")
|
| 152 |
+
|
| 153 |
+
return attn_metadata
|
| 154 |
+
|
| 155 |
+
def get_graph_input_buffers(self,
|
| 156 |
+
attn_metadata,
|
| 157 |
+
is_encoder_decoder_model: bool = False):
|
| 158 |
+
input_buffers = {
|
| 159 |
+
"slot_mapping": attn_metadata.slot_mapping,
|
| 160 |
+
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
| 161 |
+
"block_tables": attn_metadata.decode_metadata.block_tables,
|
| 162 |
+
"input_positions": attn_metadata.decode_metadata.input_positions,
|
| 163 |
+
}
|
| 164 |
+
if is_encoder_decoder_model:
|
| 165 |
+
raise NotImplementedError(
|
| 166 |
+
"TritonMLAState does not support encoder/decoder yet")
|
| 167 |
+
|
| 168 |
+
return input_buffers
|
| 169 |
+
|
| 170 |
+
def prepare_graph_input_buffers(self,
|
| 171 |
+
input_buffers,
|
| 172 |
+
attn_metadata,
|
| 173 |
+
is_encoder_decoder_model: bool = False):
|
| 174 |
+
input_positions = attn_metadata.input_positions
|
| 175 |
+
num_positions = input_positions.shape[0]
|
| 176 |
+
input_buffers["seq_lens_tensor"].copy_(
|
| 177 |
+
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
| 178 |
+
input_buffers["block_tables"].copy_(
|
| 179 |
+
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
| 180 |
+
# CUDA graph buffer is padded so only perform a partial copy based on
|
| 181 |
+
# num_positions
|
| 182 |
+
input_buffers["input_positions"][:num_positions].copy_(
|
| 183 |
+
input_positions, non_blocking=True)
|
| 184 |
+
if is_encoder_decoder_model:
|
| 185 |
+
raise NotImplementedError(
|
| 186 |
+
"TritonMLAState does not support encoder/decoder yet")
|
| 187 |
+
|
| 188 |
+
def begin_forward(self, model_input):
|
| 189 |
+
return
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@dataclass
|
| 193 |
+
class TritonMLAMetadata(MLACommonMetadata):
|
| 194 |
+
"""Metadata for TritonMLAMetadata.
|
| 195 |
+
|
| 196 |
+
NOTE: Any python object stored here is not updated when it is
|
| 197 |
+
cuda-graph replayed. If you have values that need to be changed
|
| 198 |
+
dynamically, it should be stored in tensor. The tensor has to be
|
| 199 |
+
updated from `CUDAGraphRunner.forward` API.
|
| 200 |
+
"""
|
| 201 |
+
# (batch_size,). The sequence length per sequence. Sequence length means
|
| 202 |
+
# the computed tokens + new tokens None if it is a decoding.
|
| 203 |
+
seq_lens: Optional[List[int]]
|
| 204 |
+
# seq_lens stored as a tensor.
|
| 205 |
+
seq_lens_tensor: Optional[torch.Tensor]
|
| 206 |
+
|
| 207 |
+
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
| 208 |
+
# |---------- N-1 iteration --------|
|
| 209 |
+
# |---------------- N iteration ---------------------|
|
| 210 |
+
# |- tokenA -|......................|-- newTokens ---|
|
| 211 |
+
# |---------- context_len ----------|
|
| 212 |
+
# |-------------------- seq_len ---------------------|
|
| 213 |
+
# |-- query_len ---|
|
| 214 |
+
|
| 215 |
+
# Maximum sequence length among prefill batch. 0 if there are decoding
|
| 216 |
+
# requests only.
|
| 217 |
+
max_prefill_seq_len: int
|
| 218 |
+
# Maximum sequence length among decode batch. 0 if there are prefill
|
| 219 |
+
# requests only.
|
| 220 |
+
max_decode_seq_len: int
|
| 221 |
+
# (batch_size,) A tensor of context lengths (tokens that are computed
|
| 222 |
+
# so far).
|
| 223 |
+
context_lens_tensor: Optional[torch.Tensor]
|
| 224 |
+
|
| 225 |
+
# (batch_size, max_blocks_per_seq).
|
| 226 |
+
# Block addresses per sequence. (Seq id -> list of physical block)
|
| 227 |
+
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
| 228 |
+
# in the kv cache. Each block can contain up to block_size tokens.
|
| 229 |
+
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
| 230 |
+
# captured.
|
| 231 |
+
block_tables: Optional[torch.Tensor]
|
| 232 |
+
|
| 233 |
+
# Whether or not if cuda graph is enabled.
|
| 234 |
+
# Cuda-graph is currently enabled for decoding only.
|
| 235 |
+
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
| 236 |
+
|
| 237 |
+
use_cuda_graph: bool
|
| 238 |
+
|
| 239 |
+
# Maximum query length in the batch.
|
| 240 |
+
max_query_len: Optional[int] = None
|
| 241 |
+
|
| 242 |
+
# Max number of query tokens among request in the batch.
|
| 243 |
+
max_decode_query_len: Optional[int] = None
|
| 244 |
+
|
| 245 |
+
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
| 246 |
+
# the batch, used to index into subquery. E.g., if the subquery length
|
| 247 |
+
# is [4, 6], it is [0, 4, 10].
|
| 248 |
+
query_start_loc: Optional[torch.Tensor] = None
|
| 249 |
+
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
| 250 |
+
# the batch, used to index into sequence. E.g., if the sequence length is
|
| 251 |
+
# [4, 6], it is [0, 4, 10].
|
| 252 |
+
seq_start_loc: Optional[torch.Tensor] = None
|
| 253 |
+
|
| 254 |
+
_cached_prefill_metadata: Optional["TritonMLAMetadata"] = None
|
| 255 |
+
_cached_decode_metadata: Optional["TritonMLAMetadata"] = None
|
| 256 |
+
|
| 257 |
+
num_prefill_tokens: int
|
| 258 |
+
|
| 259 |
+
num_kv_splits: int = 4 # TODO(lucas) add heuristic
|
| 260 |
+
attn_logits: Optional[torch.Tensor] = None
|
| 261 |
+
req_idx: Optional[torch.Tensor] = None
|
| 262 |
+
|
| 263 |
+
# The dimension of the attention heads
|
| 264 |
+
head_dim: Optional[int] = None
|
| 265 |
+
|
| 266 |
+
def __post_init__(self):
|
| 267 |
+
supported_head_sizes = TritonMLABackend.get_supported_head_sizes()
|
| 268 |
+
if self.head_dim is not None and self.head_dim \
|
| 269 |
+
not in supported_head_sizes:
|
| 270 |
+
raise ValueError(
|
| 271 |
+
f"Only {supported_head_sizes} are supported for head_dim,",
|
| 272 |
+
f"received {self.head_dim}.")
|
| 273 |
+
|
| 274 |
+
@property
|
| 275 |
+
def prefill_metadata(self) -> Optional["TritonMLAMetadata"]:
|
| 276 |
+
if self.num_prefills == 0:
|
| 277 |
+
return None
|
| 278 |
+
|
| 279 |
+
if self._cached_prefill_metadata is not None:
|
| 280 |
+
return self._cached_prefill_metadata
|
| 281 |
+
|
| 282 |
+
assert self.seq_lens is not None
|
| 283 |
+
assert self.seq_lens_tensor is not None
|
| 284 |
+
|
| 285 |
+
# Compute some attn_metadata fields which default to None
|
| 286 |
+
query_start_loc = (None if self.query_start_loc is None else
|
| 287 |
+
self.query_start_loc[:self.num_prefills + 1])
|
| 288 |
+
slot_mapping = (None if self.slot_mapping is None else
|
| 289 |
+
self.slot_mapping[:self.num_prefill_tokens])
|
| 290 |
+
seq_lens = (None if self.seq_lens is None else
|
| 291 |
+
self.seq_lens[:self.num_prefills])
|
| 292 |
+
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
| 293 |
+
self.seq_lens_tensor[:self.num_prefills])
|
| 294 |
+
seq_start_loc = (None if self.seq_start_loc is None else
|
| 295 |
+
self.seq_start_loc[:self.num_prefills + 1])
|
| 296 |
+
context_lens_tensor = (None if self.context_lens_tensor is None else
|
| 297 |
+
self.context_lens_tensor[:self.num_prefills])
|
| 298 |
+
block_tables = (None if self.block_tables is None else
|
| 299 |
+
self.block_tables[:self.num_prefills])
|
| 300 |
+
input_positions = (None if self.input_positions is None else
|
| 301 |
+
self.input_positions[:self.num_prefill_tokens])
|
| 302 |
+
|
| 303 |
+
self._cached_prefill_metadata = TritonMLAMetadata(
|
| 304 |
+
num_prefills=self.num_prefills,
|
| 305 |
+
num_prefill_tokens=self.num_prefill_tokens,
|
| 306 |
+
num_decode_tokens=0,
|
| 307 |
+
slot_mapping=slot_mapping,
|
| 308 |
+
multi_modal_placeholder_index_maps=self.
|
| 309 |
+
multi_modal_placeholder_index_maps,
|
| 310 |
+
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
| 311 |
+
input_positions=input_positions,
|
| 312 |
+
seq_lens=seq_lens,
|
| 313 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 314 |
+
max_query_len=self.max_query_len,
|
| 315 |
+
max_prefill_seq_len=self.max_prefill_seq_len,
|
| 316 |
+
max_decode_query_len=0,
|
| 317 |
+
max_decode_seq_len=0,
|
| 318 |
+
query_start_loc=query_start_loc,
|
| 319 |
+
seq_start_loc=seq_start_loc,
|
| 320 |
+
context_lens_tensor=context_lens_tensor,
|
| 321 |
+
block_tables=block_tables,
|
| 322 |
+
use_cuda_graph=False,
|
| 323 |
+
head_dim=self.head_dim)
|
| 324 |
+
return self._cached_prefill_metadata
|
| 325 |
+
|
| 326 |
+
@property
|
| 327 |
+
def decode_metadata(self) -> Optional["TritonMLAMetadata"]:
|
| 328 |
+
if self.num_decode_tokens == 0:
|
| 329 |
+
return None
|
| 330 |
+
|
| 331 |
+
if self._cached_decode_metadata is not None:
|
| 332 |
+
return self._cached_decode_metadata
|
| 333 |
+
assert self.seq_lens_tensor is not None
|
| 334 |
+
|
| 335 |
+
# Compute some attn_metadata fields which default to None
|
| 336 |
+
slot_mapping = (None if self.slot_mapping is None else
|
| 337 |
+
self.slot_mapping[self.num_prefill_tokens:])
|
| 338 |
+
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
| 339 |
+
self.seq_lens_tensor[self.num_prefills:])
|
| 340 |
+
block_tables = (None if self.block_tables is None else
|
| 341 |
+
self.block_tables[self.num_prefills:])
|
| 342 |
+
input_positions = (None if self.input_positions is None else
|
| 343 |
+
self.input_positions[self.num_prefill_tokens:])
|
| 344 |
+
|
| 345 |
+
self._cached_decode_metadata = TritonMLAMetadata(
|
| 346 |
+
num_prefills=0,
|
| 347 |
+
num_prefill_tokens=0,
|
| 348 |
+
num_decode_tokens=self.num_decode_tokens,
|
| 349 |
+
slot_mapping=slot_mapping,
|
| 350 |
+
multi_modal_placeholder_index_maps=None,
|
| 351 |
+
enable_kv_scales_calculation=True,
|
| 352 |
+
seq_lens=None,
|
| 353 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 354 |
+
max_decode_query_len=self.max_decode_query_len,
|
| 355 |
+
max_query_len=self.max_query_len,
|
| 356 |
+
max_prefill_seq_len=0,
|
| 357 |
+
max_decode_seq_len=self.max_decode_seq_len,
|
| 358 |
+
# Batch may be composed of prefill|decodes, adjust query start
|
| 359 |
+
# indices to refer to the start of decodes. E.g.
|
| 360 |
+
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
| 361 |
+
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
| 362 |
+
self.query_start_loc[self.num_prefills])
|
| 363 |
+
if self.query_start_loc is not None else None,
|
| 364 |
+
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
| 365 |
+
if self.seq_start_loc is not None else None,
|
| 366 |
+
context_lens_tensor=None,
|
| 367 |
+
block_tables=block_tables,
|
| 368 |
+
use_cuda_graph=self.use_cuda_graph,
|
| 369 |
+
input_positions=input_positions,
|
| 370 |
+
head_dim=self.head_dim)
|
| 371 |
+
return self._cached_decode_metadata
|
| 372 |
+
|
| 373 |
+
def advance_step(self,
|
| 374 |
+
model_input: "ModelInputForGPUWithSamplingMetadata",
|
| 375 |
+
sampled_token_ids: Optional[torch.Tensor],
|
| 376 |
+
block_size: int,
|
| 377 |
+
num_seqs: int,
|
| 378 |
+
num_queries: int,
|
| 379 |
+
turn_prefills_into_decodes: bool = False):
|
| 380 |
+
"""
|
| 381 |
+
Update metadata in-place to advance one decode step.
|
| 382 |
+
"""
|
| 383 |
+
# When using cudagraph, the num_seqs is padded to the next captured
|
| 384 |
+
# batch sized, but num_queries tracks the actual number of requests in
|
| 385 |
+
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
| 386 |
+
if num_seqs != num_queries:
|
| 387 |
+
assert num_seqs > num_queries
|
| 388 |
+
assert self.use_cuda_graph
|
| 389 |
+
|
| 390 |
+
if turn_prefills_into_decodes:
|
| 391 |
+
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
| 392 |
+
# decodes are scheduled together. In the first step, all the
|
| 393 |
+
# prefills turn into decodes. This update reflects that
|
| 394 |
+
# conversion.
|
| 395 |
+
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
| 396 |
+
self.num_decode_tokens += self.num_prefills
|
| 397 |
+
self.num_prefills = 0
|
| 398 |
+
self.num_prefill_tokens = 0
|
| 399 |
+
self.max_prefill_seq_len = 0
|
| 400 |
+
self.max_query_len = 1
|
| 401 |
+
|
| 402 |
+
self.slot_mapping = self.slot_mapping[:num_seqs]
|
| 403 |
+
else:
|
| 404 |
+
assert self.seq_lens is not None
|
| 405 |
+
assert self.max_decode_seq_len == max(self.seq_lens)
|
| 406 |
+
|
| 407 |
+
assert self.num_prefills == 0
|
| 408 |
+
assert self.num_prefill_tokens == 0
|
| 409 |
+
assert self.num_decode_tokens == num_seqs
|
| 410 |
+
assert self.slot_mapping.shape == (num_seqs, )
|
| 411 |
+
|
| 412 |
+
assert self.seq_lens is not None
|
| 413 |
+
assert len(self.seq_lens) == num_seqs
|
| 414 |
+
assert self.seq_lens_tensor is not None
|
| 415 |
+
assert self.seq_lens_tensor.shape == (num_seqs, )
|
| 416 |
+
assert self.max_query_len == 1
|
| 417 |
+
assert self.max_prefill_seq_len == 0
|
| 418 |
+
|
| 419 |
+
assert self.query_start_loc is not None
|
| 420 |
+
assert self.query_start_loc.shape == (num_queries + 1, )
|
| 421 |
+
assert self.seq_start_loc is not None
|
| 422 |
+
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
| 423 |
+
|
| 424 |
+
assert self.context_lens_tensor is not None
|
| 425 |
+
assert self.context_lens_tensor.shape == (num_queries, )
|
| 426 |
+
|
| 427 |
+
assert self.block_tables is not None
|
| 428 |
+
assert self.block_tables.shape[0] == num_seqs
|
| 429 |
+
|
| 430 |
+
# Update query lengths. Note that we update only queries and not seqs,
|
| 431 |
+
# since tensors may be padded due to captured cuda graph batch size
|
| 432 |
+
for i in range(num_queries):
|
| 433 |
+
self.seq_lens[i] += 1
|
| 434 |
+
self.max_decode_seq_len = max(self.seq_lens)
|
| 435 |
+
|
| 436 |
+
ops.advance_step_flashattn(num_seqs=num_seqs,
|
| 437 |
+
num_queries=num_queries,
|
| 438 |
+
block_size=block_size,
|
| 439 |
+
input_tokens=model_input.input_tokens,
|
| 440 |
+
sampled_token_ids=sampled_token_ids,
|
| 441 |
+
input_positions=model_input.input_positions,
|
| 442 |
+
seq_lens=self.seq_lens_tensor,
|
| 443 |
+
slot_mapping=self.slot_mapping,
|
| 444 |
+
block_tables=self.block_tables)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
|
| 448 |
+
|
| 449 |
+
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
| 450 |
+
self.input_builder = input_builder
|
| 451 |
+
self.runner = input_builder.runner
|
| 452 |
+
self.sliding_window = input_builder.sliding_window
|
| 453 |
+
self.block_size = input_builder.block_size
|
| 454 |
+
|
| 455 |
+
def prepare(self):
|
| 456 |
+
self.slot_mapping: List[int] = []
|
| 457 |
+
self.prefill_seq_lens: List[int] = []
|
| 458 |
+
self.context_lens: List[int] = []
|
| 459 |
+
self.block_tables: List[List[int]] = []
|
| 460 |
+
self.curr_seq_lens: List[int] = []
|
| 461 |
+
self.input_positions: List[int] = []
|
| 462 |
+
self.multimodal_placeholder_maps: Dict[
|
| 463 |
+
str,
|
| 464 |
+
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
| 465 |
+
self.num_prefills = 0
|
| 466 |
+
self.num_prefill_tokens = 0
|
| 467 |
+
self.num_decode_tokens = 0
|
| 468 |
+
self.has_prefix_cache_hit = False
|
| 469 |
+
|
| 470 |
+
def _add_seq_group(
|
| 471 |
+
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
| 472 |
+
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
|
| 473 |
+
"""Add a sequence group to the metadata. Specifically update/append
|
| 474 |
+
1. context length.
|
| 475 |
+
2. block table.
|
| 476 |
+
3. slot mapping.
|
| 477 |
+
"""
|
| 478 |
+
is_prompt = inter_data.is_prompt
|
| 479 |
+
block_tables = inter_data.block_tables
|
| 480 |
+
|
| 481 |
+
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
| 482 |
+
curr_sliding_window_block, input_positions) in zip(
|
| 483 |
+
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
| 484 |
+
inter_data.orig_seq_lens, inter_data.seq_lens,
|
| 485 |
+
inter_data.query_lens, inter_data.context_lens,
|
| 486 |
+
inter_data.curr_sliding_window_blocks,
|
| 487 |
+
inter_data.input_positions):
|
| 488 |
+
self.input_positions.extend(input_positions)
|
| 489 |
+
self.context_lens.append(context_len)
|
| 490 |
+
if is_prompt:
|
| 491 |
+
mm_maps = inter_data.multi_modal_placeholder_maps
|
| 492 |
+
if mm_maps:
|
| 493 |
+
for modality, placeholders in mm_maps.items():
|
| 494 |
+
self.multimodal_placeholder_maps[modality].extend(
|
| 495 |
+
placeholders)
|
| 496 |
+
|
| 497 |
+
self.num_prefills += 1
|
| 498 |
+
self.num_prefill_tokens += token_len
|
| 499 |
+
self.prefill_seq_lens.append(seq_len)
|
| 500 |
+
else:
|
| 501 |
+
self.num_decode_tokens += query_len
|
| 502 |
+
self.curr_seq_lens.append(curr_seq_len)
|
| 503 |
+
|
| 504 |
+
# Compute block table.
|
| 505 |
+
# TODO(sang): Combine chunked prefill and prefix caching by
|
| 506 |
+
# only allowing multiple of block_size chunk size.
|
| 507 |
+
# NOTE: This only works for oooooooxxx style attention.
|
| 508 |
+
block_table = []
|
| 509 |
+
if prefix_cache_hit:
|
| 510 |
+
# NOTE(woosuk): For flash-attn, the block table should
|
| 511 |
+
# include the entries for the incoming prefill tokens.
|
| 512 |
+
block_table = block_tables[seq_id]
|
| 513 |
+
elif ((chunked_prefill_enabled or not is_prompt)
|
| 514 |
+
and block_tables is not None):
|
| 515 |
+
if curr_sliding_window_block == 0:
|
| 516 |
+
block_table = block_tables[seq_id]
|
| 517 |
+
else:
|
| 518 |
+
block_table = block_tables[seq_id][
|
| 519 |
+
-curr_sliding_window_block:]
|
| 520 |
+
self.block_tables.append(block_table)
|
| 521 |
+
|
| 522 |
+
# Compute slot mapping.
|
| 523 |
+
is_profile_run = is_block_tables_empty(block_tables)
|
| 524 |
+
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
| 525 |
+
context_len,
|
| 526 |
+
self.sliding_window)
|
| 527 |
+
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
| 528 |
+
seq_len, context_len, start_idx,
|
| 529 |
+
self.block_size, inter_data.block_tables)
|
| 530 |
+
|
| 531 |
+
def _get_graph_runner_block_tables(
|
| 532 |
+
self, num_seqs: int,
|
| 533 |
+
block_tables: List[List[int]]) -> torch.Tensor:
|
| 534 |
+
# The shape of graph_block_tables is
|
| 535 |
+
# [max batch size, max context len // block size].
|
| 536 |
+
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
| 537 |
+
assert max_batch_size >= num_seqs
|
| 538 |
+
|
| 539 |
+
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
|
| 540 |
+
for i, block_table in enumerate(block_tables):
|
| 541 |
+
if block_table:
|
| 542 |
+
num_blocks = len(block_table)
|
| 543 |
+
if num_blocks <= max_blocks:
|
| 544 |
+
graph_block_tables[i, :num_blocks] = block_table
|
| 545 |
+
else:
|
| 546 |
+
# It may be possible to have more blocks allocated due
|
| 547 |
+
# to lookahead slots of multi-step, however, they are
|
| 548 |
+
# not used anyway, so can be safely ignored.
|
| 549 |
+
graph_block_tables[
|
| 550 |
+
i, :max_blocks] = block_table[:max_blocks]
|
| 551 |
+
|
| 552 |
+
return torch.from_numpy(graph_block_tables).to(
|
| 553 |
+
device=self.runner.device, non_blocking=True)
|
| 554 |
+
|
| 555 |
+
def build(self, seq_lens: List[int], query_lens: List[int],
|
| 556 |
+
cuda_graph_pad_size: int, batch_size: int):
|
| 557 |
+
"""Build attention metadata with on-device tensors.
|
| 558 |
+
|
| 559 |
+
Args:
|
| 560 |
+
seq_lens: The maybe padded sequence lengths of the input sequences.
|
| 561 |
+
query_lens: The query lengths of the input sequences.
|
| 562 |
+
cuda_graph_pad_size: The padding size for cuda graph.
|
| 563 |
+
-1 if cuda graph is not used.
|
| 564 |
+
batch_size: The maybe padded batch size.
|
| 565 |
+
"""
|
| 566 |
+
prefix_cache_hit = any([
|
| 567 |
+
inter_data.prefix_cache_hit
|
| 568 |
+
for inter_data in self.input_builder.inter_data_list
|
| 569 |
+
])
|
| 570 |
+
for inter_data in self.input_builder.inter_data_list:
|
| 571 |
+
self._add_seq_group(inter_data,
|
| 572 |
+
self.input_builder.chunked_prefill_enabled,
|
| 573 |
+
prefix_cache_hit)
|
| 574 |
+
|
| 575 |
+
device = self.runner.device
|
| 576 |
+
use_captured_graph = cuda_graph_pad_size != -1
|
| 577 |
+
|
| 578 |
+
max_query_len = max(query_lens)
|
| 579 |
+
decode_query_lens = query_lens[self.num_prefills:]
|
| 580 |
+
if len(decode_query_lens) > 0:
|
| 581 |
+
max_decode_query_len = max(decode_query_lens)
|
| 582 |
+
else:
|
| 583 |
+
max_decode_query_len = 1
|
| 584 |
+
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
| 585 |
+
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
| 586 |
+
num_decode_tokens = self.num_decode_tokens
|
| 587 |
+
query_start_loc = list(accumulate(query_lens, initial=0))
|
| 588 |
+
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
| 589 |
+
|
| 590 |
+
num_seqs = len(seq_lens)
|
| 591 |
+
if use_captured_graph:
|
| 592 |
+
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
| 593 |
+
self.block_tables.extend([] * cuda_graph_pad_size)
|
| 594 |
+
num_decode_tokens = batch_size - self.num_prefill_tokens
|
| 595 |
+
block_tables = self._get_graph_runner_block_tables(
|
| 596 |
+
num_seqs, self.block_tables)
|
| 597 |
+
else:
|
| 598 |
+
block_tables = make_tensor_with_pad(
|
| 599 |
+
self.block_tables,
|
| 600 |
+
pad=0,
|
| 601 |
+
dtype=torch.int,
|
| 602 |
+
device=device,
|
| 603 |
+
)
|
| 604 |
+
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
| 605 |
+
|
| 606 |
+
assert device is not None
|
| 607 |
+
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
| 608 |
+
device, self.runner.pin_memory)
|
| 609 |
+
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
| 610 |
+
self.runner.pin_memory)
|
| 611 |
+
input_positions = async_tensor_h2d(self.input_positions, torch.long,
|
| 612 |
+
device, self.runner.pin_memory)
|
| 613 |
+
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
| 614 |
+
device, self.runner.pin_memory)
|
| 615 |
+
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
| 616 |
+
device,
|
| 617 |
+
self.runner.pin_memory)
|
| 618 |
+
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
| 619 |
+
device, self.runner.pin_memory)
|
| 620 |
+
placeholder_index_maps = {
|
| 621 |
+
modality: placeholder_map.index_map()
|
| 622 |
+
for modality, placeholder_map in
|
| 623 |
+
self.multimodal_placeholder_maps.items()
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
return TritonMLAMetadata(
|
| 627 |
+
num_prefills=self.num_prefills,
|
| 628 |
+
slot_mapping=slot_mapping_tensor,
|
| 629 |
+
num_prefill_tokens=self.num_prefill_tokens,
|
| 630 |
+
num_decode_tokens=num_decode_tokens,
|
| 631 |
+
seq_lens=seq_lens,
|
| 632 |
+
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
| 633 |
+
enable_kv_scales_calculation=True,
|
| 634 |
+
input_positions=input_positions,
|
| 635 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 636 |
+
max_query_len=max_query_len,
|
| 637 |
+
max_decode_query_len=max_decode_query_len,
|
| 638 |
+
max_prefill_seq_len=max_prefill_seq_len,
|
| 639 |
+
max_decode_seq_len=max_decode_seq_len,
|
| 640 |
+
query_start_loc=query_start_loc_tensor,
|
| 641 |
+
seq_start_loc=seq_start_loc_tensor,
|
| 642 |
+
context_lens_tensor=context_lens_tensor,
|
| 643 |
+
block_tables=block_tables,
|
| 644 |
+
use_cuda_graph=use_captured_graph,
|
| 645 |
+
num_kv_splits=4, # TODO(lucas) add heuristic
|
| 646 |
+
head_dim=self.runner.model_config.get_head_size(),
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
|
| 651 |
+
|
| 652 |
+
def __init__(
|
| 653 |
+
self,
|
| 654 |
+
num_heads: int,
|
| 655 |
+
head_size: int,
|
| 656 |
+
scale: float,
|
| 657 |
+
num_kv_heads: int,
|
| 658 |
+
alibi_slopes: Optional[List[float]],
|
| 659 |
+
sliding_window: Optional[int],
|
| 660 |
+
kv_cache_dtype: str,
|
| 661 |
+
blocksparse_params: Optional[Dict[str, Any]],
|
| 662 |
+
logits_soft_cap: Optional[float],
|
| 663 |
+
attn_type: str,
|
| 664 |
+
# MLA Specific Arguments
|
| 665 |
+
**kwargs) -> None:
|
| 666 |
+
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
| 667 |
+
alibi_slopes, sliding_window, kv_cache_dtype,
|
| 668 |
+
blocksparse_params, logits_soft_cap, attn_type,
|
| 669 |
+
**kwargs)
|
| 670 |
+
|
| 671 |
+
unsupported_features = [
|
| 672 |
+
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
| 673 |
+
]
|
| 674 |
+
if any(unsupported_features):
|
| 675 |
+
raise NotImplementedError(
|
| 676 |
+
"TritonMLAImpl does not support one of the following: "
|
| 677 |
+
"alibi_slopes, sliding_window, blocksparse_params, "
|
| 678 |
+
"logits_soft_cap")
|
| 679 |
+
|
| 680 |
+
if attn_type != AttentionType.DECODER:
|
| 681 |
+
raise NotImplementedError("Encoder self-attention and "
|
| 682 |
+
"encoder/decoder cross-attention "
|
| 683 |
+
"are not implemented for "
|
| 684 |
+
"TritonMLAImpl")
|
| 685 |
+
|
| 686 |
+
def _forward_prefill(
|
| 687 |
+
self,
|
| 688 |
+
q: torch.Tensor,
|
| 689 |
+
kv_c_normed: torch.Tensor,
|
| 690 |
+
k_pe: torch.Tensor,
|
| 691 |
+
attn_metadata: TritonMLAMetadata,
|
| 692 |
+
) -> torch.Tensor:
|
| 693 |
+
assert isinstance(attn_metadata, TritonMLAMetadata)
|
| 694 |
+
return self._forward_prefill_flash(q, kv_c_normed, k_pe,
|
| 695 |
+
attn_metadata.seq_start_loc,
|
| 696 |
+
attn_metadata.max_prefill_seq_len)
|
| 697 |
+
|
| 698 |
+
def _forward_decode(
|
| 699 |
+
self,
|
| 700 |
+
q_nope: torch.Tensor,
|
| 701 |
+
q_pe: torch.Tensor,
|
| 702 |
+
kv_c_and_k_pe_cache: torch.Tensor,
|
| 703 |
+
attn_metadata: TritonMLAMetadata,
|
| 704 |
+
) -> torch.Tensor:
|
| 705 |
+
assert kv_c_and_k_pe_cache.numel() > 0
|
| 706 |
+
if self.kv_cache_dtype.startswith("fp8"):
|
| 707 |
+
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
| 708 |
+
|
| 709 |
+
decode_meta = attn_metadata.decode_metadata
|
| 710 |
+
assert decode_meta is not None
|
| 711 |
+
B = q_nope.shape[0]
|
| 712 |
+
|
| 713 |
+
q = torch.cat([q_nope, q_pe], dim=-1)
|
| 714 |
+
o = torch.zeros(B,
|
| 715 |
+
self.num_heads,
|
| 716 |
+
self.kv_lora_rank,
|
| 717 |
+
dtype=q.dtype,
|
| 718 |
+
device=q.device)
|
| 719 |
+
|
| 720 |
+
# TODO(lucas) Allocate ahead of time
|
| 721 |
+
attn_logits = torch.empty(
|
| 722 |
+
(
|
| 723 |
+
B,
|
| 724 |
+
self.num_heads,
|
| 725 |
+
attn_metadata.num_kv_splits,
|
| 726 |
+
# NOTE(lucas) idk why the +1 is here but sglang has it so we
|
| 727 |
+
# just mirror that
|
| 728 |
+
self.kv_lora_rank + 1,
|
| 729 |
+
),
|
| 730 |
+
dtype=torch.float32,
|
| 731 |
+
device=q.device,
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
# Add a head dim of 1
|
| 735 |
+
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
|
| 736 |
+
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
| 737 |
+
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
|
| 738 |
+
|
| 739 |
+
# Run MQA
|
| 740 |
+
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
|
| 741 |
+
decode_meta.block_tables,
|
| 742 |
+
decode_meta.seq_lens_tensor, attn_logits,
|
| 743 |
+
attn_metadata.num_kv_splits, self.scale,
|
| 744 |
+
PAGE_SIZE)
|
| 745 |
+
|
| 746 |
+
return self._v_up_proj_and_o_proj(o)
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/utils.py
ADDED
|
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Attention backend utils"""
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from contextlib import contextmanager
|
| 5 |
+
from itertools import accumulate
|
| 6 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
| 12 |
+
AttentionState)
|
| 13 |
+
from vllm.attention.backends.abstract import AttentionType
|
| 14 |
+
from vllm.multimodal import MultiModalPlaceholderMap
|
| 15 |
+
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from vllm.worker.model_runner_base import ModelRunnerBase
|
| 19 |
+
|
| 20 |
+
# Error string(s) for encoder/decoder
|
| 21 |
+
# unsupported attention scenarios
|
| 22 |
+
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
|
| 23 |
+
"with encoder/decoder models.")
|
| 24 |
+
|
| 25 |
+
PAD_SLOT_ID = -1
|
| 26 |
+
|
| 27 |
+
# Switch to numpy implementation of compute_slot_mapping
|
| 28 |
+
# if we have at least this many elements. Could be tuned further.
|
| 29 |
+
_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256
|
| 30 |
+
|
| 31 |
+
if TYPE_CHECKING:
|
| 32 |
+
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def is_block_tables_empty(block_tables: Union[None, Dict]):
|
| 36 |
+
"""
|
| 37 |
+
Check if block_tables is None or a dictionary with all None values.
|
| 38 |
+
"""
|
| 39 |
+
if block_tables is None:
|
| 40 |
+
return True
|
| 41 |
+
return (isinstance(block_tables, dict)
|
| 42 |
+
and all(value is None for value in block_tables.values()))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
|
| 46 |
+
context_len: int, sliding_window: int):
|
| 47 |
+
"""
|
| 48 |
+
Compute the start index of slot mapping.
|
| 49 |
+
"""
|
| 50 |
+
start_idx = 0
|
| 51 |
+
if is_prompt and sliding_window is not None:
|
| 52 |
+
start_idx = max(0, query_len - sliding_window)
|
| 53 |
+
return start_idx
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _compute_slot_mapping_python(slot_mapping: List[int],
|
| 57 |
+
block_table: List[int], range_start: int,
|
| 58 |
+
range_end: int, block_size: int):
|
| 59 |
+
for i in range(range_start, range_end):
|
| 60 |
+
block_number = block_table[i // block_size]
|
| 61 |
+
block_offset = i % block_size
|
| 62 |
+
slot = block_number * block_size + block_offset
|
| 63 |
+
slot_mapping.append(slot)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _compute_slot_mapping_numpy(slot_mapping: List[int],
|
| 67 |
+
block_table: List[int], range_start: int,
|
| 68 |
+
range_end: int, block_size: int):
|
| 69 |
+
block_table_array = np.array(block_table)
|
| 70 |
+
idx = np.arange(range_start, range_end)
|
| 71 |
+
block_offset = idx % block_size
|
| 72 |
+
idx //= block_size
|
| 73 |
+
seq_slot_mapping_array = block_table_array[idx]
|
| 74 |
+
seq_slot_mapping_array *= block_size
|
| 75 |
+
seq_slot_mapping_array += block_offset
|
| 76 |
+
slot_mapping.extend(seq_slot_mapping_array)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
|
| 80 |
+
seq_id: int, seq_len: int, context_len: int,
|
| 81 |
+
start_idx: int, block_size: int,
|
| 82 |
+
block_tables: Dict[int, List[int]]):
|
| 83 |
+
"""
|
| 84 |
+
Compute slot mapping.
|
| 85 |
+
"""
|
| 86 |
+
if is_profile_run:
|
| 87 |
+
# During memory profiling, the block tables are not
|
| 88 |
+
# initialized yet. In this case, we just use a dummy
|
| 89 |
+
# slot mapping.
|
| 90 |
+
# In embeddings, the block tables are {seq_id: None}.
|
| 91 |
+
slot_mapping.extend([PAD_SLOT_ID] * seq_len)
|
| 92 |
+
return
|
| 93 |
+
|
| 94 |
+
# Mask the [0, start_idx) tokens of the prompt with
|
| 95 |
+
# PAD_SLOT_ID, where start_idx is max(0, seq_len -
|
| 96 |
+
# sliding_window). For example, if the prompt len is 10,
|
| 97 |
+
# sliding window is 8, and block size is 4, the first two
|
| 98 |
+
# tokens are masked and the slot mapping will be
|
| 99 |
+
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
| 100 |
+
padding_mask_len = max(0, start_idx - context_len)
|
| 101 |
+
slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len)
|
| 102 |
+
|
| 103 |
+
range_start = max(start_idx, context_len)
|
| 104 |
+
range_end = seq_len
|
| 105 |
+
numel = range_end - range_start
|
| 106 |
+
block_table = block_tables[seq_id]
|
| 107 |
+
|
| 108 |
+
# numpy implementation will be faster than python if we have
|
| 109 |
+
# many elements, otherwise it will be slower.
|
| 110 |
+
if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL:
|
| 111 |
+
_compute_slot_mapping_python(slot_mapping, block_table, range_start,
|
| 112 |
+
range_end, block_size)
|
| 113 |
+
else:
|
| 114 |
+
_compute_slot_mapping_numpy(slot_mapping, block_table, range_start,
|
| 115 |
+
range_end, block_size)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
| 122 |
+
|
| 123 |
+
_metadata_cls: Type[TAttentionMetadata]
|
| 124 |
+
|
| 125 |
+
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
| 126 |
+
self.input_builder = input_builder
|
| 127 |
+
self.runner = input_builder.runner
|
| 128 |
+
|
| 129 |
+
self.sliding_window = input_builder.sliding_window
|
| 130 |
+
self.block_size = input_builder.block_size
|
| 131 |
+
|
| 132 |
+
def prepare(self):
|
| 133 |
+
self.slot_mapping: List[int] = []
|
| 134 |
+
self.prefill_seq_lens: List[int] = []
|
| 135 |
+
self.context_lens: List[int] = []
|
| 136 |
+
self.block_tables: List[List[int]] = []
|
| 137 |
+
self.curr_seq_lens: List[int] = []
|
| 138 |
+
self.multimodal_placeholder_maps: Dict[
|
| 139 |
+
str,
|
| 140 |
+
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
| 141 |
+
self.num_prefills = 0
|
| 142 |
+
self.num_prefill_tokens = 0
|
| 143 |
+
self.num_decode_tokens = 0
|
| 144 |
+
|
| 145 |
+
def _add_seq_group(
|
| 146 |
+
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
| 147 |
+
chunked_prefill_enabled: bool):
|
| 148 |
+
is_prompt = inter_data.is_prompt
|
| 149 |
+
block_tables = inter_data.block_tables
|
| 150 |
+
|
| 151 |
+
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
| 152 |
+
curr_sliding_window_block) in zip(
|
| 153 |
+
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
| 154 |
+
inter_data.orig_seq_lens, inter_data.seq_lens,
|
| 155 |
+
inter_data.query_lens, inter_data.context_lens,
|
| 156 |
+
inter_data.curr_sliding_window_blocks):
|
| 157 |
+
self.context_lens.append(context_len)
|
| 158 |
+
if is_prompt:
|
| 159 |
+
mm_maps = inter_data.multi_modal_placeholder_maps
|
| 160 |
+
if mm_maps:
|
| 161 |
+
for modality, placeholders in mm_maps.items():
|
| 162 |
+
self.multimodal_placeholder_maps[modality].extend(
|
| 163 |
+
placeholders)
|
| 164 |
+
|
| 165 |
+
self.num_prefills += 1
|
| 166 |
+
self.num_prefill_tokens += token_len
|
| 167 |
+
self.prefill_seq_lens.append(seq_len)
|
| 168 |
+
else:
|
| 169 |
+
assert query_len == 1, (
|
| 170 |
+
"seq_len: {}, context_len: {}, query_len: {}".format(
|
| 171 |
+
seq_len, context_len, query_len))
|
| 172 |
+
self.num_decode_tokens += query_len
|
| 173 |
+
self.curr_seq_lens.append(curr_seq_len)
|
| 174 |
+
|
| 175 |
+
# Compute block table.
|
| 176 |
+
# TODO(sang): Combine chunked prefill and prefix caching by
|
| 177 |
+
# only allowing multiple of block_size chunk size.
|
| 178 |
+
# NOTE: This only works for oooooooxxx style attention.
|
| 179 |
+
block_table = []
|
| 180 |
+
if inter_data.prefix_cache_hit:
|
| 181 |
+
block_table = block_tables[seq_id]
|
| 182 |
+
elif ((chunked_prefill_enabled or not is_prompt)
|
| 183 |
+
and block_tables is not None):
|
| 184 |
+
if curr_sliding_window_block == 0:
|
| 185 |
+
block_table = block_tables[seq_id]
|
| 186 |
+
else:
|
| 187 |
+
block_table = block_tables[seq_id][
|
| 188 |
+
-curr_sliding_window_block:]
|
| 189 |
+
self.block_tables.append(block_table)
|
| 190 |
+
|
| 191 |
+
# Compute slot mapping.
|
| 192 |
+
is_profile_run = is_block_tables_empty(block_tables)
|
| 193 |
+
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
| 194 |
+
context_len,
|
| 195 |
+
self.sliding_window)
|
| 196 |
+
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
| 197 |
+
seq_len, context_len, start_idx,
|
| 198 |
+
self.block_size, inter_data.block_tables)
|
| 199 |
+
|
| 200 |
+
def build(self, seq_lens: List[int], query_lens: List[int],
|
| 201 |
+
cuda_graph_pad_size: int, batch_size: int):
|
| 202 |
+
"""Build attention metadata with on-device tensors.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
seq_lens: The maybe padded sequence lengths of the input sequences.
|
| 206 |
+
query_lens: The query lengths of the input sequences.
|
| 207 |
+
cuda_graph_pad_size: The padding size for cuda graph.
|
| 208 |
+
-1 if cuda graph is not used.
|
| 209 |
+
batch_size: The maybe padded batch size.
|
| 210 |
+
"""
|
| 211 |
+
for inter_data in self.input_builder.inter_data_list:
|
| 212 |
+
self._add_seq_group(inter_data,
|
| 213 |
+
self.input_builder.chunked_prefill_enabled)
|
| 214 |
+
|
| 215 |
+
device = self.runner.device
|
| 216 |
+
use_captured_graph = cuda_graph_pad_size != -1
|
| 217 |
+
|
| 218 |
+
max_query_len = max(query_lens)
|
| 219 |
+
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
| 220 |
+
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
| 221 |
+
num_decode_tokens = self.num_decode_tokens
|
| 222 |
+
query_start_loc = list(accumulate(query_lens, initial=0))
|
| 223 |
+
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
| 224 |
+
|
| 225 |
+
if use_captured_graph:
|
| 226 |
+
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
| 227 |
+
self.block_tables.extend([] * cuda_graph_pad_size)
|
| 228 |
+
num_decode_tokens = batch_size
|
| 229 |
+
|
| 230 |
+
# The shape of graph_block_tables is
|
| 231 |
+
# [max batch size, max context len // block size].
|
| 232 |
+
input_block_tables = self.runner.graph_block_tables[:batch_size]
|
| 233 |
+
for i, block_table in enumerate(self.block_tables):
|
| 234 |
+
if block_table:
|
| 235 |
+
input_block_tables[i, :len(block_table)] = block_table
|
| 236 |
+
block_tables = torch.from_numpy(input_block_tables).to(
|
| 237 |
+
device, non_blocking=True)
|
| 238 |
+
else:
|
| 239 |
+
block_tables = make_tensor_with_pad(
|
| 240 |
+
self.block_tables,
|
| 241 |
+
pad=0,
|
| 242 |
+
dtype=torch.int,
|
| 243 |
+
device=device,
|
| 244 |
+
)
|
| 245 |
+
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
| 246 |
+
|
| 247 |
+
assert device is not None
|
| 248 |
+
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
| 249 |
+
device, self.runner.pin_memory)
|
| 250 |
+
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
| 251 |
+
self.runner.pin_memory)
|
| 252 |
+
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
| 253 |
+
device, self.runner.pin_memory)
|
| 254 |
+
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
| 255 |
+
device,
|
| 256 |
+
self.runner.pin_memory)
|
| 257 |
+
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
| 258 |
+
device, self.runner.pin_memory)
|
| 259 |
+
placeholder_index_maps = {
|
| 260 |
+
modality: placeholder_map.index_map()
|
| 261 |
+
for modality, placeholder_map in
|
| 262 |
+
self.multimodal_placeholder_maps.items()
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
return self._metadata_cls( # type: ignore
|
| 266 |
+
num_prefills=self.num_prefills,
|
| 267 |
+
slot_mapping=slot_mapping_tensor,
|
| 268 |
+
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
| 269 |
+
enable_kv_scales_calculation=True,
|
| 270 |
+
num_prefill_tokens=self.num_prefill_tokens,
|
| 271 |
+
num_decode_tokens=num_decode_tokens,
|
| 272 |
+
seq_lens=seq_lens,
|
| 273 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 274 |
+
max_query_len=max_query_len,
|
| 275 |
+
max_prefill_seq_len=max_prefill_seq_len,
|
| 276 |
+
max_decode_seq_len=max_decode_seq_len,
|
| 277 |
+
query_start_loc=query_start_loc_tensor,
|
| 278 |
+
seq_start_loc=seq_start_loc_tensor,
|
| 279 |
+
context_lens_tensor=context_lens_tensor,
|
| 280 |
+
block_tables=block_tables,
|
| 281 |
+
use_cuda_graph=use_captured_graph,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class CommonAttentionState(AttentionState):
|
| 286 |
+
|
| 287 |
+
def __init__(self, runner: "ModelRunnerBase"):
|
| 288 |
+
self.runner = runner
|
| 289 |
+
self._is_graph_capturing = False
|
| 290 |
+
|
| 291 |
+
@contextmanager
|
| 292 |
+
def graph_capture(self, max_batch_size: int):
|
| 293 |
+
|
| 294 |
+
self._is_graph_capturing = True
|
| 295 |
+
|
| 296 |
+
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
| 297 |
+
PAD_SLOT_ID,
|
| 298 |
+
dtype=torch.long,
|
| 299 |
+
device=self.runner.device)
|
| 300 |
+
self._graph_seq_lens = torch.ones(max_batch_size,
|
| 301 |
+
dtype=torch.int32,
|
| 302 |
+
device=self.runner.device)
|
| 303 |
+
self._graph_block_tables = torch.from_numpy(
|
| 304 |
+
self.runner.graph_block_tables).to(device=self.runner.device)
|
| 305 |
+
|
| 306 |
+
yield
|
| 307 |
+
|
| 308 |
+
self._is_graph_capturing = False
|
| 309 |
+
del self._graph_slot_mapping
|
| 310 |
+
del self._graph_seq_lens
|
| 311 |
+
del self._graph_block_tables
|
| 312 |
+
|
| 313 |
+
def graph_clone(self, batch_size: int) -> "CommonAttentionState":
|
| 314 |
+
assert self._is_graph_capturing
|
| 315 |
+
return self.__class__(self.runner)
|
| 316 |
+
|
| 317 |
+
def graph_capture_get_metadata_for_batch(
|
| 318 |
+
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
| 319 |
+
assert self._is_graph_capturing
|
| 320 |
+
attn_metadata = self.runner.attn_backend.make_metadata(
|
| 321 |
+
num_prefills=0,
|
| 322 |
+
num_prefill_tokens=0,
|
| 323 |
+
num_decode_tokens=batch_size,
|
| 324 |
+
slot_mapping=self._graph_slot_mapping[:batch_size],
|
| 325 |
+
multi_modal_placeholder_index_maps=None,
|
| 326 |
+
enable_kv_scales_calculation=True,
|
| 327 |
+
seq_lens=None,
|
| 328 |
+
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
| 329 |
+
max_query_len=1,
|
| 330 |
+
max_decode_query_len=1,
|
| 331 |
+
max_prefill_seq_len=0,
|
| 332 |
+
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
| 333 |
+
query_start_loc=None,
|
| 334 |
+
seq_start_loc=None,
|
| 335 |
+
context_lens_tensor=None,
|
| 336 |
+
block_tables=self._graph_block_tables[:batch_size],
|
| 337 |
+
use_cuda_graph=True,
|
| 338 |
+
)
|
| 339 |
+
if is_encoder_decoder_model:
|
| 340 |
+
# The encoder decoder model works only with XFormers and
|
| 341 |
+
# Flash Attention backend. Assert the same.
|
| 342 |
+
assert self.runner.attn_backend.get_name() in\
|
| 343 |
+
["XFORMERS", "FLASH_ATTN"], \
|
| 344 |
+
f"Expected attn_backend name to be either 'XFORMERS' or " \
|
| 345 |
+
f"'FLASH_ATTN', but "\
|
| 346 |
+
f"got '{self.runner.attn_backend.get_name()}'"
|
| 347 |
+
self._update_captured_metadata_for_enc_dec_model(
|
| 348 |
+
batch_size=batch_size, attn_metadata=attn_metadata)
|
| 349 |
+
|
| 350 |
+
return attn_metadata
|
| 351 |
+
|
| 352 |
+
def get_graph_input_buffers(
|
| 353 |
+
self,
|
| 354 |
+
attn_metadata,
|
| 355 |
+
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
|
| 356 |
+
input_buffers = {
|
| 357 |
+
"slot_mapping": attn_metadata.slot_mapping,
|
| 358 |
+
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
| 359 |
+
"block_tables": attn_metadata.decode_metadata.block_tables,
|
| 360 |
+
}
|
| 361 |
+
if is_encoder_decoder_model:
|
| 362 |
+
# The encoder decoder model works only with XFormers and
|
| 363 |
+
# Flash Attention backend. Assert the same.
|
| 364 |
+
assert self.runner.attn_backend.get_name() in\
|
| 365 |
+
["XFORMERS", "FLASH_ATTN"], \
|
| 366 |
+
f"Expected attn_backend name to be either 'XFORMERS' or "\
|
| 367 |
+
f"'FLASH_ATTN', but "\
|
| 368 |
+
f"got '{self.runner.attn_backend.get_name()}'"
|
| 369 |
+
self._add_additonal_input_buffers_for_enc_dec_model(
|
| 370 |
+
attn_metadata=attn_metadata, input_buffers=input_buffers)
|
| 371 |
+
return input_buffers
|
| 372 |
+
|
| 373 |
+
def prepare_graph_input_buffers(
|
| 374 |
+
self,
|
| 375 |
+
input_buffers,
|
| 376 |
+
attn_metadata,
|
| 377 |
+
is_encoder_decoder_model: bool = False) -> None:
|
| 378 |
+
input_buffers["seq_lens_tensor"].copy_(
|
| 379 |
+
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
| 380 |
+
input_buffers["block_tables"].copy_(
|
| 381 |
+
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
| 382 |
+
if is_encoder_decoder_model:
|
| 383 |
+
# The encoder decoder model works only with XFormers and
|
| 384 |
+
# Flash Attention backend. Assert the same.
|
| 385 |
+
assert self.runner.attn_backend.get_name() in\
|
| 386 |
+
["XFORMERS", "FLASH_ATTN"], \
|
| 387 |
+
f"Expected attn_backend name to be either 'XFORMERS' or "\
|
| 388 |
+
f"'FLASH_ATTN', but "\
|
| 389 |
+
f"got '{self.runner.attn_backend.get_name()}'"
|
| 390 |
+
self._prepare_input_buffers_for_enc_dec_model(
|
| 391 |
+
attn_metadata, input_buffers)
|
| 392 |
+
|
| 393 |
+
def begin_forward(self, model_input) -> None:
|
| 394 |
+
return
|
| 395 |
+
|
| 396 |
+
def _update_captured_metadata_for_enc_dec_model(self, batch_size: int,
|
| 397 |
+
attn_metadata):
|
| 398 |
+
"""
|
| 399 |
+
Updates the attention metadata parameters for CUDA graph capture in an
|
| 400 |
+
encoder-decoder model.
|
| 401 |
+
|
| 402 |
+
This method modifies attention-related tensors and metadata required
|
| 403 |
+
for CUDA graph capture in encoder-decoder models. Specifically, it
|
| 404 |
+
updates the cross-attention and encoder sequence tensors in the
|
| 405 |
+
AttentionMetadata object.
|
| 406 |
+
"""
|
| 407 |
+
# During decode phase the cross_slot_mapping will be empty. Hence set
|
| 408 |
+
# an empty tensor for CUDA Graph capture.
|
| 409 |
+
attn_metadata.cross_slot_mapping = torch.tensor(
|
| 410 |
+
[], dtype=torch.int).cuda()
|
| 411 |
+
attn_metadata.cross_block_tables = torch.full(
|
| 412 |
+
(batch_size, self.runner.get_max_block_per_batch()),
|
| 413 |
+
1,
|
| 414 |
+
dtype=torch.int).cuda()
|
| 415 |
+
attn_metadata.encoder_seq_lens = torch.full((batch_size, ),
|
| 416 |
+
1,
|
| 417 |
+
dtype=torch.int).cuda()
|
| 418 |
+
attn_metadata.encoder_seq_lens_tensor = torch.full(
|
| 419 |
+
(batch_size, ), 1, dtype=torch.int).cuda()
|
| 420 |
+
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
|
| 421 |
+
attn_metadata.num_encoder_tokens = 0
|
| 422 |
+
|
| 423 |
+
def _add_additonal_input_buffers_for_enc_dec_model(
|
| 424 |
+
self, attn_metadata, input_buffers: Dict[str, Any]):
|
| 425 |
+
"""
|
| 426 |
+
Saves additional input buffers specific to the encoder-decoder model
|
| 427 |
+
from the attention metadata.
|
| 428 |
+
|
| 429 |
+
This method extracts and stores encoder-decoder related input buffers
|
| 430 |
+
from the `attn_metadata` into the `input_buffers` dictionary. The
|
| 431 |
+
buffers include encoder sequence lengths, cross-slot mappings, and
|
| 432 |
+
cross-block tables, which are essential for the encoder-decoder model
|
| 433 |
+
during CUDA graph replay.
|
| 434 |
+
"""
|
| 435 |
+
input_buffers["encoder_seq_lens_tensor"] = (
|
| 436 |
+
attn_metadata.decode_metadata.encoder_seq_lens_tensor)
|
| 437 |
+
input_buffers["cross_slot_mapping"] = (
|
| 438 |
+
attn_metadata.decode_metadata.cross_slot_mapping)
|
| 439 |
+
input_buffers["cross_block_tables"] = (
|
| 440 |
+
attn_metadata.decode_metadata.cross_block_tables)
|
| 441 |
+
|
| 442 |
+
def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata,
|
| 443 |
+
input_buffers: Dict[str,
|
| 444 |
+
Any]):
|
| 445 |
+
"""
|
| 446 |
+
Populates input buffers with data from the encoder-decoder model's
|
| 447 |
+
attention metadata.
|
| 448 |
+
|
| 449 |
+
This method fills the input buffers with encoder-decoder specific
|
| 450 |
+
tensors. It copies data from the `attn_metadata` and keyword arguments
|
| 451 |
+
(`kwargs`) into corresponding buffers in the `input_buffers` dictionary.
|
| 452 |
+
The copied data includes attention-related metadata as well as input
|
| 453 |
+
IDs and positional information for the encoder.
|
| 454 |
+
"""
|
| 455 |
+
input_buffers["encoder_seq_lens_tensor"].copy_(
|
| 456 |
+
attn_metadata.decode_metadata.encoder_seq_lens_tensor,
|
| 457 |
+
non_blocking=True)
|
| 458 |
+
input_buffers["cross_slot_mapping"].copy_(
|
| 459 |
+
attn_metadata.decode_metadata.cross_slot_mapping,
|
| 460 |
+
non_blocking=True)
|
| 461 |
+
input_buffers["cross_block_tables"].copy_(
|
| 462 |
+
attn_metadata.decode_metadata.cross_block_tables,
|
| 463 |
+
non_blocking=True)
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def is_all_encoder_attn_metadata_set(attn_metadata):
|
| 467 |
+
'''
|
| 468 |
+
All attention metadata required for encoder attention is set.
|
| 469 |
+
'''
|
| 470 |
+
return ((attn_metadata.encoder_seq_lens is not None)
|
| 471 |
+
and (attn_metadata.encoder_seq_lens_tensor is not None)
|
| 472 |
+
and (attn_metadata.max_encoder_seq_len is not None))
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def is_all_cross_attn_metadata_set(attn_metadata):
|
| 476 |
+
'''
|
| 477 |
+
All attention metadata required for enc/dec cross-attention is set.
|
| 478 |
+
|
| 479 |
+
Superset of encoder attention required metadata.
|
| 480 |
+
'''
|
| 481 |
+
return (attn_metadata.is_all_encoder_attn_metadata_set
|
| 482 |
+
and (attn_metadata.cross_slot_mapping is not None)
|
| 483 |
+
and (attn_metadata.cross_block_tables is not None))
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def get_seq_len_block_table_args(
|
| 487 |
+
attn_metadata,
|
| 488 |
+
is_prompt: bool,
|
| 489 |
+
attn_type: str,
|
| 490 |
+
) -> tuple:
|
| 491 |
+
'''
|
| 492 |
+
The particular choice of sequence-length- and block-table-related
|
| 493 |
+
attributes which should be extracted from attn_metadata is dependent
|
| 494 |
+
on the type of attention operation.
|
| 495 |
+
|
| 496 |
+
Decoder attn -> select entirely decoder self-attention-related fields
|
| 497 |
+
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
| 498 |
+
cross-attn block-tables fields
|
| 499 |
+
Encoder attn -> select encoder sequence lengths fields & no block tables
|
| 500 |
+
|
| 501 |
+
Arguments:
|
| 502 |
+
|
| 503 |
+
* attn_metadata: Attention metadata structure associated with attention op
|
| 504 |
+
* is_prompt: True if prefill, False otherwise
|
| 505 |
+
* attn_type: encoder attention, decoder self-attention,
|
| 506 |
+
encoder/decoder cross-attention
|
| 507 |
+
|
| 508 |
+
Returns:
|
| 509 |
+
|
| 510 |
+
* Appropriate sequence-lengths tensor
|
| 511 |
+
* Appropriate max sequence-length scalar
|
| 512 |
+
* Appropriate block tables (or None)
|
| 513 |
+
'''
|
| 514 |
+
|
| 515 |
+
if attn_type == AttentionType.DECODER:
|
| 516 |
+
# Decoder self-attention
|
| 517 |
+
# Choose max_seq_len based on whether we are in prompt_run
|
| 518 |
+
if is_prompt:
|
| 519 |
+
max_seq_len = attn_metadata.max_prefill_seq_len
|
| 520 |
+
else:
|
| 521 |
+
max_seq_len = attn_metadata.max_decode_seq_len
|
| 522 |
+
return (attn_metadata.seq_lens_tensor, max_seq_len,
|
| 523 |
+
attn_metadata.block_tables)
|
| 524 |
+
elif attn_type == AttentionType.ENCODER_DECODER:
|
| 525 |
+
# Enc/dec cross-attention KVs match encoder sequence length;
|
| 526 |
+
# cross-attention utilizes special "cross" block tables
|
| 527 |
+
return (attn_metadata.encoder_seq_lens_tensor,
|
| 528 |
+
attn_metadata.max_encoder_seq_len,
|
| 529 |
+
attn_metadata.cross_block_tables)
|
| 530 |
+
elif attn_type == AttentionType.ENCODER:
|
| 531 |
+
# No block tables associated with encoder attention
|
| 532 |
+
return (attn_metadata.encoder_seq_lens_tensor,
|
| 533 |
+
attn_metadata.max_encoder_seq_len, None)
|
| 534 |
+
else:
|
| 535 |
+
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def get_num_prefill_decode_query_kv_tokens(
|
| 539 |
+
attn_metadata,
|
| 540 |
+
attn_type: str,
|
| 541 |
+
) -> Tuple[int, int, int]:
|
| 542 |
+
"""
|
| 543 |
+
Calculate the number of prefill and decode tokens for query, key/value
|
| 544 |
+
based on the attention metadata and the specified attention type.
|
| 545 |
+
|
| 546 |
+
Args:
|
| 547 |
+
attn_metadata (FlashAttentionMetadata): Attention Metadata object.
|
| 548 |
+
attn_type (AttentionType): The type of attention being used.
|
| 549 |
+
Returns:
|
| 550 |
+
Tuple[int, int, int]: A tuple containing three integers:
|
| 551 |
+
- The number of prefill query tokens.
|
| 552 |
+
- The number of prefill key/value tokens.
|
| 553 |
+
- The number of decode query tokens.
|
| 554 |
+
|
| 555 |
+
Raises:
|
| 556 |
+
AssertionError: If the number of encoder tokens in `attn_metadata`
|
| 557 |
+
is `None` when required for the calculations.
|
| 558 |
+
"""
|
| 559 |
+
num_prefill_query_tokens = 0
|
| 560 |
+
num_decode_query_tokens = 0
|
| 561 |
+
num_prefill_kv_tokens = 0
|
| 562 |
+
if attn_type == AttentionType.ENCODER:
|
| 563 |
+
# Encoder attention is only invoked during prefill phase.
|
| 564 |
+
# The same input servers a both query and key.
|
| 565 |
+
assert attn_metadata.num_encoder_tokens is not None
|
| 566 |
+
num_prefill_query_tokens = attn_metadata.num_encoder_tokens
|
| 567 |
+
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
|
| 568 |
+
num_decode_query_tokens = 0
|
| 569 |
+
elif attn_type == AttentionType.ENCODER_DECODER:
|
| 570 |
+
assert attn_metadata.num_encoder_tokens is not None
|
| 571 |
+
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
|
| 572 |
+
# The key is the encoder/cross-attention.
|
| 573 |
+
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
|
| 574 |
+
num_decode_query_tokens = attn_metadata.num_decode_tokens
|
| 575 |
+
else: # attn_type == AttentionType.DECODER or
|
| 576 |
+
# attn_type == AttentionType.ENCODER_ONLY
|
| 577 |
+
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
|
| 578 |
+
num_prefill_kv_tokens = attn_metadata.num_prefill_tokens
|
| 579 |
+
num_decode_query_tokens = attn_metadata.num_decode_tokens
|
| 580 |
+
|
| 581 |
+
return (num_prefill_query_tokens, num_prefill_kv_tokens,
|
| 582 |
+
num_decode_query_tokens)
|
.venv/lib/python3.11/site-packages/vllm/attention/backends/xformers.py
ADDED
|
@@ -0,0 +1,794 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Attention layer with xFormers and PagedAttention."""
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from xformers import ops as xops
|
| 8 |
+
from xformers.ops.fmha.attn_bias import (AttentionBias,
|
| 9 |
+
BlockDiagonalCausalMask,
|
| 10 |
+
BlockDiagonalMask,
|
| 11 |
+
LowerTriangularMaskWithTensorBias)
|
| 12 |
+
|
| 13 |
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
| 14 |
+
AttentionLayer,
|
| 15 |
+
AttentionMetadata, AttentionType)
|
| 16 |
+
from vllm.attention.backends.utils import (
|
| 17 |
+
CommonAttentionState, CommonMetadataBuilder,
|
| 18 |
+
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
|
| 19 |
+
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set)
|
| 20 |
+
from vllm.attention.ops.paged_attn import (PagedAttention,
|
| 21 |
+
PagedAttentionMetadata)
|
| 22 |
+
from vllm.logger import init_logger
|
| 23 |
+
|
| 24 |
+
logger = init_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class XFormersBackend(AttentionBackend):
|
| 28 |
+
|
| 29 |
+
@staticmethod
|
| 30 |
+
def get_name() -> str:
|
| 31 |
+
return "XFORMERS"
|
| 32 |
+
|
| 33 |
+
@staticmethod
|
| 34 |
+
def get_impl_cls() -> Type["XFormersImpl"]:
|
| 35 |
+
return XFormersImpl
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
| 39 |
+
return XFormersMetadata
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
|
| 43 |
+
return XFormersMetadataBuilder
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def get_state_cls() -> Type["CommonAttentionState"]:
|
| 47 |
+
return CommonAttentionState
|
| 48 |
+
|
| 49 |
+
@staticmethod
|
| 50 |
+
def get_kv_cache_shape(
|
| 51 |
+
num_blocks: int,
|
| 52 |
+
block_size: int,
|
| 53 |
+
num_kv_heads: int,
|
| 54 |
+
head_size: int,
|
| 55 |
+
) -> Tuple[int, ...]:
|
| 56 |
+
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
| 57 |
+
num_kv_heads, head_size)
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def swap_blocks(
|
| 61 |
+
src_kv_cache: torch.Tensor,
|
| 62 |
+
dst_kv_cache: torch.Tensor,
|
| 63 |
+
src_to_dst: Dict[int, int],
|
| 64 |
+
) -> None:
|
| 65 |
+
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def copy_blocks(
|
| 69 |
+
kv_caches: List[torch.Tensor],
|
| 70 |
+
src_to_dists: torch.Tensor,
|
| 71 |
+
) -> None:
|
| 72 |
+
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
| 77 |
+
"""Metadata for XFormersbackend.
|
| 78 |
+
|
| 79 |
+
NOTE: Any python object stored here is not updated when it is
|
| 80 |
+
cuda-graph replayed. If you have values that need to be changed
|
| 81 |
+
dynamically, it should be stored in tensor. The tensor has to be
|
| 82 |
+
updated from `CUDAGraphRunner.forward` API.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
# |---------- N-1 iteration --------|
|
| 86 |
+
# |---------------- N iteration ---------------------|
|
| 87 |
+
# |- tokenA -|......................|-- newTokens ---|
|
| 88 |
+
# |---------- context_len ----------|
|
| 89 |
+
# |-------------------- seq_len ----------------------|
|
| 90 |
+
# |-- query_len ---|
|
| 91 |
+
|
| 92 |
+
# seq_lens stored as a tensor.
|
| 93 |
+
seq_lens_tensor: Optional[torch.Tensor]
|
| 94 |
+
|
| 95 |
+
# FIXME: It is for flash attn.
|
| 96 |
+
# Maximum sequence length among prefill batch. 0 if there are decoding
|
| 97 |
+
# requests only.
|
| 98 |
+
max_prefill_seq_len: int
|
| 99 |
+
# Maximum sequence length among decode batch. 0 if there are prefill
|
| 100 |
+
# requests only.
|
| 101 |
+
max_decode_seq_len: int
|
| 102 |
+
|
| 103 |
+
# Whether or not if cuda graph is enabled.
|
| 104 |
+
# Cuda-graph is currently enabled for decoding only.
|
| 105 |
+
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
| 106 |
+
use_cuda_graph: bool
|
| 107 |
+
|
| 108 |
+
# (batch_size,). The sequence length per sequence. Sequence length means
|
| 109 |
+
# the computed tokens + new tokens None if it is a decoding.
|
| 110 |
+
seq_lens: Optional[List[int]] = None
|
| 111 |
+
|
| 112 |
+
# FIXME: It is for flash attn.
|
| 113 |
+
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
| 114 |
+
# the batch, used to index into sequence. E.g., if the sequence length is
|
| 115 |
+
# [4, 6], it is [0, 4, 10].
|
| 116 |
+
seq_start_loc: Optional[torch.Tensor] = None
|
| 117 |
+
|
| 118 |
+
# (batch_size,) A tensor of context lengths (tokens that are computed
|
| 119 |
+
# so far).
|
| 120 |
+
context_lens_tensor: Optional[torch.Tensor] = None
|
| 121 |
+
|
| 122 |
+
# Maximum query length in the batch. None for decoding.
|
| 123 |
+
max_query_len: Optional[int] = None
|
| 124 |
+
|
| 125 |
+
# Max number of query tokens among request in the batch.
|
| 126 |
+
max_decode_query_len: Optional[int] = None
|
| 127 |
+
|
| 128 |
+
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
| 129 |
+
# the batch, used to index into subquery. E.g., if the subquery length
|
| 130 |
+
# is [4, 6], it is [0, 4, 10].
|
| 131 |
+
query_start_loc: Optional[torch.Tensor] = None
|
| 132 |
+
|
| 133 |
+
# Self-attention prefill/decode metadata cache
|
| 134 |
+
_cached_prefill_metadata: Optional["XFormersMetadata"] = None
|
| 135 |
+
_cached_decode_metadata: Optional["XFormersMetadata"] = None
|
| 136 |
+
|
| 137 |
+
# Begin encoder attn & enc/dec cross-attn fields...
|
| 138 |
+
|
| 139 |
+
# Encoder sequence lengths representation
|
| 140 |
+
encoder_seq_lens: Optional[List[int]] = None
|
| 141 |
+
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
| 142 |
+
# FIXME: It is for flash attn.
|
| 143 |
+
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
| 144 |
+
# the batch, used to index into sequence. E.g., if the sequence length is
|
| 145 |
+
# [4, 6], it is [0, 4, 10].
|
| 146 |
+
encoder_seq_start_loc: Optional[torch.Tensor] = None
|
| 147 |
+
|
| 148 |
+
# Maximum sequence length among encoder sequences
|
| 149 |
+
max_encoder_seq_len: Optional[int] = None
|
| 150 |
+
|
| 151 |
+
# Number of tokens input to encoder
|
| 152 |
+
num_encoder_tokens: Optional[int] = None
|
| 153 |
+
|
| 154 |
+
# Cross-attention memory-mapping data structures: slot mapping
|
| 155 |
+
# and block tables
|
| 156 |
+
cross_slot_mapping: Optional[torch.Tensor] = None
|
| 157 |
+
cross_block_tables: Optional[torch.Tensor] = None
|
| 158 |
+
|
| 159 |
+
def __post_init__(self):
|
| 160 |
+
# Set during the execution of the first attention op.
|
| 161 |
+
# It is a list because it is needed to set per prompt
|
| 162 |
+
# when alibi slopes is used. It is because of the limitation
|
| 163 |
+
# from xformer API.
|
| 164 |
+
# will not appear in the __repr__ and __init__
|
| 165 |
+
self.attn_bias: Optional[List[AttentionBias]] = None
|
| 166 |
+
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
|
| 167 |
+
self.cross_attn_bias: Optional[List[AttentionBias]] = None
|
| 168 |
+
|
| 169 |
+
@property
|
| 170 |
+
def is_all_encoder_attn_metadata_set(self):
|
| 171 |
+
'''
|
| 172 |
+
All attention metadata required for encoder attention is set.
|
| 173 |
+
'''
|
| 174 |
+
return is_all_encoder_attn_metadata_set(self)
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def is_all_cross_attn_metadata_set(self):
|
| 178 |
+
'''
|
| 179 |
+
All attention metadata required for enc/dec cross-attention is set.
|
| 180 |
+
|
| 181 |
+
Superset of encoder attention required metadata.
|
| 182 |
+
'''
|
| 183 |
+
return is_all_cross_attn_metadata_set(self)
|
| 184 |
+
|
| 185 |
+
@property
|
| 186 |
+
def prefill_metadata(self) -> Optional["XFormersMetadata"]:
|
| 187 |
+
if self.num_prefills == 0:
|
| 188 |
+
return None
|
| 189 |
+
|
| 190 |
+
if self._cached_prefill_metadata is not None:
|
| 191 |
+
# Recover cached prefill-phase attention
|
| 192 |
+
# metadata structure
|
| 193 |
+
return self._cached_prefill_metadata
|
| 194 |
+
|
| 195 |
+
assert ((self.seq_lens is not None)
|
| 196 |
+
or (self.encoder_seq_lens is not None))
|
| 197 |
+
assert ((self.seq_lens_tensor is not None)
|
| 198 |
+
or (self.encoder_seq_lens_tensor is not None))
|
| 199 |
+
|
| 200 |
+
# Compute some attn_metadata fields which default to None
|
| 201 |
+
query_start_loc = (None if self.query_start_loc is None else
|
| 202 |
+
self.query_start_loc[:self.num_prefills + 1])
|
| 203 |
+
seq_start_loc = (None if self.seq_start_loc is None else
|
| 204 |
+
self.seq_start_loc[:self.num_prefills + 1])
|
| 205 |
+
slot_mapping = (None if self.slot_mapping is None else
|
| 206 |
+
self.slot_mapping[:self.num_prefill_tokens])
|
| 207 |
+
seq_lens = (None if self.seq_lens is None else
|
| 208 |
+
self.seq_lens[:self.num_prefills])
|
| 209 |
+
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
| 210 |
+
self.seq_lens_tensor[:self.num_prefills])
|
| 211 |
+
context_lens_tensor = (None if self.context_lens_tensor is None else
|
| 212 |
+
self.context_lens_tensor[:self.num_prefills])
|
| 213 |
+
block_tables = (None if self.block_tables is None else
|
| 214 |
+
self.block_tables[:self.num_prefills])
|
| 215 |
+
|
| 216 |
+
# Construct & cache prefill-phase attention metadata structure
|
| 217 |
+
self._cached_prefill_metadata = XFormersMetadata(
|
| 218 |
+
num_prefills=self.num_prefills,
|
| 219 |
+
num_prefill_tokens=self.num_prefill_tokens,
|
| 220 |
+
num_decode_tokens=0,
|
| 221 |
+
slot_mapping=slot_mapping,
|
| 222 |
+
multi_modal_placeholder_index_maps=self.
|
| 223 |
+
multi_modal_placeholder_index_maps,
|
| 224 |
+
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
| 225 |
+
seq_lens=seq_lens,
|
| 226 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 227 |
+
max_query_len=self.max_query_len,
|
| 228 |
+
max_prefill_seq_len=self.max_prefill_seq_len,
|
| 229 |
+
max_decode_seq_len=0,
|
| 230 |
+
query_start_loc=query_start_loc,
|
| 231 |
+
seq_start_loc=seq_start_loc,
|
| 232 |
+
context_lens_tensor=context_lens_tensor,
|
| 233 |
+
block_tables=block_tables,
|
| 234 |
+
use_cuda_graph=False,
|
| 235 |
+
# Begin encoder & cross attn fields below...
|
| 236 |
+
encoder_seq_lens=self.encoder_seq_lens,
|
| 237 |
+
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
| 238 |
+
max_encoder_seq_len=self.max_encoder_seq_len,
|
| 239 |
+
cross_slot_mapping=self.cross_slot_mapping,
|
| 240 |
+
cross_block_tables=self.cross_block_tables)
|
| 241 |
+
return self._cached_prefill_metadata
|
| 242 |
+
|
| 243 |
+
@property
|
| 244 |
+
def decode_metadata(self) -> Optional["XFormersMetadata"]:
|
| 245 |
+
if self.num_decode_tokens == 0:
|
| 246 |
+
return None
|
| 247 |
+
|
| 248 |
+
if self._cached_decode_metadata is not None:
|
| 249 |
+
# Recover cached decode-phase attention
|
| 250 |
+
# metadata structure
|
| 251 |
+
return self._cached_decode_metadata
|
| 252 |
+
assert ((self.seq_lens_tensor is not None)
|
| 253 |
+
or (self.encoder_seq_lens_tensor is not None))
|
| 254 |
+
|
| 255 |
+
# Compute some attn_metadata fields which default to None
|
| 256 |
+
slot_mapping = (None if self.slot_mapping is None else
|
| 257 |
+
self.slot_mapping[self.num_prefill_tokens:])
|
| 258 |
+
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
| 259 |
+
self.seq_lens_tensor[self.num_prefills:])
|
| 260 |
+
block_tables = (None if self.block_tables is None else
|
| 261 |
+
self.block_tables[self.num_prefills:])
|
| 262 |
+
|
| 263 |
+
# Construct & cache decode-phase attention metadata structure
|
| 264 |
+
self._cached_decode_metadata = XFormersMetadata(
|
| 265 |
+
num_prefills=0,
|
| 266 |
+
num_prefill_tokens=0,
|
| 267 |
+
num_decode_tokens=self.num_decode_tokens,
|
| 268 |
+
slot_mapping=slot_mapping,
|
| 269 |
+
multi_modal_placeholder_index_maps=None,
|
| 270 |
+
enable_kv_scales_calculation=True,
|
| 271 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 272 |
+
max_prefill_seq_len=0,
|
| 273 |
+
max_decode_seq_len=self.max_decode_seq_len,
|
| 274 |
+
block_tables=block_tables,
|
| 275 |
+
use_cuda_graph=self.use_cuda_graph,
|
| 276 |
+
# Begin encoder & cross attn fields below...
|
| 277 |
+
encoder_seq_lens=self.encoder_seq_lens,
|
| 278 |
+
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
| 279 |
+
max_encoder_seq_len=self.max_encoder_seq_len,
|
| 280 |
+
cross_slot_mapping=self.cross_slot_mapping,
|
| 281 |
+
cross_block_tables=self.cross_block_tables)
|
| 282 |
+
|
| 283 |
+
# Batch may be composed of prefill|decodes, adjust query start indices
|
| 284 |
+
# to refer to the start of decodes when the two are split apart.
|
| 285 |
+
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
| 286 |
+
if self._cached_decode_metadata.query_start_loc is not None:
|
| 287 |
+
qs = self._cached_decode_metadata.query_start_loc
|
| 288 |
+
self._cached_decode_metadata.query_start_loc = qs - qs[0]
|
| 289 |
+
return self._cached_decode_metadata
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def _get_attn_bias(
|
| 293 |
+
attn_metadata: XFormersMetadata,
|
| 294 |
+
attn_type: str,
|
| 295 |
+
) -> Optional[AttentionBias]:
|
| 296 |
+
'''
|
| 297 |
+
Extract appropriate attention bias from attention metadata
|
| 298 |
+
according to attention type.
|
| 299 |
+
|
| 300 |
+
Arguments:
|
| 301 |
+
|
| 302 |
+
* attn_metadata: Attention metadata structure associated with attention
|
| 303 |
+
* attn_type: encoder attention, decoder self-attention,
|
| 304 |
+
encoder/decoder cross-attention
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
* Appropriate attention bias value given the attention type
|
| 308 |
+
'''
|
| 309 |
+
|
| 310 |
+
if (attn_type == AttentionType.DECODER
|
| 311 |
+
or attn_type == AttentionType.ENCODER_ONLY):
|
| 312 |
+
return attn_metadata.attn_bias
|
| 313 |
+
elif attn_type == AttentionType.ENCODER:
|
| 314 |
+
return attn_metadata.encoder_attn_bias
|
| 315 |
+
elif attn_type == AttentionType.ENCODER_DECODER:
|
| 316 |
+
return attn_metadata.cross_attn_bias
|
| 317 |
+
else:
|
| 318 |
+
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def _set_attn_bias(
|
| 322 |
+
attn_metadata: XFormersMetadata,
|
| 323 |
+
attn_bias: List[Optional[AttentionBias]],
|
| 324 |
+
attn_type: str,
|
| 325 |
+
) -> None:
|
| 326 |
+
'''
|
| 327 |
+
Update appropriate attention bias field of attention metadata,
|
| 328 |
+
according to attention type.
|
| 329 |
+
|
| 330 |
+
Arguments:
|
| 331 |
+
|
| 332 |
+
* attn_metadata: Attention metadata structure associated with attention
|
| 333 |
+
* attn_bias: The desired attention bias value
|
| 334 |
+
* attn_type: encoder attention, decoder self-attention,
|
| 335 |
+
encoder/decoder cross-attention
|
| 336 |
+
'''
|
| 337 |
+
|
| 338 |
+
if (attn_type == AttentionType.DECODER
|
| 339 |
+
or attn_type == AttentionType.ENCODER_ONLY):
|
| 340 |
+
attn_metadata.attn_bias = attn_bias
|
| 341 |
+
elif attn_type == AttentionType.ENCODER:
|
| 342 |
+
attn_metadata.encoder_attn_bias = attn_bias
|
| 343 |
+
elif attn_type == AttentionType.ENCODER_DECODER:
|
| 344 |
+
attn_metadata.cross_attn_bias = attn_bias
|
| 345 |
+
else:
|
| 346 |
+
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):
|
| 350 |
+
|
| 351 |
+
_metadata_cls = XFormersMetadata
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
| 355 |
+
"""
|
| 356 |
+
If the input tensors contain prompt tokens, the layout is as follows:
|
| 357 |
+
|<--------------- num_prefill_tokens ----------------->|
|
| 358 |
+
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
|
| 359 |
+
|
| 360 |
+
Otherwise, the layout is as follows:
|
| 361 |
+
|<----------------- num_decode_tokens ------------------>|
|
| 362 |
+
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
|
| 363 |
+
|
| 364 |
+
Generation tokens can contain padding when cuda-graph is used.
|
| 365 |
+
Currently, prompt tokens don't contain any padding.
|
| 366 |
+
|
| 367 |
+
The prompts might have different lengths, while the generation tokens
|
| 368 |
+
always have length 1.
|
| 369 |
+
|
| 370 |
+
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
| 371 |
+
batched together in a flattened 1D query.
|
| 372 |
+
|
| 373 |
+
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|
| 374 |
+
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
|
| 375 |
+
|
| 376 |
+
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
| 377 |
+
padding between prefill and decode tokens.
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
def __init__(
|
| 381 |
+
self,
|
| 382 |
+
num_heads: int,
|
| 383 |
+
head_size: int,
|
| 384 |
+
scale: float,
|
| 385 |
+
num_kv_heads: int,
|
| 386 |
+
alibi_slopes: Optional[List[float]],
|
| 387 |
+
sliding_window: Optional[int],
|
| 388 |
+
kv_cache_dtype: str,
|
| 389 |
+
blocksparse_params: Optional[Dict[str, Any]] = None,
|
| 390 |
+
logits_soft_cap: Optional[float] = None,
|
| 391 |
+
attn_type: str = AttentionType.DECODER,
|
| 392 |
+
) -> None:
|
| 393 |
+
if blocksparse_params is not None:
|
| 394 |
+
raise ValueError(
|
| 395 |
+
"XFormers does not support block-sparse attention.")
|
| 396 |
+
if logits_soft_cap is not None:
|
| 397 |
+
logger.warning_once("XFormers does not support logits soft cap. "
|
| 398 |
+
"Outputs may be slightly off.")
|
| 399 |
+
self.num_heads = num_heads
|
| 400 |
+
self.head_size = head_size
|
| 401 |
+
self.scale = float(scale)
|
| 402 |
+
self.num_kv_heads = num_kv_heads
|
| 403 |
+
if alibi_slopes is not None:
|
| 404 |
+
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
| 405 |
+
self.alibi_slopes = alibi_slopes
|
| 406 |
+
self.sliding_window = sliding_window
|
| 407 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 408 |
+
|
| 409 |
+
assert self.num_heads % self.num_kv_heads == 0
|
| 410 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 411 |
+
|
| 412 |
+
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
|
| 413 |
+
if head_size not in suppored_head_sizes:
|
| 414 |
+
raise ValueError(
|
| 415 |
+
f"Head size {head_size} is not supported by PagedAttention. "
|
| 416 |
+
f"Supported head sizes are: {suppored_head_sizes}.")
|
| 417 |
+
|
| 418 |
+
self.attn_type = attn_type
|
| 419 |
+
|
| 420 |
+
def forward(
|
| 421 |
+
self,
|
| 422 |
+
layer: AttentionLayer,
|
| 423 |
+
query: torch.Tensor,
|
| 424 |
+
key: Optional[torch.Tensor],
|
| 425 |
+
value: Optional[torch.Tensor],
|
| 426 |
+
kv_cache: torch.Tensor,
|
| 427 |
+
attn_metadata: "XFormersMetadata",
|
| 428 |
+
output: Optional[torch.Tensor] = None,
|
| 429 |
+
) -> torch.Tensor:
|
| 430 |
+
"""Forward pass with xFormers and PagedAttention.
|
| 431 |
+
|
| 432 |
+
For decoder-only models: query, key and value must be non-None.
|
| 433 |
+
|
| 434 |
+
For encoder/decoder models:
|
| 435 |
+
* XFormersImpl.forward() may be invoked for both self- and cross-
|
| 436 |
+
attention layers.
|
| 437 |
+
* For self-attention: query, key and value must be non-None.
|
| 438 |
+
* For cross-attention:
|
| 439 |
+
* Query must be non-None
|
| 440 |
+
* During prefill, key and value must be non-None; key and value
|
| 441 |
+
get cached for use during decode.
|
| 442 |
+
* During decode, key and value may be None, since:
|
| 443 |
+
(1) key and value tensors were cached during prefill, and
|
| 444 |
+
(2) cross-attention key and value tensors do not grow during
|
| 445 |
+
decode
|
| 446 |
+
|
| 447 |
+
A note on how the attn_type (attention type enum) argument impacts
|
| 448 |
+
attention forward() behavior:
|
| 449 |
+
|
| 450 |
+
* DECODER: normal decoder-only behavior;
|
| 451 |
+
use decoder self-attention block table
|
| 452 |
+
* ENCODER: no KV caching; pass encoder sequence
|
| 453 |
+
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
|
| 454 |
+
max_encoder_seq_len) to kernel, in lieu of decoder
|
| 455 |
+
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len).
|
| 456 |
+
Used for encoder branch of encoder-decoder models.
|
| 457 |
+
* ENCODER_ONLY: no kv_caching, uses the normal attention
|
| 458 |
+
attributes (seq_lens/seq_lens_tensor/max_seq_len).
|
| 459 |
+
* ENCODER_DECODER: cross-attention behavior;
|
| 460 |
+
use cross-attention block table for caching KVs derived
|
| 461 |
+
from encoder hidden states; since KV sequence lengths
|
| 462 |
+
will match encoder sequence lengths, pass encoder sequence
|
| 463 |
+
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
|
| 464 |
+
max_encoder_seq_len)
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
query: shape = [num_tokens, num_heads * head_size]
|
| 468 |
+
key: shape = [num_tokens, num_kv_heads * head_size]
|
| 469 |
+
value: shape = [num_tokens, num_kv_heads * head_size]
|
| 470 |
+
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
| 471 |
+
NOTE: kv_cache will be an empty tensor with shape [0]
|
| 472 |
+
for profiling run.
|
| 473 |
+
attn_metadata: Metadata for attention.
|
| 474 |
+
attn_type: Select attention type, between encoder attention,
|
| 475 |
+
decoder self-attention, or encoder/decoder cross-
|
| 476 |
+
attention. Defaults to decoder self-attention,
|
| 477 |
+
which is the vLLM default generally
|
| 478 |
+
Returns:
|
| 479 |
+
shape = [num_tokens, num_heads * head_size]
|
| 480 |
+
"""
|
| 481 |
+
attn_type = self.attn_type
|
| 482 |
+
# Check that appropriate attention metadata attributes are
|
| 483 |
+
# selected for the desired attention type
|
| 484 |
+
if (attn_type == AttentionType.ENCODER
|
| 485 |
+
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
| 486 |
+
raise AttributeError("Encoder attention requires setting "
|
| 487 |
+
"encoder metadata attributes.")
|
| 488 |
+
|
| 489 |
+
elif (attn_type == AttentionType.ENCODER_DECODER
|
| 490 |
+
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
| 491 |
+
raise AttributeError("Encoder/decoder cross-attention "
|
| 492 |
+
"requires setting cross-attention "
|
| 493 |
+
"metadata attributes.")
|
| 494 |
+
|
| 495 |
+
query = query.view(-1, self.num_heads, self.head_size)
|
| 496 |
+
if key is not None:
|
| 497 |
+
assert value is not None
|
| 498 |
+
key = key.view(-1, self.num_kv_heads, self.head_size)
|
| 499 |
+
value = value.view(-1, self.num_kv_heads, self.head_size)
|
| 500 |
+
else:
|
| 501 |
+
assert value is None
|
| 502 |
+
|
| 503 |
+
# Self-attention vs. cross-attention will impact
|
| 504 |
+
# which KV cache memory-mapping & which
|
| 505 |
+
# seqlen datastructures we utilize
|
| 506 |
+
|
| 507 |
+
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
| 508 |
+
# KV-cache during decoder-self- or
|
| 509 |
+
# encoder-decoder-cross-attention, but not
|
| 510 |
+
# during encoder attention.
|
| 511 |
+
#
|
| 512 |
+
# Even if there are no new key/value pairs to cache,
|
| 513 |
+
# we still need to break out key_cache and value_cache
|
| 514 |
+
# i.e. for later use by paged attention
|
| 515 |
+
key_cache, value_cache = PagedAttention.split_kv_cache(
|
| 516 |
+
kv_cache, self.num_kv_heads, self.head_size)
|
| 517 |
+
|
| 518 |
+
if (key is not None) and (value is not None):
|
| 519 |
+
|
| 520 |
+
if attn_type == AttentionType.ENCODER_DECODER:
|
| 521 |
+
# Update cross-attention KV cache (prefill-only)
|
| 522 |
+
# During cross-attention decode, key & value will be None,
|
| 523 |
+
# preventing this IF-statement branch from running
|
| 524 |
+
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
| 525 |
+
else:
|
| 526 |
+
# Update self-attention KV cache (prefill/decode)
|
| 527 |
+
updated_slot_mapping = attn_metadata.slot_mapping
|
| 528 |
+
|
| 529 |
+
# Reshape the input keys and values and store them in the cache.
|
| 530 |
+
# If kv_cache is not provided, the new key and value tensors are
|
| 531 |
+
# not cached. This happens during the initial memory
|
| 532 |
+
# profiling run.
|
| 533 |
+
PagedAttention.write_to_paged_cache(
|
| 534 |
+
key, value, key_cache, value_cache, updated_slot_mapping,
|
| 535 |
+
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
|
| 536 |
+
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
| 537 |
+
num_decode_query_tokens) = \
|
| 538 |
+
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
| 539 |
+
|
| 540 |
+
output = torch.empty_like(query)
|
| 541 |
+
# Query for decode. KV is not needed because it is already cached.
|
| 542 |
+
decode_query = query[num_prefill_query_tokens:]
|
| 543 |
+
# QKV for prefill.
|
| 544 |
+
query = query[:num_prefill_query_tokens]
|
| 545 |
+
if key is not None and value is not None:
|
| 546 |
+
key = key[:num_prefill_kv_tokens]
|
| 547 |
+
value = value[:num_prefill_kv_tokens]
|
| 548 |
+
|
| 549 |
+
assert query.shape[0] == num_prefill_query_tokens
|
| 550 |
+
assert decode_query.shape[0] == num_decode_query_tokens
|
| 551 |
+
|
| 552 |
+
if prefill_meta := attn_metadata.prefill_metadata:
|
| 553 |
+
# Prompt run.
|
| 554 |
+
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
|
| 555 |
+
# normal attention.
|
| 556 |
+
# block tables are empty if the prompt does not have a cached
|
| 557 |
+
# prefix.
|
| 558 |
+
out = self._run_memory_efficient_xformers_forward(
|
| 559 |
+
query, key, value, prefill_meta, attn_type=attn_type)
|
| 560 |
+
assert out.shape == output[:num_prefill_query_tokens].shape
|
| 561 |
+
output[:num_prefill_query_tokens] = out
|
| 562 |
+
else:
|
| 563 |
+
assert attn_type != AttentionType.ENCODER_ONLY, (
|
| 564 |
+
"Encoder-only models should not have prefix attention.")
|
| 565 |
+
|
| 566 |
+
assert prefill_meta.query_start_loc is not None
|
| 567 |
+
assert prefill_meta.max_query_len is not None
|
| 568 |
+
|
| 569 |
+
# prefix-enabled attention
|
| 570 |
+
# TODO(Hai) this triton kernel has regression issue (broke) to
|
| 571 |
+
# deal with different data types between KV and FP8 KV cache,
|
| 572 |
+
# to be addressed separately.
|
| 573 |
+
out = PagedAttention.forward_prefix(
|
| 574 |
+
query,
|
| 575 |
+
key,
|
| 576 |
+
value,
|
| 577 |
+
self.kv_cache_dtype,
|
| 578 |
+
key_cache,
|
| 579 |
+
value_cache,
|
| 580 |
+
prefill_meta.block_tables,
|
| 581 |
+
prefill_meta.query_start_loc,
|
| 582 |
+
prefill_meta.seq_lens_tensor,
|
| 583 |
+
prefill_meta.context_lens_tensor,
|
| 584 |
+
prefill_meta.max_query_len,
|
| 585 |
+
self.alibi_slopes,
|
| 586 |
+
self.sliding_window,
|
| 587 |
+
layer._k_scale,
|
| 588 |
+
layer._v_scale,
|
| 589 |
+
)
|
| 590 |
+
assert output[:num_prefill_query_tokens].shape == out.shape
|
| 591 |
+
output[:num_prefill_query_tokens] = out
|
| 592 |
+
|
| 593 |
+
if decode_meta := attn_metadata.decode_metadata:
|
| 594 |
+
assert attn_type != AttentionType.ENCODER_ONLY, (
|
| 595 |
+
"Encoder-only models should not have decode metadata.")
|
| 596 |
+
|
| 597 |
+
(
|
| 598 |
+
seq_lens_arg,
|
| 599 |
+
max_seq_len_arg,
|
| 600 |
+
block_tables_arg,
|
| 601 |
+
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
| 602 |
+
|
| 603 |
+
output[num_prefill_query_tokens:] = PagedAttention.forward_decode(
|
| 604 |
+
decode_query,
|
| 605 |
+
key_cache,
|
| 606 |
+
value_cache,
|
| 607 |
+
block_tables_arg,
|
| 608 |
+
seq_lens_arg,
|
| 609 |
+
max_seq_len_arg,
|
| 610 |
+
self.kv_cache_dtype,
|
| 611 |
+
self.num_kv_heads,
|
| 612 |
+
self.scale,
|
| 613 |
+
self.alibi_slopes,
|
| 614 |
+
layer._k_scale,
|
| 615 |
+
layer._v_scale,
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
# Reshape the output tensor.
|
| 619 |
+
return output.view(-1, self.num_heads * self.head_size)
|
| 620 |
+
|
| 621 |
+
def _run_memory_efficient_xformers_forward(
|
| 622 |
+
self,
|
| 623 |
+
query: torch.Tensor,
|
| 624 |
+
key: torch.Tensor,
|
| 625 |
+
value: torch.Tensor,
|
| 626 |
+
attn_metadata: XFormersMetadata,
|
| 627 |
+
attn_type: str = AttentionType.DECODER,
|
| 628 |
+
) -> torch.Tensor:
|
| 629 |
+
"""Attention for 1D query of multiple prompts. Multiple prompt
|
| 630 |
+
tokens are flattened in to `query` input.
|
| 631 |
+
|
| 632 |
+
See https://facebookresearch.github.io/xformers/components/ops.html
|
| 633 |
+
for API spec.
|
| 634 |
+
|
| 635 |
+
Args:
|
| 636 |
+
output: shape = [num_prefill_tokens, num_heads, head_size]
|
| 637 |
+
query: shape = [num_prefill_tokens, num_heads, head_size]
|
| 638 |
+
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
| 639 |
+
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
| 640 |
+
attn_metadata: Metadata for attention.
|
| 641 |
+
attn_type: Select attention type, between encoder attention,
|
| 642 |
+
decoder self-attention, or encoder/decoder cross-
|
| 643 |
+
attention. Defaults to decoder self-attention,
|
| 644 |
+
which is the vLLM default generally
|
| 645 |
+
"""
|
| 646 |
+
|
| 647 |
+
original_query = query
|
| 648 |
+
if self.num_kv_heads != self.num_heads:
|
| 649 |
+
# GQA/MQA requires the shape [B, M, G, H, K].
|
| 650 |
+
# Note that the output also has the same shape (which is different
|
| 651 |
+
# from a spec from the doc).
|
| 652 |
+
query = query.view(query.shape[0], self.num_kv_heads,
|
| 653 |
+
self.num_queries_per_kv, query.shape[-1])
|
| 654 |
+
key = key[:, :,
|
| 655 |
+
None, :].expand(key.shape[0], self.num_kv_heads,
|
| 656 |
+
self.num_queries_per_kv, key.shape[-1])
|
| 657 |
+
value = value[:, :,
|
| 658 |
+
None, :].expand(value.shape[0], self.num_kv_heads,
|
| 659 |
+
self.num_queries_per_kv,
|
| 660 |
+
value.shape[-1])
|
| 661 |
+
|
| 662 |
+
# Set attention bias if not provided. This typically happens at
|
| 663 |
+
# the very attention layer of every iteration.
|
| 664 |
+
# FIXME(woosuk): This is a hack.
|
| 665 |
+
attn_bias = _get_attn_bias(attn_metadata, attn_type)
|
| 666 |
+
if attn_bias is None:
|
| 667 |
+
if self.alibi_slopes is None:
|
| 668 |
+
|
| 669 |
+
# Cross attention block of decoder branch of encoder-decoder
|
| 670 |
+
# model uses seq_lens for dec / encoder_seq_lens for enc
|
| 671 |
+
if (attn_type == AttentionType.ENCODER_DECODER):
|
| 672 |
+
assert attn_metadata.seq_lens is not None
|
| 673 |
+
assert attn_metadata.encoder_seq_lens is not None
|
| 674 |
+
|
| 675 |
+
# Cross-attention mask is non-causal
|
| 676 |
+
attn_bias = BlockDiagonalMask.from_seqlens(
|
| 677 |
+
attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
|
| 678 |
+
|
| 679 |
+
# Encoder branch of encoder-decoder model uses
|
| 680 |
+
# attn_metadata.encoder_seq_lens
|
| 681 |
+
elif attn_type == AttentionType.ENCODER:
|
| 682 |
+
|
| 683 |
+
assert attn_metadata.encoder_seq_lens is not None
|
| 684 |
+
|
| 685 |
+
# Encoder self-attention mask is non-causal
|
| 686 |
+
attn_bias = BlockDiagonalMask.from_seqlens(
|
| 687 |
+
attn_metadata.encoder_seq_lens)
|
| 688 |
+
|
| 689 |
+
# Self-attention block of encoder-only model just
|
| 690 |
+
# uses the seq_lens directly.
|
| 691 |
+
elif attn_type == AttentionType.ENCODER_ONLY:
|
| 692 |
+
assert attn_metadata.seq_lens is not None
|
| 693 |
+
|
| 694 |
+
# Encoder self-attention mask is non-causal
|
| 695 |
+
attn_bias = BlockDiagonalMask.from_seqlens(
|
| 696 |
+
attn_metadata.seq_lens)
|
| 697 |
+
|
| 698 |
+
# Self-attention block of decoder branch just
|
| 699 |
+
# uses the seq_lens directly
|
| 700 |
+
elif attn_type == AttentionType.DECODER:
|
| 701 |
+
assert attn_metadata.seq_lens is not None
|
| 702 |
+
|
| 703 |
+
# Decoder self-attention mask is causal
|
| 704 |
+
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
| 705 |
+
attn_metadata.seq_lens)
|
| 706 |
+
else:
|
| 707 |
+
raise ValueError("Unknown AttentionType: %s", attn_type)
|
| 708 |
+
|
| 709 |
+
if self.sliding_window is not None:
|
| 710 |
+
attn_bias = attn_bias.make_local_attention(
|
| 711 |
+
self.sliding_window)
|
| 712 |
+
attn_bias = [attn_bias]
|
| 713 |
+
else:
|
| 714 |
+
assert attn_type == AttentionType.DECODER
|
| 715 |
+
assert attn_metadata.seq_lens is not None
|
| 716 |
+
attn_bias = _make_alibi_bias(self.alibi_slopes,
|
| 717 |
+
self.num_kv_heads, query.dtype,
|
| 718 |
+
attn_metadata.seq_lens)
|
| 719 |
+
|
| 720 |
+
_set_attn_bias(attn_metadata, attn_bias, attn_type)
|
| 721 |
+
|
| 722 |
+
# No alibi slopes.
|
| 723 |
+
# TODO(woosuk): Too many view operations. Let's try to reduce
|
| 724 |
+
# them in the future for code readability.
|
| 725 |
+
if self.alibi_slopes is None:
|
| 726 |
+
# Add the batch dimension.
|
| 727 |
+
query = query.unsqueeze(0)
|
| 728 |
+
key = key.unsqueeze(0)
|
| 729 |
+
value = value.unsqueeze(0)
|
| 730 |
+
out = xops.memory_efficient_attention_forward(
|
| 731 |
+
query,
|
| 732 |
+
key,
|
| 733 |
+
value,
|
| 734 |
+
attn_bias=attn_bias[0],
|
| 735 |
+
p=0.0,
|
| 736 |
+
scale=self.scale)
|
| 737 |
+
return out.view_as(original_query)
|
| 738 |
+
|
| 739 |
+
# Attention with alibi slopes.
|
| 740 |
+
# FIXME(woosuk): Because xformers does not support dynamic sequence
|
| 741 |
+
# lengths with custom attention bias, we process each prompt one by
|
| 742 |
+
# one. This is inefficient, especially when we have many short prompts.
|
| 743 |
+
assert attn_metadata.seq_lens is not None
|
| 744 |
+
output = torch.empty_like(original_query)
|
| 745 |
+
start = 0
|
| 746 |
+
for i, seq_len in enumerate(attn_metadata.seq_lens):
|
| 747 |
+
end = start + seq_len
|
| 748 |
+
out = xops.memory_efficient_attention_forward(
|
| 749 |
+
query[None, start:end],
|
| 750 |
+
key[None, start:end],
|
| 751 |
+
value[None, start:end],
|
| 752 |
+
attn_bias=attn_bias[i],
|
| 753 |
+
p=0.0,
|
| 754 |
+
scale=self.scale)
|
| 755 |
+
# TODO(woosuk): Unnecessary copy. Optimize.
|
| 756 |
+
output[start:end].copy_(out.view_as(original_query[start:end]))
|
| 757 |
+
start += seq_len
|
| 758 |
+
return output
|
| 759 |
+
|
| 760 |
+
|
| 761 |
+
def _make_alibi_bias(
|
| 762 |
+
alibi_slopes: torch.Tensor,
|
| 763 |
+
num_kv_heads: int,
|
| 764 |
+
dtype: torch.dtype,
|
| 765 |
+
seq_lens: List[int],
|
| 766 |
+
) -> List[AttentionBias]:
|
| 767 |
+
attn_biases: List[AttentionBias] = []
|
| 768 |
+
for seq_len in seq_lens:
|
| 769 |
+
bias = torch.arange(seq_len, dtype=dtype)
|
| 770 |
+
# NOTE(zhuohan): HF uses
|
| 771 |
+
# `bias = bias[None, :].repeat(seq_len, 1)`
|
| 772 |
+
# here. We find that both biases give the same results, but
|
| 773 |
+
# the bias below more accurately follows the original ALiBi
|
| 774 |
+
# paper.
|
| 775 |
+
# Calculate a matrix where each element represents ith element- jth
|
| 776 |
+
# element.
|
| 777 |
+
bias = bias[None, :] - bias[:, None]
|
| 778 |
+
|
| 779 |
+
padded_len = (seq_len + 7) // 8 * 8
|
| 780 |
+
num_heads = alibi_slopes.shape[0]
|
| 781 |
+
bias = torch.empty(
|
| 782 |
+
1, # batch size
|
| 783 |
+
num_heads,
|
| 784 |
+
seq_len,
|
| 785 |
+
padded_len,
|
| 786 |
+
device=alibi_slopes.device,
|
| 787 |
+
dtype=dtype,
|
| 788 |
+
)[:, :, :, :seq_len].copy_(bias)
|
| 789 |
+
bias.mul_(alibi_slopes[:, None, None])
|
| 790 |
+
if num_heads != num_kv_heads:
|
| 791 |
+
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
|
| 792 |
+
attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
|
| 793 |
+
|
| 794 |
+
return attn_biases
|
.venv/lib/python3.11/site-packages/vllm/attention/ops/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (191 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/hpu_paged_attn.cpython-311.pyc
ADDED
|
Binary file (5.95 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/ipex_attn.cpython-311.pyc
ADDED
|
Binary file (8.01 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/nki_flash_attn.cpython-311.pyc
ADDED
|
Binary file (25.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/paged_attn.cpython-311.pyc
ADDED
|
Binary file (8.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/prefix_prefill.cpython-311.pyc
ADDED
|
Binary file (31.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/triton_decode_attention.cpython-311.pyc
ADDED
|
Binary file (19.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/triton_flash_attention.cpython-311.pyc
ADDED
|
Binary file (20.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (213 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/blocksparse_attention_kernel.cpython-311.pyc
ADDED
|
Binary file (14.8 kB). View file
|
|
|