koichi12 commited on
Commit
2bdf65d
·
verified ·
1 Parent(s): 7921a79

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/vllm/attention/__init__.py +19 -0
  2. .venv/lib/python3.11/site-packages/vllm/attention/__pycache__/__init__.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/vllm/attention/__pycache__/layer.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/vllm/attention/__pycache__/selector.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/vllm/attention/backends/__init__.py +0 -0
  6. .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/__init__.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/abstract.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/blocksparse_attn.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/flash_attn.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/flashinfer.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/hpu_attn.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/ipex_attn.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/openvino.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/pallas.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/placeholder_attn.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/rocm_flash_attn.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/torch_sdpa.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/triton_mla.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/utils.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/xformers.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/vllm/attention/backends/abstract.py +296 -0
  22. .venv/lib/python3.11/site-packages/vllm/attention/backends/blocksparse_attn.py +457 -0
  23. .venv/lib/python3.11/site-packages/vllm/attention/backends/flash_attn.py +942 -0
  24. .venv/lib/python3.11/site-packages/vllm/attention/backends/flashinfer.py +1066 -0
  25. .venv/lib/python3.11/site-packages/vllm/attention/backends/hpu_attn.py +293 -0
  26. .venv/lib/python3.11/site-packages/vllm/attention/backends/ipex_attn.py +387 -0
  27. .venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__init__.py +0 -0
  28. .venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__pycache__/__init__.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__pycache__/utils.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/vllm/attention/backends/mla/utils.py +541 -0
  31. .venv/lib/python3.11/site-packages/vllm/attention/backends/openvino.py +146 -0
  32. .venv/lib/python3.11/site-packages/vllm/attention/backends/pallas.py +337 -0
  33. .venv/lib/python3.11/site-packages/vllm/attention/backends/placeholder_attn.py +410 -0
  34. .venv/lib/python3.11/site-packages/vllm/attention/backends/rocm_flash_attn.py +891 -0
  35. .venv/lib/python3.11/site-packages/vllm/attention/backends/torch_sdpa.py +681 -0
  36. .venv/lib/python3.11/site-packages/vllm/attention/backends/triton_mla.py +746 -0
  37. .venv/lib/python3.11/site-packages/vllm/attention/backends/utils.py +582 -0
  38. .venv/lib/python3.11/site-packages/vllm/attention/backends/xformers.py +794 -0
  39. .venv/lib/python3.11/site-packages/vllm/attention/ops/__init__.py +0 -0
  40. .venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/__init__.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/hpu_paged_attn.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/ipex_attn.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/nki_flash_attn.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/paged_attn.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/prefix_prefill.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/triton_decode_attention.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/triton_flash_attention.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__init__.py +0 -0
  49. .venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/__init__.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/blocksparse_attention_kernel.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/vllm/attention/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from vllm.attention.backends.abstract import (AttentionBackend,
4
+ AttentionMetadata,
5
+ AttentionMetadataBuilder,
6
+ AttentionState, AttentionType)
7
+ from vllm.attention.layer import Attention
8
+ from vllm.attention.selector import get_attn_backend
9
+
10
+ __all__ = [
11
+ "Attention",
12
+ "AttentionBackend",
13
+ "AttentionMetadata",
14
+ "AttentionType",
15
+ "AttentionMetadataBuilder",
16
+ "Attention",
17
+ "AttentionState",
18
+ "get_attn_backend",
19
+ ]
.venv/lib/python3.11/site-packages/vllm/attention/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (678 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/__pycache__/layer.cpython-311.pyc ADDED
Binary file (16.1 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/__pycache__/selector.cpython-311.pyc ADDED
Binary file (5.85 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/backends/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (196 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/abstract.cpython-311.pyc ADDED
Binary file (14.2 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/blocksparse_attn.cpython-311.pyc ADDED
Binary file (17.4 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/flash_attn.cpython-311.pyc ADDED
Binary file (36.6 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/flashinfer.cpython-311.pyc ADDED
Binary file (44 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/hpu_attn.cpython-311.pyc ADDED
Binary file (13.7 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/ipex_attn.cpython-311.pyc ADDED
Binary file (16 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/openvino.cpython-311.pyc ADDED
Binary file (6.34 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/pallas.cpython-311.pyc ADDED
Binary file (14.8 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/placeholder_attn.cpython-311.pyc ADDED
Binary file (18 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/rocm_flash_attn.cpython-311.pyc ADDED
Binary file (35.3 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/torch_sdpa.cpython-311.pyc ADDED
Binary file (28.4 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/triton_mla.cpython-311.pyc ADDED
Binary file (31.3 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/utils.cpython-311.pyc ADDED
Binary file (25.4 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/backends/__pycache__/xformers.cpython-311.pyc ADDED
Binary file (29 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/backends/abstract.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from abc import ABC, abstractmethod
4
+ from contextlib import contextmanager
5
+ from dataclasses import dataclass, fields
6
+ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
7
+ Protocol, Set, Tuple, Type, TypeVar)
8
+
9
+ import torch
10
+
11
+ from vllm.multimodal import MultiModalPlaceholderMap
12
+
13
+ if TYPE_CHECKING:
14
+ from vllm.worker.model_runner_base import (ModelRunnerBase,
15
+ ModelRunnerInputBase,
16
+ ModelRunnerInputBuilderBase)
17
+
18
+
19
+ class AttentionType:
20
+ """
21
+ Attention type.
22
+ Use string to be compatible with `torch.compile`.
23
+ """
24
+ # Decoder attention between previous layer Q/K/V
25
+ DECODER = "decoder"
26
+ # Encoder attention between previous layer Q/K/V for encoder-decoder
27
+ ENCODER = "encoder"
28
+ # Encoder attention between previous layer Q/K/V
29
+ ENCODER_ONLY = "encoder_only"
30
+ # Attention between dec. Q and enc. K/V for encoder-decoder
31
+ ENCODER_DECODER = "encoder_decoder"
32
+
33
+
34
+ class AttentionBackend(ABC):
35
+ """Abstract class for attention backends."""
36
+ # For some attention backends, we allocate an output tensor before
37
+ # calling the custom op. When piecewise cudagraph is enabled, this
38
+ # makes sure the output tensor is allocated inside the cudagraph.
39
+ accept_output_buffer: bool = False
40
+
41
+ @staticmethod
42
+ @abstractmethod
43
+ def get_name() -> str:
44
+ raise NotImplementedError
45
+
46
+ @staticmethod
47
+ @abstractmethod
48
+ def get_impl_cls() -> Type["AttentionImpl"]:
49
+ raise NotImplementedError
50
+
51
+ @staticmethod
52
+ @abstractmethod
53
+ def get_metadata_cls() -> Type["AttentionMetadata"]:
54
+ raise NotImplementedError
55
+
56
+ @staticmethod
57
+ @abstractmethod
58
+ def get_state_cls() -> Type["AttentionState"]:
59
+ raise NotImplementedError
60
+
61
+ @classmethod
62
+ def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
63
+ return cls.get_metadata_cls()(*args, **kwargs)
64
+
65
+ @staticmethod
66
+ @abstractmethod
67
+ def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
68
+ raise NotImplementedError
69
+
70
+ @staticmethod
71
+ @abstractmethod
72
+ def get_kv_cache_shape(
73
+ num_blocks: int,
74
+ block_size: int,
75
+ num_kv_heads: int,
76
+ head_size: int,
77
+ ) -> Tuple[int, ...]:
78
+ raise NotImplementedError
79
+
80
+ @staticmethod
81
+ @abstractmethod
82
+ def swap_blocks(
83
+ src_kv_cache: torch.Tensor,
84
+ dst_kv_cache: torch.Tensor,
85
+ src_to_dst: torch.Tensor,
86
+ ) -> None:
87
+ raise NotImplementedError
88
+
89
+ @staticmethod
90
+ @abstractmethod
91
+ def copy_blocks(
92
+ kv_caches: List[torch.Tensor],
93
+ src_to_dists: torch.Tensor,
94
+ ) -> None:
95
+ raise NotImplementedError
96
+
97
+ def advance_step(self, model_input: "ModelRunnerInputBase",
98
+ sampled_token_ids: Optional[torch.Tensor],
99
+ block_size: int, num_seqs: int, num_queries: int) -> None:
100
+ raise NotImplementedError
101
+
102
+
103
+ @dataclass
104
+ class AttentionMetadata:
105
+ """Attention metadata for prefill and decode batched together."""
106
+ # Total number of prefill requests.
107
+ num_prefills: int
108
+ # Number of prefill tokens.
109
+ num_prefill_tokens: int
110
+ # Number of decode tokens. Note that it is equivalent to the number of
111
+ # decode requests.
112
+ num_decode_tokens: int
113
+ # (num_tokens,). The indices of the token slots that input tokens will be
114
+ # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
115
+ # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
116
+ # in block 0, and 1st slot in block 1, respectively.
117
+ slot_mapping: torch.Tensor
118
+
119
+ # The index maps that relate multi-modal embeddings to the corresponding
120
+ # placeholders.
121
+ #
122
+ # N.B. These aren't really related to attention and don't belong on this
123
+ # type -- this is just a temporary solution to make them available to
124
+ # `model_executable`.
125
+ multi_modal_placeholder_index_maps: Optional[Dict[
126
+ str, MultiModalPlaceholderMap.IndexMap]]
127
+
128
+ # Enable/disable KV scales calculation. This is so that we can disable the
129
+ # calculation until after prefill and cuda graph capture.
130
+ enable_kv_scales_calculation: bool
131
+
132
+ @property
133
+ @abstractmethod
134
+ def prefill_metadata(self) -> Optional["AttentionMetadata"]:
135
+ """Return the attention metadata that's required to run prefill
136
+ attention."""
137
+ pass
138
+
139
+ @property
140
+ @abstractmethod
141
+ def decode_metadata(self) -> Optional["AttentionMetadata"]:
142
+ """Return the attention metadata that's required to run decode
143
+ attention."""
144
+ pass
145
+
146
+ def asdict_zerocopy(self,
147
+ skip_fields: Optional[Set[str]] = None
148
+ ) -> Dict[str, Any]:
149
+ """Similar to dataclasses.asdict, but avoids deepcopying."""
150
+ if skip_fields is None:
151
+ skip_fields = set()
152
+ # Note that if we add dataclasses as fields, they will need
153
+ # similar handling.
154
+ return {
155
+ field.name: getattr(self, field.name)
156
+ for field in fields(self) if field.name not in skip_fields
157
+ }
158
+
159
+
160
+ T = TypeVar("T", bound=AttentionMetadata)
161
+
162
+
163
+ class AttentionState(ABC, Generic[T]):
164
+ """Holds attention backend-specific objects reused during the
165
+ lifetime of the model runner."""
166
+
167
+ @abstractmethod
168
+ def __init__(self, runner: "ModelRunnerBase"):
169
+ ...
170
+
171
+ @abstractmethod
172
+ @contextmanager
173
+ def graph_capture(self, max_batch_size: int):
174
+ """Context manager used when capturing CUDA graphs."""
175
+ yield
176
+
177
+ @abstractmethod
178
+ def graph_clone(self, batch_size: int) -> "AttentionState[T]":
179
+ """Clone attention state to save in CUDA graph metadata."""
180
+ ...
181
+
182
+ @abstractmethod
183
+ def graph_capture_get_metadata_for_batch(
184
+ self,
185
+ batch_size: int,
186
+ is_encoder_decoder_model: bool = False) -> T:
187
+ """Get attention metadata for CUDA graph capture of batch_size."""
188
+ ...
189
+
190
+ @abstractmethod
191
+ def get_graph_input_buffers(
192
+ self,
193
+ attn_metadata: T,
194
+ is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
195
+ """Get attention-specific input buffers for CUDA graph capture."""
196
+ ...
197
+
198
+ @abstractmethod
199
+ def prepare_graph_input_buffers(
200
+ self,
201
+ input_buffers: Dict[str, Any],
202
+ attn_metadata: T,
203
+ is_encoder_decoder_model: bool = False) -> None:
204
+ """In-place modify input buffers dict for CUDA graph replay."""
205
+ ...
206
+
207
+ @abstractmethod
208
+ def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
209
+ """Prepare state for forward pass."""
210
+ ...
211
+
212
+
213
+ class AttentionMetadataBuilder(ABC, Generic[T]):
214
+ """Abstract class for attention metadata builders."""
215
+
216
+ @abstractmethod
217
+ def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
218
+ """Create the builder, remember some configuration and parameters."""
219
+ raise NotImplementedError
220
+
221
+ @abstractmethod
222
+ def prepare(self) -> None:
223
+ """Prepare for one batch."""
224
+ raise NotImplementedError
225
+
226
+ @abstractmethod
227
+ def build(self, seq_lens: List[int], query_lens: List[int],
228
+ cuda_graph_pad_size: int, batch_size: int) -> T:
229
+ """Build attention metadata with on-device tensors."""
230
+ raise NotImplementedError
231
+
232
+
233
+ class AttentionLayer(Protocol):
234
+
235
+ _k_scale: torch.Tensor
236
+ _v_scale: torch.Tensor
237
+ _k_scale_float: float
238
+ _v_scale_float: float
239
+
240
+ def forward(
241
+ self,
242
+ query: torch.Tensor,
243
+ key: torch.Tensor,
244
+ value: torch.Tensor,
245
+ kv_cache: torch.Tensor,
246
+ attn_metadata: AttentionMetadata,
247
+ ) -> torch.Tensor:
248
+ ...
249
+
250
+
251
+ class AttentionImpl(ABC, Generic[T]):
252
+
253
+ @abstractmethod
254
+ def __init__(
255
+ self,
256
+ num_heads: int,
257
+ head_size: int,
258
+ scale: float,
259
+ num_kv_heads: Optional[int] = None,
260
+ alibi_slopes: Optional[List[float]] = None,
261
+ sliding_window: Optional[int] = None,
262
+ kv_cache_dtype: str = "auto",
263
+ blocksparse_params: Optional[Dict[str, Any]] = None,
264
+ logits_soft_cap: Optional[float] = None,
265
+ attn_type: str = AttentionType.DECODER,
266
+ ) -> None:
267
+ raise NotImplementedError
268
+
269
+ @abstractmethod
270
+ def forward(
271
+ self,
272
+ layer: AttentionLayer,
273
+ query: torch.Tensor,
274
+ key: torch.Tensor,
275
+ value: torch.Tensor,
276
+ kv_cache: torch.Tensor,
277
+ attn_metadata: T,
278
+ output: Optional[torch.Tensor] = None,
279
+ ) -> torch.Tensor:
280
+ raise NotImplementedError
281
+
282
+
283
+ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
284
+
285
+ @abstractmethod
286
+ def forward(
287
+ self,
288
+ layer: AttentionLayer,
289
+ hidden_states_or_cq: torch.Tensor,
290
+ kv_c_normed: torch.Tensor,
291
+ k_pe: torch.Tensor,
292
+ kv_cache: torch.Tensor,
293
+ attn_metadata: T,
294
+ output: Optional[torch.Tensor] = None,
295
+ ) -> torch.Tensor:
296
+ raise NotImplementedError
.venv/lib/python3.11/site-packages/vllm/attention/backends/blocksparse_attn.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List, Optional, Tuple, Type
5
+
6
+ import torch
7
+
8
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
9
+ AttentionLayer,
10
+ AttentionMetadata, AttentionType)
11
+ from vllm.attention.backends.utils import (CommonAttentionState,
12
+ CommonMetadataBuilder)
13
+ from vllm.attention.ops.blocksparse_attention.interface import (
14
+ LocalStridedBlockSparseAttn, get_head_sliding_step)
15
+ from vllm.attention.ops.paged_attn import PagedAttention
16
+ from vllm.distributed import (get_tensor_model_parallel_rank,
17
+ get_tensor_model_parallel_world_size)
18
+
19
+
20
+ @dataclass
21
+ class BlocksparseParams:
22
+ max_seqlen: int
23
+
24
+ # Num q heads per tensor-parallel rank/partition
25
+ num_heads: int # per TP partition
26
+ # Num kv heads per tensor-parallel rank/partition
27
+ num_kv_heads: int
28
+
29
+ # block size used for blocksparse attention.
30
+ # This is the block_size used in `local_blocks`, `vert_stride`.
31
+ block_size: int
32
+
33
+ # Number of blocks for local attention, i.e., number of
34
+ # local attended tokens / `sparse_block_size`
35
+ local_blocks: int
36
+
37
+ # Attend to one block per every `vert_stride` blocks.
38
+ # Controlling the sparsity
39
+ vert_stride: int
40
+ """
41
+ If to use the same vertical stride offset for all heads,
42
+ i.e., attend to the same block of tokens on all heads.
43
+ By default, it is False, i.e., attention on the non-local
44
+ blocks depends on the `head_idx`, that is on
45
+ blocks satisfying
46
+ `(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0`
47
+ where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`,
48
+ `block_idx = position_id // sparse_block_size`.
49
+ See `..ops.blocksparse_attention.utils:get_sparse_attn_mask`
50
+ for more detail.
51
+ """
52
+ homo_head: bool = False
53
+
54
+ # If within a group, the kv offsets that each q attends is the same or no.
55
+ homo_head_group: bool = False
56
+
57
+ # Decided by homo_head and homo_head group
58
+ head_sliding_step: int = field(init=False)
59
+
60
+ # range of q heads to for a TP rank
61
+ active_head_range: Tuple = field(init=False)
62
+
63
+ def __post_init__(self):
64
+ assert self.block_size > 0
65
+ assert self.local_blocks >= 0
66
+ assert self.vert_stride >= 1
67
+ assert self.num_heads % self.num_kv_heads == 0
68
+
69
+ tp_size = get_tensor_model_parallel_world_size()
70
+ tp_rank = get_tensor_model_parallel_rank()
71
+ total_heads = tp_size * self.num_heads
72
+ total_kv_heads = tp_size * self.num_kv_heads
73
+
74
+ if self.homo_head:
75
+ self.head_sliding_step = 0
76
+ elif self.homo_head_group:
77
+ head_sliding_step = get_head_sliding_step(total_kv_heads,
78
+ self.vert_stride)
79
+ # negative indicates sliding along kv heads, i.e., homo q group
80
+ self.head_sliding_step = -head_sliding_step
81
+ else:
82
+ self.head_sliding_step = get_head_sliding_step(
83
+ total_heads, self.vert_stride)
84
+
85
+ self.active_head_range = (
86
+ tp_rank * self.num_heads,
87
+ (tp_rank + 1) * self.num_heads,
88
+ )
89
+
90
+
91
+ class BlocksparseFlashAttentionBackend(AttentionBackend):
92
+
93
+ @staticmethod
94
+ def get_name() -> str:
95
+ return "BLOCK_SPARSE_FLASH_ATTN"
96
+
97
+ @staticmethod
98
+ def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
99
+ return BlocksparseFlashAttentionImpl
100
+
101
+ @staticmethod
102
+ def get_metadata_cls() -> Type["AttentionMetadata"]:
103
+ return BlocksparseFlashAttentionMetadata
104
+
105
+ @staticmethod
106
+ def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
107
+ return BlocksparseFlashAttentionMetadataBuilder
108
+
109
+ @staticmethod
110
+ def get_state_cls() -> Type["CommonAttentionState"]:
111
+ return CommonAttentionState
112
+
113
+ @staticmethod
114
+ def get_kv_cache_shape(
115
+ num_blocks: int,
116
+ block_size: int,
117
+ num_kv_heads: int,
118
+ head_size: int,
119
+ ) -> Tuple[int, ...]:
120
+ return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
121
+ num_kv_heads, head_size)
122
+
123
+ @staticmethod
124
+ def swap_blocks(
125
+ src_kv_cache: torch.Tensor,
126
+ dst_kv_cache: torch.Tensor,
127
+ src_to_dst: Dict[int, int],
128
+ ) -> None:
129
+ PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
130
+
131
+ @staticmethod
132
+ def copy_blocks(
133
+ kv_caches: List[torch.Tensor],
134
+ src_to_dists: Dict[int, List[int]],
135
+ ) -> None:
136
+ PagedAttention.copy_blocks(kv_caches, src_to_dists)
137
+
138
+
139
+ @dataclass
140
+ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
141
+ """A copy of Metadata for FlashAttentionBackend,
142
+ to avoid having to install flash_attn.
143
+
144
+ NOTE: Any python object stored here is not updated when it is
145
+ cuda-graph replayed. If you have values that need to be changed
146
+ dynamically, it should be stored in tensor. The tensor has to be
147
+ updated from `CUDAGraphRunner.forward` API.
148
+ """
149
+ # (batch_size,). The sequence length per sequence. Sequence length means
150
+ # the computed tokens + new tokens None if it is a decoding.
151
+ seq_lens: Optional[List[int]]
152
+ # seq_lens stored as a tensor.
153
+ seq_lens_tensor: Optional[torch.Tensor]
154
+
155
+ # NOTE(sang): Definition of context_len, query_len, and seq_len.
156
+ # |---------- N-1 iteration --------|
157
+ # |---------------- N iteration ---------------------|
158
+ # |- tokenA -|......................|-- newTokens ---|
159
+ # |---------- context_len ----------|
160
+ # |-------------------- seq_len ----------------------|
161
+ # |-- query_len ---|
162
+
163
+ # Maximum query length in the batch. None for decoding.
164
+ max_query_len: Optional[int]
165
+ # Maximum sequence length among prefill batch. 0 if there are decoding
166
+ # requests only.
167
+ max_prefill_seq_len: int
168
+ # Maximum sequence length among decode batch. 0 if there are prefill
169
+ # requests only.
170
+ max_decode_seq_len: int
171
+ # (batch_size + 1,). The cumulative subquery lengths of the sequences in
172
+ # the batch, used to index into subquery. E.g., if the subquery length
173
+ # is [4, 6], it is [0, 4, 10].
174
+ query_start_loc: Optional[torch.Tensor]
175
+ # (batch_size + 1,). The cumulative sequence lengths of the sequences in
176
+ # the batch, used to index into sequence. E.g., if the sequence length is
177
+ # [4, 6], it is [0, 4, 10].
178
+ seq_start_loc: Optional[torch.Tensor]
179
+ # (batch_size,) A tensor of context lengths (tokens that are computed
180
+ # so far).
181
+ context_lens_tensor: Optional[torch.Tensor]
182
+
183
+ # (batch_size, max_blocks_per_seq).
184
+ # Block addresses per sequence. (Seq id -> list of physical block)
185
+ # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
186
+ # in the kv cache. Each block can contain up to block_size tokens.
187
+ # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
188
+ # captured.
189
+ block_tables: Optional[torch.Tensor]
190
+
191
+ # Whether or not if cuda graph is enabled.
192
+ # Cuda-graph is currently enabled for decoding only.
193
+ # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
194
+ use_cuda_graph: bool
195
+
196
+ # Max number of query tokens for among request in the batch.
197
+ max_decode_query_len: Optional[int] = None
198
+
199
+ _cached_prefill_metadata: Optional[
200
+ "BlocksparseFlashAttentionMetadata"] = None
201
+ _cached_decode_metadata: Optional[
202
+ "BlocksparseFlashAttentionMetadata"] = None
203
+
204
+ @property
205
+ def prefill_metadata(
206
+ self) -> Optional["BlocksparseFlashAttentionMetadata"]:
207
+ if self.num_prefills == 0:
208
+ return None
209
+
210
+ if self._cached_prefill_metadata is not None:
211
+ return self._cached_prefill_metadata
212
+
213
+ assert self.seq_lens is not None
214
+ assert self.seq_lens_tensor is not None
215
+ assert self.query_start_loc is not None
216
+ assert self.context_lens_tensor is not None
217
+ assert self.block_tables is not None
218
+ assert self.seq_start_loc is not None
219
+
220
+ self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata(
221
+ num_prefills=self.num_prefills,
222
+ num_prefill_tokens=self.num_prefill_tokens,
223
+ num_decode_tokens=0,
224
+ slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
225
+ multi_modal_placeholder_index_maps=self.
226
+ multi_modal_placeholder_index_maps,
227
+ enable_kv_scales_calculation=self.enable_kv_scales_calculation,
228
+ seq_lens=self.seq_lens[:self.num_prefills],
229
+ seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
230
+ max_query_len=self.max_query_len,
231
+ max_prefill_seq_len=self.max_prefill_seq_len,
232
+ max_decode_seq_len=0,
233
+ query_start_loc=self.query_start_loc[:self.num_prefills + 1],
234
+ seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
235
+ context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
236
+ block_tables=self.block_tables[:self.num_prefills],
237
+ use_cuda_graph=False,
238
+ )
239
+ return self._cached_prefill_metadata
240
+
241
+ @property
242
+ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
243
+ if self.num_decode_tokens == 0:
244
+ return None
245
+
246
+ if self._cached_decode_metadata is not None:
247
+ return self._cached_decode_metadata
248
+ assert self.block_tables is not None
249
+ assert self.seq_lens_tensor is not None
250
+
251
+ self._cached_decode_metadata = BlocksparseFlashAttentionMetadata(
252
+ num_prefills=0,
253
+ num_prefill_tokens=0,
254
+ num_decode_tokens=self.num_decode_tokens,
255
+ slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
256
+ multi_modal_placeholder_index_maps=None,
257
+ enable_kv_scales_calculation=False,
258
+ seq_lens=None,
259
+ seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
260
+ max_query_len=None,
261
+ max_prefill_seq_len=0,
262
+ max_decode_seq_len=self.max_decode_seq_len,
263
+ query_start_loc=None,
264
+ seq_start_loc=None,
265
+ context_lens_tensor=None,
266
+ block_tables=self.block_tables[self.num_prefills:],
267
+ use_cuda_graph=self.use_cuda_graph,
268
+ )
269
+ return self._cached_decode_metadata
270
+
271
+
272
+ class BlocksparseFlashAttentionMetadataBuilder(
273
+ CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]):
274
+
275
+ _metadata_cls = BlocksparseFlashAttentionMetadata
276
+
277
+
278
+ class BlocksparseFlashAttentionImpl(AttentionImpl):
279
+ """
280
+ If the input tensors contain prompt tokens, the layout is as follows:
281
+ |<--------------- num_prompt_tokens -------------->|
282
+ |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
283
+
284
+ Otherwise, the layout is as follows:
285
+ |<------------------ num_generation_tokens (M) ----------------->|
286
+ |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
287
+
288
+ Generation tokens can contain padding when cuda-graph is used.
289
+ Currently, prompt tokens don't contain any padding.
290
+
291
+ The prompts might have different lengths, while the generation tokens
292
+ always have length 1.
293
+
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ num_heads: int,
299
+ head_size: int,
300
+ scale: float,
301
+ num_kv_heads: int,
302
+ alibi_slopes: Optional[List[float]],
303
+ sliding_window: Optional[int],
304
+ kv_cache_dtype: str,
305
+ blocksparse_params: Optional[Dict[str, Any]] = None,
306
+ logits_soft_cap: Optional[float] = None,
307
+ attn_type: str = AttentionType.DECODER,
308
+ ) -> None:
309
+ assert blocksparse_params is not None
310
+ assert alibi_slopes is None, ValueError(
311
+ "Alibi not support for blocksparse flash attention.")
312
+ assert sliding_window is None, ValueError(
313
+ "sliding_window is invalid for blocksparse attention.")
314
+ assert logits_soft_cap is None, ValueError(
315
+ "logits_soft_cap is invalid for blocksparse attention.")
316
+
317
+ if "num_heads" not in blocksparse_params:
318
+ blocksparse_params["num_heads"] = num_heads
319
+ if "num_kv_heads" not in blocksparse_params:
320
+ blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads
321
+ self.blocksparse_params = BlocksparseParams(**blocksparse_params)
322
+ self.kv_cache_dtype = kv_cache_dtype
323
+
324
+ self.num_heads = num_heads
325
+ self.head_size = head_size
326
+ self.scale = float(scale)
327
+ self.alibi_slopes = alibi_slopes
328
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
329
+
330
+ assert self.num_heads % self.num_kv_heads == 0
331
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
332
+
333
+ self.local_blocks = self.blocksparse_params.local_blocks
334
+ self.vert_stride = self.blocksparse_params.vert_stride
335
+ self.sparse_block_size = self.blocksparse_params.block_size
336
+ self.head_sliding_step = self.blocksparse_params.head_sliding_step
337
+
338
+ suppored_head_sizes = PagedAttention.get_supported_head_sizes()
339
+ if head_size not in suppored_head_sizes:
340
+ raise ValueError(
341
+ f"Head size {head_size} is not supported by PagedAttention. "
342
+ f"Supported head sizes are: {suppored_head_sizes}.")
343
+
344
+ self.tp_size = get_tensor_model_parallel_world_size()
345
+ self.tp_rank = get_tensor_model_parallel_rank()
346
+
347
+ total_num_heads = num_heads * self.tp_size
348
+ self.bs_attn = LocalStridedBlockSparseAttn(
349
+ total_num_heads,
350
+ self.blocksparse_params.max_seqlen,
351
+ self.blocksparse_params.local_blocks,
352
+ self.blocksparse_params.vert_stride,
353
+ self.blocksparse_params.block_size,
354
+ homo_head=self.blocksparse_params.homo_head,
355
+ active_head_range=self.blocksparse_params.active_head_range,
356
+ )
357
+
358
+ if attn_type != AttentionType.DECODER:
359
+ raise NotImplementedError("Encoder self-attention and "
360
+ "encoder/decoder cross-attention "
361
+ "are not implemented for "
362
+ "BlocksparseFlashAttentionImpl")
363
+
364
+ def forward(
365
+ self,
366
+ layer: AttentionLayer,
367
+ query: torch.Tensor,
368
+ key: torch.Tensor,
369
+ value: torch.Tensor,
370
+ kv_cache: torch.Tensor,
371
+ attn_metadata: BlocksparseFlashAttentionMetadata,
372
+ output: Optional[torch.Tensor] = None,
373
+ ) -> torch.Tensor:
374
+ """Forward pass with FlashAttention and PagedAttention.
375
+
376
+ Args:
377
+ query: shape = [num_tokens, num_heads * head_size]
378
+ key: shape = [num_tokens, num_kv_heads * head_size]
379
+ value: shape = [num_tokens, num_kv_heads * head_size]
380
+ kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
381
+ NOTE: kv_cache will be an empty tensor with shape [0]
382
+ for profiling run.
383
+ attn_metadata: Metadata for attention.
384
+ Returns:
385
+ shape = [num_tokens, num_heads * head_size]
386
+ """
387
+ num_tokens, hidden_size = query.shape
388
+ # Reshape the query, key, and value tensors.
389
+ query = query.view(-1, self.num_heads, self.head_size)
390
+ key = key.view(-1, self.num_kv_heads, self.head_size)
391
+ value = value.view(-1, self.num_kv_heads, self.head_size)
392
+
393
+ if kv_cache.numel() > 0:
394
+ key_cache, value_cache = PagedAttention.split_kv_cache(
395
+ kv_cache, self.num_kv_heads, self.head_size)
396
+
397
+ # Reshape the input keys and values and store them in the cache.
398
+ # If kv_cache is not provided, the new key and value tensors are
399
+ # not cached. This happens during the initial memory profiling run.
400
+
401
+ PagedAttention.write_to_paged_cache(
402
+ key,
403
+ value,
404
+ key_cache,
405
+ value_cache,
406
+ attn_metadata.slot_mapping,
407
+ self.kv_cache_dtype,
408
+ layer._k_scale,
409
+ layer._v_scale,
410
+ )
411
+
412
+ if prefill_meta := attn_metadata.prefill_metadata:
413
+
414
+ # Prompt run.
415
+ # normal attention
416
+ # When block_tables are not filled, it means q and k are the
417
+ # prompt, and they have the same length.
418
+
419
+ assert kv_cache.numel() == 0 \
420
+ or prefill_meta.block_tables is None \
421
+ or prefill_meta.block_tables.numel() == 0, \
422
+ "Does not support prefix-enabled attention."
423
+
424
+ output = self.bs_attn(
425
+ q=query,
426
+ k=key,
427
+ v=value,
428
+ cu_seqlens_q=prefill_meta.seq_start_loc,
429
+ cu_seqlens_k=prefill_meta.seq_start_loc,
430
+ sm_scale=self.scale,
431
+ )
432
+
433
+ if decode_meta := attn_metadata.decode_metadata:
434
+ # Decoding run.
435
+ output = PagedAttention.forward_decode(
436
+ query,
437
+ key_cache,
438
+ value_cache,
439
+ decode_meta.block_tables,
440
+ decode_meta.seq_lens_tensor,
441
+ self.blocksparse_params.max_seqlen,
442
+ self.kv_cache_dtype,
443
+ self.num_kv_heads,
444
+ self.scale,
445
+ self.alibi_slopes,
446
+ layer._k_scale,
447
+ layer._v_scale,
448
+ tp_rank=self.tp_rank,
449
+ blocksparse_local_blocks=self.local_blocks,
450
+ blocksparse_vert_stride=self.vert_stride,
451
+ blocksparse_block_size=self.sparse_block_size,
452
+ blocksparse_head_sliding_step=self.head_sliding_step,
453
+ )
454
+
455
+ assert output is not None
456
+ # Reshape the output tensor.
457
+ return output.view(num_tokens, hidden_size)
.venv/lib/python3.11/site-packages/vllm/attention/backends/flash_attn.py ADDED
@@ -0,0 +1,942 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Attention layer with FlashAttention."""
3
+ from collections import defaultdict
4
+ from dataclasses import dataclass
5
+ from itertools import accumulate
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
7
+
8
+ import torch
9
+
10
+ from vllm import _custom_ops as ops
11
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
12
+ AttentionLayer,
13
+ AttentionMetadata,
14
+ AttentionMetadataBuilder,
15
+ AttentionType)
16
+ from vllm.attention.backends.utils import (
17
+ PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
18
+ compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
19
+ get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
20
+ is_all_encoder_attn_metadata_set, is_block_tables_empty)
21
+ from vllm.envs import VLLM_FLASH_ATTN_VERSION
22
+ from vllm.logger import init_logger
23
+ from vllm.multimodal import MultiModalPlaceholderMap
24
+ from vllm.platforms import current_platform
25
+ from vllm.utils import async_tensor_h2d, make_tensor_with_pad
26
+ from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
27
+ flash_attn_varlen_func,
28
+ flash_attn_with_kvcache,
29
+ is_fa_version_supported)
30
+
31
+ if TYPE_CHECKING:
32
+ from vllm.worker.model_runner import (ModelInputForGPUBuilder,
33
+ ModelInputForGPUWithSamplingMetadata)
34
+
35
+ logger = init_logger(__name__)
36
+
37
+
38
+ class FlashAttentionBackend(AttentionBackend):
39
+
40
+ accept_output_buffer: bool = True
41
+
42
+ @staticmethod
43
+ def get_supported_head_sizes() -> List[int]:
44
+ return [32, 64, 96, 128, 160, 192, 224, 256]
45
+
46
+ @staticmethod
47
+ def get_name() -> str:
48
+ return "FLASH_ATTN"
49
+
50
+ @staticmethod
51
+ def get_impl_cls() -> Type["FlashAttentionImpl"]:
52
+ return FlashAttentionImpl
53
+
54
+ @staticmethod
55
+ def get_metadata_cls() -> Type["AttentionMetadata"]:
56
+ return FlashAttentionMetadata
57
+
58
+ @staticmethod
59
+ def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
60
+ return FlashAttentionMetadataBuilder
61
+
62
+ @staticmethod
63
+ def get_state_cls() -> Type["CommonAttentionState"]:
64
+ return CommonAttentionState
65
+
66
+ @staticmethod
67
+ def get_kv_cache_shape(
68
+ num_blocks: int,
69
+ block_size: int,
70
+ num_kv_heads: int,
71
+ head_size: int,
72
+ ) -> Tuple[int, ...]:
73
+ if block_size % 16 != 0:
74
+ raise ValueError("Block size must be a multiple of 16.")
75
+ return (2, num_blocks, block_size, num_kv_heads, head_size)
76
+
77
+ @staticmethod
78
+ def swap_blocks(
79
+ src_kv_cache: torch.Tensor,
80
+ dst_kv_cache: torch.Tensor,
81
+ src_to_dst: torch.Tensor,
82
+ ) -> None:
83
+ src_key_cache = src_kv_cache[0]
84
+ dst_key_cache = dst_kv_cache[0]
85
+ ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
86
+ src_value_cache = src_kv_cache[1]
87
+ dst_value_cache = dst_kv_cache[1]
88
+ ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
89
+
90
+ @staticmethod
91
+ def copy_blocks(
92
+ kv_caches: List[torch.Tensor],
93
+ src_to_dists: torch.Tensor,
94
+ ) -> None:
95
+ key_caches = [kv_cache[0] for kv_cache in kv_caches]
96
+ value_caches = [kv_cache[1] for kv_cache in kv_caches]
97
+
98
+ ops.copy_blocks(key_caches, value_caches, src_to_dists)
99
+
100
+
101
+ @dataclass
102
+ class FlashAttentionMetadata(AttentionMetadata):
103
+ """Metadata for FlashAttentionBackend.
104
+
105
+ NOTE: Any python object stored here is not updated when it is
106
+ cuda-graph replayed. If you have values that need to be changed
107
+ dynamically, it should be stored in tensor. The tensor has to be
108
+ updated from `CUDAGraphRunner.forward` API.
109
+ """
110
+ # (batch_size,). The sequence length per sequence. Sequence length means
111
+ # the computed tokens + new tokens None if it is a decoding.
112
+ seq_lens: Optional[List[int]]
113
+ # seq_lens stored as a tensor.
114
+ seq_lens_tensor: Optional[torch.Tensor]
115
+
116
+ # NOTE(sang): Definition of context_len, query_len, and seq_len.
117
+ # |---------- N-1 iteration --------|
118
+ # |---------------- N iteration ---------------------|
119
+ # |- tokenA -|......................|-- newTokens ---|
120
+ # |---------- context_len ----------|
121
+ # |-------------------- seq_len ---------------------|
122
+ # |-- query_len ---|
123
+
124
+ # Maximum sequence length among prefill batch. 0 if there are decoding
125
+ # requests only.
126
+ max_prefill_seq_len: int
127
+ # Maximum sequence length among decode batch. 0 if there are prefill
128
+ # requests only.
129
+ max_decode_seq_len: int
130
+ # (batch_size,) A tensor of context lengths (tokens that are computed
131
+ # so far).
132
+ context_lens_tensor: Optional[torch.Tensor]
133
+
134
+ # (batch_size, max_blocks_per_seq).
135
+ # Block addresses per sequence. (Seq id -> list of physical block)
136
+ # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
137
+ # in the kv cache. Each block can contain up to block_size tokens.
138
+ # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
139
+ # captured.
140
+ block_tables: Optional[torch.Tensor]
141
+
142
+ # Whether or not if cuda graph is enabled.
143
+ # Cuda-graph is currently enabled for decoding only.
144
+ # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
145
+
146
+ use_cuda_graph: bool
147
+
148
+ # Maximum query length in the batch.
149
+ max_query_len: Optional[int] = None
150
+
151
+ # Max number of query tokens among request in the batch.
152
+ max_decode_query_len: Optional[int] = None
153
+
154
+ # (batch_size + 1,). The cumulative subquery lengths of the sequences in
155
+ # the batch, used to index into subquery. E.g., if the subquery length
156
+ # is [4, 6], it is [0, 4, 10].
157
+ query_start_loc: Optional[torch.Tensor] = None
158
+ # (batch_size + 1,). The cumulative sequence lengths of the sequences in
159
+ # the batch, used to index into sequence. E.g., if the sequence length is
160
+ # [4, 6], it is [0, 4, 10].
161
+ seq_start_loc: Optional[torch.Tensor] = None
162
+
163
+ _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
164
+ _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
165
+
166
+ # Begin encoder attn & enc/dec cross-attn fields...
167
+
168
+ # Encoder sequence lengths representation
169
+ encoder_seq_lens: Optional[List[int]] = None
170
+ encoder_seq_lens_tensor: Optional[torch.Tensor] = None
171
+ # (batch_size + 1,). The cumulative sequence lengths of the sequences in
172
+ # the batch, used to index into sequence. E.g., if the sequence length is
173
+ # [4, 6], it is [0, 4, 10].
174
+ encoder_seq_start_loc: Optional[torch.Tensor] = None
175
+ # Maximum sequence length among encoder sequences
176
+ max_encoder_seq_len: Optional[int] = None
177
+ # Number of tokens input to encoder
178
+ num_encoder_tokens: Optional[int] = None
179
+
180
+ # Cross-attention memory-mapping data structures: slot mapping
181
+ # and block tables
182
+ cross_slot_mapping: Optional[torch.Tensor] = None
183
+ cross_block_tables: Optional[torch.Tensor] = None
184
+
185
+ @property
186
+ def is_all_encoder_attn_metadata_set(self):
187
+ '''
188
+ All attention metadata required for encoder attention is set.
189
+ '''
190
+ return is_all_encoder_attn_metadata_set(self)
191
+
192
+ @property
193
+ def is_all_cross_attn_metadata_set(self):
194
+ '''
195
+ All attention metadata required for enc/dec cross-attention is set.
196
+
197
+ Superset of encoder attention required metadata.
198
+ '''
199
+ return is_all_cross_attn_metadata_set(self)
200
+
201
+ @property
202
+ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
203
+ if self.num_prefills == 0:
204
+ return None
205
+
206
+ if self._cached_prefill_metadata is not None:
207
+ return self._cached_prefill_metadata
208
+
209
+ assert ((self.seq_lens is not None)
210
+ or (self.encoder_seq_lens is not None))
211
+ assert ((self.seq_lens_tensor is not None)
212
+ or (self.encoder_seq_lens_tensor is not None))
213
+
214
+ # Compute some attn_metadata fields which default to None
215
+ query_start_loc = (None if self.query_start_loc is None else
216
+ self.query_start_loc[:self.num_prefills + 1])
217
+ slot_mapping = (None if self.slot_mapping is None else
218
+ self.slot_mapping[:self.num_prefill_tokens])
219
+ seq_lens = (None if self.seq_lens is None else
220
+ self.seq_lens[:self.num_prefills])
221
+ seq_lens_tensor = (None if self.seq_lens_tensor is None else
222
+ self.seq_lens_tensor[:self.num_prefills])
223
+ seq_start_loc = (None if self.seq_start_loc is None else
224
+ self.seq_start_loc[:self.num_prefills + 1])
225
+ context_lens_tensor = (None if self.context_lens_tensor is None else
226
+ self.context_lens_tensor[:self.num_prefills])
227
+ block_tables = (None if self.block_tables is None else
228
+ self.block_tables[:self.num_prefills])
229
+
230
+ self._cached_prefill_metadata = FlashAttentionMetadata(
231
+ num_prefills=self.num_prefills,
232
+ num_prefill_tokens=self.num_prefill_tokens,
233
+ num_decode_tokens=0,
234
+ slot_mapping=slot_mapping,
235
+ multi_modal_placeholder_index_maps=self.
236
+ multi_modal_placeholder_index_maps,
237
+ enable_kv_scales_calculation=self.enable_kv_scales_calculation,
238
+ seq_lens=seq_lens,
239
+ seq_lens_tensor=seq_lens_tensor,
240
+ max_query_len=self.max_query_len,
241
+ max_prefill_seq_len=self.max_prefill_seq_len,
242
+ max_decode_query_len=0,
243
+ max_decode_seq_len=0,
244
+ query_start_loc=query_start_loc,
245
+ seq_start_loc=seq_start_loc,
246
+ context_lens_tensor=context_lens_tensor,
247
+ block_tables=block_tables,
248
+ use_cuda_graph=False,
249
+ # Begin encoder & cross attn fields below...
250
+ encoder_seq_lens=self.encoder_seq_lens,
251
+ encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
252
+ encoder_seq_start_loc=self.encoder_seq_start_loc,
253
+ max_encoder_seq_len=self.max_encoder_seq_len,
254
+ cross_slot_mapping=self.cross_slot_mapping,
255
+ cross_block_tables=self.cross_block_tables)
256
+ return self._cached_prefill_metadata
257
+
258
+ @property
259
+ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
260
+ if self.num_decode_tokens == 0:
261
+ return None
262
+
263
+ if self._cached_decode_metadata is not None:
264
+ return self._cached_decode_metadata
265
+ assert ((self.seq_lens_tensor is not None)
266
+ or (self.encoder_seq_lens_tensor is not None))
267
+
268
+ # Compute some attn_metadata fields which default to None
269
+ slot_mapping = (None if self.slot_mapping is None else
270
+ self.slot_mapping[self.num_prefill_tokens:])
271
+ seq_lens_tensor = (None if self.seq_lens_tensor is None else
272
+ self.seq_lens_tensor[self.num_prefills:])
273
+ block_tables = (None if self.block_tables is None else
274
+ self.block_tables[self.num_prefills:])
275
+
276
+ self._cached_decode_metadata = FlashAttentionMetadata(
277
+ num_prefills=0,
278
+ num_prefill_tokens=0,
279
+ num_decode_tokens=self.num_decode_tokens,
280
+ slot_mapping=slot_mapping,
281
+ multi_modal_placeholder_index_maps=None,
282
+ enable_kv_scales_calculation=True,
283
+ seq_lens=None,
284
+ seq_lens_tensor=seq_lens_tensor,
285
+ max_decode_query_len=self.max_decode_query_len,
286
+ max_query_len=self.max_query_len,
287
+ max_prefill_seq_len=0,
288
+ max_decode_seq_len=self.max_decode_seq_len,
289
+ # Batch may be composed of prefill|decodes, adjust query start
290
+ # indices to refer to the start of decodes. E.g.
291
+ # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
292
+ query_start_loc=(self.query_start_loc[self.num_prefills:] -
293
+ self.query_start_loc[self.num_prefills])
294
+ if self.query_start_loc is not None else None,
295
+ seq_start_loc=self.seq_start_loc[self.num_prefills:]
296
+ if self.seq_start_loc is not None else None,
297
+ context_lens_tensor=None,
298
+ block_tables=block_tables,
299
+ use_cuda_graph=self.use_cuda_graph,
300
+ # Begin encoder & cross attn fields below...
301
+ encoder_seq_lens=self.encoder_seq_lens,
302
+ encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
303
+ encoder_seq_start_loc=self.encoder_seq_start_loc,
304
+ max_encoder_seq_len=self.max_encoder_seq_len,
305
+ cross_slot_mapping=self.cross_slot_mapping,
306
+ cross_block_tables=self.cross_block_tables)
307
+ return self._cached_decode_metadata
308
+
309
+ def advance_step(self,
310
+ model_input: "ModelInputForGPUWithSamplingMetadata",
311
+ sampled_token_ids: Optional[torch.Tensor],
312
+ block_size: int,
313
+ num_seqs: int,
314
+ num_queries: int,
315
+ turn_prefills_into_decodes: bool = False):
316
+ """
317
+ Update metadata in-place to advance one decode step.
318
+ """
319
+ # When using cudagraph, the num_seqs is padded to the next captured
320
+ # batch sized, but num_queries tracks the actual number of requests in
321
+ # the batch. For --enforce-eager mode, num_seqs == num_queries
322
+ if num_seqs != num_queries:
323
+ assert num_seqs > num_queries
324
+ assert self.use_cuda_graph
325
+
326
+ if turn_prefills_into_decodes:
327
+ # When Mutli-Step is enabled with Chunked-Prefill, prefills and
328
+ # decodes are scheduled together. In the first step, all the
329
+ # prefills turn into decodes. This update reflects that
330
+ # conversion.
331
+ assert self.num_decode_tokens + self.num_prefills == num_seqs
332
+ self.num_decode_tokens += self.num_prefills
333
+ self.num_prefills = 0
334
+ self.num_prefill_tokens = 0
335
+ self.max_prefill_seq_len = 0
336
+ self.max_query_len = 1
337
+
338
+ self.slot_mapping = self.slot_mapping[:num_seqs]
339
+ else:
340
+ assert self.seq_lens is not None
341
+ assert self.max_decode_seq_len == max(self.seq_lens)
342
+
343
+ assert self.num_prefills == 0
344
+ assert self.num_prefill_tokens == 0
345
+ assert self.num_decode_tokens == num_seqs
346
+ assert self.slot_mapping.shape == (num_seqs, )
347
+
348
+ assert self.seq_lens is not None
349
+ assert len(self.seq_lens) == num_seqs
350
+ assert self.seq_lens_tensor is not None
351
+ assert self.seq_lens_tensor.shape == (num_seqs, )
352
+ assert self.max_query_len == 1
353
+ assert self.max_prefill_seq_len == 0
354
+
355
+ assert self.query_start_loc is not None
356
+ assert self.query_start_loc.shape == (num_queries + 1, )
357
+ assert self.seq_start_loc is not None
358
+ assert self.seq_start_loc.shape == (num_seqs + 1, )
359
+
360
+ assert self.context_lens_tensor is not None
361
+ assert self.context_lens_tensor.shape == (num_queries, )
362
+
363
+ assert self.block_tables is not None
364
+ assert self.block_tables.shape[0] == num_seqs
365
+
366
+ # Update query lengths. Note that we update only queries and not seqs,
367
+ # since tensors may be padded due to captured cuda graph batch size
368
+ for i in range(num_queries):
369
+ self.seq_lens[i] += 1
370
+ self.max_decode_seq_len = max(self.seq_lens)
371
+
372
+ ops.advance_step_flashattn(num_seqs=num_seqs,
373
+ num_queries=num_queries,
374
+ block_size=block_size,
375
+ input_tokens=model_input.input_tokens,
376
+ sampled_token_ids=sampled_token_ids,
377
+ input_positions=model_input.input_positions,
378
+ seq_lens=self.seq_lens_tensor,
379
+ slot_mapping=self.slot_mapping,
380
+ block_tables=self.block_tables)
381
+
382
+
383
+ class FlashAttentionMetadataBuilder(
384
+ AttentionMetadataBuilder[FlashAttentionMetadata]):
385
+
386
+ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
387
+ self.input_builder = input_builder
388
+ self.runner = input_builder.runner
389
+ self.sliding_window = input_builder.sliding_window
390
+ self.block_size = input_builder.block_size
391
+
392
+ def prepare(self):
393
+ self.slot_mapping: List[int] = []
394
+ self.prefill_seq_lens: List[int] = []
395
+ self.context_lens: List[int] = []
396
+ self.block_tables: List[List[int]] = []
397
+ self.curr_seq_lens: List[int] = []
398
+ self.multimodal_placeholder_maps: Dict[
399
+ str,
400
+ MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
401
+ self.num_prefills = 0
402
+ self.num_prefill_tokens = 0
403
+ self.num_decode_tokens = 0
404
+ self.has_prefix_cache_hit = False
405
+
406
+ def _add_seq_group(
407
+ self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
408
+ chunked_prefill_enabled: bool, prefix_cache_hit: bool):
409
+ """Add a sequence group to the metadata. Specifically update/append
410
+ 1. context length.
411
+ 2. block table.
412
+ 3. slot mapping.
413
+ """
414
+ is_prompt = inter_data.is_prompt
415
+ block_tables = inter_data.block_tables
416
+
417
+ for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
418
+ curr_sliding_window_block) in zip(
419
+ inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
420
+ inter_data.orig_seq_lens, inter_data.seq_lens,
421
+ inter_data.query_lens, inter_data.context_lens,
422
+ inter_data.curr_sliding_window_blocks):
423
+ self.context_lens.append(context_len)
424
+
425
+ if is_prompt:
426
+ mm_maps = inter_data.multi_modal_placeholder_maps
427
+ if mm_maps:
428
+ for modality, placeholders in mm_maps.items():
429
+ self.multimodal_placeholder_maps[modality].extend(
430
+ placeholders)
431
+
432
+ self.num_prefills += 1
433
+ self.num_prefill_tokens += token_len
434
+ self.prefill_seq_lens.append(seq_len)
435
+ else:
436
+ self.num_decode_tokens += query_len
437
+ self.curr_seq_lens.append(curr_seq_len)
438
+
439
+ # Compute block table.
440
+ # TODO(sang): Combine chunked prefill and prefix caching by
441
+ # only allowing multiple of block_size chunk size.
442
+ # NOTE: This only works for oooooooxxx style attention.
443
+ block_table = []
444
+ if prefix_cache_hit:
445
+ # NOTE(woosuk): For flash-attn, the block table should
446
+ # include the entries for the incoming prefill tokens.
447
+ block_table = block_tables[seq_id]
448
+ elif ((chunked_prefill_enabled or not is_prompt)
449
+ and block_tables is not None):
450
+ if curr_sliding_window_block == 0:
451
+ block_table = block_tables[seq_id]
452
+ else:
453
+ block_table = block_tables[seq_id][
454
+ -curr_sliding_window_block:]
455
+ self.block_tables.append(block_table)
456
+
457
+ # Compute slot mapping.
458
+ is_profile_run = is_block_tables_empty(block_tables)
459
+ start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
460
+ context_len,
461
+ self.sliding_window)
462
+ compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
463
+ seq_len, context_len, start_idx,
464
+ self.block_size, inter_data.block_tables)
465
+
466
+ def _get_graph_runner_block_tables(
467
+ self, num_seqs: int,
468
+ block_tables: List[List[int]]) -> torch.Tensor:
469
+ # The shape of graph_block_tables is
470
+ # [max batch size, max context len // block size].
471
+ max_batch_size, max_blocks = self.runner.graph_block_tables.shape
472
+ assert max_batch_size >= num_seqs
473
+
474
+ graph_block_tables = self.runner.graph_block_tables[:num_seqs]
475
+ for i, block_table in enumerate(block_tables):
476
+ if block_table:
477
+ num_blocks = len(block_table)
478
+ if num_blocks <= max_blocks:
479
+ graph_block_tables[i, :num_blocks] = block_table
480
+ else:
481
+ # It may be possible to have more blocks allocated due
482
+ # to lookahead slots of multi-step, however, they are
483
+ # not used anyway, so can be safely ignored.
484
+ graph_block_tables[
485
+ i, :max_blocks] = block_table[:max_blocks]
486
+
487
+ return torch.from_numpy(graph_block_tables).to(
488
+ device=self.runner.device, non_blocking=True)
489
+
490
+ def build(self, seq_lens: List[int], query_lens: List[int],
491
+ cuda_graph_pad_size: int, batch_size: int):
492
+ """Build attention metadata with on-device tensors.
493
+
494
+ Args:
495
+ seq_lens: The maybe padded sequence lengths of the input sequences.
496
+ query_lens: The query lengths of the input sequences.
497
+ cuda_graph_pad_size: The padding size for cuda graph.
498
+ -1 if cuda graph is not used.
499
+ batch_size: The maybe padded batch size.
500
+ """
501
+ prefix_cache_hit = any([
502
+ inter_data.prefix_cache_hit
503
+ for inter_data in self.input_builder.inter_data_list
504
+ ])
505
+ for inter_data in self.input_builder.inter_data_list:
506
+ self._add_seq_group(inter_data,
507
+ self.input_builder.chunked_prefill_enabled,
508
+ prefix_cache_hit)
509
+
510
+ device = self.runner.device
511
+ use_captured_graph = cuda_graph_pad_size != -1
512
+
513
+ max_query_len = max(query_lens)
514
+ decode_query_lens = query_lens[self.num_prefills:]
515
+ if len(decode_query_lens) > 0:
516
+ max_decode_query_len = max(decode_query_lens)
517
+ else:
518
+ max_decode_query_len = 1
519
+ max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
520
+ max_decode_seq_len = max(self.curr_seq_lens, default=0)
521
+ num_decode_tokens = self.num_decode_tokens
522
+ query_start_loc = list(accumulate(query_lens, initial=0))
523
+ seq_start_loc = list(accumulate(seq_lens, initial=0))
524
+
525
+ num_seqs = len(seq_lens)
526
+ if use_captured_graph:
527
+ self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
528
+ self.block_tables.extend([] * cuda_graph_pad_size)
529
+ num_decode_tokens = batch_size - self.num_prefill_tokens
530
+ block_tables = self._get_graph_runner_block_tables(
531
+ num_seqs, self.block_tables)
532
+ else:
533
+ block_tables = make_tensor_with_pad(
534
+ self.block_tables,
535
+ pad=0,
536
+ dtype=torch.int,
537
+ device=device,
538
+ )
539
+ assert max_query_len > 0, ("query_lens: {}".format(query_lens))
540
+
541
+ assert device is not None
542
+ context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
543
+ device, self.runner.pin_memory)
544
+ seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
545
+ self.runner.pin_memory)
546
+ slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
547
+ device, self.runner.pin_memory)
548
+ query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
549
+ device,
550
+ self.runner.pin_memory)
551
+ seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
552
+ device, self.runner.pin_memory)
553
+ placeholder_index_maps = {
554
+ modality: placeholder_map.index_map()
555
+ for modality, placeholder_map in
556
+ self.multimodal_placeholder_maps.items()
557
+ }
558
+
559
+ return FlashAttentionMetadata(
560
+ num_prefills=self.num_prefills,
561
+ slot_mapping=slot_mapping_tensor,
562
+ num_prefill_tokens=self.num_prefill_tokens,
563
+ num_decode_tokens=num_decode_tokens,
564
+ seq_lens=seq_lens,
565
+ multi_modal_placeholder_index_maps=placeholder_index_maps,
566
+ enable_kv_scales_calculation=True,
567
+ seq_lens_tensor=seq_lens_tensor,
568
+ max_query_len=max_query_len,
569
+ max_decode_query_len=max_decode_query_len,
570
+ max_prefill_seq_len=max_prefill_seq_len,
571
+ max_decode_seq_len=max_decode_seq_len,
572
+ query_start_loc=query_start_loc_tensor,
573
+ seq_start_loc=seq_start_loc_tensor,
574
+ context_lens_tensor=context_lens_tensor,
575
+ block_tables=block_tables,
576
+ use_cuda_graph=use_captured_graph,
577
+ )
578
+
579
+
580
+ class FlashAttentionImpl(AttentionImpl):
581
+ """
582
+ If the input tensors contain prompt tokens, the layout is as follows:
583
+ |<--------------- num_prefill_tokens ----------------->|
584
+ |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
585
+
586
+ Otherwise, the layout is as follows:
587
+ |<----------------- num_decode_tokens ------------------>|
588
+ |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
589
+
590
+ Generation tokens can contain padding when cuda-graph is used.
591
+ Currently, prompt tokens don't contain any padding.
592
+
593
+ The prompts might have different lengths, while the generation tokens
594
+ always have length 1.
595
+
596
+ If chunked prefill is enabled, prefill tokens and decode tokens can be
597
+ batched together in a flattened 1D query.
598
+
599
+ |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
600
+ |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
601
+
602
+ Currently, cuda graph is disabled for chunked prefill, meaning there's no
603
+ padding between prefill and decode tokens.
604
+ """
605
+
606
+ def __init__(
607
+ self,
608
+ num_heads: int,
609
+ head_size: int,
610
+ scale: float,
611
+ num_kv_heads: int,
612
+ alibi_slopes: Optional[List[float]],
613
+ sliding_window: Optional[int],
614
+ kv_cache_dtype: str,
615
+ blocksparse_params: Optional[Dict[str, Any]] = None,
616
+ logits_soft_cap: Optional[float] = None,
617
+ attn_type: str = AttentionType.DECODER,
618
+ ) -> None:
619
+ if blocksparse_params is not None:
620
+ raise ValueError(
621
+ "FlashAttention does not support block-sparse attention.")
622
+ self.num_heads = num_heads
623
+ self.head_size = head_size
624
+ self.scale = float(scale)
625
+ self.num_kv_heads = num_kv_heads
626
+ if alibi_slopes is not None:
627
+ alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
628
+ self.alibi_slopes = alibi_slopes
629
+ self.sliding_window = ((sliding_window - 1,
630
+ 0) if sliding_window is not None else (-1, -1))
631
+ self.kv_cache_dtype = kv_cache_dtype
632
+ if logits_soft_cap is None:
633
+ # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
634
+ logits_soft_cap = 0
635
+ self.logits_soft_cap = logits_soft_cap
636
+
637
+ assert self.num_heads % self.num_kv_heads == 0
638
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
639
+
640
+ support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
641
+ if head_size not in support_head_sizes:
642
+ raise ValueError(
643
+ f"Head size {head_size} is not supported by FlashAttention. "
644
+ f"Supported head sizes are: {support_head_sizes}.")
645
+ self.attn_type = attn_type
646
+
647
+ # if hopper default to FA3, otherwise stick to FA2 for now
648
+ # TODO(lucas): profile FA3 on ampere to see if it makes sense to
649
+ # use FA3 as default for both
650
+ if current_platform.get_device_capability()[0] >= 9:
651
+ self.fa_version = 3 if is_fa_version_supported(3) else 2
652
+ else:
653
+ self.fa_version = 2
654
+
655
+ if VLLM_FLASH_ATTN_VERSION is not None:
656
+ assert VLLM_FLASH_ATTN_VERSION in [2, 3]
657
+ self.fa_version = VLLM_FLASH_ATTN_VERSION
658
+
659
+ if not is_fa_version_supported(self.fa_version):
660
+ logger.error("Cannot use FA version %d is not supported due to %s",
661
+ self.fa_version,
662
+ fa_version_unsupported_reason(self.fa_version))
663
+
664
+ assert is_fa_version_supported(self.fa_version)
665
+
666
+ def forward(
667
+ self,
668
+ layer: AttentionLayer,
669
+ query: torch.Tensor,
670
+ key: torch.Tensor,
671
+ value: torch.Tensor,
672
+ kv_cache: torch.Tensor,
673
+ attn_metadata: FlashAttentionMetadata,
674
+ output: Optional[torch.Tensor] = None,
675
+ ) -> torch.Tensor:
676
+ """Forward pass with FlashAttention.
677
+
678
+ Args:
679
+ query: shape = [num_tokens, num_heads, head_size]
680
+ key: shape = [num_tokens, num_kv_heads, head_size]
681
+ value: shape = [num_tokens, num_kv_heads, head_size]
682
+ output: shape = [num_tokens, num_heads, head_size]
683
+ kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
684
+ NOTE: kv_cache will be an empty tensor with shape [0]
685
+ for profiling run.
686
+ attn_metadata: Metadata for attention.
687
+ NOTE: It in-place updates the output tensor.
688
+ """
689
+ # NOTE(woosuk): FlashAttention does not support FP8 KV cache.
690
+ assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0, (
691
+ "key/v_scale is not supported in FlashAttention.")
692
+
693
+ assert output is not None, "Output tensor must be provided."
694
+
695
+ attn_type = self.attn_type
696
+ if (attn_type == AttentionType.ENCODER
697
+ and (not attn_metadata.is_all_encoder_attn_metadata_set)):
698
+ raise AttributeError("Encoder attention requires setting "
699
+ "encoder metadata attributes.")
700
+ elif (attn_type == AttentionType.ENCODER_DECODER
701
+ and (not attn_metadata.is_all_cross_attn_metadata_set)):
702
+ raise AttributeError("Encoder/decoder cross-attention "
703
+ "requires setting cross-attention "
704
+ "metadata attributes.")
705
+
706
+ kv_cache_dtype: str = self.kv_cache_dtype
707
+ softmax_scale: float = self.scale
708
+ window_size = self.sliding_window
709
+ alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
710
+ logits_soft_cap: Optional[float] = self.logits_soft_cap
711
+
712
+ if kv_cache.numel() > 0:
713
+ key_cache = kv_cache[0]
714
+ value_cache = kv_cache[1]
715
+ # We skip updating the KV cache under two conditions:
716
+ # a. When the Attention Type is ENCODER. In this phase, we compute
717
+ # only the encoder attention without updating the cache.
718
+ # b. When both Key and Value are None. This occurs during
719
+ # cross-attention computation in the decoding phase, where the
720
+ # KV cache is already populated with the cross-attention
721
+ # tensor. Thus, we skip cache updates during this time.
722
+ if (attn_type != AttentionType.ENCODER) and (key is not None) and (
723
+ value is not None):
724
+ if attn_type == AttentionType.ENCODER_DECODER:
725
+ # Update cross-attention KV cache (prefill-only)
726
+ updated_slot_mapping = attn_metadata.cross_slot_mapping
727
+ else:
728
+ # Update self-attention KV cache (prefill/decode)
729
+ updated_slot_mapping = attn_metadata.slot_mapping
730
+
731
+ # Reshape the input keys and values and store them in the cache.
732
+ # If kv_cache is not provided, the new key and value tensors are
733
+ # not cached. This happens during the initial memory
734
+ # profiling run.
735
+ torch.ops._C_cache_ops.reshape_and_cache_flash(
736
+ key,
737
+ value,
738
+ kv_cache[0],
739
+ kv_cache[1],
740
+ updated_slot_mapping.flatten(), # type: ignore[union-attr]
741
+ kv_cache_dtype,
742
+ layer._k_scale,
743
+ layer._v_scale,
744
+ )
745
+
746
+ (num_prefill_query_tokens, num_prefill_kv_tokens,
747
+ num_decode_query_tokens) = \
748
+ get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
749
+ decode_query = query[num_prefill_query_tokens:]
750
+ decode_output = output[num_prefill_query_tokens:]
751
+ # QKV for prefill.
752
+ query = query[:num_prefill_query_tokens]
753
+ prefill_output = output[:num_prefill_query_tokens]
754
+ assert query.shape[0] == num_prefill_query_tokens
755
+ assert decode_query.shape[0] == num_decode_query_tokens
756
+
757
+ if prefill_meta := attn_metadata.prefill_metadata:
758
+ # Prompt run.
759
+ if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
760
+ or prefill_meta.block_tables.numel() == 0):
761
+ # normal attention
762
+ # When block_tables are not filled, it means q and k are the
763
+ # prompt, and they have the same length.
764
+ q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
765
+ _get_query_key_seq_metadata(prefill_meta, True, attn_type)
766
+
767
+ key = key[:num_prefill_kv_tokens]
768
+ value = value[:num_prefill_kv_tokens]
769
+
770
+ flash_attn_varlen_func(
771
+ q=query,
772
+ k=key,
773
+ v=value,
774
+ cu_seqlens_q=q_seq_start_loc,
775
+ cu_seqlens_k=k_seq_start_loc,
776
+ max_seqlen_q=q_seq_len,
777
+ max_seqlen_k=k_seq_len,
778
+ softmax_scale=softmax_scale,
779
+ causal=_get_causal_option(attn_type),
780
+ window_size=window_size,
781
+ alibi_slopes=alibi_slopes,
782
+ softcap=logits_soft_cap,
783
+ out=prefill_output,
784
+ fa_version=self.fa_version,
785
+ )
786
+ else:
787
+ # prefix-enabled attention
788
+ assert attn_type == AttentionType.DECODER, (
789
+ "Only decoder-only models support prefix caching")
790
+ assert prefill_meta.seq_lens is not None
791
+ max_seq_len = max(prefill_meta.seq_lens)
792
+ flash_attn_varlen_func( # noqa
793
+ q=query,
794
+ k=key_cache,
795
+ v=value_cache,
796
+ cu_seqlens_q=prefill_meta.query_start_loc,
797
+ max_seqlen_q=prefill_meta.max_query_len,
798
+ seqused_k=prefill_meta.seq_lens_tensor,
799
+ max_seqlen_k=max_seq_len,
800
+ softmax_scale=softmax_scale,
801
+ causal=True,
802
+ window_size=window_size,
803
+ alibi_slopes=alibi_slopes,
804
+ block_table=prefill_meta.block_tables,
805
+ softcap=logits_soft_cap,
806
+ out=prefill_output,
807
+ fa_version=self.fa_version,
808
+ )
809
+
810
+ if decode_meta := attn_metadata.decode_metadata:
811
+ # Decoding run.
812
+ # Use flash_attn_varlen_func kernel for speculative decoding
813
+ # because different queries might have different lengths.
814
+
815
+ assert decode_meta.max_decode_query_len is not None
816
+ # use only for actual varlen decoding
817
+ if decode_meta.max_decode_query_len > 1:
818
+ assert attn_type == AttentionType.DECODER, (
819
+ "Only decoder-only models support max_decode_query_len > 1"
820
+ )
821
+ flash_attn_varlen_func(
822
+ q=decode_query,
823
+ k=key_cache,
824
+ v=value_cache,
825
+ cu_seqlens_q=decode_meta.query_start_loc,
826
+ max_seqlen_q=decode_meta.max_decode_query_len,
827
+ seqused_k=decode_meta.seq_lens_tensor,
828
+ max_seqlen_k=decode_meta.max_decode_seq_len,
829
+ softmax_scale=softmax_scale,
830
+ causal=True,
831
+ window_size=window_size,
832
+ alibi_slopes=alibi_slopes,
833
+ softcap=logits_soft_cap,
834
+ block_table=decode_meta.block_tables,
835
+ out=decode_output,
836
+ fa_version=self.fa_version,
837
+ )
838
+ else:
839
+ # Use flash_attn_with_kvcache for normal decoding.
840
+ (
841
+ seq_lens_arg,
842
+ _,
843
+ block_tables_arg,
844
+ ) = get_seq_len_block_table_args(decode_meta, False, attn_type)
845
+ flash_attn_with_kvcache(
846
+ q=decode_query.unsqueeze(1),
847
+ k_cache=key_cache,
848
+ v_cache=value_cache,
849
+ block_table=block_tables_arg,
850
+ cache_seqlens=seq_lens_arg,
851
+ softmax_scale=softmax_scale,
852
+ causal=True,
853
+ window_size=window_size,
854
+ alibi_slopes=alibi_slopes,
855
+ softcap=logits_soft_cap,
856
+ out=decode_output.unsqueeze(1),
857
+ fa_version=self.fa_version,
858
+ )
859
+ return output
860
+
861
+
862
+ def _get_query_key_seq_metadata(
863
+ attn_metadata,
864
+ is_prompt: bool,
865
+ attn_type: str,
866
+ ) -> tuple:
867
+ """
868
+ Returns sequence metadata for key and query based on the specified
869
+ attention type and whether input is a prompt.
870
+
871
+ This function computes the starting locations and maximum sequence lengths
872
+ for key and query sequences for different attention types.
873
+
874
+ Args:
875
+ attn_metadata: The attention metadata object
876
+ is_prompt (bool): A flag indicating if the input is a prompt
877
+ attn_type (AttentionType): The type of attention being used.
878
+
879
+ Returns:
880
+ tuple: A tuple containing four integers:
881
+ - Starting location for the query sequence.
882
+ - Maximum sequence length for the query sequence.
883
+ - Starting location for the key sequence.
884
+ - Maximum sequence length for the key sequence.
885
+
886
+ Raises:
887
+ AttributeError: If an invalid attention type is provided.
888
+ """
889
+ if attn_type == AttentionType.DECODER:
890
+ # Decoder self-attention
891
+ # Choose max_seq_len based on whether we are in prompt_run
892
+ if is_prompt:
893
+ max_seq_len = attn_metadata.max_prefill_seq_len
894
+ else:
895
+ max_seq_len = attn_metadata.max_decode_seq_len
896
+ return (attn_metadata.seq_start_loc, max_seq_len,
897
+ attn_metadata.seq_start_loc, max_seq_len)
898
+
899
+ elif attn_type == AttentionType.ENCODER_DECODER:
900
+ # This is cross attention between the where the key
901
+ # is the precomputed encoder attention and query
902
+ # is the input sequence.
903
+ # Choose query max length based on whether it is prompt
904
+ # or not.
905
+ if is_prompt:
906
+ max_seq_len = attn_metadata.max_prefill_seq_len
907
+ else:
908
+ max_seq_len = attn_metadata.max_decode_seq_len
909
+ return (attn_metadata.seq_start_loc, max_seq_len,
910
+ attn_metadata.encoder_seq_start_loc,
911
+ attn_metadata.max_encoder_seq_len)
912
+ elif attn_type == AttentionType.ENCODER:
913
+ # For encoder attention both the query and the key are same i.e the
914
+ # encoder sequence.
915
+ return (attn_metadata.encoder_seq_start_loc,
916
+ attn_metadata.max_encoder_seq_len,
917
+ attn_metadata.encoder_seq_start_loc,
918
+ attn_metadata.max_encoder_seq_len)
919
+ elif attn_type == AttentionType.ENCODER_ONLY:
920
+ assert is_prompt, "Should not have decode for encoder only model."
921
+ return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len,
922
+ attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len)
923
+ else:
924
+ raise AttributeError(f"Invalid attention type {str(attn_type)}")
925
+
926
+
927
+ def _get_causal_option(attn_type: str) -> bool:
928
+ """
929
+ Determine whether the given attention type is suitable for causal
930
+ attention mechanisms.
931
+
932
+ Args:
933
+ attn_type (AttentionType): The type of attention being evaluated
934
+
935
+ Returns:
936
+ bool: Returns `True` if the attention type is suitable for causal
937
+ attention (i.e., not encoder, encoder-only, or encoder-decoder),
938
+ otherwise returns `False`.
939
+ """
940
+ return not (attn_type == AttentionType.ENCODER
941
+ or attn_type == AttentionType.ENCODER_ONLY
942
+ or attn_type == AttentionType.ENCODER_DECODER)
.venv/lib/python3.11/site-packages/vllm/attention/backends/flashinfer.py ADDED
@@ -0,0 +1,1066 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import dataclasses
4
+ from collections import defaultdict
5
+ from contextlib import contextmanager
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
8
+
9
+ from vllm.multimodal import MultiModalPlaceholderMap
10
+
11
+ try:
12
+ from flashinfer import BatchDecodeWithPagedKVCacheWrapper
13
+ from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
14
+ from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
15
+
16
+ from vllm.vllm_flash_attn import flash_attn_varlen_func
17
+ FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
18
+ except ImportError:
19
+ # Avoid turning these types into variables during type checking
20
+ if not TYPE_CHECKING:
21
+ BatchDecodeWithPagedKVCacheWrapper = None
22
+ CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
23
+ BatchPrefillWithPagedKVCacheWrapper = None
24
+ FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
25
+
26
+ import torch
27
+
28
+ import vllm.envs as envs
29
+ from vllm import _custom_ops as ops
30
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
31
+ AttentionLayer,
32
+ AttentionMetadata,
33
+ AttentionMetadataBuilder,
34
+ AttentionState, AttentionType)
35
+ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
36
+ compute_slot_mapping_start_idx,
37
+ is_block_tables_empty)
38
+ from vllm.attention.layer import Attention
39
+ from vllm.attention.ops.paged_attn import PagedAttention
40
+ from vllm.config import VllmConfig, get_current_vllm_config
41
+ from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
42
+ make_tensor_with_pad)
43
+
44
+ if TYPE_CHECKING:
45
+ from vllm.worker.model_runner import (ModelInputForGPUBuilder,
46
+ ModelInputForGPUWithSamplingMetadata)
47
+
48
+
49
+ class FlashInferBackend(AttentionBackend):
50
+
51
+ @staticmethod
52
+ def get_name() -> str:
53
+ return "FLASHINFER"
54
+
55
+ @staticmethod
56
+ def get_impl_cls() -> Type["FlashInferImpl"]:
57
+ return FlashInferImpl
58
+
59
+ @staticmethod
60
+ def get_metadata_cls() -> Type["AttentionMetadata"]:
61
+ return FlashInferMetadata
62
+
63
+ @staticmethod
64
+ def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
65
+ return FlashInferMetadataBuilder
66
+
67
+ @staticmethod
68
+ def get_state_cls() -> Type["FlashInferState"]:
69
+ return FlashInferState
70
+
71
+ @staticmethod
72
+ def get_kv_cache_shape(
73
+ num_blocks: int,
74
+ block_size: int,
75
+ num_kv_heads: int,
76
+ head_size: int,
77
+ ) -> Tuple[int, ...]:
78
+ return (num_blocks, 2, block_size, num_kv_heads, head_size)
79
+
80
+ @staticmethod
81
+ def swap_blocks(
82
+ src_kv_cache: torch.Tensor,
83
+ dst_kv_cache: torch.Tensor,
84
+ src_to_dst: torch.Tensor,
85
+ ) -> None:
86
+ PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
87
+
88
+ @staticmethod
89
+ def copy_blocks(
90
+ kv_caches: List[torch.Tensor],
91
+ src_to_dists: torch.Tensor,
92
+ ) -> None:
93
+ PagedAttention.copy_blocks(kv_caches, src_to_dists)
94
+
95
+ @staticmethod
96
+ def get_supported_head_sizes() -> List[int]:
97
+ return [64, 128, 256]
98
+
99
+ @staticmethod
100
+ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
101
+ if kv_cache_dtype in ("fp8", "fp8_e4m3"):
102
+ return torch.float8_e4m3fn
103
+ elif kv_cache_dtype == "fp8_e5m2":
104
+ return torch.float8_e5m2
105
+ else:
106
+ raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
107
+
108
+
109
+ @dataclass
110
+ class PerLayerParameters:
111
+ """
112
+ Currently, FlashInfer backend only support models in which all layers share
113
+ the same values for the following hyperparameters.
114
+ """
115
+
116
+ window_left: int
117
+ logits_soft_cap: Optional[float]
118
+ sm_scale: float
119
+
120
+
121
+ def get_per_layer_parameters(
122
+ vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]:
123
+ """
124
+ Scan all attention layers and determine some hyperparameters
125
+ to use during `plan`.
126
+ """
127
+
128
+ layers = vllm_config.compilation_config.static_forward_context
129
+ per_layer_params: Dict[str, PerLayerParameters] = {}
130
+
131
+ for key, layer in layers.items():
132
+ assert isinstance(layer, Attention)
133
+
134
+ impl = layer.impl
135
+ assert isinstance(impl, FlashInferImpl)
136
+
137
+ # Infer hyperparameters from the attention layer
138
+ window_size = impl.sliding_window
139
+ window_left = window_size[0] if window_size is not None else -1
140
+ logits_soft_cap = impl.logits_soft_cap
141
+ sm_scale = impl.scale
142
+
143
+ per_layer_params[key] = PerLayerParameters(window_left,
144
+ logits_soft_cap, sm_scale)
145
+
146
+ return per_layer_params
147
+
148
+
149
+ def infer_global_hyperparameters(
150
+ per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters:
151
+ """
152
+ Currently, FlashInfer backend only support models in which all layers share
153
+ the same values for the following hyperparameters:
154
+ - `window_left`
155
+ - `logits_soft_cap`
156
+ - `sm_scale`
157
+
158
+ So this function asserts that all layers share the same values for these
159
+ hyperparameters and returns the global values.
160
+ """
161
+
162
+ assert len(per_layer_params) > 0, "No attention layers found in the model."
163
+
164
+ param_sets = list(per_layer_params.values())
165
+ global_params = param_sets[0]
166
+ for params in param_sets:
167
+ assert params == global_params, (
168
+ "FlashInfer backend currently only supports models in which all "
169
+ "layers share the same values for the following hyperparameters: "
170
+ "`window_left`, `logits_soft_cap`, `sm_scale`.")
171
+
172
+ return global_params
173
+
174
+
175
+ class FlashInferState(AttentionState):
176
+
177
+ def __init__(self, runner):
178
+ self.runner = runner
179
+ self._is_graph_capturing = False
180
+ self._workspace_buffer = None
181
+ self._decode_wrapper = None
182
+ self._prefill_wrapper = None
183
+
184
+ # Global hyperparameters shared by all attention layers
185
+ self.global_hyperparameters: Optional[PerLayerParameters] = None
186
+
187
+ self.vllm_config = get_current_vllm_config()
188
+
189
+ def _get_workspace_buffer(self):
190
+ if self._workspace_buffer is None:
191
+ self._workspace_buffer = torch.empty(
192
+ FLASHINFER_WORKSPACE_BUFFER_SIZE,
193
+ dtype=torch.uint8,
194
+ device=self.runner.device)
195
+ return self._workspace_buffer
196
+
197
+ def _get_prefill_wrapper(self):
198
+ if self._prefill_wrapper is None:
199
+ self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
200
+ self._get_workspace_buffer(), "NHD")
201
+ return self._prefill_wrapper
202
+
203
+ def _get_decode_wrapper(self):
204
+ if self._decode_wrapper is None:
205
+ num_qo_heads = (self.runner.model_config.get_num_attention_heads(
206
+ self.runner.parallel_config))
207
+ num_kv_heads = self.runner.model_config.get_num_kv_heads(
208
+ self.runner.parallel_config)
209
+ use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
210
+ num_qo_heads // num_kv_heads > 4)
211
+ self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
212
+ self._get_workspace_buffer(),
213
+ "NHD",
214
+ use_tensor_cores=use_tensor_cores)
215
+ return self._decode_wrapper
216
+
217
+ @contextmanager
218
+ def graph_capture(self, max_batch_size: int):
219
+ self._is_graph_capturing = True
220
+ self._graph_decode_wrapper = None
221
+ self._graph_slot_mapping = torch.full((max_batch_size, ),
222
+ PAD_SLOT_ID,
223
+ dtype=torch.long,
224
+ device=self.runner.device)
225
+ self._graph_seq_lens = torch.ones(max_batch_size,
226
+ dtype=torch.int32,
227
+ device=self.runner.device)
228
+ self._graph_block_tables = torch.from_numpy(
229
+ self.runner.graph_block_tables).to(device=self.runner.device)
230
+ self._graph_decode_workspace_buffer = self._get_workspace_buffer()
231
+ self._graph_indices_buffer = torch.empty(
232
+ max_batch_size * self.runner.cache_config.num_gpu_blocks,
233
+ dtype=torch.int32,
234
+ device=self.runner.device)
235
+ self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
236
+ dtype=torch.int32,
237
+ device=self.runner.device)
238
+ self._graph_last_page_len_buffer = torch.empty(
239
+ max_batch_size, dtype=torch.int32, device=self.runner.device)
240
+ yield
241
+ self._is_graph_capturing = False
242
+ del self._graph_slot_mapping
243
+ del self._graph_seq_lens
244
+ del self._graph_block_tables
245
+ del self._graph_decode_workspace_buffer
246
+ del self._graph_indices_buffer
247
+ del self._graph_indptr_buffer
248
+ del self._graph_last_page_len_buffer
249
+ del self._graph_decode_wrapper
250
+
251
+ def graph_clone(self, batch_size: int):
252
+ assert self._is_graph_capturing
253
+ state = self.__class__(self.runner)
254
+ state._workspace_buffer = self._graph_decode_workspace_buffer
255
+ state._decode_wrapper = self._graph_decode_wrapper
256
+ state._prefill_wrapper = self._get_prefill_wrapper()
257
+ return state
258
+
259
+ def graph_capture_get_metadata_for_batch(
260
+ self, batch_size: int, is_encoder_decoder_model: bool = False):
261
+ assert self._is_graph_capturing
262
+ _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
263
+ _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]
264
+
265
+ num_qo_heads = (self.runner.model_config.get_num_attention_heads(
266
+ self.runner.parallel_config))
267
+ num_kv_heads = self.runner.model_config.get_num_kv_heads(
268
+ self.runner.parallel_config)
269
+ use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
270
+ num_qo_heads // num_kv_heads > 4)
271
+ self._graph_decode_wrapper = \
272
+ CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
273
+ self._graph_decode_workspace_buffer, _indptr_buffer,
274
+ self._graph_indices_buffer, _last_page_len_buffer, "NHD",
275
+ use_tensor_cores)
276
+ if self.runner.kv_cache_dtype.startswith("fp8"):
277
+ kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
278
+ self.runner.kv_cache_dtype)
279
+ else:
280
+ kv_cache_dtype = get_kv_cache_torch_dtype(
281
+ self.runner.kv_cache_dtype, self.runner.model_config.dtype)
282
+
283
+ paged_kv_indptr_tensor_host = torch.arange(0,
284
+ batch_size + 1,
285
+ dtype=torch.int32)
286
+ paged_kv_indices_tensor_host = torch.arange(0,
287
+ batch_size,
288
+ dtype=torch.int32)
289
+ paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
290
+ self.runner.block_size,
291
+ dtype=torch.int32)
292
+ query_start_loc_host = torch.arange(0,
293
+ batch_size + 1,
294
+ dtype=torch.int32)
295
+
296
+ global_params = infer_global_hyperparameters(
297
+ get_per_layer_parameters(self.vllm_config))
298
+
299
+ attn_metadata = self.runner.attn_backend.make_metadata(
300
+ num_prefills=0,
301
+ slot_mapping=self._graph_slot_mapping[:batch_size],
302
+ multi_modal_placeholder_index_maps=None,
303
+ enable_kv_scales_calculation=False,
304
+ num_prefill_tokens=0,
305
+ num_decode_tokens=batch_size,
306
+ max_prefill_seq_len=0,
307
+ block_tables=self._graph_block_tables,
308
+ paged_kv_indptr=paged_kv_indptr_tensor_host,
309
+ paged_kv_indices=paged_kv_indices_tensor_host,
310
+ paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
311
+ num_qo_heads=num_qo_heads,
312
+ num_kv_heads=num_kv_heads,
313
+ head_dim=self.runner.model_config.get_head_size(),
314
+ page_size=self.runner.block_size,
315
+ seq_start_loc=None,
316
+ query_start_loc=query_start_loc_host,
317
+ device=self.runner.device,
318
+ data_type=kv_cache_dtype,
319
+ q_data_type=self.runner.model_config.dtype,
320
+ use_cuda_graph=True,
321
+ decode_wrapper=self._graph_decode_wrapper,
322
+ prefill_wrapper=None,
323
+ **dataclasses.asdict(global_params),
324
+ )
325
+ attn_metadata.begin_forward()
326
+ return attn_metadata
327
+
328
+ def get_graph_input_buffers(self,
329
+ attn_metadata,
330
+ is_encoder_decoder_model: bool = False):
331
+ return {
332
+ "slot_mapping": attn_metadata.slot_mapping,
333
+ }
334
+
335
+ def prepare_graph_input_buffers(self,
336
+ input_buffers,
337
+ attn_metadata,
338
+ is_encoder_decoder_model: bool = False):
339
+ return
340
+
341
+ def begin_forward(self, model_input):
342
+ assert not self._is_graph_capturing
343
+ state = self
344
+ use_cuda_graph = model_input.attn_metadata.use_cuda_graph
345
+ is_decode = model_input.attn_metadata.num_prefills == 0
346
+ # In case of multistep chunked-prefill, there might be prefill requests
347
+ # scheduled while CUDA graph mode is enabled. We don't run graph in that
348
+ # case.
349
+ if use_cuda_graph and is_decode:
350
+ batch_size = model_input.input_tokens.shape[0]
351
+ state = (self.runner.graph_runners[model_input.virtual_engine]
352
+ [batch_size].attn_state)
353
+ model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
354
+ )
355
+ model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
356
+ model_input.attn_metadata.begin_forward()
357
+
358
+
359
+ @dataclass
360
+ class FlashInferMetadata(AttentionMetadata):
361
+ # Maximum sequence length among prefill batch. 0 if there are decoding
362
+ # requests only.
363
+ max_prefill_seq_len: int
364
+ # Number of query tokens for each request in the batch.
365
+ # Currently, we require that all requests have the same number of query
366
+ # tokens during the decoding phase. When speculavie decoding is enabled,
367
+ # decode_query_len might be greater than 1. In all other cases, it is 1.
368
+ decode_query_len: Optional[int] = 1
369
+
370
+ use_cuda_graph: bool = True
371
+
372
+ prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
373
+ decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
374
+
375
+ # Metadata for the prefill stage
376
+ seq_start_loc: Optional[torch.Tensor] = None
377
+ query_start_loc: Optional[torch.Tensor] = None
378
+ block_tables: Optional[torch.Tensor] = None
379
+
380
+ # used for GPU in-place advance_step
381
+ seq_lens_tensor: Optional[torch.Tensor] = None
382
+ block_table_bound: Optional[torch.Tensor] = None
383
+
384
+ # An example for paged_kv_indices, paged_kv_indptr:
385
+ # request 1, page indices [0, 5, 8]
386
+ # request 2, page indices [1, 6, 7]
387
+ # request 3, page indices [3, 4]
388
+ # paged_kv_indices is a concatenation of page indices of all requests:
389
+ # [0, 5, 8, 1, 6, 7, 3, 4]
390
+ # paged_kv_indptr is used to index into paged_kv_indices:
391
+ # [0, 3, 6, 8]
392
+ # The indptr of the paged kv cache, shape: [batch_size + 1]
393
+ paged_kv_indptr: Optional[torch.Tensor] = None
394
+ # The page indices of the paged kv cache
395
+ paged_kv_indices: Optional[torch.Tensor] = None
396
+ # The number of entries in the last page of each request in
397
+ # the paged kv cache, shape: [batch_size]
398
+ paged_kv_last_page_len: Optional[torch.Tensor] = None
399
+ # The number of query/output heads
400
+ num_qo_heads: Optional[int] = None
401
+ # The number of key/value heads
402
+ num_kv_heads: Optional[int] = None
403
+ # The dimension of the attention heads
404
+ head_dim: Optional[int] = None
405
+ # Block size of vllm
406
+ page_size: Optional[int] = None
407
+ # The data type of the paged kv cache
408
+ data_type: torch.dtype = None
409
+ # The data type of the query
410
+ q_data_type: torch.dtype = None
411
+ # FlashInfer 0.2 encourages passing host tensors
412
+ device: torch.device = torch.device("cpu")
413
+ is_profile_run: bool = False
414
+
415
+ # The FlashInfer backend currently supports only models in which all layers
416
+ # share the same following hyperparameters:
417
+
418
+ # The left (inclusive) window size for the attention window, when
419
+ # set to `-1`, the window size will be set to the full length of
420
+ # the sequence. Defaults to `-1`.
421
+ window_left: int = -1
422
+ # The attention logits soft capping value (used in Gemini, Grok and
423
+ # Gemma-2, etc.), if not provided, will be set to `0`. If greater
424
+ # than 0, the logits will be capped according to formula:
425
+ # $$\texttt{logits\_soft\_cap} \times
426
+ # \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$,
427
+ # where $x$ is the input logits.
428
+ logits_soft_cap: Optional[float] = None
429
+ # The scale used in softmax, if not provided, will be set to
430
+ # `1.0 / sqrt(head_dim)`.
431
+ sm_scale: Optional[float] = None
432
+
433
+ def __post_init__(self):
434
+ # Refer to
435
+ # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
436
+ supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
437
+ if self.head_dim is not None and self.head_dim \
438
+ not in supported_head_sizes:
439
+ raise ValueError(
440
+ f"Only {supported_head_sizes} are supported for head_dim,",
441
+ f"received {self.head_dim}.")
442
+
443
+ def begin_forward(self):
444
+ if self.num_prefill_tokens > 0:
445
+ if self.paged_kv_indices is None:
446
+ return
447
+
448
+ assert self.prefill_wrapper is not None
449
+ assert self.query_start_loc is not None
450
+ assert self.paged_kv_indices is not None
451
+ assert self.paged_kv_indptr is not None
452
+ assert self.paged_kv_last_page_len is not None
453
+ assert self.block_table_bound is not None
454
+ assert self.seq_lens_tensor is not None
455
+ self.query_start_loc = self.query_start_loc[:self.num_prefills + 1]
456
+ batch_size = self.query_start_loc.shape[0] - 1
457
+ assert batch_size >= 0
458
+ # We will use flash attention for profiling to
459
+ # determine the number of blocks. Therefore,
460
+ # we don't need to prepare the input for flashinfer for profile run.
461
+ if not self.is_profile_run:
462
+ self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
463
+ self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
464
+ self.device)
465
+ self.block_table_bound = self.block_table_bound.to(self.device)
466
+ self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
467
+ self.paged_kv_indices = self.paged_kv_indices.to(self.device)
468
+ self.prefill_wrapper.plan(
469
+ self.query_start_loc,
470
+ self.paged_kv_indptr[:self.num_prefills + 1],
471
+ self.paged_kv_indices,
472
+ self.paged_kv_last_page_len[:self.num_prefills],
473
+ self.num_qo_heads,
474
+ self.num_kv_heads,
475
+ self.head_dim,
476
+ self.page_size,
477
+ causal=True,
478
+ sm_scale=self.sm_scale,
479
+ window_left=self.window_left,
480
+ logits_soft_cap=self.logits_soft_cap,
481
+ q_data_type=self.q_data_type,
482
+ kv_data_type=self.data_type)
483
+ if self.num_decode_tokens > 0:
484
+ assert self.paged_kv_indices is not None
485
+ assert self.paged_kv_indptr is not None
486
+ assert self.paged_kv_last_page_len is not None
487
+ self.paged_kv_indices = self.paged_kv_indices.to(self.device)
488
+ self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
489
+ self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
490
+ self.device)
491
+ # handle model warmup path
492
+ if self.block_table_bound is not None:
493
+ self.block_table_bound = self.block_table_bound.to(self.device)
494
+ if self.seq_lens_tensor is not None:
495
+ self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
496
+
497
+ assert self.decode_wrapper is not None
498
+ self.decode_wrapper.plan(
499
+ self.paged_kv_indptr[self.num_prefills:],
500
+ self.paged_kv_indices,
501
+ self.paged_kv_last_page_len[self.num_prefills:],
502
+ self.num_qo_heads,
503
+ self.num_kv_heads,
504
+ self.head_dim,
505
+ self.page_size,
506
+ # Disable flashinfer's pos encoding and use vllm's rope.
507
+ pos_encoding_mode="NONE",
508
+ window_left=self.window_left,
509
+ logits_soft_cap=self.logits_soft_cap,
510
+ sm_scale=self.sm_scale,
511
+ # kv-cache data type.
512
+ kv_data_type=self.data_type,
513
+ # query data type.
514
+ q_data_type=self.q_data_type)
515
+
516
+ def asdict_zerocopy(self,
517
+ skip_fields: Optional[Set[str]] = None
518
+ ) -> Dict[str, Any]:
519
+ if skip_fields is None:
520
+ skip_fields = set()
521
+ # We need to skip the prefill/decode_wrapper field since it cannot be
522
+ # broadcasted with nccl when TP is enabled.
523
+ skip_fields.add('prefill_wrapper')
524
+ skip_fields.add('decode_wrapper')
525
+ return super().asdict_zerocopy(skip_fields)
526
+
527
+ @property
528
+ def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
529
+ if self.num_prefills == 0:
530
+ return None
531
+ return self
532
+
533
+ @property
534
+ def decode_metadata(self) -> Optional["FlashInferMetadata"]:
535
+ if self.num_decode_tokens == 0:
536
+ return None
537
+ return self
538
+
539
+ def advance_step(self,
540
+ model_input: "ModelInputForGPUWithSamplingMetadata",
541
+ sampled_token_ids: Optional[torch.Tensor],
542
+ block_size: int,
543
+ num_seqs: int,
544
+ num_queries: int,
545
+ turn_prefills_into_decodes: bool = False):
546
+ """
547
+ Update metadata in-place to advance one decode step.
548
+ """
549
+
550
+ if turn_prefills_into_decodes:
551
+ # When Multi-Step is enabled with Chunked-Prefill, prefills and
552
+ # decodes are scheduled together. In the first step, all the
553
+ # prefills turn into decodes. This update reflects that
554
+ # conversion.
555
+ assert self.num_decode_tokens + self.num_prefills == num_seqs
556
+ # Flashinfer doesn't support speculative decoding + chunked-prefill
557
+ # + multi-step scheduling yet.
558
+ assert self.decode_query_len == 1
559
+ self.num_decode_tokens += self.num_prefills
560
+ self.num_prefills = 0
561
+ self.num_prefill_tokens = 0
562
+ self.max_prefill_seq_len = 0
563
+ self.max_query_len = 1
564
+
565
+ self.slot_mapping = self.slot_mapping[:num_seqs]
566
+ else:
567
+ assert self.seq_lens_tensor is not None
568
+
569
+ assert num_seqs > 0
570
+ assert num_queries > 0
571
+ assert model_input.attn_metadata is not None
572
+ assert sampled_token_ids is not None
573
+
574
+ # When using cudagraph, the num_seqs is padded to the next captured
575
+ # batch sized, but num_queries tracks the actual number of requests in
576
+ # the batch. For --enforce-eager mode, num_seqs == num_queries
577
+ if num_seqs != num_queries:
578
+ assert num_seqs > num_queries
579
+ assert self.use_cuda_graph
580
+
581
+ model_input.input_tokens[:num_queries] = sampled_token_ids.flatten()
582
+
583
+ # Update GPU tensors
584
+ ops.advance_step_flashinfer(
585
+ num_seqs=num_seqs,
586
+ num_queries=num_queries,
587
+ block_size=block_size,
588
+ input_tokens=model_input.input_tokens,
589
+ sampled_token_ids=model_input.input_tokens,
590
+ input_positions=model_input.input_positions,
591
+ seq_lens=self.seq_lens_tensor,
592
+ slot_mapping=self.slot_mapping,
593
+ block_tables=self.block_tables,
594
+ paged_kv_indices=self.paged_kv_indices,
595
+ paged_kv_indptr=self.paged_kv_indptr,
596
+ paged_kv_last_page_len=self.paged_kv_last_page_len,
597
+ block_table_bound=self.block_table_bound)
598
+
599
+
600
+ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
601
+
602
+ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
603
+
604
+ self.input_builder = input_builder
605
+ self.runner = input_builder.runner
606
+
607
+ self.sliding_window = input_builder.sliding_window
608
+ self.block_size = input_builder.block_size
609
+
610
+ # Global hyperparameters shared by all attention layers
611
+ self.global_hyperparameters: Optional[PerLayerParameters] = None
612
+
613
+ self.vllm_config = get_current_vllm_config()
614
+
615
+ def prepare(self):
616
+ self.slot_mapping: List[int] = []
617
+ self.prefill_seq_lens: List[int] = []
618
+ self.context_lens: List[int] = []
619
+ self.block_tables: List[List[int]] = []
620
+ self.curr_seq_lens: List[int] = []
621
+ self.multimodal_placeholder_maps: Dict[
622
+ str,
623
+ MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
624
+ self.num_prefills = 0
625
+ self.num_prefill_tokens = 0
626
+ self.num_decode_tokens = 0
627
+
628
+ # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
629
+ # for the precise definition of the following fields.
630
+ # An example:
631
+ # request 1, page indices [0, 5, 8]
632
+ # request 2, page indices [1, 6, 7]
633
+ # request 3, page indices [3, 4]
634
+ # paged_kv_indices is a concatenation of page indices of all requests:
635
+ # [0, 5, 8, 1, 6, 7, 3, 4]
636
+ # paged_kv_indptr is used to index into paged_kv_indices:
637
+ # [0, 3, 6, 8]
638
+ self.paged_kv_indices: List[int] = []
639
+ # 0 at the beginning of paged_kv_indptr indicates the start of the
640
+ # first request’s page indices in the paged_kv_indices list.
641
+ self.paged_kv_indptr: List[int] = [0]
642
+ # paged_kv_last_page_len is the length of the last page of each request
643
+ self.paged_kv_last_page_len: List[int] = []
644
+ self.total_blocks = 0
645
+ self.is_profile_run: bool = False
646
+
647
+ if self.global_hyperparameters is None:
648
+ # Infer global hyperparameters, since currently we only support
649
+ # models in which all layers share the same values for the
650
+ # following hyperparameters:
651
+ # - `window_left`
652
+ # - `logits_soft_cap`
653
+ # - `sm_scale`
654
+ inferred_params = infer_global_hyperparameters(
655
+ get_per_layer_parameters(self.vllm_config))
656
+ self.global_hyperparameters = inferred_params
657
+ self.window_left = inferred_params.window_left
658
+ self.logits_soft_cap = inferred_params.logits_soft_cap
659
+ self.sm_scale = inferred_params.sm_scale
660
+
661
+ def _add_seq_group(
662
+ self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
663
+ chunked_prefill_enabled: bool):
664
+ """Add a sequence group to the metadata. Specifically update/append
665
+ 1. context length.
666
+ 2. block table.
667
+ 3. slot mapping.
668
+ """
669
+ is_prompt = inter_data.is_prompt
670
+ block_tables = inter_data.block_tables
671
+ computed_block_nums = inter_data.computed_block_nums
672
+
673
+ for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
674
+ curr_sliding_window_block) in zip(
675
+ inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
676
+ inter_data.orig_seq_lens, inter_data.seq_lens,
677
+ inter_data.query_lens, inter_data.context_lens,
678
+ inter_data.curr_sliding_window_blocks):
679
+ self.context_lens.append(context_len)
680
+ if is_prompt:
681
+ mm_maps = inter_data.multi_modal_placeholder_maps
682
+ if mm_maps:
683
+ for modality, placeholders in mm_maps.items():
684
+ self.multimodal_placeholder_maps[modality].extend(
685
+ placeholders)
686
+ self.num_prefills += 1
687
+ self.num_prefill_tokens += token_len
688
+ self.prefill_seq_lens.append(seq_len)
689
+ else:
690
+ assert query_len == 1, (
691
+ "seq_len: {}, context_len: {}, query_len: {}".format(
692
+ seq_len, context_len, query_len))
693
+ self.num_decode_tokens += query_len
694
+ self.curr_seq_lens.append(curr_seq_len)
695
+
696
+ # Compute block table.
697
+ # TODO(sang): Combine chunked prefill and prefix caching by
698
+ # only allowing multiple of block_size chunk size.
699
+ # NOTE: This only works for oooooooxxx style attention.
700
+ block_table = []
701
+ if inter_data.prefix_cache_hit:
702
+ block_table = computed_block_nums
703
+ elif ((chunked_prefill_enabled or not is_prompt)
704
+ and block_tables is not None):
705
+ block_table = block_tables[seq_id][-curr_sliding_window_block:]
706
+ self.block_tables.append(block_table)
707
+
708
+ is_profile_run = is_block_tables_empty(block_tables)
709
+
710
+ # Compute slot mapping.
711
+ start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
712
+ context_len,
713
+ self.sliding_window)
714
+ compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
715
+ seq_len, context_len, start_idx,
716
+ self.block_size, inter_data.block_tables)
717
+
718
+ # It is not necessary to add paged_kv_indices, paged_kv_indptr,
719
+ # and paged_kv_last_page_len for profile run because we will
720
+ # create dummy inputs.
721
+ if is_profile_run:
722
+ self.is_profile_run = is_profile_run
723
+ return
724
+
725
+ block_table = block_tables[seq_id]
726
+ self._update_paged_kv_tensors(block_table, seq_len)
727
+
728
+ def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
729
+ # Get the number of valid blocks based on sequence length.
730
+ # If seq_len = 16, block_size = 16,
731
+ # block_table_bound is 1 with 1 valid block.
732
+ # If seq_len = 15, block_size = 16,
733
+ # block_table_bound is 0 + 1 with 1 valid block.
734
+ self.total_blocks += len(block_table)
735
+ block_table_bound = seq_len // self.block_size + 1 \
736
+ if seq_len % self.block_size != 0 \
737
+ else seq_len // self.block_size
738
+ self.paged_kv_indices.extend(block_table[:block_table_bound])
739
+ self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
740
+ block_table_bound)
741
+
742
+ last_page_len = seq_len % self.block_size
743
+ if last_page_len == 0:
744
+ last_page_len = self.block_size
745
+ self.paged_kv_last_page_len.append(last_page_len)
746
+
747
+ def build(self, seq_lens: List[int], query_lens: List[int],
748
+ cuda_graph_pad_size: int, batch_size: int):
749
+ """Build attention metadata with on-device tensors.
750
+
751
+ Args:
752
+ seq_lens: The maybe padded sequence lengths of the input sequences.
753
+ query_lens: The query lengths of the input sequences.
754
+ cuda_graph_pad_size: The padding size for cuda graph.
755
+ -1 if cuda graph is not used.
756
+ batch_size: The maybe padded batch size.
757
+ """
758
+ for inter_data in self.input_builder.inter_data_list:
759
+ self._add_seq_group(inter_data,
760
+ self.input_builder.chunked_prefill_enabled)
761
+
762
+ device = self.runner.device
763
+ use_captured_graph = cuda_graph_pad_size != -1
764
+
765
+ max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
766
+ num_decode_tokens = self.num_decode_tokens
767
+ decode_query_len = max(query_lens[self.num_prefills:], default=1)
768
+
769
+ if use_captured_graph:
770
+ self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
771
+ self.block_tables.extend([] * cuda_graph_pad_size)
772
+ num_decode_tokens = batch_size - self.num_prefill_tokens
773
+
774
+ # The shape of graph_block_tables is
775
+ # [max batch size, max context len // block size].
776
+ input_block_tables = self.runner.graph_block_tables[:batch_size]
777
+ max_blocks = input_block_tables.shape[1]
778
+ for i, block_table in enumerate(self.block_tables):
779
+ if block_table:
780
+ num_blocks = len(block_table)
781
+ if num_blocks <= max_blocks:
782
+ input_block_tables[i, :num_blocks] = block_table
783
+ else:
784
+ # It may be possible to have more blocks allocated due
785
+ # to lookahead slots of multi-step, however, they are
786
+ # not used anyway, so can be safely ignored.
787
+ input_block_tables[
788
+ i, :max_blocks] = block_table[:max_blocks]
789
+
790
+ block_tables = torch.from_numpy(input_block_tables).to(
791
+ device, non_blocking=True)
792
+
793
+ last_paged_kv_indptr = self.paged_kv_indptr[-1]
794
+ self.paged_kv_indptr.extend([last_paged_kv_indptr] *
795
+ cuda_graph_pad_size)
796
+ self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
797
+ else:
798
+ block_tables = make_tensor_with_pad(
799
+ self.block_tables,
800
+ pad=0,
801
+ dtype=torch.int,
802
+ device=device,
803
+ )
804
+
805
+ assert device is not None
806
+ seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
807
+ self.runner.pin_memory)
808
+ query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
809
+ self.runner.pin_memory)
810
+ slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
811
+ device, self.runner.pin_memory)
812
+ query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
813
+ dtype=torch.int32,
814
+ device=device)
815
+ seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
816
+ dtype=torch.int32,
817
+ device=device)
818
+ placeholder_index_maps = {
819
+ modality: placeholder_map.index_map()
820
+ for modality, placeholder_map in
821
+ self.multimodal_placeholder_maps.items()
822
+ }
823
+ torch.cumsum(seq_lens_tensor,
824
+ dim=0,
825
+ dtype=seq_start_loc.dtype,
826
+ out=seq_start_loc[1:])
827
+ torch.cumsum(query_lens_tensor,
828
+ dim=0,
829
+ dtype=query_start_loc.dtype,
830
+ out=query_start_loc[1:])
831
+
832
+ if len(self.paged_kv_indptr) > 0:
833
+ # extend to the maximum number of blocks as returned by the
834
+ # scheduler
835
+ self.paged_kv_indices.extend(
836
+ [0] * (self.total_blocks - len(self.paged_kv_indices)))
837
+ paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
838
+ device="cpu",
839
+ dtype=torch.int)
840
+ paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
841
+ device="cpu",
842
+ dtype=torch.int)
843
+ paged_kv_last_page_len_tensor = torch.tensor(
844
+ self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
845
+ block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
846
+ 1,
847
+ device="cpu",
848
+ dtype=torch.int)
849
+ else:
850
+ paged_kv_indices_tensor = None
851
+ paged_kv_indptr_tensor = None
852
+ paged_kv_last_page_len_tensor = None
853
+ block_table_bound_tensor = None
854
+
855
+ if self.runner.kv_cache_dtype.startswith("fp8"):
856
+ kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
857
+ self.runner.kv_cache_dtype)
858
+ else:
859
+ kv_cache_dtype = get_kv_cache_torch_dtype(
860
+ self.runner.kv_cache_dtype, self.runner.model_config.dtype)
861
+
862
+ return FlashInferMetadata(
863
+ decode_query_len=decode_query_len,
864
+ num_prefills=self.num_prefills,
865
+ slot_mapping=slot_mapping_tensor,
866
+ multi_modal_placeholder_index_maps=placeholder_index_maps,
867
+ enable_kv_scales_calculation=False,
868
+ num_prefill_tokens=self.num_prefill_tokens,
869
+ num_decode_tokens=num_decode_tokens,
870
+ max_prefill_seq_len=max_prefill_seq_len,
871
+ block_tables=block_tables,
872
+ paged_kv_indptr=paged_kv_indptr_tensor,
873
+ paged_kv_indices=paged_kv_indices_tensor,
874
+ paged_kv_last_page_len=paged_kv_last_page_len_tensor,
875
+ block_table_bound=block_table_bound_tensor,
876
+ seq_lens_tensor=seq_lens_tensor,
877
+ num_qo_heads=self.runner.model_config.get_num_attention_heads(
878
+ self.runner.parallel_config),
879
+ num_kv_heads=self.runner.model_config.get_num_kv_heads(
880
+ self.runner.parallel_config),
881
+ head_dim=self.runner.model_config.get_head_size(),
882
+ page_size=self.block_size,
883
+ seq_start_loc=seq_start_loc,
884
+ query_start_loc=query_start_loc,
885
+ device=device,
886
+ data_type=kv_cache_dtype,
887
+ q_data_type=self.runner.model_config.dtype,
888
+ use_cuda_graph=use_captured_graph,
889
+ is_profile_run=self.is_profile_run,
890
+ window_left=self.window_left,
891
+ logits_soft_cap=self.logits_soft_cap,
892
+ sm_scale=self.sm_scale,
893
+ )
894
+
895
+
896
+ class FlashInferImpl(AttentionImpl):
897
+
898
+ def __init__(
899
+ self,
900
+ num_heads: int,
901
+ head_size: int,
902
+ scale: float,
903
+ num_kv_heads: int,
904
+ alibi_slopes: Optional[List[float]],
905
+ sliding_window: Optional[int],
906
+ kv_cache_dtype: str,
907
+ blocksparse_params: Optional[Dict[str, Any]] = None,
908
+ logits_soft_cap: Optional[float] = None,
909
+ attn_type: str = AttentionType.DECODER,
910
+ ) -> None:
911
+ self.num_heads = num_heads
912
+ self.head_size = head_size
913
+ self.scale = float(scale)
914
+ self.num_kv_heads = num_kv_heads
915
+ if alibi_slopes is not None:
916
+ alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
917
+ self.alibi_slopes = alibi_slopes
918
+ self.sliding_window = ((sliding_window - 1,
919
+ 0) if sliding_window is not None else (-1, -1))
920
+ self.kv_cache_dtype = kv_cache_dtype
921
+ self.logits_soft_cap = logits_soft_cap
922
+
923
+ assert self.num_heads % self.num_kv_heads == 0
924
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
925
+
926
+ if attn_type != AttentionType.DECODER:
927
+ raise NotImplementedError("Encoder self-attention and "
928
+ "encoder/decoder cross-attention "
929
+ "are not implemented for "
930
+ "FlashInferImpl")
931
+
932
+ def forward(
933
+ self,
934
+ layer: AttentionLayer,
935
+ query: torch.Tensor,
936
+ key: torch.Tensor,
937
+ value: torch.Tensor,
938
+ kv_cache: torch.Tensor,
939
+ attn_metadata: FlashInferMetadata,
940
+ output: Optional[torch.Tensor] = None,
941
+ ) -> torch.Tensor:
942
+
943
+ # TODO: directly write to output tensor
944
+ num_heads: int = self.num_heads
945
+ head_size: int = self.head_size
946
+ num_kv_heads: int = self.num_kv_heads
947
+ kv_cache_dtype: str = self.kv_cache_dtype
948
+ softmax_scale: float = self.scale
949
+ window_size = self.sliding_window
950
+ alibi_slopes = self.alibi_slopes
951
+ logits_soft_cap = self.logits_soft_cap
952
+
953
+ num_tokens, hidden_size = query.shape
954
+ query = query.view(-1, num_heads, head_size)
955
+ key = key.view(-1, num_kv_heads, head_size)
956
+ value = value.view(-1, num_kv_heads, head_size)
957
+
958
+ if kv_cache.numel() > 0:
959
+ # Use the same reshape and cache kernel as flash attention.
960
+ ops.reshape_and_cache_flash(
961
+ key,
962
+ value,
963
+ kv_cache[:, 0],
964
+ kv_cache[:, 1],
965
+ attn_metadata.slot_mapping.flatten(),
966
+ kv_cache_dtype,
967
+ layer._k_scale,
968
+ layer._v_scale,
969
+ )
970
+ # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
971
+ # to process the cache when the kv_cache_dtype is fp8
972
+ if kv_cache_dtype.startswith("fp8"):
973
+ torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
974
+ kv_cache_dtype)
975
+ kv_cache = kv_cache.view(torch_dtype)
976
+
977
+ num_prefill_tokens = attn_metadata.num_prefill_tokens
978
+ num_decode_tokens = attn_metadata.num_decode_tokens
979
+ assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
980
+ f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
981
+ assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
982
+ f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
983
+ query = query.contiguous(
984
+ ) # Flashinfer requires query to be contiguous
985
+ # Query for decode. KV is not needed because it is already cached.
986
+ # QKV for prefill.
987
+ decode_query = query[num_prefill_tokens:]
988
+ query = query[:num_prefill_tokens]
989
+
990
+ key = key[:num_prefill_tokens]
991
+ value = value[:num_prefill_tokens]
992
+
993
+ assert query.shape[0] == num_prefill_tokens
994
+ assert decode_query.shape[0] == num_decode_tokens
995
+
996
+ window_left = window_size[0] if window_size is not None else -1
997
+
998
+ prefill_output: Optional[torch.Tensor] = None
999
+ decode_output: Optional[torch.Tensor] = None
1000
+ if prefill_meta := attn_metadata.prefill_metadata:
1001
+ # We will use flash attention for prefill
1002
+ # when kv_cache is not provided.
1003
+ # This happens when vllm runs the profiling to
1004
+ # determine the number of blocks.
1005
+ if kv_cache.numel() == 0:
1006
+ prefill_output = flash_attn_varlen_func(
1007
+ q=query,
1008
+ k=key,
1009
+ v=value,
1010
+ cu_seqlens_q=prefill_meta.seq_start_loc,
1011
+ cu_seqlens_k=prefill_meta.seq_start_loc,
1012
+ max_seqlen_q=prefill_meta.max_prefill_seq_len,
1013
+ max_seqlen_k=prefill_meta.max_prefill_seq_len,
1014
+ softmax_scale=softmax_scale,
1015
+ causal=True,
1016
+ window_size=window_size,
1017
+ alibi_slopes=alibi_slopes,
1018
+ )
1019
+ else:
1020
+ assert prefill_meta is not None
1021
+ assert prefill_meta.prefill_wrapper is not None
1022
+
1023
+ assert prefill_meta.prefill_wrapper._causal
1024
+ assert prefill_meta.prefill_wrapper._window_left == window_left
1025
+ assert prefill_meta.prefill_wrapper._logits_soft_cap == (
1026
+ logits_soft_cap or 0.0)
1027
+ assert prefill_meta.prefill_wrapper._sm_scale == softmax_scale
1028
+
1029
+ prefill_output = prefill_meta.prefill_wrapper.run(
1030
+ query,
1031
+ kv_cache,
1032
+ k_scale=layer._k_scale_float,
1033
+ v_scale=layer._v_scale_float,
1034
+ )
1035
+ if decode_meta := attn_metadata.decode_metadata:
1036
+ assert decode_meta is not None
1037
+ assert decode_meta.decode_wrapper is not None
1038
+
1039
+ assert decode_meta.decode_wrapper._window_left == window_left
1040
+ assert decode_meta.decode_wrapper._logits_soft_cap == (
1041
+ logits_soft_cap or 0.0)
1042
+ assert decode_meta.decode_wrapper._sm_scale == softmax_scale
1043
+
1044
+ decode_output = decode_meta.decode_wrapper.run(
1045
+ decode_query,
1046
+ kv_cache,
1047
+ k_scale=layer._k_scale_float,
1048
+ v_scale=layer._v_scale_float,
1049
+ )
1050
+
1051
+ if prefill_output is None and decode_output is not None:
1052
+ # Decode only batch.
1053
+ output, num_tokens = decode_output, num_decode_tokens
1054
+ elif decode_output is None and prefill_output is not None:
1055
+ # Prefill only batch.
1056
+ output, num_tokens = prefill_output, num_prefill_tokens
1057
+ else:
1058
+ # Chunked prefill batch does not work with speculative decoding in
1059
+ # FlashInfer backend, so the query length for decode should be 1.
1060
+ assert prefill_output is not None
1061
+ assert decode_output is not None
1062
+ assert decode_meta is not None
1063
+ assert decode_meta.decode_query_len == 1
1064
+ decode_output = decode_output.squeeze(1)
1065
+ output = torch.cat([prefill_output, decode_output], dim=0)
1066
+ return output.view(num_tokens, hidden_size)
.venv/lib/python3.11/site-packages/vllm/attention/backends/hpu_attn.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ ###############################################################################
4
+ # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
5
+ ###############################################################################
6
+
7
+ import os
8
+ from dataclasses import dataclass
9
+ from typing import Any, Dict, List, Optional, Tuple, Type
10
+
11
+ import torch
12
+ import vllm_hpu_extension.ops as ops
13
+ from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
14
+ VLLMKVCache)
15
+
16
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
17
+ AttentionLayer,
18
+ AttentionMetadata, AttentionType)
19
+ from vllm.attention.backends.utils import CommonAttentionState
20
+ from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
21
+ HPUPagedAttentionMetadata)
22
+ from vllm.logger import init_logger
23
+
24
+ logger = init_logger(__name__)
25
+
26
+
27
+ class HPUAttentionBackend(AttentionBackend):
28
+
29
+ @staticmethod
30
+ def get_name() -> str:
31
+ return "HPU_ATTN"
32
+
33
+ @staticmethod
34
+ def get_impl_cls() -> Type["HPUAttentionImpl"]:
35
+ return HPUAttentionImpl
36
+
37
+ @staticmethod
38
+ def get_metadata_cls() -> Type["AttentionMetadata"]:
39
+ return HPUAttentionMetadata
40
+
41
+ @staticmethod
42
+ def get_state_cls() -> Type["CommonAttentionState"]:
43
+ return CommonAttentionState
44
+
45
+ @staticmethod
46
+ def get_kv_cache_shape(
47
+ num_blocks: int,
48
+ block_size: int,
49
+ num_kv_heads: int,
50
+ head_size: int,
51
+ ) -> Tuple[int, ...]:
52
+ return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size,
53
+ num_kv_heads, head_size)
54
+
55
+ @staticmethod
56
+ def swap_blocks(
57
+ src_kv_cache: torch.Tensor,
58
+ dst_kv_cache: torch.Tensor,
59
+ src_to_dst: Dict[int, int],
60
+ ) -> None:
61
+ HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
62
+
63
+ @staticmethod
64
+ def copy_blocks(
65
+ kv_caches: List[torch.Tensor],
66
+ src_to_dists: Dict[int, List[int]],
67
+ ) -> None:
68
+ HPUPagedAttention.copy_blocks(kv_caches, src_to_dists)
69
+
70
+
71
+ @dataclass
72
+ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
73
+ """Metadata for HPUAttentionbackend."""
74
+ # Currently, input sequences can only contain all prompts
75
+ # or all decoding. True if all sequences are prompts.
76
+ is_prompt: bool
77
+ attn_bias: Optional[torch.Tensor]
78
+ seq_lens_tensor: Optional[torch.Tensor]
79
+
80
+
81
+ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
82
+ """
83
+ If the input tensors contain prompt tokens, the layout is as follows:
84
+ |<--------------- num_prefill_tokens ----------------->|
85
+ |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
86
+
87
+ Otherwise, the layout is as follows:
88
+ |<----------------- num_decode_tokens ------------------>|
89
+ |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
90
+
91
+ Generation tokens can contain padding when cuda-graph is used.
92
+ Currently, prompt tokens don't contain any padding.
93
+
94
+ The prompts might have different lengths, while the generation tokens
95
+ always have length 1.
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ num_heads: int,
101
+ head_size: int,
102
+ scale: float,
103
+ num_kv_heads: int,
104
+ alibi_slopes: Optional[List[float]],
105
+ sliding_window: Optional[int],
106
+ kv_cache_dtype: str,
107
+ blocksparse_params: Optional[Dict[str, Any]] = None,
108
+ max_seq_len: int = 4096,
109
+ attn_type: str = AttentionType.DECODER,
110
+ ) -> None:
111
+ super(AttentionImpl, self).__init__()
112
+ self.kv_cache_dtype = kv_cache_dtype
113
+ self.num_heads = num_heads
114
+ self.head_size = head_size
115
+ self.scale = float(scale)
116
+ self.matmul_qk = Matmul()
117
+ self.softmax = Softmax()
118
+ self.matmul_av = Matmul()
119
+ self.batch2block_matmul = Matmul()
120
+ self.block2batch_matmul = Matmul()
121
+ # NOTE(kzawora): Contiguous PA is off until model runner supports it
122
+ self.k_cache = VLLMKVCache()
123
+ self.k_cache.use_contiguous_pa = False
124
+ self.v_cache = VLLMKVCache()
125
+ self.v_cache.use_contiguous_pa = False
126
+ # NOTE(kzawora): Pipelined PA is off until model runner supports it
127
+ ops.pa_impl = ops.pa
128
+
129
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
130
+ self.sliding_window = sliding_window
131
+ self.alibi_slopes = alibi_slopes
132
+ if alibi_slopes is not None:
133
+ alibi_slopes_tensor = torch.tensor(alibi_slopes,
134
+ dtype=torch.bfloat16)
135
+ self.alibi_slopes = alibi_slopes_tensor
136
+ assert self.num_heads % self.num_kv_heads == 0
137
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
138
+
139
+ self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
140
+ '0').lower() in ['1', 'true']
141
+ self.fused_scaled_dot_product_attention = None
142
+ if self.prefill_usefusedsdpa:
143
+ assert alibi_slopes is None, \
144
+ 'Prefill with FusedSDPA not supported with alibi slopes!'
145
+ try:
146
+ from habana_frameworks.torch.hpex.kernels import FusedSDPA
147
+ self.fused_scaled_dot_product_attention = ModuleFusedSDPA(
148
+ FusedSDPA)
149
+ except ImportError:
150
+ logger().warning("Could not import HPU FusedSDPA kernel. "
151
+ "vLLM will use native implementation.")
152
+
153
+ suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
154
+ if head_size not in suppored_head_sizes:
155
+ raise ValueError(
156
+ f"Head size {head_size} is not supported by PagedAttention. "
157
+ f"Supported head sizes are: {suppored_head_sizes}.")
158
+
159
+ if attn_type != AttentionType.DECODER:
160
+ raise NotImplementedError("Encoder self-attention and "
161
+ "encoder/decoder cross-attention "
162
+ "are not implemented for "
163
+ "HPUAttentionImpl")
164
+
165
+ def forward(
166
+ self,
167
+ layer: AttentionLayer,
168
+ query: torch.Tensor,
169
+ key: torch.Tensor,
170
+ value: torch.Tensor,
171
+ kv_cache: torch.Tensor,
172
+ attn_metadata: HPUAttentionMetadata,
173
+ output: Optional[torch.Tensor] = None,
174
+ ) -> torch.Tensor:
175
+ """Forward pass with xFormers and PagedAttention.
176
+
177
+ Args:
178
+ query: shape = [num_tokens, num_heads * head_size]
179
+ key: shape = [num_tokens, num_kv_heads * head_size]
180
+ value: shape = [num_tokens, num_kv_heads * head_size]
181
+ kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
182
+ attn_metadata: Metadata for attention.
183
+ Returns:
184
+ shape = [num_tokens, num_heads * head_size]
185
+ """
186
+ batch_size, seq_len, hidden_size = query.shape
187
+ _, seq_len_kv, _ = key.shape
188
+
189
+ query = query.view(-1, self.num_heads, self.head_size)
190
+ key = key.view(-1, self.num_kv_heads, self.head_size)
191
+ value = value.view(-1, self.num_kv_heads, self.head_size)
192
+ block_indices = attn_metadata.block_indices
193
+ block_offsets = attn_metadata.block_offsets
194
+ if attn_metadata.is_prompt:
195
+ key = key.unflatten(0, (block_indices.size(0), -1))
196
+ value = value.unflatten(0, (block_indices.size(0), -1))
197
+ if kv_cache is not None:
198
+ key_cache, value_cache = HPUPagedAttention.split_kv_cache(
199
+ kv_cache, self.num_kv_heads, self.head_size)
200
+
201
+ # Reshape the input keys and values and store them in the cache.
202
+ # If kv_cache is not provided, the new key and value tensors are
203
+ # not cached. This happens during the initial memory profiling run.
204
+ key_cache = self.k_cache(key, key_cache, block_indices,
205
+ block_offsets)
206
+ value_cache = self.v_cache(value, value_cache, block_indices,
207
+ block_offsets)
208
+
209
+ if attn_metadata.is_prompt:
210
+ # Prompt run.
211
+ if not self.prefill_usefusedsdpa:
212
+ # TODO: move this outside of model
213
+ assert attn_metadata.attn_bias is not None, \
214
+ 'attn_bias must be set before calling model.forward!'
215
+ attn_bias = attn_metadata.attn_bias
216
+ if self.alibi_slopes is not None:
217
+ position_bias = _make_alibi_bias(self.alibi_slopes,
218
+ self.num_kv_heads,
219
+ attn_bias.dtype,
220
+ attn_bias.shape[-1])
221
+ attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
222
+ attn_bias.add_(position_bias)
223
+ else:
224
+ attn_bias = None
225
+
226
+ query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
227
+ kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
228
+ self.head_size)
229
+ out = ops.prompt_attention(
230
+ query.view(query_shape),
231
+ key.view(kv_shape),
232
+ value.view(kv_shape),
233
+ attn_bias=attn_bias,
234
+ p=0.0,
235
+ scale=self.scale,
236
+ matmul_qk_op=self.matmul_qk,
237
+ softmax_op=self.softmax,
238
+ matmul_av_op=self.matmul_av,
239
+ fsdpa_op=self.fused_scaled_dot_product_attention,
240
+ )
241
+ output = out.reshape(batch_size, seq_len, hidden_size)
242
+ else:
243
+ # Decoding run.
244
+ output = HPUPagedAttention.forward_decode(
245
+ query=query,
246
+ key_cache=key_cache,
247
+ value_cache=value_cache,
248
+ block_list=attn_metadata.block_list,
249
+ block_mapping=attn_metadata.block_mapping,
250
+ block_bias=attn_metadata.attn_bias,
251
+ block_scales=attn_metadata.block_scales,
252
+ block_groups=None,
253
+ scale=self.scale,
254
+ matmul_qk_op=self.matmul_qk,
255
+ matmul_av_op=self.matmul_av,
256
+ batch2block_matmul_op=self.batch2block_matmul,
257
+ block2batch_matmul_op=self.block2batch_matmul,
258
+ keys_fetch_func=self.k_cache.fetch_from_cache,
259
+ values_fetch_func=self.v_cache.fetch_from_cache)
260
+ # Reshape the output tensor.
261
+ return output.view(batch_size, seq_len, hidden_size)
262
+
263
+
264
+ def _make_alibi_bias(
265
+ alibi_slopes: torch.Tensor,
266
+ num_kv_heads: int,
267
+ dtype: torch.dtype,
268
+ seq_len: int,
269
+ ) -> torch.Tensor:
270
+ bias = torch.arange(seq_len, dtype=dtype)
271
+ # NOTE(zhuohan): HF uses
272
+ # `bias = bias[None, :].repeat(seq_len, 1)`
273
+ # here. We find that both biases give the same results, but
274
+ # the bias below more accurately follows the original ALiBi
275
+ # paper.
276
+ # Calculate a matrix where each element represents ith element- jth
277
+ # element.
278
+ bias = bias[None, :] - bias[:, None]
279
+
280
+ padded_len = (seq_len + 7) // 8 * 8
281
+ num_heads = alibi_slopes.shape[0]
282
+ bias = torch.empty(
283
+ 1, # batch size
284
+ num_heads,
285
+ seq_len,
286
+ padded_len,
287
+ device=alibi_slopes.device,
288
+ dtype=dtype,
289
+ )[:, :, :, :seq_len].copy_(bias)
290
+ bias.mul_(alibi_slopes[:, None, None])
291
+ if num_heads != num_kv_heads:
292
+ bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
293
+ return bias
.venv/lib/python3.11/site-packages/vllm/attention/backends/ipex_attn.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """ Attention layer with torch scaled_dot_product_attention
3
+ and PagedAttention."""
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Optional, Tuple, Type
6
+
7
+ import torch
8
+
9
+ from vllm._ipex_ops import ipex_ops
10
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
11
+ AttentionLayer,
12
+ AttentionMetadata, AttentionType)
13
+ from vllm.attention.backends.utils import CommonAttentionState
14
+ from vllm.attention.ops.paged_attn import (PagedAttention,
15
+ PagedAttentionMetadata)
16
+
17
+ _PARTITION_SIZE = 512
18
+
19
+
20
+ class IpexAttnBackend(AttentionBackend):
21
+
22
+ @staticmethod
23
+ def get_name() -> str:
24
+ return "IPEX"
25
+
26
+ @staticmethod
27
+ def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
28
+ return IpexAttnBackendImpl
29
+
30
+ @staticmethod
31
+ def get_metadata_cls() -> Type["IpexAttnMetadata"]:
32
+ return IpexAttnMetadata
33
+
34
+ @staticmethod
35
+ def get_state_cls() -> Type["CommonAttentionState"]:
36
+ return CommonAttentionState
37
+
38
+ @staticmethod
39
+ def get_kv_cache_shape(
40
+ num_blocks: int,
41
+ block_size: int,
42
+ num_kv_heads: int,
43
+ head_size: int,
44
+ ) -> Tuple[int, ...]:
45
+ return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
46
+ num_kv_heads, head_size)
47
+
48
+ @staticmethod
49
+ def swap_blocks(
50
+ src_kv_cache: torch.Tensor,
51
+ dst_kv_cache: torch.Tensor,
52
+ src_to_dst: torch.Tensor,
53
+ ) -> None:
54
+ from vllm._ipex_ops import ipex_ops as ops
55
+ ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
56
+
57
+ @staticmethod
58
+ def copy_blocks(
59
+ kv_caches: List[torch.Tensor],
60
+ src_to_dists: torch.Tensor,
61
+ ) -> None:
62
+ from vllm._ipex_ops import ipex_ops as ops
63
+ key_caches = [kv_cache[0] for kv_cache in kv_caches]
64
+ value_caches = [kv_cache[1] for kv_cache in kv_caches]
65
+ ops.copy_blocks(key_caches, value_caches, src_to_dists)
66
+
67
+
68
+ @dataclass
69
+ class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata):
70
+ """Metadata for IpexAttnBackend.
71
+ """
72
+ # Currently, input sequences can only contain all prompts
73
+ # or all decoding. True if all sequences are prompts.
74
+ is_prompt: bool
75
+ slot_mapping: torch.Tensor
76
+ seq_lens: Optional[List[int]]
77
+ seqlen_q: Optional[torch.Tensor]
78
+ max_seqlen: Optional[int]
79
+
80
+ def __post_init__(self):
81
+ # Set during the execution of the first attention op.
82
+ # It is a list because it is needed to set per prompt
83
+ # when alibi slopes is used. It is because of the limitation
84
+ # from xformer API.
85
+ # will not appear in the __repr__ and __init__
86
+ self.attn_bias: Optional[List[torch.Tensor]] = None
87
+
88
+ @property
89
+ def prefill_metadata(self) -> Optional["IpexAttnMetadata"]:
90
+ # Currently chunked prefill is not supported
91
+ if self.num_decode_tokens == 0:
92
+ assert self.num_prefills > 0
93
+ return self
94
+
95
+ return None
96
+
97
+ @property
98
+ def decode_metadata(self) -> Optional["IpexAttnMetadata"]:
99
+ # Currently chunked prefill is not supported
100
+ if self.num_prefills > 0:
101
+ assert self.num_decode_tokens == 0
102
+ return None
103
+
104
+ return self
105
+
106
+
107
+ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
108
+
109
+ def __init__(
110
+ self,
111
+ num_heads: int,
112
+ head_size: int,
113
+ scale: float,
114
+ num_kv_heads: int,
115
+ alibi_slopes: Optional[List[float]],
116
+ sliding_window: Optional[int],
117
+ kv_cache_dtype: str,
118
+ blocksparse_params: Optional[Dict[str, Any]] = None,
119
+ logits_soft_cap: Optional[float] = None,
120
+ attn_type: str = AttentionType.DECODER,
121
+ ) -> None:
122
+ if blocksparse_params is not None:
123
+ raise ValueError(
124
+ "IPEX backend does not support block-sparse attention.")
125
+ self.num_heads = num_heads
126
+ self.head_size = head_size
127
+ self.scale = float(scale)
128
+ self.num_kv_heads = num_kv_heads
129
+ if alibi_slopes is not None:
130
+ alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
131
+ self.alibi_slopes = alibi_slopes
132
+ self.sliding_window = sliding_window
133
+ self.kv_cache_dtype = kv_cache_dtype
134
+
135
+ assert self.num_heads % self.num_kv_heads == 0
136
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
137
+ self.need_mask = (self.alibi_slopes is not None
138
+ or self.sliding_window is not None)
139
+ if logits_soft_cap is None:
140
+ logits_soft_cap = 0
141
+ self.logits_soft_cap = logits_soft_cap
142
+
143
+ supported_head_sizes = PagedAttention.get_supported_head_sizes()
144
+ if head_size not in supported_head_sizes:
145
+ raise ValueError(
146
+ f"Head size {head_size} is not supported by PagedAttention. "
147
+ f"Supported head sizes are: {supported_head_sizes}.")
148
+ if kv_cache_dtype != "auto":
149
+ raise NotImplementedError(
150
+ "IPEX backend does not support FP8 KV cache. "
151
+ "Please use xFormers backend instead.")
152
+ if attn_type != AttentionType.DECODER:
153
+ raise NotImplementedError("Encoder self-attention and "
154
+ "encoder/decoder cross-attention "
155
+ "are not implemented for "
156
+ "IpexAttnBackendImpl")
157
+
158
+ def split_kv_cache(
159
+ self,
160
+ kv_cache: torch.Tensor,
161
+ num_kv_heads: int,
162
+ head_size: int,
163
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
164
+ x = 1
165
+ num_blocks = kv_cache.shape[1]
166
+
167
+ key_cache = kv_cache[0]
168
+ key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
169
+ -1, x)
170
+ value_cache = kv_cache[1]
171
+ value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
172
+ return key_cache, value_cache
173
+
174
+ def forward(
175
+ self,
176
+ layer: AttentionLayer,
177
+ query: torch.Tensor,
178
+ key: torch.Tensor,
179
+ value: torch.Tensor,
180
+ kv_cache: torch.Tensor,
181
+ attn_metadata: IpexAttnMetadata, # type: ignore
182
+ output: Optional[torch.Tensor] = None,
183
+ ) -> torch.Tensor:
184
+ """Forward pass with IPEX varlen_attention and PagedAttention.
185
+
186
+ Args:
187
+ query: shape = [num_tokens, num_heads * head_size]
188
+ key: shape = [num_tokens, num_kv_heads * head_size]
189
+ value: shape = [num_tokens, num_kv_heads * head_size]
190
+ kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
191
+ NOTE: kv_cache will be an empty tensor with shape [0]
192
+ for profiling run.
193
+ attn_metadata: Metadata for attention.
194
+ Returns:
195
+ shape = [num_tokens, num_heads * head_size]
196
+ """
197
+ assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
198
+ num_tokens, hidden_size = query.shape
199
+ # Reshape the query, key, and value tensors.
200
+ query = query.view(-1, self.num_heads, self.head_size)
201
+ key = key.view(-1, self.num_kv_heads, self.head_size)
202
+ value = value.view(-1, self.num_kv_heads, self.head_size)
203
+
204
+ if kv_cache.numel() > 0:
205
+ key_cache, value_cache = self.split_kv_cache(
206
+ kv_cache, self.num_kv_heads, self.head_size)
207
+ ipex_ops.reshape_and_cache(
208
+ key,
209
+ value,
210
+ key_cache,
211
+ value_cache,
212
+ attn_metadata.slot_mapping.flatten(),
213
+ self.kv_cache_dtype,
214
+ layer._k_scale,
215
+ layer._v_scale,
216
+ )
217
+
218
+ if attn_metadata.is_prompt:
219
+ assert attn_metadata.seq_lens is not None
220
+ if (kv_cache.numel() == 0
221
+ or attn_metadata.block_tables.numel() == 0):
222
+ if self.num_kv_heads != self.num_heads:
223
+ key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
224
+ value = value.repeat_interleave(self.num_queries_per_kv,
225
+ dim=1)
226
+
227
+ if attn_metadata.attn_bias is None:
228
+ if self.alibi_slopes is not None:
229
+ att_masks = _make_alibi_bias(
230
+ self.alibi_slopes, query.dtype,
231
+ attn_metadata.seq_lens) # type: ignore
232
+ elif self.sliding_window is not None:
233
+ att_masks = _make_sliding_window_bias(
234
+ attn_metadata.seq_lens, self.sliding_window,
235
+ query.dtype) # type: ignore
236
+ else:
237
+ att_masks = _make_sliding_window_bias(
238
+ attn_metadata.seq_lens, None, dtype=query.dtype)
239
+ attn_metadata.attn_bias = att_masks
240
+
241
+ output = torch.empty(
242
+ (num_tokens, self.num_heads, self.head_size),
243
+ dtype=query.dtype,
244
+ device=query.device)
245
+ ipex_ops.varlen_attention(
246
+ query,
247
+ key,
248
+ value,
249
+ output,
250
+ attn_metadata.seqlen_q,
251
+ attn_metadata.seqlen_q,
252
+ attn_metadata.max_seqlen,
253
+ attn_metadata.max_seqlen,
254
+ pdropout=0.0,
255
+ softmax_scale=self.scale,
256
+ zero_tensors=False,
257
+ is_causal=True,
258
+ return_softmax=False,
259
+ gen_=None,
260
+ logits_soft_cap=self.logits_soft_cap,
261
+ )
262
+ else:
263
+ # prefix-enabled attention
264
+ raise RuntimeError(
265
+ "IPEX backend doesn't support prefix decoding.")
266
+
267
+ else:
268
+ # Decoding run.
269
+ max_seq_len = attn_metadata.max_decode_seq_len
270
+ output = torch.empty_like(query)
271
+ block_size = value_cache.shape[3]
272
+ num_seqs, num_heads, head_size = query.shape
273
+ max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
274
+ _PARTITION_SIZE)
275
+ # NOTE(woosuk): We use a simple heuristic to decide whether to use
276
+ # PagedAttention V1 or V2. If the number of partitions is 1, we use
277
+ # V1 to avoid the overhead of reduction. Also, if the number of
278
+ # sequences or heads is large, we use V1 since there is enough work
279
+ # to parallelize.
280
+ # TODO(woosuk): Tune this heuristic.
281
+ # For context len > 8192, use V2 kernel to avoid shared memory
282
+ # shortage.
283
+ use_v1 = (max_seq_len <= 8192 and
284
+ (max_num_partitions == 1 or num_seqs * num_heads > 512))
285
+ if use_v1:
286
+ # Run PagedAttention V1.
287
+ ipex_ops.paged_attention_v1(
288
+ output,
289
+ query,
290
+ key_cache,
291
+ value_cache,
292
+ self.num_kv_heads,
293
+ self.scale,
294
+ attn_metadata.block_tables,
295
+ attn_metadata.seq_lens_tensor,
296
+ block_size,
297
+ max_seq_len,
298
+ self.alibi_slopes,
299
+ self.kv_cache_dtype,
300
+ layer._k_scale,
301
+ layer._v_scale,
302
+ )
303
+ else:
304
+ # Run PagedAttention V2.
305
+ assert _PARTITION_SIZE % block_size == 0
306
+ tmp_output = torch.empty(
307
+ size=(num_seqs, num_heads, max_num_partitions, head_size),
308
+ dtype=output.dtype,
309
+ device=output.device,
310
+ )
311
+ exp_sums = torch.empty(
312
+ size=(num_seqs, num_heads, max_num_partitions),
313
+ dtype=torch.float32,
314
+ device=output.device,
315
+ )
316
+ max_logits = torch.empty_like(exp_sums)
317
+ ipex_ops.paged_attention_v2(
318
+ output,
319
+ exp_sums,
320
+ max_logits,
321
+ tmp_output,
322
+ query,
323
+ key_cache,
324
+ value_cache,
325
+ self.num_kv_heads,
326
+ self.scale,
327
+ attn_metadata.block_tables,
328
+ attn_metadata.seq_lens_tensor,
329
+ block_size,
330
+ max_seq_len,
331
+ self.alibi_slopes,
332
+ self.kv_cache_dtype,
333
+ layer._k_scale,
334
+ layer._v_scale,
335
+ )
336
+
337
+ # Reshape the output tensor.
338
+ return output.view(-1, self.num_heads * self.head_size)
339
+
340
+
341
+ def _make_alibi_bias(
342
+ alibi_slopes: torch.Tensor,
343
+ dtype: torch.dtype,
344
+ seq_lens: List[int],
345
+ ) -> List[torch.Tensor]:
346
+ attn_biases = []
347
+ for seq_len in seq_lens:
348
+ bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
349
+ # NOTE(zhuohan): HF uses
350
+ # `bias = bias[None, :].repeat(seq_len, 1)`
351
+ # here. We find that both biases give the same results, but
352
+ # the bias below more accurately follows the original ALiBi
353
+ # paper.
354
+ bias = bias[None, :] - bias[:, None]
355
+
356
+ num_heads = alibi_slopes.shape[0]
357
+ bias = bias[None, :].repeat((num_heads, 1, 1))
358
+ bias.mul_(alibi_slopes[:, None, None])
359
+ inf_mask = torch.empty(
360
+ (1, seq_len, seq_len),
361
+ dtype=bias.dtype,
362
+ device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1)
363
+ attn_biases.append((bias + inf_mask).to(dtype))
364
+
365
+ return attn_biases
366
+
367
+
368
+ def _make_sliding_window_bias(
369
+ seq_lens: List[int],
370
+ window_size: Optional[int],
371
+ dtype: torch.dtype,
372
+ ) -> List[torch.Tensor]:
373
+ attn_biases = []
374
+ for seq_len in seq_lens:
375
+ tensor = torch.full(
376
+ (1, seq_len, seq_len),
377
+ dtype=dtype,
378
+ fill_value=1,
379
+ )
380
+ shift = 0
381
+ mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
382
+ if window_size is not None:
383
+ mask = torch.triu(mask, diagonal=shift - window_size + 1)
384
+ mask = torch.log(mask)
385
+ attn_biases.append(mask.to(dtype))
386
+
387
+ return attn_biases
.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (200 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/__pycache__/utils.cpython-311.pyc ADDED
Binary file (25.3 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/backends/mla/utils.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from abc import abstractmethod
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, Generic, List, Optional, Tuple
6
+
7
+ import torch
8
+ from compressed_tensors.quantization import QuantizationStrategy
9
+
10
+ from vllm import _custom_ops as ops
11
+ from vllm import envs
12
+ from vllm.attention.backends.abstract import (AttentionLayer,
13
+ AttentionMetadata,
14
+ MLAAttentionImpl, T)
15
+ from vllm.distributed import (get_tensor_model_parallel_world_size,
16
+ tensor_model_parallel_all_reduce)
17
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
18
+ LinearBase, RowParallelLinear,
19
+ UnquantizedLinearMethod)
20
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
21
+ CompressedTensorsLinearMethod)
22
+ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
23
+ CompressedTensorsW8A8Fp8)
24
+ from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
25
+ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
26
+ apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
27
+ from vllm.model_executor.layers.quantization.utils.quant_utils import (
28
+ scaled_dequantize, scaled_quantize)
29
+ from vllm.model_executor.layers.rotary_embedding import (
30
+ DeepseekScalingRotaryEmbedding, RotaryEmbedding)
31
+
32
+ try:
33
+ from vllm.vllm_flash_attn import flash_attn_varlen_func
34
+ except ImportError:
35
+ from flash_attn import flash_attn_varlen_func
36
+
37
+
38
+ @dataclass
39
+ class MLACommonMetadata(AttentionMetadata):
40
+ # Input positions for rotrary embeddings since for MLA the rotary
41
+ # position embeddings are applied inside the attention backend
42
+ input_positions: torch.Tensor
43
+
44
+
45
+ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
46
+ """
47
+ Common class for implementing repeated parts
48
+
49
+ Main reference: DeepseekV2 paper, and FlashInfer Implementation
50
+ (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
51
+
52
+ Deepseek's MLA attention works the following way:
53
+ * Use a single latent vector to represent the entire KV cache.
54
+ * The attention "simulates" a multi-head attention, while the compute is
55
+ similar to multi-query attention.
56
+ * The dataflow is as follows,
57
+
58
+ * B: batch/sequence length
59
+ * H: hidden size
60
+ * N: number of attention heads
61
+ * Lq: latent dimension for Q
62
+ * Lkv: latent dimension for K/V
63
+ * P: nope dimension, P+R is the actual head_dim in common attention.
64
+ * R: rope dimension, this slide of the head_dim goes through rope.
65
+ * V: V head dim.
66
+ * kv_c: latent/compressed KV
67
+ * q_c: latent/compressed Q
68
+
69
+ #
70
+ # Outside the MLA attention backend
71
+ #
72
+
73
+ 1. The hidden states (B, H) are projected down into cq (B, Lq) and
74
+ kv_c_k_pe (B, Lkv+R).
75
+ 2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq
76
+ and kv_c are normalized.
77
+
78
+ #
79
+ # Inside the MLA attention backend
80
+ #
81
+
82
+ * if prefill:
83
+
84
+ 3. The q_c is then projected up into the multi-head version.
85
+ * q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
86
+ (B, N, P) and q_pe (B, N, R).
87
+ 4. q_pe, k_pe are then passed through rotary embeddings.
88
+ 5. kv_c and k_pe are concatenated and inserted into the cache
89
+ 6. The kv_c is then projected up into the multi-head version.
90
+ * kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
91
+ dimensions for K and V, which is split into k_nope (B, N, P)
92
+ and v (B, N, V).
93
+ 7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from
94
+ q_nope, q_pe, k_nope, k_pe.
95
+ 8. Attention is computued with q, k, v.
96
+ 9. The attention computation returns (B, N, V), which is projected back
97
+ to (B, H) using out projection.
98
+
99
+ * if decode:
100
+
101
+ 3. Here's the change, we do not perform up the full up projection for
102
+ q_c, and there is no up projection at all for kv_c. This is
103
+ achieved by the technique of "weight absorption". The paper says
104
+ "Fortunately, due to the associative law of matrix multiplication,
105
+ we can absorb WUK into WUQ, and WUV into WO"
106
+ * The q up projection turns (B, Lq) into (B, N, (P+R)), we split it
107
+ into W_UQ (Lq, N, P) and W_QR (Lq, N, R).
108
+ * The kv_c up projection turns (B, Lkv) into (B, N, (P+V)), we split
109
+ it into W_UK (Lkv, N, P) and W_UV (Lkv, N, V).
110
+ * The out projection shape W_O (N*V, H) turns (B, N, V) into (B, H).
111
+ * We can precompute the product of W_UQ and W_UK into
112
+ W_UQ_UK (Lq, N, Lkv), which is possible due to QK^T operation in
113
+ attention.
114
+ * We can precompute the product of W_UV and W_O into
115
+ W_UV_O (N, Lkv, H), which is possible due to V@O as the
116
+ "epilogue" of attention
117
+ 4. We still need to compute q_pe (B, N, R) by applying W_QR to q_latent.
118
+ 5. q_pe, k_pe are then passed through rotary embeddings.
119
+ 6. kv_c and k_pe are concatenated and inserted into the cache
120
+ 7. By applying W_UQ_UK to q_latent, we have the new q_nope of shape
121
+ (B, N, Lkv).
122
+ 8. q (B, N, (Lkv+R)), k (B, (Lkv+R)) are assembled from q_nope, q_pe,
123
+ kv_a, k_pe. v (B, Lkv) is exactly the same vector as kv_a.
124
+ 9. The attention is computed with q, k, v. Note that we just performed
125
+ a MQA attention with (LKv+R) as our head dim.
126
+ 10. The KV cache is updated using the new entries k (B, N, (Lkv+R)),
127
+ which included the v and rope values.
128
+ 11. The attention computation returns (B, N, Lkv), which is projected
129
+ back to (B, H) using W_UV_O.
130
+
131
+ From @tsu-bin's calculation, we only want to use the absorption technique
132
+ for decode. The prefill algorithm should still use the up-projected MHA
133
+ for less flops and memory usage.
134
+
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ num_heads: int,
140
+ head_size: int,
141
+ scale: float,
142
+ num_kv_heads: int,
143
+ alibi_slopes: Optional[List[float]],
144
+ sliding_window: Optional[int],
145
+ kv_cache_dtype: str,
146
+ blocksparse_params: Optional[Dict[str, Any]],
147
+ logits_soft_cap: Optional[float],
148
+ attn_type: str,
149
+ # MLA Specific Arguments
150
+ q_lora_rank: Optional[int],
151
+ kv_lora_rank: int,
152
+ qk_nope_head_dim: int,
153
+ qk_rope_head_dim: int,
154
+ qk_head_dim: int,
155
+ v_head_dim: int,
156
+ rotary_emb: RotaryEmbedding,
157
+ # q_proj should be q_b_proj if q_lora_rank is not None, but from an
158
+ # attention backend perspective we rely on the layer to pass in the
159
+ # correct matrix
160
+ q_proj: ColumnParallelLinear,
161
+ kv_b_proj: ColumnParallelLinear,
162
+ o_proj: RowParallelLinear,
163
+ ) -> None:
164
+ self.num_heads = num_heads
165
+ self.head_size = head_size
166
+ self.scale = float(scale)
167
+ self.num_kv_heads = num_kv_heads
168
+ self.kv_cache_dtype = kv_cache_dtype
169
+
170
+ self.q_lora_rank = q_lora_rank
171
+ self.kv_lora_rank = kv_lora_rank
172
+ self.qk_nope_head_dim = qk_nope_head_dim
173
+ self.qk_rope_head_dim = qk_rope_head_dim
174
+ self.qk_head_dim = qk_head_dim
175
+ self.v_head_dim = v_head_dim
176
+
177
+ self.rotary_emb = rotary_emb
178
+ self.use_yarn_rope = isinstance(rotary_emb,
179
+ DeepseekScalingRotaryEmbedding)
180
+ self.q_proj = q_proj
181
+ self.kv_b_proj = kv_b_proj
182
+ self.o_proj = o_proj
183
+
184
+ def _v_up_proj_and_o_proj(self, x):
185
+ if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
186
+ if is_fp8(self.W_UV_O):
187
+ output_parallel = apply_fp8_linear_generic(
188
+ x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
189
+ self.reqaunt_input_group_shape,
190
+ self.reqaunt_weight_group_shape)
191
+ else:
192
+ output_parallel = torch.matmul(x.flatten(start_dim=1),
193
+ self.W_UV_O)
194
+ if self.tp_size > 1:
195
+ output = tensor_model_parallel_all_reduce(output_parallel)
196
+ else:
197
+ output = output_parallel
198
+ return output
199
+ else:
200
+ x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
201
+ return self.o_proj(x.reshape(-1,
202
+ self.num_heads * self.v_head_dim))[0]
203
+
204
+ def _q_proj_and_k_up_proj(self, x):
205
+ if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
206
+ if is_fp8(self.W_Q_UK):
207
+ return apply_fp8_linear_generic(
208
+ x, self.W_Q_UK, self.W_Q_UK_scales,
209
+ self.reqaunt_input_group_shape,
210
+ self.reqaunt_weight_group_shape).view(
211
+ -1, self.num_heads, self.kv_lora_rank)
212
+ return torch.matmul(x, self.W_Q_UK)\
213
+ .view(-1, self.num_heads, self.kv_lora_rank)
214
+ else:
215
+ x = torch.matmul(x, self.W_Q)\
216
+ .view(-1, self.num_heads, self.qk_nope_head_dim)
217
+ return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
218
+ .view(-1, self.num_heads, self.kv_lora_rank)
219
+
220
+ def process_weights_after_loading(self, act_dtype: torch.dtype):
221
+
222
+ def is_layer_fp8(layer: LinearBase) -> bool:
223
+ return isinstance(layer.quant_method, Fp8LinearMethod) or\
224
+ (isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
225
+ and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
226
+
227
+ def quantization_scheme_supported(layer: LinearBase) -> bool:
228
+ return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
229
+ is_layer_fp8(layer)
230
+
231
+ # TODO(lucas) This is very gross, we need a more wide scale refactor of
232
+ # all the FP8 code with a more standard way of
233
+ # defining schemes/group-shapes, we should also potentially force
234
+ # quant_methods to support a decompress function
235
+ #
236
+ # returns input_group_shape, weight_group_shape
237
+ def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
238
+ Tuple[Tuple[int, int], Tuple[int, int]]:
239
+ if isinstance(layer.quant_method, Fp8LinearMethod):
240
+ if layer.quant_method.block_quant is not None:
241
+ weight_block_size = \
242
+ layer.quant_method.quant_config.weight_block_size
243
+ # per-token-group (1, X), block-quantized (X, Y)
244
+ return (1, weight_block_size[-1]), weight_block_size
245
+ else:
246
+ return (-1, -1), (-1, -1) # per-tensor, per-tensor
247
+ elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
248
+ and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
249
+ # this is hacky but we always assume the for
250
+ # CompressedTensorsW8A8Fp8 the input is dynamic per-token
251
+ # we ignore if it is static-per-tensor since we are going to
252
+ # requantize after later anyways
253
+ strategy = layer.scheme.strategy
254
+ if strategy == QuantizationStrategy.TENSOR:
255
+ return (1, -1), (-1, -1) # per-token, per-tensor
256
+ elif strategy == QuantizationStrategy.CHANNEL:
257
+ return (1, -1), (-1, 1) # per-token, per-channel
258
+ else:
259
+ raise NotImplementedError(
260
+ f"QuantizationStrategy.{strategy} is not supported for "
261
+ "fp8 MLA, please run with VLLM_MLA_DISABLE=1")
262
+ else:
263
+ raise NotImplementedError(
264
+ "Can't determine scale group shapes for "
265
+ f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
266
+ )
267
+
268
+ def get_scales(layer: LinearBase) -> torch.Tensor:
269
+ if hasattr(layer, "weight_scale_inv"):
270
+ return layer.weight_scale_inv
271
+ return layer.weight_scale
272
+
273
+ def get_and_maybe_dequant_weights(layer: LinearBase):
274
+ if is_layer_fp8(layer):
275
+ if isinstance(layer.quant_method, \
276
+ CompressedTensorsLinearMethod) and \
277
+ isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
278
+ # NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
279
+ # seems to store weights as (input, output) instead of
280
+ # (output, input) so we need to transpose
281
+ weight = layer.weight.T # standardize to (output, input)
282
+ else:
283
+ weight = layer.weight
284
+ _, weight_scale_group_shape = \
285
+ get_scale_group_shapes_for_fp8(layer)
286
+ scales = get_scales(layer)
287
+
288
+ return scaled_dequantize(weight, scales,
289
+ weight_scale_group_shape)
290
+ else:
291
+ return layer.weight
292
+
293
+ if not (quantization_scheme_supported(self.kv_b_proj) and\
294
+ quantization_scheme_supported(self.q_proj) and\
295
+ quantization_scheme_supported(self.o_proj)):
296
+ raise NotImplementedError(
297
+ "Only FP8 and UnquantizedLinearMethod are supported for MLA"
298
+ ", please run with VLLM_MLA_DISABLE=1")
299
+
300
+ weight_dtype = self.kv_b_proj.weight.dtype
301
+ assert self.o_proj.weight.dtype == weight_dtype
302
+ assert self.q_proj.weight.dtype == weight_dtype
303
+
304
+ kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
305
+ assert kv_b_proj_weight.shape == (
306
+ self.kv_lora_rank,
307
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
308
+ f"{kv_b_proj_weight.shape=}, "
309
+ f"{self.kv_lora_rank=}, "
310
+ f"{self.num_heads=}, "
311
+ f"{self.qk_nope_head_dim=}, "
312
+ f"{self.v_head_dim=}")
313
+ kv_b_proj_weight = kv_b_proj_weight.view(
314
+ self.kv_lora_rank,
315
+ self.num_heads,
316
+ self.qk_nope_head_dim + self.v_head_dim,
317
+ )
318
+
319
+ W_UK, W_UV = kv_b_proj_weight.split(
320
+ [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
321
+
322
+ q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
323
+ .view(-1, self.num_heads, self.qk_head_dim)
324
+
325
+ # can be W_Q or W_UQ depending q_lora_rank, the former if
326
+ # q_lora_rank is None, the latter otherwise. From the Attention backend
327
+ # perspective though we call these both W_Q and rely on the layer
328
+ # to pass in the correct matrix
329
+ W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
330
+ self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
331
+ .flatten(start_dim=1).contiguous()
332
+
333
+ # W_QR is small so for simplicity we dont bother requantizing it
334
+ self.W_QR = self.W_QR.to(act_dtype)
335
+
336
+ if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
337
+ requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
338
+ if is_fp8(weight_dtype) and requantization_enabled:
339
+ # This assumes it wise to requantize using the same group shapes
340
+ # (i.e. strategy, per-tensor, per-channel, block etc.) that the
341
+ # weights were originally quantized
342
+ requant_input_group_shape, requant_weight_group_shape = \
343
+ get_scale_group_shapes_for_fp8(self.q_proj)
344
+ assert (requant_input_group_shape, requant_weight_group_shape)\
345
+ == get_scale_group_shapes_for_fp8(self.kv_b_proj)
346
+ assert (requant_input_group_shape, requant_weight_group_shape)\
347
+ == get_scale_group_shapes_for_fp8(self.o_proj)
348
+ self.reqaunt_input_group_shape = requant_input_group_shape
349
+ self.reqaunt_weight_group_shape = requant_weight_group_shape
350
+
351
+ #
352
+ # Perform matrix-absorption following
353
+ # https://github.com/flashinfer-ai/flashinfer/pull/551
354
+ # for decode, as a result we end up with absorbed weights for decode
355
+ # and another copy of raw weights for prefill.
356
+ #
357
+ self.W_UK, self.W_UV = kv_b_proj_weight.split(
358
+ [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
359
+ # We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
360
+ # depending q_lora_rank, the former if q_lora_rank is None, the
361
+ # latter otherwise
362
+ # basically if q_lora_rank is none we are absorbing into q_proj
363
+ # instead of UQ
364
+ W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
365
+ .flatten(start_dim=1).contiguous()
366
+
367
+ if is_fp8(weight_dtype) and requantization_enabled:
368
+ W_Q_UK, W_Q_UK_scales = scaled_quantize(
369
+ W_Q_UK,
370
+ self.reqaunt_weight_group_shape,
371
+ quant_dtype=current_platform_fp8_dtype)
372
+ # For FP8 save the transpose so we can use
373
+ # `apply_w8a8_block_fp8_linear` directly
374
+ self.W_Q_UK = W_Q_UK.T.contiguous()
375
+ self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
376
+ else:
377
+ self.W_Q_UK = W_Q_UK.to(act_dtype)
378
+
379
+ W_O = get_and_maybe_dequant_weights(self.o_proj)\
380
+ .view(-1, self.num_heads, self.v_head_dim)
381
+ W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
382
+ .flatten(start_dim=0, end_dim=1).contiguous()
383
+
384
+ if is_fp8(weight_dtype) and requantization_enabled:
385
+ W_UV_O, W_UV_O_scales = scaled_quantize(
386
+ W_UV_O,
387
+ self.reqaunt_weight_group_shape,
388
+ quant_dtype=current_platform_fp8_dtype)
389
+ # For FP8 save the transpose so we can use
390
+ # `apply_w8a8_block_fp8_linear` directly
391
+ self.W_UV_O = W_UV_O.T.contiguous()
392
+ self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
393
+ else:
394
+ self.W_UV_O = W_UV_O.to(act_dtype)
395
+
396
+ self.tp_size = get_tensor_model_parallel_world_size()
397
+ else:
398
+ if is_fp8(weight_dtype):
399
+ raise NotImplementedError(
400
+ "Currently fp8 requires matrix absorption")
401
+
402
+ self.W_UV = W_UV
403
+ self.W_UK = W_UK
404
+ self.W_Q = W_Q.flatten(start_dim=1)
405
+
406
+ @abstractmethod
407
+ def _forward_prefill(
408
+ self,
409
+ q: torch.Tensor,
410
+ kv_c_normed: torch.Tensor,
411
+ k_pe: torch.Tensor,
412
+ attn_metadata: T,
413
+ ) -> torch.Tensor:
414
+ raise NotImplementedError
415
+
416
+ @abstractmethod
417
+ def _forward_decode(
418
+ self,
419
+ q_nope: torch.Tensor,
420
+ q_pe: torch.Tensor,
421
+ kv_cache: torch.Tensor,
422
+ attn_metadata: T,
423
+ ) -> torch.Tensor:
424
+ raise NotImplementedError
425
+
426
+ def apply_pure_rope(
427
+ self,
428
+ input_positions: torch.Tensor,
429
+ q_pe: torch.Tensor,
430
+ k_pe: torch.Tensor,
431
+ ) -> tuple[torch.Tensor, torch.Tensor]:
432
+ seq_len = input_positions.size(0)
433
+ ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
434
+
435
+ q_pe, k_pe = self.rotary_emb(
436
+ input_positions,
437
+ q_pe.reshape(seq_len, -1),
438
+ k_pe.reshape(seq_len, -1),
439
+ )
440
+ q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
441
+
442
+ return q_pe, k_pe
443
+
444
+ def forward(
445
+ self,
446
+ layer: AttentionLayer,
447
+ hidden_states_or_q_c: torch.Tensor, # query in unified attn
448
+ k_c_normed: torch.Tensor, # key in unified attn
449
+ k_pe: torch.Tensor, # value in unified attn
450
+ kv_cache: torch.Tensor,
451
+ attn_metadata: T,
452
+ output: Optional[torch.Tensor] = None,
453
+ ) -> torch.Tensor:
454
+ if output is not None:
455
+ raise NotImplementedError(
456
+ "output is not yet supported for MLAImplBase")
457
+
458
+ is_decode = attn_metadata.decode_metadata is not None
459
+ is_prefill = attn_metadata.prefill_metadata is not None
460
+
461
+ if (is_decode and is_prefill):
462
+ raise NotImplementedError(
463
+ "chunked prefill is not supported for MLAImplBase")
464
+
465
+ # Restore head dim (for rotary embedding)
466
+ k_pe = k_pe.unsqueeze(1)
467
+ assert hasattr(attn_metadata, "input_positions")
468
+ rope_fn = (self.rotary_emb
469
+ if self.use_yarn_rope else self.apply_pure_rope)
470
+
471
+ if is_decode:
472
+ q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
473
+ q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
474
+ .view(-1, self.num_heads, self.qk_rope_head_dim)
475
+ q_pe, k_pe = rope_fn(attn_metadata.input_positions, q_pe, k_pe)
476
+ else:
477
+ assert is_prefill
478
+ q = self.q_proj(hidden_states_or_q_c)[0]\
479
+ .view(-1, self.num_heads, self.qk_head_dim)
480
+
481
+ # TODO(lucas): there must be a nicer way to write this line
482
+ q[..., self.qk_nope_head_dim:], k_pe = \
483
+ rope_fn(
484
+ attn_metadata.input_positions,
485
+ q[..., self.qk_nope_head_dim:], k_pe)
486
+
487
+ # write the latent and rope to kv cache
488
+ if kv_cache.numel() > 0:
489
+ ops.concat_and_cache_mla(
490
+ k_c_normed,
491
+ k_pe.squeeze(1),
492
+ kv_cache,
493
+ attn_metadata.slot_mapping.flatten(),
494
+ kv_cache_dtype=self.kv_cache_dtype,
495
+ scale=layer._k_scale,
496
+ )
497
+
498
+ if attn_metadata.prefill_metadata is not None:
499
+ return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata)
500
+
501
+ if attn_metadata.decode_metadata is not None:
502
+ return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata)
503
+
504
+ # Optional common flash-attn based prefill
505
+ def _forward_prefill_flash(
506
+ self,
507
+ q: torch.Tensor,
508
+ k_c_normed: torch.Tensor,
509
+ k_pe: torch.Tensor,
510
+ seq_start_loc: torch.Tensor,
511
+ max_prefill_seq_len: int,
512
+ ) -> torch.Tensor:
513
+
514
+ kv_nope = self.kv_b_proj(k_c_normed)[0]\
515
+ .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
516
+ k_nope, v = kv_nope\
517
+ .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
518
+
519
+ k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
520
+
521
+ # For MLA the v head dim is smaller than qk head dim so we pad out
522
+ # v with 0s to match the qk head dim
523
+ v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
524
+ value=0)
525
+
526
+ attn_output = flash_attn_varlen_func(
527
+ q=q,
528
+ k=k,
529
+ v=v_padded,
530
+ cu_seqlens_q=seq_start_loc,
531
+ cu_seqlens_k=seq_start_loc,
532
+ max_seqlen_q=max_prefill_seq_len,
533
+ max_seqlen_k=max_prefill_seq_len,
534
+ softmax_scale=self.scale,
535
+ causal=True,
536
+ )
537
+ attn_output = attn_output\
538
+ .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
539
+ .reshape(-1, self.num_heads * v.shape[-1])
540
+
541
+ return self.o_proj(attn_output)[0]
.venv/lib/python3.11/site-packages/vllm/attention/backends/openvino.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Optional, Tuple, Type
5
+
6
+ import openvino as ov
7
+ import torch
8
+
9
+ from vllm.attention.backends.abstract import (AttentionBackend,
10
+ AttentionMetadata)
11
+ from vllm.attention.backends.utils import CommonAttentionState
12
+ from vllm.multimodal import MultiModalPlaceholderMap
13
+
14
+
15
+ def copy_cache_block(src_tensor: ov.Tensor, dst_tensor: ov.Tensor,
16
+ src_offset: int, dst_offset: int) -> None:
17
+
18
+ def create_roi_tensor(
19
+ tensor: ov.Tensor,
20
+ block_number: int,
21
+ ) -> ov.Tensor:
22
+ roi_begin = ov.runtime.Coordinate([0, 0, 0, 0])
23
+ roi_end = ov.runtime.Coordinate(tensor.get_shape())
24
+
25
+ roi_begin[0] = block_number
26
+ roi_end[0] = block_number + 1
27
+
28
+ if isinstance(tensor, ov.Tensor):
29
+ return ov.Tensor(tensor, roi_begin, roi_end)
30
+ else:
31
+ return ov.RemoteTensor(tensor, roi_begin, roi_end)
32
+
33
+ src_roi_tensor = \
34
+ create_roi_tensor(src_tensor, src_offset)
35
+ dst_roi_tensor = \
36
+ create_roi_tensor(dst_tensor, dst_offset)
37
+ src_roi_tensor.copy_to(dst_roi_tensor)
38
+
39
+
40
+ class OpenVINOAttentionBackend(AttentionBackend):
41
+
42
+ @staticmethod
43
+ def get_name() -> str:
44
+ return "OPENVINO"
45
+
46
+ @staticmethod
47
+ def get_impl_cls():
48
+ # OpenVINO implements PagedAttention as part of the Optimum
49
+ # exported model
50
+ raise NotImplementedError
51
+
52
+ @staticmethod
53
+ def make_metadata(*args, **kwargs) -> "AttentionMetadata":
54
+ raise NotImplementedError
55
+
56
+ @staticmethod
57
+ def get_state_cls() -> Type["CommonAttentionState"]:
58
+ return CommonAttentionState
59
+
60
+ @staticmethod
61
+ def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata":
62
+ return OpenVINOAttentionMetadata(*args, **kwargs)
63
+
64
+ @staticmethod
65
+ def get_kv_cache_shape(
66
+ num_blocks: int,
67
+ block_size: int,
68
+ num_kv_heads: int,
69
+ head_size: int,
70
+ ) -> Tuple[int, ...]:
71
+ return (2, num_blocks, num_kv_heads, block_size, head_size)
72
+
73
+ @staticmethod
74
+ def swap_blocks(
75
+ src_tensor: ov.Tensor,
76
+ dst_tensor: ov.Tensor,
77
+ src_to_dists: List[Tuple[int, int]],
78
+ ) -> None:
79
+ for src, dst in src_to_dists:
80
+ copy_cache_block(src_tensor, dst_tensor, src, dst)
81
+
82
+ @staticmethod
83
+ def copy_blocks(
84
+ kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
85
+ src_to_dists: List[Tuple[int, int]],
86
+ ) -> None:
87
+ for src, dst in src_to_dists:
88
+ for key_cache, value_cache in kv_caches:
89
+ copy_cache_block(key_cache, key_cache, src, dst)
90
+ copy_cache_block(value_cache, value_cache, src, dst)
91
+
92
+
93
+ @dataclass
94
+ class OpenVINOAttentionMetadata:
95
+ """Metadata for OpenVINOAttentionBackend.
96
+
97
+ Basic terms used below:
98
+ - batch_size_in_sequences - total number of sequences to execute​
99
+ - prompt_lens – per sequence size number of scheduled tokens​
100
+ - batch_size_in_tokens = sum(prompt_lens)​
101
+ - max_context_len = max(context_lens)​
102
+ - max_num_blocks = div_up(max_context_len / BLOCK_SIZE)​
103
+ - num_blocks – total number of blocks in block_indices​
104
+ """
105
+
106
+ # Describes past KV cache size for each sequence within a batch
107
+ # Shape: [batch_size_in_sequences]
108
+ # Type: i32​
109
+ past_lens: torch.Tensor
110
+
111
+ # Describes start indices of input / speculative tokens from
112
+ # current sequences within a batch sequence​
113
+ # Shape: [batch_size_in_sequences + 1]​
114
+ # Type: i32
115
+ subsequence_begins: torch.Tensor
116
+
117
+ # Describes block tables for each sequence within a batch​ -
118
+ # indices along 0th dimension in key_cache and value_cache inputs​
119
+ # Shape: [num_blocks]
120
+ # Type: i32​
121
+ block_indices: torch.Tensor
122
+
123
+ # Describes block tables for each sequence within a batch​ -
124
+ # for i-th element, it is an index in block_indices with the
125
+ # first block belonging to i-th sequence​
126
+ # Shape: [batch_size_in_sequences + 1]
127
+ # Type: i32​
128
+ block_indices_begins: torch.Tensor
129
+
130
+ # Describes max context length
131
+ # Shape: scalar
132
+ # Type: i32
133
+ max_context_len: torch.Tensor
134
+
135
+ # The index maps that relate multi-modal embeddings to the corresponding
136
+ # placeholders.
137
+ #
138
+ # N.B. These aren't really related to attention and don't belong on this
139
+ # type -- this is just a temporary solution to make them available to
140
+ # `model_executable`.
141
+ multi_modal_placeholder_index_maps: Optional[Dict[
142
+ str, MultiModalPlaceholderMap.IndexMap]]
143
+
144
+ # Enable/disable KV scales calculation. This is so that we can disable the
145
+ # calculation until after prefill and cuda graph capture.
146
+ enable_kv_scales_calculation: bool
.venv/lib/python3.11/site-packages/vllm/attention/backends/pallas.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Tuple, Type
5
+
6
+ import torch
7
+ import torch_xla.experimental.custom_kernel # Required to register custom ops.
8
+
9
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
10
+ AttentionLayer,
11
+ AttentionMetadata, AttentionType)
12
+ from vllm.attention.backends.utils import CommonAttentionState
13
+
14
+
15
+ class PallasAttentionBackend(AttentionBackend):
16
+
17
+ @staticmethod
18
+ def get_name() -> str:
19
+ return "PALLAS"
20
+
21
+ @staticmethod
22
+ def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
23
+ return PallasAttentionBackendImpl
24
+
25
+ @staticmethod
26
+ def get_metadata_cls() -> Type["PallasMetadata"]:
27
+ return PallasMetadata
28
+
29
+ @staticmethod
30
+ def get_state_cls() -> Type["CommonAttentionState"]:
31
+ return CommonAttentionState
32
+
33
+ @staticmethod
34
+ def get_kv_cache_shape(
35
+ num_blocks: int,
36
+ block_size: int,
37
+ num_kv_heads: int,
38
+ head_size: int,
39
+ ) -> Tuple[int, ...]:
40
+ return (num_kv_heads, num_blocks, block_size, head_size)
41
+
42
+ @staticmethod
43
+ def swap_blocks(
44
+ src_kv_cache: torch.Tensor,
45
+ dst_kv_cache: torch.Tensor,
46
+ src_to_dst: torch.Tensor,
47
+ ) -> None:
48
+ raise RuntimeError("swap_blocks is not used for the TPU backend.")
49
+
50
+ @torch.compile(backend="openxla")
51
+ @staticmethod
52
+ def copy_blocks(
53
+ kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
54
+ src_to_dists: Tuple[torch.Tensor, torch.Tensor],
55
+ ) -> None:
56
+ src_indices, dst_indices = src_to_dists
57
+ for k_cache, v_cache in kv_caches:
58
+ torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
59
+ k_cache[:, dst_indices] = k_cache[:, src_indices]
60
+ torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
61
+ v_cache[:, dst_indices] = v_cache[:, src_indices]
62
+
63
+
64
+ @dataclass
65
+ class PallasMetadata(AttentionMetadata):
66
+
67
+ # Currently, input sequences can only contain all prefills
68
+ # or all decoding.
69
+ block_tables: Optional[torch.Tensor] = None
70
+ context_lens: Optional[torch.Tensor] = None
71
+ effective_query_lens: Optional[torch.Tensor] = None
72
+
73
+ @property
74
+ def prefill_metadata(self) -> Optional["PallasMetadata"]:
75
+ if self.num_prefills == 0:
76
+ return None
77
+
78
+ assert self.num_decode_tokens == 0
79
+ return self
80
+
81
+ @property
82
+ def decode_metadata(self) -> Optional["PallasMetadata"]:
83
+ if self.num_decode_tokens == 0:
84
+ return None
85
+
86
+ assert self.num_prefills == 0
87
+ assert self.num_prefill_tokens == 0
88
+ assert self.block_tables is not None
89
+ assert self.context_lens is not None
90
+ return self
91
+
92
+
93
+ class PallasAttentionBackendImpl(AttentionImpl):
94
+
95
+ def __init__(
96
+ self,
97
+ num_heads: int,
98
+ head_size: int,
99
+ scale: float,
100
+ num_kv_heads: int,
101
+ alibi_slopes: Optional[List[float]],
102
+ sliding_window: Optional[int],
103
+ kv_cache_dtype: str,
104
+ blocksparse_params: Optional[Dict[str, Any]] = None,
105
+ logits_soft_cap: Optional[float] = None,
106
+ attn_type: str = AttentionType.DECODER,
107
+ ) -> None:
108
+ self.num_heads = num_heads
109
+ self.head_size = head_size
110
+ self.scale = float(scale)
111
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
112
+
113
+ assert self.num_heads % self.num_kv_heads == 0
114
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
115
+ self.logits_soft_cap = logits_soft_cap
116
+ if head_size % 128 != 0:
117
+ raise NotImplementedError("Head size must be a multiple of 128.")
118
+ if alibi_slopes is not None:
119
+ raise NotImplementedError("Alibi slopes is not supported.")
120
+ if sliding_window is not None:
121
+ raise NotImplementedError("Sliding window is not supported.")
122
+ if kv_cache_dtype != "auto":
123
+ raise NotImplementedError("FP8 KV cache dtype is not supported.")
124
+ if blocksparse_params is not None:
125
+ raise NotImplementedError("Blocksparse is not supported.")
126
+
127
+ if torch_xla.tpu.version() < 4:
128
+ raise NotImplementedError("TPU version must be 4 or higher.")
129
+
130
+ self.megacore_mode = None
131
+ tpu_env = torch_xla.tpu.get_tpu_env()
132
+ tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
133
+ or tpu_env.get("TYPE", None)
134
+ or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
135
+ assert tpu_type is not None
136
+ tpu_type = tpu_type.lower()
137
+
138
+ if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
139
+ if self.num_kv_heads % 2 == 0:
140
+ self.megacore_mode = "kv_head"
141
+ else:
142
+ # NOTE(woosuk): If the batch size is not a multiple of 2, the
143
+ # megacore mode will be None.
144
+ self.megacore_mode = "batch"
145
+
146
+ if attn_type != AttentionType.DECODER:
147
+ raise NotImplementedError("Encoder self-attention and "
148
+ "encoder/decoder cross-attention "
149
+ "are not implemented for "
150
+ "PallasAttentionBackendImpl")
151
+
152
+ def forward(
153
+ self,
154
+ layer: AttentionLayer,
155
+ query: torch.Tensor,
156
+ key: torch.Tensor,
157
+ value: torch.Tensor,
158
+ kv_cache: Tuple[torch.Tensor, torch.Tensor],
159
+ attn_metadata: PallasMetadata,
160
+ output: Optional[torch.Tensor] = None,
161
+ ) -> torch.Tensor:
162
+ """Forward pass with Pallas attention.
163
+
164
+ Args:
165
+ query: shape = [batch_size, seq_len, num_heads * head_size]
166
+ key: shape = [batch_size, seq_len, num_kv_heads * head_size]
167
+ value: shape = [batch_size, seq_len, num_kv_heads * head_size]
168
+ kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
169
+ kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
170
+ NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
171
+ with shape [0] for profiling run.
172
+ attn_metadata: Metadata for attention.
173
+ Returns:
174
+ shape = [batch_size, seq_len, num_heads * head_size]
175
+ """
176
+ assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
177
+ batch_size, seq_len, hidden_size = query.shape
178
+ query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
179
+ key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
180
+ value = value.view(batch_size, seq_len, self.num_kv_heads,
181
+ self.head_size)
182
+
183
+ if kv_cache[0].numel() > 0:
184
+ slot_mapping = attn_metadata.slot_mapping
185
+ key_cache, value_cache = kv_cache
186
+ write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
187
+
188
+ query = query * self.scale
189
+ if attn_metadata.num_prefills > 0:
190
+ if attn_metadata.block_tables is None:
191
+ # Prefill without paged KV cache.
192
+ assert seq_len % 16 == 0, (
193
+ "Pallas FlashAttention kernel requires seq_len to be a "
194
+ f"multiple of 16 but got {seq_len}")
195
+
196
+ # Handle GQA/MQA.
197
+ if self.num_kv_heads != self.num_heads:
198
+ key = key.repeat_interleave(self.num_queries_per_kv,
199
+ dim=-2)
200
+ key = key.view(batch_size, seq_len, self.num_heads,
201
+ self.head_size)
202
+ value = value.repeat_interleave(self.num_queries_per_kv,
203
+ dim=-2)
204
+ value = value.view(batch_size, seq_len, self.num_heads,
205
+ self.head_size)
206
+ # FlashAttention kernel requires the input shape to be
207
+ # [batch_size, num_heads, seq_len, d_model]
208
+ # while the input is [batch_size, seq_len, num_heads, d_model].
209
+ # Permute the input to match the required format.
210
+ output = torch.ops.xla.flash_attention(
211
+ query.permute(0, 2, 1, 3),
212
+ key.permute(0, 2, 1, 3),
213
+ value.permute(0, 2, 1, 3),
214
+ True,
215
+ )
216
+ output = output.permute(0, 2, 1, 3)
217
+ else:
218
+ # Prefill with paged KV cache.
219
+ # TODO(woosuk): Tune the below knobs.
220
+ num_kv_pages_per_compute_block = 16
221
+ num_queries_per_compute_block = 16
222
+ assert seq_len % num_queries_per_compute_block == 0
223
+ output = torch.ops.xla.multi_queries_paged_attention(
224
+ query,
225
+ key_cache,
226
+ value_cache,
227
+ attn_metadata.context_lens,
228
+ attn_metadata.block_tables,
229
+ attn_metadata.effective_query_lens,
230
+ num_kv_pages_per_compute_block,
231
+ num_queries_per_compute_block,
232
+ use_kernel=True,
233
+ attn_logits_soft_cap=self.logits_soft_cap,
234
+ )
235
+ else:
236
+ # Decoding run.
237
+ assert kv_cache[0].numel() > 0
238
+ query = query.squeeze(dim=1)
239
+ pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
240
+
241
+ assert attn_metadata.block_tables is not None
242
+ assert attn_metadata.context_lens is not None
243
+ # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
244
+ # block table in SMEM. Therefore, if the block table is too large,
245
+ # the kernel compilation will fail. To avoid this, we split the
246
+ # batch dimension into smaller chunks and run the kernel multiple
247
+ # times.
248
+ MAX_SMEM_USAGE = 512 * 1024
249
+ size_per_seq = 4 * attn_metadata.block_tables.shape[1]
250
+ max_num_seq = MAX_SMEM_USAGE // size_per_seq
251
+
252
+ if batch_size <= max_num_seq:
253
+ output = paged_attention(
254
+ query,
255
+ key_cache,
256
+ value_cache,
257
+ attn_metadata.context_lens,
258
+ attn_metadata.block_tables,
259
+ pages_per_compute_block,
260
+ self.megacore_mode,
261
+ attn_logits_soft_cap=self.logits_soft_cap,
262
+ )
263
+ else:
264
+ chunk_size = max_num_seq
265
+ # Make sure the chunk size is a multiple of 2.
266
+ chunk_size = chunk_size // 2 * 2
267
+ num_chunks = (batch_size + chunk_size - 1) // chunk_size
268
+
269
+ output = torch.empty_like(query)
270
+ for chunk_idx in range(num_chunks):
271
+ chunk_start = chunk_idx * chunk_size
272
+ chunk_end = chunk_start + chunk_size
273
+ # NOTE(woosuk): We skip this line because it causes Dynamo
274
+ # compilation error. Instead, we rely on the slice operation
275
+ # to handle the out-of-bound case.
276
+ # chunk_end = min(chunk_end, batch_size)
277
+ chunk_output = paged_attention(
278
+ query[chunk_start:chunk_end],
279
+ key_cache,
280
+ value_cache,
281
+ attn_metadata.context_lens[chunk_start:chunk_end],
282
+ attn_metadata.block_tables[chunk_start:chunk_end],
283
+ pages_per_compute_block,
284
+ self.megacore_mode,
285
+ attn_logits_soft_cap=self.logits_soft_cap,
286
+ )
287
+ output[chunk_start:chunk_end] = chunk_output
288
+
289
+ # Reshape the output tensor.
290
+ return output.reshape(batch_size, seq_len, hidden_size)
291
+
292
+
293
+ def write_to_kv_cache(
294
+ key: torch.Tensor,
295
+ value: torch.Tensor,
296
+ key_cache: torch.Tensor,
297
+ value_cache: torch.Tensor,
298
+ slot_mapping: torch.Tensor,
299
+ ) -> None:
300
+ torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
301
+ torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
302
+
303
+ key = key.flatten(0, 2)
304
+ value = value.flatten(0, 2)
305
+ key_cache = key_cache.flatten(0, 2)
306
+ value_cache = value_cache.flatten(0, 2)
307
+ key_cache.index_copy_(0, slot_mapping, key)
308
+ value_cache.index_copy_(0, slot_mapping, value)
309
+
310
+
311
+ def paged_attention(
312
+ query: torch.Tensor,
313
+ key_cache: torch.Tensor,
314
+ value_cache: torch.Tensor,
315
+ context_lens: torch.Tensor,
316
+ block_tables: torch.Tensor,
317
+ pages_per_compute_block: int,
318
+ megacore_mode: Optional[str],
319
+ *,
320
+ attn_logits_soft_cap: Optional[float],
321
+ ) -> torch.Tensor:
322
+ batch_size = query.shape[0]
323
+ if megacore_mode == "batch" and batch_size % 2 != 0:
324
+ megacore_mode = None
325
+ else:
326
+ megacore_mode = megacore_mode
327
+
328
+ return torch.ops.xla.paged_attention(
329
+ query,
330
+ key_cache,
331
+ value_cache,
332
+ context_lens,
333
+ block_tables,
334
+ pages_per_compute_block,
335
+ megacore_mode=megacore_mode,
336
+ attn_logits_soft_cap=attn_logits_soft_cap,
337
+ )
.venv/lib/python3.11/site-packages/vllm/attention/backends/placeholder_attn.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from collections import defaultdict
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
6
+
7
+ import torch
8
+
9
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
10
+ AttentionMetadata,
11
+ AttentionMetadataBuilder)
12
+ from vllm.attention.backends.utils import CommonAttentionState
13
+ from vllm.multimodal import MultiModalPlaceholderMap
14
+
15
+ if TYPE_CHECKING:
16
+ from vllm.worker.model_runner import (ModelInputForGPUBuilder,
17
+ ModelInputForGPUWithSamplingMetadata)
18
+
19
+ # Placeholder attention backend for models like Mamba and pooling models that
20
+ # lack attention.
21
+
22
+
23
+ class PlaceholderAttentionBackend(AttentionBackend):
24
+ """Placeholder backend for when no attention is needed."""
25
+
26
+ @staticmethod
27
+ def get_name() -> str:
28
+ return "NO_ATTENTION"
29
+
30
+ @staticmethod
31
+ def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
32
+ return PlaceholderAttentionImpl
33
+
34
+ @staticmethod
35
+ def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
36
+ return PlaceholderAttentionMetadataBuilder
37
+
38
+ @staticmethod
39
+ def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
40
+ return PlaceholderAttentionMetadata
41
+
42
+ @staticmethod
43
+ def get_state_cls() -> Type["CommonAttentionState"]:
44
+ return CommonAttentionState
45
+
46
+ @staticmethod
47
+ def get_kv_cache_shape(
48
+ num_blocks: int,
49
+ block_size: int,
50
+ num_kv_heads: int,
51
+ head_size: int,
52
+ ) -> Tuple[int, ...]:
53
+ return (1, 1, 1, 1, 1)
54
+
55
+ @staticmethod
56
+ def swap_blocks(
57
+ src_kv_cache: torch.Tensor,
58
+ dst_kv_cache: torch.Tensor,
59
+ src_to_dst: torch.Tensor,
60
+ ) -> None:
61
+ return
62
+
63
+ @staticmethod
64
+ def copy_blocks(
65
+ kv_caches: List[torch.Tensor],
66
+ src_to_dists: torch.Tensor,
67
+ ) -> None:
68
+ return
69
+
70
+
71
+ @dataclass
72
+ class PlaceholderAttentionMetadata(AttentionMetadata):
73
+ """Attention metadata for prefill and decode batched together."""
74
+ # (batch_size,). The sequence length per sequence. Sequence length means
75
+ # the computed tokens + new tokens None if it is a decoding.
76
+ seq_lens: Optional[List[int]]
77
+ # seq_lens stored as a tensor.
78
+ seq_lens_tensor: Optional[torch.Tensor]
79
+
80
+ # Maximum query length in the batch.
81
+ max_query_len: Optional[int]
82
+
83
+ # Max number of query tokens among request in the batch.
84
+ max_decode_query_len: Optional[int]
85
+
86
+ # Maximum sequence length among prefill batch. 0 if there are decoding
87
+ # requests only.
88
+ max_prefill_seq_len: int
89
+ # Maximum sequence length among decode batch. 0 if there are prefill
90
+ # requests only.
91
+ max_decode_seq_len: int
92
+ # (batch_size + 1,). The cumulative subquery lengths of the sequences in
93
+ # the batch, used to index into subquery. E.g., if the subquery length
94
+ # is [4, 6], it is [0, 4, 10].
95
+ query_start_loc: Optional[torch.Tensor]
96
+ # (batch_size + 1,). The cumulative sequence lengths of the sequences in
97
+ # the batch, used to index into sequence. E.g., if the sequence length is
98
+ # [4, 6], it is [0, 4, 10].
99
+ seq_start_loc: Optional[torch.Tensor]
100
+ # (batch_size,) A tensor of context lengths (tokens that are computed
101
+ # so far).
102
+ context_lens_tensor: Optional[torch.Tensor]
103
+
104
+ # (batch_size, max_blocks_per_seq).
105
+ # Block addresses per sequence. (Seq id -> list of physical block)
106
+ # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
107
+ # in the kv cache. Each block can contain up to block_size tokens.
108
+ # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
109
+ # captured.
110
+ block_tables: Optional[torch.Tensor]
111
+
112
+ # Whether or not if cuda graph is enabled.
113
+ # Cuda-graph is currently enabled for decoding only.
114
+ # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
115
+ use_cuda_graph: bool
116
+
117
+ _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
118
+ _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None
119
+
120
+ @property
121
+ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
122
+ if self.num_prefills == 0:
123
+ return None
124
+
125
+ if self._cached_prefill_metadata is not None:
126
+ return self._cached_prefill_metadata
127
+
128
+ assert self.seq_lens is not None
129
+ assert self.seq_lens_tensor is not None
130
+ assert self.query_start_loc is not None
131
+ assert self.context_lens_tensor is not None
132
+ assert self.seq_start_loc is not None
133
+
134
+ # Placeholders
135
+ slot_mapping = torch.empty(0)
136
+ block_tables = torch.empty(0)
137
+
138
+ self._cached_prefill_metadata = PlaceholderAttentionMetadata(
139
+ num_prefills=self.num_prefills,
140
+ num_prefill_tokens=self.num_prefill_tokens,
141
+ num_decode_tokens=0,
142
+ slot_mapping=slot_mapping,
143
+ multi_modal_placeholder_index_maps=self.
144
+ multi_modal_placeholder_index_maps,
145
+ enable_kv_scales_calculation=self.enable_kv_scales_calculation,
146
+ seq_lens=self.seq_lens[:self.num_prefills],
147
+ seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
148
+ max_decode_query_len=0,
149
+ max_query_len=self.max_query_len,
150
+ max_prefill_seq_len=self.max_prefill_seq_len,
151
+ max_decode_seq_len=0,
152
+ query_start_loc=self.query_start_loc[:self.num_prefills + 1],
153
+ seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
154
+ context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
155
+ block_tables=block_tables,
156
+ use_cuda_graph=False,
157
+ )
158
+ return self._cached_prefill_metadata
159
+
160
+ @property
161
+ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
162
+ if self.num_decode_tokens == 0:
163
+ return None
164
+
165
+ if self._cached_decode_metadata is not None:
166
+ return self._cached_decode_metadata
167
+ assert self.seq_lens_tensor is not None
168
+
169
+ # Placeholders
170
+ slot_mapping = torch.empty(0)
171
+ block_tables = torch.empty(0)
172
+
173
+ self._cached_decode_metadata = PlaceholderAttentionMetadata(
174
+ num_prefills=0,
175
+ num_prefill_tokens=0,
176
+ num_decode_tokens=self.num_decode_tokens,
177
+ slot_mapping=slot_mapping,
178
+ multi_modal_placeholder_index_maps=None,
179
+ enable_kv_scales_calculation=True,
180
+ seq_lens=None,
181
+ seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
182
+ max_decode_query_len=self.max_decode_query_len,
183
+ max_query_len=None,
184
+ max_prefill_seq_len=0,
185
+ max_decode_seq_len=self.max_decode_seq_len,
186
+ query_start_loc=None,
187
+ seq_start_loc=None,
188
+ context_lens_tensor=None,
189
+ block_tables=block_tables,
190
+ use_cuda_graph=self.use_cuda_graph,
191
+ )
192
+ return self._cached_decode_metadata
193
+
194
+ def advance_step(self,
195
+ model_input: "ModelInputForGPUWithSamplingMetadata",
196
+ sampled_token_ids: Optional[torch.Tensor],
197
+ block_size: int,
198
+ num_seqs: int,
199
+ num_queries: int,
200
+ turn_prefills_into_decodes: bool = False):
201
+ """
202
+ Update metadata in-place to advance one decode step.
203
+ """
204
+ # When using cudagraph, the num_seqs is padded to the next captured
205
+ # batch sized, but num_queries tracks the actual number of requests in
206
+ # the batch. For --enforce-eager mode, num_seqs == num_queries
207
+ if num_seqs != num_queries:
208
+ assert num_seqs > num_queries
209
+ assert self.use_cuda_graph
210
+
211
+ assert not turn_prefills_into_decodes, \
212
+ ("Multi-Step + Chunked-Prefill is not supported for attention-free"
213
+ "models. turn_prefills_into_decodes is a "
214
+ "Multi-Step + Chunked-Prefill specific parameter.")
215
+
216
+ assert self.seq_lens is not None
217
+ assert self.max_decode_seq_len == max(self.seq_lens)
218
+
219
+ assert self.num_prefills == 0
220
+ assert self.num_prefill_tokens == 0
221
+ assert self.num_decode_tokens == num_seqs
222
+
223
+ assert self.seq_lens is not None
224
+ assert len(self.seq_lens) == num_seqs
225
+ assert self.seq_lens_tensor is not None
226
+ assert self.seq_lens_tensor.shape == (num_seqs, )
227
+ assert self.max_query_len == 1
228
+ assert self.max_prefill_seq_len == 0
229
+
230
+ assert self.query_start_loc is not None
231
+ assert self.query_start_loc.shape == (num_queries + 1, )
232
+ assert self.seq_start_loc is not None
233
+ assert self.seq_start_loc.shape == (num_seqs + 1, )
234
+
235
+ assert self.context_lens_tensor is not None
236
+ assert self.context_lens_tensor.shape == (num_queries, )
237
+
238
+ assert self.block_tables is not None
239
+
240
+ # Update query lengths. Note that we update only queries and not seqs,
241
+ # since tensors may be padded due to captured cuda graph batch size
242
+ for i in range(num_queries):
243
+ self.seq_lens[i] += 1
244
+ self.max_decode_seq_len = max(self.seq_lens)
245
+
246
+ # Update sequences, masking off entries greater than num_queries
247
+ device = self.seq_lens_tensor.device
248
+ mask = torch.arange(self.seq_lens_tensor.size(0),
249
+ device=device) < num_queries
250
+ self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype)
251
+ if sampled_token_ids is not None:
252
+ model_input.input_tokens.masked_scatter_(
253
+ mask, sampled_token_ids[:num_queries])
254
+
255
+
256
+ class PlaceholderAttentionMetadataBuilder(
257
+ AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
258
+
259
+ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
260
+
261
+ self.input_builder = input_builder
262
+ self.runner = input_builder.runner
263
+
264
+ def prepare(self):
265
+ self.prefill_seq_lens: List[int] = []
266
+ self.context_lens: List[int] = []
267
+ self.curr_seq_lens: List[int] = []
268
+ self.multimodal_placeholder_maps: Dict[
269
+ str,
270
+ MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
271
+ self.num_prefills = 0
272
+ self.num_prefill_tokens = 0
273
+ self.num_decode_tokens = 0
274
+
275
+ def _add_seq_group(
276
+ self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
277
+ chunked_prefill_enabled: bool):
278
+ """Add a sequence group to the metadata. Specifically update/append
279
+ 1. context length.
280
+ """
281
+ is_prompt = inter_data.is_prompt
282
+
283
+ for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
284
+ curr_sliding_window_block) in zip(
285
+ inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
286
+ inter_data.orig_seq_lens, inter_data.seq_lens,
287
+ inter_data.query_lens, inter_data.context_lens,
288
+ inter_data.curr_sliding_window_blocks):
289
+ self.context_lens.append(context_len)
290
+
291
+ if is_prompt:
292
+ mm_maps = inter_data.multi_modal_placeholder_maps
293
+ if mm_maps:
294
+ for modality, placeholders in mm_maps.items():
295
+ self.multimodal_placeholder_maps[modality].extend(
296
+ placeholders)
297
+
298
+ self.num_prefills += 1
299
+ self.num_prefill_tokens += token_len
300
+ self.prefill_seq_lens.append(seq_len)
301
+ else:
302
+ assert query_len == 1, (
303
+ "seq_len: {}, context_len: {}, query_len: {}".format(
304
+ seq_len, context_len, query_len))
305
+ self.num_decode_tokens += query_len
306
+ self.curr_seq_lens.append(curr_seq_len)
307
+
308
+ def build(self, seq_lens: List[int], query_lens: List[int],
309
+ cuda_graph_pad_size: int, batch_size: int):
310
+ """Build attention metadata with on-device tensors.
311
+
312
+ Args:
313
+ seq_lens: The maybe padded sequence lengths of the input sequences.
314
+ query_lens: The query lengths of the input sequences.
315
+ cuda_graph_pad_size: The padding size for cuda graph.
316
+ -1 if cuda graph is not used.
317
+ batch_size: The maybe padded batch size.
318
+ """
319
+ for inter_data in self.input_builder.inter_data_list:
320
+ self._add_seq_group(inter_data,
321
+ self.input_builder.chunked_prefill_enabled)
322
+
323
+ device = self.runner.device
324
+ use_captured_graph = cuda_graph_pad_size != -1
325
+
326
+ logits_soft_cap = getattr(self.runner.model_config.hf_config,
327
+ "attn_logit_softcapping", None)
328
+ if logits_soft_cap is not None:
329
+ raise ValueError(
330
+ "Please use Flashinfer backend for models with logits_soft_cap"
331
+ " (i.e., Gemma-2). Otherwise, the output might be wrong."
332
+ " Set Flashinfer backend by "
333
+ "export VLLM_ATTENTION_BACKEND=FLASHINFER.")
334
+
335
+ max_query_len = max(query_lens)
336
+ decode_query_lens = query_lens[self.num_prefills:]
337
+ if len(decode_query_lens) > 0:
338
+ max_decode_query_len = max(decode_query_lens)
339
+ else:
340
+ max_decode_query_len = 1
341
+ max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
342
+ max_decode_seq_len = max(self.curr_seq_lens, default=0)
343
+ num_decode_tokens = self.num_decode_tokens
344
+
345
+ if use_captured_graph:
346
+ num_decode_tokens = batch_size
347
+
348
+ assert max_query_len > 0, ("query_lens: {}".format(query_lens))
349
+
350
+ context_lens_tensor = torch.tensor(self.context_lens,
351
+ dtype=torch.int,
352
+ device=device)
353
+ seq_lens_tensor = torch.tensor(seq_lens,
354
+ dtype=torch.int,
355
+ device=device)
356
+ query_lens_tensor = torch.tensor(query_lens,
357
+ dtype=torch.long,
358
+ device=device)
359
+ query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
360
+ dtype=torch.int32,
361
+ device=device)
362
+ seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
363
+ dtype=torch.int32,
364
+ device=device)
365
+ placeholder_index_maps = {
366
+ modality: placeholder_map.index_map()
367
+ for modality, placeholder_map in
368
+ self.multimodal_placeholder_maps.items()
369
+ }
370
+ torch.cumsum(seq_lens_tensor,
371
+ dim=0,
372
+ dtype=seq_start_loc.dtype,
373
+ out=seq_start_loc[1:])
374
+ torch.cumsum(query_lens_tensor,
375
+ dim=0,
376
+ dtype=query_start_loc.dtype,
377
+ out=query_start_loc[1:])
378
+
379
+ # Placeholders
380
+ slot_mapping = torch.empty(0)
381
+ block_tables = torch.empty(0)
382
+
383
+ return PlaceholderAttentionMetadata(
384
+ num_prefills=self.num_prefills,
385
+ slot_mapping=slot_mapping,
386
+ multi_modal_placeholder_index_maps=placeholder_index_maps,
387
+ enable_kv_scales_calculation=True,
388
+ num_prefill_tokens=self.num_prefill_tokens,
389
+ num_decode_tokens=num_decode_tokens,
390
+ seq_lens=seq_lens,
391
+ seq_lens_tensor=seq_lens_tensor,
392
+ max_query_len=max_query_len,
393
+ max_decode_query_len=max_decode_query_len,
394
+ max_prefill_seq_len=max_prefill_seq_len,
395
+ max_decode_seq_len=max_decode_seq_len,
396
+ query_start_loc=query_start_loc,
397
+ seq_start_loc=seq_start_loc,
398
+ context_lens_tensor=context_lens_tensor,
399
+ block_tables=block_tables,
400
+ use_cuda_graph=use_captured_graph,
401
+ )
402
+
403
+
404
+ class PlaceholderAttentionImpl(AttentionImpl):
405
+
406
+ def __init__(self, *args, **kwargs) -> None:
407
+ return
408
+
409
+ def forward(self, *args, **kwargs) -> torch.Tensor:
410
+ raise NotImplementedError
.venv/lib/python3.11/site-packages/vllm/attention/backends/rocm_flash_attn.py ADDED
@@ -0,0 +1,891 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Attention layer ROCm GPUs."""
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
5
+
6
+ import torch
7
+
8
+ import vllm.envs as envs
9
+ from vllm import _custom_ops as ops
10
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
11
+ AttentionLayer,
12
+ AttentionMetadata, AttentionType)
13
+ from vllm.attention.backends.utils import (CommonAttentionState,
14
+ CommonMetadataBuilder)
15
+ from vllm.attention.ops.paged_attn import (PagedAttention,
16
+ PagedAttentionMetadata)
17
+ from vllm.logger import init_logger
18
+ from vllm.platforms import current_platform
19
+
20
+ if TYPE_CHECKING:
21
+ from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
22
+
23
+ logger = init_logger(__name__)
24
+
25
+ _PARTITION_SIZE_ROCM = 512
26
+ _GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
27
+ _ON_NAVI = "gfx1" in _GPU_ARCH
28
+ _ON_MI250_MI300 = any(arch in _GPU_ARCH
29
+ for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"])
30
+
31
+
32
+ class ROCmFlashAttentionBackend(AttentionBackend):
33
+
34
+ @staticmethod
35
+ def get_name() -> str:
36
+ return "ROCM_FLASH"
37
+
38
+ @staticmethod
39
+ def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
40
+ return ROCmFlashAttentionImpl
41
+
42
+ @staticmethod
43
+ def get_metadata_cls() -> Type["AttentionMetadata"]:
44
+ return ROCmFlashAttentionMetadata
45
+
46
+ @staticmethod
47
+ def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
48
+ return ROCmFlashAttentionMetadataBuilder
49
+
50
+ @staticmethod
51
+ def get_state_cls() -> Type["CommonAttentionState"]:
52
+ return CommonAttentionState
53
+
54
+ @staticmethod
55
+ def get_kv_cache_shape(
56
+ num_blocks: int,
57
+ block_size: int,
58
+ num_kv_heads: int,
59
+ head_size: int,
60
+ ) -> Tuple[int, ...]:
61
+ return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
62
+ num_kv_heads, head_size)
63
+
64
+ @staticmethod
65
+ def swap_blocks(
66
+ src_kv_cache: torch.Tensor,
67
+ dst_kv_cache: torch.Tensor,
68
+ src_to_dst: torch.Tensor,
69
+ ) -> None:
70
+ PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
71
+
72
+ @staticmethod
73
+ def copy_blocks(
74
+ kv_caches: List[torch.Tensor],
75
+ src_to_dists: torch.Tensor,
76
+ ) -> None:
77
+ PagedAttention.copy_blocks(kv_caches, src_to_dists)
78
+
79
+
80
+ @dataclass
81
+ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
82
+ """Metadata for FlashAttentionBackend.
83
+
84
+ NOTE: Any python object stored here is not updated when it is
85
+ cuda-graph replayed. If you have values that need to be changed
86
+ dynamically, it should be stored in tensor. The tensor has to be
87
+ updated from `CUDAGraphRunner.forward` API.
88
+ """
89
+ # (batch_size,). The sequence length per sequence. Sequence length means
90
+ # the computed tokens + new tokens None if it is a decoding.
91
+ seq_lens: Optional[List[int]]
92
+ # seq_lens stored as a tensor.
93
+ seq_lens_tensor: Optional[torch.Tensor]
94
+ # Maximum sequence length among prefill batch. 0 if there are decoding
95
+ # requests only.
96
+ max_prefill_seq_len: int
97
+ # Maximum sequence length among decode batch. 0 if there are prefill
98
+ # requests only.
99
+ max_decode_seq_len: int
100
+
101
+ # Whether or not if cuda graph is enabled.
102
+ # Cuda-graph is currently enabled for decoding only.
103
+ # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
104
+ use_cuda_graph: bool
105
+
106
+ # NOTE(sang): Definition of context_len, query_len, and seq_len.
107
+ # |---------- N-1 iteration --------|
108
+ # |---------------- N iteration ---------------------|
109
+ # |- tokenA -|......................|-- newTokens ---|
110
+ # |---------- context_len ----------|
111
+ # |-------------------- seq_len ----------------------|
112
+ # |-- query_len ---|
113
+
114
+ # Maximum query length in the batch. None for decoding.
115
+ max_query_len: Optional[int] = None
116
+ # (batch_size + 1,). The cumulative subquery lengths of the sequences in
117
+ # the batch, used to index into subquery. E.g., if the subquery length
118
+ # is [4, 6], it is [0, 4, 10].
119
+ query_start_loc: Optional[torch.Tensor] = None
120
+ # (batch_size + 1,). The cumulative sequence lengths of the sequences in
121
+ # the batch, used to index into sequence. E.g., if the sequence length is
122
+ # [4, 6], it is [0, 4, 10].
123
+ seq_start_loc: Optional[torch.Tensor] = None
124
+ # (batch_size,) A tensor of context lengths (tokens that are computed
125
+ # so far).
126
+ context_lens_tensor: Optional[torch.Tensor] = None
127
+
128
+ # Max number of query tokens among request in the batch.
129
+ max_decode_query_len: Optional[int] = None
130
+
131
+ _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
132
+ _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
133
+
134
+ # Begin encoder attn & enc/dec cross-attn fields...
135
+
136
+ # Encoder sequence lengths representation
137
+ encoder_seq_lens: Optional[List[int]] = None
138
+ encoder_seq_lens_tensor: Optional[torch.Tensor] = None
139
+
140
+ # Maximum sequence length among encoder sequences
141
+ max_encoder_seq_len: Optional[int] = None
142
+
143
+ # Number of tokens input to encoder
144
+ num_encoder_tokens: Optional[int] = None
145
+
146
+ # Cross-attention memory-mapping data structures: slot mapping
147
+ # and block tables
148
+ cross_slot_mapping: Optional[torch.Tensor] = None
149
+ cross_block_tables: Optional[torch.Tensor] = None
150
+
151
+ @property
152
+ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
153
+ if self.num_prefills == 0:
154
+ return None
155
+
156
+ if self._cached_prefill_metadata is not None:
157
+ return self._cached_prefill_metadata
158
+
159
+ assert self.seq_lens is not None
160
+ assert self.seq_lens_tensor is not None
161
+ assert self.block_tables is not None
162
+
163
+ self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
164
+ num_prefills=self.num_prefills,
165
+ num_prefill_tokens=self.num_prefill_tokens,
166
+ num_decode_tokens=0,
167
+ slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
168
+ multi_modal_placeholder_index_maps=self.
169
+ multi_modal_placeholder_index_maps,
170
+ enable_kv_scales_calculation=self.enable_kv_scales_calculation,
171
+ seq_lens=self.seq_lens[:self.num_prefills],
172
+ seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
173
+ max_query_len=self.max_query_len,
174
+ max_prefill_seq_len=self.max_prefill_seq_len,
175
+ max_decode_seq_len=0,
176
+ query_start_loc=None if self.query_start_loc is None else
177
+ self.query_start_loc[:self.num_prefills + 1],
178
+ seq_start_loc=None if self.seq_start_loc is None else
179
+ self.seq_start_loc[:self.num_prefills + 1],
180
+ context_lens_tensor=None if self.context_lens_tensor is None else
181
+ self.context_lens_tensor[:self.num_prefills],
182
+ block_tables=self.block_tables[:self.num_prefills],
183
+ use_cuda_graph=False,
184
+ # Begin encoder & cross attn fields below...
185
+ encoder_seq_lens=self.encoder_seq_lens,
186
+ encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
187
+ max_encoder_seq_len=self.max_encoder_seq_len,
188
+ cross_slot_mapping=self.cross_slot_mapping,
189
+ cross_block_tables=self.cross_block_tables)
190
+ return self._cached_prefill_metadata
191
+
192
+ @property
193
+ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
194
+ if self.num_decode_tokens == 0:
195
+ return None
196
+
197
+ if self._cached_decode_metadata is not None:
198
+ return self._cached_decode_metadata
199
+ assert self.block_tables is not None
200
+ assert self.seq_lens_tensor is not None
201
+
202
+ self._cached_decode_metadata = ROCmFlashAttentionMetadata(
203
+ num_prefills=0,
204
+ num_prefill_tokens=0,
205
+ num_decode_tokens=self.num_decode_tokens,
206
+ slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
207
+ multi_modal_placeholder_index_maps=None,
208
+ enable_kv_scales_calculation=True,
209
+ seq_lens=None,
210
+ seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
211
+ max_query_len=None,
212
+ max_prefill_seq_len=0,
213
+ max_decode_seq_len=self.max_decode_seq_len,
214
+ query_start_loc=None,
215
+ seq_start_loc=None,
216
+ context_lens_tensor=None,
217
+ block_tables=self.block_tables[self.num_prefills:],
218
+ use_cuda_graph=self.use_cuda_graph,
219
+ # Begin encoder & cross attn fields below...
220
+ encoder_seq_lens=self.encoder_seq_lens,
221
+ encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
222
+ max_encoder_seq_len=self.max_encoder_seq_len,
223
+ cross_slot_mapping=self.cross_slot_mapping,
224
+ cross_block_tables=self.cross_block_tables)
225
+ # Batch may be composed of prefill|decodes, adjust query start indices
226
+ # to refer to the start of decodes when the two are split apart.
227
+ # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
228
+ if self._cached_decode_metadata.query_start_loc is not None:
229
+ qs = self._cached_decode_metadata.query_start_loc
230
+ self._cached_decode_metadata.query_start_loc = qs - qs[0]
231
+ return self._cached_decode_metadata
232
+
233
+ def advance_step(self,
234
+ model_input: "ModelInputForGPUWithSamplingMetadata",
235
+ sampled_token_ids: Optional[torch.Tensor],
236
+ block_size: int,
237
+ num_seqs: int,
238
+ num_queries: int,
239
+ turn_prefills_into_decodes: bool = False):
240
+ """
241
+ Update metadata in-place to advance one decode step.
242
+ """
243
+
244
+ assert not turn_prefills_into_decodes, \
245
+ ("Chunked prefill is not supported with rocm_flash_attn yet."
246
+ "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
247
+ "specific parameter.")
248
+
249
+ # When using cudagraph, the num_seqs is padded to the next captured
250
+ # batch sized, but num_queries tracks the actual number of requests in
251
+ # the batch. For --enforce-eager mode, num_seqs == num_queries
252
+ if num_seqs != num_queries:
253
+ assert num_seqs > num_queries
254
+ assert self.use_cuda_graph
255
+
256
+ assert self.num_prefills == 0
257
+ assert self.num_prefill_tokens == 0
258
+ assert self.num_decode_tokens == num_seqs
259
+ assert self.slot_mapping.shape == (num_seqs, )
260
+
261
+ assert self.seq_lens is not None
262
+ assert len(self.seq_lens) == num_seqs
263
+ assert self.seq_lens_tensor is not None
264
+ assert self.seq_lens_tensor.shape == (num_seqs, )
265
+ assert self.max_query_len == 1
266
+ assert self.max_prefill_seq_len == 0
267
+ assert self.max_decode_seq_len == max(self.seq_lens)
268
+
269
+ assert self.query_start_loc is not None
270
+ assert self.query_start_loc.shape == (num_queries + 1, )
271
+ assert self.seq_start_loc is not None
272
+ assert self.seq_start_loc.shape == (num_seqs + 1, )
273
+
274
+ assert self.context_lens_tensor is not None
275
+ assert self.context_lens_tensor.shape == (num_queries, )
276
+
277
+ assert self.block_tables is not None
278
+ assert self.block_tables.shape[0] == num_seqs
279
+
280
+ # Update query lengths. Note that we update only queries and not seqs,
281
+ # since tensors may be padded due to captured cuda graph batch size
282
+ for i in range(num_queries):
283
+ self.seq_lens[i] += 1
284
+ self.max_decode_seq_len = max(self.seq_lens)
285
+
286
+ ops.advance_step_flashattn(num_seqs=num_seqs,
287
+ num_queries=num_queries,
288
+ block_size=block_size,
289
+ input_tokens=model_input.input_tokens,
290
+ sampled_token_ids=sampled_token_ids,
291
+ input_positions=model_input.input_positions,
292
+ seq_lens=self.seq_lens_tensor,
293
+ slot_mapping=self.slot_mapping,
294
+ block_tables=self.block_tables)
295
+
296
+
297
+ class ROCmFlashAttentionMetadataBuilder(
298
+ CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
299
+
300
+ _metadata_cls = ROCmFlashAttentionMetadata
301
+
302
+
303
+ def _make_alibi_bias(alibi_slopes: torch.Tensor,
304
+ dtype: torch.dtype,
305
+ seq_lens: Optional[List[int]],
306
+ make_attn_mask: bool = True) -> List[torch.Tensor]:
307
+ attn_biases = []
308
+ if seq_lens:
309
+ for seq_len in seq_lens:
310
+ bias = torch.arange(seq_len, dtype=dtype)
311
+ # NOTE(zhuohan): HF uses
312
+ # `bias = bias[None, :].repeat(seq_len, 1)`
313
+ # here. We find that both biases give the same results, but
314
+ # the bias below more accurately follows the original ALiBi
315
+ # paper.
316
+ bias = bias[None, :] - bias[:, None]
317
+
318
+ num_heads = alibi_slopes.shape[0]
319
+ bias = bias[None, :].repeat(
320
+ (num_heads, 1, 1)).to(alibi_slopes.device)
321
+ bias.mul_(alibi_slopes[:, None, None])
322
+ if make_attn_mask:
323
+ inf_mask = torch.empty(
324
+ (1, seq_len, seq_len),
325
+ dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
326
+ alibi_slopes.device)
327
+ attn_biases.append((bias + inf_mask).to(dtype))
328
+ else:
329
+ attn_biases.append(bias.to(dtype))
330
+
331
+ return attn_biases
332
+
333
+
334
+ def _get_seq_len_block_table_args(
335
+ attn_metadata: ROCmFlashAttentionMetadata,
336
+ attn_type: str,
337
+ ) -> tuple:
338
+ '''
339
+ The particular choice of sequence-length
340
+ attributes which should be extracted from attn_metadata is dependent
341
+ on the type of attention operation.
342
+
343
+ Decoder attn -> select entirely decoder self-attention-related fields
344
+ Encoder/decoder cross-attn -> select encoder sequence lengths
345
+ Encoder attn -> select encoder sequence lengths fields
346
+
347
+ Arguments:
348
+
349
+ * attn_metadata: Attention metadata structure associated with attention op
350
+ * attn_type: encoder attention, decoder self-attention,
351
+ encoder/decoder cross-attention
352
+
353
+ Returns:
354
+
355
+ * Appropriate sequence-lengths tensors for query and key
356
+ * Appropriate max sequence-length scalar
357
+ '''
358
+
359
+ partial_prefix_sum = 0
360
+ if attn_type == AttentionType.ENCODER:
361
+ assert attn_metadata.encoder_seq_lens is not None
362
+ assert attn_metadata.encoder_seq_lens_tensor is not None
363
+ query_seq_start_loc = torch.tensor(
364
+ [0] + [
365
+ partial_prefix_sum := partial_prefix_sum + i
366
+ for i in attn_metadata.encoder_seq_lens
367
+ ],
368
+ device=attn_metadata.encoder_seq_lens_tensor.device,
369
+ dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
370
+ causal_mask = False
371
+
372
+ # No block tables associated with encoder attention
373
+ return (query_seq_start_loc, attn_metadata.max_encoder_seq_len,
374
+ query_seq_start_loc, attn_metadata.max_encoder_seq_len,
375
+ attn_metadata.encoder_seq_lens, causal_mask)
376
+ elif attn_type == AttentionType.DECODER:
377
+ # Decoder self-attention
378
+ # Choose max_seq_len based on whether we are in prompt_run
379
+ assert attn_metadata.seq_lens is not None
380
+ assert attn_metadata.seq_lens_tensor is not None
381
+ query_seq_start_loc = torch.tensor(
382
+ [0] + [
383
+ partial_prefix_sum := partial_prefix_sum + i
384
+ for i in attn_metadata.seq_lens
385
+ ],
386
+ device=attn_metadata.seq_lens_tensor.device,
387
+ dtype=attn_metadata.seq_lens_tensor.dtype)
388
+ max_seq_len = attn_metadata.max_prefill_seq_len
389
+ causal_mask = True
390
+
391
+ return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
392
+ max_seq_len, attn_metadata.seq_lens, causal_mask)
393
+ elif attn_type == AttentionType.ENCODER_DECODER:
394
+ assert attn_metadata.seq_lens is not None
395
+ assert attn_metadata.encoder_seq_lens_tensor is not None
396
+ query_start_loc = torch.tensor(
397
+ [0] + [
398
+ partial_prefix_sum := partial_prefix_sum + i
399
+ for i in attn_metadata.seq_lens
400
+ ],
401
+ device=attn_metadata.encoder_seq_lens_tensor.device,
402
+ dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
403
+
404
+ partial_prefix_sum = 0
405
+ assert attn_metadata.encoder_seq_lens is not None
406
+ assert attn_metadata.seq_lens_tensor is not None
407
+ key_seq_start_loc = torch.tensor(
408
+ [0] + [
409
+ partial_prefix_sum := partial_prefix_sum + i
410
+ for i in attn_metadata.encoder_seq_lens
411
+ ],
412
+ device=attn_metadata.seq_lens_tensor.device,
413
+ dtype=attn_metadata.seq_lens_tensor.dtype)
414
+ causal_mask = False
415
+
416
+ # Enc/dec cross-attention KVs match encoder sequence length;
417
+ # cross-attention utilizes special "cross" block tables
418
+ return (query_start_loc, attn_metadata.max_prefill_seq_len,
419
+ key_seq_start_loc, attn_metadata.max_encoder_seq_len,
420
+ attn_metadata.seq_lens, causal_mask)
421
+ else:
422
+ raise AttributeError(f"Invalid attention type {str(attn_type)}")
423
+
424
+
425
+ class ROCmFlashAttentionImpl(AttentionImpl):
426
+ """
427
+ If the input tensors contain prompt tokens, the layout is as follows:
428
+ |<--------------- num_prompt_tokens -------------->|
429
+ |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
430
+
431
+ Otherwise, the layout is as follows:
432
+ |<------------------ num_generation_tokens (M) ----------------->|
433
+ |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
434
+
435
+ Generation tokens can contain padding when cuda-graph is used.
436
+ Currently, prompt tokens don't contain any padding.
437
+
438
+ The prompts might have different lengths, while the generation tokens
439
+ always have length 1.
440
+
441
+ If chunked prefill is enabled, prefill tokens and decode tokens can be
442
+ batched together in a flattened 1D query.
443
+
444
+ |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
445
+ |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
446
+
447
+ Currently, cuda graph is disabled for chunked prefill, meaning there's no
448
+ padding between prefill and decode tokens.
449
+ """
450
+
451
+ def __init__(
452
+ self,
453
+ num_heads: int,
454
+ head_size: int,
455
+ scale: float,
456
+ num_kv_heads: int,
457
+ alibi_slopes: Optional[List[float]],
458
+ sliding_window: Optional[int],
459
+ kv_cache_dtype: str,
460
+ blocksparse_params: Optional[Dict[str, Any]] = None,
461
+ logits_soft_cap: Optional[float] = None,
462
+ attn_type: str = AttentionType.DECODER,
463
+ ) -> None:
464
+ if blocksparse_params is not None:
465
+ raise ValueError(
466
+ "ROCmFlashAttention does not support blocksparse attention.")
467
+
468
+ if logits_soft_cap is None:
469
+ # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
470
+ self.logits_soft_cap = 0.0
471
+ else:
472
+ self.logits_soft_cap = logits_soft_cap
473
+ self.attn_type = attn_type
474
+ self.num_heads = num_heads
475
+ self.head_size = head_size
476
+ self.scale = float(scale)
477
+ self.num_kv_heads = num_kv_heads
478
+ if alibi_slopes is not None:
479
+ alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
480
+ self.alibi_slopes = alibi_slopes
481
+ self.sliding_window = ((sliding_window, sliding_window)
482
+ if sliding_window is not None else (-1, -1))
483
+ self.kv_cache_dtype = kv_cache_dtype
484
+
485
+ assert self.num_heads % self.num_kv_heads == 0
486
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
487
+
488
+ supported_head_sizes = PagedAttention.get_supported_head_sizes()
489
+ if head_size not in supported_head_sizes:
490
+ raise ValueError(
491
+ f"Head size {head_size} is not supported by PagedAttention. "
492
+ f"Supported head sizes are: {supported_head_sizes}.")
493
+
494
+ self.use_naive_attn = False
495
+ # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
496
+ self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
497
+ if self.use_triton_flash_attn:
498
+ if logits_soft_cap is not None:
499
+ raise ValueError(
500
+ "ROCm Triton FlashAttention does not support attention"
501
+ "logits soft capping."
502
+ " please try using the ROCm CK "
503
+ "FA backend instead by setting the env var "
504
+ "`VLLM_USE_TRITON_FLASH_ATTN=0`")
505
+
506
+ from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
507
+ triton_attention)
508
+ self.attn_func = triton_attention
509
+ logger.debug("Using Triton FA in ROCmBackend")
510
+ if self.sliding_window != (-1, -1):
511
+ logger.warning("ROCm Triton FA does not currently support "
512
+ "sliding window attention. If using half "
513
+ "precision, please try using the ROCm CK "
514
+ "FA backend instead by setting the env var "
515
+ "`VLLM_USE_TRITON_FLASH_ATTN=0`")
516
+ else:
517
+ # if not using triton, navi3x/navi21/navi10 do not use flash-attn
518
+ # either
519
+ if not current_platform.has_device_capability(90):
520
+ self.use_naive_attn = True
521
+ else:
522
+ try:
523
+ from flash_attn import flash_attn_varlen_func # noqa: F401
524
+ self.attn_func = flash_attn_varlen_func
525
+ logger.debug("Using CK FA in ROCmBackend")
526
+ except ModuleNotFoundError:
527
+ self.use_naive_attn = True
528
+
529
+ if self.use_naive_attn:
530
+ if logits_soft_cap is not None:
531
+ raise ValueError(
532
+ "ROCm Naive FlashAttention does not support"
533
+ "attention logits soft capping.")
534
+
535
+ self.attn_func = _sdpa_attention
536
+ logger.debug("Using naive (SDPA) attention in ROCmBackend")
537
+
538
+ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
539
+ """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
540
+ tokens, n_kv_heads, head_dim = x.shape
541
+ return (x[:, :,
542
+ None, :].expand(tokens, n_kv_heads, n_rep,
543
+ head_dim).reshape(tokens, n_kv_heads * n_rep,
544
+ head_dim))
545
+
546
+ def forward(
547
+ self,
548
+ layer: AttentionLayer,
549
+ query: torch.Tensor,
550
+ key: torch.Tensor,
551
+ value: torch.Tensor,
552
+ kv_cache: torch.Tensor,
553
+ attn_metadata: ROCmFlashAttentionMetadata,
554
+ output: Optional[torch.Tensor] = None,
555
+ ) -> torch.Tensor:
556
+ """Forward pass with FlashAttention and PagedAttention.
557
+
558
+ For decoder-only models: query, key and value must be non-None.
559
+
560
+ For encoder/decoder models:
561
+ * ROCmFlashAttentionImpl.forward() may be invoked for both self- and
562
+ cross-attention layers.
563
+ * For self-attention: query, key and value must be non-None.
564
+ * For cross-attention:
565
+ * Query must be non-None
566
+ * During prefill, key and value must be non-None; key and value
567
+ get cached for use during decode.
568
+ * During decode, key and value may be None, since:
569
+ (1) key and value tensors were cached during prefill, and
570
+ (2) cross-attention key and value tensors do not grow during
571
+ decode
572
+
573
+ A note on how the attn_type (attention type enum) argument impacts
574
+ attention forward() behavior:
575
+
576
+ * DECODER: normal decoder-only behavior;
577
+ use decoder self-attention block table
578
+ * ENCODER: no KV caching; pass encoder sequence
579
+ attributes (encoder_seq_lens/encoder_seq_lens_tensor/
580
+ max_encoder_seq_len) to kernel, in lieu of decoder
581
+ sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
582
+ * ENCODER_DECODER: cross-attention behavior;
583
+ use cross-attention block table for caching KVs derived
584
+ from encoder hidden states; since KV sequence lengths
585
+ will match encoder sequence lengths, pass encoder sequence
586
+ attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
587
+ max_encoder_seq_len)
588
+
589
+ Args:
590
+ query: shape = [num_tokens, num_heads * head_size]
591
+ key: shape = [num_tokens, num_kv_heads * head_size]
592
+ value: shape = [num_tokens, num_kv_heads * head_size]
593
+ kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
594
+ NOTE: kv_cache will be an empty tensor with shape [0]
595
+ for profiling run.
596
+ attn_metadata: Metadata for attention.
597
+ attn_type: Select attention type, between encoder attention,
598
+ decoder self-attention, or encoder/decoder cross-
599
+ attention. Defaults to decoder self-attention,
600
+ which is the vLLM default generally
601
+ Returns:
602
+ shape = [num_tokens, num_heads * head_size]
603
+ """
604
+ query = query.view(-1, self.num_heads, self.head_size)
605
+ if key is not None:
606
+ assert value is not None
607
+ key = key.view(-1, self.num_kv_heads, self.head_size)
608
+ value = value.view(-1, self.num_kv_heads, self.head_size)
609
+ else:
610
+ assert value is None
611
+
612
+ if self.attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
613
+ key_cache, value_cache = PagedAttention.split_kv_cache(
614
+ kv_cache, self.num_kv_heads, self.head_size)
615
+
616
+ if key is not None and value is not None:
617
+ # Reshape the input keys and values and store them in the
618
+ # cache. If kv_cache is not provided, the new key and value
619
+ # tensors are not cached. This happens during the initial
620
+ # memory profiling run.
621
+ PagedAttention.write_to_paged_cache(
622
+ key,
623
+ value,
624
+ key_cache,
625
+ value_cache,
626
+ attn_metadata.slot_mapping
627
+ if self.attn_type != AttentionType.ENCODER_DECODER else
628
+ attn_metadata.cross_slot_mapping,
629
+ self.kv_cache_dtype,
630
+ layer._k_scale,
631
+ layer._v_scale,
632
+ )
633
+
634
+ if self.attn_type != AttentionType.ENCODER:
635
+ num_prefill_tokens = attn_metadata.num_prefill_tokens
636
+ else:
637
+ assert attn_metadata.num_encoder_tokens is not None
638
+ num_prefill_tokens = attn_metadata.num_encoder_tokens
639
+
640
+ output = torch.empty_like(query)
641
+ # Query for decode. KV is not needed because it is already cached.
642
+ decode_query = query[num_prefill_tokens:]
643
+ # QKV for prefill.
644
+ query = query[:num_prefill_tokens]
645
+
646
+ if key is not None and value is not None \
647
+ and self.attn_type != AttentionType.ENCODER_DECODER:
648
+ key = key[:num_prefill_tokens]
649
+ value = value[:num_prefill_tokens]
650
+
651
+ if prefill_meta := attn_metadata.prefill_metadata:
652
+ # Prompt run.
653
+ # normal attention and DECODER
654
+ if self.attn_type == AttentionType.DECODER and (
655
+ kv_cache.numel() == 0 or prefill_meta.block_tables is None
656
+ or prefill_meta.block_tables.numel() == 0):
657
+ (query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
658
+ key_max_seq_len, seq_lens,
659
+ causal_mask) = (prefill_meta.seq_start_loc,
660
+ prefill_meta.max_prefill_seq_len,
661
+ prefill_meta.seq_start_loc,
662
+ prefill_meta.max_prefill_seq_len,
663
+ attn_metadata.seq_lens, True)
664
+ # prefix-enabled attention and ENCODER/ENCODER_DECODER
665
+ else:
666
+ (query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
667
+ key_max_seq_len, seq_lens,
668
+ causal_mask) = _get_seq_len_block_table_args(
669
+ prefill_meta, self.attn_type)
670
+ # Prompt run.
671
+ if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
672
+ # triton attention
673
+ # When block_tables are not filled, it means q and k are the
674
+ # prompt, and they have the same length.
675
+ attn_masks = None
676
+ if self.use_triton_flash_attn:
677
+ if self.alibi_slopes is not None:
678
+ attn_masks = _make_alibi_bias(
679
+ self.alibi_slopes,
680
+ query.dtype,
681
+ seq_lens,
682
+ make_attn_mask=False) # type: ignore
683
+ out, _ = self.attn_func(
684
+ query,
685
+ key,
686
+ value,
687
+ None,
688
+ query_seq_start_loc,
689
+ key_seq_start_loc,
690
+ query_max_seq_len,
691
+ key_max_seq_len,
692
+ causal_mask,
693
+ self.scale,
694
+ attn_masks[0][None]
695
+ if attn_masks is not None else None,
696
+ )
697
+ elif self.use_naive_attn:
698
+ if self.num_kv_heads != self.num_heads:
699
+ # Interleave for MQA workaround.
700
+ key = self.repeat_kv(key, self.num_queries_per_kv)
701
+ value = self.repeat_kv(value, self.num_queries_per_kv)
702
+ if self.alibi_slopes is not None:
703
+ attn_masks = _make_alibi_bias(
704
+ self.alibi_slopes,
705
+ query.dtype,
706
+ attn_metadata.seq_lens,
707
+ make_attn_mask=True) # type: ignore
708
+ query = query.movedim(0, query.dim() - 2)
709
+ key = key.movedim(0, key.dim() - 2)
710
+ value = value.movedim(0, value.dim() - 2)
711
+ # sdpa math backend attention
712
+ out = self.attn_func(
713
+ query,
714
+ key,
715
+ value,
716
+ query_seq_start_loc,
717
+ num_prefill_tokens,
718
+ self.num_heads,
719
+ self.head_size,
720
+ self.scale,
721
+ causal_mask,
722
+ attn_masks,
723
+ )
724
+ else:
725
+ out = self.attn_func(
726
+ q=query,
727
+ k=key,
728
+ v=value,
729
+ cu_seqlens_q=query_seq_start_loc,
730
+ cu_seqlens_k=key_seq_start_loc,
731
+ max_seqlen_q=prefill_meta.max_prefill_seq_len,
732
+ max_seqlen_k=key_max_seq_len,
733
+ softmax_scale=self.scale,
734
+ causal=True,
735
+ window_size=self.sliding_window,
736
+ alibi_slopes=self.alibi_slopes,
737
+ softcap=self.logits_soft_cap,
738
+ )
739
+
740
+ # common code for prefill
741
+ assert output[:num_prefill_tokens].shape == out.shape
742
+ if output.shape[0] > num_prefill_tokens:
743
+ output[:num_prefill_tokens] = out
744
+ else:
745
+ output = out
746
+ else:
747
+ # prefix-enabled attention
748
+ output[:num_prefill_tokens] = PagedAttention.forward_prefix(
749
+ query,
750
+ key,
751
+ value,
752
+ self.kv_cache_dtype,
753
+ key_cache,
754
+ value_cache,
755
+ prefill_meta.block_tables,
756
+ prefill_meta.query_start_loc,
757
+ prefill_meta.seq_lens_tensor,
758
+ prefill_meta.context_lens_tensor,
759
+ prefill_meta.max_query_len,
760
+ self.alibi_slopes,
761
+ self.sliding_window[0],
762
+ layer._k_scale,
763
+ layer._v_scale,
764
+ )
765
+
766
+ if decode_meta := attn_metadata.decode_metadata:
767
+ # Decoding run.
768
+ # Whether to use rocm custom paged attention or not
769
+ num_seqs, num_heads, head_size = decode_query.shape
770
+ block_size = value_cache.shape[3]
771
+ gqa_ratio = num_heads // self.num_kv_heads
772
+ use_custom = _use_rocm_custom_paged_attention(
773
+ decode_query.dtype, head_size, block_size, gqa_ratio,
774
+ decode_meta.max_decode_seq_len)
775
+ if use_custom:
776
+ max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
777
+ != AttentionType.ENCODER_DECODER else
778
+ decode_meta.max_encoder_seq_len)
779
+ assert max_seq_len is not None
780
+ max_num_partitions = (
781
+ (max_seq_len + _PARTITION_SIZE_ROCM - 1) //
782
+ _PARTITION_SIZE_ROCM)
783
+ assert _PARTITION_SIZE_ROCM % block_size == 0
784
+ tmp_output = torch.empty(
785
+ size=(num_seqs, num_heads, max_num_partitions, head_size),
786
+ dtype=output.dtype,
787
+ device=output.device,
788
+ )
789
+ exp_sums = torch.empty(
790
+ size=(num_seqs, num_heads, max_num_partitions),
791
+ dtype=torch.float32,
792
+ device=output.device,
793
+ )
794
+ max_logits = torch.empty_like(exp_sums)
795
+ if num_prefill_tokens > 0:
796
+ out = output[num_prefill_tokens:]
797
+ else:
798
+ out = output
799
+ ops.paged_attention_rocm(
800
+ out,
801
+ exp_sums,
802
+ max_logits,
803
+ tmp_output,
804
+ decode_query,
805
+ key_cache,
806
+ value_cache,
807
+ self.num_kv_heads,
808
+ self.scale,
809
+ decode_meta.block_tables
810
+ if self.attn_type != AttentionType.ENCODER_DECODER else
811
+ decode_meta.cross_block_tables,
812
+ decode_meta.seq_lens_tensor
813
+ if self.attn_type != AttentionType.ENCODER_DECODER else
814
+ decode_meta.encoder_seq_lens_tensor,
815
+ block_size,
816
+ max_seq_len,
817
+ self.alibi_slopes,
818
+ self.kv_cache_dtype,
819
+ layer._k_scale,
820
+ layer._v_scale,
821
+ )
822
+ else:
823
+ output[num_prefill_tokens:] = PagedAttention.forward_decode(
824
+ decode_query,
825
+ key_cache,
826
+ value_cache,
827
+ decode_meta.block_tables
828
+ if self.attn_type != AttentionType.ENCODER_DECODER else
829
+ decode_meta.cross_block_tables,
830
+ decode_meta.seq_lens_tensor
831
+ if self.attn_type != AttentionType.ENCODER_DECODER else
832
+ decode_meta.encoder_seq_lens_tensor,
833
+ decode_meta.max_decode_seq_len
834
+ if self.attn_type != AttentionType.ENCODER_DECODER else
835
+ decode_meta.max_encoder_seq_len,
836
+ self.kv_cache_dtype,
837
+ self.num_kv_heads,
838
+ self.scale,
839
+ self.alibi_slopes,
840
+ layer._k_scale,
841
+ layer._v_scale,
842
+ )
843
+
844
+ # Reshape the output tensor.
845
+ return output.view(-1, self.num_heads * self.head_size)
846
+
847
+
848
+ def _sdpa_attention(
849
+ query: torch.Tensor,
850
+ key: torch.Tensor,
851
+ value: torch.Tensor,
852
+ seq_lens: List[int],
853
+ num_tokens: int,
854
+ num_heads: int,
855
+ head_size: int,
856
+ scale: float,
857
+ attn_masks: Optional[List[torch.Tensor]] = None,
858
+ ) -> torch.Tensor:
859
+ start = 0
860
+ output = torch.empty((num_tokens, num_heads, head_size),
861
+ dtype=query.dtype,
862
+ device=query.device)
863
+
864
+ for i, seq_len in enumerate(seq_lens):
865
+ end = start + seq_len
866
+ with torch.backends.cuda.sdp_kernel(enable_math=True,
867
+ enable_flash=False,
868
+ enable_mem_efficient=False):
869
+ sub_out = torch.nn.functional.scaled_dot_product_attention(
870
+ query[:, start:end, :],
871
+ key[:, start:end, :],
872
+ value[:, start:end, :],
873
+ dropout_p=0.0,
874
+ is_causal=attn_masks is None,
875
+ attn_mask=attn_masks[i] if attn_masks else None,
876
+ scale=scale).movedim(query.dim() - 2, 0)
877
+ output[start:end, :, :] = sub_out
878
+ start = end
879
+
880
+ return output
881
+
882
+
883
+ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
884
+ block_size: int, gqa_ratio: int,
885
+ max_seq_len: int) -> bool:
886
+ # rocm custom page attention not support on navi (gfx1*)
887
+ return (_ON_MI250_MI300 and not _ON_NAVI
888
+ and (qtype == torch.half or qtype == torch.bfloat16)
889
+ and (head_size == 64 or head_size == 128)
890
+ and (block_size == 16 or block_size == 32)
891
+ and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
.venv/lib/python3.11/site-packages/vllm/attention/backends/torch_sdpa.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """ Attention layer with torch scaled_dot_product_attention
3
+ and PagedAttention."""
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Optional, Tuple, Type
6
+
7
+ import torch
8
+ from torch.nn.functional import scaled_dot_product_attention
9
+
10
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
11
+ AttentionLayer,
12
+ AttentionMetadata,
13
+ AttentionMetadataBuilder,
14
+ AttentionType)
15
+ from vllm.attention.backends.utils import CommonAttentionState
16
+ from vllm.attention.ops.ipex_attn import PagedAttention
17
+ from vllm.attention.ops.paged_attn import PagedAttentionMetadata
18
+ from vllm.logger import init_logger
19
+ from vllm.utils import make_tensor_with_pad
20
+ from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
21
+
22
+ logger = init_logger(__name__)
23
+
24
+
25
+ class TorchSDPABackend(AttentionBackend):
26
+
27
+ @staticmethod
28
+ def get_name() -> str:
29
+ return "TORCH_SDPA"
30
+
31
+ @staticmethod
32
+ def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
33
+ return TorchSDPABackendImpl
34
+
35
+ @staticmethod
36
+ def get_metadata_cls() -> Type["AttentionMetadata"]:
37
+ return TorchSDPAMetadata
38
+
39
+ @staticmethod
40
+ def get_state_cls() -> Type["CommonAttentionState"]:
41
+ return CommonAttentionState
42
+
43
+ @staticmethod
44
+ def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]:
45
+ return TorchSDPAMetadataBuilder
46
+
47
+ @staticmethod
48
+ def get_kv_cache_shape(
49
+ num_blocks: int,
50
+ block_size: int,
51
+ num_kv_heads: int,
52
+ head_size: int,
53
+ ) -> Tuple[int, ...]:
54
+ return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
55
+ num_kv_heads, head_size)
56
+
57
+ @staticmethod
58
+ def swap_blocks(
59
+ src_kv_cache: torch.Tensor,
60
+ dst_kv_cache: torch.Tensor,
61
+ src_to_dst: torch.Tensor,
62
+ ) -> None:
63
+ PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
64
+
65
+ @staticmethod
66
+ def copy_blocks(
67
+ kv_caches: List[torch.Tensor],
68
+ src_to_dists: torch.Tensor,
69
+ ) -> None:
70
+ PagedAttention.copy_blocks(kv_caches, src_to_dists)
71
+
72
+
73
+ @dataclass
74
+ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
75
+ """Metadata for TorchSDPABackend.
76
+ """
77
+ # Currently, input sequences can only contain all prompts
78
+ # or all decoding. True if all sequences are prompts.
79
+ chunked_prefill: bool
80
+ seq_lens: Optional[List[int]] = None # For non-chunked prefill
81
+
82
+ # For chunked prefill only
83
+ max_query_len: Optional[int] = None
84
+ max_kv_len: Optional[int] = None
85
+ query_start_loc: Optional[torch.Tensor] = None
86
+ kv_start_loc: Optional[torch.Tensor] = None
87
+ prefill_block_tables: Optional[torch.Tensor] = None
88
+
89
+ # Begin encoder attn & enc/dec cross-attn fields...
90
+ # Encoder sequence lengths representation
91
+ encoder_seq_lens: Optional[List[int]] = None
92
+ encoder_seq_lens_tensor: Optional[torch.Tensor] = None
93
+
94
+ # Maximum sequence length among encoder sequences
95
+ max_encoder_seq_len: Optional[int] = None
96
+
97
+ # Number of tokens input to encoder
98
+ num_encoder_tokens: Optional[int] = None
99
+
100
+ # Cross-attention memory-mapping data structures: slot mapping
101
+ # and block tables
102
+ cross_slot_mapping: Optional[torch.Tensor] = None
103
+ cross_block_tables: Optional[torch.Tensor] = None
104
+
105
+ def __post_init__(self):
106
+ # Set during the execution of the first attention op.
107
+ # It is a list because it is needed to set per prompt
108
+ # when alibi slopes is used. It is because of the limitation
109
+ # from xformer API.
110
+ # will not appear in the __repr__ and __init__
111
+ self.attn_bias: Optional[List[torch.Tensor]] = None
112
+ self.encoder_attn_bias: Optional[List[torch.Tensor]] = None
113
+ self.cross_attn_bias: Optional[List[torch.Tensor]] = None
114
+
115
+ @property
116
+ def is_all_encoder_attn_metadata_set(self):
117
+ '''
118
+ All attention metadata required for encoder attention is set.
119
+ '''
120
+ return ((self.encoder_seq_lens is not None)
121
+ and (self.encoder_seq_lens_tensor is not None)
122
+ and (self.max_encoder_seq_len is not None))
123
+
124
+ @property
125
+ def is_all_cross_attn_metadata_set(self):
126
+ '''
127
+ All attention metadata required for enc/dec cross-attention is set.
128
+
129
+ Superset of encoder attention required metadata.
130
+ '''
131
+ return (self.is_all_encoder_attn_metadata_set
132
+ and (self.cross_slot_mapping is not None)
133
+ and (self.cross_block_tables is not None))
134
+
135
+ @property
136
+ def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
137
+ if self.num_prefill_tokens == 0:
138
+ return None
139
+ return self
140
+
141
+ @property
142
+ def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
143
+ if self.num_decode_tokens == 0:
144
+ return None
145
+ return self
146
+
147
+ def get_seq_lens(
148
+ self,
149
+ attn_type: str,
150
+ ):
151
+ '''
152
+ Extract appropriate sequence lengths from attention metadata
153
+ according to attention type.
154
+
155
+ Arguments:
156
+
157
+ * attn_metadata: Attention metadata structure associated with attention
158
+ * attn_type: encoder attention, decoder self-attention,
159
+ encoder/decoder cross-attention
160
+
161
+ Returns:
162
+ * Appropriate sequence lengths tensor for query
163
+ * Appropriate sequence lengths tensor for key & value
164
+ '''
165
+
166
+ if (attn_type == AttentionType.DECODER
167
+ or attn_type == AttentionType.ENCODER_ONLY):
168
+ seq_lens_q = self.seq_lens
169
+ seq_lens_kv = self.seq_lens
170
+ elif attn_type == AttentionType.ENCODER:
171
+ seq_lens_q = self.encoder_seq_lens
172
+ seq_lens_kv = self.encoder_seq_lens
173
+ elif attn_type == AttentionType.ENCODER_DECODER:
174
+ seq_lens_q = self.seq_lens
175
+ seq_lens_kv = self.encoder_seq_lens
176
+ else:
177
+ raise AttributeError(f"Invalid attention type {str(attn_type)}")
178
+ return seq_lens_q, seq_lens_kv
179
+
180
+ def get_attn_bias(
181
+ self,
182
+ attn_type: str,
183
+ ) -> Optional[List[torch.Tensor]]:
184
+ '''
185
+ Extract appropriate attention bias from attention metadata
186
+ according to attention type.
187
+
188
+ Arguments:
189
+
190
+ * attn_metadata: Attention metadata structure associated with attention
191
+ * attn_type: encoder attention, decoder self-attention,
192
+ encoder/decoder cross-attention
193
+
194
+ Returns:
195
+ * Appropriate attention bias value given the attention type
196
+ '''
197
+
198
+ if (attn_type == AttentionType.DECODER
199
+ or attn_type == AttentionType.ENCODER_ONLY):
200
+ return self.attn_bias
201
+ elif attn_type == AttentionType.ENCODER:
202
+ return self.encoder_attn_bias
203
+ elif attn_type == AttentionType.ENCODER_DECODER:
204
+ return self.cross_attn_bias
205
+ else:
206
+ raise AttributeError(f"Invalid attention type {str(attn_type)}")
207
+
208
+ def set_attn_bias(
209
+ self,
210
+ attn_bias: List[torch.Tensor],
211
+ attn_type: str,
212
+ ) -> None:
213
+ '''
214
+ Update appropriate attention bias field of attention metadata,
215
+ according to attention type.
216
+
217
+ Arguments:
218
+
219
+ * attn_metadata: Attention metadata structure associated with attention
220
+ * attn_bias: The desired attention bias value
221
+ * attn_type: encoder attention, decoder self-attention,
222
+ encoder/decoder cross-attention
223
+ '''
224
+
225
+ if (attn_type == AttentionType.DECODER
226
+ or attn_type == AttentionType.ENCODER_ONLY):
227
+ self.attn_bias = attn_bias
228
+ elif attn_type == AttentionType.ENCODER:
229
+ self.encoder_attn_bias = attn_bias
230
+ elif attn_type == AttentionType.ENCODER_DECODER:
231
+ self.cross_attn_bias = attn_bias
232
+ else:
233
+ raise AttributeError(f"Invalid attention type {str(attn_type)}")
234
+
235
+ def get_seq_len_block_table_args(
236
+ self,
237
+ attn_type: str,
238
+ ) -> tuple:
239
+ '''
240
+ The particular choice of sequence-length- and block-table-related
241
+ attributes which should be extracted from attn_metadata is dependent
242
+ on the type of attention operation.
243
+
244
+ Decoder attn -> select entirely decoder self-attention-related fields
245
+ Encoder/decoder cross-attn -> select encoder sequence lengths &
246
+ cross-attn block-tables fields
247
+ Encoder attn -> select encoder sequence lengths fields & no block tables
248
+
249
+ Arguments:
250
+
251
+ * attn_metadata: Attention metadata structure associated with attention
252
+ * is_prompt: True if prefill, False otherwise
253
+ * attn_type: encoder attention, decoder self-attention,
254
+ encoder/decoder cross-attention
255
+
256
+ Returns:
257
+
258
+ * Appropriate sequence-lengths tensor
259
+ * Appropriate max sequence-length scalar
260
+ * Appropriate block tables (or None)
261
+ '''
262
+
263
+ if (attn_type == AttentionType.DECODER
264
+ or attn_type == AttentionType.ENCODER_ONLY):
265
+ # Decoder self-attention
266
+ # Choose max_seq_len based on whether we are in prompt_run
267
+ return (self.seq_lens_tensor, self.max_decode_seq_len,
268
+ self.block_tables)
269
+ elif attn_type == AttentionType.ENCODER_DECODER:
270
+ # Enc/dec cross-attention KVs match encoder sequence length;
271
+ # cross-attention utilizes special "cross" block tables
272
+ return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
273
+ self.cross_block_tables)
274
+ elif attn_type == AttentionType.ENCODER:
275
+ # No block tables associated with encoder attention
276
+ return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
277
+ None)
278
+ else:
279
+ raise AttributeError(f"Invalid attention type {str(attn_type)}")
280
+
281
+
282
+ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
283
+
284
+ def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
285
+ self.chunked_prefill = input_builder.chunked_prefill
286
+ self.input_builder = input_builder
287
+
288
+ def prepare(self):
289
+ self.input_data = self.input_builder.input_data
290
+
291
+ def build(self, seq_lens: List[int], query_lens: List[int],
292
+ cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:
293
+ input_data = self.input_data
294
+ prefill_seq_lens = seq_lens[0:input_data.num_prefills]
295
+ prefill_query_lens = query_lens[0:input_data.num_prefills]
296
+ slot_mapping = torch.tensor(input_data.slot_mapping,
297
+ dtype=torch.long,
298
+ device="cpu")
299
+
300
+ # For chunked-prefill
301
+ if self.chunked_prefill and input_data.num_prefill_tokens != 0:
302
+ prefill_block_tables = make_tensor_with_pad(
303
+ self.input_data.prefill_block_tables,
304
+ pad=0,
305
+ dtype=torch.int32,
306
+ device="cpu",
307
+ )
308
+ query_lens_tensor = torch.tensor(prefill_query_lens,
309
+ dtype=torch.int32,
310
+ device="cpu")
311
+ kv_lens_tensor = torch.tensor(prefill_seq_lens,
312
+ dtype=torch.int32,
313
+ device="cpu")
314
+ query_start_loc = torch.zeros(input_data.num_prefills + 1,
315
+ dtype=torch.int32,
316
+ device="cpu")
317
+ kv_start_loc = torch.zeros(input_data.num_prefills + 1,
318
+ dtype=torch.int32,
319
+ device="cpu")
320
+ torch.cumsum(query_lens_tensor,
321
+ dim=0,
322
+ dtype=torch.int32,
323
+ out=query_start_loc[1:])
324
+ torch.cumsum(kv_lens_tensor,
325
+ dim=0,
326
+ dtype=torch.int32,
327
+ out=kv_start_loc[1:])
328
+ max_query_len = max(prefill_query_lens)
329
+ max_kv_len = max(prefill_seq_lens)
330
+ else:
331
+ prefill_block_tables = None
332
+ query_start_loc = None
333
+ kv_start_loc = None
334
+ max_query_len = None
335
+ max_kv_len = None
336
+
337
+ # For paged attention
338
+ if input_data.num_decode_tokens != 0:
339
+ seq_lens_tensor = torch.tensor(
340
+ input_data.seq_lens[input_data.num_prefills:],
341
+ dtype=torch.int32,
342
+ device="cpu",
343
+ )
344
+ block_tables = make_tensor_with_pad(
345
+ self.input_data.decode_block_tables,
346
+ pad=0,
347
+ dtype=torch.int32,
348
+ device="cpu",
349
+ )
350
+ else:
351
+ block_tables = torch.tensor([])
352
+ seq_lens_tensor = torch.tensor(
353
+ input_data.seq_lens[:input_data.num_prefills],
354
+ dtype=torch.int32,
355
+ device="cpu",
356
+ )
357
+
358
+ # For multi-modal models
359
+ placeholder_index_maps = None
360
+ if len(input_data.multi_modal_inputs_list) != 0:
361
+ placeholder_index_maps = {
362
+ modality: placeholder_map.index_map()
363
+ for modality, placeholder_map in
364
+ input_data.multi_modal_placeholder_maps.items()
365
+ }
366
+
367
+ attn_metadata = TorchSDPAMetadata(
368
+ chunked_prefill=self.chunked_prefill,
369
+ seq_lens=prefill_seq_lens,
370
+ seq_lens_tensor=seq_lens_tensor,
371
+ max_query_len=max_query_len,
372
+ max_kv_len=max_kv_len,
373
+ query_start_loc=query_start_loc,
374
+ kv_start_loc=kv_start_loc,
375
+ max_decode_seq_len=input_data.max_decode_seq_len,
376
+ num_prefills=input_data.num_prefills,
377
+ num_prefill_tokens=input_data.num_prefill_tokens,
378
+ num_decode_tokens=input_data.num_decode_tokens,
379
+ block_tables=block_tables,
380
+ prefill_block_tables=prefill_block_tables,
381
+ slot_mapping=slot_mapping,
382
+ multi_modal_placeholder_index_maps=placeholder_index_maps,
383
+ enable_kv_scales_calculation=False,
384
+ )
385
+
386
+ return attn_metadata
387
+
388
+
389
+ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
390
+
391
+ def __init__(
392
+ self,
393
+ num_heads: int,
394
+ head_size: int,
395
+ scale: float,
396
+ num_kv_heads: int,
397
+ alibi_slopes: Optional[List[float]],
398
+ sliding_window: Optional[int],
399
+ kv_cache_dtype: str,
400
+ blocksparse_params: Optional[Dict[str, Any]] = None,
401
+ logits_soft_cap: Optional[float] = None,
402
+ attn_type: str = AttentionType.DECODER,
403
+ ) -> None:
404
+ if blocksparse_params is not None:
405
+ raise ValueError(
406
+ "Torch SPDA does not support block-sparse attention.")
407
+ if logits_soft_cap is not None:
408
+ logger.warning_once("Torch SPDA does not support logits soft cap. "
409
+ "Outputs may be slightly off.")
410
+ self.num_heads = num_heads
411
+ self.head_size = head_size
412
+ self.scale = float(scale)
413
+ self.num_kv_heads = num_kv_heads
414
+ if alibi_slopes is not None:
415
+ alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
416
+ self.alibi_slopes = alibi_slopes
417
+ self.sliding_window = sliding_window
418
+ self.kv_cache_dtype = kv_cache_dtype
419
+
420
+ assert self.num_heads % self.num_kv_heads == 0
421
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
422
+ self.need_mask = (self.alibi_slopes is not None
423
+ or self.sliding_window is not None)
424
+
425
+ supported_head_sizes = PagedAttention.get_supported_head_sizes()
426
+ if head_size not in supported_head_sizes:
427
+ raise ValueError(
428
+ f"Head size {head_size} is not supported by PagedAttention. "
429
+ f"Supported head sizes are: {supported_head_sizes}.")
430
+ if kv_cache_dtype != "auto":
431
+ raise NotImplementedError(
432
+ "Torch SDPA backend does not support FP8 KV cache. "
433
+ "Please use xFormers backend instead.")
434
+ self.attn_type = attn_type
435
+
436
+ def forward(
437
+ self,
438
+ layer: AttentionLayer,
439
+ query: torch.Tensor,
440
+ key: torch.Tensor,
441
+ value: torch.Tensor,
442
+ kv_cache: torch.Tensor,
443
+ attn_metadata: TorchSDPAMetadata, # type: ignore
444
+ output: Optional[torch.Tensor] = None,
445
+ ) -> torch.Tensor:
446
+ """Forward pass with torch SDPA and PagedAttention.
447
+
448
+ Args:
449
+ query: shape = [num_tokens, num_heads * head_size]
450
+ key: shape = [num_tokens, num_kv_heads * head_size]
451
+ value: shape = [num_tokens, num_kv_heads * head_size]
452
+ kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
453
+ NOTE: kv_cache will be an empty tensor with shape [0]
454
+ for profiling run.
455
+ attn_metadata: Metadata for attention.
456
+ Returns:
457
+ shape = [num_tokens, num_heads * head_size]
458
+ """
459
+ attn_type = self.attn_type
460
+ if (attn_type == AttentionType.ENCODER
461
+ and (not attn_metadata.is_all_encoder_attn_metadata_set)):
462
+ raise AttributeError("Encoder attention requires setting "
463
+ "encoder metadata attributes.")
464
+ elif (attn_type == AttentionType.ENCODER_DECODER
465
+ and (not attn_metadata.is_all_cross_attn_metadata_set)):
466
+ raise AttributeError("Encoder/decoder cross-attention "
467
+ "requires setting cross-attention "
468
+ "metadata attributes.")
469
+
470
+ # Reshape the query, key, and value tensors.
471
+ query = query.view(-1, self.num_heads, self.head_size)
472
+ if key is not None:
473
+ assert value is not None
474
+ key = key.view(-1, self.num_kv_heads, self.head_size)
475
+ value = value.view(-1, self.num_kv_heads, self.head_size)
476
+ else:
477
+ assert value is None
478
+
479
+ if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
480
+ # KV-cache during decoder-self- or
481
+ # encoder-decoder-cross-attention, but not
482
+ # during encoder attention.
483
+ #
484
+ # Even if there are no new key/value pairs to cache,
485
+ # we still need to break out key_cache and value_cache
486
+ # i.e. for later use by paged attention
487
+ key_cache, value_cache = PagedAttention.split_kv_cache(
488
+ kv_cache, self.num_kv_heads, self.head_size)
489
+
490
+ if (key is not None) and (value is not None):
491
+ if attn_type == AttentionType.ENCODER_DECODER:
492
+ # Update cross-attention KV cache (prefill-only)
493
+ # During cross-attention decode, key & value will be None,
494
+ # preventing this IF-statement branch from running
495
+ updated_slot_mapping = attn_metadata.cross_slot_mapping
496
+ else:
497
+ # Update self-attention KV cache (prefill/decode)
498
+ updated_slot_mapping = attn_metadata.slot_mapping
499
+
500
+ PagedAttention.write_to_paged_cache(
501
+ key, value, key_cache, value_cache, updated_slot_mapping,
502
+ self.kv_cache_dtype, layer._k_scale, layer._v_scale)
503
+
504
+ if attn_type != AttentionType.ENCODER:
505
+ # Decoder self-attention supports chunked prefill.
506
+ # Encoder/decoder cross-attention requires no chunked
507
+ # prefill (100% prefill or 100% decode tokens, no mix)
508
+ num_prefill_tokens = attn_metadata.num_prefill_tokens
509
+ num_decode_tokens = attn_metadata.num_decode_tokens
510
+ else:
511
+ # Encoder attention - chunked prefill is not applicable;
512
+ # derive token-count from query shape & and treat them
513
+ # as 100% prefill tokens
514
+ assert attn_metadata.num_encoder_tokens is not None
515
+ num_prefill_tokens = attn_metadata.num_encoder_tokens
516
+ num_decode_tokens = 0
517
+
518
+ if attn_type == AttentionType.DECODER:
519
+ # Only enforce this shape-constraint for decoder
520
+ # self-attention
521
+ assert key.shape[0] == num_prefill_tokens + num_decode_tokens
522
+ assert value.shape[0] == num_prefill_tokens + num_decode_tokens
523
+
524
+ output = torch.empty_like(query)
525
+ if prefill_meta := attn_metadata.prefill_metadata:
526
+ assert attn_metadata.seq_lens is not None
527
+ if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore
528
+ self._run_sdpa_forward(output,
529
+ query,
530
+ key,
531
+ value,
532
+ prefill_meta,
533
+ attn_type=attn_type)
534
+ else:
535
+ # prefix-enabled attention
536
+ assert not self.need_mask
537
+ import intel_extension_for_pytorch.llm.modules as ipex_modules
538
+ output = torch.empty_like(query)
539
+ ipex_modules.PagedAttention.flash_attn_varlen_func(
540
+ output[:prefill_meta.num_prefill_tokens, :, :],
541
+ query[:prefill_meta.num_prefill_tokens, :, :],
542
+ key_cache,
543
+ value_cache,
544
+ prefill_meta.query_start_loc,
545
+ prefill_meta.kv_start_loc,
546
+ prefill_meta.max_query_len,
547
+ prefill_meta.max_kv_len,
548
+ self.scale,
549
+ True,
550
+ prefill_meta.prefill_block_tables,
551
+ self.alibi_slopes,
552
+ )
553
+
554
+ if decode_meta := attn_metadata.decode_metadata:
555
+ assert attn_type != AttentionType.ENCODER_ONLY, (
556
+ "Encoder-only models should not have decode metadata.")
557
+ # Decoding run.
558
+ (
559
+ seq_lens_arg,
560
+ max_seq_len_arg,
561
+ block_tables_arg,
562
+ ) = decode_meta.get_seq_len_block_table_args(attn_type)
563
+
564
+ PagedAttention.forward_decode(
565
+ output[attn_metadata.num_prefill_tokens:, :, :],
566
+ query[attn_metadata.num_prefill_tokens:, :, :],
567
+ key_cache,
568
+ value_cache,
569
+ block_tables_arg,
570
+ seq_lens_arg,
571
+ max_seq_len_arg,
572
+ self.kv_cache_dtype,
573
+ self.num_kv_heads,
574
+ self.scale,
575
+ self.alibi_slopes,
576
+ layer._k_scale,
577
+ layer._v_scale,
578
+ )
579
+
580
+ # Reshape the output tensor.
581
+ return output.view(-1, self.num_heads * self.head_size)
582
+
583
+ def _run_sdpa_forward(
584
+ self,
585
+ output: torch.Tensor,
586
+ query: torch.Tensor,
587
+ key: torch.Tensor,
588
+ value: torch.Tensor,
589
+ attn_metadata: TorchSDPAMetadata,
590
+ attn_type: str = AttentionType.DECODER,
591
+ ) -> None:
592
+ if self.num_kv_heads != self.num_heads:
593
+ key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
594
+ value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
595
+
596
+ attn_masks = attn_metadata.get_attn_bias(attn_type)
597
+ if attn_masks is None:
598
+ if self.alibi_slopes is not None:
599
+ attn_masks = _make_alibi_bias(
600
+ self.alibi_slopes, query.dtype,
601
+ attn_metadata.seq_lens) # type: ignore
602
+ elif self.sliding_window is not None:
603
+ assert attn_metadata.seq_lens is not None
604
+ attn_masks = _make_sliding_window_bias(
605
+ attn_metadata.seq_lens, self.sliding_window,
606
+ query.dtype) # type: ignore
607
+ else:
608
+ seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
609
+ attn_masks = [None] * len(seq_lens)
610
+ attn_metadata.set_attn_bias(attn_masks, attn_type)
611
+
612
+ query = query.movedim(0, query.dim() - 2)
613
+ key = key.movedim(0, key.dim() - 2)
614
+ value = value.movedim(0, value.dim() - 2)
615
+
616
+ causal_attn = (attn_type == AttentionType.DECODER)
617
+
618
+ seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
619
+ start_q, start_kv = 0, 0
620
+ for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
621
+ attn_masks):
622
+ end_q = start_q + seq_len_q
623
+ end_kv = start_kv + seq_len_kv
624
+ sub_out = scaled_dot_product_attention(
625
+ query[None, :, start_q:end_q, :],
626
+ key[None, :, start_kv:end_kv, :],
627
+ value[None, :, start_kv:end_kv, :],
628
+ attn_mask=mask,
629
+ dropout_p=0.0,
630
+ is_causal=causal_attn and mask is None,
631
+ scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
632
+ output[start_q:end_q, :, :] = sub_out
633
+ start_q, start_kv = end_q, end_kv
634
+
635
+
636
+ def _make_alibi_bias(
637
+ alibi_slopes: torch.Tensor,
638
+ dtype: torch.dtype,
639
+ seq_lens: List[int],
640
+ ) -> List[torch.Tensor]:
641
+ attn_biases: List[torch.Tensor] = []
642
+ for seq_len in seq_lens:
643
+ bias = torch.arange(seq_len, dtype=dtype)
644
+ # NOTE(zhuohan): HF uses
645
+ # `bias = bias[None, :].repeat(seq_len, 1)`
646
+ # here. We find that both biases give the same results, but
647
+ # the bias below more accurately follows the original ALiBi
648
+ # paper.
649
+ bias = bias[None, :] - bias[:, None]
650
+
651
+ num_heads = alibi_slopes.shape[0]
652
+ bias = bias[None, :].repeat((num_heads, 1, 1))
653
+ bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
654
+ inf_mask = torch.empty(
655
+ (1, seq_len, seq_len),
656
+ dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
657
+ attn_biases.append((bias + inf_mask).to(dtype))
658
+
659
+ return attn_biases
660
+
661
+
662
+ def _make_sliding_window_bias(
663
+ seq_lens: List[int],
664
+ window_size: Optional[int],
665
+ dtype: torch.dtype,
666
+ ) -> List[torch.Tensor]:
667
+ attn_biases: List[torch.Tensor] = []
668
+ for seq_len in seq_lens:
669
+ tensor = torch.full(
670
+ (1, seq_len, seq_len),
671
+ dtype=dtype,
672
+ fill_value=1,
673
+ )
674
+ shift = 0
675
+ mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
676
+ if window_size is not None:
677
+ mask = torch.triu(mask, diagonal=shift - window_size + 1)
678
+ mask = torch.log(mask)
679
+ attn_biases.append(mask.to(dtype))
680
+
681
+ return attn_biases
.venv/lib/python3.11/site-packages/vllm/attention/backends/triton_mla.py ADDED
@@ -0,0 +1,746 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from collections import defaultdict
4
+ from contextlib import contextmanager
5
+ from dataclasses import dataclass
6
+ from itertools import accumulate
7
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
8
+
9
+ from vllm.multimodal import MultiModalPlaceholderMap
10
+
11
+ try:
12
+ from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper
13
+ FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
14
+ except ImportError:
15
+ BatchDecodeMlaWithPagedKVCacheWrapper = None
16
+ FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
17
+
18
+ import torch
19
+
20
+ from vllm import _custom_ops as ops
21
+ from vllm.attention.backends.abstract import (AttentionBackend,
22
+ AttentionMetadata,
23
+ AttentionMetadataBuilder,
24
+ AttentionState, AttentionType)
25
+ from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata
26
+ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
27
+ compute_slot_mapping_start_idx,
28
+ is_block_tables_empty)
29
+ from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
30
+ from vllm.utils import async_tensor_h2d, make_tensor_with_pad
31
+
32
+ if TYPE_CHECKING:
33
+ from vllm.worker.model_runner import (ModelInputForGPUBuilder,
34
+ ModelInputForGPUWithSamplingMetadata)
35
+
36
+
37
+ class TritonMLABackend(AttentionBackend):
38
+
39
+ @staticmethod
40
+ def get_name() -> str:
41
+ return "TRITON_MLA"
42
+
43
+ @staticmethod
44
+ def get_impl_cls() -> Type["TritonMLAImpl"]:
45
+ return TritonMLAImpl
46
+
47
+ @staticmethod
48
+ def get_metadata_cls() -> Type["AttentionMetadata"]:
49
+ return TritonMLAMetadata
50
+
51
+ @staticmethod
52
+ def get_builder_cls() -> Type["TritonMLAMetadataBuilder"]:
53
+ return TritonMLAMetadataBuilder
54
+
55
+ @staticmethod
56
+ def get_state_cls() -> Type["TritonMLAState"]:
57
+ return TritonMLAState
58
+
59
+ @staticmethod
60
+ def get_kv_cache_shape(
61
+ num_blocks: int,
62
+ block_size: int,
63
+ num_kv_heads: int, # assumed to be 1 for MLA
64
+ head_size: int,
65
+ ) -> Tuple[int, ...]:
66
+ return (num_blocks, block_size, head_size)
67
+
68
+ @staticmethod
69
+ def swap_blocks(
70
+ src_kv_cache: torch.Tensor,
71
+ dst_kv_cache: torch.Tensor,
72
+ src_to_dst: torch.Tensor,
73
+ ) -> None:
74
+ ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
75
+
76
+ @staticmethod
77
+ def copy_blocks(
78
+ kv_caches: List[torch.Tensor],
79
+ src_to_dists: torch.Tensor,
80
+ ) -> None:
81
+ ops.copy_blocks_mla(kv_caches, src_to_dists)
82
+
83
+ @staticmethod
84
+ def get_supported_head_sizes() -> List[int]:
85
+ return [576]
86
+
87
+
88
+ class TritonMLAState(AttentionState):
89
+
90
+ def __init__(self, runner):
91
+ self.runner = runner
92
+ self._is_graph_capturing = False
93
+
94
+ @contextmanager
95
+ def graph_capture(self, max_batch_size: int):
96
+ self._is_graph_capturing = True
97
+
98
+ self._graph_slot_mapping = torch.full((max_batch_size, ),
99
+ PAD_SLOT_ID,
100
+ dtype=torch.long,
101
+ device=self.runner.device)
102
+ self._graph_seq_lens = torch.ones(max_batch_size,
103
+ dtype=torch.int32,
104
+ device=self.runner.device)
105
+ self._graph_block_tables = torch.from_numpy(
106
+ self.runner.graph_block_tables).to(device=self.runner.device)
107
+
108
+ self._positions = torch.zeros((max_batch_size, ),
109
+ dtype=torch.long,
110
+ device=self.runner.device)
111
+
112
+ yield
113
+
114
+ self._is_graph_capturing = False
115
+ del self._graph_slot_mapping
116
+ del self._graph_seq_lens
117
+ del self._graph_block_tables
118
+ del self._positions
119
+
120
+ def graph_clone(self, batch_size: int):
121
+ assert self._is_graph_capturing
122
+ return self.__class__(self.runner)
123
+
124
+ def graph_capture_get_metadata_for_batch(
125
+ self, batch_size: int, is_encoder_decoder_model: bool = False):
126
+ assert self._is_graph_capturing
127
+
128
+ attn_metadata = self.runner.attn_backend.make_metadata(
129
+ num_prefills=0,
130
+ num_prefill_tokens=0,
131
+ num_decode_tokens=batch_size,
132
+ slot_mapping=self._graph_slot_mapping[:batch_size],
133
+ multi_modal_placeholder_index_maps=None,
134
+ enable_kv_scales_calculation=True,
135
+ seq_lens=None,
136
+ seq_lens_tensor=self._graph_seq_lens[:batch_size],
137
+ max_query_len=1,
138
+ max_decode_query_len=1,
139
+ max_prefill_seq_len=0,
140
+ max_decode_seq_len=self.runner.max_seq_len_to_capture,
141
+ query_start_loc=None,
142
+ seq_start_loc=None,
143
+ context_lens_tensor=None,
144
+ block_tables=self._graph_block_tables[:batch_size],
145
+ use_cuda_graph=True,
146
+ input_positions=self._positions[:batch_size],
147
+ head_dim=self.runner.model_config.get_head_size())
148
+
149
+ if is_encoder_decoder_model:
150
+ raise NotImplementedError(
151
+ "TritonMLAState does not support encoder/decoder yet")
152
+
153
+ return attn_metadata
154
+
155
+ def get_graph_input_buffers(self,
156
+ attn_metadata,
157
+ is_encoder_decoder_model: bool = False):
158
+ input_buffers = {
159
+ "slot_mapping": attn_metadata.slot_mapping,
160
+ "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
161
+ "block_tables": attn_metadata.decode_metadata.block_tables,
162
+ "input_positions": attn_metadata.decode_metadata.input_positions,
163
+ }
164
+ if is_encoder_decoder_model:
165
+ raise NotImplementedError(
166
+ "TritonMLAState does not support encoder/decoder yet")
167
+
168
+ return input_buffers
169
+
170
+ def prepare_graph_input_buffers(self,
171
+ input_buffers,
172
+ attn_metadata,
173
+ is_encoder_decoder_model: bool = False):
174
+ input_positions = attn_metadata.input_positions
175
+ num_positions = input_positions.shape[0]
176
+ input_buffers["seq_lens_tensor"].copy_(
177
+ attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
178
+ input_buffers["block_tables"].copy_(
179
+ attn_metadata.decode_metadata.block_tables, non_blocking=True)
180
+ # CUDA graph buffer is padded so only perform a partial copy based on
181
+ # num_positions
182
+ input_buffers["input_positions"][:num_positions].copy_(
183
+ input_positions, non_blocking=True)
184
+ if is_encoder_decoder_model:
185
+ raise NotImplementedError(
186
+ "TritonMLAState does not support encoder/decoder yet")
187
+
188
+ def begin_forward(self, model_input):
189
+ return
190
+
191
+
192
+ @dataclass
193
+ class TritonMLAMetadata(MLACommonMetadata):
194
+ """Metadata for TritonMLAMetadata.
195
+
196
+ NOTE: Any python object stored here is not updated when it is
197
+ cuda-graph replayed. If you have values that need to be changed
198
+ dynamically, it should be stored in tensor. The tensor has to be
199
+ updated from `CUDAGraphRunner.forward` API.
200
+ """
201
+ # (batch_size,). The sequence length per sequence. Sequence length means
202
+ # the computed tokens + new tokens None if it is a decoding.
203
+ seq_lens: Optional[List[int]]
204
+ # seq_lens stored as a tensor.
205
+ seq_lens_tensor: Optional[torch.Tensor]
206
+
207
+ # NOTE(sang): Definition of context_len, query_len, and seq_len.
208
+ # |---------- N-1 iteration --------|
209
+ # |---------------- N iteration ---------------------|
210
+ # |- tokenA -|......................|-- newTokens ---|
211
+ # |---------- context_len ----------|
212
+ # |-------------------- seq_len ---------------------|
213
+ # |-- query_len ---|
214
+
215
+ # Maximum sequence length among prefill batch. 0 if there are decoding
216
+ # requests only.
217
+ max_prefill_seq_len: int
218
+ # Maximum sequence length among decode batch. 0 if there are prefill
219
+ # requests only.
220
+ max_decode_seq_len: int
221
+ # (batch_size,) A tensor of context lengths (tokens that are computed
222
+ # so far).
223
+ context_lens_tensor: Optional[torch.Tensor]
224
+
225
+ # (batch_size, max_blocks_per_seq).
226
+ # Block addresses per sequence. (Seq id -> list of physical block)
227
+ # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
228
+ # in the kv cache. Each block can contain up to block_size tokens.
229
+ # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
230
+ # captured.
231
+ block_tables: Optional[torch.Tensor]
232
+
233
+ # Whether or not if cuda graph is enabled.
234
+ # Cuda-graph is currently enabled for decoding only.
235
+ # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
236
+
237
+ use_cuda_graph: bool
238
+
239
+ # Maximum query length in the batch.
240
+ max_query_len: Optional[int] = None
241
+
242
+ # Max number of query tokens among request in the batch.
243
+ max_decode_query_len: Optional[int] = None
244
+
245
+ # (batch_size + 1,). The cumulative subquery lengths of the sequences in
246
+ # the batch, used to index into subquery. E.g., if the subquery length
247
+ # is [4, 6], it is [0, 4, 10].
248
+ query_start_loc: Optional[torch.Tensor] = None
249
+ # (batch_size + 1,). The cumulative sequence lengths of the sequences in
250
+ # the batch, used to index into sequence. E.g., if the sequence length is
251
+ # [4, 6], it is [0, 4, 10].
252
+ seq_start_loc: Optional[torch.Tensor] = None
253
+
254
+ _cached_prefill_metadata: Optional["TritonMLAMetadata"] = None
255
+ _cached_decode_metadata: Optional["TritonMLAMetadata"] = None
256
+
257
+ num_prefill_tokens: int
258
+
259
+ num_kv_splits: int = 4 # TODO(lucas) add heuristic
260
+ attn_logits: Optional[torch.Tensor] = None
261
+ req_idx: Optional[torch.Tensor] = None
262
+
263
+ # The dimension of the attention heads
264
+ head_dim: Optional[int] = None
265
+
266
+ def __post_init__(self):
267
+ supported_head_sizes = TritonMLABackend.get_supported_head_sizes()
268
+ if self.head_dim is not None and self.head_dim \
269
+ not in supported_head_sizes:
270
+ raise ValueError(
271
+ f"Only {supported_head_sizes} are supported for head_dim,",
272
+ f"received {self.head_dim}.")
273
+
274
+ @property
275
+ def prefill_metadata(self) -> Optional["TritonMLAMetadata"]:
276
+ if self.num_prefills == 0:
277
+ return None
278
+
279
+ if self._cached_prefill_metadata is not None:
280
+ return self._cached_prefill_metadata
281
+
282
+ assert self.seq_lens is not None
283
+ assert self.seq_lens_tensor is not None
284
+
285
+ # Compute some attn_metadata fields which default to None
286
+ query_start_loc = (None if self.query_start_loc is None else
287
+ self.query_start_loc[:self.num_prefills + 1])
288
+ slot_mapping = (None if self.slot_mapping is None else
289
+ self.slot_mapping[:self.num_prefill_tokens])
290
+ seq_lens = (None if self.seq_lens is None else
291
+ self.seq_lens[:self.num_prefills])
292
+ seq_lens_tensor = (None if self.seq_lens_tensor is None else
293
+ self.seq_lens_tensor[:self.num_prefills])
294
+ seq_start_loc = (None if self.seq_start_loc is None else
295
+ self.seq_start_loc[:self.num_prefills + 1])
296
+ context_lens_tensor = (None if self.context_lens_tensor is None else
297
+ self.context_lens_tensor[:self.num_prefills])
298
+ block_tables = (None if self.block_tables is None else
299
+ self.block_tables[:self.num_prefills])
300
+ input_positions = (None if self.input_positions is None else
301
+ self.input_positions[:self.num_prefill_tokens])
302
+
303
+ self._cached_prefill_metadata = TritonMLAMetadata(
304
+ num_prefills=self.num_prefills,
305
+ num_prefill_tokens=self.num_prefill_tokens,
306
+ num_decode_tokens=0,
307
+ slot_mapping=slot_mapping,
308
+ multi_modal_placeholder_index_maps=self.
309
+ multi_modal_placeholder_index_maps,
310
+ enable_kv_scales_calculation=self.enable_kv_scales_calculation,
311
+ input_positions=input_positions,
312
+ seq_lens=seq_lens,
313
+ seq_lens_tensor=seq_lens_tensor,
314
+ max_query_len=self.max_query_len,
315
+ max_prefill_seq_len=self.max_prefill_seq_len,
316
+ max_decode_query_len=0,
317
+ max_decode_seq_len=0,
318
+ query_start_loc=query_start_loc,
319
+ seq_start_loc=seq_start_loc,
320
+ context_lens_tensor=context_lens_tensor,
321
+ block_tables=block_tables,
322
+ use_cuda_graph=False,
323
+ head_dim=self.head_dim)
324
+ return self._cached_prefill_metadata
325
+
326
+ @property
327
+ def decode_metadata(self) -> Optional["TritonMLAMetadata"]:
328
+ if self.num_decode_tokens == 0:
329
+ return None
330
+
331
+ if self._cached_decode_metadata is not None:
332
+ return self._cached_decode_metadata
333
+ assert self.seq_lens_tensor is not None
334
+
335
+ # Compute some attn_metadata fields which default to None
336
+ slot_mapping = (None if self.slot_mapping is None else
337
+ self.slot_mapping[self.num_prefill_tokens:])
338
+ seq_lens_tensor = (None if self.seq_lens_tensor is None else
339
+ self.seq_lens_tensor[self.num_prefills:])
340
+ block_tables = (None if self.block_tables is None else
341
+ self.block_tables[self.num_prefills:])
342
+ input_positions = (None if self.input_positions is None else
343
+ self.input_positions[self.num_prefill_tokens:])
344
+
345
+ self._cached_decode_metadata = TritonMLAMetadata(
346
+ num_prefills=0,
347
+ num_prefill_tokens=0,
348
+ num_decode_tokens=self.num_decode_tokens,
349
+ slot_mapping=slot_mapping,
350
+ multi_modal_placeholder_index_maps=None,
351
+ enable_kv_scales_calculation=True,
352
+ seq_lens=None,
353
+ seq_lens_tensor=seq_lens_tensor,
354
+ max_decode_query_len=self.max_decode_query_len,
355
+ max_query_len=self.max_query_len,
356
+ max_prefill_seq_len=0,
357
+ max_decode_seq_len=self.max_decode_seq_len,
358
+ # Batch may be composed of prefill|decodes, adjust query start
359
+ # indices to refer to the start of decodes. E.g.
360
+ # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
361
+ query_start_loc=(self.query_start_loc[self.num_prefills:] -
362
+ self.query_start_loc[self.num_prefills])
363
+ if self.query_start_loc is not None else None,
364
+ seq_start_loc=self.seq_start_loc[self.num_prefills:]
365
+ if self.seq_start_loc is not None else None,
366
+ context_lens_tensor=None,
367
+ block_tables=block_tables,
368
+ use_cuda_graph=self.use_cuda_graph,
369
+ input_positions=input_positions,
370
+ head_dim=self.head_dim)
371
+ return self._cached_decode_metadata
372
+
373
+ def advance_step(self,
374
+ model_input: "ModelInputForGPUWithSamplingMetadata",
375
+ sampled_token_ids: Optional[torch.Tensor],
376
+ block_size: int,
377
+ num_seqs: int,
378
+ num_queries: int,
379
+ turn_prefills_into_decodes: bool = False):
380
+ """
381
+ Update metadata in-place to advance one decode step.
382
+ """
383
+ # When using cudagraph, the num_seqs is padded to the next captured
384
+ # batch sized, but num_queries tracks the actual number of requests in
385
+ # the batch. For --enforce-eager mode, num_seqs == num_queries
386
+ if num_seqs != num_queries:
387
+ assert num_seqs > num_queries
388
+ assert self.use_cuda_graph
389
+
390
+ if turn_prefills_into_decodes:
391
+ # When Mutli-Step is enabled with Chunked-Prefill, prefills and
392
+ # decodes are scheduled together. In the first step, all the
393
+ # prefills turn into decodes. This update reflects that
394
+ # conversion.
395
+ assert self.num_decode_tokens + self.num_prefills == num_seqs
396
+ self.num_decode_tokens += self.num_prefills
397
+ self.num_prefills = 0
398
+ self.num_prefill_tokens = 0
399
+ self.max_prefill_seq_len = 0
400
+ self.max_query_len = 1
401
+
402
+ self.slot_mapping = self.slot_mapping[:num_seqs]
403
+ else:
404
+ assert self.seq_lens is not None
405
+ assert self.max_decode_seq_len == max(self.seq_lens)
406
+
407
+ assert self.num_prefills == 0
408
+ assert self.num_prefill_tokens == 0
409
+ assert self.num_decode_tokens == num_seqs
410
+ assert self.slot_mapping.shape == (num_seqs, )
411
+
412
+ assert self.seq_lens is not None
413
+ assert len(self.seq_lens) == num_seqs
414
+ assert self.seq_lens_tensor is not None
415
+ assert self.seq_lens_tensor.shape == (num_seqs, )
416
+ assert self.max_query_len == 1
417
+ assert self.max_prefill_seq_len == 0
418
+
419
+ assert self.query_start_loc is not None
420
+ assert self.query_start_loc.shape == (num_queries + 1, )
421
+ assert self.seq_start_loc is not None
422
+ assert self.seq_start_loc.shape == (num_seqs + 1, )
423
+
424
+ assert self.context_lens_tensor is not None
425
+ assert self.context_lens_tensor.shape == (num_queries, )
426
+
427
+ assert self.block_tables is not None
428
+ assert self.block_tables.shape[0] == num_seqs
429
+
430
+ # Update query lengths. Note that we update only queries and not seqs,
431
+ # since tensors may be padded due to captured cuda graph batch size
432
+ for i in range(num_queries):
433
+ self.seq_lens[i] += 1
434
+ self.max_decode_seq_len = max(self.seq_lens)
435
+
436
+ ops.advance_step_flashattn(num_seqs=num_seqs,
437
+ num_queries=num_queries,
438
+ block_size=block_size,
439
+ input_tokens=model_input.input_tokens,
440
+ sampled_token_ids=sampled_token_ids,
441
+ input_positions=model_input.input_positions,
442
+ seq_lens=self.seq_lens_tensor,
443
+ slot_mapping=self.slot_mapping,
444
+ block_tables=self.block_tables)
445
+
446
+
447
+ class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
448
+
449
+ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
450
+ self.input_builder = input_builder
451
+ self.runner = input_builder.runner
452
+ self.sliding_window = input_builder.sliding_window
453
+ self.block_size = input_builder.block_size
454
+
455
+ def prepare(self):
456
+ self.slot_mapping: List[int] = []
457
+ self.prefill_seq_lens: List[int] = []
458
+ self.context_lens: List[int] = []
459
+ self.block_tables: List[List[int]] = []
460
+ self.curr_seq_lens: List[int] = []
461
+ self.input_positions: List[int] = []
462
+ self.multimodal_placeholder_maps: Dict[
463
+ str,
464
+ MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
465
+ self.num_prefills = 0
466
+ self.num_prefill_tokens = 0
467
+ self.num_decode_tokens = 0
468
+ self.has_prefix_cache_hit = False
469
+
470
+ def _add_seq_group(
471
+ self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
472
+ chunked_prefill_enabled: bool, prefix_cache_hit: bool):
473
+ """Add a sequence group to the metadata. Specifically update/append
474
+ 1. context length.
475
+ 2. block table.
476
+ 3. slot mapping.
477
+ """
478
+ is_prompt = inter_data.is_prompt
479
+ block_tables = inter_data.block_tables
480
+
481
+ for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
482
+ curr_sliding_window_block, input_positions) in zip(
483
+ inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
484
+ inter_data.orig_seq_lens, inter_data.seq_lens,
485
+ inter_data.query_lens, inter_data.context_lens,
486
+ inter_data.curr_sliding_window_blocks,
487
+ inter_data.input_positions):
488
+ self.input_positions.extend(input_positions)
489
+ self.context_lens.append(context_len)
490
+ if is_prompt:
491
+ mm_maps = inter_data.multi_modal_placeholder_maps
492
+ if mm_maps:
493
+ for modality, placeholders in mm_maps.items():
494
+ self.multimodal_placeholder_maps[modality].extend(
495
+ placeholders)
496
+
497
+ self.num_prefills += 1
498
+ self.num_prefill_tokens += token_len
499
+ self.prefill_seq_lens.append(seq_len)
500
+ else:
501
+ self.num_decode_tokens += query_len
502
+ self.curr_seq_lens.append(curr_seq_len)
503
+
504
+ # Compute block table.
505
+ # TODO(sang): Combine chunked prefill and prefix caching by
506
+ # only allowing multiple of block_size chunk size.
507
+ # NOTE: This only works for oooooooxxx style attention.
508
+ block_table = []
509
+ if prefix_cache_hit:
510
+ # NOTE(woosuk): For flash-attn, the block table should
511
+ # include the entries for the incoming prefill tokens.
512
+ block_table = block_tables[seq_id]
513
+ elif ((chunked_prefill_enabled or not is_prompt)
514
+ and block_tables is not None):
515
+ if curr_sliding_window_block == 0:
516
+ block_table = block_tables[seq_id]
517
+ else:
518
+ block_table = block_tables[seq_id][
519
+ -curr_sliding_window_block:]
520
+ self.block_tables.append(block_table)
521
+
522
+ # Compute slot mapping.
523
+ is_profile_run = is_block_tables_empty(block_tables)
524
+ start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
525
+ context_len,
526
+ self.sliding_window)
527
+ compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
528
+ seq_len, context_len, start_idx,
529
+ self.block_size, inter_data.block_tables)
530
+
531
+ def _get_graph_runner_block_tables(
532
+ self, num_seqs: int,
533
+ block_tables: List[List[int]]) -> torch.Tensor:
534
+ # The shape of graph_block_tables is
535
+ # [max batch size, max context len // block size].
536
+ max_batch_size, max_blocks = self.runner.graph_block_tables.shape
537
+ assert max_batch_size >= num_seqs
538
+
539
+ graph_block_tables = self.runner.graph_block_tables[:num_seqs]
540
+ for i, block_table in enumerate(block_tables):
541
+ if block_table:
542
+ num_blocks = len(block_table)
543
+ if num_blocks <= max_blocks:
544
+ graph_block_tables[i, :num_blocks] = block_table
545
+ else:
546
+ # It may be possible to have more blocks allocated due
547
+ # to lookahead slots of multi-step, however, they are
548
+ # not used anyway, so can be safely ignored.
549
+ graph_block_tables[
550
+ i, :max_blocks] = block_table[:max_blocks]
551
+
552
+ return torch.from_numpy(graph_block_tables).to(
553
+ device=self.runner.device, non_blocking=True)
554
+
555
+ def build(self, seq_lens: List[int], query_lens: List[int],
556
+ cuda_graph_pad_size: int, batch_size: int):
557
+ """Build attention metadata with on-device tensors.
558
+
559
+ Args:
560
+ seq_lens: The maybe padded sequence lengths of the input sequences.
561
+ query_lens: The query lengths of the input sequences.
562
+ cuda_graph_pad_size: The padding size for cuda graph.
563
+ -1 if cuda graph is not used.
564
+ batch_size: The maybe padded batch size.
565
+ """
566
+ prefix_cache_hit = any([
567
+ inter_data.prefix_cache_hit
568
+ for inter_data in self.input_builder.inter_data_list
569
+ ])
570
+ for inter_data in self.input_builder.inter_data_list:
571
+ self._add_seq_group(inter_data,
572
+ self.input_builder.chunked_prefill_enabled,
573
+ prefix_cache_hit)
574
+
575
+ device = self.runner.device
576
+ use_captured_graph = cuda_graph_pad_size != -1
577
+
578
+ max_query_len = max(query_lens)
579
+ decode_query_lens = query_lens[self.num_prefills:]
580
+ if len(decode_query_lens) > 0:
581
+ max_decode_query_len = max(decode_query_lens)
582
+ else:
583
+ max_decode_query_len = 1
584
+ max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
585
+ max_decode_seq_len = max(self.curr_seq_lens, default=0)
586
+ num_decode_tokens = self.num_decode_tokens
587
+ query_start_loc = list(accumulate(query_lens, initial=0))
588
+ seq_start_loc = list(accumulate(seq_lens, initial=0))
589
+
590
+ num_seqs = len(seq_lens)
591
+ if use_captured_graph:
592
+ self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
593
+ self.block_tables.extend([] * cuda_graph_pad_size)
594
+ num_decode_tokens = batch_size - self.num_prefill_tokens
595
+ block_tables = self._get_graph_runner_block_tables(
596
+ num_seqs, self.block_tables)
597
+ else:
598
+ block_tables = make_tensor_with_pad(
599
+ self.block_tables,
600
+ pad=0,
601
+ dtype=torch.int,
602
+ device=device,
603
+ )
604
+ assert max_query_len > 0, ("query_lens: {}".format(query_lens))
605
+
606
+ assert device is not None
607
+ context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
608
+ device, self.runner.pin_memory)
609
+ seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
610
+ self.runner.pin_memory)
611
+ input_positions = async_tensor_h2d(self.input_positions, torch.long,
612
+ device, self.runner.pin_memory)
613
+ slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
614
+ device, self.runner.pin_memory)
615
+ query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
616
+ device,
617
+ self.runner.pin_memory)
618
+ seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
619
+ device, self.runner.pin_memory)
620
+ placeholder_index_maps = {
621
+ modality: placeholder_map.index_map()
622
+ for modality, placeholder_map in
623
+ self.multimodal_placeholder_maps.items()
624
+ }
625
+
626
+ return TritonMLAMetadata(
627
+ num_prefills=self.num_prefills,
628
+ slot_mapping=slot_mapping_tensor,
629
+ num_prefill_tokens=self.num_prefill_tokens,
630
+ num_decode_tokens=num_decode_tokens,
631
+ seq_lens=seq_lens,
632
+ multi_modal_placeholder_index_maps=placeholder_index_maps,
633
+ enable_kv_scales_calculation=True,
634
+ input_positions=input_positions,
635
+ seq_lens_tensor=seq_lens_tensor,
636
+ max_query_len=max_query_len,
637
+ max_decode_query_len=max_decode_query_len,
638
+ max_prefill_seq_len=max_prefill_seq_len,
639
+ max_decode_seq_len=max_decode_seq_len,
640
+ query_start_loc=query_start_loc_tensor,
641
+ seq_start_loc=seq_start_loc_tensor,
642
+ context_lens_tensor=context_lens_tensor,
643
+ block_tables=block_tables,
644
+ use_cuda_graph=use_captured_graph,
645
+ num_kv_splits=4, # TODO(lucas) add heuristic
646
+ head_dim=self.runner.model_config.get_head_size(),
647
+ )
648
+
649
+
650
+ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
651
+
652
+ def __init__(
653
+ self,
654
+ num_heads: int,
655
+ head_size: int,
656
+ scale: float,
657
+ num_kv_heads: int,
658
+ alibi_slopes: Optional[List[float]],
659
+ sliding_window: Optional[int],
660
+ kv_cache_dtype: str,
661
+ blocksparse_params: Optional[Dict[str, Any]],
662
+ logits_soft_cap: Optional[float],
663
+ attn_type: str,
664
+ # MLA Specific Arguments
665
+ **kwargs) -> None:
666
+ super().__init__(num_heads, head_size, scale, num_kv_heads,
667
+ alibi_slopes, sliding_window, kv_cache_dtype,
668
+ blocksparse_params, logits_soft_cap, attn_type,
669
+ **kwargs)
670
+
671
+ unsupported_features = [
672
+ alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
673
+ ]
674
+ if any(unsupported_features):
675
+ raise NotImplementedError(
676
+ "TritonMLAImpl does not support one of the following: "
677
+ "alibi_slopes, sliding_window, blocksparse_params, "
678
+ "logits_soft_cap")
679
+
680
+ if attn_type != AttentionType.DECODER:
681
+ raise NotImplementedError("Encoder self-attention and "
682
+ "encoder/decoder cross-attention "
683
+ "are not implemented for "
684
+ "TritonMLAImpl")
685
+
686
+ def _forward_prefill(
687
+ self,
688
+ q: torch.Tensor,
689
+ kv_c_normed: torch.Tensor,
690
+ k_pe: torch.Tensor,
691
+ attn_metadata: TritonMLAMetadata,
692
+ ) -> torch.Tensor:
693
+ assert isinstance(attn_metadata, TritonMLAMetadata)
694
+ return self._forward_prefill_flash(q, kv_c_normed, k_pe,
695
+ attn_metadata.seq_start_loc,
696
+ attn_metadata.max_prefill_seq_len)
697
+
698
+ def _forward_decode(
699
+ self,
700
+ q_nope: torch.Tensor,
701
+ q_pe: torch.Tensor,
702
+ kv_c_and_k_pe_cache: torch.Tensor,
703
+ attn_metadata: TritonMLAMetadata,
704
+ ) -> torch.Tensor:
705
+ assert kv_c_and_k_pe_cache.numel() > 0
706
+ if self.kv_cache_dtype.startswith("fp8"):
707
+ raise NotImplementedError("FP8 Triton MLA not yet supported")
708
+
709
+ decode_meta = attn_metadata.decode_metadata
710
+ assert decode_meta is not None
711
+ B = q_nope.shape[0]
712
+
713
+ q = torch.cat([q_nope, q_pe], dim=-1)
714
+ o = torch.zeros(B,
715
+ self.num_heads,
716
+ self.kv_lora_rank,
717
+ dtype=q.dtype,
718
+ device=q.device)
719
+
720
+ # TODO(lucas) Allocate ahead of time
721
+ attn_logits = torch.empty(
722
+ (
723
+ B,
724
+ self.num_heads,
725
+ attn_metadata.num_kv_splits,
726
+ # NOTE(lucas) idk why the +1 is here but sglang has it so we
727
+ # just mirror that
728
+ self.kv_lora_rank + 1,
729
+ ),
730
+ dtype=torch.float32,
731
+ device=q.device,
732
+ )
733
+
734
+ # Add a head dim of 1
735
+ kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
736
+ kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
737
+ PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
738
+
739
+ # Run MQA
740
+ decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
741
+ decode_meta.block_tables,
742
+ decode_meta.seq_lens_tensor, attn_logits,
743
+ attn_metadata.num_kv_splits, self.scale,
744
+ PAGE_SIZE)
745
+
746
+ return self._v_up_proj_and_o_proj(o)
.venv/lib/python3.11/site-packages/vllm/attention/backends/utils.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Attention backend utils"""
3
+ from collections import defaultdict
4
+ from contextlib import contextmanager
5
+ from itertools import accumulate
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
12
+ AttentionState)
13
+ from vllm.attention.backends.abstract import AttentionType
14
+ from vllm.multimodal import MultiModalPlaceholderMap
15
+ from vllm.utils import async_tensor_h2d, make_tensor_with_pad
16
+
17
+ if TYPE_CHECKING:
18
+ from vllm.worker.model_runner_base import ModelRunnerBase
19
+
20
+ # Error string(s) for encoder/decoder
21
+ # unsupported attention scenarios
22
+ STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
23
+ "with encoder/decoder models.")
24
+
25
+ PAD_SLOT_ID = -1
26
+
27
+ # Switch to numpy implementation of compute_slot_mapping
28
+ # if we have at least this many elements. Could be tuned further.
29
+ _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256
30
+
31
+ if TYPE_CHECKING:
32
+ from vllm.worker.model_runner import ModelInputForGPUBuilder
33
+
34
+
35
+ def is_block_tables_empty(block_tables: Union[None, Dict]):
36
+ """
37
+ Check if block_tables is None or a dictionary with all None values.
38
+ """
39
+ if block_tables is None:
40
+ return True
41
+ return (isinstance(block_tables, dict)
42
+ and all(value is None for value in block_tables.values()))
43
+
44
+
45
+ def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
46
+ context_len: int, sliding_window: int):
47
+ """
48
+ Compute the start index of slot mapping.
49
+ """
50
+ start_idx = 0
51
+ if is_prompt and sliding_window is not None:
52
+ start_idx = max(0, query_len - sliding_window)
53
+ return start_idx
54
+
55
+
56
+ def _compute_slot_mapping_python(slot_mapping: List[int],
57
+ block_table: List[int], range_start: int,
58
+ range_end: int, block_size: int):
59
+ for i in range(range_start, range_end):
60
+ block_number = block_table[i // block_size]
61
+ block_offset = i % block_size
62
+ slot = block_number * block_size + block_offset
63
+ slot_mapping.append(slot)
64
+
65
+
66
+ def _compute_slot_mapping_numpy(slot_mapping: List[int],
67
+ block_table: List[int], range_start: int,
68
+ range_end: int, block_size: int):
69
+ block_table_array = np.array(block_table)
70
+ idx = np.arange(range_start, range_end)
71
+ block_offset = idx % block_size
72
+ idx //= block_size
73
+ seq_slot_mapping_array = block_table_array[idx]
74
+ seq_slot_mapping_array *= block_size
75
+ seq_slot_mapping_array += block_offset
76
+ slot_mapping.extend(seq_slot_mapping_array)
77
+
78
+
79
+ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
80
+ seq_id: int, seq_len: int, context_len: int,
81
+ start_idx: int, block_size: int,
82
+ block_tables: Dict[int, List[int]]):
83
+ """
84
+ Compute slot mapping.
85
+ """
86
+ if is_profile_run:
87
+ # During memory profiling, the block tables are not
88
+ # initialized yet. In this case, we just use a dummy
89
+ # slot mapping.
90
+ # In embeddings, the block tables are {seq_id: None}.
91
+ slot_mapping.extend([PAD_SLOT_ID] * seq_len)
92
+ return
93
+
94
+ # Mask the [0, start_idx) tokens of the prompt with
95
+ # PAD_SLOT_ID, where start_idx is max(0, seq_len -
96
+ # sliding_window). For example, if the prompt len is 10,
97
+ # sliding window is 8, and block size is 4, the first two
98
+ # tokens are masked and the slot mapping will be
99
+ # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
100
+ padding_mask_len = max(0, start_idx - context_len)
101
+ slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len)
102
+
103
+ range_start = max(start_idx, context_len)
104
+ range_end = seq_len
105
+ numel = range_end - range_start
106
+ block_table = block_tables[seq_id]
107
+
108
+ # numpy implementation will be faster than python if we have
109
+ # many elements, otherwise it will be slower.
110
+ if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL:
111
+ _compute_slot_mapping_python(slot_mapping, block_table, range_start,
112
+ range_end, block_size)
113
+ else:
114
+ _compute_slot_mapping_numpy(slot_mapping, block_table, range_start,
115
+ range_end, block_size)
116
+
117
+
118
+ TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
119
+
120
+
121
+ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
122
+
123
+ _metadata_cls: Type[TAttentionMetadata]
124
+
125
+ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
126
+ self.input_builder = input_builder
127
+ self.runner = input_builder.runner
128
+
129
+ self.sliding_window = input_builder.sliding_window
130
+ self.block_size = input_builder.block_size
131
+
132
+ def prepare(self):
133
+ self.slot_mapping: List[int] = []
134
+ self.prefill_seq_lens: List[int] = []
135
+ self.context_lens: List[int] = []
136
+ self.block_tables: List[List[int]] = []
137
+ self.curr_seq_lens: List[int] = []
138
+ self.multimodal_placeholder_maps: Dict[
139
+ str,
140
+ MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
141
+ self.num_prefills = 0
142
+ self.num_prefill_tokens = 0
143
+ self.num_decode_tokens = 0
144
+
145
+ def _add_seq_group(
146
+ self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
147
+ chunked_prefill_enabled: bool):
148
+ is_prompt = inter_data.is_prompt
149
+ block_tables = inter_data.block_tables
150
+
151
+ for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
152
+ curr_sliding_window_block) in zip(
153
+ inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
154
+ inter_data.orig_seq_lens, inter_data.seq_lens,
155
+ inter_data.query_lens, inter_data.context_lens,
156
+ inter_data.curr_sliding_window_blocks):
157
+ self.context_lens.append(context_len)
158
+ if is_prompt:
159
+ mm_maps = inter_data.multi_modal_placeholder_maps
160
+ if mm_maps:
161
+ for modality, placeholders in mm_maps.items():
162
+ self.multimodal_placeholder_maps[modality].extend(
163
+ placeholders)
164
+
165
+ self.num_prefills += 1
166
+ self.num_prefill_tokens += token_len
167
+ self.prefill_seq_lens.append(seq_len)
168
+ else:
169
+ assert query_len == 1, (
170
+ "seq_len: {}, context_len: {}, query_len: {}".format(
171
+ seq_len, context_len, query_len))
172
+ self.num_decode_tokens += query_len
173
+ self.curr_seq_lens.append(curr_seq_len)
174
+
175
+ # Compute block table.
176
+ # TODO(sang): Combine chunked prefill and prefix caching by
177
+ # only allowing multiple of block_size chunk size.
178
+ # NOTE: This only works for oooooooxxx style attention.
179
+ block_table = []
180
+ if inter_data.prefix_cache_hit:
181
+ block_table = block_tables[seq_id]
182
+ elif ((chunked_prefill_enabled or not is_prompt)
183
+ and block_tables is not None):
184
+ if curr_sliding_window_block == 0:
185
+ block_table = block_tables[seq_id]
186
+ else:
187
+ block_table = block_tables[seq_id][
188
+ -curr_sliding_window_block:]
189
+ self.block_tables.append(block_table)
190
+
191
+ # Compute slot mapping.
192
+ is_profile_run = is_block_tables_empty(block_tables)
193
+ start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
194
+ context_len,
195
+ self.sliding_window)
196
+ compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
197
+ seq_len, context_len, start_idx,
198
+ self.block_size, inter_data.block_tables)
199
+
200
+ def build(self, seq_lens: List[int], query_lens: List[int],
201
+ cuda_graph_pad_size: int, batch_size: int):
202
+ """Build attention metadata with on-device tensors.
203
+
204
+ Args:
205
+ seq_lens: The maybe padded sequence lengths of the input sequences.
206
+ query_lens: The query lengths of the input sequences.
207
+ cuda_graph_pad_size: The padding size for cuda graph.
208
+ -1 if cuda graph is not used.
209
+ batch_size: The maybe padded batch size.
210
+ """
211
+ for inter_data in self.input_builder.inter_data_list:
212
+ self._add_seq_group(inter_data,
213
+ self.input_builder.chunked_prefill_enabled)
214
+
215
+ device = self.runner.device
216
+ use_captured_graph = cuda_graph_pad_size != -1
217
+
218
+ max_query_len = max(query_lens)
219
+ max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
220
+ max_decode_seq_len = max(self.curr_seq_lens, default=0)
221
+ num_decode_tokens = self.num_decode_tokens
222
+ query_start_loc = list(accumulate(query_lens, initial=0))
223
+ seq_start_loc = list(accumulate(seq_lens, initial=0))
224
+
225
+ if use_captured_graph:
226
+ self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
227
+ self.block_tables.extend([] * cuda_graph_pad_size)
228
+ num_decode_tokens = batch_size
229
+
230
+ # The shape of graph_block_tables is
231
+ # [max batch size, max context len // block size].
232
+ input_block_tables = self.runner.graph_block_tables[:batch_size]
233
+ for i, block_table in enumerate(self.block_tables):
234
+ if block_table:
235
+ input_block_tables[i, :len(block_table)] = block_table
236
+ block_tables = torch.from_numpy(input_block_tables).to(
237
+ device, non_blocking=True)
238
+ else:
239
+ block_tables = make_tensor_with_pad(
240
+ self.block_tables,
241
+ pad=0,
242
+ dtype=torch.int,
243
+ device=device,
244
+ )
245
+ assert max_query_len > 0, "query_lens: {}".format(query_lens)
246
+
247
+ assert device is not None
248
+ context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
249
+ device, self.runner.pin_memory)
250
+ seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
251
+ self.runner.pin_memory)
252
+ slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
253
+ device, self.runner.pin_memory)
254
+ query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
255
+ device,
256
+ self.runner.pin_memory)
257
+ seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
258
+ device, self.runner.pin_memory)
259
+ placeholder_index_maps = {
260
+ modality: placeholder_map.index_map()
261
+ for modality, placeholder_map in
262
+ self.multimodal_placeholder_maps.items()
263
+ }
264
+
265
+ return self._metadata_cls( # type: ignore
266
+ num_prefills=self.num_prefills,
267
+ slot_mapping=slot_mapping_tensor,
268
+ multi_modal_placeholder_index_maps=placeholder_index_maps,
269
+ enable_kv_scales_calculation=True,
270
+ num_prefill_tokens=self.num_prefill_tokens,
271
+ num_decode_tokens=num_decode_tokens,
272
+ seq_lens=seq_lens,
273
+ seq_lens_tensor=seq_lens_tensor,
274
+ max_query_len=max_query_len,
275
+ max_prefill_seq_len=max_prefill_seq_len,
276
+ max_decode_seq_len=max_decode_seq_len,
277
+ query_start_loc=query_start_loc_tensor,
278
+ seq_start_loc=seq_start_loc_tensor,
279
+ context_lens_tensor=context_lens_tensor,
280
+ block_tables=block_tables,
281
+ use_cuda_graph=use_captured_graph,
282
+ )
283
+
284
+
285
+ class CommonAttentionState(AttentionState):
286
+
287
+ def __init__(self, runner: "ModelRunnerBase"):
288
+ self.runner = runner
289
+ self._is_graph_capturing = False
290
+
291
+ @contextmanager
292
+ def graph_capture(self, max_batch_size: int):
293
+
294
+ self._is_graph_capturing = True
295
+
296
+ self._graph_slot_mapping = torch.full((max_batch_size, ),
297
+ PAD_SLOT_ID,
298
+ dtype=torch.long,
299
+ device=self.runner.device)
300
+ self._graph_seq_lens = torch.ones(max_batch_size,
301
+ dtype=torch.int32,
302
+ device=self.runner.device)
303
+ self._graph_block_tables = torch.from_numpy(
304
+ self.runner.graph_block_tables).to(device=self.runner.device)
305
+
306
+ yield
307
+
308
+ self._is_graph_capturing = False
309
+ del self._graph_slot_mapping
310
+ del self._graph_seq_lens
311
+ del self._graph_block_tables
312
+
313
+ def graph_clone(self, batch_size: int) -> "CommonAttentionState":
314
+ assert self._is_graph_capturing
315
+ return self.__class__(self.runner)
316
+
317
+ def graph_capture_get_metadata_for_batch(
318
+ self, batch_size: int, is_encoder_decoder_model: bool = False):
319
+ assert self._is_graph_capturing
320
+ attn_metadata = self.runner.attn_backend.make_metadata(
321
+ num_prefills=0,
322
+ num_prefill_tokens=0,
323
+ num_decode_tokens=batch_size,
324
+ slot_mapping=self._graph_slot_mapping[:batch_size],
325
+ multi_modal_placeholder_index_maps=None,
326
+ enable_kv_scales_calculation=True,
327
+ seq_lens=None,
328
+ seq_lens_tensor=self._graph_seq_lens[:batch_size],
329
+ max_query_len=1,
330
+ max_decode_query_len=1,
331
+ max_prefill_seq_len=0,
332
+ max_decode_seq_len=self.runner.max_seq_len_to_capture,
333
+ query_start_loc=None,
334
+ seq_start_loc=None,
335
+ context_lens_tensor=None,
336
+ block_tables=self._graph_block_tables[:batch_size],
337
+ use_cuda_graph=True,
338
+ )
339
+ if is_encoder_decoder_model:
340
+ # The encoder decoder model works only with XFormers and
341
+ # Flash Attention backend. Assert the same.
342
+ assert self.runner.attn_backend.get_name() in\
343
+ ["XFORMERS", "FLASH_ATTN"], \
344
+ f"Expected attn_backend name to be either 'XFORMERS' or " \
345
+ f"'FLASH_ATTN', but "\
346
+ f"got '{self.runner.attn_backend.get_name()}'"
347
+ self._update_captured_metadata_for_enc_dec_model(
348
+ batch_size=batch_size, attn_metadata=attn_metadata)
349
+
350
+ return attn_metadata
351
+
352
+ def get_graph_input_buffers(
353
+ self,
354
+ attn_metadata,
355
+ is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
356
+ input_buffers = {
357
+ "slot_mapping": attn_metadata.slot_mapping,
358
+ "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
359
+ "block_tables": attn_metadata.decode_metadata.block_tables,
360
+ }
361
+ if is_encoder_decoder_model:
362
+ # The encoder decoder model works only with XFormers and
363
+ # Flash Attention backend. Assert the same.
364
+ assert self.runner.attn_backend.get_name() in\
365
+ ["XFORMERS", "FLASH_ATTN"], \
366
+ f"Expected attn_backend name to be either 'XFORMERS' or "\
367
+ f"'FLASH_ATTN', but "\
368
+ f"got '{self.runner.attn_backend.get_name()}'"
369
+ self._add_additonal_input_buffers_for_enc_dec_model(
370
+ attn_metadata=attn_metadata, input_buffers=input_buffers)
371
+ return input_buffers
372
+
373
+ def prepare_graph_input_buffers(
374
+ self,
375
+ input_buffers,
376
+ attn_metadata,
377
+ is_encoder_decoder_model: bool = False) -> None:
378
+ input_buffers["seq_lens_tensor"].copy_(
379
+ attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
380
+ input_buffers["block_tables"].copy_(
381
+ attn_metadata.decode_metadata.block_tables, non_blocking=True)
382
+ if is_encoder_decoder_model:
383
+ # The encoder decoder model works only with XFormers and
384
+ # Flash Attention backend. Assert the same.
385
+ assert self.runner.attn_backend.get_name() in\
386
+ ["XFORMERS", "FLASH_ATTN"], \
387
+ f"Expected attn_backend name to be either 'XFORMERS' or "\
388
+ f"'FLASH_ATTN', but "\
389
+ f"got '{self.runner.attn_backend.get_name()}'"
390
+ self._prepare_input_buffers_for_enc_dec_model(
391
+ attn_metadata, input_buffers)
392
+
393
+ def begin_forward(self, model_input) -> None:
394
+ return
395
+
396
+ def _update_captured_metadata_for_enc_dec_model(self, batch_size: int,
397
+ attn_metadata):
398
+ """
399
+ Updates the attention metadata parameters for CUDA graph capture in an
400
+ encoder-decoder model.
401
+
402
+ This method modifies attention-related tensors and metadata required
403
+ for CUDA graph capture in encoder-decoder models. Specifically, it
404
+ updates the cross-attention and encoder sequence tensors in the
405
+ AttentionMetadata object.
406
+ """
407
+ # During decode phase the cross_slot_mapping will be empty. Hence set
408
+ # an empty tensor for CUDA Graph capture.
409
+ attn_metadata.cross_slot_mapping = torch.tensor(
410
+ [], dtype=torch.int).cuda()
411
+ attn_metadata.cross_block_tables = torch.full(
412
+ (batch_size, self.runner.get_max_block_per_batch()),
413
+ 1,
414
+ dtype=torch.int).cuda()
415
+ attn_metadata.encoder_seq_lens = torch.full((batch_size, ),
416
+ 1,
417
+ dtype=torch.int).cuda()
418
+ attn_metadata.encoder_seq_lens_tensor = torch.full(
419
+ (batch_size, ), 1, dtype=torch.int).cuda()
420
+ attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
421
+ attn_metadata.num_encoder_tokens = 0
422
+
423
+ def _add_additonal_input_buffers_for_enc_dec_model(
424
+ self, attn_metadata, input_buffers: Dict[str, Any]):
425
+ """
426
+ Saves additional input buffers specific to the encoder-decoder model
427
+ from the attention metadata.
428
+
429
+ This method extracts and stores encoder-decoder related input buffers
430
+ from the `attn_metadata` into the `input_buffers` dictionary. The
431
+ buffers include encoder sequence lengths, cross-slot mappings, and
432
+ cross-block tables, which are essential for the encoder-decoder model
433
+ during CUDA graph replay.
434
+ """
435
+ input_buffers["encoder_seq_lens_tensor"] = (
436
+ attn_metadata.decode_metadata.encoder_seq_lens_tensor)
437
+ input_buffers["cross_slot_mapping"] = (
438
+ attn_metadata.decode_metadata.cross_slot_mapping)
439
+ input_buffers["cross_block_tables"] = (
440
+ attn_metadata.decode_metadata.cross_block_tables)
441
+
442
+ def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata,
443
+ input_buffers: Dict[str,
444
+ Any]):
445
+ """
446
+ Populates input buffers with data from the encoder-decoder model's
447
+ attention metadata.
448
+
449
+ This method fills the input buffers with encoder-decoder specific
450
+ tensors. It copies data from the `attn_metadata` and keyword arguments
451
+ (`kwargs`) into corresponding buffers in the `input_buffers` dictionary.
452
+ The copied data includes attention-related metadata as well as input
453
+ IDs and positional information for the encoder.
454
+ """
455
+ input_buffers["encoder_seq_lens_tensor"].copy_(
456
+ attn_metadata.decode_metadata.encoder_seq_lens_tensor,
457
+ non_blocking=True)
458
+ input_buffers["cross_slot_mapping"].copy_(
459
+ attn_metadata.decode_metadata.cross_slot_mapping,
460
+ non_blocking=True)
461
+ input_buffers["cross_block_tables"].copy_(
462
+ attn_metadata.decode_metadata.cross_block_tables,
463
+ non_blocking=True)
464
+
465
+
466
+ def is_all_encoder_attn_metadata_set(attn_metadata):
467
+ '''
468
+ All attention metadata required for encoder attention is set.
469
+ '''
470
+ return ((attn_metadata.encoder_seq_lens is not None)
471
+ and (attn_metadata.encoder_seq_lens_tensor is not None)
472
+ and (attn_metadata.max_encoder_seq_len is not None))
473
+
474
+
475
+ def is_all_cross_attn_metadata_set(attn_metadata):
476
+ '''
477
+ All attention metadata required for enc/dec cross-attention is set.
478
+
479
+ Superset of encoder attention required metadata.
480
+ '''
481
+ return (attn_metadata.is_all_encoder_attn_metadata_set
482
+ and (attn_metadata.cross_slot_mapping is not None)
483
+ and (attn_metadata.cross_block_tables is not None))
484
+
485
+
486
+ def get_seq_len_block_table_args(
487
+ attn_metadata,
488
+ is_prompt: bool,
489
+ attn_type: str,
490
+ ) -> tuple:
491
+ '''
492
+ The particular choice of sequence-length- and block-table-related
493
+ attributes which should be extracted from attn_metadata is dependent
494
+ on the type of attention operation.
495
+
496
+ Decoder attn -> select entirely decoder self-attention-related fields
497
+ Encoder/decoder cross-attn -> select encoder sequence lengths &
498
+ cross-attn block-tables fields
499
+ Encoder attn -> select encoder sequence lengths fields & no block tables
500
+
501
+ Arguments:
502
+
503
+ * attn_metadata: Attention metadata structure associated with attention op
504
+ * is_prompt: True if prefill, False otherwise
505
+ * attn_type: encoder attention, decoder self-attention,
506
+ encoder/decoder cross-attention
507
+
508
+ Returns:
509
+
510
+ * Appropriate sequence-lengths tensor
511
+ * Appropriate max sequence-length scalar
512
+ * Appropriate block tables (or None)
513
+ '''
514
+
515
+ if attn_type == AttentionType.DECODER:
516
+ # Decoder self-attention
517
+ # Choose max_seq_len based on whether we are in prompt_run
518
+ if is_prompt:
519
+ max_seq_len = attn_metadata.max_prefill_seq_len
520
+ else:
521
+ max_seq_len = attn_metadata.max_decode_seq_len
522
+ return (attn_metadata.seq_lens_tensor, max_seq_len,
523
+ attn_metadata.block_tables)
524
+ elif attn_type == AttentionType.ENCODER_DECODER:
525
+ # Enc/dec cross-attention KVs match encoder sequence length;
526
+ # cross-attention utilizes special "cross" block tables
527
+ return (attn_metadata.encoder_seq_lens_tensor,
528
+ attn_metadata.max_encoder_seq_len,
529
+ attn_metadata.cross_block_tables)
530
+ elif attn_type == AttentionType.ENCODER:
531
+ # No block tables associated with encoder attention
532
+ return (attn_metadata.encoder_seq_lens_tensor,
533
+ attn_metadata.max_encoder_seq_len, None)
534
+ else:
535
+ raise AttributeError(f"Invalid attention type {str(attn_type)}")
536
+
537
+
538
+ def get_num_prefill_decode_query_kv_tokens(
539
+ attn_metadata,
540
+ attn_type: str,
541
+ ) -> Tuple[int, int, int]:
542
+ """
543
+ Calculate the number of prefill and decode tokens for query, key/value
544
+ based on the attention metadata and the specified attention type.
545
+
546
+ Args:
547
+ attn_metadata (FlashAttentionMetadata): Attention Metadata object.
548
+ attn_type (AttentionType): The type of attention being used.
549
+ Returns:
550
+ Tuple[int, int, int]: A tuple containing three integers:
551
+ - The number of prefill query tokens.
552
+ - The number of prefill key/value tokens.
553
+ - The number of decode query tokens.
554
+
555
+ Raises:
556
+ AssertionError: If the number of encoder tokens in `attn_metadata`
557
+ is `None` when required for the calculations.
558
+ """
559
+ num_prefill_query_tokens = 0
560
+ num_decode_query_tokens = 0
561
+ num_prefill_kv_tokens = 0
562
+ if attn_type == AttentionType.ENCODER:
563
+ # Encoder attention is only invoked during prefill phase.
564
+ # The same input servers a both query and key.
565
+ assert attn_metadata.num_encoder_tokens is not None
566
+ num_prefill_query_tokens = attn_metadata.num_encoder_tokens
567
+ num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
568
+ num_decode_query_tokens = 0
569
+ elif attn_type == AttentionType.ENCODER_DECODER:
570
+ assert attn_metadata.num_encoder_tokens is not None
571
+ num_prefill_query_tokens = attn_metadata.num_prefill_tokens
572
+ # The key is the encoder/cross-attention.
573
+ num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
574
+ num_decode_query_tokens = attn_metadata.num_decode_tokens
575
+ else: # attn_type == AttentionType.DECODER or
576
+ # attn_type == AttentionType.ENCODER_ONLY
577
+ num_prefill_query_tokens = attn_metadata.num_prefill_tokens
578
+ num_prefill_kv_tokens = attn_metadata.num_prefill_tokens
579
+ num_decode_query_tokens = attn_metadata.num_decode_tokens
580
+
581
+ return (num_prefill_query_tokens, num_prefill_kv_tokens,
582
+ num_decode_query_tokens)
.venv/lib/python3.11/site-packages/vllm/attention/backends/xformers.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Attention layer with xFormers and PagedAttention."""
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Tuple, Type
5
+
6
+ import torch
7
+ from xformers import ops as xops
8
+ from xformers.ops.fmha.attn_bias import (AttentionBias,
9
+ BlockDiagonalCausalMask,
10
+ BlockDiagonalMask,
11
+ LowerTriangularMaskWithTensorBias)
12
+
13
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
14
+ AttentionLayer,
15
+ AttentionMetadata, AttentionType)
16
+ from vllm.attention.backends.utils import (
17
+ CommonAttentionState, CommonMetadataBuilder,
18
+ get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
19
+ is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set)
20
+ from vllm.attention.ops.paged_attn import (PagedAttention,
21
+ PagedAttentionMetadata)
22
+ from vllm.logger import init_logger
23
+
24
+ logger = init_logger(__name__)
25
+
26
+
27
+ class XFormersBackend(AttentionBackend):
28
+
29
+ @staticmethod
30
+ def get_name() -> str:
31
+ return "XFORMERS"
32
+
33
+ @staticmethod
34
+ def get_impl_cls() -> Type["XFormersImpl"]:
35
+ return XFormersImpl
36
+
37
+ @staticmethod
38
+ def get_metadata_cls() -> Type["AttentionMetadata"]:
39
+ return XFormersMetadata
40
+
41
+ @staticmethod
42
+ def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
43
+ return XFormersMetadataBuilder
44
+
45
+ @staticmethod
46
+ def get_state_cls() -> Type["CommonAttentionState"]:
47
+ return CommonAttentionState
48
+
49
+ @staticmethod
50
+ def get_kv_cache_shape(
51
+ num_blocks: int,
52
+ block_size: int,
53
+ num_kv_heads: int,
54
+ head_size: int,
55
+ ) -> Tuple[int, ...]:
56
+ return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
57
+ num_kv_heads, head_size)
58
+
59
+ @staticmethod
60
+ def swap_blocks(
61
+ src_kv_cache: torch.Tensor,
62
+ dst_kv_cache: torch.Tensor,
63
+ src_to_dst: Dict[int, int],
64
+ ) -> None:
65
+ PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
66
+
67
+ @staticmethod
68
+ def copy_blocks(
69
+ kv_caches: List[torch.Tensor],
70
+ src_to_dists: torch.Tensor,
71
+ ) -> None:
72
+ PagedAttention.copy_blocks(kv_caches, src_to_dists)
73
+
74
+
75
+ @dataclass
76
+ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
77
+ """Metadata for XFormersbackend.
78
+
79
+ NOTE: Any python object stored here is not updated when it is
80
+ cuda-graph replayed. If you have values that need to be changed
81
+ dynamically, it should be stored in tensor. The tensor has to be
82
+ updated from `CUDAGraphRunner.forward` API.
83
+ """
84
+
85
+ # |---------- N-1 iteration --------|
86
+ # |---------------- N iteration ---------------------|
87
+ # |- tokenA -|......................|-- newTokens ---|
88
+ # |---------- context_len ----------|
89
+ # |-------------------- seq_len ----------------------|
90
+ # |-- query_len ---|
91
+
92
+ # seq_lens stored as a tensor.
93
+ seq_lens_tensor: Optional[torch.Tensor]
94
+
95
+ # FIXME: It is for flash attn.
96
+ # Maximum sequence length among prefill batch. 0 if there are decoding
97
+ # requests only.
98
+ max_prefill_seq_len: int
99
+ # Maximum sequence length among decode batch. 0 if there are prefill
100
+ # requests only.
101
+ max_decode_seq_len: int
102
+
103
+ # Whether or not if cuda graph is enabled.
104
+ # Cuda-graph is currently enabled for decoding only.
105
+ # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
106
+ use_cuda_graph: bool
107
+
108
+ # (batch_size,). The sequence length per sequence. Sequence length means
109
+ # the computed tokens + new tokens None if it is a decoding.
110
+ seq_lens: Optional[List[int]] = None
111
+
112
+ # FIXME: It is for flash attn.
113
+ # (batch_size + 1,). The cumulative sequence lengths of the sequences in
114
+ # the batch, used to index into sequence. E.g., if the sequence length is
115
+ # [4, 6], it is [0, 4, 10].
116
+ seq_start_loc: Optional[torch.Tensor] = None
117
+
118
+ # (batch_size,) A tensor of context lengths (tokens that are computed
119
+ # so far).
120
+ context_lens_tensor: Optional[torch.Tensor] = None
121
+
122
+ # Maximum query length in the batch. None for decoding.
123
+ max_query_len: Optional[int] = None
124
+
125
+ # Max number of query tokens among request in the batch.
126
+ max_decode_query_len: Optional[int] = None
127
+
128
+ # (batch_size + 1,). The cumulative subquery lengths of the sequences in
129
+ # the batch, used to index into subquery. E.g., if the subquery length
130
+ # is [4, 6], it is [0, 4, 10].
131
+ query_start_loc: Optional[torch.Tensor] = None
132
+
133
+ # Self-attention prefill/decode metadata cache
134
+ _cached_prefill_metadata: Optional["XFormersMetadata"] = None
135
+ _cached_decode_metadata: Optional["XFormersMetadata"] = None
136
+
137
+ # Begin encoder attn & enc/dec cross-attn fields...
138
+
139
+ # Encoder sequence lengths representation
140
+ encoder_seq_lens: Optional[List[int]] = None
141
+ encoder_seq_lens_tensor: Optional[torch.Tensor] = None
142
+ # FIXME: It is for flash attn.
143
+ # (batch_size + 1,). The cumulative sequence lengths of the sequences in
144
+ # the batch, used to index into sequence. E.g., if the sequence length is
145
+ # [4, 6], it is [0, 4, 10].
146
+ encoder_seq_start_loc: Optional[torch.Tensor] = None
147
+
148
+ # Maximum sequence length among encoder sequences
149
+ max_encoder_seq_len: Optional[int] = None
150
+
151
+ # Number of tokens input to encoder
152
+ num_encoder_tokens: Optional[int] = None
153
+
154
+ # Cross-attention memory-mapping data structures: slot mapping
155
+ # and block tables
156
+ cross_slot_mapping: Optional[torch.Tensor] = None
157
+ cross_block_tables: Optional[torch.Tensor] = None
158
+
159
+ def __post_init__(self):
160
+ # Set during the execution of the first attention op.
161
+ # It is a list because it is needed to set per prompt
162
+ # when alibi slopes is used. It is because of the limitation
163
+ # from xformer API.
164
+ # will not appear in the __repr__ and __init__
165
+ self.attn_bias: Optional[List[AttentionBias]] = None
166
+ self.encoder_attn_bias: Optional[List[AttentionBias]] = None
167
+ self.cross_attn_bias: Optional[List[AttentionBias]] = None
168
+
169
+ @property
170
+ def is_all_encoder_attn_metadata_set(self):
171
+ '''
172
+ All attention metadata required for encoder attention is set.
173
+ '''
174
+ return is_all_encoder_attn_metadata_set(self)
175
+
176
+ @property
177
+ def is_all_cross_attn_metadata_set(self):
178
+ '''
179
+ All attention metadata required for enc/dec cross-attention is set.
180
+
181
+ Superset of encoder attention required metadata.
182
+ '''
183
+ return is_all_cross_attn_metadata_set(self)
184
+
185
+ @property
186
+ def prefill_metadata(self) -> Optional["XFormersMetadata"]:
187
+ if self.num_prefills == 0:
188
+ return None
189
+
190
+ if self._cached_prefill_metadata is not None:
191
+ # Recover cached prefill-phase attention
192
+ # metadata structure
193
+ return self._cached_prefill_metadata
194
+
195
+ assert ((self.seq_lens is not None)
196
+ or (self.encoder_seq_lens is not None))
197
+ assert ((self.seq_lens_tensor is not None)
198
+ or (self.encoder_seq_lens_tensor is not None))
199
+
200
+ # Compute some attn_metadata fields which default to None
201
+ query_start_loc = (None if self.query_start_loc is None else
202
+ self.query_start_loc[:self.num_prefills + 1])
203
+ seq_start_loc = (None if self.seq_start_loc is None else
204
+ self.seq_start_loc[:self.num_prefills + 1])
205
+ slot_mapping = (None if self.slot_mapping is None else
206
+ self.slot_mapping[:self.num_prefill_tokens])
207
+ seq_lens = (None if self.seq_lens is None else
208
+ self.seq_lens[:self.num_prefills])
209
+ seq_lens_tensor = (None if self.seq_lens_tensor is None else
210
+ self.seq_lens_tensor[:self.num_prefills])
211
+ context_lens_tensor = (None if self.context_lens_tensor is None else
212
+ self.context_lens_tensor[:self.num_prefills])
213
+ block_tables = (None if self.block_tables is None else
214
+ self.block_tables[:self.num_prefills])
215
+
216
+ # Construct & cache prefill-phase attention metadata structure
217
+ self._cached_prefill_metadata = XFormersMetadata(
218
+ num_prefills=self.num_prefills,
219
+ num_prefill_tokens=self.num_prefill_tokens,
220
+ num_decode_tokens=0,
221
+ slot_mapping=slot_mapping,
222
+ multi_modal_placeholder_index_maps=self.
223
+ multi_modal_placeholder_index_maps,
224
+ enable_kv_scales_calculation=self.enable_kv_scales_calculation,
225
+ seq_lens=seq_lens,
226
+ seq_lens_tensor=seq_lens_tensor,
227
+ max_query_len=self.max_query_len,
228
+ max_prefill_seq_len=self.max_prefill_seq_len,
229
+ max_decode_seq_len=0,
230
+ query_start_loc=query_start_loc,
231
+ seq_start_loc=seq_start_loc,
232
+ context_lens_tensor=context_lens_tensor,
233
+ block_tables=block_tables,
234
+ use_cuda_graph=False,
235
+ # Begin encoder & cross attn fields below...
236
+ encoder_seq_lens=self.encoder_seq_lens,
237
+ encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
238
+ max_encoder_seq_len=self.max_encoder_seq_len,
239
+ cross_slot_mapping=self.cross_slot_mapping,
240
+ cross_block_tables=self.cross_block_tables)
241
+ return self._cached_prefill_metadata
242
+
243
+ @property
244
+ def decode_metadata(self) -> Optional["XFormersMetadata"]:
245
+ if self.num_decode_tokens == 0:
246
+ return None
247
+
248
+ if self._cached_decode_metadata is not None:
249
+ # Recover cached decode-phase attention
250
+ # metadata structure
251
+ return self._cached_decode_metadata
252
+ assert ((self.seq_lens_tensor is not None)
253
+ or (self.encoder_seq_lens_tensor is not None))
254
+
255
+ # Compute some attn_metadata fields which default to None
256
+ slot_mapping = (None if self.slot_mapping is None else
257
+ self.slot_mapping[self.num_prefill_tokens:])
258
+ seq_lens_tensor = (None if self.seq_lens_tensor is None else
259
+ self.seq_lens_tensor[self.num_prefills:])
260
+ block_tables = (None if self.block_tables is None else
261
+ self.block_tables[self.num_prefills:])
262
+
263
+ # Construct & cache decode-phase attention metadata structure
264
+ self._cached_decode_metadata = XFormersMetadata(
265
+ num_prefills=0,
266
+ num_prefill_tokens=0,
267
+ num_decode_tokens=self.num_decode_tokens,
268
+ slot_mapping=slot_mapping,
269
+ multi_modal_placeholder_index_maps=None,
270
+ enable_kv_scales_calculation=True,
271
+ seq_lens_tensor=seq_lens_tensor,
272
+ max_prefill_seq_len=0,
273
+ max_decode_seq_len=self.max_decode_seq_len,
274
+ block_tables=block_tables,
275
+ use_cuda_graph=self.use_cuda_graph,
276
+ # Begin encoder & cross attn fields below...
277
+ encoder_seq_lens=self.encoder_seq_lens,
278
+ encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
279
+ max_encoder_seq_len=self.max_encoder_seq_len,
280
+ cross_slot_mapping=self.cross_slot_mapping,
281
+ cross_block_tables=self.cross_block_tables)
282
+
283
+ # Batch may be composed of prefill|decodes, adjust query start indices
284
+ # to refer to the start of decodes when the two are split apart.
285
+ # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
286
+ if self._cached_decode_metadata.query_start_loc is not None:
287
+ qs = self._cached_decode_metadata.query_start_loc
288
+ self._cached_decode_metadata.query_start_loc = qs - qs[0]
289
+ return self._cached_decode_metadata
290
+
291
+
292
+ def _get_attn_bias(
293
+ attn_metadata: XFormersMetadata,
294
+ attn_type: str,
295
+ ) -> Optional[AttentionBias]:
296
+ '''
297
+ Extract appropriate attention bias from attention metadata
298
+ according to attention type.
299
+
300
+ Arguments:
301
+
302
+ * attn_metadata: Attention metadata structure associated with attention
303
+ * attn_type: encoder attention, decoder self-attention,
304
+ encoder/decoder cross-attention
305
+
306
+ Returns:
307
+ * Appropriate attention bias value given the attention type
308
+ '''
309
+
310
+ if (attn_type == AttentionType.DECODER
311
+ or attn_type == AttentionType.ENCODER_ONLY):
312
+ return attn_metadata.attn_bias
313
+ elif attn_type == AttentionType.ENCODER:
314
+ return attn_metadata.encoder_attn_bias
315
+ elif attn_type == AttentionType.ENCODER_DECODER:
316
+ return attn_metadata.cross_attn_bias
317
+ else:
318
+ raise AttributeError(f"Invalid attention type {str(attn_type)}")
319
+
320
+
321
+ def _set_attn_bias(
322
+ attn_metadata: XFormersMetadata,
323
+ attn_bias: List[Optional[AttentionBias]],
324
+ attn_type: str,
325
+ ) -> None:
326
+ '''
327
+ Update appropriate attention bias field of attention metadata,
328
+ according to attention type.
329
+
330
+ Arguments:
331
+
332
+ * attn_metadata: Attention metadata structure associated with attention
333
+ * attn_bias: The desired attention bias value
334
+ * attn_type: encoder attention, decoder self-attention,
335
+ encoder/decoder cross-attention
336
+ '''
337
+
338
+ if (attn_type == AttentionType.DECODER
339
+ or attn_type == AttentionType.ENCODER_ONLY):
340
+ attn_metadata.attn_bias = attn_bias
341
+ elif attn_type == AttentionType.ENCODER:
342
+ attn_metadata.encoder_attn_bias = attn_bias
343
+ elif attn_type == AttentionType.ENCODER_DECODER:
344
+ attn_metadata.cross_attn_bias = attn_bias
345
+ else:
346
+ raise AttributeError(f"Invalid attention type {str(attn_type)}")
347
+
348
+
349
+ class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):
350
+
351
+ _metadata_cls = XFormersMetadata
352
+
353
+
354
+ class XFormersImpl(AttentionImpl[XFormersMetadata]):
355
+ """
356
+ If the input tensors contain prompt tokens, the layout is as follows:
357
+ |<--------------- num_prefill_tokens ----------------->|
358
+ |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
359
+
360
+ Otherwise, the layout is as follows:
361
+ |<----------------- num_decode_tokens ------------------>|
362
+ |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
363
+
364
+ Generation tokens can contain padding when cuda-graph is used.
365
+ Currently, prompt tokens don't contain any padding.
366
+
367
+ The prompts might have different lengths, while the generation tokens
368
+ always have length 1.
369
+
370
+ If chunked prefill is enabled, prefill tokens and decode tokens can be
371
+ batched together in a flattened 1D query.
372
+
373
+ |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
374
+ |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
375
+
376
+ Currently, cuda graph is disabled for chunked prefill, meaning there's no
377
+ padding between prefill and decode tokens.
378
+ """
379
+
380
+ def __init__(
381
+ self,
382
+ num_heads: int,
383
+ head_size: int,
384
+ scale: float,
385
+ num_kv_heads: int,
386
+ alibi_slopes: Optional[List[float]],
387
+ sliding_window: Optional[int],
388
+ kv_cache_dtype: str,
389
+ blocksparse_params: Optional[Dict[str, Any]] = None,
390
+ logits_soft_cap: Optional[float] = None,
391
+ attn_type: str = AttentionType.DECODER,
392
+ ) -> None:
393
+ if blocksparse_params is not None:
394
+ raise ValueError(
395
+ "XFormers does not support block-sparse attention.")
396
+ if logits_soft_cap is not None:
397
+ logger.warning_once("XFormers does not support logits soft cap. "
398
+ "Outputs may be slightly off.")
399
+ self.num_heads = num_heads
400
+ self.head_size = head_size
401
+ self.scale = float(scale)
402
+ self.num_kv_heads = num_kv_heads
403
+ if alibi_slopes is not None:
404
+ alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
405
+ self.alibi_slopes = alibi_slopes
406
+ self.sliding_window = sliding_window
407
+ self.kv_cache_dtype = kv_cache_dtype
408
+
409
+ assert self.num_heads % self.num_kv_heads == 0
410
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
411
+
412
+ suppored_head_sizes = PagedAttention.get_supported_head_sizes()
413
+ if head_size not in suppored_head_sizes:
414
+ raise ValueError(
415
+ f"Head size {head_size} is not supported by PagedAttention. "
416
+ f"Supported head sizes are: {suppored_head_sizes}.")
417
+
418
+ self.attn_type = attn_type
419
+
420
+ def forward(
421
+ self,
422
+ layer: AttentionLayer,
423
+ query: torch.Tensor,
424
+ key: Optional[torch.Tensor],
425
+ value: Optional[torch.Tensor],
426
+ kv_cache: torch.Tensor,
427
+ attn_metadata: "XFormersMetadata",
428
+ output: Optional[torch.Tensor] = None,
429
+ ) -> torch.Tensor:
430
+ """Forward pass with xFormers and PagedAttention.
431
+
432
+ For decoder-only models: query, key and value must be non-None.
433
+
434
+ For encoder/decoder models:
435
+ * XFormersImpl.forward() may be invoked for both self- and cross-
436
+ attention layers.
437
+ * For self-attention: query, key and value must be non-None.
438
+ * For cross-attention:
439
+ * Query must be non-None
440
+ * During prefill, key and value must be non-None; key and value
441
+ get cached for use during decode.
442
+ * During decode, key and value may be None, since:
443
+ (1) key and value tensors were cached during prefill, and
444
+ (2) cross-attention key and value tensors do not grow during
445
+ decode
446
+
447
+ A note on how the attn_type (attention type enum) argument impacts
448
+ attention forward() behavior:
449
+
450
+ * DECODER: normal decoder-only behavior;
451
+ use decoder self-attention block table
452
+ * ENCODER: no KV caching; pass encoder sequence
453
+ attributes (encoder_seq_lens/encoder_seq_lens_tensor/
454
+ max_encoder_seq_len) to kernel, in lieu of decoder
455
+ sequence attributes (seq_lens/seq_lens_tensor/max_seq_len).
456
+ Used for encoder branch of encoder-decoder models.
457
+ * ENCODER_ONLY: no kv_caching, uses the normal attention
458
+ attributes (seq_lens/seq_lens_tensor/max_seq_len).
459
+ * ENCODER_DECODER: cross-attention behavior;
460
+ use cross-attention block table for caching KVs derived
461
+ from encoder hidden states; since KV sequence lengths
462
+ will match encoder sequence lengths, pass encoder sequence
463
+ attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
464
+ max_encoder_seq_len)
465
+
466
+ Args:
467
+ query: shape = [num_tokens, num_heads * head_size]
468
+ key: shape = [num_tokens, num_kv_heads * head_size]
469
+ value: shape = [num_tokens, num_kv_heads * head_size]
470
+ kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
471
+ NOTE: kv_cache will be an empty tensor with shape [0]
472
+ for profiling run.
473
+ attn_metadata: Metadata for attention.
474
+ attn_type: Select attention type, between encoder attention,
475
+ decoder self-attention, or encoder/decoder cross-
476
+ attention. Defaults to decoder self-attention,
477
+ which is the vLLM default generally
478
+ Returns:
479
+ shape = [num_tokens, num_heads * head_size]
480
+ """
481
+ attn_type = self.attn_type
482
+ # Check that appropriate attention metadata attributes are
483
+ # selected for the desired attention type
484
+ if (attn_type == AttentionType.ENCODER
485
+ and (not attn_metadata.is_all_encoder_attn_metadata_set)):
486
+ raise AttributeError("Encoder attention requires setting "
487
+ "encoder metadata attributes.")
488
+
489
+ elif (attn_type == AttentionType.ENCODER_DECODER
490
+ and (not attn_metadata.is_all_cross_attn_metadata_set)):
491
+ raise AttributeError("Encoder/decoder cross-attention "
492
+ "requires setting cross-attention "
493
+ "metadata attributes.")
494
+
495
+ query = query.view(-1, self.num_heads, self.head_size)
496
+ if key is not None:
497
+ assert value is not None
498
+ key = key.view(-1, self.num_kv_heads, self.head_size)
499
+ value = value.view(-1, self.num_kv_heads, self.head_size)
500
+ else:
501
+ assert value is None
502
+
503
+ # Self-attention vs. cross-attention will impact
504
+ # which KV cache memory-mapping & which
505
+ # seqlen datastructures we utilize
506
+
507
+ if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
508
+ # KV-cache during decoder-self- or
509
+ # encoder-decoder-cross-attention, but not
510
+ # during encoder attention.
511
+ #
512
+ # Even if there are no new key/value pairs to cache,
513
+ # we still need to break out key_cache and value_cache
514
+ # i.e. for later use by paged attention
515
+ key_cache, value_cache = PagedAttention.split_kv_cache(
516
+ kv_cache, self.num_kv_heads, self.head_size)
517
+
518
+ if (key is not None) and (value is not None):
519
+
520
+ if attn_type == AttentionType.ENCODER_DECODER:
521
+ # Update cross-attention KV cache (prefill-only)
522
+ # During cross-attention decode, key & value will be None,
523
+ # preventing this IF-statement branch from running
524
+ updated_slot_mapping = attn_metadata.cross_slot_mapping
525
+ else:
526
+ # Update self-attention KV cache (prefill/decode)
527
+ updated_slot_mapping = attn_metadata.slot_mapping
528
+
529
+ # Reshape the input keys and values and store them in the cache.
530
+ # If kv_cache is not provided, the new key and value tensors are
531
+ # not cached. This happens during the initial memory
532
+ # profiling run.
533
+ PagedAttention.write_to_paged_cache(
534
+ key, value, key_cache, value_cache, updated_slot_mapping,
535
+ self.kv_cache_dtype, layer._k_scale, layer._v_scale)
536
+ (num_prefill_query_tokens, num_prefill_kv_tokens,
537
+ num_decode_query_tokens) = \
538
+ get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
539
+
540
+ output = torch.empty_like(query)
541
+ # Query for decode. KV is not needed because it is already cached.
542
+ decode_query = query[num_prefill_query_tokens:]
543
+ # QKV for prefill.
544
+ query = query[:num_prefill_query_tokens]
545
+ if key is not None and value is not None:
546
+ key = key[:num_prefill_kv_tokens]
547
+ value = value[:num_prefill_kv_tokens]
548
+
549
+ assert query.shape[0] == num_prefill_query_tokens
550
+ assert decode_query.shape[0] == num_decode_query_tokens
551
+
552
+ if prefill_meta := attn_metadata.prefill_metadata:
553
+ # Prompt run.
554
+ if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
555
+ # normal attention.
556
+ # block tables are empty if the prompt does not have a cached
557
+ # prefix.
558
+ out = self._run_memory_efficient_xformers_forward(
559
+ query, key, value, prefill_meta, attn_type=attn_type)
560
+ assert out.shape == output[:num_prefill_query_tokens].shape
561
+ output[:num_prefill_query_tokens] = out
562
+ else:
563
+ assert attn_type != AttentionType.ENCODER_ONLY, (
564
+ "Encoder-only models should not have prefix attention.")
565
+
566
+ assert prefill_meta.query_start_loc is not None
567
+ assert prefill_meta.max_query_len is not None
568
+
569
+ # prefix-enabled attention
570
+ # TODO(Hai) this triton kernel has regression issue (broke) to
571
+ # deal with different data types between KV and FP8 KV cache,
572
+ # to be addressed separately.
573
+ out = PagedAttention.forward_prefix(
574
+ query,
575
+ key,
576
+ value,
577
+ self.kv_cache_dtype,
578
+ key_cache,
579
+ value_cache,
580
+ prefill_meta.block_tables,
581
+ prefill_meta.query_start_loc,
582
+ prefill_meta.seq_lens_tensor,
583
+ prefill_meta.context_lens_tensor,
584
+ prefill_meta.max_query_len,
585
+ self.alibi_slopes,
586
+ self.sliding_window,
587
+ layer._k_scale,
588
+ layer._v_scale,
589
+ )
590
+ assert output[:num_prefill_query_tokens].shape == out.shape
591
+ output[:num_prefill_query_tokens] = out
592
+
593
+ if decode_meta := attn_metadata.decode_metadata:
594
+ assert attn_type != AttentionType.ENCODER_ONLY, (
595
+ "Encoder-only models should not have decode metadata.")
596
+
597
+ (
598
+ seq_lens_arg,
599
+ max_seq_len_arg,
600
+ block_tables_arg,
601
+ ) = get_seq_len_block_table_args(decode_meta, False, attn_type)
602
+
603
+ output[num_prefill_query_tokens:] = PagedAttention.forward_decode(
604
+ decode_query,
605
+ key_cache,
606
+ value_cache,
607
+ block_tables_arg,
608
+ seq_lens_arg,
609
+ max_seq_len_arg,
610
+ self.kv_cache_dtype,
611
+ self.num_kv_heads,
612
+ self.scale,
613
+ self.alibi_slopes,
614
+ layer._k_scale,
615
+ layer._v_scale,
616
+ )
617
+
618
+ # Reshape the output tensor.
619
+ return output.view(-1, self.num_heads * self.head_size)
620
+
621
+ def _run_memory_efficient_xformers_forward(
622
+ self,
623
+ query: torch.Tensor,
624
+ key: torch.Tensor,
625
+ value: torch.Tensor,
626
+ attn_metadata: XFormersMetadata,
627
+ attn_type: str = AttentionType.DECODER,
628
+ ) -> torch.Tensor:
629
+ """Attention for 1D query of multiple prompts. Multiple prompt
630
+ tokens are flattened in to `query` input.
631
+
632
+ See https://facebookresearch.github.io/xformers/components/ops.html
633
+ for API spec.
634
+
635
+ Args:
636
+ output: shape = [num_prefill_tokens, num_heads, head_size]
637
+ query: shape = [num_prefill_tokens, num_heads, head_size]
638
+ key: shape = [num_prefill_tokens, num_kv_heads, head_size]
639
+ value: shape = [num_prefill_tokens, num_kv_heads, head_size]
640
+ attn_metadata: Metadata for attention.
641
+ attn_type: Select attention type, between encoder attention,
642
+ decoder self-attention, or encoder/decoder cross-
643
+ attention. Defaults to decoder self-attention,
644
+ which is the vLLM default generally
645
+ """
646
+
647
+ original_query = query
648
+ if self.num_kv_heads != self.num_heads:
649
+ # GQA/MQA requires the shape [B, M, G, H, K].
650
+ # Note that the output also has the same shape (which is different
651
+ # from a spec from the doc).
652
+ query = query.view(query.shape[0], self.num_kv_heads,
653
+ self.num_queries_per_kv, query.shape[-1])
654
+ key = key[:, :,
655
+ None, :].expand(key.shape[0], self.num_kv_heads,
656
+ self.num_queries_per_kv, key.shape[-1])
657
+ value = value[:, :,
658
+ None, :].expand(value.shape[0], self.num_kv_heads,
659
+ self.num_queries_per_kv,
660
+ value.shape[-1])
661
+
662
+ # Set attention bias if not provided. This typically happens at
663
+ # the very attention layer of every iteration.
664
+ # FIXME(woosuk): This is a hack.
665
+ attn_bias = _get_attn_bias(attn_metadata, attn_type)
666
+ if attn_bias is None:
667
+ if self.alibi_slopes is None:
668
+
669
+ # Cross attention block of decoder branch of encoder-decoder
670
+ # model uses seq_lens for dec / encoder_seq_lens for enc
671
+ if (attn_type == AttentionType.ENCODER_DECODER):
672
+ assert attn_metadata.seq_lens is not None
673
+ assert attn_metadata.encoder_seq_lens is not None
674
+
675
+ # Cross-attention mask is non-causal
676
+ attn_bias = BlockDiagonalMask.from_seqlens(
677
+ attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
678
+
679
+ # Encoder branch of encoder-decoder model uses
680
+ # attn_metadata.encoder_seq_lens
681
+ elif attn_type == AttentionType.ENCODER:
682
+
683
+ assert attn_metadata.encoder_seq_lens is not None
684
+
685
+ # Encoder self-attention mask is non-causal
686
+ attn_bias = BlockDiagonalMask.from_seqlens(
687
+ attn_metadata.encoder_seq_lens)
688
+
689
+ # Self-attention block of encoder-only model just
690
+ # uses the seq_lens directly.
691
+ elif attn_type == AttentionType.ENCODER_ONLY:
692
+ assert attn_metadata.seq_lens is not None
693
+
694
+ # Encoder self-attention mask is non-causal
695
+ attn_bias = BlockDiagonalMask.from_seqlens(
696
+ attn_metadata.seq_lens)
697
+
698
+ # Self-attention block of decoder branch just
699
+ # uses the seq_lens directly
700
+ elif attn_type == AttentionType.DECODER:
701
+ assert attn_metadata.seq_lens is not None
702
+
703
+ # Decoder self-attention mask is causal
704
+ attn_bias = BlockDiagonalCausalMask.from_seqlens(
705
+ attn_metadata.seq_lens)
706
+ else:
707
+ raise ValueError("Unknown AttentionType: %s", attn_type)
708
+
709
+ if self.sliding_window is not None:
710
+ attn_bias = attn_bias.make_local_attention(
711
+ self.sliding_window)
712
+ attn_bias = [attn_bias]
713
+ else:
714
+ assert attn_type == AttentionType.DECODER
715
+ assert attn_metadata.seq_lens is not None
716
+ attn_bias = _make_alibi_bias(self.alibi_slopes,
717
+ self.num_kv_heads, query.dtype,
718
+ attn_metadata.seq_lens)
719
+
720
+ _set_attn_bias(attn_metadata, attn_bias, attn_type)
721
+
722
+ # No alibi slopes.
723
+ # TODO(woosuk): Too many view operations. Let's try to reduce
724
+ # them in the future for code readability.
725
+ if self.alibi_slopes is None:
726
+ # Add the batch dimension.
727
+ query = query.unsqueeze(0)
728
+ key = key.unsqueeze(0)
729
+ value = value.unsqueeze(0)
730
+ out = xops.memory_efficient_attention_forward(
731
+ query,
732
+ key,
733
+ value,
734
+ attn_bias=attn_bias[0],
735
+ p=0.0,
736
+ scale=self.scale)
737
+ return out.view_as(original_query)
738
+
739
+ # Attention with alibi slopes.
740
+ # FIXME(woosuk): Because xformers does not support dynamic sequence
741
+ # lengths with custom attention bias, we process each prompt one by
742
+ # one. This is inefficient, especially when we have many short prompts.
743
+ assert attn_metadata.seq_lens is not None
744
+ output = torch.empty_like(original_query)
745
+ start = 0
746
+ for i, seq_len in enumerate(attn_metadata.seq_lens):
747
+ end = start + seq_len
748
+ out = xops.memory_efficient_attention_forward(
749
+ query[None, start:end],
750
+ key[None, start:end],
751
+ value[None, start:end],
752
+ attn_bias=attn_bias[i],
753
+ p=0.0,
754
+ scale=self.scale)
755
+ # TODO(woosuk): Unnecessary copy. Optimize.
756
+ output[start:end].copy_(out.view_as(original_query[start:end]))
757
+ start += seq_len
758
+ return output
759
+
760
+
761
+ def _make_alibi_bias(
762
+ alibi_slopes: torch.Tensor,
763
+ num_kv_heads: int,
764
+ dtype: torch.dtype,
765
+ seq_lens: List[int],
766
+ ) -> List[AttentionBias]:
767
+ attn_biases: List[AttentionBias] = []
768
+ for seq_len in seq_lens:
769
+ bias = torch.arange(seq_len, dtype=dtype)
770
+ # NOTE(zhuohan): HF uses
771
+ # `bias = bias[None, :].repeat(seq_len, 1)`
772
+ # here. We find that both biases give the same results, but
773
+ # the bias below more accurately follows the original ALiBi
774
+ # paper.
775
+ # Calculate a matrix where each element represents ith element- jth
776
+ # element.
777
+ bias = bias[None, :] - bias[:, None]
778
+
779
+ padded_len = (seq_len + 7) // 8 * 8
780
+ num_heads = alibi_slopes.shape[0]
781
+ bias = torch.empty(
782
+ 1, # batch size
783
+ num_heads,
784
+ seq_len,
785
+ padded_len,
786
+ device=alibi_slopes.device,
787
+ dtype=dtype,
788
+ )[:, :, :, :seq_len].copy_(bias)
789
+ bias.mul_(alibi_slopes[:, None, None])
790
+ if num_heads != num_kv_heads:
791
+ bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
792
+ attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
793
+
794
+ return attn_biases
.venv/lib/python3.11/site-packages/vllm/attention/ops/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (191 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/hpu_paged_attn.cpython-311.pyc ADDED
Binary file (5.95 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/ipex_attn.cpython-311.pyc ADDED
Binary file (8.01 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/nki_flash_attn.cpython-311.pyc ADDED
Binary file (25.6 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/paged_attn.cpython-311.pyc ADDED
Binary file (8.6 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/prefix_prefill.cpython-311.pyc ADDED
Binary file (31.2 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/triton_decode_attention.cpython-311.pyc ADDED
Binary file (19.3 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/ops/__pycache__/triton_flash_attention.cpython-311.pyc ADDED
Binary file (20.5 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (213 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/attention/ops/blocksparse_attention/__pycache__/blocksparse_attention_kernel.cpython-311.pyc ADDED
Binary file (14.8 kB). View file