Upload inference/vllm-ascend_v0.11.0rc0.patch with huggingface_hub
Browse files
inference/vllm-ascend_v0.11.0rc0.patch
ADDED
|
@@ -0,0 +1,847 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py
|
| 2 |
+
index d289bb4..0357b50 100644
|
| 3 |
+
--- a/vllm_ascend/attention/attention_v1.py
|
| 4 |
+
+++ b/vllm_ascend/attention/attention_v1.py
|
| 5 |
+
@@ -21,6 +21,7 @@ from typing import ClassVar, List, Optional, Tuple, Type
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
+import torch.nn.functional as F
|
| 10 |
+
import torch_npu
|
| 11 |
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
| 12 |
+
AttentionLayer, AttentionType)
|
| 13 |
+
@@ -30,6 +31,7 @@ from vllm.utils import cdiv, direct_register_custom_op
|
| 14 |
+
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
| 15 |
+
from vllm.v1.core.sched.output import SchedulerOutput
|
| 16 |
+
from vllm.v1.kv_cache_interface import AttentionSpec
|
| 17 |
+
+from vllm.model_executor.models.utils import extract_layer_index
|
| 18 |
+
|
| 19 |
+
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
| 20 |
+
maybe_save_kv_layer_to_connector,
|
| 21 |
+
@@ -39,6 +41,9 @@ from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
| 22 |
+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
| 23 |
+
nd_to_nz_2d, nd_to_nz_spec)
|
| 24 |
+
|
| 25 |
+
+if torch.version.cann.startswith("8.3"):
|
| 26 |
+
+ import omni_custom_ops
|
| 27 |
+
+
|
| 28 |
+
|
| 29 |
+
class AscendAttentionBackend(AttentionBackend):
|
| 30 |
+
accept_output_buffer: bool = True
|
| 31 |
+
@@ -115,6 +120,7 @@ class AscendAttentionBackend(AttentionBackend):
|
| 32 |
+
return [64]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
+
|
| 36 |
+
class AscendAttentionState(Enum):
|
| 37 |
+
PrefillNoCache = 0
|
| 38 |
+
PrefillCacheHit = 1
|
| 39 |
+
@@ -135,8 +141,8 @@ class AscendMetadata:
|
| 40 |
+
num_actual_tokens: int = 0
|
| 41 |
+
|
| 42 |
+
# The sequence length per sequence. Sequence length means the computed
|
| 43 |
+
- # tokens + new tokens (is None if it is a decoding).
|
| 44 |
+
- # (batch_size,)
|
| 45 |
+
+ # tokens + new tokens (is None if it is a decoding).(batch_size,)
|
| 46 |
+
+
|
| 47 |
+
seq_lens: torch.Tensor = None
|
| 48 |
+
|
| 49 |
+
query_start_loc: torch.Tensor = None
|
| 50 |
+
@@ -145,20 +151,25 @@ class AscendMetadata:
|
| 51 |
+
max_query_len: Optional[int] = None
|
| 52 |
+
|
| 53 |
+
# ********************** KV Cache Related Properties ********************* #
|
| 54 |
+
- # Block addresses per sequence (Seq id -> list of physical block).
|
| 55 |
+
- # (batch_size, max_blocks_per_seq)
|
| 56 |
+
+ # Block addresses per sequence (Seq id -> list of physical block).(batch_size, max_blocks_per_seq)
|
| 57 |
+
+
|
| 58 |
+
block_tables: torch.Tensor = None
|
| 59 |
+
|
| 60 |
+
# The indices of the token slots that input tokens will be stored into.
|
| 61 |
+
# E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the
|
| 62 |
+
# three tokens are stored in the 3rd slot in block 2, 2nd slot in block 0,
|
| 63 |
+
- # and 1st slot in block 1, respectively.
|
| 64 |
+
- # (num_tokens,)
|
| 65 |
+
+ # and 1st slot in block 1, respectively. (num_tokens,)
|
| 66 |
+
+
|
| 67 |
+
slot_mapping: torch.Tensor = None
|
| 68 |
+
|
| 69 |
+
# *************************** Other Properties *************************** #
|
| 70 |
+
enable_dbo_across_dp: bool = False
|
| 71 |
+
|
| 72 |
+
+ # Patch for param sink
|
| 73 |
+
+ sink_block_tables: Optional[List[torch.Tensor]] = None
|
| 74 |
+
+ sink_attn_mask: Optional[torch.Tensor] = None
|
| 75 |
+
+ sink_seq_kvlens: torch.Tensor = None
|
| 76 |
+
+ swa_seq_qlens: torch.Tensor = None
|
| 77 |
+
|
| 78 |
+
class AscendAttentionMetadataBuilder:
|
| 79 |
+
# Does this backend/builder support ACL Graphs for attention (default: no).
|
| 80 |
+
@@ -182,6 +193,7 @@ class AscendAttentionMetadataBuilder:
|
| 81 |
+
self.max_num_blocks_per_req = cdiv(
|
| 82 |
+
self.model_config.max_model_len,
|
| 83 |
+
AscendAttentionBackend.get_supported_block_size()[0])
|
| 84 |
+
+ self.param_sink_number = self.model_config.hf_config.param_sink_number
|
| 85 |
+
|
| 86 |
+
def reorder_batch(self, input_batch,
|
| 87 |
+
scheduler_output: "SchedulerOutput") -> bool:
|
| 88 |
+
@@ -210,6 +222,33 @@ class AscendAttentionMetadataBuilder:
|
| 89 |
+
query_start_loc = query_start_loc_cpu.to(self.device,
|
| 90 |
+
non_blocking=True)
|
| 91 |
+
|
| 92 |
+
+ num_input_tokens = common_attn_metadata.num_input_tokens
|
| 93 |
+
+
|
| 94 |
+
+
|
| 95 |
+
+ if num_input_tokens > num_reqs and attn_state == AscendAttentionState.DecodeOnly:
|
| 96 |
+
+ tokens_gap_num = num_input_tokens-num_reqs
|
| 97 |
+
+
|
| 98 |
+
+ sink_block_tables = F.pad(block_table, (1, 0, 0, tokens_gap_num), value=0)
|
| 99 |
+
+
|
| 100 |
+
+ sink_seq_kvlens = seq_lens + self.param_sink_number
|
| 101 |
+
+ sink_seq_kvlens = torch.cat([sink_seq_kvlens, torch.full((tokens_gap_num,), \
|
| 102 |
+
+ self.param_sink_number, dtype=torch.int32)], dim=0)
|
| 103 |
+
+
|
| 104 |
+
+ gap_query_lens = torch.cat([query_lens, torch.ones(tokens_gap_num, dtype=torch.int32)], dim=0)
|
| 105 |
+
+ swa_seq_qlens = torch.cumsum(gap_query_lens, dim=0).to(dtype=torch.int32)
|
| 106 |
+
+ else:
|
| 107 |
+
+ sink_block_tables = F.pad(block_table, (1, 0, 0, 0), value=0)
|
| 108 |
+
+ sink_seq_kvlens = seq_lens + self.param_sink_number
|
| 109 |
+
+ swa_seq_qlens = torch.cumsum(query_lens, dim=0).to(dtype=torch.int32)
|
| 110 |
+
+
|
| 111 |
+
+
|
| 112 |
+
+ if attn_mask is not None:
|
| 113 |
+
+ sink_attn_mask = F.pad(attn_mask, (self.param_sink_number, 0, 0, 0), value=0)
|
| 114 |
+
+ else:
|
| 115 |
+
+ sink_attn_mask = None
|
| 116 |
+
+
|
| 117 |
+
+
|
| 118 |
+
+
|
| 119 |
+
if is_310p():
|
| 120 |
+
if attn_state == AscendAttentionState.PrefillNoCache:
|
| 121 |
+
mask_nz = nd_to_nz_2d(attn_mask)
|
| 122 |
+
@@ -230,7 +269,12 @@ class AscendAttentionMetadataBuilder:
|
| 123 |
+
slot_mapping=slot_mapping,
|
| 124 |
+
attn_mask=attn_mask,
|
| 125 |
+
attn_state=attn_state,
|
| 126 |
+
- enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
|
| 127 |
+
+ enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
| 128 |
+
+ sink_block_tables=sink_block_tables,
|
| 129 |
+
+ sink_attn_mask=sink_attn_mask,
|
| 130 |
+
+ sink_seq_kvlens=sink_seq_kvlens,
|
| 131 |
+
+ swa_seq_qlens=swa_seq_qlens
|
| 132 |
+
+ )
|
| 133 |
+
return attn_metadata
|
| 134 |
+
|
| 135 |
+
def build_for_graph_capture(
|
| 136 |
+
@@ -265,6 +309,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
| 137 |
+
kv_cache_dtype: str,
|
| 138 |
+
logits_soft_cap: Optional[float],
|
| 139 |
+
attn_type: str,
|
| 140 |
+
+ layer_name: str,
|
| 141 |
+
kv_sharing_target_layer_name: Optional[str],
|
| 142 |
+
**kwargs,
|
| 143 |
+
) -> None:
|
| 144 |
+
@@ -287,6 +332,13 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
| 145 |
+
self.key_cache = None
|
| 146 |
+
self.value_cache = None
|
| 147 |
+
|
| 148 |
+
+ self.layer_idx = extract_layer_index(layer_name)
|
| 149 |
+
+
|
| 150 |
+
+ # Patch for Sink
|
| 151 |
+
+ self.sink_cached = False
|
| 152 |
+
+ self.attn_mask = torch.ones((2048, 2048), dtype=torch.int8, device="npu").triu_(diagonal=1)
|
| 153 |
+
+ self.attn_mask = self.attn_mask.to(torch.bool)
|
| 154 |
+
+
|
| 155 |
+
def _forward_prefill_no_cache(
|
| 156 |
+
self,
|
| 157 |
+
query: torch.Tensor,
|
| 158 |
+
@@ -295,6 +347,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
| 159 |
+
attn_metadata: AscendMetadata,
|
| 160 |
+
output: Optional[torch.Tensor] = None,
|
| 161 |
+
num_tokens=0,
|
| 162 |
+
+ param_sink_number: Optional[int] = 0
|
| 163 |
+
) -> torch.Tensor:
|
| 164 |
+
assert attn_metadata is not None
|
| 165 |
+
assert attn_metadata.attn_mask is not None
|
| 166 |
+
@@ -311,18 +364,72 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
| 167 |
+
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
|
| 168 |
+
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
| 169 |
+
ACL_FORMAT_FRACTAL_NZ)
|
| 170 |
+
+ if torch.version.cann.startswith("8.3"):
|
| 171 |
+
+ mask = torch.ones((2048, 2048), dtype=torch.int8, device=mask.device).triu_(diagonal=1)
|
| 172 |
+
+ # TODO: nocache swa
|
| 173 |
+
+ if param_sink_number > 0:
|
| 174 |
+
+ query_lens = attn_metadata.query_lens
|
| 175 |
+
+ seq_lens = attn_metadata.seq_lens + param_sink_number
|
| 176 |
+
+ output, _ = torch.ops.custom.npu_fused_infer_attention_sink(
|
| 177 |
+
+ query,
|
| 178 |
+
+ key,
|
| 179 |
+
+ value,
|
| 180 |
+
+ atten_mask=mask,
|
| 181 |
+
+ actual_seq_qlen=query_lens,
|
| 182 |
+
+ actual_seq_kvlen=seq_lens,
|
| 183 |
+
+ num_query_heads=self.num_heads,
|
| 184 |
+
+ num_key_value_heads=self.num_kv_heads,
|
| 185 |
+
+ input_layout='TND',
|
| 186 |
+
+ sparse_mode=3,
|
| 187 |
+
+ sink_number=param_sink_number,
|
| 188 |
+
+ softmax_scale=self.scale,
|
| 189 |
+
+ )
|
| 190 |
+
+ else:
|
| 191 |
+
+ output, _ = torch_npu.npu_fused_infer_attention_score(
|
| 192 |
+
+ query=query,
|
| 193 |
+
+ key=key,
|
| 194 |
+
+ value=value,
|
| 195 |
+
+ atten_mask=mask,
|
| 196 |
+
+ input_layout="TND",
|
| 197 |
+
+ actual_seq_lengths=attn_metadata.query_start_loc[1:],
|
| 198 |
+
+ actual_seq_lengths_kv=attn_metadata.seq_lens,
|
| 199 |
+
+ num_key_value_heads=self.num_kv_heads,
|
| 200 |
+
+ num_heads=self.num_heads,
|
| 201 |
+
+ scale=self.scale,
|
| 202 |
+
+ sparse_mode=3,
|
| 203 |
+
+ )
|
| 204 |
+
+ return output
|
| 205 |
+
+ # Patch for sink on CANN8.2
|
| 206 |
+
+ if param_sink_number > 0:
|
| 207 |
+
+ seq_lens = attn_metadata.seq_lens + param_sink_number
|
| 208 |
+
+ # TODO: _npu_flash_attention only allows qlen==kvlen,
|
| 209 |
+
+ mask_elem = mask[0, -1]
|
| 210 |
+
+ sink_mask = torch.full((mask.size(0) + param_sink_number,
|
| 211 |
+
+ mask.size(1) + param_sink_number),
|
| 212 |
+
+ mask_elem, dtype=mask.dtype, device=mask.device)
|
| 213 |
+
+ sink_mask[param_sink_number:, :param_sink_number] = 0.0
|
| 214 |
+
+ sink_mask[param_sink_number:, param_sink_number:] = mask
|
| 215 |
+
+ sink_mask[:param_sink_number, :param_sink_number].triu_(diagonal=1)
|
| 216 |
+
+ mask = sink_mask
|
| 217 |
+
+
|
| 218 |
+
+ output = torch.zeros((output.size(0) + param_sink_number,
|
| 219 |
+
+ output.size(1), output.size(2)),
|
| 220 |
+
+ dtype=output.dtype,
|
| 221 |
+
+ device=output.device)
|
| 222 |
+
+ else:
|
| 223 |
+
+ seq_lens = attn_metadata.seq_lens
|
| 224 |
+
|
| 225 |
+
torch_npu._npu_flash_attention(query=query,
|
| 226 |
+
key=key,
|
| 227 |
+
value=value,
|
| 228 |
+
mask=mask,
|
| 229 |
+
- seq_len=attn_metadata.seq_lens,
|
| 230 |
+
+ seq_len=seq_lens,
|
| 231 |
+
scale_value=self.scale,
|
| 232 |
+
num_heads=self.num_heads,
|
| 233 |
+
num_kv_heads=self.num_kv_heads,
|
| 234 |
+
out=output)
|
| 235 |
+
assert output is not None
|
| 236 |
+
- return output[:num_tokens, :, :]
|
| 237 |
+
+ return output[param_sink_number:param_sink_number + num_tokens, :, :]
|
| 238 |
+
|
| 239 |
+
def _forward_prefill_cache_hit(
|
| 240 |
+
self,
|
| 241 |
+
@@ -356,6 +463,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
| 242 |
+
query: torch.Tensor,
|
| 243 |
+
attn_metadata: AscendMetadata,
|
| 244 |
+
output: Optional[torch.Tensor] = None,
|
| 245 |
+
+ layer: AttentionLayer = None,
|
| 246 |
+
+ param_sink_number: Optional[int] = 0
|
| 247 |
+
) -> torch.Tensor:
|
| 248 |
+
if is_310p():
|
| 249 |
+
# seq_lens_tensor needs to be transferred to the device for 310P.
|
| 250 |
+
@@ -426,16 +535,46 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
| 251 |
+
handle = torch.npu.graph_task_group_end(stream)
|
| 252 |
+
graph_params.handles[num_tokens].append(handle)
|
| 253 |
+
else:
|
| 254 |
+
- torch_npu._npu_paged_attention(
|
| 255 |
+
- query=query,
|
| 256 |
+
- key_cache=self.key_cache,
|
| 257 |
+
- value_cache=self.value_cache,
|
| 258 |
+
- num_kv_heads=self.num_kv_heads,
|
| 259 |
+
- num_heads=self.num_heads,
|
| 260 |
+
- scale_value=self.scale,
|
| 261 |
+
- block_table=attn_metadata.block_tables,
|
| 262 |
+
- context_lens=attn_metadata.seq_lens,
|
| 263 |
+
- out=output)
|
| 264 |
+
+ # Patch for Sparse KV cache of SWA.
|
| 265 |
+
+ num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
| 266 |
+
+ key = self.key_cache.view( # type: ignore
|
| 267 |
+
+ num_block, block_size, -1)
|
| 268 |
+
+ value = self.value_cache.view( # type: ignore
|
| 269 |
+
+ num_block, block_size, -1)
|
| 270 |
+
+ block_tables = attn_metadata.sink_block_tables
|
| 271 |
+
+ use_swa = (self.layer_idx % 2 == 0)
|
| 272 |
+
+ seq_kvlens = attn_metadata.sink_seq_kvlens
|
| 273 |
+
+ if use_swa:
|
| 274 |
+
+ attn_mask = self.attn_mask.to(query.device, non_blocking=True)
|
| 275 |
+
+
|
| 276 |
+
+ output, _ = torch.ops.custom.npu_fused_infer_attention_sink(
|
| 277 |
+
+ query,
|
| 278 |
+
+ key,
|
| 279 |
+
+ value,
|
| 280 |
+
+ atten_mask=attn_mask,
|
| 281 |
+
+ actual_seq_qlen=attn_metadata.swa_seq_qlens,
|
| 282 |
+
+ actual_seq_kvlen=seq_kvlens,
|
| 283 |
+
+ block_table=block_tables,
|
| 284 |
+
+ pre_tokens=128,
|
| 285 |
+
+ next_tokens=0,
|
| 286 |
+
+ num_query_heads=self.num_heads,
|
| 287 |
+
+ num_key_value_heads=self.num_kv_heads,
|
| 288 |
+
+ input_layout='TND',
|
| 289 |
+
+ sparse_mode=4,
|
| 290 |
+
+ block_size=block_size,
|
| 291 |
+
+ sink_number=param_sink_number,
|
| 292 |
+
+ softmax_scale=self.scale)
|
| 293 |
+
+ else:
|
| 294 |
+
+ torch_npu._npu_paged_attention(
|
| 295 |
+
+ query=query,
|
| 296 |
+
+ key_cache=self.key_cache,
|
| 297 |
+
+ value_cache=self.value_cache,
|
| 298 |
+
+ num_kv_heads=self.num_kv_heads,
|
| 299 |
+
+ num_heads=self.num_heads,
|
| 300 |
+
+ scale_value=self.scale,
|
| 301 |
+
+ block_table=block_tables,
|
| 302 |
+
+ context_lens=seq_kvlens,
|
| 303 |
+
+ out=output)
|
| 304 |
+
return output
|
| 305 |
+
|
| 306 |
+
def _forward_v1_style(
|
| 307 |
+
@@ -443,6 +582,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
| 308 |
+
query: torch.Tensor,
|
| 309 |
+
attn_metadata: AscendMetadata,
|
| 310 |
+
output: Optional[torch.Tensor] = None,
|
| 311 |
+
+ param_sink_number: Optional[int] = 0
|
| 312 |
+
) -> torch.Tensor:
|
| 313 |
+
# Use chunked prefill for head size 192 scenario, like deepseek
|
| 314 |
+
# paged_attention_splitfuse maybe crash at such scenario.
|
| 315 |
+
@@ -485,34 +625,87 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
| 316 |
+
value = self.value_cache.view( # type: ignore
|
| 317 |
+
num_block, block_size, -1)
|
| 318 |
+
|
| 319 |
+
- output, _ = torch_npu.npu_fused_infer_attention_score(
|
| 320 |
+
- query=query,
|
| 321 |
+
- key=key,
|
| 322 |
+
- value=value,
|
| 323 |
+
- atten_mask=attn_metadata.attn_mask,
|
| 324 |
+
- block_table=attn_metadata.block_tables,
|
| 325 |
+
- input_layout="TND",
|
| 326 |
+
- block_size=block_size,
|
| 327 |
+
- actual_seq_lengths=attn_metadata.query_start_loc[1:],
|
| 328 |
+
- actual_seq_lengths_kv=attn_metadata.seq_lens,
|
| 329 |
+
- num_key_value_heads=self.num_kv_heads,
|
| 330 |
+
- num_heads=self.num_heads,
|
| 331 |
+
- scale=self.scale,
|
| 332 |
+
- sparse_mode=3,
|
| 333 |
+
- )
|
| 334 |
+
+ #TODO: swa层,window长度 传参
|
| 335 |
+
+ use_swa = (self.layer_idx % 2 == 0)
|
| 336 |
+
+ sparse_mode = 4 if use_swa else 3
|
| 337 |
+
+ if param_sink_number > 0:
|
| 338 |
+
+ if sparse_mode == 4:
|
| 339 |
+
+ output, _ = torch.ops.custom.npu_fused_infer_attention_sink(
|
| 340 |
+
+ query,
|
| 341 |
+
+ key,
|
| 342 |
+
+ value,
|
| 343 |
+
+ atten_mask=self.attn_mask,
|
| 344 |
+
+ actual_seq_qlen=attn_metadata.swa_seq_qlens,
|
| 345 |
+
+ actual_seq_kvlen=attn_metadata.sink_seq_kvlens,
|
| 346 |
+
+ block_table=attn_metadata.sink_block_tables,
|
| 347 |
+
+ pre_tokens=128,
|
| 348 |
+
+ next_tokens=0,
|
| 349 |
+
+ num_query_heads=self.num_heads,
|
| 350 |
+
+ num_key_value_heads=self.num_kv_heads,
|
| 351 |
+
+ input_layout='TND',
|
| 352 |
+
+ sparse_mode=4,
|
| 353 |
+
+ block_size=block_size,
|
| 354 |
+
+ sink_number=param_sink_number,
|
| 355 |
+
+ softmax_scale=self.scale
|
| 356 |
+
+ )
|
| 357 |
+
+ elif sparse_mode == 3:
|
| 358 |
+
+ output, _ = torch.ops.custom.npu_fused_infer_attention_sink(
|
| 359 |
+
+ query,
|
| 360 |
+
+ key,
|
| 361 |
+
+ value,
|
| 362 |
+
+ atten_mask=self.attn_mask,
|
| 363 |
+
+ actual_seq_qlen=attn_metadata.swa_seq_qlens,
|
| 364 |
+
+ actual_seq_kvlen=attn_metadata.sink_seq_kvlens,
|
| 365 |
+
+ block_table=attn_metadata.sink_block_tables,
|
| 366 |
+
+ num_query_heads=self.num_heads,
|
| 367 |
+
+ num_key_value_heads=self.num_kv_heads,
|
| 368 |
+
+ input_layout='TND',
|
| 369 |
+
+ sparse_mode=3,
|
| 370 |
+
+ block_size=block_size,
|
| 371 |
+
+ sink_number=param_sink_number,
|
| 372 |
+
+ softmax_scale=self.scale
|
| 373 |
+
+ )
|
| 374 |
+
+
|
| 375 |
+
+ else:
|
| 376 |
+
+ output, _ = torch_npu.npu_fused_infer_attention_score(
|
| 377 |
+
+ query=query,
|
| 378 |
+
+ key=key,
|
| 379 |
+
+ value=value,
|
| 380 |
+
+ atten_mask=attn_metadata.attn_mask,
|
| 381 |
+
+ block_table=attn_metadata.block_tables,
|
| 382 |
+
+ input_layout="TND",
|
| 383 |
+
+ block_size=block_size,
|
| 384 |
+
+ actual_seq_lengths=attn_metadata.query_start_loc[1:],
|
| 385 |
+
+ actual_seq_lengths_kv=attn_metadata.seq_lens,
|
| 386 |
+
+ num_key_value_heads=self.num_kv_heads,
|
| 387 |
+
+ num_heads=self.num_heads,
|
| 388 |
+
+ scale=self.scale,
|
| 389 |
+
+ sparse_mode=3,
|
| 390 |
+
+ )
|
| 391 |
+
else:
|
| 392 |
+
+ # Patch for sink on CANN 8.2
|
| 393 |
+
+ if param_sink_number > 0:
|
| 394 |
+
+ seq_kvlens = attn_metadata.seq_lens + param_sink_number
|
| 395 |
+
+ block_tables = F.pad(attn_metadata.block_tables, (1, 0, 0, 0), value=0)
|
| 396 |
+
+ mask = F.pad(attn_metadata.attn_mask, (param_sink_number, 0, 0, 0), value=0)
|
| 397 |
+
+ else:
|
| 398 |
+
+ seq_kvlens = attn_metadata.seq_lens
|
| 399 |
+
+ block_tables = attn_metadata.block_tables
|
| 400 |
+
+ mask = attn_metadata.attn_mask
|
| 401 |
+
+
|
| 402 |
+
torch_npu._npu_paged_attention_splitfuse(
|
| 403 |
+
query=query,
|
| 404 |
+
key_cache=self.key_cache,
|
| 405 |
+
value_cache=self.value_cache,
|
| 406 |
+
- mask=attn_metadata.attn_mask,
|
| 407 |
+
- block_table=attn_metadata.block_tables,
|
| 408 |
+
+ mask=mask,
|
| 409 |
+
+ block_table=block_tables,
|
| 410 |
+
seq_len=attn_metadata.query_lens,
|
| 411 |
+
- context_lens=attn_metadata.seq_lens,
|
| 412 |
+
+ context_lens=seq_kvlens,
|
| 413 |
+
num_kv_heads=self.num_kv_heads,
|
| 414 |
+
num_heads=self.num_heads,
|
| 415 |
+
scale_value=self.scale,
|
| 416 |
+
out=output)
|
| 417 |
+
+
|
| 418 |
+
return output
|
| 419 |
+
|
| 420 |
+
def forward(
|
| 421 |
+
@@ -525,6 +718,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
| 422 |
+
attn_metadata: AscendMetadata,
|
| 423 |
+
output: Optional[torch.Tensor] = None,
|
| 424 |
+
trace_flag: bool = True,
|
| 425 |
+
+ sink_query: Optional[torch.Tensor] = None,
|
| 426 |
+
+ sink_key: Optional[torch.Tensor] = None,
|
| 427 |
+
+ sink_value: Optional[torch.Tensor] = None,
|
| 428 |
+
+ v_head_size: Optional[int] = None,
|
| 429 |
+
) -> torch.Tensor:
|
| 430 |
+
"""Forward pass with Ascend attention.
|
| 431 |
+
Args:
|
| 432 |
+
@@ -556,7 +753,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
| 433 |
+
key=key,
|
| 434 |
+
value=value,
|
| 435 |
+
output=output,
|
| 436 |
+
- layer_name=layer.layer_name)
|
| 437 |
+
+ layer_name=layer.layer_name,
|
| 438 |
+
+ sink_query=sink_query,
|
| 439 |
+
+ sink_key=sink_key,
|
| 440 |
+
+ sink_value=sink_value,
|
| 441 |
+
+ v_head_size=v_head_size
|
| 442 |
+
+ )
|
| 443 |
+
|
| 444 |
+
elif hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
| 445 |
+
output = layer.quant_method.apply(layer, query, key, value,
|
| 446 |
+
@@ -575,10 +777,13 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
| 447 |
+
"encoder/decoder cross-attention "
|
| 448 |
+
"are not implemented for "
|
| 449 |
+
"PallasAttentionBackendImpl")
|
| 450 |
+
+ sink_key_flag = (sink_key is not None)
|
| 451 |
+
+ param_sink_number = sink_key.shape[0] if sink_key_flag else 0
|
| 452 |
+
# View q k v to BSH.
|
| 453 |
+
query = query.view(-1, self.num_heads, self.head_size)
|
| 454 |
+
key = key.view(-1, self.num_kv_heads, self.head_size)
|
| 455 |
+
- value = value.view(-1, self.num_kv_heads, self.head_size)
|
| 456 |
+
+ value = value.view(-1, self.num_kv_heads,
|
| 457 |
+
+ v_head_size if v_head_size is not None else self.head_size)
|
| 458 |
+
# TODO: Remove this contiguous in the future.
|
| 459 |
+
value = value.contiguous()
|
| 460 |
+
|
| 461 |
+
@@ -586,33 +791,63 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
| 462 |
+
if self.key_cache is None:
|
| 463 |
+
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
| 464 |
+
slots = attn_metadata.slot_mapping
|
| 465 |
+
+
|
| 466 |
+
torch_npu._npu_reshape_and_cache(
|
| 467 |
+
key=key[:num_actual_tokens],
|
| 468 |
+
value=value[:num_actual_tokens],
|
| 469 |
+
key_cache=self.key_cache,
|
| 470 |
+
value_cache=self.value_cache,
|
| 471 |
+
slot_indices=slots)
|
| 472 |
+
-
|
| 473 |
+
+ if sink_key_flag and not self.sink_cached:
|
| 474 |
+
+ # kv cache start from block 1 and slots 128, so we store sink in block 0.
|
| 475 |
+
+ slots = torch.arange(0, param_sink_number,
|
| 476 |
+
+ dtype=attn_metadata.slot_mapping.dtype,
|
| 477 |
+
+ device=attn_metadata.slot_mapping.device)
|
| 478 |
+
+ torch_npu._npu_reshape_and_cache(
|
| 479 |
+
+ key=sink_key,
|
| 480 |
+
+ value=sink_value,
|
| 481 |
+
+ key_cache=self.key_cache,
|
| 482 |
+
+ value_cache=self.value_cache,
|
| 483 |
+
+ slot_indices=slots)
|
| 484 |
+
+ self.sink_cached = True
|
| 485 |
+
+
|
| 486 |
+
+ # TODO: 暂不进PrefillCacheHit分支,不更新sink实现
|
| 487 |
+
+ if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache and sink_key_flag:
|
| 488 |
+
+ attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
| 489 |
+
# V0-Style scheduler situation.
|
| 490 |
+
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
| 491 |
+
+ if torch.version.cann.startswith("8.3"):
|
| 492 |
+
+ # npu_fused_infer_attention_score and npu_fused_infer_attention_sink
|
| 493 |
+
+ # does not support cases where query.shape[0] != actual_seq_lengths
|
| 494 |
+
+ # Thus we need unpad it here.
|
| 495 |
+
+ num_tokens = attn_metadata.query_start_loc[-1]
|
| 496 |
+
+ query = query[:num_tokens]
|
| 497 |
+
+ key = key[:num_tokens]
|
| 498 |
+
+ value = value[:num_tokens]
|
| 499 |
+
+ elif sink_key_flag:
|
| 500 |
+
+ query = torch.cat([sink_query, query], dim=0)
|
| 501 |
+
+ if sink_key_flag:
|
| 502 |
+
+ key = torch.cat([sink_key, key], dim=0)
|
| 503 |
+
+ value = torch.cat([sink_value, value], dim=0)
|
| 504 |
+
output = self._forward_prefill_no_cache(
|
| 505 |
+
- query, key, value, attn_metadata, output, num_tokens)
|
| 506 |
+
+ query, key, value, attn_metadata, output, num_tokens,
|
| 507 |
+
+ param_sink_number
|
| 508 |
+
+ )
|
| 509 |
+
elif attn_metadata.attn_state == \
|
| 510 |
+
AscendAttentionState.PrefillCacheHit:
|
| 511 |
+
output = self._forward_prefill_cache_hit(
|
| 512 |
+
query, attn_metadata, output)
|
| 513 |
+
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
| 514 |
+
output = self._forward_decode_only(query, attn_metadata,
|
| 515 |
+
- output)
|
| 516 |
+
+ output, layer,
|
| 517 |
+
+ param_sink_number)
|
| 518 |
+
# Normal V1 situation.
|
| 519 |
+
else:
|
| 520 |
+
if torch.version.cann.startswith("8.3"):
|
| 521 |
+
- # npu_fused_infer_attention_score does not support cases
|
| 522 |
+
- # where query.shape[0] != attn_metadata.query_start_loc[-1].
|
| 523 |
+
- # Thus we need unpad it here.
|
| 524 |
+
num_tokens = attn_metadata.query_start_loc[-1]
|
| 525 |
+
query = query[:num_tokens]
|
| 526 |
+
- output = self._forward_v1_style(query, attn_metadata, output)
|
| 527 |
+
+ output = self._forward_v1_style(query, attn_metadata, output,
|
| 528 |
+
+ param_sink_number)
|
| 529 |
+
|
| 530 |
+
# to make in-place change to the output tensor
|
| 531 |
+
if hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
| 532 |
+
@@ -627,6 +862,10 @@ def unified_ascend_attention_with_output(
|
| 533 |
+
value: torch.Tensor,
|
| 534 |
+
output: torch.Tensor,
|
| 535 |
+
layer_name: str,
|
| 536 |
+
+ sink_query: Optional[torch.Tensor] = None,
|
| 537 |
+
+ sink_key: Optional[torch.Tensor] = None,
|
| 538 |
+
+ sink_value: Optional[torch.Tensor] = None,
|
| 539 |
+
+ v_head_size: Optional[int] = None,
|
| 540 |
+
) -> None:
|
| 541 |
+
wait_for_kv_layer_from_connector(layer_name)
|
| 542 |
+
forward_context: ForwardContext = get_forward_context()
|
| 543 |
+
@@ -642,7 +881,11 @@ def unified_ascend_attention_with_output(
|
| 544 |
+
kv_cache,
|
| 545 |
+
attn_metadata,
|
| 546 |
+
output,
|
| 547 |
+
- trace_flag=False)
|
| 548 |
+
+ trace_flag=False,
|
| 549 |
+
+ sink_query=sink_query,
|
| 550 |
+
+ sink_key=sink_key,
|
| 551 |
+
+ sink_value=sink_value,
|
| 552 |
+
+ v_head_size=v_head_size)
|
| 553 |
+
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
| 554 |
+
return
|
| 555 |
+
|
| 556 |
+
@@ -653,6 +896,11 @@ def unified_attention_with_output_fake(
|
| 557 |
+
value: torch.Tensor,
|
| 558 |
+
output: torch.Tensor,
|
| 559 |
+
layer_name: str,
|
| 560 |
+
+ # patch for pangu with attention sink
|
| 561 |
+
+ sink_query: Optional[torch.Tensor] = None,
|
| 562 |
+
+ sink_key: Optional[torch.Tensor] = None,
|
| 563 |
+
+ sink_value: Optional[torch.Tensor] = None,
|
| 564 |
+
+ v_head_size: Optional[int] = None,
|
| 565 |
+
) -> None:
|
| 566 |
+
return
|
| 567 |
+
|
| 568 |
+
diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py
|
| 569 |
+
index 519cde0..93e1c95 100644
|
| 570 |
+
--- a/vllm_ascend/attention/utils.py
|
| 571 |
+
+++ b/vllm_ascend/attention/utils.py
|
| 572 |
+
@@ -63,6 +63,8 @@ class AscendCommonAttentionMetadata:
|
| 573 |
+
|
| 574 |
+
graph_pad_size: int = -1
|
| 575 |
+
|
| 576 |
+
+ num_input_tokens: int = -1
|
| 577 |
+
+
|
| 578 |
+
|
| 579 |
+
def split_decodes_and_prefills(
|
| 580 |
+
common_attn_metadata: AscendCommonAttentionMetadata,
|
| 581 |
+
diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py
|
| 582 |
+
index f1581df..b690bcb 100644
|
| 583 |
+
--- a/vllm_ascend/platform.py
|
| 584 |
+
+++ b/vllm_ascend/platform.py
|
| 585 |
+
@@ -216,6 +216,9 @@ class NPUPlatform(Platform):
|
| 586 |
+
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
|
| 587 |
+
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
| 588 |
+
|
| 589 |
+
+ if compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
| 590 |
+
+ compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE_DECODE_ONLY
|
| 591 |
+
+
|
| 592 |
+
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
| 593 |
+
compilation_config.level = CompilationLevel.NO_COMPILATION
|
| 594 |
+
# TODO: Currently MLA does not support FULL_DECODE_ONLY, remove the second condition
|
| 595 |
+
@@ -223,7 +226,8 @@ class NPUPlatform(Platform):
|
| 596 |
+
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or (
|
| 597 |
+
compilation_config.cudagraph_mode
|
| 598 |
+
== CUDAGraphMode.FULL_DECODE_ONLY and model_config is not None
|
| 599 |
+
- and model_config.use_mla):
|
| 600 |
+
+ and model_config.use_mla) or (
|
| 601 |
+
+ compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE_DECODE_ONLY):
|
| 602 |
+
logger.info(
|
| 603 |
+
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
|
| 604 |
+
"using only ACL Graph mode")
|
| 605 |
+
@@ -232,7 +236,8 @@ class NPUPlatform(Platform):
|
| 606 |
+
compilation_config.set_splitting_ops_for_v1()
|
| 607 |
+
compilation_config.use_inductor = False
|
| 608 |
+
compilation_config.splitting_ops.extend([
|
| 609 |
+
- "vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
|
| 610 |
+
+ "vllm.unified_ascend_attention_with_output", "vllm.mla_forward",
|
| 611 |
+
+ "vllm.aggregate_hiddden",
|
| 612 |
+
])
|
| 613 |
+
update_aclgraph_sizes(vllm_config)
|
| 614 |
+
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
| 615 |
+
diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py
|
| 616 |
+
index 9281dd7..34808ec 100644
|
| 617 |
+
--- a/vllm_ascend/worker/model_runner_v1.py
|
| 618 |
+
+++ b/vllm_ascend/worker/model_runner_v1.py
|
| 619 |
+
@@ -281,6 +281,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
| 620 |
+
self.encoder_cache: Dict[str, torch.Tensor] = {}
|
| 621 |
+
self.attn_mask = None
|
| 622 |
+
self.attn_state = None
|
| 623 |
+
+ self.with_prefill = False
|
| 624 |
+
self.requests: Dict[str, CachedRequestState] = {}
|
| 625 |
+
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
| 626 |
+
self.runner_only_attn_layers: set[str] = set()
|
| 627 |
+
@@ -509,6 +510,48 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
| 628 |
+
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
|
| 629 |
+
dtype=torch.int32)
|
| 630 |
+
|
| 631 |
+
+ # Patch for conv cache
|
| 632 |
+
+ self.router_sliding_window = getattr(self.model_config.hf_text_config, "router_sliding_window", 0)
|
| 633 |
+
+ if self.router_sliding_window > 1:
|
| 634 |
+
+ self.cache_length = self.router_sliding_window - 1
|
| 635 |
+
+ self.req_cache_map = {}
|
| 636 |
+
+ self.occupied_cache = [0]*(self.max_num_reqs)
|
| 637 |
+
+ self.q_offsets = torch.arange(-self.cache_length, 0, device=self.device)
|
| 638 |
+
+ self.cache_slot_id = torch.empty(self.max_num_reqs,
|
| 639 |
+
+ dtype=torch.long, device=self.device)
|
| 640 |
+
+ self.is_first_chunk = torch.empty(self.max_num_reqs, dtype=torch.bool, device=self.device) # For chunked prefill
|
| 641 |
+
+
|
| 642 |
+
+ def _build_conv_context(self, with_prefill:bool = False, dummy:bool = False, num_tokens:int = 0):
|
| 643 |
+
+ # conv cache slot & prefill hiddenstates loc
|
| 644 |
+
+ cache_slot_id = self.cache_slot_id[:self.input_batch.num_reqs]
|
| 645 |
+
+ query_start_loc = self.query_start_loc[:self.input_batch.num_reqs + 1]
|
| 646 |
+
+ is_first_chunk = self.is_first_chunk[:self.input_batch.num_reqs]
|
| 647 |
+
+
|
| 648 |
+
+ if with_prefill:
|
| 649 |
+
+ for idx, req_id in enumerate(self.input_batch.req_ids):
|
| 650 |
+
+ if req_id in self.req_cache_map:
|
| 651 |
+
+ cache_id = self.req_cache_map[req_id]
|
| 652 |
+
+ cache_slot_id[idx] = cache_id
|
| 653 |
+
+ is_first_chunk[idx] = False
|
| 654 |
+
+ else:
|
| 655 |
+
+ # new request with the first chunk
|
| 656 |
+
+ new_cahce_id = self.occupied_cache.index(0)
|
| 657 |
+
+ self.occupied_cache[new_cahce_id] = 1
|
| 658 |
+
+ self.req_cache_map[req_id] = new_cahce_id
|
| 659 |
+
+ cache_slot_id[idx] = new_cahce_id
|
| 660 |
+
+ is_first_chunk[idx] = True
|
| 661 |
+
+ else:
|
| 662 |
+
+ for idx, req_id in enumerate(self.input_batch.req_ids):
|
| 663 |
+
+ cache_id = self.req_cache_map[req_id]
|
| 664 |
+
+ cache_slot_id[idx] = cache_id
|
| 665 |
+
+ is_first_chunk[idx] = False
|
| 666 |
+
+
|
| 667 |
+
+ forward_context = get_forward_context()
|
| 668 |
+
+ forward_context.cache_slot_id = cache_slot_id
|
| 669 |
+
+ forward_context.is_first_chunk = is_first_chunk
|
| 670 |
+
+ forward_context.query_start_loc = query_start_loc
|
| 671 |
+
+
|
| 672 |
+
+
|
| 673 |
+
def _make_buffer(self,
|
| 674 |
+
*size: Union[int, torch.SymInt],
|
| 675 |
+
dtype: torch.dtype,
|
| 676 |
+
@@ -548,12 +591,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
| 677 |
+
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
|
| 678 |
+
|
| 679 |
+
def _use_aclgraph(self) -> bool:
|
| 680 |
+
- return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
|
| 681 |
+
+ return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and \
|
| 682 |
+
+ self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
|
| 683 |
+
|
| 684 |
+
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
| 685 |
+
# Remove finished requests from the cached states.
|
| 686 |
+
for req_id in scheduler_output.finished_req_ids:
|
| 687 |
+
- self.requests.pop(req_id, None)
|
| 688 |
+
+ self.requests.pop(req_id, None)
|
| 689 |
+
+ if self.router_sliding_window > 1 and req_id in self.req_cache_map:
|
| 690 |
+
+ cache_id = self.req_cache_map.pop(req_id)
|
| 691 |
+
+ self.occupied_cache[cache_id] = 0
|
| 692 |
+
|
| 693 |
+
# Remove the finished requests from the persistent batch.
|
| 694 |
+
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
| 695 |
+
@@ -891,7 +938,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
| 696 |
+
def _make_attention_mask(self, seq_lens, position,
|
| 697 |
+
attn_state) -> torch.Tensor:
|
| 698 |
+
# Chunk Prefill situation.
|
| 699 |
+
- if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa:
|
| 700 |
+
+ if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not \
|
| 701 |
+
+ self.ascend_config.use_sfa:
|
| 702 |
+
if torch.version.cann.startswith("8.3"):
|
| 703 |
+
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
| 704 |
+
else:
|
| 705 |
+
@@ -942,7 +990,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
| 706 |
+
src_end = num_computed_tokens + prompt_part_len
|
| 707 |
+
|
| 708 |
+
self.mrope_positions_cpu[:, dst_start:dst_end] = \
|
| 709 |
+
- req.mrope_positions[:,src_start:src_end]
|
| 710 |
+
+ req.mrope_positions[:, src_start:src_end]
|
| 711 |
+
|
| 712 |
+
mrope_pos_ptr += prompt_part_len
|
| 713 |
+
|
| 714 |
+
@@ -1126,9 +1174,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
| 715 |
+
cumsum_dtype: Optional[np.dtype] = None,
|
| 716 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 717 |
+
"""Get the cumulative sum and batched arange of the given array.
|
| 718 |
+
- # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
|
| 719 |
+
- # Equivalent to but faster than:
|
| 720 |
+
- # np.concatenate([np.arange(n) for n in num_tokens])
|
| 721 |
+
+ E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
|
| 722 |
+
+ Equivalent to but faster than:
|
| 723 |
+
+ np.concatenate([np.arange(n) for n in num_tokens])
|
| 724 |
+
"""
|
| 725 |
+
# Step 1. [2, 5, 3] -> [2, 7, 10]
|
| 726 |
+
cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
|
| 727 |
+
@@ -1518,6 +1566,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
| 728 |
+
max_query_len=max_num_scheduled_tokens,
|
| 729 |
+
graph_pad_size=self.graph_pad_size,
|
| 730 |
+
decode_token_per_req=self.decode_token_per_req,
|
| 731 |
+
+ num_input_tokens=num_input_tokens
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
if self.speculative_config and \
|
| 735 |
+
@@ -1964,6 +2013,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
| 736 |
+
model_instance=self.model):
|
| 737 |
+
self.maybe_setup_kv_connector(scheduler_output)
|
| 738 |
+
|
| 739 |
+
+ if self.router_sliding_window > 1:
|
| 740 |
+
+ self._build_conv_context(self.with_prefill)
|
| 741 |
+
+
|
| 742 |
+
hidden_states = self._generate_process_reqs_hidden_states(
|
| 743 |
+
attn_metadata, self.with_prefill, maybe_padded_num_tokens,
|
| 744 |
+
input_ids, positions, intermediate_tensors, inputs_embeds)
|
| 745 |
+
@@ -2339,7 +2391,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
| 746 |
+
) -> torch.Tensor:
|
| 747 |
+
# only support eager mode and piecewise graph now
|
| 748 |
+
assert aclgraph_runtime_mode in {
|
| 749 |
+
- CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
| 750 |
+
+ CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL,
|
| 751 |
+
+ CUDAGraphMode.PIECEWISE_DECODE_ONLY
|
| 752 |
+
}
|
| 753 |
+
|
| 754 |
+
# Padding for DP
|
| 755 |
+
@@ -2472,6 +2525,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
| 756 |
+
batch_descriptor=batch_descriptor,
|
| 757 |
+
prefetch_stream=self.prefetch_stream,
|
| 758 |
+
model_instance=self.model):
|
| 759 |
+
+ if self.router_sliding_window > 1:
|
| 760 |
+
+ self._build_conv_context(with_prefill, dummy=True, num_tokens=num_tokens)
|
| 761 |
+
hidden_states = self._generate_dummy_run_hidden_states(
|
| 762 |
+
with_prefill, is_torchair_compile, input_ids, positions,
|
| 763 |
+
attn_metadata, num_tokens, intermediate_tensors,
|
| 764 |
+
@@ -2789,8 +2844,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
| 765 |
+
|
| 766 |
+
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
|
| 767 |
+
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
|
| 768 |
+
- # we found there are also some exceptions during test, so we manual align those memory here, this part
|
| 769 |
+
- # of code may consume 2M * 2 * elem_size memory every layer.
|
| 770 |
+
+ # we found there are also some exceptions during test, so we manual align those memory here,
|
| 771 |
+
+ # this part of code may consume 2M * 2 * elem_size memory every layer.
|
| 772 |
+
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
|
| 773 |
+
nope_allocate_shape_alignment = nope_allocate_shape + alignment
|
| 774 |
+
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
|
| 775 |
+
@@ -2888,8 +2943,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
| 776 |
+
|
| 777 |
+
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
|
| 778 |
+
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
|
| 779 |
+
- # we found there are also some exceptions during test, so we manual align those memory here, this part
|
| 780 |
+
- # of code may consume 2M * 2 * elem_size memory every layer.
|
| 781 |
+
+ # we found there are also some exceptions during test, so we manual align those memory here,
|
| 782 |
+
+ # this part of code may consume 2M * 2 * elem_size memory every layer.
|
| 783 |
+
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
|
| 784 |
+
nope_allocate_shape_alignment = nope_allocate_shape + alignment
|
| 785 |
+
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
|
| 786 |
+
@@ -3432,6 +3487,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
| 787 |
+
and all(op in self.compilation_config.splitting_ops for op in [
|
| 788 |
+
"vllm.unified_ascend_attention_with_output",
|
| 789 |
+
"vllm.mla_forward",
|
| 790 |
+
+ "vllm.aggregate_hiddden",
|
| 791 |
+
]))
|
| 792 |
+
|
| 793 |
+
# Flexible resolve the aclgraph mode
|
| 794 |
+
@@ -3495,7 +3551,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
| 795 |
+
uniform_decode: bool):
|
| 796 |
+
assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \
|
| 797 |
+
aclgraph_runtime_mode in [CUDAGraphMode.FULL,
|
| 798 |
+
- CUDAGraphMode.PIECEWISE]
|
| 799 |
+
+ CUDAGraphMode.PIECEWISE,
|
| 800 |
+
+ CUDAGraphMode.PIECEWISE_DECODE_ONLY]
|
| 801 |
+
|
| 802 |
+
# Only rank 0 should print progress bar during capture
|
| 803 |
+
if is_global_first_rank():
|
| 804 |
+
@@ -3519,10 +3576,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
| 805 |
+
# attention while `PIECEWISE` implies no attention.
|
| 806 |
+
force_attention = (aclgraph_runtime_mode == CUDAGraphMode.FULL)
|
| 807 |
+
self._dummy_run(num_tokens,
|
| 808 |
+
+ with_prefill = (uniform_decode == False),
|
| 809 |
+
aclgraph_runtime_mode=CUDAGraphMode.NONE,
|
| 810 |
+
force_attention=force_attention,
|
| 811 |
+
uniform_decode=uniform_decode)
|
| 812 |
+
self._dummy_run(num_tokens,
|
| 813 |
+
+ with_prefill = (uniform_decode == False),
|
| 814 |
+
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
| 815 |
+
force_attention=force_attention,
|
| 816 |
+
uniform_decode=uniform_decode)
|
| 817 |
+
@@ -3556,7 +3615,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
| 818 |
+
logger.error(
|
| 819 |
+
f"ACLgraph sizes capture fail: {type(e).__name__}:\n"
|
| 820 |
+
"ACLgraph has insufficient available streams to capture the configured number of sizes. "
|
| 821 |
+
- "Please verify both the availability of adequate streams and the appropriateness of the configured size count.\n\n"
|
| 822 |
+
+ "Please verify both the availability of adequate streams "
|
| 823 |
+
+ "and the appropriateness of the configured size count.\n\n"
|
| 824 |
+
"Recommended solutions:\n"
|
| 825 |
+
"1. Manually configure the compilation_config parameter "
|
| 826 |
+
"with a reduced set of sizes: '{\"cudagraph_capture_sizes\":[size1, size2, size3, ...]}'.\n"
|
| 827 |
+
@@ -3564,8 +3624,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
| 828 |
+
f"{str(e)}")
|
| 829 |
+
raise
|
| 830 |
+
|
| 831 |
+
- if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
|
| 832 |
+
- aclgraph_mode.separate_routine():
|
| 833 |
+
+ if aclgraph_mode.separate_routine() and \
|
| 834 |
+
+ (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL or \
|
| 835 |
+
+ aclgraph_mode.decode_mode() == CUDAGraphMode.PIECEWISE):
|
| 836 |
+
max_num_tokens = self.scheduler_config.max_num_seqs * \
|
| 837 |
+
self.uniform_decode_query_len
|
| 838 |
+
decode_cudagraph_batch_sizes = [
|
| 839 |
+
@@ -3576,7 +3637,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
| 840 |
+
reversed(decode_cudagraph_batch_sizes))
|
| 841 |
+
self._capture_aclgraphs(
|
| 842 |
+
compilation_cases=compilation_cases_decode,
|
| 843 |
+
- aclgraph_runtime_mode=CUDAGraphMode.FULL,
|
| 844 |
+
+ aclgraph_runtime_mode=aclgraph_mode.decode_mode(),
|
| 845 |
+
uniform_decode=True)
|
| 846 |
+
|
| 847 |
+
# Disable aclgraph capturing globally, so any unexpected aclgraph
|