diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/__init__.py b/.venv/lib/python3.11/site-packages/vllm/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..85c5715faba7f11c95db33a9eb82a7f3946613f8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/__init__.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, AttentionType) +from vllm.attention.layer import Attention +from vllm.attention.selector import get_attn_backend + +__all__ = [ + "Attention", + "AttentionBackend", + "AttentionMetadata", + "AttentionType", + "AttentionMetadataBuilder", + "Attention", + "AttentionState", + "get_attn_backend", +] diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38b735df77e39d5f9b9578081b1143c3fb700a2a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/__pycache__/layer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/__pycache__/layer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5d31260886a4b4fbc8928f5459c0abf7c932691 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/__pycache__/layer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/__pycache__/selector.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/__pycache__/selector.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d95ff4f893e34dd6db0393138009a85de5bd54f3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/__pycache__/selector.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/__init__.py b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..682629b7b2b38a5d5da4ce21665db588d2a691d5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/abstract.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/abstract.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e73955973af6b1e8c6b29947f2e4b7a118451448 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/abstract.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/blocksparse_attn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/blocksparse_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2991fbaa89d15db214a66a83dca2fc251a1c753c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/blocksparse_attn.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/flash_attn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/flash_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d21f31a25943bf66c49e6aea31fa4d7b6c801f0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/flash_attn.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/flashinfer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/flashinfer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50c83e72a16084e544643658540a11d7987ae209 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/flashinfer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/hpu_attn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/hpu_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e274afae93d07d982fcaeffbb638db05861566a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/hpu_attn.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/ipex_attn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/ipex_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..743aba393088c0592d296cfc5f98ad94a97dd7e2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/ipex_attn.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/openvino.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/openvino.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73abd2d0d16462b7cb1f33033fc19bd7bf7b74fc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/openvino.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/pallas.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/pallas.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43cc26c259dddd75a873200dbbea365e4b614ad2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/pallas.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/placeholder_attn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/placeholder_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74d6ea1d2125cfde2743931c2406f1b4b61d79c9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/placeholder_attn.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/rocm_flash_attn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/rocm_flash_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcab9898d154cf53d642217c71a322855faa708d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/rocm_flash_attn.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/torch_sdpa.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/torch_sdpa.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2312cbb629bb4b3176ef00f58c163824672d8663 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/torch_sdpa.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/triton_mla.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/triton_mla.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02b9dad249e04f73a011c2b5ed887e8e294a1b9d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/triton_mla.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae09ab64eb55f54990ee48506441c8ffe6c3152d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/xformers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/xformers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba7f01c7d93c1bdecd2ee7fbf2693473dbd5348a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/xformers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/abstract.py b/.venv/lib/python3.11/site-packages/vllm/attention/backends/abstract.py new file mode 100644 index 0000000000000000000000000000000000000000..5f0a540135402e6afb9a37ea94a7d0b5bacc9325 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/backends/abstract.py @@ -0,0 +1,296 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from contextlib import contextmanager +from dataclasses import dataclass, fields +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, + Protocol, Set, Tuple, Type, TypeVar) + +import torch + +from vllm.multimodal import MultiModalPlaceholderMap + +if TYPE_CHECKING: + from vllm.worker.model_runner_base import (ModelRunnerBase, + ModelRunnerInputBase, + ModelRunnerInputBuilderBase) + + +class AttentionType: + """ + Attention type. + Use string to be compatible with `torch.compile`. + """ + # Decoder attention between previous layer Q/K/V + DECODER = "decoder" + # Encoder attention between previous layer Q/K/V for encoder-decoder + ENCODER = "encoder" + # Encoder attention between previous layer Q/K/V + ENCODER_ONLY = "encoder_only" + # Attention between dec. Q and enc. K/V for encoder-decoder + ENCODER_DECODER = "encoder_decoder" + + +class AttentionBackend(ABC): + """Abstract class for attention backends.""" + # For some attention backends, we allocate an output tensor before + # calling the custom op. When piecewise cudagraph is enabled, this + # makes sure the output tensor is allocated inside the cudagraph. + accept_output_buffer: bool = False + + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_impl_cls() -> Type["AttentionImpl"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_state_cls() -> Type["AttentionState"]: + raise NotImplementedError + + @classmethod + def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": + return cls.get_metadata_cls()(*args, **kwargs) + + @staticmethod + @abstractmethod + def get_builder_cls() -> Type["AttentionMetadataBuilder"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise NotImplementedError + + @staticmethod + @abstractmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + raise NotImplementedError + + def advance_step(self, model_input: "ModelRunnerInputBase", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, num_seqs: int, num_queries: int) -> None: + raise NotImplementedError + + +@dataclass +class AttentionMetadata: + """Attention metadata for prefill and decode batched together.""" + # Total number of prefill requests. + num_prefills: int + # Number of prefill tokens. + num_prefill_tokens: int + # Number of decode tokens. Note that it is equivalent to the number of + # decode requests. + num_decode_tokens: int + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + + # The index maps that relate multi-modal embeddings to the corresponding + # placeholders. + # + # N.B. These aren't really related to attention and don't belong on this + # type -- this is just a temporary solution to make them available to + # `model_executable`. + multi_modal_placeholder_index_maps: Optional[Dict[ + str, MultiModalPlaceholderMap.IndexMap]] + + # Enable/disable KV scales calculation. This is so that we can disable the + # calculation until after prefill and cuda graph capture. + enable_kv_scales_calculation: bool + + @property + @abstractmethod + def prefill_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run prefill + attention.""" + pass + + @property + @abstractmethod + def decode_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run decode + attention.""" + pass + + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: + """Similar to dataclasses.asdict, but avoids deepcopying.""" + if skip_fields is None: + skip_fields = set() + # Note that if we add dataclasses as fields, they will need + # similar handling. + return { + field.name: getattr(self, field.name) + for field in fields(self) if field.name not in skip_fields + } + + +T = TypeVar("T", bound=AttentionMetadata) + + +class AttentionState(ABC, Generic[T]): + """Holds attention backend-specific objects reused during the + lifetime of the model runner.""" + + @abstractmethod + def __init__(self, runner: "ModelRunnerBase"): + ... + + @abstractmethod + @contextmanager + def graph_capture(self, max_batch_size: int): + """Context manager used when capturing CUDA graphs.""" + yield + + @abstractmethod + def graph_clone(self, batch_size: int) -> "AttentionState[T]": + """Clone attention state to save in CUDA graph metadata.""" + ... + + @abstractmethod + def graph_capture_get_metadata_for_batch( + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> T: + """Get attention metadata for CUDA graph capture of batch_size.""" + ... + + @abstractmethod + def get_graph_input_buffers( + self, + attn_metadata: T, + is_encoder_decoder_model: bool = False) -> Dict[str, Any]: + """Get attention-specific input buffers for CUDA graph capture.""" + ... + + @abstractmethod + def prepare_graph_input_buffers( + self, + input_buffers: Dict[str, Any], + attn_metadata: T, + is_encoder_decoder_model: bool = False) -> None: + """In-place modify input buffers dict for CUDA graph replay.""" + ... + + @abstractmethod + def begin_forward(self, model_input: "ModelRunnerInputBase") -> None: + """Prepare state for forward pass.""" + ... + + +class AttentionMetadataBuilder(ABC, Generic[T]): + """Abstract class for attention metadata builders.""" + + @abstractmethod + def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None: + """Create the builder, remember some configuration and parameters.""" + raise NotImplementedError + + @abstractmethod + def prepare(self) -> None: + """Prepare for one batch.""" + raise NotImplementedError + + @abstractmethod + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int) -> T: + """Build attention metadata with on-device tensors.""" + raise NotImplementedError + + +class AttentionLayer(Protocol): + + _k_scale: torch.Tensor + _v_scale: torch.Tensor + _k_scale_float: float + _v_scale_float: float + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + ... + + +class AttentionImpl(ABC, Generic[T]): + + @abstractmethod + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + raise NotImplementedError + + @abstractmethod + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: T, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + +class MLAAttentionImpl(AttentionImpl[T], Generic[T]): + + @abstractmethod + def forward( + self, + layer: AttentionLayer, + hidden_states_or_cq: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: T, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/blocksparse_attn.py b/.venv/lib/python3.11/site-packages/vllm/attention/backends/blocksparse_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..9765e7881ad9d61611026f69a05bdb4bacd4578e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/backends/blocksparse_attn.py @@ -0,0 +1,457 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) +from vllm.attention.ops.blocksparse_attention.interface import ( + LocalStridedBlockSparseAttn, get_head_sliding_step) +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) + + +@dataclass +class BlocksparseParams: + max_seqlen: int + + # Num q heads per tensor-parallel rank/partition + num_heads: int # per TP partition + # Num kv heads per tensor-parallel rank/partition + num_kv_heads: int + + # block size used for blocksparse attention. + # This is the block_size used in `local_blocks`, `vert_stride`. + block_size: int + + # Number of blocks for local attention, i.e., number of + # local attended tokens / `sparse_block_size` + local_blocks: int + + # Attend to one block per every `vert_stride` blocks. + # Controlling the sparsity + vert_stride: int + """ + If to use the same vertical stride offset for all heads, + i.e., attend to the same block of tokens on all heads. + By default, it is False, i.e., attention on the non-local + blocks depends on the `head_idx`, that is on + blocks satisfying + `(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0` + where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`, + `block_idx = position_id // sparse_block_size`. + See `..ops.blocksparse_attention.utils:get_sparse_attn_mask` + for more detail. + """ + homo_head: bool = False + + # If within a group, the kv offsets that each q attends is the same or no. + homo_head_group: bool = False + + # Decided by homo_head and homo_head group + head_sliding_step: int = field(init=False) + + # range of q heads to for a TP rank + active_head_range: Tuple = field(init=False) + + def __post_init__(self): + assert self.block_size > 0 + assert self.local_blocks >= 0 + assert self.vert_stride >= 1 + assert self.num_heads % self.num_kv_heads == 0 + + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + total_heads = tp_size * self.num_heads + total_kv_heads = tp_size * self.num_kv_heads + + if self.homo_head: + self.head_sliding_step = 0 + elif self.homo_head_group: + head_sliding_step = get_head_sliding_step(total_kv_heads, + self.vert_stride) + # negative indicates sliding along kv heads, i.e., homo q group + self.head_sliding_step = -head_sliding_step + else: + self.head_sliding_step = get_head_sliding_step( + total_heads, self.vert_stride) + + self.active_head_range = ( + tp_rank * self.num_heads, + (tp_rank + 1) * self.num_heads, + ) + + +class BlocksparseFlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "BLOCK_SPARSE_FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: + return BlocksparseFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return BlocksparseFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]: + return BlocksparseFlashAttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class BlocksparseFlashAttentionMetadata(AttentionMetadata): + """A copy of Metadata for FlashAttentionBackend, + to avoid having to install flash_attn. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # Max number of query tokens for among request in the batch. + max_decode_query_len: Optional[int] = None + + _cached_prefill_metadata: Optional[ + "BlocksparseFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional[ + "BlocksparseFlashAttentionMetadata"] = None + + @property + def prefill_metadata( + self) -> Optional["BlocksparseFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = BlocksparseFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + +class BlocksparseFlashAttentionMetadataBuilder( + CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]): + + _metadata_cls = BlocksparseFlashAttentionMetadata + + +class BlocksparseFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + assert blocksparse_params is not None + assert alibi_slopes is None, ValueError( + "Alibi not support for blocksparse flash attention.") + assert sliding_window is None, ValueError( + "sliding_window is invalid for blocksparse attention.") + assert logits_soft_cap is None, ValueError( + "logits_soft_cap is invalid for blocksparse attention.") + + if "num_heads" not in blocksparse_params: + blocksparse_params["num_heads"] = num_heads + if "num_kv_heads" not in blocksparse_params: + blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads + self.blocksparse_params = BlocksparseParams(**blocksparse_params) + self.kv_cache_dtype = kv_cache_dtype + + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.alibi_slopes = alibi_slopes + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.local_blocks = self.blocksparse_params.local_blocks + self.vert_stride = self.blocksparse_params.vert_stride + self.sparse_block_size = self.blocksparse_params.block_size + self.head_sliding_step = self.blocksparse_params.head_sliding_step + + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + total_num_heads = num_heads * self.tp_size + self.bs_attn = LocalStridedBlockSparseAttn( + total_num_heads, + self.blocksparse_params.max_seqlen, + self.blocksparse_params.local_blocks, + self.blocksparse_params.vert_stride, + self.blocksparse_params.block_size, + homo_head=self.blocksparse_params.homo_head, + active_head_range=self.blocksparse_params.active_head_range, + ) + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "BlocksparseFlashAttentionImpl") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: BlocksparseFlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache.numel() > 0: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if prefill_meta := attn_metadata.prefill_metadata: + + # Prompt run. + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + + assert kv_cache.numel() == 0 \ + or prefill_meta.block_tables is None \ + or prefill_meta.block_tables.numel() == 0, \ + "Does not support prefix-enabled attention." + + output = self.bs_attn( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + sm_scale=self.scale, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + output = PagedAttention.forward_decode( + query, + key_cache, + value_cache, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + self.blocksparse_params.max_seqlen, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + layer._k_scale, + layer._v_scale, + tp_rank=self.tp_rank, + blocksparse_local_blocks=self.local_blocks, + blocksparse_vert_stride=self.vert_stride, + blocksparse_block_size=self.sparse_block_size, + blocksparse_head_sliding_step=self.head_sliding_step, + ) + + assert output is not None + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/flash_attn.py b/.venv/lib/python3.11/site-packages/vllm/attention/backends/flash_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..6a82127acdf7af7b64ca4b25d0e978d747cbea02 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/backends/flash_attn.py @@ -0,0 +1,942 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer with FlashAttention.""" +from collections import defaultdict +from dataclasses import dataclass +from itertools import accumulate +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType) +from vllm.attention.backends.utils import ( + PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, + compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, + get_seq_len_block_table_args, is_all_cross_attn_metadata_set, + is_all_encoder_attn_metadata_set, is_block_tables_empty) +from vllm.envs import VLLM_FLASH_ATTN_VERSION +from vllm.logger import init_logger +from vllm.multimodal import MultiModalPlaceholderMap +from vllm.platforms import current_platform +from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm.vllm_flash_attn import (fa_version_unsupported_reason, + flash_attn_varlen_func, + flash_attn_with_kvcache, + is_fa_version_supported) + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + +logger = init_logger(__name__) + + +class FlashAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> Type["FlashAttentionImpl"]: + return FlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +@dataclass +class FlashAttentionMetadata(AttentionMetadata): + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + + use_cuda_graph: bool + + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return is_all_encoder_attn_metadata_set(self) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return is_all_cross_attn_metadata_set(self) + + @property + def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + + self._cached_prefill_metadata = FlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + + self._cached_decode_metadata = FlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=self.max_query_len, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + # Batch may be composed of prefill|decodes, adjust query start + # indices to refer to the start of decodes. E.g. + # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=block_tables, + use_cuda_graph=self.use_cuda_graph, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + if turn_prefills_into_decodes: + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class FlashAttentionMetadataBuilder( + AttentionMetadataBuilder[FlashAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def _get_graph_runner_block_tables( + self, num_seqs: int, + block_tables: List[List[int]]) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + graph_block_tables = self.runner.graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + num_seqs = len(seq_lens) + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + + return FlashAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + +class FlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window - 1, + 0) if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + self.attn_type = attn_type + + # if hopper default to FA3, otherwise stick to FA2 for now + # TODO(lucas): profile FA3 on ampere to see if it makes sense to + # use FA3 as default for both + if current_platform.get_device_capability()[0] >= 9: + self.fa_version = 3 if is_fa_version_supported(3) else 2 + else: + self.fa_version = 2 + + if VLLM_FLASH_ATTN_VERSION is not None: + assert VLLM_FLASH_ATTN_VERSION in [2, 3] + self.fa_version = VLLM_FLASH_ATTN_VERSION + + if not is_fa_version_supported(self.fa_version): + logger.error("Cannot use FA version %d is not supported due to %s", + self.fa_version, + fa_version_unsupported_reason(self.fa_version)) + + assert is_fa_version_supported(self.fa_version) + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + output: shape = [num_tokens, num_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + NOTE: It in-place updates the output tensor. + """ + # NOTE(woosuk): FlashAttention does not support FP8 KV cache. + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0, ( + "key/v_scale is not supported in FlashAttention.") + + assert output is not None, "Output tensor must be provided." + + attn_type = self.attn_type + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + + kv_cache_dtype: str = self.kv_cache_dtype + softmax_scale: float = self.scale + window_size = self.sliding_window + alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes + logits_soft_cap: Optional[float] = self.logits_soft_cap + + if kv_cache.numel() > 0: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + # We skip updating the KV cache under two conditions: + # a. When the Attention Type is ENCODER. In this phase, we compute + # only the encoder attention without updating the cache. + # b. When both Key and Value are None. This occurs during + # cross-attention computation in the decoding phase, where the + # KV cache is already populated with the cross-attention + # tensor. Thus, we skip cache updates during this time. + if (attn_type != AttentionType.ENCODER) and (key is not None) and ( + value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory + # profiling run. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[0], + kv_cache[1], + updated_slot_mapping.flatten(), # type: ignore[union-attr] + kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) + decode_query = query[num_prefill_query_tokens:] + decode_output = output[num_prefill_query_tokens:] + # QKV for prefill. + query = query[:num_prefill_query_tokens] + prefill_output = output[:num_prefill_query_tokens] + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if (kv_cache.numel() == 0 or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \ + _get_query_key_seq_metadata(prefill_meta, True, attn_type) + + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] + + flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=q_seq_start_loc, + cu_seqlens_k=k_seq_start_loc, + max_seqlen_q=q_seq_len, + max_seqlen_k=k_seq_len, + softmax_scale=softmax_scale, + causal=_get_causal_option(attn_type), + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + out=prefill_output, + fa_version=self.fa_version, + ) + else: + # prefix-enabled attention + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support prefix caching") + assert prefill_meta.seq_lens is not None + max_seq_len = max(prefill_meta.seq_lens) + flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + seqused_k=prefill_meta.seq_lens_tensor, + max_seqlen_k=max_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + block_table=prefill_meta.block_tables, + softcap=logits_soft_cap, + out=prefill_output, + fa_version=self.fa_version, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + # Use flash_attn_varlen_func kernel for speculative decoding + # because different queries might have different lengths. + + assert decode_meta.max_decode_query_len is not None + # use only for actual varlen decoding + if decode_meta.max_decode_query_len > 1: + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support max_decode_query_len > 1" + ) + flash_attn_varlen_func( + q=decode_query, + k=key_cache, + v=value_cache, + cu_seqlens_q=decode_meta.query_start_loc, + max_seqlen_q=decode_meta.max_decode_query_len, + seqused_k=decode_meta.seq_lens_tensor, + max_seqlen_k=decode_meta.max_decode_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + block_table=decode_meta.block_tables, + out=decode_output, + fa_version=self.fa_version, + ) + else: + # Use flash_attn_with_kvcache for normal decoding. + ( + seq_lens_arg, + _, + block_tables_arg, + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) + flash_attn_with_kvcache( + q=decode_query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + block_table=block_tables_arg, + cache_seqlens=seq_lens_arg, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + out=decode_output.unsqueeze(1), + fa_version=self.fa_version, + ) + return output + + +def _get_query_key_seq_metadata( + attn_metadata, + is_prompt: bool, + attn_type: str, +) -> tuple: + """ + Returns sequence metadata for key and query based on the specified + attention type and whether input is a prompt. + + This function computes the starting locations and maximum sequence lengths + for key and query sequences for different attention types. + + Args: + attn_metadata: The attention metadata object + is_prompt (bool): A flag indicating if the input is a prompt + attn_type (AttentionType): The type of attention being used. + + Returns: + tuple: A tuple containing four integers: + - Starting location for the query sequence. + - Maximum sequence length for the query sequence. + - Starting location for the key sequence. + - Maximum sequence length for the key sequence. + + Raises: + AttributeError: If an invalid attention type is provided. + """ + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.seq_start_loc, max_seq_len) + + elif attn_type == AttentionType.ENCODER_DECODER: + # This is cross attention between the where the key + # is the precomputed encoder attention and query + # is the input sequence. + # Choose query max length based on whether it is prompt + # or not. + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER: + # For encoder attention both the query and the key are same i.e the + # encoder sequence. + return (attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER_ONLY: + assert is_prompt, "Should not have decode for encoder only model." + return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len, + attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def _get_causal_option(attn_type: str) -> bool: + """ + Determine whether the given attention type is suitable for causal + attention mechanisms. + + Args: + attn_type (AttentionType): The type of attention being evaluated + + Returns: + bool: Returns `True` if the attention type is suitable for causal + attention (i.e., not encoder, encoder-only, or encoder-decoder), + otherwise returns `False`. + """ + return not (attn_type == AttentionType.ENCODER + or attn_type == AttentionType.ENCODER_ONLY + or attn_type == AttentionType.ENCODER_DECODER) diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/flashinfer.py b/.venv/lib/python3.11/site-packages/vllm/attention/backends/flashinfer.py new file mode 100644 index 0000000000000000000000000000000000000000..715ed6748b846efcfdb40c124d004254afd90e36 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/backends/flashinfer.py @@ -0,0 +1,1066 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type + +from vllm.multimodal import MultiModalPlaceholderMap + +try: + from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper + from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper + + from vllm.vllm_flash_attn import flash_attn_varlen_func + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +except ImportError: + # Avoid turning these types into variables during type checking + if not TYPE_CHECKING: + BatchDecodeWithPagedKVCacheWrapper = None + CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None + BatchPrefillWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, AttentionType) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.attention.layer import Attention +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, + make_tensor_with_pad) + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + + +class FlashInferBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "FLASHINFER" + + @staticmethod + def get_impl_cls() -> Type["FlashInferImpl"]: + return FlashInferImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashInferMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashInferMetadataBuilder"]: + return FlashInferMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["FlashInferState"]: + return FlashInferState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 128, 256] + + @staticmethod + def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + return torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + return torch.float8_e5m2 + else: + raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + + +@dataclass +class PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters. + """ + + window_left: int + logits_soft_cap: Optional[float] + sm_scale: float + + +def get_per_layer_parameters( + vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]: + """ + Scan all attention layers and determine some hyperparameters + to use during `plan`. + """ + + layers = vllm_config.compilation_config.static_forward_context + per_layer_params: Dict[str, PerLayerParameters] = {} + + for key, layer in layers.items(): + assert isinstance(layer, Attention) + + impl = layer.impl + assert isinstance(impl, FlashInferImpl) + + # Infer hyperparameters from the attention layer + window_size = impl.sliding_window + window_left = window_size[0] if window_size is not None else -1 + logits_soft_cap = impl.logits_soft_cap + sm_scale = impl.scale + + per_layer_params[key] = PerLayerParameters(window_left, + logits_soft_cap, sm_scale) + + return per_layer_params + + +def infer_global_hyperparameters( + per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters: + - `window_left` + - `logits_soft_cap` + - `sm_scale` + + So this function asserts that all layers share the same values for these + hyperparameters and returns the global values. + """ + + assert len(per_layer_params) > 0, "No attention layers found in the model." + + param_sets = list(per_layer_params.values()) + global_params = param_sets[0] + for params in param_sets: + assert params == global_params, ( + "FlashInfer backend currently only supports models in which all " + "layers share the same values for the following hyperparameters: " + "`window_left`, `logits_soft_cap`, `sm_scale`.") + + return global_params + + +class FlashInferState(AttentionState): + + def __init__(self, runner): + self.runner = runner + self._is_graph_capturing = False + self._workspace_buffer = None + self._decode_wrapper = None + self._prefill_wrapper = None + + # Global hyperparameters shared by all attention layers + self.global_hyperparameters: Optional[PerLayerParameters] = None + + self.vllm_config = get_current_vllm_config() + + def _get_workspace_buffer(self): + if self._workspace_buffer is None: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.runner.device) + return self._workspace_buffer + + def _get_prefill_wrapper(self): + if self._prefill_wrapper is None: + self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._get_workspace_buffer(), "NHD") + return self._prefill_wrapper + + def _get_decode_wrapper(self): + if self._decode_wrapper is None: + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( + num_qo_heads // num_kv_heads > 4) + self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self._get_workspace_buffer(), + "NHD", + use_tensor_cores=use_tensor_cores) + return self._decode_wrapper + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + self._graph_decode_wrapper = None + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + self._graph_decode_workspace_buffer = self._get_workspace_buffer() + self._graph_indices_buffer = torch.empty( + max_batch_size * self.runner.cache_config.num_gpu_blocks, + dtype=torch.int32, + device=self.runner.device) + self._graph_indptr_buffer = torch.empty(max_batch_size + 1, + dtype=torch.int32, + device=self.runner.device) + self._graph_last_page_len_buffer = torch.empty( + max_batch_size, dtype=torch.int32, device=self.runner.device) + yield + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + del self._graph_decode_workspace_buffer + del self._graph_indices_buffer + del self._graph_indptr_buffer + del self._graph_last_page_len_buffer + del self._graph_decode_wrapper + + def graph_clone(self, batch_size: int): + assert self._is_graph_capturing + state = self.__class__(self.runner) + state._workspace_buffer = self._graph_decode_workspace_buffer + state._decode_wrapper = self._graph_decode_wrapper + state._prefill_wrapper = self._get_prefill_wrapper() + return state + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + assert self._is_graph_capturing + _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1] + _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size] + + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( + num_qo_heads // num_kv_heads > 4) + self._graph_decode_wrapper = \ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + self._graph_decode_workspace_buffer, _indptr_buffer, + self._graph_indices_buffer, _last_page_len_buffer, "NHD", + use_tensor_cores) + if self.runner.kv_cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) + else: + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) + + paged_kv_indptr_tensor_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + paged_kv_indices_tensor_host = torch.arange(0, + batch_size, + dtype=torch.int32) + paged_kv_last_page_len_tensor_host = torch.full((batch_size, ), + self.runner.block_size, + dtype=torch.int32) + query_start_loc_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + + global_params = infer_global_hyperparameters( + get_per_layer_parameters(self.vllm_config)) + + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + slot_mapping=self._graph_slot_mapping[:batch_size], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + max_prefill_seq_len=0, + block_tables=self._graph_block_tables, + paged_kv_indptr=paged_kv_indptr_tensor_host, + paged_kv_indices=paged_kv_indices_tensor_host, + paged_kv_last_page_len=paged_kv_last_page_len_tensor_host, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=self.runner.model_config.get_head_size(), + page_size=self.runner.block_size, + seq_start_loc=None, + query_start_loc=query_start_loc_host, + device=self.runner.device, + data_type=kv_cache_dtype, + q_data_type=self.runner.model_config.dtype, + use_cuda_graph=True, + decode_wrapper=self._graph_decode_wrapper, + prefill_wrapper=None, + **dataclasses.asdict(global_params), + ) + attn_metadata.begin_forward() + return attn_metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + return { + "slot_mapping": attn_metadata.slot_mapping, + } + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + return + + def begin_forward(self, model_input): + assert not self._is_graph_capturing + state = self + use_cuda_graph = model_input.attn_metadata.use_cuda_graph + is_decode = model_input.attn_metadata.num_prefills == 0 + # In case of multistep chunked-prefill, there might be prefill requests + # scheduled while CUDA graph mode is enabled. We don't run graph in that + # case. + if use_cuda_graph and is_decode: + batch_size = model_input.input_tokens.shape[0] + state = (self.runner.graph_runners[model_input.virtual_engine] + [batch_size].attn_state) + model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( + ) + model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() + model_input.attn_metadata.begin_forward() + + +@dataclass +class FlashInferMetadata(AttentionMetadata): + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] = 1 + + use_cuda_graph: bool = True + + prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None + decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None + + # Metadata for the prefill stage + seq_start_loc: Optional[torch.Tensor] = None + query_start_loc: Optional[torch.Tensor] = None + block_tables: Optional[torch.Tensor] = None + + # used for GPU in-place advance_step + seq_lens_tensor: Optional[torch.Tensor] = None + block_table_bound: Optional[torch.Tensor] = None + + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: Optional[torch.Tensor] = None + # The number of query/output heads + num_qo_heads: Optional[int] = None + # The number of key/value heads + num_kv_heads: Optional[int] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + # Block size of vllm + page_size: Optional[int] = None + # The data type of the paged kv cache + data_type: torch.dtype = None + # The data type of the query + q_data_type: torch.dtype = None + # FlashInfer 0.2 encourages passing host tensors + device: torch.device = torch.device("cpu") + is_profile_run: bool = False + + # The FlashInfer backend currently supports only models in which all layers + # share the same following hyperparameters: + + # The left (inclusive) window size for the attention window, when + # set to `-1`, the window size will be set to the full length of + # the sequence. Defaults to `-1`. + window_left: int = -1 + # The attention logits soft capping value (used in Gemini, Grok and + # Gemma-2, etc.), if not provided, will be set to `0`. If greater + # than 0, the logits will be capped according to formula: + # $$\texttt{logits\_soft\_cap} \times + # \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$, + # where $x$ is the input logits. + logits_soft_cap: Optional[float] = None + # The scale used in softmax, if not provided, will be set to + # `1.0 / sqrt(head_dim)`. + sm_scale: Optional[float] = None + + def __post_init__(self): + # Refer to + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + supported_head_sizes = FlashInferBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f"received {self.head_dim}.") + + def begin_forward(self): + if self.num_prefill_tokens > 0: + if self.paged_kv_indices is None: + return + + assert self.prefill_wrapper is not None + assert self.query_start_loc is not None + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + assert self.block_table_bound is not None + assert self.seq_lens_tensor is not None + self.query_start_loc = self.query_start_loc[:self.num_prefills + 1] + batch_size = self.query_start_loc.shape[0] - 1 + assert batch_size >= 0 + # We will use flash attention for profiling to + # determine the number of blocks. Therefore, + # we don't need to prepare the input for flashinfer for profile run. + if not self.is_profile_run: + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + self.block_table_bound = self.block_table_bound.to(self.device) + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.prefill_wrapper.plan( + self.query_start_loc, + self.paged_kv_indptr[:self.num_prefills + 1], + self.paged_kv_indices, + self.paged_kv_last_page_len[:self.num_prefills], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.data_type) + if self.num_decode_tokens > 0: + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + # handle model warmup path + if self.block_table_bound is not None: + self.block_table_bound = self.block_table_bound.to(self.device) + if self.seq_lens_tensor is not None: + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + + assert self.decode_wrapper is not None + self.decode_wrapper.plan( + self.paged_kv_indptr[self.num_prefills:], + self.paged_kv_indices, + self.paged_kv_last_page_len[self.num_prefills:], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + sm_scale=self.sm_scale, + # kv-cache data type. + kv_data_type=self.data_type, + # query data type. + q_data_type=self.q_data_type) + + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: + if skip_fields is None: + skip_fields = set() + # We need to skip the prefill/decode_wrapper field since it cannot be + # broadcasted with nccl when TP is enabled. + skip_fields.add('prefill_wrapper') + skip_fields.add('decode_wrapper') + return super().asdict_zerocopy(skip_fields) + + @property + def prefill_metadata(self) -> Optional["FlashInferMetadata"]: + if self.num_prefills == 0: + return None + return self + + @property + def decode_metadata(self) -> Optional["FlashInferMetadata"]: + if self.num_decode_tokens == 0: + return None + return self + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + + if turn_prefills_into_decodes: + # When Multi-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + # Flashinfer doesn't support speculative decoding + chunked-prefill + # + multi-step scheduling yet. + assert self.decode_query_len == 1 + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens_tensor is not None + + assert num_seqs > 0 + assert num_queries > 0 + assert model_input.attn_metadata is not None + assert sampled_token_ids is not None + + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + model_input.input_tokens[:num_queries] = sampled_token_ids.flatten() + + # Update GPU tensors + ops.advance_step_flashinfer( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=model_input.input_tokens, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables, + paged_kv_indices=self.paged_kv_indices, + paged_kv_indptr=self.paged_kv_indptr, + paged_kv_last_page_len=self.paged_kv_last_page_len, + block_table_bound=self.block_table_bound) + + +class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + + self.input_builder = input_builder + self.runner = input_builder.runner + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + # Global hyperparameters shared by all attention layers + self.global_hyperparameters: Optional[PerLayerParameters] = None + + self.vllm_config = get_current_vllm_config() + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + self.paged_kv_indices: List[int] = [] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + self.paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request + self.paged_kv_last_page_len: List[int] = [] + self.total_blocks = 0 + self.is_profile_run: bool = False + + if self.global_hyperparameters is None: + # Infer global hyperparameters, since currently we only support + # models in which all layers share the same values for the + # following hyperparameters: + # - `window_left` + # - `logits_soft_cap` + # - `sm_scale` + inferred_params = infer_global_hyperparameters( + get_per_layer_parameters(self.vllm_config)) + self.global_hyperparameters = inferred_params + self.window_left = inferred_params.window_left + self.logits_soft_cap = inferred_params.logits_soft_cap + self.sm_scale = inferred_params.sm_scale + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + computed_block_nums = inter_data.computed_block_nums + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if inter_data.prefix_cache_hit: + block_table = computed_block_nums + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + block_table = block_tables[seq_id][-curr_sliding_window_block:] + self.block_tables.append(block_table) + + is_profile_run = is_block_tables_empty(block_tables) + + # Compute slot mapping. + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + # It is not necessary to add paged_kv_indices, paged_kv_indptr, + # and paged_kv_last_page_len for profile run because we will + # create dummy inputs. + if is_profile_run: + self.is_profile_run = is_profile_run + return + + block_table = block_tables[seq_id] + self._update_paged_kv_tensors(block_table, seq_len) + + def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + self.total_blocks += len(block_table) + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_len.append(last_page_len) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + decode_query_len = max(query_lens[self.num_prefills:], default=1) + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.runner.graph_block_tables[:batch_size] + max_blocks = input_block_tables.shape[1] + for i, block_table in enumerate(self.block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + input_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + input_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + block_tables = torch.from_numpy(input_block_tables).to( + device, non_blocking=True) + + last_paged_kv_indptr = self.paged_kv_indptr[-1] + self.paged_kv_indptr.extend([last_paged_kv_indptr] * + cuda_graph_pad_size) + self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + + assert device is not None + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + if len(self.paged_kv_indptr) > 0: + # extend to the maximum number of blocks as returned by the + # scheduler + self.paged_kv_indices.extend( + [0] * (self.total_blocks - len(self.paged_kv_indices))) + paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, + device="cpu", + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, + device="cpu", + dtype=torch.int) + paged_kv_last_page_len_tensor = torch.tensor( + self.paged_kv_last_page_len, device="cpu", dtype=torch.int) + block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - + 1, + device="cpu", + dtype=torch.int) + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_len_tensor = None + block_table_bound_tensor = None + + if self.runner.kv_cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) + else: + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) + + return FlashInferMetadata( + decode_query_len=decode_query_len, + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + max_prefill_seq_len=max_prefill_seq_len, + block_tables=block_tables, + paged_kv_indptr=paged_kv_indptr_tensor, + paged_kv_indices=paged_kv_indices_tensor, + paged_kv_last_page_len=paged_kv_last_page_len_tensor, + block_table_bound=block_table_bound_tensor, + seq_lens_tensor=seq_lens_tensor, + num_qo_heads=self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config), + num_kv_heads=self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config), + head_dim=self.runner.model_config.get_head_size(), + page_size=self.block_size, + seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc, + device=device, + data_type=kv_cache_dtype, + q_data_type=self.runner.model_config.dtype, + use_cuda_graph=use_captured_graph, + is_profile_run=self.is_profile_run, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + sm_scale=self.sm_scale, + ) + + +class FlashInferImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window - 1, + 0) if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferImpl") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashInferMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # TODO: directly write to output tensor + num_heads: int = self.num_heads + head_size: int = self.head_size + num_kv_heads: int = self.num_kv_heads + kv_cache_dtype: str = self.kv_cache_dtype + softmax_scale: float = self.scale + window_size = self.sliding_window + alibi_slopes = self.alibi_slopes + logits_soft_cap = self.logits_soft_cap + + num_tokens, hidden_size = query.shape + query = query.view(-1, num_heads, head_size) + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + + if kv_cache.numel() > 0: + # Use the same reshape and cache kernel as flash attention. + ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache when the kv_cache_dtype is fp8 + if kv_cache_dtype.startswith("fp8"): + torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + kv_cache_dtype) + kv_cache = kv_cache.view(torch_dtype) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + query = query.contiguous( + ) # Flashinfer requires query to be contiguous + # Query for decode. KV is not needed because it is already cached. + # QKV for prefill. + decode_query = query[num_prefill_tokens:] + query = query[:num_prefill_tokens] + + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + window_left = window_size[0] if window_size is not None else -1 + + prefill_output: Optional[torch.Tensor] = None + decode_output: Optional[torch.Tensor] = None + if prefill_meta := attn_metadata.prefill_metadata: + # We will use flash attention for prefill + # when kv_cache is not provided. + # This happens when vllm runs the profiling to + # determine the number of blocks. + if kv_cache.numel() == 0: + prefill_output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + ) + else: + assert prefill_meta is not None + assert prefill_meta.prefill_wrapper is not None + + assert prefill_meta.prefill_wrapper._causal + assert prefill_meta.prefill_wrapper._window_left == window_left + assert prefill_meta.prefill_wrapper._logits_soft_cap == ( + logits_soft_cap or 0.0) + assert prefill_meta.prefill_wrapper._sm_scale == softmax_scale + + prefill_output = prefill_meta.prefill_wrapper.run( + query, + kv_cache, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + ) + if decode_meta := attn_metadata.decode_metadata: + assert decode_meta is not None + assert decode_meta.decode_wrapper is not None + + assert decode_meta.decode_wrapper._window_left == window_left + assert decode_meta.decode_wrapper._logits_soft_cap == ( + logits_soft_cap or 0.0) + assert decode_meta.decode_wrapper._sm_scale == softmax_scale + + decode_output = decode_meta.decode_wrapper.run( + decode_query, + kv_cache, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + ) + + if prefill_output is None and decode_output is not None: + # Decode only batch. + output, num_tokens = decode_output, num_decode_tokens + elif decode_output is None and prefill_output is not None: + # Prefill only batch. + output, num_tokens = prefill_output, num_prefill_tokens + else: + # Chunked prefill batch does not work with speculative decoding in + # FlashInfer backend, so the query length for decode should be 1. + assert prefill_output is not None + assert decode_output is not None + assert decode_meta is not None + assert decode_meta.decode_query_len == 1 + decode_output = decode_output.squeeze(1) + output = torch.cat([prefill_output, decode_output], dim=0) + return output.view(num_tokens, hidden_size) diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/hpu_attn.py b/.venv/lib/python3.11/site-packages/vllm/attention/backends/hpu_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad5e6e8e4e17e57f1f8e073b90b812b72b4765c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/backends/hpu_attn.py @@ -0,0 +1,293 @@ +# SPDX-License-Identifier: Apache-2.0 + +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +############################################################################### + +import os +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import vllm_hpu_extension.ops as ops +from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, + VLLMKVCache) + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention, + HPUPagedAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class HPUAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "HPU_ATTN" + + @staticmethod + def get_impl_cls() -> Type["HPUAttentionImpl"]: + return HPUAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return HPUAttentionMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + HPUPagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): + """Metadata for HPUAttentionbackend.""" + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + attn_bias: Optional[torch.Tensor] + seq_lens_tensor: Optional[torch.Tensor] + + +class HPUAttentionImpl(AttentionImpl, torch.nn.Module): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + max_seq_len: int = 4096, + attn_type: str = AttentionType.DECODER, + ) -> None: + super(AttentionImpl, self).__init__() + self.kv_cache_dtype = kv_cache_dtype + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.matmul_qk = Matmul() + self.softmax = Softmax() + self.matmul_av = Matmul() + self.batch2block_matmul = Matmul() + self.block2batch_matmul = Matmul() + # NOTE(kzawora): Contiguous PA is off until model runner supports it + self.k_cache = VLLMKVCache() + self.k_cache.use_contiguous_pa = False + self.v_cache = VLLMKVCache() + self.v_cache.use_contiguous_pa = False + # NOTE(kzawora): Pipelined PA is off until model runner supports it + ops.pa_impl = ops.pa + + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = sliding_window + self.alibi_slopes = alibi_slopes + if alibi_slopes is not None: + alibi_slopes_tensor = torch.tensor(alibi_slopes, + dtype=torch.bfloat16) + self.alibi_slopes = alibi_slopes_tensor + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', + '0').lower() in ['1', 'true'] + self.fused_scaled_dot_product_attention = None + if self.prefill_usefusedsdpa: + assert alibi_slopes is None, \ + 'Prefill with FusedSDPA not supported with alibi slopes!' + try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA + self.fused_scaled_dot_product_attention = ModuleFusedSDPA( + FusedSDPA) + except ImportError: + logger().warning("Could not import HPU FusedSDPA kernel. " + "vLLM will use native implementation.") + + suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "HPUAttentionImpl") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: HPUAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with xFormers and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + batch_size, seq_len, hidden_size = query.shape + _, seq_len_kv, _ = key.shape + + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + block_indices = attn_metadata.block_indices + block_offsets = attn_metadata.block_offsets + if attn_metadata.is_prompt: + key = key.unflatten(0, (block_indices.size(0), -1)) + value = value.unflatten(0, (block_indices.size(0), -1)) + if kv_cache is not None: + key_cache, value_cache = HPUPagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + key_cache = self.k_cache(key, key_cache, block_indices, + block_offsets) + value_cache = self.v_cache(value, value_cache, block_indices, + block_offsets) + + if attn_metadata.is_prompt: + # Prompt run. + if not self.prefill_usefusedsdpa: + # TODO: move this outside of model + assert attn_metadata.attn_bias is not None, \ + 'attn_bias must be set before calling model.forward!' + attn_bias = attn_metadata.attn_bias + if self.alibi_slopes is not None: + position_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, + attn_bias.dtype, + attn_bias.shape[-1]) + attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) + attn_bias.add_(position_bias) + else: + attn_bias = None + + query_shape = (batch_size, seq_len, self.num_heads, self.head_size) + kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, + self.head_size) + out = ops.prompt_attention( + query.view(query_shape), + key.view(kv_shape), + value.view(kv_shape), + attn_bias=attn_bias, + p=0.0, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + softmax_op=self.softmax, + matmul_av_op=self.matmul_av, + fsdpa_op=self.fused_scaled_dot_product_attention, + ) + output = out.reshape(batch_size, seq_len, hidden_size) + else: + # Decoding run. + output = HPUPagedAttention.forward_decode( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_list=attn_metadata.block_list, + block_mapping=attn_metadata.block_mapping, + block_bias=attn_metadata.attn_bias, + block_scales=attn_metadata.block_scales, + block_groups=None, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + matmul_av_op=self.matmul_av, + batch2block_matmul_op=self.batch2block_matmul, + block2batch_matmul_op=self.block2batch_matmul, + keys_fetch_func=self.k_cache.fetch_from_cache, + values_fetch_func=self.v_cache.fetch_from_cache) + # Reshape the output tensor. + return output.view(batch_size, seq_len, hidden_size) + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + num_kv_heads: int, + dtype: torch.dtype, + seq_len: int, +) -> torch.Tensor: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + # Calculate a matrix where each element represents ith element- jth + # element. + bias = bias[None, :] - bias[:, None] + + padded_len = (seq_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + 1, # batch size + num_heads, + seq_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :seq_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + if num_heads != num_kv_heads: + bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) + return bias diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/ipex_attn.py b/.venv/lib/python3.11/site-packages/vllm/attention/backends/ipex_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..b4879af4cf20e93fcc9156cca0853b137a14f757 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/backends/ipex_attn.py @@ -0,0 +1,387 @@ +# SPDX-License-Identifier: Apache-2.0 +""" Attention layer with torch scaled_dot_product_attention + and PagedAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm._ipex_ops import ipex_ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) + +_PARTITION_SIZE = 512 + + +class IpexAttnBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "IPEX" + + @staticmethod + def get_impl_cls() -> Type["IpexAttnBackendImpl"]: + return IpexAttnBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["IpexAttnMetadata"]: + return IpexAttnMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + from vllm._ipex_ops import ipex_ops as ops + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + from vllm._ipex_ops import ipex_ops as ops + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +@dataclass +class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for IpexAttnBackend. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + slot_mapping: torch.Tensor + seq_lens: Optional[List[int]] + seqlen_q: Optional[torch.Tensor] + max_seqlen: Optional[int] + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[torch.Tensor]] = None + + @property + def prefill_metadata(self) -> Optional["IpexAttnMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self + + return None + + @property + def decode_metadata(self) -> Optional["IpexAttnMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0 + return None + + return self + + +class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "IPEX backend does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + if logits_soft_cap is None: + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + if kv_cache_dtype != "auto": + raise NotImplementedError( + "IPEX backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "IpexAttnBackendImpl") + + def split_kv_cache( + self, + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 1 + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: IpexAttnMetadata, # type: ignore + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with IPEX varlen_attention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache.numel() > 0: + key_cache, value_cache = self.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + ipex_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if attn_metadata.is_prompt: + assert attn_metadata.seq_lens is not None + if (kv_cache.numel() == 0 + or attn_metadata.block_tables.numel() == 0): + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=1) + + if attn_metadata.attn_bias is None: + if self.alibi_slopes is not None: + att_masks = _make_alibi_bias( + self.alibi_slopes, query.dtype, + attn_metadata.seq_lens) # type: ignore + elif self.sliding_window is not None: + att_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, self.sliding_window, + query.dtype) # type: ignore + else: + att_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, None, dtype=query.dtype) + attn_metadata.attn_bias = att_masks + + output = torch.empty( + (num_tokens, self.num_heads, self.head_size), + dtype=query.dtype, + device=query.device) + ipex_ops.varlen_attention( + query, + key, + value, + output, + attn_metadata.seqlen_q, + attn_metadata.seqlen_q, + attn_metadata.max_seqlen, + attn_metadata.max_seqlen, + pdropout=0.0, + softmax_scale=self.scale, + zero_tensors=False, + is_causal=True, + return_softmax=False, + gen_=None, + logits_soft_cap=self.logits_soft_cap, + ) + else: + # prefix-enabled attention + raise RuntimeError( + "IPEX backend doesn't support prefix decoding.") + + else: + # Decoding run. + max_seq_len = attn_metadata.max_decode_seq_len + output = torch.empty_like(query) + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory + # shortage. + use_v1 = (max_seq_len <= 8192 and + (max_num_partitions == 1 or num_seqs * num_heads > 512)) + if use_v1: + # Run PagedAttention V1. + ipex_ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ipex_ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat((num_heads, 1, 1)) + bias.mul_(alibi_slopes[:, None, None]) + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype, + device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_sliding_window_bias( + seq_lens: List[int], + window_size: Optional[int], + dtype: torch.dtype, +) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + tensor = torch.full( + (1, seq_len, seq_len), + dtype=dtype, + fill_value=1, + ) + shift = 0 + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + attn_biases.append(mask.to(dtype)) + + return attn_biases diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__init__.py b/.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d2463ebdf4c18cb442469778cf80922f86f8829 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91cdb5fe4b366e1837d3752f5b12dabc214790fa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/utils.py b/.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cd8c08e5ab47260af39beb7d8f4370a6bae46b72 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/utils.py @@ -0,0 +1,541 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, Generic, List, Optional, Tuple + +import torch +from compressed_tensors.quantization import QuantizationStrategy + +from vllm import _custom_ops as ops +from vllm import envs +from vllm.attention.backends.abstract import (AttentionLayer, + AttentionMetadata, + MLAAttentionImpl, T) +from vllm.distributed import (get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, RowParallelLinear, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsLinearMethod) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsW8A8Fp8) +from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + scaled_dequantize, scaled_quantize) +from vllm.model_executor.layers.rotary_embedding import ( + DeepseekScalingRotaryEmbedding, RotaryEmbedding) + +try: + from vllm.vllm_flash_attn import flash_attn_varlen_func +except ImportError: + from flash_attn import flash_attn_varlen_func + + +@dataclass +class MLACommonMetadata(AttentionMetadata): + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + + +class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): + """ + Common class for implementing repeated parts + + Main reference: DeepseekV2 paper, and FlashInfer Implementation + (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + + Deepseek's MLA attention works the following way: + * Use a single latent vector to represent the entire KV cache. + * The attention "simulates" a multi-head attention, while the compute is + similar to multi-query attention. + * The dataflow is as follows, + + * B: batch/sequence length + * H: hidden size + * N: number of attention heads + * Lq: latent dimension for Q + * Lkv: latent dimension for K/V + * P: nope dimension, P+R is the actual head_dim in common attention. + * R: rope dimension, this slide of the head_dim goes through rope. + * V: V head dim. + * kv_c: latent/compressed KV + * q_c: latent/compressed Q + + # + # Outside the MLA attention backend + # + + 1. The hidden states (B, H) are projected down into cq (B, Lq) and + kv_c_k_pe (B, Lkv+R). + 2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq + and kv_c are normalized. + + # + # Inside the MLA attention backend + # + + * if prefill: + + 3. The q_c is then projected up into the multi-head version. + * q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope + (B, N, P) and q_pe (B, N, R). + 4. q_pe, k_pe are then passed through rotary embeddings. + 5. kv_c and k_pe are concatenated and inserted into the cache + 6. The kv_c is then projected up into the multi-head version. + * kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope + dimensions for K and V, which is split into k_nope (B, N, P) + and v (B, N, V). + 7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from + q_nope, q_pe, k_nope, k_pe. + 8. Attention is computued with q, k, v. + 9. The attention computation returns (B, N, V), which is projected back + to (B, H) using out projection. + + * if decode: + + 3. Here's the change, we do not perform up the full up projection for + q_c, and there is no up projection at all for kv_c. This is + achieved by the technique of "weight absorption". The paper says + "Fortunately, due to the associative law of matrix multiplication, + we can absorb WUK into WUQ, and WUV into WO" + * The q up projection turns (B, Lq) into (B, N, (P+R)), we split it + into W_UQ (Lq, N, P) and W_QR (Lq, N, R). + * The kv_c up projection turns (B, Lkv) into (B, N, (P+V)), we split + it into W_UK (Lkv, N, P) and W_UV (Lkv, N, V). + * The out projection shape W_O (N*V, H) turns (B, N, V) into (B, H). + * We can precompute the product of W_UQ and W_UK into + W_UQ_UK (Lq, N, Lkv), which is possible due to QK^T operation in + attention. + * We can precompute the product of W_UV and W_O into + W_UV_O (N, Lkv, H), which is possible due to V@O as the + "epilogue" of attention + 4. We still need to compute q_pe (B, N, R) by applying W_QR to q_latent. + 5. q_pe, k_pe are then passed through rotary embeddings. + 6. kv_c and k_pe are concatenated and inserted into the cache + 7. By applying W_UQ_UK to q_latent, we have the new q_nope of shape + (B, N, Lkv). + 8. q (B, N, (Lkv+R)), k (B, (Lkv+R)) are assembled from q_nope, q_pe, + kv_a, k_pe. v (B, Lkv) is exactly the same vector as kv_a. + 9. The attention is computed with q, k, v. Note that we just performed + a MQA attention with (LKv+R) as our head dim. + 10. The KV cache is updated using the new entries k (B, N, (Lkv+R)), + which included the v and rope values. + 11. The attention computation returns (B, N, Lkv), which is projected + back to (B, H) using W_UV_O. + + From @tsu-bin's calculation, we only want to use the absorption technique + for decode. The prefill algorithm should still use the up-projected MHA + for less flops and memory usage. + + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + q_lora_rank: Optional[int], + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + rotary_emb: RotaryEmbedding, + # q_proj should be q_b_proj if q_lora_rank is not None, but from an + # attention backend perspective we rely on the layer to pass in the + # correct matrix + q_proj: ColumnParallelLinear, + kv_b_proj: ColumnParallelLinear, + o_proj: RowParallelLinear, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + + self.rotary_emb = rotary_emb + self.use_yarn_rope = isinstance(rotary_emb, + DeepseekScalingRotaryEmbedding) + self.q_proj = q_proj + self.kv_b_proj = kv_b_proj + self.o_proj = o_proj + + def _v_up_proj_and_o_proj(self, x): + if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + if is_fp8(self.W_UV_O): + output_parallel = apply_fp8_linear_generic( + x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales, + self.reqaunt_input_group_shape, + self.reqaunt_weight_group_shape) + else: + output_parallel = torch.matmul(x.flatten(start_dim=1), + self.W_UV_O) + if self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + return output + else: + x = torch.einsum("bnl,lnv->bnv", x, self.W_UV) + return self.o_proj(x.reshape(-1, + self.num_heads * self.v_head_dim))[0] + + def _q_proj_and_k_up_proj(self, x): + if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + if is_fp8(self.W_Q_UK): + return apply_fp8_linear_generic( + x, self.W_Q_UK, self.W_Q_UK_scales, + self.reqaunt_input_group_shape, + self.reqaunt_weight_group_shape).view( + -1, self.num_heads, self.kv_lora_rank) + return torch.matmul(x, self.W_Q_UK)\ + .view(-1, self.num_heads, self.kv_lora_rank) + else: + x = torch.matmul(x, self.W_Q)\ + .view(-1, self.num_heads, self.qk_nope_head_dim) + return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\ + .view(-1, self.num_heads, self.kv_lora_rank) + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def is_layer_fp8(layer: LinearBase) -> bool: + return isinstance(layer.quant_method, Fp8LinearMethod) or\ + (isinstance(layer.quant_method, CompressedTensorsLinearMethod)\ + and isinstance(layer.scheme, CompressedTensorsW8A8Fp8)) + + def quantization_scheme_supported(layer: LinearBase) -> bool: + return isinstance(layer.quant_method, UnquantizedLinearMethod) or \ + is_layer_fp8(layer) + + # TODO(lucas) This is very gross, we need a more wide scale refactor of + # all the FP8 code with a more standard way of + # defining schemes/group-shapes, we should also potentially force + # quant_methods to support a decompress function + # + # returns input_group_shape, weight_group_shape + def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \ + Tuple[Tuple[int, int], Tuple[int, int]]: + if isinstance(layer.quant_method, Fp8LinearMethod): + if layer.quant_method.block_quant is not None: + weight_block_size = \ + layer.quant_method.quant_config.weight_block_size + # per-token-group (1, X), block-quantized (X, Y) + return (1, weight_block_size[-1]), weight_block_size + else: + return (-1, -1), (-1, -1) # per-tensor, per-tensor + elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\ + and isinstance(layer.scheme, CompressedTensorsW8A8Fp8): + # this is hacky but we always assume the for + # CompressedTensorsW8A8Fp8 the input is dynamic per-token + # we ignore if it is static-per-tensor since we are going to + # requantize after later anyways + strategy = layer.scheme.strategy + if strategy == QuantizationStrategy.TENSOR: + return (1, -1), (-1, -1) # per-token, per-tensor + elif strategy == QuantizationStrategy.CHANNEL: + return (1, -1), (-1, 1) # per-token, per-channel + else: + raise NotImplementedError( + f"QuantizationStrategy.{strategy} is not supported for " + "fp8 MLA, please run with VLLM_MLA_DISABLE=1") + else: + raise NotImplementedError( + "Can't determine scale group shapes for " + f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1" + ) + + def get_scales(layer: LinearBase) -> torch.Tensor: + if hasattr(layer, "weight_scale_inv"): + return layer.weight_scale_inv + return layer.weight_scale + + def get_and_maybe_dequant_weights(layer: LinearBase): + if is_layer_fp8(layer): + if isinstance(layer.quant_method, \ + CompressedTensorsLinearMethod) and \ + isinstance(layer.scheme, CompressedTensorsW8A8Fp8): + # NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8` + # seems to store weights as (input, output) instead of + # (output, input) so we need to transpose + weight = layer.weight.T # standardize to (output, input) + else: + weight = layer.weight + _, weight_scale_group_shape = \ + get_scale_group_shapes_for_fp8(layer) + scales = get_scales(layer) + + return scaled_dequantize(weight, scales, + weight_scale_group_shape) + else: + return layer.weight + + if not (quantization_scheme_supported(self.kv_b_proj) and\ + quantization_scheme_supported(self.q_proj) and\ + quantization_scheme_supported(self.o_proj)): + raise NotImplementedError( + "Only FP8 and UnquantizedLinearMethod are supported for MLA" + ", please run with VLLM_MLA_DISABLE=1") + + weight_dtype = self.kv_b_proj.weight.dtype + assert self.o_proj.weight.dtype == weight_dtype + assert self.q_proj.weight.dtype == weight_dtype + + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ + .view(-1, self.num_heads, self.qk_head_dim) + + # can be W_Q or W_UQ depending q_lora_rank, the former if + # q_lora_rank is None, the latter otherwise. From the Attention backend + # perspective though we call these both W_Q and rely on the layer + # to pass in the correct matrix + W_Q = q_proj_weight[..., :self.qk_nope_head_dim] + self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\ + .flatten(start_dim=1).contiguous() + + # W_QR is small so for simplicity we dont bother requantizing it + self.W_QR = self.W_QR.to(act_dtype) + + if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION + if is_fp8(weight_dtype) and requantization_enabled: + # This assumes it wise to requantize using the same group shapes + # (i.e. strategy, per-tensor, per-channel, block etc.) that the + # weights were originally quantized + requant_input_group_shape, requant_weight_group_shape = \ + get_scale_group_shapes_for_fp8(self.q_proj) + assert (requant_input_group_shape, requant_weight_group_shape)\ + == get_scale_group_shapes_for_fp8(self.kv_b_proj) + assert (requant_input_group_shape, requant_weight_group_shape)\ + == get_scale_group_shapes_for_fp8(self.o_proj) + self.reqaunt_input_group_shape = requant_input_group_shape + self.reqaunt_weight_group_shape = requant_weight_group_shape + + # + # Perform matrix-absorption following + # https://github.com/flashinfer-ai/flashinfer/pull/551 + # for decode, as a result we end up with absorbed weights for decode + # and another copy of raw weights for prefill. + # + self.W_UK, self.W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + # We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK + # depending q_lora_rank, the former if q_lora_rank is None, the + # latter otherwise + # basically if q_lora_rank is none we are absorbing into q_proj + # instead of UQ + W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\ + .flatten(start_dim=1).contiguous() + + if is_fp8(weight_dtype) and requantization_enabled: + W_Q_UK, W_Q_UK_scales = scaled_quantize( + W_Q_UK, + self.reqaunt_weight_group_shape, + quant_dtype=current_platform_fp8_dtype) + # For FP8 save the transpose so we can use + # `apply_w8a8_block_fp8_linear` directly + self.W_Q_UK = W_Q_UK.T.contiguous() + self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous() + else: + self.W_Q_UK = W_Q_UK.to(act_dtype) + + W_O = get_and_maybe_dequant_weights(self.o_proj)\ + .view(-1, self.num_heads, self.v_head_dim) + W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\ + .flatten(start_dim=0, end_dim=1).contiguous() + + if is_fp8(weight_dtype) and requantization_enabled: + W_UV_O, W_UV_O_scales = scaled_quantize( + W_UV_O, + self.reqaunt_weight_group_shape, + quant_dtype=current_platform_fp8_dtype) + # For FP8 save the transpose so we can use + # `apply_w8a8_block_fp8_linear` directly + self.W_UV_O = W_UV_O.T.contiguous() + self.W_UV_O_scales = W_UV_O_scales.T.contiguous() + else: + self.W_UV_O = W_UV_O.to(act_dtype) + + self.tp_size = get_tensor_model_parallel_world_size() + else: + if is_fp8(weight_dtype): + raise NotImplementedError( + "Currently fp8 requires matrix absorption") + + self.W_UV = W_UV + self.W_UK = W_UK + self.W_Q = W_Q.flatten(start_dim=1) + + @abstractmethod + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + attn_metadata: T, + ) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: T, + ) -> torch.Tensor: + raise NotImplementedError + + def apply_pure_rope( + self, + input_positions: torch.Tensor, + q_pe: torch.Tensor, + k_pe: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + seq_len = input_positions.size(0) + ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape + + q_pe, k_pe = self.rotary_emb( + input_positions, + q_pe.reshape(seq_len, -1), + k_pe.reshape(seq_len, -1), + ) + q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape) + + return q_pe, k_pe + + def forward( + self, + layer: AttentionLayer, + hidden_states_or_q_c: torch.Tensor, # query in unified attn + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: T, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if output is not None: + raise NotImplementedError( + "output is not yet supported for MLAImplBase") + + is_decode = attn_metadata.decode_metadata is not None + is_prefill = attn_metadata.prefill_metadata is not None + + if (is_decode and is_prefill): + raise NotImplementedError( + "chunked prefill is not supported for MLAImplBase") + + # Restore head dim (for rotary embedding) + k_pe = k_pe.unsqueeze(1) + assert hasattr(attn_metadata, "input_positions") + rope_fn = (self.rotary_emb + if self.use_yarn_rope else self.apply_pure_rope) + + if is_decode: + q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c) + q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\ + .view(-1, self.num_heads, self.qk_rope_head_dim) + q_pe, k_pe = rope_fn(attn_metadata.input_positions, q_pe, k_pe) + else: + assert is_prefill + q = self.q_proj(hidden_states_or_q_c)[0]\ + .view(-1, self.num_heads, self.qk_head_dim) + + # TODO(lucas): there must be a nicer way to write this line + q[..., self.qk_nope_head_dim:], k_pe = \ + rope_fn( + attn_metadata.input_positions, + q[..., self.qk_nope_head_dim:], k_pe) + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + if attn_metadata.prefill_metadata is not None: + return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata) + + if attn_metadata.decode_metadata is not None: + return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata) + + # Optional common flash-attn based prefill + def _forward_prefill_flash( + self, + q: torch.Tensor, + k_c_normed: torch.Tensor, + k_pe: torch.Tensor, + seq_start_loc: torch.Tensor, + max_prefill_seq_len: int, + ) -> torch.Tensor: + + kv_nope = self.kv_b_proj(k_c_normed)[0]\ + .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + + attn_output = flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=seq_start_loc, + cu_seqlens_k=seq_start_loc, + max_seqlen_q=max_prefill_seq_len, + max_seqlen_k=max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + ) + attn_output = attn_output\ + .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ + .reshape(-1, self.num_heads * v.shape[-1]) + + return self.o_proj(attn_output)[0] diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/openvino.py b/.venv/lib/python3.11/site-packages/vllm/attention/backends/openvino.py new file mode 100644 index 0000000000000000000000000000000000000000..9908620a32a238ac31202fe570ca0752cdc47213 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/backends/openvino.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Type + +import openvino as ov +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.multimodal import MultiModalPlaceholderMap + + +def copy_cache_block(src_tensor: ov.Tensor, dst_tensor: ov.Tensor, + src_offset: int, dst_offset: int) -> None: + + def create_roi_tensor( + tensor: ov.Tensor, + block_number: int, + ) -> ov.Tensor: + roi_begin = ov.runtime.Coordinate([0, 0, 0, 0]) + roi_end = ov.runtime.Coordinate(tensor.get_shape()) + + roi_begin[0] = block_number + roi_end[0] = block_number + 1 + + if isinstance(tensor, ov.Tensor): + return ov.Tensor(tensor, roi_begin, roi_end) + else: + return ov.RemoteTensor(tensor, roi_begin, roi_end) + + src_roi_tensor = \ + create_roi_tensor(src_tensor, src_offset) + dst_roi_tensor = \ + create_roi_tensor(dst_tensor, dst_offset) + src_roi_tensor.copy_to(dst_roi_tensor) + + +class OpenVINOAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "OPENVINO" + + @staticmethod + def get_impl_cls(): + # OpenVINO implements PagedAttention as part of the Optimum + # exported model + raise NotImplementedError + + @staticmethod + def make_metadata(*args, **kwargs) -> "AttentionMetadata": + raise NotImplementedError + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata": + return OpenVINOAttentionMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, num_kv_heads, block_size, head_size) + + @staticmethod + def swap_blocks( + src_tensor: ov.Tensor, + dst_tensor: ov.Tensor, + src_to_dists: List[Tuple[int, int]], + ) -> None: + for src, dst in src_to_dists: + copy_cache_block(src_tensor, dst_tensor, src, dst) + + @staticmethod + def copy_blocks( + kv_caches: List[Tuple[ov.Tensor, ov.Tensor]], + src_to_dists: List[Tuple[int, int]], + ) -> None: + for src, dst in src_to_dists: + for key_cache, value_cache in kv_caches: + copy_cache_block(key_cache, key_cache, src, dst) + copy_cache_block(value_cache, value_cache, src, dst) + + +@dataclass +class OpenVINOAttentionMetadata: + """Metadata for OpenVINOAttentionBackend. + + Basic terms used below: + - batch_size_in_sequences - total number of sequences to execute​ + - prompt_lens – per sequence size number of scheduled tokens​ + - batch_size_in_tokens = sum(prompt_lens)​ + - max_context_len = max(context_lens)​ + - max_num_blocks = div_up(max_context_len / BLOCK_SIZE)​ + - num_blocks – total number of blocks in block_indices​ + """ + + # Describes past KV cache size for each sequence within a batch + # Shape: [batch_size_in_sequences] + # Type: i32​ + past_lens: torch.Tensor + + # Describes start indices of input / speculative tokens from + # current sequences within a batch sequence​ + # Shape: [batch_size_in_sequences + 1]​ + # Type: i32 + subsequence_begins: torch.Tensor + + # Describes block tables for each sequence within a batch​ - + # indices along 0th dimension in key_cache and value_cache inputs​ + # Shape: [num_blocks] + # Type: i32​ + block_indices: torch.Tensor + + # Describes block tables for each sequence within a batch​ - + # for i-th element, it is an index in block_indices with the + # first block belonging to i-th sequence​ + # Shape: [batch_size_in_sequences + 1] + # Type: i32​ + block_indices_begins: torch.Tensor + + # Describes max context length + # Shape: scalar + # Type: i32 + max_context_len: torch.Tensor + + # The index maps that relate multi-modal embeddings to the corresponding + # placeholders. + # + # N.B. These aren't really related to attention and don't belong on this + # type -- this is just a temporary solution to make them available to + # `model_executable`. + multi_modal_placeholder_index_maps: Optional[Dict[ + str, MultiModalPlaceholderMap.IndexMap]] + + # Enable/disable KV scales calculation. This is so that we can disable the + # calculation until after prefill and cuda graph capture. + enable_kv_scales_calculation: bool diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/pallas.py b/.venv/lib/python3.11/site-packages/vllm/attention/backends/pallas.py new file mode 100644 index 0000000000000000000000000000000000000000..b61dfe63ddcaaffea67b617b4746d71eb0e07805 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/backends/pallas.py @@ -0,0 +1,337 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import torch_xla.experimental.custom_kernel # Required to register custom ops. + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState + + +class PallasAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "PALLAS" + + @staticmethod + def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: + return PallasAttentionBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["PallasMetadata"]: + return PallasMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_kv_heads, num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise RuntimeError("swap_blocks is not used for the TPU backend.") + + @torch.compile(backend="openxla") + @staticmethod + def copy_blocks( + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + src_to_dists: Tuple[torch.Tensor, torch.Tensor], + ) -> None: + src_indices, dst_indices = src_to_dists + for k_cache, v_cache in kv_caches: + torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) + k_cache[:, dst_indices] = k_cache[:, src_indices] + torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) + v_cache[:, dst_indices] = v_cache[:, src_indices] + + +@dataclass +class PallasMetadata(AttentionMetadata): + + # Currently, input sequences can only contain all prefills + # or all decoding. + block_tables: Optional[torch.Tensor] = None + context_lens: Optional[torch.Tensor] = None + effective_query_lens: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self) -> Optional["PallasMetadata"]: + if self.num_prefills == 0: + return None + + assert self.num_decode_tokens == 0 + return self + + @property + def decode_metadata(self) -> Optional["PallasMetadata"]: + if self.num_decode_tokens == 0: + return None + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.block_tables is not None + assert self.context_lens is not None + return self + + +class PallasAttentionBackendImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.logits_soft_cap = logits_soft_cap + if head_size % 128 != 0: + raise NotImplementedError("Head size must be a multiple of 128.") + if alibi_slopes is not None: + raise NotImplementedError("Alibi slopes is not supported.") + if sliding_window is not None: + raise NotImplementedError("Sliding window is not supported.") + if kv_cache_dtype != "auto": + raise NotImplementedError("FP8 KV cache dtype is not supported.") + if blocksparse_params is not None: + raise NotImplementedError("Blocksparse is not supported.") + + if torch_xla.tpu.version() < 4: + raise NotImplementedError("TPU version must be 4 or higher.") + + self.megacore_mode = None + tpu_env = torch_xla.tpu.get_tpu_env() + tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) + or tpu_env.get("TYPE", None) + or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) + assert tpu_type is not None + tpu_type = tpu_type.lower() + + if (("lite" not in tpu_type) and ("v6" not in tpu_type)): + if self.num_kv_heads % 2 == 0: + self.megacore_mode = "kv_head" + else: + # NOTE(woosuk): If the batch size is not a multiple of 2, the + # megacore mode will be None. + self.megacore_mode = "batch" + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + attn_metadata: PallasMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with Pallas attention. + + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] + kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] + NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor + with shape [0] for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size, seq_len, num_heads * head_size] + """ + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + batch_size, seq_len, hidden_size = query.shape + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) + value = value.view(batch_size, seq_len, self.num_kv_heads, + self.head_size) + + if kv_cache[0].numel() > 0: + slot_mapping = attn_metadata.slot_mapping + key_cache, value_cache = kv_cache + write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) + + query = query * self.scale + if attn_metadata.num_prefills > 0: + if attn_metadata.block_tables is None: + # Prefill without paged KV cache. + assert seq_len % 16 == 0, ( + "Pallas FlashAttention kernel requires seq_len to be a " + f"multiple of 16 but got {seq_len}") + + # Handle GQA/MQA. + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, + dim=-2) + key = key.view(batch_size, seq_len, self.num_heads, + self.head_size) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=-2) + value = value.view(batch_size, seq_len, self.num_heads, + self.head_size) + # FlashAttention kernel requires the input shape to be + # [batch_size, num_heads, seq_len, d_model] + # while the input is [batch_size, seq_len, num_heads, d_model]. + # Permute the input to match the required format. + output = torch.ops.xla.flash_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + True, + ) + output = output.permute(0, 2, 1, 3) + else: + # Prefill with paged KV cache. + # TODO(woosuk): Tune the below knobs. + num_kv_pages_per_compute_block = 16 + num_queries_per_compute_block = 16 + assert seq_len % num_queries_per_compute_block == 0 + output = torch.ops.xla.multi_queries_paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + attn_metadata.effective_query_lens, + num_kv_pages_per_compute_block, + num_queries_per_compute_block, + use_kernel=True, + attn_logits_soft_cap=self.logits_soft_cap, + ) + else: + # Decoding run. + assert kv_cache[0].numel() > 0 + query = query.squeeze(dim=1) + pages_per_compute_block = 16 # TODO(woosuk): Tune this value. + + assert attn_metadata.block_tables is not None + assert attn_metadata.context_lens is not None + # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire + # block table in SMEM. Therefore, if the block table is too large, + # the kernel compilation will fail. To avoid this, we split the + # batch dimension into smaller chunks and run the kernel multiple + # times. + MAX_SMEM_USAGE = 512 * 1024 + size_per_seq = 4 * attn_metadata.block_tables.shape[1] + max_num_seq = MAX_SMEM_USAGE // size_per_seq + + if batch_size <= max_num_seq: + output = paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + pages_per_compute_block, + self.megacore_mode, + attn_logits_soft_cap=self.logits_soft_cap, + ) + else: + chunk_size = max_num_seq + # Make sure the chunk size is a multiple of 2. + chunk_size = chunk_size // 2 * 2 + num_chunks = (batch_size + chunk_size - 1) // chunk_size + + output = torch.empty_like(query) + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = chunk_start + chunk_size + # NOTE(woosuk): We skip this line because it causes Dynamo + # compilation error. Instead, we rely on the slice operation + # to handle the out-of-bound case. + # chunk_end = min(chunk_end, batch_size) + chunk_output = paged_attention( + query[chunk_start:chunk_end], + key_cache, + value_cache, + attn_metadata.context_lens[chunk_start:chunk_end], + attn_metadata.block_tables[chunk_start:chunk_end], + pages_per_compute_block, + self.megacore_mode, + attn_logits_soft_cap=self.logits_soft_cap, + ) + output[chunk_start:chunk_end] = chunk_output + + # Reshape the output tensor. + return output.reshape(batch_size, seq_len, hidden_size) + + +def write_to_kv_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) + torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) + + key = key.flatten(0, 2) + value = value.flatten(0, 2) + key_cache = key_cache.flatten(0, 2) + value_cache = value_cache.flatten(0, 2) + key_cache.index_copy_(0, slot_mapping, key) + value_cache.index_copy_(0, slot_mapping, value) + + +def paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + pages_per_compute_block: int, + megacore_mode: Optional[str], + *, + attn_logits_soft_cap: Optional[float], +) -> torch.Tensor: + batch_size = query.shape[0] + if megacore_mode == "batch" and batch_size % 2 != 0: + megacore_mode = None + else: + megacore_mode = megacore_mode + + return torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + megacore_mode=megacore_mode, + attn_logits_soft_cap=attn_logits_soft_cap, + ) diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/placeholder_attn.py b/.venv/lib/python3.11/site-packages/vllm/attention/backends/placeholder_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..9f6e731afd1930214cd35b9b700b3a00a70952c5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/backends/placeholder_attn.py @@ -0,0 +1,410 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.multimodal import MultiModalPlaceholderMap + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + +# Placeholder attention backend for models like Mamba and pooling models that +# lack attention. + + +class PlaceholderAttentionBackend(AttentionBackend): + """Placeholder backend for when no attention is needed.""" + + @staticmethod + def get_name() -> str: + return "NO_ATTENTION" + + @staticmethod + def get_impl_cls() -> Type["PlaceholderAttentionImpl"]: + return PlaceholderAttentionImpl + + @staticmethod + def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]: + return PlaceholderAttentionMetadataBuilder + + @staticmethod + def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]: + return PlaceholderAttentionMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (1, 1, 1, 1, 1) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + return + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + return + + +@dataclass +class PlaceholderAttentionMetadata(AttentionMetadata): + """Attention metadata for prefill and decode batched together.""" + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # Maximum query length in the batch. + max_query_len: Optional[int] + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None + _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.seq_start_loc is not None + + # Placeholders + slot_mapping = torch.empty(0) + block_tables = torch.empty(0) + + self._cached_prefill_metadata = PlaceholderAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_decode_query_len=0, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=block_tables, + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.seq_lens_tensor is not None + + # Placeholders + slot_mapping = torch.empty(0) + block_tables = torch.empty(0) + + self._cached_decode_metadata = PlaceholderAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_decode_query_len=self.max_decode_query_len, + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=block_tables, + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + assert not turn_prefills_into_decodes, \ + ("Multi-Step + Chunked-Prefill is not supported for attention-free" + "models. turn_prefills_into_decodes is a " + "Multi-Step + Chunked-Prefill specific parameter.") + + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + # Update sequences, masking off entries greater than num_queries + device = self.seq_lens_tensor.device + mask = torch.arange(self.seq_lens_tensor.size(0), + device=device) < num_queries + self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype) + if sampled_token_ids is not None: + model_input.input_tokens.masked_scatter_( + mask, sampled_token_ids[:num_queries]) + + +class PlaceholderAttentionMetadataBuilder( + AttentionMetadataBuilder[PlaceholderAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + + self.input_builder = input_builder + self.runner = input_builder.runner + + def prepare(self): + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + """ + is_prompt = inter_data.is_prompt + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + logits_soft_cap = getattr(self.runner.model_config.hf_config, + "attn_logit_softcapping", None) + if logits_soft_cap is not None: + raise ValueError( + "Please use Flashinfer backend for models with logits_soft_cap" + " (i.e., Gemma-2). Otherwise, the output might be wrong." + " Set Flashinfer backend by " + "export VLLM_ATTENTION_BACKEND=FLASHINFER.") + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + if use_captured_graph: + num_decode_tokens = batch_size + + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + context_lens_tensor = torch.tensor(self.context_lens, + dtype=torch.int, + device=device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + # Placeholders + slot_mapping = torch.empty(0) + block_tables = torch.empty(0) + + return PlaceholderAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + +class PlaceholderAttentionImpl(AttentionImpl): + + def __init__(self, *args, **kwargs) -> None: + return + + def forward(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/rocm_flash_attn.py b/.venv/lib/python3.11/site-packages/vllm/attention/backends/rocm_flash_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..02bff57a62b7cb506c82d28c6aaafbd683be4f02 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/backends/rocm_flash_attn.py @@ -0,0 +1,891 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer ROCm GPUs.""" +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.logger import init_logger +from vllm.platforms import current_platform + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + +_PARTITION_SIZE_ROCM = 512 +_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName +_ON_NAVI = "gfx1" in _GPU_ARCH +_ON_MI250_MI300 = any(arch in _GPU_ARCH + for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"]) + + +class ROCmFlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "ROCM_FLASH" + + @staticmethod + def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: + return ROCmFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return ROCmFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]: + return ROCmFlashAttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] = None + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None + + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.block_tables is not None + + self._cached_prefill_metadata = ROCmFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = ROCmFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + # Batch may be composed of prefill|decodes, adjust query start indices + # to refer to the start of decodes when the two are split apart. + # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + if self._cached_decode_metadata.query_start_loc is not None: + qs = self._cached_decode_metadata.query_start_loc + self._cached_decode_metadata.query_start_loc = qs - qs[0] + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + + assert not turn_prefills_into_decodes, \ + ("Chunked prefill is not supported with rocm_flash_attn yet." + "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " + "specific parameter.") + + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class ROCmFlashAttentionMetadataBuilder( + CommonMetadataBuilder[ROCmFlashAttentionMetadata]): + + _metadata_cls = ROCmFlashAttentionMetadata + + +def _make_alibi_bias(alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: Optional[List[int]], + make_attn_mask: bool = True) -> List[torch.Tensor]: + attn_biases = [] + if seq_lens: + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat( + (num_heads, 1, 1)).to(alibi_slopes.device) + bias.mul_(alibi_slopes[:, None, None]) + if make_attn_mask: + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to( + alibi_slopes.device) + attn_biases.append((bias + inf_mask).to(dtype)) + else: + attn_biases.append(bias.to(dtype)) + + return attn_biases + + +def _get_seq_len_block_table_args( + attn_metadata: ROCmFlashAttentionMetadata, + attn_type: str, +) -> tuple: + ''' + The particular choice of sequence-length + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths + Encoder attn -> select encoder sequence lengths fields + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensors for query and key + * Appropriate max sequence-length scalar + ''' + + partial_prefix_sum = 0 + if attn_type == AttentionType.ENCODER: + assert attn_metadata.encoder_seq_lens is not None + assert attn_metadata.encoder_seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.encoder_seq_lens + ], + device=attn_metadata.encoder_seq_lens_tensor.device, + dtype=attn_metadata.encoder_seq_lens_tensor.dtype) + causal_mask = False + + # No block tables associated with encoder attention + return (query_seq_start_loc, attn_metadata.max_encoder_seq_len, + query_seq_start_loc, attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_lens, causal_mask) + elif attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + assert attn_metadata.seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.seq_lens + ], + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + max_seq_len = attn_metadata.max_prefill_seq_len + causal_mask = True + + return (query_seq_start_loc, max_seq_len, query_seq_start_loc, + max_seq_len, attn_metadata.seq_lens, causal_mask) + elif attn_type == AttentionType.ENCODER_DECODER: + assert attn_metadata.seq_lens is not None + assert attn_metadata.encoder_seq_lens_tensor is not None + query_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.seq_lens + ], + device=attn_metadata.encoder_seq_lens_tensor.device, + dtype=attn_metadata.encoder_seq_lens_tensor.dtype) + + partial_prefix_sum = 0 + assert attn_metadata.encoder_seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + key_seq_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.encoder_seq_lens + ], + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + causal_mask = False + + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (query_start_loc, attn_metadata.max_prefill_seq_len, + key_seq_start_loc, attn_metadata.max_encoder_seq_len, + attn_metadata.seq_lens, causal_mask) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +class ROCmFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| + |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "ROCmFlashAttention does not support blocksparse attention.") + + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + self.logits_soft_cap = 0.0 + else: + self.logits_soft_cap = logits_soft_cap + self.attn_type = attn_type + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + + self.use_naive_attn = False + # NOTE: Allow for switching between Triton and CK. Defaulting to triton. + self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN + if self.use_triton_flash_attn: + if logits_soft_cap is not None: + raise ValueError( + "ROCm Triton FlashAttention does not support attention" + "logits soft capping." + " please try using the ROCm CK " + "FA backend instead by setting the env var " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") + + from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 + triton_attention) + self.attn_func = triton_attention + logger.debug("Using Triton FA in ROCmBackend") + if self.sliding_window != (-1, -1): + logger.warning("ROCm Triton FA does not currently support " + "sliding window attention. If using half " + "precision, please try using the ROCm CK " + "FA backend instead by setting the env var " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") + else: + # if not using triton, navi3x/navi21/navi10 do not use flash-attn + # either + if not current_platform.has_device_capability(90): + self.use_naive_attn = True + else: + try: + from flash_attn import flash_attn_varlen_func # noqa: F401 + self.attn_func = flash_attn_varlen_func + logger.debug("Using CK FA in ROCmBackend") + except ModuleNotFoundError: + self.use_naive_attn = True + + if self.use_naive_attn: + if logits_soft_cap is not None: + raise ValueError( + "ROCm Naive FlashAttention does not support" + "attention logits soft capping.") + + self.attn_func = _sdpa_attention + logger.debug("Using naive (SDPA) attention in ROCmBackend") + + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + tokens, n_kv_heads, head_dim = x.shape + return (x[:, :, + None, :].expand(tokens, n_kv_heads, n_rep, + head_dim).reshape(tokens, n_kv_heads * n_rep, + head_dim)) + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: ROCmFlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + For decoder-only models: query, key and value must be non-None. + + For encoder/decoder models: + * ROCmFlashAttentionImpl.forward() may be invoked for both self- and + cross-attention layers. + * For self-attention: query, key and value must be non-None. + * For cross-attention: + * Query must be non-None + * During prefill, key and value must be non-None; key and value + get cached for use during decode. + * During decode, key and value may be None, since: + (1) key and value tensors were cached during prefill, and + (2) cross-attention key and value tensors do not grow during + decode + + A note on how the attn_type (attention type enum) argument impacts + attention forward() behavior: + + * DECODER: normal decoder-only behavior; + use decoder self-attention block table + * ENCODER: no KV caching; pass encoder sequence + attributes (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) to kernel, in lieu of decoder + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) + * ENCODER_DECODER: cross-attention behavior; + use cross-attention block table for caching KVs derived + from encoder hidden states; since KV sequence lengths + will match encoder sequence lengths, pass encoder sequence + attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally + Returns: + shape = [num_tokens, num_heads * head_size] + """ + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None + + if self.attn_type != AttentionType.ENCODER and kv_cache.numel() > 0: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + if key is not None and value is not None: + # Reshape the input keys and values and store them in the + # cache. If kv_cache is not provided, the new key and value + # tensors are not cached. This happens during the initial + # memory profiling run. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping + if self.attn_type != AttentionType.ENCODER_DECODER else + attn_metadata.cross_slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.attn_type != AttentionType.ENCODER: + num_prefill_tokens = attn_metadata.num_prefill_tokens + else: + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + + if key is not None and value is not None \ + and self.attn_type != AttentionType.ENCODER_DECODER: + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + # normal attention and DECODER + if self.attn_type == AttentionType.DECODER and ( + kv_cache.numel() == 0 or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, + key_max_seq_len, seq_lens, + causal_mask) = (prefill_meta.seq_start_loc, + prefill_meta.max_prefill_seq_len, + prefill_meta.seq_start_loc, + prefill_meta.max_prefill_seq_len, + attn_metadata.seq_lens, True) + # prefix-enabled attention and ENCODER/ENCODER_DECODER + else: + (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, + key_max_seq_len, seq_lens, + causal_mask) = _get_seq_len_block_table_args( + prefill_meta, self.attn_type) + # Prompt run. + if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: + # triton attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + attn_masks = None + if self.use_triton_flash_attn: + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, + query.dtype, + seq_lens, + make_attn_mask=False) # type: ignore + out, _ = self.attn_func( + query, + key, + value, + None, + query_seq_start_loc, + key_seq_start_loc, + query_max_seq_len, + key_max_seq_len, + causal_mask, + self.scale, + attn_masks[0][None] + if attn_masks is not None else None, + ) + elif self.use_naive_attn: + if self.num_kv_heads != self.num_heads: + # Interleave for MQA workaround. + key = self.repeat_kv(key, self.num_queries_per_kv) + value = self.repeat_kv(value, self.num_queries_per_kv) + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, + query.dtype, + attn_metadata.seq_lens, + make_attn_mask=True) # type: ignore + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + # sdpa math backend attention + out = self.attn_func( + query, + key, + value, + query_seq_start_loc, + num_prefill_tokens, + self.num_heads, + self.head_size, + self.scale, + causal_mask, + attn_masks, + ) + else: + out = self.attn_func( + q=query, + k=key, + v=value, + cu_seqlens_q=query_seq_start_loc, + cu_seqlens_k=key_seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=key_max_seq_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ) + + # common code for prefill + assert output[:num_prefill_tokens].shape == out.shape + if output.shape[0] > num_prefill_tokens: + output[:num_prefill_tokens] = out + else: + output = out + else: + # prefix-enabled attention + output[:num_prefill_tokens] = PagedAttention.forward_prefix( + query, + key, + value, + self.kv_cache_dtype, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.query_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window[0], + layer._k_scale, + layer._v_scale, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + # Whether to use rocm custom paged attention or not + num_seqs, num_heads, head_size = decode_query.shape + block_size = value_cache.shape[3] + gqa_ratio = num_heads // self.num_kv_heads + use_custom = _use_rocm_custom_paged_attention( + decode_query.dtype, head_size, block_size, gqa_ratio, + decode_meta.max_decode_seq_len) + if use_custom: + max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type + != AttentionType.ENCODER_DECODER else + decode_meta.max_encoder_seq_len) + assert max_seq_len is not None + max_num_partitions = ( + (max_seq_len + _PARTITION_SIZE_ROCM - 1) // + _PARTITION_SIZE_ROCM) + assert _PARTITION_SIZE_ROCM % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + if num_prefill_tokens > 0: + out = output[num_prefill_tokens:] + else: + out = output + ops.paged_attention_rocm( + out, + exp_sums, + max_logits, + tmp_output, + decode_query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + decode_meta.block_tables + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, + decode_meta.seq_lens_tensor + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + else: + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, + key_cache, + value_cache, + decode_meta.block_tables + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, + decode_meta.seq_lens_tensor + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, + decode_meta.max_decode_seq_len + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.max_encoder_seq_len, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + layer._k_scale, + layer._v_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + +def _sdpa_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + seq_lens: List[int], + num_tokens: int, + num_heads: int, + head_size: int, + scale: float, + attn_masks: Optional[List[torch.Tensor]] = None, +) -> torch.Tensor: + start = 0 + output = torch.empty((num_tokens, num_heads, head_size), + dtype=query.dtype, + device=query.device) + + for i, seq_len in enumerate(seq_lens): + end = start + seq_len + with torch.backends.cuda.sdp_kernel(enable_math=True, + enable_flash=False, + enable_mem_efficient=False): + sub_out = torch.nn.functional.scaled_dot_product_attention( + query[:, start:end, :], + key[:, start:end, :], + value[:, start:end, :], + dropout_p=0.0, + is_causal=attn_masks is None, + attn_mask=attn_masks[i] if attn_masks else None, + scale=scale).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end + + return output + + +def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, + block_size: int, gqa_ratio: int, + max_seq_len: int) -> bool: + # rocm custom page attention not support on navi (gfx1*) + return (_ON_MI250_MI300 and not _ON_NAVI + and (qtype == torch.half or qtype == torch.bfloat16) + and (head_size == 64 or head_size == 128) + and (block_size == 16 or block_size == 32) + and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/torch_sdpa.py b/.venv/lib/python3.11/site-packages/vllm/attention/backends/torch_sdpa.py new file mode 100644 index 0000000000000000000000000000000000000000..25fe6ed95c5dff1ab32b43c9e841d376f9266602 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/backends/torch_sdpa.py @@ -0,0 +1,681 @@ +# SPDX-License-Identifier: Apache-2.0 +""" Attention layer with torch scaled_dot_product_attention + and PagedAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +from torch.nn.functional import scaled_dot_product_attention + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.ipex_attn import PagedAttention +from vllm.attention.ops.paged_attn import PagedAttentionMetadata +from vllm.logger import init_logger +from vllm.utils import make_tensor_with_pad +from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder + +logger = init_logger(__name__) + + +class TorchSDPABackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "TORCH_SDPA" + + @staticmethod + def get_impl_cls() -> Type["TorchSDPABackendImpl"]: + return TorchSDPABackendImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return TorchSDPAMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]: + return TorchSDPAMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for TorchSDPABackend. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + chunked_prefill: bool + seq_lens: Optional[List[int]] = None # For non-chunked prefill + + # For chunked prefill only + max_query_len: Optional[int] = None + max_kv_len: Optional[int] = None + query_start_loc: Optional[torch.Tensor] = None + kv_start_loc: Optional[torch.Tensor] = None + prefill_block_tables: Optional[torch.Tensor] = None + + # Begin encoder attn & enc/dec cross-attn fields... + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[torch.Tensor]] = None + self.encoder_attn_bias: Optional[List[torch.Tensor]] = None + self.cross_attn_bias: Optional[List[torch.Tensor]] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((self.encoder_seq_lens is not None) + and (self.encoder_seq_lens_tensor is not None) + and (self.max_encoder_seq_len is not None)) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return (self.is_all_encoder_attn_metadata_set + and (self.cross_slot_mapping is not None) + and (self.cross_block_tables is not None)) + + @property + def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: + if self.num_prefill_tokens == 0: + return None + return self + + @property + def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: + if self.num_decode_tokens == 0: + return None + return self + + def get_seq_lens( + self, + attn_type: str, + ): + ''' + Extract appropriate sequence lengths from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate sequence lengths tensor for query + * Appropriate sequence lengths tensor for key & value + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + seq_lens_q = self.seq_lens + seq_lens_kv = self.seq_lens + elif attn_type == AttentionType.ENCODER: + seq_lens_q = self.encoder_seq_lens + seq_lens_kv = self.encoder_seq_lens + elif attn_type == AttentionType.ENCODER_DECODER: + seq_lens_q = self.seq_lens + seq_lens_kv = self.encoder_seq_lens + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + return seq_lens_q, seq_lens_kv + + def get_attn_bias( + self, + attn_type: str, + ) -> Optional[List[torch.Tensor]]: + ''' + Extract appropriate attention bias from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate attention bias value given the attention type + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + return self.attn_bias + elif attn_type == AttentionType.ENCODER: + return self.encoder_attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + return self.cross_attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + def set_attn_bias( + self, + attn_bias: List[torch.Tensor], + attn_type: str, + ) -> None: + ''' + Update appropriate attention bias field of attention metadata, + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_bias: The desired attention bias value + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + self.attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER: + self.encoder_attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + self.cross_attn_bias = attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + def get_seq_len_block_table_args( + self, + attn_type: str, + ) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + return (self.seq_lens_tensor, self.max_decode_seq_len, + self.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, + self.cross_block_tables) + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, + None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): + + def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: + self.chunked_prefill = input_builder.chunked_prefill + self.input_builder = input_builder + + def prepare(self): + self.input_data = self.input_builder.input_data + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata: + input_data = self.input_data + prefill_seq_lens = seq_lens[0:input_data.num_prefills] + prefill_query_lens = query_lens[0:input_data.num_prefills] + slot_mapping = torch.tensor(input_data.slot_mapping, + dtype=torch.long, + device="cpu") + + # For chunked-prefill + if self.chunked_prefill and input_data.num_prefill_tokens != 0: + prefill_block_tables = make_tensor_with_pad( + self.input_data.prefill_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + query_lens_tensor = torch.tensor(prefill_query_lens, + dtype=torch.int32, + device="cpu") + kv_lens_tensor = torch.tensor(prefill_seq_lens, + dtype=torch.int32, + device="cpu") + query_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + kv_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(query_lens_tensor, + dim=0, + dtype=torch.int32, + out=query_start_loc[1:]) + torch.cumsum(kv_lens_tensor, + dim=0, + dtype=torch.int32, + out=kv_start_loc[1:]) + max_query_len = max(prefill_query_lens) + max_kv_len = max(prefill_seq_lens) + else: + prefill_block_tables = None + query_start_loc = None + kv_start_loc = None + max_query_len = None + max_kv_len = None + + # For paged attention + if input_data.num_decode_tokens != 0: + seq_lens_tensor = torch.tensor( + input_data.seq_lens[input_data.num_prefills:], + dtype=torch.int32, + device="cpu", + ) + block_tables = make_tensor_with_pad( + self.input_data.decode_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + else: + block_tables = torch.tensor([]) + seq_lens_tensor = torch.tensor( + input_data.seq_lens[:input_data.num_prefills], + dtype=torch.int32, + device="cpu", + ) + + # For multi-modal models + placeholder_index_maps = None + if len(input_data.multi_modal_inputs_list) != 0: + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + input_data.multi_modal_placeholder_maps.items() + } + + attn_metadata = TorchSDPAMetadata( + chunked_prefill=self.chunked_prefill, + seq_lens=prefill_seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_kv_len=max_kv_len, + query_start_loc=query_start_loc, + kv_start_loc=kv_start_loc, + max_decode_seq_len=input_data.max_decode_seq_len, + num_prefills=input_data.num_prefills, + num_prefill_tokens=input_data.num_prefill_tokens, + num_decode_tokens=input_data.num_decode_tokens, + block_tables=block_tables, + prefill_block_tables=prefill_block_tables, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, + ) + + return attn_metadata + + +class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "Torch SPDA does not support block-sparse attention.") + if logits_soft_cap is not None: + logger.warning_once("Torch SPDA does not support logits soft cap. " + "Outputs may be slightly off.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + if kv_cache_dtype != "auto": + raise NotImplementedError( + "Torch SDPA backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") + self.attn_type = attn_type + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: TorchSDPAMetadata, # type: ignore + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with torch SDPA and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + attn_type = self.attn_type + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None + + if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): + # KV-cache during decoder-self- or + # encoder-decoder-cross-attention, but not + # during encoder attention. + # + # Even if there are no new key/value pairs to cache, + # we still need to break out key_cache and value_cache + # i.e. for later use by paged attention + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + if (key is not None) and (value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + PagedAttention.write_to_paged_cache( + key, value, key_cache, value_cache, updated_slot_mapping, + self.kv_cache_dtype, layer._k_scale, layer._v_scale) + + if attn_type != AttentionType.ENCODER: + # Decoder self-attention supports chunked prefill. + # Encoder/decoder cross-attention requires no chunked + # prefill (100% prefill or 100% decode tokens, no mix) + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + else: + # Encoder attention - chunked prefill is not applicable; + # derive token-count from query shape & and treat them + # as 100% prefill tokens + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + num_decode_tokens = 0 + + if attn_type == AttentionType.DECODER: + # Only enforce this shape-constraint for decoder + # self-attention + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + if prefill_meta := attn_metadata.prefill_metadata: + assert attn_metadata.seq_lens is not None + if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore + self._run_sdpa_forward(output, + query, + key, + value, + prefill_meta, + attn_type=attn_type) + else: + # prefix-enabled attention + assert not self.need_mask + import intel_extension_for_pytorch.llm.modules as ipex_modules + output = torch.empty_like(query) + ipex_modules.PagedAttention.flash_attn_varlen_func( + output[:prefill_meta.num_prefill_tokens, :, :], + query[:prefill_meta.num_prefill_tokens, :, :], + key_cache, + value_cache, + prefill_meta.query_start_loc, + prefill_meta.kv_start_loc, + prefill_meta.max_query_len, + prefill_meta.max_kv_len, + self.scale, + True, + prefill_meta.prefill_block_tables, + self.alibi_slopes, + ) + + if decode_meta := attn_metadata.decode_metadata: + assert attn_type != AttentionType.ENCODER_ONLY, ( + "Encoder-only models should not have decode metadata.") + # Decoding run. + ( + seq_lens_arg, + max_seq_len_arg, + block_tables_arg, + ) = decode_meta.get_seq_len_block_table_args(attn_type) + + PagedAttention.forward_decode( + output[attn_metadata.num_prefill_tokens:, :, :], + query[attn_metadata.num_prefill_tokens:, :, :], + key_cache, + value_cache, + block_tables_arg, + seq_lens_arg, + max_seq_len_arg, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + layer._k_scale, + layer._v_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + def _run_sdpa_forward( + self, + output: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: TorchSDPAMetadata, + attn_type: str = AttentionType.DECODER, + ) -> None: + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + attn_masks = attn_metadata.get_attn_bias(attn_type) + if attn_masks is None: + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, query.dtype, + attn_metadata.seq_lens) # type: ignore + elif self.sliding_window is not None: + assert attn_metadata.seq_lens is not None + attn_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, self.sliding_window, + query.dtype) # type: ignore + else: + seq_lens, _ = attn_metadata.get_seq_lens(attn_type) + attn_masks = [None] * len(seq_lens) + attn_metadata.set_attn_bias(attn_masks, attn_type) + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + causal_attn = (attn_type == AttentionType.DECODER) + + seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) + start_q, start_kv = 0, 0 + for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, + attn_masks): + end_q = start_q + seq_len_q + end_kv = start_kv + seq_len_kv + sub_out = scaled_dot_product_attention( + query[None, :, start_q:end_q, :], + key[None, :, start_kv:end_kv, :], + value[None, :, start_kv:end_kv, :], + attn_mask=mask, + dropout_p=0.0, + is_causal=causal_attn and mask is None, + scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + output[start_q:end_q, :, :] = sub_out + start_q, start_kv = end_q, end_kv + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[torch.Tensor]: + attn_biases: List[torch.Tensor] = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat((num_heads, 1, 1)) + bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0) + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_sliding_window_bias( + seq_lens: List[int], + window_size: Optional[int], + dtype: torch.dtype, +) -> List[torch.Tensor]: + attn_biases: List[torch.Tensor] = [] + for seq_len in seq_lens: + tensor = torch.full( + (1, seq_len, seq_len), + dtype=dtype, + fill_value=1, + ) + shift = 0 + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + attn_biases.append(mask.to(dtype)) + + return attn_biases diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/triton_mla.py b/.venv/lib/python3.11/site-packages/vllm/attention/backends/triton_mla.py new file mode 100644 index 0000000000000000000000000000000000000000..9a1984a931b55ab62e52243c88d7806ee1160cad --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/backends/triton_mla.py @@ -0,0 +1,746 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from itertools import accumulate +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +from vllm.multimodal import MultiModalPlaceholderMap + +try: + from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +except ImportError: + BatchDecodeMlaWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 + +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, AttentionType) +from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.attention.ops.triton_decode_attention import decode_attention_fwd +from vllm.utils import async_tensor_h2d, make_tensor_with_pad + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + + +class TritonMLABackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "TRITON_MLA" + + @staticmethod + def get_impl_cls() -> Type["TritonMLAImpl"]: + return TritonMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return TritonMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["TritonMLAMetadataBuilder"]: + return TritonMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["TritonMLAState"]: + return TritonMLAState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + ops.copy_blocks_mla(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [576] + + +class TritonMLAState(AttentionState): + + def __init__(self, runner): + self.runner = runner + self._is_graph_capturing = False + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + + self._positions = torch.zeros((max_batch_size, ), + dtype=torch.long, + device=self.runner.device) + + yield + + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + del self._positions + + def graph_clone(self, batch_size: int): + assert self._is_graph_capturing + return self.__class__(self.runner) + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + assert self._is_graph_capturing + + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=1, + max_decode_query_len=1, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + use_cuda_graph=True, + input_positions=self._positions[:batch_size], + head_dim=self.runner.model_config.get_head_size()) + + if is_encoder_decoder_model: + raise NotImplementedError( + "TritonMLAState does not support encoder/decoder yet") + + return attn_metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_buffers = { + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + "input_positions": attn_metadata.decode_metadata.input_positions, + } + if is_encoder_decoder_model: + raise NotImplementedError( + "TritonMLAState does not support encoder/decoder yet") + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_positions = attn_metadata.input_positions + num_positions = input_positions.shape[0] + input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + # CUDA graph buffer is padded so only perform a partial copy based on + # num_positions + input_buffers["input_positions"][:num_positions].copy_( + input_positions, non_blocking=True) + if is_encoder_decoder_model: + raise NotImplementedError( + "TritonMLAState does not support encoder/decoder yet") + + def begin_forward(self, model_input): + return + + +@dataclass +class TritonMLAMetadata(MLACommonMetadata): + """Metadata for TritonMLAMetadata. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + + use_cuda_graph: bool + + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + _cached_prefill_metadata: Optional["TritonMLAMetadata"] = None + _cached_decode_metadata: Optional["TritonMLAMetadata"] = None + + num_prefill_tokens: int + + num_kv_splits: int = 4 # TODO(lucas) add heuristic + attn_logits: Optional[torch.Tensor] = None + req_idx: Optional[torch.Tensor] = None + + # The dimension of the attention heads + head_dim: Optional[int] = None + + def __post_init__(self): + supported_head_sizes = TritonMLABackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f"received {self.head_dim}.") + + @property + def prefill_metadata(self) -> Optional["TritonMLAMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + input_positions = (None if self.input_positions is None else + self.input_positions[:self.num_prefill_tokens]) + + self._cached_prefill_metadata = TritonMLAMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + input_positions=input_positions, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + head_dim=self.head_dim) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["TritonMLAMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.seq_lens_tensor is not None + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + input_positions = (None if self.input_positions is None else + self.input_positions[self.num_prefill_tokens:]) + + self._cached_decode_metadata = TritonMLAMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=self.max_query_len, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + # Batch may be composed of prefill|decodes, adjust query start + # indices to refer to the start of decodes. E.g. + # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=block_tables, + use_cuda_graph=self.use_cuda_graph, + input_positions=input_positions, + head_dim=self.head_dim) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + if turn_prefills_into_decodes: + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.input_positions: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block, input_positions) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks, + inter_data.input_positions): + self.input_positions.extend(input_positions) + self.context_lens.append(context_len) + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def _get_graph_runner_block_tables( + self, num_seqs: int, + block_tables: List[List[int]]) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + graph_block_tables = self.runner.graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + num_seqs = len(seq_lens) + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + input_positions = async_tensor_h2d(self.input_positions, torch.long, + device, self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + + return TritonMLAMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, + input_positions=input_positions, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + num_kv_splits=4, # TODO(lucas) add heuristic + head_dim=self.runner.model_config.get_head_size(), + ) + + +class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **kwargs) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **kwargs) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "TritonMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonMLAImpl") + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + attn_metadata: TritonMLAMetadata, + ) -> torch.Tensor: + assert isinstance(attn_metadata, TritonMLAMetadata) + return self._forward_prefill_flash(q, kv_c_normed, k_pe, + attn_metadata.seq_start_loc, + attn_metadata.max_prefill_seq_len) + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: TritonMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError("FP8 Triton MLA not yet supported") + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + B = q_nope.shape[0] + + q = torch.cat([q_nope, q_pe], dim=-1) + o = torch.zeros(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + + # TODO(lucas) Allocate ahead of time + attn_logits = torch.empty( + ( + B, + self.num_heads, + attn_metadata.num_kv_splits, + # NOTE(lucas) idk why the +1 is here but sglang has it so we + # just mirror that + self.kv_lora_rank + 1, + ), + dtype=torch.float32, + device=q.device, + ) + + # Add a head dim of 1 + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) + kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] + PAGE_SIZE = kv_c_and_k_pe_cache.size(1) + + # Run MQA + decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, attn_logits, + attn_metadata.num_kv_splits, self.scale, + PAGE_SIZE) + + return self._v_up_proj_and_o_proj(o) diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/utils.py b/.venv/lib/python3.11/site-packages/vllm/attention/backends/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ad53e4e70b0fade7409f2c48dc706ea49a6d3c46 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/backends/utils.py @@ -0,0 +1,582 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention backend utils""" +from collections import defaultdict +from contextlib import contextmanager +from itertools import accumulate +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union + +import numpy as np +import torch + +from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, + AttentionState) +from vllm.attention.backends.abstract import AttentionType +from vllm.multimodal import MultiModalPlaceholderMap +from vllm.utils import async_tensor_h2d, make_tensor_with_pad + +if TYPE_CHECKING: + from vllm.worker.model_runner_base import ModelRunnerBase + +# Error string(s) for encoder/decoder +# unsupported attention scenarios +STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " + "with encoder/decoder models.") + +PAD_SLOT_ID = -1 + +# Switch to numpy implementation of compute_slot_mapping +# if we have at least this many elements. Could be tuned further. +_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256 + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + + +def is_block_tables_empty(block_tables: Union[None, Dict]): + """ + Check if block_tables is None or a dictionary with all None values. + """ + if block_tables is None: + return True + return (isinstance(block_tables, dict) + and all(value is None for value in block_tables.values())) + + +def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, + context_len: int, sliding_window: int): + """ + Compute the start index of slot mapping. + """ + start_idx = 0 + if is_prompt and sliding_window is not None: + start_idx = max(0, query_len - sliding_window) + return start_idx + + +def _compute_slot_mapping_python(slot_mapping: List[int], + block_table: List[int], range_start: int, + range_end: int, block_size: int): + for i in range(range_start, range_end): + block_number = block_table[i // block_size] + block_offset = i % block_size + slot = block_number * block_size + block_offset + slot_mapping.append(slot) + + +def _compute_slot_mapping_numpy(slot_mapping: List[int], + block_table: List[int], range_start: int, + range_end: int, block_size: int): + block_table_array = np.array(block_table) + idx = np.arange(range_start, range_end) + block_offset = idx % block_size + idx //= block_size + seq_slot_mapping_array = block_table_array[idx] + seq_slot_mapping_array *= block_size + seq_slot_mapping_array += block_offset + slot_mapping.extend(seq_slot_mapping_array) + + +def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], + seq_id: int, seq_len: int, context_len: int, + start_idx: int, block_size: int, + block_tables: Dict[int, List[int]]): + """ + Compute slot mapping. + """ + if is_profile_run: + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + slot_mapping.extend([PAD_SLOT_ID] * seq_len) + return + + # Mask the [0, start_idx) tokens of the prompt with + # PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + padding_mask_len = max(0, start_idx - context_len) + slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len) + + range_start = max(start_idx, context_len) + range_end = seq_len + numel = range_end - range_start + block_table = block_tables[seq_id] + + # numpy implementation will be faster than python if we have + # many elements, otherwise it will be slower. + if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL: + _compute_slot_mapping_python(slot_mapping, block_table, range_start, + range_end, block_size) + else: + _compute_slot_mapping_numpy(slot_mapping, block_table, range_start, + range_end, block_size) + + +TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata') + + +class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): + + _metadata_cls: Type[TAttentionMetadata] + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if inter_data.prefix_cache_hit: + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.runner.graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.from_numpy(input_block_tables).to( + device, non_blocking=True) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, "query_lens: {}".format(query_lens) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + + return self._metadata_cls( # type: ignore + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + +class CommonAttentionState(AttentionState): + + def __init__(self, runner: "ModelRunnerBase"): + self.runner = runner + self._is_graph_capturing = False + + @contextmanager + def graph_capture(self, max_batch_size: int): + + self._is_graph_capturing = True + + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + + yield + + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + + def graph_clone(self, batch_size: int) -> "CommonAttentionState": + assert self._is_graph_capturing + return self.__class__(self.runner) + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + assert self._is_graph_capturing + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=1, + max_decode_query_len=1, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + use_cuda_graph=True, + ) + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or " \ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" + self._update_captured_metadata_for_enc_dec_model( + batch_size=batch_size, attn_metadata=attn_metadata) + + return attn_metadata + + def get_graph_input_buffers( + self, + attn_metadata, + is_encoder_decoder_model: bool = False) -> Dict[str, Any]: + input_buffers = { + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + } + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or "\ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" + self._add_additonal_input_buffers_for_enc_dec_model( + attn_metadata=attn_metadata, input_buffers=input_buffers) + return input_buffers + + def prepare_graph_input_buffers( + self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False) -> None: + input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or "\ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" + self._prepare_input_buffers_for_enc_dec_model( + attn_metadata, input_buffers) + + def begin_forward(self, model_input) -> None: + return + + def _update_captured_metadata_for_enc_dec_model(self, batch_size: int, + attn_metadata): + """ + Updates the attention metadata parameters for CUDA graph capture in an + encoder-decoder model. + + This method modifies attention-related tensors and metadata required + for CUDA graph capture in encoder-decoder models. Specifically, it + updates the cross-attention and encoder sequence tensors in the + AttentionMetadata object. + """ + # During decode phase the cross_slot_mapping will be empty. Hence set + # an empty tensor for CUDA Graph capture. + attn_metadata.cross_slot_mapping = torch.tensor( + [], dtype=torch.int).cuda() + attn_metadata.cross_block_tables = torch.full( + (batch_size, self.runner.get_max_block_per_batch()), + 1, + dtype=torch.int).cuda() + attn_metadata.encoder_seq_lens = torch.full((batch_size, ), + 1, + dtype=torch.int).cuda() + attn_metadata.encoder_seq_lens_tensor = torch.full( + (batch_size, ), 1, dtype=torch.int).cuda() + attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture + attn_metadata.num_encoder_tokens = 0 + + def _add_additonal_input_buffers_for_enc_dec_model( + self, attn_metadata, input_buffers: Dict[str, Any]): + """ + Saves additional input buffers specific to the encoder-decoder model + from the attention metadata. + + This method extracts and stores encoder-decoder related input buffers + from the `attn_metadata` into the `input_buffers` dictionary. The + buffers include encoder sequence lengths, cross-slot mappings, and + cross-block tables, which are essential for the encoder-decoder model + during CUDA graph replay. + """ + input_buffers["encoder_seq_lens_tensor"] = ( + attn_metadata.decode_metadata.encoder_seq_lens_tensor) + input_buffers["cross_slot_mapping"] = ( + attn_metadata.decode_metadata.cross_slot_mapping) + input_buffers["cross_block_tables"] = ( + attn_metadata.decode_metadata.cross_block_tables) + + def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata, + input_buffers: Dict[str, + Any]): + """ + Populates input buffers with data from the encoder-decoder model's + attention metadata. + + This method fills the input buffers with encoder-decoder specific + tensors. It copies data from the `attn_metadata` and keyword arguments + (`kwargs`) into corresponding buffers in the `input_buffers` dictionary. + The copied data includes attention-related metadata as well as input + IDs and positional information for the encoder. + """ + input_buffers["encoder_seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.encoder_seq_lens_tensor, + non_blocking=True) + input_buffers["cross_slot_mapping"].copy_( + attn_metadata.decode_metadata.cross_slot_mapping, + non_blocking=True) + input_buffers["cross_block_tables"].copy_( + attn_metadata.decode_metadata.cross_block_tables, + non_blocking=True) + + +def is_all_encoder_attn_metadata_set(attn_metadata): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((attn_metadata.encoder_seq_lens is not None) + and (attn_metadata.encoder_seq_lens_tensor is not None) + and (attn_metadata.max_encoder_seq_len is not None)) + + +def is_all_cross_attn_metadata_set(attn_metadata): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return (attn_metadata.is_all_encoder_attn_metadata_set + and (attn_metadata.cross_slot_mapping is not None) + and (attn_metadata.cross_block_tables is not None)) + + +def get_seq_len_block_table_args( + attn_metadata, + is_prompt: bool, + attn_type: str, +) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_lens_tensor, max_seq_len, + attn_metadata.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_block_tables) + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def get_num_prefill_decode_query_kv_tokens( + attn_metadata, + attn_type: str, +) -> Tuple[int, int, int]: + """ + Calculate the number of prefill and decode tokens for query, key/value + based on the attention metadata and the specified attention type. + + Args: + attn_metadata (FlashAttentionMetadata): Attention Metadata object. + attn_type (AttentionType): The type of attention being used. + Returns: + Tuple[int, int, int]: A tuple containing three integers: + - The number of prefill query tokens. + - The number of prefill key/value tokens. + - The number of decode query tokens. + + Raises: + AssertionError: If the number of encoder tokens in `attn_metadata` + is `None` when required for the calculations. + """ + num_prefill_query_tokens = 0 + num_decode_query_tokens = 0 + num_prefill_kv_tokens = 0 + if attn_type == AttentionType.ENCODER: + # Encoder attention is only invoked during prefill phase. + # The same input servers a both query and key. + assert attn_metadata.num_encoder_tokens is not None + num_prefill_query_tokens = attn_metadata.num_encoder_tokens + num_prefill_kv_tokens = attn_metadata.num_encoder_tokens + num_decode_query_tokens = 0 + elif attn_type == AttentionType.ENCODER_DECODER: + assert attn_metadata.num_encoder_tokens is not None + num_prefill_query_tokens = attn_metadata.num_prefill_tokens + # The key is the encoder/cross-attention. + num_prefill_kv_tokens = attn_metadata.num_encoder_tokens + num_decode_query_tokens = attn_metadata.num_decode_tokens + else: # attn_type == AttentionType.DECODER or + # attn_type == AttentionType.ENCODER_ONLY + num_prefill_query_tokens = attn_metadata.num_prefill_tokens + num_prefill_kv_tokens = attn_metadata.num_prefill_tokens + num_decode_query_tokens = attn_metadata.num_decode_tokens + + return (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/backends/xformers.py b/.venv/lib/python3.11/site-packages/vllm/attention/backends/xformers.py new file mode 100644 index 0000000000000000000000000000000000000000..723a4558d0b35af255608581d2f57ab4a7e6ddb4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/backends/xformers.py @@ -0,0 +1,794 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer with xFormers and PagedAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import (AttentionBias, + BlockDiagonalCausalMask, + BlockDiagonalMask, + LowerTriangularMaskWithTensorBias) + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import ( + CommonAttentionState, CommonMetadataBuilder, + get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, + is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class XFormersBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "XFORMERS" + + @staticmethod + def get_impl_cls() -> Type["XFormersImpl"]: + return XFormersImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return XFormersMetadata + + @staticmethod + def get_builder_cls() -> Type["XFormersMetadataBuilder"]: + return XFormersMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for XFormersbackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # FIXME: It is for flash attn. + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] = None + + # FIXME: It is for flash attn. + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] = None + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + + # Self-attention prefill/decode metadata cache + _cached_prefill_metadata: Optional["XFormersMetadata"] = None + _cached_decode_metadata: Optional["XFormersMetadata"] = None + + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # FIXME: It is for flash attn. + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[AttentionBias]] = None + self.encoder_attn_bias: Optional[List[AttentionBias]] = None + self.cross_attn_bias: Optional[List[AttentionBias]] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return is_all_encoder_attn_metadata_set(self) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return is_all_cross_attn_metadata_set(self) + + @property + def prefill_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + # Recover cached prefill-phase attention + # metadata structure + return self._cached_prefill_metadata + + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + + # Construct & cache prefill-phase attention metadata structure + self._cached_prefill_metadata = XFormersMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + # Recover cached decode-phase attention + # metadata structure + return self._cached_decode_metadata + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + + # Construct & cache decode-phase attention metadata structure + self._cached_decode_metadata = XFormersMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens_tensor=seq_lens_tensor, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + block_tables=block_tables, + use_cuda_graph=self.use_cuda_graph, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + + # Batch may be composed of prefill|decodes, adjust query start indices + # to refer to the start of decodes when the two are split apart. + # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + if self._cached_decode_metadata.query_start_loc is not None: + qs = self._cached_decode_metadata.query_start_loc + self._cached_decode_metadata.query_start_loc = qs - qs[0] + return self._cached_decode_metadata + + +def _get_attn_bias( + attn_metadata: XFormersMetadata, + attn_type: str, +) -> Optional[AttentionBias]: + ''' + Extract appropriate attention bias from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate attention bias value given the attention type + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + return attn_metadata.attn_bias + elif attn_type == AttentionType.ENCODER: + return attn_metadata.encoder_attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + return attn_metadata.cross_attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def _set_attn_bias( + attn_metadata: XFormersMetadata, + attn_bias: List[Optional[AttentionBias]], + attn_type: str, +) -> None: + ''' + Update appropriate attention bias field of attention metadata, + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_bias: The desired attention bias value + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + attn_metadata.attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER: + attn_metadata.encoder_attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + attn_metadata.cross_attn_bias = attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]): + + _metadata_cls = XFormersMetadata + + +class XFormersImpl(AttentionImpl[XFormersMetadata]): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "XFormers does not support block-sparse attention.") + if logits_soft_cap is not None: + logger.warning_once("XFormers does not support logits soft cap. " + "Outputs may be slightly off.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + self.attn_type = attn_type + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + kv_cache: torch.Tensor, + attn_metadata: "XFormersMetadata", + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with xFormers and PagedAttention. + + For decoder-only models: query, key and value must be non-None. + + For encoder/decoder models: + * XFormersImpl.forward() may be invoked for both self- and cross- + attention layers. + * For self-attention: query, key and value must be non-None. + * For cross-attention: + * Query must be non-None + * During prefill, key and value must be non-None; key and value + get cached for use during decode. + * During decode, key and value may be None, since: + (1) key and value tensors were cached during prefill, and + (2) cross-attention key and value tensors do not grow during + decode + + A note on how the attn_type (attention type enum) argument impacts + attention forward() behavior: + + * DECODER: normal decoder-only behavior; + use decoder self-attention block table + * ENCODER: no KV caching; pass encoder sequence + attributes (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) to kernel, in lieu of decoder + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len). + Used for encoder branch of encoder-decoder models. + * ENCODER_ONLY: no kv_caching, uses the normal attention + attributes (seq_lens/seq_lens_tensor/max_seq_len). + * ENCODER_DECODER: cross-attention behavior; + use cross-attention block table for caching KVs derived + from encoder hidden states; since KV sequence lengths + will match encoder sequence lengths, pass encoder sequence + attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally + Returns: + shape = [num_tokens, num_heads * head_size] + """ + attn_type = self.attn_type + # Check that appropriate attention metadata attributes are + # selected for the desired attention type + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None + + # Self-attention vs. cross-attention will impact + # which KV cache memory-mapping & which + # seqlen datastructures we utilize + + if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): + # KV-cache during decoder-self- or + # encoder-decoder-cross-attention, but not + # during encoder attention. + # + # Even if there are no new key/value pairs to cache, + # we still need to break out key_cache and value_cache + # i.e. for later use by paged attention + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + if (key is not None) and (value is not None): + + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory + # profiling run. + PagedAttention.write_to_paged_cache( + key, value, key_cache, value_cache, updated_slot_mapping, + self.kv_cache_dtype, layer._k_scale, layer._v_scale) + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_query_tokens:] + # QKV for prefill. + query = query[:num_prefill_query_tokens] + if key is not None and value is not None: + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] + + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: + # normal attention. + # block tables are empty if the prompt does not have a cached + # prefix. + out = self._run_memory_efficient_xformers_forward( + query, key, value, prefill_meta, attn_type=attn_type) + assert out.shape == output[:num_prefill_query_tokens].shape + output[:num_prefill_query_tokens] = out + else: + assert attn_type != AttentionType.ENCODER_ONLY, ( + "Encoder-only models should not have prefix attention.") + + assert prefill_meta.query_start_loc is not None + assert prefill_meta.max_query_len is not None + + # prefix-enabled attention + # TODO(Hai) this triton kernel has regression issue (broke) to + # deal with different data types between KV and FP8 KV cache, + # to be addressed separately. + out = PagedAttention.forward_prefix( + query, + key, + value, + self.kv_cache_dtype, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.query_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window, + layer._k_scale, + layer._v_scale, + ) + assert output[:num_prefill_query_tokens].shape == out.shape + output[:num_prefill_query_tokens] = out + + if decode_meta := attn_metadata.decode_metadata: + assert attn_type != AttentionType.ENCODER_ONLY, ( + "Encoder-only models should not have decode metadata.") + + ( + seq_lens_arg, + max_seq_len_arg, + block_tables_arg, + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) + + output[num_prefill_query_tokens:] = PagedAttention.forward_decode( + decode_query, + key_cache, + value_cache, + block_tables_arg, + seq_lens_arg, + max_seq_len_arg, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + layer._k_scale, + layer._v_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + def _run_memory_efficient_xformers_forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: XFormersMetadata, + attn_type: str = AttentionType.DECODER, + ) -> torch.Tensor: + """Attention for 1D query of multiple prompts. Multiple prompt + tokens are flattened in to `query` input. + + See https://facebookresearch.github.io/xformers/components/ops.html + for API spec. + + Args: + output: shape = [num_prefill_tokens, num_heads, head_size] + query: shape = [num_prefill_tokens, num_heads, head_size] + key: shape = [num_prefill_tokens, num_kv_heads, head_size] + value: shape = [num_prefill_tokens, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally + """ + + original_query = query + if self.num_kv_heads != self.num_heads: + # GQA/MQA requires the shape [B, M, G, H, K]. + # Note that the output also has the same shape (which is different + # from a spec from the doc). + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) + + # Set attention bias if not provided. This typically happens at + # the very attention layer of every iteration. + # FIXME(woosuk): This is a hack. + attn_bias = _get_attn_bias(attn_metadata, attn_type) + if attn_bias is None: + if self.alibi_slopes is None: + + # Cross attention block of decoder branch of encoder-decoder + # model uses seq_lens for dec / encoder_seq_lens for enc + if (attn_type == AttentionType.ENCODER_DECODER): + assert attn_metadata.seq_lens is not None + assert attn_metadata.encoder_seq_lens is not None + + # Cross-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.seq_lens, attn_metadata.encoder_seq_lens) + + # Encoder branch of encoder-decoder model uses + # attn_metadata.encoder_seq_lens + elif attn_type == AttentionType.ENCODER: + + assert attn_metadata.encoder_seq_lens is not None + + # Encoder self-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.encoder_seq_lens) + + # Self-attention block of encoder-only model just + # uses the seq_lens directly. + elif attn_type == AttentionType.ENCODER_ONLY: + assert attn_metadata.seq_lens is not None + + # Encoder self-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.seq_lens) + + # Self-attention block of decoder branch just + # uses the seq_lens directly + elif attn_type == AttentionType.DECODER: + assert attn_metadata.seq_lens is not None + + # Decoder self-attention mask is causal + attn_bias = BlockDiagonalCausalMask.from_seqlens( + attn_metadata.seq_lens) + else: + raise ValueError("Unknown AttentionType: %s", attn_type) + + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + attn_bias = [attn_bias] + else: + assert attn_type == AttentionType.DECODER + assert attn_metadata.seq_lens is not None + attn_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, query.dtype, + attn_metadata.seq_lens) + + _set_attn_bias(attn_metadata, attn_bias, attn_type) + + # No alibi slopes. + # TODO(woosuk): Too many view operations. Let's try to reduce + # them in the future for code readability. + if self.alibi_slopes is None: + # Add the batch dimension. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=attn_bias[0], + p=0.0, + scale=self.scale) + return out.view_as(original_query) + + # Attention with alibi slopes. + # FIXME(woosuk): Because xformers does not support dynamic sequence + # lengths with custom attention bias, we process each prompt one by + # one. This is inefficient, especially when we have many short prompts. + assert attn_metadata.seq_lens is not None + output = torch.empty_like(original_query) + start = 0 + for i, seq_len in enumerate(attn_metadata.seq_lens): + end = start + seq_len + out = xops.memory_efficient_attention_forward( + query[None, start:end], + key[None, start:end], + value[None, start:end], + attn_bias=attn_bias[i], + p=0.0, + scale=self.scale) + # TODO(woosuk): Unnecessary copy. Optimize. + output[start:end].copy_(out.view_as(original_query[start:end])) + start += seq_len + return output + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + num_kv_heads: int, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[AttentionBias]: + attn_biases: List[AttentionBias] = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + # Calculate a matrix where each element represents ith element- jth + # element. + bias = bias[None, :] - bias[:, None] + + padded_len = (seq_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + 1, # batch size + num_heads, + seq_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :seq_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + if num_heads != num_kv_heads: + bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) + attn_biases.append(LowerTriangularMaskWithTensorBias(bias)) + + return attn_biases diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/__init__.py b/.venv/lib/python3.11/site-packages/vllm/attention/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..979a5bc5485177064163b93dea11c77dca93ad68 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/hpu_paged_attn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/hpu_paged_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e06cfbd317f4a52eb8ba9db7d5b4c6e0abcd6eee Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/hpu_paged_attn.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/ipex_attn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/ipex_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2568fc55b17f7f0d9f291c39238f4d2a90040e6f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/ipex_attn.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/nki_flash_attn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/nki_flash_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58c91e56cb6a165bd85dccc36fb34a180a0da5df Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/nki_flash_attn.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/paged_attn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/paged_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16d190441d5fbb014fdd248531da401e20494f5c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/paged_attn.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/prefix_prefill.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/prefix_prefill.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac22ea9e49b57977f84f2ba257bdee3872c160d8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/prefix_prefill.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/triton_decode_attention.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/triton_decode_attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1218084639938b9571198ae856d39f801d9c2bf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/triton_decode_attention.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/triton_flash_attention.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/triton_flash_attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a84f91d6634418d48c51814bbdceec88b3d42aae Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/triton_flash_attention.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__init__.py b/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46175ab453dba555039cd89f3001288302f5acf3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/blocksparse_attention_kernel.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/blocksparse_attention_kernel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ed1db2f17c8d90d4e00b753c0ee7dab7b2f7744 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/blocksparse_attention_kernel.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/interface.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/interface.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82e1bd20938a4a774863b39186a071466ed26493 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/interface.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6106872127e89e2679b1e6648151cbb847193ef9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py b/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..71caf3cbac02c65095484eda96de4c5219bc56a5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py @@ -0,0 +1,432 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + + +def blocksparse_flash_attn_varlen_fwd( + q, + k, + v, # (#tokens, n_heads, head_size) + cu_seqlens_k, + cu_seqlens_q, + sm_scale, + sparse_layout, + *, + block_size=64, + q_block_size=None, + max_seqlen=None): + # split q to blocks + + assert isinstance(sparse_layout, (list, tuple)) + + _, n_heads, head_size = q.shape + batch_size = cu_seqlens_k.size(0) - 1 + q_block_size = q_block_size or block_size + + assert q.dim() == k.dim() == v.dim() == 3 + assert q.size(1) % k.size(1) == 0 + assert q.size(2) == k.size(2) + # TODO(linxihui): allow k, v to have different head_size + assert k.shape == v.shape + assert cu_seqlens_k.dim() == 1 + + q_k_ratio = q.size(1) // k.size(1) + + if cu_seqlens_q is None: + if q.size(0) == batch_size: # decoding only + cu_seqlens_q = torch.arange( + 0, + batch_size + 1, + dtype=cu_seqlens_k.dtype, + device=cu_seqlens_k.device, + ) + elif q.size(0) == k.size(0): + cu_seqlens_q = cu_seqlens_k + else: + raise ValueError("cu_seqlens_q must be specified\ + if it mix of prefilling and decoding.") + else: + assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0) + + # switch to use cpu to avoid too many kernel launches when iterated over + q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu() + k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu() + + assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), ( + "length of q should either be 1 (decoding) or same as k (prefilling).") + + if max_seqlen: + assert k_lens.max() <= max_seqlen + + n_blocks = (q_lens + q_block_size - 1) // q_block_size + + q_batch_ids = torch.tensor( + [i for i, n in enumerate(n_blocks) for _ in range(n)], + dtype=cu_seqlens_q.dtype, + device=cu_seqlens_q.device, + ) + q_start_sids = torch.tensor( + [i * q_block_size for n in n_blocks for i in range(n)], + dtype=cu_seqlens_q.dtype, + device=cu_seqlens_q.device, + ) + + out = q.new_empty(q.shape) + cu_seqlens_q = cu_seqlens_q.contiguous() + cu_seqlens_k = cu_seqlens_k.contiguous() + + layout_crow_indices, layout_col_indices = sparse_layout + block_d = triton.next_power_of_2(head_size) + + decoding_only = (q_lens == 1).all().item() + grid = (len(q_start_sids), n_heads, 1) + + _fwd_kernel_batch_inference[grid]( + q, + k, + v, + out, + sm_scale, + cu_seqlens_q[:-1], + cu_seqlens_q[1:], + cu_seqlens_k[:-1], + cu_seqlens_k[1:], + q_batch_ids, + q_start_sids, + 0, + *q.stride(), + 0, + *k.stride(), + 0, + *v.stride(), + 0, + *out.stride(), + layout_crow_indices, + layout_col_indices, + *layout_crow_indices.stride(), + *layout_col_indices.stride(), + q_k_ratio, + HAS_BATCH_DIM=False, + D_HEAD=head_size, + BLOCK_M=q_block_size, + BLOCK_N=block_size, + BLOCK_D=block_d, + BLOCK_M_LOADING=(16 if decoding_only else + q_block_size), # smaller for decoding + EVEN_D=block_d == head_size, + num_warps=1 if decoding_only else 4, + num_stages=3) + + return out + + +@triton.jit +def _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_col_idx, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + LAST_K_BLOCK: tl.constexpr, + BLOCK_M_LOADING: tl.constexpr, + BLOCK_N: tl.constexpr, + D_HEAD: tl.constexpr, + EVEN_D: tl.constexpr, + M_LT_N: tl.constexpr, +): + k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h + + k_block_col_idx * layout_col_stride_m).to(tl.int32) + start_n = k_block_id * BLOCK_N + if LAST_K_BLOCK: + if EVEN_D: + k = tl.load( + k_ptrs + start_n * stride_kt, + mask=offs_n[None, :] + start_n < k_seqlen, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kt, + mask=(offs_n[None, :] + start_n < k_seqlen) & + (offs_d[:, None] < D_HEAD), + other=0.0, + ) + else: + if EVEN_D: + k = tl.load(k_ptrs + start_n * stride_kt) + else: + k = tl.load(k_ptrs + start_n * stride_kt, + mask=offs_d[:, None] < D_HEAD, + other=0.0) + + qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N + if LAST_K_BLOCK | M_LT_N: + qk += tl.where( + offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), + 0, + float("-inf"), + ) + + # flash-attn2 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + # update m_i + m_i = m_ij + l_i = l_i * alpha + l_ij + + p = p.to(Q.dtype.element_ty) + # update acc + if LAST_K_BLOCK: + if EVEN_D: + v = tl.load( + v_ptrs + start_n * stride_vt, + mask=offs_n[:, None] + start_n < k_seqlen, + other=0.0, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vt, + mask=(offs_n[:, None] + start_n < k_seqlen) & + (offs_d[None, :] < D_HEAD), + other=0.0, + ) + else: + if EVEN_D: + v = tl.load(v_ptrs + start_n * stride_vt) + else: + v = tl.load(v_ptrs + start_n * stride_vt, + mask=offs_d[None, :] < D_HEAD, + other=0.0) + + acc += tl.dot(p, v) + + return acc, l_i, m_i + + +@triton.heuristics({ + "M_LT_N": + lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"], +}) +@triton.jit +def _fwd_kernel_batch_inference( + Q, + K, + V, + Out, + sm_scale, + q_batch_starts, + q_batch_ends, + k_batch_starts, + k_batch_ends, + q_batch_ids, + q_start_sids, + stride_qb, + stride_qt, + stride_qh, + stride_qd, + stride_kb, + stride_kt, + stride_kh, + stride_kd, + stride_vb, + stride_vt, + stride_vh, + stride_vd, + stride_ob, + stride_ot, + stride_oh, + stride_od, + layout_crow_ptr, + layout_col_ptr, + layout_crow_stride_h, + layout_crow_stride_m, + layout_col_stride_h, + layout_col_stride_m, + q_k_ratio, + HAS_BATCH_DIM: tl.constexpr, + D_HEAD: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_M_LOADING: tl.constexpr, + EVEN_D: tl.constexpr, + M_LT_N: tl.constexpr, +): + """ + NOTATION: + pid: position id + sid: storage id + sbid: storage block id + pbid: position block id + offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col) + + TODO(linxihui): + Optimize grouped-attn + """ + off_zm = tl.program_id(0) + off_h = tl.program_id(1) + + off_h_for_kv = off_h // q_k_ratio + + if HAS_BATCH_DIM: + off_z = tl.program_id(2) + Q += off_z * stride_qb + K += off_z * stride_kb + V += off_z * stride_vb + Out += off_z * stride_ob + start_m = off_zm + q_start_sid = start_m * BLOCK_M # always 0 for decoding + else: + off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1] + q_start_sid = tl.load(q_start_sids + off_zm) + start_m = q_start_sid // BLOCK_M # q_sbid + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32) + q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start + k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32) + k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start + past_len = k_seqlen - q_seqlen + + Q += q_cu_start * stride_qt + off_h * stride_qh + K += k_cu_start * stride_kt + off_h_for_kv * stride_kh + V += k_cu_start * stride_vt + off_h_for_kv * stride_vh + Out += q_cu_start * stride_ot + off_h * stride_oh + + q_pbid = (past_len + q_start_sid) // BLOCK_M + + if EVEN_D: + q = tl.load( + Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, + mask=offs_m[:, None] < q_seqlen, + other=0.0, + ) + else: + q = tl.load( + Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, + mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), + other=0.0, + ) + + sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h + + q_pbid * layout_crow_stride_m) + + # TODO(linxihui): load at once, with any Triton version + # that supports `tl.split`, e.g., Triton 3.0 + k_block_start = tl.load(sparse_crow_ptr).to(tl.int32) + k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32) + + m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) + acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32) + + k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd + v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd + + sm_scale *= ( + 1.44269504 # 1/log2 as we use base2 for exponential and logarithm + ) + + for k_block_col_idx in range(k_block_start, k_block_end - 1): + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_col_idx, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + False, + BLOCK_M_LOADING, + BLOCK_N, + D_HEAD, + EVEN_D, + M_LT_N, + ) + + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_end - 1, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + True, + BLOCK_M_LOADING, + BLOCK_N, + D_HEAD, + EVEN_D, + M_LT_N, + ) + + # flash-attn 2 + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + + # write output + if EVEN_D: + tl.store( + Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, + acc, + mask=offs_m[:, None] < q_seqlen, + ) + else: + tl.store( + Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, + acc, + mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), + ) diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/interface.py b/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..6ab69ea5b4098d8a1cc0f3728c5f013f9564718d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/interface.py @@ -0,0 +1,238 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math + +import torch + +from vllm.platforms import current_platform + +from .utils import (dense_to_crow_col, get_head_sliding_step, + get_sparse_attn_mask) + +IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80) + +if IS_COMPUTE_8_OR_ABOVE: + from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd + + +class LocalStridedBlockSparseAttn(torch.nn.Module): + + def __init__( + self, + n_heads, + max_seqlen, + local_blocks, + vert_stride, + block_size, + device=None, + dtype=None, + homo_head=False, + active_head_range=None, + q_block_size=None, + use_spda=None, + ): + super().__init__() + if use_spda is None: + use_spda = current_platform.is_rocm() or \ + current_platform.is_cpu() or not \ + IS_COMPUTE_8_OR_ABOVE + device = device or (torch.cuda.current_device() + if current_platform.is_cuda_alike() else "cpu") + device = torch.device(device) + # NOTE: vllm CPU backend support BF16 instead of FP16. + dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE + or device.type == "cpu" else torch.half) + + self.n_heads = n_heads + self.max_seqlen = max_seqlen + self.local_blocks = local_blocks + self.vert_stride = vert_stride + self.use_spda = use_spda + self.dtype = dtype + self.device = device + self.block_size = block_size + self.q_block_size = q_block_size + self.homo_head = homo_head + self.active_head_range = active_head_range + self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride, + homo_head) + + sparse_layout, sparse_pattern, self.dense_attn_mask = ( + self.get_attn_pattern(dtype, device)) + + if q_block_size is not None and q_block_size != block_size: + if q_block_size > block_size: + assert q_block_size % block_size == 0 + blocks_to_merge = q_block_size // block_size + shape = sparse_pattern.shape + sparse_pattern = sparse_pattern.view(shape[0], -1, + blocks_to_merge, + shape[-1]) + sparse_pattern = sparse_pattern.sum(2) + sparse_layout = dense_to_crow_col(sparse_pattern) + else: + raise ValueError( + "Does not support smaller q_block_size. It will be slower." + ) + + self.sparse_layout = sparse_layout + + def get_attn_pattern(self, dtype, device): + sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask( + self.n_heads, + self.max_seqlen, + self.max_seqlen, + dtype, + device, + block_size=self.block_size, + local_blocks=self.local_blocks, + vert_stride=self.vert_stride, + homo_head=self.homo_head, + return_dense=self.use_spda, + dense_mask_type="bias", + ) + if (not self.homo_head) and (self.active_head_range is not None): + assert isinstance(self.active_head_range, tuple) + assert (len(self.active_head_range) == 2) + h_start, h_end = self.active_head_range + sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout) + if self.use_spda: + dense_attn_mask = dense_attn_mask[h_start:h_end] + return sparse_layout, sparse_pattern, dense_attn_mask + + def varlen_attn(self, + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=None, + sm_scale=None): + """ + q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). + Support grouped attention, with `q[:, i*r:(i*r + r)]` + is correspondent to `k[:, i]`, where `r` is the q/k ratio. + cu_seqlens_k: shape=(batch_size + 1,), + indicating segment of samples, + e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i + cu_seqlens_q: shape=(batch_size + 1, ). + Default None: same as cu_seqlens_k for prefilling or + [0, 1, .., batch_size] for decoding. + The only case you need to specify is when q is a mix of + prefilling and decoding. + sm_scale: softmax scale, default to 1/sqrt(head_size). + + return: tensor of shape as q. + """ + assert ( + IS_COMPUTE_8_OR_ABOVE + ), "Requires compute capability of 8 or above (Ampere or newer) to use \ + Triton kernel." + + sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) + + return blocksparse_flash_attn_varlen_fwd( + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q, + sm_scale, + self.sparse_layout, + block_size=self.block_size, + q_block_size=self.q_block_size, + max_seqlen=self.max_seqlen, + ) + + @staticmethod + def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1): + """ + :param x: (total_tokens, n_heads, head_size) + :return: (batch, n_heads, length, head_size) + """ + x_padded = x.new_empty( + len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2)) + cu_seqlens = cu_seqlens.cpu() + for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): + x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0, + 1).unsqueeze(1)) + return x_padded.flatten(1, 2) + + @staticmethod + def transpose_and_unpad(x_padded, cu_seqlens): + """ + :param x_padded: (batch, n_heads, length, head_size) + :return: (total_tokens, n_heads, head_size) + """ + cu_seqlens = cu_seqlens.cpu() + total_n_tokens = cu_seqlens[-1] + x = x_padded.new_empty(total_n_tokens, x_padded.size(1), + x_padded.size(3)) + for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): + x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1)) + return x + + def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): + """For CPU, V100 or other older GPUs. + NOTE: torch SPDA supports nested tensor, + but seems extremely slow. Choose to pad instead. + """ + assert (cu_seqlens_q is None or + (cu_seqlens_q + == cu_seqlens_k).all()), "Can only handle prompt with SPDA." + assert q.size(0) == k.size(0), "can only handle prompt with SPDA." + + assert q.size(1) % k.size(1) == 0 + q_k_ratio = q.size(1) // k.size(1) + sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) + cu_seqlens = cu_seqlens_k.cpu() + maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + + if (self.dense_attn_mask.dtype != q.dtype + or self.dense_attn_mask.device != q.device): + _, _, self.dense_attn_mask = self.get_attn_pattern( + q.dtype, q.device) + attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen] + + q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1) + k2, v2 = (self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio) + for x in [k, v]) + spda_output = torch.nn.functional.scaled_dot_product_attention( + q2, k2, v2, attn_mask=attn_mask, scale=sm_scale) + return self.transpose_and_unpad(spda_output, cu_seqlens) + + def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): + """Dispatch to `varlen_attn` (Ampere or newer) or + `self.spda`(cpu, Volta, Turing or older)based on + the type of device used and cuda compute capability. + + q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). + Support grouped attention, with `q[:, i*r:(i*r + r)]` + is correspondent to `k[:, i]`, where `r` is the q/k ratio. + cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples, + e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i + cu_seqlens_q: shape=(batch_size + 1, ). + Default None: same as cu_seqlens_k for prefilling or + [0, 1, .., batch_size] for decoding. + The only case you need to specify + is when q is a mix of prefilling + and decoding. + sm_scale: softmax scale, default to 1/sqrt(head_size). + + return: tensor of shape as q. + """ + assert k.dim() == 3 + if self.use_spda: + return self.spda( + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=cu_seqlens_q, + sm_scale=sm_scale, + ) + return self.varlen_attn(q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=cu_seqlens_q, + sm_scale=sm_scale) diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/utils.py b/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4de9bd530642806c3d842674442be98f4342f4db --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/utils.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Helper functions for 3D sparse pattern +# These function are not optimized and very inefficient. +# Avoid calling them too frequent or use a cache mechanism. + +from functools import lru_cache + +import numpy as np +import torch +import triton + + +class csr_matrix: + """Simple implementation of CSR matrix conversion without scipy. + This replaced scipy.sparse.csr_matrix() previously used.""" + + def __init__(self, input_array): + if not isinstance(input_array, np.ndarray): + raise ValueError("Input must be a NumPy array") + + self.shape = input_array.shape + rows, cols = self.shape + data = [] + indices = [] + indptr = [0] + + for i in range(rows): + for j in range(cols): + if input_array[i, j]: + data.append(input_array[i, j]) + indices.append(j) + indptr.append(len(indices)) + + self.data = np.array(data) + self.indices = np.array(indices) + self.indptr = np.array(indptr) + + +def dense_to_crow_col(x: torch.Tensor): + """Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing. + NOTE: col_indices padded -1 + """ + device = x.device + pad = -1 + dim = x.dim() + assert x.dim() in (2, 3) + if x.dim() == 2: + x = x[None] + x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x] + crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x]) + cols = [torch.from_numpy(xi.indices) for xi in x] + max_cols = max(len(xi) for xi in cols) + cols = [ + torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) + for xi in cols + ] + cols = torch.vstack(cols) + if dim == 2: + crows = crows[0] + cols = cols[0] + return crows.to(device), cols.to(device) + + +def crow_col_to_dense(crows: torch.Tensor, + cols: torch.Tensor, + dtype: torch.dtype = torch.float16): + dim = crows.dim() + if dim == 1: + crows = crows[None] + cols = cols[None] + device = crows.device + crows, cols = crows.cpu(), cols.cpu() # faster in cpu + shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1) + x = torch.zeros(shape, dtype=dtype) + for i in range(shape[0]): + for j in range(shape[1]): + x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1 + if dim == 1: + x = x[0] + return x.to(device) + + +def dense_to_ccol_row(x: torch.Tensor): + """Similar, but to CSC format""" + x = x.transpose(-2, -1) + return dense_to_crow_col(x) + + +def ccol_row_to_dense(ccol: torch.Tensor, + rows: torch.Tensor, + dtype: torch.dtype = torch.float16): + return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous() + + +def _get_sparse_attn_mask_homo_head( + q_len: int, + max_seqlen: int, + dtype: torch.dtype, + device: torch.device, + block_size: int = 128, + local_blocks: int = 4, + vert_stride: int = 4, + return_dense: bool = False, +): + """ + :return: a tuple of 3: + - tuple of crow_indices, col_indices representation + of CSR format. + - block dense mask + - all token dense mask (be aware that it can be + OOM if it is too big) if `return_dense==True`, + otherwise, None + """ + with torch.no_grad(): + num_blocks = triton.cdiv(max_seqlen, block_size) + q_pos = torch.arange(num_blocks)[:, None] + k_pos = torch.arange(num_blocks)[None] + mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0 + block_mask_dense = (((q_pos >= k_pos) + & ((q_pos - k_pos < local_blocks) + | mask_vert_strided)).to(device).to(dtype)) + num_blocks_q = triton.cdiv(q_len, block_size) + block_mask_dense_output = (dense_to_crow_col( + block_mask_dense[-num_blocks_q:].contiguous())) + if return_dense: + mask_dense = torch.kron( + block_mask_dense, + block_mask_dense.new_ones((block_size, block_size)), + ) + causal_mask = torch.tril(torch.ones( + max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] + mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask + return ( + block_mask_dense_output, + block_mask_dense, + mask_dense, + ) + else: + return ( + block_mask_dense_output, + block_mask_dense, + None, + ) + + +def binary_mask_to_bias(mask_dense: torch.Tensor): + mask_dense = 1 - mask_dense + mask_dense.masked_fill_(mask_dense.bool(), -torch.inf) + return mask_dense + + +def get_head_sliding_step(n_heads: int, + vert_stride: int, + homo_head: bool = False): + if homo_head: + return 0 + return max(1, int(vert_stride / n_heads)) + + +@lru_cache +def get_sparse_attn_mask( + n_heads: int, + q_len: int, + max_seqlen: int, + dtype: torch.dtype, + device: torch.device, + block_size: int = 64, + local_blocks: int = 4, + vert_stride: int = 4, + homo_head: bool = True, + return_dense: bool = False, + dense_mask_type: str = "binary", +): + """ + :param dense_mask_type: "binary" (0 for skip token, 1 for others) + or "bias" (-inf for skip token, 0 or others) + :return: a tuple of 3: + - tuple of crow_indices, col_indices representation + of CSR format. + - block dense mask + - all token dense mask (be aware that it can be OOM if it + is too big) if `return_dense==True`, otherwise, None + """ + assert dense_mask_type in ("binary", "bias") + if homo_head: + with torch.no_grad(): + (crow, col), block_mask_dense, mask_dense = ( + _get_sparse_attn_mask_homo_head( + q_len, + max_seqlen, + dtype, + device, + block_size, + local_blocks, + vert_stride, + return_dense, + )) + crow = crow[None].expand(n_heads, crow.shape[0]) + col = col[None].expand(n_heads, col.shape[0]) + if return_dense: + mask_dense = mask_dense[None].expand(n_heads, + *mask_dense.shape) + if dense_mask_type == "bias": + mask_dense = binary_mask_to_bias(mask_dense) + return (crow, col), block_mask_dense, mask_dense + + with torch.no_grad(): + num_blocks = triton.cdiv(max_seqlen, block_size) + q_pos = torch.arange(num_blocks)[None, :, None] + k_pos = torch.arange(num_blocks)[None, None] + head_sliding_step = get_head_sliding_step(n_heads, vert_stride) + mask_vert_strided = [ + (torch.arange(num_blocks) + h * head_sliding_step + 1) % + vert_stride == 0 for h in range(n_heads) + ] + mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1) + block_mask_dense = (((q_pos >= k_pos) + & ((q_pos - k_pos < local_blocks) + | mask_vert_strided)).to(device).to(dtype)) + num_blocks_q = triton.cdiv(q_len, block_size) + block_mask_dense_output = block_mask_dense[:, -num_blocks_q:] + if return_dense: + mask_dense = torch.kron( + block_mask_dense, + block_mask_dense.new_ones((block_size, block_size)), + ) + causal_mask = torch.tril(torch.ones( + max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] + mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None] + if dense_mask_type == "bias": + mask_dense = binary_mask_to_bias(mask_dense) + + return ( + dense_to_crow_col(block_mask_dense_output), + block_mask_dense, + mask_dense, + ) + else: + return ( + dense_to_crow_col(block_mask_dense_output), + block_mask_dense, + None, + ) diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/hpu_paged_attn.py b/.venv/lib/python3.11/site-packages/vllm/attention/ops/hpu_paged_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..8bb536343ed8c45d166d13e642daf23464e90118 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/ops/hpu_paged_attn.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 + +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +############################################################################### + +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch +from vllm_hpu_extension import cache_ops, ops + +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 512 + + +@dataclass +class HPUPagedAttentionMetadata: + """Metadata for PagedAttention.""" + block_list: Optional[torch.Tensor] + block_mapping: Optional[torch.Tensor] + block_usage: Optional[torch.Tensor] + block_indices: Optional[torch.Tensor] + block_offsets: Optional[torch.Tensor] + block_scales: Optional[torch.Tensor] + + +class HPUPagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 80, 96, 112, 128, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, kv_cache_dtype: str, + is_prompt: bool) -> None: + cache_ops.reshape_and_cache(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, is_prompt) + + @staticmethod + def forward_decode(**kwargs) -> torch.Tensor: + return ops.flat_pa(**kwargs) + + @staticmethod + def forward_prefix( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + subquery_start_loc: torch.Tensor, + seq_lens_tensor: torch.Tensor, + context_lens: torch.Tensor, + max_query_len: int, + alibi_slopes: Optional[torch.Tensor], + sliding_window: Optional[int], + ) -> torch.Tensor: + raise NotImplementedError( + "forward_prefix is not implemented for HPUPagedAttention") + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/ipex_attn.py b/.venv/lib/python3.11/site-packages/vllm/attention/ops/ipex_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..598ceea130d97b029cc766ebf22b8a01588b8e70 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/ops/ipex_attn.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List, Optional, Tuple + +try: + import intel_extension_for_pytorch.llm.modules as ipex_modules + _use_ipex = True +except ImportError: + _use_ipex = False + +import torch + +from vllm import _custom_ops as ops + + +class _PagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 80, 96, 112, 128, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size * num_kv_heads * head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 16 // kv_cache.element_size() + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + @staticmethod + def forward_decode( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + tp_rank: int = 0 + blocksparse_local_blocks: int = 0 + blocksparse_vert_stride: int = 0 + blocksparse_block_size: int = 64 + blocksparse_head_sliding_step: int = 0 + block_size = value_cache.shape[3] + + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + *args, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +class _IPEXPagedAttention(_PagedAttention): + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + ipex_modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, + slot_mapping.flatten().int()) + + @staticmethod + def forward_decode( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + block_size = value_cache.shape[2] + head_mapping = torch.arange( + 0, + num_kv_heads, + device="cpu", + dtype=torch.int32, + ).view(num_kv_heads, + 1).repeat_interleave(query.size(1) // num_kv_heads).flatten() + ipex_modules.PagedAttention.single_query_cached_kv_attention( + output, query.contiguous(), key_cache, value_cache, head_mapping, + scale, block_tables, context_lens, block_size, max_context_len, + alibi_slopes) + + +PagedAttention = _IPEXPagedAttention if _use_ipex else _PagedAttention diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/nki_flash_attn.py b/.venv/lib/python3.11/site-packages/vllm/attention/ops/nki_flash_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..68aa63f5ac16c3f094fed3579a95a4716695c712 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/ops/nki_flash_attn.py @@ -0,0 +1,684 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass + +import neuronxcc.nki.isa as nisa +import neuronxcc.nki.language as nl +import numpy as np +from neuronxcc import nki +from neuronxcc.nki.language import par_dim + + +@dataclass(frozen=True) +class FlashConfig: + """ + Config class for flash attention with default values + """ + + seq_tile_size: int = 2048 + should_transpose_v: bool = False + + __annotations__ = { + "seq_tile_size": int, + "should_transpose_v": bool, + } + + +@nki.jit +def transpose_p_local(p_local_transposed, + p_local, + LARGE_TILE_SZ, + forward_mask, + B_F_SIZE=512): + for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE): + if nisa.get_nc_version() == nisa.nc_version.gen3: + p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE), + buffer=nl.sbuf, + dtype=p_local.dtype) + else: + p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE), + buffer=nl.psum, + dtype=np.float32) + + for j in nl.affine_range(B_F_SIZE // 128): + j_128_slice = nl.ds(j * 128, 128) + i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128) + + if nisa.get_nc_version() == nisa.nc_version.gen3: + p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose( + p_local[:, i_j_128_slice], mask=forward_mask) + else: + p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose( + p_local[:, i_j_128_slice], mask=forward_mask) + + p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy( + p_local_t_tmp, dtype=p_local_transposed.dtype, mask=forward_mask) + + +@nki.jit +def _flash_attention_core( + q_local_tile, + k, + v, + q_h_per_k_h, + seqlen_q, + nheads, + o_buffer, + l_buffer, + m_buffer, + batch_id, + head_id, + gqa_head_idx, + q_tile_idx, + local_k_large_tile_idx, + kernel_dtype, + acc_type, + flash_config: FlashConfig, + use_causal_mask=False, + continuous_batching_mask=None, + initialize=False, + B_P_SIZE=128, + B_F_SIZE=512, + B_D_SIZE=128, + dropout_p=0.0, + dropout_p_tensor=None, + seed_tensor=None, + logit_bias_tile=None, + qk_res_buffer=None, +): + """ + The flash attention core function to calculate self attention between a tile + of q and a block of K and V. + The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF + already. The block size of K and V + is defined in the seq_tile_size of the flash_config. The results are stored + in the following three buffers + o_buffer: (B_P_SIZE, d) + l_buffer: (B_P_SIZE, 1) + m_buffer: (B_P_SIZE, 1) + """ + LARGE_TILE_SZ = flash_config.seq_tile_size + num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE + seqlen_k = k.shape[-1] + seqlen_q // B_P_SIZE + seqlen_k // B_F_SIZE + + # TODO : support logit_bias with continuous_batching_mask + assert not use_causal_mask, "causal mask is not supported." + assert (continuous_batching_mask + is not None), "continuous_batching_mask input is required." + if continuous_batching_mask is not None: + assert ( + logit_bias_tile + is None), "continuous_batching_mask does not support logit_bias!" + + # mask are used to only apply computation to the lower half of the matrix, + # which reduce the arithmetic intensity by half + forward_mask = (q_tile_idx * B_P_SIZE >= local_k_large_tile_idx * + LARGE_TILE_SZ if use_causal_mask else None) + + qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), + buffer=nl.sbuf, + dtype=acc_type) + max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile), + dtype=acc_type) + for k_i in nl.affine_range(num_k_tile_per_large_tile): + k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE) + + qk_psum = nl.zeros((par_dim(B_P_SIZE), B_F_SIZE), + dtype=np.float32, + buffer=nl.psum) # (128, 512) + qk_psum[:, :] = nl.matmul(q_local_tile, + k[:, k_i_b_f_slice], + transpose_x=True, + mask=None) # (p(128), 512) + + qk_res_buf[:, k_i_b_f_slice] = nl.where( + continuous_batching_mask[:, k_i_b_f_slice], + qk_psum[:, nl.ds(0, B_F_SIZE)], + -9984.0, + dtype=acc_type, + ) + + # Calculate max of the current tile + max_local[:, k_i] = nisa.tensor_reduce( + np.max, + qk_res_buf[:, k_i_b_f_slice], + axis=(1, ), + dtype=acc_type, + negate=False, + mask=forward_mask, + ) + + if qk_res_buffer is not None: + qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :]) + + max_ = nisa.tensor_reduce( + np.max, + max_local[:, :], + axis=(1, ), + dtype=acc_type, + negate=False, + mask=forward_mask, + ) + + o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE), + dtype=o_buffer.dtype) + + if initialize: + m_buffer[:, 0] = nl.copy(max_) + m_current = max_ + else: + m_previous = nl.copy(m_buffer[:, 0]) + m_buffer[:, 0] = nl.maximum(m_previous, max_, + mask=forward_mask) # (128,1) + + m_current = m_buffer[:, 0] + # Compute scaling factor + alpha = nisa.activation( + np.exp, + m_previous, + bias=-1 * m_current, + scale=1.0, + mask=forward_mask, + ) + o_previous_scaled[...] = nl.multiply(o_buffer[:, :], + alpha, + mask=forward_mask) + + p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype) + REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2) + + p_partial_sum = nl.ndarray( + (par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), dtype=acc_type) + + for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE): + k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE) + + # compute exp(qk - max) + # Compute partial row - tile sum of exp(qk - max)) + # FIXME : Use activation accumulate to accumulate over k_r_i loop ? + p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce( + np.exp, + qk_res_buf[:, k_r_i_reduce_slice], + bias=-1 * m_current, + scale=1.0, + reduce_op=nl.add, + reduce_res=p_partial_sum[:, k_r_i], + dtype=kernel_dtype, + mask=forward_mask, + ) + + ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type, mask=forward_mask) + + p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype) + transpose_p_local( + p_local_transposed=p_local_transposed, + p_local=p_local, + LARGE_TILE_SZ=LARGE_TILE_SZ, + forward_mask=forward_mask, + B_F_SIZE=B_F_SIZE, + ) + + pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE), + dtype=np.float32, + buffer=nl.psum) + for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): + pv_psum[:, :] += nl.matmul( + p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)], + v[k_i, :, :], + transpose_x=True, + mask=forward_mask, + ) # (128, 128) (p(Br), d) + + if initialize: + o_buffer[:, :] = nl.copy(pv_psum[:, :]) + l_buffer[:, 0] = nl.add(nl.log(ps), max_) + else: + o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum, mask=forward_mask) + + l_prev = l_buffer[:, 0] + l_exp = nl.add( + nl.exp( + nl.subtract(l_prev, m_current, mask=forward_mask), + mask=forward_mask, + ), + ps, + mask=forward_mask, + ) + l_buffer[:, 0] = nl.add(m_current, + nl.log(l_exp, mask=forward_mask), + mask=forward_mask) + + +@nki.jit +def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config): + LARGE_TILE_SZ = config.seq_tile_size + B_P_SIZE = 128 + + if not config.should_transpose_v: + cur_v_tile[v_i, :, :] = nl.load( + v_hbm_tile[nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), :], + dtype=cur_v_tile.dtype, + ) + return + + if nisa.get_nc_version() == nisa.nc_version.gen3: + cur_v_tile_transposed = nisa.dma_transpose( + v_hbm_tile[:, + nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)]) + cur_v_tile[v_i, :, :] = nisa.tensor_copy(cur_v_tile_transposed, + dtype=cur_v_tile.dtype) + return + + cur_v_tile[v_i, :, :] = nl.load_transpose2d( + v_hbm_tile[:, nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)], + dtype=cur_v_tile.dtype, + ) + + +@nki.jit +def flash_paged_attention( + query, + key, + value, + key_cache, + value_cache, + block_tables, + mask, + softmax_scale=None, + mixed_precision=True, + config=None, + return_debug_tensors=False, +): + """ + Flash PagedAttention Forward Kernel. + - PagedAttention Paper: https://arxiv.org/abs/2309.06180 + - Chunked Prefill Paper: https://arxiv.org/abs/2403.02310 + + IO tensor layouts: + - query: shape (1, n_heads, d, seq_q) + - key: shape (1, n_kv_heads, d, seq_k) + - value: shape (1, n_kv_heads, seq_v, d) + - key_cache: (num_blocks, block_size, n_kv_heads, d) + - value_cache: (num_blocks, block_size, n_kv_heads, d) + - block_tables: (num_active_blocks, ) + - mask: (seq_q, num_active_blocks * block_size) + - o: shape (1, n_heads, seq_q, d) + - l_m: shape (1, n_heads, seq_q, 2) + + - This kernel requires seq_k == seq_v + - We use continuous batching by default, so the batch dimension is + always 1, and different requests are concatenated along sequence + dimension. + - We use paged cache blocks (key_cache, value_cache) to store KV cache. + + IO tensor dtypes: + - This kernel assumes all IO tensors have the same dtype except for + block_tables (int32) and mask (int32) + - If mixed_percision is True, then all Tensor Engine operation will be + performed in bfloat16 and accumulation will be performed in float32. + Otherwise the intermediates will be in the same type as the inputs. + + Compile-time Constants: + - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)` + - mixed_precision: flag to set non-matmul ops in fp32 precision, default + is set to `true`, if false, we use same precision as input types + - config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig` + with Performance config parameters for flash attention with default + values + seq_tile_size: `default=2048`, size of the kv tile size for attention + computation reduction + + GQA support Notes: + the spmd kernel for launching kernel should be on kv_heads instead of + nheads + + Example usage: + MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d] + usage: `flash_fwd[b, h](q, k, v, ...)` + GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d] + usage: `flash_fwd[b, kv_h](q, k, v, ...)` + """ + config = config or FlashConfig() + B_F_SIZE = 512 + B_P_SIZE = 128 + b, h, d, seqlen_q = query.shape + B_D_SIZE = d + LARGE_TILE_SZ = config.seq_tile_size + n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine + num_blocks, block_size, k_h, _ = key_cache.shape + q_h_per_k_h = h // k_h + assert tuple(key_cache.shape) == ( + num_blocks, + block_size, + k_h, + d, + ), "Input shape mismatch!" + assert tuple(value_cache.shape) == ( + num_blocks, + block_size, + k_h, + d, + ), "Input shape mismatch!" + assert b == 1, f"invalid batch size {b=}" + assert d <= 128, f" we do not support head_dim > 128, got head dim {d}" + kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype + acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype + + o = nl.ndarray((b, h, seqlen_q, d), + dtype=query.dtype, + buffer=nl.shared_hbm) + hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = ( + None, + None, + None, + None, + ) + if return_debug_tensors: + hbm_l_buffer = nl.ndarray((b, h, seqlen_q), + dtype=acc_type, + buffer=nl.shared_hbm) + hbm_m_buffer = nl.ndarray((b, h, seqlen_q), + dtype=acc_type, + buffer=nl.shared_hbm) + hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q), + dtype=acc_type, + buffer=nl.shared_hbm) + qk_res_buffer = nl.zeros( + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + + assert ( + nl.program_ndim() == 2 + ), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!" + batch_id = nl.program_id(axis=0) + head_id = nl.program_id(axis=1) + + softmax_scale = softmax_scale or (1.0 / (d**0.5)) + + (num_active_blocks, ) = block_tables.shape + context_kv_len = num_active_blocks * block_size + assert (config.seq_tile_size >= 512 + ), f" seq tile_size {config.seq_tile_size} cannot be less than 512" + assert (context_kv_len % LARGE_TILE_SZ == 0 + ), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}" + assert ( + LARGE_TILE_SZ % B_P_SIZE == 0 + ), f"Need LARGE_TILE_SZ ({LARGE_TILE_SZ}) to be divisible by {B_P_SIZE=}" + assert (B_P_SIZE % block_size == 0 + ), f"Need B_P_SIZE ({B_P_SIZE}) to be divisible by {block_size=}" + num_large_k_tile = context_kv_len // LARGE_TILE_SZ + num_blocks_per_large_tile = LARGE_TILE_SZ // block_size + assert (num_blocks_per_large_tile <= B_P_SIZE + ), f"The number of blocks in each large tile " \ + f"({num_blocks_per_large_tile}) shouldn't exceed partition size {B_P_SIZE}" + + block_tables_sbuf = nl.full((par_dim(B_P_SIZE), num_large_k_tile), + 0, + dtype=np.int32, + buffer=nl.sbuf) + for j in nl.affine_range(num_large_k_tile): + i_p = nl.arange(num_blocks_per_large_tile)[:, None] + block_tables_sbuf[i_p, j] = nl.load( + block_tables[j * num_blocks_per_large_tile + i_p], dtype=np.int32) + + # Global Flash Attention accumulators + o_buffer = nl.zeros( + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + l_buffer = nl.zeros( + (par_dim(B_P_SIZE), n_tile_q, q_h_per_k_h), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + m_buffer = nl.zeros( + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + + for j in nl.sequential_range(0, num_large_k_tile): + cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype) + cur_v_tile = nl.ndarray( + (LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), + dtype=kernel_dtype, + ) + + for k_i in nl.affine_range(num_blocks_per_large_tile): + loaded = nl.load(key_cache[block_tables_sbuf[k_i, j], :, + head_id, :]) + cur_k_tile[:, nl.ds(k_i * + block_size, block_size)] = nl.transpose(loaded) + + load_tile_size = B_P_SIZE + num_blocks_per_partition = load_tile_size // block_size + for partition_idx in nl.affine_range(LARGE_TILE_SZ // load_tile_size): + for block_in_partition in nl.affine_range( + num_blocks_per_partition): + v_i = (partition_idx * num_blocks_per_partition + + block_in_partition) + loaded_v = nl.load(value_cache[block_tables_sbuf[v_i, j], :, + head_id, :]) + cur_v_tile[ + partition_idx, + nl.ds(block_in_partition * block_size, block_size), + :, + ] = loaded_v + + cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), + dtype=mask.dtype) + for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE): + cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load( + mask[:, nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE)]) + + for i_q_h in nl.affine_range(q_h_per_k_h): + for i in nl.affine_range(n_tile_q): + q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) + q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] + q_sbuf_tile = nl.load( + q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)], + dtype=kernel_dtype, + ) # load (d, 128) tile in SBUF + q_tile[:, :] = q_sbuf_tile * softmax_scale + + _flash_attention_core( + q_local_tile=q_tile, + k=cur_k_tile, + v=cur_v_tile, + q_h_per_k_h=q_h_per_k_h, + seqlen_q=seqlen_q, + nheads=h, + o_buffer=o_buffer[i, i_q_h], + l_buffer=l_buffer[:, i, i_q_h], + m_buffer=m_buffer[i, i_q_h], + batch_id=batch_id, + head_id=head_id, + gqa_head_idx=i_q_h, + q_tile_idx=i, + local_k_large_tile_idx=j, + kernel_dtype=kernel_dtype, + acc_type=acc_type, + flash_config=config, + use_causal_mask=False, + continuous_batching_mask=cur_mask, + initialize=j == 0, + B_P_SIZE=B_P_SIZE, + B_F_SIZE=B_F_SIZE, + B_D_SIZE=B_D_SIZE, + dropout_p=0.0, + dropout_p_tensor=None, + seed_tensor=None, + logit_bias_tile=None, + ) + + # compute attention between input query, key and value + if key is not None and value is not None: + B_F_SIZE = seqlen_q + LARGE_TILE_SZ = seqlen_q + active_config = FlashConfig( + seq_tile_size=LARGE_TILE_SZ, + should_transpose_v=config.should_transpose_v, + ) + + cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype) + cur_v_tile = nl.ndarray( + (LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), + dtype=kernel_dtype, + ) + + cur_k_tile[:, :] = nl.load(key[batch_id, head_id, :, :]) + + load_tile_size = B_P_SIZE + v_hbm_tile = value[batch_id, head_id] + for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size): + load_v_tile( + v_hbm_tile=v_hbm_tile, + cur_v_tile=cur_v_tile, + j=0, + v_i=v_i, + config=active_config, + ) + + cur_mask = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), dtype=mask.dtype) + cur_mask[:, :] = nl.load(mask[:, nl.ds(context_kv_len, B_F_SIZE)]) + + for i_q_h in nl.affine_range(q_h_per_k_h): + for i in nl.affine_range(n_tile_q): + q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) + q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] + q_sbuf_tile = nl.load( + q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)], + dtype=kernel_dtype, + ) # load (d, 128) tile in SBUF + q_tile[:, :] = q_sbuf_tile * softmax_scale + _flash_attention_core( + q_local_tile=q_tile, + k=cur_k_tile, + v=cur_v_tile, + q_h_per_k_h=q_h_per_k_h, + seqlen_q=seqlen_q, + nheads=h, + o_buffer=o_buffer[i, i_q_h], + l_buffer=l_buffer[:, i, i_q_h], + m_buffer=m_buffer[i, i_q_h], + batch_id=batch_id, + head_id=head_id, + gqa_head_idx=i_q_h, + q_tile_idx=i, + local_k_large_tile_idx=0, + kernel_dtype=kernel_dtype, + acc_type=acc_type, + flash_config=active_config, + use_causal_mask=False, + continuous_batching_mask=cur_mask, + initialize=False, + B_P_SIZE=B_P_SIZE, + B_F_SIZE=B_F_SIZE, + B_D_SIZE=B_D_SIZE, + dropout_p=0.0, + dropout_p_tensor=None, + seed_tensor=None, + logit_bias_tile=None, + qk_res_buffer=qk_res_buffer[i, i_q_h] + if qk_res_buffer is not None else None, + ) + + # -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- # + for i_q_h in nl.affine_range(q_h_per_k_h): + for i in nl.affine_range(n_tile_q): + out = nl.multiply( + o_buffer[i, i_q_h, :, :], + nl.exp(m_buffer[i, i_q_h, :, :] - l_buffer[:, i, i_q_h]), + dtype=kernel_dtype, + ) + + nl.store( + o[ + batch_id, + head_id * q_h_per_k_h + i_q_h, + nl.ds(i * B_P_SIZE, B_P_SIZE), + :, + ], + out, + ) + # maximum and summation statistics + if return_debug_tensors: + nl.store( + hbm_m_buffer[ + batch_id, + head_id * q_h_per_k_h + i_q_h, + nl.ds(i * B_P_SIZE, B_P_SIZE), + ], + m_buffer[i, i_q_h, :, :], + ) + nl.store( + hbm_l_buffer[ + batch_id, + head_id * q_h_per_k_h + i_q_h, + nl.ds(i * B_P_SIZE, B_P_SIZE), + ], + l_buffer[:, i, i_q_h], + ) + nl.store( + hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :], + qk_res_buffer[batch_id, i_q_h, :, :], + ) + + if return_debug_tensors: + return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res + return o + + +def flash_attn_varlen_nkifunc( + query, + key, + value, + key_cache, + value_cache, + block_table, + attn_mask, + n_kv_head=None, + head_size=None, + B_P_SIZE=128, + LARGE_TILE_SZ=2048, + return_debug_tensors=False, + mixed_precision=True, +): + config = FlashConfig( + seq_tile_size=LARGE_TILE_SZ, + should_transpose_v=False, + ) + kwargs = dict( + query=query, + key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_table, + mask=attn_mask, + softmax_scale=1.0 / (head_size**0.5), + config=config, + mixed_precision=mixed_precision, + return_debug_tensors=return_debug_tensors, + ) + _, n_kv_head, _, _ = key.shape + + if return_debug_tensors: + o, *debug_tensors = flash_paged_attention[1, n_kv_head](**kwargs) + return o, *debug_tensors + else: + o = flash_paged_attention[1, n_kv_head](**kwargs) + return o diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/paged_attn.py b/.venv/lib/python3.11/site-packages/vllm/attention/ops/paged_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..2c60bd0c38d66e9e4f31caa2688b9ca4df92c701 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/ops/paged_attn.py @@ -0,0 +1,255 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.attention.ops.prefix_prefill import context_attention_fwd + +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 512 + + +@dataclass +class PagedAttentionMetadata: + """Metadata for PagedAttention.""" + # (batch_size,). The length of sequences (entire tokens seen so far) per + # sequence. + seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length in the batch. 0 if it is prefill-only batch. + max_decode_seq_len: int + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + +class PagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 80, 96, 112, 120, 128, 192, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size * num_kv_heads * head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 16 // kv_cache.element_size() + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + ) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> torch.Tensor: + if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: + # use blocksparse paged attention + block_size = value_cache.size(-1) + assert (blocksparse_block_size > 0 and + blocksparse_block_size % block_size == 0), \ + (f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables.") + + output = torch.empty_like(query) + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory shortage. + use_v1 = (max_seq_len <= 8192 + and (max_num_partitions == 1 or num_seqs * num_heads > 512)) + + if use_v1: + # Run PagedAttention V1. + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + return output + + @staticmethod + def forward_prefix( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache_dtype: str, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + query_start_loc: torch.Tensor, + seq_lens_tensor: torch.Tensor, + context_lens: torch.Tensor, + max_query_len: int, + alibi_slopes: Optional[torch.Tensor], + sliding_window: Optional[int], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + ) -> torch.Tensor: + output = torch.empty_like(query) + context_attention_fwd( + query, + key, + value, + output, + kv_cache_dtype, + key_cache, + value_cache, + block_tables, + # query_start_loc is (batch_size + 1,) + query_start_loc[:-1], + seq_lens_tensor, + context_lens, + max_query_len, + k_scale, + v_scale, + alibi_slopes, + sliding_window, + ) + return output + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/prefix_prefill.py b/.venv/lib/python3.11/site-packages/vllm/attention/ops/prefix_prefill.py new file mode 100644 index 0000000000000000000000000000000000000000..5fca1639363e002468d56a6f28c32d8121bab6b8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/ops/prefix_prefill.py @@ -0,0 +1,878 @@ +# SPDX-License-Identifier: Apache-2.0 + +# The kernels in this file are adapted from LightLLM's context_attention_fwd: +# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py + +import torch +import triton +import triton.language as tl + +from vllm.platforms import current_platform + +# Static kernels parameters +BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64 +NUM_WARPS = 4 if current_platform.is_rocm() else 8 + +# To check compatibility +IS_TURING = current_platform.get_device_capability() == (7, 5) + +if triton.__version__ >= "2.1.0": + + @triton.jit + def _fwd_kernel( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len + + # start position inside of the query + # generally, N goes over kv, while M goes over query_len + block_start_loc = BLOCK_M * start_m + + # initialize offsets + # [N]; starts at 0 + offs_n = tl.arange(0, BLOCK_N) + # [D]; starts at 0 + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + # [M]; starts at current position in query + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # [M,D] + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, + 0).to(tl.int1) # [D] + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len), + other=0.0) # [M,D] + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M] + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M] + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], + dtype=tl.float32) # [M,D] + + # compute query against context (no causal mask here) + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) # [N] + # [D,N] + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + # [N,D] + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k_load = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + if SLIDING_WINDOW > 0: + # (cur_batch_ctx_len + offs_m[:, None]) are the positions of + # Q entries in sequence + # (start_n + offs_n[None, :]) are the positions of + # KV entries in sequence + # So the condition makes sure each entry in Q only attends + # to KV entries not more than SLIDING_WINDOW away. + # + # We can't use -inf here, because the + # sliding window may lead to the entire row being masked. + # This then makes m_ij contain -inf, which causes NaNs in + # exp(). + qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - + (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, + -10000) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) # [M] + p = tl.exp(qk - m_ij[:, None]) # [M,N] + l_ij = tl.sum(p, 1) # [M] + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) # [M] + alpha = tl.exp(m_i - m_i_new) # [M] + beta = tl.exp(m_ij - m_i_new) # [M] + l_i_new = alpha * l_i + beta * l_ij # [M] + + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v_load = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) # [N,D] + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + # block_mask is 0 when we're already past the current query length + block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) + + # compute query against itself (with causal mask) + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_query_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk *= sm_scale + # apply causal mask + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + if SLIDING_WINDOW > 0: + qk = tl.where( + offs_m[:, None] - (start_n + offs_n[None, :]) + < SLIDING_WINDOW, qk, -10000) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_query_len), + other=0.0) + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len)) + return + + @triton.jit + def _fwd_kernel_flash_attn_v2( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load(Q + off_q, + mask=offs_m[:, None] + < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) + < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) + < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + # acc /= l_i[:, None] + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + @triton.jit + def _fwd_kernel_alibi( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + Alibi_slopes, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + ): + # attn_bias[] + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + # cur_batch_seq_len: the length of prompts + # cur_batch_ctx_len: the length of prefix + # cur_batch_in_all_start_index: the start id of the dim=0 + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) + + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange( + 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = 0 + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k_load = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v_load = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision='ieee') + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + # init alibi + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange( + 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = cur_batch_ctx_len + # # init debugger + # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc + # offset_db_k = tl.arange(0, BLOCK_N) + # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) + < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision='ieee') + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) + < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision='ieee') + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) + return + + @torch.inference_mode() + def context_attention_fwd(q, + k, + v, + o, + kv_cache_dtype: str, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + alibi_slopes=None, + sliding_window=None): + + q_dtype_is_f32 = q.dtype is torch.float32 + # need to reduce num. blocks when using fp32 + # due to increased use of GPU shared memory + # if q.dtype is torch.float32: + BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK + + # Turing does have tensor core for float32 multiplication + # use ieee as fallback for triton kernels work. There is also + # warning on vllm/config.py to inform users this fallback + # implementation + IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None + + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert (k_cache.dtype == torch.uint8) + assert (v_cache.dtype == torch.uint8) + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + k_cache = k_cache.view(target_dtype) + v_cache = v_cache.view(target_dtype) + + if (k_cache.dtype == torch.uint8 + or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): + raise ValueError("kv_cache_dtype='auto' unsupported for\ + FP8 KV Cache prefill kernel") + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + # round up Lk to a power of 2 - this is required for Triton block size + Lk_padded = triton.next_power_of_2(Lk) + + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + num_queries_per_kv = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + + # 0 means "disable" + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + + if alibi_slopes is not None: + _fwd_kernel_alibi[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + k_scale, + v_scale, + b_start_loc, + b_seq_len, + b_ctx_len, + alibi_slopes, + v_cache.shape[3], + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4 + ), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + BLOCK_N=BLOCK, + num_warps=NUM_WARPS, + num_stages=1, + ) + return + + _fwd_kernel[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + k_scale, + v_scale, + b_start_loc, + b_seq_len, + b_ctx_len, + v_cache.shape[3], + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + BLOCK_N=BLOCK, + SLIDING_WINDOW=sliding_window, + num_warps=NUM_WARPS, + num_stages=1, + ) + return diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/triton_decode_attention.py b/.venv/lib/python3.11/site-packages/vllm/attention/ops/triton_decode_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..057fccb5e5981b7884db23984490145ab8d80d18 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/ops/triton_decode_attention.py @@ -0,0 +1,669 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +# which was originally adapted from +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py + +# Changes: +# - Add support for page size >= 1. + +# Copyright 2025 vLLM Team +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Memory-efficient attention for decoding. +It supports page size >= 1. +""" + +import logging + +import triton +import triton.language as tl + +from vllm.platforms import current_platform + +is_hip_ = current_platform.is_rocm() + +logger = logging.getLogger(__name__) + +# TODO: Remove this when triton>=3.2.0. This issue will not affect performance +# and accuracy. +logger.warning( + "The following error message 'operation scheduled before its operands' " + "can be ignored.") + + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def _fwd_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + split_kv_id = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = cur_batch + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + q = tl.load(Q + off_q, mask=mask_d, other=0.0) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, + cur_batch_seq_len) + + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_buf_k = (kv_loc[:, None] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + offs_d[None, :]) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), + other=0.0, + ) + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) + + offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + offs_dv[None, :]) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) + + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max + + offs_mid_o = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + offs_dv) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum, + mask=(mask_dv), + ) + + offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + Lv) + + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + ) + + +def _decode_att_m_fwd( + q, + k_buffer, + v_buffer, + att_out, + Req_to_tokens, + B_Seqlen, + num_kv_splits, + sm_scale, + page_size, + logit_cap, +): + BLOCK = 64 + NUM_KV_SPLITS = num_kv_splits + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + batch, head_num = q.shape[0], q.shape[1] + + grid = (batch, head_num, NUM_KV_SPLITS) + kv_group_num = q.shape[1] // k_buffer.shape[-2] + + num_warps = 4 if kv_group_num == 1 else 2 + + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) + + _fwd_kernel_stage1[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + NUM_KV_SPLITS=NUM_KV_SPLITS, + PAGE_SIZE=page_size, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=2, + Lk=Lk, + Lv=Lv, + ) + + +@triton.jit +def _fwd_grouped_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head_id = tl.program_id(1) + cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) + split_kv_id = tl.program_id(2) + + if kv_group_num > BLOCK_H: + VALID_BLOCK_H: tl.constexpr = BLOCK_H + else: + VALID_BLOCK_H: tl.constexpr = kv_group_num + cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H + mask_h = mask_h & (cur_head < q_head_num) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = cur_batch + + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[ + None, :] + q = tl.load(Q + offs_q, + mask=(mask_h[:, None]) & (mask_d[None, :]), + other=0.0) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + mask_dpe = offs_dpe < Lk + off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + + offs_dpe[None, :]) + qpe = tl.load(Q + off_qpe, + mask=(mask_h[:, None]) & (mask_dpe[None, :]), + other=0.0) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, + cur_batch_seq_len) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_buf_k = (kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + offs_d[:, None]) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), + other=0.0, + ) + qk = tl.dot(q, k.to(q.dtype)) + if BLOCK_DPE > 0: + offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None]) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=(offs_n[None, :] < split_kv_end) & + (mask_dpe[:, None]), + other=0.0, + ) + qk += tl.dot(qpe, kpe.to(qpe.dtype)) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end), + qk, float("-inf")) + + offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + offs_dv[None, :]) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v.dtype), v) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + offs_mid_o = (cur_batch * stride_mid_ob + + cur_head[:, None] * stride_mid_oh + + split_kv_id * stride_mid_os + offs_dv[None, :]) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum[:, None], + mask=(mask_h[:, None]) & (mask_dv[None, :]), + ) + + offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + Lv) + + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + mask=mask_h, + ) + + +def _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + att_out, + Req_to_tokens, + B_Seqlen, + num_kv_splits, + sm_scale, + page_size, + logit_cap, +): + BLOCK = 32 + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + # [TODO] work around shmem limit on MI3xx + if is_hip_ and Lk >= 576: + BLOCK = 16 + + if Lk == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + elif Lk == 288: + BLOCK_DMODEL = 256 + BLOCK_DPE = 32 + else: + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DPE = 0 + BLOCK_DV = triton.next_power_of_2(Lv) + + batch, head_num = q.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k_buffer.shape[-2] + + BLOCK_H = 16 + NUM_KV_SPLITS = num_kv_splits + grid = ( + batch, + triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), + NUM_KV_SPLITS, + ) + + extra_kargs = {} + if is_hip_: + # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = { + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } + + _fwd_grouped_kernel_stage1[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + q_head_num=head_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + BLOCK_H=BLOCK_H, + NUM_KV_SPLITS=NUM_KV_SPLITS, + PAGE_SIZE=page_size, + logit_cap=logit_cap, + num_warps=4, + num_stages=2, + Lk=Lk, + Lv=Lv, + **extra_kargs, + ) + + +@triton.jit +def _fwd_kernel_stage2( + Mid_O, + o, + B_Seqlen, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + NUM_KV_SPLITS: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, + cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, + mask=mask_d, + other=0.0) + tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + o + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def _decode_softmax_reducev_fwd( + logits, + q, + o, + v_buffer, + b_seq_len, + num_kv_splits, +): + batch, head_num = q.shape[0], q.shape[1] + Lv = v_buffer.shape[-1] + BLOCK_DV = triton.next_power_of_2(Lv) + + NUM_KV_SPLITS = num_kv_splits + + extra_kargs = {} + if is_hip_: + # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = { + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } + + grid = (batch, head_num) + _fwd_kernel_stage2[grid]( + logits, + o, + b_seq_len, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=NUM_KV_SPLITS, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + num_warps=4, + num_stages=2, + **extra_kargs, + ) + + +def decode_attention_fwd_normal( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap=0.0, +): + _decode_att_m_fwd( + q, + k_buffer, + v_buffer, + attn_logits, + req_to_token, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, + num_kv_splits) + + +def decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap=0.0, +): + _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + attn_logits, + req_to_token, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, + num_kv_splits) + + +def decode_attention_fwd( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size=1, + logit_cap=0.0, +): + assert num_kv_splits == attn_logits.shape[2] + kv_group_num = q.shape[1] // v_buffer.shape[-2] + + if kv_group_num == 1: + # MHA + decode_attention_fwd_normal( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + else: + # GQA/MQA/MLA + decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/ops/triton_flash_attention.py b/.venv/lib/python3.11/site-packages/vllm/attention/ops/triton_flash_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..745818eb6cfff349284d8965f09e2b3b953bff1a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/ops/triton_flash_attention.py @@ -0,0 +1,821 @@ +#!/usr/bin/env python +# SPDX-License-Identifier: Apache-2.0 +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao +(https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team, AMD ML Frameworks Triton team + +Features supported: + +1) Fwd with causal masking +2) Any sequence lengths without padding (currently fwd kernel only) +3) Support for different sequence lengths for q and k +4) Nested tensor API currently does not support dropout or bias. + +Not currently supported: + +1) Non power of two head dims + +""" + +import torch +import triton +import triton.language as tl + +torch_dtype: tl.constexpr = torch.float16 + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, + stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, + stride) + rng_keep = rng_output > dropout_p + return rng_keep + + +@triton.jit +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) + else: + tensor = tl.load(block_ptr) + return tensor + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + actual_seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr, +): + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + k = load_fn( + K_block_ptr, + PADDED_HEAD, + MASK_STEPS and (n_extra_tokens != 0), + "zero", + ) + if PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: # noqa: SIM102 + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps + # if not is_modulo_mn. last step might get wasted but that is okay. + # check if this masking works for that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], + actual_seqlen_k, + dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- + qk += tl.dot(q, k) + if bias_ptr is not None: + bias = load_fn(bias_ptr, False, MASK_STEPS + and (n_extra_tokens != 0), "zero") + # While bias is added after multiplying qk with sm_scale, our + # optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += bias * 1.44269504089 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = (batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + start_n - + BLOCK_N) + keep = dropout_mask( + philox_seed, + philox_offset, + dropout_p, + BLOCK_M, + BLOCK_N, + actual_seqlen_k, + ) + if RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + tl.where(keep, p, + -p).to(encoded_softmax_block_ptr.type.element_ty), + ) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + p.to(encoded_softmax_block_ptr.type.element_ty), + ) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, BLOCK_N)) + return acc, l_i, m_i + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 3, + "PRE_LOAD_V": True, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 3, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 64, + "waves_per_eu": 4, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 32, + "waves_per_eu": 4, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + # TODO: This config fails with head_size not pow2 with data mismatches. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, + # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config( + { + "BLOCK_M": 16, + "BLOCK_N": 16, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + ], + key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], +) +@triton.jit +def attn_fwd( + Q, + K, + V, + bias, + sm_scale, + L, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_bz, + stride_bh, + stride_bm, + stride_bn, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset_base, + encoded_softmax, + HQ: tl.constexpr, + HK: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, +): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if IS_CAUSAL: + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn + # matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is + # part of the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + # We still need to write 0s to the result + # tl.store(O_block_ptr, + # acc.to(Out.type.element_ty), boundary_check=(0,1)) + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + # + offs_m + # We store inf to LSE, not -inf because in the bwd pass, + # we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 + # for these masked blocks. + # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + # tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here? + return + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q + + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL + + # Compute pointers for all the tensors used in this kernel. + q_offset = (off_z * stride_qz + off_h_q * stride_qh + + cu_seqlens_q_start * stride_qm) + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_offset = (off_z * stride_kz + off_h_k * stride_kh + + cu_seqlens_k_start * stride_kn) + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + v_offset = (off_z * stride_vz + off_h_k * stride_vh + + cu_seqlens_k_start * stride_vk) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + if BIAS_TYPE != 0: + bias_ptr = tl.make_block_ptr( + base=bias + off_h_q * stride_bh, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + bias_ptr = None + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base \ + + (off_z * HQ + off_h_q) \ + * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. + # In this case, we return an invalid pointer so indicate the mask is not i + # valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr( + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + encoded_softmax_block_ptr = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q = load_fn(Q_block_ptr, True, padded_head, "zero") + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional + # block. In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, + block_max, + 0, + 0, + 0, + bias_ptr, + # IS_CAUSAL, .... + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + False, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + ) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if masked_blocks > 0: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, n_full_blocks)) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + True, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + ) + # epilogue + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: # noqa: SIM102 + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL, ), + causal_start_idx, + dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = (mask_m_offsets[:, None] + >= out_mask_boundary[None, :]) + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last + # few rows. This is only true for the last M block. For others, + # overflow_size will be -ve + # overflow_size = end_m_idx - seqlen_q + # if overflow_size > 0: + # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # # This is a > check because mask being 0 blocks the store. + # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + # else: + # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + + +def check_args( + q, + k, + v, + o, + varlen=True, + max_seqlens=None, + cu_seqlens_q=None, + cu_seqlens_k=None, +): + assert q.dim() == k.dim() and q.dim() == v.dim() + if varlen: + assert q.dim() == 3 + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + assert cu_seqlens_q is not None + assert cu_seqlens_k is not None + assert len(cu_seqlens_q) == len(cu_seqlens_k) + else: + assert q.dim() == 4 + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert max_seqlens > 0 + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + assert head_size <= 256 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + o, + cu_seqlens_q, + cu_seqlens_k, + max_seqlens_q, + max_seqlens_k, + causal=False, + sm_scale=1.0, + bias=None, + ): + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + + check_args( + q, + k, + v, + o, + varlen=True, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + if True: # varlen + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + batch = len(cu_seqlens_q) - 1 + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + batch, seqlen_q, nheads_q, head_size = q.shape + _, seqlen_k, nheads_k, _ = k.shape + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + + # Get closest power of 2 over or equal to 32. + unpadded_head_dims = {32, 64, 128, 256} + if head_size not in unpadded_head_dims: + padded_d_model = None + for i in unpadded_head_dims: + if i > head_size: + padded_d_model = i + break + assert padded_d_model is not None + else: + padded_d_model = head_size + + grid = lambda META: ( + triton.cdiv(max_seqlens_q, META["BLOCK_M"]), + nheads_q, + batch, + ) + + encoded_softmax = None + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if bias is not None: + bias_strides = ( + bias.stride(0), + bias.stride(1), + bias.stride(2), + bias.stride(3), + ) + else: + bias_strides = (0, 0, 0, 0) + + attn_fwd[grid]( + q, + k, + v, + bias, + sm_scale, + None, + o, + *q_strides, + *k_strides, + *v_strides, + *o_strides, + *bias_strides, + cu_seqlens_q, + cu_seqlens_k, + dropout_p=0.0, + philox_seed=philox_seed, + philox_offset_base=philox_offset, + encoded_softmax=encoded_softmax, + HQ=nheads_q, + HK=nheads_k, + ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=max_seqlens_q, + MAX_SEQLENS_K=max_seqlens_k, + IS_CAUSAL=causal, + VARLEN=True, + BLOCK_DMODEL=padded_d_model, + BIAS_TYPE=0 if bias is None else 1, + ENABLE_DROPOUT=False, + RETURN_ENCODED_SOFTMAX=False, + ) + + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = causal + ctx.dropout_p = 0.0 + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = False + return o, encoded_softmax + + +triton_attention = _attention.apply diff --git a/.venv/lib/python3.11/site-packages/vllm/attention/selector.py b/.venv/lib/python3.11/site-packages/vllm/attention/selector.py new file mode 100644 index 0000000000000000000000000000000000000000..26c6ac812a12507af0d8f91f13a7cd42a2ccfa59 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/attention/selector.py @@ -0,0 +1,186 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from contextlib import contextmanager +from functools import cache +from typing import Generator, Optional, Type + +import torch + +import vllm.envs as envs +from vllm.attention.backends.abstract import AttentionBackend +from vllm.logger import init_logger +from vllm.platforms import _Backend, current_platform +from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname + +logger = init_logger(__name__) + + +def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: + """ + Convert a string backend name to a _Backend enum value. + + Returns: + * _Backend: enum value if backend_name is a valid in-tree type + * None: otherwise it's an invalid in-tree type or an out-of-tree platform is + loaded. + """ + assert backend_name is not None + return _Backend[backend_name] if backend_name in _Backend.__members__ else \ + None + + +def get_env_variable_attn_backend() -> Optional[_Backend]: + ''' + Get the backend override specified by the vLLM attention + backend environment variable, if one is specified. + + Returns: + + * _Backend enum value if an override is specified + * None otherwise + ''' + backend_name = os.environ.get(STR_BACKEND_ENV_VAR) + return (None + if backend_name is None else backend_name_to_enum(backend_name)) + + +# Global state allows a particular choice of backend +# to be forced, overriding the logic which auto-selects +# a backend based on system & workload configuration +# (default behavior if this variable is None) +# +# THIS SELECTION TAKES PRECEDENCE OVER THE +# VLLM ATTENTION BACKEND ENVIRONMENT VARIABLE +forced_attn_backend: Optional[_Backend] = None + + +def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None: + ''' + Force all attention operations to use a specified backend. + + Passing `None` for the argument re-enables automatic + backend selection., + + Arguments: + + * attn_backend: backend selection (None to revert to auto) + ''' + global forced_attn_backend + forced_attn_backend = attn_backend + + +def get_global_forced_attn_backend() -> Optional[_Backend]: + ''' + Get the currently-forced choice of attention backend, + or None if auto-selection is currently enabled. + ''' + return forced_attn_backend + + +def get_attn_backend( + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool, + is_blocksparse: bool = False, + use_mla: bool = False, +) -> Type[AttentionBackend]: + """Selects which attention backend to use and lazily imports it.""" + # Accessing envs.* behind an @lru_cache decorator can cause the wrong + # value to be returned from the cache if the value changes between calls. + # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the + # private function. + return _cached_get_attn_backend( + head_size=head_size, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + block_size=block_size, + is_attention_free=is_attention_free, + is_blocksparse=is_blocksparse, + use_v1=envs.VLLM_USE_V1, + use_mla=use_mla, + ) + + +@cache +def _cached_get_attn_backend( + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool, + is_blocksparse: bool = False, + use_v1: bool = False, + use_mla: bool = False, +) -> Type[AttentionBackend]: + if is_blocksparse: + logger.info("Using BlocksparseFlashAttention backend.") + from vllm.attention.backends.blocksparse_attn import ( + BlocksparseFlashAttentionBackend) + return BlocksparseFlashAttentionBackend + + # If there are no attention layers (e.g. we are running Mamba), + # use the placeholder NO_ATTENTION + if is_attention_free: + from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionBackend) + return PlaceholderAttentionBackend + + # Check whether a particular choice of backend was + # previously forced. + # + # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND + # ENVIRONMENT VARIABLE. + selected_backend = None + backend_by_global_setting: Optional[_Backend] = ( + get_global_forced_attn_backend()) + if backend_by_global_setting is not None: + selected_backend = backend_by_global_setting + else: + # Check the environment variable and override if specified + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = backend_name_to_enum(backend_by_env_var) + + # get device-specific attn_backend + attention_cls = current_platform.get_attn_backend_cls( + selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, + use_mla) + if not attention_cls: + raise ValueError( + f"Invalid attention backend for {current_platform.device_name}") + return resolve_obj_by_qualname(attention_cls) + + +@contextmanager +def global_force_attn_backend_context_manager( + attn_backend: _Backend) -> Generator[None, None, None]: + ''' + Globally force a vLLM attention backend override within a + context manager, reverting the global attention backend + override to its prior state upon exiting the context + manager. + + Arguments: + + * attn_backend: attention backend to force + + Returns: + + * Generator + ''' + + # Save the current state of the global backend override (if any) + original_value = get_global_forced_attn_backend() + + # Globally force the new backend override + global_force_attn_backend(attn_backend) + + # Yield control back to the enclosed code block + try: + yield + finally: + # Revert the original global backend override, if any + global_force_attn_backend(original_value) diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cf51cfe768dc33d75ac8445e787e5fb6db42379 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/logits_processor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/logits_processor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e20bd2dfed864ef4535759fe34acb25bba6c75d4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/logits_processor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/rejection_sampler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/rejection_sampler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43f880a1ebff1c326f003c8cd9e24d8e7da6cba0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/rejection_sampler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/resampler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/resampler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..600610de8338f997e9c6bb3cff4c21250ae2ad27 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/resampler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/rotary_embedding.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/rotary_embedding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2b496b689c39a94536f5b7653b9eb5d717429b8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/rotary_embedding.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/sampler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/sampler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fecbe9dc1e1d57a235a015b2ab8e9098ac107b2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/sampler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/spec_decode_base_sampler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/spec_decode_base_sampler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ee54931f37278605d87e0157822a6e96bfd3888 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/spec_decode_base_sampler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/typical_acceptance_sampler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/typical_acceptance_sampler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3700f1b25b7ccd4572d7ff596537bd5c6a83f1db Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/typical_acceptance_sampler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cbc785f165ae22c79ac6c8b7cf810aedbc5236c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/mamba/__init__.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/mamba/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/triton_utils/custom_cache_manager.py b/.venv/lib/python3.11/site-packages/vllm/triton_utils/custom_cache_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..4163969c9a52df5db3216e2cc45f895eb9baa56f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/triton_utils/custom_cache_manager.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os + +from triton.runtime.cache import (FileCacheManager, default_cache_dir, + default_dump_dir, default_override_dir) + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def maybe_set_triton_cache_manager() -> None: + """Set environment variable to tell Triton to use a + custom cache manager""" + cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None) + if cache_manger is None: + manager = "vllm.triton_utils.custom_cache_manager:CustomCacheManager" + logger.info("Setting Triton cache manager to: %s", manager) + os.environ["TRITON_CACHE_MANAGER"] = manager + + +class CustomCacheManager(FileCacheManager): + """Re-implements Triton's cache manager, ensuring that a + unique cache directory is created for each process. This is + needed to avoid collisions when running with tp>1 and + using multi-processing as the distributed backend. + + Note this issue was fixed by triton-lang/triton/pull/4295, + but the fix is not yet included in triton==v3.0.0. However, + it should be included in the subsequent version. + """ + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = default_dump_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = os.getenv("TRITON_CACHE_DIR", + "").strip() or default_cache_dir() + if self.cache_dir: + self.cache_dir = f"{self.cache_dir}_{os.getpid()}" + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") diff --git a/.venv/lib/python3.11/site-packages/vllm/triton_utils/importing.py b/.venv/lib/python3.11/site-packages/vllm/triton_utils/importing.py new file mode 100644 index 0000000000000000000000000000000000000000..a20700248c26872202e703b3283289b7e99785a9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/triton_utils/importing.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 + +from importlib.util import find_spec + +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +HAS_TRITON = ( + find_spec("triton") is not None + and not current_platform.is_xpu() # Not compatible +) + +if not HAS_TRITON: + logger.info("Triton not installed or not compatible; certain GPU-related" + " functions will not be available.")