Mirror DeepSeek source tree (code/config only, no weights)
Browse files- .gitattributes +1 -0
- LICENSE +21 -0
- README.md +154 -0
- assets/chat_template.jinja +90 -0
- config.json +66 -0
- generation_config.json +9 -0
- inference/README.md +14 -0
- inference/config_671B_v3.2.json +26 -0
- inference/convert.py +100 -0
- inference/generate.py +186 -0
- inference/kernel.py +274 -0
- inference/model.py +923 -0
- inference/requirements.txt +5 -0
- model.safetensors.index.json +0 -0
- tokenizer.json +0 -0
- tokenizer_config.json +35 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/cost.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 DeepSeek
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
library_name: transformers
|
| 4 |
+
base_model:
|
| 5 |
+
- deepseek-ai/DeepSeek-V3.2-Exp-Base
|
| 6 |
+
base_model_relation: finetune
|
| 7 |
+
---
|
| 8 |
+
# DeepSeek-V3.2-Exp
|
| 9 |
+
|
| 10 |
+
<!-- markdownlint-disable first-line-h1 -->
|
| 11 |
+
<!-- markdownlint-disable html -->
|
| 12 |
+
<!-- markdownlint-disable no-duplicate-header -->
|
| 13 |
+
|
| 14 |
+
<div align="center">
|
| 15 |
+
<img src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/logo.svg?raw=true" width="60%" alt="DeepSeek-V3" />
|
| 16 |
+
</div>
|
| 17 |
+
<hr>
|
| 18 |
+
<div align="center" style="line-height: 1;">
|
| 19 |
+
<a href="https://www.deepseek.com/" target="_blank" style="margin: 2px;">
|
| 20 |
+
<img alt="Homepage" src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/badge.svg?raw=true" style="display: inline-block; vertical-align: middle;"/>
|
| 21 |
+
</a>
|
| 22 |
+
<a href="https://chat.deepseek.com/" target="_blank" style="margin: 2px;">
|
| 23 |
+
<img alt="Chat" src="https://img.shields.io/badge/🤖%20Chat-DeepSeek%20V3-536af5?color=536af5&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
|
| 24 |
+
</a>
|
| 25 |
+
<a href="https://huggingface.co/deepseek-ai" target="_blank" style="margin: 2px;">
|
| 26 |
+
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
|
| 27 |
+
</a>
|
| 28 |
+
</div>
|
| 29 |
+
<div align="center" style="line-height: 1;">
|
| 30 |
+
<a href="https://discord.gg/Tc7c45Zzu5" target="_blank" style="margin: 2px;">
|
| 31 |
+
<img alt="Discord" src="https://img.shields.io/badge/Discord-DeepSeek%20AI-7289da?logo=discord&logoColor=white&color=7289da" style="display: inline-block; vertical-align: middle;"/>
|
| 32 |
+
</a>
|
| 33 |
+
<a href="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/qr.jpeg?raw=true" target="_blank" style="margin: 2px;">
|
| 34 |
+
<img alt="Wechat" src="https://img.shields.io/badge/WeChat-DeepSeek%20AI-brightgreen?logo=wechat&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
|
| 35 |
+
</a>
|
| 36 |
+
<a href="https://twitter.com/deepseek_ai" target="_blank" style="margin: 2px;">
|
| 37 |
+
<img alt="Twitter Follow" src="https://img.shields.io/badge/Twitter-deepseek_ai-white?logo=x&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
|
| 38 |
+
</a>
|
| 39 |
+
</div>
|
| 40 |
+
<div align="center" style="line-height: 1;">
|
| 41 |
+
<a href="LICENSE" style="margin: 2px;">
|
| 42 |
+
<img alt="License" src="https://img.shields.io/badge/License-MIT-f5de53?&color=f5de53" style="display: inline-block; vertical-align: middle;"/>
|
| 43 |
+
</a>
|
| 44 |
+
</div>
|
| 45 |
+
|
| 46 |
+
## Introduction
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
We are excited to announce the official release of DeepSeek-V3.2-Exp, an experimental version of our model. As an intermediate step toward our next-generation architecture, V3.2-Exp builds upon V3.1-Terminus by introducing DeepSeek Sparse Attention—a sparse attention mechanism designed to explore and validate optimizations for training and inference efficiency in long-context scenarios.
|
| 50 |
+
|
| 51 |
+
This experimental release represents our ongoing research into more efficient transformer architectures, particularly focusing on improving computational efficiency when processing extended text sequences.
|
| 52 |
+
|
| 53 |
+
<div align="center">
|
| 54 |
+
<img src="assets/cost.png" >
|
| 55 |
+
</div>
|
| 56 |
+
|
| 57 |
+
- DeepSeek Sparse Attention (DSA) achieves fine-grained sparse attention for the first time, delivering substantial improvements in long-context training and inference efficiency while maintaining virtually identical model output quality.
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
- To rigorously evaluate the impact of introducing sparse attention, we deliberately aligned the training configurations of DeepSeek-V3.2-Exp with V3.1-Terminus. Across public benchmarks in various domains, DeepSeek-V3.2-Exp demonstrates performance on par with V3.1-Terminus.
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
| Benchmark | DeepSeek-V3.1-Terminus | DeepSeek-V3.2-Exp |
|
| 64 |
+
| :--- | :---: | :---: |
|
| 65 |
+
| **Reasoning Mode w/o Tool Use** | | |
|
| 66 |
+
| MMLU-Pro | 85.0 | 85.0 |
|
| 67 |
+
| GPQA-Diamond | 80.7 | 79.9 |
|
| 68 |
+
| Humanity's Last Exam | 21.7 | 19.8 |
|
| 69 |
+
| LiveCodeBench | 74.9 | 74.1 |
|
| 70 |
+
| AIME 2025 | 88.4 | 89.3 |
|
| 71 |
+
| HMMT 2025 | 86.1 | 83.6 |
|
| 72 |
+
| Codeforces | 2046 | 2121 |
|
| 73 |
+
| Aider-Polyglot | 76.1 | 74.5 |
|
| 74 |
+
| **Agentic Tool Use** | | |
|
| 75 |
+
| BrowseComp | 38.5 | 40.1 |
|
| 76 |
+
| BrowseComp-zh | 45.0 | 47.9 |
|
| 77 |
+
| SimpleQA | 96.8 | 97.1 |
|
| 78 |
+
| SWE Verified | 68.4 | 67.8 |
|
| 79 |
+
| SWE-bench Multilingual | 57.8 | 57.9 |
|
| 80 |
+
| Terminal-bench | 36.7 | 37.7 |
|
| 81 |
+
|
| 82 |
+
## Update
|
| 83 |
+
|
| 84 |
+
- 2025.11.17: **We have identified that previous versions of the inference demo code contained an implementation discrepancy in Rotary Position Embedding (RoPE) within the indexer module, potentially leading to degraded model performance.** Specifically, the input tensor to RoPE in the indexer module requires a non-interleaved layout, whereas RoPE in the MLA module expects an interleaved layout. This issue has now been resolved. Please refer to the updated version of the inference demo code and take note of this implementation detail.
|
| 85 |
+
|
| 86 |
+
## How to Run Locally
|
| 87 |
+
|
| 88 |
+
### HuggingFace
|
| 89 |
+
|
| 90 |
+
We provide an updated inference demo code in the [inference](https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp/tree/main/inference) folder to help the community quickly get started with our model and understand its architectural details.
|
| 91 |
+
|
| 92 |
+
First convert huggingface model weights to the the format required by our inference demo. Set `MP` to match your available GPU count:
|
| 93 |
+
```bash
|
| 94 |
+
cd inference
|
| 95 |
+
export EXPERTS=256
|
| 96 |
+
python convert.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP}
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
Launch the interactive chat interface and start exploring DeepSeek's capabilities:
|
| 100 |
+
```bash
|
| 101 |
+
export CONFIG=config_671B_v3.2.json
|
| 102 |
+
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
### SGLang
|
| 106 |
+
|
| 107 |
+
#### Installation with Docker
|
| 108 |
+
|
| 109 |
+
```
|
| 110 |
+
# H200
|
| 111 |
+
docker pull lmsysorg/sglang:dsv32
|
| 112 |
+
|
| 113 |
+
# MI350
|
| 114 |
+
docker pull lmsysorg/sglang:dsv32-rocm
|
| 115 |
+
|
| 116 |
+
# NPUs
|
| 117 |
+
docker pull lmsysorg/sglang:dsv32-a2
|
| 118 |
+
docker pull lmsysorg/sglang:dsv32-a3
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
#### Launch Command
|
| 122 |
+
```bash
|
| 123 |
+
python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp 8 --enable-dp-attention
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
### vLLM
|
| 127 |
+
|
| 128 |
+
vLLM provides day-0 support of DeepSeek-V3.2-Exp. See the [recipes](https://docs.vllm.ai/projects/recipes/en/latest/DeepSeek/DeepSeek-V3_2-Exp.html) for up-to-date details.
|
| 129 |
+
|
| 130 |
+
## Open-Source Kernels
|
| 131 |
+
|
| 132 |
+
For TileLang kernels with **better readability and research-purpose design**, please refer to [TileLang](https://github.com/tile-ai/tilelang/tree/main/examples/deepseek_v32).
|
| 133 |
+
|
| 134 |
+
For **high-performance CUDA kernels**, indexer logit kernels (including paged versions) are available in [DeepGEMM](https://github.com/deepseek-ai/DeepGEMM/pull/200). Sparse attention kernels are released in [FlashMLA](https://github.com/deepseek-ai/FlashMLA/pull/98).
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
## License
|
| 139 |
+
|
| 140 |
+
This repository and the model weights are licensed under the [MIT License](LICENSE).
|
| 141 |
+
|
| 142 |
+
## Citation
|
| 143 |
+
|
| 144 |
+
```
|
| 145 |
+
@misc{deepseekai2024deepseekv32,
|
| 146 |
+
title={DeepSeek-V3.2-Exp: Boosting Long-Context Efficiency with DeepSeek Sparse Attention},
|
| 147 |
+
author={DeepSeek-AI},
|
| 148 |
+
year={2025},
|
| 149 |
+
}
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
## Contact
|
| 153 |
+
|
| 154 |
+
If you have any questions, please raise an issue or contact us at [service@deepseek.com](service@deepseek.com).
|
assets/chat_template.jinja
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{% if not add_generation_prompt is defined %}
|
| 2 |
+
{% set add_generation_prompt = false %}
|
| 3 |
+
{% endif %}
|
| 4 |
+
{% if not thinking is defined %}
|
| 5 |
+
{% set thinking = false %}
|
| 6 |
+
{% endif %}
|
| 7 |
+
{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false, is_only_sys=false, is_prefix=false) %}
|
| 8 |
+
{%- for message in messages %}
|
| 9 |
+
{%- if message['role'] == 'system' %}
|
| 10 |
+
{%- if ns.is_first_sp %}
|
| 11 |
+
{% set ns.system_prompt = ns.system_prompt + message['content'] %}
|
| 12 |
+
{% set ns.is_first_sp = false %}
|
| 13 |
+
{%- else %}
|
| 14 |
+
{% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}
|
| 15 |
+
{%- endif %}
|
| 16 |
+
{% set ns.is_only_sys = true %}
|
| 17 |
+
{%- endif %}
|
| 18 |
+
{%- endfor %}
|
| 19 |
+
{{ bos_token }}{{ ns.system_prompt }}
|
| 20 |
+
{%- for message in messages %}
|
| 21 |
+
{%- if message['role'] == 'user' %}
|
| 22 |
+
{%- set ns.is_tool = false -%}
|
| 23 |
+
{%- set ns.is_first = false -%}
|
| 24 |
+
{%- set ns.is_last_user = true -%}
|
| 25 |
+
{{'<|User|>' + message['content']}}
|
| 26 |
+
{%- endif %}
|
| 27 |
+
{%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}
|
| 28 |
+
{%- if ns.is_last_user or ns.is_only_sys %}
|
| 29 |
+
{{'<|Assistant|></think>'}}
|
| 30 |
+
{%- endif %}
|
| 31 |
+
{%- set ns.is_last_user = false -%}
|
| 32 |
+
{%- set ns.is_first = false %}
|
| 33 |
+
{%- set ns.is_tool = false -%}
|
| 34 |
+
{%- for tool in message['tool_calls'] %}
|
| 35 |
+
{%- if not ns.is_first %}
|
| 36 |
+
{%- if message['content'] is none %}
|
| 37 |
+
{{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}
|
| 38 |
+
{%- else %}
|
| 39 |
+
{{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}
|
| 40 |
+
{%- endif %}
|
| 41 |
+
{%- set ns.is_first = true -%}
|
| 42 |
+
{%- else %}
|
| 43 |
+
{{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}
|
| 44 |
+
{%- endif %}
|
| 45 |
+
{%- endfor %}
|
| 46 |
+
{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}
|
| 47 |
+
{%- endif %}
|
| 48 |
+
{%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %}
|
| 49 |
+
{%- if ns.is_last_user %}
|
| 50 |
+
{{'<|Assistant|>'}}
|
| 51 |
+
{%- if message['prefix'] is defined and message['prefix'] and thinking %}
|
| 52 |
+
{{'<think>'}}
|
| 53 |
+
{%- else %}
|
| 54 |
+
{{'</think>'}}
|
| 55 |
+
{%- endif %}
|
| 56 |
+
{%- endif %}
|
| 57 |
+
{%- if message['prefix'] is defined and message['prefix'] %}
|
| 58 |
+
{%- set ns.is_prefix = true -%}
|
| 59 |
+
{%- endif %}
|
| 60 |
+
{%- set ns.is_last_user = false -%}
|
| 61 |
+
{%- if ns.is_tool %}
|
| 62 |
+
{{message['content'] + '<|end▁of▁sentence|>'}}
|
| 63 |
+
{%- set ns.is_tool = false -%}
|
| 64 |
+
{%- else %}
|
| 65 |
+
{%- set content = message['content'] -%}
|
| 66 |
+
{%- if '</think>' in content %}
|
| 67 |
+
{%- set content = content.split('</think>', 1)[1] -%}
|
| 68 |
+
{%- endif %}
|
| 69 |
+
{{content + '<|end▁of▁sentence|>'}}
|
| 70 |
+
{%- endif %}
|
| 71 |
+
{%- endif %}
|
| 72 |
+
{%- if message['role'] == 'tool' %}
|
| 73 |
+
{%- set ns.is_last_user = false -%}
|
| 74 |
+
{%- set ns.is_tool = true -%}
|
| 75 |
+
{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
|
| 76 |
+
{%- endif %}
|
| 77 |
+
{%- if message['role'] != 'system' %}
|
| 78 |
+
{% set ns.is_only_sys = false %}
|
| 79 |
+
{%- endif %}
|
| 80 |
+
{%- endfor -%}
|
| 81 |
+
{% if add_generation_prompt and not ns.is_tool%}
|
| 82 |
+
{% if ns.is_last_user or ns.is_only_sys or not ns.is_prefix %}
|
| 83 |
+
{{'<|Assistant|>'}}
|
| 84 |
+
{%- if not thinking %}
|
| 85 |
+
{{'</think>'}}
|
| 86 |
+
{%- else %}
|
| 87 |
+
{{'<think>'}}
|
| 88 |
+
{%- endif %}
|
| 89 |
+
{% endif %}
|
| 90 |
+
{% endif %}
|
config.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"DeepseekV32ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 0,
|
| 8 |
+
"eos_token_id": 1,
|
| 9 |
+
"ep_size": 1,
|
| 10 |
+
"first_k_dense_replace": 3,
|
| 11 |
+
"hidden_act": "silu",
|
| 12 |
+
"hidden_size": 7168,
|
| 13 |
+
"index_head_dim": 128,
|
| 14 |
+
"index_n_heads": 64,
|
| 15 |
+
"index_topk": 2048,
|
| 16 |
+
"initializer_range": 0.02,
|
| 17 |
+
"intermediate_size": 18432,
|
| 18 |
+
"kv_lora_rank": 512,
|
| 19 |
+
"max_position_embeddings": 163840,
|
| 20 |
+
"model_type": "deepseek_v32",
|
| 21 |
+
"moe_intermediate_size": 2048,
|
| 22 |
+
"moe_layer_freq": 1,
|
| 23 |
+
"n_group": 8,
|
| 24 |
+
"n_routed_experts": 256,
|
| 25 |
+
"n_shared_experts": 1,
|
| 26 |
+
"norm_topk_prob": true,
|
| 27 |
+
"num_attention_heads": 128,
|
| 28 |
+
"num_experts_per_tok": 8,
|
| 29 |
+
"num_hidden_layers": 61,
|
| 30 |
+
"num_key_value_heads": 128,
|
| 31 |
+
"num_nextn_predict_layers": 1,
|
| 32 |
+
"q_lora_rank": 1536,
|
| 33 |
+
"qk_nope_head_dim": 128,
|
| 34 |
+
"qk_rope_head_dim": 64,
|
| 35 |
+
"quantization_config": {
|
| 36 |
+
"activation_scheme": "dynamic",
|
| 37 |
+
"fmt": "e4m3",
|
| 38 |
+
"quant_method": "fp8",
|
| 39 |
+
"scale_fmt": "ue8m0",
|
| 40 |
+
"weight_block_size": [
|
| 41 |
+
128,
|
| 42 |
+
128
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
"rms_norm_eps": 1e-06,
|
| 46 |
+
"rope_scaling": {
|
| 47 |
+
"beta_fast": 32,
|
| 48 |
+
"beta_slow": 1,
|
| 49 |
+
"factor": 40,
|
| 50 |
+
"mscale": 1.0,
|
| 51 |
+
"mscale_all_dim": 1.0,
|
| 52 |
+
"original_max_position_embeddings": 4096,
|
| 53 |
+
"type": "yarn"
|
| 54 |
+
},
|
| 55 |
+
"rope_theta": 10000,
|
| 56 |
+
"routed_scaling_factor": 2.5,
|
| 57 |
+
"scoring_func": "sigmoid",
|
| 58 |
+
"tie_word_embeddings": false,
|
| 59 |
+
"topk_group": 4,
|
| 60 |
+
"topk_method": "noaux_tc",
|
| 61 |
+
"torch_dtype": "bfloat16",
|
| 62 |
+
"transformers_version": "4.44.2",
|
| 63 |
+
"use_cache": true,
|
| 64 |
+
"v_head_dim": 128,
|
| 65 |
+
"vocab_size": 129280
|
| 66 |
+
}
|
generation_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 0,
|
| 4 |
+
"eos_token_id": 1,
|
| 5 |
+
"do_sample": true,
|
| 6 |
+
"temperature": 0.6,
|
| 7 |
+
"top_p": 0.95,
|
| 8 |
+
"transformers_version": "4.46.3"
|
| 9 |
+
}
|
inference/README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DeepSeek V3.2
|
| 2 |
+
|
| 3 |
+
First convert huggingface model weights to the the format required by our inference demo. Set `MP` to match your available GPU count:
|
| 4 |
+
```bash
|
| 5 |
+
cd inference
|
| 6 |
+
export EXPERTS=256
|
| 7 |
+
python convert.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP}
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
Launch the interactive chat interface and start exploring DeepSeek's capabilities:
|
| 11 |
+
```bash
|
| 12 |
+
export CONFIG=config_671B_v3.2.json
|
| 13 |
+
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive
|
| 14 |
+
```
|
inference/config_671B_v3.2.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 129280,
|
| 3 |
+
"dim": 7168,
|
| 4 |
+
"inter_dim": 18432,
|
| 5 |
+
"moe_inter_dim": 2048,
|
| 6 |
+
"n_layers": 61,
|
| 7 |
+
"n_dense_layers": 3,
|
| 8 |
+
"n_heads": 128,
|
| 9 |
+
"n_routed_experts": 256,
|
| 10 |
+
"n_shared_experts": 1,
|
| 11 |
+
"n_activated_experts": 8,
|
| 12 |
+
"n_expert_groups": 8,
|
| 13 |
+
"n_limited_groups": 4,
|
| 14 |
+
"route_scale": 2.5,
|
| 15 |
+
"score_func": "sigmoid",
|
| 16 |
+
"q_lora_rank": 1536,
|
| 17 |
+
"kv_lora_rank": 512,
|
| 18 |
+
"qk_nope_head_dim": 128,
|
| 19 |
+
"qk_rope_head_dim": 64,
|
| 20 |
+
"v_head_dim": 128,
|
| 21 |
+
"dtype": "fp8",
|
| 22 |
+
"scale_fmt": "ue8m0",
|
| 23 |
+
"index_n_heads": 64,
|
| 24 |
+
"index_head_dim": 128,
|
| 25 |
+
"index_topk": 2048
|
| 26 |
+
}
|
inference/convert.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
from argparse import ArgumentParser
|
| 4 |
+
from glob import glob
|
| 5 |
+
from tqdm import tqdm, trange
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from safetensors.torch import safe_open, save_file
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
mapping = {
|
| 12 |
+
"embed_tokens": ("embed", 0),
|
| 13 |
+
"input_layernorm": ("attn_norm", None),
|
| 14 |
+
"post_attention_layernorm": ("ffn_norm", None),
|
| 15 |
+
"q_proj": ("wq", 0),
|
| 16 |
+
"q_a_proj": ("wq_a", None),
|
| 17 |
+
"q_a_layernorm": ("q_norm", None),
|
| 18 |
+
"q_b_proj": ("wq_b", 0),
|
| 19 |
+
"kv_a_proj_with_mqa": ("wkv_a", None),
|
| 20 |
+
"kv_a_layernorm": ("kv_norm", None),
|
| 21 |
+
"kv_b_proj": ("wkv_b", 0),
|
| 22 |
+
"o_proj": ("wo", 1),
|
| 23 |
+
"gate": ("gate", None),
|
| 24 |
+
"gate_proj": ("w1", 0),
|
| 25 |
+
"down_proj": ("w2", 1),
|
| 26 |
+
"up_proj": ("w3", 0),
|
| 27 |
+
"norm": ("norm", None),
|
| 28 |
+
"lm_head": ("head", 0),
|
| 29 |
+
"scale": ("scale", None),
|
| 30 |
+
"wq_b": ("wq_b", None),
|
| 31 |
+
"wk": ("wk", None),
|
| 32 |
+
"k_norm": ("k_norm", None),
|
| 33 |
+
"weights_proj": ("weights_proj", None),
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def main(hf_ckpt_path, save_path, n_experts, mp):
|
| 38 |
+
"""
|
| 39 |
+
Converts and saves model checkpoint files into a specified format.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
|
| 43 |
+
save_path (str): Path to the directory where the converted checkpoint files will be saved.
|
| 44 |
+
n_experts (int): Total number of experts in the model.
|
| 45 |
+
mp (int): Model parallelism factor.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
None
|
| 49 |
+
"""
|
| 50 |
+
torch.set_num_threads(8)
|
| 51 |
+
n_local_experts = n_experts // mp
|
| 52 |
+
state_dicts = [{} for _ in range(mp)]
|
| 53 |
+
|
| 54 |
+
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
|
| 55 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 56 |
+
for name in f.keys():
|
| 57 |
+
if "model.layers.61" in name:
|
| 58 |
+
continue
|
| 59 |
+
param: torch.Tensor = f.get_tensor(name)
|
| 60 |
+
if name.startswith("model."):
|
| 61 |
+
name = name[len("model."):]
|
| 62 |
+
name = name.replace("self_attn", "attn")
|
| 63 |
+
name = name.replace("mlp", "ffn")
|
| 64 |
+
name = name.replace("weight_scale_inv", "scale")
|
| 65 |
+
name = name.replace("e_score_correction_bias", "bias")
|
| 66 |
+
key = name.split(".")[-2]
|
| 67 |
+
assert key in mapping, f"Key {key} not found in mapping"
|
| 68 |
+
new_key, dim = mapping[key]
|
| 69 |
+
name = name.replace(key, new_key)
|
| 70 |
+
for i in range(mp):
|
| 71 |
+
new_param = param
|
| 72 |
+
if "experts" in name and "shared_experts" not in name:
|
| 73 |
+
idx = int(name.split(".")[-3])
|
| 74 |
+
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
|
| 75 |
+
continue
|
| 76 |
+
elif dim is not None:
|
| 77 |
+
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
|
| 78 |
+
shard_size = param.size(dim) // mp
|
| 79 |
+
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
|
| 80 |
+
state_dicts[i][name] = new_param
|
| 81 |
+
|
| 82 |
+
os.makedirs(save_path, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
for i in trange(mp):
|
| 85 |
+
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
|
| 86 |
+
|
| 87 |
+
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
|
| 88 |
+
new_file_path = os.path.join(save_path, os.path.basename(file_path))
|
| 89 |
+
shutil.copyfile(file_path, new_file_path)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
parser = ArgumentParser()
|
| 94 |
+
parser.add_argument("--hf-ckpt-path", type=str, required=True)
|
| 95 |
+
parser.add_argument("--save-path", type=str, required=True)
|
| 96 |
+
parser.add_argument("--n-experts", type=int, required=True)
|
| 97 |
+
parser.add_argument("--model-parallel", type=int, required=True)
|
| 98 |
+
args = parser.parse_args()
|
| 99 |
+
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
|
| 100 |
+
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
|
inference/generate.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from argparse import ArgumentParser
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
+
from safetensors.torch import load_model
|
| 10 |
+
|
| 11 |
+
from model import Transformer, ModelArgs
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def sample(logits, temperature: float = 1.0):
|
| 15 |
+
"""
|
| 16 |
+
Samples a token from the logits using temperature scaling.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
logits (torch.Tensor): The logits tensor for token predictions.
|
| 20 |
+
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
torch.Tensor: The sampled token.
|
| 24 |
+
"""
|
| 25 |
+
logits = logits / max(temperature, 1e-5)
|
| 26 |
+
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
| 27 |
+
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@torch.inference_mode()
|
| 31 |
+
def generate(
|
| 32 |
+
model: Transformer,
|
| 33 |
+
prompt_tokens: List[List[int]],
|
| 34 |
+
max_new_tokens: int,
|
| 35 |
+
eos_id: int,
|
| 36 |
+
temperature: float = 1.0
|
| 37 |
+
) -> List[List[int]]:
|
| 38 |
+
"""
|
| 39 |
+
Generates new tokens based on the given prompt tokens using the specified model.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
model (Transformer): The transformer model used for token generation.
|
| 43 |
+
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
|
| 44 |
+
max_new_tokens (int): The maximum number of new tokens to generate.
|
| 45 |
+
eos_id (int): The end-of-sequence token ID.
|
| 46 |
+
temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
List[List[int]]: A list of lists containing the generated tokens for each sequence.
|
| 50 |
+
"""
|
| 51 |
+
prompt_lens = [len(t) for t in prompt_tokens]
|
| 52 |
+
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
|
| 53 |
+
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
| 54 |
+
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
|
| 55 |
+
for i, t in enumerate(prompt_tokens):
|
| 56 |
+
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
| 57 |
+
prev_pos = 0
|
| 58 |
+
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
|
| 59 |
+
prompt_mask = tokens != -1
|
| 60 |
+
for cur_pos in range(min(prompt_lens), total_len):
|
| 61 |
+
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
| 62 |
+
if temperature > 0:
|
| 63 |
+
next_token = sample(logits, temperature)
|
| 64 |
+
else:
|
| 65 |
+
next_token = logits.argmax(dim=-1)
|
| 66 |
+
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
| 67 |
+
tokens[:, cur_pos] = next_token
|
| 68 |
+
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
|
| 69 |
+
prev_pos = cur_pos
|
| 70 |
+
if finished.all():
|
| 71 |
+
break
|
| 72 |
+
completion_tokens = []
|
| 73 |
+
for i, toks in enumerate(tokens.tolist()):
|
| 74 |
+
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
|
| 75 |
+
if eos_id in toks:
|
| 76 |
+
toks = toks[:toks.index(eos_id)]
|
| 77 |
+
completion_tokens.append(toks)
|
| 78 |
+
return completion_tokens
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def main(
|
| 82 |
+
ckpt_path: str,
|
| 83 |
+
config: str,
|
| 84 |
+
input_file: str = "",
|
| 85 |
+
interactive: bool = True,
|
| 86 |
+
max_new_tokens: int = 100,
|
| 87 |
+
temperature: float = 1.0,
|
| 88 |
+
) -> None:
|
| 89 |
+
"""
|
| 90 |
+
Main function to load the model and perform interactive or batch text generation.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
ckpt_path (str): Path to the model checkpoint directory.
|
| 94 |
+
config (str): Path to the model configuration file.
|
| 95 |
+
input_file (str, optional): Path to a file containing input prompts. Defaults to "".
|
| 96 |
+
interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
|
| 97 |
+
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
|
| 98 |
+
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
|
| 99 |
+
"""
|
| 100 |
+
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
| 101 |
+
rank = int(os.getenv("RANK", "0"))
|
| 102 |
+
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
| 103 |
+
if world_size > 1:
|
| 104 |
+
dist.init_process_group("nccl")
|
| 105 |
+
global print
|
| 106 |
+
if rank != 0:
|
| 107 |
+
print = lambda *_, **__: None
|
| 108 |
+
torch.cuda.set_device(local_rank)
|
| 109 |
+
torch.set_default_dtype(torch.bfloat16)
|
| 110 |
+
torch.set_num_threads(8)
|
| 111 |
+
torch.manual_seed(33377335)
|
| 112 |
+
with open(config) as f:
|
| 113 |
+
args = ModelArgs(**json.load(f))
|
| 114 |
+
print(args)
|
| 115 |
+
with torch.device("cuda"):
|
| 116 |
+
model = Transformer(args)
|
| 117 |
+
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
| 118 |
+
print("load model")
|
| 119 |
+
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
|
| 120 |
+
print("I'm DeepSeek 👋")
|
| 121 |
+
|
| 122 |
+
if interactive:
|
| 123 |
+
messages = []
|
| 124 |
+
while True:
|
| 125 |
+
if world_size == 1:
|
| 126 |
+
prompt = input(">>> ")
|
| 127 |
+
elif rank == 0:
|
| 128 |
+
prompt = input(">>> ")
|
| 129 |
+
objects = [prompt]
|
| 130 |
+
dist.broadcast_object_list(objects, 0)
|
| 131 |
+
else:
|
| 132 |
+
objects = [None]
|
| 133 |
+
dist.broadcast_object_list(objects, 0)
|
| 134 |
+
prompt = objects[0]
|
| 135 |
+
if prompt == "/exit":
|
| 136 |
+
break
|
| 137 |
+
elif prompt == "/clear":
|
| 138 |
+
messages.clear()
|
| 139 |
+
continue
|
| 140 |
+
messages.append({"role": "user", "content": prompt})
|
| 141 |
+
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
| 142 |
+
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
|
| 143 |
+
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
|
| 144 |
+
print(completion)
|
| 145 |
+
messages.append({"role": "assistant", "content": completion})
|
| 146 |
+
else:
|
| 147 |
+
with open(input_file) as f:
|
| 148 |
+
prompts = f.read().split("\n\n")
|
| 149 |
+
assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
|
| 150 |
+
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
|
| 151 |
+
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
|
| 152 |
+
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
|
| 153 |
+
for prompt, completion in zip(prompts, completions):
|
| 154 |
+
print("Prompt:", prompt)
|
| 155 |
+
print("Completion:", completion)
|
| 156 |
+
print()
|
| 157 |
+
|
| 158 |
+
if world_size > 1:
|
| 159 |
+
dist.destroy_process_group()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
"""
|
| 164 |
+
Command-line interface for distributed text generation.
|
| 165 |
+
|
| 166 |
+
Arguments:
|
| 167 |
+
--ckpt-path (str): Path to the model checkpoint directory.
|
| 168 |
+
--config (str): Path to the model configuration file.
|
| 169 |
+
--input-file (str, optional): File containing prompts for batch processing.
|
| 170 |
+
--interactive (bool, optional): Enable interactive mode for generating text.
|
| 171 |
+
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
|
| 172 |
+
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.
|
| 173 |
+
|
| 174 |
+
Raises:
|
| 175 |
+
AssertionError: If neither input-file nor interactive mode is specified.
|
| 176 |
+
"""
|
| 177 |
+
parser = ArgumentParser()
|
| 178 |
+
parser.add_argument("--ckpt-path", type=str, required=True)
|
| 179 |
+
parser.add_argument("--config", type=str, required=True)
|
| 180 |
+
parser.add_argument("--input-file", type=str, default="")
|
| 181 |
+
parser.add_argument("--interactive", action="store_true")
|
| 182 |
+
parser.add_argument("--max-new-tokens", type=int, default=200)
|
| 183 |
+
parser.add_argument("--temperature", type=float, default=0.6)
|
| 184 |
+
args = parser.parse_args()
|
| 185 |
+
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
|
| 186 |
+
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
|
inference/kernel.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import tilelang
|
| 3 |
+
import tilelang.language as T
|
| 4 |
+
from typing import Tuple, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
tilelang.set_log_level("WARNING")
|
| 8 |
+
|
| 9 |
+
pass_configs = {
|
| 10 |
+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
| 11 |
+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
| 12 |
+
tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
FP8 = "float8_e4m3"
|
| 16 |
+
BF16 = "bfloat16"
|
| 17 |
+
FP32 = "float32"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def fast_log2_ceil(x):
|
| 21 |
+
bits_x = T.reinterpret("uint32", x)
|
| 22 |
+
exp_x = (bits_x >> 23) & 0xFF
|
| 23 |
+
man_bits = bits_x & ((1 << 23) - 1)
|
| 24 |
+
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def fast_pow2(x):
|
| 28 |
+
bits_x = (x + 127) << 23
|
| 29 |
+
return T.reinterpret("float32", bits_x)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def fast_round_scale(amax, fp8_max_inv):
|
| 33 |
+
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 37 |
+
def act_quant_kernel(
|
| 38 |
+
N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False
|
| 39 |
+
):
|
| 40 |
+
M = T.symbolic("M")
|
| 41 |
+
fp8_min = -448.0
|
| 42 |
+
fp8_max = 448.0
|
| 43 |
+
fp8_max_inv = 1 / fp8_max
|
| 44 |
+
num_stages = 0 if round_scale else 2
|
| 45 |
+
blk_m = 32
|
| 46 |
+
group_size = 128
|
| 47 |
+
|
| 48 |
+
@T.prim_func
|
| 49 |
+
def act_quant_kernel_(
|
| 50 |
+
X: T.Tensor[(M, N), in_dtype],
|
| 51 |
+
Y: T.Tensor[(M, N), out_dtype],
|
| 52 |
+
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
|
| 53 |
+
):
|
| 54 |
+
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
|
| 55 |
+
pid_m,
|
| 56 |
+
pid_n,
|
| 57 |
+
):
|
| 58 |
+
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
|
| 59 |
+
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
|
| 60 |
+
amax_local = T.alloc_fragment((blk_m,), scale_dtype)
|
| 61 |
+
s_local = T.alloc_fragment((blk_m,), scale_dtype)
|
| 62 |
+
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
|
| 63 |
+
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
|
| 64 |
+
|
| 65 |
+
for _ in T.Pipelined(1, num_stages=num_stages):
|
| 66 |
+
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
|
| 67 |
+
T.copy(x_shared, x_local)
|
| 68 |
+
T.reduce_absmax(x_local, amax_local, dim=1)
|
| 69 |
+
for i in T.Parallel(blk_m):
|
| 70 |
+
amax_local[i] = T.max(amax_local[i], 1e-4)
|
| 71 |
+
if round_scale:
|
| 72 |
+
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
|
| 73 |
+
else:
|
| 74 |
+
s_local[i] = amax_local[i] * fp8_max_inv
|
| 75 |
+
for i, j in T.Parallel(blk_m, group_size):
|
| 76 |
+
y_local[i, j] = T.clamp(
|
| 77 |
+
x_local[i, j] / s_local[i], fp8_min, fp8_max
|
| 78 |
+
)
|
| 79 |
+
for i in T.Parallel(blk_m):
|
| 80 |
+
S[pid_m * blk_m + i, pid_n] = s_local[i]
|
| 81 |
+
T.copy(y_local, y_shared)
|
| 82 |
+
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
|
| 83 |
+
|
| 84 |
+
return act_quant_kernel_
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def act_quant(
|
| 88 |
+
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
|
| 89 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 90 |
+
"""
|
| 91 |
+
Quantizes the input tensor `x` using block-wise quantization.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
|
| 95 |
+
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
|
| 96 |
+
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
|
| 97 |
+
Returns:
|
| 98 |
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
| 99 |
+
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
| 100 |
+
- A tensor of scaling factors with dtype `torch.float32`.
|
| 101 |
+
"""
|
| 102 |
+
assert x.is_contiguous(), "Input tensor must be contiguous"
|
| 103 |
+
assert x.size(-1) % block_size == 0, (
|
| 104 |
+
f"Last dimension size must be divisible by block_size (block_size={block_size})"
|
| 105 |
+
)
|
| 106 |
+
N = x.size(-1)
|
| 107 |
+
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
| 108 |
+
s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
|
| 109 |
+
kernel = act_quant_kernel(N, round_scale=scale_fmt is not None)
|
| 110 |
+
kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
|
| 111 |
+
return y, s
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 115 |
+
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"):
|
| 116 |
+
assert out_dtype in [BF16, "float32"]
|
| 117 |
+
|
| 118 |
+
M = T.symbolic("M")
|
| 119 |
+
group_size = 128
|
| 120 |
+
block_M = 32
|
| 121 |
+
block_N = 128
|
| 122 |
+
block_K = 128
|
| 123 |
+
|
| 124 |
+
@T.prim_func
|
| 125 |
+
def fp8_gemm_kernel_(
|
| 126 |
+
A: T.Tensor[(M, K), FP8],
|
| 127 |
+
B: T.Tensor[(N, K), FP8],
|
| 128 |
+
C: T.Tensor[(M, N), out_dtype],
|
| 129 |
+
scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), FP32],
|
| 130 |
+
scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), FP32],
|
| 131 |
+
):
|
| 132 |
+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
|
| 133 |
+
bx,
|
| 134 |
+
by,
|
| 135 |
+
):
|
| 136 |
+
A_shared = T.alloc_shared((block_M, block_K), FP8)
|
| 137 |
+
B_shared = T.alloc_shared((block_N, block_K), FP8)
|
| 138 |
+
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
|
| 139 |
+
Scale_C_shared = T.alloc_shared((block_M), FP32)
|
| 140 |
+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
|
| 141 |
+
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
|
| 142 |
+
|
| 143 |
+
# Improve L2 Cache
|
| 144 |
+
T.use_swizzle(panel_size=10)
|
| 145 |
+
|
| 146 |
+
T.clear(C_local)
|
| 147 |
+
T.clear(C_local_accum)
|
| 148 |
+
K_iters = T.ceildiv(K, block_K)
|
| 149 |
+
for k in T.Pipelined(K_iters, num_stages=4):
|
| 150 |
+
# Load A into shared memory
|
| 151 |
+
T.copy(A[by * block_M, k * block_K], A_shared)
|
| 152 |
+
# Load B into shared memory
|
| 153 |
+
T.copy(B[bx * block_N, k * block_K], B_shared)
|
| 154 |
+
# Load scale into shared memory
|
| 155 |
+
Scale_B = scales_b[bx * block_N // group_size, k]
|
| 156 |
+
for i in T.Parallel(block_M):
|
| 157 |
+
Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
|
| 158 |
+
|
| 159 |
+
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
|
| 160 |
+
# Promote to enable 2xAcc
|
| 161 |
+
for i, j in T.Parallel(block_M, block_N):
|
| 162 |
+
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
|
| 163 |
+
T.clear(C_local)
|
| 164 |
+
# TMA store
|
| 165 |
+
T.copy(C_local_accum, C_shared)
|
| 166 |
+
T.copy(C_shared, C[by * block_M, bx * block_N])
|
| 167 |
+
|
| 168 |
+
return fp8_gemm_kernel_
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def fp8_gemm(
|
| 172 |
+
a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor
|
| 173 |
+
) -> torch.Tensor:
|
| 174 |
+
"""
|
| 175 |
+
Perform a matrix multiplication using FP8 precision.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
a (torch.Tensor): The first input matrix, must be contiguous.
|
| 179 |
+
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
|
| 180 |
+
b (torch.Tensor): The second input matrix, must be contiguous.
|
| 181 |
+
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
torch.Tensor: The result of the matrix multiplication.
|
| 185 |
+
"""
|
| 186 |
+
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
|
| 187 |
+
assert a_s.is_contiguous() and b_s.is_contiguous(), (
|
| 188 |
+
"Scaling factor tensors must be contiguous"
|
| 189 |
+
)
|
| 190 |
+
K = a.size(-1)
|
| 191 |
+
M = a.numel() // K
|
| 192 |
+
N = b.size(0)
|
| 193 |
+
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
| 194 |
+
kernel = fp8_gemm_kernel(N, K)
|
| 195 |
+
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
|
| 196 |
+
return c
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@tilelang.jit(out_idx=[4], pass_configs=pass_configs)
|
| 200 |
+
def fp8_index_kernel(h: int, d: int):
|
| 201 |
+
b = T.symbolic("b")
|
| 202 |
+
m = T.symbolic("m")
|
| 203 |
+
n = T.symbolic("n")
|
| 204 |
+
|
| 205 |
+
blk_n1 = 512
|
| 206 |
+
blk_n2 = 128
|
| 207 |
+
|
| 208 |
+
@T.prim_func
|
| 209 |
+
def fp8_index_kernel_(
|
| 210 |
+
q: T.Tensor[(b, m, h, d), FP8],
|
| 211 |
+
q_s: T.Tensor[(b, m, h), FP32],
|
| 212 |
+
k: T.Tensor[(b, n, d), FP8],
|
| 213 |
+
k_s: T.Tensor[(b, n), FP32],
|
| 214 |
+
o: T.Tensor[(b, m, n), FP32],
|
| 215 |
+
) -> None:
|
| 216 |
+
with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n):
|
| 217 |
+
q_smem = T.alloc_shared((h, d), FP8)
|
| 218 |
+
T.copy(q[i_b, i_m, 0, 0], q_smem)
|
| 219 |
+
|
| 220 |
+
q_s_frag = T.alloc_fragment(h, FP32)
|
| 221 |
+
T.copy(q_s[i_b, i_m, 0], q_s_frag)
|
| 222 |
+
|
| 223 |
+
for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2):
|
| 224 |
+
k_smem = T.alloc_shared((blk_n2, d), FP8)
|
| 225 |
+
T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
|
| 226 |
+
|
| 227 |
+
k_s_frag = T.alloc_fragment(blk_n2, FP32)
|
| 228 |
+
T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
|
| 229 |
+
|
| 230 |
+
logits = T.alloc_fragment((blk_n2, h), FP32)
|
| 231 |
+
T.gemm(
|
| 232 |
+
k_smem,
|
| 233 |
+
q_smem,
|
| 234 |
+
logits,
|
| 235 |
+
transpose_A=False,
|
| 236 |
+
transpose_B=True,
|
| 237 |
+
clear_accum=True,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
for i_h, i3_n in T.Parallel(h, blk_n2):
|
| 241 |
+
logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
|
| 242 |
+
|
| 243 |
+
logits_sum = T.alloc_fragment(blk_n2, FP32)
|
| 244 |
+
T.reduce_sum(logits, logits_sum, dim=1)
|
| 245 |
+
|
| 246 |
+
for i3_n in T.Parallel(blk_n2):
|
| 247 |
+
logits_sum[i3_n] *= k_s_frag[i3_n]
|
| 248 |
+
|
| 249 |
+
T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2])
|
| 250 |
+
|
| 251 |
+
return fp8_index_kernel_
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def fp8_index(
|
| 255 |
+
q: torch.Tensor,
|
| 256 |
+
q_s: torch.Tensor,
|
| 257 |
+
k: torch.Tensor,
|
| 258 |
+
k_s: torch.Tensor,
|
| 259 |
+
) -> torch.Tensor:
|
| 260 |
+
"""
|
| 261 |
+
Perform index score using FP8 precision.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
q (torch.Tensor): The Q tensor, must be contiguous.
|
| 265 |
+
q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
|
| 266 |
+
k (torch.Tensor): The K tensor, must be contiguous.
|
| 267 |
+
k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
|
| 268 |
+
|
| 269 |
+
fp8 q @ fp8 k -> fp32 logits
|
| 270 |
+
relu(fp32 logits) * q_s (weights) -> fp32 logits
|
| 271 |
+
fp32 logits -> fp32 logits_sum
|
| 272 |
+
fp32 logits_sum * k_s (e8m0) -> fp32 index_score
|
| 273 |
+
"""
|
| 274 |
+
return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s)
|
inference/model.py
ADDED
|
@@ -0,0 +1,923 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Tuple, Optional, Literal
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
|
| 10 |
+
from kernel import act_quant, fp8_gemm, fp8_index
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
world_size = 1
|
| 14 |
+
rank = 0
|
| 15 |
+
block_size = 128
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class ModelArgs:
|
| 19 |
+
"""
|
| 20 |
+
Data class for defining model arguments and hyperparameters.
|
| 21 |
+
|
| 22 |
+
Attributes:
|
| 23 |
+
max_batch_size (int): Maximum batch size.
|
| 24 |
+
max_seq_len (int): Maximum sequence length.
|
| 25 |
+
dtype (Literal["bf16", "fp8"]): Data type for computations.
|
| 26 |
+
scale_fmt (Optional[str]): Format for quantization scale.
|
| 27 |
+
vocab_size (int): Vocabulary size.
|
| 28 |
+
dim (int): Model dimension.
|
| 29 |
+
inter_dim (int): Intermediate dimension for MLP layers.
|
| 30 |
+
moe_inter_dim (int): Intermediate dimension for MoE layers.
|
| 31 |
+
n_layers (int): Number of transformer layers.
|
| 32 |
+
n_dense_layers (int): Number of dense layers in the model.
|
| 33 |
+
n_heads (int): Number of attention heads.
|
| 34 |
+
n_routed_experts (int): Number of routed experts for MoE layers.
|
| 35 |
+
n_shared_experts (int): Number of shared experts for MoE layers.
|
| 36 |
+
n_activated_experts (int): Number of activated experts in MoE layers.
|
| 37 |
+
n_expert_groups (int): Number of expert groups.
|
| 38 |
+
n_limited_groups (int): Number of limited groups for MoE routing.
|
| 39 |
+
score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
|
| 40 |
+
route_scale (float): Scaling factor for routing scores.
|
| 41 |
+
q_lora_rank (int): LoRA rank for query projections.
|
| 42 |
+
kv_lora_rank (int): LoRA rank for key-value projections.
|
| 43 |
+
qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
|
| 44 |
+
qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
|
| 45 |
+
v_head_dim (int): Dimension for value projections.
|
| 46 |
+
original_seq_len (int): Original sequence length.
|
| 47 |
+
rope_theta (float): Base for rotary positional encoding.
|
| 48 |
+
rope_factor (float): Scaling factor for extended sequence lengths.
|
| 49 |
+
beta_fast (int): Fast beta correction factor.
|
| 50 |
+
beta_slow (int): Slow beta correction factor.
|
| 51 |
+
mscale (float): Scaling factor for extended attention.
|
| 52 |
+
index_head_dim (int): Dimension for index head.
|
| 53 |
+
index_topk (int): Top-k for index head.
|
| 54 |
+
"""
|
| 55 |
+
max_batch_size: int = 8
|
| 56 |
+
max_seq_len: int = 4096 * 4
|
| 57 |
+
dtype: Literal["bf16", "fp8"] = "bf16"
|
| 58 |
+
scale_fmt: Optional[str] = None
|
| 59 |
+
vocab_size: int = 102400
|
| 60 |
+
dim: int = 2048
|
| 61 |
+
inter_dim: int = 10944
|
| 62 |
+
moe_inter_dim: int = 1408
|
| 63 |
+
n_layers: int = 27
|
| 64 |
+
n_dense_layers: int = 1
|
| 65 |
+
n_heads: int = 16
|
| 66 |
+
# moe
|
| 67 |
+
n_routed_experts: int = 64
|
| 68 |
+
n_shared_experts: int = 2
|
| 69 |
+
n_activated_experts: int = 6
|
| 70 |
+
n_expert_groups: int = 1
|
| 71 |
+
n_limited_groups: int = 1
|
| 72 |
+
score_func: Literal["softmax", "sigmoid"] = "softmax"
|
| 73 |
+
route_scale: float = 1.
|
| 74 |
+
# mla
|
| 75 |
+
q_lora_rank: int = 0
|
| 76 |
+
kv_lora_rank: int = 512
|
| 77 |
+
qk_nope_head_dim: int = 128
|
| 78 |
+
qk_rope_head_dim: int = 64
|
| 79 |
+
v_head_dim: int = 128
|
| 80 |
+
# yarn
|
| 81 |
+
original_seq_len: int = 4096
|
| 82 |
+
rope_theta: float = 10000.0
|
| 83 |
+
rope_factor: float = 40
|
| 84 |
+
beta_fast: int = 32
|
| 85 |
+
beta_slow: int = 1
|
| 86 |
+
mscale: float = 1.
|
| 87 |
+
# index
|
| 88 |
+
index_n_heads: int = 64
|
| 89 |
+
index_head_dim: int = 128
|
| 90 |
+
index_topk: int = 2048
|
| 91 |
+
|
| 92 |
+
class ParallelEmbedding(nn.Module):
|
| 93 |
+
"""
|
| 94 |
+
Embedding layer with parallelism support across distributed processes.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
vocab_size (int): Vocabulary size.
|
| 98 |
+
dim (int): Embedding dimension.
|
| 99 |
+
"""
|
| 100 |
+
def __init__(self, vocab_size: int, dim: int):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.vocab_size = vocab_size
|
| 103 |
+
self.dim = dim
|
| 104 |
+
assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
|
| 105 |
+
self.part_vocab_size = (vocab_size // world_size)
|
| 106 |
+
self.vocab_start_idx = rank * self.part_vocab_size
|
| 107 |
+
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
|
| 108 |
+
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
|
| 109 |
+
|
| 110 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 111 |
+
"""
|
| 112 |
+
Forward pass for parallel embedding layer.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
x (torch.Tensor): Input tensor containing token indices.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
torch.Tensor: Embedded representations.
|
| 119 |
+
|
| 120 |
+
Raises:
|
| 121 |
+
ValueError: If `world_size` is not defined.
|
| 122 |
+
"""
|
| 123 |
+
if world_size > 1:
|
| 124 |
+
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
|
| 125 |
+
x = x - self.vocab_start_idx
|
| 126 |
+
x[mask] = 0
|
| 127 |
+
y = F.embedding(x, self.weight)
|
| 128 |
+
if world_size > 1:
|
| 129 |
+
y[mask] = 0
|
| 130 |
+
dist.all_reduce(y)
|
| 131 |
+
return y
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None,
|
| 135 |
+
scale_fmt: Optional[str] = None) -> torch.Tensor:
|
| 136 |
+
"""
|
| 137 |
+
Applies a linear transformation to the incoming data: y = xA^T + b.
|
| 138 |
+
This function supports specialized implementations based on quantization
|
| 139 |
+
and tensor formats.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
x (torch.Tensor): The input tensor.
|
| 143 |
+
weight (torch.Tensor): The weight tensor. It may be quantized and
|
| 144 |
+
requires dequantization for certain cases.
|
| 145 |
+
bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
|
| 146 |
+
scale_fmt (Optional[str]): The format of scaling factors.
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
torch.Tensor: The result of the linear transformation, which may involve
|
| 150 |
+
quantization-aware computations depending on the input parameters.
|
| 151 |
+
|
| 152 |
+
Notes:
|
| 153 |
+
- If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version
|
| 154 |
+
is used for computation.
|
| 155 |
+
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
|
| 156 |
+
"""
|
| 157 |
+
assert bias is None
|
| 158 |
+
|
| 159 |
+
if weight.dtype != torch.float8_e4m3fn:
|
| 160 |
+
return F.linear(x, weight)
|
| 161 |
+
else:
|
| 162 |
+
x, scale = act_quant(x, block_size, scale_fmt)
|
| 163 |
+
return fp8_gemm(x, scale, weight, weight.scale)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class Linear(nn.Module):
|
| 167 |
+
"""
|
| 168 |
+
Custom linear layer with support for quantized weights and optional bias.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
in_features (int): Number of input features.
|
| 172 |
+
out_features (int): Number of output features.
|
| 173 |
+
bias (bool): Whether to include a bias term. Defaults to False.
|
| 174 |
+
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
| 175 |
+
"""
|
| 176 |
+
dtype = torch.bfloat16
|
| 177 |
+
scale_fmt: Optional[str] = None
|
| 178 |
+
|
| 179 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
| 180 |
+
super().__init__()
|
| 181 |
+
self.in_features = in_features
|
| 182 |
+
self.out_features = out_features
|
| 183 |
+
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
|
| 184 |
+
if self.weight.element_size() == 1:
|
| 185 |
+
scale_out_features = (out_features + block_size - 1) // block_size
|
| 186 |
+
scale_in_features = (in_features + block_size - 1) // block_size
|
| 187 |
+
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
|
| 188 |
+
else:
|
| 189 |
+
self.register_parameter("scale", None)
|
| 190 |
+
if bias:
|
| 191 |
+
self.bias = nn.Parameter(torch.empty(out_features))
|
| 192 |
+
else:
|
| 193 |
+
self.register_parameter("bias", None)
|
| 194 |
+
|
| 195 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 196 |
+
"""
|
| 197 |
+
Forward pass for the custom linear layer.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
x (torch.Tensor): Input tensor.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
torch.Tensor: Transformed tensor after linear computation.
|
| 204 |
+
"""
|
| 205 |
+
return linear(x, self.weight, self.bias, self.scale_fmt)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class ColumnParallelLinear(Linear):
|
| 209 |
+
"""
|
| 210 |
+
Linear layer with column parallelism, splitting output features across distributed processes.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
in_features (int): Number of input features.
|
| 214 |
+
out_features (int): Total number of output features.
|
| 215 |
+
bias (bool): Whether to include a bias term. Defaults to False.
|
| 216 |
+
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
| 217 |
+
"""
|
| 218 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
| 219 |
+
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
|
| 220 |
+
self.part_out_features = out_features // world_size
|
| 221 |
+
super().__init__(in_features, self.part_out_features, bias, dtype)
|
| 222 |
+
|
| 223 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 224 |
+
"""
|
| 225 |
+
Forward pass for column parallel linear layer.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
x (torch.Tensor): Input tensor.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
torch.Tensor: Transformed tensor with column-parallel computation.
|
| 232 |
+
"""
|
| 233 |
+
y = linear(x, self.weight, self.bias, self.scale_fmt)
|
| 234 |
+
return y
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class RowParallelLinear(Linear):
|
| 238 |
+
"""
|
| 239 |
+
Linear layer with row parallelism, splitting input features across distributed processes.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
in_features (int): Total number of input features.
|
| 243 |
+
out_features (int): Number of output features.
|
| 244 |
+
bias (bool): Whether to include a bias term. Defaults to False.
|
| 245 |
+
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
| 246 |
+
"""
|
| 247 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = False, reduce_output = True, dtype = None):
|
| 248 |
+
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
|
| 249 |
+
self.part_in_features = in_features // world_size
|
| 250 |
+
self.reduce_output = reduce_output
|
| 251 |
+
super().__init__(self.part_in_features, out_features, bias, dtype)
|
| 252 |
+
|
| 253 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 254 |
+
"""
|
| 255 |
+
Forward pass for row parallel linear layer.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
x (torch.Tensor): Input tensor.
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
torch.Tensor: Transformed tensor with row-parallel computation.
|
| 262 |
+
"""
|
| 263 |
+
y = linear(x, self.weight, None, self.scale_fmt)
|
| 264 |
+
if self.reduce_output and world_size > 1:
|
| 265 |
+
y = y.float()
|
| 266 |
+
dist.all_reduce(y)
|
| 267 |
+
if self.bias is not None:
|
| 268 |
+
y += self.bias
|
| 269 |
+
return y.type_as(x)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class RMSNorm(nn.Module):
|
| 273 |
+
"""
|
| 274 |
+
Root Mean Square Layer Normalization (RMSNorm).
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
dim (int): Dimension of the input tensor.
|
| 278 |
+
eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
|
| 279 |
+
"""
|
| 280 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 281 |
+
super().__init__()
|
| 282 |
+
self.dim = dim
|
| 283 |
+
self.eps = eps
|
| 284 |
+
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
| 285 |
+
|
| 286 |
+
def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
|
| 287 |
+
"""
|
| 288 |
+
Forward pass for RMSNorm.
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
x (torch.Tensor): Input tensor.
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
torch.Tensor: Normalized tensor with the same shape as input.
|
| 295 |
+
"""
|
| 296 |
+
dtype = x.dtype
|
| 297 |
+
if residual is None:
|
| 298 |
+
x = x.float()
|
| 299 |
+
var = x.pow(2).mean(-1, keepdim=True)
|
| 300 |
+
x = x * torch.rsqrt(var + self.eps)
|
| 301 |
+
return (self.weight * x).to(dtype)
|
| 302 |
+
else:
|
| 303 |
+
x = residual = x.float() + residual.float()
|
| 304 |
+
var = x.pow(2).mean(-1, keepdim=True)
|
| 305 |
+
x = x * torch.rsqrt(var + self.eps)
|
| 306 |
+
return (self.weight * x).to(dtype), residual.to(dtype)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class LayerNorm(nn.Module):
|
| 310 |
+
"""
|
| 311 |
+
Layer Normalization.
|
| 312 |
+
"""
|
| 313 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 314 |
+
super().__init__()
|
| 315 |
+
self.dim = dim
|
| 316 |
+
self.eps = eps
|
| 317 |
+
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
| 318 |
+
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
|
| 319 |
+
|
| 320 |
+
def forward(self, x: torch.Tensor):
|
| 321 |
+
return F.layer_norm(x.float(), (self.dim,), self.weight, self.bias, self.eps).type_as(x)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
| 325 |
+
"""
|
| 326 |
+
Precomputes frequency-based complex exponential values for rotary positional embeddings.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
args (ModelArgs): Model arguments containing positional embedding parameters.
|
| 330 |
+
|
| 331 |
+
Returns:
|
| 332 |
+
torch.Tensor: Precomputed complex exponential values for positional embeddings.
|
| 333 |
+
"""
|
| 334 |
+
dim = args.qk_rope_head_dim
|
| 335 |
+
seqlen = args.max_seq_len
|
| 336 |
+
beta_fast = args.beta_fast
|
| 337 |
+
beta_slow = args.beta_slow
|
| 338 |
+
base = args.rope_theta
|
| 339 |
+
factor = args.rope_factor
|
| 340 |
+
|
| 341 |
+
def find_correction_dim(num_rotations, dim, base, max_seq_len):
|
| 342 |
+
"""
|
| 343 |
+
Computes the correction dimension for a given number of rotations in the rotary positional embedding.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
num_rotations (float): Number of rotations to compute the correction for.
|
| 347 |
+
dim (int): Dimensionality of the embedding space.
|
| 348 |
+
base (float): Base value for the exponential computation.
|
| 349 |
+
max_seq_len (int): Maximum sequence length.
|
| 350 |
+
|
| 351 |
+
Returns:
|
| 352 |
+
float: The correction dimension based on the input parameters.
|
| 353 |
+
"""
|
| 354 |
+
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
|
| 355 |
+
|
| 356 |
+
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
|
| 357 |
+
"""
|
| 358 |
+
Computes the range of correction dimensions for rotary positional embeddings.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
low_rot (float): Lower bound for the number of rotations.
|
| 362 |
+
high_rot (float): Upper bound for the number of rotations.
|
| 363 |
+
dim (int): Dimensionality of the embedding space.
|
| 364 |
+
base (float): Base value for the exponential computation.
|
| 365 |
+
max_seq_len (int): Maximum sequence length.
|
| 366 |
+
|
| 367 |
+
Returns:
|
| 368 |
+
Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
|
| 369 |
+
"""
|
| 370 |
+
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
|
| 371 |
+
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
|
| 372 |
+
return max(low, 0), min(high, dim-1)
|
| 373 |
+
|
| 374 |
+
def linear_ramp_factor(min, max, dim):
|
| 375 |
+
"""
|
| 376 |
+
Computes a linear ramp function used to smooth values between a minimum and maximum range.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
min (float): Minimum value for the ramp function.
|
| 380 |
+
max (float): Maximum value for the ramp function.
|
| 381 |
+
dim (int): Dimensionality of the ramp tensor.
|
| 382 |
+
|
| 383 |
+
Returns:
|
| 384 |
+
torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
|
| 385 |
+
clamped to the range [0, 1].
|
| 386 |
+
"""
|
| 387 |
+
if min == max:
|
| 388 |
+
max += 0.001
|
| 389 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
| 390 |
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
| 391 |
+
return ramp_func
|
| 392 |
+
|
| 393 |
+
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
| 394 |
+
if seqlen > args.original_seq_len:
|
| 395 |
+
low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
|
| 396 |
+
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
|
| 397 |
+
freqs = freqs / factor * (1 - smooth) + freqs * smooth
|
| 398 |
+
|
| 399 |
+
t = torch.arange(seqlen)
|
| 400 |
+
freqs = torch.outer(t, freqs)
|
| 401 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| 402 |
+
return freqs_cis
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, interleaved: bool = True) -> torch.Tensor:
|
| 406 |
+
"""
|
| 407 |
+
Applies rotary positional embeddings to the input tensor.
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
x (torch.Tensor): Input tensor with positional embeddings to be applied.
|
| 411 |
+
freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
|
| 412 |
+
|
| 413 |
+
Returns:
|
| 414 |
+
torch.Tensor: Tensor with rotary embeddings applied.
|
| 415 |
+
"""
|
| 416 |
+
dtype = x.dtype
|
| 417 |
+
shape = x.shape
|
| 418 |
+
if not interleaved:
|
| 419 |
+
x = x.view(*shape[:-1], 2, -1).transpose(-1, -2).contiguous()
|
| 420 |
+
x = torch.view_as_complex(x.float().view(*shape[:-1], -1, 2))
|
| 421 |
+
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
|
| 422 |
+
y = torch.view_as_real(x * freqs_cis).flatten(3)
|
| 423 |
+
if not interleaved:
|
| 424 |
+
y = torch.cat([y[..., 0::2], y[..., 1::2]], dim=-1)
|
| 425 |
+
return y.to(dtype)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
|
| 429 |
+
assert x.dtype == torch.bfloat16
|
| 430 |
+
from fast_hadamard_transform import hadamard_transform
|
| 431 |
+
hidden_size = x.size(-1)
|
| 432 |
+
return hadamard_transform(x, scale=hidden_size ** -0.5)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
class Indexer(torch.nn.Module):
|
| 436 |
+
def __init__(self, args: ModelArgs):
|
| 437 |
+
super().__init__()
|
| 438 |
+
self.dim: int = args.dim
|
| 439 |
+
self.n_heads: int = args.index_n_heads
|
| 440 |
+
self.n_local_heads = args.index_n_heads // world_size
|
| 441 |
+
self.head_dim: int = args.index_head_dim
|
| 442 |
+
self.rope_head_dim: int = args.qk_rope_head_dim
|
| 443 |
+
self.index_topk: int = args.index_topk
|
| 444 |
+
self.q_lora_rank: int = args.q_lora_rank
|
| 445 |
+
self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim)
|
| 446 |
+
self.wk = Linear(self.dim, self.head_dim)
|
| 447 |
+
self.k_norm = LayerNorm(self.head_dim)
|
| 448 |
+
# weights_proj in the checkpoint is stored in bf16, while the parameters here are stored in fp32 for convenient.
|
| 449 |
+
self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.float32)
|
| 450 |
+
self.softmax_scale = self.head_dim ** -0.5
|
| 451 |
+
self.scale_fmt = args.scale_fmt
|
| 452 |
+
|
| 453 |
+
self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.head_dim, dtype=torch.float8_e4m3fn), persistent=False)
|
| 454 |
+
self.register_buffer("k_scale_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.head_dim // block_size, dtype=torch.float32), persistent=False)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
| 458 |
+
bsz, seqlen, _ = x.size()
|
| 459 |
+
end_pos = start_pos + seqlen
|
| 460 |
+
q = self.wq_b(qr)
|
| 461 |
+
q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
|
| 462 |
+
q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
|
| 463 |
+
# rope in indexer is not interleaved
|
| 464 |
+
q_pe = apply_rotary_emb(q_pe, freqs_cis, False)
|
| 465 |
+
q = torch.cat([q_pe, q_nope], dim=-1)
|
| 466 |
+
k = self.wk(x)
|
| 467 |
+
k = self.k_norm(k)
|
| 468 |
+
k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
|
| 469 |
+
# rope in indexer is not interleaved
|
| 470 |
+
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis, False).squeeze(2)
|
| 471 |
+
k = torch.cat([k_pe, k_nope], dim=-1)
|
| 472 |
+
q = rotate_activation(q)
|
| 473 |
+
k = rotate_activation(k)
|
| 474 |
+
q_fp8, q_scale = act_quant(q, block_size, self.scale_fmt)
|
| 475 |
+
k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt)
|
| 476 |
+
self.k_cache[:bsz, start_pos:end_pos] = k_fp8
|
| 477 |
+
self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale
|
| 478 |
+
weights = self.weights_proj(x.float()) * self.n_heads ** -0.5
|
| 479 |
+
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
|
| 480 |
+
index_score = fp8_index(q_fp8.contiguous(), weights, self.k_cache[:bsz, :end_pos].contiguous(), self.k_scale_cache[:bsz, :end_pos].contiguous())
|
| 481 |
+
if mask is not None:
|
| 482 |
+
index_score += mask
|
| 483 |
+
topk_indices = index_score.topk(min(self.index_topk, end_pos), dim=-1)[1]
|
| 484 |
+
topk_indices_ = topk_indices.clone()
|
| 485 |
+
dist.broadcast(topk_indices_, src=0)
|
| 486 |
+
assert torch.all(topk_indices == topk_indices_), f"{topk_indices=} {topk_indices_=}"
|
| 487 |
+
return topk_indices
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def weight_dequant(weight, scale):
|
| 491 |
+
shape = weight.shape
|
| 492 |
+
assert weight.dim() == 2
|
| 493 |
+
weight = weight.view(shape[0] // block_size, block_size, shape[1] // block_size, block_size).transpose(1, 2).contiguous().view(-1, block_size * block_size)
|
| 494 |
+
weight = (weight.float() * scale.view(-1, 1).float()).to(torch.get_default_dtype()).view(shape[0] // block_size, shape[1] // block_size, block_size, block_size).transpose(1, 2).contiguous().view(shape)
|
| 495 |
+
return weight
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
class MLA(nn.Module):
|
| 499 |
+
"""
|
| 500 |
+
Multi-Head Latent Attention (MLA) Layer.
|
| 501 |
+
|
| 502 |
+
Attributes:
|
| 503 |
+
dim (int): Dimensionality of the input features.
|
| 504 |
+
n_heads (int): Number of attention heads.
|
| 505 |
+
n_local_heads (int): Number of local attention heads for distributed systems.
|
| 506 |
+
q_lora_rank (int): Rank for low-rank query projection.
|
| 507 |
+
kv_lora_rank (int): Rank for low-rank key/value projection.
|
| 508 |
+
qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
|
| 509 |
+
qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
|
| 510 |
+
qk_head_dim (int): Total dimensionality of query/key projections.
|
| 511 |
+
v_head_dim (int): Dimensionality of value projections.
|
| 512 |
+
softmax_scale (float): Scaling factor for softmax in attention computation.
|
| 513 |
+
"""
|
| 514 |
+
def __init__(self, args: ModelArgs):
|
| 515 |
+
super().__init__()
|
| 516 |
+
self.dim = args.dim
|
| 517 |
+
self.n_heads = args.n_heads
|
| 518 |
+
self.n_local_heads = args.n_heads // world_size
|
| 519 |
+
self.q_lora_rank = args.q_lora_rank
|
| 520 |
+
self.kv_lora_rank = args.kv_lora_rank
|
| 521 |
+
self.qk_nope_head_dim = args.qk_nope_head_dim
|
| 522 |
+
self.qk_rope_head_dim = args.qk_rope_head_dim
|
| 523 |
+
self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
|
| 524 |
+
self.v_head_dim = args.v_head_dim
|
| 525 |
+
|
| 526 |
+
self.wq_a = Linear(self.dim, self.q_lora_rank)
|
| 527 |
+
self.q_norm = RMSNorm(self.q_lora_rank)
|
| 528 |
+
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
|
| 529 |
+
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
|
| 530 |
+
self.kv_norm = RMSNorm(self.kv_lora_rank)
|
| 531 |
+
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
|
| 532 |
+
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
|
| 533 |
+
self.softmax_scale = self.qk_head_dim ** -0.5
|
| 534 |
+
self.scale_fmt = args.scale_fmt
|
| 535 |
+
if args.max_seq_len > args.original_seq_len:
|
| 536 |
+
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
|
| 537 |
+
self.softmax_scale = self.softmax_scale * mscale * mscale
|
| 538 |
+
|
| 539 |
+
self.indexer = Indexer(args)
|
| 540 |
+
|
| 541 |
+
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
|
| 542 |
+
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
|
| 543 |
+
self.dequant_wkv_b = None
|
| 544 |
+
|
| 545 |
+
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
| 546 |
+
"""
|
| 547 |
+
Forward pass for the Multi-Head Latent Attention (MLA) Layer.
|
| 548 |
+
|
| 549 |
+
Args:
|
| 550 |
+
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
|
| 551 |
+
start_pos (int): Starting position in the sequence for caching.
|
| 552 |
+
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
| 553 |
+
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
|
| 554 |
+
|
| 555 |
+
Returns:
|
| 556 |
+
torch.Tensor: Output tensor with the same shape as the input.
|
| 557 |
+
"""
|
| 558 |
+
bsz, seqlen, _ = x.size()
|
| 559 |
+
end_pos = start_pos + seqlen
|
| 560 |
+
qr = self.q_norm(self.wq_a(x))
|
| 561 |
+
q = self.wq_b(qr)
|
| 562 |
+
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
|
| 563 |
+
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
| 564 |
+
q_pe = apply_rotary_emb(q_pe, freqs_cis)
|
| 565 |
+
kv = self.wkv_a(x)
|
| 566 |
+
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
| 567 |
+
kv = self.kv_norm(kv)
|
| 568 |
+
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
|
| 569 |
+
# we use fp8 kv cache in actual deployment, so here we simulate the precision by casting kv to fp8 and then back to bf16.
|
| 570 |
+
kv_fp8, kv_scale = act_quant(kv, block_size, self.scale_fmt)
|
| 571 |
+
kv = (kv_fp8.view(-1, block_size).float() * kv_scale.view(-1, 1)).to(kv.dtype).view_as(kv)
|
| 572 |
+
self.kv_cache[:bsz, start_pos:end_pos] = kv
|
| 573 |
+
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
|
| 574 |
+
if mask is not None: # MHA prefill
|
| 575 |
+
q = torch.cat([q_nope, q_pe], dim=-1)
|
| 576 |
+
kv = self.wkv_b(kv)
|
| 577 |
+
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
| 578 |
+
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 579 |
+
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
|
| 580 |
+
scores = torch.einsum("bshd,bthd->bsht", q, k).mul_(self.softmax_scale)
|
| 581 |
+
|
| 582 |
+
# indexer
|
| 583 |
+
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
|
| 584 |
+
index_mask = torch.full((bsz, seqlen, seqlen), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
|
| 585 |
+
index_mask += mask
|
| 586 |
+
scores += index_mask.unsqueeze(2)
|
| 587 |
+
|
| 588 |
+
scores = scores.softmax(dim=-1)
|
| 589 |
+
x = torch.einsum("bsht,bthd->bshd", scores, v)
|
| 590 |
+
else: # MQA decode
|
| 591 |
+
if self.dequant_wkv_b is None and self.wkv_b.scale is not None:
|
| 592 |
+
self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
|
| 593 |
+
wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b
|
| 594 |
+
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
|
| 595 |
+
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
|
| 596 |
+
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
|
| 597 |
+
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
|
| 598 |
+
|
| 599 |
+
# indexer
|
| 600 |
+
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
|
| 601 |
+
index_mask = torch.full((bsz, 1, end_pos), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
|
| 602 |
+
scores += index_mask.unsqueeze(2)
|
| 603 |
+
|
| 604 |
+
scores = scores.softmax(dim=-1)
|
| 605 |
+
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
|
| 606 |
+
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
|
| 607 |
+
x = self.wo(x.flatten(2))
|
| 608 |
+
return x
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
class MLP(nn.Module):
|
| 612 |
+
"""
|
| 613 |
+
Multi-Layer Perceptron (MLP) used as a feed-forward layer.
|
| 614 |
+
|
| 615 |
+
Attributes:
|
| 616 |
+
w1 (nn.Module): Linear layer for input-to-hidden transformation.
|
| 617 |
+
w2 (nn.Module): Linear layer for hidden-to-output transformation.
|
| 618 |
+
w3 (nn.Module): Additional linear layer for feature transformation.
|
| 619 |
+
"""
|
| 620 |
+
def __init__(self, dim: int, inter_dim: int, reduce_output: bool = True):
|
| 621 |
+
"""
|
| 622 |
+
Initializes the MLP layer.
|
| 623 |
+
|
| 624 |
+
Args:
|
| 625 |
+
dim (int): Input and output dimensionality.
|
| 626 |
+
inter_dim (int): Hidden layer dimensionality.
|
| 627 |
+
"""
|
| 628 |
+
super().__init__()
|
| 629 |
+
self.w1 = ColumnParallelLinear(dim, inter_dim)
|
| 630 |
+
self.w2 = RowParallelLinear(inter_dim, dim, reduce_output=reduce_output)
|
| 631 |
+
self.w3 = ColumnParallelLinear(dim, inter_dim)
|
| 632 |
+
|
| 633 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 634 |
+
"""
|
| 635 |
+
Forward pass for the MLP layer.
|
| 636 |
+
|
| 637 |
+
Args:
|
| 638 |
+
x (torch.Tensor): Input tensor.
|
| 639 |
+
|
| 640 |
+
Returns:
|
| 641 |
+
torch.Tensor: Output tensor after MLP computation.
|
| 642 |
+
"""
|
| 643 |
+
return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x))
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
class Gate(nn.Module):
|
| 647 |
+
"""
|
| 648 |
+
Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.
|
| 649 |
+
|
| 650 |
+
Attributes:
|
| 651 |
+
dim (int): Dimensionality of input features.
|
| 652 |
+
topk (int): Number of top experts activated for each input.
|
| 653 |
+
n_groups (int): Number of groups for routing.
|
| 654 |
+
topk_groups (int): Number of groups to route inputs to.
|
| 655 |
+
score_func (str): Scoring function ('softmax' or 'sigmoid').
|
| 656 |
+
route_scale (float): Scaling factor for routing weights.
|
| 657 |
+
weight (torch.nn.Parameter): Learnable weights for the gate.
|
| 658 |
+
bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
|
| 659 |
+
"""
|
| 660 |
+
def __init__(self, args: ModelArgs):
|
| 661 |
+
"""
|
| 662 |
+
Initializes the Gate module.
|
| 663 |
+
|
| 664 |
+
Args:
|
| 665 |
+
args (ModelArgs): Model arguments containing gating parameters.
|
| 666 |
+
"""
|
| 667 |
+
super().__init__()
|
| 668 |
+
self.dim = args.dim
|
| 669 |
+
self.topk = args.n_activated_experts
|
| 670 |
+
self.n_groups = args.n_expert_groups
|
| 671 |
+
self.topk_groups = args.n_limited_groups
|
| 672 |
+
self.score_func = args.score_func
|
| 673 |
+
self.route_scale = args.route_scale
|
| 674 |
+
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
|
| 675 |
+
self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) if self.dim == 7168 else None
|
| 676 |
+
|
| 677 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 678 |
+
"""
|
| 679 |
+
Forward pass for the gating mechanism.
|
| 680 |
+
|
| 681 |
+
Args:
|
| 682 |
+
x (torch.Tensor): Input tensor.
|
| 683 |
+
|
| 684 |
+
Returns:
|
| 685 |
+
Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
|
| 686 |
+
"""
|
| 687 |
+
scores = linear(x.float(), self.weight.float())
|
| 688 |
+
if self.score_func == "softmax":
|
| 689 |
+
scores = scores.softmax(dim=-1)
|
| 690 |
+
else:
|
| 691 |
+
scores = scores.sigmoid()
|
| 692 |
+
original_scores = scores
|
| 693 |
+
if self.bias is not None:
|
| 694 |
+
scores = scores + self.bias
|
| 695 |
+
if self.n_groups > 1:
|
| 696 |
+
scores = scores.view(x.size(0), self.n_groups, -1)
|
| 697 |
+
if self.bias is None:
|
| 698 |
+
group_scores = scores.amax(dim=-1)
|
| 699 |
+
else:
|
| 700 |
+
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
|
| 701 |
+
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
|
| 702 |
+
mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)
|
| 703 |
+
scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)
|
| 704 |
+
indices = scores.topk(self.topk, dim=-1)[1]
|
| 705 |
+
weights = original_scores.gather(1, indices)
|
| 706 |
+
if self.score_func == "sigmoid":
|
| 707 |
+
weights /= weights.sum(dim=-1, keepdim=True)
|
| 708 |
+
weights *= self.route_scale
|
| 709 |
+
return weights, indices
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
class Expert(nn.Module):
|
| 713 |
+
"""
|
| 714 |
+
Expert layer for Mixture-of-Experts (MoE) models.
|
| 715 |
+
|
| 716 |
+
Attributes:
|
| 717 |
+
w1 (nn.Module): Linear layer for input-to-hidden transformation.
|
| 718 |
+
w2 (nn.Module): Linear layer for hidden-to-output transformation.
|
| 719 |
+
w3 (nn.Module): Additional linear layer for feature transformation.
|
| 720 |
+
"""
|
| 721 |
+
def __init__(self, dim: int, inter_dim: int):
|
| 722 |
+
"""
|
| 723 |
+
Initializes the Expert layer.
|
| 724 |
+
|
| 725 |
+
Args:
|
| 726 |
+
dim (int): Input and output dimensionality.
|
| 727 |
+
inter_dim (int): Hidden layer dimensionality.
|
| 728 |
+
"""
|
| 729 |
+
super().__init__()
|
| 730 |
+
self.w1 = Linear(dim, inter_dim)
|
| 731 |
+
self.w2 = Linear(inter_dim, dim)
|
| 732 |
+
self.w3 = Linear(dim, inter_dim)
|
| 733 |
+
|
| 734 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 735 |
+
"""
|
| 736 |
+
Forward pass for the Expert layer.
|
| 737 |
+
|
| 738 |
+
Args:
|
| 739 |
+
x (torch.Tensor): Input tensor.
|
| 740 |
+
|
| 741 |
+
Returns:
|
| 742 |
+
torch.Tensor: Output tensor after expert computation.
|
| 743 |
+
"""
|
| 744 |
+
return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x))
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
class MoE(nn.Module):
|
| 748 |
+
"""
|
| 749 |
+
Mixture-of-Experts (MoE) module.
|
| 750 |
+
|
| 751 |
+
Attributes:
|
| 752 |
+
dim (int): Dimensionality of input features.
|
| 753 |
+
n_routed_experts (int): Total number of experts in the model.
|
| 754 |
+
n_local_experts (int): Number of experts handled locally in distributed systems.
|
| 755 |
+
n_activated_experts (int): Number of experts activated for each input.
|
| 756 |
+
gate (nn.Module): Gating mechanism to route inputs to experts.
|
| 757 |
+
experts (nn.ModuleList): List of expert modules.
|
| 758 |
+
shared_experts (nn.Module): Shared experts applied to all inputs.
|
| 759 |
+
"""
|
| 760 |
+
def __init__(self, args: ModelArgs):
|
| 761 |
+
"""
|
| 762 |
+
Initializes the MoE module.
|
| 763 |
+
|
| 764 |
+
Args:
|
| 765 |
+
args (ModelArgs): Model arguments containing MoE parameters.
|
| 766 |
+
"""
|
| 767 |
+
super().__init__()
|
| 768 |
+
self.dim = args.dim
|
| 769 |
+
assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
|
| 770 |
+
self.n_routed_experts = args.n_routed_experts
|
| 771 |
+
self.n_local_experts = args.n_routed_experts // world_size
|
| 772 |
+
self.n_activated_experts = args.n_activated_experts
|
| 773 |
+
self.experts_start_idx = rank * self.n_local_experts
|
| 774 |
+
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
|
| 775 |
+
self.gate = Gate(args)
|
| 776 |
+
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
|
| 777 |
+
for i in range(self.n_routed_experts)])
|
| 778 |
+
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim, reduce_output=False)
|
| 779 |
+
|
| 780 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 781 |
+
"""
|
| 782 |
+
Forward pass for the MoE module.
|
| 783 |
+
|
| 784 |
+
Args:
|
| 785 |
+
x (torch.Tensor): Input tensor.
|
| 786 |
+
|
| 787 |
+
Returns:
|
| 788 |
+
torch.Tensor: Output tensor after expert routing and computation.
|
| 789 |
+
"""
|
| 790 |
+
shape = x.size()
|
| 791 |
+
x = x.view(-1, self.dim)
|
| 792 |
+
weights, indices = self.gate(x)
|
| 793 |
+
y = torch.zeros_like(x, dtype=torch.float32)
|
| 794 |
+
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
|
| 795 |
+
for i in range(self.experts_start_idx, self.experts_end_idx):
|
| 796 |
+
if counts[i] == 0:
|
| 797 |
+
continue
|
| 798 |
+
expert = self.experts[i]
|
| 799 |
+
idx, top = torch.where(indices == i)
|
| 800 |
+
y[idx] += expert(x[idx]) * weights[idx, top, None]
|
| 801 |
+
y += self.shared_experts(x)
|
| 802 |
+
if world_size > 1:
|
| 803 |
+
dist.all_reduce(y)
|
| 804 |
+
return y.type_as(x).view(shape)
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
class Block(nn.Module):
|
| 808 |
+
"""
|
| 809 |
+
Transformer block combining attention and feed-forward layers.
|
| 810 |
+
|
| 811 |
+
Attributes:
|
| 812 |
+
attn (nn.Module): Attention layer (MLA).
|
| 813 |
+
ffn (nn.Module): Feed-forward network (MLP or MoE).
|
| 814 |
+
attn_norm (nn.Module): Layer normalization for attention.
|
| 815 |
+
ffn_norm (nn.Module): Layer normalization for feed-forward network.
|
| 816 |
+
"""
|
| 817 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 818 |
+
"""
|
| 819 |
+
Initializes the Transformer block.
|
| 820 |
+
|
| 821 |
+
Args:
|
| 822 |
+
layer_id (int): Layer index in the transformer.
|
| 823 |
+
args (ModelArgs): Model arguments containing block parameters.
|
| 824 |
+
"""
|
| 825 |
+
super().__init__()
|
| 826 |
+
self.attn = MLA(args)
|
| 827 |
+
self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
|
| 828 |
+
self.attn_norm = RMSNorm(args.dim)
|
| 829 |
+
self.ffn_norm = RMSNorm(args.dim)
|
| 830 |
+
|
| 831 |
+
def forward(self, x: torch.Tensor, residual: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
|
| 832 |
+
"""
|
| 833 |
+
Forward pass for the Transformer block.
|
| 834 |
+
|
| 835 |
+
Args:
|
| 836 |
+
x (torch.Tensor): Input tensor.
|
| 837 |
+
start_pos (int): Starting position in the sequence.
|
| 838 |
+
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
| 839 |
+
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
|
| 840 |
+
|
| 841 |
+
Returns:
|
| 842 |
+
torch.Tensor: Output tensor after block computation.
|
| 843 |
+
"""
|
| 844 |
+
if residual is None:
|
| 845 |
+
x, residual = self.attn_norm(x), x
|
| 846 |
+
else:
|
| 847 |
+
x, residual = self.attn_norm(x, residual)
|
| 848 |
+
x = self.attn(x, start_pos, freqs_cis, mask)
|
| 849 |
+
x, residual = self.ffn_norm(x, residual)
|
| 850 |
+
x = self.ffn(x)
|
| 851 |
+
return x, residual
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
class Transformer(nn.Module):
|
| 855 |
+
"""
|
| 856 |
+
Transformer model with positional embeddings, multiple layers, and output projection.
|
| 857 |
+
|
| 858 |
+
Attributes:
|
| 859 |
+
max_seq_len (int): Maximum sequence length for the transformer.
|
| 860 |
+
embed (nn.Module): Embedding layer for input tokens.
|
| 861 |
+
layers (torch.nn.ModuleList): List of transformer blocks.
|
| 862 |
+
norm (nn.Module): Layer normalization applied after all blocks.
|
| 863 |
+
head (nn.Module): Output projection layer mapping to vocabulary size.
|
| 864 |
+
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
| 865 |
+
"""
|
| 866 |
+
def __init__(self, args: ModelArgs):
|
| 867 |
+
"""
|
| 868 |
+
Initializes the Transformer model.
|
| 869 |
+
|
| 870 |
+
Args:
|
| 871 |
+
args (ModelArgs): Model arguments containing transformer parameters.
|
| 872 |
+
"""
|
| 873 |
+
global world_size, rank
|
| 874 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 875 |
+
rank = dist.get_rank() if dist.is_initialized() else 0
|
| 876 |
+
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
|
| 877 |
+
Linear.scale_fmt = args.scale_fmt
|
| 878 |
+
super().__init__()
|
| 879 |
+
self.max_seq_len = args.max_seq_len
|
| 880 |
+
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
|
| 881 |
+
self.layers = torch.nn.ModuleList()
|
| 882 |
+
for layer_id in range(args.n_layers):
|
| 883 |
+
self.layers.append(Block(layer_id, args))
|
| 884 |
+
self.norm = RMSNorm(args.dim)
|
| 885 |
+
# lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later.
|
| 886 |
+
self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.float32)
|
| 887 |
+
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
|
| 888 |
+
|
| 889 |
+
@torch.inference_mode()
|
| 890 |
+
def forward(self, tokens: torch.Tensor, start_pos: int = 0):
|
| 891 |
+
"""
|
| 892 |
+
Forward pass for the Transformer model.
|
| 893 |
+
|
| 894 |
+
Args:
|
| 895 |
+
tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len).
|
| 896 |
+
start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0.
|
| 897 |
+
|
| 898 |
+
Returns:
|
| 899 |
+
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
|
| 900 |
+
"""
|
| 901 |
+
seqlen = tokens.size(1)
|
| 902 |
+
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
| 903 |
+
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) if seqlen > 1 else None
|
| 904 |
+
h, residual = self.embed(tokens), None
|
| 905 |
+
for layer in self.layers:
|
| 906 |
+
h, residual = layer(h, residual, start_pos, freqs_cis, mask)
|
| 907 |
+
h, _ = self.norm(h, residual)
|
| 908 |
+
logits = self.head(h[:, -1].float())
|
| 909 |
+
if world_size > 1:
|
| 910 |
+
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
|
| 911 |
+
dist.all_gather(all_logits, logits)
|
| 912 |
+
logits = torch.cat(all_logits, dim=-1)
|
| 913 |
+
return logits
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
if __name__ == "__main__":
|
| 917 |
+
torch.set_default_dtype(torch.bfloat16)
|
| 918 |
+
torch.set_default_device("cuda")
|
| 919 |
+
torch.manual_seed(0)
|
| 920 |
+
args = ModelArgs()
|
| 921 |
+
x = torch.randint(0, args.vocab_size, (2, 128))
|
| 922 |
+
model = Transformer(args)
|
| 923 |
+
print(model(x).size())
|
inference/requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers
|
| 3 |
+
safetensors
|
| 4 |
+
fast_hadamard_transform
|
| 5 |
+
tilelang==0.1.6
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": true,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"bos_token": {
|
| 5 |
+
"__type": "AddedToken",
|
| 6 |
+
"content": "<|begin▁of▁sentence|>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": true,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false
|
| 11 |
+
},
|
| 12 |
+
"clean_up_tokenization_spaces": false,
|
| 13 |
+
"eos_token": {
|
| 14 |
+
"__type": "AddedToken",
|
| 15 |
+
"content": "<|end▁of▁sentence|>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": true,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false
|
| 20 |
+
},
|
| 21 |
+
"legacy": true,
|
| 22 |
+
"model_max_length": 131072,
|
| 23 |
+
"pad_token": {
|
| 24 |
+
"__type": "AddedToken",
|
| 25 |
+
"content": "<|end▁of▁sentence|>",
|
| 26 |
+
"lstrip": false,
|
| 27 |
+
"normalized": true,
|
| 28 |
+
"rstrip": false,
|
| 29 |
+
"single_word": false
|
| 30 |
+
},
|
| 31 |
+
"sp_model_kwargs": {},
|
| 32 |
+
"unk_token": null,
|
| 33 |
+
"tokenizer_class": "LlamaTokenizerFast",
|
| 34 |
+
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% if not thinking is defined %}{% set thinking = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false, is_only_sys=false, is_prefix=false) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}{%- endif %}{% set ns.is_only_sys = true %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{%- set ns.is_first = false -%}{%- set ns.is_last_user = true -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}{%- if ns.is_last_user or ns.is_only_sys %}{{'<|Assistant|></think>'}}{%- endif %}{%- set ns.is_last_user = false -%}{%- set ns.is_first = false %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}{%- else %}{{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %}{%- if ns.is_last_user %}{{'<|Assistant|>'}}{%- if message['prefix'] is defined and message['prefix'] and thinking %}{{'<think>'}}{%- else %}{{'</think>'}}{%- endif %}{%- endif %}{%- if message['prefix'] is defined and message['prefix'] %}{%- set ns.is_prefix = true -%}{%- endif %}{%- set ns.is_last_user = false -%}{%- if ns.is_tool %}{{message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{%- set content = message['content'] -%}{%- if '</think>' in content %}{%- set content = content.split('</think>', 1)[1] -%}{%- endif %}{{content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_last_user = false -%}{%- set ns.is_tool = true -%}{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- if message['role'] != 'system' %}{% set ns.is_only_sys = false %}{%- endif %}{%- endfor -%}{% if add_generation_prompt and not ns.is_tool%}{% if ns.is_last_user or ns.is_only_sys or not ns.is_prefix %}{{'<|Assistant|>'}}{%- if not thinking %}{{'</think>'}}{%- else %}{{'<think>'}}{%- endif %}{% endif %}{% endif %}"
|
| 35 |
+
}
|