drizzlezyk commited on
Commit
cb4f943
·
verified ·
1 Parent(s): e19cb2e

Upload inference/vllm-ascend_v0.11.0rc0.patch with huggingface_hub

Browse files
inference/vllm-ascend_v0.11.0rc0.patch ADDED
@@ -0,0 +1,847 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py
2
+ index d289bb4..0357b50 100644
3
+ --- a/vllm_ascend/attention/attention_v1.py
4
+ +++ b/vllm_ascend/attention/attention_v1.py
5
+ @@ -21,6 +21,7 @@ from typing import ClassVar, List, Optional, Tuple, Type
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ +import torch.nn.functional as F
10
+ import torch_npu
11
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
12
+ AttentionLayer, AttentionType)
13
+ @@ -30,6 +31,7 @@ from vllm.utils import cdiv, direct_register_custom_op
14
+ from vllm.v1.attention.backends.utils import AttentionCGSupport
15
+ from vllm.v1.core.sched.output import SchedulerOutput
16
+ from vllm.v1.kv_cache_interface import AttentionSpec
17
+ +from vllm.model_executor.models.utils import extract_layer_index
18
+
19
+ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
20
+ maybe_save_kv_layer_to_connector,
21
+ @@ -39,6 +41,9 @@ from vllm_ascend.ops.attention import vanilla_chunked_prefill
22
+ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
23
+ nd_to_nz_2d, nd_to_nz_spec)
24
+
25
+ +if torch.version.cann.startswith("8.3"):
26
+ + import omni_custom_ops
27
+ +
28
+
29
+ class AscendAttentionBackend(AttentionBackend):
30
+ accept_output_buffer: bool = True
31
+ @@ -115,6 +120,7 @@ class AscendAttentionBackend(AttentionBackend):
32
+ return [64]
33
+
34
+
35
+ +
36
+ class AscendAttentionState(Enum):
37
+ PrefillNoCache = 0
38
+ PrefillCacheHit = 1
39
+ @@ -135,8 +141,8 @@ class AscendMetadata:
40
+ num_actual_tokens: int = 0
41
+
42
+ # The sequence length per sequence. Sequence length means the computed
43
+ - # tokens + new tokens (is None if it is a decoding).
44
+ - # (batch_size,)
45
+ + # tokens + new tokens (is None if it is a decoding).(batch_size,)
46
+ +
47
+ seq_lens: torch.Tensor = None
48
+
49
+ query_start_loc: torch.Tensor = None
50
+ @@ -145,20 +151,25 @@ class AscendMetadata:
51
+ max_query_len: Optional[int] = None
52
+
53
+ # ********************** KV Cache Related Properties ********************* #
54
+ - # Block addresses per sequence (Seq id -> list of physical block).
55
+ - # (batch_size, max_blocks_per_seq)
56
+ + # Block addresses per sequence (Seq id -> list of physical block).(batch_size, max_blocks_per_seq)
57
+ +
58
+ block_tables: torch.Tensor = None
59
+
60
+ # The indices of the token slots that input tokens will be stored into.
61
+ # E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the
62
+ # three tokens are stored in the 3rd slot in block 2, 2nd slot in block 0,
63
+ - # and 1st slot in block 1, respectively.
64
+ - # (num_tokens,)
65
+ + # and 1st slot in block 1, respectively. (num_tokens,)
66
+ +
67
+ slot_mapping: torch.Tensor = None
68
+
69
+ # *************************** Other Properties *************************** #
70
+ enable_dbo_across_dp: bool = False
71
+
72
+ + # Patch for param sink
73
+ + sink_block_tables: Optional[List[torch.Tensor]] = None
74
+ + sink_attn_mask: Optional[torch.Tensor] = None
75
+ + sink_seq_kvlens: torch.Tensor = None
76
+ + swa_seq_qlens: torch.Tensor = None
77
+
78
+ class AscendAttentionMetadataBuilder:
79
+ # Does this backend/builder support ACL Graphs for attention (default: no).
80
+ @@ -182,6 +193,7 @@ class AscendAttentionMetadataBuilder:
81
+ self.max_num_blocks_per_req = cdiv(
82
+ self.model_config.max_model_len,
83
+ AscendAttentionBackend.get_supported_block_size()[0])
84
+ + self.param_sink_number = self.model_config.hf_config.param_sink_number
85
+
86
+ def reorder_batch(self, input_batch,
87
+ scheduler_output: "SchedulerOutput") -> bool:
88
+ @@ -210,6 +222,33 @@ class AscendAttentionMetadataBuilder:
89
+ query_start_loc = query_start_loc_cpu.to(self.device,
90
+ non_blocking=True)
91
+
92
+ + num_input_tokens = common_attn_metadata.num_input_tokens
93
+ +
94
+ +
95
+ + if num_input_tokens > num_reqs and attn_state == AscendAttentionState.DecodeOnly:
96
+ + tokens_gap_num = num_input_tokens-num_reqs
97
+ +
98
+ + sink_block_tables = F.pad(block_table, (1, 0, 0, tokens_gap_num), value=0)
99
+ +
100
+ + sink_seq_kvlens = seq_lens + self.param_sink_number
101
+ + sink_seq_kvlens = torch.cat([sink_seq_kvlens, torch.full((tokens_gap_num,), \
102
+ + self.param_sink_number, dtype=torch.int32)], dim=0)
103
+ +
104
+ + gap_query_lens = torch.cat([query_lens, torch.ones(tokens_gap_num, dtype=torch.int32)], dim=0)
105
+ + swa_seq_qlens = torch.cumsum(gap_query_lens, dim=0).to(dtype=torch.int32)
106
+ + else:
107
+ + sink_block_tables = F.pad(block_table, (1, 0, 0, 0), value=0)
108
+ + sink_seq_kvlens = seq_lens + self.param_sink_number
109
+ + swa_seq_qlens = torch.cumsum(query_lens, dim=0).to(dtype=torch.int32)
110
+ +
111
+ +
112
+ + if attn_mask is not None:
113
+ + sink_attn_mask = F.pad(attn_mask, (self.param_sink_number, 0, 0, 0), value=0)
114
+ + else:
115
+ + sink_attn_mask = None
116
+ +
117
+ +
118
+ +
119
+ if is_310p():
120
+ if attn_state == AscendAttentionState.PrefillNoCache:
121
+ mask_nz = nd_to_nz_2d(attn_mask)
122
+ @@ -230,7 +269,12 @@ class AscendAttentionMetadataBuilder:
123
+ slot_mapping=slot_mapping,
124
+ attn_mask=attn_mask,
125
+ attn_state=attn_state,
126
+ - enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
127
+ + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
128
+ + sink_block_tables=sink_block_tables,
129
+ + sink_attn_mask=sink_attn_mask,
130
+ + sink_seq_kvlens=sink_seq_kvlens,
131
+ + swa_seq_qlens=swa_seq_qlens
132
+ + )
133
+ return attn_metadata
134
+
135
+ def build_for_graph_capture(
136
+ @@ -265,6 +309,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
137
+ kv_cache_dtype: str,
138
+ logits_soft_cap: Optional[float],
139
+ attn_type: str,
140
+ + layer_name: str,
141
+ kv_sharing_target_layer_name: Optional[str],
142
+ **kwargs,
143
+ ) -> None:
144
+ @@ -287,6 +332,13 @@ class AscendAttentionBackendImpl(AttentionImpl):
145
+ self.key_cache = None
146
+ self.value_cache = None
147
+
148
+ + self.layer_idx = extract_layer_index(layer_name)
149
+ +
150
+ + # Patch for Sink
151
+ + self.sink_cached = False
152
+ + self.attn_mask = torch.ones((2048, 2048), dtype=torch.int8, device="npu").triu_(diagonal=1)
153
+ + self.attn_mask = self.attn_mask.to(torch.bool)
154
+ +
155
+ def _forward_prefill_no_cache(
156
+ self,
157
+ query: torch.Tensor,
158
+ @@ -295,6 +347,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
159
+ attn_metadata: AscendMetadata,
160
+ output: Optional[torch.Tensor] = None,
161
+ num_tokens=0,
162
+ + param_sink_number: Optional[int] = 0
163
+ ) -> torch.Tensor:
164
+ assert attn_metadata is not None
165
+ assert attn_metadata.attn_mask is not None
166
+ @@ -311,18 +364,72 @@ class AscendAttentionBackendImpl(AttentionImpl):
167
+ mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
168
+ mask = torch_npu.npu_format_cast(mask.contiguous(),
169
+ ACL_FORMAT_FRACTAL_NZ)
170
+ + if torch.version.cann.startswith("8.3"):
171
+ + mask = torch.ones((2048, 2048), dtype=torch.int8, device=mask.device).triu_(diagonal=1)
172
+ + # TODO: nocache swa
173
+ + if param_sink_number > 0:
174
+ + query_lens = attn_metadata.query_lens
175
+ + seq_lens = attn_metadata.seq_lens + param_sink_number
176
+ + output, _ = torch.ops.custom.npu_fused_infer_attention_sink(
177
+ + query,
178
+ + key,
179
+ + value,
180
+ + atten_mask=mask,
181
+ + actual_seq_qlen=query_lens,
182
+ + actual_seq_kvlen=seq_lens,
183
+ + num_query_heads=self.num_heads,
184
+ + num_key_value_heads=self.num_kv_heads,
185
+ + input_layout='TND',
186
+ + sparse_mode=3,
187
+ + sink_number=param_sink_number,
188
+ + softmax_scale=self.scale,
189
+ + )
190
+ + else:
191
+ + output, _ = torch_npu.npu_fused_infer_attention_score(
192
+ + query=query,
193
+ + key=key,
194
+ + value=value,
195
+ + atten_mask=mask,
196
+ + input_layout="TND",
197
+ + actual_seq_lengths=attn_metadata.query_start_loc[1:],
198
+ + actual_seq_lengths_kv=attn_metadata.seq_lens,
199
+ + num_key_value_heads=self.num_kv_heads,
200
+ + num_heads=self.num_heads,
201
+ + scale=self.scale,
202
+ + sparse_mode=3,
203
+ + )
204
+ + return output
205
+ + # Patch for sink on CANN8.2
206
+ + if param_sink_number > 0:
207
+ + seq_lens = attn_metadata.seq_lens + param_sink_number
208
+ + # TODO: _npu_flash_attention only allows qlen==kvlen,
209
+ + mask_elem = mask[0, -1]
210
+ + sink_mask = torch.full((mask.size(0) + param_sink_number,
211
+ + mask.size(1) + param_sink_number),
212
+ + mask_elem, dtype=mask.dtype, device=mask.device)
213
+ + sink_mask[param_sink_number:, :param_sink_number] = 0.0
214
+ + sink_mask[param_sink_number:, param_sink_number:] = mask
215
+ + sink_mask[:param_sink_number, :param_sink_number].triu_(diagonal=1)
216
+ + mask = sink_mask
217
+ +
218
+ + output = torch.zeros((output.size(0) + param_sink_number,
219
+ + output.size(1), output.size(2)),
220
+ + dtype=output.dtype,
221
+ + device=output.device)
222
+ + else:
223
+ + seq_lens = attn_metadata.seq_lens
224
+
225
+ torch_npu._npu_flash_attention(query=query,
226
+ key=key,
227
+ value=value,
228
+ mask=mask,
229
+ - seq_len=attn_metadata.seq_lens,
230
+ + seq_len=seq_lens,
231
+ scale_value=self.scale,
232
+ num_heads=self.num_heads,
233
+ num_kv_heads=self.num_kv_heads,
234
+ out=output)
235
+ assert output is not None
236
+ - return output[:num_tokens, :, :]
237
+ + return output[param_sink_number:param_sink_number + num_tokens, :, :]
238
+
239
+ def _forward_prefill_cache_hit(
240
+ self,
241
+ @@ -356,6 +463,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
242
+ query: torch.Tensor,
243
+ attn_metadata: AscendMetadata,
244
+ output: Optional[torch.Tensor] = None,
245
+ + layer: AttentionLayer = None,
246
+ + param_sink_number: Optional[int] = 0
247
+ ) -> torch.Tensor:
248
+ if is_310p():
249
+ # seq_lens_tensor needs to be transferred to the device for 310P.
250
+ @@ -426,16 +535,46 @@ class AscendAttentionBackendImpl(AttentionImpl):
251
+ handle = torch.npu.graph_task_group_end(stream)
252
+ graph_params.handles[num_tokens].append(handle)
253
+ else:
254
+ - torch_npu._npu_paged_attention(
255
+ - query=query,
256
+ - key_cache=self.key_cache,
257
+ - value_cache=self.value_cache,
258
+ - num_kv_heads=self.num_kv_heads,
259
+ - num_heads=self.num_heads,
260
+ - scale_value=self.scale,
261
+ - block_table=attn_metadata.block_tables,
262
+ - context_lens=attn_metadata.seq_lens,
263
+ - out=output)
264
+ + # Patch for Sparse KV cache of SWA.
265
+ + num_block, block_size, _, _ = self.key_cache.shape # type: ignore
266
+ + key = self.key_cache.view( # type: ignore
267
+ + num_block, block_size, -1)
268
+ + value = self.value_cache.view( # type: ignore
269
+ + num_block, block_size, -1)
270
+ + block_tables = attn_metadata.sink_block_tables
271
+ + use_swa = (self.layer_idx % 2 == 0)
272
+ + seq_kvlens = attn_metadata.sink_seq_kvlens
273
+ + if use_swa:
274
+ + attn_mask = self.attn_mask.to(query.device, non_blocking=True)
275
+ +
276
+ + output, _ = torch.ops.custom.npu_fused_infer_attention_sink(
277
+ + query,
278
+ + key,
279
+ + value,
280
+ + atten_mask=attn_mask,
281
+ + actual_seq_qlen=attn_metadata.swa_seq_qlens,
282
+ + actual_seq_kvlen=seq_kvlens,
283
+ + block_table=block_tables,
284
+ + pre_tokens=128,
285
+ + next_tokens=0,
286
+ + num_query_heads=self.num_heads,
287
+ + num_key_value_heads=self.num_kv_heads,
288
+ + input_layout='TND',
289
+ + sparse_mode=4,
290
+ + block_size=block_size,
291
+ + sink_number=param_sink_number,
292
+ + softmax_scale=self.scale)
293
+ + else:
294
+ + torch_npu._npu_paged_attention(
295
+ + query=query,
296
+ + key_cache=self.key_cache,
297
+ + value_cache=self.value_cache,
298
+ + num_kv_heads=self.num_kv_heads,
299
+ + num_heads=self.num_heads,
300
+ + scale_value=self.scale,
301
+ + block_table=block_tables,
302
+ + context_lens=seq_kvlens,
303
+ + out=output)
304
+ return output
305
+
306
+ def _forward_v1_style(
307
+ @@ -443,6 +582,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
308
+ query: torch.Tensor,
309
+ attn_metadata: AscendMetadata,
310
+ output: Optional[torch.Tensor] = None,
311
+ + param_sink_number: Optional[int] = 0
312
+ ) -> torch.Tensor:
313
+ # Use chunked prefill for head size 192 scenario, like deepseek
314
+ # paged_attention_splitfuse maybe crash at such scenario.
315
+ @@ -485,34 +625,87 @@ class AscendAttentionBackendImpl(AttentionImpl):
316
+ value = self.value_cache.view( # type: ignore
317
+ num_block, block_size, -1)
318
+
319
+ - output, _ = torch_npu.npu_fused_infer_attention_score(
320
+ - query=query,
321
+ - key=key,
322
+ - value=value,
323
+ - atten_mask=attn_metadata.attn_mask,
324
+ - block_table=attn_metadata.block_tables,
325
+ - input_layout="TND",
326
+ - block_size=block_size,
327
+ - actual_seq_lengths=attn_metadata.query_start_loc[1:],
328
+ - actual_seq_lengths_kv=attn_metadata.seq_lens,
329
+ - num_key_value_heads=self.num_kv_heads,
330
+ - num_heads=self.num_heads,
331
+ - scale=self.scale,
332
+ - sparse_mode=3,
333
+ - )
334
+ + #TODO: swa层,window长度 传参
335
+ + use_swa = (self.layer_idx % 2 == 0)
336
+ + sparse_mode = 4 if use_swa else 3
337
+ + if param_sink_number > 0:
338
+ + if sparse_mode == 4:
339
+ + output, _ = torch.ops.custom.npu_fused_infer_attention_sink(
340
+ + query,
341
+ + key,
342
+ + value,
343
+ + atten_mask=self.attn_mask,
344
+ + actual_seq_qlen=attn_metadata.swa_seq_qlens,
345
+ + actual_seq_kvlen=attn_metadata.sink_seq_kvlens,
346
+ + block_table=attn_metadata.sink_block_tables,
347
+ + pre_tokens=128,
348
+ + next_tokens=0,
349
+ + num_query_heads=self.num_heads,
350
+ + num_key_value_heads=self.num_kv_heads,
351
+ + input_layout='TND',
352
+ + sparse_mode=4,
353
+ + block_size=block_size,
354
+ + sink_number=param_sink_number,
355
+ + softmax_scale=self.scale
356
+ + )
357
+ + elif sparse_mode == 3:
358
+ + output, _ = torch.ops.custom.npu_fused_infer_attention_sink(
359
+ + query,
360
+ + key,
361
+ + value,
362
+ + atten_mask=self.attn_mask,
363
+ + actual_seq_qlen=attn_metadata.swa_seq_qlens,
364
+ + actual_seq_kvlen=attn_metadata.sink_seq_kvlens,
365
+ + block_table=attn_metadata.sink_block_tables,
366
+ + num_query_heads=self.num_heads,
367
+ + num_key_value_heads=self.num_kv_heads,
368
+ + input_layout='TND',
369
+ + sparse_mode=3,
370
+ + block_size=block_size,
371
+ + sink_number=param_sink_number,
372
+ + softmax_scale=self.scale
373
+ + )
374
+ +
375
+ + else:
376
+ + output, _ = torch_npu.npu_fused_infer_attention_score(
377
+ + query=query,
378
+ + key=key,
379
+ + value=value,
380
+ + atten_mask=attn_metadata.attn_mask,
381
+ + block_table=attn_metadata.block_tables,
382
+ + input_layout="TND",
383
+ + block_size=block_size,
384
+ + actual_seq_lengths=attn_metadata.query_start_loc[1:],
385
+ + actual_seq_lengths_kv=attn_metadata.seq_lens,
386
+ + num_key_value_heads=self.num_kv_heads,
387
+ + num_heads=self.num_heads,
388
+ + scale=self.scale,
389
+ + sparse_mode=3,
390
+ + )
391
+ else:
392
+ + # Patch for sink on CANN 8.2
393
+ + if param_sink_number > 0:
394
+ + seq_kvlens = attn_metadata.seq_lens + param_sink_number
395
+ + block_tables = F.pad(attn_metadata.block_tables, (1, 0, 0, 0), value=0)
396
+ + mask = F.pad(attn_metadata.attn_mask, (param_sink_number, 0, 0, 0), value=0)
397
+ + else:
398
+ + seq_kvlens = attn_metadata.seq_lens
399
+ + block_tables = attn_metadata.block_tables
400
+ + mask = attn_metadata.attn_mask
401
+ +
402
+ torch_npu._npu_paged_attention_splitfuse(
403
+ query=query,
404
+ key_cache=self.key_cache,
405
+ value_cache=self.value_cache,
406
+ - mask=attn_metadata.attn_mask,
407
+ - block_table=attn_metadata.block_tables,
408
+ + mask=mask,
409
+ + block_table=block_tables,
410
+ seq_len=attn_metadata.query_lens,
411
+ - context_lens=attn_metadata.seq_lens,
412
+ + context_lens=seq_kvlens,
413
+ num_kv_heads=self.num_kv_heads,
414
+ num_heads=self.num_heads,
415
+ scale_value=self.scale,
416
+ out=output)
417
+ +
418
+ return output
419
+
420
+ def forward(
421
+ @@ -525,6 +718,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
422
+ attn_metadata: AscendMetadata,
423
+ output: Optional[torch.Tensor] = None,
424
+ trace_flag: bool = True,
425
+ + sink_query: Optional[torch.Tensor] = None,
426
+ + sink_key: Optional[torch.Tensor] = None,
427
+ + sink_value: Optional[torch.Tensor] = None,
428
+ + v_head_size: Optional[int] = None,
429
+ ) -> torch.Tensor:
430
+ """Forward pass with Ascend attention.
431
+ Args:
432
+ @@ -556,7 +753,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
433
+ key=key,
434
+ value=value,
435
+ output=output,
436
+ - layer_name=layer.layer_name)
437
+ + layer_name=layer.layer_name,
438
+ + sink_query=sink_query,
439
+ + sink_key=sink_key,
440
+ + sink_value=sink_value,
441
+ + v_head_size=v_head_size
442
+ + )
443
+
444
+ elif hasattr(layer, 'quant_method') and use_kv_cache_int8:
445
+ output = layer.quant_method.apply(layer, query, key, value,
446
+ @@ -575,10 +777,13 @@ class AscendAttentionBackendImpl(AttentionImpl):
447
+ "encoder/decoder cross-attention "
448
+ "are not implemented for "
449
+ "PallasAttentionBackendImpl")
450
+ + sink_key_flag = (sink_key is not None)
451
+ + param_sink_number = sink_key.shape[0] if sink_key_flag else 0
452
+ # View q k v to BSH.
453
+ query = query.view(-1, self.num_heads, self.head_size)
454
+ key = key.view(-1, self.num_kv_heads, self.head_size)
455
+ - value = value.view(-1, self.num_kv_heads, self.head_size)
456
+ + value = value.view(-1, self.num_kv_heads,
457
+ + v_head_size if v_head_size is not None else self.head_size)
458
+ # TODO: Remove this contiguous in the future.
459
+ value = value.contiguous()
460
+
461
+ @@ -586,33 +791,63 @@ class AscendAttentionBackendImpl(AttentionImpl):
462
+ if self.key_cache is None:
463
+ self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
464
+ slots = attn_metadata.slot_mapping
465
+ +
466
+ torch_npu._npu_reshape_and_cache(
467
+ key=key[:num_actual_tokens],
468
+ value=value[:num_actual_tokens],
469
+ key_cache=self.key_cache,
470
+ value_cache=self.value_cache,
471
+ slot_indices=slots)
472
+ -
473
+ + if sink_key_flag and not self.sink_cached:
474
+ + # kv cache start from block 1 and slots 128, so we store sink in block 0.
475
+ + slots = torch.arange(0, param_sink_number,
476
+ + dtype=attn_metadata.slot_mapping.dtype,
477
+ + device=attn_metadata.slot_mapping.device)
478
+ + torch_npu._npu_reshape_and_cache(
479
+ + key=sink_key,
480
+ + value=sink_value,
481
+ + key_cache=self.key_cache,
482
+ + value_cache=self.value_cache,
483
+ + slot_indices=slots)
484
+ + self.sink_cached = True
485
+ +
486
+ + # TODO: 暂不进PrefillCacheHit分支,不更新sink实现
487
+ + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache and sink_key_flag:
488
+ + attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
489
+ # V0-Style scheduler situation.
490
+ if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
491
+ + if torch.version.cann.startswith("8.3"):
492
+ + # npu_fused_infer_attention_score and npu_fused_infer_attention_sink
493
+ + # does not support cases where query.shape[0] != actual_seq_lengths
494
+ + # Thus we need unpad it here.
495
+ + num_tokens = attn_metadata.query_start_loc[-1]
496
+ + query = query[:num_tokens]
497
+ + key = key[:num_tokens]
498
+ + value = value[:num_tokens]
499
+ + elif sink_key_flag:
500
+ + query = torch.cat([sink_query, query], dim=0)
501
+ + if sink_key_flag:
502
+ + key = torch.cat([sink_key, key], dim=0)
503
+ + value = torch.cat([sink_value, value], dim=0)
504
+ output = self._forward_prefill_no_cache(
505
+ - query, key, value, attn_metadata, output, num_tokens)
506
+ + query, key, value, attn_metadata, output, num_tokens,
507
+ + param_sink_number
508
+ + )
509
+ elif attn_metadata.attn_state == \
510
+ AscendAttentionState.PrefillCacheHit:
511
+ output = self._forward_prefill_cache_hit(
512
+ query, attn_metadata, output)
513
+ elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
514
+ output = self._forward_decode_only(query, attn_metadata,
515
+ - output)
516
+ + output, layer,
517
+ + param_sink_number)
518
+ # Normal V1 situation.
519
+ else:
520
+ if torch.version.cann.startswith("8.3"):
521
+ - # npu_fused_infer_attention_score does not support cases
522
+ - # where query.shape[0] != attn_metadata.query_start_loc[-1].
523
+ - # Thus we need unpad it here.
524
+ num_tokens = attn_metadata.query_start_loc[-1]
525
+ query = query[:num_tokens]
526
+ - output = self._forward_v1_style(query, attn_metadata, output)
527
+ + output = self._forward_v1_style(query, attn_metadata, output,
528
+ + param_sink_number)
529
+
530
+ # to make in-place change to the output tensor
531
+ if hasattr(layer, 'quant_method') and use_kv_cache_int8:
532
+ @@ -627,6 +862,10 @@ def unified_ascend_attention_with_output(
533
+ value: torch.Tensor,
534
+ output: torch.Tensor,
535
+ layer_name: str,
536
+ + sink_query: Optional[torch.Tensor] = None,
537
+ + sink_key: Optional[torch.Tensor] = None,
538
+ + sink_value: Optional[torch.Tensor] = None,
539
+ + v_head_size: Optional[int] = None,
540
+ ) -> None:
541
+ wait_for_kv_layer_from_connector(layer_name)
542
+ forward_context: ForwardContext = get_forward_context()
543
+ @@ -642,7 +881,11 @@ def unified_ascend_attention_with_output(
544
+ kv_cache,
545
+ attn_metadata,
546
+ output,
547
+ - trace_flag=False)
548
+ + trace_flag=False,
549
+ + sink_query=sink_query,
550
+ + sink_key=sink_key,
551
+ + sink_value=sink_value,
552
+ + v_head_size=v_head_size)
553
+ maybe_save_kv_layer_to_connector(layer_name, kv_cache)
554
+ return
555
+
556
+ @@ -653,6 +896,11 @@ def unified_attention_with_output_fake(
557
+ value: torch.Tensor,
558
+ output: torch.Tensor,
559
+ layer_name: str,
560
+ + # patch for pangu with attention sink
561
+ + sink_query: Optional[torch.Tensor] = None,
562
+ + sink_key: Optional[torch.Tensor] = None,
563
+ + sink_value: Optional[torch.Tensor] = None,
564
+ + v_head_size: Optional[int] = None,
565
+ ) -> None:
566
+ return
567
+
568
+ diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py
569
+ index 519cde0..93e1c95 100644
570
+ --- a/vllm_ascend/attention/utils.py
571
+ +++ b/vllm_ascend/attention/utils.py
572
+ @@ -63,6 +63,8 @@ class AscendCommonAttentionMetadata:
573
+
574
+ graph_pad_size: int = -1
575
+
576
+ + num_input_tokens: int = -1
577
+ +
578
+
579
+ def split_decodes_and_prefills(
580
+ common_attn_metadata: AscendCommonAttentionMetadata,
581
+ diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py
582
+ index f1581df..b690bcb 100644
583
+ --- a/vllm_ascend/platform.py
584
+ +++ b/vllm_ascend/platform.py
585
+ @@ -216,6 +216,9 @@ class NPUPlatform(Platform):
586
+ if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
587
+ compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
588
+
589
+ + if compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
590
+ + compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE_DECODE_ONLY
591
+ +
592
+ if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
593
+ compilation_config.level = CompilationLevel.NO_COMPILATION
594
+ # TODO: Currently MLA does not support FULL_DECODE_ONLY, remove the second condition
595
+ @@ -223,7 +226,8 @@ class NPUPlatform(Platform):
596
+ elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or (
597
+ compilation_config.cudagraph_mode
598
+ == CUDAGraphMode.FULL_DECODE_ONLY and model_config is not None
599
+ - and model_config.use_mla):
600
+ + and model_config.use_mla) or (
601
+ + compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE_DECODE_ONLY):
602
+ logger.info(
603
+ "PIECEWISE compilation enabled on NPU. use_inductor not supported - "
604
+ "using only ACL Graph mode")
605
+ @@ -232,7 +236,8 @@ class NPUPlatform(Platform):
606
+ compilation_config.set_splitting_ops_for_v1()
607
+ compilation_config.use_inductor = False
608
+ compilation_config.splitting_ops.extend([
609
+ - "vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
610
+ + "vllm.unified_ascend_attention_with_output", "vllm.mla_forward",
611
+ + "vllm.aggregate_hiddden",
612
+ ])
613
+ update_aclgraph_sizes(vllm_config)
614
+ elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
615
+ diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py
616
+ index 9281dd7..34808ec 100644
617
+ --- a/vllm_ascend/worker/model_runner_v1.py
618
+ +++ b/vllm_ascend/worker/model_runner_v1.py
619
+ @@ -281,6 +281,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
620
+ self.encoder_cache: Dict[str, torch.Tensor] = {}
621
+ self.attn_mask = None
622
+ self.attn_state = None
623
+ + self.with_prefill = False
624
+ self.requests: Dict[str, CachedRequestState] = {}
625
+ self.intermediate_tensors: Optional[IntermediateTensors] = None
626
+ self.runner_only_attn_layers: set[str] = set()
627
+ @@ -509,6 +510,48 @@ class NPUModelRunner(LoRAModelRunnerMixin):
628
+ self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
629
+ dtype=torch.int32)
630
+
631
+ + # Patch for conv cache
632
+ + self.router_sliding_window = getattr(self.model_config.hf_text_config, "router_sliding_window", 0)
633
+ + if self.router_sliding_window > 1:
634
+ + self.cache_length = self.router_sliding_window - 1
635
+ + self.req_cache_map = {}
636
+ + self.occupied_cache = [0]*(self.max_num_reqs)
637
+ + self.q_offsets = torch.arange(-self.cache_length, 0, device=self.device)
638
+ + self.cache_slot_id = torch.empty(self.max_num_reqs,
639
+ + dtype=torch.long, device=self.device)
640
+ + self.is_first_chunk = torch.empty(self.max_num_reqs, dtype=torch.bool, device=self.device) # For chunked prefill
641
+ +
642
+ + def _build_conv_context(self, with_prefill:bool = False, dummy:bool = False, num_tokens:int = 0):
643
+ + # conv cache slot & prefill hiddenstates loc
644
+ + cache_slot_id = self.cache_slot_id[:self.input_batch.num_reqs]
645
+ + query_start_loc = self.query_start_loc[:self.input_batch.num_reqs + 1]
646
+ + is_first_chunk = self.is_first_chunk[:self.input_batch.num_reqs]
647
+ +
648
+ + if with_prefill:
649
+ + for idx, req_id in enumerate(self.input_batch.req_ids):
650
+ + if req_id in self.req_cache_map:
651
+ + cache_id = self.req_cache_map[req_id]
652
+ + cache_slot_id[idx] = cache_id
653
+ + is_first_chunk[idx] = False
654
+ + else:
655
+ + # new request with the first chunk
656
+ + new_cahce_id = self.occupied_cache.index(0)
657
+ + self.occupied_cache[new_cahce_id] = 1
658
+ + self.req_cache_map[req_id] = new_cahce_id
659
+ + cache_slot_id[idx] = new_cahce_id
660
+ + is_first_chunk[idx] = True
661
+ + else:
662
+ + for idx, req_id in enumerate(self.input_batch.req_ids):
663
+ + cache_id = self.req_cache_map[req_id]
664
+ + cache_slot_id[idx] = cache_id
665
+ + is_first_chunk[idx] = False
666
+ +
667
+ + forward_context = get_forward_context()
668
+ + forward_context.cache_slot_id = cache_slot_id
669
+ + forward_context.is_first_chunk = is_first_chunk
670
+ + forward_context.query_start_loc = query_start_loc
671
+ +
672
+ +
673
+ def _make_buffer(self,
674
+ *size: Union[int, torch.SymInt],
675
+ dtype: torch.dtype,
676
+ @@ -548,12 +591,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
677
+ self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
678
+
679
+ def _use_aclgraph(self) -> bool:
680
+ - return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
681
+ + return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and \
682
+ + self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
683
+
684
+ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
685
+ # Remove finished requests from the cached states.
686
+ for req_id in scheduler_output.finished_req_ids:
687
+ - self.requests.pop(req_id, None)
688
+ + self.requests.pop(req_id, None)
689
+ + if self.router_sliding_window > 1 and req_id in self.req_cache_map:
690
+ + cache_id = self.req_cache_map.pop(req_id)
691
+ + self.occupied_cache[cache_id] = 0
692
+
693
+ # Remove the finished requests from the persistent batch.
694
+ # NOTE(woosuk): There could be an edge case where finished_req_ids and
695
+ @@ -891,7 +938,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
696
+ def _make_attention_mask(self, seq_lens, position,
697
+ attn_state) -> torch.Tensor:
698
+ # Chunk Prefill situation.
699
+ - if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa:
700
+ + if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not \
701
+ + self.ascend_config.use_sfa:
702
+ if torch.version.cann.startswith("8.3"):
703
+ return self.attn_mask_builder.get_splitfuse_attn_mask()
704
+ else:
705
+ @@ -942,7 +990,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
706
+ src_end = num_computed_tokens + prompt_part_len
707
+
708
+ self.mrope_positions_cpu[:, dst_start:dst_end] = \
709
+ - req.mrope_positions[:,src_start:src_end]
710
+ + req.mrope_positions[:, src_start:src_end]
711
+
712
+ mrope_pos_ptr += prompt_part_len
713
+
714
+ @@ -1126,9 +1174,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
715
+ cumsum_dtype: Optional[np.dtype] = None,
716
+ ) -> tuple[np.ndarray, np.ndarray]:
717
+ """Get the cumulative sum and batched arange of the given array.
718
+ - # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
719
+ - # Equivalent to but faster than:
720
+ - # np.concatenate([np.arange(n) for n in num_tokens])
721
+ + E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
722
+ + Equivalent to but faster than:
723
+ + np.concatenate([np.arange(n) for n in num_tokens])
724
+ """
725
+ # Step 1. [2, 5, 3] -> [2, 7, 10]
726
+ cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
727
+ @@ -1518,6 +1566,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
728
+ max_query_len=max_num_scheduled_tokens,
729
+ graph_pad_size=self.graph_pad_size,
730
+ decode_token_per_req=self.decode_token_per_req,
731
+ + num_input_tokens=num_input_tokens
732
+ )
733
+
734
+ if self.speculative_config and \
735
+ @@ -1964,6 +2013,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
736
+ model_instance=self.model):
737
+ self.maybe_setup_kv_connector(scheduler_output)
738
+
739
+ + if self.router_sliding_window > 1:
740
+ + self._build_conv_context(self.with_prefill)
741
+ +
742
+ hidden_states = self._generate_process_reqs_hidden_states(
743
+ attn_metadata, self.with_prefill, maybe_padded_num_tokens,
744
+ input_ids, positions, intermediate_tensors, inputs_embeds)
745
+ @@ -2339,7 +2391,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
746
+ ) -> torch.Tensor:
747
+ # only support eager mode and piecewise graph now
748
+ assert aclgraph_runtime_mode in {
749
+ - CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
750
+ + CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL,
751
+ + CUDAGraphMode.PIECEWISE_DECODE_ONLY
752
+ }
753
+
754
+ # Padding for DP
755
+ @@ -2472,6 +2525,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
756
+ batch_descriptor=batch_descriptor,
757
+ prefetch_stream=self.prefetch_stream,
758
+ model_instance=self.model):
759
+ + if self.router_sliding_window > 1:
760
+ + self._build_conv_context(with_prefill, dummy=True, num_tokens=num_tokens)
761
+ hidden_states = self._generate_dummy_run_hidden_states(
762
+ with_prefill, is_torchair_compile, input_ids, positions,
763
+ attn_metadata, num_tokens, intermediate_tensors,
764
+ @@ -2789,8 +2844,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
765
+
766
+ # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
767
+ # address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
768
+ - # we found there are also some exceptions during test, so we manual align those memory here, this part
769
+ - # of code may consume 2M * 2 * elem_size memory every layer.
770
+ + # we found there are also some exceptions during test, so we manual align those memory here,
771
+ + # this part of code may consume 2M * 2 * elem_size memory every layer.
772
+ nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
773
+ nope_allocate_shape_alignment = nope_allocate_shape + alignment
774
+ rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
775
+ @@ -2888,8 +2943,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
776
+
777
+ # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
778
+ # address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
779
+ - # we found there are also some exceptions during test, so we manual align those memory here, this part
780
+ - # of code may consume 2M * 2 * elem_size memory every layer.
781
+ + # we found there are also some exceptions during test, so we manual align those memory here,
782
+ + # this part of code may consume 2M * 2 * elem_size memory every layer.
783
+ nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
784
+ nope_allocate_shape_alignment = nope_allocate_shape + alignment
785
+ rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
786
+ @@ -3432,6 +3487,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
787
+ and all(op in self.compilation_config.splitting_ops for op in [
788
+ "vllm.unified_ascend_attention_with_output",
789
+ "vllm.mla_forward",
790
+ + "vllm.aggregate_hiddden",
791
+ ]))
792
+
793
+ # Flexible resolve the aclgraph mode
794
+ @@ -3495,7 +3551,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
795
+ uniform_decode: bool):
796
+ assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \
797
+ aclgraph_runtime_mode in [CUDAGraphMode.FULL,
798
+ - CUDAGraphMode.PIECEWISE]
799
+ + CUDAGraphMode.PIECEWISE,
800
+ + CUDAGraphMode.PIECEWISE_DECODE_ONLY]
801
+
802
+ # Only rank 0 should print progress bar during capture
803
+ if is_global_first_rank():
804
+ @@ -3519,10 +3576,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
805
+ # attention while `PIECEWISE` implies no attention.
806
+ force_attention = (aclgraph_runtime_mode == CUDAGraphMode.FULL)
807
+ self._dummy_run(num_tokens,
808
+ + with_prefill = (uniform_decode == False),
809
+ aclgraph_runtime_mode=CUDAGraphMode.NONE,
810
+ force_attention=force_attention,
811
+ uniform_decode=uniform_decode)
812
+ self._dummy_run(num_tokens,
813
+ + with_prefill = (uniform_decode == False),
814
+ aclgraph_runtime_mode=aclgraph_runtime_mode,
815
+ force_attention=force_attention,
816
+ uniform_decode=uniform_decode)
817
+ @@ -3556,7 +3615,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
818
+ logger.error(
819
+ f"ACLgraph sizes capture fail: {type(e).__name__}:\n"
820
+ "ACLgraph has insufficient available streams to capture the configured number of sizes. "
821
+ - "Please verify both the availability of adequate streams and the appropriateness of the configured size count.\n\n"
822
+ + "Please verify both the availability of adequate streams "
823
+ + "and the appropriateness of the configured size count.\n\n"
824
+ "Recommended solutions:\n"
825
+ "1. Manually configure the compilation_config parameter "
826
+ "with a reduced set of sizes: '{\"cudagraph_capture_sizes\":[size1, size2, size3, ...]}'.\n"
827
+ @@ -3564,8 +3624,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
828
+ f"{str(e)}")
829
+ raise
830
+
831
+ - if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
832
+ - aclgraph_mode.separate_routine():
833
+ + if aclgraph_mode.separate_routine() and \
834
+ + (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL or \
835
+ + aclgraph_mode.decode_mode() == CUDAGraphMode.PIECEWISE):
836
+ max_num_tokens = self.scheduler_config.max_num_seqs * \
837
+ self.uniform_decode_query_len
838
+ decode_cudagraph_batch_sizes = [
839
+ @@ -3576,7 +3637,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
840
+ reversed(decode_cudagraph_batch_sizes))
841
+ self._capture_aclgraphs(
842
+ compilation_cases=compilation_cases_decode,
843
+ - aclgraph_runtime_mode=CUDAGraphMode.FULL,
844
+ + aclgraph_runtime_mode=aclgraph_mode.decode_mode(),
845
+ uniform_decode=True)
846
+
847
+ # Disable aclgraph capturing globally, so any unexpected aclgraph