merge vllm-ascend-adaptor into main
Browse filesAdd support for vllm-ascend
Created-by: crazyang
Commit-by: Crazyang
Merged-by: Alex_Ty
Description: This PR implements openPangu-1B model adaptation for vllm-ascend inference framework.
Please refer to README for detailed usage instructions.
See merge request: ascend-tribe/openpangu-embedded-1b-model!1
- .gitignore +1 -0
- README.md +7 -2
- README_EN.md +7 -3
- checklist.chk +20 -0
- inference/vllm_ascend/_build_info.py +3 -0
- inference/vllm_ascend/attention/attention.py +1220 -0
- inference/vllm_ascend/attention/mla_v1.py +1224 -0
- inference/vllm_ascend/entrypoints/openai/reasoning_parsers/__init__.py +6 -0
- inference/vllm_ascend/entrypoints/openai/reasoning_parsers/pangu_reasoning_parser.py +171 -0
- inference/vllm_ascend/entrypoints/openai/tool_parsers/__init__.py +6 -0
- inference/vllm_ascend/entrypoints/openai/tool_parsers/pangu_tool_parser.py +300 -0
- inference/vllm_ascend/envs.py +153 -0
- inference/vllm_ascend/models/__init__.py +68 -0
- inference/vllm_ascend/models/open_pangu.py +1127 -0
- inference/vllm_ascend/ops/fused_moe.py +1530 -0
- inference/vllm_ascend/patch/worker/patch_common/__init__.py +27 -0
- inference/vllm_ascend/patch/worker/patch_common/patch_config.py +97 -0
- inference/vllm_ascend/patch/worker/patch_common/patch_parsers.py +26 -0
- inference/vllm_ascend/patch/worker/patch_common/patch_sampler.py +159 -0
- inference/vllm_ascend/quantization/w8a8.py +757 -0
- inference/vllm_ascend/quantization/w8a8_dynamic.py +831 -0
- inference/vllm_ascend/utils.py +563 -0
- inference/vllm_ascend/worker/model_runner_v1.py +0 -0
- inference/vllm_ascend/worker/npu_input_batch.py +796 -0
- inference/vllm_ascend_for_openpangu_embedded_1b.md +124 -0
- inference/vllm_ascend_for_openpangu_embedded_1b.zh.md +124 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
README.md
CHANGED
|
@@ -55,7 +55,7 @@ Atlas 800T A2 (64GB),驱动与固件安装包获取请参照 [[Atlas 800T A2](
|
|
| 55 |
##### 软件环境
|
| 56 |
|
| 57 |
- 操作系统:Linux(推荐 openEuler>=24.03)
|
| 58 |
-
- CANN==8.1.RC1,安装准备及流程请参照 [CANN Install](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/82RC1alpha002/softwareinst/instg/instg_0001.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)
|
| 59 |
- python==3.10
|
| 60 |
- torch==2.1.0
|
| 61 |
- torch-npu==2.1.0.post12
|
|
@@ -85,7 +85,12 @@ fi
|
|
| 85 |
cd inference
|
| 86 |
python generate.py
|
| 87 |
```
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
## 5. 模型许可证
|
| 91 |
|
|
|
|
| 55 |
##### 软件环境
|
| 56 |
|
| 57 |
- 操作系统:Linux(推荐 openEuler>=24.03)
|
| 58 |
+
- CANN==8.1.RC1,安装准备及流程请参照 [[CANN Install](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/82RC1alpha002/softwareinst/instg/instg_0001.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)]
|
| 59 |
- python==3.10
|
| 60 |
- torch==2.1.0
|
| 61 |
- torch-npu==2.1.0.post12
|
|
|
|
| 85 |
cd inference
|
| 86 |
python generate.py
|
| 87 |
```
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
### 4.4 使用推理框架
|
| 91 |
+
**vllm_ascend:** 参考[[vllm_ascend_for_openpangu_embedded_1b.zh]](inference/vllm_ascend_for_openpangu_embedded_1b.zh.md)
|
| 92 |
+
|
| 93 |
+
**昇腾 Atlas 200I A2推理:** openPangu-Embedded-1B 模型推理已适配昇腾 MindIE 2.2.T10([[下载链接]](https://support.huawei.com/enterprise/zh/ascend-computing/mindie-pid-261803968/software/266130647?idAbsPath=fixnode01|23710424|251366513|254884019|261408772|261803968)),支持 OrangePi AIpro (昇腾 Atlas 200I A2) 推理部署。届时可前往 [[昇腾社区ModelZoo]](https://gitee.com/ascend/ModelZoo-PyTorch/blob/master/MindIE/LLM/Pangu/openPangu-Embedded-1B-OrangePi/README.md) 下载适配,下载镜像前需要申请权限,耐心等待权限申请通过后,根据指南下载对应版本文件和安装指导完成推理部署。
|
| 94 |
|
| 95 |
## 5. 模型许可证
|
| 96 |
|
README_EN.md
CHANGED
|
@@ -43,7 +43,7 @@ The openPangu-Embedded-1B is a high-efficiency fast-thinking language model desi
|
|
| 43 |
| MBPP | Pass@1 | 54.09 |
|
| 44 |
| HumanEval | Pass@1 | 56.71 |
|
| 45 |
|
| 46 |
-
**Note:** The system prompt is left empty.
|
| 47 |
|
| 48 |
|
| 49 |
## 4. Deployment
|
|
@@ -57,7 +57,7 @@ Atlas 800T A2 (64GB), please refer to [[Atlas 800T A2](https://www.hiascend.com/
|
|
| 57 |
#### System Requirements & Dependencies
|
| 58 |
|
| 59 |
- System: Linux (OpenEuler ≥ 24.03 recommended)
|
| 60 |
-
- CANN==8.1.RC1: [CANN Install](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/82RC1alpha002/softwareinst/instg/instg_0001.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)
|
| 61 |
- python==3.10
|
| 62 |
- torch==2.1.0
|
| 63 |
- torch-npu==2.1.0.post12
|
|
@@ -88,7 +88,11 @@ The following provides a simple inference example of openPangu-Embedded-1B based
|
|
| 88 |
cd inference
|
| 89 |
python generate.py
|
| 90 |
```
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
## 5. Model License
|
| 94 |
|
|
|
|
| 43 |
| MBPP | Pass@1 | 54.09 |
|
| 44 |
| HumanEval | Pass@1 | 56.71 |
|
| 45 |
|
| 46 |
+
**Note:** The system prompt is left empty during the evaluation.
|
| 47 |
|
| 48 |
|
| 49 |
## 4. Deployment
|
|
|
|
| 57 |
#### System Requirements & Dependencies
|
| 58 |
|
| 59 |
- System: Linux (OpenEuler ≥ 24.03 recommended)
|
| 60 |
+
- CANN==8.1.RC1: [[CANN Install]](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/82RC1alpha002/softwareinst/instg/instg_0001.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)
|
| 61 |
- python==3.10
|
| 62 |
- torch==2.1.0
|
| 63 |
- torch-npu==2.1.0.post12
|
|
|
|
| 88 |
cd inference
|
| 89 |
python generate.py
|
| 90 |
```
|
| 91 |
+
|
| 92 |
+
### 4.4 Using Inference Framework
|
| 93 |
+
**vllm_ascend:** [[vllm_ascend_for_openpangu_embedded_1b]](inference/vllm_ascend_for_openpangu_embedded_1b.md)
|
| 94 |
+
|
| 95 |
+
**Ascend Atlas 200I A2 Inference:** The openPangu-Embedded-1B model inference has been adapted for Ascend MindIE version 2.2.T10 ([[download link]](https://support.huawei.com/enterprise/zh/ascend-computing/mindie-pid-261803968/software/266130647?idAbsPath=fixnode01|23710424|251366513|254884019|261408772|261803968)), and can be deployed on OrangePi AIpro (Ascend Atlas 200I A2) for inference. The adapted package will be available for download on [[Ascend Community ModelZoo]](https://gitee.com/ascend/ModelZoo-PyTorch/blob/master/MindIE/LLM/Pangu/openPangu-Embedded-1B-OrangePi/README.md). Before downloading the image, you need to apply for permissions. Please wait patiently until the permission application is approved, then follow the guidelines to download the corresponding image file and installation guide to complete the inference deployment.
|
| 96 |
|
| 97 |
## 5. Model License
|
| 98 |
|
checklist.chk
CHANGED
|
@@ -2,6 +2,26 @@
|
|
| 2 |
7694a0e7b59d7ec2eeebc2fd058f02fe4dc4464b27f82839fc9f425a88555a3a *./configuration_openpangu_dense.py
|
| 3 |
a12bff27a61421a0dddff6d814d6a512d423d466f7fdec406460e45eaca2e7ce *./generation_config.json
|
| 4 |
58f15aa7474fcb08d59156d6ecf28df23f187cc84a912a66b2f1d06053dcc988 *./inference/generate.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
10b12467031fcfbce46f280245aa7e24959b912bfe8bbd4f6a44168d012b565e *./model.safetensors
|
| 6 |
f15eaf322af8a0b0f16b26795eb68af836179413d3dbfa4dc44505db6c8b0d6f *./modeling_openpangu_dense.py
|
| 7 |
c1f2d87f855b994039c52b1e83c8a7f3d71a2d1eb52946c4a2e862e99f19d8b3 *./modular_openpangu_dense.py
|
|
|
|
| 2 |
7694a0e7b59d7ec2eeebc2fd058f02fe4dc4464b27f82839fc9f425a88555a3a *./configuration_openpangu_dense.py
|
| 3 |
a12bff27a61421a0dddff6d814d6a512d423d466f7fdec406460e45eaca2e7ce *./generation_config.json
|
| 4 |
58f15aa7474fcb08d59156d6ecf28df23f187cc84a912a66b2f1d06053dcc988 *./inference/generate.py
|
| 5 |
+
ba6d7edcf1cf464d6fd787b12a9bda2a16fea0ac0d5d1e54136baec503d6e696 *./inference/vllm_ascend/attention/attention.py
|
| 6 |
+
2254aeca0be7b8922318e10c4a950f39afb30ba5fe3b46564a58671b237ac612 *./inference/vllm_ascend/attention/mla_v1.py
|
| 7 |
+
f9577c29bc4dc19a4cc41ccfcca17065402c9dd92221bef987c74808b23ed124 *./inference/vllm_ascend/entrypoints/openai/reasoning_parsers/pangu_reasoning_parser.py
|
| 8 |
+
9070682b058a79d2b2874ba5e07ce72beff6efb870f75cdac30cdcf6ba8fadc7 *./inference/vllm_ascend/entrypoints/openai/reasoning_parsers/__init__.py
|
| 9 |
+
91eab52cdc19603b7b705b302e25345d849e18fa66875261a1135d5382392123 *./inference/vllm_ascend/entrypoints/openai/tool_parsers/pangu_tool_parser.py
|
| 10 |
+
d07256c9014f911f81269e65aad6c0d7dd61d4e82f5cb399e05285d5c1bc8fa8 *./inference/vllm_ascend/entrypoints/openai/tool_parsers/__init__.py
|
| 11 |
+
52a968f10ebaebeb626248afd3e1d1b92f8fbfcaad19ebf05cafbc0bd03192cb *./inference/vllm_ascend/envs.py
|
| 12 |
+
b654e72ece161b3f04080e5c4d2476641c024939ac5308115fe1c65a6c5c7215 *./inference/vllm_ascend/models/open_pangu.py
|
| 13 |
+
e98aa2549f02017a35b07499216fe569e86400684087821820cf2d971c8fcbac *./inference/vllm_ascend/models/__init__.py
|
| 14 |
+
09273eb0e4696d2fb530881ba1ad9d331897dd81c0cd2f203ed3d0a475b4d39b *./inference/vllm_ascend/ops/fused_moe.py
|
| 15 |
+
8436ab93933989431160e55627b5dce5326f0fc5ec18263653902764ac8ace7b *./inference/vllm_ascend/patch/worker/patch_common/patch_config.py
|
| 16 |
+
8c59df8086bde0cd4df674403f83000921a34403651a8ff2b31de9b28768247a *./inference/vllm_ascend/patch/worker/patch_common/patch_parsers.py
|
| 17 |
+
e712ea36caf16c2a9dd21c5288f9d8e34c7fd2face444da44dca6db6c21f6c1b *./inference/vllm_ascend/patch/worker/patch_common/patch_sampler.py
|
| 18 |
+
63a6ba0d0b0158d4586219c979bf96d5fe87b74123af93f1c8d9ed842db96500 *./inference/vllm_ascend/patch/worker/patch_common/__init__.py
|
| 19 |
+
743bd96cfc109975a11fe5412c4b5de46f880501dcbbbdd10e11cbeb865fa4f2 *./inference/vllm_ascend/quantization/w8a8.py
|
| 20 |
+
6adfaa8a67ea9b561dec2e6a2392f6fc85ff376fb2030d8761c34c6c6d3f4cbf *./inference/vllm_ascend/quantization/w8a8_dynamic.py
|
| 21 |
+
e2457c558f048876afe069d1226e7080ac214478f1a9ac28ae472928b81b5a06 *./inference/vllm_ascend/utils.py
|
| 22 |
+
62c6734d1283e3d649a6478d2004f46bfee2f7878af7f2849c979b124e355302 *./inference/vllm_ascend/worker/model_runner_v1.py
|
| 23 |
+
bc6505adabc0498ad07b49187858788c65c13dbf9446fd0bcf177a3e1b27220d *./inference/vllm_ascend/worker/npu_input_batch.py
|
| 24 |
+
4aaf57e6f6d2e139b3847b10ee59d738398ebfc4927a22325b27dad384874aec *./inference/vllm_ascend/_build_info.py
|
| 25 |
10b12467031fcfbce46f280245aa7e24959b912bfe8bbd4f6a44168d012b565e *./model.safetensors
|
| 26 |
f15eaf322af8a0b0f16b26795eb68af836179413d3dbfa4dc44505db6c8b0d6f *./modeling_openpangu_dense.py
|
| 27 |
c1f2d87f855b994039c52b1e83c8a7f3d71a2d1eb52946c4a2e862e99f19d8b3 *./modular_openpangu_dense.py
|
inference/vllm_ascend/_build_info.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Auto-generated file
|
| 2 |
+
__soc_version__ = 'ASCEND910B1'
|
| 3 |
+
__sleep_mode_enabled__ = True
|
inference/vllm_ascend/attention/attention.py
ADDED
|
@@ -0,0 +1,1220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# This file is a part of the vllm-ascend project.
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch_npu
|
| 24 |
+
import torchair._contrib.custom_torch_ops # type: ignore # noqa: F401
|
| 25 |
+
from torch.nn.functional import scaled_dot_product_attention
|
| 26 |
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
| 27 |
+
AttentionLayer,
|
| 28 |
+
AttentionMetadata, AttentionType,
|
| 29 |
+
MLAAttentionImpl)
|
| 30 |
+
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
|
| 31 |
+
CommonMetadataBuilder,
|
| 32 |
+
compute_slot_mapping,
|
| 33 |
+
compute_slot_mapping_start_idx,
|
| 34 |
+
is_block_tables_empty)
|
| 35 |
+
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
| 36 |
+
|
| 37 |
+
from vllm_ascend.ascend_config import get_ascend_config
|
| 38 |
+
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
| 39 |
+
from vllm_ascend.ops.cache import concat_and_cache_mla
|
| 40 |
+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
|
| 41 |
+
enable_custom_op, is_310p, nd_to_nz_2d)
|
| 42 |
+
from vllm_ascend.worker.model_runner import (
|
| 43 |
+
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
| 44 |
+
|
| 45 |
+
_ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class AscendAttentionBackend(AttentionBackend):
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def get_name() -> str:
|
| 52 |
+
return "ASCEND"
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
| 56 |
+
return AscendAttentionBackendImpl
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
def get_metadata_cls() -> Type["AscendMetadata"]:
|
| 60 |
+
return AscendMetadata
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
def get_state_cls() -> Type["CommonAttentionState"]:
|
| 64 |
+
return CommonAttentionState
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def get_kv_cache_shape(
|
| 68 |
+
num_blocks: int,
|
| 69 |
+
block_size: int,
|
| 70 |
+
num_kv_heads: int,
|
| 71 |
+
head_size: int,
|
| 72 |
+
) -> Tuple[int, ...]:
|
| 73 |
+
if is_310p():
|
| 74 |
+
return (2, num_blocks, num_kv_heads * head_size // 16, block_size,
|
| 75 |
+
16)
|
| 76 |
+
else:
|
| 77 |
+
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
| 78 |
+
|
| 79 |
+
@staticmethod
|
| 80 |
+
def swap_blocks(
|
| 81 |
+
src_kv_cache: List[torch.Tensor],
|
| 82 |
+
dst_kv_cache: List[torch.Tensor],
|
| 83 |
+
src_to_dst: torch.Tensor,
|
| 84 |
+
) -> None:
|
| 85 |
+
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
|
| 86 |
+
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
|
| 87 |
+
src_indices = src_to_dst[:, 0]
|
| 88 |
+
dst_indices = src_to_dst[:, 1]
|
| 89 |
+
|
| 90 |
+
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
|
| 91 |
+
dst_key_cache.device)
|
| 92 |
+
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
|
| 93 |
+
dst_key_cache.device)
|
| 94 |
+
|
| 95 |
+
@staticmethod
|
| 96 |
+
def copy_blocks(
|
| 97 |
+
kv_caches: List[torch.Tensor],
|
| 98 |
+
src_to_dists: torch.Tensor,
|
| 99 |
+
) -> None:
|
| 100 |
+
src_indices = src_to_dists[:, 0]
|
| 101 |
+
dst_indices = src_to_dists[:, 1]
|
| 102 |
+
|
| 103 |
+
for kv_cache in kv_caches:
|
| 104 |
+
key_caches = kv_cache[0]
|
| 105 |
+
value_caches = kv_cache[1]
|
| 106 |
+
key_caches[dst_indices] = key_caches[src_indices]
|
| 107 |
+
value_caches[dst_indices] = value_caches[src_indices]
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def get_builder_cls() -> Type["AscendMetadataBuilder"]:
|
| 111 |
+
return AscendMetadataBuilder
|
| 112 |
+
|
| 113 |
+
@classmethod
|
| 114 |
+
def make_metadata_builder(cls, *args, **kwargs) -> "AscendMetadataBuilder":
|
| 115 |
+
return cls.get_builder_cls()(*args, **kwargs)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class AscendMLAAttentionBackend(AscendAttentionBackend):
|
| 119 |
+
|
| 120 |
+
@staticmethod
|
| 121 |
+
def get_impl_cls() -> Type["AscendMLAAttentionBackendImpl"]:
|
| 122 |
+
return AscendMLAAttentionBackendImpl
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def get_kv_cache_shape(
|
| 126 |
+
num_blocks: int,
|
| 127 |
+
block_size: int,
|
| 128 |
+
num_kv_heads: int,
|
| 129 |
+
head_size: int,
|
| 130 |
+
) -> Tuple[int, ...]:
|
| 131 |
+
return (num_blocks, block_size, num_kv_heads, head_size)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@dataclass
|
| 135 |
+
class AscendMetadata(AttentionMetadata):
|
| 136 |
+
"""Metadata for Ascendbackend.
|
| 137 |
+
* modified from XFormersbackend
|
| 138 |
+
NOTE: Any python object stored here is not updated when it is
|
| 139 |
+
cuda-graph replayed. If you have values that need to be changed
|
| 140 |
+
dynamically, it should be stored in tensor. The tensor has to be
|
| 141 |
+
updated from `CUDAGraphRunner.forward` API.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
# |---------- N-1 iteration --------|
|
| 145 |
+
# |---------------- N iteration ---------------------|
|
| 146 |
+
# |- tokenA -|......................|-- newTokens ---|
|
| 147 |
+
# |---------- context_len ----------|
|
| 148 |
+
# |-------------------- seq_len ----------------------|
|
| 149 |
+
# |-- query_len ---|
|
| 150 |
+
|
| 151 |
+
# FIXME: It is for flash attn.
|
| 152 |
+
# Maximum sequence length among prefill batch. 0 if there are decoding
|
| 153 |
+
# Avoid mypy error
|
| 154 |
+
# Total number of prefill requests.
|
| 155 |
+
num_prefills: int
|
| 156 |
+
# Number of prefill tokens.
|
| 157 |
+
num_prefill_tokens: int
|
| 158 |
+
# (num_tokens,). The indices of the token slots that input tokens will be
|
| 159 |
+
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
| 160 |
+
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
| 161 |
+
# in block 0, and 1st slot in block 1, respectively.
|
| 162 |
+
slot_mapping: torch.Tensor
|
| 163 |
+
|
| 164 |
+
# requests only.
|
| 165 |
+
max_prefill_seq_len: int
|
| 166 |
+
# Maximum sequence length among decode batch. 0 if there are prefill
|
| 167 |
+
# requests only.
|
| 168 |
+
max_decode_seq_len: int
|
| 169 |
+
|
| 170 |
+
chunked_prefill_enabled: bool
|
| 171 |
+
|
| 172 |
+
# (batch_size, max_blocks_per_seq).
|
| 173 |
+
# Block addresses per sequence. (Seq id -> list of physical block)
|
| 174 |
+
block_tables: Optional[torch.Tensor]
|
| 175 |
+
|
| 176 |
+
# seq_lens stored as a tensor.
|
| 177 |
+
seq_lens_tensor: Optional[torch.Tensor]
|
| 178 |
+
|
| 179 |
+
# (batch_size,). The sequence length per sequence. Sequence length means
|
| 180 |
+
# the computed tokens + new tokens None if it is a decoding.
|
| 181 |
+
seq_lens: Optional[List[int]] = None
|
| 182 |
+
|
| 183 |
+
# The query lengths of the input sequences
|
| 184 |
+
query_lens: Optional[List[int]] = None
|
| 185 |
+
|
| 186 |
+
# Maximum query length in the batch. None for decoding.
|
| 187 |
+
max_query_len: Optional[int] = None
|
| 188 |
+
|
| 189 |
+
# Self-attention prefill/decode metadata cache
|
| 190 |
+
_cached_prefill_metadata: Optional["AscendMetadata"] = None
|
| 191 |
+
_cached_decode_metadata: Optional["AscendMetadata"] = None
|
| 192 |
+
|
| 193 |
+
# Begin encoder attn & enc/dec cross-attn fields...
|
| 194 |
+
|
| 195 |
+
# Encoder sequence lengths representation
|
| 196 |
+
encoder_seq_lens: Optional[List[int]] = None
|
| 197 |
+
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
| 198 |
+
|
| 199 |
+
# Maximum sequence length among encoder sequences
|
| 200 |
+
max_encoder_seq_len: Optional[int] = None
|
| 201 |
+
|
| 202 |
+
# Number of tokens input to encoder
|
| 203 |
+
num_encoder_tokens: Optional[int] = None
|
| 204 |
+
|
| 205 |
+
# Mask for normal situation
|
| 206 |
+
attn_mask: Optional[torch.Tensor] = None
|
| 207 |
+
|
| 208 |
+
# Mask for prefix caching
|
| 209 |
+
compress_mask: Optional[torch.Tensor] = None
|
| 210 |
+
|
| 211 |
+
# Mask for chunked prefill
|
| 212 |
+
chunk_mask: Optional[torch.Tensor] = None
|
| 213 |
+
|
| 214 |
+
# Cross-attention memory-mapping data structures: slot mapping
|
| 215 |
+
# and block tables
|
| 216 |
+
cross_slot_mapping: Optional[torch.Tensor] = None
|
| 217 |
+
cross_block_tables: Optional[torch.Tensor] = None
|
| 218 |
+
|
| 219 |
+
@property
|
| 220 |
+
def prefill_metadata(self) -> Optional["AscendMetadata"]:
|
| 221 |
+
if self.num_prefills == 0:
|
| 222 |
+
return None
|
| 223 |
+
|
| 224 |
+
if self._cached_prefill_metadata is not None:
|
| 225 |
+
# Recover cached prefill-phase attention
|
| 226 |
+
# metadata structure.
|
| 227 |
+
return self._cached_prefill_metadata
|
| 228 |
+
|
| 229 |
+
assert ((self.seq_lens is not None)
|
| 230 |
+
or (self.encoder_seq_lens is not None))
|
| 231 |
+
|
| 232 |
+
# Compute some attn_metadata fields which default to None.
|
| 233 |
+
slot_mapping = (None if self.slot_mapping is None else
|
| 234 |
+
self.slot_mapping[:self.num_prefill_tokens])
|
| 235 |
+
seq_lens = (None if self.seq_lens is None else
|
| 236 |
+
self.seq_lens[:self.num_prefills])
|
| 237 |
+
query_lens = (None if self.query_lens is None else
|
| 238 |
+
self.query_lens[:self.num_prefills])
|
| 239 |
+
block_tables = (None if self.block_tables is None else
|
| 240 |
+
self.block_tables[:self.num_prefills])
|
| 241 |
+
|
| 242 |
+
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
| 243 |
+
self.seq_lens_tensor[:self.num_prefills])
|
| 244 |
+
|
| 245 |
+
# Construct & cache prefill-phase attention metadata structure.
|
| 246 |
+
self._cached_prefill_metadata = AscendMetadata(
|
| 247 |
+
num_prefills=self.num_prefills,
|
| 248 |
+
num_prefill_tokens=self.num_prefill_tokens,
|
| 249 |
+
num_decode_tokens=0,
|
| 250 |
+
slot_mapping=slot_mapping,
|
| 251 |
+
seq_lens=seq_lens,
|
| 252 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 253 |
+
query_lens=query_lens,
|
| 254 |
+
max_query_len=self.max_query_len,
|
| 255 |
+
max_prefill_seq_len=self.max_prefill_seq_len,
|
| 256 |
+
max_decode_seq_len=0,
|
| 257 |
+
chunked_prefill_enabled=self.chunked_prefill_enabled,
|
| 258 |
+
block_tables=block_tables,
|
| 259 |
+
# Begin encoder & cross attn fields below...
|
| 260 |
+
encoder_seq_lens=self.encoder_seq_lens,
|
| 261 |
+
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
| 262 |
+
max_encoder_seq_len=self.max_encoder_seq_len,
|
| 263 |
+
multi_modal_placeholder_index_maps=self.
|
| 264 |
+
multi_modal_placeholder_index_maps,
|
| 265 |
+
cross_slot_mapping=self.cross_slot_mapping,
|
| 266 |
+
cross_block_tables=self.cross_block_tables,
|
| 267 |
+
enable_kv_scales_calculation=False)
|
| 268 |
+
return self._cached_prefill_metadata
|
| 269 |
+
|
| 270 |
+
@property
|
| 271 |
+
def decode_metadata(self) -> Optional["AscendMetadata"]:
|
| 272 |
+
if self.num_decode_tokens == 0:
|
| 273 |
+
return None
|
| 274 |
+
|
| 275 |
+
if self._cached_decode_metadata is not None:
|
| 276 |
+
# Recover cached decode-phase attention
|
| 277 |
+
# metadata structure.
|
| 278 |
+
return self._cached_decode_metadata
|
| 279 |
+
|
| 280 |
+
# Compute some attn_metadata fields which default to None.
|
| 281 |
+
slot_mapping = (None if self.slot_mapping is None else
|
| 282 |
+
self.slot_mapping[self.num_prefill_tokens:])
|
| 283 |
+
seq_lens = (None if self.seq_lens is None else
|
| 284 |
+
self.seq_lens[self.num_prefills:])
|
| 285 |
+
query_lens = (None if self.query_lens is None else
|
| 286 |
+
self.query_lens[self.num_prefills:])
|
| 287 |
+
block_tables = (None if self.block_tables is None else
|
| 288 |
+
self.block_tables[self.num_prefills:])
|
| 289 |
+
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
| 290 |
+
self.seq_lens_tensor[self.num_prefills:])
|
| 291 |
+
# Construct & cache decode-phase attention metadata structure.
|
| 292 |
+
self._cached_decode_metadata = AscendMetadata(
|
| 293 |
+
num_prefills=0,
|
| 294 |
+
num_prefill_tokens=0,
|
| 295 |
+
num_decode_tokens=self.num_decode_tokens,
|
| 296 |
+
slot_mapping=slot_mapping,
|
| 297 |
+
seq_lens=seq_lens,
|
| 298 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 299 |
+
query_lens=query_lens,
|
| 300 |
+
max_query_len=self.max_query_len,
|
| 301 |
+
max_prefill_seq_len=0,
|
| 302 |
+
max_decode_seq_len=self.max_decode_seq_len,
|
| 303 |
+
chunked_prefill_enabled=self.chunked_prefill_enabled,
|
| 304 |
+
block_tables=block_tables,
|
| 305 |
+
# Begin encoder & cross attn fields below...
|
| 306 |
+
encoder_seq_lens=self.encoder_seq_lens,
|
| 307 |
+
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
| 308 |
+
max_encoder_seq_len=self.max_encoder_seq_len,
|
| 309 |
+
multi_modal_placeholder_index_maps=self.
|
| 310 |
+
multi_modal_placeholder_index_maps,
|
| 311 |
+
cross_slot_mapping=self.cross_slot_mapping,
|
| 312 |
+
cross_block_tables=self.cross_block_tables,
|
| 313 |
+
enable_kv_scales_calculation=False)
|
| 314 |
+
return self._cached_decode_metadata
|
| 315 |
+
|
| 316 |
+
def advance_step(self,
|
| 317 |
+
model_input: "ModelInputForNPUWithSamplingMetadata",
|
| 318 |
+
sampled_token_ids: Optional[torch.Tensor],
|
| 319 |
+
block_size: int,
|
| 320 |
+
num_seqs: int,
|
| 321 |
+
num_queries: int,
|
| 322 |
+
turn_prefills_into_decodes: bool = False):
|
| 323 |
+
"""
|
| 324 |
+
Update metadata in-place to advance one decode step.
|
| 325 |
+
"""
|
| 326 |
+
# When using cudagraph, the num_seqs is padded to the next captured
|
| 327 |
+
# batch sized, but num_queries tracks the actual number of requests in
|
| 328 |
+
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
| 329 |
+
if num_seqs != num_queries:
|
| 330 |
+
assert num_seqs > num_queries
|
| 331 |
+
|
| 332 |
+
if turn_prefills_into_decodes:
|
| 333 |
+
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
| 334 |
+
# decodes are scheduled together. In the first step, all the
|
| 335 |
+
# prefills turn into decodes. This update reflects that
|
| 336 |
+
# conversion.
|
| 337 |
+
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
| 338 |
+
self.num_decode_tokens += self.num_prefills
|
| 339 |
+
self.num_prefills = 0
|
| 340 |
+
self.num_prefill_tokens = 0
|
| 341 |
+
self.max_prefill_seq_len = 0
|
| 342 |
+
self.max_query_len = 1
|
| 343 |
+
|
| 344 |
+
self.slot_mapping = self.slot_mapping[:num_seqs]
|
| 345 |
+
else:
|
| 346 |
+
assert self.seq_lens is not None
|
| 347 |
+
assert self.max_decode_seq_len == max(self.seq_lens)
|
| 348 |
+
|
| 349 |
+
assert self.num_prefills == 0
|
| 350 |
+
assert self.num_prefill_tokens == 0
|
| 351 |
+
assert self.num_decode_tokens == num_seqs
|
| 352 |
+
assert self.slot_mapping.shape == (num_seqs, )
|
| 353 |
+
|
| 354 |
+
assert self.seq_lens is not None
|
| 355 |
+
assert len(self.seq_lens) == num_seqs
|
| 356 |
+
assert self.seq_lens_tensor is not None
|
| 357 |
+
assert self.seq_lens_tensor.shape == (num_seqs, )
|
| 358 |
+
assert self.max_query_len == 1
|
| 359 |
+
assert self.max_prefill_seq_len == 0
|
| 360 |
+
|
| 361 |
+
assert self.block_tables is not None
|
| 362 |
+
assert self.block_tables.shape[0] == num_seqs
|
| 363 |
+
|
| 364 |
+
# Update query lengths. Note that we update only queries and not seqs,
|
| 365 |
+
# since tensors may be padded due to captured cuda graph batch size
|
| 366 |
+
for i in range(num_queries):
|
| 367 |
+
self.seq_lens[i] += 1
|
| 368 |
+
self.max_decode_seq_len = max(self.seq_lens)
|
| 369 |
+
if enable_custom_op():
|
| 370 |
+
#advance a step on NPU for existing inputs for a multi-step runner if custom ops is enabled
|
| 371 |
+
torch.ops._C.advance_step_flashattn_ascendc(
|
| 372 |
+
num_seqs=num_seqs,
|
| 373 |
+
num_queries=num_queries,
|
| 374 |
+
block_size=block_size,
|
| 375 |
+
input_tokens=model_input.input_tokens,
|
| 376 |
+
sampled_token_ids=sampled_token_ids,
|
| 377 |
+
input_positions=model_input.input_positions,
|
| 378 |
+
seq_lens=self.seq_lens_tensor,
|
| 379 |
+
slot_mapping=self.slot_mapping,
|
| 380 |
+
block_tables=self.block_tables)
|
| 381 |
+
else:
|
| 382 |
+
# use traditional Pytorch method for updating these tensors.
|
| 383 |
+
# update input_tokens
|
| 384 |
+
sampled_token_ids_list = sampled_token_ids[:
|
| 385 |
+
num_queries].squeeze( # type: ignore
|
| 386 |
+
-1)
|
| 387 |
+
model_input.input_tokens[:
|
| 388 |
+
num_queries] = sampled_token_ids_list # type: ignore
|
| 389 |
+
|
| 390 |
+
# get seq_lens and input_positions
|
| 391 |
+
seq_lens = self.seq_lens_tensor[:num_queries]
|
| 392 |
+
next_seq_lens = seq_lens + 1
|
| 393 |
+
next_input_pos = next_seq_lens - 1
|
| 394 |
+
|
| 395 |
+
# update seq_lens and input_positions
|
| 396 |
+
self.seq_lens_tensor[:num_queries] = next_seq_lens
|
| 397 |
+
model_input.input_positions[:
|
| 398 |
+
num_queries] = next_input_pos # type: ignore
|
| 399 |
+
|
| 400 |
+
# 计算 block index 和 offset
|
| 401 |
+
block_idx = next_input_pos // block_size
|
| 402 |
+
block_offset = next_input_pos % block_size
|
| 403 |
+
|
| 404 |
+
current_block_table = self.block_tables.gather(
|
| 405 |
+
1, block_idx.unsqueeze(-1)).squeeze(-1)
|
| 406 |
+
slot_num = current_block_table * block_size + block_offset
|
| 407 |
+
|
| 408 |
+
# update slot_mapping
|
| 409 |
+
self.slot_mapping[:num_queries] = slot_num
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
| 413 |
+
|
| 414 |
+
_attn_mask_builder = None # noqa
|
| 415 |
+
|
| 416 |
+
def __init__(self, input_builder: "ModelInputForNPUBuilder"):
|
| 417 |
+
self.input_builder = input_builder
|
| 418 |
+
self.runner = input_builder.runner
|
| 419 |
+
self.sliding_window = input_builder.sliding_window
|
| 420 |
+
self.block_size = input_builder.block_size
|
| 421 |
+
|
| 422 |
+
self.attn_mask = None
|
| 423 |
+
self.compress_mask = None
|
| 424 |
+
self.chunk_mask = None
|
| 425 |
+
if AscendMetadataBuilder._attn_mask_builder is None:
|
| 426 |
+
AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder(
|
| 427 |
+
128, self.input_builder.runner.model_config.dtype)
|
| 428 |
+
|
| 429 |
+
def _add_seq_group(
|
| 430 |
+
self, inter_data: ModelInputForNPUBuilder.InterDataForSeqGroup,
|
| 431 |
+
chunked_prefill_enabled: bool):
|
| 432 |
+
"""Add a sequence group to the metadata. Specifically update/append
|
| 433 |
+
1. context length.
|
| 434 |
+
2. block table.
|
| 435 |
+
3. slot mapping.
|
| 436 |
+
"""
|
| 437 |
+
is_prompt = inter_data.is_prompt
|
| 438 |
+
block_tables = inter_data.block_tables
|
| 439 |
+
|
| 440 |
+
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
| 441 |
+
curr_sliding_window_block) in zip(
|
| 442 |
+
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
| 443 |
+
inter_data.orig_seq_lens, inter_data.seq_lens,
|
| 444 |
+
inter_data.query_lens, inter_data.context_lens,
|
| 445 |
+
inter_data.curr_sliding_window_blocks):
|
| 446 |
+
self.context_lens.append(context_len)
|
| 447 |
+
if is_prompt:
|
| 448 |
+
self.num_prefills += 1
|
| 449 |
+
self.num_prefill_tokens += token_len
|
| 450 |
+
self.prefill_seq_lens.append(seq_len)
|
| 451 |
+
else:
|
| 452 |
+
self.num_decode_tokens += query_len
|
| 453 |
+
self.curr_seq_lens.append(curr_seq_len)
|
| 454 |
+
|
| 455 |
+
# Compute block table.
|
| 456 |
+
# TODO(sang): Combine chunked prefill and prefix caching by
|
| 457 |
+
# only allowing multiple of block_size chunk size.
|
| 458 |
+
# NOTE: This only works for oooooooxxx style attention.
|
| 459 |
+
block_table: List[int] = []
|
| 460 |
+
prefix_cache_hit = any([
|
| 461 |
+
inter_data.prefix_cache_hit
|
| 462 |
+
for inter_data in self.input_builder.inter_data_list
|
| 463 |
+
])
|
| 464 |
+
if prefix_cache_hit:
|
| 465 |
+
# NOTE(woosuk): For flash-attn, the block table should
|
| 466 |
+
# include the entries for the incoming prefill tokens.
|
| 467 |
+
if block_tables is not None:
|
| 468 |
+
block_table = block_tables[seq_id]
|
| 469 |
+
elif ((chunked_prefill_enabled or not is_prompt)
|
| 470 |
+
and block_tables is not None):
|
| 471 |
+
if curr_sliding_window_block == 0:
|
| 472 |
+
block_table = block_tables[seq_id]
|
| 473 |
+
else:
|
| 474 |
+
block_table = block_tables[seq_id][
|
| 475 |
+
-curr_sliding_window_block:]
|
| 476 |
+
self.block_tables.append(block_table)
|
| 477 |
+
|
| 478 |
+
# Compute slot mapping.
|
| 479 |
+
is_profile_run = is_block_tables_empty(block_tables)
|
| 480 |
+
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
| 481 |
+
context_len,
|
| 482 |
+
self.sliding_window)
|
| 483 |
+
compute_slot_mapping(
|
| 484 |
+
is_profile_run,
|
| 485 |
+
self.slot_mapping,
|
| 486 |
+
seq_id,
|
| 487 |
+
seq_len,
|
| 488 |
+
context_len,
|
| 489 |
+
start_idx,
|
| 490 |
+
self.block_size,
|
| 491 |
+
inter_data.block_tables,
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
def _get_graph_runner_block_tables(
|
| 495 |
+
self, num_seqs: int,
|
| 496 |
+
block_tables: List[List[int]]) -> torch.Tensor:
|
| 497 |
+
# The shape of graph_block_tables is
|
| 498 |
+
# [max batch size, max context len // block size].
|
| 499 |
+
|
| 500 |
+
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
| 501 |
+
assert max_batch_size >= num_seqs
|
| 502 |
+
|
| 503 |
+
graph_block_tables = self.runner.graph_block_tables # [:num_seqs]
|
| 504 |
+
for i, block_table in enumerate(block_tables):
|
| 505 |
+
if block_table:
|
| 506 |
+
num_blocks = len(block_table)
|
| 507 |
+
if num_blocks <= max_blocks:
|
| 508 |
+
graph_block_tables[i, :num_blocks] = block_table
|
| 509 |
+
else:
|
| 510 |
+
graph_block_tables[
|
| 511 |
+
i, :max_blocks] = block_table[:max_blocks]
|
| 512 |
+
|
| 513 |
+
return torch.from_numpy(graph_block_tables).to(
|
| 514 |
+
device=self.runner.device, non_blocking=True)
|
| 515 |
+
|
| 516 |
+
def build(
|
| 517 |
+
self,
|
| 518 |
+
seq_lens: List[int],
|
| 519 |
+
query_lens: List[int],
|
| 520 |
+
graph_pad_size: int,
|
| 521 |
+
):
|
| 522 |
+
"""Build attention metadata with on-device tensors.
|
| 523 |
+
|
| 524 |
+
Args:
|
| 525 |
+
seq_lens: The maybe padded sequence lengths of the input sequences.
|
| 526 |
+
query_lens: The query lengths of the input sequences.
|
| 527 |
+
"""
|
| 528 |
+
for inter_data in self.input_builder.inter_data_list:
|
| 529 |
+
self._add_seq_group(inter_data,
|
| 530 |
+
self.input_builder.chunked_prefill_enabled)
|
| 531 |
+
|
| 532 |
+
device = self.runner.device
|
| 533 |
+
dtype = self.runner.model_config.dtype
|
| 534 |
+
use_npu_graph = graph_pad_size != -1
|
| 535 |
+
|
| 536 |
+
max_query_len = max(query_lens)
|
| 537 |
+
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
| 538 |
+
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
| 539 |
+
max_seq_len = max(max_prefill_seq_len, max_decode_seq_len)
|
| 540 |
+
num_decode_tokens = self.num_decode_tokens
|
| 541 |
+
|
| 542 |
+
if self.num_prefills == 0 and use_npu_graph:
|
| 543 |
+
num_seqs = len(seq_lens)
|
| 544 |
+
self.slot_mapping.extend([PAD_SLOT_ID] * graph_pad_size)
|
| 545 |
+
self.block_tables.extend([[]] * graph_pad_size)
|
| 546 |
+
block_tables = self._get_graph_runner_block_tables(
|
| 547 |
+
num_seqs, self.block_tables)
|
| 548 |
+
else:
|
| 549 |
+
block_tables = make_tensor_with_pad(
|
| 550 |
+
self.block_tables,
|
| 551 |
+
pad=0,
|
| 552 |
+
dtype=torch.int32,
|
| 553 |
+
device=device,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
if self.num_prefills > 0:
|
| 557 |
+
if block_tables is None or block_tables.numel() == 0:
|
| 558 |
+
# normal mask
|
| 559 |
+
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
| 560 |
+
max_prefill_seq_len, dtype, device)
|
| 561 |
+
if is_310p():
|
| 562 |
+
mask_nz = nd_to_nz_2d(self.attn_mask)
|
| 563 |
+
mask_nz = torch_npu.npu_format_cast(
|
| 564 |
+
mask_nz.contiguous(), ACL_FORMAT_FRACTAL_NZ)
|
| 565 |
+
self.attn_mask = mask_nz
|
| 566 |
+
elif self.num_decode_tokens == 0 and not self.input_builder.chunked_prefill_enabled:
|
| 567 |
+
# compress mask for prefix cache
|
| 568 |
+
self.compress_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
| 569 |
+
128, dtype, device)
|
| 570 |
+
else:
|
| 571 |
+
# chunk_mask for chunk prefill
|
| 572 |
+
attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
| 573 |
+
max_seq_len, dtype, device)
|
| 574 |
+
if attn_mask.numel() > 1 and attn_mask[0][1] > 0:
|
| 575 |
+
# Do not use in-place multiplication to avoid modifying `attn_mask_cache`!
|
| 576 |
+
attn_mask = attn_mask * -10000
|
| 577 |
+
chunk_mask_list = []
|
| 578 |
+
for i, seq_len in enumerate(seq_lens):
|
| 579 |
+
context_len = self.context_lens[i]
|
| 580 |
+
chunk_mask_list.append(attn_mask[context_len:seq_len])
|
| 581 |
+
self.chunk_mask = torch.cat(chunk_mask_list, 0)
|
| 582 |
+
else:
|
| 583 |
+
self.attn_mask = None
|
| 584 |
+
self.compress_mask = None
|
| 585 |
+
self.chunk_mask = None
|
| 586 |
+
|
| 587 |
+
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
| 588 |
+
|
| 589 |
+
assert device is not None
|
| 590 |
+
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
|
| 591 |
+
device, self.runner.pin_memory)
|
| 592 |
+
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
| 593 |
+
self.runner.pin_memory)
|
| 594 |
+
placeholder_index_maps = {
|
| 595 |
+
modality: placeholder_map.index_map()
|
| 596 |
+
for modality, placeholder_map in
|
| 597 |
+
self.multimodal_placeholder_maps.items()
|
| 598 |
+
}
|
| 599 |
+
|
| 600 |
+
return AscendMetadata(
|
| 601 |
+
num_prefills=self.num_prefills,
|
| 602 |
+
slot_mapping=slot_mapping_tensor,
|
| 603 |
+
num_prefill_tokens=self.num_prefill_tokens,
|
| 604 |
+
num_decode_tokens=num_decode_tokens,
|
| 605 |
+
seq_lens=seq_lens,
|
| 606 |
+
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
| 607 |
+
enable_kv_scales_calculation=True,
|
| 608 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 609 |
+
query_lens=query_lens,
|
| 610 |
+
max_query_len=max_query_len,
|
| 611 |
+
max_prefill_seq_len=max_prefill_seq_len,
|
| 612 |
+
max_decode_seq_len=max_decode_seq_len,
|
| 613 |
+
block_tables=block_tables,
|
| 614 |
+
attn_mask=self.attn_mask,
|
| 615 |
+
compress_mask=self.compress_mask,
|
| 616 |
+
chunk_mask=self.chunk_mask,
|
| 617 |
+
chunked_prefill_enabled=self.input_builder.chunked_prefill_enabled,
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class AscendAttentionBackendImpl(AttentionImpl):
|
| 622 |
+
|
| 623 |
+
def __init__(
|
| 624 |
+
self,
|
| 625 |
+
num_heads: int,
|
| 626 |
+
head_size: int,
|
| 627 |
+
scale: float,
|
| 628 |
+
num_kv_heads: int,
|
| 629 |
+
alibi_slopes: Optional[List[float]],
|
| 630 |
+
sliding_window: Optional[int],
|
| 631 |
+
kv_cache_dtype: str,
|
| 632 |
+
blocksparse_params: Optional[Dict[str, Any]] = None,
|
| 633 |
+
logits_soft_cap: Optional[float] = None,
|
| 634 |
+
attn_type: str = AttentionType.DECODER,
|
| 635 |
+
kv_sharing_target_layer_name: Optional[str] = None,
|
| 636 |
+
use_irope: bool = False,
|
| 637 |
+
) -> None:
|
| 638 |
+
self.num_heads = num_heads
|
| 639 |
+
self.head_size = head_size
|
| 640 |
+
self.scale = float(scale)
|
| 641 |
+
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
| 642 |
+
self.hidden_size = self.num_heads * self.head_size
|
| 643 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 644 |
+
self.sliding_window = sliding_window
|
| 645 |
+
if alibi_slopes is not None:
|
| 646 |
+
alibi_slopes = torch.tensor(alibi_slopes,
|
| 647 |
+
dtype=torch.float32,
|
| 648 |
+
device="npu")
|
| 649 |
+
self.alibi_slopes = alibi_slopes
|
| 650 |
+
self.attn_type = attn_type
|
| 651 |
+
|
| 652 |
+
assert self.num_heads % self.num_kv_heads == 0
|
| 653 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 654 |
+
self.seq_len_cpu_tensor = None
|
| 655 |
+
self.query_len_cpu_tensor = None
|
| 656 |
+
self.key_cache = None
|
| 657 |
+
self.value_cache = None
|
| 658 |
+
|
| 659 |
+
def forward(
|
| 660 |
+
self,
|
| 661 |
+
layer: AttentionLayer,
|
| 662 |
+
query: torch.Tensor,
|
| 663 |
+
key: torch.Tensor,
|
| 664 |
+
value: torch.Tensor,
|
| 665 |
+
kv_cache: torch.Tensor,
|
| 666 |
+
attn_metadata: AscendMetadata,
|
| 667 |
+
attn_type: str = AttentionType.DECODER,
|
| 668 |
+
output: Optional[torch.Tensor] = None,
|
| 669 |
+
) -> torch.Tensor:
|
| 670 |
+
"""Forward pass with Ascend attention.
|
| 671 |
+
Args:
|
| 672 |
+
query: shape = [num_tokens, num_heads * head_size]
|
| 673 |
+
num_tokens = batch_size * seq_len
|
| 674 |
+
key: shape = [num_tokens, num_kv_heads * head_size]
|
| 675 |
+
value: shape = [num_tokens, num_kv_heads * head_size]
|
| 676 |
+
kv_cache: shape = [2, num_blocks, block_size,
|
| 677 |
+
num_kv_heads, head_size]
|
| 678 |
+
key_cache = [num_blocks, block_size,
|
| 679 |
+
num_kv_heads, head_size]
|
| 680 |
+
value_cache = [num_blocks, block_size,
|
| 681 |
+
num_kv_heads, head_size]
|
| 682 |
+
attn_metadata: Metadata for attention.
|
| 683 |
+
Returns:
|
| 684 |
+
shape = [batch_size, seq_len * num_heads * head_size]
|
| 685 |
+
"""
|
| 686 |
+
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
| 687 |
+
# View q k v to BSH.
|
| 688 |
+
num_tokens = query.shape[0]
|
| 689 |
+
query = query.view(-1, self.num_heads, self.head_size)
|
| 690 |
+
key = key.view(-1, self.num_kv_heads, self.head_size)
|
| 691 |
+
value = value.view(-1, self.num_kv_heads, self.head_size)
|
| 692 |
+
# TODO: Remove this contiguous in the future.
|
| 693 |
+
value = value.contiguous()
|
| 694 |
+
attn_type = self.attn_type
|
| 695 |
+
|
| 696 |
+
output = torch.empty(num_tokens,
|
| 697 |
+
self.num_heads,
|
| 698 |
+
self.head_size,
|
| 699 |
+
dtype=query.dtype,
|
| 700 |
+
device=query.device)
|
| 701 |
+
|
| 702 |
+
if kv_cache.numel() > 0:
|
| 703 |
+
if self.key_cache is None:
|
| 704 |
+
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
| 705 |
+
slots = attn_metadata.slot_mapping
|
| 706 |
+
|
| 707 |
+
if hasattr(layer, 'quant_method'):
|
| 708 |
+
isPrefill = True if attn_metadata.num_prefills > 0 else False
|
| 709 |
+
if isPrefill:
|
| 710 |
+
assert attn_metadata.prefill_metadata is not None
|
| 711 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 712 |
+
np.array(attn_metadata.prefill_metadata.seq_lens).astype(
|
| 713 |
+
np.int32))
|
| 714 |
+
else:
|
| 715 |
+
assert attn_metadata.decode_metadata is not None
|
| 716 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 717 |
+
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
| 718 |
+
np.int32))
|
| 719 |
+
block_tables = attn_metadata.decode_metadata.block_tables if attn_metadata.decode_metadata else None
|
| 720 |
+
# Details of kv_cache arrangement in attention quantization
|
| 721 |
+
# are implemented by quant_method.
|
| 722 |
+
layer.quant_method.apply(
|
| 723 |
+
layer,
|
| 724 |
+
query,
|
| 725 |
+
key,
|
| 726 |
+
value,
|
| 727 |
+
self.key_cache,
|
| 728 |
+
self.value_cache,
|
| 729 |
+
self.scale,
|
| 730 |
+
block_tables,
|
| 731 |
+
isPrefill,
|
| 732 |
+
attn_metadata,
|
| 733 |
+
output,
|
| 734 |
+
seq_lens_tensor_cpu=self.seq_lens_tensor_cpu)
|
| 735 |
+
else:
|
| 736 |
+
if self.key_cache is not None:
|
| 737 |
+
torch_npu._npu_reshape_and_cache(key=key,
|
| 738 |
+
value=value,
|
| 739 |
+
key_cache=self.key_cache,
|
| 740 |
+
value_cache=self.value_cache,
|
| 741 |
+
slot_indices=slots)
|
| 742 |
+
|
| 743 |
+
if attn_metadata.num_prefills > 0:
|
| 744 |
+
# Prefix cache disabled and chunk prefill disabled or no prefix cache hit
|
| 745 |
+
if (attn_metadata.block_tables is None
|
| 746 |
+
or attn_metadata.block_tables.numel() == 0):
|
| 747 |
+
if attn_type == AttentionType.ENCODER_ONLY:
|
| 748 |
+
# TODO: change to use torch_npu encoder attention op, instead
|
| 749 |
+
# of torch sdpa
|
| 750 |
+
query = query.movedim(0, query.dim() - 2)
|
| 751 |
+
key = key.movedim(0, key.dim() - 2)
|
| 752 |
+
value = value.movedim(0, value.dim() - 2)
|
| 753 |
+
|
| 754 |
+
causal_attn = (attn_type == AttentionType.DECODER)
|
| 755 |
+
if attn_metadata.seq_lens is not None:
|
| 756 |
+
seq_lens_q = seq_lens_kv = attn_metadata.seq_lens
|
| 757 |
+
attn_masks = [None] * len(seq_lens_q)
|
| 758 |
+
start_q, start_kv = 0, 0
|
| 759 |
+
for seq_len_q, seq_len_kv, mask in zip(
|
| 760 |
+
seq_lens_q, seq_lens_kv, attn_masks):
|
| 761 |
+
end_q = start_q + seq_len_q
|
| 762 |
+
end_kv = start_kv + seq_len_kv
|
| 763 |
+
sub_out = scaled_dot_product_attention(
|
| 764 |
+
query[None, :, start_q:end_q, :],
|
| 765 |
+
key[None, :, start_kv:end_kv, :],
|
| 766 |
+
value[None, :, start_kv:end_kv, :],
|
| 767 |
+
attn_mask=mask,
|
| 768 |
+
dropout_p=0.0,
|
| 769 |
+
is_causal=causal_attn and mask is None,
|
| 770 |
+
scale=self.scale).squeeze(0).movedim(
|
| 771 |
+
query.dim() - 2, 0)
|
| 772 |
+
output[start_q:end_q, :, :] = sub_out
|
| 773 |
+
start_q, start_kv = end_q, end_kv
|
| 774 |
+
else:
|
| 775 |
+
assert attn_metadata.attn_mask is not None
|
| 776 |
+
mask = attn_metadata.attn_mask
|
| 777 |
+
assert attn_metadata.prefill_metadata is not None
|
| 778 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 779 |
+
np.array(attn_metadata.prefill_metadata.seq_lens).
|
| 780 |
+
astype(np.int32))
|
| 781 |
+
if is_310p():
|
| 782 |
+
# align q k v output tensors
|
| 783 |
+
query = aligned_16(query)
|
| 784 |
+
key = aligned_16(key)
|
| 785 |
+
value = aligned_16(value)
|
| 786 |
+
output = aligned_16(output)
|
| 787 |
+
|
| 788 |
+
# do reformat in case of broadcasted tensors
|
| 789 |
+
mask = mask.repeat(
|
| 790 |
+
self.seq_lens_tensor_cpu.size(0), 1, 1, 1)
|
| 791 |
+
mask = torch_npu.npu_format_cast(
|
| 792 |
+
mask.contiguous(), ACL_FORMAT_FRACTAL_NZ)
|
| 793 |
+
torch_npu._npu_flash_attention(
|
| 794 |
+
query=query,
|
| 795 |
+
key=key,
|
| 796 |
+
value=value,
|
| 797 |
+
mask=mask,
|
| 798 |
+
seq_len=self.seq_lens_tensor_cpu,
|
| 799 |
+
scale_value=self.scale,
|
| 800 |
+
num_heads=self.num_heads,
|
| 801 |
+
num_kv_heads=self.num_kv_heads,
|
| 802 |
+
out=output)
|
| 803 |
+
output = output[:num_tokens, :, :]
|
| 804 |
+
# Prefix cache only and cache hit
|
| 805 |
+
elif attn_metadata.num_decode_tokens == 0 and not attn_metadata.chunked_prefill_enabled:
|
| 806 |
+
assert kv_cache is not None
|
| 807 |
+
assert attn_metadata.prefill_metadata is not None
|
| 808 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 809 |
+
np.array(
|
| 810 |
+
attn_metadata.prefill_metadata.seq_lens).astype(
|
| 811 |
+
np.int32))
|
| 812 |
+
self.query_lens_tensor_cpu = torch.from_numpy(
|
| 813 |
+
np.array(
|
| 814 |
+
attn_metadata.prefill_metadata.query_lens).astype(
|
| 815 |
+
np.int32))
|
| 816 |
+
block_tables = attn_metadata.prefill_metadata.block_tables
|
| 817 |
+
assert attn_metadata.compress_mask is not None
|
| 818 |
+
compress_mask = attn_metadata.compress_mask
|
| 819 |
+
torch_npu._npu_flash_attention_qlens(
|
| 820 |
+
query=query,
|
| 821 |
+
key_cache=self.key_cache,
|
| 822 |
+
value_cache=self.value_cache,
|
| 823 |
+
block_table=block_tables,
|
| 824 |
+
mask=compress_mask,
|
| 825 |
+
seq_len=self.query_lens_tensor_cpu,
|
| 826 |
+
context_lens=self.seq_lens_tensor_cpu,
|
| 827 |
+
num_kv_heads=self.num_kv_heads,
|
| 828 |
+
num_heads=self.num_heads,
|
| 829 |
+
scale_value=self.scale,
|
| 830 |
+
out=output)
|
| 831 |
+
# Splitfuse
|
| 832 |
+
else:
|
| 833 |
+
assert kv_cache is not None
|
| 834 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 835 |
+
np.array(attn_metadata.seq_lens).astype(np.int32))
|
| 836 |
+
self.query_lens_tensor_cpu = torch.from_numpy(
|
| 837 |
+
np.array(attn_metadata.query_lens).astype(np.int32))
|
| 838 |
+
block_tables = attn_metadata.block_tables
|
| 839 |
+
assert attn_metadata.chunk_mask is not None
|
| 840 |
+
chunk_mask = attn_metadata.chunk_mask
|
| 841 |
+
torch_npu._npu_paged_attention_splitfuse(
|
| 842 |
+
query=query,
|
| 843 |
+
key_cache=self.key_cache,
|
| 844 |
+
value_cache=self.value_cache,
|
| 845 |
+
block_table=block_tables,
|
| 846 |
+
context_lens=self.seq_lens_tensor_cpu,
|
| 847 |
+
mask=chunk_mask,
|
| 848 |
+
seq_len=self.query_lens_tensor_cpu,
|
| 849 |
+
num_kv_heads=self.num_kv_heads,
|
| 850 |
+
num_heads=self.num_heads,
|
| 851 |
+
scale_value=self.scale,
|
| 852 |
+
out=output)
|
| 853 |
+
# Decode only
|
| 854 |
+
else:
|
| 855 |
+
assert self.key_cache is not None
|
| 856 |
+
assert self.value_cache is not None
|
| 857 |
+
assert attn_metadata.decode_metadata is not None
|
| 858 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 859 |
+
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
| 860 |
+
np.int32))
|
| 861 |
+
if is_310p():
|
| 862 |
+
# # seq_lens_tensor needs to be transferred to the device for 310P
|
| 863 |
+
self.seq_lens_tensor_cpu = self.seq_lens_tensor_cpu.to(
|
| 864 |
+
device=self.key_cache.device)
|
| 865 |
+
block_tables = attn_metadata.decode_metadata.block_tables
|
| 866 |
+
torch_npu._npu_paged_attention(
|
| 867 |
+
query=query,
|
| 868 |
+
key_cache=self.key_cache,
|
| 869 |
+
value_cache=self.value_cache,
|
| 870 |
+
num_kv_heads=self.num_kv_heads,
|
| 871 |
+
num_heads=self.num_heads,
|
| 872 |
+
scale_value=self.scale,
|
| 873 |
+
block_table=block_tables,
|
| 874 |
+
context_lens=self.seq_lens_tensor_cpu,
|
| 875 |
+
out=output)
|
| 876 |
+
|
| 877 |
+
return output.view(num_tokens, self.hidden_size)
|
| 878 |
+
|
| 879 |
+
|
| 880 |
+
class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
| 881 |
+
|
| 882 |
+
def __init__(
|
| 883 |
+
self,
|
| 884 |
+
num_heads: int,
|
| 885 |
+
head_size: int,
|
| 886 |
+
scale: float,
|
| 887 |
+
num_kv_heads: int,
|
| 888 |
+
alibi_slopes: Optional[List[float]],
|
| 889 |
+
sliding_window: Optional[int],
|
| 890 |
+
kv_cache_dtype: str,
|
| 891 |
+
blocksparse_params: Optional[Dict[str, Any]] = None,
|
| 892 |
+
logits_soft_cap: Optional[float] = None,
|
| 893 |
+
attn_type: str = AttentionType.DECODER,
|
| 894 |
+
kv_sharing_target_layer_name: Optional[str] = None,
|
| 895 |
+
**extra_impl_args,
|
| 896 |
+
) -> None:
|
| 897 |
+
self.num_heads = num_heads
|
| 898 |
+
self.head_size = head_size
|
| 899 |
+
self.scale = float(scale)
|
| 900 |
+
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
| 901 |
+
self.hidden_size = self.num_heads * self.head_size
|
| 902 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 903 |
+
self.sliding_window = sliding_window
|
| 904 |
+
if alibi_slopes is not None:
|
| 905 |
+
alibi_slopes = torch.tensor(alibi_slopes,
|
| 906 |
+
dtype=torch.float32,
|
| 907 |
+
device="npu")
|
| 908 |
+
self.alibi_slopes = alibi_slopes
|
| 909 |
+
self.attn_type = attn_type
|
| 910 |
+
|
| 911 |
+
assert self.num_heads % self.num_kv_heads == 0
|
| 912 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 913 |
+
self.seq_len_cpu_tensor = None
|
| 914 |
+
|
| 915 |
+
# MLA Args
|
| 916 |
+
self.q_lora_rank = extra_impl_args['q_lora_rank']
|
| 917 |
+
self.kv_lora_rank = extra_impl_args['kv_lora_rank']
|
| 918 |
+
self.qk_nope_head_dim = extra_impl_args['qk_nope_head_dim']
|
| 919 |
+
self.qk_rope_head_dim = extra_impl_args['qk_rope_head_dim']
|
| 920 |
+
self.qk_head_dim = extra_impl_args['qk_head_dim']
|
| 921 |
+
self.v_head_dim = extra_impl_args['v_head_dim']
|
| 922 |
+
self.rotary_emb = extra_impl_args['rotary_emb']
|
| 923 |
+
self.q_proj = extra_impl_args['q_proj']
|
| 924 |
+
self.kv_b_proj = extra_impl_args['kv_b_proj']
|
| 925 |
+
self.o_proj = extra_impl_args['o_proj']
|
| 926 |
+
self.kv_a_proj_with_mqa = extra_impl_args.get('kv_a_proj_with_mqa',
|
| 927 |
+
None)
|
| 928 |
+
self.kv_a_layernorm = extra_impl_args.get('kv_a_layernorm', None)
|
| 929 |
+
self.k_pe_cache = None
|
| 930 |
+
self.k_nope_cache = None
|
| 931 |
+
self.w_kc = None
|
| 932 |
+
self.w_vc = None
|
| 933 |
+
|
| 934 |
+
ascend_config = get_ascend_config()
|
| 935 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
def exec_kv(
|
| 939 |
+
self,
|
| 940 |
+
hidden_states: torch.Tensor,
|
| 941 |
+
cos: torch.Tensor,
|
| 942 |
+
sin: torch.Tensor,
|
| 943 |
+
kv_cache: Tuple,
|
| 944 |
+
slots: torch.Tensor,
|
| 945 |
+
):
|
| 946 |
+
B = hidden_states.shape[0]
|
| 947 |
+
N = self.num_kv_heads
|
| 948 |
+
S = 1
|
| 949 |
+
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
| 950 |
+
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
| 951 |
+
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
| 952 |
+
|
| 953 |
+
k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
|
| 954 |
+
kv,
|
| 955 |
+
self.kv_a_layernorm.weight,
|
| 956 |
+
cos,
|
| 957 |
+
sin,
|
| 958 |
+
slots.to(torch.int64),
|
| 959 |
+
kv_cache[1],
|
| 960 |
+
kv_cache[0],
|
| 961 |
+
epsilon=self.kv_a_layernorm.variance_epsilon,
|
| 962 |
+
cache_mode="PA",
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
return k_pe, k_nope
|
| 966 |
+
|
| 967 |
+
def apply_rotary_emb(
|
| 968 |
+
self,
|
| 969 |
+
x: torch.Tensor,
|
| 970 |
+
cos: torch.Tensor,
|
| 971 |
+
sin: torch.Tensor,
|
| 972 |
+
is_neox_style: bool,
|
| 973 |
+
) -> torch.Tensor:
|
| 974 |
+
"""
|
| 975 |
+
Args:
|
| 976 |
+
x: [num_tokens, num_heads, head_size]
|
| 977 |
+
cos: [num_tokens, head_size // 2]
|
| 978 |
+
sin: [num_tokens, head_size // 2]
|
| 979 |
+
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
| 980 |
+
positional embeddings.
|
| 981 |
+
"""
|
| 982 |
+
cos = cos.unsqueeze(-2).to(x.dtype)
|
| 983 |
+
sin = sin.unsqueeze(-2).to(x.dtype)
|
| 984 |
+
if is_neox_style:
|
| 985 |
+
x1, x2 = torch.chunk(x, 2, dim=-1)
|
| 986 |
+
else:
|
| 987 |
+
x1 = x[..., ::2]
|
| 988 |
+
x2 = x[..., 1::2]
|
| 989 |
+
o1 = x1 * cos - x2 * sin
|
| 990 |
+
o2 = x2 * cos + x1 * sin
|
| 991 |
+
if is_neox_style:
|
| 992 |
+
return torch.cat((o1, o2), dim=-1)
|
| 993 |
+
else:
|
| 994 |
+
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
| 995 |
+
|
| 996 |
+
def rope_single(
|
| 997 |
+
self,
|
| 998 |
+
x: torch.Tensor,
|
| 999 |
+
cos: torch.Tensor,
|
| 1000 |
+
sin: torch.Tensor,
|
| 1001 |
+
) -> torch.Tensor:
|
| 1002 |
+
B, N, D = x.shape
|
| 1003 |
+
S = 1
|
| 1004 |
+
x = x.view(B, N, S, D)
|
| 1005 |
+
x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
|
| 1006 |
+
return x.view(B, N, D)
|
| 1007 |
+
|
| 1008 |
+
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
| 1009 |
+
if self.w_kc is None or self.w_vc is None:
|
| 1010 |
+
kv_b_proj_weight = self.kv_b_proj.weight.reshape(
|
| 1011 |
+
self.num_heads, self.qk_nope_head_dim + self.v_head_dim,
|
| 1012 |
+
self.kv_lora_rank)
|
| 1013 |
+
self.w_kc = kv_b_proj_weight[:, :self.
|
| 1014 |
+
qk_nope_head_dim, :].contiguous()
|
| 1015 |
+
self.w_vc = kv_b_proj_weight[:,
|
| 1016 |
+
self.qk_nope_head_dim:, :].transpose(
|
| 1017 |
+
1, 2).contiguous()
|
| 1018 |
+
|
| 1019 |
+
def forward(
|
| 1020 |
+
self,
|
| 1021 |
+
layer: AttentionLayer,
|
| 1022 |
+
hidden_states_or_q_c: torch.Tensor,
|
| 1023 |
+
hidden_states_or_kv_c_normed: torch.Tensor,
|
| 1024 |
+
k_pe: torch.Tensor,
|
| 1025 |
+
kv_cache: torch.Tensor,
|
| 1026 |
+
attn_metadata: AscendMetadata,
|
| 1027 |
+
attn_type: str = AttentionType.DECODER,
|
| 1028 |
+
output: Optional[torch.Tensor] = None,
|
| 1029 |
+
) -> torch.Tensor:
|
| 1030 |
+
"""Forward pass with Ascend attention.
|
| 1031 |
+
Args:
|
| 1032 |
+
hidden_states_or_q_c: shape = [num_tokens, num_heads * head_size]
|
| 1033 |
+
num_tokens = batch_size * seq_len
|
| 1034 |
+
hidden_states_or_kv_c_normed: shape = [num_tokens, num_kv_heads * head_size]
|
| 1035 |
+
k_pe: shape = [num_tokens, num_kv_heads * head_size]
|
| 1036 |
+
kv_cache: shape = [1, num_blocks, block_size,
|
| 1037 |
+
num_kv_heads * head_size]
|
| 1038 |
+
attn_metadata: Metadata for attention.
|
| 1039 |
+
Returns:
|
| 1040 |
+
shape = [batch_size, seq_len * num_heads * head_size]
|
| 1041 |
+
"""
|
| 1042 |
+
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
| 1043 |
+
attn_type = self.attn_type
|
| 1044 |
+
if attn_type != AttentionType.DECODER:
|
| 1045 |
+
raise NotImplementedError("Encoder self-attention and "
|
| 1046 |
+
"encoder/decoder cross-attention "
|
| 1047 |
+
"are not implemented for "
|
| 1048 |
+
"PallasAttentionBackendImpl")
|
| 1049 |
+
|
| 1050 |
+
if attn_metadata is None:
|
| 1051 |
+
# for profile run
|
| 1052 |
+
return hidden_states_or_q_c
|
| 1053 |
+
|
| 1054 |
+
num_tokens = hidden_states_or_q_c.shape[0]
|
| 1055 |
+
q = self.q_proj(hidden_states_or_q_c)[0].view(-1, self.num_heads,
|
| 1056 |
+
self.qk_head_dim)
|
| 1057 |
+
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
| 1058 |
+
dim=-1)
|
| 1059 |
+
if k_pe is None and attn_metadata.decode_metadata:
|
| 1060 |
+
seq_len = self.rotary_emb.max_position_embeddings
|
| 1061 |
+
|
| 1062 |
+
cos = self.rotary_emb.cos_cached[:seq_len].to(dtype=q_pe.dtype)
|
| 1063 |
+
sin = self.rotary_emb.sin_cached[:seq_len].to(dtype=q_pe.dtype)
|
| 1064 |
+
cos = cos[attn_metadata.input_positions]
|
| 1065 |
+
sin = sin[attn_metadata.input_positions]
|
| 1066 |
+
cos = cos[:, None, None, :]
|
| 1067 |
+
sin = sin[:, None, None, :]
|
| 1068 |
+
|
| 1069 |
+
q_pe = self.rope_single(q_pe, cos, sin)
|
| 1070 |
+
k_pe, k_nope = self.exec_kv(hidden_states_or_kv_c_normed, cos, sin,
|
| 1071 |
+
kv_cache, attn_metadata.slot_mapping)
|
| 1072 |
+
else:
|
| 1073 |
+
if k_pe is None:
|
| 1074 |
+
# NOTE: k_pe is None when graph mode enabled
|
| 1075 |
+
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
| 1076 |
+
hidden_states_or_kv_c_normed)[0].split(
|
| 1077 |
+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
| 1078 |
+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
| 1079 |
+
else:
|
| 1080 |
+
kv_c_normed = hidden_states_or_kv_c_normed
|
| 1081 |
+
k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)
|
| 1082 |
+
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
|
| 1083 |
+
# NOTE: When scaling not specified
|
| 1084 |
+
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
| 1085 |
+
q_pe = q_pe.reshape(num_tokens, -1)
|
| 1086 |
+
k_pe = k_pe.reshape(num_tokens, -1)
|
| 1087 |
+
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
|
| 1088 |
+
q_pe, k_pe)
|
| 1089 |
+
q_pe = q_pe.view(ori_q_pe_shape)
|
| 1090 |
+
k_pe = k_pe.view(ori_k_pe_shape)
|
| 1091 |
+
else:
|
| 1092 |
+
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
|
| 1093 |
+
q_pe, k_pe)
|
| 1094 |
+
|
| 1095 |
+
if attn_metadata.num_prefills > 0:
|
| 1096 |
+
kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens,
|
| 1097 |
+
self.num_heads, -1)
|
| 1098 |
+
k_nope, value = kv.split([self.qk_nope_head_dim, self.v_head_dim],
|
| 1099 |
+
dim=-1)
|
| 1100 |
+
else:
|
| 1101 |
+
q_nope_t = torch.transpose(q_nope, 0, 1)
|
| 1102 |
+
q_nope_out = torch.bmm(q_nope_t, self.w_kc)
|
| 1103 |
+
q_nope = torch.transpose(q_nope_out, 0, 1)
|
| 1104 |
+
|
| 1105 |
+
query = torch.cat([q_nope, q_pe], dim=-1).view(num_tokens,
|
| 1106 |
+
self.num_heads, -1)
|
| 1107 |
+
|
| 1108 |
+
# TODO: Replace the env with more flexible expressions
|
| 1109 |
+
if self.torchair_graph_enabled:
|
| 1110 |
+
if len(kv_cache) > 0 and kv_cache[0].numel(
|
| 1111 |
+
) > 0 and attn_metadata.num_prefills > 0:
|
| 1112 |
+
slots = attn_metadata.slot_mapping
|
| 1113 |
+
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
|
| 1114 |
+
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
|
| 1115 |
+
num_tokens, self.num_kv_heads, -1),
|
| 1116 |
+
value=k_pe,
|
| 1117 |
+
key_cache=kv_cache[0],
|
| 1118 |
+
value_cache=kv_cache[1],
|
| 1119 |
+
slot_indices=slots)
|
| 1120 |
+
elif kv_cache.numel() > 0:
|
| 1121 |
+
# TODO replace this naive implement with fusion kernel
|
| 1122 |
+
concat_and_cache_mla(kv_c_normed, k_pe, kv_cache,
|
| 1123 |
+
attn_metadata.slot_mapping)
|
| 1124 |
+
|
| 1125 |
+
if attn_metadata.num_prefills > 0:
|
| 1126 |
+
attn_output = torch.empty(num_tokens,
|
| 1127 |
+
self.num_heads,
|
| 1128 |
+
self.v_head_dim,
|
| 1129 |
+
dtype=query.dtype,
|
| 1130 |
+
device=query.device)
|
| 1131 |
+
if (attn_metadata.block_tables is None
|
| 1132 |
+
or attn_metadata.block_tables.numel() == 0):
|
| 1133 |
+
assert attn_metadata.attn_mask is not None
|
| 1134 |
+
assert attn_metadata.prefill_metadata is not None
|
| 1135 |
+
assert attn_metadata.prefill_metadata.seq_lens is not None
|
| 1136 |
+
mask = attn_metadata.attn_mask
|
| 1137 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 1138 |
+
np.array(attn_metadata.prefill_metadata.seq_lens).astype(
|
| 1139 |
+
np.int32))
|
| 1140 |
+
k_pe = k_pe.repeat(1, self.num_heads, 1)
|
| 1141 |
+
key = torch.cat(
|
| 1142 |
+
[k_nope.view(num_tokens, self.num_heads, -1), k_pe], dim=2)
|
| 1143 |
+
torch_npu._npu_flash_attention(
|
| 1144 |
+
query=query,
|
| 1145 |
+
key=key,
|
| 1146 |
+
value=value,
|
| 1147 |
+
mask=mask,
|
| 1148 |
+
seq_len=self.seq_lens_tensor_cpu,
|
| 1149 |
+
scale_value=self.scale,
|
| 1150 |
+
num_heads=self.num_heads,
|
| 1151 |
+
num_kv_heads=self.num_heads,
|
| 1152 |
+
out=attn_output)
|
| 1153 |
+
else:
|
| 1154 |
+
# TODO: Will support prefix cache and chunked prefill soon.
|
| 1155 |
+
raise RuntimeError(
|
| 1156 |
+
"Prefix cache and chunked prefill are currently not supported."
|
| 1157 |
+
)
|
| 1158 |
+
elif attn_metadata.decode_metadata:
|
| 1159 |
+
assert kv_cache is not None
|
| 1160 |
+
if self.torchair_graph_enabled:
|
| 1161 |
+
# shape of query for npu graph mode should be:
|
| 1162 |
+
# [bs, num_heads_per_rank, seq_len, dim]
|
| 1163 |
+
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
| 1164 |
+
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
| 1165 |
+
# shape of knope/k_pe for npu graph mode should be:
|
| 1166 |
+
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
| 1167 |
+
block_size = kv_cache[0].shape[1]
|
| 1168 |
+
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
| 1169 |
+
self.kv_lora_rank)
|
| 1170 |
+
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
| 1171 |
+
self.qk_rope_head_dim)
|
| 1172 |
+
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
| 1173 |
+
q_nope,
|
| 1174 |
+
k_nope,
|
| 1175 |
+
k_nope,
|
| 1176 |
+
query_rope=q_pe,
|
| 1177 |
+
key_rope=k_pe,
|
| 1178 |
+
num_heads=self.num_heads,
|
| 1179 |
+
num_key_value_heads=self.num_kv_heads,
|
| 1180 |
+
input_layout="BNSD",
|
| 1181 |
+
atten_mask=attn_metadata.attn_mask,
|
| 1182 |
+
scale=self.scale,
|
| 1183 |
+
antiquant_mode=0,
|
| 1184 |
+
antiquant_scale=None,
|
| 1185 |
+
block_table=attn_metadata.block_tables,
|
| 1186 |
+
block_size=block_size,
|
| 1187 |
+
actual_seq_lengths_kv=attn_metadata.seq_lens,
|
| 1188 |
+
)
|
| 1189 |
+
attn_output = attn_output.view(num_tokens, -1,
|
| 1190 |
+
self.kv_lora_rank).transpose(
|
| 1191 |
+
0, 1)
|
| 1192 |
+
attn_output = torch.bmm(attn_output, self.w_vc).transpose(0, 1)
|
| 1193 |
+
else:
|
| 1194 |
+
# if torch.empty is used here, the preemptive scheduling case of
|
| 1195 |
+
# test_mtp_correctness.py will fail to run.
|
| 1196 |
+
attn_output = torch.randn(
|
| 1197 |
+
[num_tokens, self.num_heads, self.kv_lora_rank],
|
| 1198 |
+
dtype=query.dtype,
|
| 1199 |
+
device=query.device)
|
| 1200 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 1201 |
+
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
| 1202 |
+
np.int32))
|
| 1203 |
+
block_tables = attn_metadata.decode_metadata.block_tables
|
| 1204 |
+
torch_npu._npu_paged_attention_mla(
|
| 1205 |
+
query=query,
|
| 1206 |
+
key_cache=kv_cache,
|
| 1207 |
+
num_kv_heads=self.num_kv_heads,
|
| 1208 |
+
num_heads=self.num_heads,
|
| 1209 |
+
scale_value=self.scale,
|
| 1210 |
+
block_table=block_tables,
|
| 1211 |
+
context_lens=self.seq_lens_tensor_cpu,
|
| 1212 |
+
mla_vheadsize=self.kv_lora_rank,
|
| 1213 |
+
out=attn_output)
|
| 1214 |
+
attn_output_t = torch.transpose(attn_output, 0, 1)
|
| 1215 |
+
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
|
| 1216 |
+
attn_output = torch.transpose(attn_output_t, 0, 1)
|
| 1217 |
+
|
| 1218 |
+
output, _ = self.o_proj(attn_output.reshape(num_tokens, -1))
|
| 1219 |
+
|
| 1220 |
+
return output
|
inference/vllm_ascend/attention/mla_v1.py
ADDED
|
@@ -0,0 +1,1224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import TYPE_CHECKING, Any, Optional, Tuple, Type, TypeVar
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch_npu
|
| 7 |
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
| 8 |
+
AttentionMetadata,
|
| 9 |
+
MLAAttentionImpl)
|
| 10 |
+
from vllm.attention.backends.utils import PAD_SLOT_ID
|
| 11 |
+
from vllm.config import get_current_vllm_config
|
| 12 |
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
| 13 |
+
from vllm.model_executor.layers.linear import (LinearBase,
|
| 14 |
+
UnquantizedLinearMethod)
|
| 15 |
+
from vllm.utils import cdiv, round_down
|
| 16 |
+
|
| 17 |
+
from vllm_ascend.ascend_config import get_ascend_config
|
| 18 |
+
from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV
|
| 19 |
+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
| 20 |
+
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
| 21 |
+
from vllm_ascend.multistream.context import get_multistream_comm_context
|
| 22 |
+
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
| 23 |
+
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
| 24 |
+
from vllm_ascend.utils import npu_prefetch, npu_stream_switch, npu_wait_tensor
|
| 25 |
+
from vllm_ascend.worker.npu_input_batch import InputBatch
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING:
|
| 28 |
+
from vllm.v1.core.sched.output import SchedulerOutput
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class CommonAttentionMetadata:
|
| 33 |
+
"""
|
| 34 |
+
Attention metadata attributes that can be shared by layers in different KV
|
| 35 |
+
cache groups and thus having different block table.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
query_start_loc: torch.Tensor
|
| 39 |
+
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
| 40 |
+
seq_lens: torch.Tensor
|
| 41 |
+
"""(batch_size,), the length of each request including both computed tokens
|
| 42 |
+
and newly scheduled tokens"""
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class AscendMLABackend(AttentionBackend):
|
| 46 |
+
|
| 47 |
+
accept_output_buffer: bool = True
|
| 48 |
+
|
| 49 |
+
@staticmethod
|
| 50 |
+
def get_name() -> str:
|
| 51 |
+
return "VLLM_ASCEND_MLA"
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def get_metadata_cls() -> type["AttentionMetadata"]:
|
| 55 |
+
return AscendMLAMetadata
|
| 56 |
+
|
| 57 |
+
@staticmethod
|
| 58 |
+
def get_builder_cls():
|
| 59 |
+
return AscendMLAMetadataBuilder
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
|
| 63 |
+
head_size: int) -> tuple[int, ...]:
|
| 64 |
+
return (num_blocks, block_size, num_kv_heads, head_size)
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def get_impl_cls() -> Type["MLAAttentionImpl"]:
|
| 68 |
+
return AscendMLAImpl
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class AscendMLAPrefillMetadata:
|
| 73 |
+
""" Prefill Specific Metadata for Ascend"""
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class ChunkedContextMetadata:
|
| 77 |
+
# New for MLA (compared to FlashAttention)
|
| 78 |
+
# For handling chunked prefill
|
| 79 |
+
cu_seq_lens: torch.Tensor
|
| 80 |
+
starts: torch.Tensor
|
| 81 |
+
seq_tot: list[int]
|
| 82 |
+
max_seq_lens: list[int]
|
| 83 |
+
workspace: torch.Tensor
|
| 84 |
+
chunk_seq_lens: torch.Tensor
|
| 85 |
+
|
| 86 |
+
attn_mask: torch.Tensor
|
| 87 |
+
query_lens: list[int]
|
| 88 |
+
seq_lens: list[int]
|
| 89 |
+
context_lens: torch.Tensor
|
| 90 |
+
input_positions: torch.Tensor
|
| 91 |
+
query_start_loc: torch.Tensor
|
| 92 |
+
block_table: torch.Tensor
|
| 93 |
+
max_query_len: int
|
| 94 |
+
max_seq_lens: int
|
| 95 |
+
chunked_context: Optional[ChunkedContextMetadata] = None
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@dataclass
|
| 99 |
+
class AscendMLADecodeMetadata:
|
| 100 |
+
# Input positions for rotrary embeddings since for MLA the rotary
|
| 101 |
+
# position embeddings are applied inside the attention backend
|
| 102 |
+
input_positions: torch.Tensor
|
| 103 |
+
block_table: torch.Tensor
|
| 104 |
+
seq_lens: torch.Tensor
|
| 105 |
+
max_seq_lens: int
|
| 106 |
+
seq_lens_list: list[int]
|
| 107 |
+
attn_mask: Optional[torch.Tensor] = None
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@dataclass
|
| 111 |
+
class AscendMLAMetadata:
|
| 112 |
+
"""Metadata for MLACommon.
|
| 113 |
+
|
| 114 |
+
NOTE: Please read the comment at the top of the file before trying to
|
| 115 |
+
understand this class
|
| 116 |
+
"""
|
| 117 |
+
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
| 118 |
+
# |---------- N-1 iteration --------|
|
| 119 |
+
# |---------------- N iteration ---------------------|
|
| 120 |
+
# |- tokenA -|......................|-- newTokens ---|
|
| 121 |
+
# |---------- context_len ----------|
|
| 122 |
+
# |-------------------- seq_len ---------------------|
|
| 123 |
+
# |-- query_len ---|
|
| 124 |
+
|
| 125 |
+
num_actual_tokens: int # Number of tokens excluding padding.
|
| 126 |
+
slot_mapping: torch.Tensor
|
| 127 |
+
query_start_loc: torch.Tensor
|
| 128 |
+
seq_lens: torch.Tensor
|
| 129 |
+
block_tables: torch.Tensor
|
| 130 |
+
|
| 131 |
+
# New for MLA (compared to FlashAttention)
|
| 132 |
+
# For handling prefill decode split
|
| 133 |
+
num_decodes: int
|
| 134 |
+
num_decode_tokens: int
|
| 135 |
+
num_prefills: int
|
| 136 |
+
|
| 137 |
+
# For logging.
|
| 138 |
+
num_input_tokens: int = 0 # Number of tokens including padding.
|
| 139 |
+
|
| 140 |
+
max_num_tokens_across_dp: int = 0
|
| 141 |
+
with_prefill_across_dp: bool = False
|
| 142 |
+
|
| 143 |
+
query_lens: Optional[list[int]] = None
|
| 144 |
+
# The dimension of the attention heads
|
| 145 |
+
head_dim: Optional[int] = None
|
| 146 |
+
attn_mask: torch.Tensor = None
|
| 147 |
+
# chunked prefill by default if no attn_states passed
|
| 148 |
+
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
| 149 |
+
|
| 150 |
+
decode: Optional[AscendMLADecodeMetadata] = None
|
| 151 |
+
prefill: Optional[AscendMLAPrefillMetadata] = None
|
| 152 |
+
|
| 153 |
+
def __post_init__(self):
|
| 154 |
+
pass
|
| 155 |
+
# supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
|
| 156 |
+
# if self.head_dim is not None and self.head_dim \
|
| 157 |
+
# not in supported_head_sizes:
|
| 158 |
+
# raise ValueError(
|
| 159 |
+
# f"Only {supported_head_sizes} are supported for head_dim,",
|
| 160 |
+
# f"received {self.head_dim}.")
|
| 161 |
+
|
| 162 |
+
def split_metadata_for_multistream(
|
| 163 |
+
self,
|
| 164 |
+
ms_split_config: MSAttentionMetadataSplitConfig,
|
| 165 |
+
) -> list["AscendMLAMetadata"]:
|
| 166 |
+
"""Split metadata for multi-stream with AscendMLAMetadata"""
|
| 167 |
+
return model_input_split_v1_mla_attn(
|
| 168 |
+
ms_split_config=ms_split_config,
|
| 169 |
+
attn_metadata=self,
|
| 170 |
+
_metadata_cls=AscendMLAMetadata,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
M = TypeVar("M", bound=AscendMLAMetadata)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class AscendMLAMetadataBuilder:
|
| 178 |
+
"""
|
| 179 |
+
NOTE: Please read the comment at the top of the file before trying to
|
| 180 |
+
understand this class
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
# _attn_mask_builder = None
|
| 184 |
+
def __init__(self,
|
| 185 |
+
runner,
|
| 186 |
+
metadata_cls: Optional[AscendMLAMetadata] = None):
|
| 187 |
+
self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \
|
| 188 |
+
if metadata_cls is not None else AscendMLAMetadata # type: ignore
|
| 189 |
+
self.runner = runner
|
| 190 |
+
scheduler_config = runner.scheduler_config
|
| 191 |
+
model_config = runner.model_config
|
| 192 |
+
self.block_size = runner.block_size
|
| 193 |
+
self.chunked_prefill_enabled = runner.chunked_prefill_enabled
|
| 194 |
+
if self.chunked_prefill_enabled:
|
| 195 |
+
self.chunked_prefill_workspace_size = min(
|
| 196 |
+
# Max sure there is enough for 8 full length request or at least
|
| 197 |
+
# 4 pages of cache per request
|
| 198 |
+
max(8 * model_config.max_model_len,
|
| 199 |
+
4 * scheduler_config.max_num_seqs * self.block_size),
|
| 200 |
+
# For long-context models try not to over-allocate limiting
|
| 201 |
+
# kv-cache space, limiting it to 64k tokens,
|
| 202 |
+
# which would result in the workspace being:
|
| 203 |
+
# 2*(576)*(64*1024) = 144mb
|
| 204 |
+
# (assuming 576 MLA head dim, and fp16)
|
| 205 |
+
# which would result in up-projected context being
|
| 206 |
+
# 2*(192*128)*(64*1024) = 3gb
|
| 207 |
+
# (assuming 192 QK head dim, 128 heads, and fp16)
|
| 208 |
+
128 * 1024)
|
| 209 |
+
assert self.chunked_prefill_workspace_size >= \
|
| 210 |
+
scheduler_config.max_num_seqs * self.block_size
|
| 211 |
+
self.chunked_prefill_workspace = torch.empty(
|
| 212 |
+
(self.chunked_prefill_workspace_size,
|
| 213 |
+
model_config.get_head_size()),
|
| 214 |
+
dtype=model_config.dtype,
|
| 215 |
+
device=runner.device,
|
| 216 |
+
)
|
| 217 |
+
ascend_config = get_ascend_config()
|
| 218 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 219 |
+
|
| 220 |
+
def reorder_batch(self, input_batch: "InputBatch",
|
| 221 |
+
scheduler_output: "SchedulerOutput") -> bool:
|
| 222 |
+
# We now want to reorder the batch so that the "decode" requests are at
|
| 223 |
+
# the front and the "prefill" requests are at the using the least amount
|
| 224 |
+
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
| 225 |
+
# where attention is likely memory-bound and "prefill" to mean requests
|
| 226 |
+
# where attention is likely compute-bound, TODO(lucas): figure out a
|
| 227 |
+
# better naming here)
|
| 228 |
+
decodes = []
|
| 229 |
+
prefills = []
|
| 230 |
+
num_decode_tokens = 0
|
| 231 |
+
num_prefill_tokens = 0
|
| 232 |
+
|
| 233 |
+
for i, req_id in enumerate(input_batch.req_ids):
|
| 234 |
+
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
| 235 |
+
num_spec_tokens = len(
|
| 236 |
+
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
| 237 |
+
# For torch air graph mode we treat spec decoding as decode.
|
| 238 |
+
if self.torchair_graph_enabled:
|
| 239 |
+
if num_tokens - num_spec_tokens == 1:
|
| 240 |
+
decodes.append(i)
|
| 241 |
+
num_decode_tokens += num_tokens
|
| 242 |
+
else:
|
| 243 |
+
prefills.append(i)
|
| 244 |
+
num_prefill_tokens += num_tokens
|
| 245 |
+
# For eager mode we treat spec decoding as chunked prefill.
|
| 246 |
+
else:
|
| 247 |
+
if num_tokens == 1:
|
| 248 |
+
decodes.append(i)
|
| 249 |
+
num_decode_tokens += num_tokens
|
| 250 |
+
else:
|
| 251 |
+
prefills.append(i)
|
| 252 |
+
num_prefill_tokens += num_tokens
|
| 253 |
+
|
| 254 |
+
# We hope that this is fairly minimal since decodes
|
| 255 |
+
# should be around for a number of iterations so hopefully they are
|
| 256 |
+
# relatively stationary (and new request are generally appended to the
|
| 257 |
+
# persistent batch so already should be at the back)
|
| 258 |
+
# To achieve this we loop over the decodes in descending order and
|
| 259 |
+
# the prefills in ascending order. We swap decodes from the "back"
|
| 260 |
+
# i.e. past where the last decode should be in the reodorered with
|
| 261 |
+
# prefills from the front of the batch.
|
| 262 |
+
# `decodes` and `prefills` are already in ascending order just based on
|
| 263 |
+
# the above loop
|
| 264 |
+
num_decodes = len(decodes)
|
| 265 |
+
num_prefills = len(prefills)
|
| 266 |
+
first_prefill = 0
|
| 267 |
+
modified_batch = False
|
| 268 |
+
|
| 269 |
+
for i in range(1, min(num_decodes, num_prefills) + 1):
|
| 270 |
+
# If the decode is at the "back" of the batch, i, we can swap it
|
| 271 |
+
# with the prefill closest to the front of the batch
|
| 272 |
+
if decodes[num_decodes - i] >= num_decodes:
|
| 273 |
+
input_batch.swap_states(prefills[first_prefill],
|
| 274 |
+
decodes[num_decodes - i])
|
| 275 |
+
first_prefill += 1
|
| 276 |
+
modified_batch = True
|
| 277 |
+
else:
|
| 278 |
+
break
|
| 279 |
+
|
| 280 |
+
# Save for next `build` call
|
| 281 |
+
# TODO(lucas): this is a bit of a hack, we should probably have a
|
| 282 |
+
# better way of doing this
|
| 283 |
+
self._num_decodes = num_decodes
|
| 284 |
+
self._num_prefills = num_prefills
|
| 285 |
+
self._num_decode_tokens = num_decode_tokens
|
| 286 |
+
self._num_prefill_tokens = num_prefill_tokens
|
| 287 |
+
|
| 288 |
+
return modified_batch
|
| 289 |
+
|
| 290 |
+
def _get_graph_runner_block_tables(
|
| 291 |
+
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
|
| 292 |
+
|
| 293 |
+
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
| 294 |
+
assert max_batch_size >= num_seqs
|
| 295 |
+
|
| 296 |
+
if isinstance(self.runner.graph_block_tables, np.ndarray):
|
| 297 |
+
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
|
| 298 |
+
dtype=block_tables.dtype,
|
| 299 |
+
device=block_tables.device)
|
| 300 |
+
else:
|
| 301 |
+
graph_block_tables = self.runner.graph_block_tables.to(
|
| 302 |
+
device=block_tables.device, dtype=block_tables.dtype)
|
| 303 |
+
|
| 304 |
+
num_blocks = block_tables.size(1)
|
| 305 |
+
if num_blocks <= max_blocks:
|
| 306 |
+
graph_block_tables[:num_seqs, :
|
| 307 |
+
num_blocks] = block_tables[:num_seqs, :
|
| 308 |
+
num_blocks]
|
| 309 |
+
else:
|
| 310 |
+
graph_block_tables[:num_seqs, :
|
| 311 |
+
max_blocks] = block_tables[:num_seqs, :
|
| 312 |
+
max_blocks]
|
| 313 |
+
|
| 314 |
+
return graph_block_tables[:num_seqs, :max_blocks]
|
| 315 |
+
|
| 316 |
+
def build_dummy(self, num_reqs: int,
|
| 317 |
+
num_actual_tokens: int) -> AscendMLAMetadata:
|
| 318 |
+
device = self.runner.device
|
| 319 |
+
_, max_blocks = self.runner.graph_block_tables.shape
|
| 320 |
+
block_table = torch.zeros((num_reqs, max_blocks),
|
| 321 |
+
dtype=torch.int32,
|
| 322 |
+
device=device)
|
| 323 |
+
block_table = self._get_graph_runner_block_tables(
|
| 324 |
+
num_reqs, block_table)
|
| 325 |
+
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
|
| 326 |
+
input_positions = torch.zeros(num_reqs,
|
| 327 |
+
dtype=torch.int32,
|
| 328 |
+
device=device).long()
|
| 329 |
+
slot_mapping = torch.full((num_reqs, ),
|
| 330 |
+
PAD_SLOT_ID,
|
| 331 |
+
dtype=torch.int32,
|
| 332 |
+
device=device)
|
| 333 |
+
query_start_loc = torch.full((num_reqs, ),
|
| 334 |
+
-1,
|
| 335 |
+
dtype=torch.int32,
|
| 336 |
+
device=device)
|
| 337 |
+
decode_metadata = AscendMLADecodeMetadata(
|
| 338 |
+
input_positions=input_positions,
|
| 339 |
+
block_table=block_table,
|
| 340 |
+
seq_lens=seq_lens,
|
| 341 |
+
seq_lens_list=seq_lens.tolist(),
|
| 342 |
+
max_seq_lens=1,
|
| 343 |
+
attn_mask=self.runner.spec_attn_mask)
|
| 344 |
+
return self.metadata_cls( # type: ignore
|
| 345 |
+
num_input_tokens=num_actual_tokens,
|
| 346 |
+
num_actual_tokens=num_actual_tokens,
|
| 347 |
+
slot_mapping=slot_mapping,
|
| 348 |
+
head_dim=self.runner.model_config.get_head_size(),
|
| 349 |
+
num_decodes=1,
|
| 350 |
+
num_decode_tokens=1,
|
| 351 |
+
num_prefills=0,
|
| 352 |
+
attn_mask=self.runner.attn_mask,
|
| 353 |
+
attn_state=AscendAttentionState.DecodeOnly,
|
| 354 |
+
prefill=None,
|
| 355 |
+
decode=decode_metadata,
|
| 356 |
+
query_start_loc=query_start_loc,
|
| 357 |
+
seq_lens=seq_lens,
|
| 358 |
+
block_tables=block_table,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
def build(
|
| 362 |
+
self,
|
| 363 |
+
num_reqs: int,
|
| 364 |
+
num_actual_tokens: int,
|
| 365 |
+
max_query_len: int,
|
| 366 |
+
common_attn_metadata: CommonAttentionMetadata,
|
| 367 |
+
common_prefix_len: Optional[int] = None,
|
| 368 |
+
graph_pad_size: int = -1,
|
| 369 |
+
max_num_tokens_across_dp: int = 0,
|
| 370 |
+
with_prefill_across_dp: bool = False,
|
| 371 |
+
) -> AscendMLAMetadata:
|
| 372 |
+
assert self._num_decodes + self._num_prefills == num_reqs
|
| 373 |
+
|
| 374 |
+
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
| 375 |
+
# function. We should avoid GPU -> CPU sync as much as possible because
|
| 376 |
+
# it blocks on all previous kernels.
|
| 377 |
+
device = self.runner.device
|
| 378 |
+
|
| 379 |
+
block_table = (self.runner.input_batch.block_table[0].
|
| 380 |
+
get_device_tensor()[:num_reqs])
|
| 381 |
+
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
| 382 |
+
device, non_blocking=True)
|
| 383 |
+
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
|
| 384 |
+
device, non_blocking=True).long()
|
| 385 |
+
|
| 386 |
+
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
|
| 387 |
+
query_lens = seq_lens_cpu - self.runner.input_batch.num_computed_tokens_cpu_tensor[:
|
| 388 |
+
num_reqs]
|
| 389 |
+
seq_lens = seq_lens_cpu
|
| 390 |
+
max_query_len = query_lens.max().item()
|
| 391 |
+
max_seq_lens = seq_lens.max().item()
|
| 392 |
+
query_start_loc = common_attn_metadata.query_start_loc
|
| 393 |
+
|
| 394 |
+
prefill_metadata = None
|
| 395 |
+
chunked_context_metadata = None
|
| 396 |
+
if self._num_prefills > 0:
|
| 397 |
+
reqs_start = self._num_decodes # prefill_start
|
| 398 |
+
tokens_start = self._num_decode_tokens
|
| 399 |
+
max_query_len = query_lens[tokens_start:].max().item()
|
| 400 |
+
max_seq_lens = seq_lens[tokens_start:].max().item()
|
| 401 |
+
query_start_loc = common_attn_metadata.query_start_loc
|
| 402 |
+
prefill_query_start_loc = query_start_loc[
|
| 403 |
+
reqs_start:] - query_start_loc[reqs_start]
|
| 404 |
+
|
| 405 |
+
context_lens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[
|
| 406 |
+
reqs_start:num_reqs]
|
| 407 |
+
max_context_len_cpu = context_lens_cpu.max().item()
|
| 408 |
+
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
| 409 |
+
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
| 410 |
+
max_context_chunk = (self.chunked_prefill_workspace_size //
|
| 411 |
+
num_prefills_with_context_cpu)
|
| 412 |
+
max_context_chunk = round_down(max_context_chunk,
|
| 413 |
+
self.block_size)
|
| 414 |
+
|
| 415 |
+
assert max_context_chunk > 0
|
| 416 |
+
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
|
| 417 |
+
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
|
| 418 |
+
.unsqueeze(1).expand(-1, self._num_prefills) * max_context_chunk
|
| 419 |
+
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
| 420 |
+
chunk_starts + max_context_chunk)
|
| 421 |
+
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
| 422 |
+
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
| 423 |
+
self._num_prefills + 1,
|
| 424 |
+
dtype=torch.int32,
|
| 425 |
+
pin_memory=True)
|
| 426 |
+
torch.cumsum(chunk_seq_lens,
|
| 427 |
+
dim=1,
|
| 428 |
+
out=cu_seq_lens_cpu[:, 1:],
|
| 429 |
+
dtype=torch.int32)
|
| 430 |
+
chunked_context_metadata = \
|
| 431 |
+
AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
| 432 |
+
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
| 433 |
+
starts=chunk_starts.to(device, non_blocking=True),
|
| 434 |
+
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
| 435 |
+
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
| 436 |
+
chunk_seq_lens=chunk_seq_lens,
|
| 437 |
+
workspace=self.chunked_prefill_workspace,
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
prefill_metadata = AscendMLAPrefillMetadata(
|
| 441 |
+
attn_mask=self.runner.attn_mask,
|
| 442 |
+
query_lens=query_lens[tokens_start:],
|
| 443 |
+
seq_lens=seq_lens,
|
| 444 |
+
context_lens=seq_lens[tokens_start:],
|
| 445 |
+
input_positions=input_positions[tokens_start:],
|
| 446 |
+
block_table=block_table[reqs_start:, ...],
|
| 447 |
+
max_query_len=max_query_len,
|
| 448 |
+
max_seq_lens=max_seq_lens,
|
| 449 |
+
query_start_loc=prefill_query_start_loc,
|
| 450 |
+
chunked_context=chunked_context_metadata,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
decode_metadata = None
|
| 454 |
+
use_torchair_graph = graph_pad_size != -1
|
| 455 |
+
if self._num_decodes > 0:
|
| 456 |
+
max_seq_lens = seq_lens[:self._num_decodes].max().item()
|
| 457 |
+
seq_lens = seq_lens[:self._num_decode_tokens]
|
| 458 |
+
input_positions = input_positions[:self._num_decode_tokens]
|
| 459 |
+
block_table = block_table[:self._num_decode_tokens, ...]
|
| 460 |
+
if use_torchair_graph and self.runner.attn_state in [
|
| 461 |
+
AscendAttentionState.DecodeOnly,
|
| 462 |
+
AscendAttentionState.SpecDecoding
|
| 463 |
+
]:
|
| 464 |
+
num_seqs = len(seq_lens)
|
| 465 |
+
if graph_pad_size != 0:
|
| 466 |
+
pad_value = 1
|
| 467 |
+
padded_seq_lens = seq_lens.tolist() + [pad_value
|
| 468 |
+
] * graph_pad_size
|
| 469 |
+
else:
|
| 470 |
+
padded_seq_lens = seq_lens.tolist()
|
| 471 |
+
|
| 472 |
+
seq_lens = torch.from_numpy(
|
| 473 |
+
np.array(padded_seq_lens).astype(np.int32))
|
| 474 |
+
padding = torch.full((graph_pad_size, ),
|
| 475 |
+
PAD_SLOT_ID,
|
| 476 |
+
dtype=slot_mapping.dtype,
|
| 477 |
+
device=slot_mapping.device)
|
| 478 |
+
slot_mapping = torch.cat([slot_mapping, padding])
|
| 479 |
+
block_table_padding = torch.zeros(
|
| 480 |
+
(graph_pad_size, ) + block_table.shape[1:],
|
| 481 |
+
dtype=block_table.dtype,
|
| 482 |
+
device=block_table.device)
|
| 483 |
+
block_table = torch.cat([block_table, block_table_padding],
|
| 484 |
+
dim=0)
|
| 485 |
+
block_table = self._get_graph_runner_block_tables(
|
| 486 |
+
num_seqs + graph_pad_size, block_table)
|
| 487 |
+
padding_0 = torch.zeros(graph_pad_size,
|
| 488 |
+
dtype=input_positions.dtype,
|
| 489 |
+
device=input_positions.device)
|
| 490 |
+
input_positions = torch.cat([input_positions, padding_0])
|
| 491 |
+
|
| 492 |
+
decode_metadata = AscendMLADecodeMetadata(
|
| 493 |
+
input_positions=input_positions,
|
| 494 |
+
block_table=block_table,
|
| 495 |
+
seq_lens=seq_lens,
|
| 496 |
+
seq_lens_list=seq_lens.tolist(),
|
| 497 |
+
max_seq_lens=max_seq_lens,
|
| 498 |
+
attn_mask=self.runner.spec_attn_mask)
|
| 499 |
+
|
| 500 |
+
return self.metadata_cls( # type: ignore
|
| 501 |
+
num_actual_tokens=num_actual_tokens,
|
| 502 |
+
query_lens=query_lens.tolist(),
|
| 503 |
+
slot_mapping=slot_mapping,
|
| 504 |
+
head_dim=self.runner.model_config.get_head_size(),
|
| 505 |
+
num_decodes=self._num_decodes,
|
| 506 |
+
num_decode_tokens=self._num_decode_tokens,
|
| 507 |
+
num_prefills=self._num_prefills,
|
| 508 |
+
attn_mask=self.runner.attn_mask,
|
| 509 |
+
attn_state=self.runner.attn_state,
|
| 510 |
+
prefill=prefill_metadata,
|
| 511 |
+
decode=decode_metadata,
|
| 512 |
+
query_start_loc=query_start_loc,
|
| 513 |
+
block_tables=block_table,
|
| 514 |
+
seq_lens=seq_lens,
|
| 515 |
+
max_num_tokens_across_dp=max_num_tokens_across_dp,
|
| 516 |
+
with_prefill_across_dp=with_prefill_across_dp,
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class AscendMLAImpl(MLAAttentionImpl):
|
| 521 |
+
"""
|
| 522 |
+
NOTE: Please read the comment at the top of the file before trying to
|
| 523 |
+
understand this class
|
| 524 |
+
"""
|
| 525 |
+
|
| 526 |
+
def __init__(
|
| 527 |
+
self,
|
| 528 |
+
num_heads: int,
|
| 529 |
+
head_size: int,
|
| 530 |
+
scale: float,
|
| 531 |
+
num_kv_heads: int,
|
| 532 |
+
alibi_slopes: Optional[list[float]],
|
| 533 |
+
sliding_window: Optional[int],
|
| 534 |
+
kv_cache_dtype: str,
|
| 535 |
+
blocksparse_params: Optional[dict[str, Any]],
|
| 536 |
+
logits_soft_cap: Optional[float],
|
| 537 |
+
attn_type: str,
|
| 538 |
+
kv_sharing_target_layer_name: Optional[str] = None,
|
| 539 |
+
**kwargs,
|
| 540 |
+
) -> None:
|
| 541 |
+
self.num_heads = num_heads
|
| 542 |
+
self.head_size = head_size
|
| 543 |
+
self.scale = float(scale)
|
| 544 |
+
self.num_kv_heads = num_kv_heads
|
| 545 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 546 |
+
|
| 547 |
+
# MLA Args
|
| 548 |
+
self.q_lora_rank = kwargs['q_lora_rank']
|
| 549 |
+
self.kv_lora_rank = kwargs['kv_lora_rank']
|
| 550 |
+
self.qk_nope_head_dim = kwargs['qk_nope_head_dim']
|
| 551 |
+
self.qk_rope_head_dim = kwargs['qk_rope_head_dim']
|
| 552 |
+
self.qk_head_dim = kwargs['qk_head_dim']
|
| 553 |
+
self.v_head_dim = kwargs['v_head_dim']
|
| 554 |
+
self.rotary_emb = kwargs['rotary_emb']
|
| 555 |
+
self.q_proj = kwargs['q_proj']
|
| 556 |
+
self.kv_b_proj = kwargs['kv_b_proj']
|
| 557 |
+
self.o_proj = kwargs['o_proj']
|
| 558 |
+
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
| 559 |
+
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
| 560 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 561 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 562 |
+
|
| 563 |
+
ascend_config = get_ascend_config()
|
| 564 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 565 |
+
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
| 566 |
+
|
| 567 |
+
# Adapt torch air graph mode with spec decoding.
|
| 568 |
+
speculative_config = get_current_vllm_config().speculative_config
|
| 569 |
+
if speculative_config is not None:
|
| 570 |
+
self.spec_token_num = speculative_config.num_speculative_tokens
|
| 571 |
+
assert self.spec_token_num > 0
|
| 572 |
+
self.SHARE_MASK_TRIL_SPARSE = ~torch.tril(torch.ones((2048, 2048), dtype=torch.bool)).npu()
|
| 573 |
+
|
| 574 |
+
def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False):
|
| 575 |
+
# Convert from (B, N, L) to (N, B, L)
|
| 576 |
+
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
| 577 |
+
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
| 578 |
+
x = torch.bmm(x, self.W_UV)
|
| 579 |
+
# Convert from (N, B, V) to (B, N * V)
|
| 580 |
+
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
| 581 |
+
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
|
| 582 |
+
npu_prefetch(self.o_proj.weight,
|
| 583 |
+
x,
|
| 584 |
+
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
| 585 |
+
enabled=enable_multistream_mla)
|
| 586 |
+
return self.o_proj(x, is_prefill=False)[0]
|
| 587 |
+
|
| 588 |
+
# Return `ql_nope`, `q_pe`
|
| 589 |
+
def _q_proj_and_k_up_proj(self, x):
|
| 590 |
+
q_nope, q_pe = self.q_proj(x)[0]\
|
| 591 |
+
.view(-1, self.num_heads, self.qk_head_dim)\
|
| 592 |
+
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
| 593 |
+
|
| 594 |
+
# Convert from (B, N, P) to (N, B, P)
|
| 595 |
+
q_nope = q_nope.transpose(0, 1)
|
| 596 |
+
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
| 597 |
+
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
| 598 |
+
# Convert from (N, B, L) to (B, N, L)
|
| 599 |
+
return ql_nope.transpose(0, 1), q_pe
|
| 600 |
+
|
| 601 |
+
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
| 602 |
+
|
| 603 |
+
def get_layer_weight(layer):
|
| 604 |
+
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
| 605 |
+
for attr in WEIGHT_NAMES:
|
| 606 |
+
if hasattr(layer, attr):
|
| 607 |
+
return getattr(layer, attr)
|
| 608 |
+
raise AttributeError(
|
| 609 |
+
f"Layer '{layer}' has no recognized weight attribute:"
|
| 610 |
+
f" {WEIGHT_NAMES}.")
|
| 611 |
+
|
| 612 |
+
def get_and_maybe_dequant_weights(layer: LinearBase):
|
| 613 |
+
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
| 614 |
+
# NOTE: This should only be used offline, since it's O(N^3)
|
| 615 |
+
eye = torch.eye(layer.input_size_per_partition,
|
| 616 |
+
dtype=act_dtype,
|
| 617 |
+
device=get_layer_weight(layer).device)
|
| 618 |
+
dequant_weights = layer.quant_method.apply(layer,
|
| 619 |
+
eye,
|
| 620 |
+
bias=None)
|
| 621 |
+
del eye
|
| 622 |
+
# standardize to (output, input)
|
| 623 |
+
return dequant_weights.T
|
| 624 |
+
return layer.weight
|
| 625 |
+
|
| 626 |
+
# we currently do not have quantized bmm's which are needed for
|
| 627 |
+
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
|
| 628 |
+
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
| 629 |
+
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
| 630 |
+
assert kv_b_proj_weight.shape == (
|
| 631 |
+
self.kv_lora_rank,
|
| 632 |
+
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
| 633 |
+
f"{kv_b_proj_weight.shape=}, "
|
| 634 |
+
f"{self.kv_lora_rank=}, "
|
| 635 |
+
f"{self.num_heads=}, "
|
| 636 |
+
f"{self.qk_nope_head_dim=}, "
|
| 637 |
+
f"{self.v_head_dim=}")
|
| 638 |
+
kv_b_proj_weight = kv_b_proj_weight.view(
|
| 639 |
+
self.kv_lora_rank,
|
| 640 |
+
self.num_heads,
|
| 641 |
+
self.qk_nope_head_dim + self.v_head_dim,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
W_UK, W_UV = kv_b_proj_weight.split(
|
| 645 |
+
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 646 |
+
|
| 647 |
+
# Convert from (L, N, V) to (N, L, V)
|
| 648 |
+
self.W_UV = W_UV.transpose(0, 1).contiguous()
|
| 649 |
+
# Convert from (L, N, P) to (N, P, L)
|
| 650 |
+
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
|
| 651 |
+
|
| 652 |
+
# Waiting for BMM NZ support
|
| 653 |
+
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
| 654 |
+
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
| 655 |
+
|
| 656 |
+
def _compute_prefill_context(
|
| 657 |
+
self,
|
| 658 |
+
query: torch.Tensor,
|
| 659 |
+
kv_c_and_k_pe_cache: torch.Tensor,
|
| 660 |
+
rope_dim: int,
|
| 661 |
+
attn_metadata: AscendMLAMetadata,
|
| 662 |
+
prefix_output: torch.Tensor,
|
| 663 |
+
prefix_lse: torch.Tensor,
|
| 664 |
+
):
|
| 665 |
+
prefill_metadata = attn_metadata.prefill
|
| 666 |
+
if prefill_metadata is None or prefill_metadata.chunked_context is None:
|
| 667 |
+
return prefix_output, prefix_lse
|
| 668 |
+
|
| 669 |
+
iters = len(prefill_metadata.chunked_context.seq_tot)
|
| 670 |
+
q_pe = query[..., self.qk_nope_head_dim:]
|
| 671 |
+
q_nope = query[..., :self.qk_nope_head_dim]
|
| 672 |
+
|
| 673 |
+
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
|
| 674 |
+
latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim
|
| 675 |
+
cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim]
|
| 676 |
+
cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:]
|
| 677 |
+
for i in range(iters):
|
| 678 |
+
toks = prefill_metadata.chunked_context.seq_tot[i]
|
| 679 |
+
|
| 680 |
+
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
|
| 681 |
+
seq_len = torch.stack([seq_len1, seq_len2])
|
| 682 |
+
kv_c_normed = torch.empty(toks,
|
| 683 |
+
kv_c_and_k_pe_cache.size(2),
|
| 684 |
+
latent_kv_dim,
|
| 685 |
+
dtype=query.dtype,
|
| 686 |
+
device=query.device)
|
| 687 |
+
k_pe = torch.empty(toks,
|
| 688 |
+
kv_c_and_k_pe_cache.size(2),
|
| 689 |
+
rope_dim,
|
| 690 |
+
dtype=query.dtype,
|
| 691 |
+
device=query.device)
|
| 692 |
+
|
| 693 |
+
torch_npu.atb.npu_paged_cache_load(
|
| 694 |
+
cache_kv_c,
|
| 695 |
+
cache_k_pe,
|
| 696 |
+
prefill_metadata.block_table,
|
| 697 |
+
seq_len2.to(query.device),
|
| 698 |
+
seq_starts=prefill_metadata.chunked_context.starts[i],
|
| 699 |
+
key=kv_c_normed,
|
| 700 |
+
value=k_pe,
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
kv_c_normed = kv_c_normed.squeeze()
|
| 704 |
+
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
|
| 705 |
+
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
| 706 |
+
k_nope, v = kv_nope\
|
| 707 |
+
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 708 |
+
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
| 709 |
+
mask = torch.triu(
|
| 710 |
+
torch.ones(512, 512, device=query.device, dtype=query.dtype),
|
| 711 |
+
1)
|
| 712 |
+
torch_npu.atb.npu_ring_mla(
|
| 713 |
+
q_nope=q_nope,
|
| 714 |
+
q_rope=q_pe,
|
| 715 |
+
k_nope=k_nope,
|
| 716 |
+
k_rope=k_pe,
|
| 717 |
+
value=v,
|
| 718 |
+
mask=mask,
|
| 719 |
+
seqlen=seq_len,
|
| 720 |
+
head_num=self.num_heads,
|
| 721 |
+
kv_head_num=self.num_heads,
|
| 722 |
+
pre_out=prefix_output,
|
| 723 |
+
prev_lse=prefix_lse,
|
| 724 |
+
qk_scale=self.scale,
|
| 725 |
+
kernel_type="kernel_type_high_precision",
|
| 726 |
+
mask_type="no_mask",
|
| 727 |
+
input_layout="type_bsnd",
|
| 728 |
+
calc_type="calc_type_default",
|
| 729 |
+
output=prefix_output,
|
| 730 |
+
softmax_lse=prefix_lse)
|
| 731 |
+
return prefix_output, prefix_lse
|
| 732 |
+
|
| 733 |
+
def _forward_prefill(
|
| 734 |
+
self,
|
| 735 |
+
query: torch.Tensor,
|
| 736 |
+
kv_c_normed: torch.Tensor,
|
| 737 |
+
k_pe: torch.Tensor,
|
| 738 |
+
kv_c_and_k_pe_cache: torch.Tensor,
|
| 739 |
+
attn_metadata: AscendMLAMetadata,
|
| 740 |
+
) -> torch.Tensor:
|
| 741 |
+
assert attn_metadata.prefill is not None
|
| 742 |
+
|
| 743 |
+
num_tokens = query.size(0)
|
| 744 |
+
attn_output = torch.empty(num_tokens,
|
| 745 |
+
self.num_heads,
|
| 746 |
+
self.v_head_dim,
|
| 747 |
+
dtype=query.dtype,
|
| 748 |
+
device=query.device)
|
| 749 |
+
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
|
| 750 |
+
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
|
| 751 |
+
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 752 |
+
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
| 753 |
+
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
|
| 754 |
+
ascend_config = get_ascend_config()
|
| 755 |
+
|
| 756 |
+
if attn_metadata.attn_state in [
|
| 757 |
+
AscendAttentionState.ChunkedPrefill,
|
| 758 |
+
AscendAttentionState.SpecDecoding,
|
| 759 |
+
AscendAttentionState.PrefillCacheHit
|
| 760 |
+
] and not ascend_config.chunked_prefill_for_mla:
|
| 761 |
+
attn_output_torch = torch.empty(num_tokens,
|
| 762 |
+
self.num_heads * self.v_head_dim,
|
| 763 |
+
dtype=query.dtype,
|
| 764 |
+
device=query.device)
|
| 765 |
+
# current requests is chunked in prefill, disable flash attention with chunked prefill
|
| 766 |
+
vanilla_chunked_prefill_mla(
|
| 767 |
+
output=attn_output_torch,
|
| 768 |
+
query=query,
|
| 769 |
+
kv_cache=kv_c_and_k_pe_cache,
|
| 770 |
+
block_tables=attn_metadata.prefill.block_table,
|
| 771 |
+
query_lens=attn_metadata.prefill.query_lens,
|
| 772 |
+
context_lens=attn_metadata.prefill.context_lens,
|
| 773 |
+
kv_b_proj=self.kv_b_proj,
|
| 774 |
+
max_query_len=attn_metadata.prefill.max_query_len,
|
| 775 |
+
max_context_len=attn_metadata.prefill.max_seq_lens,
|
| 776 |
+
nope_dim=self.qk_nope_head_dim,
|
| 777 |
+
rope_dim=self.qk_rope_head_dim,
|
| 778 |
+
v_head_dim=self.v_head_dim,
|
| 779 |
+
scale=self.scale,
|
| 780 |
+
alibi_slopes=None,
|
| 781 |
+
causal=True)
|
| 782 |
+
elif attn_metadata.attn_state in [
|
| 783 |
+
AscendAttentionState.ChunkedPrefill,
|
| 784 |
+
AscendAttentionState.SpecDecoding,
|
| 785 |
+
AscendAttentionState.PrefillCacheHit
|
| 786 |
+
]:
|
| 787 |
+
attn_lse = torch.empty(self.num_heads,
|
| 788 |
+
num_tokens,
|
| 789 |
+
dtype=torch.float32,
|
| 790 |
+
device=query.device)
|
| 791 |
+
q_pe = query[..., self.qk_nope_head_dim:]
|
| 792 |
+
q_nope = query[..., :self.qk_nope_head_dim]
|
| 793 |
+
mask = torch.triu(
|
| 794 |
+
torch.ones(512, 512, device=query.device, dtype=query.dtype),
|
| 795 |
+
1) # 512: mask only support 512
|
| 796 |
+
if attn_metadata.num_prefills > 1:
|
| 797 |
+
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
|
| 798 |
+
1)
|
| 799 |
+
torch_npu.atb.npu_ring_mla(
|
| 800 |
+
q_nope=q_nope,
|
| 801 |
+
q_rope=q_pe,
|
| 802 |
+
k_nope=k_nope,
|
| 803 |
+
k_rope=k_pe,
|
| 804 |
+
value=value,
|
| 805 |
+
mask=mask,
|
| 806 |
+
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
|
| 807 |
+
dtype=torch.int32),
|
| 808 |
+
head_num=self.num_heads,
|
| 809 |
+
kv_head_num=self.num_heads,
|
| 810 |
+
pre_out=None,
|
| 811 |
+
prev_lse=None,
|
| 812 |
+
qk_scale=self.scale,
|
| 813 |
+
kernel_type="kernel_type_high_precision",
|
| 814 |
+
mask_type="mask_type_triu",
|
| 815 |
+
input_layout="type_bsnd",
|
| 816 |
+
calc_type="calc_type_first_ring",
|
| 817 |
+
output=attn_output,
|
| 818 |
+
softmax_lse=attn_lse)
|
| 819 |
+
attn_output, attn_lse = self._compute_prefill_context( \
|
| 820 |
+
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
| 821 |
+
|
| 822 |
+
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
| 823 |
+
key = torch.cat((k_nope, k_pe), dim=-1)
|
| 824 |
+
context_lens_list = torch.cumsum(attn_metadata.prefill.context_lens, dim=0).tolist()
|
| 825 |
+
attn_output = torch_npu.npu_fused_infer_attention_score(
|
| 826 |
+
query,
|
| 827 |
+
key,
|
| 828 |
+
value,
|
| 829 |
+
num_heads=self.num_heads,
|
| 830 |
+
input_layout="TND",
|
| 831 |
+
scale=self.scale,
|
| 832 |
+
sparse_mode=3,
|
| 833 |
+
atten_mask=self.SHARE_MASK_TRIL_SPARSE,
|
| 834 |
+
actual_seq_lengths=context_lens_list,
|
| 835 |
+
actual_seq_lengths_kv=context_lens_list,
|
| 836 |
+
inner_precise=0)[0]
|
| 837 |
+
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
|
| 838 |
+
else:
|
| 839 |
+
raise RuntimeError(
|
| 840 |
+
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
|
| 841 |
+
)
|
| 842 |
+
attn_output = attn_output.reshape(
|
| 843 |
+
[num_tokens, self.num_heads * self.v_head_dim])
|
| 844 |
+
if attn_metadata.attn_state in [
|
| 845 |
+
AscendAttentionState.ChunkedPrefill,
|
| 846 |
+
AscendAttentionState.SpecDecoding,
|
| 847 |
+
AscendAttentionState.PrefillCacheHit
|
| 848 |
+
] and not ascend_config.chunked_prefill_for_mla:
|
| 849 |
+
attn_output = attn_output_torch
|
| 850 |
+
|
| 851 |
+
current_ms_metadata = get_multistream_comm_context()
|
| 852 |
+
if current_ms_metadata is None:
|
| 853 |
+
return self.o_proj(attn_output, is_prefill=True)[0]
|
| 854 |
+
else:
|
| 855 |
+
current_ms_metadata.before_comm_event.record()
|
| 856 |
+
with torch.npu.stream(current_ms_metadata.comm_stream):
|
| 857 |
+
current_ms_metadata.before_comm_event.wait()
|
| 858 |
+
return self.o_proj(attn_output, is_prefill=True)[0]
|
| 859 |
+
|
| 860 |
+
def exec_kv(
|
| 861 |
+
self,
|
| 862 |
+
hidden_states: torch.Tensor,
|
| 863 |
+
cos: torch.Tensor,
|
| 864 |
+
sin: torch.Tensor,
|
| 865 |
+
kv_cache: Tuple,
|
| 866 |
+
slots: torch.Tensor,
|
| 867 |
+
):
|
| 868 |
+
|
| 869 |
+
B = hidden_states.shape[0]
|
| 870 |
+
N = self.num_kv_heads
|
| 871 |
+
S = 1
|
| 872 |
+
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
| 873 |
+
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
| 874 |
+
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
| 875 |
+
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
| 876 |
+
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
| 877 |
+
kv,
|
| 878 |
+
self.kv_a_layernorm.weight,
|
| 879 |
+
cos,
|
| 880 |
+
sin,
|
| 881 |
+
slots.to(torch.int64),
|
| 882 |
+
kv_cache[1],
|
| 883 |
+
kv_cache[0],
|
| 884 |
+
epsilon=self.kv_a_layernorm.variance_epsilon,
|
| 885 |
+
cache_mode=cache_mode,
|
| 886 |
+
)
|
| 887 |
+
return k_pe, k_nope, kv
|
| 888 |
+
|
| 889 |
+
def exec_kv_prefill(
|
| 890 |
+
self,
|
| 891 |
+
hidden_states: torch.Tensor,
|
| 892 |
+
cos: torch.Tensor,
|
| 893 |
+
sin: torch.Tensor,
|
| 894 |
+
kv_cache: Tuple,
|
| 895 |
+
slots: torch.Tensor,
|
| 896 |
+
):
|
| 897 |
+
|
| 898 |
+
B = hidden_states.shape[0]
|
| 899 |
+
N = self.num_kv_heads
|
| 900 |
+
S = 1
|
| 901 |
+
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
| 902 |
+
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
| 903 |
+
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
| 904 |
+
cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA"
|
| 905 |
+
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
|
| 906 |
+
kv,
|
| 907 |
+
self.kv_a_layernorm.weight,
|
| 908 |
+
cos,
|
| 909 |
+
sin,
|
| 910 |
+
slots.to(torch.int64),
|
| 911 |
+
kv_cache[1],
|
| 912 |
+
kv_cache[0],
|
| 913 |
+
epsilon=self.kv_a_layernorm.variance_epsilon,
|
| 914 |
+
cache_mode=cache_mode,
|
| 915 |
+
is_output_kv=True,
|
| 916 |
+
)
|
| 917 |
+
return k_pe, k_nope
|
| 918 |
+
|
| 919 |
+
def rope_single(
|
| 920 |
+
self,
|
| 921 |
+
x: torch.Tensor,
|
| 922 |
+
cos: torch.Tensor,
|
| 923 |
+
sin: torch.Tensor,
|
| 924 |
+
) -> torch.Tensor:
|
| 925 |
+
B, N, D = x.shape
|
| 926 |
+
S = 1
|
| 927 |
+
x = x.view(B, N, S, D)
|
| 928 |
+
x = torch_npu.npu_interleave_rope(x, cos, sin)
|
| 929 |
+
return x.view(B, N, D)
|
| 930 |
+
|
| 931 |
+
def _forward_decode(
|
| 932 |
+
self,
|
| 933 |
+
q_nope: torch.Tensor,
|
| 934 |
+
q_pe: torch.Tensor,
|
| 935 |
+
k_nope: torch.Tensor,
|
| 936 |
+
k_pe: torch.Tensor,
|
| 937 |
+
kv_c_and_k_pe_cache: torch.Tensor,
|
| 938 |
+
attn_metadata: AscendMLAMetadata,
|
| 939 |
+
enable_multistream_mla: bool = False,
|
| 940 |
+
) -> torch.Tensor:
|
| 941 |
+
decode_meta = attn_metadata.decode
|
| 942 |
+
assert decode_meta is not None
|
| 943 |
+
|
| 944 |
+
q = torch.cat([q_nope, q_pe], dim=-1)
|
| 945 |
+
num_tokens = q.size(0)
|
| 946 |
+
attn_output = torch.empty(
|
| 947 |
+
[num_tokens, self.num_heads, self.kv_lora_rank],
|
| 948 |
+
dtype=q.dtype,
|
| 949 |
+
device=q.device)
|
| 950 |
+
if self.running_in_graph:
|
| 951 |
+
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
|
| 952 |
+
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
| 953 |
+
assert num_tokens % self.spec_token_num == 0
|
| 954 |
+
q_nope = q_nope.view(num_tokens // (self.spec_token_num + 1),
|
| 955 |
+
self.spec_token_num + 1, self.num_heads,
|
| 956 |
+
-1)
|
| 957 |
+
q_pe = q_pe.view(num_tokens // (self.spec_token_num + 1),
|
| 958 |
+
self.spec_token_num + 1, self.num_heads, -1)
|
| 959 |
+
if not self.enable_kv_nz:
|
| 960 |
+
q_nope = q_nope.transpose(1, 2).contiguous()
|
| 961 |
+
q_pe = q_pe.transpose(1, 2).contiguous()
|
| 962 |
+
sparse_mode = 3
|
| 963 |
+
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
| 964 |
+
else:
|
| 965 |
+
if self.enable_kv_nz:
|
| 966 |
+
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
|
| 967 |
+
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
|
| 968 |
+
else:
|
| 969 |
+
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
| 970 |
+
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
| 971 |
+
sparse_mode = 0
|
| 972 |
+
spec_attn_mask = None
|
| 973 |
+
# shape of knope/k_pe for npu graph mode should be:
|
| 974 |
+
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
| 975 |
+
block_size = kv_c_and_k_pe_cache[0].shape[1]
|
| 976 |
+
if self.enable_kv_nz:
|
| 977 |
+
k_nope = k_nope.view(-1, self.num_kv_heads,
|
| 978 |
+
self.kv_lora_rank // 16, block_size, 16)
|
| 979 |
+
k_pe = k_pe.view(-1, self.num_kv_heads,
|
| 980 |
+
self.qk_rope_head_dim // 16, block_size, 16)
|
| 981 |
+
input_layout = "BSND"
|
| 982 |
+
else:
|
| 983 |
+
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
| 984 |
+
self.kv_lora_rank)
|
| 985 |
+
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
| 986 |
+
self.qk_rope_head_dim)
|
| 987 |
+
input_layout = "BNSD"
|
| 988 |
+
|
| 989 |
+
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
| 990 |
+
q_nope,
|
| 991 |
+
k_nope,
|
| 992 |
+
k_nope,
|
| 993 |
+
query_rope=q_pe,
|
| 994 |
+
key_rope=k_pe,
|
| 995 |
+
num_heads=self.num_heads,
|
| 996 |
+
num_key_value_heads=self.num_kv_heads,
|
| 997 |
+
input_layout=input_layout,
|
| 998 |
+
atten_mask=spec_attn_mask,
|
| 999 |
+
sparse_mode=sparse_mode,
|
| 1000 |
+
scale=self.scale,
|
| 1001 |
+
antiquant_mode=0,
|
| 1002 |
+
antiquant_scale=None,
|
| 1003 |
+
block_table=decode_meta.block_table,
|
| 1004 |
+
block_size=block_size,
|
| 1005 |
+
actual_seq_lengths_kv=decode_meta.seq_lens_list,
|
| 1006 |
+
)
|
| 1007 |
+
else:
|
| 1008 |
+
torch_npu._npu_paged_attention_mla(
|
| 1009 |
+
query=q,
|
| 1010 |
+
key_cache=kv_c_and_k_pe_cache,
|
| 1011 |
+
num_kv_heads=self.num_kv_heads,
|
| 1012 |
+
num_heads=self.num_heads,
|
| 1013 |
+
scale_value=self.scale,
|
| 1014 |
+
block_table=attn_metadata.decode.block_table, # type:ignore
|
| 1015 |
+
context_lens=attn_metadata.decode.seq_lens, # type:ignore
|
| 1016 |
+
mla_vheadsize=self.kv_lora_rank,
|
| 1017 |
+
out=attn_output)
|
| 1018 |
+
current_ms_metadata = get_multistream_comm_context()
|
| 1019 |
+
if current_ms_metadata is None:
|
| 1020 |
+
return self._v_up_proj_and_o_proj(attn_output,
|
| 1021 |
+
enable_multistream_mla)
|
| 1022 |
+
else:
|
| 1023 |
+
current_ms_metadata.before_comm_event.record()
|
| 1024 |
+
with torch.npu.stream(current_ms_metadata.comm_stream):
|
| 1025 |
+
current_ms_metadata.before_comm_event.wait()
|
| 1026 |
+
return self._v_up_proj_and_o_proj(attn_output)
|
| 1027 |
+
|
| 1028 |
+
def forward(
|
| 1029 |
+
self,
|
| 1030 |
+
layer: AttentionLayer,
|
| 1031 |
+
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
| 1032 |
+
hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn
|
| 1033 |
+
k_pe: torch.Tensor, # value in unified attn
|
| 1034 |
+
kv_cache: torch.Tensor,
|
| 1035 |
+
attn_metadata: M,
|
| 1036 |
+
output: Optional[torch.Tensor] = None,
|
| 1037 |
+
enable_multistream_mla: bool = False,
|
| 1038 |
+
ckq: Optional[torch.Tensor] = None,
|
| 1039 |
+
) -> torch.Tensor:
|
| 1040 |
+
assert output is not None, "Output tensor must be provided."
|
| 1041 |
+
if attn_metadata is None:
|
| 1042 |
+
# Profiling run.
|
| 1043 |
+
return output
|
| 1044 |
+
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
|
| 1045 |
+
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
| 1046 |
+
]
|
| 1047 |
+
num_actual_toks = attn_metadata.num_actual_tokens
|
| 1048 |
+
if k_pe is None and not self.running_in_graph:
|
| 1049 |
+
if not self.torchair_graph_enabled:
|
| 1050 |
+
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
| 1051 |
+
hidden_states_or_kv_c_normed)[0].split(
|
| 1052 |
+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
| 1053 |
+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
| 1054 |
+
else:
|
| 1055 |
+
kv_c_normed = hidden_states_or_kv_c_normed
|
| 1056 |
+
assert attn_metadata.num_decodes is not None and \
|
| 1057 |
+
attn_metadata.num_prefills is not None and \
|
| 1058 |
+
attn_metadata.num_decode_tokens is not None
|
| 1059 |
+
has_decode = attn_metadata.num_decodes > 0
|
| 1060 |
+
has_prefill = attn_metadata.num_prefills > 0
|
| 1061 |
+
num_decode_tokens = attn_metadata.num_decode_tokens
|
| 1062 |
+
if not self.running_in_graph:
|
| 1063 |
+
# Inputs and outputs may be padded for CUDA graphs
|
| 1064 |
+
output_padded = output
|
| 1065 |
+
output = output[:num_actual_toks, ...]
|
| 1066 |
+
if not self.torchair_graph_enabled:
|
| 1067 |
+
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
|
| 1068 |
+
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
|
| 1069 |
+
if not self.running_in_graph:
|
| 1070 |
+
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
|
| 1071 |
+
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
| 1072 |
+
if not self.torchair_graph_enabled:
|
| 1073 |
+
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
| 1074 |
+
k_pe = k_pe[:num_actual_toks, ...]
|
| 1075 |
+
k_pe = k_pe.unsqueeze(1)
|
| 1076 |
+
decode_k_pe = k_pe[:num_decode_tokens]
|
| 1077 |
+
prefill_k_pe = k_pe[num_decode_tokens:]
|
| 1078 |
+
else:
|
| 1079 |
+
decode_hs_or_q_c = hidden_states_or_q_c
|
| 1080 |
+
if has_decode:
|
| 1081 |
+
decode_k_nope = None
|
| 1082 |
+
assert attn_metadata.decode is not None
|
| 1083 |
+
if self.running_in_graph:
|
| 1084 |
+
seq_len = self.rotary_emb.max_position_embeddings * \
|
| 1085 |
+
getattr(self.rotary_emb, "scaling_factor", 1)
|
| 1086 |
+
cos = self.rotary_emb.cos_cached[:seq_len].to(
|
| 1087 |
+
dtype=decode_hs_or_q_c.dtype)
|
| 1088 |
+
sin = self.rotary_emb.sin_cached[:seq_len].to(
|
| 1089 |
+
dtype=decode_hs_or_q_c.dtype)
|
| 1090 |
+
cos = cos[attn_metadata.decode.input_positions]
|
| 1091 |
+
sin = sin[attn_metadata.decode.input_positions]
|
| 1092 |
+
cos = cos[:, None, None, :]
|
| 1093 |
+
sin = sin[:, None, None, :]
|
| 1094 |
+
with npu_stream_switch("mla_secondary",
|
| 1095 |
+
0,
|
| 1096 |
+
enabled=enable_multistream_mla):
|
| 1097 |
+
npu_wait_tensor(hidden_states_or_kv_c_normed,
|
| 1098 |
+
ckq,
|
| 1099 |
+
enabled=enable_multistream_mla)
|
| 1100 |
+
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
|
| 1101 |
+
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
| 1102 |
+
attn_metadata.slot_mapping)
|
| 1103 |
+
# Without explicitly controlling the order, IndexByTensor operations
|
| 1104 |
+
# would be placed after `matmul W_KV_T` hindering the overlapping of
|
| 1105 |
+
# KvRmsNormRopeCache and SingleRope.
|
| 1106 |
+
npu_wait_tensor(decode_hs_or_q_c,
|
| 1107 |
+
cos,
|
| 1108 |
+
enabled=enable_multistream_mla)
|
| 1109 |
+
npu_wait_tensor(decode_hs_or_q_c,
|
| 1110 |
+
sin,
|
| 1111 |
+
enabled=enable_multistream_mla)
|
| 1112 |
+
npu_wait_tensor(decode_hs_or_q_c,
|
| 1113 |
+
decode_kv,
|
| 1114 |
+
enabled=enable_multistream_mla)
|
| 1115 |
+
|
| 1116 |
+
decode_ql_nope, decode_q_pe = \
|
| 1117 |
+
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
| 1118 |
+
if self.running_in_graph:
|
| 1119 |
+
with npu_stream_switch("mla_secondary",
|
| 1120 |
+
0,
|
| 1121 |
+
enabled=enable_multistream_mla):
|
| 1122 |
+
npu_wait_tensor(decode_q_pe,
|
| 1123 |
+
decode_k_pe,
|
| 1124 |
+
enabled=enable_multistream_mla)
|
| 1125 |
+
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
| 1126 |
+
else:
|
| 1127 |
+
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
| 1128 |
+
attn_metadata.decode.input_positions,
|
| 1129 |
+
decode_q_pe.contiguous(),
|
| 1130 |
+
decode_k_pe,
|
| 1131 |
+
max_seq_len=attn_metadata.decode.max_seq_lens)
|
| 1132 |
+
if has_prefill:
|
| 1133 |
+
assert attn_metadata.prefill is not None
|
| 1134 |
+
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
| 1135 |
+
.view(-1, self.num_heads, self.qk_head_dim)
|
| 1136 |
+
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
| 1137 |
+
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
|
| 1138 |
+
if self.torchair_graph_enabled:
|
| 1139 |
+
num_tokens = prefill_hs_or_q_c.shape[0]
|
| 1140 |
+
seq_len = self.rotary_emb.max_position_embeddings * \
|
| 1141 |
+
getattr(self.rotary_emb, "scaling_factor", 1)
|
| 1142 |
+
cos = self.rotary_emb.cos_cached[:seq_len].to(
|
| 1143 |
+
dtype=prefill_q_pe.dtype)
|
| 1144 |
+
sin = self.rotary_emb.sin_cached[:seq_len].to(
|
| 1145 |
+
dtype=prefill_q_pe.dtype)
|
| 1146 |
+
cos = cos[attn_metadata.prefill.input_positions]
|
| 1147 |
+
sin = sin[attn_metadata.prefill.input_positions]
|
| 1148 |
+
cos = cos[:, None, None, :]
|
| 1149 |
+
sin = sin[:, None, None, :]
|
| 1150 |
+
|
| 1151 |
+
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
| 1152 |
+
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
|
| 1153 |
+
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
| 1154 |
+
attn_metadata.slot_mapping)
|
| 1155 |
+
|
| 1156 |
+
kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
|
| 1157 |
+
prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
|
| 1158 |
+
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
|
| 1159 |
+
-1)
|
| 1160 |
+
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
|
| 1161 |
+
else:
|
| 1162 |
+
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
| 1163 |
+
attn_metadata.prefill.input_positions,
|
| 1164 |
+
prefill_q_pe.contiguous(),
|
| 1165 |
+
prefill_k_pe,
|
| 1166 |
+
max_seq_len=attn_metadata.prefill.max_seq_lens)
|
| 1167 |
+
if self.torchair_graph_enabled:
|
| 1168 |
+
if len(kv_cache) > 0 and kv_cache[0].numel(
|
| 1169 |
+
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
| 1170 |
+
slots = attn_metadata.slot_mapping
|
| 1171 |
+
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
|
| 1172 |
+
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
|
| 1173 |
+
num_tokens, self.num_kv_heads, -1),
|
| 1174 |
+
value=prefill_k_pe,
|
| 1175 |
+
key_cache=kv_cache[0],
|
| 1176 |
+
value_cache=kv_cache[1],
|
| 1177 |
+
slot_indices=slots)
|
| 1178 |
+
elif kv_cache.numel() > 0:
|
| 1179 |
+
key = torch.cat([
|
| 1180 |
+
kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]),
|
| 1181 |
+
k_pe
|
| 1182 |
+
],
|
| 1183 |
+
dim=2)
|
| 1184 |
+
torch_npu._npu_reshape_and_cache_siso(
|
| 1185 |
+
key=key,
|
| 1186 |
+
key_cache=kv_cache,
|
| 1187 |
+
slot_indices=attn_metadata.slot_mapping.flatten())
|
| 1188 |
+
if has_prefill:
|
| 1189 |
+
# FIX: aicore move should be also placed on the comm stream in dbo,
|
| 1190 |
+
# otherwise it may affect the accuracy
|
| 1191 |
+
# TODO: use an elegant way to overlap
|
| 1192 |
+
output_prefill = self._forward_prefill(prefill_q,
|
| 1193 |
+
prefill_k_c_normed,
|
| 1194 |
+
prefill_k_pe, kv_cache,
|
| 1195 |
+
attn_metadata)
|
| 1196 |
+
current_ms_metadata = get_multistream_comm_context()
|
| 1197 |
+
if current_ms_metadata is not None:
|
| 1198 |
+
with torch.npu.stream(current_ms_metadata.comm_stream):
|
| 1199 |
+
output[num_decode_tokens:] = output_prefill
|
| 1200 |
+
current_ms_metadata.after_comm_event.record()
|
| 1201 |
+
else:
|
| 1202 |
+
output[num_decode_tokens:] = output_prefill
|
| 1203 |
+
|
| 1204 |
+
if has_decode:
|
| 1205 |
+
if self.running_in_graph:
|
| 1206 |
+
return self._forward_decode(decode_ql_nope, decode_q_pe,
|
| 1207 |
+
decode_k_nope, decode_k_pe,
|
| 1208 |
+
kv_cache, attn_metadata,
|
| 1209 |
+
enable_multistream_mla)
|
| 1210 |
+
else:
|
| 1211 |
+
output_decode = self._forward_decode(decode_ql_nope,
|
| 1212 |
+
decode_q_pe,
|
| 1213 |
+
decode_k_nope,
|
| 1214 |
+
decode_k_pe, kv_cache,
|
| 1215 |
+
attn_metadata)
|
| 1216 |
+
current_ms_metadata = get_multistream_comm_context()
|
| 1217 |
+
if current_ms_metadata is not None:
|
| 1218 |
+
with torch.npu.stream(current_ms_metadata.comm_stream):
|
| 1219 |
+
output[:num_decode_tokens] = output_decode
|
| 1220 |
+
current_ms_metadata.after_comm_event.record()
|
| 1221 |
+
else:
|
| 1222 |
+
output[:num_decode_tokens] = output_decode
|
| 1223 |
+
|
| 1224 |
+
return output_padded
|
inference/vllm_ascend/entrypoints/openai/reasoning_parsers/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 2 |
+
from .pangu_reasoning_parser import PanguReasoningParser
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"PanguReasoningParser"
|
| 6 |
+
]
|
inference/vllm_ascend/entrypoints/openai/reasoning_parsers/pangu_reasoning_parser.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 4 |
+
|
| 5 |
+
from collections.abc import Sequence
|
| 6 |
+
from typing import Optional, Union
|
| 7 |
+
|
| 8 |
+
from transformers import PreTrainedTokenizerBase
|
| 9 |
+
|
| 10 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
| 11 |
+
DeltaMessage)
|
| 12 |
+
from vllm.logger import init_logger
|
| 13 |
+
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
| 14 |
+
|
| 15 |
+
logger = init_logger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@ReasoningParserManager.register_module("pangu")
|
| 19 |
+
class PanguReasoningParser(ReasoningParser):
|
| 20 |
+
"""
|
| 21 |
+
Reasoning parser for Pangu model.
|
| 22 |
+
|
| 23 |
+
The Pangu model uses [unused16]...[unused17] tokens to denote reasoning
|
| 24 |
+
text. This parser extracts the reasoning content from the model output.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
start_token_id: int
|
| 28 |
+
end_token_id: int
|
| 29 |
+
|
| 30 |
+
start_token: str = "[unused16]"
|
| 31 |
+
end_token: str = "[unused17]"
|
| 32 |
+
|
| 33 |
+
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
| 34 |
+
super().__init__(tokenizer)
|
| 35 |
+
|
| 36 |
+
if not self.model_tokenizer:
|
| 37 |
+
raise ValueError(
|
| 38 |
+
"The model tokenizer must be passed to the ReasoningParser "
|
| 39 |
+
"constructor during construction.")
|
| 40 |
+
|
| 41 |
+
self.start_token_id = self.vocab.get(self.start_token)
|
| 42 |
+
self.end_token_id = self.vocab.get(self.end_token)
|
| 43 |
+
if self.start_token_id is None or self.end_token_id is None:
|
| 44 |
+
raise RuntimeError(
|
| 45 |
+
"Pangu reasoning parser could not locate think start/end "
|
| 46 |
+
"tokens in the tokenizer!")
|
| 47 |
+
|
| 48 |
+
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
| 49 |
+
return self.end_token_id in input_ids
|
| 50 |
+
|
| 51 |
+
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
| 52 |
+
"""
|
| 53 |
+
Extract the content after the end tokens
|
| 54 |
+
"""
|
| 55 |
+
if self.end_token_id not in input_ids[:-1]:
|
| 56 |
+
return []
|
| 57 |
+
else:
|
| 58 |
+
return input_ids[input_ids.index(self.end_token_id) + 1:]
|
| 59 |
+
|
| 60 |
+
def extract_reasoning_content_streaming(
|
| 61 |
+
self,
|
| 62 |
+
previous_text: str,
|
| 63 |
+
current_text: str,
|
| 64 |
+
delta_text: str,
|
| 65 |
+
previous_token_ids: Sequence[int],
|
| 66 |
+
current_token_ids: Sequence[int],
|
| 67 |
+
delta_token_ids: Sequence[int],
|
| 68 |
+
) -> Union[DeltaMessage, None]:
|
| 69 |
+
"""
|
| 70 |
+
Extract reasoning content from a delta message.
|
| 71 |
+
Handles streaming output where previous + delta = current.
|
| 72 |
+
Uses token IDs for faster processing.
|
| 73 |
+
For text [unused16]abc[unused17]xyz:
|
| 74 |
+
- 'abc' goes to reasoning_content
|
| 75 |
+
- 'xyz' goes to content
|
| 76 |
+
"""
|
| 77 |
+
# Skip single special tokens
|
| 78 |
+
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
|
| 79 |
+
self.start_token_id, self.end_token_id
|
| 80 |
+
]):
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
# Check if [unused16] is present in previous or delta.
|
| 84 |
+
# Keep compatibility with models that don't generate [unused16] tokens.
|
| 85 |
+
if self.start_token_id in previous_token_ids:
|
| 86 |
+
if self.end_token_id in delta_token_ids:
|
| 87 |
+
# [unused16] in previous, [unused17] in delta,
|
| 88 |
+
# extract reasoning content
|
| 89 |
+
end_index = delta_text.find(self.end_token)
|
| 90 |
+
reasoning_content = delta_text[:end_index]
|
| 91 |
+
content = delta_text[end_index + len(self.end_token):]
|
| 92 |
+
return DeltaMessage(
|
| 93 |
+
reasoning_content=reasoning_content,
|
| 94 |
+
content=content if content else None,
|
| 95 |
+
)
|
| 96 |
+
elif self.end_token_id in previous_token_ids:
|
| 97 |
+
# [unused16] in previous, [unused17] in previous,
|
| 98 |
+
# reasoning content continues
|
| 99 |
+
return DeltaMessage(content=delta_text)
|
| 100 |
+
else:
|
| 101 |
+
# [unused16] in previous, no [unused17] in previous or delta,
|
| 102 |
+
# reasoning content continues
|
| 103 |
+
return DeltaMessage(reasoning_content=delta_text)
|
| 104 |
+
elif self.start_token_id in delta_token_ids:
|
| 105 |
+
if self.end_token_id in delta_token_ids:
|
| 106 |
+
# [unused16] in delta, [unused17] in delta, extract reasoning content
|
| 107 |
+
start_index = delta_text.find(self.start_token)
|
| 108 |
+
end_index = delta_text.find(self.end_token)
|
| 109 |
+
reasoning_content = delta_text[start_index +
|
| 110 |
+
len(self.start_token):end_index]
|
| 111 |
+
content = delta_text[end_index + len(self.end_token):]
|
| 112 |
+
return DeltaMessage(
|
| 113 |
+
reasoning_content=reasoning_content,
|
| 114 |
+
content=content if content else None,
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
# [unused16] in delta, no [unused17] in delta,
|
| 118 |
+
# reasoning content continues
|
| 119 |
+
return DeltaMessage(reasoning_content=delta_text)
|
| 120 |
+
else:
|
| 121 |
+
# No [unused16] in previous or delta, also need to check for [unused17].
|
| 122 |
+
# Because the model may have generated [unused17] without [unused16]
|
| 123 |
+
if self.end_token_id in delta_token_ids:
|
| 124 |
+
# [unused17] in delta with more tokens,
|
| 125 |
+
# extract reasoning content and content
|
| 126 |
+
end_index = delta_text.find(self.end_token)
|
| 127 |
+
reasoning_content = delta_text[:end_index]
|
| 128 |
+
content = delta_text[end_index + len(self.end_token):]
|
| 129 |
+
return DeltaMessage(
|
| 130 |
+
reasoning_content=reasoning_content,
|
| 131 |
+
content=content if content else None,
|
| 132 |
+
)
|
| 133 |
+
elif self.end_token_id in previous_token_ids:
|
| 134 |
+
# [unused17] in previous, thinking content ends
|
| 135 |
+
return DeltaMessage(content=delta_text)
|
| 136 |
+
else:
|
| 137 |
+
# no [unused17] in previous or delta, reasoning content continues
|
| 138 |
+
return DeltaMessage(reasoning_content=delta_text)
|
| 139 |
+
|
| 140 |
+
def extract_reasoning_content(
|
| 141 |
+
self, model_output: str, request: ChatCompletionRequest
|
| 142 |
+
) -> tuple[Optional[str], Optional[str]]:
|
| 143 |
+
"""
|
| 144 |
+
Extract reasoning content from the model output.
|
| 145 |
+
|
| 146 |
+
For text [unused16]abc[unused17]xyz:
|
| 147 |
+
- 'abc' goes to reasoning_content
|
| 148 |
+
- 'xyz' goes to content
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
tuple[Optional[str], Optional[str]]: reasoning content and content
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
# Check if the start token is present in the model output, remove it
|
| 155 |
+
# if it is present.
|
| 156 |
+
model_output_parts = model_output.partition(self.start_token)
|
| 157 |
+
model_output = model_output_parts[2] if model_output_parts[
|
| 158 |
+
1] else model_output_parts[0]
|
| 159 |
+
|
| 160 |
+
# Thus we assume the reasoning content is always at the start.
|
| 161 |
+
if self.end_token not in model_output:
|
| 162 |
+
return model_output, None
|
| 163 |
+
else:
|
| 164 |
+
reasoning_content, _, content = model_output.partition(
|
| 165 |
+
self.end_token)
|
| 166 |
+
# If the end token is not found, return the model output as is.
|
| 167 |
+
# It should not happen since we already checked for the presence
|
| 168 |
+
# of the end token.
|
| 169 |
+
# If generation stops right after end-of-think, return null content
|
| 170 |
+
final_content = content or None
|
| 171 |
+
return reasoning_content, final_content
|
inference/vllm_ascend/entrypoints/openai/tool_parsers/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 2 |
+
from .pangu_tool_parser import PanguToolParser
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"PanguToolParser"
|
| 6 |
+
]
|
inference/vllm_ascend/entrypoints/openai/tool_parsers/pangu_tool_parser.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 2 |
+
# Copyright 2023 The vLLM team.
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
from json import JSONDecodeError, JSONDecoder
|
| 7 |
+
from typing import Dict, List, Sequence, Union, Optional
|
| 8 |
+
from pydantic import Field
|
| 9 |
+
import partial_json_parser
|
| 10 |
+
from partial_json_parser.core.options import Allow
|
| 11 |
+
from transformers import PreTrainedTokenizerBase
|
| 12 |
+
|
| 13 |
+
from vllm.entrypoints.chat_utils import random_tool_call_id
|
| 14 |
+
from vllm.entrypoints.openai.tool_parsers.utils import (
|
| 15 |
+
extract_intermediate_diff)
|
| 16 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
| 17 |
+
DeltaFunctionCall, DeltaMessage,
|
| 18 |
+
DeltaToolCall,
|
| 19 |
+
ExtractedToolCallInformation,
|
| 20 |
+
FunctionCall, ToolCall,
|
| 21 |
+
)
|
| 22 |
+
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
| 23 |
+
ToolParser, ToolParserManager)
|
| 24 |
+
from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix,
|
| 25 |
+
is_complete_json)
|
| 26 |
+
from vllm.logger import init_logger
|
| 27 |
+
import os
|
| 28 |
+
|
| 29 |
+
logger = init_logger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@ToolParserManager.register_module("pangu")
|
| 33 |
+
class PanguToolParser(ToolParser):
|
| 34 |
+
|
| 35 |
+
def __init__(self, tokenizer: PreTrainedTokenizerBase, enable_reasoning=False):
|
| 36 |
+
super().__init__(tokenizer)
|
| 37 |
+
|
| 38 |
+
# initialize properties used for state when parsing tool calls in
|
| 39 |
+
# streaming mode
|
| 40 |
+
self.prev_tool_call_arr: List[Dict] = []
|
| 41 |
+
self.current_tool_id: int = -1
|
| 42 |
+
self.current_tool_name_sent: bool = False
|
| 43 |
+
self.streamed_args_for_tool: List[str] = [
|
| 44 |
+
] # map what has been streamed for each tool so far to a list
|
| 45 |
+
|
| 46 |
+
self.tool_call_start_token = "[unused11]"
|
| 47 |
+
self.tool_call_end_token = "[unused12]"
|
| 48 |
+
self.pattern = re.escape(self.tool_call_start_token) \
|
| 49 |
+
+ "(.*?)" + re.escape(self.tool_call_end_token)
|
| 50 |
+
self.tool_call_regex = re.compile(self.pattern, re.DOTALL)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
self.tool_call_start_token_id = self.vocab.get(
|
| 54 |
+
self.tool_call_start_token)
|
| 55 |
+
self.tool_call_end_token_id = self.vocab.get(
|
| 56 |
+
self.tool_call_end_token)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if (self.tool_call_start_token_id is None
|
| 60 |
+
or self.tool_call_end_token_id is None):
|
| 61 |
+
raise RuntimeError(
|
| 62 |
+
"Pangu Tool parser could not locate tool calls start/end "
|
| 63 |
+
"tokens in the tokenizer!")
|
| 64 |
+
self.is_complete = []
|
| 65 |
+
self.text_after_start_token = ""
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def extract_tool_calls(
|
| 69 |
+
self, model_output: str,
|
| 70 |
+
request: ChatCompletionRequest
|
| 71 |
+
) -> ExtractedToolCallInformation:
|
| 72 |
+
"""
|
| 73 |
+
Extract the tool calls from a complete model response.
|
| 74 |
+
"""
|
| 75 |
+
# case -- if a tool call token is not present, return a text response
|
| 76 |
+
if not (self.tool_call_start_token in model_output and \
|
| 77 |
+
model_output.find(self.tool_call_end_token) != -1):
|
| 78 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 79 |
+
tool_calls=[],
|
| 80 |
+
content=model_output)
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
raw_function_calls = []
|
| 84 |
+
# use a regex to find the tool call between the tags
|
| 85 |
+
function_call_tuples = self.tool_call_regex.findall(model_output)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# load the JSON, and then use it to build the Function and
|
| 89 |
+
# Tool Call
|
| 90 |
+
for function_call_str in function_call_tuples:
|
| 91 |
+
function_call = json.loads(function_call_str)
|
| 92 |
+
raw_function_calls.extend(function_call)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
tool_calls: List[ToolCall] = [
|
| 96 |
+
ToolCall(
|
| 97 |
+
type="function",
|
| 98 |
+
function=FunctionCall(
|
| 99 |
+
name=function_call["name"],
|
| 100 |
+
# function call args are JSON but as a string
|
| 101 |
+
arguments=json.dumps(function_call["arguments"] \
|
| 102 |
+
if "arguments" in function_call \
|
| 103 |
+
else function_call["parameters"], ensure_ascii=False)))
|
| 104 |
+
for function_call in raw_function_calls
|
| 105 |
+
]
|
| 106 |
+
content = model_output[:model_output.
|
| 107 |
+
find(self.tool_call_start_token)]
|
| 108 |
+
|
| 109 |
+
# get any content before the tool call
|
| 110 |
+
ret = ExtractedToolCallInformation(tools_called=True,
|
| 111 |
+
tool_calls=tool_calls,
|
| 112 |
+
content=content if content else None)
|
| 113 |
+
|
| 114 |
+
return ret
|
| 115 |
+
|
| 116 |
+
except Exception:
|
| 117 |
+
logger.exception("Error in extracting tool call from response.")
|
| 118 |
+
# return information to just treat the tool call as regular JSON
|
| 119 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 120 |
+
tool_calls=[],
|
| 121 |
+
content=model_output)
|
| 122 |
+
|
| 123 |
+
def extract_tool_calls_streaming(
|
| 124 |
+
self,
|
| 125 |
+
previous_text: str,
|
| 126 |
+
current_text: str,
|
| 127 |
+
delta_text: str,
|
| 128 |
+
previous_token_ids: Sequence[int],
|
| 129 |
+
current_token_ids: Sequence[int],
|
| 130 |
+
delta_token_ids: Sequence[int],
|
| 131 |
+
request: ChatCompletionRequest,
|
| 132 |
+
) -> Union[DeltaMessage, None]:
|
| 133 |
+
|
| 134 |
+
if (self.tool_call_end_token_id in delta_token_ids
|
| 135 |
+
and len(delta_token_ids) == 1):
|
| 136 |
+
# if it's the only token, return None, so we don't send a chat
|
| 137 |
+
# completion and don't send a control token
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
if (self.tool_call_end_token in current_text
|
| 141 |
+
and self.tool_call_end_token not in delta_text):
|
| 142 |
+
return DeltaMessage(content=delta_text)
|
| 143 |
+
|
| 144 |
+
if self.tool_call_start_token not in current_text:
|
| 145 |
+
return DeltaMessage(content=delta_text)
|
| 146 |
+
|
| 147 |
+
if self.tool_call_start_token in delta_text:
|
| 148 |
+
texts = delta_text.split(self.tool_call_start_token)
|
| 149 |
+
text_before_start_token = texts[0]
|
| 150 |
+
if text_before_start_token:
|
| 151 |
+
return DeltaMessage(content=text_before_start_token)
|
| 152 |
+
|
| 153 |
+
if (self.tool_call_start_token_id in delta_token_ids
|
| 154 |
+
and len(delta_token_ids) == 1):
|
| 155 |
+
# if it's the only token, return None, so we don't send a chat
|
| 156 |
+
# completion and don't send a control token
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
# bit mask flags for partial JSON parsing. If the name hasn't been
|
| 160 |
+
# sent yet, don't allow sending
|
| 161 |
+
# an incomplete string since OpenAI only ever (as far as I have
|
| 162 |
+
# seen) allows sending the entire tool/ function name at once.
|
| 163 |
+
flags = Allow.ALL if self.current_tool_name_sent \
|
| 164 |
+
else Allow.ALL & ~Allow.STR
|
| 165 |
+
try:
|
| 166 |
+
|
| 167 |
+
tool_call_portion = current_text.split(
|
| 168 |
+
self.tool_call_start_token)[-1].split(self.tool_call_end_token)[0]
|
| 169 |
+
try:
|
| 170 |
+
tool_call_arr: list[dict] = partial_json_parser.loads(
|
| 171 |
+
tool_call_portion, flags)
|
| 172 |
+
|
| 173 |
+
self.is_complete.append(
|
| 174 |
+
is_complete_json(tool_call_portion))
|
| 175 |
+
except partial_json_parser.core.exceptions.MalformedJSON:
|
| 176 |
+
logger.debug('not enough tokens to parse into JSON yet')
|
| 177 |
+
return None
|
| 178 |
+
|
| 179 |
+
# select as the current tool call the one we're on the state at
|
| 180 |
+
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
|
| 181 |
+
if len(tool_call_arr) > 0 else {}
|
| 182 |
+
|
| 183 |
+
# case -- if no tokens have been streamed for the tool, e.g.
|
| 184 |
+
# only the array brackets, stream nothing
|
| 185 |
+
if len(tool_call_arr) == 0:
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
# case: we are starting a new tool in the array
|
| 189 |
+
# -> array has > 0 length AND length has moved past cursor
|
| 190 |
+
elif (len(tool_call_arr) > 0
|
| 191 |
+
and len(tool_call_arr) > self.current_tool_id + 1):
|
| 192 |
+
|
| 193 |
+
# if we're moving on to a new call, first make sure we
|
| 194 |
+
# haven't missed anything in the previous one that was
|
| 195 |
+
# auto-generated due to JSON completions, but wasn't
|
| 196 |
+
# streamed to the client yet.
|
| 197 |
+
if self.current_tool_id >= 0:
|
| 198 |
+
cur_arguments = current_tool_call.get("arguments")
|
| 199 |
+
if cur_arguments:
|
| 200 |
+
cur_args_json = json.dumps(cur_arguments,
|
| 201 |
+
ensure_ascii=False)
|
| 202 |
+
sent = len(
|
| 203 |
+
self.streamed_args_for_tool[self.current_tool_id])
|
| 204 |
+
argument_diff = cur_args_json[sent:]
|
| 205 |
+
|
| 206 |
+
logger.debug("got arguments diff: %s", argument_diff)
|
| 207 |
+
delta = DeltaMessage(tool_calls=[
|
| 208 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 209 |
+
function=DeltaFunctionCall(
|
| 210 |
+
arguments=argument_diff).
|
| 211 |
+
model_dump(exclude_none=True))
|
| 212 |
+
])
|
| 213 |
+
self.streamed_args_for_tool[
|
| 214 |
+
self.current_tool_id] += argument_diff
|
| 215 |
+
else:
|
| 216 |
+
delta = None
|
| 217 |
+
else:
|
| 218 |
+
delta = None
|
| 219 |
+
# re-set stuff pertaining to progress in the current tool
|
| 220 |
+
self.current_tool_id = len(tool_call_arr) - 1
|
| 221 |
+
self.current_tool_name_sent = False
|
| 222 |
+
self.streamed_args_for_tool.append("")
|
| 223 |
+
self.is_complete = []
|
| 224 |
+
logger.debug("starting on new tool %d", self.current_tool_id)
|
| 225 |
+
return delta
|
| 226 |
+
|
| 227 |
+
# if the current tool name hasn't been sent, send if available
|
| 228 |
+
# - otherwise send nothing
|
| 229 |
+
elif not self.current_tool_name_sent:
|
| 230 |
+
function_name = current_tool_call.get("name")
|
| 231 |
+
if function_name:
|
| 232 |
+
delta = DeltaMessage(tool_calls=[
|
| 233 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 234 |
+
type="function",
|
| 235 |
+
id=random_tool_call_id(),
|
| 236 |
+
function=DeltaFunctionCall(
|
| 237 |
+
name=function_name).model_dump(
|
| 238 |
+
exclude_none=True))
|
| 239 |
+
])
|
| 240 |
+
self.current_tool_name_sent = True
|
| 241 |
+
else:
|
| 242 |
+
delta = None
|
| 243 |
+
|
| 244 |
+
# now we know we're on the same tool call and we're streaming
|
| 245 |
+
# arguments
|
| 246 |
+
else:
|
| 247 |
+
cur_arguments = current_tool_call.get("arguments")
|
| 248 |
+
delta = None
|
| 249 |
+
if (self.is_complete[-1] and not cur_arguments
|
| 250 |
+
and not self.streamed_args_for_tool[-1]):
|
| 251 |
+
argument_diff = "{}"
|
| 252 |
+
delta = DeltaMessage(tool_calls=[
|
| 253 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 254 |
+
function=DeltaFunctionCall(
|
| 255 |
+
arguments=argument_diff).
|
| 256 |
+
model_dump(exclude_none=True))
|
| 257 |
+
])
|
| 258 |
+
self.streamed_args_for_tool[
|
| 259 |
+
self.current_tool_id] += argument_diff
|
| 260 |
+
|
| 261 |
+
if cur_arguments:
|
| 262 |
+
sent = len(
|
| 263 |
+
self.streamed_args_for_tool[self.current_tool_id])
|
| 264 |
+
cur_args_json = json.dumps(cur_arguments,
|
| 265 |
+
ensure_ascii=False)
|
| 266 |
+
prev_arguments = self.prev_tool_call_arr[
|
| 267 |
+
self.current_tool_id].get("arguments")
|
| 268 |
+
|
| 269 |
+
argument_diff = None
|
| 270 |
+
if self.is_complete[-1]:
|
| 271 |
+
argument_diff = cur_args_json[sent:]
|
| 272 |
+
elif prev_arguments:
|
| 273 |
+
prev_args_json = json.dumps(prev_arguments,
|
| 274 |
+
ensure_ascii=False)
|
| 275 |
+
if cur_args_json != prev_args_json:
|
| 276 |
+
|
| 277 |
+
prefix = find_common_prefix(
|
| 278 |
+
prev_args_json, cur_args_json)
|
| 279 |
+
argument_diff = prefix[sent:]
|
| 280 |
+
|
| 281 |
+
if argument_diff is not None:
|
| 282 |
+
delta = DeltaMessage(tool_calls=[
|
| 283 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 284 |
+
function=DeltaFunctionCall(
|
| 285 |
+
arguments=argument_diff).
|
| 286 |
+
model_dump(exclude_none=True))
|
| 287 |
+
])
|
| 288 |
+
self.streamed_args_for_tool[
|
| 289 |
+
self.current_tool_id] += argument_diff
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
self.prev_tool_call_arr = tool_call_arr
|
| 293 |
+
return delta
|
| 294 |
+
|
| 295 |
+
except Exception:
|
| 296 |
+
logger.exception("Error trying to handle streaming tool call.")
|
| 297 |
+
logger.debug(
|
| 298 |
+
"Skipping chunk as a result of tool streaming extraction "
|
| 299 |
+
"error")
|
| 300 |
+
return None
|
inference/vllm_ascend/envs.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# This file is a part of the vllm-ascend project.
|
| 4 |
+
#
|
| 5 |
+
# This file is mainly Adapted from vllm-project/vllm/vllm/envs.py
|
| 6 |
+
# Copyright 2023 The vLLM team.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
#
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
from typing import Any, Callable, Dict
|
| 23 |
+
|
| 24 |
+
# The begin-* and end* here are used by the documentation generator
|
| 25 |
+
# to extract the used env vars.
|
| 26 |
+
|
| 27 |
+
# begin-env-vars-definition
|
| 28 |
+
|
| 29 |
+
env_variables: Dict[str, Callable[[], Any]] = {
|
| 30 |
+
# max compile thread number for package building. Usually, it is set to
|
| 31 |
+
# the number of CPU cores. If not set, the default value is None, which
|
| 32 |
+
# means all number of CPU cores will be used.
|
| 33 |
+
"MAX_JOBS":
|
| 34 |
+
lambda: os.getenv("MAX_JOBS", None),
|
| 35 |
+
# The build type of the package. It can be one of the following values:
|
| 36 |
+
# Release, Debug, RelWithDebugInfo. If not set, the default value is Release.
|
| 37 |
+
"CMAKE_BUILD_TYPE":
|
| 38 |
+
lambda: os.getenv("CMAKE_BUILD_TYPE"),
|
| 39 |
+
# Whether to compile custom kernels. If not set, the default value is True.
|
| 40 |
+
# If set to False, the custom kernels will not be compiled. Please note that
|
| 41 |
+
# the sleep mode feature will be disabled as well if custom kernels are not
|
| 42 |
+
# compiled.
|
| 43 |
+
"COMPILE_CUSTOM_KERNELS":
|
| 44 |
+
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
|
| 45 |
+
# The CXX compiler used for compiling the package. If not set, the default
|
| 46 |
+
# value is None, which means the system default CXX compiler will be used.
|
| 47 |
+
"CXX_COMPILER":
|
| 48 |
+
lambda: os.getenv("CXX_COMPILER", None),
|
| 49 |
+
# The C compiler used for compiling the package. If not set, the default
|
| 50 |
+
# value is None, which means the system default C compiler will be used.
|
| 51 |
+
"C_COMPILER":
|
| 52 |
+
lambda: os.getenv("C_COMPILER", None),
|
| 53 |
+
# The version of the Ascend chip. If not set, the default value is
|
| 54 |
+
# ASCEND910B1. It's used for package building. Please make sure that the
|
| 55 |
+
# version is correct.
|
| 56 |
+
"SOC_VERSION":
|
| 57 |
+
lambda: os.getenv("SOC_VERSION", "ASCEND910B1"),
|
| 58 |
+
# If set, vllm-ascend will print verbose logs during compilation
|
| 59 |
+
"VERBOSE":
|
| 60 |
+
lambda: bool(int(os.getenv('VERBOSE', '0'))),
|
| 61 |
+
# The home path for CANN toolkit. If not set, the default value is
|
| 62 |
+
# /usr/local/Ascend/ascend-toolkit/latest
|
| 63 |
+
"ASCEND_HOME_PATH":
|
| 64 |
+
lambda: os.getenv("ASCEND_HOME_PATH", None),
|
| 65 |
+
# The path for HCCN Tool, the tool will be called by disaggregated prefilling
|
| 66 |
+
# case.
|
| 67 |
+
"HCCN_PATH":
|
| 68 |
+
lambda: os.getenv("HCCN_PATH", "/usr/local/Ascend/driver/tools/hccn_tool"),
|
| 69 |
+
# The path for HCCL library, it's used by pyhccl communicator backend. If
|
| 70 |
+
# not set, the default value is libhccl.so。
|
| 71 |
+
"HCCL_SO_PATH":
|
| 72 |
+
# The prefill device id for disaggregated prefilling case.
|
| 73 |
+
lambda: os.environ.get("HCCL_SO_PATH", None),
|
| 74 |
+
"PROMPT_DEVICE_ID":
|
| 75 |
+
lambda: os.getenv("PROMPT_DEVICE_ID", None),
|
| 76 |
+
# The decode device id for disaggregated prefilling case.
|
| 77 |
+
"DECODE_DEVICE_ID":
|
| 78 |
+
lambda: os.getenv("DECODE_DEVICE_ID", None),
|
| 79 |
+
# The port number for llmdatadist communication. If not set, the default
|
| 80 |
+
# value is 26000.
|
| 81 |
+
"LLMDATADIST_COMM_PORT":
|
| 82 |
+
lambda: os.getenv("LLMDATADIST_COMM_PORT", "26000"),
|
| 83 |
+
# The wait time for llmdatadist sync cache. If not set, the default value is
|
| 84 |
+
# 5000ms.
|
| 85 |
+
"LLMDATADIST_SYNC_CACHE_WAIT_TIME":
|
| 86 |
+
lambda: os.getenv("LLMDATADIST_SYNC_CACHE_WAIT_TIME", "5000"),
|
| 87 |
+
# The version of vllm is installed. This value is used for developers who
|
| 88 |
+
# installed vllm from source locally. In this case, the version of vllm is
|
| 89 |
+
# usually changed. For example, if the version of vllm is "0.9.0", but when
|
| 90 |
+
# it's installed from source, the version of vllm is usually set to "0.9.1".
|
| 91 |
+
# In this case, developers need to set this value to "0.9.0" to make sure
|
| 92 |
+
# that the correct package is installed.
|
| 93 |
+
"VLLM_VERSION":
|
| 94 |
+
lambda: os.getenv("VLLM_VERSION", None),
|
| 95 |
+
# Whether to enable the trace recompiles from pytorch.
|
| 96 |
+
"VLLM_ASCEND_TRACE_RECOMPILES":
|
| 97 |
+
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
|
| 98 |
+
# Whether to enable fused_experts_allgather_ep. MoeInitRoutingV3 and
|
| 99 |
+
# GroupedMatmulFinalizeRouting operators are combined to implement EP.
|
| 100 |
+
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP":
|
| 101 |
+
lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0'))
|
| 102 |
+
),
|
| 103 |
+
"VLLM_ASCEND_ENABLE_DBO":
|
| 104 |
+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
|
| 105 |
+
# Whether to enable the model execute time observe profile. Disable it when
|
| 106 |
+
# running vllm ascend in production environment.
|
| 107 |
+
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
|
| 108 |
+
lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
|
| 109 |
+
),
|
| 110 |
+
# MOE_ALL2ALL_BUFFER:
|
| 111 |
+
# 0: default, normal init.
|
| 112 |
+
# 1: enable moe_all2all_buffer.
|
| 113 |
+
"MOE_ALL2ALL_BUFFER":
|
| 114 |
+
lambda: bool(int(os.getenv("MOE_ALL2ALL_BUFFER", '0'))),
|
| 115 |
+
# Some models are optimized by vllm ascend. While in some case, e.g. rlhf
|
| 116 |
+
# training, the optimized model may not be suitable. In this case, set this
|
| 117 |
+
# value to False to disable the optimized model.
|
| 118 |
+
"USE_OPTIMIZED_MODEL":
|
| 119 |
+
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
|
| 120 |
+
# SELECT_GATING_TOPK_SOTFMAX_EXPERTS is the equivalent of select_experts in non-quantized scenarios.
|
| 121 |
+
# In theory, it should have better performance than select_experts.
|
| 122 |
+
# Subsequent versions will remove the SELECT_GATING_TOPK_SOTFMAX_EXPERTS tag and use it as the default mode.
|
| 123 |
+
"SELECT_GATING_TOPK_SOTFMAX_EXPERTS":
|
| 124 |
+
lambda: bool(int(os.getenv("SELECT_GATING_TOPK_SOTFMAX_EXPERTS", '0'))),
|
| 125 |
+
# The tolerance of the kv cache size, if the difference between the
|
| 126 |
+
# actual kv cache size and the cached kv cache size is less than this value,
|
| 127 |
+
# then the cached kv cache size will be used.
|
| 128 |
+
"VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE":
|
| 129 |
+
lambda: int(
|
| 130 |
+
os.getenv("VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", 64)),
|
| 131 |
+
# Whether to enable the topk optimization. It's disabled by default for experimental support
|
| 132 |
+
# We'll make it enabled by default in the future.
|
| 133 |
+
"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION":
|
| 134 |
+
lambda: bool(
|
| 135 |
+
int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '0'))),
|
| 136 |
+
# Whether to enable top n sigma sampling
|
| 137 |
+
"VLLM_ASCEND_ENABLE_TOP_N_SIGMA":
|
| 138 |
+
lambda: bool(
|
| 139 |
+
int(os.getenv("VLLM_ASCEND_ENABLE_TOP_N_SIGMA", '0'))),
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
# end-env-vars-definition
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def __getattr__(name: str):
|
| 146 |
+
# lazy evaluation of environment variables
|
| 147 |
+
if name in env_variables:
|
| 148 |
+
return env_variables[name]()
|
| 149 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def __dir__():
|
| 153 |
+
return list(env_variables.keys())
|
inference/vllm_ascend/models/__init__.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from vllm import ModelRegistry
|
| 2 |
+
|
| 3 |
+
import vllm_ascend.envs as envs
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def register_model():
|
| 7 |
+
from .deepseek_dbo import CustomDeepseekDBOForCausalLM # noqa: F401
|
| 8 |
+
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
|
| 9 |
+
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
|
| 10 |
+
from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
|
| 11 |
+
from .open_pangu import PanguUltraMoEForCausalLM # noqa: F401
|
| 12 |
+
from .open_pangu import PanguEmbeddedForCausalLM # noqa: F401
|
| 13 |
+
from .qwen2_5_vl import \
|
| 14 |
+
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
|
| 15 |
+
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
|
| 16 |
+
|
| 17 |
+
ModelRegistry.register_model(
|
| 18 |
+
"DeepSeekMTPModel",
|
| 19 |
+
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
|
| 20 |
+
|
| 21 |
+
ModelRegistry.register_model(
|
| 22 |
+
"Qwen2VLForConditionalGeneration",
|
| 23 |
+
"vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")
|
| 24 |
+
|
| 25 |
+
if envs.USE_OPTIMIZED_MODEL:
|
| 26 |
+
ModelRegistry.register_model(
|
| 27 |
+
"Qwen2_5_VLForConditionalGeneration",
|
| 28 |
+
"vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration"
|
| 29 |
+
)
|
| 30 |
+
else:
|
| 31 |
+
ModelRegistry.register_model(
|
| 32 |
+
"Qwen2_5_VLForConditionalGeneration",
|
| 33 |
+
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
if envs.VLLM_ASCEND_ENABLE_DBO:
|
| 37 |
+
ModelRegistry.register_model(
|
| 38 |
+
"DeepseekV2ForCausalLM",
|
| 39 |
+
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
|
| 40 |
+
|
| 41 |
+
ModelRegistry.register_model(
|
| 42 |
+
"DeepseekV3ForCausalLM",
|
| 43 |
+
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
|
| 44 |
+
|
| 45 |
+
else:
|
| 46 |
+
ModelRegistry.register_model(
|
| 47 |
+
"DeepseekV2ForCausalLM",
|
| 48 |
+
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
|
| 49 |
+
|
| 50 |
+
ModelRegistry.register_model(
|
| 51 |
+
"DeepseekV3ForCausalLM",
|
| 52 |
+
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
|
| 53 |
+
|
| 54 |
+
ModelRegistry.register_model(
|
| 55 |
+
"Qwen3MoeForCausalLM",
|
| 56 |
+
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
| 57 |
+
|
| 58 |
+
ModelRegistry.register_model(
|
| 59 |
+
"PanguProMoEForCausalLM",
|
| 60 |
+
"vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM")
|
| 61 |
+
|
| 62 |
+
ModelRegistry.register_model(
|
| 63 |
+
"PanguUltraMoEForCausalLM",
|
| 64 |
+
"vllm_ascend.models.open_pangu:PanguUltraMoEForCausalLM")
|
| 65 |
+
|
| 66 |
+
ModelRegistry.register_model(
|
| 67 |
+
"PanguEmbeddedForCausalLM",
|
| 68 |
+
"vllm_ascend.models.open_pangu:PanguEmbeddedForCausalLM")
|
inference/vllm_ascend/models/open_pangu.py
ADDED
|
@@ -0,0 +1,1127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# Copyright 2023 The vLLM team.
|
| 4 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 7 |
+
# and OPT implementations in this library. It has been modified from its
|
| 8 |
+
# original forms to accommodate minor architectural differences compared
|
| 9 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 10 |
+
#
|
| 11 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 12 |
+
# you may not use this file except in compliance with the License.
|
| 13 |
+
# You may obtain a copy of the License at
|
| 14 |
+
#
|
| 15 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 16 |
+
#
|
| 17 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 18 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 19 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 20 |
+
# See the License for the specific language governing permissions and
|
| 21 |
+
# limitations under the License.
|
| 22 |
+
|
| 23 |
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
| 24 |
+
import torch
|
| 25 |
+
import torch_npu
|
| 26 |
+
import vllm.envs as envs
|
| 27 |
+
from torch import nn
|
| 28 |
+
from transformers import PretrainedConfig
|
| 29 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 30 |
+
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
| 31 |
+
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
| 32 |
+
from vllm.distributed import (get_tensor_model_parallel_rank,
|
| 33 |
+
get_tensor_model_parallel_world_size,
|
| 34 |
+
get_tp_group, split_tensor_along_last_dim,
|
| 35 |
+
tensor_model_parallel_all_gather,
|
| 36 |
+
tensor_model_parallel_all_reduce,
|
| 37 |
+
tensor_model_parallel_reduce_scatter)
|
| 38 |
+
from vllm.distributed.parallel_state import get_dp_group
|
| 39 |
+
from vllm.forward_context import get_forward_context
|
| 40 |
+
from vllm.model_executor.layers.activation import SiluAndMul
|
| 41 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 42 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 43 |
+
MergedColumnParallelLinear,
|
| 44 |
+
ReplicatedLinear,
|
| 45 |
+
RowParallelLinear,
|
| 46 |
+
UnquantizedLinearMethod,
|
| 47 |
+
QKVParallelLinear)
|
| 48 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 49 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 50 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope, _rotate_gptj
|
| 51 |
+
from vllm.model_executor.layers.sampler import get_sampler
|
| 52 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 53 |
+
ParallelLMHead, VocabParallelEmbedding)
|
| 54 |
+
from vllm.model_executor.model_loader.weight_utils import (
|
| 55 |
+
default_weight_loader, maybe_remap_kv_scale_name)
|
| 56 |
+
from vllm.model_executor.models.utils import (
|
| 57 |
+
make_layers, maybe_prefix, extract_layer_index)
|
| 58 |
+
from vllm_ascend.ascend_config import get_ascend_config
|
| 59 |
+
from vllm_ascend.distributed.parallel_state import get_ep_group
|
| 60 |
+
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
| 61 |
+
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
| 62 |
+
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
| 63 |
+
from vllm_ascend.utils import dispose_tensor, npu_prefetch, get_fused_moe_state, FusedMoEState
|
| 64 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class OpenPanguMergedReplicatedLinear(ReplicatedLinear):
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
input_size: int,
|
| 72 |
+
output_sizes: list[int],
|
| 73 |
+
bias: bool = True,
|
| 74 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 75 |
+
prefix: str = "",
|
| 76 |
+
):
|
| 77 |
+
self.output_sizes = output_sizes
|
| 78 |
+
super().__init__(input_size,
|
| 79 |
+
sum(output_sizes),
|
| 80 |
+
bias=bias,
|
| 81 |
+
quant_config=quant_config,
|
| 82 |
+
prefix=prefix)
|
| 83 |
+
|
| 84 |
+
def weight_loader(self, param: torch.nn.Parameter,
|
| 85 |
+
loaded_weight: torch.Tensor, loaded_shard_id: int):
|
| 86 |
+
# With no support for GGUF format yet.
|
| 87 |
+
if getattr(param, "is_gguf_weight", False) or getattr(param, "is_gguf_weight_type", False):
|
| 88 |
+
raise ValueError('With no support for GGUF format yet.')
|
| 89 |
+
if loaded_shard_id >= len(self.output_sizes):
|
| 90 |
+
raise ValueError(f'loaded_shard_id {loaded_shard_id} >= len(self.output_sizes) {len(self.output_sizes)}.')
|
| 91 |
+
shard_offset = sum(self.output_sizes[:loaded_shard_id])
|
| 92 |
+
shard_size = self.output_sizes[loaded_shard_id]
|
| 93 |
+
shard = param.data.narrow(param.output_dim, shard_offset, shard_size)
|
| 94 |
+
if shard.size() != loaded_weight.size():
|
| 95 |
+
raise ValueError(f"Tried to load weights of size {loaded_weight.size()} "
|
| 96 |
+
f"to a parameter shard of id {loaded_shard_id} size {shard.size()}.")
|
| 97 |
+
shard.copy_(loaded_weight)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class OpenPanguRowParallelLinearReplaceAllreduce(RowParallelLinear):
|
| 101 |
+
|
| 102 |
+
def forward(
|
| 103 |
+
self,
|
| 104 |
+
input_,
|
| 105 |
+
is_prefill=True
|
| 106 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
|
| 107 |
+
if self.input_is_parallel:
|
| 108 |
+
input_parallel = input_
|
| 109 |
+
else:
|
| 110 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 111 |
+
splitted_input = split_tensor_along_last_dim(
|
| 112 |
+
input_, num_partitions=self.tp_size)
|
| 113 |
+
input_parallel = splitted_input[tp_rank].contiguous()
|
| 114 |
+
|
| 115 |
+
# Matrix multiply.
|
| 116 |
+
if self.quant_method is None:
|
| 117 |
+
raise ValueError('self.quant_method is None.')
|
| 118 |
+
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
| 119 |
+
# bias will not get added more than once in TP>1 case)
|
| 120 |
+
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
| 121 |
+
output_parallel = self.quant_method.apply(self,
|
| 122 |
+
input_parallel,
|
| 123 |
+
bias=bias_)
|
| 124 |
+
if self.reduce_results and self.tp_size > 1:
|
| 125 |
+
if not is_prefill and output_parallel.shape[0] % self.tp_size == 0:
|
| 126 |
+
output = tensor_model_parallel_reduce_scatter(output_parallel,
|
| 127 |
+
dim=0)
|
| 128 |
+
else:
|
| 129 |
+
output = tensor_model_parallel_all_reduce(output_parallel)
|
| 130 |
+
else:
|
| 131 |
+
output = output_parallel
|
| 132 |
+
|
| 133 |
+
output_bias = self.bias if self.skip_bias_add else None
|
| 134 |
+
|
| 135 |
+
if not self.return_bias:
|
| 136 |
+
return output
|
| 137 |
+
return output, output_bias
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class OpenPanguRowParallelLinear(RowParallelLinear):
|
| 141 |
+
|
| 142 |
+
def forward(
|
| 143 |
+
self,
|
| 144 |
+
input_,
|
| 145 |
+
is_prefill=True
|
| 146 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
|
| 147 |
+
return super().forward(input_)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class OpenPanguRotaryEmbedding(nn.Module):
|
| 151 |
+
def __init__(self,
|
| 152 |
+
head_size: int,
|
| 153 |
+
rotary_dim: int,
|
| 154 |
+
max_position_embeddings: int,
|
| 155 |
+
base: float,
|
| 156 |
+
):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.dim = rotary_dim
|
| 159 |
+
self.max_position_embeddings = max_position_embeddings
|
| 160 |
+
self.base = base
|
| 161 |
+
self._set_cos_sin_cache(
|
| 162 |
+
seq_len=max_position_embeddings,
|
| 163 |
+
device='npu',
|
| 164 |
+
dtype=torch.get_default_dtype(),
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def _set_cos_sin_cache(self,
|
| 168 |
+
seq_len: int,
|
| 169 |
+
device: str,
|
| 170 |
+
dtype: torch.dtype
|
| 171 |
+
):
|
| 172 |
+
self.max_seq_len = seq_len
|
| 173 |
+
inv_freq = 1.0 / (
|
| 174 |
+
self.base
|
| 175 |
+
** (torch.arange(0, self.dim, 2, dtype=torch.float32, device='npu') / self.dim)
|
| 176 |
+
)
|
| 177 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 178 |
+
t = torch.arange(seq_len, device='npu', dtype=torch.float32)
|
| 179 |
+
freqs = torch.outer(t, inv_freq)
|
| 180 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 181 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 182 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 183 |
+
|
| 184 |
+
def forward(self,
|
| 185 |
+
positions: torch.Tensor,
|
| 186 |
+
query: torch.Tensor,
|
| 187 |
+
key: torch.Tensor,
|
| 188 |
+
offsets: Optional[torch.Tensor] = None,
|
| 189 |
+
max_seq_len: Optional[int] = None,
|
| 190 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 191 |
+
if max_seq_len is not None and max_seq_len > self.max_seq_len:
|
| 192 |
+
self._set_cos_sin_cache(max_seq_len, query.device, query.dtype)
|
| 193 |
+
idx = torch.add(positions, offsets) if offsets is not None else positions
|
| 194 |
+
cos = self.cos_cached[idx]
|
| 195 |
+
sin = self.sin_cached[idx]
|
| 196 |
+
# Adapt: adapt cos and sin shape
|
| 197 |
+
cos = cos.view(-1, 1, cos.shape[-1])
|
| 198 |
+
sin = sin.view(-1, 1, sin.shape[-1])
|
| 199 |
+
# Adapt end.
|
| 200 |
+
query_rot = query * cos + _rotate_gptj(query) * sin
|
| 201 |
+
if key is not None:
|
| 202 |
+
key_rot = key * cos + _rotate_gptj(key) * sin
|
| 203 |
+
return query_rot, key_rot
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class OpenPanguSiluAndMul(SiluAndMul):
|
| 207 |
+
|
| 208 |
+
def __init__(self,
|
| 209 |
+
*,
|
| 210 |
+
weight_scale: Optional[Callable[[], torch.Tensor]] = None):
|
| 211 |
+
super().__init__()
|
| 212 |
+
self.weight_scale = weight_scale
|
| 213 |
+
|
| 214 |
+
def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
|
| 215 |
+
torch.Tensor]]):
|
| 216 |
+
if isinstance(x, tuple):
|
| 217 |
+
if self.weight_scale is None:
|
| 218 |
+
raise ValueError('self.weight_scale is None.')
|
| 219 |
+
quantized_x, dynamic_scale = x
|
| 220 |
+
return torch_npu.npu_dequant_swiglu_quant(
|
| 221 |
+
x=quantized_x,
|
| 222 |
+
weight_scale=self.weight_scale(),
|
| 223 |
+
activation_scale=dynamic_scale,
|
| 224 |
+
activate_left=True,
|
| 225 |
+
quant_mode=1)
|
| 226 |
+
else:
|
| 227 |
+
return super().forward_oot(x)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def check_ffn_act_fn(act_fn: str):
|
| 231 |
+
if act_fn != "silu":
|
| 232 |
+
raise ValueError(
|
| 233 |
+
f"Unsupported activation: {act_fn}. Only silu is supported for now.")
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class OpenPanguMLP(nn.Module):
|
| 237 |
+
|
| 238 |
+
def __init__(
|
| 239 |
+
self,
|
| 240 |
+
hidden_size: int,
|
| 241 |
+
intermediate_size: int,
|
| 242 |
+
hidden_act: str,
|
| 243 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 244 |
+
bias: bool = False,
|
| 245 |
+
reduce_results: bool = True,
|
| 246 |
+
force_replicate: bool = False,
|
| 247 |
+
prefix: str = "",
|
| 248 |
+
) -> None:
|
| 249 |
+
super().__init__()
|
| 250 |
+
if not force_replicate:
|
| 251 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
| 252 |
+
hidden_size, [intermediate_size] * 2,
|
| 253 |
+
bias=bias,
|
| 254 |
+
quant_config=quant_config,
|
| 255 |
+
prefix=f"{prefix}.gate_up_proj")
|
| 256 |
+
self.down_proj = RowParallelLinear(intermediate_size,
|
| 257 |
+
hidden_size,
|
| 258 |
+
bias=bias,
|
| 259 |
+
quant_config=quant_config,
|
| 260 |
+
reduce_results=reduce_results,
|
| 261 |
+
prefix=f"{prefix}.down_proj")
|
| 262 |
+
else:
|
| 263 |
+
self.gate_up_proj = OpenPanguMergedReplicatedLinear(
|
| 264 |
+
hidden_size, [intermediate_size] * 2,
|
| 265 |
+
bias=bias,
|
| 266 |
+
quant_config=quant_config,
|
| 267 |
+
prefix=f"{prefix}.gate_up_proj")
|
| 268 |
+
self.down_proj = ReplicatedLinear(intermediate_size,
|
| 269 |
+
hidden_size,
|
| 270 |
+
bias=bias,
|
| 271 |
+
quant_config=quant_config,
|
| 272 |
+
prefix=f"{prefix}.down_proj")
|
| 273 |
+
|
| 274 |
+
check_ffn_act_fn(hidden_act)
|
| 275 |
+
|
| 276 |
+
quant_method = self.gate_up_proj.quant_method
|
| 277 |
+
if isinstance(quant_method, UnquantizedLinearMethod):
|
| 278 |
+
self.act_fn = OpenPanguSiluAndMul()
|
| 279 |
+
elif (isinstance(quant_method, AscendLinearMethod) and isinstance(
|
| 280 |
+
quant_method.quant_method, AscendW8A8DynamicLinearMethod)):
|
| 281 |
+
# TODO(sdmyzlp): Currently preserved as before:
|
| 282 |
+
# 1. The only quantization supported for silu is W8A8Dynamic
|
| 283 |
+
# 2. Output dtype of gate_up/down is fixed to be int32/bfloat16
|
| 284 |
+
#
|
| 285 |
+
# Maybe one can implement a better and more general configuration
|
| 286 |
+
# scheme, e.g. by somehow passing around the tweaked `quant_config`
|
| 287 |
+
self.act_fn = OpenPanguSiluAndMul(
|
| 288 |
+
# Use lazy binding, for `weight_scale_fp32` is accessible
|
| 289 |
+
# only after `process_weights_after_loading`.
|
| 290 |
+
weight_scale=lambda: self.gate_up_proj.weight_scale_fp32)
|
| 291 |
+
# To be consumed by AscendW8A8DynamicLinearMethod.apply()
|
| 292 |
+
self.gate_up_proj._ascend_quant_config = {
|
| 293 |
+
"output_dtype": torch.int32,
|
| 294 |
+
"pertoken_scale": False,
|
| 295 |
+
"return_scale": True,
|
| 296 |
+
}
|
| 297 |
+
self.down_proj._ascend_quant_config = {
|
| 298 |
+
"output_dtype": torch.bfloat16,
|
| 299 |
+
"pertoken_scale": True,
|
| 300 |
+
"return_scale": False,
|
| 301 |
+
}
|
| 302 |
+
else:
|
| 303 |
+
raise NotImplementedError(
|
| 304 |
+
f"Quantization with [{type(quant_method)}] is NOT supported")
|
| 305 |
+
|
| 306 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 307 |
+
return self.down_proj(self.act_fn(self.gate_up_proj(x)[0]))[0]
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class OpenPanguMoE(nn.Module):
|
| 311 |
+
|
| 312 |
+
top_k: int
|
| 313 |
+
|
| 314 |
+
def __init__(
|
| 315 |
+
self,
|
| 316 |
+
config: PretrainedConfig,
|
| 317 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 318 |
+
prefix: str = "",
|
| 319 |
+
):
|
| 320 |
+
super().__init__()
|
| 321 |
+
ascend_config = get_ascend_config()
|
| 322 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 323 |
+
self.enable_multistream_moe = \
|
| 324 |
+
ascend_config.torchair_graph_config.enable_multistream_moe
|
| 325 |
+
self.routed_scaling_factor = config.routed_scaling_factor
|
| 326 |
+
check_ffn_act_fn(config.hidden_act)
|
| 327 |
+
|
| 328 |
+
self.gate = ReplicatedLinear(config.hidden_size,
|
| 329 |
+
config.num_routed_experts,
|
| 330 |
+
bias=False,
|
| 331 |
+
quant_config=None,
|
| 332 |
+
prefix=f"{prefix}.gate")
|
| 333 |
+
|
| 334 |
+
self.experts = AscendFusedMoE(
|
| 335 |
+
num_experts=config.num_routed_experts,
|
| 336 |
+
top_k=config.num_experts_per_tok,
|
| 337 |
+
hidden_size=config.hidden_size,
|
| 338 |
+
intermediate_size=config.moe_intermediate_size,
|
| 339 |
+
reduce_results=False,
|
| 340 |
+
renormalize=config.norm_topk_prob,
|
| 341 |
+
quant_config=quant_config,
|
| 342 |
+
use_grouped_topk=True,
|
| 343 |
+
num_expert_group=1,
|
| 344 |
+
topk_group=1,
|
| 345 |
+
prefix=f"{prefix}.experts",
|
| 346 |
+
scoring_func='sigmoid',
|
| 347 |
+
e_score_correction_bias=None)
|
| 348 |
+
|
| 349 |
+
if config.num_shared_experts is not None:
|
| 350 |
+
self.all_reduce_merge = self.experts.all_reduce_merge
|
| 351 |
+
reduce_results = not self.all_reduce_merge
|
| 352 |
+
intermediate_size = (config.moe_intermediate_size * config.num_shared_experts)
|
| 353 |
+
self.shared_experts = OpenPanguMLP(
|
| 354 |
+
hidden_size=config.hidden_size,
|
| 355 |
+
intermediate_size=intermediate_size,
|
| 356 |
+
hidden_act=config.hidden_act,
|
| 357 |
+
quant_config=quant_config,
|
| 358 |
+
reduce_results=reduce_results,
|
| 359 |
+
force_replicate=self.enable_multistream_moe,
|
| 360 |
+
prefix=f"{prefix}.shared_experts",
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
self.shared_experts = None # type: ignore
|
| 364 |
+
|
| 365 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 366 |
+
self.dp_size = get_dp_group().world_size
|
| 367 |
+
self.tp_group = get_tp_group().device_group
|
| 368 |
+
self.tp_rank = get_tp_group().rank_in_group
|
| 369 |
+
self.ep_group = get_ep_group()
|
| 370 |
+
|
| 371 |
+
self.params_dtype = torch.get_default_dtype()
|
| 372 |
+
self.rm_router_logits = self.experts.rm_router_logits
|
| 373 |
+
|
| 374 |
+
self.__class__.top_k = config.num_experts_per_tok
|
| 375 |
+
|
| 376 |
+
def forward(self,
|
| 377 |
+
hidden_states: torch.Tensor,
|
| 378 |
+
attn_metadata: Optional[AttentionMetadata] = None,
|
| 379 |
+
replace_allreduce: bool = False) -> torch.Tensor:
|
| 380 |
+
|
| 381 |
+
if attn_metadata is None:
|
| 382 |
+
attn_metadata = get_forward_context().attn_metadata
|
| 383 |
+
# when profile runs, force experts to load balanced tokens
|
| 384 |
+
# to avoid high memory consumption on a single rank.
|
| 385 |
+
# TODO: need a better flag to indicate whether in profile run or not.
|
| 386 |
+
if attn_metadata is None:
|
| 387 |
+
# for profile run
|
| 388 |
+
is_prefill = True
|
| 389 |
+
fused_moe_state = get_fused_moe_state(self.ep_group.world_size, is_prefill, True)
|
| 390 |
+
enable_force_load_balance = fused_moe_state != FusedMoEState.AllGatherEP
|
| 391 |
+
else:
|
| 392 |
+
is_prefill = attn_metadata.num_prefills > 0
|
| 393 |
+
enable_force_load_balance = False
|
| 394 |
+
if hasattr(attn_metadata, 'with_prefill_across_dp'):
|
| 395 |
+
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
|
| 396 |
+
fused_moe_state = get_fused_moe_state(self.ep_group.world_size, is_prefill, True)
|
| 397 |
+
|
| 398 |
+
# router_logits: (num_tokens, n_experts)
|
| 399 |
+
router_logits = None
|
| 400 |
+
if not self.rm_router_logits or fused_moe_state == FusedMoEState.All2All:
|
| 401 |
+
router_logits, _ = self.gate(hidden_states.float())
|
| 402 |
+
|
| 403 |
+
routed_hidden_states, shared_hidden_states = self.experts(
|
| 404 |
+
hidden_states=hidden_states,
|
| 405 |
+
router_logits=router_logits,
|
| 406 |
+
is_prefill=is_prefill,
|
| 407 |
+
top_k=self.__class__.top_k,
|
| 408 |
+
enable_force_load_balance=enable_force_load_balance,
|
| 409 |
+
shared_experts=self.shared_experts,
|
| 410 |
+
gate=self.gate,
|
| 411 |
+
replace_allreduce=replace_allreduce)
|
| 412 |
+
|
| 413 |
+
if self.all_reduce_merge and fused_moe_state == FusedMoEState.All2All:
|
| 414 |
+
shared_hidden_states = tensor_model_parallel_all_reduce(shared_hidden_states)
|
| 415 |
+
hidden_states = routed_hidden_states * self.routed_scaling_factor + shared_hidden_states
|
| 416 |
+
if self.all_reduce_merge and fused_moe_state != FusedMoEState.All2All:
|
| 417 |
+
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
| 418 |
+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
| 419 |
+
|
| 420 |
+
return hidden_states
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
class OpenPanguMLAAttention(nn.Module):
|
| 424 |
+
|
| 425 |
+
def __init__(
|
| 426 |
+
self,
|
| 427 |
+
config: PretrainedConfig,
|
| 428 |
+
hidden_size: int,
|
| 429 |
+
num_heads: int,
|
| 430 |
+
attention_qk_dim: int,
|
| 431 |
+
attention_qk_rope_dim: int,
|
| 432 |
+
attention_v_dim: int,
|
| 433 |
+
attention_q_lora_dim: Optional[int],
|
| 434 |
+
attention_kv_lora_dim: int,
|
| 435 |
+
rope_theta: float = 10000,
|
| 436 |
+
max_position_embeddings: int = 8192,
|
| 437 |
+
cache_config: Optional[CacheConfig] = None,
|
| 438 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 439 |
+
prefix: str = "",
|
| 440 |
+
) -> None:
|
| 441 |
+
super().__init__()
|
| 442 |
+
ascend_config = get_ascend_config()
|
| 443 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 444 |
+
self.enable_multistream_mla = ascend_config.torchair_graph_config.enable_multistream_mla
|
| 445 |
+
|
| 446 |
+
self.hidden_size = hidden_size
|
| 447 |
+
self.num_heads = num_heads
|
| 448 |
+
self.attention_qk_dim = attention_qk_dim
|
| 449 |
+
self.attention_qk_rope_dim = attention_qk_rope_dim
|
| 450 |
+
self.qk_head_dim = attention_qk_dim + attention_qk_rope_dim
|
| 451 |
+
self.attention_v_dim = attention_v_dim
|
| 452 |
+
self.attention_q_lora_dim = attention_q_lora_dim
|
| 453 |
+
self.attention_kv_lora_dim = attention_kv_lora_dim
|
| 454 |
+
self.rope_theta = rope_theta
|
| 455 |
+
|
| 456 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 457 |
+
if num_heads % tp_size != 0:
|
| 458 |
+
raise ValueError(f'num_heads {num_heads} is not divisible by tp_size {tp_size}.')
|
| 459 |
+
self.num_local_heads = num_heads // tp_size
|
| 460 |
+
|
| 461 |
+
self.scaling = self.qk_head_dim**-0.5
|
| 462 |
+
self.max_position_embeddings = max_position_embeddings
|
| 463 |
+
|
| 464 |
+
self.prefix = prefix
|
| 465 |
+
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
| 466 |
+
|
| 467 |
+
if self.attention_q_lora_dim is not None:
|
| 468 |
+
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
| 469 |
+
self.attention_q_lora_dim,
|
| 470 |
+
bias=False,
|
| 471 |
+
quant_config=quant_config,
|
| 472 |
+
prefix=f"{prefix}.q_a_proj")
|
| 473 |
+
self.q_a_layernorm = RMSNorm(self.attention_q_lora_dim, eps=config.rms_norm_eps)
|
| 474 |
+
self.q_b_proj = ColumnParallelLinear(attention_q_lora_dim,
|
| 475 |
+
self.num_heads * self.qk_head_dim,
|
| 476 |
+
bias=False,
|
| 477 |
+
quant_config=quant_config,
|
| 478 |
+
prefix=f"{prefix}.q_b_proj")
|
| 479 |
+
else:
|
| 480 |
+
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
| 481 |
+
self.num_heads * self.qk_head_dim,
|
| 482 |
+
bias=False,
|
| 483 |
+
quant_config=quant_config,
|
| 484 |
+
prefix=f"{prefix}.q_proj")
|
| 485 |
+
|
| 486 |
+
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
| 487 |
+
self.hidden_size,
|
| 488 |
+
self.attention_kv_lora_dim + self.attention_qk_rope_dim,
|
| 489 |
+
bias=False,
|
| 490 |
+
quant_config=quant_config,
|
| 491 |
+
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
| 492 |
+
self.kv_a_layernorm = RMSNorm(self.attention_kv_lora_dim,
|
| 493 |
+
eps=config.rms_norm_eps)
|
| 494 |
+
self.kv_b_proj = ColumnParallelLinear(
|
| 495 |
+
self.attention_kv_lora_dim,
|
| 496 |
+
self.num_heads * (self.attention_qk_dim + self.attention_v_dim),
|
| 497 |
+
bias=False,
|
| 498 |
+
quant_config=quant_config,
|
| 499 |
+
prefix=f"{prefix}.kv_b_proj")
|
| 500 |
+
if (config.num_routed_experts is not None
|
| 501 |
+
and self.debug_layer_idx >= config.num_dense_layers and
|
| 502 |
+
ascend_config.torchair_graph_config.enable_multistream_moe):
|
| 503 |
+
self.o_proj = OpenPanguRowParallelLinearReplaceAllreduce(
|
| 504 |
+
self.num_heads * self.attention_v_dim,
|
| 505 |
+
self.hidden_size,
|
| 506 |
+
bias=False,
|
| 507 |
+
quant_config=quant_config,
|
| 508 |
+
prefix=f"{prefix}.o_proj")
|
| 509 |
+
else:
|
| 510 |
+
self.o_proj = OpenPanguRowParallelLinear(
|
| 511 |
+
self.num_heads * self.attention_v_dim,
|
| 512 |
+
self.hidden_size,
|
| 513 |
+
bias=False,
|
| 514 |
+
quant_config=quant_config,
|
| 515 |
+
prefix=f"{prefix}.o_proj")
|
| 516 |
+
|
| 517 |
+
self.rotary_emb = OpenPanguRotaryEmbedding(attention_qk_rope_dim,
|
| 518 |
+
rotary_dim=attention_qk_rope_dim,
|
| 519 |
+
max_position_embeddings=max_position_embeddings,
|
| 520 |
+
base=rope_theta)
|
| 521 |
+
|
| 522 |
+
self.mla_attn = Attention(
|
| 523 |
+
num_heads=self.num_local_heads,
|
| 524 |
+
head_size=self.attention_kv_lora_dim + self.attention_qk_rope_dim,
|
| 525 |
+
scale=self.scaling,
|
| 526 |
+
num_kv_heads=1,
|
| 527 |
+
cache_config=cache_config,
|
| 528 |
+
quant_config=quant_config,
|
| 529 |
+
prefix=f"{prefix}.attn",
|
| 530 |
+
use_mla=True,
|
| 531 |
+
# MLA Args
|
| 532 |
+
q_lora_rank=self.attention_q_lora_dim,
|
| 533 |
+
kv_lora_rank=self.attention_kv_lora_dim,
|
| 534 |
+
qk_nope_head_dim=self.attention_qk_dim,
|
| 535 |
+
qk_rope_head_dim=self.attention_qk_rope_dim,
|
| 536 |
+
qk_head_dim=self.qk_head_dim,
|
| 537 |
+
v_head_dim=self.attention_v_dim,
|
| 538 |
+
rotary_emb=self.rotary_emb,
|
| 539 |
+
q_proj=self.q_proj if self.attention_q_lora_dim is None else self.q_b_proj,
|
| 540 |
+
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
| 541 |
+
kv_a_layernorm=self.kv_a_layernorm,
|
| 542 |
+
kv_b_proj=self.kv_b_proj,
|
| 543 |
+
o_proj=self.o_proj,
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
def forward(
|
| 547 |
+
self,
|
| 548 |
+
positions: torch.Tensor,
|
| 549 |
+
hidden_states: torch.Tensor,
|
| 550 |
+
kv_cache: Optional[torch.Tensor] = None,
|
| 551 |
+
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
| 552 |
+
enable_multistream_mla = (self.enable_multistream_mla
|
| 553 |
+
and attn_metadata is not None
|
| 554 |
+
and not attn_metadata.with_prefill_across_dp
|
| 555 |
+
and attn_metadata.num_decodes > 0)
|
| 556 |
+
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
|
| 557 |
+
if self.attention_q_lora_dim is not None:
|
| 558 |
+
npu_prefetch(self.q_a_proj.weight,
|
| 559 |
+
hidden_states,
|
| 560 |
+
enabled=enable_multistream_mla)
|
| 561 |
+
ckq = self.q_a_proj(hidden_states)[0]
|
| 562 |
+
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
| 563 |
+
forward_kwargs['ckq'] = ckq
|
| 564 |
+
else:
|
| 565 |
+
hidden_states_or_q_c = hidden_states
|
| 566 |
+
if self.torchair_graph_enabled:
|
| 567 |
+
if envs.VLLM_USE_V1:
|
| 568 |
+
output_shape = hidden_states.shape
|
| 569 |
+
output = torch.empty(output_shape,
|
| 570 |
+
dtype=hidden_states_or_q_c.dtype,
|
| 571 |
+
device=hidden_states_or_q_c.device)
|
| 572 |
+
forward_kwargs['output'] = output
|
| 573 |
+
|
| 574 |
+
output = self.mla_attn.impl.forward(self.mla_attn,
|
| 575 |
+
hidden_states_or_q_c,
|
| 576 |
+
hidden_states, None, kv_cache,
|
| 577 |
+
attn_metadata,
|
| 578 |
+
**forward_kwargs)
|
| 579 |
+
if envs.VLLM_USE_V1:
|
| 580 |
+
output = output.view(-1, output_shape[-1])
|
| 581 |
+
return output
|
| 582 |
+
else:
|
| 583 |
+
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
| 584 |
+
[self.attention_kv_lora_dim, self.attention_qk_rope_dim], dim=-1)
|
| 585 |
+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
| 586 |
+
return self.mla_attn(hidden_states_or_q_c,
|
| 587 |
+
kv_c_normed,
|
| 588 |
+
k_pe,
|
| 589 |
+
output_shape=hidden_states.shape)
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
class OpenPanguEmbeddedAttention(nn.Module):
|
| 593 |
+
|
| 594 |
+
def __init__(
|
| 595 |
+
self,
|
| 596 |
+
config: PretrainedConfig,
|
| 597 |
+
hidden_size: int,
|
| 598 |
+
num_heads: int,
|
| 599 |
+
num_kv_heads: int,
|
| 600 |
+
rope_theta: float = 10000,
|
| 601 |
+
rope_scaling: Optional[dict[str, Any]] = None,
|
| 602 |
+
max_position_embeddings: int = 8192,
|
| 603 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 604 |
+
bias: bool = False,
|
| 605 |
+
bias_o_proj: bool = False,
|
| 606 |
+
cache_config: Optional[CacheConfig] = None,
|
| 607 |
+
prefix: str = "",
|
| 608 |
+
attn_type: str = AttentionType.DECODER,
|
| 609 |
+
) -> None:
|
| 610 |
+
super().__init__()
|
| 611 |
+
layer_idx = extract_layer_index(prefix)
|
| 612 |
+
self.hidden_size = hidden_size
|
| 613 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 614 |
+
self.total_num_heads = num_heads
|
| 615 |
+
if self.total_num_heads % tp_size != 0:
|
| 616 |
+
raise ValueError(f'total_num_heads {total_num_heads} is not divisible by tp_size {tp_size}.')
|
| 617 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 618 |
+
self.total_num_kv_heads = num_kv_heads
|
| 619 |
+
if self.total_num_kv_heads >= tp_size and self.total_num_kv_heads % tp_size != 0:
|
| 620 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 621 |
+
# the KV heads across multiple tensor parallel NPUs.
|
| 622 |
+
raise ValueError(f'Number of KV heads is less than TP size, but total_num_kv_heads {self.total_num_kv_heads} '
|
| 623 |
+
f'is not divisible by tp_size {tp_size}.')
|
| 624 |
+
elif self.total_num_kv_heads < tp_size and tp_size % self.total_num_kv_heads != 0:
|
| 625 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 626 |
+
# the KV heads across multiple tensor parallel NPUs.
|
| 627 |
+
raise ValueError(f'Number of KV heads is less than TP size, but tp_size {tp_size} '
|
| 628 |
+
f'is not divisible by total_num_kv_heads {self.total_num_kv_heads}.')
|
| 629 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
| 630 |
+
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
| 631 |
+
head_dim = getattr(config, "head_dim", None)
|
| 632 |
+
if head_dim is None:
|
| 633 |
+
head_dim = self.hidden_size // self.total_num_heads
|
| 634 |
+
self.head_dim = head_dim
|
| 635 |
+
# Phi models introduced a partial_rotary_factor parameter in the config
|
| 636 |
+
self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
|
| 637 |
+
self.q_size = self.num_heads * self.head_dim
|
| 638 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 639 |
+
self.scaling = self.head_dim**-0.5
|
| 640 |
+
self.rope_theta = rope_theta
|
| 641 |
+
self.max_position_embeddings = max_position_embeddings
|
| 642 |
+
|
| 643 |
+
self.qkv_proj = QKVParallelLinear(
|
| 644 |
+
hidden_size=hidden_size,
|
| 645 |
+
head_size=self.head_dim,
|
| 646 |
+
total_num_heads=self.total_num_heads,
|
| 647 |
+
total_num_kv_heads=self.total_num_kv_heads,
|
| 648 |
+
bias=bias,
|
| 649 |
+
quant_config=quant_config,
|
| 650 |
+
prefix=f"{prefix}.qkv_proj",
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
self.o_proj = RowParallelLinear(
|
| 654 |
+
input_size=self.total_num_heads * self.head_dim,
|
| 655 |
+
output_size=hidden_size,
|
| 656 |
+
bias=bias_o_proj,
|
| 657 |
+
quant_config=quant_config,
|
| 658 |
+
prefix=f"{prefix}.o_proj",
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
self._init_rotary_emb(config,
|
| 662 |
+
rope_scaling=rope_scaling,
|
| 663 |
+
quant_config=quant_config)
|
| 664 |
+
|
| 665 |
+
if hasattr(config, "interleaved_sliding_window"):
|
| 666 |
+
interleaved_sliding_window = config.interleaved_sliding_window
|
| 667 |
+
if isinstance(interleaved_sliding_window, int):
|
| 668 |
+
sliding_window = interleaved_sliding_window
|
| 669 |
+
elif isinstance(interleaved_sliding_window, list):
|
| 670 |
+
sw_idx = layer_idx % len(interleaved_sliding_window)
|
| 671 |
+
sliding_window = interleaved_sliding_window[sw_idx]
|
| 672 |
+
else:
|
| 673 |
+
raise ValueError(
|
| 674 |
+
f"{type(interleaved_sliding_window)} is not supported.")
|
| 675 |
+
else:
|
| 676 |
+
sliding_window = None
|
| 677 |
+
|
| 678 |
+
self.attn = Attention(
|
| 679 |
+
self.num_heads,
|
| 680 |
+
self.head_dim,
|
| 681 |
+
self.scaling,
|
| 682 |
+
num_kv_heads=self.num_kv_heads,
|
| 683 |
+
cache_config=cache_config,
|
| 684 |
+
quant_config=quant_config,
|
| 685 |
+
per_layer_sliding_window=sliding_window,
|
| 686 |
+
attn_type=attn_type,
|
| 687 |
+
prefix=f"{prefix}.attn",
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
def forward(
|
| 691 |
+
self,
|
| 692 |
+
positions: torch.Tensor,
|
| 693 |
+
hidden_states: torch.Tensor,
|
| 694 |
+
kv_cache: Optional[torch.Tensor] = None,
|
| 695 |
+
attn_metadata: Optional[AttentionMetadata] = None
|
| 696 |
+
) -> torch.Tensor:
|
| 697 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 698 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 699 |
+
q, k = self.rotary_emb(positions, q, k)
|
| 700 |
+
attn_output = self.attn(q, k, v)
|
| 701 |
+
output, _ = self.o_proj(attn_output)
|
| 702 |
+
return output
|
| 703 |
+
|
| 704 |
+
def _init_rotary_emb(self, config: PretrainedConfig,
|
| 705 |
+
rope_scaling: Optional[dict[str, Any]],
|
| 706 |
+
quant_config: Optional[QuantizationConfig]) -> None:
|
| 707 |
+
is_neox_style = True
|
| 708 |
+
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
| 709 |
+
if is_gguf and config.model_type == "Pangu":
|
| 710 |
+
is_neox_style = False
|
| 711 |
+
|
| 712 |
+
self.rotary_emb = get_rope(
|
| 713 |
+
self.head_dim,
|
| 714 |
+
rotary_dim=self.head_dim,
|
| 715 |
+
max_position=self.max_position_embeddings,
|
| 716 |
+
base=self.rope_theta,
|
| 717 |
+
rope_scaling=rope_scaling,
|
| 718 |
+
is_neox_style=is_neox_style,
|
| 719 |
+
#partial_rotary_factor=self.partial_rotary_factor,
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
class OpenPanguDecoderLayer(nn.Module):
|
| 724 |
+
|
| 725 |
+
def __init__(
|
| 726 |
+
self,
|
| 727 |
+
config: PretrainedConfig,
|
| 728 |
+
prefix: str,
|
| 729 |
+
model_config: ModelConfig,
|
| 730 |
+
cache_config: Optional[CacheConfig] = None,
|
| 731 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 732 |
+
) -> None:
|
| 733 |
+
super().__init__()
|
| 734 |
+
self.hidden_size = config.hidden_size
|
| 735 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
| 736 |
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
| 737 |
+
|
| 738 |
+
layer_idx = int(prefix.split(sep='.')[-1])
|
| 739 |
+
self.layer_idx = layer_idx
|
| 740 |
+
self.layers = config.num_hidden_layers
|
| 741 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 742 |
+
self.tp_rank = get_tp_group().rank_in_group
|
| 743 |
+
ascend_config = get_ascend_config()
|
| 744 |
+
|
| 745 |
+
self.use_mla = hasattr(config, 'attention_qk_dim') and hasattr(config, 'attention_qk_rope_dim') \
|
| 746 |
+
and hasattr(config, 'attention_v_dim') and hasattr(config, 'attention_kv_lora_dim')
|
| 747 |
+
if self.use_mla:
|
| 748 |
+
self.self_attn = OpenPanguMLAAttention(
|
| 749 |
+
config=config,
|
| 750 |
+
hidden_size=self.hidden_size,
|
| 751 |
+
num_heads=config.num_attention_heads,
|
| 752 |
+
attention_qk_dim=config.attention_qk_dim,
|
| 753 |
+
attention_qk_rope_dim=config.attention_qk_rope_dim,
|
| 754 |
+
attention_v_dim=config.attention_v_dim,
|
| 755 |
+
attention_q_lora_dim=config.attention_q_lora_dim
|
| 756 |
+
if hasattr(config, "attention_q_lora_dim") else None,
|
| 757 |
+
attention_kv_lora_dim=config.attention_kv_lora_dim,
|
| 758 |
+
rope_theta=rope_theta,
|
| 759 |
+
max_position_embeddings=max_position_embeddings,
|
| 760 |
+
cache_config=cache_config,
|
| 761 |
+
quant_config=quant_config,
|
| 762 |
+
prefix=f"{prefix}.self_attn",
|
| 763 |
+
)
|
| 764 |
+
else:
|
| 765 |
+
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
| 766 |
+
config, "bias", False)
|
| 767 |
+
bias_o_proj = attention_bias
|
| 768 |
+
if hasattr(config, 'qkv_bias'):
|
| 769 |
+
attention_bias = config.qkv_bias
|
| 770 |
+
# By default, PanguEmbedded uses causal attention as it is a decoder-only model.
|
| 771 |
+
# You can override the HF config with `is_causal=False` to enable
|
| 772 |
+
# bidirectional attention, which is used in some embedding models
|
| 773 |
+
if getattr(config, "is_causal", True):
|
| 774 |
+
attn_type = AttentionType.DECODER
|
| 775 |
+
else:
|
| 776 |
+
attn_type = AttentionType.ENCODER_ONLY
|
| 777 |
+
self.self_attn = OpenPanguEmbeddedAttention(
|
| 778 |
+
config=config,
|
| 779 |
+
hidden_size=self.hidden_size,
|
| 780 |
+
num_heads=config.num_attention_heads,
|
| 781 |
+
num_kv_heads=getattr(config, "num_key_value_heads", config.num_attention_heads),
|
| 782 |
+
rope_theta=rope_theta,
|
| 783 |
+
rope_scaling=getattr(config, "rope_scaling", None),
|
| 784 |
+
max_position_embeddings=max_position_embeddings,
|
| 785 |
+
quant_config=quant_config,
|
| 786 |
+
bias=attention_bias,
|
| 787 |
+
bias_o_proj=bias_o_proj,
|
| 788 |
+
cache_config=cache_config,
|
| 789 |
+
prefix=f"{prefix}.self_attn",
|
| 790 |
+
attn_type=attn_type,
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
if getattr(config, 'num_routed_experts', None) is not None and layer_idx >= config.num_dense_layers:
|
| 794 |
+
self.mlp = OpenPanguMoE(
|
| 795 |
+
config=config,
|
| 796 |
+
quant_config=quant_config,
|
| 797 |
+
prefix=f"{prefix}.mlp",
|
| 798 |
+
)
|
| 799 |
+
self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \
|
| 800 |
+
and model_config.use_mla and envs.VLLM_USE_V1 and self.tp_size > 1
|
| 801 |
+
else:
|
| 802 |
+
self.mlp = OpenPanguMLP(
|
| 803 |
+
hidden_size=self.hidden_size,
|
| 804 |
+
intermediate_size=config.intermediate_size,
|
| 805 |
+
hidden_act=config.hidden_act,
|
| 806 |
+
quant_config=quant_config,
|
| 807 |
+
bias=getattr(config, "mlp_bias", False),
|
| 808 |
+
prefix=f"{prefix}.mlp",
|
| 809 |
+
)
|
| 810 |
+
self.mla_moe_communication = False
|
| 811 |
+
self.routed_scaling_factor = getattr(config, 'routed_scaling_factor', None)
|
| 812 |
+
self.num_dense_layers = getattr(config, 'num_dense_layers', None)
|
| 813 |
+
|
| 814 |
+
self.input_layernorm = RMSNorm(config.hidden_size,
|
| 815 |
+
eps=config.rms_norm_eps)
|
| 816 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
| 817 |
+
eps=config.rms_norm_eps)
|
| 818 |
+
if getattr(config, 'sandwich_norm', False):
|
| 819 |
+
self.sandwich_norm = True
|
| 820 |
+
self.pre_mlp_layernorm = RMSNorm(config.hidden_size,
|
| 821 |
+
eps=config.rms_norm_eps)
|
| 822 |
+
self.post_mlp_layernorm = RMSNorm(config.hidden_size,
|
| 823 |
+
eps=config.rms_norm_eps)
|
| 824 |
+
else:
|
| 825 |
+
self.sandwich_norm = False
|
| 826 |
+
|
| 827 |
+
def forward(
|
| 828 |
+
self,
|
| 829 |
+
positions: torch.Tensor,
|
| 830 |
+
hidden_states: torch.Tensor,
|
| 831 |
+
residual: Optional[torch.Tensor],
|
| 832 |
+
kv_cache: Optional[torch.Tensor] = None,
|
| 833 |
+
attn_metadata: Optional[AttentionMetadata] = None,
|
| 834 |
+
replace_allreduce: bool = False,
|
| 835 |
+
) -> torch.Tensor:
|
| 836 |
+
# Self Attention
|
| 837 |
+
if self.use_mla and attn_metadata is not None and attn_metadata.num_decodes > 0:
|
| 838 |
+
mla_moe_communication = self.mla_moe_communication and replace_allreduce
|
| 839 |
+
else:
|
| 840 |
+
mla_moe_communication = False
|
| 841 |
+
if residual is None:
|
| 842 |
+
residual = hidden_states
|
| 843 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 844 |
+
else:
|
| 845 |
+
previous_hidden_states, previous_residual = hidden_states, residual
|
| 846 |
+
hidden_states, residual = self.input_layernorm(
|
| 847 |
+
hidden_states, residual)
|
| 848 |
+
# Dispose hidden_states and residual from the previous layer
|
| 849 |
+
# to save npu memory because they're no longer used.
|
| 850 |
+
dispose_tensor(previous_hidden_states)
|
| 851 |
+
dispose_tensor(previous_residual)
|
| 852 |
+
if mla_moe_communication and self.layer_idx > self.num_dense_layers:
|
| 853 |
+
hidden_states = tensor_model_parallel_all_gather(hidden_states,
|
| 854 |
+
dim=0)
|
| 855 |
+
|
| 856 |
+
hidden_states = self.self_attn(
|
| 857 |
+
positions=positions,
|
| 858 |
+
hidden_states=hidden_states,
|
| 859 |
+
kv_cache=kv_cache,
|
| 860 |
+
attn_metadata=attn_metadata,
|
| 861 |
+
)
|
| 862 |
+
|
| 863 |
+
if mla_moe_communication and residual.shape[0] != hidden_states.shape[0]:
|
| 864 |
+
chunk_hidden_states = torch.tensor_split(residual,
|
| 865 |
+
self.tp_size,
|
| 866 |
+
dim=0)
|
| 867 |
+
residual = chunk_hidden_states[self.tp_rank]
|
| 868 |
+
|
| 869 |
+
if self.routed_scaling_factor is not None and hidden_states.dtype == torch.float16:
|
| 870 |
+
# Fix FP16 overflow
|
| 871 |
+
# We scale both hidden_states and residual before
|
| 872 |
+
# rmsnorm, and rmsnorm result would not affect by scale.
|
| 873 |
+
hidden_states *= 1. / self.routed_scaling_factor
|
| 874 |
+
if self.layer_idx == 0:
|
| 875 |
+
# The residual is shared by all layers, we only scale it on
|
| 876 |
+
# first layer.
|
| 877 |
+
residual *= 1. / self.routed_scaling_factor
|
| 878 |
+
|
| 879 |
+
if self.sandwich_norm:
|
| 880 |
+
hidden_states = self.post_attention_layernorm(
|
| 881 |
+
hidden_states)
|
| 882 |
+
hidden_states, residual = self.pre_mlp_layernorm(
|
| 883 |
+
hidden_states, residual)
|
| 884 |
+
else:
|
| 885 |
+
hidden_states, residual = self.post_attention_layernorm(
|
| 886 |
+
hidden_states, residual)
|
| 887 |
+
|
| 888 |
+
# Fully Connected
|
| 889 |
+
if isinstance(self.mlp, OpenPanguMoE):
|
| 890 |
+
hidden_states = self.mlp(hidden_states,
|
| 891 |
+
attn_metadata,
|
| 892 |
+
replace_allreduce=mla_moe_communication)
|
| 893 |
+
else:
|
| 894 |
+
hidden_states = self.mlp(hidden_states)
|
| 895 |
+
|
| 896 |
+
if self.routed_scaling_factor is not None and isinstance(self.mlp, OpenPanguMLP) \
|
| 897 |
+
and hidden_states.dtype == torch.float16:
|
| 898 |
+
hidden_states *= 1. / self.routed_scaling_factor
|
| 899 |
+
|
| 900 |
+
if self.sandwich_norm:
|
| 901 |
+
hidden_states = self.post_mlp_layernorm(hidden_states)
|
| 902 |
+
|
| 903 |
+
if mla_moe_communication and self.layer_idx == self.layers - 1:
|
| 904 |
+
hidden_states = tensor_model_parallel_all_gather(hidden_states,
|
| 905 |
+
dim=0)
|
| 906 |
+
residual = tensor_model_parallel_all_gather(residual, dim=0)
|
| 907 |
+
|
| 908 |
+
return hidden_states, residual
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
@support_torch_compile
|
| 912 |
+
class OpenPanguModel(nn.Module):
|
| 913 |
+
|
| 914 |
+
fall_back_to_pt_during_load = False
|
| 915 |
+
|
| 916 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 917 |
+
super().__init__()
|
| 918 |
+
|
| 919 |
+
config = vllm_config.model_config.hf_config
|
| 920 |
+
model_config = vllm_config.model_config
|
| 921 |
+
cache_config = vllm_config.cache_config
|
| 922 |
+
quant_config = vllm_config.quant_config
|
| 923 |
+
|
| 924 |
+
self.padding_idx = config.pad_token_id
|
| 925 |
+
self.vocab_size = config.vocab_size
|
| 926 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 927 |
+
|
| 928 |
+
self.embed_tokens = VocabParallelEmbedding(
|
| 929 |
+
config.vocab_size,
|
| 930 |
+
config.hidden_size,
|
| 931 |
+
quant_config=quant_config,
|
| 932 |
+
prefix=f"{prefix}.embed_tokens")
|
| 933 |
+
|
| 934 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 935 |
+
config.num_hidden_layers,
|
| 936 |
+
lambda prefix: OpenPanguDecoderLayer(
|
| 937 |
+
config,
|
| 938 |
+
prefix,
|
| 939 |
+
model_config=model_config,
|
| 940 |
+
cache_config=cache_config,
|
| 941 |
+
quant_config=quant_config,
|
| 942 |
+
),
|
| 943 |
+
prefix=f"{prefix}.layers")
|
| 944 |
+
|
| 945 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 946 |
+
|
| 947 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 948 |
+
return self.embed_tokens(input_ids)
|
| 949 |
+
|
| 950 |
+
def forward(
|
| 951 |
+
self,
|
| 952 |
+
input_ids: torch.Tensor,
|
| 953 |
+
positions: torch.Tensor,
|
| 954 |
+
kv_caches: Optional[List[torch.Tensor]] = None,
|
| 955 |
+
attn_metadata: Optional[AttentionMetadata] = None,
|
| 956 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 957 |
+
**kwargs,
|
| 958 |
+
) -> torch.Tensor:
|
| 959 |
+
if inputs_embeds is not None:
|
| 960 |
+
hidden_states = inputs_embeds
|
| 961 |
+
else:
|
| 962 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 963 |
+
residual = None
|
| 964 |
+
|
| 965 |
+
replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
|
| 966 |
+
|
| 967 |
+
for i in range(self.start_layer, self.end_layer):
|
| 968 |
+
layer = self.layers[i]
|
| 969 |
+
hidden_states, residual = layer(
|
| 970 |
+
positions,
|
| 971 |
+
hidden_states,
|
| 972 |
+
residual,
|
| 973 |
+
kv_caches[i -
|
| 974 |
+
self.start_layer] if kv_caches is not None else None,
|
| 975 |
+
attn_metadata,
|
| 976 |
+
replace_allreduce=replace_allreduce)
|
| 977 |
+
|
| 978 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
| 979 |
+
return hidden_states
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
class OpenPanguForCausalLM(nn.Module):
|
| 983 |
+
packed_modules_mapping = {
|
| 984 |
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
| 985 |
+
"experts":
|
| 986 |
+
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
| 987 |
+
}
|
| 988 |
+
|
| 989 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 990 |
+
super().__init__()
|
| 991 |
+
config = vllm_config.model_config.hf_config
|
| 992 |
+
quant_config = vllm_config.quant_config
|
| 993 |
+
self.config = config
|
| 994 |
+
self.quant_config = quant_config
|
| 995 |
+
self.model = OpenPanguModel(vllm_config=vllm_config,
|
| 996 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 997 |
+
self.lm_head = ParallelLMHead(config.vocab_size,
|
| 998 |
+
config.hidden_size,
|
| 999 |
+
quant_config=quant_config,
|
| 1000 |
+
prefix=maybe_prefix(prefix, "lm_head"))
|
| 1001 |
+
self.logits_processor = LogitsProcessor(config.vocab_size)
|
| 1002 |
+
self.sampler = get_sampler()
|
| 1003 |
+
|
| 1004 |
+
def load_attn_mlp_weight(self,
|
| 1005 |
+
attn_mlp_replace_mapping: List[Tuple[str, str, int]],
|
| 1006 |
+
params_dict: Dict[str, Any],
|
| 1007 |
+
weight_name: str,
|
| 1008 |
+
loaded_weight: torch.Tensor,
|
| 1009 |
+
loaded_params: set[str]) -> bool:
|
| 1010 |
+
for (param_name, origin_name, shard_id) in attn_mlp_replace_mapping:
|
| 1011 |
+
if origin_name not in weight_name or \
|
| 1012 |
+
(("mlp.experts." in weight_name) and weight_name not in params_dict):
|
| 1013 |
+
continue
|
| 1014 |
+
weight_name = weight_name.replace(origin_name, param_name)
|
| 1015 |
+
if weight_name.endswith(".bias") and weight_name not in params_dict:
|
| 1016 |
+
continue
|
| 1017 |
+
param = params_dict[weight_name]
|
| 1018 |
+
weight_loader = param.weight_loader
|
| 1019 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 1020 |
+
loaded_params.add(weight_name)
|
| 1021 |
+
return True
|
| 1022 |
+
return False
|
| 1023 |
+
|
| 1024 |
+
def load_expert_weight(self,
|
| 1025 |
+
expert_merge_mapping: List[Tuple[str, str, int, str]],
|
| 1026 |
+
params_dict: Dict[str, Any],
|
| 1027 |
+
weight_name: str,
|
| 1028 |
+
loaded_weight: torch.Tensor,
|
| 1029 |
+
loaded_params: set[str]) -> bool:
|
| 1030 |
+
for mapping in expert_merge_mapping:
|
| 1031 |
+
param_name, origin_name, expert_id, shard_id = mapping
|
| 1032 |
+
if origin_name not in weight_name:
|
| 1033 |
+
continue
|
| 1034 |
+
weight_name = weight_name.replace(origin_name, param_name)
|
| 1035 |
+
param = params_dict[weight_name]
|
| 1036 |
+
weight_loader = param.weight_loader
|
| 1037 |
+
weight_loader(param,
|
| 1038 |
+
loaded_weight,
|
| 1039 |
+
weight_name,
|
| 1040 |
+
shard_id=shard_id,
|
| 1041 |
+
expert_id=expert_id,
|
| 1042 |
+
return_success=False)
|
| 1043 |
+
loaded_params.add(weight_name)
|
| 1044 |
+
return True
|
| 1045 |
+
return False
|
| 1046 |
+
|
| 1047 |
+
def load_weights(self, weights: Iterable[tuple[str,
|
| 1048 |
+
torch.Tensor]]) -> set[str]:
|
| 1049 |
+
# (param_name, shard_name, shard_id)
|
| 1050 |
+
attn_mlp_replace_mapping = [
|
| 1051 |
+
(".qkv_proj", ".q_proj", "q"),
|
| 1052 |
+
(".qkv_proj", ".k_proj", "k"),
|
| 1053 |
+
(".qkv_proj", ".v_proj", "v"),
|
| 1054 |
+
(".gate_up_proj", ".gate_proj", 0),
|
| 1055 |
+
(".gate_up_proj", ".up_proj", 1),
|
| 1056 |
+
]
|
| 1057 |
+
has_experts = hasattr(self.config, 'num_routed_experts')
|
| 1058 |
+
if has_experts:
|
| 1059 |
+
expert_merge_mapping = AscendFusedMoE.make_expert_params_mapping(
|
| 1060 |
+
ckpt_gate_proj_name="gate_proj",
|
| 1061 |
+
ckpt_down_proj_name="down_proj",
|
| 1062 |
+
ckpt_up_proj_name="up_proj",
|
| 1063 |
+
num_experts=self.config.num_routed_experts)
|
| 1064 |
+
|
| 1065 |
+
params_dict = dict(self.named_parameters())
|
| 1066 |
+
loaded_params: set[str] = set()
|
| 1067 |
+
for name, loaded_weight in weights:
|
| 1068 |
+
if "rotary_emb.inv_freq" in name:
|
| 1069 |
+
continue
|
| 1070 |
+
if 'layers' in name: # skip spec decode layers for main model
|
| 1071 |
+
layer_idx = int(name.split('layers.')[-1].split('.')[0])
|
| 1072 |
+
if layer_idx > self.config.num_hidden_layers:
|
| 1073 |
+
continue
|
| 1074 |
+
|
| 1075 |
+
if 'layers' in name and hasattr(self.config, "num_mtp_layers") \
|
| 1076 |
+
and (self.config.num_mtp_layers > 0):
|
| 1077 |
+
layer_idx = int(name.split('layers.')[-1].split('.')[0])
|
| 1078 |
+
mtp_idx = layer_idx - self.config.num_hidden_layers
|
| 1079 |
+
if mtp_idx >= 0 and mtp_idx < self.config.num_mtp_layers:
|
| 1080 |
+
continue # skip spec decode layers for main model
|
| 1081 |
+
if self.load_attn_mlp_weight(attn_mlp_replace_mapping, params_dict, name, loaded_weight, loaded_params):
|
| 1082 |
+
continue
|
| 1083 |
+
elif has_experts and self.load_expert_weight(expert_merge_mapping, params_dict, name, loaded_weight, loaded_params):
|
| 1084 |
+
continue
|
| 1085 |
+
else:
|
| 1086 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 1087 |
+
continue
|
| 1088 |
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
| 1089 |
+
if name is None:
|
| 1090 |
+
continue
|
| 1091 |
+
param = params_dict[name]
|
| 1092 |
+
weight_loader = getattr(param, "weight_loader",
|
| 1093 |
+
default_weight_loader)
|
| 1094 |
+
weight_loader(param, loaded_weight)
|
| 1095 |
+
loaded_params.add(name)
|
| 1096 |
+
if self.config.tie_word_embeddings:
|
| 1097 |
+
self.lm_head.weight = self.model.embed_tokens.weight
|
| 1098 |
+
return loaded_params
|
| 1099 |
+
|
| 1100 |
+
def forward(
|
| 1101 |
+
self,
|
| 1102 |
+
input_ids: torch.Tensor,
|
| 1103 |
+
positions: torch.Tensor,
|
| 1104 |
+
kv_caches: Optional[List[torch.Tensor]] = None,
|
| 1105 |
+
attn_metadata: Optional[AttentionMetadata] = None,
|
| 1106 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1107 |
+
**kwargs,
|
| 1108 |
+
) -> torch.Tensor:
|
| 1109 |
+
hidden_states = self.model(input_ids, positions, kv_caches,
|
| 1110 |
+
attn_metadata, inputs_embeds)
|
| 1111 |
+
return hidden_states
|
| 1112 |
+
|
| 1113 |
+
def compute_logits(
|
| 1114 |
+
self,
|
| 1115 |
+
hidden_states: torch.Tensor,
|
| 1116 |
+
sampling_metadata: SamplingMetadata,
|
| 1117 |
+
) -> Optional[torch.Tensor]:
|
| 1118 |
+
logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
|
| 1119 |
+
return logits
|
| 1120 |
+
|
| 1121 |
+
|
| 1122 |
+
class PanguUltraMoEForCausalLM(OpenPanguForCausalLM):
|
| 1123 |
+
pass
|
| 1124 |
+
|
| 1125 |
+
|
| 1126 |
+
class PanguEmbeddedForCausalLM(OpenPanguForCausalLM):
|
| 1127 |
+
pass
|
inference/vllm_ascend/ops/fused_moe.py
ADDED
|
@@ -0,0 +1,1530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 2 |
+
# Copyright 2023 The vLLM team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# This file is a part of the vllm-ascend project.
|
| 16 |
+
# Adapted from vllm/tests/kernels/test_moe.py
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
import torch_npu
|
| 24 |
+
from torch import nn
|
| 25 |
+
from vllm.config import get_current_vllm_config
|
| 26 |
+
from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
|
| 27 |
+
get_tensor_model_parallel_world_size,
|
| 28 |
+
tensor_model_parallel_all_reduce)
|
| 29 |
+
from vllm.distributed.parallel_state import get_dp_group, get_tp_group
|
| 30 |
+
from vllm.forward_context import get_forward_context
|
| 31 |
+
from vllm.model_executor.layers.fused_moe.config import \
|
| 32 |
+
FusedMoEConfig # isort: skip
|
| 33 |
+
from vllm.model_executor.layers.fused_moe.config import \
|
| 34 |
+
FusedMoEParallelConfig # isort: skip
|
| 35 |
+
from vllm.model_executor.layers.fused_moe.layer import (
|
| 36 |
+
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
| 37 |
+
from vllm.model_executor.layers.quantization.base_config import \
|
| 38 |
+
QuantizationConfig
|
| 39 |
+
|
| 40 |
+
import vllm_ascend.envs as envs_ascend
|
| 41 |
+
from vllm_ascend.ascend_config import get_ascend_config
|
| 42 |
+
from vllm_ascend.distributed.communication_op import \
|
| 43 |
+
data_parallel_reduce_scatter
|
| 44 |
+
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
|
| 45 |
+
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
| 46 |
+
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
|
| 47 |
+
get_all_reduce_merge_state, get_fused_moe_state,
|
| 48 |
+
get_rm_router_logits_state, is_310p,
|
| 49 |
+
npu_stream_switch, npu_wait_tensor)
|
| 50 |
+
|
| 51 |
+
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
| 52 |
+
SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
| 56 |
+
max_row_per_ep_rank: int, num_tokens: int,
|
| 57 |
+
top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
| 58 |
+
original_total_elements = num_tokens * top_k
|
| 59 |
+
device = topk_ids.device
|
| 60 |
+
original_dtype = topk_ids.dtype
|
| 61 |
+
|
| 62 |
+
if original_total_elements == 0:
|
| 63 |
+
output_len = ep_size * max_row_per_ep_rank
|
| 64 |
+
topk_ids_pad = torch.full((output_len, ),
|
| 65 |
+
expert_num,
|
| 66 |
+
dtype=original_dtype,
|
| 67 |
+
device=device)
|
| 68 |
+
unpad_indices = torch.full((original_total_elements, ),
|
| 69 |
+
-1,
|
| 70 |
+
dtype=torch.long,
|
| 71 |
+
device=device)
|
| 72 |
+
return topk_ids_pad, unpad_indices
|
| 73 |
+
|
| 74 |
+
experts_per_ep_rank_val = expert_num // ep_size
|
| 75 |
+
if experts_per_ep_rank_val == 0:
|
| 76 |
+
raise ValueError(
|
| 77 |
+
"expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. "
|
| 78 |
+
"Ensure expert_num >= ep_size.")
|
| 79 |
+
|
| 80 |
+
assigned_ep_rank = (topk_ids.float() /
|
| 81 |
+
experts_per_ep_rank_val).to(original_dtype)
|
| 82 |
+
indices_arange = torch.arange(topk_ids.shape[0], device=device)
|
| 83 |
+
|
| 84 |
+
is_new_segment = torch.cat(
|
| 85 |
+
(torch.tensor([True], device=device), assigned_ep_rank[1:]
|
| 86 |
+
!= assigned_ep_rank[:-1]))
|
| 87 |
+
temp_start_markers = torch.full_like(indices_arange,
|
| 88 |
+
-1,
|
| 89 |
+
dtype=indices_arange.dtype)
|
| 90 |
+
temp_start_markers[is_new_segment] = indices_arange[is_new_segment]
|
| 91 |
+
start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0]
|
| 92 |
+
token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token
|
| 93 |
+
is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank
|
| 94 |
+
cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long)
|
| 95 |
+
indices_in_rec_cond_list_for_all = cumsum_kept - 1
|
| 96 |
+
unpad_indices = torch.where(
|
| 97 |
+
is_kept_mask, indices_in_rec_cond_list_for_all,
|
| 98 |
+
torch.tensor(-1, device=device, dtype=torch.long))
|
| 99 |
+
output_len = ep_size * max_row_per_ep_rank
|
| 100 |
+
topk_ids_pad = torch.full((output_len, ),
|
| 101 |
+
expert_num,
|
| 102 |
+
dtype=original_dtype,
|
| 103 |
+
device=device)
|
| 104 |
+
if topk_ids.shape[0] > 0:
|
| 105 |
+
all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx
|
| 106 |
+
temp_pad_buffer = torch.full((output_len + 1, ),
|
| 107 |
+
expert_num,
|
| 108 |
+
dtype=original_dtype,
|
| 109 |
+
device=device)
|
| 110 |
+
output_len_tensor = torch.tensor(output_len,
|
| 111 |
+
dtype=torch.long,
|
| 112 |
+
device=device)
|
| 113 |
+
scatter_indices = torch.where(is_kept_mask, all_destination_indices,
|
| 114 |
+
output_len_tensor)
|
| 115 |
+
temp_pad_buffer.scatter_(0, scatter_indices, topk_ids)
|
| 116 |
+
topk_ids_pad = temp_pad_buffer[:output_len]
|
| 117 |
+
return topk_ids_pad, unpad_indices
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def fused_experts_with_mc2(
|
| 121 |
+
hidden_states: torch.Tensor,
|
| 122 |
+
w1: torch.Tensor,
|
| 123 |
+
w2: torch.Tensor,
|
| 124 |
+
topk_weights: torch.Tensor,
|
| 125 |
+
topk_ids: torch.Tensor,
|
| 126 |
+
top_k: int,
|
| 127 |
+
expert_map: torch.Tensor = None,
|
| 128 |
+
moe_all_to_all_group_name: Optional[str] = None,
|
| 129 |
+
shared_experts: Optional[Any] = None
|
| 130 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 131 |
+
global_bs = 0
|
| 132 |
+
moe_expert_num = len(expert_map)
|
| 133 |
+
kwargs_mc2 = {
|
| 134 |
+
"x": hidden_states,
|
| 135 |
+
"expert_ids": topk_ids,
|
| 136 |
+
"expert_shard_type": 0,
|
| 137 |
+
"shared_expert_rank_num": 0,
|
| 138 |
+
"moe_expert_num": moe_expert_num,
|
| 139 |
+
"global_bs": global_bs,
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
rank = torch.distributed.get_rank()
|
| 143 |
+
|
| 144 |
+
quant_mode = 0
|
| 145 |
+
ep_group = get_ep_group().device_group
|
| 146 |
+
local_rank = torch.distributed.get_rank(group=ep_group)
|
| 147 |
+
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
|
| 148 |
+
|
| 149 |
+
tp_size = get_etp_group().world_size
|
| 150 |
+
tp_rank = rank % tp_size
|
| 151 |
+
|
| 152 |
+
stage1_kwargs = {
|
| 153 |
+
"scales": None,
|
| 154 |
+
"quant_mode": quant_mode,
|
| 155 |
+
"group_ep": moe_all_to_all_group_name,
|
| 156 |
+
"ep_world_size": all_to_all_group_size,
|
| 157 |
+
"ep_rank_id": local_rank,
|
| 158 |
+
# "group_tp": self.moe_rs_group_name,
|
| 159 |
+
"group_tp": moe_all_to_all_group_name,
|
| 160 |
+
"tp_world_size": tp_size,
|
| 161 |
+
"tp_rank_id": tp_rank,
|
| 162 |
+
}
|
| 163 |
+
kwargs_mc2.update(stage1_kwargs)
|
| 164 |
+
|
| 165 |
+
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
| 166 |
+
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
| 167 |
+
0:5]
|
| 168 |
+
|
| 169 |
+
if shared_experts is not None:
|
| 170 |
+
with npu_stream_switch("moe_secondary", 0):
|
| 171 |
+
npu_wait_tensor(hidden_states, topk_weights)
|
| 172 |
+
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
| 173 |
+
npu_wait_tensor(shared_gate_up, expand_x)
|
| 174 |
+
shared_act = shared_experts.act_fn(shared_gate_up)
|
| 175 |
+
|
| 176 |
+
w1 = w1.transpose(1, 2)
|
| 177 |
+
|
| 178 |
+
group_list = expert_token_nums.to(torch.int64)
|
| 179 |
+
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
| 180 |
+
x=[expand_x],
|
| 181 |
+
weight=[w1],
|
| 182 |
+
split_item=2,
|
| 183 |
+
# 1 means count mode, to avoid cumulative operation of the group list
|
| 184 |
+
group_list_type=1,
|
| 185 |
+
group_type=0,
|
| 186 |
+
group_list=group_list,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# TODO: Remove this in the future.
|
| 190 |
+
gate_up_out = torch.cat(gate_up_out_list, dim=0)
|
| 191 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
| 192 |
+
|
| 193 |
+
w2 = w2.transpose(1, 2)
|
| 194 |
+
down_out_list = torch_npu.npu_grouped_matmul(
|
| 195 |
+
x=[gate_up_out],
|
| 196 |
+
weight=[w2],
|
| 197 |
+
split_item=2,
|
| 198 |
+
group_list_type=1,
|
| 199 |
+
group_type=0,
|
| 200 |
+
group_list=group_list,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
down_out_list = torch.cat(down_out_list, dim=0)
|
| 204 |
+
|
| 205 |
+
# moeCombine
|
| 206 |
+
kwargs_mc2 = {
|
| 207 |
+
"expand_x": down_out_list,
|
| 208 |
+
"expert_ids": topk_ids,
|
| 209 |
+
"expand_idx": expand_idx,
|
| 210 |
+
"expert_scales": topk_weights.to(torch.float32),
|
| 211 |
+
"expert_shard_type": 0,
|
| 212 |
+
"shared_expert_rank_num": 0,
|
| 213 |
+
"moe_expert_num": moe_expert_num,
|
| 214 |
+
"global_bs": 0,
|
| 215 |
+
}
|
| 216 |
+
tp_recv_counts = output[5]
|
| 217 |
+
stage3_kwargs = {
|
| 218 |
+
"ep_send_counts": ep_recv_counts,
|
| 219 |
+
"group_ep": moe_all_to_all_group_name,
|
| 220 |
+
"ep_world_size": all_to_all_group_size,
|
| 221 |
+
"ep_rank_id": local_rank,
|
| 222 |
+
"tp_send_counts": tp_recv_counts,
|
| 223 |
+
# "group_tp": self.moe_rs_group_name,
|
| 224 |
+
"group_tp": moe_all_to_all_group_name,
|
| 225 |
+
"tp_world_size": tp_size,
|
| 226 |
+
"tp_rank_id": tp_rank,
|
| 227 |
+
}
|
| 228 |
+
kwargs_mc2.update(stage3_kwargs)
|
| 229 |
+
|
| 230 |
+
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
| 231 |
+
|
| 232 |
+
if shared_experts is None:
|
| 233 |
+
return hidden_states
|
| 234 |
+
else:
|
| 235 |
+
with npu_stream_switch("moe_secondary", 0):
|
| 236 |
+
npu_wait_tensor(shared_act, down_out_list)
|
| 237 |
+
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
|
| 238 |
+
return hidden_states, shared_hidden_states
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
| 242 |
+
w1: torch.Tensor,
|
| 243 |
+
w2: torch.Tensor,
|
| 244 |
+
group_list: torch.Tensor,
|
| 245 |
+
group_list_type: int = 1) -> torch.Tensor:
|
| 246 |
+
"""
|
| 247 |
+
apply MLP: gate_up_proj -> swiglu -> down_proj
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
|
| 251 |
+
w1: expert weights1 with shape
|
| 252 |
+
(num_experts, hidden_size, intermediate_size * 2)
|
| 253 |
+
w2: expert weights2 with shape
|
| 254 |
+
(num_experts, intermediate_size, hidden_size)
|
| 255 |
+
group_list: number of tokens for each expert, follow cumsum mode, and
|
| 256 |
+
with shape (num_experts).
|
| 257 |
+
transpose_weight:
|
| 258 |
+
w1: (num_experts, intermediate_size * 2, hidden_size) ->
|
| 259 |
+
(num_experts, hidden_size, intermediate_size * 2)
|
| 260 |
+
w2: (num_experts, hidden_size, intermediate_size) ->
|
| 261 |
+
(num_experts, intermediate_size, hidden_size)
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
hidden_states: output hidden states after MLP.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
assert len(hidden_states_wrapper) == 1
|
| 268 |
+
hidden_states = hidden_states_wrapper.pop()
|
| 269 |
+
|
| 270 |
+
w1 = w1.transpose(1, 2)
|
| 271 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 272 |
+
x=[hidden_states],
|
| 273 |
+
weight=[w1],
|
| 274 |
+
split_item=2,
|
| 275 |
+
group_list_type=group_list_type,
|
| 276 |
+
group_type=0,
|
| 277 |
+
group_list=group_list,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
hidden_states = torch.cat(hidden_states, dim=0)
|
| 281 |
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
| 282 |
+
|
| 283 |
+
w2 = w2.transpose(1, 2)
|
| 284 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 285 |
+
x=[hidden_states],
|
| 286 |
+
weight=[w2],
|
| 287 |
+
split_item=2,
|
| 288 |
+
group_list_type=group_list_type,
|
| 289 |
+
group_type=0,
|
| 290 |
+
group_list=group_list,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
hidden_states = torch.cat(hidden_states, dim=0)
|
| 294 |
+
return hidden_states
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def fused_experts_with_all2all(
|
| 298 |
+
hidden_states: torch.Tensor,
|
| 299 |
+
w1: torch.Tensor,
|
| 300 |
+
w2: torch.Tensor,
|
| 301 |
+
topk_weights: torch.Tensor,
|
| 302 |
+
topk_ids: torch.Tensor,
|
| 303 |
+
top_k: int,
|
| 304 |
+
expert_map: torch.Tensor = None,
|
| 305 |
+
ep_group: GroupCoordinator = None,
|
| 306 |
+
):
|
| 307 |
+
original_shape = hidden_states.shape
|
| 308 |
+
if len(original_shape) == 3:
|
| 309 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 310 |
+
|
| 311 |
+
num_tokens, _ = hidden_states.shape
|
| 312 |
+
num_experts = w1.shape[0]
|
| 313 |
+
device = hidden_states.device
|
| 314 |
+
|
| 315 |
+
if expert_map is not None:
|
| 316 |
+
global_num_experts = len(expert_map)
|
| 317 |
+
local_num_experts = global_num_experts // ep_group.world_size
|
| 318 |
+
row_idx_len = num_tokens * top_k
|
| 319 |
+
row_idx = (torch.arange(0,
|
| 320 |
+
row_idx_len,
|
| 321 |
+
dtype=torch.int32,
|
| 322 |
+
device=device).view(top_k, -1).permute(
|
| 323 |
+
1, 0).contiguous())
|
| 324 |
+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 325 |
+
hidden_states,
|
| 326 |
+
row_idx=row_idx,
|
| 327 |
+
expert_idx=topk_ids,
|
| 328 |
+
active_num=num_tokens)
|
| 329 |
+
|
| 330 |
+
global_expert_tokens = torch.bincount(expanded_expert_idx,
|
| 331 |
+
minlength=global_num_experts)
|
| 332 |
+
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
|
| 333 |
+
-1).sum(-1)
|
| 334 |
+
|
| 335 |
+
gather_sizes = torch.empty_like(scatter_sizes)
|
| 336 |
+
dist.all_to_all_single(gather_sizes,
|
| 337 |
+
scatter_sizes,
|
| 338 |
+
group=ep_group.device_group)
|
| 339 |
+
scatter_size_list = scatter_sizes.cpu().tolist()
|
| 340 |
+
gather_size_list = gather_sizes.cpu().tolist()
|
| 341 |
+
|
| 342 |
+
expanded_expert_idx = expanded_expert_idx % local_num_experts
|
| 343 |
+
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
| 344 |
+
scatter_size_list,
|
| 345 |
+
gather_size_list)
|
| 346 |
+
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
|
| 347 |
+
scatter_size_list,
|
| 348 |
+
gather_size_list)
|
| 349 |
+
|
| 350 |
+
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
|
| 351 |
+
|
| 352 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 353 |
+
sorted_local_expert_idx, local_num_experts).to(torch.int64)
|
| 354 |
+
|
| 355 |
+
hidden_states = hidden_states[sorted_idx]
|
| 356 |
+
else:
|
| 357 |
+
row_idx_len = num_tokens * top_k
|
| 358 |
+
row_idx = torch.arange(0,
|
| 359 |
+
row_idx_len,
|
| 360 |
+
dtype=torch.int32,
|
| 361 |
+
device=topk_weights.device).view(
|
| 362 |
+
top_k, -1).permute(1, 0).contiguous()
|
| 363 |
+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 364 |
+
hidden_states,
|
| 365 |
+
row_idx=row_idx,
|
| 366 |
+
expert_idx=topk_ids,
|
| 367 |
+
active_num=num_tokens)
|
| 368 |
+
|
| 369 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 370 |
+
expanded_expert_idx, num_experts)
|
| 371 |
+
expert_tokens = expert_tokens.to(torch.int64)
|
| 372 |
+
|
| 373 |
+
w1 = w1.transpose(1, 2)
|
| 374 |
+
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
| 375 |
+
x=[hidden_states],
|
| 376 |
+
weight=[w1],
|
| 377 |
+
split_item=2,
|
| 378 |
+
group_list_type=0,
|
| 379 |
+
group_type=0,
|
| 380 |
+
group_list=expert_tokens,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# TODO: Remove this in the future.
|
| 384 |
+
hidden_states = torch.cat(gate_up_out_list, dim=0)
|
| 385 |
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
| 386 |
+
|
| 387 |
+
w2 = w2.transpose(1, 2)
|
| 388 |
+
down_out_list = torch_npu.npu_grouped_matmul(
|
| 389 |
+
x=[hidden_states],
|
| 390 |
+
weight=[w2],
|
| 391 |
+
split_item=2,
|
| 392 |
+
group_list_type=0,
|
| 393 |
+
group_type=0,
|
| 394 |
+
group_list=expert_tokens,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
hidden_states = torch.cat(down_out_list, dim=0)
|
| 398 |
+
|
| 399 |
+
if expert_map is not None:
|
| 400 |
+
resorted_idx = torch.argsort(sorted_idx)
|
| 401 |
+
hidden_states = hidden_states[resorted_idx]
|
| 402 |
+
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
| 403 |
+
gather_size_list,
|
| 404 |
+
scatter_size_list)
|
| 405 |
+
|
| 406 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 407 |
+
hidden_states,
|
| 408 |
+
skip1=None,
|
| 409 |
+
skip2=None,
|
| 410 |
+
bias=None,
|
| 411 |
+
scales=topk_weights,
|
| 412 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 413 |
+
export_for_source_row=topk_ids,
|
| 414 |
+
)
|
| 415 |
+
else:
|
| 416 |
+
# TODO: Reorder device memory 2 times here, replace the current
|
| 417 |
+
# implementation here when suitable operators become available.
|
| 418 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 419 |
+
hidden_states,
|
| 420 |
+
skip1=None,
|
| 421 |
+
skip2=None,
|
| 422 |
+
bias=None,
|
| 423 |
+
scales=topk_weights,
|
| 424 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 425 |
+
export_for_source_row=topk_ids,
|
| 426 |
+
)
|
| 427 |
+
if len(original_shape) == 3:
|
| 428 |
+
final_hidden_states = final_hidden_states.view(original_shape)
|
| 429 |
+
return final_hidden_states
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
# currently expert parallelism implemented with all2all
|
| 433 |
+
# is under-optimized.
|
| 434 |
+
def fused_experts_with_all2all_buffer(
|
| 435 |
+
hidden_states: torch.Tensor,
|
| 436 |
+
w1: torch.Tensor,
|
| 437 |
+
w2: torch.Tensor,
|
| 438 |
+
topk_weights: torch.Tensor,
|
| 439 |
+
topk_ids: torch.Tensor,
|
| 440 |
+
top_k: int,
|
| 441 |
+
max_model_len: int,
|
| 442 |
+
global_batch_size: int,
|
| 443 |
+
expert_map: torch.Tensor = None,
|
| 444 |
+
ep_group: GroupCoordinator = None,
|
| 445 |
+
):
|
| 446 |
+
original_shape = hidden_states.shape
|
| 447 |
+
if len(original_shape) == 3:
|
| 448 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 449 |
+
|
| 450 |
+
num_tokens, _ = hidden_states.shape
|
| 451 |
+
device = hidden_states.device
|
| 452 |
+
|
| 453 |
+
global_num_experts = len(expert_map)
|
| 454 |
+
local_num_experts = global_num_experts // ep_group.world_size
|
| 455 |
+
row_idx_len = num_tokens * top_k
|
| 456 |
+
row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32,
|
| 457 |
+
device=device).view(top_k,
|
| 458 |
+
-1).permute(1, 0).contiguous())
|
| 459 |
+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 460 |
+
hidden_states,
|
| 461 |
+
row_idx=row_idx,
|
| 462 |
+
expert_idx=topk_ids,
|
| 463 |
+
active_num=num_tokens)
|
| 464 |
+
|
| 465 |
+
max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) *
|
| 466 |
+
max_model_len // ep_group.world_size +
|
| 467 |
+
1) * top_k * 2
|
| 468 |
+
expert_idx_buffer_scatter, unpad_indices = process_topk_ids(
|
| 469 |
+
expanded_expert_idx, global_num_experts, ep_group.world_size,
|
| 470 |
+
max_row_per_ep_rank, num_tokens, top_k)
|
| 471 |
+
hidden_states_pad_idx = torch.zeros(
|
| 472 |
+
expert_idx_buffer_scatter.shape,
|
| 473 |
+
dtype=expert_idx_buffer_scatter.dtype,
|
| 474 |
+
device=expert_idx_buffer_scatter.device)
|
| 475 |
+
non_pad_len = torch.sum((expert_idx_buffer_scatter
|
| 476 |
+
!= global_num_experts).to(torch.int32))
|
| 477 |
+
hidden_states_pad_idx[expert_idx_buffer_scatter !=
|
| 478 |
+
global_num_experts] = torch.arange(
|
| 479 |
+
non_pad_len,
|
| 480 |
+
dtype=expert_idx_buffer_scatter.dtype,
|
| 481 |
+
device=hidden_states.device)
|
| 482 |
+
|
| 483 |
+
hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx]
|
| 484 |
+
expert_idx_buffer_gather = torch.empty_like(
|
| 485 |
+
expert_idx_buffer_scatter,
|
| 486 |
+
dtype=expert_idx_buffer_scatter.dtype,
|
| 487 |
+
device=expert_idx_buffer_scatter.device)
|
| 488 |
+
hidden_states_buffer_gather = torch.empty_like(
|
| 489 |
+
hidden_states_buffer_scatter,
|
| 490 |
+
dtype=hidden_states_buffer_scatter.dtype,
|
| 491 |
+
device=hidden_states_buffer_scatter.device)
|
| 492 |
+
dist.all_to_all_single(expert_idx_buffer_gather,
|
| 493 |
+
expert_idx_buffer_scatter,
|
| 494 |
+
group=ep_group.device_group)
|
| 495 |
+
dist.all_to_all_single(hidden_states_buffer_gather,
|
| 496 |
+
hidden_states_buffer_scatter,
|
| 497 |
+
group=ep_group.device_group)
|
| 498 |
+
mask = expert_idx_buffer_gather != global_num_experts
|
| 499 |
+
local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * (
|
| 500 |
+
global_num_experts // ep_group.world_size)
|
| 501 |
+
hidden_states = hidden_states_buffer_gather[mask]
|
| 502 |
+
idx_type = local_expert_idx.dtype
|
| 503 |
+
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float())
|
| 504 |
+
sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type)
|
| 505 |
+
|
| 506 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 507 |
+
sorted_local_expert_idx, local_num_experts).to(torch.int64)
|
| 508 |
+
hidden_states = hidden_states[sorted_idx]
|
| 509 |
+
group_list_type = 0
|
| 510 |
+
|
| 511 |
+
hidden_states_wrapper = [hidden_states]
|
| 512 |
+
del hidden_states
|
| 513 |
+
|
| 514 |
+
hidden_states = apply_mlp(hidden_states_wrapper,
|
| 515 |
+
w1,
|
| 516 |
+
w2,
|
| 517 |
+
expert_tokens,
|
| 518 |
+
group_list_type=group_list_type)
|
| 519 |
+
|
| 520 |
+
resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype)
|
| 521 |
+
hidden_states = hidden_states[resorted_idx]
|
| 522 |
+
hidden_states_scatter = torch.zeros(
|
| 523 |
+
(mask.shape[0], hidden_states.shape[1]),
|
| 524 |
+
dtype=hidden_states.dtype,
|
| 525 |
+
device=hidden_states.device)
|
| 526 |
+
hidden_states_scatter[mask] = hidden_states
|
| 527 |
+
hidden_states_gatter = torch.empty_like(
|
| 528 |
+
hidden_states_scatter,
|
| 529 |
+
dtype=hidden_states_scatter.dtype,
|
| 530 |
+
device=hidden_states_scatter.device)
|
| 531 |
+
dist.all_to_all_single(hidden_states_gatter,
|
| 532 |
+
hidden_states_scatter,
|
| 533 |
+
group=ep_group.device_group)
|
| 534 |
+
hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter !=
|
| 535 |
+
global_num_experts]
|
| 536 |
+
if hidden_states_gatter.shape[0] != row_idx_len:
|
| 537 |
+
hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]),
|
| 538 |
+
dtype=hidden_states.dtype,
|
| 539 |
+
device=hidden_states.device)
|
| 540 |
+
hidden_states[unpad_indices != -1] = hidden_states_gatter
|
| 541 |
+
else:
|
| 542 |
+
# TODO: Reorder device memory 2 times here, replace the current
|
| 543 |
+
hidden_states = hidden_states_gatter
|
| 544 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 545 |
+
hidden_states,
|
| 546 |
+
skip1=None,
|
| 547 |
+
skip2=None,
|
| 548 |
+
bias=None,
|
| 549 |
+
scales=topk_weights,
|
| 550 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 551 |
+
export_for_source_row=topk_ids,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
if len(original_shape) == 3:
|
| 555 |
+
final_hidden_states = final_hidden_states.view(original_shape)
|
| 556 |
+
return final_hidden_states
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def fused_experts_moge(
|
| 560 |
+
hidden_states: torch.Tensor,
|
| 561 |
+
w1: torch.Tensor,
|
| 562 |
+
w2: torch.Tensor,
|
| 563 |
+
topk_weights: torch.Tensor,
|
| 564 |
+
topk_ids: torch.Tensor,
|
| 565 |
+
top_k: int,
|
| 566 |
+
global_num_experts: int,
|
| 567 |
+
expert_map: torch.Tensor = None,
|
| 568 |
+
apply_router_weight_on_input: bool = False,
|
| 569 |
+
) -> torch.Tensor:
|
| 570 |
+
"""
|
| 571 |
+
|
| 572 |
+
Args:
|
| 573 |
+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
| 574 |
+
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
| 575 |
+
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
| 576 |
+
topk_weights: Routing weights of shape (num_tokens, top_k).
|
| 577 |
+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
| 578 |
+
top_k: Number of experts to select.
|
| 579 |
+
expert_map: Expert mapping of shape (num_experts,).
|
| 580 |
+
|
| 581 |
+
Returns:
|
| 582 |
+
hidden_states: Hidden states after routing.
|
| 583 |
+
"""
|
| 584 |
+
ep_size = get_ep_group().world_size
|
| 585 |
+
local_num_experts = global_num_experts // ep_size
|
| 586 |
+
local_num_group = top_k // ep_size
|
| 587 |
+
|
| 588 |
+
if apply_router_weight_on_input:
|
| 589 |
+
assert (topk_weights.dim() == 2
|
| 590 |
+
), "`topk_weights` should be in shape (num_tokens, topk)"
|
| 591 |
+
_, topk = topk_weights.shape
|
| 592 |
+
assert (
|
| 593 |
+
topk == 1
|
| 594 |
+
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
| 595 |
+
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
| 596 |
+
|
| 597 |
+
bsz, _ = hidden_states.shape
|
| 598 |
+
flatten_topk_ids = topk_ids.view(-1)
|
| 599 |
+
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
| 600 |
+
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
| 601 |
+
sorted_hidden_states = hidden_states.index_select(
|
| 602 |
+
0, sorted_topk_ids // local_num_group)
|
| 603 |
+
|
| 604 |
+
experts_id = torch.arange(0,
|
| 605 |
+
local_num_experts,
|
| 606 |
+
dtype=topk_ids.dtype,
|
| 607 |
+
device=topk_ids.device)
|
| 608 |
+
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
| 609 |
+
torch.float32).sum(0)
|
| 610 |
+
topk_scales = topk_weights.view(-1).index_select(
|
| 611 |
+
0, sorted_topk_ids).unsqueeze(-1)
|
| 612 |
+
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
| 613 |
+
|
| 614 |
+
w1 = w1.transpose(1, 2)
|
| 615 |
+
gate_up_out = torch_npu.npu_grouped_matmul(
|
| 616 |
+
x=[sorted_hidden_states],
|
| 617 |
+
weight=[w1],
|
| 618 |
+
split_item=2,
|
| 619 |
+
group_list_type=0,
|
| 620 |
+
group_type=0,
|
| 621 |
+
group_list=group_list,
|
| 622 |
+
)[0]
|
| 623 |
+
|
| 624 |
+
if is_310p():
|
| 625 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
| 626 |
+
torch.float16)
|
| 627 |
+
else:
|
| 628 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
| 629 |
+
gate_up_out *= topk_scales
|
| 630 |
+
|
| 631 |
+
w2 = w2.transpose(1, 2)
|
| 632 |
+
down_out_list = torch_npu.npu_grouped_matmul(
|
| 633 |
+
x=[gate_up_out],
|
| 634 |
+
weight=[w2],
|
| 635 |
+
split_item=2,
|
| 636 |
+
group_list_type=0,
|
| 637 |
+
group_type=0,
|
| 638 |
+
group_list=group_list,
|
| 639 |
+
)[0]
|
| 640 |
+
|
| 641 |
+
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
| 642 |
+
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
|
| 643 |
+
final_hidden_states = unsorted_hidden_states.reshape(
|
| 644 |
+
bsz, top_k // ep_size, -1).sum(1)
|
| 645 |
+
|
| 646 |
+
return final_hidden_states
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def fused_experts(
|
| 650 |
+
hidden_states: torch.Tensor,
|
| 651 |
+
w1: torch.Tensor,
|
| 652 |
+
w2: torch.Tensor,
|
| 653 |
+
topk_weights: torch.Tensor,
|
| 654 |
+
topk_ids: torch.Tensor,
|
| 655 |
+
top_k: int,
|
| 656 |
+
expert_map: torch.Tensor = None,
|
| 657 |
+
apply_router_weight_on_input: bool = False,
|
| 658 |
+
max_num_tokens: Optional[int] = None,
|
| 659 |
+
) -> torch.Tensor:
|
| 660 |
+
"""
|
| 661 |
+
Fused experts with top-k routing.
|
| 662 |
+
|
| 663 |
+
Args:
|
| 664 |
+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
| 665 |
+
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
| 666 |
+
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
| 667 |
+
topk_weights: Routing weights of shape (num_tokens, top_k).
|
| 668 |
+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
| 669 |
+
top_k: Number of experts to select.
|
| 670 |
+
expert_map: Expert mapping of shape (num_experts,).
|
| 671 |
+
|
| 672 |
+
Returns:
|
| 673 |
+
hidden_states: Hidden states after routing.
|
| 674 |
+
"""
|
| 675 |
+
"""
|
| 676 |
+
# Check constraints.
|
| 677 |
+
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
| 678 |
+
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
| 679 |
+
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
| 680 |
+
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
| 681 |
+
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
| 682 |
+
"""
|
| 683 |
+
# if torch.distributed.get_rank() == 0:
|
| 684 |
+
# print(w1.shape)
|
| 685 |
+
# print(hidden_states.shape)
|
| 686 |
+
|
| 687 |
+
original_shape = hidden_states.shape
|
| 688 |
+
# assert len(original_shape) == 2
|
| 689 |
+
|
| 690 |
+
num_tokens = hidden_states.shape[:-1].numel()
|
| 691 |
+
num_experts = w1.shape[0]
|
| 692 |
+
dtype = hidden_states.dtype
|
| 693 |
+
device = hidden_states.device
|
| 694 |
+
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
|
| 695 |
+
# ], "Only float32, float16, and bfloat16 are supported"
|
| 696 |
+
|
| 697 |
+
if apply_router_weight_on_input:
|
| 698 |
+
assert (topk_weights.dim() == 2
|
| 699 |
+
), "`topk_weights` should be in shape (num_tokens, topk)"
|
| 700 |
+
_, topk = topk_weights.shape
|
| 701 |
+
assert (
|
| 702 |
+
topk == 1
|
| 703 |
+
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
| 704 |
+
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
| 705 |
+
|
| 706 |
+
if expert_map is not None:
|
| 707 |
+
# Generate token indices and flatten
|
| 708 |
+
token_indices = (torch.arange(num_tokens,
|
| 709 |
+
device=device,
|
| 710 |
+
dtype=torch.int64).unsqueeze(1).expand(
|
| 711 |
+
-1, top_k).reshape(-1))
|
| 712 |
+
|
| 713 |
+
# Flatten token-to-expert mappings and map to local experts
|
| 714 |
+
weights_flat = topk_weights.view(-1)
|
| 715 |
+
experts_flat = topk_ids.view(-1)
|
| 716 |
+
local_experts_flat = expert_map[experts_flat]
|
| 717 |
+
|
| 718 |
+
# Filter valid token-expert pairs
|
| 719 |
+
mask = local_experts_flat != -1
|
| 720 |
+
filtered_weights = torch.where(
|
| 721 |
+
mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype)
|
| 722 |
+
filtered_experts = torch.where(
|
| 723 |
+
mask, local_experts_flat,
|
| 724 |
+
torch.full_like(local_experts_flat,
|
| 725 |
+
num_experts)).to(topk_ids.dtype)
|
| 726 |
+
|
| 727 |
+
# Sort by local expert IDs
|
| 728 |
+
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
| 729 |
+
sorted_token_indices = token_indices[sort_indices]
|
| 730 |
+
sorted_weights = filtered_weights[sort_indices]
|
| 731 |
+
|
| 732 |
+
# Compute token counts with minlength of num_experts
|
| 733 |
+
# This is equivalent to but faster than:
|
| 734 |
+
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
| 735 |
+
token_counts = torch.zeros(num_experts + 1,
|
| 736 |
+
device=device,
|
| 737 |
+
dtype=torch.int64)
|
| 738 |
+
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
| 739 |
+
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
| 740 |
+
token_counts = token_counts[:num_experts]
|
| 741 |
+
expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64)
|
| 742 |
+
|
| 743 |
+
# Rearrange hidden_states
|
| 744 |
+
sorted_hidden_states = hidden_states[sorted_token_indices]
|
| 745 |
+
else:
|
| 746 |
+
row_idx_len = num_tokens * top_k
|
| 747 |
+
row_idx = (torch.arange(0,
|
| 748 |
+
row_idx_len,
|
| 749 |
+
dtype=torch.int32,
|
| 750 |
+
device=device).view(top_k, -1).permute(
|
| 751 |
+
1, 0).contiguous())
|
| 752 |
+
active_num = max_num_tokens if max_num_tokens is not None else num_tokens
|
| 753 |
+
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 754 |
+
hidden_states,
|
| 755 |
+
row_idx=row_idx,
|
| 756 |
+
expert_idx=topk_ids,
|
| 757 |
+
active_num=active_num)
|
| 758 |
+
|
| 759 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 760 |
+
expanded_expert_idx, num_experts)
|
| 761 |
+
expert_tokens = expert_tokens.to(torch.int64)
|
| 762 |
+
|
| 763 |
+
w1 = w1.transpose(1, 2)
|
| 764 |
+
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
| 765 |
+
x=[sorted_hidden_states],
|
| 766 |
+
weight=[w1],
|
| 767 |
+
split_item=2,
|
| 768 |
+
group_list_type=0,
|
| 769 |
+
group_type=0,
|
| 770 |
+
group_list=expert_tokens,
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
# TODO: Remove this in the future.
|
| 774 |
+
gate_up_out = torch.cat(gate_up_out_list, dim=0)
|
| 775 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
| 776 |
+
|
| 777 |
+
w2 = w2.transpose(1, 2)
|
| 778 |
+
down_out_list = torch_npu.npu_grouped_matmul(
|
| 779 |
+
x=[gate_up_out],
|
| 780 |
+
weight=[w2],
|
| 781 |
+
split_item=2,
|
| 782 |
+
group_list_type=0,
|
| 783 |
+
group_type=0,
|
| 784 |
+
group_list=expert_tokens,
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
down_out_list = torch.cat(down_out_list, dim=0)
|
| 788 |
+
|
| 789 |
+
if expert_map is not None:
|
| 790 |
+
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)
|
| 791 |
+
|
| 792 |
+
final_hidden_states = torch.zeros(*original_shape,
|
| 793 |
+
device=hidden_states.device,
|
| 794 |
+
dtype=dtype)
|
| 795 |
+
|
| 796 |
+
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
| 797 |
+
# This created multiple NaN and index_add_ will mix them up which harms accuracy
|
| 798 |
+
# remove this mask and filter after it being fixed
|
| 799 |
+
num_valid_tokens = mask.sum()
|
| 800 |
+
valid_token_mask = torch.arange(
|
| 801 |
+
0, sorted_token_indices.shape[0],
|
| 802 |
+
device=device).unsqueeze(1) < num_valid_tokens
|
| 803 |
+
valid_output = torch.where(
|
| 804 |
+
valid_token_mask, weighted_down_out,
|
| 805 |
+
torch.zeros_like(weighted_down_out)).to(dtype)
|
| 806 |
+
final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
|
| 807 |
+
else:
|
| 808 |
+
scales = torch.ones_like(
|
| 809 |
+
topk_weights) if apply_router_weight_on_input else topk_weights
|
| 810 |
+
# TODO: Reorder device memory 2 times here, replace the current
|
| 811 |
+
# implementation here when suitable operators become available.
|
| 812 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 813 |
+
down_out_list,
|
| 814 |
+
skip1=None,
|
| 815 |
+
skip2=None,
|
| 816 |
+
bias=None,
|
| 817 |
+
scales=scales,
|
| 818 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 819 |
+
export_for_source_row=topk_ids,
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
return final_hidden_states
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def fused_experts_allgather_ep(
|
| 826 |
+
hidden_states: torch.Tensor,
|
| 827 |
+
w1: torch.Tensor,
|
| 828 |
+
w2: torch.Tensor,
|
| 829 |
+
topk_weights: torch.Tensor,
|
| 830 |
+
topk_ids: torch.Tensor,
|
| 831 |
+
is_prefill: bool
|
| 832 |
+
):
|
| 833 |
+
local_rank = torch.distributed.get_rank(group=get_ep_group().device_group)
|
| 834 |
+
num_experts_per_ep = w1.shape[0]
|
| 835 |
+
local_expert_indices_offset = local_rank * num_experts_per_ep
|
| 836 |
+
global_local_mask = (topk_ids >= local_expert_indices_offset) & \
|
| 837 |
+
(topk_ids <= local_expert_indices_offset + num_experts_per_ep - 1)
|
| 838 |
+
non_global_local_mask = (~global_local_mask).to(torch.int32)
|
| 839 |
+
global_local_mask = global_local_mask.to(torch.int32)
|
| 840 |
+
row_idx = torch.arange(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32).view(
|
| 841 |
+
-1, topk_ids.shape[0]).transpose(0, 1).contiguous()
|
| 842 |
+
|
| 843 |
+
topk_ids -= local_expert_indices_offset
|
| 844 |
+
local_topk_ids_mask_with_max = topk_ids * global_local_mask + non_global_local_mask * num_experts_per_ep
|
| 845 |
+
sorted_tokens, expanded_src_to_dst_row, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 846 |
+
x=hidden_states,
|
| 847 |
+
row_idx=row_idx,
|
| 848 |
+
expert_idx=local_topk_ids_mask_with_max,
|
| 849 |
+
active_num=topk_ids.shape[0]*topk_ids.shape[1]
|
| 850 |
+
)
|
| 851 |
+
if expanded_expert_idx.shape[0] > 8192:
|
| 852 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(expanded_expert_idx, num_experts_per_ep + 1)
|
| 853 |
+
expert_tokens = expert_tokens[:-1]
|
| 854 |
+
else:
|
| 855 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(expanded_expert_idx, num_experts_per_ep)
|
| 856 |
+
expert_tokens = expert_tokens.to(torch.int64)
|
| 857 |
+
|
| 858 |
+
w1 = w1.transpose(1, 2)
|
| 859 |
+
gate_up_out = torch_npu.npu_grouped_matmul(
|
| 860 |
+
x=[sorted_tokens],
|
| 861 |
+
weight=[w1],
|
| 862 |
+
group_list=expert_tokens,
|
| 863 |
+
split_item=3,
|
| 864 |
+
group_type=0
|
| 865 |
+
)[0]
|
| 866 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
| 867 |
+
|
| 868 |
+
w2 = w2.transpose(1, 2)
|
| 869 |
+
down_out = torch_npu.npu_grouped_matmul(
|
| 870 |
+
x=[gate_up_out],
|
| 871 |
+
weight=[w2],
|
| 872 |
+
group_list=expert_tokens,
|
| 873 |
+
split_item=3,
|
| 874 |
+
group_type=0
|
| 875 |
+
)[0]
|
| 876 |
+
|
| 877 |
+
if is_prefill:
|
| 878 |
+
down_out[expert_tokens[-1]:] = 0
|
| 879 |
+
else:
|
| 880 |
+
sorted_tokens_mask = expanded_expert_idx != num_experts_per_ep
|
| 881 |
+
down_out *= sorted_tokens_mask.unsqueeze(1)
|
| 882 |
+
|
| 883 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 884 |
+
expanded_permuted_rows=down_out,
|
| 885 |
+
skip1=None,
|
| 886 |
+
skip2=None,
|
| 887 |
+
bias=None,
|
| 888 |
+
scales=topk_weights.to(down_out.dtype),
|
| 889 |
+
expanded_src_to_dst_row=expanded_src_to_dst_row,
|
| 890 |
+
export_for_source_row=topk_ids
|
| 891 |
+
)
|
| 892 |
+
return final_hidden_states
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
def select_gating_top_k_softmax_experts(
|
| 896 |
+
hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int,
|
| 897 |
+
renormalize: bool) -> tuple[torch.Tensor, torch.Tensor]:
|
| 898 |
+
"""
|
| 899 |
+
Select top-k experts based on router logits.
|
| 900 |
+
only supports float16、bfloat16、float32
|
| 901 |
+
|
| 902 |
+
Args:
|
| 903 |
+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
| 904 |
+
router_logits: Router logits of shape (num_tokens, num_experts).
|
| 905 |
+
top_k: Number of experts to select.
|
| 906 |
+
renormalize: Whether to renormalize the routing weights.
|
| 907 |
+
|
| 908 |
+
Returns:
|
| 909 |
+
topk_weights: Routing weights of shape (num_tokens, top_k).
|
| 910 |
+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
| 911 |
+
|
| 912 |
+
Raises:
|
| 913 |
+
ValueError: If an unsupported scoring function is provided.
|
| 914 |
+
"""
|
| 915 |
+
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
|
| 916 |
+
router_logits, None, k=top_k)
|
| 917 |
+
|
| 918 |
+
# # Required by npu_moe_init_routing
|
| 919 |
+
# topk_weights = topk_weights.to(hidden_states.dtype)
|
| 920 |
+
# topk_ids = topk_ids.to(torch.int32)
|
| 921 |
+
|
| 922 |
+
if renormalize:
|
| 923 |
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
| 924 |
+
|
| 925 |
+
return topk_weights, topk_ids
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
def native_grouped_topk(
|
| 929 |
+
topk_weights: torch.Tensor,
|
| 930 |
+
num_expert_group: Optional[int],
|
| 931 |
+
topk_group: Optional[int],
|
| 932 |
+
):
|
| 933 |
+
topk_group = 0 if topk_group is None else topk_group
|
| 934 |
+
num_expert_group = 0 if num_expert_group is None else num_expert_group
|
| 935 |
+
|
| 936 |
+
num_token = topk_weights.shape[0]
|
| 937 |
+
grouped_weights = topk_weights.view(num_token, num_expert_group,
|
| 938 |
+
-1).max(dim=-1).values
|
| 939 |
+
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
|
| 940 |
+
k=topk_group,
|
| 941 |
+
dim=-1,
|
| 942 |
+
sorted=False)[1]
|
| 943 |
+
topk_group_mask = torch.zeros_like(grouped_weights)
|
| 944 |
+
topk_group_mask.scatter_(1, topk_group_indices, 1)
|
| 945 |
+
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
|
| 946 |
+
num_token, num_expert_group,
|
| 947 |
+
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
|
| 948 |
+
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
|
| 949 |
+
|
| 950 |
+
return topk_weights
|
| 951 |
+
|
| 952 |
+
|
| 953 |
+
def select_experts(
|
| 954 |
+
hidden_states: torch.Tensor,
|
| 955 |
+
router_logits: torch.Tensor,
|
| 956 |
+
top_k: int,
|
| 957 |
+
use_grouped_topk: bool,
|
| 958 |
+
renormalize: bool,
|
| 959 |
+
topk_group: Optional[int] = None,
|
| 960 |
+
num_expert_group: Optional[int] = None,
|
| 961 |
+
custom_routing_function: Optional[Callable] = None,
|
| 962 |
+
scoring_func: str = "softmax",
|
| 963 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 964 |
+
global_num_experts: Optional[torch.Tensor] = None
|
| 965 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 966 |
+
"""
|
| 967 |
+
Select top-k experts based on router logits.
|
| 968 |
+
|
| 969 |
+
Args:
|
| 970 |
+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
| 971 |
+
router_logits: Router logits of shape (num_tokens, num_experts).
|
| 972 |
+
top_k: Number of experts to select.
|
| 973 |
+
use_grouped_topk: Whether to group experts before selecting top-k.
|
| 974 |
+
renormalize: Whether to renormalize the routing weights.
|
| 975 |
+
topk_group: Number of expert groups to select from.
|
| 976 |
+
num_expert_group: Number of experts in each group.
|
| 977 |
+
custom_routing_function: Custom routing function.
|
| 978 |
+
scoring_func: Scoring function to use.
|
| 979 |
+
e_score_correction_bias: Correction bias to apply to expert scores.
|
| 980 |
+
|
| 981 |
+
Returns:
|
| 982 |
+
topk_weights: Routing weights of shape (num_tokens, top_k).
|
| 983 |
+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
| 984 |
+
|
| 985 |
+
Raises:
|
| 986 |
+
ValueError: If an unsupported scoring function is provided.
|
| 987 |
+
"""
|
| 988 |
+
|
| 989 |
+
if scoring_func == "softmax":
|
| 990 |
+
# NOTE: vLLM use dtype=torch.float here
|
| 991 |
+
topk_weights = router_logits.softmax(dim=-1)
|
| 992 |
+
elif scoring_func == "sigmoid":
|
| 993 |
+
topk_weights = router_logits.sigmoid()
|
| 994 |
+
else:
|
| 995 |
+
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
| 996 |
+
|
| 997 |
+
if use_grouped_topk:
|
| 998 |
+
assert topk_group is not None
|
| 999 |
+
assert num_expert_group is not None
|
| 1000 |
+
|
| 1001 |
+
if e_score_correction_bias is not None:
|
| 1002 |
+
# Store original scores before applying correction bias. We use biased
|
| 1003 |
+
# scores for expert selection but original scores for routing weights
|
| 1004 |
+
original_weights = topk_weights
|
| 1005 |
+
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
|
| 1006 |
+
|
| 1007 |
+
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
|
| 1008 |
+
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
|
| 1009 |
+
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
|
| 1010 |
+
topk_group)
|
| 1011 |
+
# TODO bfloat16 is not supported in torch.topk with ge graph.
|
| 1012 |
+
if e_score_correction_bias is not None:
|
| 1013 |
+
topk_ids = torch.topk(topk_weights.to(torch.float32),
|
| 1014 |
+
k=top_k,
|
| 1015 |
+
dim=-1,
|
| 1016 |
+
sorted=False)[1]
|
| 1017 |
+
# Use original unbiased scores for the routing weights
|
| 1018 |
+
topk_weights = original_weights.gather(1, topk_ids)
|
| 1019 |
+
else:
|
| 1020 |
+
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
|
| 1021 |
+
k=top_k,
|
| 1022 |
+
dim=-1,
|
| 1023 |
+
sorted=False)
|
| 1024 |
+
elif custom_routing_function is None:
|
| 1025 |
+
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
| 1026 |
+
else:
|
| 1027 |
+
topk_weights, topk_ids = custom_routing_function(
|
| 1028 |
+
hidden_states=hidden_states,
|
| 1029 |
+
gating_output=router_logits,
|
| 1030 |
+
topk=top_k,
|
| 1031 |
+
renormalize=renormalize,
|
| 1032 |
+
global_num_experts=global_num_experts)
|
| 1033 |
+
# Required by npu_moe_init_routing
|
| 1034 |
+
topk_ids = topk_ids.to(torch.int32)
|
| 1035 |
+
return topk_weights, topk_ids
|
| 1036 |
+
|
| 1037 |
+
# Required by npu_moe_init_routing
|
| 1038 |
+
topk_ids = topk_ids.to(torch.int32)
|
| 1039 |
+
|
| 1040 |
+
if renormalize:
|
| 1041 |
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
| 1042 |
+
|
| 1043 |
+
return topk_weights, topk_ids
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
| 1047 |
+
|
| 1048 |
+
def __init__(self, moe: FusedMoEConfig = None):
|
| 1049 |
+
|
| 1050 |
+
super().__init__(moe=moe)
|
| 1051 |
+
vllm_config = get_current_vllm_config()
|
| 1052 |
+
|
| 1053 |
+
self.ep_group = get_ep_group()
|
| 1054 |
+
self.ep_size = self.ep_group.world_size
|
| 1055 |
+
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
| 1056 |
+
self.local_batch_size = self.global_batch_size // self.ep_size
|
| 1057 |
+
self.max_model_len = vllm_config.model_config.max_model_len
|
| 1058 |
+
|
| 1059 |
+
ascend_config = get_ascend_config()
|
| 1060 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 1061 |
+
|
| 1062 |
+
try:
|
| 1063 |
+
device_group = self.ep_group.device_group
|
| 1064 |
+
# TODO: Try local_rank = ep_group.rank_in_group
|
| 1065 |
+
local_rank = torch.distributed.get_rank(group=device_group)
|
| 1066 |
+
backend = device_group._get_backend(torch.device("npu"))
|
| 1067 |
+
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
| 1068 |
+
local_rank)
|
| 1069 |
+
except AttributeError:
|
| 1070 |
+
self.moe_all_to_all_group_name = None
|
| 1071 |
+
|
| 1072 |
+
def process_weights_after_loading(self, layer):
|
| 1073 |
+
super(UnquantizedFusedMoEMethod,
|
| 1074 |
+
self).process_weights_after_loading(layer)
|
| 1075 |
+
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
| 1076 |
+
layer.w13_weight.data),
|
| 1077 |
+
requires_grad=False)
|
| 1078 |
+
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
| 1079 |
+
layer.w2_weight.data),
|
| 1080 |
+
requires_grad=False)
|
| 1081 |
+
|
| 1082 |
+
def apply(
|
| 1083 |
+
self,
|
| 1084 |
+
layer: torch.nn.Module,
|
| 1085 |
+
x: torch.Tensor,
|
| 1086 |
+
router_logits: torch.Tensor,
|
| 1087 |
+
top_k: int,
|
| 1088 |
+
renormalize: bool,
|
| 1089 |
+
use_grouped_topk: bool = False,
|
| 1090 |
+
global_num_experts: int = -1,
|
| 1091 |
+
expert_map: Optional[torch.Tensor] = None,
|
| 1092 |
+
topk_group: Optional[int] = None,
|
| 1093 |
+
num_expert_group: Optional[int] = None,
|
| 1094 |
+
custom_routing_function: Optional[Callable] = None,
|
| 1095 |
+
scoring_func: str = "softmax",
|
| 1096 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 1097 |
+
is_prefill: bool = False,
|
| 1098 |
+
enable_force_load_balance: bool = False,
|
| 1099 |
+
shared_experts: Optional[Any] = None,
|
| 1100 |
+
**kwargs,
|
| 1101 |
+
) -> torch.Tensor:
|
| 1102 |
+
use_grouped_topk = (topk_group > 1 or num_expert_group > 1)
|
| 1103 |
+
is_deepseek_v3_r1 = global_num_experts == 256
|
| 1104 |
+
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
| 1105 |
+
if use_grouped_topk and is_deepseek_v3_r1:
|
| 1106 |
+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
| 1107 |
+
router_logits,
|
| 1108 |
+
k=top_k, # topk当前写8
|
| 1109 |
+
bias=e_score_correction_bias,
|
| 1110 |
+
k_group=topk_group, # fix: 4
|
| 1111 |
+
group_count=num_expert_group, # fix 8
|
| 1112 |
+
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
|
| 1113 |
+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
| 1114 |
+
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
| 1115 |
+
# out_flag=False, # todo new api; 第三个输出是否输出
|
| 1116 |
+
# y2_flag=False, # old api; 第三个输出是否输出
|
| 1117 |
+
routed_scaling_factor=1,
|
| 1118 |
+
eps=float(1e-20))
|
| 1119 |
+
elif use_grouped_topk and SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
|
| 1120 |
+
topk_weights, topk_ids = select_gating_top_k_softmax_experts(
|
| 1121 |
+
hidden_states=x,
|
| 1122 |
+
router_logits=router_logits,
|
| 1123 |
+
top_k=top_k,
|
| 1124 |
+
renormalize=renormalize)
|
| 1125 |
+
else:
|
| 1126 |
+
topk_weights, topk_ids = select_experts(
|
| 1127 |
+
hidden_states=x,
|
| 1128 |
+
router_logits=router_logits,
|
| 1129 |
+
top_k=top_k,
|
| 1130 |
+
use_grouped_topk=use_grouped_topk,
|
| 1131 |
+
renormalize=renormalize,
|
| 1132 |
+
topk_group=topk_group,
|
| 1133 |
+
num_expert_group=num_expert_group,
|
| 1134 |
+
custom_routing_function=custom_routing_function,
|
| 1135 |
+
scoring_func=scoring_func,
|
| 1136 |
+
e_score_correction_bias=e_score_correction_bias,
|
| 1137 |
+
)
|
| 1138 |
+
|
| 1139 |
+
topk_weights = topk_weights.to(x.dtype)
|
| 1140 |
+
# this is a naive implementation for experts load balance so as
|
| 1141 |
+
# to avoid accumulating too much tokens on a single rank.
|
| 1142 |
+
# currently it is only activated when doing profile runs.
|
| 1143 |
+
if enable_force_load_balance:
|
| 1144 |
+
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
| 1145 |
+
|
| 1146 |
+
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
|
| 1147 |
+
is_prefill, is_deepseek_v3_r1)
|
| 1148 |
+
if fused_moe_state == FusedMoEState.MC2:
|
| 1149 |
+
return fused_experts_with_mc2(
|
| 1150 |
+
hidden_states=x,
|
| 1151 |
+
w1=layer.w13_weight,
|
| 1152 |
+
w2=layer.w2_weight,
|
| 1153 |
+
topk_weights=topk_weights,
|
| 1154 |
+
topk_ids=topk_ids,
|
| 1155 |
+
top_k=top_k,
|
| 1156 |
+
expert_map=expert_map,
|
| 1157 |
+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
| 1158 |
+
shared_experts=shared_experts)
|
| 1159 |
+
elif fused_moe_state == FusedMoEState.AllGatherEP:
|
| 1160 |
+
return fused_experts_allgather_ep(
|
| 1161 |
+
hidden_states=x,
|
| 1162 |
+
w1=layer.w13_weight,
|
| 1163 |
+
w2=layer.w2_weight,
|
| 1164 |
+
topk_weights=topk_weights,
|
| 1165 |
+
topk_ids=topk_ids,
|
| 1166 |
+
is_prefill=is_prefill)
|
| 1167 |
+
elif fused_moe_state in [
|
| 1168 |
+
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
| 1169 |
+
]:
|
| 1170 |
+
return fused_experts(hidden_states=x,
|
| 1171 |
+
w1=layer.w13_weight,
|
| 1172 |
+
w2=layer.w2_weight,
|
| 1173 |
+
topk_weights=topk_weights,
|
| 1174 |
+
topk_ids=topk_ids,
|
| 1175 |
+
top_k=top_k,
|
| 1176 |
+
expert_map=expert_map)
|
| 1177 |
+
elif MOE_ALL2ALL_BUFFER:
|
| 1178 |
+
return fused_experts_with_all2all_buffer(
|
| 1179 |
+
hidden_states=x,
|
| 1180 |
+
w1=layer.w13_weight,
|
| 1181 |
+
w2=layer.w2_weight,
|
| 1182 |
+
topk_weights=topk_weights,
|
| 1183 |
+
topk_ids=topk_ids,
|
| 1184 |
+
top_k=top_k,
|
| 1185 |
+
max_model_len=self.max_model_len,
|
| 1186 |
+
global_batch_size=self.global_batch_size,
|
| 1187 |
+
expert_map=expert_map,
|
| 1188 |
+
ep_group=get_ep_group())
|
| 1189 |
+
else:
|
| 1190 |
+
return fused_experts_with_all2all(hidden_states=x,
|
| 1191 |
+
w1=layer.w13_weight,
|
| 1192 |
+
w2=layer.w2_weight,
|
| 1193 |
+
topk_weights=topk_weights,
|
| 1194 |
+
topk_ids=topk_ids,
|
| 1195 |
+
top_k=top_k,
|
| 1196 |
+
expert_map=expert_map,
|
| 1197 |
+
ep_group=get_ep_group())
|
| 1198 |
+
|
| 1199 |
+
|
| 1200 |
+
class AscendFusedMoE(FusedMoE):
|
| 1201 |
+
|
| 1202 |
+
# The moe_counter parameter is required during the initialization of EPLB
|
| 1203 |
+
# to identify the current layer index within the MOE model.
|
| 1204 |
+
moe_counter = -1
|
| 1205 |
+
|
| 1206 |
+
def __init__(
|
| 1207 |
+
self,
|
| 1208 |
+
num_experts: int, # Global number of experts
|
| 1209 |
+
top_k: int,
|
| 1210 |
+
hidden_size: int,
|
| 1211 |
+
intermediate_size: int,
|
| 1212 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 1213 |
+
reduce_results: bool = False,
|
| 1214 |
+
renormalize: bool = True,
|
| 1215 |
+
use_grouped_topk: bool = False,
|
| 1216 |
+
num_expert_group: Optional[int] = None,
|
| 1217 |
+
topk_group: Optional[int] = None,
|
| 1218 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 1219 |
+
tp_size: Optional[int] = None,
|
| 1220 |
+
ep_size: Optional[int] = None,
|
| 1221 |
+
dp_size: Optional[int] = None,
|
| 1222 |
+
prefix: str = "",
|
| 1223 |
+
custom_routing_function: Optional[Callable] = None,
|
| 1224 |
+
scoring_func: str = "softmax",
|
| 1225 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 1226 |
+
activation: str = "silu",
|
| 1227 |
+
apply_router_weight_on_input: bool = False,
|
| 1228 |
+
):
|
| 1229 |
+
# TODO: This could not initialize FusedMoE baseclass,
|
| 1230 |
+
# fixme and make __init__() of AscendFusedMoE more clear
|
| 1231 |
+
super(FusedMoE, self).__init__()
|
| 1232 |
+
|
| 1233 |
+
AscendFusedMoE.moe_counter += 1
|
| 1234 |
+
self.moe_instance_id = AscendFusedMoE.moe_counter
|
| 1235 |
+
|
| 1236 |
+
if params_dtype is None:
|
| 1237 |
+
params_dtype = torch.get_default_dtype()
|
| 1238 |
+
|
| 1239 |
+
vllm_config = get_current_vllm_config()
|
| 1240 |
+
|
| 1241 |
+
self.moe_parallel_config = FusedMoEParallelConfig.make(
|
| 1242 |
+
tp_size_=(tp_size if tp_size is not None else
|
| 1243 |
+
get_tensor_model_parallel_world_size()),
|
| 1244 |
+
dp_size_=(dp_size
|
| 1245 |
+
if dp_size is not None else get_dp_group().world_size),
|
| 1246 |
+
vllm_parallel_config=vllm_config.parallel_config)
|
| 1247 |
+
|
| 1248 |
+
self.top_k = top_k
|
| 1249 |
+
self.num_experts = num_experts
|
| 1250 |
+
self.global_num_experts = num_experts
|
| 1251 |
+
assert intermediate_size % self.tp_size == 0
|
| 1252 |
+
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
| 1253 |
+
self.reduce_results = reduce_results
|
| 1254 |
+
self.renormalize = renormalize
|
| 1255 |
+
self.use_grouped_topk = use_grouped_topk
|
| 1256 |
+
if self.use_grouped_topk:
|
| 1257 |
+
assert num_expert_group is not None and topk_group is not None
|
| 1258 |
+
self.num_expert_group = num_expert_group
|
| 1259 |
+
self.topk_group = topk_group
|
| 1260 |
+
self.custom_routing_function = custom_routing_function
|
| 1261 |
+
self.scoring_func = scoring_func
|
| 1262 |
+
self.e_score_correction_bias = e_score_correction_bias
|
| 1263 |
+
self.expert_map = None
|
| 1264 |
+
self.activation = activation
|
| 1265 |
+
self.log2phy = None
|
| 1266 |
+
self.global_redundant_expert_num = 0
|
| 1267 |
+
|
| 1268 |
+
is_deepseek_v3_r1 = self.global_num_experts == 256
|
| 1269 |
+
self.rm_router_logits = get_rm_router_logits_state(
|
| 1270 |
+
self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1)
|
| 1271 |
+
self.all_reduce_merge = get_all_reduce_merge_state(
|
| 1272 |
+
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
|
| 1273 |
+
|
| 1274 |
+
ascend_config = get_ascend_config()
|
| 1275 |
+
expert_map_path = ascend_config.expert_map_path
|
| 1276 |
+
if expert_map_path and os.path.exists(expert_map_path):
|
| 1277 |
+
# moe expert load balance
|
| 1278 |
+
expert_load_balancer = ExpertLoadBalancer(expert_map_path,
|
| 1279 |
+
self.global_num_experts)
|
| 1280 |
+
self.local_num_experts, self.expert_map = \
|
| 1281 |
+
expert_load_balancer.get_rank_placement_map(
|
| 1282 |
+
self.moe_instance_id,
|
| 1283 |
+
get_ep_group().rank_in_group)
|
| 1284 |
+
self.log2phy = expert_load_balancer.get_rank_log2phy_map(
|
| 1285 |
+
self.moe_instance_id,
|
| 1286 |
+
get_ep_group().rank_in_group)
|
| 1287 |
+
self.global_redundant_expert_num = \
|
| 1288 |
+
expert_load_balancer.get_global_redundant_expert_num()
|
| 1289 |
+
else:
|
| 1290 |
+
# Create a tensor of size num_experts filled with -1
|
| 1291 |
+
self.local_num_experts, self.expert_map = determine_expert_map(
|
| 1292 |
+
self.ep_size,
|
| 1293 |
+
get_ep_group().rank_in_group, self.global_num_experts)
|
| 1294 |
+
|
| 1295 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 1296 |
+
self.enable_multistream_moe = \
|
| 1297 |
+
ascend_config.torchair_graph_config.enable_multistream_moe
|
| 1298 |
+
|
| 1299 |
+
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
| 1300 |
+
raise ValueError("Only softmax scoring function is supported for "
|
| 1301 |
+
"non-grouped topk.")
|
| 1302 |
+
moe = FusedMoEConfig.make(
|
| 1303 |
+
num_experts=self.global_num_experts,
|
| 1304 |
+
experts_per_token=top_k,
|
| 1305 |
+
hidden_dim=hidden_size,
|
| 1306 |
+
num_local_experts=self.local_num_experts,
|
| 1307 |
+
moe_parallel_config=self.moe_parallel_config,
|
| 1308 |
+
# TODO (bnell): this needs to be fixed for quantized types.
|
| 1309 |
+
in_dtype=params_dtype,
|
| 1310 |
+
quant_config=quant_config)
|
| 1311 |
+
|
| 1312 |
+
if quant_config is None:
|
| 1313 |
+
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
|
| 1314 |
+
else:
|
| 1315 |
+
self.quant_method = quant_config.get_quant_method(self, prefix)
|
| 1316 |
+
|
| 1317 |
+
assert self.quant_method is not None
|
| 1318 |
+
|
| 1319 |
+
local_num_experts = torch.sum(self.expert_map != -1) \
|
| 1320 |
+
if self.expert_map is not None else num_experts
|
| 1321 |
+
|
| 1322 |
+
moe_quant_params = {
|
| 1323 |
+
"num_experts": local_num_experts,
|
| 1324 |
+
"hidden_size": hidden_size,
|
| 1325 |
+
"intermediate_size_per_partition":
|
| 1326 |
+
self.intermediate_size_per_partition,
|
| 1327 |
+
"params_dtype": params_dtype,
|
| 1328 |
+
"weight_loader": self.weight_loader,
|
| 1329 |
+
}
|
| 1330 |
+
# need full intermediate size pre-sharding for WNA16 act order
|
| 1331 |
+
if (self.quant_method.__class__.__name__
|
| 1332 |
+
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
| 1333 |
+
moe_quant_params["intermediate_size_full"] = intermediate_size
|
| 1334 |
+
|
| 1335 |
+
self.ep_group = get_ep_group()
|
| 1336 |
+
# NOTE: self.tp_group is not expert_tp_group
|
| 1337 |
+
self.tp_group = get_tp_group().device_group
|
| 1338 |
+
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
| 1339 |
+
|
| 1340 |
+
def naive_multicast(self, x: torch.Tensor,
|
| 1341 |
+
cu_tokens_across_dp_cpu: torch.Tensor):
|
| 1342 |
+
assert (len(x.shape) == 2)
|
| 1343 |
+
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
| 1344 |
+
device=x.device,
|
| 1345 |
+
dtype=x.dtype)
|
| 1346 |
+
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
| 1347 |
+
self.dp_rank - 1]
|
| 1348 |
+
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
| 1349 |
+
buffer[start:end, :].copy_(x)
|
| 1350 |
+
for idx in range(self.dp_size):
|
| 1351 |
+
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
| 1352 |
+
end = cu_tokens_across_dp_cpu[idx]
|
| 1353 |
+
get_dp_group().broadcast(buffer[start:end, :], idx)
|
| 1354 |
+
return buffer
|
| 1355 |
+
|
| 1356 |
+
def forward(self,
|
| 1357 |
+
hidden_states: torch.Tensor,
|
| 1358 |
+
router_logits: torch.Tensor,
|
| 1359 |
+
is_prefill: bool,
|
| 1360 |
+
enable_force_load_balance: bool = False,
|
| 1361 |
+
top_k: Optional[int] = None,
|
| 1362 |
+
shared_experts: Optional[Any] = None,
|
| 1363 |
+
gate=None,
|
| 1364 |
+
replace_allreduce: bool = False):
|
| 1365 |
+
|
| 1366 |
+
assert self.quant_method is not None
|
| 1367 |
+
|
| 1368 |
+
if top_k:
|
| 1369 |
+
real_top_k = top_k
|
| 1370 |
+
else:
|
| 1371 |
+
real_top_k = self.top_k
|
| 1372 |
+
|
| 1373 |
+
num_tokens, hidden_size = hidden_states.shape
|
| 1374 |
+
is_deepseek_v3_r1 = self.global_num_experts == 256
|
| 1375 |
+
|
| 1376 |
+
fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
|
| 1377 |
+
is_prefill, is_deepseek_v3_r1)
|
| 1378 |
+
if shared_experts:
|
| 1379 |
+
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
|
| 1380 |
+
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
| 1381 |
+
shared_hidden_states = shared_experts(hidden_states)
|
| 1382 |
+
|
| 1383 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 1384 |
+
if (tp_size > 1 and fused_moe_state not in [
|
| 1385 |
+
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
| 1386 |
+
FusedMoEState.NaiveMulticast
|
| 1387 |
+
] and not replace_allreduce):
|
| 1388 |
+
if num_tokens < tp_size:
|
| 1389 |
+
hidden_states = nn.functional.pad(
|
| 1390 |
+
hidden_states, (0, 0, 0, tp_size - num_tokens))
|
| 1391 |
+
router_logits = nn.functional.pad(
|
| 1392 |
+
router_logits, (0, 0, 0, tp_size - num_tokens))
|
| 1393 |
+
chunk_hidden_states = torch.tensor_split(hidden_states,
|
| 1394 |
+
tp_size,
|
| 1395 |
+
dim=0)
|
| 1396 |
+
chunk_router_logits = torch.tensor_split(router_logits,
|
| 1397 |
+
tp_size,
|
| 1398 |
+
dim=0)
|
| 1399 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 1400 |
+
hidden_states = chunk_hidden_states[tp_rank]
|
| 1401 |
+
router_logits = chunk_router_logits[tp_rank]
|
| 1402 |
+
|
| 1403 |
+
if self.dp_size > 1:
|
| 1404 |
+
if fused_moe_state in (FusedMoEState.AllGather, FusedMoEState.AllGatherEP):
|
| 1405 |
+
# NOTE: When in torchair graph, it has been padded in model_runner_v1
|
| 1406 |
+
if not self.torchair_graph_enabled or is_prefill:
|
| 1407 |
+
attn_metadata = get_forward_context().attn_metadata
|
| 1408 |
+
if attn_metadata is not None:
|
| 1409 |
+
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
|
| 1410 |
+
if num_tokens < max_num_tokens_across_dp:
|
| 1411 |
+
hidden_states = nn.functional.pad(
|
| 1412 |
+
hidden_states,
|
| 1413 |
+
(0, 0, 0,
|
| 1414 |
+
max_num_tokens_across_dp - num_tokens))
|
| 1415 |
+
if not self.rm_router_logits:
|
| 1416 |
+
router_logits = nn.functional.pad(
|
| 1417 |
+
router_logits,
|
| 1418 |
+
(0, 0, 0,
|
| 1419 |
+
max_num_tokens_across_dp - num_tokens))
|
| 1420 |
+
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
| 1421 |
+
if self.rm_router_logits:
|
| 1422 |
+
router_logits, _ = gate(hidden_states.float())
|
| 1423 |
+
else:
|
| 1424 |
+
router_logits = get_dp_group().all_gather(router_logits, 0)
|
| 1425 |
+
|
| 1426 |
+
elif fused_moe_state == FusedMoEState.NaiveMulticast:
|
| 1427 |
+
cu_tokens_across_dp_cpu = get_forward_context(
|
| 1428 |
+
).dp_metadata.cu_tokens_across_dp_cpu
|
| 1429 |
+
hidden_states = self.naive_multicast(hidden_states,
|
| 1430 |
+
cu_tokens_across_dp_cpu)
|
| 1431 |
+
if self.rm_router_logits:
|
| 1432 |
+
router_logits, _ = gate(hidden_states.float())
|
| 1433 |
+
else:
|
| 1434 |
+
router_logits = self.naive_multicast(
|
| 1435 |
+
router_logits, cu_tokens_across_dp_cpu)
|
| 1436 |
+
|
| 1437 |
+
# Matrix multiply.
|
| 1438 |
+
e_hidden_states = self.quant_method.apply(
|
| 1439 |
+
layer=self,
|
| 1440 |
+
x=hidden_states,
|
| 1441 |
+
router_logits=router_logits,
|
| 1442 |
+
top_k=real_top_k,
|
| 1443 |
+
renormalize=self.renormalize,
|
| 1444 |
+
use_grouped_topk=self.use_grouped_topk,
|
| 1445 |
+
global_num_experts=self.global_num_experts,
|
| 1446 |
+
expert_map=self.expert_map,
|
| 1447 |
+
topk_group=self.topk_group,
|
| 1448 |
+
num_expert_group=self.num_expert_group,
|
| 1449 |
+
custom_routing_function=self.custom_routing_function,
|
| 1450 |
+
scoring_func=self.scoring_func,
|
| 1451 |
+
e_score_correction_bias=self.e_score_correction_bias,
|
| 1452 |
+
is_prefill=is_prefill,
|
| 1453 |
+
enable_force_load_balance=enable_force_load_balance,
|
| 1454 |
+
log2phy=self.log2phy,
|
| 1455 |
+
global_redundant_expert_num=self.global_redundant_expert_num,
|
| 1456 |
+
shared_experts=shared_experts if self.torchair_graph_enabled
|
| 1457 |
+
and self.enable_multistream_moe and not is_prefill else None,
|
| 1458 |
+
)
|
| 1459 |
+
|
| 1460 |
+
if shared_experts:
|
| 1461 |
+
if isinstance(e_hidden_states, tuple):
|
| 1462 |
+
e_hidden_states, shared_hidden_states = e_hidden_states
|
| 1463 |
+
|
| 1464 |
+
if (tp_size > 1 and fused_moe_state not in [
|
| 1465 |
+
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
| 1466 |
+
FusedMoEState.NaiveMulticast
|
| 1467 |
+
] and not replace_allreduce):
|
| 1468 |
+
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
| 1469 |
+
self.tp_group)
|
| 1470 |
+
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
| 1471 |
+
if num_tokens < tp_size:
|
| 1472 |
+
final_hidden_states = final_hidden_states[:num_tokens]
|
| 1473 |
+
dispose_tensor(e_hidden_states)
|
| 1474 |
+
elif self.dp_size > 1:
|
| 1475 |
+
if fused_moe_state == FusedMoEState.NaiveMulticast:
|
| 1476 |
+
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
| 1477 |
+
self.dp_rank - 1]
|
| 1478 |
+
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
| 1479 |
+
final_hidden_states = get_dp_group().all_reduce(
|
| 1480 |
+
e_hidden_states)
|
| 1481 |
+
final_hidden_states = final_hidden_states[start:end, :]
|
| 1482 |
+
dispose_tensor(e_hidden_states)
|
| 1483 |
+
elif fused_moe_state in (FusedMoEState.AllGather, FusedMoEState.AllGatherEP):
|
| 1484 |
+
final_hidden_states = data_parallel_reduce_scatter(
|
| 1485 |
+
e_hidden_states, dim=0)
|
| 1486 |
+
final_hidden_states = final_hidden_states[:num_tokens]
|
| 1487 |
+
dispose_tensor(e_hidden_states)
|
| 1488 |
+
else:
|
| 1489 |
+
final_hidden_states = e_hidden_states
|
| 1490 |
+
|
| 1491 |
+
if tp_size > 1 and not self.all_reduce_merge and fused_moe_state in [
|
| 1492 |
+
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
| 1493 |
+
FusedMoEState.NaiveMulticast
|
| 1494 |
+
]:
|
| 1495 |
+
final_hidden_states = tensor_model_parallel_all_reduce(
|
| 1496 |
+
final_hidden_states)
|
| 1497 |
+
|
| 1498 |
+
if shared_experts:
|
| 1499 |
+
return final_hidden_states, shared_hidden_states
|
| 1500 |
+
else:
|
| 1501 |
+
return final_hidden_states
|
| 1502 |
+
|
| 1503 |
+
# ----------------------------------------- TBO-related --------------------------------------------
|
| 1504 |
+
|
| 1505 |
+
def _forward_ms_fused_moe_comp(
|
| 1506 |
+
self,
|
| 1507 |
+
hidden_states: torch.Tensor,
|
| 1508 |
+
router_logits: torch.Tensor,
|
| 1509 |
+
is_prefill: bool,
|
| 1510 |
+
real_top_k,
|
| 1511 |
+
enable_force_load_balance: bool = False,
|
| 1512 |
+
):
|
| 1513 |
+
hidden_states = self.quant_method.apply(
|
| 1514 |
+
layer=self,
|
| 1515 |
+
x=hidden_states,
|
| 1516 |
+
router_logits=router_logits,
|
| 1517 |
+
top_k=real_top_k,
|
| 1518 |
+
renormalize=self.renormalize,
|
| 1519 |
+
use_grouped_topk=self.use_grouped_topk,
|
| 1520 |
+
global_num_experts=self.global_num_experts,
|
| 1521 |
+
expert_map=self.expert_map,
|
| 1522 |
+
topk_group=self.topk_group,
|
| 1523 |
+
num_expert_group=self.num_expert_group,
|
| 1524 |
+
custom_routing_function=self.custom_routing_function,
|
| 1525 |
+
scoring_func=self.scoring_func,
|
| 1526 |
+
e_score_correction_bias=self.e_score_correction_bias,
|
| 1527 |
+
is_prefill=is_prefill,
|
| 1528 |
+
enable_force_load_balance=enable_force_load_balance)
|
| 1529 |
+
|
| 1530 |
+
return hidden_states
|
inference/vllm_ascend/patch/worker/patch_common/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# This file is a part of the vllm-ascend project.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
# patch_utils should be the first import, because it will be used by other
|
| 19 |
+
# patch files.
|
| 20 |
+
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
|
| 21 |
+
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
| 22 |
+
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
| 23 |
+
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
|
| 24 |
+
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
|
| 25 |
+
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
|
| 26 |
+
import vllm_ascend.patch.worker.patch_common.patch_config # noqa
|
| 27 |
+
import vllm_ascend.patch.worker.patch_common.patch_parsers # noqa
|
inference/vllm_ascend/patch/worker/patch_common/patch_config.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# This file is a part of the vllm-ascend project.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
from vllm.config import ModelConfig
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_attr_by_names(src_config, attrs, default_value):
|
| 21 |
+
for attr in attrs:
|
| 22 |
+
value = getattr(src_config, attr, 0)
|
| 23 |
+
if value > 0:
|
| 24 |
+
return value
|
| 25 |
+
return default_value
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _verify_with_expert_parallelism(self) -> None:
|
| 29 |
+
num_expert_names = [
|
| 30 |
+
"moe_num_experts", # Dbrx
|
| 31 |
+
"num_experts", # Jamba
|
| 32 |
+
"n_routed_experts", # DeepSeek
|
| 33 |
+
"num_local_experts", # Mixtral
|
| 34 |
+
"num_routed_experts", # Pangu
|
| 35 |
+
]
|
| 36 |
+
num_experts = 0
|
| 37 |
+
for name in num_expert_names:
|
| 38 |
+
num_experts = getattr(self.hf_text_config, name, 0)
|
| 39 |
+
if num_experts > 0:
|
| 40 |
+
break
|
| 41 |
+
if num_experts < 1:
|
| 42 |
+
raise ValueError(
|
| 43 |
+
"Number of experts in the model must be greater than 0 "
|
| 44 |
+
"when expert parallelism is enabled.")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def is_deepseek_mla(self) -> bool:
|
| 49 |
+
kv_lora_dim_names = ['attention_kv_lora_dim', 'kv_lora_rank']
|
| 50 |
+
kv_lora_dim = get_attr_by_names(self.hf_text_config, kv_lora_dim_names, None)
|
| 51 |
+
if not hasattr(self.hf_text_config, "model_type"):
|
| 52 |
+
return False
|
| 53 |
+
elif self.hf_text_config.model_type in \
|
| 54 |
+
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'pangu_ultra_moe'):
|
| 55 |
+
return kv_lora_dim is not None
|
| 56 |
+
elif self.hf_text_config.model_type == 'eagle':
|
| 57 |
+
# if the model is an EAGLE module, check for the
|
| 58 |
+
# underlying architecture
|
| 59 |
+
return self.hf_text_config.model.model_type in \
|
| 60 |
+
('deepseek_v2', 'deepseek_v3', 'pangu_ultra_moe') \
|
| 61 |
+
and kv_lora_dim is not None
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_head_size(self) -> int:
|
| 66 |
+
if self.is_deepseek_mla:
|
| 67 |
+
qk_rope_dim_names = ['attention_qk_rope_dim', 'qk_rope_head_dim']
|
| 68 |
+
kv_lora_dim_names = ['attention_kv_lora_dim', 'kv_lora_rank']
|
| 69 |
+
qk_rope_dim = get_attr_by_names(self.hf_text_config, qk_rope_dim_names, 0)
|
| 70 |
+
kv_lora_dim = get_attr_by_names(self.hf_text_config, kv_lora_dim_names, 0)
|
| 71 |
+
if self.use_mla:
|
| 72 |
+
return kv_lora_dim + qk_rope_dim
|
| 73 |
+
else:
|
| 74 |
+
qk_dim_names = ['attention_qk_dim', 'qk_nope_head_dim']
|
| 75 |
+
qk_dim = get_attr_by_names(self.hf_text_config, qk_dim_names, 0)
|
| 76 |
+
if qk_rope_dim and qk_dim:
|
| 77 |
+
return qk_rope_dim + qk_dim
|
| 78 |
+
if hasattr(self.hf_text_config,
|
| 79 |
+
"model_type") and (self.hf_text_config.model_type
|
| 80 |
+
== "zamba2"):
|
| 81 |
+
return self.hf_text_config.attention_head_dim
|
| 82 |
+
|
| 83 |
+
if self.is_attention_free:
|
| 84 |
+
return 0
|
| 85 |
+
|
| 86 |
+
# NOTE: Some configs may set head_dim=None in the config
|
| 87 |
+
if getattr(self.hf_text_config, "head_dim", None) is not None:
|
| 88 |
+
return self.hf_text_config.head_dim
|
| 89 |
+
|
| 90 |
+
# FIXME(woosuk): This may not be true for all models.
|
| 91 |
+
return (self.hf_text_config.hidden_size //
|
| 92 |
+
self.hf_text_config.num_attention_heads)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
ModelConfig._verify_with_expert_parallelism = _verify_with_expert_parallelism
|
| 96 |
+
ModelConfig.is_deepseek_mla = is_deepseek_mla
|
| 97 |
+
ModelConfig.get_head_size = get_head_size
|
inference/vllm_ascend/patch/worker/patch_common/patch_parsers.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# This file is a part of the vllm-ascend project.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from vllm.entrypoints.openai import tool_parsers
|
| 20 |
+
from vllm_ascend.entrypoints.openai.tool_parsers import PanguToolParser
|
| 21 |
+
tool_parsers.__all__.append("PanguToolParser")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
from vllm import reasoning
|
| 25 |
+
from vllm_ascend.entrypoints.openai.reasoning_parsers import PanguReasoningParser
|
| 26 |
+
reasoning.__all__.append("PanguReasoningParser")
|
inference/vllm_ascend/patch/worker/patch_common/patch_sampler.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
# This file is a part of the vllm-ascend project.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
#
|
| 18 |
+
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch_npu
|
| 23 |
+
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
|
| 24 |
+
from vllm.v1.sample.sampler import Sampler, _SAMPLING_EPS
|
| 25 |
+
from vllm.v1.sample.metadata import SamplingMetadata
|
| 26 |
+
from vllm_ascend import envs
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def apply_top_k_top_p(
|
| 30 |
+
logits: torch.Tensor,
|
| 31 |
+
k: torch.Tensor,
|
| 32 |
+
p: torch.Tensor,
|
| 33 |
+
) -> torch.Tensor:
|
| 34 |
+
if p is not None and k is not None:
|
| 35 |
+
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
|
| 36 |
+
return torch_npu.npu_top_k_top_p(logits, p, k)
|
| 37 |
+
|
| 38 |
+
probs = logits.softmax(dim=-1)
|
| 39 |
+
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
| 40 |
+
|
| 41 |
+
if k is not None:
|
| 42 |
+
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
| 43 |
+
top_k_count = top_k_count.unsqueeze(dim=1)
|
| 44 |
+
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
| 45 |
+
|
| 46 |
+
# Make sure the no top-k rows are no-op.
|
| 47 |
+
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
| 48 |
+
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
| 49 |
+
|
| 50 |
+
elements_to_discard = probs < top_k_cutoff
|
| 51 |
+
logits.masked_fill_(elements_to_discard, -float("inf"))
|
| 52 |
+
|
| 53 |
+
if p is not None:
|
| 54 |
+
cumprob = torch.cumsum(probs_sort, dim=-1)
|
| 55 |
+
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
| 56 |
+
top_p_mask[:, -1] = False # at least one
|
| 57 |
+
|
| 58 |
+
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
| 59 |
+
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
| 60 |
+
elements_to_discard = probs < top_p_cutoff
|
| 61 |
+
logits.masked_fill_(elements_to_discard, -float("inf"))
|
| 62 |
+
|
| 63 |
+
return logits
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def topk_topp_forward_native(
|
| 67 |
+
self,
|
| 68 |
+
logits: torch.Tensor,
|
| 69 |
+
generators: dict[int, torch.Generator],
|
| 70 |
+
k: Optional[torch.Tensor],
|
| 71 |
+
p: Optional[torch.Tensor],
|
| 72 |
+
) -> torch.Tensor:
|
| 73 |
+
"""
|
| 74 |
+
PyTorch-native implementation of top-k and top-p sampling.
|
| 75 |
+
|
| 76 |
+
The logits tensor may be updated in-place.
|
| 77 |
+
"""
|
| 78 |
+
logits = apply_top_k_top_p(logits, k, p)
|
| 79 |
+
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
| 80 |
+
return random_sample(probs, generators)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def apply_top_n_sigma(
|
| 84 |
+
logits: torch.Tensor,
|
| 85 |
+
sampling_metadata: SamplingMetadata,
|
| 86 |
+
):
|
| 87 |
+
if sampling_metadata.no_top_n_sigma:
|
| 88 |
+
return logits
|
| 89 |
+
|
| 90 |
+
top_n_sigma = sampling_metadata.top_n_sigma[:, None]
|
| 91 |
+
top_n_sigma_mask = (top_n_sigma != -1)
|
| 92 |
+
filter_value = -3.4028e+38
|
| 93 |
+
max_vals, _ = logits.max(dim=-1, keepdim=True)
|
| 94 |
+
std_vals = logits.std(dim=-1, keepdim=True)
|
| 95 |
+
threshold = max_vals - top_n_sigma * std_vals
|
| 96 |
+
threshold[~top_n_sigma_mask] = filter_value
|
| 97 |
+
mask = (logits < threshold)
|
| 98 |
+
logits = torch.where(mask, filter_value, logits)
|
| 99 |
+
return logits
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def sample(
|
| 103 |
+
self,
|
| 104 |
+
logits: torch.Tensor,
|
| 105 |
+
sampling_metadata: SamplingMetadata,
|
| 106 |
+
) -> torch.Tensor:
|
| 107 |
+
"""Sample logits based on sampling metadata.
|
| 108 |
+
|
| 109 |
+
The various logits processing functions called in this method
|
| 110 |
+
may update the logits tensor in-place.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
assert not (sampling_metadata.all_greedy
|
| 114 |
+
and sampling_metadata.all_random)
|
| 115 |
+
if sampling_metadata.all_random:
|
| 116 |
+
greedy_sampled = None
|
| 117 |
+
else:
|
| 118 |
+
greedy_sampled = self.greedy_sample(logits)
|
| 119 |
+
if sampling_metadata.all_greedy:
|
| 120 |
+
return greedy_sampled
|
| 121 |
+
|
| 122 |
+
assert sampling_metadata.temperature is not None
|
| 123 |
+
|
| 124 |
+
# Apply temperature.
|
| 125 |
+
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
| 126 |
+
|
| 127 |
+
# Apply logits processors that only apply to random sampling
|
| 128 |
+
# (argmax invariant)
|
| 129 |
+
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
| 130 |
+
logits = processor.apply(logits)
|
| 131 |
+
|
| 132 |
+
# Apply top_n_sigma
|
| 133 |
+
logits = apply_top_n_sigma(logits, sampling_metadata)
|
| 134 |
+
|
| 135 |
+
# Apply top_k and/or top_p.
|
| 136 |
+
random_sampled = self.topk_topp_sampler(
|
| 137 |
+
logits,
|
| 138 |
+
sampling_metadata.generators,
|
| 139 |
+
sampling_metadata.top_k,
|
| 140 |
+
sampling_metadata.top_p,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if greedy_sampled is None:
|
| 144 |
+
return random_sampled
|
| 145 |
+
|
| 146 |
+
sampled = torch.where(
|
| 147 |
+
sampling_metadata.temperature < _SAMPLING_EPS,
|
| 148 |
+
greedy_sampled,
|
| 149 |
+
random_sampled,
|
| 150 |
+
out=greedy_sampled, # Reuse tensor
|
| 151 |
+
)
|
| 152 |
+
return sampled
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
|
| 156 |
+
TopKTopPSampler.forward_native = topk_topp_forward_native
|
| 157 |
+
|
| 158 |
+
if envs.VLLM_ASCEND_ENABLE_TOP_N_SIGMA:
|
| 159 |
+
Sampler.sample = sample
|
inference/vllm_ascend/quantization/w8a8.py
ADDED
|
@@ -0,0 +1,757 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# This file is a part of the vllm-ascend project.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
from typing import Any, Callable, Dict, Optional
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch_npu
|
| 22 |
+
from vllm.attention.backends.abstract import AttentionType
|
| 23 |
+
|
| 24 |
+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
| 25 |
+
from vllm_ascend.distributed.parallel_state import get_ep_group
|
| 26 |
+
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def quant_per_tensor(in_tensor: torch.Tensor,
|
| 30 |
+
input_scale: torch.Tensor,
|
| 31 |
+
input_offset: torch.Tensor,
|
| 32 |
+
function=False):
|
| 33 |
+
return torch_npu.npu_quantize(in_tensor, input_scale, input_offset,
|
| 34 |
+
torch.qint8, -1, function)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class AscendW8A8LinearMethod:
|
| 38 |
+
"""Linear method for Ascend W8A8.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
w_sym: whether the linear weight is symmetrically quantized.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self) -> None:
|
| 45 |
+
# aclnn quant matmul requires to transpose matrix B, set to true by default.
|
| 46 |
+
self.transpose_weight = not is_310p()
|
| 47 |
+
|
| 48 |
+
@staticmethod
|
| 49 |
+
def get_weight(
|
| 50 |
+
input_size: int,
|
| 51 |
+
output_size: int,
|
| 52 |
+
params_dtype: torch.dtype = torch.bfloat16,
|
| 53 |
+
) -> Dict[str, Any]:
|
| 54 |
+
params_dict = {
|
| 55 |
+
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
| 56 |
+
}
|
| 57 |
+
return params_dict
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 61 |
+
params_dict = {}
|
| 62 |
+
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
| 63 |
+
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
|
| 64 |
+
return params_dict
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def get_perchannel_param(
|
| 68 |
+
output_size: int,
|
| 69 |
+
params_dtype: torch.dtype,
|
| 70 |
+
) -> Dict[str, Any]:
|
| 71 |
+
params_dict = {}
|
| 72 |
+
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
|
| 73 |
+
if params_dtype == torch.bfloat16:
|
| 74 |
+
params_dict["deq_scale"] = torch.empty(output_size,
|
| 75 |
+
dtype=torch.float32)
|
| 76 |
+
elif params_dtype == torch.float16:
|
| 77 |
+
params_dict["deq_scale"] = torch.empty(output_size,
|
| 78 |
+
dtype=torch.int64)
|
| 79 |
+
params_dict["weight_scale"] = torch.empty(output_size,
|
| 80 |
+
1,
|
| 81 |
+
dtype=params_dtype)
|
| 82 |
+
params_dict["weight_offset"] = torch.empty(output_size,
|
| 83 |
+
1,
|
| 84 |
+
dtype=params_dtype)
|
| 85 |
+
return params_dict
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def apply(
|
| 89 |
+
layer: torch.nn.Module,
|
| 90 |
+
x: torch.Tensor,
|
| 91 |
+
bias: Optional[torch.Tensor] = None,
|
| 92 |
+
tp_rank: Optional[int] = 0,
|
| 93 |
+
) -> torch.Tensor:
|
| 94 |
+
original_dtype = x.dtype
|
| 95 |
+
if original_dtype != torch.int8:
|
| 96 |
+
x = quant_per_tensor(x, layer.aclnn_input_scale,
|
| 97 |
+
layer.aclnn_input_offset)
|
| 98 |
+
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
| 99 |
+
if is_310p():
|
| 100 |
+
# On 300I Duo platform, we need transpose again if
|
| 101 |
+
# using nz. This transpose can be skipped in torchair.
|
| 102 |
+
output = torch_npu.npu_quant_matmul(
|
| 103 |
+
x,
|
| 104 |
+
layer.weight.data.transpose(1, 0),
|
| 105 |
+
layer.deq_scale,
|
| 106 |
+
bias=quant_bias,
|
| 107 |
+
output_dtype=original_dtype,
|
| 108 |
+
)
|
| 109 |
+
else:
|
| 110 |
+
output = torch_npu.npu_quant_matmul(
|
| 111 |
+
x,
|
| 112 |
+
layer.weight,
|
| 113 |
+
layer.deq_scale,
|
| 114 |
+
bias=quant_bias,
|
| 115 |
+
output_dtype=original_dtype,
|
| 116 |
+
)
|
| 117 |
+
return output
|
| 118 |
+
|
| 119 |
+
def process_weights_after_loading(self, layer):
|
| 120 |
+
expanding_factor = layer.weight.data.shape[1]
|
| 121 |
+
layer.aclnn_input_scale = 1 / torch.nn.Parameter(
|
| 122 |
+
layer.input_scale.data.repeat(expanding_factor),
|
| 123 |
+
requires_grad=False)
|
| 124 |
+
layer.aclnn_input_offset = torch.nn.Parameter(
|
| 125 |
+
layer.input_offset.data.repeat(expanding_factor),
|
| 126 |
+
requires_grad=False).to(layer.aclnn_input_scale.dtype)
|
| 127 |
+
if self.transpose_weight:
|
| 128 |
+
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
| 129 |
+
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data,
|
| 130 |
+
ACL_FORMAT_FRACTAL_NZ)
|
| 131 |
+
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
| 132 |
+
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class AscendW8A8FusedMoEMethod:
|
| 136 |
+
"""FusedMoe method for Ascend W8A8.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(self):
|
| 140 |
+
self.transpose_weight = True
|
| 141 |
+
|
| 142 |
+
@staticmethod
|
| 143 |
+
def get_weight(num_experts: int, intermediate_size_per_partition: int,
|
| 144 |
+
hidden_sizes: int,
|
| 145 |
+
params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 146 |
+
param_dict = {}
|
| 147 |
+
param_dict["w13_weight"] = torch.empty(num_experts,
|
| 148 |
+
2 *
|
| 149 |
+
intermediate_size_per_partition,
|
| 150 |
+
hidden_sizes,
|
| 151 |
+
dtype=torch.int8,
|
| 152 |
+
requires_grad=False)
|
| 153 |
+
param_dict["w2_weight"] = torch.empty(num_experts,
|
| 154 |
+
hidden_sizes,
|
| 155 |
+
intermediate_size_per_partition,
|
| 156 |
+
dtype=torch.int8,
|
| 157 |
+
requires_grad=False)
|
| 158 |
+
return param_dict
|
| 159 |
+
|
| 160 |
+
@staticmethod
|
| 161 |
+
def get_dynamic_quant_param(num_experts: int,
|
| 162 |
+
intermediate_size_per_partition: int,
|
| 163 |
+
hidden_sizes: int,
|
| 164 |
+
params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 165 |
+
param_dict = {}
|
| 166 |
+
param_dict["w13_weight_scale"] = torch.empty(
|
| 167 |
+
num_experts,
|
| 168 |
+
2 * intermediate_size_per_partition,
|
| 169 |
+
1,
|
| 170 |
+
dtype=torch.float32)
|
| 171 |
+
param_dict["w13_weight_offset"] = torch.empty(
|
| 172 |
+
num_experts,
|
| 173 |
+
2 * intermediate_size_per_partition,
|
| 174 |
+
1,
|
| 175 |
+
dtype=torch.float16)
|
| 176 |
+
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
| 177 |
+
hidden_sizes,
|
| 178 |
+
1,
|
| 179 |
+
dtype=torch.float32)
|
| 180 |
+
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
| 181 |
+
hidden_sizes,
|
| 182 |
+
1,
|
| 183 |
+
dtype=torch.float16)
|
| 184 |
+
param_dict["w2_deq_scale"] = torch.empty(num_experts,
|
| 185 |
+
hidden_sizes,
|
| 186 |
+
dtype=torch.float32)
|
| 187 |
+
param_dict["w13_deq_scale"] = torch.empty(
|
| 188 |
+
num_experts,
|
| 189 |
+
2 * intermediate_size_per_partition,
|
| 190 |
+
dtype=torch.float32)
|
| 191 |
+
param_dict["w2_input_scale"] = torch.empty(num_experts,
|
| 192 |
+
1,
|
| 193 |
+
dtype=torch.float32)
|
| 194 |
+
param_dict["w13_input_scale"] = torch.empty(num_experts,
|
| 195 |
+
1,
|
| 196 |
+
dtype=torch.float32)
|
| 197 |
+
param_dict["w2_input_offset"] = torch.empty(num_experts,
|
| 198 |
+
1,
|
| 199 |
+
dtype=torch.int8)
|
| 200 |
+
param_dict["w13_input_offset"] = torch.empty(num_experts,
|
| 201 |
+
1,
|
| 202 |
+
dtype=torch.int8)
|
| 203 |
+
param_dict["quant_bias"] = torch.empty(num_experts,
|
| 204 |
+
hidden_sizes,
|
| 205 |
+
dtype=torch.int32)
|
| 206 |
+
|
| 207 |
+
return param_dict
|
| 208 |
+
|
| 209 |
+
def apply(
|
| 210 |
+
self,
|
| 211 |
+
layer: torch.nn.Module,
|
| 212 |
+
x: torch.Tensor,
|
| 213 |
+
router_logits: torch.Tensor,
|
| 214 |
+
top_k: int,
|
| 215 |
+
renormalize: bool,
|
| 216 |
+
use_grouped_topk: bool = False,
|
| 217 |
+
global_num_experts: int = -1,
|
| 218 |
+
expert_map: Optional[torch.Tensor] = None,
|
| 219 |
+
topk_group: Optional[int] = None,
|
| 220 |
+
num_expert_group: Optional[int] = None,
|
| 221 |
+
custom_routing_function: Optional[Callable] = None,
|
| 222 |
+
scoring_func: str = "softmax",
|
| 223 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 224 |
+
is_prefill: bool = True,
|
| 225 |
+
enable_force_load_balance: bool = False,
|
| 226 |
+
log2phy: torch.Tensor = None,
|
| 227 |
+
global_redundant_expert_num: int = 0,
|
| 228 |
+
shared_experts: Optional[Any] = None,
|
| 229 |
+
**kwargs,
|
| 230 |
+
) -> torch.Tensor:
|
| 231 |
+
assert router_logits.shape[
|
| 232 |
+
1] == global_num_experts, "Number of global experts mismatch"
|
| 233 |
+
|
| 234 |
+
topk_weights, topk_ids = select_experts(
|
| 235 |
+
hidden_states=x,
|
| 236 |
+
router_logits=router_logits,
|
| 237 |
+
top_k=top_k,
|
| 238 |
+
use_grouped_topk=use_grouped_topk,
|
| 239 |
+
renormalize=renormalize,
|
| 240 |
+
topk_group=topk_group,
|
| 241 |
+
num_expert_group=num_expert_group,
|
| 242 |
+
custom_routing_function=custom_routing_function,
|
| 243 |
+
scoring_func=scoring_func,
|
| 244 |
+
e_score_correction_bias=e_score_correction_bias,
|
| 245 |
+
global_num_experts=global_num_experts,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
if is_310p():
|
| 249 |
+
return fused_experts_310p(hidden_states=x,
|
| 250 |
+
w1=layer.w13_weight,
|
| 251 |
+
w1_scale=layer.w13_weight_scale,
|
| 252 |
+
w1_input_scale=layer.w13_input_scale,
|
| 253 |
+
w2=layer.w2_weight,
|
| 254 |
+
w2_scale=layer.w2_weight_scale,
|
| 255 |
+
w2_input_scale=layer.w2_input_scale,
|
| 256 |
+
topk_weights=topk_weights,
|
| 257 |
+
topk_ids=topk_ids,
|
| 258 |
+
top_k=top_k,
|
| 259 |
+
global_num_experts=global_num_experts,
|
| 260 |
+
expert_map=expert_map)
|
| 261 |
+
return fused_experts(hidden_states=x,
|
| 262 |
+
w1=layer.w13_weight,
|
| 263 |
+
w1_scale=layer.w13_weight_scale,
|
| 264 |
+
w1_input_scale=layer.w13_input_scale,
|
| 265 |
+
w1_input_offset=layer.w13_input_offset,
|
| 266 |
+
w2=layer.w2_weight,
|
| 267 |
+
w2_scale=layer.w2_weight_scale,
|
| 268 |
+
w2_input_scale=layer.w2_input_scale,
|
| 269 |
+
w2_input_offset=layer.w2_input_offset,
|
| 270 |
+
topk_weights=topk_weights,
|
| 271 |
+
topk_ids=topk_ids,
|
| 272 |
+
top_k=top_k,
|
| 273 |
+
global_num_experts=global_num_experts,
|
| 274 |
+
expert_map=expert_map)
|
| 275 |
+
|
| 276 |
+
def process_weights_after_loading(self, layer):
|
| 277 |
+
if not is_310p():
|
| 278 |
+
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
| 279 |
+
1, 2).contiguous()
|
| 280 |
+
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
| 281 |
+
1, 2).contiguous()
|
| 282 |
+
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
| 283 |
+
layer.w13_weight_scale.data.shape[0], -1)
|
| 284 |
+
|
| 285 |
+
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
| 286 |
+
layer.w13_weight_offset.data.shape[0], -1)
|
| 287 |
+
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
| 288 |
+
layer.w2_weight_scale.data.shape[0], -1)
|
| 289 |
+
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
| 290 |
+
layer.w2_weight_offset.data.shape[0], -1)
|
| 291 |
+
expanding_factor_w13 = layer.w13_weight.data.shape[1]
|
| 292 |
+
expanding_factor_w2 = layer.w2_weight.data.shape[1]
|
| 293 |
+
|
| 294 |
+
if is_310p():
|
| 295 |
+
layer.w13_input_scale.data = torch.nn.Parameter(
|
| 296 |
+
layer.w13_input_scale.data.max())
|
| 297 |
+
layer.w2_input_scale.data = torch.nn.Parameter(
|
| 298 |
+
layer.w2_input_scale.data.max())
|
| 299 |
+
else:
|
| 300 |
+
layer.w13_input_scale.data = torch.nn.Parameter(
|
| 301 |
+
layer.w13_input_scale.data.repeat(1,
|
| 302 |
+
expanding_factor_w13)[0:1])
|
| 303 |
+
layer.w2_input_scale.data = torch.nn.Parameter(
|
| 304 |
+
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
|
| 305 |
+
|
| 306 |
+
layer.w13_input_offset.data = torch.nn.Parameter(
|
| 307 |
+
layer.w13_input_scale.data.repeat(1, expanding_factor_w13)[0:1])
|
| 308 |
+
layer.w2_input_offset.data = torch.nn.Parameter(
|
| 309 |
+
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
|
| 310 |
+
|
| 311 |
+
# converting ACL_FORMAT_FRACTAL_NZ.
|
| 312 |
+
# npu_quant_grouped_matmul_dequant in eager mode does not accept
|
| 313 |
+
# ACL_FORMAT_FRACTAL_NZ.
|
| 314 |
+
if not is_310p():
|
| 315 |
+
layer.w13_weight.data = torch_npu.npu_format_cast(
|
| 316 |
+
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
|
| 317 |
+
layer.w2_weight.data = torch_npu.npu_format_cast(
|
| 318 |
+
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class AscendC8KVCacheMethod:
|
| 322 |
+
|
| 323 |
+
def __init__(self) -> None:
|
| 324 |
+
self.antiquant_scale_comb = None
|
| 325 |
+
|
| 326 |
+
@staticmethod
|
| 327 |
+
def create_weights(layer) -> None:
|
| 328 |
+
param_dict = {} # num_kv_heads * head_size
|
| 329 |
+
param_dict["key_antiquant_scale"] = torch.empty(layer.num_kv_heads *
|
| 330 |
+
layer.head_size,
|
| 331 |
+
dtype=torch.float16,
|
| 332 |
+
requires_grad=False)
|
| 333 |
+
param_dict["value_antiquant_scale"] = torch.empty(layer.num_kv_heads *
|
| 334 |
+
layer.head_size,
|
| 335 |
+
dtype=torch.float16,
|
| 336 |
+
requires_grad=False)
|
| 337 |
+
for weight_name, weight_param in param_dict.items():
|
| 338 |
+
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
| 339 |
+
layer.register_parameter(weight_name, param)
|
| 340 |
+
|
| 341 |
+
def process_weights_after_loading(self, layer):
|
| 342 |
+
self.antiquant_scale_comb = torch.cat(
|
| 343 |
+
(layer.key_antiquant_scale.data.unsqueeze(0),
|
| 344 |
+
layer.value_antiquant_scale.data.unsqueeze(0)),
|
| 345 |
+
dim=0).to(torch.float16).contiguous()
|
| 346 |
+
|
| 347 |
+
def apply(self, layer, query, key, value, kv_cache, attn_metadata,
|
| 348 |
+
attn_type, scale, output) -> torch.Tensor:
|
| 349 |
+
num_tokens = query.shape[0]
|
| 350 |
+
if attn_metadata is None:
|
| 351 |
+
return output.view(num_tokens, layer.num_heads * layer.head_size)
|
| 352 |
+
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
| 353 |
+
if attn_type != AttentionType.DECODER:
|
| 354 |
+
raise NotImplementedError("Encoder self-attention and "
|
| 355 |
+
"encoder/decoder cross-attention "
|
| 356 |
+
"are not implemented for "
|
| 357 |
+
"PallasAttentionBackendImpl")
|
| 358 |
+
|
| 359 |
+
# C8
|
| 360 |
+
quant_key = quant_per_tensor(
|
| 361 |
+
key.view(-1, layer.num_kv_heads * layer.head_size),
|
| 362 |
+
layer.key_antiquant_scale.data.view(-1), None, True)
|
| 363 |
+
quant_value = quant_per_tensor(
|
| 364 |
+
value.view(-1, layer.num_kv_heads * layer.head_size),
|
| 365 |
+
layer.value_antiquant_scale.data.view(-1), None, True)
|
| 366 |
+
|
| 367 |
+
# View q k v to BSH.
|
| 368 |
+
query = query.view(-1, layer.num_heads, layer.head_size)
|
| 369 |
+
key = key.view(-1, layer.num_kv_heads, layer.head_size)
|
| 370 |
+
value = value.view(-1, layer.num_kv_heads, layer.head_size)
|
| 371 |
+
# TODO: Remove this contiguous in the future.
|
| 372 |
+
value = value.contiguous()
|
| 373 |
+
|
| 374 |
+
if kv_cache[0].numel() > 0:
|
| 375 |
+
# if key_cache is None:
|
| 376 |
+
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
| 377 |
+
slots = attn_metadata.slot_mapping
|
| 378 |
+
|
| 379 |
+
block_size = key_cache.shape[1]
|
| 380 |
+
slots_indices = slots.reshape(-1, 1)
|
| 381 |
+
block_indices = slots_indices // block_size
|
| 382 |
+
slots_indices = slots_indices % block_size
|
| 383 |
+
indices = torch.cat((block_indices, slots_indices), dim=1)
|
| 384 |
+
|
| 385 |
+
# C8
|
| 386 |
+
torch_npu.npu_scatter_nd_update_(key_cache, indices, quant_key)
|
| 387 |
+
torch_npu.npu_scatter_nd_update_(value_cache, indices, quant_value)
|
| 388 |
+
|
| 389 |
+
# V0-Style scheduler situation.
|
| 390 |
+
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
| 391 |
+
assert attn_metadata is not None
|
| 392 |
+
assert attn_metadata.attn_mask is not None
|
| 393 |
+
mask = attn_metadata.attn_mask
|
| 394 |
+
torch_npu._npu_flash_attention(query=query,
|
| 395 |
+
key=key,
|
| 396 |
+
value=value,
|
| 397 |
+
mask=mask,
|
| 398 |
+
seq_len=attn_metadata.seq_lens,
|
| 399 |
+
scale_value=scale,
|
| 400 |
+
num_heads=layer.num_heads,
|
| 401 |
+
num_kv_heads=layer.num_kv_heads,
|
| 402 |
+
out=output.reshape(query.shape))
|
| 403 |
+
|
| 404 |
+
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
| 405 |
+
raise NotImplementedError("kv cache int8 are not "
|
| 406 |
+
"implemented for "
|
| 407 |
+
"PrefillCacheHit")
|
| 408 |
+
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # changed attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
| 409 |
+
if hasattr(attn_metadata, "decode"):
|
| 410 |
+
# torch_air
|
| 411 |
+
decode_meta = attn_metadata.decode
|
| 412 |
+
seq_lens = decode_meta.seq_lens_list
|
| 413 |
+
else:
|
| 414 |
+
seq_lens = attn_metadata.seq_lens
|
| 415 |
+
block_size = key_cache.shape[1]
|
| 416 |
+
query = query.view(num_tokens, 1, layer.num_heads *
|
| 417 |
+
layer.head_size).contiguous() # changed
|
| 418 |
+
|
| 419 |
+
# [num_blocks, block_size, N, D] --> [num_blocks, N, block_size, D]
|
| 420 |
+
key = key_cache
|
| 421 |
+
value = value_cache
|
| 422 |
+
|
| 423 |
+
output = torch_npu.npu_incre_flash_attention(
|
| 424 |
+
query,
|
| 425 |
+
key,
|
| 426 |
+
value,
|
| 427 |
+
num_key_value_heads=layer.num_kv_heads,
|
| 428 |
+
num_heads=layer.num_heads,
|
| 429 |
+
actual_seq_lengths=seq_lens,
|
| 430 |
+
scale_value=scale,
|
| 431 |
+
input_layout='BSH',
|
| 432 |
+
block_size=block_size,
|
| 433 |
+
block_table=attn_metadata.block_tables,
|
| 434 |
+
antiquant_scale=self.antiquant_scale_comb,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# Normal V1 situation.
|
| 438 |
+
else:
|
| 439 |
+
raise NotImplementedError("kv cache int8 are not "
|
| 440 |
+
"implemented for "
|
| 441 |
+
"other case")
|
| 442 |
+
return output
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def fused_experts_310p(
|
| 446 |
+
hidden_states: torch.Tensor,
|
| 447 |
+
w1: torch.Tensor,
|
| 448 |
+
w1_scale: torch.Tensor,
|
| 449 |
+
w1_input_scale: torch.Tensor,
|
| 450 |
+
w2: torch.Tensor,
|
| 451 |
+
w2_scale: torch.Tensor,
|
| 452 |
+
w2_input_scale: torch.Tensor,
|
| 453 |
+
topk_weights: torch.Tensor,
|
| 454 |
+
topk_ids: torch.Tensor,
|
| 455 |
+
top_k: int,
|
| 456 |
+
global_num_experts: int,
|
| 457 |
+
expert_map: torch.Tensor = None,
|
| 458 |
+
) -> torch.Tensor:
|
| 459 |
+
ep_size = get_ep_group().world_size
|
| 460 |
+
local_num_experts = global_num_experts // ep_size
|
| 461 |
+
local_num_group = top_k // ep_size
|
| 462 |
+
|
| 463 |
+
bsz, _ = hidden_states.shape
|
| 464 |
+
flatten_topk_ids = topk_ids.view(-1)
|
| 465 |
+
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
| 466 |
+
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
| 467 |
+
sorted_hidden_states = hidden_states.index_select(
|
| 468 |
+
0, sorted_topk_ids // local_num_group)
|
| 469 |
+
|
| 470 |
+
experts_id = torch.arange(0,
|
| 471 |
+
local_num_experts,
|
| 472 |
+
dtype=topk_ids.dtype,
|
| 473 |
+
device=topk_ids.device)
|
| 474 |
+
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
| 475 |
+
torch.float32).sum(0)
|
| 476 |
+
topk_scales = topk_weights.view(-1).index_select(
|
| 477 |
+
0, sorted_topk_ids).unsqueeze(-1)
|
| 478 |
+
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
| 479 |
+
|
| 480 |
+
gate_up_out = torch_npu.npu_quant_grouped_matmul_dequant(
|
| 481 |
+
x=sorted_hidden_states,
|
| 482 |
+
quantized_weight=w1,
|
| 483 |
+
weight_scale=w1_scale,
|
| 484 |
+
group_list=group_list,
|
| 485 |
+
x_scale=w1_input_scale,
|
| 486 |
+
quant_mode="pertensor")
|
| 487 |
+
|
| 488 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
| 489 |
+
torch.float16)
|
| 490 |
+
gate_up_out *= topk_scales
|
| 491 |
+
|
| 492 |
+
down_out = torch_npu.npu_quant_grouped_matmul_dequant(
|
| 493 |
+
x=gate_up_out,
|
| 494 |
+
quantized_weight=w2,
|
| 495 |
+
weight_scale=w2_scale,
|
| 496 |
+
group_list=group_list,
|
| 497 |
+
x_scale=w2_input_scale,
|
| 498 |
+
quant_mode="pertensor")
|
| 499 |
+
|
| 500 |
+
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
| 501 |
+
unsorted_hidden_states = down_out.index_select(0, unsorted_topk_ids)
|
| 502 |
+
final_hidden_states = unsorted_hidden_states.reshape(
|
| 503 |
+
bsz, top_k // ep_size, -1).sum(1)
|
| 504 |
+
|
| 505 |
+
return final_hidden_states
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def fused_experts(
|
| 509 |
+
hidden_states: torch.Tensor,
|
| 510 |
+
w1: torch.Tensor,
|
| 511 |
+
w1_scale: torch.Tensor,
|
| 512 |
+
w1_input_scale: torch.Tensor,
|
| 513 |
+
w1_input_offset: torch.Tensor,
|
| 514 |
+
w2: torch.Tensor,
|
| 515 |
+
w2_scale: torch.Tensor,
|
| 516 |
+
w2_input_scale: torch.Tensor,
|
| 517 |
+
w2_input_offset: torch.Tensor,
|
| 518 |
+
topk_weights: torch.Tensor,
|
| 519 |
+
topk_ids: torch.Tensor,
|
| 520 |
+
top_k: int,
|
| 521 |
+
global_num_experts: int,
|
| 522 |
+
expert_map: torch.Tensor = None,
|
| 523 |
+
) -> torch.Tensor:
|
| 524 |
+
"""
|
| 525 |
+
Fused experts with top-k routing.
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
| 529 |
+
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
| 530 |
+
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
| 531 |
+
topk_weights: Routing weights of shape (num_tokens, top_k).
|
| 532 |
+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
| 533 |
+
top_k: Number of experts to select.
|
| 534 |
+
expert_map: Expert mapping of shape (num_experts,).
|
| 535 |
+
|
| 536 |
+
Returns:
|
| 537 |
+
hidden_states: Hidden states after routing.
|
| 538 |
+
"""
|
| 539 |
+
"""
|
| 540 |
+
# Check constraints.
|
| 541 |
+
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
| 542 |
+
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
| 543 |
+
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
| 544 |
+
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
| 545 |
+
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
| 546 |
+
"""
|
| 547 |
+
|
| 548 |
+
original_dtype = hidden_states.dtype
|
| 549 |
+
ep_size = get_ep_group().world_size
|
| 550 |
+
local_num_experts = global_num_experts // ep_size
|
| 551 |
+
w1_input_scale, _ = w1_input_scale.max(0)
|
| 552 |
+
quant_sorted_hidden_states = quant_per_tensor(
|
| 553 |
+
hidden_states,
|
| 554 |
+
w1_input_scale,
|
| 555 |
+
None,
|
| 556 |
+
True,
|
| 557 |
+
)
|
| 558 |
+
if expert_map is not None:
|
| 559 |
+
expanded_x, expanded_row_idx, expert_token_count, expanded_scale = torch_npu.npu_moe_init_routing_v2(
|
| 560 |
+
quant_sorted_hidden_states,
|
| 561 |
+
topk_ids,
|
| 562 |
+
scale=None,
|
| 563 |
+
active_num=topk_ids.numel(),
|
| 564 |
+
expert_capacity=-1,
|
| 565 |
+
expert_num=local_num_experts,
|
| 566 |
+
drop_pad_mode=0,
|
| 567 |
+
expert_tokens_num_type=1,
|
| 568 |
+
expert_tokens_num_flag=True,
|
| 569 |
+
quant_mode=-1,
|
| 570 |
+
active_expert_range=[0, local_num_experts],
|
| 571 |
+
row_idx_type=0,
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
else:
|
| 575 |
+
raise NotImplementedError(
|
| 576 |
+
"The quantified version of MOE class models "
|
| 577 |
+
"currently does not support tensor parallelism")
|
| 578 |
+
if expanded_x.dtype != w1.dtype:
|
| 579 |
+
w1_input_scale, _ = w1_input_scale.max(0)
|
| 580 |
+
quant_sorted_hidden_states = quant_per_tensor(
|
| 581 |
+
expanded_x,
|
| 582 |
+
w1_input_scale,
|
| 583 |
+
None,
|
| 584 |
+
True,
|
| 585 |
+
)
|
| 586 |
+
else:
|
| 587 |
+
quant_sorted_hidden_states = expanded_x
|
| 588 |
+
gate_up_out = torch_npu.npu_grouped_matmul(
|
| 589 |
+
x=[quant_sorted_hidden_states],
|
| 590 |
+
weight=[w1],
|
| 591 |
+
scale=[w1_scale * w1_input_scale[0]],
|
| 592 |
+
split_item=2,
|
| 593 |
+
group_list_type=1,
|
| 594 |
+
group_type=0,
|
| 595 |
+
group_list=expert_token_count,
|
| 596 |
+
output_dtype=original_dtype,
|
| 597 |
+
)[0]
|
| 598 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
| 599 |
+
|
| 600 |
+
if gate_up_out.dtype != w2.dtype:
|
| 601 |
+
w2_input_scale, _ = w2_input_scale.max(0)
|
| 602 |
+
quant_gate_up_out = quant_per_tensor(
|
| 603 |
+
gate_up_out,
|
| 604 |
+
w2_input_scale,
|
| 605 |
+
None,
|
| 606 |
+
True,
|
| 607 |
+
)
|
| 608 |
+
else:
|
| 609 |
+
quant_gate_up_out = gate_up_out
|
| 610 |
+
|
| 611 |
+
down_out = torch_npu.npu_grouped_matmul(
|
| 612 |
+
x=[quant_gate_up_out],
|
| 613 |
+
weight=[w2],
|
| 614 |
+
scale=[w2_scale * w2_input_scale[0]],
|
| 615 |
+
split_item=2,
|
| 616 |
+
group_list_type=1,
|
| 617 |
+
group_type=0,
|
| 618 |
+
group_list=expert_token_count,
|
| 619 |
+
output_dtype=original_dtype,
|
| 620 |
+
)[0]
|
| 621 |
+
|
| 622 |
+
if expert_map is not None:
|
| 623 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 624 |
+
down_out,
|
| 625 |
+
skip1=None,
|
| 626 |
+
skip2=None,
|
| 627 |
+
bias=None,
|
| 628 |
+
scales=topk_weights.to(down_out.dtype),
|
| 629 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 630 |
+
export_for_source_row=topk_ids,
|
| 631 |
+
drop_pad_mode=2,
|
| 632 |
+
)
|
| 633 |
+
else:
|
| 634 |
+
raise NotImplementedError(
|
| 635 |
+
"The quantified version of MOE class models "
|
| 636 |
+
"currently does not support tensor parallelism")
|
| 637 |
+
|
| 638 |
+
return final_hidden_states
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
def select_experts(
|
| 642 |
+
hidden_states: torch.Tensor,
|
| 643 |
+
router_logits: torch.Tensor,
|
| 644 |
+
top_k: int,
|
| 645 |
+
use_grouped_topk: bool,
|
| 646 |
+
renormalize: bool,
|
| 647 |
+
topk_group: Optional[int] = None,
|
| 648 |
+
num_expert_group: Optional[int] = None,
|
| 649 |
+
custom_routing_function: Optional[Callable] = None,
|
| 650 |
+
scoring_func: str = "softmax",
|
| 651 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 652 |
+
global_num_experts=-1,
|
| 653 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 654 |
+
"""
|
| 655 |
+
Select top-k experts based on router logits.
|
| 656 |
+
|
| 657 |
+
Args:
|
| 658 |
+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
| 659 |
+
router_logits: Router logits of shape (num_tokens, num_experts).
|
| 660 |
+
top_k: Number of experts to select.
|
| 661 |
+
use_grouped_topk: Whether to group experts before selecting top-k.
|
| 662 |
+
renormalize: Whether to renormalize the routing weights.
|
| 663 |
+
topk_group: Number of expert groups to select from.
|
| 664 |
+
num_expert_group: Number of experts in each group.
|
| 665 |
+
custom_routing_function: Custom routing function.
|
| 666 |
+
scoring_func: Scoring function to use.
|
| 667 |
+
e_score_correction_bias: Correction bias to apply to expert scores.
|
| 668 |
+
|
| 669 |
+
Returns:
|
| 670 |
+
topk_weights: Routing weights of shape (num_tokens, top_k).
|
| 671 |
+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
| 672 |
+
|
| 673 |
+
Raises:
|
| 674 |
+
ValueError: If an unsupported scoring function is provided.
|
| 675 |
+
"""
|
| 676 |
+
|
| 677 |
+
if scoring_func == "softmax":
|
| 678 |
+
# NOTE: vLLM use dtype=torch.float here
|
| 679 |
+
topk_weights = router_logits.softmax(dim=-1)
|
| 680 |
+
elif scoring_func == "sigmoid":
|
| 681 |
+
topk_weights = router_logits.sigmoid()
|
| 682 |
+
else:
|
| 683 |
+
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
| 684 |
+
|
| 685 |
+
if use_grouped_topk:
|
| 686 |
+
assert topk_group is not None
|
| 687 |
+
assert num_expert_group is not None
|
| 688 |
+
|
| 689 |
+
if e_score_correction_bias is not None:
|
| 690 |
+
# Store original scores before applying correction bias. We use biased
|
| 691 |
+
# scores for expert selection but original scores for routing weights
|
| 692 |
+
original_weights = topk_weights
|
| 693 |
+
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
|
| 694 |
+
|
| 695 |
+
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
|
| 696 |
+
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
|
| 697 |
+
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
|
| 698 |
+
topk_group)
|
| 699 |
+
# TODO bfloat16 is not supported in torch.topk with ge graph.
|
| 700 |
+
if e_score_correction_bias is not None:
|
| 701 |
+
topk_ids = torch.topk(topk_weights.to(torch.float32),
|
| 702 |
+
k=top_k,
|
| 703 |
+
dim=-1,
|
| 704 |
+
sorted=False)[1]
|
| 705 |
+
# Use original unbiased scores for the routing weights
|
| 706 |
+
topk_weights = original_weights.gather(1, topk_ids)
|
| 707 |
+
else:
|
| 708 |
+
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
|
| 709 |
+
k=top_k,
|
| 710 |
+
dim=-1,
|
| 711 |
+
sorted=False)
|
| 712 |
+
elif custom_routing_function is None:
|
| 713 |
+
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
| 714 |
+
else:
|
| 715 |
+
topk_weights, topk_ids = custom_routing_function(
|
| 716 |
+
hidden_states=hidden_states,
|
| 717 |
+
gating_output=router_logits,
|
| 718 |
+
topk=top_k,
|
| 719 |
+
renormalize=renormalize,
|
| 720 |
+
global_num_experts=global_num_experts,
|
| 721 |
+
)
|
| 722 |
+
# Required by npu_moe_init_routing
|
| 723 |
+
topk_ids = topk_ids.to(torch.int32)
|
| 724 |
+
return topk_weights, topk_ids
|
| 725 |
+
|
| 726 |
+
# Required by npu_moe_init_routing
|
| 727 |
+
topk_ids = topk_ids.to(torch.int32)
|
| 728 |
+
|
| 729 |
+
if renormalize:
|
| 730 |
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
| 731 |
+
|
| 732 |
+
return topk_weights, topk_ids
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
def native_grouped_topk(
|
| 736 |
+
topk_weights: torch.Tensor,
|
| 737 |
+
num_expert_group: Optional[int],
|
| 738 |
+
topk_group: Optional[int],
|
| 739 |
+
):
|
| 740 |
+
topk_group = 0 if topk_group is None else topk_group
|
| 741 |
+
num_expert_group = 0 if num_expert_group is None else num_expert_group
|
| 742 |
+
|
| 743 |
+
num_token = topk_weights.shape[0]
|
| 744 |
+
grouped_weights = topk_weights.view(num_token, num_expert_group,
|
| 745 |
+
-1).max(dim=-1).values
|
| 746 |
+
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
|
| 747 |
+
k=topk_group,
|
| 748 |
+
dim=-1,
|
| 749 |
+
sorted=False)[1]
|
| 750 |
+
topk_group_mask = torch.zeros_like(grouped_weights)
|
| 751 |
+
topk_group_mask.scatter_(1, topk_group_indices, 1)
|
| 752 |
+
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
|
| 753 |
+
num_token, num_expert_group,
|
| 754 |
+
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
|
| 755 |
+
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
|
| 756 |
+
|
| 757 |
+
return topk_weights
|
inference/vllm_ascend/quantization/w8a8_dynamic.py
ADDED
|
@@ -0,0 +1,831 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# This file is a part of the vllm-ascend project.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.distributed as dist
|
| 22 |
+
import torch_npu
|
| 23 |
+
from vllm.distributed import GroupCoordinator
|
| 24 |
+
|
| 25 |
+
import vllm_ascend.envs as envs
|
| 26 |
+
from vllm_ascend.ascend_config import get_ascend_config
|
| 27 |
+
from vllm_ascend.distributed.parallel_state import get_ep_group
|
| 28 |
+
from vllm_ascend.ops.fused_moe import select_experts
|
| 29 |
+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, FusedMoEState,
|
| 30 |
+
dispose_tensor, get_fused_moe_state,
|
| 31 |
+
npu_stream_switch, npu_wait_tensor)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def apply_mlp(hidden_states: torch.Tensor,
|
| 35 |
+
w1: torch.Tensor,
|
| 36 |
+
w1_scale: torch.Tensor,
|
| 37 |
+
w2: torch.Tensor,
|
| 38 |
+
w2_scale: torch.Tensor,
|
| 39 |
+
group_list: torch.Tensor,
|
| 40 |
+
dynamic_scale: torch.Tensor = None,
|
| 41 |
+
group_list_type: int = 1) -> torch.Tensor:
|
| 42 |
+
"""
|
| 43 |
+
apply MLP: gate_up_proj -> swiglu -> down_proj
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
hidden_states: input hidden states with shape (num_tokens, hidden_size).
|
| 47 |
+
w1: expert weights1 with shape
|
| 48 |
+
(num_experts, hidden_size, intermediate_size * 2)
|
| 49 |
+
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
|
| 50 |
+
w2: expert weights2 with shape
|
| 51 |
+
(num_experts, intermediate_size, hidden_size)
|
| 52 |
+
w2_scale: weights2 scale with shape (num_experts, hidden_size)
|
| 53 |
+
group_list: number of tokens for each expert, follow cumsum mode, and
|
| 54 |
+
with shape (num_experts).
|
| 55 |
+
transpose_weight:
|
| 56 |
+
w1: (num_experts, intermediate_size * 2, hidden_size) ->
|
| 57 |
+
(num_experts, hidden_size, intermediate_size * 2)
|
| 58 |
+
w2: (num_experts, hidden_size, intermediate_size) ->
|
| 59 |
+
(num_experts, intermediate_size, hidden_size)
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
hidden_states: output hidden states after MLP.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
if dynamic_scale is None:
|
| 66 |
+
unquantized_hidden_states = hidden_states
|
| 67 |
+
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
| 68 |
+
hidden_states)
|
| 69 |
+
# Dispose the original unquantized hidden states
|
| 70 |
+
# to save npu memory because they're no longer used.
|
| 71 |
+
dispose_tensor(unquantized_hidden_states)
|
| 72 |
+
else:
|
| 73 |
+
pertoken_scale = dynamic_scale
|
| 74 |
+
|
| 75 |
+
# gmm1: gate_up_proj
|
| 76 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 77 |
+
x=[hidden_states],
|
| 78 |
+
weight=[w1],
|
| 79 |
+
scale=[w1_scale],
|
| 80 |
+
per_token_scale=[pertoken_scale],
|
| 81 |
+
split_item=2,
|
| 82 |
+
group_list_type=group_list_type,
|
| 83 |
+
group_type=0,
|
| 84 |
+
group_list=group_list,
|
| 85 |
+
output_dtype=w2_scale.dtype)[0]
|
| 86 |
+
|
| 87 |
+
# act_fn: swiglu
|
| 88 |
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
| 89 |
+
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
| 90 |
+
hidden_states)
|
| 91 |
+
|
| 92 |
+
# gmm2: down_proj
|
| 93 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 94 |
+
x=[hidden_states],
|
| 95 |
+
weight=[w2],
|
| 96 |
+
scale=[w2_scale],
|
| 97 |
+
per_token_scale=[swiglu_out_scale],
|
| 98 |
+
split_item=2,
|
| 99 |
+
group_list_type=group_list_type,
|
| 100 |
+
group_type=0,
|
| 101 |
+
group_list=group_list,
|
| 102 |
+
output_dtype=w2_scale.dtype)[0]
|
| 103 |
+
|
| 104 |
+
return hidden_states
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def fused_experts_with_mc2(
|
| 108 |
+
hidden_states: torch.Tensor,
|
| 109 |
+
w1: torch.Tensor,
|
| 110 |
+
w2: torch.Tensor,
|
| 111 |
+
w1_scale: torch.Tensor,
|
| 112 |
+
w2_scale: torch.Tensor,
|
| 113 |
+
topk_weights: torch.Tensor,
|
| 114 |
+
topk_ids: torch.Tensor,
|
| 115 |
+
top_k: int,
|
| 116 |
+
expert_map: torch.Tensor = None,
|
| 117 |
+
moe_all_to_all_group_name: str = "",
|
| 118 |
+
log2phy: torch.Tensor = None,
|
| 119 |
+
global_redundant_expert_num: int = 0,
|
| 120 |
+
shared_experts: Optional[Any] = None,
|
| 121 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 122 |
+
if log2phy is not None:
|
| 123 |
+
topk_ids = log2phy[topk_ids]
|
| 124 |
+
global_bs = 0
|
| 125 |
+
moe_expert_num = len(expert_map) + global_redundant_expert_num
|
| 126 |
+
# hidden_states = hidden_states.bfloat16()
|
| 127 |
+
kwargs_mc2 = {
|
| 128 |
+
"x": hidden_states,
|
| 129 |
+
"expert_ids": topk_ids,
|
| 130 |
+
"expert_shard_type": 0,
|
| 131 |
+
"shared_expert_rank_num": 0,
|
| 132 |
+
"moe_expert_num": moe_expert_num,
|
| 133 |
+
"global_bs": global_bs,
|
| 134 |
+
"expert_scales": topk_weights.to(torch.float32),
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
rank = torch.distributed.get_rank()
|
| 138 |
+
|
| 139 |
+
quant_mode = 2
|
| 140 |
+
ep_group = get_ep_group().device_group
|
| 141 |
+
local_rank = torch.distributed.get_rank(group=ep_group)
|
| 142 |
+
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
|
| 143 |
+
|
| 144 |
+
world_size = torch.distributed.get_world_size()
|
| 145 |
+
tp_size = world_size // all_to_all_group_size
|
| 146 |
+
tp_rank = rank % tp_size
|
| 147 |
+
|
| 148 |
+
stage1_kwargs = {
|
| 149 |
+
"scales": None,
|
| 150 |
+
"quant_mode": quant_mode,
|
| 151 |
+
"group_ep": moe_all_to_all_group_name,
|
| 152 |
+
"ep_world_size": all_to_all_group_size,
|
| 153 |
+
"ep_rank_id": local_rank,
|
| 154 |
+
# "group_tp": self.moe_rs_group_name,
|
| 155 |
+
"group_tp": moe_all_to_all_group_name,
|
| 156 |
+
"tp_world_size": tp_size,
|
| 157 |
+
"tp_rank_id": tp_rank,
|
| 158 |
+
}
|
| 159 |
+
kwargs_mc2.update(stage1_kwargs)
|
| 160 |
+
|
| 161 |
+
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
| 162 |
+
# comm_stream.wait_stream(torch.npu.current_stream())
|
| 163 |
+
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts, _, expand_scales = output[
|
| 164 |
+
0:7]
|
| 165 |
+
|
| 166 |
+
if shared_experts is not None:
|
| 167 |
+
with npu_stream_switch("moe_secondary", 0):
|
| 168 |
+
npu_wait_tensor(hidden_states, topk_weights)
|
| 169 |
+
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
| 170 |
+
npu_wait_tensor(shared_gate_up[0], expand_x)
|
| 171 |
+
shared_act = shared_experts.act_fn(shared_gate_up)
|
| 172 |
+
|
| 173 |
+
# `expand_x` will be disposed in the `apply_mlp` function
|
| 174 |
+
down_out_list = apply_mlp(expand_x,
|
| 175 |
+
w1,
|
| 176 |
+
w1_scale,
|
| 177 |
+
w2,
|
| 178 |
+
w2_scale,
|
| 179 |
+
expert_token_nums,
|
| 180 |
+
dynamic_scale=dynamic_scale)
|
| 181 |
+
|
| 182 |
+
# moeCombine
|
| 183 |
+
kwargs_mc2 = {
|
| 184 |
+
"expand_x": down_out_list,
|
| 185 |
+
"expert_ids": topk_ids,
|
| 186 |
+
"expand_idx": expand_idx,
|
| 187 |
+
"expert_scales": topk_weights.to(torch.float32),
|
| 188 |
+
"expert_shard_type": 0,
|
| 189 |
+
"shared_expert_rank_num": 0,
|
| 190 |
+
"moe_expert_num": moe_expert_num,
|
| 191 |
+
"global_bs": 0,
|
| 192 |
+
"expand_scales": expand_scales,
|
| 193 |
+
}
|
| 194 |
+
tp_recv_counts = torch.empty(1,
|
| 195 |
+
dtype=torch.int32,
|
| 196 |
+
device=hidden_states.device)
|
| 197 |
+
stage3_kwargs = {
|
| 198 |
+
"ep_send_counts": ep_recv_counts,
|
| 199 |
+
"group_ep": moe_all_to_all_group_name,
|
| 200 |
+
"ep_world_size": all_to_all_group_size,
|
| 201 |
+
"ep_rank_id": local_rank,
|
| 202 |
+
"tp_send_counts": tp_recv_counts,
|
| 203 |
+
# "group_tp": self.moe_rs_group_name,
|
| 204 |
+
"group_tp": moe_all_to_all_group_name,
|
| 205 |
+
"tp_world_size": tp_size,
|
| 206 |
+
"tp_rank_id": tp_rank,
|
| 207 |
+
}
|
| 208 |
+
kwargs_mc2.update(stage3_kwargs)
|
| 209 |
+
|
| 210 |
+
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
| 211 |
+
|
| 212 |
+
if shared_experts is None:
|
| 213 |
+
return hidden_states
|
| 214 |
+
else:
|
| 215 |
+
with npu_stream_switch("moe_secondary", 0):
|
| 216 |
+
npu_wait_tensor(shared_act[0], down_out_list)
|
| 217 |
+
shared_output, _ = shared_experts.down_proj(shared_act)
|
| 218 |
+
return hidden_states, shared_output
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# currently expert parallelism implemented with all2all
|
| 222 |
+
# is under-optimized.
|
| 223 |
+
def fused_experts_with_all2all(
|
| 224 |
+
hidden_states: torch.Tensor,
|
| 225 |
+
w1: torch.Tensor,
|
| 226 |
+
w1_scale: torch.Tensor,
|
| 227 |
+
w2: torch.Tensor,
|
| 228 |
+
w2_scale: torch.Tensor,
|
| 229 |
+
topk_weights: torch.Tensor,
|
| 230 |
+
topk_ids: torch.Tensor,
|
| 231 |
+
top_k: int,
|
| 232 |
+
expert_map: torch.Tensor = None,
|
| 233 |
+
ep_group: GroupCoordinator = None,
|
| 234 |
+
log2phy: torch.Tensor = None,
|
| 235 |
+
global_redundant_expert_num: int = 0,
|
| 236 |
+
):
|
| 237 |
+
if log2phy is not None:
|
| 238 |
+
topk_ids = log2phy[topk_ids]
|
| 239 |
+
original_shape = hidden_states.shape
|
| 240 |
+
if len(original_shape) == 3:
|
| 241 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 242 |
+
|
| 243 |
+
num_tokens, _ = hidden_states.shape
|
| 244 |
+
num_experts = w1.shape[0]
|
| 245 |
+
device = hidden_states.device
|
| 246 |
+
|
| 247 |
+
if expert_map is not None:
|
| 248 |
+
global_num_experts = len(expert_map) + global_redundant_expert_num
|
| 249 |
+
local_num_experts = global_num_experts // ep_group.world_size
|
| 250 |
+
row_idx_len = num_tokens * top_k
|
| 251 |
+
row_idx = (torch.arange(0,
|
| 252 |
+
row_idx_len,
|
| 253 |
+
dtype=torch.int32,
|
| 254 |
+
device=device).view(top_k, -1).permute(
|
| 255 |
+
1, 0).contiguous())
|
| 256 |
+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 257 |
+
hidden_states,
|
| 258 |
+
row_idx=row_idx,
|
| 259 |
+
expert_idx=topk_ids,
|
| 260 |
+
active_num=num_tokens)
|
| 261 |
+
|
| 262 |
+
global_expert_tokens = torch.bincount(expanded_expert_idx,
|
| 263 |
+
minlength=global_num_experts)
|
| 264 |
+
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
|
| 265 |
+
-1).sum(-1)
|
| 266 |
+
|
| 267 |
+
gather_sizes = torch.empty_like(scatter_sizes)
|
| 268 |
+
dist.all_to_all_single(gather_sizes,
|
| 269 |
+
scatter_sizes,
|
| 270 |
+
group=ep_group.device_group)
|
| 271 |
+
scatter_size_list = scatter_sizes.cpu().tolist()
|
| 272 |
+
gather_size_list = gather_sizes.cpu().tolist()
|
| 273 |
+
|
| 274 |
+
expanded_expert_idx = expanded_expert_idx % local_num_experts
|
| 275 |
+
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
| 276 |
+
scatter_size_list,
|
| 277 |
+
gather_size_list)
|
| 278 |
+
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
|
| 279 |
+
scatter_size_list,
|
| 280 |
+
gather_size_list)
|
| 281 |
+
|
| 282 |
+
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
|
| 283 |
+
|
| 284 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 285 |
+
sorted_local_expert_idx, local_num_experts).to(torch.int64)
|
| 286 |
+
|
| 287 |
+
hidden_states = hidden_states[sorted_idx]
|
| 288 |
+
group_list_type = 0
|
| 289 |
+
else:
|
| 290 |
+
row_idx_len = num_tokens * top_k
|
| 291 |
+
row_idx = torch.arange(0,
|
| 292 |
+
row_idx_len,
|
| 293 |
+
dtype=torch.int32,
|
| 294 |
+
device=topk_weights.device).view(
|
| 295 |
+
top_k, -1).permute(1, 0).contiguous()
|
| 296 |
+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 297 |
+
hidden_states,
|
| 298 |
+
row_idx=row_idx,
|
| 299 |
+
expert_idx=topk_ids,
|
| 300 |
+
active_num=num_tokens)
|
| 301 |
+
|
| 302 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 303 |
+
expanded_expert_idx, num_experts)
|
| 304 |
+
expert_tokens = expert_tokens.to(torch.int64)
|
| 305 |
+
group_list_type = 0
|
| 306 |
+
|
| 307 |
+
# `hidden_states` will be disposed in the `apply_mlp` function
|
| 308 |
+
hidden_states = apply_mlp(
|
| 309 |
+
hidden_states,
|
| 310 |
+
w1,
|
| 311 |
+
w1_scale, #17
|
| 312 |
+
w2,
|
| 313 |
+
w2_scale,
|
| 314 |
+
expert_tokens, #16
|
| 315 |
+
group_list_type=group_list_type)
|
| 316 |
+
|
| 317 |
+
if expert_map is not None:
|
| 318 |
+
resorted_idx = torch.argsort(sorted_idx)
|
| 319 |
+
hidden_states = hidden_states[resorted_idx]
|
| 320 |
+
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
| 321 |
+
gather_size_list,
|
| 322 |
+
scatter_size_list)
|
| 323 |
+
|
| 324 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 325 |
+
hidden_states,
|
| 326 |
+
skip1=None,
|
| 327 |
+
skip2=None,
|
| 328 |
+
bias=None,
|
| 329 |
+
scales=topk_weights,
|
| 330 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 331 |
+
export_for_source_row=topk_ids,
|
| 332 |
+
)
|
| 333 |
+
else:
|
| 334 |
+
# TODO: Reorder device memory 2 times here, replace the current
|
| 335 |
+
# implementation here when suitable operators become available.
|
| 336 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 337 |
+
hidden_states,
|
| 338 |
+
skip1=None,
|
| 339 |
+
skip2=None,
|
| 340 |
+
bias=None,
|
| 341 |
+
scales=topk_weights,
|
| 342 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 343 |
+
export_for_source_row=topk_ids,
|
| 344 |
+
)
|
| 345 |
+
if len(original_shape) == 3:
|
| 346 |
+
final_hidden_states = final_hidden_states.view(original_shape)
|
| 347 |
+
return final_hidden_states
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def fused_experts_with_allgather(hidden_states: torch.Tensor,
|
| 351 |
+
w1: torch.Tensor,
|
| 352 |
+
w1_scale: torch.Tensor,
|
| 353 |
+
w2: torch.Tensor,
|
| 354 |
+
w2_scale: torch.Tensor,
|
| 355 |
+
topk_weights: torch.Tensor,
|
| 356 |
+
topk_ids: torch.Tensor,
|
| 357 |
+
top_k: int,
|
| 358 |
+
expert_map: torch.Tensor = None):
|
| 359 |
+
original_shape = hidden_states.shape
|
| 360 |
+
if len(original_shape) == 3:
|
| 361 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 362 |
+
num_tokens = hidden_states.shape[0]
|
| 363 |
+
batch_size, hidden_size = hidden_states.shape
|
| 364 |
+
|
| 365 |
+
ep_group = get_ep_group().device_group
|
| 366 |
+
ep_rank = torch.distributed.get_rank(group=ep_group)
|
| 367 |
+
ep_size = torch.distributed.get_world_size(ep_group)
|
| 368 |
+
|
| 369 |
+
global_num_experts = len(expert_map)
|
| 370 |
+
local_num_experts = global_num_experts // ep_size
|
| 371 |
+
|
| 372 |
+
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
| 373 |
+
|
| 374 |
+
hidden_states, expanded_x_idx, expert_tokens, pertoken_scale = torch_npu.npu_moe_init_routing_v2(
|
| 375 |
+
hidden_states,
|
| 376 |
+
topk_ids,
|
| 377 |
+
scale=pertoken_scale,
|
| 378 |
+
offset=None,
|
| 379 |
+
active_num=num_tokens * top_k,
|
| 380 |
+
expert_num=global_num_experts,
|
| 381 |
+
expert_tokens_num_type=1,
|
| 382 |
+
expert_tokens_num_flag=True,
|
| 383 |
+
active_expert_range=[
|
| 384 |
+
ep_rank * local_num_experts, (ep_rank + 1) * local_num_experts
|
| 385 |
+
],
|
| 386 |
+
quant_mode=-1,
|
| 387 |
+
row_idx_type=0)
|
| 388 |
+
group_list_type = 1
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 392 |
+
x=[hidden_states],
|
| 393 |
+
weight=[w1],
|
| 394 |
+
split_item=3,
|
| 395 |
+
group_list_type=group_list_type,
|
| 396 |
+
group_type=0,
|
| 397 |
+
group_list=expert_tokens,
|
| 398 |
+
output_dtype=torch.int32)[0]
|
| 399 |
+
|
| 400 |
+
# act_fn: swiglu
|
| 401 |
+
hidden_states, pertoken_scale = torch_npu.npu_dequant_swiglu_quant(
|
| 402 |
+
x=hidden_states,
|
| 403 |
+
weight_scale=w1_scale.to(torch.float32),
|
| 404 |
+
activation_scale=pertoken_scale,
|
| 405 |
+
bias=None,
|
| 406 |
+
quant_scale=None,
|
| 407 |
+
quant_offset=None,
|
| 408 |
+
group_index=expert_tokens,
|
| 409 |
+
activate_left=True,
|
| 410 |
+
quant_mode=1,
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 414 |
+
x=[hidden_states],
|
| 415 |
+
weight=[w2],
|
| 416 |
+
scale=[w2_scale.to(torch.bfloat16)],
|
| 417 |
+
per_token_scale=[pertoken_scale.view(-1)],
|
| 418 |
+
split_item=3,
|
| 419 |
+
group_list_type=group_list_type,
|
| 420 |
+
group_type=0,
|
| 421 |
+
group_list=expert_tokens,
|
| 422 |
+
output_dtype=torch.bfloat16)[0]
|
| 423 |
+
|
| 424 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 425 |
+
expanded_permuted_rows=hidden_states.unsqueeze(1),
|
| 426 |
+
skip1=None,
|
| 427 |
+
skip2=None,
|
| 428 |
+
bias=None,
|
| 429 |
+
scales=topk_weights.to(torch.bfloat16),
|
| 430 |
+
expanded_src_to_dst_row=expanded_x_idx.to(torch.int32),
|
| 431 |
+
export_for_source_row=topk_ids,
|
| 432 |
+
drop_pad_mode=3
|
| 433 |
+
).to(torch.bfloat16)
|
| 434 |
+
|
| 435 |
+
if len(original_shape) == 3:
|
| 436 |
+
final_hidden_states = final_hidden_states.view(original_shape)
|
| 437 |
+
|
| 438 |
+
return final_hidden_states
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def fused_experts(hidden_states: torch.Tensor,
|
| 442 |
+
w1: torch.Tensor,
|
| 443 |
+
w1_scale: torch.Tensor,
|
| 444 |
+
w2: torch.Tensor,
|
| 445 |
+
w2_scale: torch.Tensor,
|
| 446 |
+
topk_weights: torch.Tensor,
|
| 447 |
+
topk_ids: torch.Tensor,
|
| 448 |
+
top_k: int,
|
| 449 |
+
expert_map: torch.Tensor = None):
|
| 450 |
+
original_shape = hidden_states.shape
|
| 451 |
+
if len(original_shape) == 3:
|
| 452 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 453 |
+
|
| 454 |
+
num_tokens, _ = hidden_states.shape
|
| 455 |
+
num_experts = w1.shape[0]
|
| 456 |
+
dtype = hidden_states.dtype
|
| 457 |
+
device = hidden_states.device
|
| 458 |
+
|
| 459 |
+
if expert_map is not None:
|
| 460 |
+
# Generate token indices and flatten
|
| 461 |
+
token_indices = (torch.arange(num_tokens,
|
| 462 |
+
device=device,
|
| 463 |
+
dtype=torch.int64).unsqueeze(1).expand(
|
| 464 |
+
-1, top_k).reshape(-1))
|
| 465 |
+
|
| 466 |
+
# Flatten token-to-expert mappings and map to local experts
|
| 467 |
+
weights_flat = topk_weights.view(-1)
|
| 468 |
+
experts_flat = topk_ids.view(-1)
|
| 469 |
+
local_experts_flat = expert_map[experts_flat]
|
| 470 |
+
|
| 471 |
+
# Filter valid token-expert pairs
|
| 472 |
+
mask = local_experts_flat != -1
|
| 473 |
+
filtered_weights = torch.where(
|
| 474 |
+
mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype)
|
| 475 |
+
filtered_experts = torch.where(
|
| 476 |
+
mask, local_experts_flat,
|
| 477 |
+
torch.full_like(local_experts_flat,
|
| 478 |
+
num_experts)).to(topk_ids.dtype)
|
| 479 |
+
|
| 480 |
+
# Sort by local expert IDs
|
| 481 |
+
sort_indices = torch.argsort(filtered_experts)
|
| 482 |
+
sorted_token_indices = token_indices[sort_indices]
|
| 483 |
+
sorted_weights = filtered_weights[sort_indices]
|
| 484 |
+
|
| 485 |
+
# Compute token counts with minlength of num_experts
|
| 486 |
+
# This is equivalent to but faster than:
|
| 487 |
+
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
| 488 |
+
token_counts = torch.zeros(num_experts + 1,
|
| 489 |
+
device=device,
|
| 490 |
+
dtype=torch.int64)
|
| 491 |
+
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
| 492 |
+
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
| 493 |
+
expert_tokens = token_counts[:num_experts]
|
| 494 |
+
# Rearrange hidden_states
|
| 495 |
+
hidden_states = hidden_states[sorted_token_indices]
|
| 496 |
+
group_list_type = 1
|
| 497 |
+
else:
|
| 498 |
+
row_idx_len = num_tokens * top_k
|
| 499 |
+
row_idx = torch.arange(0,
|
| 500 |
+
row_idx_len,
|
| 501 |
+
dtype=torch.int32,
|
| 502 |
+
device=topk_weights.device).view(
|
| 503 |
+
top_k, -1).permute(1, 0).contiguous()
|
| 504 |
+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 505 |
+
hidden_states,
|
| 506 |
+
row_idx=row_idx,
|
| 507 |
+
expert_idx=topk_ids,
|
| 508 |
+
active_num=num_tokens)
|
| 509 |
+
|
| 510 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 511 |
+
expanded_expert_idx, num_experts)
|
| 512 |
+
expert_tokens = expert_tokens.to(torch.int64)
|
| 513 |
+
group_list_type = 0
|
| 514 |
+
|
| 515 |
+
# `hidden_states` will be disposed in the `apply_mlp` function
|
| 516 |
+
hidden_states = apply_mlp(hidden_states,
|
| 517 |
+
w1,
|
| 518 |
+
w1_scale,
|
| 519 |
+
w2,
|
| 520 |
+
w2_scale,
|
| 521 |
+
expert_tokens,
|
| 522 |
+
group_list_type=group_list_type)
|
| 523 |
+
|
| 524 |
+
if expert_map is not None:
|
| 525 |
+
hidden_states.mul_(sorted_weights.unsqueeze(1))
|
| 526 |
+
final_hidden_states = torch.zeros(*original_shape,
|
| 527 |
+
device=device,
|
| 528 |
+
dtype=dtype)
|
| 529 |
+
|
| 530 |
+
num_valid_tokens = mask.sum()
|
| 531 |
+
valid_token_mask = torch.arange(
|
| 532 |
+
0, sorted_token_indices.shape[0],
|
| 533 |
+
device=device).unsqueeze(1) < num_valid_tokens
|
| 534 |
+
hidden_states = hidden_states.masked_fill_(~valid_token_mask,
|
| 535 |
+
0).to(dtype)
|
| 536 |
+
final_hidden_states.index_add_(0, sorted_token_indices, hidden_states)
|
| 537 |
+
else:
|
| 538 |
+
# TODO: Reorder device memory 2 times here, replace the current
|
| 539 |
+
# implementation here when suitable operators become available.
|
| 540 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 541 |
+
hidden_states,
|
| 542 |
+
skip1=None,
|
| 543 |
+
skip2=None,
|
| 544 |
+
bias=None,
|
| 545 |
+
scales=topk_weights,
|
| 546 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 547 |
+
export_for_source_row=topk_ids,
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
if len(original_shape) == 3:
|
| 551 |
+
final_hidden_states = final_hidden_states.view(original_shape)
|
| 552 |
+
return final_hidden_states
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
class AscendW8A8DynamicLinearMethod:
|
| 556 |
+
"""Linear method for Ascend W8A8_DYNAMIC.
|
| 557 |
+
"""
|
| 558 |
+
|
| 559 |
+
def __init__(self):
|
| 560 |
+
self.transpose_weight = True
|
| 561 |
+
|
| 562 |
+
@staticmethod
|
| 563 |
+
def get_weight(input_size: int, output_size: int,
|
| 564 |
+
params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 565 |
+
params_dict = {
|
| 566 |
+
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
| 567 |
+
}
|
| 568 |
+
return params_dict
|
| 569 |
+
|
| 570 |
+
@staticmethod
|
| 571 |
+
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 572 |
+
return {}
|
| 573 |
+
|
| 574 |
+
@staticmethod
|
| 575 |
+
def get_perchannel_param(
|
| 576 |
+
output_size: int,
|
| 577 |
+
params_dtype: torch.dtype,
|
| 578 |
+
) -> Dict[str, Any]:
|
| 579 |
+
params_dict = {}
|
| 580 |
+
params_dict["weight_scale"] = torch.empty(output_size,
|
| 581 |
+
1,
|
| 582 |
+
dtype=params_dtype)
|
| 583 |
+
params_dict["weight_offset"] = torch.empty(output_size,
|
| 584 |
+
1,
|
| 585 |
+
dtype=params_dtype)
|
| 586 |
+
return params_dict
|
| 587 |
+
|
| 588 |
+
@staticmethod
|
| 589 |
+
def apply(
|
| 590 |
+
layer: torch.nn.Module,
|
| 591 |
+
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
| 592 |
+
bias: Optional[torch.Tensor] = None,
|
| 593 |
+
tp_rank: Optional[int] = 0,
|
| 594 |
+
) -> torch.Tensor:
|
| 595 |
+
config = getattr(layer, "_ascend_quant_config", {})
|
| 596 |
+
if not isinstance(x, tuple):
|
| 597 |
+
output_dtype = config.get("output_dtype", x.dtype)
|
| 598 |
+
quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
| 599 |
+
else:
|
| 600 |
+
assert "output_dtype" in config.keys(), (
|
| 601 |
+
f"DynamicLinearMethod needs explicitly specified `output_dtype`"
|
| 602 |
+
f"for pre-quantized input, got config [{config}]")
|
| 603 |
+
output_dtype = config["output_dtype"]
|
| 604 |
+
quantized_x, dynamic_scale = x
|
| 605 |
+
pertoken_scale = (dynamic_scale
|
| 606 |
+
if config.get("pertoken_scale", True) else None)
|
| 607 |
+
|
| 608 |
+
output = torch_npu.npu_quant_matmul(
|
| 609 |
+
quantized_x,
|
| 610 |
+
layer.weight,
|
| 611 |
+
layer.weight_scale,
|
| 612 |
+
pertoken_scale=pertoken_scale,
|
| 613 |
+
bias=bias,
|
| 614 |
+
output_dtype=output_dtype,
|
| 615 |
+
)
|
| 616 |
+
return ((output, dynamic_scale)
|
| 617 |
+
if config.get("return_scale", False) else output)
|
| 618 |
+
|
| 619 |
+
def process_weights_after_loading(self, layer):
|
| 620 |
+
if self.transpose_weight:
|
| 621 |
+
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
| 622 |
+
# cast quantized weight tensors in NZ format (29) for higher inference speed
|
| 623 |
+
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
| 624 |
+
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
| 625 |
+
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
| 626 |
+
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
class AscendW8A8DynamicFusedMoEMethod:
|
| 630 |
+
"""FusedMoe method for Ascend W8A8_DYNAMIC.
|
| 631 |
+
"""
|
| 632 |
+
|
| 633 |
+
def __init__(self):
|
| 634 |
+
self.transpose_weight = True
|
| 635 |
+
|
| 636 |
+
self.ep_group = get_ep_group()
|
| 637 |
+
|
| 638 |
+
ascend_config = get_ascend_config()
|
| 639 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 640 |
+
|
| 641 |
+
try:
|
| 642 |
+
device_group = self.ep_group.device_group
|
| 643 |
+
# TODO: Try local_rank = ep_group.rank_in_group
|
| 644 |
+
local_rank = torch.distributed.get_rank(group=device_group)
|
| 645 |
+
backend = device_group._get_backend(torch.device("npu"))
|
| 646 |
+
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
| 647 |
+
local_rank)
|
| 648 |
+
except AttributeError:
|
| 649 |
+
self.moe_all_to_all_group_name = ""
|
| 650 |
+
|
| 651 |
+
@staticmethod
|
| 652 |
+
def get_weight(num_experts: int, intermediate_size_per_partition: int,
|
| 653 |
+
hidden_sizes: int,
|
| 654 |
+
params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 655 |
+
param_dict = {}
|
| 656 |
+
param_dict["w13_weight"] = torch.empty(num_experts,
|
| 657 |
+
2 *
|
| 658 |
+
intermediate_size_per_partition,
|
| 659 |
+
hidden_sizes,
|
| 660 |
+
dtype=torch.int8)
|
| 661 |
+
param_dict["w2_weight"] = torch.empty(num_experts,
|
| 662 |
+
hidden_sizes,
|
| 663 |
+
intermediate_size_per_partition,
|
| 664 |
+
dtype=torch.int8)
|
| 665 |
+
return param_dict
|
| 666 |
+
|
| 667 |
+
@staticmethod
|
| 668 |
+
def get_dynamic_quant_param(num_experts: int,
|
| 669 |
+
intermediate_size_per_partition: int,
|
| 670 |
+
hidden_sizes: int,
|
| 671 |
+
params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 672 |
+
param_dict = {}
|
| 673 |
+
param_dict["w13_weight_scale"] = torch.empty(
|
| 674 |
+
num_experts,
|
| 675 |
+
2 * intermediate_size_per_partition,
|
| 676 |
+
1,
|
| 677 |
+
dtype=params_dtype)
|
| 678 |
+
param_dict["w13_weight_offset"] = torch.empty(
|
| 679 |
+
num_experts,
|
| 680 |
+
2 * intermediate_size_per_partition,
|
| 681 |
+
1,
|
| 682 |
+
dtype=params_dtype)
|
| 683 |
+
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
| 684 |
+
hidden_sizes,
|
| 685 |
+
1,
|
| 686 |
+
dtype=params_dtype)
|
| 687 |
+
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
| 688 |
+
hidden_sizes,
|
| 689 |
+
1,
|
| 690 |
+
dtype=params_dtype)
|
| 691 |
+
return param_dict
|
| 692 |
+
|
| 693 |
+
def apply(
|
| 694 |
+
self,
|
| 695 |
+
layer: torch.nn.Module,
|
| 696 |
+
x: torch.Tensor,
|
| 697 |
+
router_logits: torch.Tensor,
|
| 698 |
+
top_k: int,
|
| 699 |
+
renormalize: bool,
|
| 700 |
+
use_grouped_topk: bool = False,
|
| 701 |
+
global_num_experts: int = -1,
|
| 702 |
+
expert_map: Optional[torch.Tensor] = None,
|
| 703 |
+
topk_group: Optional[int] = None,
|
| 704 |
+
num_expert_group: Optional[int] = None,
|
| 705 |
+
custom_routing_function: Optional[Callable] = None,
|
| 706 |
+
scoring_func: str = "softmax",
|
| 707 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 708 |
+
is_prefill: bool = True,
|
| 709 |
+
enable_force_load_balance: bool = True,
|
| 710 |
+
log2phy: torch.Tensor = None,
|
| 711 |
+
global_redundant_expert_num: int = 0,
|
| 712 |
+
shared_experts: Optional[Any] = None,
|
| 713 |
+
**kwargs,
|
| 714 |
+
) -> torch.Tensor:
|
| 715 |
+
assert router_logits.shape[
|
| 716 |
+
1] == global_num_experts, "Number of global experts mismatch"
|
| 717 |
+
|
| 718 |
+
is_deepseek_v3_r1 = global_num_experts == 256
|
| 719 |
+
use_grouped_topk = (topk_group > 1 or num_expert_group > 1)
|
| 720 |
+
|
| 721 |
+
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
| 722 |
+
if use_grouped_topk and is_deepseek_v3_r1:
|
| 723 |
+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
| 724 |
+
router_logits,
|
| 725 |
+
k=top_k, # topk当前写8
|
| 726 |
+
bias=e_score_correction_bias,
|
| 727 |
+
k_group=topk_group, # fix: 4
|
| 728 |
+
group_count=num_expert_group, # fix 8
|
| 729 |
+
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
|
| 730 |
+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
| 731 |
+
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
| 732 |
+
# out_flag=False, # todo new api; 第三个输出是否输出
|
| 733 |
+
# y2_flag=False, # old api; 第三个输出是否输出
|
| 734 |
+
routed_scaling_factor=1,
|
| 735 |
+
eps=float(1e-20))
|
| 736 |
+
else:
|
| 737 |
+
topk_weights, topk_ids = select_experts(
|
| 738 |
+
hidden_states=x,
|
| 739 |
+
router_logits=router_logits,
|
| 740 |
+
top_k=top_k,
|
| 741 |
+
use_grouped_topk=use_grouped_topk,
|
| 742 |
+
renormalize=renormalize,
|
| 743 |
+
topk_group=topk_group,
|
| 744 |
+
num_expert_group=num_expert_group,
|
| 745 |
+
custom_routing_function=custom_routing_function,
|
| 746 |
+
scoring_func=scoring_func,
|
| 747 |
+
e_score_correction_bias=e_score_correction_bias,
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
# this is a naive implementation for experts load balance so as
|
| 751 |
+
# to avoid accumulating too much tokens on a single rank.
|
| 752 |
+
# currently it is only activated when doing profile runs.
|
| 753 |
+
if enable_force_load_balance:
|
| 754 |
+
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
| 755 |
+
|
| 756 |
+
topk_weights = topk_weights.to(x.dtype)
|
| 757 |
+
|
| 758 |
+
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
|
| 759 |
+
is_prefill, is_deepseek_v3_r1)
|
| 760 |
+
if fused_moe_state == FusedMoEState.AllGatherEP:
|
| 761 |
+
return fused_experts_with_allgather(
|
| 762 |
+
hidden_states=x,
|
| 763 |
+
w1=layer.w13_weight,
|
| 764 |
+
w1_scale=layer.w13_weight_scale,
|
| 765 |
+
w2=layer.w2_weight,
|
| 766 |
+
w2_scale=layer.w2_weight_scale,
|
| 767 |
+
topk_weights=topk_weights,
|
| 768 |
+
topk_ids=topk_ids,
|
| 769 |
+
top_k=top_k,
|
| 770 |
+
expert_map=expert_map)
|
| 771 |
+
elif fused_moe_state == FusedMoEState.MC2:
|
| 772 |
+
return fused_experts_with_mc2(
|
| 773 |
+
hidden_states=x,
|
| 774 |
+
w1=layer.w13_weight,
|
| 775 |
+
w2=layer.w2_weight,
|
| 776 |
+
w1_scale=layer.w13_weight_scale,
|
| 777 |
+
w2_scale=layer.w2_weight_scale,
|
| 778 |
+
topk_weights=topk_weights,
|
| 779 |
+
topk_ids=topk_ids,
|
| 780 |
+
top_k=top_k,
|
| 781 |
+
expert_map=expert_map,
|
| 782 |
+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
| 783 |
+
log2phy=log2phy,
|
| 784 |
+
global_redundant_expert_num=global_redundant_expert_num,
|
| 785 |
+
shared_experts=shared_experts)
|
| 786 |
+
elif fused_moe_state in [
|
| 787 |
+
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
| 788 |
+
]:
|
| 789 |
+
return fused_experts(hidden_states=x,
|
| 790 |
+
w1=layer.w13_weight,
|
| 791 |
+
w1_scale=layer.w13_weight_scale,
|
| 792 |
+
w2=layer.w2_weight,
|
| 793 |
+
w2_scale=layer.w2_weight_scale,
|
| 794 |
+
topk_weights=topk_weights,
|
| 795 |
+
topk_ids=topk_ids,
|
| 796 |
+
top_k=top_k,
|
| 797 |
+
expert_map=expert_map)
|
| 798 |
+
else:
|
| 799 |
+
# The current implementation of deepseek moe splits hidden_states
|
| 800 |
+
# according to tp_size before they are feed into fused_moe module.
|
| 801 |
+
# Therefore, all2all is needed no matter how dp/tp is set so as to
|
| 802 |
+
# dispatch/combine tokens.
|
| 803 |
+
return fused_experts_with_all2all(
|
| 804 |
+
hidden_states=x,
|
| 805 |
+
w1=layer.w13_weight,
|
| 806 |
+
w1_scale=layer.w13_weight_scale,
|
| 807 |
+
w2=layer.w2_weight,
|
| 808 |
+
w2_scale=layer.w2_weight_scale,
|
| 809 |
+
topk_weights=topk_weights,
|
| 810 |
+
topk_ids=topk_ids,
|
| 811 |
+
top_k=top_k,
|
| 812 |
+
expert_map=expert_map,
|
| 813 |
+
ep_group=self.ep_group,
|
| 814 |
+
log2phy=log2phy,
|
| 815 |
+
global_redundant_expert_num=global_redundant_expert_num,
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
def process_weights_after_loading(self, layer):
|
| 819 |
+
if self.transpose_weight:
|
| 820 |
+
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
| 821 |
+
1, 2).contiguous()
|
| 822 |
+
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
| 823 |
+
1, 2).contiguous()
|
| 824 |
+
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
| 825 |
+
layer.w13_weight_scale.data.shape[0], -1)
|
| 826 |
+
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
| 827 |
+
layer.w13_weight_offset.data.shape[0], -1)
|
| 828 |
+
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
| 829 |
+
layer.w2_weight_scale.data.shape[0], -1)
|
| 830 |
+
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
| 831 |
+
layer.w2_weight_offset.data.shape[0], -1)
|
inference/vllm_ascend/utils.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# Copyright 2023 The vLLM team.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
# This file is a part of the vllm-ascend project.
|
| 17 |
+
# Adapted from vllm-project/vllm/vllm/worker/worker.py
|
| 18 |
+
#
|
| 19 |
+
|
| 20 |
+
import atexit
|
| 21 |
+
import fcntl
|
| 22 |
+
import math
|
| 23 |
+
import os
|
| 24 |
+
import shutil
|
| 25 |
+
from contextlib import contextmanager, nullcontext
|
| 26 |
+
from enum import Enum
|
| 27 |
+
from threading import Lock
|
| 28 |
+
from typing import TYPE_CHECKING, List, Tuple
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch_npu # noqa: F401 # noqa: F401
|
| 32 |
+
from packaging.version import InvalidVersion, Version
|
| 33 |
+
from torch_npu.npu.streams import Event
|
| 34 |
+
from vllm.logger import logger
|
| 35 |
+
|
| 36 |
+
import vllm_ascend.envs as envs
|
| 37 |
+
from vllm_ascend.ascend_config import get_ascend_config
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
# Recent release of torchair has moved these ops to `.scope`.
|
| 41 |
+
from torchair.scope import npu_stream_switch as _npu_stream_switch
|
| 42 |
+
from torchair.scope import npu_wait_tensor as _npu_wait_tensor
|
| 43 |
+
except ImportError:
|
| 44 |
+
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
|
| 45 |
+
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
|
| 46 |
+
|
| 47 |
+
if TYPE_CHECKING:
|
| 48 |
+
from vllm.config import VllmConfig
|
| 49 |
+
else:
|
| 50 |
+
VllmConfig = None
|
| 51 |
+
|
| 52 |
+
# NOTE: Currently, we can only capture 1920 graphs at most,
|
| 53 |
+
# due to the limitation of ACL graph. This number is bounded by
|
| 54 |
+
# the number of streams, which is 2048, we save 128 streams
|
| 55 |
+
# as a buffer.
|
| 56 |
+
# Maximum number of graphs that can be captured by ACL Graph
|
| 57 |
+
MAX_CAPTURE_SIZE = 1920
|
| 58 |
+
|
| 59 |
+
ASCEND_QUATIZATION_METHOD = "ascend"
|
| 60 |
+
SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"]
|
| 61 |
+
|
| 62 |
+
ACL_FORMAT_FRACTAL_ND = 2
|
| 63 |
+
ACL_FORMAT_FRACTAL_NZ = 29
|
| 64 |
+
|
| 65 |
+
_CUSTOM_OP_ENABLED = None
|
| 66 |
+
_IS_310P = None
|
| 67 |
+
_SLEEP_MODE_ENABLED = None
|
| 68 |
+
_CURRENT_STREAM = None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def is_310p():
|
| 72 |
+
global _IS_310P
|
| 73 |
+
if _IS_310P is None:
|
| 74 |
+
from vllm_ascend import _build_info # type: ignore
|
| 75 |
+
_IS_310P = _build_info.__soc_version__.lower().startswith("ascend310p")
|
| 76 |
+
return _IS_310P
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def sleep_mode_enabled():
|
| 80 |
+
global _SLEEP_MODE_ENABLED
|
| 81 |
+
if _SLEEP_MODE_ENABLED is None:
|
| 82 |
+
from vllm_ascend import _build_info # type: ignore
|
| 83 |
+
_SLEEP_MODE_ENABLED = _build_info.__sleep_mode_enabled__
|
| 84 |
+
return _SLEEP_MODE_ENABLED
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _round_up(x: int, align: int):
|
| 88 |
+
# round up x to align, for example, if align is 16, x will be rounded up to 16, 32, 48, etc.
|
| 89 |
+
# input: 15, 16 -> output: 16
|
| 90 |
+
# input: 17, 16 -> output: 32
|
| 91 |
+
# input: 30, 16 -> output: 32
|
| 92 |
+
# input: 33, 16 -> output: 48
|
| 93 |
+
# ...
|
| 94 |
+
return (x + align - 1) // align * align
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _custom_pad(x, pad_dims):
|
| 98 |
+
# pad the input tensor to the shape of pad_dims
|
| 99 |
+
# input: (13, 30), pad_dims: [0, 2, 0, 3]
|
| 100 |
+
# output: (16, 32)
|
| 101 |
+
return torch.nn.functional.pad(x, pad_dims)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _custom_reshape(x, target_shape):
|
| 105 |
+
# reshape the input tensor to the shape of target_shape
|
| 106 |
+
# input: (16, 32), target_shape: [1, 16, 2, 16]
|
| 107 |
+
# output: (1, 16, 2, 16)
|
| 108 |
+
return x.reshape(target_shape)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _custom_transpose(x, dim1, dim2):
|
| 112 |
+
# transpose the input tensor
|
| 113 |
+
# input: (1, 16, 2, 16), dim1: 1, dim2: 2
|
| 114 |
+
# output: (1, 2, 16, 16)
|
| 115 |
+
return x.transpose(dim1, dim2)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def nd_to_nz_2d(in_tensor: torch.Tensor) -> torch.Tensor:
|
| 119 |
+
# in_tensor: (13, 30)
|
| 120 |
+
aux_dims = [1, 0, 0, 16]
|
| 121 |
+
# aux_dims[1]: 16
|
| 122 |
+
aux_dims[1] = _round_up(in_tensor.size(0), 16)
|
| 123 |
+
# aux_dims[2]: 2
|
| 124 |
+
aux_dims[2] = _round_up(in_tensor.size(1), 16) // 16
|
| 125 |
+
|
| 126 |
+
# after: aux_dims: [1, 16, 2, 16]
|
| 127 |
+
|
| 128 |
+
pad_dims = [0, 0, 0, 0]
|
| 129 |
+
# pad_dims[1]: 2
|
| 130 |
+
pad_dims[1] = _round_up(in_tensor.size(1), 16) - in_tensor.size(1)
|
| 131 |
+
# pad_dims[3]: 3
|
| 132 |
+
pad_dims[3] = _round_up(in_tensor.size(0), 16) - in_tensor.size(0)
|
| 133 |
+
|
| 134 |
+
# after: pad_dims: [0, 2, 0, 3]
|
| 135 |
+
|
| 136 |
+
# return: (1, 2, 16, 16)
|
| 137 |
+
return _custom_transpose(
|
| 138 |
+
_custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1,
|
| 139 |
+
2).contiguous()
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor:
|
| 143 |
+
num_tokens = mask_tensor.shape[0]
|
| 144 |
+
max_seq_len = mask_tensor.shape[1]
|
| 145 |
+
|
| 146 |
+
tokens_pad = (num_tokens + 15) // 16 * 16
|
| 147 |
+
max_seq_len_pad = (max_seq_len + 15) // 16 * 16
|
| 148 |
+
|
| 149 |
+
mask_tensor_pad = \
|
| 150 |
+
torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device)
|
| 151 |
+
mask_tensor_pad[0][:num_tokens, :max_seq_len] = mask_tensor
|
| 152 |
+
mask = mask_tensor_pad.reshape(
|
| 153 |
+
(1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)
|
| 154 |
+
return mask
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def aligned_16(tensor: torch.Tensor):
|
| 158 |
+
"""Aligned tensor for 310P"""
|
| 159 |
+
|
| 160 |
+
# Get the size of the current 0th dimension
|
| 161 |
+
n = tensor.size(0)
|
| 162 |
+
|
| 163 |
+
# Calculate the aligned size
|
| 164 |
+
n_aligned = ((n + 15) // 16) * 16
|
| 165 |
+
|
| 166 |
+
# If already aligned, return the original tensor
|
| 167 |
+
if n == n_aligned:
|
| 168 |
+
return tensor
|
| 169 |
+
|
| 170 |
+
# Create a new tensor with shape (n_aligned, H, W) and fill it with zeros
|
| 171 |
+
new_tensor = torch.zeros(n_aligned,
|
| 172 |
+
*tensor.shape[1:],
|
| 173 |
+
dtype=tensor.dtype,
|
| 174 |
+
device=tensor.device)
|
| 175 |
+
|
| 176 |
+
# Copy the original tensor to the first N positions of the new tensor
|
| 177 |
+
new_tensor[:n] = tensor
|
| 178 |
+
|
| 179 |
+
return new_tensor
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def maybe_converting_weight_acl_format(model, format=ACL_FORMAT_FRACTAL_NZ):
|
| 183 |
+
# currently, there are some operations which do not support ACL_FORMAT_FRACTAL_NZ
|
| 184 |
+
# in eager mode but support it in torchair graph mode. since ACL_FORMAT_FRACTAL_NZ
|
| 185 |
+
# is much more preferred than ACL_FORMAT_FRACTAL_ND on 300I Duo, we add this
|
| 186 |
+
# conversion when using torchair graph mode on 300I Duo platform.
|
| 187 |
+
# TODO: we will remove this conversion if npu_quant_grouped_matmul_dequant
|
| 188 |
+
# accepts weight format of ACL_FORMAT_FRACTAL_NZ in eager mode.
|
| 189 |
+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
| 190 |
+
|
| 191 |
+
use_torchair = get_ascend_config().torchair_graph_config.enabled
|
| 192 |
+
if not is_310p() or not use_torchair:
|
| 193 |
+
return
|
| 194 |
+
for module in model.modules():
|
| 195 |
+
if isinstance(module, FusedMoE):
|
| 196 |
+
if torch_npu.get_npu_format(module.w13_weight.data) == format:
|
| 197 |
+
return
|
| 198 |
+
module.w13_weight.data = torch_npu.npu_format_cast(
|
| 199 |
+
module.w13_weight.data, format)
|
| 200 |
+
module.w2_weight.data = torch_npu.npu_format_cast(
|
| 201 |
+
module.w2_weight.data, format)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def try_register_lib(lib_name: str, lib_info: str = ""):
|
| 205 |
+
import importlib
|
| 206 |
+
import importlib.util
|
| 207 |
+
try:
|
| 208 |
+
module_spec = importlib.util.find_spec(lib_name)
|
| 209 |
+
if module_spec is not None:
|
| 210 |
+
importlib.import_module(lib_name)
|
| 211 |
+
if lib_info:
|
| 212 |
+
logger.info(lib_info)
|
| 213 |
+
except Exception:
|
| 214 |
+
pass
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def enable_custom_op():
|
| 218 |
+
"""
|
| 219 |
+
Enable lazy init for vllm_ascend_C to avoid early initialization of CANN's RTS component.
|
| 220 |
+
Ensure that ASCEND_RT_VISIBLE_DEVICES can be dynamically modified before torch.npu.set_device().
|
| 221 |
+
"""
|
| 222 |
+
global _CUSTOM_OP_ENABLED
|
| 223 |
+
if _CUSTOM_OP_ENABLED is not None:
|
| 224 |
+
return _CUSTOM_OP_ENABLED
|
| 225 |
+
try:
|
| 226 |
+
# register custom ops into torch_library here
|
| 227 |
+
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
|
| 228 |
+
_CUSTOM_OP_ENABLED = True
|
| 229 |
+
except ImportError:
|
| 230 |
+
_CUSTOM_OP_ENABLED = False
|
| 231 |
+
logger.warning(
|
| 232 |
+
"Warning: Failed to register custom ops, all custom ops will be disabled"
|
| 233 |
+
)
|
| 234 |
+
return _CUSTOM_OP_ENABLED
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def find_hccl_library() -> str:
|
| 238 |
+
"""
|
| 239 |
+
We either use the library file specified by the `HCCL_SO_PATH`
|
| 240 |
+
environment variable, or we find the library file brought by PyTorch.
|
| 241 |
+
After importing `torch`, `libhccl.so` can be
|
| 242 |
+
found by `ctypes` automatically.
|
| 243 |
+
"""
|
| 244 |
+
so_file = envs.HCCL_SO_PATH
|
| 245 |
+
|
| 246 |
+
# manually load the hccl library
|
| 247 |
+
if so_file:
|
| 248 |
+
logger.info("Found hccl from environment variable HCCL_SO_PATH=%s",
|
| 249 |
+
so_file)
|
| 250 |
+
else:
|
| 251 |
+
if torch.version.cann is not None:
|
| 252 |
+
so_file = "libhccl.so"
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError("HCCL only supports Ascend NPU backends.")
|
| 255 |
+
logger.info("Found hccl from library %s", so_file)
|
| 256 |
+
return so_file
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def current_stream() -> torch.npu.Stream:
|
| 260 |
+
"""
|
| 261 |
+
replace `torch.npu.current_stream()` with `vllm.utils.current_stream()`.
|
| 262 |
+
it turns out that `torch.npu.current_stream()` is quite expensive,
|
| 263 |
+
as it will construct a new stream object at each call.
|
| 264 |
+
here we patch `torch.npu.set_stream` to keep track of the current stream
|
| 265 |
+
directly, so that we can avoid calling `torch.npu.current_stream()`.
|
| 266 |
+
|
| 267 |
+
"""
|
| 268 |
+
global _CURRENT_STREAM
|
| 269 |
+
if _CURRENT_STREAM is None:
|
| 270 |
+
# when this function is called before any stream is set,
|
| 271 |
+
# we return the default stream.
|
| 272 |
+
_CURRENT_STREAM = torch.npu.current_stream()
|
| 273 |
+
return _CURRENT_STREAM
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def adapt_patch(is_global_patch: bool = False):
|
| 277 |
+
if is_global_patch:
|
| 278 |
+
from vllm_ascend.patch import platform # noqa: F401
|
| 279 |
+
else:
|
| 280 |
+
from vllm_ascend.patch import worker # noqa: F401
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def vllm_version_is(target_vllm_version: str):
|
| 284 |
+
if envs.VLLM_VERSION is not None:
|
| 285 |
+
vllm_version = envs.VLLM_VERSION
|
| 286 |
+
else:
|
| 287 |
+
import vllm
|
| 288 |
+
vllm_version = vllm.__version__
|
| 289 |
+
try:
|
| 290 |
+
return Version(vllm_version) == Version(target_vllm_version)
|
| 291 |
+
except InvalidVersion:
|
| 292 |
+
raise ValueError(
|
| 293 |
+
f"Invalid vllm version {vllm_version} found. A dev version of vllm "
|
| 294 |
+
"is installed probably. Set the environment variable VLLM_VERSION "
|
| 295 |
+
"to control it by hand. And please make sure the value follows the "
|
| 296 |
+
"format of x.y.z.")
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
| 300 |
+
"""Update ACL graph capture sizes based on hardware limitations"""
|
| 301 |
+
# Store original configuration and temporarily clear it
|
| 302 |
+
compilation_config = vllm_config.compilation_config
|
| 303 |
+
original_sizes, compilation_config.cudagraph_capture_sizes = \
|
| 304 |
+
compilation_config.cudagraph_capture_sizes, None
|
| 305 |
+
|
| 306 |
+
# Calculate parallel configuration factor
|
| 307 |
+
num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
| 308 |
+
parallel_config = vllm_config.parallel_config
|
| 309 |
+
|
| 310 |
+
# TODO: Find out whether we need to take into account the pp_size
|
| 311 |
+
parallel_factor = 1 + sum(size > 1 for size in [
|
| 312 |
+
parallel_config.data_parallel_size_local,
|
| 313 |
+
parallel_config.tensor_parallel_size,
|
| 314 |
+
parallel_config.expert_parallel_size,
|
| 315 |
+
parallel_config.expert_tensor_parallel_size,
|
| 316 |
+
])
|
| 317 |
+
|
| 318 |
+
# Calculate maximum supported batch sizes considering model architecture
|
| 319 |
+
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
|
| 320 |
+
(num_hidden_layers + 1) / parallel_factor)
|
| 321 |
+
logger.info("Calculated maximum supported batch sizes for ACL graph: %s",
|
| 322 |
+
max_num_batch_sizes)
|
| 323 |
+
|
| 324 |
+
# If original sizes exceed maximum, sample a representative subset
|
| 325 |
+
if max_num_batch_sizes < len(original_sizes):
|
| 326 |
+
# Sample uniformly from original sizes
|
| 327 |
+
step = (len(original_sizes) - 1) / (max_num_batch_sizes - 1)
|
| 328 |
+
indices = [round(i * step) for i in range(max_num_batch_sizes)]
|
| 329 |
+
|
| 330 |
+
# Ensure first and last elements are preserved
|
| 331 |
+
indices[0], indices[-1] = 0, len(original_sizes) - 1
|
| 332 |
+
|
| 333 |
+
sampled_sizes = [original_sizes[i] for i in indices]
|
| 334 |
+
compilation_config.init_with_cudagraph_sizes(sampled_sizes)
|
| 335 |
+
|
| 336 |
+
logger.info(
|
| 337 |
+
"Adjusted ACL graph batch sizes for %s model (layers: %d): %d → %d sizes",
|
| 338 |
+
vllm_config.model_config.architectures[0],
|
| 339 |
+
num_hidden_layers,
|
| 340 |
+
len(original_sizes),
|
| 341 |
+
len(compilation_config.
|
| 342 |
+
cudagraph_capture_sizes # type: ignore[arg-type]
|
| 343 |
+
))
|
| 344 |
+
else:
|
| 345 |
+
# No adjustment needed
|
| 346 |
+
compilation_config.cudagraph_capture_sizes = original_sizes
|
| 347 |
+
logger.info(
|
| 348 |
+
"No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes",
|
| 349 |
+
vllm_config.model_config.architectures[0], num_hidden_layers,
|
| 350 |
+
len(original_sizes))
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
# TODO(wxy): Move to ops module
|
| 354 |
+
def dispose_tensor(x: torch.Tensor):
|
| 355 |
+
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
class ProfileExecuteDuration:
|
| 359 |
+
_instance = None
|
| 360 |
+
_observations: List[Tuple[str, Event, Event]] = []
|
| 361 |
+
_lock = Lock()
|
| 362 |
+
|
| 363 |
+
def __new__(cls):
|
| 364 |
+
with cls._lock:
|
| 365 |
+
if cls._instance is None:
|
| 366 |
+
cls._instance = super().__new__(cls)
|
| 367 |
+
atexit.register(cls._instance.destroy)
|
| 368 |
+
return cls._instance
|
| 369 |
+
|
| 370 |
+
def destroy(self):
|
| 371 |
+
with self._lock:
|
| 372 |
+
self._observations.clear()
|
| 373 |
+
|
| 374 |
+
@contextmanager
|
| 375 |
+
def capture_async(self, duration_tag: str):
|
| 376 |
+
if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
|
| 377 |
+
yield
|
| 378 |
+
return
|
| 379 |
+
|
| 380 |
+
observe_start = Event(enable_timing=True)
|
| 381 |
+
observe_start.record()
|
| 382 |
+
try:
|
| 383 |
+
yield
|
| 384 |
+
finally:
|
| 385 |
+
observe_end = Event(enable_timing=True)
|
| 386 |
+
observe_end.record()
|
| 387 |
+
with self._lock:
|
| 388 |
+
self._observations.append(
|
| 389 |
+
(duration_tag, observe_start, observe_end))
|
| 390 |
+
|
| 391 |
+
def pop_captured_sync(self) -> dict:
|
| 392 |
+
"""Pop and synchronize all events in the observation list"""
|
| 393 |
+
durations: dict[str, float] = {}
|
| 394 |
+
if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
|
| 395 |
+
return durations
|
| 396 |
+
|
| 397 |
+
while self._observations:
|
| 398 |
+
with self._lock:
|
| 399 |
+
tag, observe_start, observe_end = self._observations.pop()
|
| 400 |
+
observe_end.synchronize()
|
| 401 |
+
durations[tag] = observe_start.elapsed_time(observe_end)
|
| 402 |
+
|
| 403 |
+
return durations
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
# TODO(wxy): Move to ops module
|
| 407 |
+
def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
|
| 408 |
+
return _npu_stream_switch(tag, priority) if enabled else nullcontext()
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
# TODO(wxy): Move to ops module
|
| 412 |
+
def npu_wait_tensor(self: torch.Tensor,
|
| 413 |
+
dependency: torch.Tensor,
|
| 414 |
+
*,
|
| 415 |
+
enabled: bool = True):
|
| 416 |
+
return _npu_wait_tensor(self, dependency) if enabled else self
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
# TODO(wxy): Move to ops module
|
| 420 |
+
def npu_prefetch(input: torch.Tensor,
|
| 421 |
+
dependency: torch.Tensor,
|
| 422 |
+
max_size: int = 0,
|
| 423 |
+
*,
|
| 424 |
+
enabled: bool = True):
|
| 425 |
+
if not enabled:
|
| 426 |
+
return
|
| 427 |
+
input_size = input.element_size() * input.numel()
|
| 428 |
+
if max_size <= 0 or max_size > input_size:
|
| 429 |
+
max_size = input_size
|
| 430 |
+
torch_npu.npu_prefetch(input, dependency, max_size)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
# TODO(zzzzwwjj): move this into forward_context
|
| 434 |
+
class FusedMoEState(Enum):
|
| 435 |
+
AllGather = 0
|
| 436 |
+
All2All = 1
|
| 437 |
+
MC2 = 2
|
| 438 |
+
AllGatherEP = 3
|
| 439 |
+
NaiveMulticast = 4
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
# TODO(ttanzhiqiang): rm_router_logits
|
| 443 |
+
# dp>1 will trigger
|
| 444 |
+
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
|
| 445 |
+
def get_rm_router_logits_state(ep_size: int, dp_size: int,
|
| 446 |
+
is_deepseek_v3_r1: bool):
|
| 447 |
+
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
| 448 |
+
# only supports deepseek v3/r1
|
| 449 |
+
if dp_size > 1:
|
| 450 |
+
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
| 451 |
+
and is_deepseek_v3_r1):
|
| 452 |
+
return True
|
| 453 |
+
elif ep_size == 1 and is_deepseek_v3_r1:
|
| 454 |
+
return True
|
| 455 |
+
return False
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
# TODO(ttanzhiqiang): all_reduce merge
|
| 459 |
+
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
| 460 |
+
# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
|
| 461 |
+
def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
|
| 462 |
+
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
| 463 |
+
# only supports deepseek v3/r1
|
| 464 |
+
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
| 465 |
+
and is_deepseek_v3_r1):
|
| 466 |
+
return True
|
| 467 |
+
elif ep_size == 1 and is_deepseek_v3_r1:
|
| 468 |
+
return True
|
| 469 |
+
return False
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
# TODO(zzzzwwjj): add soc_version to choose branch
|
| 473 |
+
def get_fused_moe_state(ep_size: int, with_prefill: bool,
|
| 474 |
+
is_deepseek_v3_r1: bool):
|
| 475 |
+
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
| 476 |
+
# only supports deepseek v3/r1
|
| 477 |
+
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
| 478 |
+
and is_deepseek_v3_r1 and not with_prefill):
|
| 479 |
+
return FusedMoEState.AllGatherEP
|
| 480 |
+
elif ep_size == 1:
|
| 481 |
+
if with_prefill:
|
| 482 |
+
return FusedMoEState.NaiveMulticast
|
| 483 |
+
else:
|
| 484 |
+
return FusedMoEState.AllGather
|
| 485 |
+
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
|
| 486 |
+
elif ep_size < 16 or with_prefill:
|
| 487 |
+
return FusedMoEState.All2All
|
| 488 |
+
else:
|
| 489 |
+
return FusedMoEState.MC2
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
|
| 493 |
+
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
|
| 494 |
+
TORCHAIR_CACHE_PATH_NAME = ".torchair_cache"
|
| 495 |
+
TORCHAIR_CACHE_DIR = os.getenv(
|
| 496 |
+
'TORCHAIR_CACHE_HOME', os.path.join(os.getcwd(), TORCHAIR_CACHE_PATH_NAME))
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def get_torchair_current_work_dir(file_name=None):
|
| 500 |
+
if file_name is None:
|
| 501 |
+
return TORCHAIR_CACHE_DIR
|
| 502 |
+
return os.path.join(TORCHAIR_CACHE_DIR, file_name)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def check_torchair_cache_exist():
|
| 506 |
+
res = False
|
| 507 |
+
torch_air_abs_path = get_torchair_current_work_dir()
|
| 508 |
+
if os.path.exists(torch_air_abs_path):
|
| 509 |
+
file_list = os.listdir(torch_air_abs_path)
|
| 510 |
+
if len(file_list) != 0:
|
| 511 |
+
res = True
|
| 512 |
+
return res
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def check_kv_cache_bytes_cache_exist():
|
| 516 |
+
res = False
|
| 517 |
+
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
|
| 518 |
+
KV_CACHE_BYTES_CACHE_PATH_NAME)
|
| 519 |
+
if os.path.exists(kv_cache_bytes_cache_abs_path):
|
| 520 |
+
file_list = os.listdir(kv_cache_bytes_cache_abs_path)
|
| 521 |
+
if len(file_list) != 0:
|
| 522 |
+
res = True
|
| 523 |
+
return res
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def read_kv_cache_bytes_from_file(rank) -> int:
|
| 527 |
+
kv_cache_bytes = -1
|
| 528 |
+
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
|
| 529 |
+
KV_CACHE_BYTES_CACHE_PATH_NAME)
|
| 530 |
+
kv_cache_bytes_file = os.path.join(
|
| 531 |
+
kv_cache_bytes_cache_abs_path,
|
| 532 |
+
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
|
| 533 |
+
with open(kv_cache_bytes_file, "r", encoding="utf-8") as f:
|
| 534 |
+
with file_lock(f, fcntl.LOCK_SH):
|
| 535 |
+
kv_cache_bytes = int(f.readline())
|
| 536 |
+
return kv_cache_bytes
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
@contextmanager
|
| 540 |
+
def file_lock(file_descriptor, lock_type):
|
| 541 |
+
fcntl.flock(file_descriptor, lock_type)
|
| 542 |
+
try:
|
| 543 |
+
yield
|
| 544 |
+
finally:
|
| 545 |
+
fcntl.flock(file_descriptor, fcntl.LOCK_UN)
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def write_kv_cache_bytes_to_file(rank, kv_cache_bytes):
|
| 549 |
+
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
|
| 550 |
+
KV_CACHE_BYTES_CACHE_PATH_NAME)
|
| 551 |
+
os.makedirs(kv_cache_bytes_cache_abs_path, exist_ok=True)
|
| 552 |
+
kv_cache_bytes_file = os.path.join(
|
| 553 |
+
kv_cache_bytes_cache_abs_path,
|
| 554 |
+
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
|
| 555 |
+
with open(kv_cache_bytes_file, "w", encoding="utf-8") as f:
|
| 556 |
+
with file_lock(f, fcntl.LOCK_EX):
|
| 557 |
+
f.write(f"{kv_cache_bytes}")
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def delete_torchair_cache_file():
|
| 561 |
+
torch_air_abs_path = get_torchair_current_work_dir()
|
| 562 |
+
if os.path.exists(torch_air_abs_path):
|
| 563 |
+
shutil.rmtree(torch_air_abs_path)
|
inference/vllm_ascend/worker/model_runner_v1.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
inference/vllm_ascend/worker/npu_input_batch.py
ADDED
|
@@ -0,0 +1,796 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# Copyright 2023 The vLLM team.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
# This file is a part of the vllm-ascend project.
|
| 17 |
+
# Adapted from vllm-project/vllm/vllm/worker/gpu_input_batch.py
|
| 18 |
+
#
|
| 19 |
+
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Optional, cast, Union
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
from vllm.lora.request import LoRARequest
|
| 26 |
+
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
| 27 |
+
from vllm.pooling_params import PoolingParams
|
| 28 |
+
from vllm.sampling_params import SamplingParams, SamplingType
|
| 29 |
+
from vllm.utils import swap_dict_values
|
| 30 |
+
from vllm.v1.outputs import LogprobsTensors
|
| 31 |
+
from vllm.v1.sample.logits_processor import init_builtin_logitsprocs
|
| 32 |
+
from vllm.v1.sample.metadata import SamplingMetadata
|
| 33 |
+
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
|
| 34 |
+
from vllm.v1.utils import copy_slice
|
| 35 |
+
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
| 36 |
+
|
| 37 |
+
from vllm_ascend.pool.metadata import PoolingMetadata
|
| 38 |
+
|
| 39 |
+
_SAMPLING_EPS = 1e-5
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class CachedRequestState:
|
| 44 |
+
|
| 45 |
+
req_id: str
|
| 46 |
+
prompt_token_ids: list[int]
|
| 47 |
+
mm_inputs: list[MultiModalKwargs]
|
| 48 |
+
mm_positions: list[PlaceholderRange]
|
| 49 |
+
sampling_params: Optional[SamplingParams]
|
| 50 |
+
pooling_params: Optional[PoolingParams]
|
| 51 |
+
generator: Optional[torch.Generator]
|
| 52 |
+
|
| 53 |
+
block_ids: tuple[list[int], ...]
|
| 54 |
+
num_computed_tokens: int
|
| 55 |
+
output_token_ids: list[int]
|
| 56 |
+
|
| 57 |
+
mrope_positions: Optional[torch.Tensor] = None
|
| 58 |
+
mrope_position_delta: Optional[int] = None
|
| 59 |
+
|
| 60 |
+
lora_request: Optional[LoRARequest] = None
|
| 61 |
+
|
| 62 |
+
def __post_init__(self):
|
| 63 |
+
self.num_prompt_tokens = len(self.prompt_token_ids)
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def num_tokens(self) -> int:
|
| 67 |
+
return self.num_prompt_tokens + len(self.output_token_ids)
|
| 68 |
+
|
| 69 |
+
def get_token_id(self, idx: int) -> int:
|
| 70 |
+
if idx < self.num_prompt_tokens:
|
| 71 |
+
return self.prompt_token_ids[idx]
|
| 72 |
+
else:
|
| 73 |
+
return self.output_token_ids[idx - self.num_prompt_tokens]
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class SamplingMetadataTopNSigma(SamplingMetadata):
|
| 77 |
+
top_n_sigma: torch.Tensor
|
| 78 |
+
no_top_n_sigma: bool
|
| 79 |
+
|
| 80 |
+
class InputBatch:
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
max_num_reqs: int,
|
| 85 |
+
max_model_len: int,
|
| 86 |
+
max_num_batched_tokens: int,
|
| 87 |
+
device: torch.device,
|
| 88 |
+
pin_memory: bool,
|
| 89 |
+
vocab_size: int,
|
| 90 |
+
block_sizes: list[int], # The block_size of each kv cache group
|
| 91 |
+
logits_processing_needs_token_ids: bool = False,
|
| 92 |
+
is_spec_decode: bool = False,
|
| 93 |
+
):
|
| 94 |
+
self.is_spec_decode = is_spec_decode
|
| 95 |
+
self.max_num_reqs = max_num_reqs
|
| 96 |
+
self.max_model_len = max_model_len
|
| 97 |
+
self.max_num_batched_tokens = max_num_batched_tokens
|
| 98 |
+
self.device = device
|
| 99 |
+
self.pin_memory = pin_memory
|
| 100 |
+
self.vocab_size = vocab_size
|
| 101 |
+
self.logits_processing_needs_token_ids = (
|
| 102 |
+
logits_processing_needs_token_ids)
|
| 103 |
+
|
| 104 |
+
self._req_ids: list[Optional[str]] = []
|
| 105 |
+
self.req_id_to_index: dict[str, int] = {}
|
| 106 |
+
|
| 107 |
+
# TODO(woosuk): This buffer could be too large if max_model_len is big.
|
| 108 |
+
# Find a way to reduce the CPU memory usage.
|
| 109 |
+
# This buffer is not directly transferred to the NPU, so it does not
|
| 110 |
+
# need to be pinned.
|
| 111 |
+
self.token_ids_cpu_tensor = torch.zeros(
|
| 112 |
+
(max_num_reqs, max_model_len),
|
| 113 |
+
device="cpu",
|
| 114 |
+
dtype=torch.int32,
|
| 115 |
+
pin_memory=False,
|
| 116 |
+
)
|
| 117 |
+
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
| 118 |
+
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
| 119 |
+
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
| 120 |
+
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
| 121 |
+
self.num_computed_tokens_cpu_tensor = torch.zeros(
|
| 122 |
+
(max_num_reqs, ),
|
| 123 |
+
device="cpu",
|
| 124 |
+
dtype=torch.int32,
|
| 125 |
+
pin_memory=pin_memory,
|
| 126 |
+
)
|
| 127 |
+
self.num_computed_tokens_cpu = \
|
| 128 |
+
self.num_computed_tokens_cpu_tensor.numpy()
|
| 129 |
+
|
| 130 |
+
# Block table.
|
| 131 |
+
self.block_table = MultiGroupBlockTable(
|
| 132 |
+
max_num_reqs=max_num_reqs,
|
| 133 |
+
max_model_len=max_model_len,
|
| 134 |
+
max_num_batched_tokens=max_num_batched_tokens,
|
| 135 |
+
pin_memory=pin_memory,
|
| 136 |
+
device=device,
|
| 137 |
+
block_sizes=block_sizes,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Sampling-related.
|
| 141 |
+
self.temperature = torch.empty((max_num_reqs, ),
|
| 142 |
+
dtype=torch.float32,
|
| 143 |
+
device=device)
|
| 144 |
+
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 145 |
+
dtype=torch.float32,
|
| 146 |
+
device="cpu",
|
| 147 |
+
pin_memory=pin_memory)
|
| 148 |
+
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
|
| 149 |
+
self.greedy_reqs: set[str] = set()
|
| 150 |
+
self.random_reqs: set[str] = set()
|
| 151 |
+
|
| 152 |
+
self.top_p = torch.empty((max_num_reqs, ),
|
| 153 |
+
dtype=torch.float32,
|
| 154 |
+
device=device)
|
| 155 |
+
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 156 |
+
dtype=torch.float32,
|
| 157 |
+
device="cpu",
|
| 158 |
+
pin_memory=pin_memory)
|
| 159 |
+
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
|
| 160 |
+
self.top_p_reqs: set[str] = set()
|
| 161 |
+
|
| 162 |
+
self.top_k = torch.empty((max_num_reqs, ),
|
| 163 |
+
dtype=torch.int32,
|
| 164 |
+
device=device)
|
| 165 |
+
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 166 |
+
dtype=torch.int32,
|
| 167 |
+
device="cpu",
|
| 168 |
+
pin_memory=pin_memory)
|
| 169 |
+
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
| 170 |
+
self.top_k_reqs: set[str] = set()
|
| 171 |
+
|
| 172 |
+
# IDs of requests which do not support spec decoding
|
| 173 |
+
self.spec_decode_unsupported_reqs: set[str] = set()
|
| 174 |
+
|
| 175 |
+
self.min_p = torch.empty((max_num_reqs, ),
|
| 176 |
+
dtype=torch.float32,
|
| 177 |
+
device=device)
|
| 178 |
+
self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 179 |
+
dtype=torch.float32,
|
| 180 |
+
device="cpu",
|
| 181 |
+
pin_memory=pin_memory)
|
| 182 |
+
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
| 183 |
+
self.min_p_reqs: set[str] = set()
|
| 184 |
+
|
| 185 |
+
# topnsigma penalty
|
| 186 |
+
self.top_n_sigma = torch.empty((max_num_reqs, ),
|
| 187 |
+
dtype=torch.float,
|
| 188 |
+
device=device)
|
| 189 |
+
self.top_n_sigma_cpu_tensor = torch.empty(
|
| 190 |
+
(max_num_reqs, ),
|
| 191 |
+
dtype=torch.float,
|
| 192 |
+
device="cpu",
|
| 193 |
+
pin_memory=pin_memory)
|
| 194 |
+
self.top_n_sigma_cpu = \
|
| 195 |
+
self.top_n_sigma_cpu_tensor.numpy()
|
| 196 |
+
self.top_n_sigma_reqs: set[str] = set()
|
| 197 |
+
|
| 198 |
+
# Frequency penalty related data structures
|
| 199 |
+
self.frequency_penalties = torch.empty((max_num_reqs, ),
|
| 200 |
+
dtype=torch.float,
|
| 201 |
+
device=device)
|
| 202 |
+
self.frequency_penalties_cpu_tensor = torch.empty(
|
| 203 |
+
(max_num_reqs, ),
|
| 204 |
+
dtype=torch.float,
|
| 205 |
+
device="cpu",
|
| 206 |
+
pin_memory=pin_memory)
|
| 207 |
+
self.frequency_penalties_cpu = \
|
| 208 |
+
self.frequency_penalties_cpu_tensor.numpy()
|
| 209 |
+
self.frequency_penalties_reqs: set[str] = set()
|
| 210 |
+
|
| 211 |
+
# Presence penalty related data structures
|
| 212 |
+
self.presence_penalties = torch.empty((max_num_reqs, ),
|
| 213 |
+
dtype=torch.float,
|
| 214 |
+
device=device)
|
| 215 |
+
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 216 |
+
dtype=torch.float,
|
| 217 |
+
device="cpu",
|
| 218 |
+
pin_memory=pin_memory)
|
| 219 |
+
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
|
| 220 |
+
)
|
| 221 |
+
self.presence_penalties_reqs: set[str] = set()
|
| 222 |
+
|
| 223 |
+
# Repetition penalty related data structures
|
| 224 |
+
self.repetition_penalties = torch.empty((max_num_reqs, ),
|
| 225 |
+
dtype=torch.float,
|
| 226 |
+
device=device)
|
| 227 |
+
self.repetition_penalties_cpu_tensor = torch.empty(
|
| 228 |
+
(max_num_reqs, ),
|
| 229 |
+
dtype=torch.float,
|
| 230 |
+
device="cpu",
|
| 231 |
+
pin_memory=pin_memory)
|
| 232 |
+
self.repetition_penalties_cpu = \
|
| 233 |
+
self.repetition_penalties_cpu_tensor.numpy()
|
| 234 |
+
self.repetition_penalties_reqs: set[str] = set()
|
| 235 |
+
|
| 236 |
+
# req_index -> (min_tokens, stop_token_ids)
|
| 237 |
+
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
|
| 238 |
+
|
| 239 |
+
# lora related
|
| 240 |
+
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
| 241 |
+
dtype=np.int32)
|
| 242 |
+
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
| 243 |
+
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
| 244 |
+
|
| 245 |
+
# req_index -> generator
|
| 246 |
+
# NOTE(woosuk): The indices of the requests that do not have their own
|
| 247 |
+
# generator should not be included in the dictionary.
|
| 248 |
+
self.generators: dict[int, torch.Generator] = {}
|
| 249 |
+
|
| 250 |
+
self.num_logprobs: dict[str, int] = {}
|
| 251 |
+
# NOTE(rob): num_prompt_logprobs only includes reqs
|
| 252 |
+
# that are currently in the prefill phase.
|
| 253 |
+
self.num_prompt_logprobs: dict[str, int] = {}
|
| 254 |
+
|
| 255 |
+
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
| 256 |
+
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
| 257 |
+
|
| 258 |
+
self.logit_bias: list[Optional[dict[int,
|
| 259 |
+
float]]] = [None] * max_num_reqs
|
| 260 |
+
self.has_allowed_token_ids: set[str] = set()
|
| 261 |
+
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
|
| 262 |
+
# the value is False. Since we use masked_fill_ to set -inf.
|
| 263 |
+
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
|
| 264 |
+
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
|
| 265 |
+
|
| 266 |
+
# req_index -> bad_words_token_ids
|
| 267 |
+
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
| 268 |
+
|
| 269 |
+
self.req_output_token_ids: list[Optional[list[int]]] = []
|
| 270 |
+
|
| 271 |
+
# Define logits processors.
|
| 272 |
+
# TODO(andy): logits processor list should be extensible via engine
|
| 273 |
+
# constructor argument; for now the list is fixed.
|
| 274 |
+
self.logitsprocs = init_builtin_logitsprocs(
|
| 275 |
+
pin_memory_available=pin_memory,
|
| 276 |
+
max_num_reqs=max_num_reqs + 1,
|
| 277 |
+
device=device)
|
| 278 |
+
|
| 279 |
+
# This is updated each time the batch constituents change.
|
| 280 |
+
self.sampling_metadata = self._make_sampling_metadata()
|
| 281 |
+
|
| 282 |
+
self.pooling_params: dict[str, PoolingParams] = {}
|
| 283 |
+
|
| 284 |
+
@property
|
| 285 |
+
def req_ids(self) -> list[str]:
|
| 286 |
+
# None elements should only be present transiently
|
| 287 |
+
# while performing state updates to the batch.
|
| 288 |
+
return cast(list[str], self._req_ids)
|
| 289 |
+
|
| 290 |
+
def add_request(
|
| 291 |
+
self,
|
| 292 |
+
request: "CachedRequestState",
|
| 293 |
+
req_index: Optional[int] = None,
|
| 294 |
+
) -> None:
|
| 295 |
+
if req_index is None:
|
| 296 |
+
req_index = self.num_reqs
|
| 297 |
+
assert req_index < self.max_num_reqs
|
| 298 |
+
|
| 299 |
+
req_id = request.req_id
|
| 300 |
+
if req_index == len(self._req_ids):
|
| 301 |
+
self._req_ids.append(req_id)
|
| 302 |
+
self.req_output_token_ids.append(request.output_token_ids)
|
| 303 |
+
else:
|
| 304 |
+
self._req_ids[req_index] = req_id
|
| 305 |
+
self.req_output_token_ids[req_index] = request.output_token_ids
|
| 306 |
+
|
| 307 |
+
self.req_id_to_index[req_id] = req_index
|
| 308 |
+
|
| 309 |
+
# Copy the prompt token ids and output token ids.
|
| 310 |
+
num_prompt_tokens = len(request.prompt_token_ids)
|
| 311 |
+
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
| 312 |
+
self.token_ids_cpu[
|
| 313 |
+
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
| 314 |
+
start_idx = num_prompt_tokens
|
| 315 |
+
end_idx = start_idx + len(request.output_token_ids)
|
| 316 |
+
self.token_ids_cpu[req_index,
|
| 317 |
+
start_idx:end_idx] = request.output_token_ids
|
| 318 |
+
# Number of token ids in token_ids_cpu.
|
| 319 |
+
# NOTE(woosuk): This may include spec decode tokens.
|
| 320 |
+
self.num_tokens[req_index] = request.num_tokens
|
| 321 |
+
# Number of tokens without spec decode tokens.
|
| 322 |
+
self.num_tokens_no_spec[req_index] = request.num_tokens
|
| 323 |
+
|
| 324 |
+
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
| 325 |
+
self.block_table.add_row(request.block_ids, req_index)
|
| 326 |
+
|
| 327 |
+
if sampling_params := request.sampling_params:
|
| 328 |
+
if self.is_spec_decode and is_spec_decode_unsupported(
|
| 329 |
+
sampling_params):
|
| 330 |
+
self.spec_decode_unsupported_reqs.add(req_id)
|
| 331 |
+
if sampling_params.sampling_type == SamplingType.GREEDY:
|
| 332 |
+
# Avoid later division by zero.
|
| 333 |
+
self.temperature_cpu[req_index] = -1.0
|
| 334 |
+
self.greedy_reqs.add(req_id)
|
| 335 |
+
else:
|
| 336 |
+
self.temperature_cpu[req_index] = sampling_params.temperature
|
| 337 |
+
self.random_reqs.add(req_id)
|
| 338 |
+
|
| 339 |
+
self.top_p_cpu[req_index] = sampling_params.top_p
|
| 340 |
+
if sampling_params.top_p < 1:
|
| 341 |
+
self.top_p_reqs.add(req_id)
|
| 342 |
+
top_k = sampling_params.top_k
|
| 343 |
+
if 0 < top_k < self.vocab_size:
|
| 344 |
+
self.top_k_reqs.add(req_id)
|
| 345 |
+
else:
|
| 346 |
+
top_k = self.vocab_size
|
| 347 |
+
self.top_k_cpu[req_index] = top_k
|
| 348 |
+
self.min_p_cpu[req_index] = sampling_params.min_p
|
| 349 |
+
self.frequency_penalties_cpu[
|
| 350 |
+
req_index] = sampling_params.frequency_penalty
|
| 351 |
+
if sampling_params.min_p > _SAMPLING_EPS:
|
| 352 |
+
self.min_p_reqs.add(req_id)
|
| 353 |
+
if sampling_params.frequency_penalty != 0.0:
|
| 354 |
+
self.frequency_penalties_reqs.add(req_id)
|
| 355 |
+
self.presence_penalties_cpu[
|
| 356 |
+
req_index] = sampling_params.presence_penalty
|
| 357 |
+
if sampling_params.presence_penalty != 0.0:
|
| 358 |
+
self.presence_penalties_reqs.add(req_id)
|
| 359 |
+
self.repetition_penalties_cpu[
|
| 360 |
+
req_index] = sampling_params.repetition_penalty
|
| 361 |
+
if sampling_params.repetition_penalty != 1.0:
|
| 362 |
+
self.repetition_penalties_reqs.add(req_id)
|
| 363 |
+
if sampling_params.min_tokens:
|
| 364 |
+
self.min_tokens[req_index] = (
|
| 365 |
+
sampling_params.min_tokens,
|
| 366 |
+
sampling_params.all_stop_token_ids)
|
| 367 |
+
|
| 368 |
+
if sampling_params.extra_args and "top_n_sigma" in sampling_params.extra_args:
|
| 369 |
+
self.top_n_sigma_cpu[
|
| 370 |
+
req_index] = sampling_params.extra_args["top_n_sigma"]
|
| 371 |
+
self.top_n_sigma_reqs.add(req_id)
|
| 372 |
+
else:
|
| 373 |
+
self.top_n_sigma_cpu[req_index] = -1
|
| 374 |
+
|
| 375 |
+
# NOTE(woosuk): self.generators should not include the requests that
|
| 376 |
+
# do not have their own generator.
|
| 377 |
+
if request.generator is not None:
|
| 378 |
+
self.generators[req_index] = request.generator
|
| 379 |
+
|
| 380 |
+
if sampling_params.logprobs is not None:
|
| 381 |
+
self.num_logprobs[req_id] = sampling_params.logprobs
|
| 382 |
+
if sampling_params.prompt_logprobs is not None:
|
| 383 |
+
self.num_prompt_logprobs[
|
| 384 |
+
req_id] = sampling_params.prompt_logprobs
|
| 385 |
+
if sampling_params.logit_bias is not None:
|
| 386 |
+
self.logit_bias[req_index] = sampling_params.logit_bias
|
| 387 |
+
|
| 388 |
+
if sampling_params.allowed_token_ids:
|
| 389 |
+
self.has_allowed_token_ids.add(req_id)
|
| 390 |
+
if self.allowed_token_ids_mask_cpu_tensor is None:
|
| 391 |
+
# Lazy allocation for this tensor, which can be large.
|
| 392 |
+
# False means we don't fill with -inf.
|
| 393 |
+
self.allowed_token_ids_mask = torch.zeros(
|
| 394 |
+
self.max_num_reqs,
|
| 395 |
+
self.vocab_size,
|
| 396 |
+
dtype=torch.bool,
|
| 397 |
+
device=self.device)
|
| 398 |
+
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
|
| 399 |
+
self.max_num_reqs,
|
| 400 |
+
self.vocab_size,
|
| 401 |
+
dtype=torch.bool,
|
| 402 |
+
device="cpu")
|
| 403 |
+
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
|
| 404 |
+
# False means we don't fill with -inf.
|
| 405 |
+
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
| 406 |
+
sampling_params.allowed_token_ids] = False
|
| 407 |
+
|
| 408 |
+
if sampling_params.bad_words_token_ids:
|
| 409 |
+
self.bad_words_token_ids[
|
| 410 |
+
req_index] = sampling_params.bad_words_token_ids
|
| 411 |
+
else:
|
| 412 |
+
assert request.pooling_params is not None
|
| 413 |
+
self.pooling_params[req_id] = request.pooling_params
|
| 414 |
+
|
| 415 |
+
# Add request lora ID
|
| 416 |
+
if request.lora_request:
|
| 417 |
+
lora_id = request.lora_request.lora_int_id
|
| 418 |
+
if lora_id not in self.lora_id_to_request_ids:
|
| 419 |
+
self.lora_id_to_request_ids[lora_id] = set()
|
| 420 |
+
|
| 421 |
+
self.request_lora_mapping[req_index] = lora_id
|
| 422 |
+
self.lora_id_to_request_ids[lora_id].add(request.req_id)
|
| 423 |
+
self.lora_id_to_lora_request[lora_id] = request.lora_request
|
| 424 |
+
else:
|
| 425 |
+
# No LoRA
|
| 426 |
+
self.request_lora_mapping[req_index] = 0
|
| 427 |
+
|
| 428 |
+
def remove_request(self, req_id: str) -> Optional[int]:
|
| 429 |
+
"""This method must always be followed by a call to condense()."""
|
| 430 |
+
|
| 431 |
+
req_index = self.req_id_to_index.pop(req_id, None)
|
| 432 |
+
if req_index is None:
|
| 433 |
+
return None
|
| 434 |
+
self._req_ids[req_index] = None
|
| 435 |
+
self.req_output_token_ids[req_index] = None
|
| 436 |
+
|
| 437 |
+
self.greedy_reqs.discard(req_id)
|
| 438 |
+
self.random_reqs.discard(req_id)
|
| 439 |
+
self.top_p_reqs.discard(req_id)
|
| 440 |
+
self.top_k_reqs.discard(req_id)
|
| 441 |
+
self.min_p_reqs.discard(req_id)
|
| 442 |
+
self.min_tokens.pop(req_index, None)
|
| 443 |
+
self.frequency_penalties_reqs.discard(req_id)
|
| 444 |
+
self.presence_penalties_reqs.discard(req_id)
|
| 445 |
+
self.repetition_penalties_reqs.discard(req_id)
|
| 446 |
+
self.spec_decode_unsupported_reqs.discard(req_id)
|
| 447 |
+
self.top_n_sigma_reqs.discard(req_id)
|
| 448 |
+
self.generators.pop(req_index, None)
|
| 449 |
+
self.num_logprobs.pop(req_id, None)
|
| 450 |
+
self.num_prompt_logprobs.pop(req_id, None)
|
| 451 |
+
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
| 452 |
+
|
| 453 |
+
# LoRA
|
| 454 |
+
lora_id = self.request_lora_mapping[req_index]
|
| 455 |
+
if lora_id != 0:
|
| 456 |
+
self.lora_id_to_request_ids[lora_id].discard(req_id)
|
| 457 |
+
if len(self.lora_id_to_request_ids[lora_id]) == 0:
|
| 458 |
+
self.lora_id_to_request_ids.pop(lora_id)
|
| 459 |
+
self.lora_id_to_lora_request.pop(lora_id)
|
| 460 |
+
self.request_lora_mapping[req_index] = 0
|
| 461 |
+
|
| 462 |
+
self.logit_bias[req_index] = None
|
| 463 |
+
self.has_allowed_token_ids.discard(req_id)
|
| 464 |
+
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
| 465 |
+
# False means we don't fill with -inf.
|
| 466 |
+
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
| 467 |
+
self.bad_words_token_ids.pop(req_index, None)
|
| 468 |
+
self.pooling_params.pop(req_id, None)
|
| 469 |
+
return req_index
|
| 470 |
+
|
| 471 |
+
def swap_states(self, i1: int, i2: int) -> None:
|
| 472 |
+
old_id_i1 = self._req_ids[i1]
|
| 473 |
+
old_id_i2 = self._req_ids[i2]
|
| 474 |
+
self._req_ids[i1], self._req_ids[i2] =\
|
| 475 |
+
self._req_ids[i2], self._req_ids[i1] # noqa
|
| 476 |
+
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
|
| 477 |
+
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
|
| 478 |
+
assert old_id_i1 is not None and old_id_i2 is not None
|
| 479 |
+
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
|
| 480 |
+
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
| 481 |
+
self.num_tokens[i1], self.num_tokens[i2] =\
|
| 482 |
+
self.num_tokens[i2], self.num_tokens[i1]
|
| 483 |
+
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
|
| 484 |
+
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
|
| 485 |
+
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
|
| 486 |
+
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
|
| 487 |
+
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
|
| 488 |
+
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
|
| 489 |
+
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
|
| 490 |
+
self.temperature_cpu[i2], self.temperature_cpu[i1]
|
| 491 |
+
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
|
| 492 |
+
self.top_p_cpu[i2], self.top_p_cpu[i1]
|
| 493 |
+
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
|
| 494 |
+
self.top_k_cpu[i2], self.top_k_cpu[i1]
|
| 495 |
+
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
|
| 496 |
+
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
|
| 497 |
+
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
|
| 498 |
+
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
| 499 |
+
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
|
| 500 |
+
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
| 501 |
+
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
|
| 502 |
+
self.min_p_cpu[i2], self.min_p_cpu[i1]
|
| 503 |
+
self.top_n_sigma_cpu[i1], self.top_n_sigma_cpu[i2] =\
|
| 504 |
+
self.top_n_sigma_cpu[i2], self.top_n_sigma_cpu[i1]
|
| 505 |
+
|
| 506 |
+
# NOTE: the following is unsafe
|
| 507 |
+
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
| 508 |
+
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
| 509 |
+
# instead, we need to temporiarily copy the data for one of the indices
|
| 510 |
+
# TODO(lucas): optimize this by only copying valid indices
|
| 511 |
+
tmp = self.token_ids_cpu[i1, ...].copy()
|
| 512 |
+
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
| 513 |
+
self.token_ids_cpu[i2, ...] = tmp
|
| 514 |
+
|
| 515 |
+
swap_dict_values(self.generators, i1, i2)
|
| 516 |
+
swap_dict_values(self.min_tokens, i1, i2)
|
| 517 |
+
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
| 518 |
+
|
| 519 |
+
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
| 520 |
+
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
| 521 |
+
self.logit_bias[i1], self.logit_bias[i2] =\
|
| 522 |
+
self.logit_bias[i2], self.logit_bias[i1]
|
| 523 |
+
|
| 524 |
+
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
| 525 |
+
self.allowed_token_ids_mask_cpu_tensor[i1], \
|
| 526 |
+
self.allowed_token_ids_mask_cpu_tensor[i2] =\
|
| 527 |
+
self.allowed_token_ids_mask_cpu_tensor[i2], \
|
| 528 |
+
self.allowed_token_ids_mask_cpu_tensor[i1]
|
| 529 |
+
self.block_table.swap_row(i1, i2)
|
| 530 |
+
|
| 531 |
+
def condense(self, empty_req_indices: list[int]) -> None:
|
| 532 |
+
"""Move non-empty requests down into lower, empty indices.
|
| 533 |
+
|
| 534 |
+
Args:
|
| 535 |
+
empty_req_indices: empty batch indices, sorted descending.
|
| 536 |
+
"""
|
| 537 |
+
num_reqs = self.num_reqs
|
| 538 |
+
if num_reqs == 0:
|
| 539 |
+
# The batched states are empty.
|
| 540 |
+
self._req_ids.clear()
|
| 541 |
+
self.req_output_token_ids.clear()
|
| 542 |
+
return
|
| 543 |
+
|
| 544 |
+
# NOTE(woosuk): This function assumes that the empty_req_indices
|
| 545 |
+
# is sorted in descending order.
|
| 546 |
+
last_req_index = num_reqs + len(empty_req_indices) - 1
|
| 547 |
+
while empty_req_indices:
|
| 548 |
+
# Find the largest non-empty index.
|
| 549 |
+
while last_req_index in empty_req_indices:
|
| 550 |
+
last_req_index -= 1
|
| 551 |
+
|
| 552 |
+
# Find the smallest empty index.
|
| 553 |
+
empty_index = empty_req_indices.pop()
|
| 554 |
+
if empty_index >= last_req_index:
|
| 555 |
+
break
|
| 556 |
+
|
| 557 |
+
# Swap the states.
|
| 558 |
+
req_id = self._req_ids[last_req_index]
|
| 559 |
+
output_token_ids = self.req_output_token_ids[last_req_index]
|
| 560 |
+
assert req_id is not None
|
| 561 |
+
self._req_ids[empty_index] = req_id
|
| 562 |
+
self._req_ids[last_req_index] = None
|
| 563 |
+
self.req_output_token_ids[empty_index] = output_token_ids
|
| 564 |
+
self.req_output_token_ids[last_req_index] = None
|
| 565 |
+
self.req_id_to_index[req_id] = empty_index
|
| 566 |
+
|
| 567 |
+
num_tokens = self.num_tokens[last_req_index]
|
| 568 |
+
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
| 569 |
+
last_req_index, :num_tokens]
|
| 570 |
+
self.num_tokens[empty_index] = num_tokens
|
| 571 |
+
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
|
| 572 |
+
last_req_index]
|
| 573 |
+
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
|
| 574 |
+
last_req_index]
|
| 575 |
+
self.num_computed_tokens_cpu[
|
| 576 |
+
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
| 577 |
+
self.block_table.move_row(last_req_index, empty_index)
|
| 578 |
+
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
| 579 |
+
last_req_index]
|
| 580 |
+
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
| 581 |
+
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
| 582 |
+
self.frequency_penalties_cpu[
|
| 583 |
+
empty_index] = self.frequency_penalties_cpu[last_req_index]
|
| 584 |
+
self.presence_penalties_cpu[
|
| 585 |
+
empty_index] = self.presence_penalties_cpu[last_req_index]
|
| 586 |
+
self.repetition_penalties_cpu[
|
| 587 |
+
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
| 588 |
+
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
|
| 589 |
+
self.top_n_sigma_cpu[
|
| 590 |
+
empty_index] = self.top_n_sigma_cpu[last_req_index]
|
| 591 |
+
generator = self.generators.pop(last_req_index, None)
|
| 592 |
+
if generator is not None:
|
| 593 |
+
self.generators[empty_index] = generator
|
| 594 |
+
|
| 595 |
+
min_token = self.min_tokens.pop(last_req_index, None)
|
| 596 |
+
if min_token is not None:
|
| 597 |
+
self.min_tokens[empty_index] = min_token
|
| 598 |
+
|
| 599 |
+
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
| 600 |
+
last_req_index]
|
| 601 |
+
|
| 602 |
+
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
|
| 603 |
+
|
| 604 |
+
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
| 605 |
+
self.allowed_token_ids_mask_cpu_tensor[
|
| 606 |
+
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
|
| 607 |
+
last_req_index]
|
| 608 |
+
|
| 609 |
+
bad_words_token_ids = self.bad_words_token_ids.pop(
|
| 610 |
+
last_req_index, None)
|
| 611 |
+
if bad_words_token_ids is not None:
|
| 612 |
+
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
| 613 |
+
# Decrement last_req_index since it is now empty.
|
| 614 |
+
last_req_index -= 1
|
| 615 |
+
|
| 616 |
+
# Trim lists to the batch size.
|
| 617 |
+
del self._req_ids[self.num_reqs:]
|
| 618 |
+
del self.req_output_token_ids[self.num_reqs:]
|
| 619 |
+
|
| 620 |
+
def refresh_sampling_metadata(self):
|
| 621 |
+
self.sampling_metadata = self._make_sampling_metadata()
|
| 622 |
+
|
| 623 |
+
def _make_sampling_metadata(self) -> Union[SamplingMetadata, SamplingMetadataTopNSigma]:
|
| 624 |
+
num_reqs = self.num_reqs
|
| 625 |
+
if not self.all_greedy:
|
| 626 |
+
temperature = copy_slice(self.temperature_cpu_tensor,
|
| 627 |
+
self.temperature, num_reqs)
|
| 628 |
+
else:
|
| 629 |
+
temperature = None
|
| 630 |
+
if not self.no_top_p:
|
| 631 |
+
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
|
| 632 |
+
if not self.no_top_k:
|
| 633 |
+
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
|
| 634 |
+
if not self.no_min_p:
|
| 635 |
+
copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs)
|
| 636 |
+
|
| 637 |
+
if not self.no_penalties:
|
| 638 |
+
# Since syncing these tensors is expensive only copy them
|
| 639 |
+
# if necessary i.e. if there are requests which require
|
| 640 |
+
# penalties to be applied during sampling.
|
| 641 |
+
copy_slice(self.frequency_penalties_cpu_tensor,
|
| 642 |
+
self.frequency_penalties, num_reqs)
|
| 643 |
+
copy_slice(self.presence_penalties_cpu_tensor,
|
| 644 |
+
self.presence_penalties, num_reqs)
|
| 645 |
+
copy_slice(self.repetition_penalties_cpu_tensor,
|
| 646 |
+
self.repetition_penalties, num_reqs)
|
| 647 |
+
|
| 648 |
+
if not self.no_top_n_sigma:
|
| 649 |
+
copy_slice(self.top_n_sigma_cpu_tensor,
|
| 650 |
+
self.top_n_sigma, num_reqs)
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
needs_prompt_token_ids = (not self.no_penalties or
|
| 654 |
+
(self.num_reqs > 0
|
| 655 |
+
and self.logits_processing_needs_token_ids))
|
| 656 |
+
if needs_prompt_token_ids:
|
| 657 |
+
# The prompt tokens are used only for applying penalties or
|
| 658 |
+
# step pooling during the sampling/pooling process.
|
| 659 |
+
# Hence copy these tensors only when there are requests which
|
| 660 |
+
# need penalties/step_pooler to be applied.
|
| 661 |
+
prompt_token_ids = self._make_prompt_token_ids_tensor()
|
| 662 |
+
else:
|
| 663 |
+
prompt_token_ids = None
|
| 664 |
+
|
| 665 |
+
allowed_token_ids_mask: Optional[torch.Tensor] = None
|
| 666 |
+
if not self.no_allowed_token_ids:
|
| 667 |
+
assert self.allowed_token_ids_mask is not None
|
| 668 |
+
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
|
| 669 |
+
self.allowed_token_ids_mask, num_reqs)
|
| 670 |
+
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
|
| 671 |
+
|
| 672 |
+
return SamplingMetadataTopNSigma(
|
| 673 |
+
temperature=temperature,
|
| 674 |
+
all_greedy=self.all_greedy,
|
| 675 |
+
all_random=self.all_random,
|
| 676 |
+
top_p=None if self.no_top_p else self.top_p[:num_reqs],
|
| 677 |
+
top_k=None if self.no_top_k else self.top_k[:num_reqs],
|
| 678 |
+
generators=self.generators,
|
| 679 |
+
max_num_logprobs=self.max_num_logprobs,
|
| 680 |
+
prompt_token_ids=prompt_token_ids,
|
| 681 |
+
frequency_penalties=self.frequency_penalties[:num_reqs],
|
| 682 |
+
presence_penalties=self.presence_penalties[:num_reqs],
|
| 683 |
+
repetition_penalties=self.repetition_penalties[:num_reqs],
|
| 684 |
+
top_n_sigma=self.top_n_sigma[:num_reqs],
|
| 685 |
+
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
|
| 686 |
+
no_penalties=self.no_penalties,
|
| 687 |
+
no_top_n_sigma=self.no_top_n_sigma,
|
| 688 |
+
allowed_token_ids_mask=allowed_token_ids_mask,
|
| 689 |
+
bad_words_token_ids=self.bad_words_token_ids,
|
| 690 |
+
logitsprocs=self.logitsprocs,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
@property
|
| 694 |
+
def pooling_metadata(self) -> PoolingMetadata:
|
| 695 |
+
if len(self.pooling_params) == 0:
|
| 696 |
+
pooling_params = []
|
| 697 |
+
else:
|
| 698 |
+
# Note, for now this assumes that all request in the batch
|
| 699 |
+
# are either sampling or pooling requests
|
| 700 |
+
assert len(self.req_ids) == len(self.pooling_params)
|
| 701 |
+
pooling_params = [
|
| 702 |
+
self.pooling_params[req_id] for req_id in self.req_ids
|
| 703 |
+
]
|
| 704 |
+
|
| 705 |
+
return PoolingMetadata(
|
| 706 |
+
prompt_lens=torch.from_numpy(
|
| 707 |
+
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
|
| 708 |
+
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
| 709 |
+
pooling_params=pooling_params,
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
| 713 |
+
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
| 714 |
+
prompt_token_ids_cpu_tensor = torch.empty(
|
| 715 |
+
(self.num_reqs, max_prompt_len),
|
| 716 |
+
device="cpu",
|
| 717 |
+
dtype=torch.int64,
|
| 718 |
+
pin_memory=self.pin_memory,
|
| 719 |
+
)
|
| 720 |
+
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
|
| 721 |
+
prompt_token_ids[:] = self.token_ids_cpu[:self.
|
| 722 |
+
num_reqs, :max_prompt_len]
|
| 723 |
+
# Use the value of vocab_size as a pad since we don't have a
|
| 724 |
+
# token_id of this value.
|
| 725 |
+
for i in range(self.num_reqs):
|
| 726 |
+
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
|
| 727 |
+
return prompt_token_ids_cpu_tensor.to(device=self.device,
|
| 728 |
+
non_blocking=True)
|
| 729 |
+
|
| 730 |
+
def make_lora_inputs(
|
| 731 |
+
self, num_scheduled_tokens: np.ndarray
|
| 732 |
+
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
| 733 |
+
"""
|
| 734 |
+
Given the num_scheduled_tokens for each request in the batch, return
|
| 735 |
+
datastructures used to activate the current LoRAs.
|
| 736 |
+
Returns:
|
| 737 |
+
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
|
| 738 |
+
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
|
| 739 |
+
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
|
| 740 |
+
where, token_lora_mapping[i] is the LoRA id to use for ith token.
|
| 741 |
+
3. lora_requests: Set of relevant LoRA requests.
|
| 742 |
+
"""
|
| 743 |
+
|
| 744 |
+
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
|
| 745 |
+
prompt_lora_mapping = tuple(req_lora_mapping)
|
| 746 |
+
token_lora_mapping = tuple(
|
| 747 |
+
req_lora_mapping.repeat(num_scheduled_tokens))
|
| 748 |
+
active_lora_requests: set[LoRARequest] = set(
|
| 749 |
+
self.lora_id_to_lora_request.values())
|
| 750 |
+
|
| 751 |
+
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
| 752 |
+
|
| 753 |
+
@property
|
| 754 |
+
def num_reqs(self) -> int:
|
| 755 |
+
return len(self.req_id_to_index)
|
| 756 |
+
|
| 757 |
+
@property
|
| 758 |
+
def all_greedy(self) -> bool:
|
| 759 |
+
return len(self.random_reqs) == 0
|
| 760 |
+
|
| 761 |
+
@property
|
| 762 |
+
def all_random(self) -> bool:
|
| 763 |
+
return len(self.greedy_reqs) == 0
|
| 764 |
+
|
| 765 |
+
@property
|
| 766 |
+
def no_top_p(self) -> bool:
|
| 767 |
+
return len(self.top_p_reqs) == 0
|
| 768 |
+
|
| 769 |
+
@property
|
| 770 |
+
def no_top_k(self) -> bool:
|
| 771 |
+
return len(self.top_k_reqs) == 0
|
| 772 |
+
|
| 773 |
+
@property
|
| 774 |
+
def no_min_p(self) -> bool:
|
| 775 |
+
return len(self.min_p_reqs) == 0
|
| 776 |
+
|
| 777 |
+
@property
|
| 778 |
+
def no_penalties(self) -> bool:
|
| 779 |
+
return (len(self.presence_penalties_reqs) == 0
|
| 780 |
+
and len(self.frequency_penalties_reqs) == 0
|
| 781 |
+
and len(self.repetition_penalties_reqs) == 0)
|
| 782 |
+
@property
|
| 783 |
+
def no_top_n_sigma(self) -> bool:
|
| 784 |
+
return len(self.top_n_sigma_reqs) == 0
|
| 785 |
+
|
| 786 |
+
@property
|
| 787 |
+
def max_num_logprobs(self) -> Optional[int]:
|
| 788 |
+
return max(self.num_logprobs.values()) if self.num_logprobs else None
|
| 789 |
+
|
| 790 |
+
@property
|
| 791 |
+
def no_prompt_logprob(self) -> bool:
|
| 792 |
+
return not self.num_prompt_logprobs
|
| 793 |
+
|
| 794 |
+
@property
|
| 795 |
+
def no_allowed_token_ids(self) -> bool:
|
| 796 |
+
return len(self.has_allowed_token_ids) == 0
|
inference/vllm_ascend_for_openpangu_embedded_1b.md
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Deployment Guide of openPangu Embedded 1B Based on [vllm-ascend](https://github.com/vllm-project/vllm-ascend)
|
| 2 |
+
|
| 3 |
+
### Deployment Environment Description
|
| 4 |
+
|
| 5 |
+
The Atlas 800T A2 (64 GB) supports the deployment of Pangu Embedded 1B (bf16) with a single card. The vllm-ascend community image v0.9.1-dev is used and needs to be pulled on multiple nodes.
|
| 6 |
+
```bash
|
| 7 |
+
docker pull quay.io/ascend/vllm-ascend:v0.9.1-dev
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
### Docker Boot and Inference Code
|
| 11 |
+
|
| 12 |
+
Perform the following operations on all nodes.
|
| 13 |
+
|
| 14 |
+
Run the following command to start the docker:
|
| 15 |
+
```bash
|
| 16 |
+
# Update the vllm-ascend image
|
| 17 |
+
export IMAGE=quay.io/ascend/vllm-ascend:v0.9.1-dev # Use correct image id
|
| 18 |
+
export NAME=vllm-ascend # Custom docker name
|
| 19 |
+
|
| 20 |
+
# Run the container using the defined variables
|
| 21 |
+
# Note if you are running bridge network with docker, Please expose available ports for multiple nodes communication in advance
|
| 22 |
+
# To prevent device interference from other docker containers, add the argument "--privileged"
|
| 23 |
+
docker run --rm \
|
| 24 |
+
--name $NAME \
|
| 25 |
+
--network host \
|
| 26 |
+
--device /dev/davinci0 \
|
| 27 |
+
--device /dev/davinci1 \
|
| 28 |
+
--device /dev/davinci2 \
|
| 29 |
+
--device /dev/davinci3 \
|
| 30 |
+
--device /dev/davinci4 \
|
| 31 |
+
--device /dev/davinci5 \
|
| 32 |
+
--device /dev/davinci6 \
|
| 33 |
+
--device /dev/davinci7 \
|
| 34 |
+
--device /dev/davinci_manager \
|
| 35 |
+
--device /dev/devmm_svm \
|
| 36 |
+
--device /dev/hisi_hdc \
|
| 37 |
+
-v /usr/local/dcmi:/usr/local/dcmi \
|
| 38 |
+
-v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
|
| 39 |
+
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
| 40 |
+
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
|
| 41 |
+
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
|
| 42 |
+
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
| 43 |
+
-v /mnt/sfs_turbo/.cache:/root/.cache \
|
| 44 |
+
-it $IMAGE bash
|
| 45 |
+
```
|
| 46 |
+
If not inside the container, enter the container as the root user:
|
| 47 |
+
```
|
| 48 |
+
docker exec -itu root $NAME /bin/bash
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
Download vllm (v0.9.2) to replace the built-in vllm code of the image.
|
| 52 |
+
```bash
|
| 53 |
+
pip install --no-deps vllm==0.9.2 pybase64==1.4.1
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
Download [vllm-ascend (v0.9.2rc1)](https://github.com/vllm-project/vllm-ascend/releases/tag/v0.9.2rc1) and replace the built-in vllm-ascend code in the image (/vllm-workspace/vllm-ascend/). For example, download [Source code (tar.gz)](https://github.com/vllm-project/vllm-ascend/archive/refs/tags/v0.9.2rc1.tar.gz) from Assets to get v0.9.2rc1.tar.gz, then extract and replace:
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
tar -zxvf vllm-ascend-0.9.2rc1.tar.gz -C /vllm-workspace/vllm-ascend/ --strip-components=1
|
| 60 |
+
export PYTHONPATH=/vllm-workspace/vllm-ascend/:${PYTHONPATH}
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
Use the Pangu model-adapted vllm-ascend code from the current repository to replace parts of the code in `/vllm-workspace/vllm-ascend/vllm_ascend/`:
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
yes | cp -r inference/vllm_ascend/* /vllm-workspace/vllm-ascend/vllm_ascend/
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### openPangu Embedded Inference
|
| 70 |
+
|
| 71 |
+
Perform the following operations on all nodes.
|
| 72 |
+
|
| 73 |
+
Configuration:
|
| 74 |
+
```bash
|
| 75 |
+
export VLLM_USE_V1=1
|
| 76 |
+
# Specifying HOST=127.0.0.1 (localhost) means the server can only be accessed from the master device.
|
| 77 |
+
# Specifying HOST=0.0.0.0 allows the vLLM server to be accessed from other devices on the same network or even from the internet, provided proper network configuration (e.g., firewall rules, port forwarding) is in place.
|
| 78 |
+
HOST=xxx.xxx.xxx.xxx
|
| 79 |
+
PORT=8080
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
openPangu Embedded 1B running command:
|
| 83 |
+
```bash
|
| 84 |
+
export ASCEND_RT_VISIBLE_DEVICES=0
|
| 85 |
+
LOCAL_CKPT_DIR=/root/.cache/pangu_embedded_1b # The pangu_embedded_1b bf16 weight
|
| 86 |
+
SERVED_MODEL_NAME=pangu_embedded_1b
|
| 87 |
+
|
| 88 |
+
vllm serve $LOCAL_CKPT_DIR \
|
| 89 |
+
--served-model-name $SERVED_MODEL_NAME \
|
| 90 |
+
--tensor-parallel-size 1 \
|
| 91 |
+
--trust-remote-code \
|
| 92 |
+
--host $HOST \
|
| 93 |
+
--port $PORT \
|
| 94 |
+
--max-num-seqs 32 \
|
| 95 |
+
--max-model-len 32768 \
|
| 96 |
+
--max-num-batched-tokens 4096 \
|
| 97 |
+
--tokenizer-mode "slow" \
|
| 98 |
+
--dtype bfloat16 \
|
| 99 |
+
--distributed-executor-backend mp \
|
| 100 |
+
--gpu-memory-utilization 0.93 \
|
| 101 |
+
--no-enable-prefix-caching \
|
| 102 |
+
--no-enable-chunked-prefill \
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
### Test Request
|
| 106 |
+
|
| 107 |
+
After server launched, send test request from master node or other nodes:
|
| 108 |
+
|
| 109 |
+
```bash
|
| 110 |
+
MASTER_NODE_IP=xxx.xxx.xxx.xxx # server node ip
|
| 111 |
+
curl http://${MASTER_NODE_IP}:${PORT}/v1/chat/completions \
|
| 112 |
+
-H "Content-Type: application/json" \
|
| 113 |
+
-d '{
|
| 114 |
+
"model": "'$SERVED_MODEL_NAME'",
|
| 115 |
+
"messages": [
|
| 116 |
+
{
|
| 117 |
+
"role": "user",
|
| 118 |
+
"content": "Who are you?"
|
| 119 |
+
}
|
| 120 |
+
],
|
| 121 |
+
"max_tokens": 512,
|
| 122 |
+
"temperature": 0
|
| 123 |
+
}'
|
| 124 |
+
```
|
inference/vllm_ascend_for_openpangu_embedded_1b.zh.md
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## openPangu Embedded 1B在[vllm-ascend](https://github.com/vllm-project/vllm-ascend)部署指导文档
|
| 2 |
+
|
| 3 |
+
### 部署环境说明
|
| 4 |
+
|
| 5 |
+
Atlas 800T A2(64GB)单卡可以部署openPangu Embedded 1B(bf16),选用vllm-ascend社区镜像v0.9.1-dev。
|
| 6 |
+
```bash
|
| 7 |
+
docker pull quay.io/ascend/vllm-ascend:v0.9.1-dev
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
### 镜像启动和推理代码适配
|
| 11 |
+
|
| 12 |
+
以下操作需在每个节点都执行。
|
| 13 |
+
|
| 14 |
+
启动镜像:
|
| 15 |
+
```bash
|
| 16 |
+
# Update the vllm-ascend image
|
| 17 |
+
export IMAGE=quay.io/ascend/vllm-ascend:v0.9.1-dev # Use correct image id
|
| 18 |
+
export NAME=vllm-ascend # Custom docker name
|
| 19 |
+
|
| 20 |
+
# Run the container using the defined variables
|
| 21 |
+
# Note if you are running bridge network with docker, Please expose available ports for multiple nodes communication in advance
|
| 22 |
+
# To prevent device interference from other docker containers, add the argument "--privileged"
|
| 23 |
+
docker run --rm \
|
| 24 |
+
--name $NAME \
|
| 25 |
+
--network host \
|
| 26 |
+
--device /dev/davinci0 \
|
| 27 |
+
--device /dev/davinci1 \
|
| 28 |
+
--device /dev/davinci2 \
|
| 29 |
+
--device /dev/davinci3 \
|
| 30 |
+
--device /dev/davinci4 \
|
| 31 |
+
--device /dev/davinci5 \
|
| 32 |
+
--device /dev/davinci6 \
|
| 33 |
+
--device /dev/davinci7 \
|
| 34 |
+
--device /dev/davinci_manager \
|
| 35 |
+
--device /dev/devmm_svm \
|
| 36 |
+
--device /dev/hisi_hdc \
|
| 37 |
+
-v /usr/local/dcmi:/usr/local/dcmi \
|
| 38 |
+
-v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
|
| 39 |
+
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
| 40 |
+
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
|
| 41 |
+
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
|
| 42 |
+
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
| 43 |
+
-v /mnt/sfs_turbo/.cache:/root/.cache \
|
| 44 |
+
-it $IMAGE bash
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
如果未进入容器,需以root用户进入容器:
|
| 48 |
+
```
|
| 49 |
+
docker exec -itu root $NAME /bin/bash
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
下载vllm (v0.9.2),替换镜像内置的vllm代码。
|
| 53 |
+
```bash
|
| 54 |
+
pip install --no-deps vllm==0.9.2 pybase64==1.4.1
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
下载[vllm-ascend (v0.9.2rc1)](https://github.com/vllm-project/vllm-ascend/releases/tag/v0.9.2rc1),替换镜像内置的vllm-ascend代码(`/vllm-workspace/vllm-ascend/`)。例如下载Assets中的[Source code
|
| 58 |
+
(tar.gz)](https://github.com/vllm-project/vllm-ascend/archive/refs/tags/v0.9.2rc1.tar.gz)得到v0.9.2rc1.tar.gz,然后解压并替换:
|
| 59 |
+
```bash
|
| 60 |
+
tar -zxvf vllm-ascend-0.9.2rc1.tar.gz -C /vllm-workspace/vllm-ascend/ --strip-components=1
|
| 61 |
+
export PYTHONPATH=/vllm-workspace/vllm-ascend/:${PYTHONPATH}
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
使用当前代码仓中适配盘古模型的vllm-ascend代码替换`/vllm-workspace/vllm-ascend/vllm_ascend/`中的部分代码。
|
| 65 |
+
```bash
|
| 66 |
+
yes | cp -r inference/vllm_ascend/* /vllm-workspace/vllm-ascend/vllm_ascend/
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### Pangu Embedded 推理
|
| 70 |
+
|
| 71 |
+
以下操作需在每个节点都执行。
|
| 72 |
+
|
| 73 |
+
配置:
|
| 74 |
+
```bash
|
| 75 |
+
export VLLM_USE_V1=1
|
| 76 |
+
# Specifying HOST=127.0.0.1 (localhost) means the server can only be accessed from the master device.
|
| 77 |
+
# Specifying HOST=0.0.0.0 allows the vLLM server to be accessed from other devices on the same network or even from the internet, provided proper network configuration (e.g., firewall rules, port forwarding) is in place.
|
| 78 |
+
HOST=xxx.xxx.xxx.xxx
|
| 79 |
+
PORT=8080
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
Pangu Embedded 1B 运行命令:
|
| 83 |
+
```bash
|
| 84 |
+
export ASCEND_RT_VISIBLE_DEVICES=0
|
| 85 |
+
LOCAL_CKPT_DIR=/root/.cache/pangu_embedded_1b # The pangu_embedded_1b bf16 weight
|
| 86 |
+
SERVED_MODEL_NAME=pangu_embedded_1b
|
| 87 |
+
|
| 88 |
+
vllm serve $LOCAL_CKPT_DIR \
|
| 89 |
+
--served-model-name $SERVED_MODEL_NAME \
|
| 90 |
+
--tensor-parallel-size 1 \
|
| 91 |
+
--trust-remote-code \
|
| 92 |
+
--host $HOST \
|
| 93 |
+
--port $PORT \
|
| 94 |
+
--max-num-seqs 32 \
|
| 95 |
+
--max-model-len 32768 \
|
| 96 |
+
--max-num-batched-tokens 4096 \
|
| 97 |
+
--tokenizer-mode "slow" \
|
| 98 |
+
--dtype bfloat16 \
|
| 99 |
+
--distributed-executor-backend mp \
|
| 100 |
+
--gpu-memory-utilization 0.93 \
|
| 101 |
+
--no-enable-prefix-caching \
|
| 102 |
+
--no-enable-chunked-prefill \
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
### 发请求测试
|
| 106 |
+
|
| 107 |
+
服务启动后,在主节点或者其他节点向主节点发送测试请求:
|
| 108 |
+
|
| 109 |
+
```bash
|
| 110 |
+
MASTER_NODE_IP=xxx.xxx.xxx.xxx # server node ip
|
| 111 |
+
curl http://${MASTER_NODE_IP}:${PORT}/v1/chat/completions \
|
| 112 |
+
-H "Content-Type: application/json" \
|
| 113 |
+
-d '{
|
| 114 |
+
"model": "'$SERVED_MODEL_NAME'",
|
| 115 |
+
"messages": [
|
| 116 |
+
{
|
| 117 |
+
"role": "user",
|
| 118 |
+
"content": "Who are you?"
|
| 119 |
+
}
|
| 120 |
+
],
|
| 121 |
+
"max_tokens": 512,
|
| 122 |
+
"temperature": 0
|
| 123 |
+
}'
|
| 124 |
+
```
|