Upload folder using huggingface_hub
Browse files- .gitattributes +3 -0
- LICENSE +63 -0
- NOTICE +9 -0
- README.md +243 -6
- __pycache__/configuration_superlinear_exp.cpython-312.pyc +0 -0
- __pycache__/configuration_superlinear_exp.cpython-313.pyc +0 -0
- __pycache__/modeling_superlinear_exp.cpython-312.pyc +3 -0
- __pycache__/modeling_superlinear_exp.cpython-313.pyc +3 -0
- __pycache__/moe.cpython-313.pyc +0 -0
- chat_template.jinja +204 -0
- config.json +80 -0
- configuration_superlinear_exp.py +341 -0
- generation_config.json +11 -0
- model-00001-of-00016.safetensors +3 -0
- model-00002-of-00016.safetensors +3 -0
- model-00003-of-00016.safetensors +3 -0
- model-00004-of-00016.safetensors +3 -0
- model-00005-of-00016.safetensors +3 -0
- model-00006-of-00016.safetensors +3 -0
- model-00007-of-00016.safetensors +3 -0
- model-00008-of-00016.safetensors +3 -0
- model-00009-of-00016.safetensors +3 -0
- model-00010-of-00016.safetensors +3 -0
- model-00011-of-00016.safetensors +3 -0
- model-00012-of-00016.safetensors +3 -0
- model-00013-of-00016.safetensors +3 -0
- model-00014-of-00016.safetensors +3 -0
- model-00015-of-00016.safetensors +3 -0
- model-00016-of-00016.safetensors +3 -0
- model.safetensors.index.json +414 -0
- modeling_superlinear_exp.py +0 -0
- moe.py +890 -0
- special_tokens_map.json +30 -0
- tokenizer.json +3 -0
- tokenizer_config.json +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
__pycache__/modeling_superlinear_exp.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
__pycache__/modeling_superlinear_exp.cpython-313.pyc filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
NVIDIA Open Model License Agreement
|
| 2 |
+
|
| 3 |
+
Last Modified: October 24, 2025
|
| 4 |
+
|
| 5 |
+
This NVIDIA Open Model License Agreement (the "Agreement") is a legal agreement between the Legal Entity You represent, or if no entity is identified, You and NVIDIA Corporation and its Affiliates ("NVIDIA") and governs Your use of the Models that NVIDIA provides to You under this Agreement. NVIDIA and You are each a "party" and collectively the "parties."
|
| 6 |
+
|
| 7 |
+
NVIDIA models released under this Agreement are intended to be used permissively and enable the further development of AI technologies. Subject to the terms of this Agreement, NVIDIA confirms that:
|
| 8 |
+
|
| 9 |
+
- Models are commercially usable.
|
| 10 |
+
- You are free to create and distribute Derivative Models.
|
| 11 |
+
- NVIDIA does not claim ownership to any outputs generated using the Models or Derivative Models.
|
| 12 |
+
|
| 13 |
+
By using, reproducing, modifying, distributing, performing or displaying any portion or element of the Model or Derivative Model, or otherwise accepting the terms of this Agreement, you agree to be bound by this Agreement.
|
| 14 |
+
|
| 15 |
+
Definitions. The following definitions apply to this Agreement:
|
| 16 |
+
|
| 17 |
+
"Derivative Model" means all (a) modifications to the Model, (b) works based on the Model, and (c) any other derivative works of the Model. An output is not a Derivative Model.
|
| 18 |
+
|
| 19 |
+
"Legal Entity" means the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (a) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (b) ownership of fifty percent (50%) or more of the outstanding shares, or (c) beneficial ownership of such entity.
|
| 20 |
+
|
| 21 |
+
"Model" means the machine learning model, software, checkpoints, learnt weights, algorithms, parameters, configuration files and documentation shared under this Agreement.
|
| 22 |
+
|
| 23 |
+
"NVIDIA Cosmos Model" means a multimodal Model shared under this Agreement
|
| 24 |
+
|
| 25 |
+
"Special-Purpose Model" means a Model that is only competent in a narrow set of purpose-specific tasks and should not be used for unintended or general-purpose applications
|
| 26 |
+
|
| 27 |
+
"You" or "Your" means an individual or Legal Entity exercising permissions granted by this Agreement.
|
| 28 |
+
|
| 29 |
+
Conditions for Use, License Grant, AI Ethics and IP Ownership.
|
| 30 |
+
|
| 31 |
+
2.1 Conditions for Use. The Model and any Derivative Model are subject to additional terms as described in Section 2 and Section 3 of this Agreement and govern Your use. If You institute copyright or patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model or a Derivative Model constitutes direct or contributory copyright or patent infringement, then any licenses granted to You under this Agreement for that Model or Derivative Model will terminate as of the date such litigation is filed. If You bypass, disable, reduce the efficacy of, or circumvent any technical limitation, safety guardrail or associated safety guardrail hyperparameter, encryption, security, digital rights management, or authentication mechanism (collectively "Guardrail") contained in the Model without a substantially similar Guardrail appropriate for your use case, your rights under this Agreement will automatically terminate. NVIDIA may indicate in relevant documentation that a Model is a Special-Purpose Model. NVIDIA may update this Agreement to comply with legal and regulatory requirements at any time and You agree to either comply with any updated license or cease Your copying, use, and distribution of the Model and any Derivative Model.
|
| 32 |
+
|
| 33 |
+
2.2 License Grant. The rights granted herein are explicitly conditioned on Your full compliance with the terms of this Agreement. Subject to the terms and conditions of this Agreement, NVIDIA hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, revocable (as stated in Section 2.1) license to publicly perform, publicly display, reproduce, use, create derivative works of, make, have made, sell, offer for sale, distribute (through multiple tiers of distribution) and import the Model.
|
| 34 |
+
|
| 35 |
+
2.3 AI Ethics. Use of the Models under the Agreement must be consistent with NVIDIA's Trustworthy AI terms found at https://www.nvidia.com/en-us/agreements/trustworthy-ai/terms/.
|
| 36 |
+
|
| 37 |
+
2.4 NVIDIA owns the Model and any Derivative Models created by NVIDIA. Subject to NVIDIA's underlying ownership rights in the Model or its Derivative Models, You are and will be the owner of Your Derivative Models. NVIDIA claims no ownership rights in outputs. You are responsible for outputs and their subsequent uses. Except as expressly granted in this Agreement, (a) NVIDIA reserves all rights, interests and remedies in connection with the Model and (b) no other license or right is granted to you by implication, estoppel or otherwise.
|
| 38 |
+
|
| 39 |
+
Redistribution. You may reproduce and distribute copies of the Model or Derivative Models thereof in any medium, with or without modifications, provided that You meet the following conditions:
|
| 40 |
+
|
| 41 |
+
3.1 If you distribute the Model, You must give any other recipients of the Model a copy of this Agreement and include the following attribution notice within a "Notice" text file with such copies: "Licensed by NVIDIA Corporation under the NVIDIA Open Model License";
|
| 42 |
+
|
| 43 |
+
3.2 If you distribute or make available a NVIDIA Cosmos Model, or a product or service (including an AI model) that contains or uses a NVIDIA Cosmos Model, use a NVIDIA Cosmos Model to create a Derivative Model, or use a NVIDIA Cosmos Model or its outputs to create, train, fine tune, or otherwise improve an AI model, you will include "Built on NVIDIA Cosmos" on a related website, user interface, blogpost, about page, or product documentation; and
|
| 44 |
+
|
| 45 |
+
3.3 You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Models as a whole, provided Your use, reproduction, and distribution of the Model otherwise complies with the conditions stated in this Agreement.
|
| 46 |
+
|
| 47 |
+
Separate Components. The Models may include or be distributed with components provided with separate legal notices or terms that accompany the components, such as an Open Source Software License or other third-party license. The components are subject to the applicable other licenses, including any proprietary notices, disclaimers, requirements and extended use rights; except that this Agreement will prevail regarding the use of third-party Open Source Software License, unless a third-party Open Source Software License requires its license terms to prevail. "Open Source Software License" means any software, data or documentation subject to any license identified as an open source license by the Open Source Initiative (https://opensource.org), Free Software Foundation (https://www.fsf.org) or other similar open source organization or listed by the Software Package Data Exchange (SPDX) Workgroup under the Linux Foundation (https://www.spdx.org).
|
| 48 |
+
|
| 49 |
+
Trademarks. This Agreement does not grant permission to use the trade names, trademarks, service marks, or product names of NVIDIA, except as required for reasonable and customary use in describing the origin of the Model and reproducing the content of the "Notice" text file.
|
| 50 |
+
|
| 51 |
+
Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, NVIDIA provides the Model on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for reviewing Model documentation, including any Special-Purpose Model limitations, and determining the appropriateness of using or redistributing the Model, Derivative Models and outputs. You assume any risks associated with Your exercise of permissions under this Agreement.
|
| 52 |
+
|
| 53 |
+
Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, will NVIDIA be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Model, Derivative Models or outputs (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if NVIDIA has been advised of the possibility of such damages.
|
| 54 |
+
|
| 55 |
+
Indemnity. You will indemnify and hold harmless NVIDIA from and against any claim by any third party arising out of or related to your use or distribution of the Model, Derivative Models or outputs.
|
| 56 |
+
|
| 57 |
+
Feedback. NVIDIA appreciates your feedback, and You agree that NVIDIA may use it without restriction or compensation to You.
|
| 58 |
+
|
| 59 |
+
Governing Law. This Agreement will be governed in all respects by the laws of the United States and the laws of the State of Delaware, without regard to conflict of laws principles or the United Nations Convention on Contracts for the International Sale of Goods. The state and federal courts residing in Santa Clara County, California will have exclusive jurisdiction over any dispute or claim arising out of or related to this Agreement, and the parties irrevocably consent to personal jurisdiction and venue in those courts; except that, either party may apply for injunctive remedies or an equivalent type of urgent legal relief in any jurisdiction.
|
| 60 |
+
|
| 61 |
+
Trade and Compliance. You agree to comply with all applicable export, import, trade and economic sanctions laws and regulations, as amended, including without limitation U.S. Export Administration Regulations and Office of Foreign Assets Control regulations. These laws include restrictions on destinations, end-users and end-use.
|
| 62 |
+
|
| 63 |
+
Version Release Date: October 24, 2025
|
NOTICE
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Licensed by NVIDIA Corporation under the NVIDIA Open Model License
|
| 2 |
+
|
| 3 |
+
This model (superlinear-exp-v0.1) is a Derivative Model based on NVIDIA Nemotron-3-Nano-30B-A3B-BF16.
|
| 4 |
+
|
| 5 |
+
Upstream model: https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
|
| 6 |
+
|
| 7 |
+
Modifications by: Concavity AI (Yufeng Huang)
|
| 8 |
+
- Replaced standard attention layers with Superlinear attention layers
|
| 9 |
+
- Fine-tuned on long-context retrieval tasks
|
README.md
CHANGED
|
@@ -1,6 +1,243 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: other
|
| 3 |
-
license_name: nvidia-open-model-license
|
| 4 |
-
license_link:
|
| 5 |
-
|
| 6 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: other
|
| 3 |
+
license_name: nvidia-open-model-license
|
| 4 |
+
license_link: LICENSE
|
| 5 |
+
library_name: transformers
|
| 6 |
+
pipeline_tag: text-generation
|
| 7 |
+
tags:
|
| 8 |
+
- long-context
|
| 9 |
+
- superlinear-attention
|
| 10 |
+
- subquadratic
|
| 11 |
+
- causal-lm
|
| 12 |
+
base_model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# Superlinear-Exp-v0.1
|
| 16 |
+
|
| 17 |
+
**Superlinear Multi-Step Attention** — a subquadratic attention mechanism that preserves random context access (structural non-exclusion) for extremely long sequences.
|
| 18 |
+
|
| 19 |
+
This is an experimental release demonstrating the Superlinear attention architecture integrated into a modified [nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) hybrid model.
|
| 20 |
+
|
| 21 |
+
> **WARNING (Security):** This model requires `trust_remote_code=True`, which executes Python code from this repository. Review the code before running in sensitive environments.
|
| 22 |
+
|
| 23 |
+
## Model Description
|
| 24 |
+
|
| 25 |
+
Superlinear attention reformulates causal self-attention as a multi-step search problem:
|
| 26 |
+
|
| 27 |
+
1. **Accumulation** — Efficiently processes the sequence and produces per-position representatives (via Mamba-2 layers in the hybrid architecture).
|
| 28 |
+
2. **Span Search** — Scores a sublinear number of candidate spans using learned routing, then selects top-k spans per query.
|
| 29 |
+
3. **Span Attention** — Computes standard token-level attention within the selected contiguous spans.
|
| 30 |
+
4. **Combination** — Produces outputs using softmax-weighted gating over span attention outputs.
|
| 31 |
+
|
| 32 |
+
In the baseline N=2 implementation, both span-search and span-attention scale as **O(L^(3/2))**, enabling practical inference at multi-million-token context lengths where dense attention becomes prohibitive.
|
| 33 |
+
|
| 34 |
+
**Key property:** *Random context access* (structural non-exclusion) — any eligible token position can be selected by the content-dependent routing mechanism; no fixed sparsity pattern permanently excludes positions.
|
| 35 |
+
|
| 36 |
+
## Quickstart
|
| 37 |
+
|
| 38 |
+
```python
|
| 39 |
+
import torch
|
| 40 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 41 |
+
|
| 42 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 43 |
+
"concavity-ai/superlinear-exp-v0.1",
|
| 44 |
+
trust_remote_code=True
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 48 |
+
"concavity-ai/superlinear-exp-v0.1",
|
| 49 |
+
torch_dtype=torch.float16,
|
| 50 |
+
device_map="cuda",
|
| 51 |
+
trust_remote_code=True,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
messages = [{"role": "user", "content": "Explain the Transformer architecture."}]
|
| 55 |
+
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda")
|
| 56 |
+
|
| 57 |
+
output = model.generate(inputs, max_new_tokens=1000, do_sample=True, temperature=0.1, top_p=0.99)
|
| 58 |
+
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
## Dependencies
|
| 62 |
+
|
| 63 |
+
This model uses custom Python code (`trust_remote_code=True`) and CUDA extensions.
|
| 64 |
+
|
| 65 |
+
### Recommended: follow the Superlinear repo install
|
| 66 |
+
|
| 67 |
+
The simplest supported path is to use the installation flow from the Superlinear repo (it pins a known-good CUDA toolchain and builds `mamba-ssm[causal-conv1d]` from source to avoid wheel/ABI mismatches):
|
| 68 |
+
|
| 69 |
+
https://github.com/concavity-ai/superlinear#installation
|
| 70 |
+
|
| 71 |
+
Copy/paste one-liner (from the Superlinear repo root):
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
conda env create -f environment.yml \
|
| 75 |
+
&& conda run -n superlinear pip install torch --index-url https://download.pytorch.org/whl/cu128 \
|
| 76 |
+
&& conda run -n superlinear pip install -e ".[server,model]" \
|
| 77 |
+
&& conda run -n superlinear bash -lc 'CUDA_HOME="$CONDA_PREFIX" pip install "mamba-ssm[causal-conv1d]" --no-build-isolation --no-cache-dir --no-binary :all:'
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
### Optional: pip-only (if `mamba-ssm` already works in your env)
|
| 81 |
+
|
| 82 |
+
If `python -c "import mamba_ssm, causal_conv1d"` already succeeds in the environment you’ll run inference in, you already have a working PyTorch/CUDA pairing for the extension in that environment — you should not need to reinstall PyTorch.
|
| 83 |
+
|
| 84 |
+
Install the remaining Python deps + Superlinear:
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
pip install -U "transformers<5" accelerate safetensors
|
| 88 |
+
pip install -U vllm triton
|
| 89 |
+
|
| 90 |
+
# Superlinear kernels (span-attention)
|
| 91 |
+
pip install -U git+https://github.com/concavity-ai/superlinear.git
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
### Building `mamba-ssm` from source (only if needed)
|
| 95 |
+
|
| 96 |
+
If you must build `mamba-ssm[causal-conv1d]` yourself, you need a CUDA toolkit with `nvcc` and `CUDA_HOME` pointing at it (example: `/usr/local/cuda`):
|
| 97 |
+
|
| 98 |
+
```bash
|
| 99 |
+
CUDA_HOME=/usr/local/cuda \
|
| 100 |
+
pip install -U "mamba-ssm[causal-conv1d]" \
|
| 101 |
+
--no-build-isolation --no-cache-dir --no-binary :all:
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
## Recommended Inference Settings
|
| 105 |
+
|
| 106 |
+
For long-context inference with the Superlinear attention mechanism, use the following configuration:
|
| 107 |
+
|
| 108 |
+
```python
|
| 109 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 110 |
+
"concavity-ai/superlinear-exp-v0.1",
|
| 111 |
+
# Attention implementation
|
| 112 |
+
_attn_implementation='block-span-gqa',
|
| 113 |
+
decode_kernel='staged-gqa',
|
| 114 |
+
|
| 115 |
+
# Performance optimizations
|
| 116 |
+
enable_cuda_graph=True,
|
| 117 |
+
enable_shared_fused_moe=True,
|
| 118 |
+
|
| 119 |
+
# Superlinear attention hyperparameters
|
| 120 |
+
span_attention_sw_index=65, # Local window boundary index
|
| 121 |
+
span_attention_num_spans=3, # Top-k spans per query
|
| 122 |
+
span_attention_backward_factor=3, # Backward span extent multiplier
|
| 123 |
+
span_attention_forward_factor=1, # Forward span extent multiplier
|
| 124 |
+
span_attention_search_power=0.55, # Search exponent (controls anchor budget)
|
| 125 |
+
span_attention_span_power=0.55, # Span exponent (controls span scale)
|
| 126 |
+
|
| 127 |
+
torch_dtype=torch.float16,
|
| 128 |
+
device_map="cuda",
|
| 129 |
+
trust_remote_code=True,
|
| 130 |
+
)
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
### Hyperparameter Notes
|
| 134 |
+
|
| 135 |
+
| Parameter | Description | Typical Value |
|
| 136 |
+
|-----------|-------------|---------------|
|
| 137 |
+
| `span_attention_num_spans` | Number of routed spans selected per query (top-k) | 2 or 3 |
|
| 138 |
+
| `span_attention_backward_factor` | Backward extent of each span relative to base scale | 2–4 |
|
| 139 |
+
| `span_attention_forward_factor` | Forward extent of each span relative to base scale | 0–2 |
|
| 140 |
+
| `span_attention_search_power` | Exponent controlling the number of candidate anchors | 0.5–0.667 |
|
| 141 |
+
| `span_attention_span_power` | Exponent controlling span length scaling | 0.5–0.667 |
|
| 142 |
+
|
| 143 |
+
**Sliding window length from `span_attention_sw_index`:** Internally, the kernels compute the sliding-window length as:
|
| 144 |
+
|
| 145 |
+
```text
|
| 146 |
+
window_len = floor((sw_index + 1) ** (1 / search_power)) - 1
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
We parameterize the local sliding window using `sw_index` (a stride/stripe index) rather than specifying `window_len` directly. This keeps the sliding-window boundary aligned with the same index space used by span search, so span-search begins immediately after the sliding-window region and avoids gaps between local attention and routed spans.
|
| 150 |
+
|
| 151 |
+
Example: with `search_power=0.55` and `sw_index=65`,
|
| 152 |
+
|
| 153 |
+
```text
|
| 154 |
+
window_len = floor(66 ** (1 / 0.55)) - 1 = 2032
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
## Hardware Requirements
|
| 158 |
+
|
| 159 |
+
- **GPU:** NVIDIA GPU with sufficient VRAM (tested on B200 180GB)
|
| 160 |
+
- **KV Cache:** ~6GB per million tokens (model-dependent)
|
| 161 |
+
- **Precision:** FP16 recommended
|
| 162 |
+
|
| 163 |
+
### Measured Throughput (Single B200, Batch=1)
|
| 164 |
+
|
| 165 |
+
| Context Length | Prefill (tok/s) | Decode (tok/s) |
|
| 166 |
+
|----------------|-----------------|----------------|
|
| 167 |
+
| 1M tokens | ~20,202 | ~109 |
|
| 168 |
+
| 10M tokens | ~5,576 | ~76 |
|
| 169 |
+
|
| 170 |
+
*Your results may vary depending on hardware, batch size, and configuration.*
|
| 171 |
+
|
| 172 |
+
## Intended Use
|
| 173 |
+
|
| 174 |
+
This is an **architecture-and-systems feasibility study** release. It demonstrates that:
|
| 175 |
+
|
| 176 |
+
1. The Superlinear attention mechanism is structurally random-context-access-preserving
|
| 177 |
+
2. The architecture achieves asymptotically subquadratic attention complexity
|
| 178 |
+
3. The resulting irregular span pattern can be implemented with practical performance at very long context lengths
|
| 179 |
+
|
| 180 |
+
### Limitations
|
| 181 |
+
|
| 182 |
+
- **Not a comprehensive quality study:** We do not present extensive ablations or claim state-of-the-art accuracy on benchmarks.
|
| 183 |
+
- **Limited evaluation:** Initial validation focused on NIAH (Needle In A Haystack) retrieval task and throughput measurements.
|
| 184 |
+
- **Experimental:** This release is intended for research and experimentation, not production use.
|
| 185 |
+
- **Memory:** Full KV cache must be retained for random context access; memory usage scales with context length.
|
| 186 |
+
|
| 187 |
+
## What's in This Repository
|
| 188 |
+
|
| 189 |
+
```
|
| 190 |
+
├── config.json # Model configuration
|
| 191 |
+
├── generation_config.json # Default generation settings
|
| 192 |
+
├── tokenizer.json # Tokenizer
|
| 193 |
+
├── tokenizer_config.json
|
| 194 |
+
├── special_tokens_map.json
|
| 195 |
+
├── chat_template.jinja # Chat template
|
| 196 |
+
├── configuration_superlinear_exp.py # Custom config class
|
| 197 |
+
├── modeling_superlinear_exp.py # Custom model implementation
|
| 198 |
+
├── moe.py # MoE components
|
| 199 |
+
├── model-*.safetensors # Model weights (16 shards)
|
| 200 |
+
├── model.safetensors.index.json # Weight index
|
| 201 |
+
├── LICENSE # NVIDIA Open Model License
|
| 202 |
+
├── NOTICE # Required attribution
|
| 203 |
+
└── README.md # This file
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
## License
|
| 207 |
+
|
| 208 |
+
### Model Weights
|
| 209 |
+
|
| 210 |
+
This model is a derivative of [nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) and is distributed under the **NVIDIA Open Model License Agreement**.
|
| 211 |
+
|
| 212 |
+
See [LICENSE](LICENSE) for the full license text.
|
| 213 |
+
|
| 214 |
+
Use of this model must be consistent with [NVIDIA's Trustworthy AI terms](https://www.nvidia.com/en-us/agreements/trustworthy-ai/terms/).
|
| 215 |
+
|
| 216 |
+
### Code
|
| 217 |
+
|
| 218 |
+
The modeling code in this repository is provided for loading and running the model. For the broader Superlinear project codebase, see [github.com/concavity-ai/superlinear](https://github.com/concavity-ai/superlinear) (Apache-2.0).
|
| 219 |
+
|
| 220 |
+
## Attribution
|
| 221 |
+
|
| 222 |
+
**Upstream Model:**
|
| 223 |
+
- NVIDIA Nemotron-3-Nano-30B-A3B ([nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16))
|
| 224 |
+
|
| 225 |
+
**Paper:**
|
| 226 |
+
```bibtex
|
| 227 |
+
@article{huang2026superlinear,
|
| 228 |
+
title={Superlinear Multi-Step Attention},
|
| 229 |
+
author={Huang, Yufeng},
|
| 230 |
+
journal={arXiv preprint arXiv:2601.18401},
|
| 231 |
+
year={2026}
|
| 232 |
+
}
|
| 233 |
+
```
|
| 234 |
+
|
| 235 |
+
## Patent Notice
|
| 236 |
+
|
| 237 |
+
Patent applications have been filed related to aspects of the methods described in this work.
|
| 238 |
+
|
| 239 |
+
## Contact
|
| 240 |
+
|
| 241 |
+
- Author: Yufeng Huang
|
| 242 |
+
- Email: yufeng@concavity.ai
|
| 243 |
+
- Organization: Concavity AI
|
__pycache__/configuration_superlinear_exp.cpython-312.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
__pycache__/configuration_superlinear_exp.cpython-313.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
__pycache__/modeling_superlinear_exp.cpython-312.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:468f989a95d55eb2d38d742756830863d7bffb25e39fb9c304a200b4d9c63b2d
|
| 3 |
+
size 164921
|
__pycache__/modeling_superlinear_exp.cpython-313.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2dc47238f394aca2f6c7176e3fa059bc975b660170d0d7fc6e385e0d6f69750d
|
| 3 |
+
size 168645
|
__pycache__/moe.cpython-313.pyc
ADDED
|
Binary file (35.1 kB). View file
|
|
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{% macro render_extra_keys(json_dict, handled_keys) %}
|
| 2 |
+
{%- if json_dict is mapping %}
|
| 3 |
+
{%- for json_key in json_dict if json_key not in handled_keys %}
|
| 4 |
+
{%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %}
|
| 5 |
+
{{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }}
|
| 6 |
+
{%- else %}
|
| 7 |
+
{{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }}
|
| 8 |
+
{%- endif %}
|
| 9 |
+
{%- endfor %}
|
| 10 |
+
{%- endif %}
|
| 11 |
+
{% endmacro %}
|
| 12 |
+
{%- set enable_thinking = enable_thinking if enable_thinking is defined else True %}
|
| 13 |
+
{%- set truncate_history_thinking = truncate_history_thinking if truncate_history_thinking is defined else True %}
|
| 14 |
+
|
| 15 |
+
{%- set ns = namespace(last_user_idx = -1) %}
|
| 16 |
+
{%- set loop_messages = messages %}
|
| 17 |
+
{%- for m in loop_messages %}
|
| 18 |
+
{%- if m["role"] == "user" %}
|
| 19 |
+
{%- set ns.last_user_idx = loop.index0 %}
|
| 20 |
+
{%- endif %}
|
| 21 |
+
{%- endfor %}
|
| 22 |
+
|
| 23 |
+
{%- if messages[0]["role"] == "system" %}
|
| 24 |
+
{%- set system_message = messages[0]["content"] %}
|
| 25 |
+
{%- set loop_messages = messages[1:] %}
|
| 26 |
+
{%- else %}
|
| 27 |
+
{%- set system_message = "" %}
|
| 28 |
+
{%- set loop_messages = messages %}
|
| 29 |
+
{%- endif %}
|
| 30 |
+
{%- if not tools is defined %}
|
| 31 |
+
{%- set tools = [] %}
|
| 32 |
+
{%- endif %}
|
| 33 |
+
{# Recompute last_user_idx relative to loop_messages after handling system #}
|
| 34 |
+
{%- set ns = namespace(last_user_idx = -1) %}
|
| 35 |
+
{%- for m in loop_messages %}
|
| 36 |
+
{%- if m["role"] == "user" %}
|
| 37 |
+
{%- set ns.last_user_idx = loop.index0 %}
|
| 38 |
+
{%- endif %}
|
| 39 |
+
{%- endfor %}
|
| 40 |
+
{%- if system_message is defined %}
|
| 41 |
+
{{- "<|im_start|>system\n" + system_message }}
|
| 42 |
+
{%- else %}
|
| 43 |
+
{%- if tools is iterable and tools | length > 0 %}
|
| 44 |
+
{{- "<|im_start|>system\n" }}
|
| 45 |
+
{%- endif %}
|
| 46 |
+
{%- endif %}
|
| 47 |
+
{%- if tools is iterable and tools | length > 0 %}
|
| 48 |
+
{%- if system_message is defined and system_message | length > 0 %}
|
| 49 |
+
{{- "\n\n" }}
|
| 50 |
+
{%- endif %}
|
| 51 |
+
{{- "# Tools\n\nYou have access to the following functions:\n\n" }}
|
| 52 |
+
{{- "<tools>" }}
|
| 53 |
+
{%- for tool in tools %}
|
| 54 |
+
{%- if tool.function is defined %}
|
| 55 |
+
{%- set tool = tool.function %}
|
| 56 |
+
{%- endif %}
|
| 57 |
+
{{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }}
|
| 58 |
+
{%- if tool.description is defined %}
|
| 59 |
+
{{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }}
|
| 60 |
+
{%- endif %}
|
| 61 |
+
{{- '\n<parameters>' }}
|
| 62 |
+
{%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}
|
| 63 |
+
{%- for param_name, param_fields in tool.parameters.properties|items %}
|
| 64 |
+
{{- '\n<parameter>' }}
|
| 65 |
+
{{- '\n<name>' ~ param_name ~ '</name>' }}
|
| 66 |
+
{%- if param_fields.type is defined %}
|
| 67 |
+
{{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }}
|
| 68 |
+
{%- endif %}
|
| 69 |
+
{%- if param_fields.description is defined %}
|
| 70 |
+
{{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}
|
| 71 |
+
{%- endif %}
|
| 72 |
+
{%- if param_fields.enum is defined %}
|
| 73 |
+
{{- '\n<enum>' ~ (param_fields.enum | tojson | safe) ~ '</enum>' }}
|
| 74 |
+
{%- endif %}
|
| 75 |
+
{%- set handled_keys = ['name', 'type', 'description', 'enum'] %}
|
| 76 |
+
{{- render_extra_keys(param_fields, handled_keys) }}
|
| 77 |
+
{{- '\n</parameter>' }}
|
| 78 |
+
{%- endfor %}
|
| 79 |
+
{%- endif %}
|
| 80 |
+
{% set handled_keys = ['type', 'properties', 'required'] %}
|
| 81 |
+
{{- render_extra_keys(tool.parameters, handled_keys) }}
|
| 82 |
+
{%- if tool.parameters is defined and tool.parameters.required is defined %}
|
| 83 |
+
{{- '\n<required>' ~ (tool.parameters.required | tojson | safe) ~ '</required>' }}
|
| 84 |
+
{%- endif %}
|
| 85 |
+
{{- '\n</parameters>' }}
|
| 86 |
+
{%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}
|
| 87 |
+
{{- render_extra_keys(tool, handled_keys) }}
|
| 88 |
+
{{- '\n</function>' }}
|
| 89 |
+
{%- endfor %}
|
| 90 |
+
{{- "\n</tools>" }}
|
| 91 |
+
|
| 92 |
+
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
|
| 93 |
+
{%- endif %}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
{%- if system_message is defined %}
|
| 97 |
+
{{- '<|im_end|>\n' }}
|
| 98 |
+
{%- else %}
|
| 99 |
+
{%- if tools is iterable and tools | length > 0 %}
|
| 100 |
+
{{- '<|im_end|>\n' }}
|
| 101 |
+
{%- endif %}
|
| 102 |
+
{%- endif %}
|
| 103 |
+
|
| 104 |
+
{%- for message in loop_messages %}
|
| 105 |
+
{%- if message.role == "assistant" %}
|
| 106 |
+
{# Add reasoning content in to content field for unified processing below. #}
|
| 107 |
+
{%- if message.reasoning_content is defined and message.reasoning_content is string and message.reasoning_content | trim | length > 0 %}
|
| 108 |
+
{%- set content = "<think>\n" ~ message.reasoning_content ~ "\n</think>\n" ~ (message.content | default('', true)) %}
|
| 109 |
+
{%- else %}
|
| 110 |
+
{%- set content = message.content | default('', true) %}
|
| 111 |
+
{%- if content is string -%}
|
| 112 |
+
{# Allow downstream logic to to take care of broken thought, only handle coherent reasoning here. #}
|
| 113 |
+
{%- if '<think>' not in content and '</think>' not in content -%}
|
| 114 |
+
{%- set content = "<think></think>" ~ content -%}
|
| 115 |
+
{%- endif -%}
|
| 116 |
+
{%- else -%}
|
| 117 |
+
{%- set content = content -%}
|
| 118 |
+
{%- endif -%}
|
| 119 |
+
{%- endif %}
|
| 120 |
+
{%- if message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %}
|
| 121 |
+
{# Assistant message has tool calls. #}
|
| 122 |
+
{{- '<|im_start|>assistant\n' }}
|
| 123 |
+
{%- set include_content = not (truncate_history_thinking and loop.index0 < ns.last_user_idx) %}
|
| 124 |
+
{%- if content is string and content | trim | length > 0 %}
|
| 125 |
+
{%- if include_content %}
|
| 126 |
+
{{- (content | trim) ~ '\n' -}}
|
| 127 |
+
{%- else %}
|
| 128 |
+
{%- set c = (content | string) %}
|
| 129 |
+
{%- if '</think>' in c %}
|
| 130 |
+
{# Keep only content after the last closing think. Also generation prompt causes this. #}
|
| 131 |
+
{%- set c = c.split('</think>')[-1] %}
|
| 132 |
+
{%- elif '<think>' in c %}
|
| 133 |
+
{# If <think> was opened but never closed, drop the trailing think segment #}
|
| 134 |
+
{%- set c = c.split('<think>')[0] %}
|
| 135 |
+
{%- endif %}
|
| 136 |
+
{%- set c = "<think></think>" ~ c | trim %}
|
| 137 |
+
{%- if c | length > 0 %}
|
| 138 |
+
{{- c ~ '\n' -}}
|
| 139 |
+
{%- endif %}
|
| 140 |
+
{%- endif %}
|
| 141 |
+
{%- else %}
|
| 142 |
+
{{- "<think></think>" -}}
|
| 143 |
+
{%- endif %}
|
| 144 |
+
{%- for tool_call in message.tool_calls %}
|
| 145 |
+
{%- if tool_call.function is defined %}
|
| 146 |
+
{%- set tool_call = tool_call.function %}
|
| 147 |
+
{%- endif %}
|
| 148 |
+
{{- '<tool_call>\n<function=' ~ tool_call.name ~ '>\n' -}}
|
| 149 |
+
{%- if tool_call.arguments is defined %}
|
| 150 |
+
{%- for args_name, args_value in tool_call.arguments|items %}
|
| 151 |
+
{{- '<parameter=' ~ args_name ~ '>\n' -}}
|
| 152 |
+
{%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
|
| 153 |
+
{{- args_value ~ '\n</parameter>\n' -}}
|
| 154 |
+
{%- endfor %}
|
| 155 |
+
{%- endif %}
|
| 156 |
+
{{- '</function>\n</tool_call>\n' -}}
|
| 157 |
+
{%- endfor %}
|
| 158 |
+
{{- '<|im_end|>\n' }}
|
| 159 |
+
{%- else %}
|
| 160 |
+
{# Assistant message doesn't have tool calls. #}
|
| 161 |
+
{%- if not (truncate_history_thinking and loop.index0 < ns.last_user_idx) %}
|
| 162 |
+
{{- '<|im_start|>assistant\n' ~ (content | default('', true) | string | trim) ~ '<|im_end|>\n' }}
|
| 163 |
+
{%- else %}
|
| 164 |
+
{%- set c = (content | default('', true) | string) %}
|
| 165 |
+
{%- if '<think>' in c and '</think>' in c %}
|
| 166 |
+
{%- set c = "<think></think>" ~ c.split('</think>')[-1] %}
|
| 167 |
+
{%- endif %}
|
| 168 |
+
{%- set c = c | trim %}
|
| 169 |
+
{%- if c | length > 0 %}
|
| 170 |
+
{{- '<|im_start|>assistant\n' ~ c ~ '<|im_end|>\n' }}
|
| 171 |
+
{%- else %}
|
| 172 |
+
{{- '<|im_start|>assistant\n<|im_end|>\n' }}
|
| 173 |
+
{%- endif %}
|
| 174 |
+
{%- endif %}
|
| 175 |
+
{%- endif %}
|
| 176 |
+
{%- elif message.role == "user" or message.role == "system" %}
|
| 177 |
+
{{- '<|im_start|>' + message.role + '\n' }}
|
| 178 |
+
{%- set content = message.content | string %}
|
| 179 |
+
{{- content }}
|
| 180 |
+
{{- '<|im_end|>\n' }}
|
| 181 |
+
{%- elif message.role == "tool" %}
|
| 182 |
+
{%- if loop.previtem and loop.previtem.role != "tool" %}
|
| 183 |
+
{{- '<|im_start|>user\n' }}
|
| 184 |
+
{%- endif %}
|
| 185 |
+
{{- '<tool_response>\n' }}
|
| 186 |
+
{{- message.content }}
|
| 187 |
+
{{- '\n</tool_response>\n' }}
|
| 188 |
+
{%- if not loop.last and loop.nextitem.role != "tool" %}
|
| 189 |
+
{{- '<|im_end|>\n' }}
|
| 190 |
+
{%- elif loop.last %}
|
| 191 |
+
{{- '<|im_end|>\n' }}
|
| 192 |
+
{%- endif %}
|
| 193 |
+
{%- else %}
|
| 194 |
+
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }}
|
| 195 |
+
{%- endif %}
|
| 196 |
+
{%- endfor %}
|
| 197 |
+
|
| 198 |
+
{%- if add_generation_prompt %}
|
| 199 |
+
{%- if enable_thinking %}
|
| 200 |
+
{{- '<|im_start|>assistant\n<think>\n' }}
|
| 201 |
+
{%- else %}
|
| 202 |
+
{{- '<|im_start|>assistant\n<think></think>' }}
|
| 203 |
+
{%- endif %}
|
| 204 |
+
{%- endif %}
|
config.json
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"SuperlinearExpForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"auto_map": {
|
| 8 |
+
"AutoConfig": "configuration_superlinear_exp.SuperlinearExpConfig",
|
| 9 |
+
"AutoModel": "modeling_superlinear_exp.SuperlinearExpForCausalLM",
|
| 10 |
+
"AutoModelForCausalLM": "modeling_superlinear_exp.SuperlinearExpForCausalLM"
|
| 11 |
+
},
|
| 12 |
+
"bos_token_id": 1,
|
| 13 |
+
"chunk_size": 128,
|
| 14 |
+
"conv_kernel": 4,
|
| 15 |
+
"decode_kernel": "staged-gqa",
|
| 16 |
+
"enable_cuda_graph": true,
|
| 17 |
+
"enable_shared_fused_moe": true,
|
| 18 |
+
"eos_token_id": 2,
|
| 19 |
+
"expand": 2,
|
| 20 |
+
"head_dim": 128,
|
| 21 |
+
"hidden_dropout": 0.0,
|
| 22 |
+
"hidden_size": 2688,
|
| 23 |
+
"hybrid_override_pattern": "MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME",
|
| 24 |
+
"initializer_range": 0.02,
|
| 25 |
+
"intermediate_size": 1856,
|
| 26 |
+
"layer_norm_epsilon": 1e-05,
|
| 27 |
+
"mamba_head_dim": 64,
|
| 28 |
+
"mamba_hidden_act": "silu",
|
| 29 |
+
"mamba_num_heads": 64,
|
| 30 |
+
"mamba_proj_bias": false,
|
| 31 |
+
"mamba_ssm_cache_dtype": "float32",
|
| 32 |
+
"max_position_embeddings": 262144,
|
| 33 |
+
"mlp_bias": false,
|
| 34 |
+
"mlp_hidden_act": "relu2",
|
| 35 |
+
"model_type": "superlinear-exp",
|
| 36 |
+
"moe_intermediate_size": 1856,
|
| 37 |
+
"moe_shared_expert_intermediate_size": 3712,
|
| 38 |
+
"n_group": 1,
|
| 39 |
+
"n_groups": 8,
|
| 40 |
+
"n_routed_experts": 128,
|
| 41 |
+
"n_shared_experts": 1,
|
| 42 |
+
"norm_eps": 1e-05,
|
| 43 |
+
"norm_topk_prob": true,
|
| 44 |
+
"num_attention_heads": 32,
|
| 45 |
+
"num_experts_per_tok": 6,
|
| 46 |
+
"num_hidden_layers": 52,
|
| 47 |
+
"num_key_value_heads": 2,
|
| 48 |
+
"num_logits_to_keep": 1,
|
| 49 |
+
"pad_token_id": 0,
|
| 50 |
+
"partial_rotary_factor": 1.0,
|
| 51 |
+
"rescale_prenorm_residual": true,
|
| 52 |
+
"residual_in_fp32": false,
|
| 53 |
+
"rope_theta": 10000,
|
| 54 |
+
"routed_scaling_factor": 2.5,
|
| 55 |
+
"sliding_window": null,
|
| 56 |
+
"span_attention_backward_factor": 3.0,
|
| 57 |
+
"span_attention_forward_factor": 1.0,
|
| 58 |
+
"span_attention_inv_search_power_int": null,
|
| 59 |
+
"span_attention_num_spans": 3,
|
| 60 |
+
"span_attention_search_power": 0.55,
|
| 61 |
+
"span_attention_span_power": 0.55,
|
| 62 |
+
"span_attention_sw_index": 65,
|
| 63 |
+
"ssm_state_size": 128,
|
| 64 |
+
"tie_word_embeddings": false,
|
| 65 |
+
"time_step_floor": 0.0001,
|
| 66 |
+
"time_step_limit": [
|
| 67 |
+
0.0,
|
| 68 |
+
Infinity
|
| 69 |
+
],
|
| 70 |
+
"time_step_max": 0.1,
|
| 71 |
+
"time_step_min": 0.001,
|
| 72 |
+
"topk_group": 1,
|
| 73 |
+
"torch_dtype": "bfloat16",
|
| 74 |
+
"transformers_version": "4.55.4",
|
| 75 |
+
"use_bias": false,
|
| 76 |
+
"use_cache": true,
|
| 77 |
+
"use_conv_bias": true,
|
| 78 |
+
"use_mamba_kernels": true,
|
| 79 |
+
"vocab_size": 131072
|
| 80 |
+
}
|
configuration_superlinear_exp.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""SuperlinearExp model configuration"""
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
import re
|
| 20 |
+
|
| 21 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 22 |
+
from transformers.utils import logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SuperlinearExpConfig(PretrainedConfig):
|
| 29 |
+
r"""
|
| 30 |
+
This is the configuration class to store the configuration of a [`SuperlinearExpModel`]. It is used to instantiate a
|
| 31 |
+
SuperlinearExp model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 32 |
+
with the defaults will yield a similar configuration to that of the SuperlinearExp-v0.1 model.
|
| 33 |
+
[todo](todo)
|
| 34 |
+
|
| 35 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 36 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
vocab_size (`int`, *optional*, defaults to 131072):
|
| 41 |
+
Vocabulary size of the SuperlinearExp model. Defines the number of different tokens that can be represented by the
|
| 42 |
+
`inputs_ids` passed when calling [`SuperlinearExpModel`]
|
| 43 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 44 |
+
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
|
| 45 |
+
model has a output word embedding layer.
|
| 46 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 47 |
+
Dimension of the hidden representations.
|
| 48 |
+
intermediate_size (`int`, *optional*, defaults to 21504):
|
| 49 |
+
Dimension of the MLP representations.
|
| 50 |
+
num_hidden_layers (`int`, *optional*, defaults to 52):
|
| 51 |
+
Number of hidden layers in the Transformer encoder.
|
| 52 |
+
hybrid_override_pattern (`str`, *optional*, defaults to `"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"`):
|
| 53 |
+
The pattern of the hybrid model. The pattern is a string of characters where each character represents M: Mamba2, *: Attention, -: MLP
|
| 54 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 55 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 56 |
+
head_dim (`int`, *optional*, defaults to 128):
|
| 57 |
+
Dimension of each attention head.
|
| 58 |
+
num_key_value_heads (`int`, *optional*, defaults to 8):
|
| 59 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 60 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 61 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used.
|
| 62 |
+
mlp_hidden_act (`str`, *optional*, defaults to "relu2"):
|
| 63 |
+
The non-linear activation function in the MLP layers.
|
| 64 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
| 65 |
+
Whether to use bias in attention layers.
|
| 66 |
+
mlp_bias (`bool`, *optional*, defaults to `False`):
|
| 67 |
+
Whether to use bias in MLP layers.
|
| 68 |
+
use_bias (`bool`, *optional*, defaults to `False`):
|
| 69 |
+
Whether to use bias in the model.
|
| 70 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 71 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 72 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
| 73 |
+
The epsilon used by the layer normalization layers.
|
| 74 |
+
residual_in_fp32 (`bool`, *optional*, defaults to `False`):
|
| 75 |
+
Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model.
|
| 76 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 77 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 78 |
+
relevant if `config.is_decoder=True`.
|
| 79 |
+
num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
|
| 80 |
+
Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
|
| 81 |
+
integer value, only last `num_logits_to_keep` logits will be calculated.
|
| 82 |
+
pad_token_id (`int`, *optional*, defaults to 0):
|
| 83 |
+
The id of the padding token.
|
| 84 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
| 85 |
+
The id of the "beginning-of-sequence" token.
|
| 86 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 87 |
+
The id of the "end-of-sequence" token.
|
| 88 |
+
sliding_window (`int`, *optional*, defaults to None):
|
| 89 |
+
Sliding window attention window size.
|
| 90 |
+
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
| 91 |
+
The maximum sequence length that this model might ever be used with.
|
| 92 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 93 |
+
The dropout ratio for the attention probabilities.
|
| 94 |
+
hidden_dropout (`float`, *optional*, defaults to 0.0):
|
| 95 |
+
The dropout ratio for the hidden states.
|
| 96 |
+
use_mamba_kernels (`bool`, *optional*, defaults to `True`):
|
| 97 |
+
Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and
|
| 98 |
+
`causal-conv1d` are installed, and the mamba modules are running on a CUDA device.
|
| 99 |
+
ssm_state_size (`int`, *optional*, defaults to 128):
|
| 100 |
+
The dimension of the mamba state space latents.
|
| 101 |
+
mamba_num_heads (`int`, *optional*, defaults to 128):
|
| 102 |
+
Number of heads in Mamba layers.
|
| 103 |
+
mamba_n_groups (`int`, *optional*, defaults to 8):
|
| 104 |
+
Number of groups in Mamba layers.
|
| 105 |
+
mamba_head_dim (`int`, *optional*, defaults to 64):
|
| 106 |
+
Dimension of each Mamba head.
|
| 107 |
+
mamba_d_conv (`int`, *optional*, defaults to 4):
|
| 108 |
+
The size of the mamba convolution kernel.
|
| 109 |
+
mamba_expand (`int`, *optional*, defaults to 2):
|
| 110 |
+
Expanding factor used to determine the mamba intermediate size.
|
| 111 |
+
mamba_hidden_act (`str`, *optional*, defaults to "silu"):
|
| 112 |
+
The non-linear activation function in the Mamba layers.
|
| 113 |
+
mamba_dt_min (`float`, *optional*, defaults to 0.001):
|
| 114 |
+
Minimum value for the time step in Mamba.
|
| 115 |
+
mamba_dt_max (`float`, *optional*, defaults to 0.1):
|
| 116 |
+
Maximum value for the time step in Mamba.
|
| 117 |
+
mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float("inf"))):
|
| 118 |
+
Limits for the time step in Mamba.
|
| 119 |
+
mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4):
|
| 120 |
+
Floor value for time step initialization in Mamba.
|
| 121 |
+
mamba_conv_bias (`bool`, *optional*, defaults to `True`):
|
| 122 |
+
Whether to use bias in the convolution layer of the mamba mixer block.
|
| 123 |
+
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
|
| 124 |
+
Whether to use bias in the input and output projections of the mamba mixer block.
|
| 125 |
+
mamba_chunk_size (`int`, *optional*, defaults to 256):
|
| 126 |
+
Size of chunks for Mamba processing.
|
| 127 |
+
rescale_prenorm_residual (`bool`, *optional*, defaults to `True`):
|
| 128 |
+
Whether to rescale the pre-normalization residual connections.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
model_type = "superlinear-exp"
|
| 132 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 133 |
+
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
vocab_size=131072,
|
| 137 |
+
tie_word_embeddings=False,
|
| 138 |
+
hidden_size=4096,
|
| 139 |
+
intermediate_size=21504,
|
| 140 |
+
num_hidden_layers=52,
|
| 141 |
+
hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-",
|
| 142 |
+
num_attention_heads=32,
|
| 143 |
+
head_dim=128,
|
| 144 |
+
num_key_value_heads=8, # nemo: num_query_groups
|
| 145 |
+
mlp_hidden_act="relu2",
|
| 146 |
+
attention_bias=False,
|
| 147 |
+
mlp_bias=False,
|
| 148 |
+
use_bias=False,
|
| 149 |
+
initializer_range=0.02, # nemo: init_method_std
|
| 150 |
+
layer_norm_epsilon=1e-5, # nemo: layernorm_epsilon
|
| 151 |
+
residual_in_fp32=False, # Megatron Core default value
|
| 152 |
+
use_cache=True,
|
| 153 |
+
num_logits_to_keep=1,
|
| 154 |
+
pad_token_id=0,
|
| 155 |
+
bos_token_id=1,
|
| 156 |
+
eos_token_id=2,
|
| 157 |
+
sliding_window=None,
|
| 158 |
+
max_position_embeddings=4096,
|
| 159 |
+
attention_dropout=0.0,
|
| 160 |
+
hidden_dropout=0.0, # * ADDED
|
| 161 |
+
use_mamba_kernels=True,
|
| 162 |
+
ssm_state_size=128, # mamba_state_size
|
| 163 |
+
mamba_num_heads=128,
|
| 164 |
+
mamba_n_groups=8, # nemo: mamba_ssm_ngroups = num_heads
|
| 165 |
+
mamba_head_dim=64,
|
| 166 |
+
mamba_d_conv=4,
|
| 167 |
+
mamba_expand=2,
|
| 168 |
+
mamba_hidden_act="silu",
|
| 169 |
+
mamba_dt_min=0.001,
|
| 170 |
+
mamba_dt_max=0.1,
|
| 171 |
+
mamba_dt_limit=(0.0, float("inf")),
|
| 172 |
+
mamba_dt_init_floor=1e-4,
|
| 173 |
+
mamba_conv_bias=True,
|
| 174 |
+
mamba_proj_bias=False,
|
| 175 |
+
mamba_chunk_size=128,
|
| 176 |
+
rescale_prenorm_residual=True,
|
| 177 |
+
span_attention_sw_index=65,
|
| 178 |
+
span_attention_num_spans=3,
|
| 179 |
+
span_attention_backward_factor: float = 3.0,
|
| 180 |
+
span_attention_forward_factor: float = 1.0,
|
| 181 |
+
span_attention_span_power: float = 0.55,
|
| 182 |
+
span_attention_search_power: float | None = 0.55,
|
| 183 |
+
span_attention_inv_search_power_int: int | None = None,
|
| 184 |
+
decode_kernel="staged-gqa",
|
| 185 |
+
enable_cuda_graph: bool = True,
|
| 186 |
+
enable_shared_fused_moe: bool = True,
|
| 187 |
+
n_routed_experts=8,
|
| 188 |
+
n_shared_experts=1,
|
| 189 |
+
moe_intermediate_size=7688,
|
| 190 |
+
moe_shared_expert_intermediate_size=7688,
|
| 191 |
+
num_experts_per_tok=2,
|
| 192 |
+
routed_scaling_factor=1.0,
|
| 193 |
+
n_group=1,
|
| 194 |
+
topk_group=1,
|
| 195 |
+
norm_topk_prob=True,
|
| 196 |
+
**kwargs,
|
| 197 |
+
):
|
| 198 |
+
self.vocab_size = vocab_size
|
| 199 |
+
self.tie_word_embeddings = tie_word_embeddings
|
| 200 |
+
self.hidden_size = hidden_size
|
| 201 |
+
self.intermediate_size = intermediate_size
|
| 202 |
+
self.num_hidden_layers = num_hidden_layers
|
| 203 |
+
self.hybrid_override_pattern = hybrid_override_pattern
|
| 204 |
+
self.num_attention_heads = num_attention_heads
|
| 205 |
+
self.head_dim = head_dim
|
| 206 |
+
self.sliding_window = sliding_window
|
| 207 |
+
self.max_position_embeddings = max_position_embeddings
|
| 208 |
+
self.attention_dropout = attention_dropout
|
| 209 |
+
self.hidden_dropout = hidden_dropout
|
| 210 |
+
|
| 211 |
+
# Span attention configuration
|
| 212 |
+
self.span_attention_sw_index = span_attention_sw_index
|
| 213 |
+
self.span_attention_num_spans = span_attention_num_spans
|
| 214 |
+
self.span_attention_backward_factor = float(span_attention_backward_factor)
|
| 215 |
+
self.span_attention_forward_factor = float(span_attention_forward_factor)
|
| 216 |
+
self.span_attention_span_power = float(span_attention_span_power)
|
| 217 |
+
if not math.isfinite(self.span_attention_backward_factor) or self.span_attention_backward_factor < 0.0:
|
| 218 |
+
raise ValueError(
|
| 219 |
+
"span_attention_backward_factor must be finite and >= 0 "
|
| 220 |
+
f"(got {self.span_attention_backward_factor})"
|
| 221 |
+
)
|
| 222 |
+
if not math.isfinite(self.span_attention_forward_factor) or self.span_attention_forward_factor < 0.0:
|
| 223 |
+
raise ValueError(
|
| 224 |
+
"span_attention_forward_factor must be finite and >= 0 "
|
| 225 |
+
f"(got {self.span_attention_forward_factor})"
|
| 226 |
+
)
|
| 227 |
+
if (self.span_attention_backward_factor + self.span_attention_forward_factor) <= 0.0:
|
| 228 |
+
raise ValueError(
|
| 229 |
+
"span_attention_backward_factor + span_attention_forward_factor must be > 0 "
|
| 230 |
+
f"(got {self.span_attention_backward_factor + self.span_attention_forward_factor})"
|
| 231 |
+
)
|
| 232 |
+
if not math.isfinite(self.span_attention_span_power) or not (0.0 < self.span_attention_span_power < 1.0):
|
| 233 |
+
raise ValueError(
|
| 234 |
+
"span_attention_span_power must be finite and in (0, 1) "
|
| 235 |
+
f"(got {self.span_attention_span_power})"
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# Stripe power parameters (search stripes + sliding window width).
|
| 239 |
+
if (span_attention_inv_search_power_int is None) == (span_attention_search_power is None):
|
| 240 |
+
raise ValueError(
|
| 241 |
+
"Provide exactly one of span_attention_inv_search_power_int or span_attention_search_power"
|
| 242 |
+
)
|
| 243 |
+
if span_attention_inv_search_power_int is not None:
|
| 244 |
+
inv_n = int(span_attention_inv_search_power_int)
|
| 245 |
+
if inv_n not in (2, 3, 4, 5, 6):
|
| 246 |
+
raise ValueError(
|
| 247 |
+
"span_attention_inv_search_power_int must be one of (2, 3, 4, 5, 6) "
|
| 248 |
+
f"(got {span_attention_inv_search_power_int})"
|
| 249 |
+
)
|
| 250 |
+
self.span_attention_inv_search_power_int = inv_n
|
| 251 |
+
self.span_attention_search_power = None
|
| 252 |
+
derived_p = 1.0 / float(inv_n)
|
| 253 |
+
else:
|
| 254 |
+
p = float(span_attention_search_power)
|
| 255 |
+
if not math.isfinite(p) or not (0.0 < p < 1.0):
|
| 256 |
+
raise ValueError(
|
| 257 |
+
"span_attention_search_power must be finite and in (0, 1) "
|
| 258 |
+
f"(got {span_attention_search_power})"
|
| 259 |
+
)
|
| 260 |
+
self.span_attention_inv_search_power_int = None
|
| 261 |
+
self.span_attention_search_power = p
|
| 262 |
+
derived_p = p
|
| 263 |
+
|
| 264 |
+
# Critical coverage constraint (see 47.1_generalized_stripe_power_design.md):
|
| 265 |
+
# span_power >= 1 - search_power.
|
| 266 |
+
if self.span_attention_span_power + derived_p < 1.0:
|
| 267 |
+
raise ValueError(
|
| 268 |
+
"span_attention_span_power must satisfy span_power >= 1 - search_power "
|
| 269 |
+
f"(got span_power={self.span_attention_span_power}, search_power={derived_p})"
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
if decode_kernel not in (None, "staged", "staged-gqa"):
|
| 273 |
+
raise ValueError(
|
| 274 |
+
f"Invalid decode_kernel={decode_kernel!r}; expected one of None, 'staged', 'staged-gqa'."
|
| 275 |
+
)
|
| 276 |
+
self.decode_kernel = decode_kernel
|
| 277 |
+
self.enable_cuda_graph = enable_cuda_graph
|
| 278 |
+
self.enable_shared_fused_moe = enable_shared_fused_moe
|
| 279 |
+
|
| 280 |
+
# Validate hybrid_override_pattern
|
| 281 |
+
# M: Mamba2, *: Attention, -: MLP
|
| 282 |
+
assert len(self.hybrid_override_pattern) == self.num_hidden_layers, "hybrid_override_pattern must have the same length as num_hidden_layers"
|
| 283 |
+
assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), "hybrid_override_pattern must only contain characters 'M', '*', or '-'"
|
| 284 |
+
|
| 285 |
+
# for backward compatibility
|
| 286 |
+
if num_key_value_heads is None:
|
| 287 |
+
num_key_value_heads = num_attention_heads
|
| 288 |
+
|
| 289 |
+
self.num_key_value_heads = num_key_value_heads
|
| 290 |
+
self.mlp_hidden_act = mlp_hidden_act
|
| 291 |
+
self.attention_bias = attention_bias
|
| 292 |
+
self.mlp_bias = mlp_bias
|
| 293 |
+
self.use_bias = use_bias
|
| 294 |
+
self.initializer_range = initializer_range
|
| 295 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
| 296 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 297 |
+
|
| 298 |
+
self.use_cache = use_cache
|
| 299 |
+
self.num_logits_to_keep = num_logits_to_keep
|
| 300 |
+
|
| 301 |
+
self.use_mamba_kernels = use_mamba_kernels
|
| 302 |
+
self.n_groups = mamba_n_groups
|
| 303 |
+
self.mamba_head_dim = mamba_head_dim
|
| 304 |
+
self.ssm_state_size = ssm_state_size
|
| 305 |
+
self.mamba_num_heads = mamba_num_heads
|
| 306 |
+
self.conv_kernel = mamba_d_conv
|
| 307 |
+
self.expand = mamba_expand
|
| 308 |
+
self.mamba_hidden_act = mamba_hidden_act
|
| 309 |
+
self.time_step_min = mamba_dt_min
|
| 310 |
+
self.time_step_max = mamba_dt_max
|
| 311 |
+
self.time_step_limit = mamba_dt_limit
|
| 312 |
+
self.time_step_floor = mamba_dt_init_floor
|
| 313 |
+
self.use_conv_bias = mamba_conv_bias
|
| 314 |
+
self.mamba_proj_bias = mamba_proj_bias
|
| 315 |
+
self.chunk_size = mamba_chunk_size
|
| 316 |
+
self.rescale_prenorm_residual = rescale_prenorm_residual
|
| 317 |
+
self.n_routed_experts = n_routed_experts
|
| 318 |
+
self.n_shared_experts = n_shared_experts
|
| 319 |
+
self.moe_intermediate_size = moe_intermediate_size
|
| 320 |
+
self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size
|
| 321 |
+
self.num_experts_per_tok = num_experts_per_tok
|
| 322 |
+
self.routed_scaling_factor = routed_scaling_factor
|
| 323 |
+
self.n_group = n_group
|
| 324 |
+
self.topk_group = topk_group
|
| 325 |
+
self.norm_topk_prob = norm_topk_prob
|
| 326 |
+
|
| 327 |
+
super().__init__(
|
| 328 |
+
pad_token_id=pad_token_id,
|
| 329 |
+
bos_token_id=bos_token_id,
|
| 330 |
+
eos_token_id=eos_token_id,
|
| 331 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 332 |
+
**kwargs,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
@property
|
| 336 |
+
def layers_block_type(self):
|
| 337 |
+
return [
|
| 338 |
+
"mamba" if self.hybrid_override_pattern[i] == "M" else
|
| 339 |
+
"attention" if self.hybrid_override_pattern[i] == "*" else
|
| 340 |
+
"mlp" if self.hybrid_override_pattern[i] == "-" else "moe"
|
| 341 |
+
for i in range(self.num_hidden_layers)]
|
generation_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"do_sample": true,
|
| 5 |
+
"eos_token_id": [
|
| 6 |
+
2,
|
| 7 |
+
11
|
| 8 |
+
],
|
| 9 |
+
"pad_token_id": 0,
|
| 10 |
+
"transformers_version": "4.57.3"
|
| 11 |
+
}
|
model-00001-of-00016.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:97abe5c9503b8f60e502d509898c658fb8c366b32308463e6fc2cd22d9533973
|
| 3 |
+
size 4654236816
|
model-00002-of-00016.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:593e019ec60819249448c91e647abff72d9171a212d1664c0957dd9b2cb94bad
|
| 3 |
+
size 4136509104
|
model-00003-of-00016.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:595c430f59919f62e1c3ba70596e853f004bb20710cbbe5b679ba8b22ef4c720
|
| 3 |
+
size 3949593672
|
model-00004-of-00016.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4fb6e9c083f14e67e0a8c6f1c90dd73d77fd0dadbc3a76203ead3f5e3b67cf91
|
| 3 |
+
size 4213999864
|
model-00005-of-00016.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:40bc0337b6b2f619e9641278ed97058cc099ca001d4d3ff80de0d97a886b5411
|
| 3 |
+
size 3949593672
|
model-00006-of-00016.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6104484169ae2ffa3b919187a176e762414dfe48a6d858461db935278441ffcd
|
| 3 |
+
size 4136509104
|
model-00007-of-00016.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d0bb006d8d803aa869ad584a8c6e5926b398f0c04ddc1740d842d3d6a369f8ee
|
| 3 |
+
size 3872102896
|
model-00008-of-00016.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:485cf181874c6b9cb817407bab34b8334e5515329323f9683d7e93b25679c331
|
| 3 |
+
size 4136509096
|
model-00009-of-00016.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2b49942cf27c63bec56f5484b113c50442735a3c5885890b43f54cb0509ee555
|
| 3 |
+
size 3949593672
|
model-00010-of-00016.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4cd5c8dfe3bba15d056458fca0413511f576c3fae10b9f2daae88d247b17db1b
|
| 3 |
+
size 4145181008
|
model-00011-of-00016.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:71e77335c4a7af5d8859a2108039a4f30bebcaf278ebd4b56f9dd8f403495abb
|
| 3 |
+
size 4018412536
|
model-00012-of-00016.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ad603897bb6b7c951bd69e9f89009521909163c8a6195c4fdab00b545fd99c03
|
| 3 |
+
size 4067690240
|
model-00013-of-00016.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e79a65b74d9ad6261748131aa86a7d84833bdf5d7609816219e4c7d5bd4c6ec0
|
| 3 |
+
size 3949593672
|
model-00014-of-00016.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f97e980088587a0b273d4709aad56b9e33b31e6887b483941f67fa06c0a2ee4
|
| 3 |
+
size 4059018320
|
model-00015-of-00016.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:684690b0a3770d79e2466e26fc5d75930572dea5e3b1e42e3aeaff49bcbb5566
|
| 3 |
+
size 3949593656
|
model-00016-of-00016.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:96d2eda64709c79a4d25ebc2d96d8abe118bbda0150f7874add27cf9695db7f2
|
| 3 |
+
size 2099910896
|
model.safetensors.index.json
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 63288001152
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"backbone.embeddings.weight": "model-00001-of-00016.safetensors",
|
| 7 |
+
"backbone.layers.0.mixer.A_log": "model-00001-of-00016.safetensors",
|
| 8 |
+
"backbone.layers.0.mixer.D": "model-00001-of-00016.safetensors",
|
| 9 |
+
"backbone.layers.0.mixer.conv1d.bias": "model-00001-of-00016.safetensors",
|
| 10 |
+
"backbone.layers.0.mixer.conv1d.weight": "model-00001-of-00016.safetensors",
|
| 11 |
+
"backbone.layers.0.mixer.dt_bias": "model-00001-of-00016.safetensors",
|
| 12 |
+
"backbone.layers.0.mixer.in_proj.weight": "model-00001-of-00016.safetensors",
|
| 13 |
+
"backbone.layers.0.mixer.norm.weight": "model-00001-of-00016.safetensors",
|
| 14 |
+
"backbone.layers.0.mixer.out_proj.weight": "model-00001-of-00016.safetensors",
|
| 15 |
+
"backbone.layers.0.norm.weight": "model-00001-of-00016.safetensors",
|
| 16 |
+
"backbone.layers.1.mixer.experts.down_proj.weight": "model-00001-of-00016.safetensors",
|
| 17 |
+
"backbone.layers.1.mixer.experts.up_proj.weight": "model-00001-of-00016.safetensors",
|
| 18 |
+
"backbone.layers.1.mixer.gate.e_score_correction_bias": "model-00001-of-00016.safetensors",
|
| 19 |
+
"backbone.layers.1.mixer.gate.weight": "model-00001-of-00016.safetensors",
|
| 20 |
+
"backbone.layers.1.mixer.shared_experts.down_proj.weight": "model-00001-of-00016.safetensors",
|
| 21 |
+
"backbone.layers.1.mixer.shared_experts.up_proj.weight": "model-00001-of-00016.safetensors",
|
| 22 |
+
"backbone.layers.1.norm.weight": "model-00001-of-00016.safetensors",
|
| 23 |
+
"backbone.layers.10.mixer.experts.down_proj.weight": "model-00001-of-00016.safetensors",
|
| 24 |
+
"backbone.layers.10.mixer.experts.up_proj.weight": "model-00002-of-00016.safetensors",
|
| 25 |
+
"backbone.layers.10.mixer.gate.e_score_correction_bias": "model-00002-of-00016.safetensors",
|
| 26 |
+
"backbone.layers.10.mixer.gate.weight": "model-00002-of-00016.safetensors",
|
| 27 |
+
"backbone.layers.10.mixer.shared_experts.down_proj.weight": "model-00002-of-00016.safetensors",
|
| 28 |
+
"backbone.layers.10.mixer.shared_experts.up_proj.weight": "model-00002-of-00016.safetensors",
|
| 29 |
+
"backbone.layers.10.norm.weight": "model-00002-of-00016.safetensors",
|
| 30 |
+
"backbone.layers.11.mixer.A_log": "model-00002-of-00016.safetensors",
|
| 31 |
+
"backbone.layers.11.mixer.D": "model-00002-of-00016.safetensors",
|
| 32 |
+
"backbone.layers.11.mixer.conv1d.bias": "model-00002-of-00016.safetensors",
|
| 33 |
+
"backbone.layers.11.mixer.conv1d.weight": "model-00002-of-00016.safetensors",
|
| 34 |
+
"backbone.layers.11.mixer.dt_bias": "model-00002-of-00016.safetensors",
|
| 35 |
+
"backbone.layers.11.mixer.in_proj.weight": "model-00002-of-00016.safetensors",
|
| 36 |
+
"backbone.layers.11.mixer.norm.weight": "model-00002-of-00016.safetensors",
|
| 37 |
+
"backbone.layers.11.mixer.out_proj.weight": "model-00002-of-00016.safetensors",
|
| 38 |
+
"backbone.layers.11.norm.weight": "model-00002-of-00016.safetensors",
|
| 39 |
+
"backbone.layers.12.mixer.k_proj.weight": "model-00002-of-00016.safetensors",
|
| 40 |
+
"backbone.layers.12.mixer.o_proj.weight": "model-00002-of-00016.safetensors",
|
| 41 |
+
"backbone.layers.12.mixer.q_proj.weight": "model-00002-of-00016.safetensors",
|
| 42 |
+
"backbone.layers.12.mixer.s_proj.weight": "model-00002-of-00016.safetensors",
|
| 43 |
+
"backbone.layers.12.mixer.v_proj.weight": "model-00002-of-00016.safetensors",
|
| 44 |
+
"backbone.layers.12.norm.weight": "model-00002-of-00016.safetensors",
|
| 45 |
+
"backbone.layers.13.mixer.experts.down_proj.weight": "model-00002-of-00016.safetensors",
|
| 46 |
+
"backbone.layers.13.mixer.experts.up_proj.weight": "model-00002-of-00016.safetensors",
|
| 47 |
+
"backbone.layers.13.mixer.gate.e_score_correction_bias": "model-00002-of-00016.safetensors",
|
| 48 |
+
"backbone.layers.13.mixer.gate.weight": "model-00002-of-00016.safetensors",
|
| 49 |
+
"backbone.layers.13.mixer.shared_experts.down_proj.weight": "model-00002-of-00016.safetensors",
|
| 50 |
+
"backbone.layers.13.mixer.shared_experts.up_proj.weight": "model-00002-of-00016.safetensors",
|
| 51 |
+
"backbone.layers.13.norm.weight": "model-00002-of-00016.safetensors",
|
| 52 |
+
"backbone.layers.14.mixer.A_log": "model-00002-of-00016.safetensors",
|
| 53 |
+
"backbone.layers.14.mixer.D": "model-00002-of-00016.safetensors",
|
| 54 |
+
"backbone.layers.14.mixer.conv1d.bias": "model-00002-of-00016.safetensors",
|
| 55 |
+
"backbone.layers.14.mixer.conv1d.weight": "model-00002-of-00016.safetensors",
|
| 56 |
+
"backbone.layers.14.mixer.dt_bias": "model-00002-of-00016.safetensors",
|
| 57 |
+
"backbone.layers.14.mixer.in_proj.weight": "model-00002-of-00016.safetensors",
|
| 58 |
+
"backbone.layers.14.mixer.norm.weight": "model-00002-of-00016.safetensors",
|
| 59 |
+
"backbone.layers.14.mixer.out_proj.weight": "model-00002-of-00016.safetensors",
|
| 60 |
+
"backbone.layers.14.norm.weight": "model-00002-of-00016.safetensors",
|
| 61 |
+
"backbone.layers.15.mixer.experts.down_proj.weight": "model-00003-of-00016.safetensors",
|
| 62 |
+
"backbone.layers.15.mixer.experts.up_proj.weight": "model-00003-of-00016.safetensors",
|
| 63 |
+
"backbone.layers.15.mixer.gate.e_score_correction_bias": "model-00003-of-00016.safetensors",
|
| 64 |
+
"backbone.layers.15.mixer.gate.weight": "model-00003-of-00016.safetensors",
|
| 65 |
+
"backbone.layers.15.mixer.shared_experts.down_proj.weight": "model-00003-of-00016.safetensors",
|
| 66 |
+
"backbone.layers.15.mixer.shared_experts.up_proj.weight": "model-00003-of-00016.safetensors",
|
| 67 |
+
"backbone.layers.15.norm.weight": "model-00003-of-00016.safetensors",
|
| 68 |
+
"backbone.layers.16.mixer.A_log": "model-00003-of-00016.safetensors",
|
| 69 |
+
"backbone.layers.16.mixer.D": "model-00003-of-00016.safetensors",
|
| 70 |
+
"backbone.layers.16.mixer.conv1d.bias": "model-00003-of-00016.safetensors",
|
| 71 |
+
"backbone.layers.16.mixer.conv1d.weight": "model-00003-of-00016.safetensors",
|
| 72 |
+
"backbone.layers.16.mixer.dt_bias": "model-00003-of-00016.safetensors",
|
| 73 |
+
"backbone.layers.16.mixer.in_proj.weight": "model-00003-of-00016.safetensors",
|
| 74 |
+
"backbone.layers.16.mixer.norm.weight": "model-00003-of-00016.safetensors",
|
| 75 |
+
"backbone.layers.16.mixer.out_proj.weight": "model-00003-of-00016.safetensors",
|
| 76 |
+
"backbone.layers.16.norm.weight": "model-00003-of-00016.safetensors",
|
| 77 |
+
"backbone.layers.17.mixer.experts.down_proj.weight": "model-00003-of-00016.safetensors",
|
| 78 |
+
"backbone.layers.17.mixer.experts.up_proj.weight": "model-00004-of-00016.safetensors",
|
| 79 |
+
"backbone.layers.17.mixer.gate.e_score_correction_bias": "model-00004-of-00016.safetensors",
|
| 80 |
+
"backbone.layers.17.mixer.gate.weight": "model-00004-of-00016.safetensors",
|
| 81 |
+
"backbone.layers.17.mixer.shared_experts.down_proj.weight": "model-00004-of-00016.safetensors",
|
| 82 |
+
"backbone.layers.17.mixer.shared_experts.up_proj.weight": "model-00004-of-00016.safetensors",
|
| 83 |
+
"backbone.layers.17.norm.weight": "model-00004-of-00016.safetensors",
|
| 84 |
+
"backbone.layers.18.mixer.A_log": "model-00004-of-00016.safetensors",
|
| 85 |
+
"backbone.layers.18.mixer.D": "model-00004-of-00016.safetensors",
|
| 86 |
+
"backbone.layers.18.mixer.conv1d.bias": "model-00004-of-00016.safetensors",
|
| 87 |
+
"backbone.layers.18.mixer.conv1d.weight": "model-00004-of-00016.safetensors",
|
| 88 |
+
"backbone.layers.18.mixer.dt_bias": "model-00004-of-00016.safetensors",
|
| 89 |
+
"backbone.layers.18.mixer.in_proj.weight": "model-00004-of-00016.safetensors",
|
| 90 |
+
"backbone.layers.18.mixer.norm.weight": "model-00004-of-00016.safetensors",
|
| 91 |
+
"backbone.layers.18.mixer.out_proj.weight": "model-00004-of-00016.safetensors",
|
| 92 |
+
"backbone.layers.18.norm.weight": "model-00004-of-00016.safetensors",
|
| 93 |
+
"backbone.layers.19.mixer.k_proj.weight": "model-00004-of-00016.safetensors",
|
| 94 |
+
"backbone.layers.19.mixer.o_proj.weight": "model-00004-of-00016.safetensors",
|
| 95 |
+
"backbone.layers.19.mixer.q_proj.weight": "model-00004-of-00016.safetensors",
|
| 96 |
+
"backbone.layers.19.mixer.s_proj.weight": "model-00004-of-00016.safetensors",
|
| 97 |
+
"backbone.layers.19.mixer.v_proj.weight": "model-00004-of-00016.safetensors",
|
| 98 |
+
"backbone.layers.19.norm.weight": "model-00004-of-00016.safetensors",
|
| 99 |
+
"backbone.layers.2.mixer.A_log": "model-00004-of-00016.safetensors",
|
| 100 |
+
"backbone.layers.2.mixer.D": "model-00004-of-00016.safetensors",
|
| 101 |
+
"backbone.layers.2.mixer.conv1d.bias": "model-00004-of-00016.safetensors",
|
| 102 |
+
"backbone.layers.2.mixer.conv1d.weight": "model-00004-of-00016.safetensors",
|
| 103 |
+
"backbone.layers.2.mixer.dt_bias": "model-00004-of-00016.safetensors",
|
| 104 |
+
"backbone.layers.2.mixer.in_proj.weight": "model-00004-of-00016.safetensors",
|
| 105 |
+
"backbone.layers.2.mixer.norm.weight": "model-00004-of-00016.safetensors",
|
| 106 |
+
"backbone.layers.2.mixer.out_proj.weight": "model-00004-of-00016.safetensors",
|
| 107 |
+
"backbone.layers.2.norm.weight": "model-00004-of-00016.safetensors",
|
| 108 |
+
"backbone.layers.20.mixer.experts.down_proj.weight": "model-00004-of-00016.safetensors",
|
| 109 |
+
"backbone.layers.20.mixer.experts.up_proj.weight": "model-00004-of-00016.safetensors",
|
| 110 |
+
"backbone.layers.20.mixer.gate.e_score_correction_bias": "model-00004-of-00016.safetensors",
|
| 111 |
+
"backbone.layers.20.mixer.gate.weight": "model-00004-of-00016.safetensors",
|
| 112 |
+
"backbone.layers.20.mixer.shared_experts.down_proj.weight": "model-00004-of-00016.safetensors",
|
| 113 |
+
"backbone.layers.20.mixer.shared_experts.up_proj.weight": "model-00004-of-00016.safetensors",
|
| 114 |
+
"backbone.layers.20.norm.weight": "model-00004-of-00016.safetensors",
|
| 115 |
+
"backbone.layers.21.mixer.A_log": "model-00004-of-00016.safetensors",
|
| 116 |
+
"backbone.layers.21.mixer.D": "model-00004-of-00016.safetensors",
|
| 117 |
+
"backbone.layers.21.mixer.conv1d.bias": "model-00004-of-00016.safetensors",
|
| 118 |
+
"backbone.layers.21.mixer.conv1d.weight": "model-00004-of-00016.safetensors",
|
| 119 |
+
"backbone.layers.21.mixer.dt_bias": "model-00004-of-00016.safetensors",
|
| 120 |
+
"backbone.layers.21.mixer.in_proj.weight": "model-00004-of-00016.safetensors",
|
| 121 |
+
"backbone.layers.21.mixer.norm.weight": "model-00004-of-00016.safetensors",
|
| 122 |
+
"backbone.layers.21.mixer.out_proj.weight": "model-00004-of-00016.safetensors",
|
| 123 |
+
"backbone.layers.21.norm.weight": "model-00004-of-00016.safetensors",
|
| 124 |
+
"backbone.layers.22.mixer.experts.down_proj.weight": "model-00005-of-00016.safetensors",
|
| 125 |
+
"backbone.layers.22.mixer.experts.up_proj.weight": "model-00005-of-00016.safetensors",
|
| 126 |
+
"backbone.layers.22.mixer.gate.e_score_correction_bias": "model-00005-of-00016.safetensors",
|
| 127 |
+
"backbone.layers.22.mixer.gate.weight": "model-00005-of-00016.safetensors",
|
| 128 |
+
"backbone.layers.22.mixer.shared_experts.down_proj.weight": "model-00005-of-00016.safetensors",
|
| 129 |
+
"backbone.layers.22.mixer.shared_experts.up_proj.weight": "model-00005-of-00016.safetensors",
|
| 130 |
+
"backbone.layers.22.norm.weight": "model-00005-of-00016.safetensors",
|
| 131 |
+
"backbone.layers.23.mixer.A_log": "model-00005-of-00016.safetensors",
|
| 132 |
+
"backbone.layers.23.mixer.D": "model-00005-of-00016.safetensors",
|
| 133 |
+
"backbone.layers.23.mixer.conv1d.bias": "model-00005-of-00016.safetensors",
|
| 134 |
+
"backbone.layers.23.mixer.conv1d.weight": "model-00005-of-00016.safetensors",
|
| 135 |
+
"backbone.layers.23.mixer.dt_bias": "model-00005-of-00016.safetensors",
|
| 136 |
+
"backbone.layers.23.mixer.in_proj.weight": "model-00005-of-00016.safetensors",
|
| 137 |
+
"backbone.layers.23.mixer.norm.weight": "model-00005-of-00016.safetensors",
|
| 138 |
+
"backbone.layers.23.mixer.out_proj.weight": "model-00005-of-00016.safetensors",
|
| 139 |
+
"backbone.layers.23.norm.weight": "model-00005-of-00016.safetensors",
|
| 140 |
+
"backbone.layers.24.mixer.experts.down_proj.weight": "model-00005-of-00016.safetensors",
|
| 141 |
+
"backbone.layers.24.mixer.experts.up_proj.weight": "model-00006-of-00016.safetensors",
|
| 142 |
+
"backbone.layers.24.mixer.gate.e_score_correction_bias": "model-00006-of-00016.safetensors",
|
| 143 |
+
"backbone.layers.24.mixer.gate.weight": "model-00006-of-00016.safetensors",
|
| 144 |
+
"backbone.layers.24.mixer.shared_experts.down_proj.weight": "model-00006-of-00016.safetensors",
|
| 145 |
+
"backbone.layers.24.mixer.shared_experts.up_proj.weight": "model-00006-of-00016.safetensors",
|
| 146 |
+
"backbone.layers.24.norm.weight": "model-00006-of-00016.safetensors",
|
| 147 |
+
"backbone.layers.25.mixer.A_log": "model-00006-of-00016.safetensors",
|
| 148 |
+
"backbone.layers.25.mixer.D": "model-00006-of-00016.safetensors",
|
| 149 |
+
"backbone.layers.25.mixer.conv1d.bias": "model-00006-of-00016.safetensors",
|
| 150 |
+
"backbone.layers.25.mixer.conv1d.weight": "model-00006-of-00016.safetensors",
|
| 151 |
+
"backbone.layers.25.mixer.dt_bias": "model-00006-of-00016.safetensors",
|
| 152 |
+
"backbone.layers.25.mixer.in_proj.weight": "model-00006-of-00016.safetensors",
|
| 153 |
+
"backbone.layers.25.mixer.norm.weight": "model-00006-of-00016.safetensors",
|
| 154 |
+
"backbone.layers.25.mixer.out_proj.weight": "model-00006-of-00016.safetensors",
|
| 155 |
+
"backbone.layers.25.norm.weight": "model-00006-of-00016.safetensors",
|
| 156 |
+
"backbone.layers.26.mixer.k_proj.weight": "model-00006-of-00016.safetensors",
|
| 157 |
+
"backbone.layers.26.mixer.o_proj.weight": "model-00006-of-00016.safetensors",
|
| 158 |
+
"backbone.layers.26.mixer.q_proj.weight": "model-00006-of-00016.safetensors",
|
| 159 |
+
"backbone.layers.26.mixer.s_proj.weight": "model-00006-of-00016.safetensors",
|
| 160 |
+
"backbone.layers.26.mixer.v_proj.weight": "model-00006-of-00016.safetensors",
|
| 161 |
+
"backbone.layers.26.norm.weight": "model-00006-of-00016.safetensors",
|
| 162 |
+
"backbone.layers.27.mixer.experts.down_proj.weight": "model-00006-of-00016.safetensors",
|
| 163 |
+
"backbone.layers.27.mixer.experts.up_proj.weight": "model-00006-of-00016.safetensors",
|
| 164 |
+
"backbone.layers.27.mixer.gate.e_score_correction_bias": "model-00006-of-00016.safetensors",
|
| 165 |
+
"backbone.layers.27.mixer.gate.weight": "model-00006-of-00016.safetensors",
|
| 166 |
+
"backbone.layers.27.mixer.shared_experts.down_proj.weight": "model-00006-of-00016.safetensors",
|
| 167 |
+
"backbone.layers.27.mixer.shared_experts.up_proj.weight": "model-00006-of-00016.safetensors",
|
| 168 |
+
"backbone.layers.27.norm.weight": "model-00006-of-00016.safetensors",
|
| 169 |
+
"backbone.layers.28.mixer.A_log": "model-00006-of-00016.safetensors",
|
| 170 |
+
"backbone.layers.28.mixer.D": "model-00006-of-00016.safetensors",
|
| 171 |
+
"backbone.layers.28.mixer.conv1d.bias": "model-00006-of-00016.safetensors",
|
| 172 |
+
"backbone.layers.28.mixer.conv1d.weight": "model-00006-of-00016.safetensors",
|
| 173 |
+
"backbone.layers.28.mixer.dt_bias": "model-00006-of-00016.safetensors",
|
| 174 |
+
"backbone.layers.28.mixer.in_proj.weight": "model-00006-of-00016.safetensors",
|
| 175 |
+
"backbone.layers.28.mixer.norm.weight": "model-00006-of-00016.safetensors",
|
| 176 |
+
"backbone.layers.28.mixer.out_proj.weight": "model-00006-of-00016.safetensors",
|
| 177 |
+
"backbone.layers.28.norm.weight": "model-00006-of-00016.safetensors",
|
| 178 |
+
"backbone.layers.29.mixer.experts.down_proj.weight": "model-00007-of-00016.safetensors",
|
| 179 |
+
"backbone.layers.29.mixer.experts.up_proj.weight": "model-00007-of-00016.safetensors",
|
| 180 |
+
"backbone.layers.29.mixer.gate.e_score_correction_bias": "model-00007-of-00016.safetensors",
|
| 181 |
+
"backbone.layers.29.mixer.gate.weight": "model-00007-of-00016.safetensors",
|
| 182 |
+
"backbone.layers.29.mixer.shared_experts.down_proj.weight": "model-00007-of-00016.safetensors",
|
| 183 |
+
"backbone.layers.29.mixer.shared_experts.up_proj.weight": "model-00007-of-00016.safetensors",
|
| 184 |
+
"backbone.layers.29.norm.weight": "model-00007-of-00016.safetensors",
|
| 185 |
+
"backbone.layers.3.mixer.experts.down_proj.weight": "model-00007-of-00016.safetensors",
|
| 186 |
+
"backbone.layers.3.mixer.experts.up_proj.weight": "model-00008-of-00016.safetensors",
|
| 187 |
+
"backbone.layers.3.mixer.gate.e_score_correction_bias": "model-00008-of-00016.safetensors",
|
| 188 |
+
"backbone.layers.3.mixer.gate.weight": "model-00008-of-00016.safetensors",
|
| 189 |
+
"backbone.layers.3.mixer.shared_experts.down_proj.weight": "model-00008-of-00016.safetensors",
|
| 190 |
+
"backbone.layers.3.mixer.shared_experts.up_proj.weight": "model-00008-of-00016.safetensors",
|
| 191 |
+
"backbone.layers.3.norm.weight": "model-00008-of-00016.safetensors",
|
| 192 |
+
"backbone.layers.30.mixer.A_log": "model-00008-of-00016.safetensors",
|
| 193 |
+
"backbone.layers.30.mixer.D": "model-00008-of-00016.safetensors",
|
| 194 |
+
"backbone.layers.30.mixer.conv1d.bias": "model-00008-of-00016.safetensors",
|
| 195 |
+
"backbone.layers.30.mixer.conv1d.weight": "model-00008-of-00016.safetensors",
|
| 196 |
+
"backbone.layers.30.mixer.dt_bias": "model-00008-of-00016.safetensors",
|
| 197 |
+
"backbone.layers.30.mixer.in_proj.weight": "model-00008-of-00016.safetensors",
|
| 198 |
+
"backbone.layers.30.mixer.norm.weight": "model-00008-of-00016.safetensors",
|
| 199 |
+
"backbone.layers.30.mixer.out_proj.weight": "model-00008-of-00016.safetensors",
|
| 200 |
+
"backbone.layers.30.norm.weight": "model-00008-of-00016.safetensors",
|
| 201 |
+
"backbone.layers.31.mixer.experts.down_proj.weight": "model-00008-of-00016.safetensors",
|
| 202 |
+
"backbone.layers.31.mixer.experts.up_proj.weight": "model-00008-of-00016.safetensors",
|
| 203 |
+
"backbone.layers.31.mixer.gate.e_score_correction_bias": "model-00008-of-00016.safetensors",
|
| 204 |
+
"backbone.layers.31.mixer.gate.weight": "model-00008-of-00016.safetensors",
|
| 205 |
+
"backbone.layers.31.mixer.shared_experts.down_proj.weight": "model-00008-of-00016.safetensors",
|
| 206 |
+
"backbone.layers.31.mixer.shared_experts.up_proj.weight": "model-00008-of-00016.safetensors",
|
| 207 |
+
"backbone.layers.31.norm.weight": "model-00008-of-00016.safetensors",
|
| 208 |
+
"backbone.layers.32.mixer.A_log": "model-00008-of-00016.safetensors",
|
| 209 |
+
"backbone.layers.32.mixer.D": "model-00008-of-00016.safetensors",
|
| 210 |
+
"backbone.layers.32.mixer.conv1d.bias": "model-00008-of-00016.safetensors",
|
| 211 |
+
"backbone.layers.32.mixer.conv1d.weight": "model-00008-of-00016.safetensors",
|
| 212 |
+
"backbone.layers.32.mixer.dt_bias": "model-00008-of-00016.safetensors",
|
| 213 |
+
"backbone.layers.32.mixer.in_proj.weight": "model-00008-of-00016.safetensors",
|
| 214 |
+
"backbone.layers.32.mixer.norm.weight": "model-00008-of-00016.safetensors",
|
| 215 |
+
"backbone.layers.32.mixer.out_proj.weight": "model-00008-of-00016.safetensors",
|
| 216 |
+
"backbone.layers.32.norm.weight": "model-00008-of-00016.safetensors",
|
| 217 |
+
"backbone.layers.33.mixer.k_proj.weight": "model-00008-of-00016.safetensors",
|
| 218 |
+
"backbone.layers.33.mixer.o_proj.weight": "model-00008-of-00016.safetensors",
|
| 219 |
+
"backbone.layers.33.mixer.q_proj.weight": "model-00008-of-00016.safetensors",
|
| 220 |
+
"backbone.layers.33.mixer.s_proj.weight": "model-00008-of-00016.safetensors",
|
| 221 |
+
"backbone.layers.33.mixer.v_proj.weight": "model-00008-of-00016.safetensors",
|
| 222 |
+
"backbone.layers.33.norm.weight": "model-00008-of-00016.safetensors",
|
| 223 |
+
"backbone.layers.34.mixer.experts.down_proj.weight": "model-00009-of-00016.safetensors",
|
| 224 |
+
"backbone.layers.34.mixer.experts.up_proj.weight": "model-00009-of-00016.safetensors",
|
| 225 |
+
"backbone.layers.34.mixer.gate.e_score_correction_bias": "model-00009-of-00016.safetensors",
|
| 226 |
+
"backbone.layers.34.mixer.gate.weight": "model-00009-of-00016.safetensors",
|
| 227 |
+
"backbone.layers.34.mixer.shared_experts.down_proj.weight": "model-00009-of-00016.safetensors",
|
| 228 |
+
"backbone.layers.34.mixer.shared_experts.up_proj.weight": "model-00009-of-00016.safetensors",
|
| 229 |
+
"backbone.layers.34.norm.weight": "model-00009-of-00016.safetensors",
|
| 230 |
+
"backbone.layers.35.mixer.A_log": "model-00009-of-00016.safetensors",
|
| 231 |
+
"backbone.layers.35.mixer.D": "model-00009-of-00016.safetensors",
|
| 232 |
+
"backbone.layers.35.mixer.conv1d.bias": "model-00009-of-00016.safetensors",
|
| 233 |
+
"backbone.layers.35.mixer.conv1d.weight": "model-00009-of-00016.safetensors",
|
| 234 |
+
"backbone.layers.35.mixer.dt_bias": "model-00009-of-00016.safetensors",
|
| 235 |
+
"backbone.layers.35.mixer.in_proj.weight": "model-00009-of-00016.safetensors",
|
| 236 |
+
"backbone.layers.35.mixer.norm.weight": "model-00009-of-00016.safetensors",
|
| 237 |
+
"backbone.layers.35.mixer.out_proj.weight": "model-00009-of-00016.safetensors",
|
| 238 |
+
"backbone.layers.35.norm.weight": "model-00009-of-00016.safetensors",
|
| 239 |
+
"backbone.layers.36.mixer.experts.down_proj.weight": "model-00009-of-00016.safetensors",
|
| 240 |
+
"backbone.layers.36.mixer.experts.up_proj.weight": "model-00010-of-00016.safetensors",
|
| 241 |
+
"backbone.layers.36.mixer.gate.e_score_correction_bias": "model-00010-of-00016.safetensors",
|
| 242 |
+
"backbone.layers.36.mixer.gate.weight": "model-00010-of-00016.safetensors",
|
| 243 |
+
"backbone.layers.36.mixer.shared_experts.down_proj.weight": "model-00010-of-00016.safetensors",
|
| 244 |
+
"backbone.layers.36.mixer.shared_experts.up_proj.weight": "model-00010-of-00016.safetensors",
|
| 245 |
+
"backbone.layers.36.norm.weight": "model-00010-of-00016.safetensors",
|
| 246 |
+
"backbone.layers.37.mixer.A_log": "model-00010-of-00016.safetensors",
|
| 247 |
+
"backbone.layers.37.mixer.D": "model-00010-of-00016.safetensors",
|
| 248 |
+
"backbone.layers.37.mixer.conv1d.bias": "model-00010-of-00016.safetensors",
|
| 249 |
+
"backbone.layers.37.mixer.conv1d.weight": "model-00010-of-00016.safetensors",
|
| 250 |
+
"backbone.layers.37.mixer.dt_bias": "model-00010-of-00016.safetensors",
|
| 251 |
+
"backbone.layers.37.mixer.in_proj.weight": "model-00010-of-00016.safetensors",
|
| 252 |
+
"backbone.layers.37.mixer.norm.weight": "model-00010-of-00016.safetensors",
|
| 253 |
+
"backbone.layers.37.mixer.out_proj.weight": "model-00010-of-00016.safetensors",
|
| 254 |
+
"backbone.layers.37.norm.weight": "model-00010-of-00016.safetensors",
|
| 255 |
+
"backbone.layers.38.mixer.experts.down_proj.weight": "model-00010-of-00016.safetensors",
|
| 256 |
+
"backbone.layers.38.mixer.experts.up_proj.weight": "model-00010-of-00016.safetensors",
|
| 257 |
+
"backbone.layers.38.mixer.gate.e_score_correction_bias": "model-00010-of-00016.safetensors",
|
| 258 |
+
"backbone.layers.38.mixer.gate.weight": "model-00010-of-00016.safetensors",
|
| 259 |
+
"backbone.layers.38.mixer.shared_experts.down_proj.weight": "model-00010-of-00016.safetensors",
|
| 260 |
+
"backbone.layers.38.mixer.shared_experts.up_proj.weight": "model-00010-of-00016.safetensors",
|
| 261 |
+
"backbone.layers.38.norm.weight": "model-00010-of-00016.safetensors",
|
| 262 |
+
"backbone.layers.39.mixer.A_log": "model-00010-of-00016.safetensors",
|
| 263 |
+
"backbone.layers.39.mixer.D": "model-00010-of-00016.safetensors",
|
| 264 |
+
"backbone.layers.39.mixer.conv1d.bias": "model-00010-of-00016.safetensors",
|
| 265 |
+
"backbone.layers.39.mixer.conv1d.weight": "model-00010-of-00016.safetensors",
|
| 266 |
+
"backbone.layers.39.mixer.dt_bias": "model-00010-of-00016.safetensors",
|
| 267 |
+
"backbone.layers.39.mixer.in_proj.weight": "model-00010-of-00016.safetensors",
|
| 268 |
+
"backbone.layers.39.mixer.norm.weight": "model-00010-of-00016.safetensors",
|
| 269 |
+
"backbone.layers.39.mixer.out_proj.weight": "model-00010-of-00016.safetensors",
|
| 270 |
+
"backbone.layers.39.norm.weight": "model-00010-of-00016.safetensors",
|
| 271 |
+
"backbone.layers.4.mixer.A_log": "model-00010-of-00016.safetensors",
|
| 272 |
+
"backbone.layers.4.mixer.D": "model-00010-of-00016.safetensors",
|
| 273 |
+
"backbone.layers.4.mixer.conv1d.bias": "model-00010-of-00016.safetensors",
|
| 274 |
+
"backbone.layers.4.mixer.conv1d.weight": "model-00010-of-00016.safetensors",
|
| 275 |
+
"backbone.layers.4.mixer.dt_bias": "model-00010-of-00016.safetensors",
|
| 276 |
+
"backbone.layers.4.mixer.in_proj.weight": "model-00010-of-00016.safetensors",
|
| 277 |
+
"backbone.layers.4.mixer.norm.weight": "model-00010-of-00016.safetensors",
|
| 278 |
+
"backbone.layers.4.mixer.out_proj.weight": "model-00010-of-00016.safetensors",
|
| 279 |
+
"backbone.layers.4.norm.weight": "model-00010-of-00016.safetensors",
|
| 280 |
+
"backbone.layers.40.mixer.experts.down_proj.weight": "model-00011-of-00016.safetensors",
|
| 281 |
+
"backbone.layers.40.mixer.experts.up_proj.weight": "model-00011-of-00016.safetensors",
|
| 282 |
+
"backbone.layers.40.mixer.gate.e_score_correction_bias": "model-00011-of-00016.safetensors",
|
| 283 |
+
"backbone.layers.40.mixer.gate.weight": "model-00011-of-00016.safetensors",
|
| 284 |
+
"backbone.layers.40.mixer.shared_experts.down_proj.weight": "model-00011-of-00016.safetensors",
|
| 285 |
+
"backbone.layers.40.mixer.shared_experts.up_proj.weight": "model-00011-of-00016.safetensors",
|
| 286 |
+
"backbone.layers.40.norm.weight": "model-00011-of-00016.safetensors",
|
| 287 |
+
"backbone.layers.41.mixer.A_log": "model-00011-of-00016.safetensors",
|
| 288 |
+
"backbone.layers.41.mixer.D": "model-00011-of-00016.safetensors",
|
| 289 |
+
"backbone.layers.41.mixer.conv1d.bias": "model-00011-of-00016.safetensors",
|
| 290 |
+
"backbone.layers.41.mixer.conv1d.weight": "model-00011-of-00016.safetensors",
|
| 291 |
+
"backbone.layers.41.mixer.dt_bias": "model-00011-of-00016.safetensors",
|
| 292 |
+
"backbone.layers.41.mixer.in_proj.weight": "model-00011-of-00016.safetensors",
|
| 293 |
+
"backbone.layers.41.mixer.norm.weight": "model-00011-of-00016.safetensors",
|
| 294 |
+
"backbone.layers.41.mixer.out_proj.weight": "model-00011-of-00016.safetensors",
|
| 295 |
+
"backbone.layers.41.norm.weight": "model-00011-of-00016.safetensors",
|
| 296 |
+
"backbone.layers.42.mixer.k_proj.weight": "model-00011-of-00016.safetensors",
|
| 297 |
+
"backbone.layers.42.mixer.o_proj.weight": "model-00011-of-00016.safetensors",
|
| 298 |
+
"backbone.layers.42.mixer.q_proj.weight": "model-00011-of-00016.safetensors",
|
| 299 |
+
"backbone.layers.42.mixer.s_proj.weight": "model-00011-of-00016.safetensors",
|
| 300 |
+
"backbone.layers.42.mixer.v_proj.weight": "model-00011-of-00016.safetensors",
|
| 301 |
+
"backbone.layers.42.norm.weight": "model-00011-of-00016.safetensors",
|
| 302 |
+
"backbone.layers.43.mixer.experts.down_proj.weight": "model-00011-of-00016.safetensors",
|
| 303 |
+
"backbone.layers.43.mixer.experts.up_proj.weight": "model-00012-of-00016.safetensors",
|
| 304 |
+
"backbone.layers.43.mixer.gate.e_score_correction_bias": "model-00012-of-00016.safetensors",
|
| 305 |
+
"backbone.layers.43.mixer.gate.weight": "model-00012-of-00016.safetensors",
|
| 306 |
+
"backbone.layers.43.mixer.shared_experts.down_proj.weight": "model-00012-of-00016.safetensors",
|
| 307 |
+
"backbone.layers.43.mixer.shared_experts.up_proj.weight": "model-00012-of-00016.safetensors",
|
| 308 |
+
"backbone.layers.43.norm.weight": "model-00012-of-00016.safetensors",
|
| 309 |
+
"backbone.layers.44.mixer.A_log": "model-00012-of-00016.safetensors",
|
| 310 |
+
"backbone.layers.44.mixer.D": "model-00012-of-00016.safetensors",
|
| 311 |
+
"backbone.layers.44.mixer.conv1d.bias": "model-00012-of-00016.safetensors",
|
| 312 |
+
"backbone.layers.44.mixer.conv1d.weight": "model-00012-of-00016.safetensors",
|
| 313 |
+
"backbone.layers.44.mixer.dt_bias": "model-00012-of-00016.safetensors",
|
| 314 |
+
"backbone.layers.44.mixer.in_proj.weight": "model-00012-of-00016.safetensors",
|
| 315 |
+
"backbone.layers.44.mixer.norm.weight": "model-00012-of-00016.safetensors",
|
| 316 |
+
"backbone.layers.44.mixer.out_proj.weight": "model-00012-of-00016.safetensors",
|
| 317 |
+
"backbone.layers.44.norm.weight": "model-00012-of-00016.safetensors",
|
| 318 |
+
"backbone.layers.45.mixer.experts.down_proj.weight": "model-00012-of-00016.safetensors",
|
| 319 |
+
"backbone.layers.45.mixer.experts.up_proj.weight": "model-00012-of-00016.safetensors",
|
| 320 |
+
"backbone.layers.45.mixer.gate.e_score_correction_bias": "model-00012-of-00016.safetensors",
|
| 321 |
+
"backbone.layers.45.mixer.gate.weight": "model-00012-of-00016.safetensors",
|
| 322 |
+
"backbone.layers.45.mixer.shared_experts.down_proj.weight": "model-00012-of-00016.safetensors",
|
| 323 |
+
"backbone.layers.45.mixer.shared_experts.up_proj.weight": "model-00012-of-00016.safetensors",
|
| 324 |
+
"backbone.layers.45.norm.weight": "model-00012-of-00016.safetensors",
|
| 325 |
+
"backbone.layers.46.mixer.A_log": "model-00012-of-00016.safetensors",
|
| 326 |
+
"backbone.layers.46.mixer.D": "model-00012-of-00016.safetensors",
|
| 327 |
+
"backbone.layers.46.mixer.conv1d.bias": "model-00012-of-00016.safetensors",
|
| 328 |
+
"backbone.layers.46.mixer.conv1d.weight": "model-00012-of-00016.safetensors",
|
| 329 |
+
"backbone.layers.46.mixer.dt_bias": "model-00012-of-00016.safetensors",
|
| 330 |
+
"backbone.layers.46.mixer.in_proj.weight": "model-00012-of-00016.safetensors",
|
| 331 |
+
"backbone.layers.46.mixer.norm.weight": "model-00012-of-00016.safetensors",
|
| 332 |
+
"backbone.layers.46.mixer.out_proj.weight": "model-00012-of-00016.safetensors",
|
| 333 |
+
"backbone.layers.46.norm.weight": "model-00012-of-00016.safetensors",
|
| 334 |
+
"backbone.layers.47.mixer.experts.down_proj.weight": "model-00013-of-00016.safetensors",
|
| 335 |
+
"backbone.layers.47.mixer.experts.up_proj.weight": "model-00013-of-00016.safetensors",
|
| 336 |
+
"backbone.layers.47.mixer.gate.e_score_correction_bias": "model-00013-of-00016.safetensors",
|
| 337 |
+
"backbone.layers.47.mixer.gate.weight": "model-00013-of-00016.safetensors",
|
| 338 |
+
"backbone.layers.47.mixer.shared_experts.down_proj.weight": "model-00013-of-00016.safetensors",
|
| 339 |
+
"backbone.layers.47.mixer.shared_experts.up_proj.weight": "model-00013-of-00016.safetensors",
|
| 340 |
+
"backbone.layers.47.norm.weight": "model-00013-of-00016.safetensors",
|
| 341 |
+
"backbone.layers.48.mixer.A_log": "model-00013-of-00016.safetensors",
|
| 342 |
+
"backbone.layers.48.mixer.D": "model-00013-of-00016.safetensors",
|
| 343 |
+
"backbone.layers.48.mixer.conv1d.bias": "model-00013-of-00016.safetensors",
|
| 344 |
+
"backbone.layers.48.mixer.conv1d.weight": "model-00013-of-00016.safetensors",
|
| 345 |
+
"backbone.layers.48.mixer.dt_bias": "model-00013-of-00016.safetensors",
|
| 346 |
+
"backbone.layers.48.mixer.in_proj.weight": "model-00013-of-00016.safetensors",
|
| 347 |
+
"backbone.layers.48.mixer.norm.weight": "model-00013-of-00016.safetensors",
|
| 348 |
+
"backbone.layers.48.mixer.out_proj.weight": "model-00013-of-00016.safetensors",
|
| 349 |
+
"backbone.layers.48.norm.weight": "model-00013-of-00016.safetensors",
|
| 350 |
+
"backbone.layers.49.mixer.experts.down_proj.weight": "model-00013-of-00016.safetensors",
|
| 351 |
+
"backbone.layers.49.mixer.experts.up_proj.weight": "model-00014-of-00016.safetensors",
|
| 352 |
+
"backbone.layers.49.mixer.gate.e_score_correction_bias": "model-00014-of-00016.safetensors",
|
| 353 |
+
"backbone.layers.49.mixer.gate.weight": "model-00014-of-00016.safetensors",
|
| 354 |
+
"backbone.layers.49.mixer.shared_experts.down_proj.weight": "model-00014-of-00016.safetensors",
|
| 355 |
+
"backbone.layers.49.mixer.shared_experts.up_proj.weight": "model-00014-of-00016.safetensors",
|
| 356 |
+
"backbone.layers.49.norm.weight": "model-00014-of-00016.safetensors",
|
| 357 |
+
"backbone.layers.5.mixer.k_proj.weight": "model-00014-of-00016.safetensors",
|
| 358 |
+
"backbone.layers.5.mixer.o_proj.weight": "model-00014-of-00016.safetensors",
|
| 359 |
+
"backbone.layers.5.mixer.q_proj.weight": "model-00014-of-00016.safetensors",
|
| 360 |
+
"backbone.layers.5.mixer.s_proj.weight": "model-00014-of-00016.safetensors",
|
| 361 |
+
"backbone.layers.5.mixer.v_proj.weight": "model-00014-of-00016.safetensors",
|
| 362 |
+
"backbone.layers.5.norm.weight": "model-00014-of-00016.safetensors",
|
| 363 |
+
"backbone.layers.50.mixer.A_log": "model-00014-of-00016.safetensors",
|
| 364 |
+
"backbone.layers.50.mixer.D": "model-00014-of-00016.safetensors",
|
| 365 |
+
"backbone.layers.50.mixer.conv1d.bias": "model-00014-of-00016.safetensors",
|
| 366 |
+
"backbone.layers.50.mixer.conv1d.weight": "model-00014-of-00016.safetensors",
|
| 367 |
+
"backbone.layers.50.mixer.dt_bias": "model-00014-of-00016.safetensors",
|
| 368 |
+
"backbone.layers.50.mixer.in_proj.weight": "model-00014-of-00016.safetensors",
|
| 369 |
+
"backbone.layers.50.mixer.norm.weight": "model-00014-of-00016.safetensors",
|
| 370 |
+
"backbone.layers.50.mixer.out_proj.weight": "model-00014-of-00016.safetensors",
|
| 371 |
+
"backbone.layers.50.norm.weight": "model-00014-of-00016.safetensors",
|
| 372 |
+
"backbone.layers.51.mixer.experts.down_proj.weight": "model-00014-of-00016.safetensors",
|
| 373 |
+
"backbone.layers.51.mixer.experts.up_proj.weight": "model-00014-of-00016.safetensors",
|
| 374 |
+
"backbone.layers.51.mixer.gate.e_score_correction_bias": "model-00014-of-00016.safetensors",
|
| 375 |
+
"backbone.layers.51.mixer.gate.weight": "model-00014-of-00016.safetensors",
|
| 376 |
+
"backbone.layers.51.mixer.shared_experts.down_proj.weight": "model-00014-of-00016.safetensors",
|
| 377 |
+
"backbone.layers.51.mixer.shared_experts.up_proj.weight": "model-00014-of-00016.safetensors",
|
| 378 |
+
"backbone.layers.51.norm.weight": "model-00014-of-00016.safetensors",
|
| 379 |
+
"backbone.layers.6.mixer.experts.down_proj.weight": "model-00015-of-00016.safetensors",
|
| 380 |
+
"backbone.layers.6.mixer.experts.up_proj.weight": "model-00015-of-00016.safetensors",
|
| 381 |
+
"backbone.layers.6.mixer.gate.e_score_correction_bias": "model-00015-of-00016.safetensors",
|
| 382 |
+
"backbone.layers.6.mixer.gate.weight": "model-00015-of-00016.safetensors",
|
| 383 |
+
"backbone.layers.6.mixer.shared_experts.down_proj.weight": "model-00015-of-00016.safetensors",
|
| 384 |
+
"backbone.layers.6.mixer.shared_experts.up_proj.weight": "model-00015-of-00016.safetensors",
|
| 385 |
+
"backbone.layers.6.norm.weight": "model-00015-of-00016.safetensors",
|
| 386 |
+
"backbone.layers.7.mixer.A_log": "model-00015-of-00016.safetensors",
|
| 387 |
+
"backbone.layers.7.mixer.D": "model-00015-of-00016.safetensors",
|
| 388 |
+
"backbone.layers.7.mixer.conv1d.bias": "model-00015-of-00016.safetensors",
|
| 389 |
+
"backbone.layers.7.mixer.conv1d.weight": "model-00015-of-00016.safetensors",
|
| 390 |
+
"backbone.layers.7.mixer.dt_bias": "model-00015-of-00016.safetensors",
|
| 391 |
+
"backbone.layers.7.mixer.in_proj.weight": "model-00015-of-00016.safetensors",
|
| 392 |
+
"backbone.layers.7.mixer.norm.weight": "model-00015-of-00016.safetensors",
|
| 393 |
+
"backbone.layers.7.mixer.out_proj.weight": "model-00015-of-00016.safetensors",
|
| 394 |
+
"backbone.layers.7.norm.weight": "model-00015-of-00016.safetensors",
|
| 395 |
+
"backbone.layers.8.mixer.experts.down_proj.weight": "model-00015-of-00016.safetensors",
|
| 396 |
+
"backbone.layers.8.mixer.experts.up_proj.weight": "model-00016-of-00016.safetensors",
|
| 397 |
+
"backbone.layers.8.mixer.gate.e_score_correction_bias": "model-00016-of-00016.safetensors",
|
| 398 |
+
"backbone.layers.8.mixer.gate.weight": "model-00016-of-00016.safetensors",
|
| 399 |
+
"backbone.layers.8.mixer.shared_experts.down_proj.weight": "model-00016-of-00016.safetensors",
|
| 400 |
+
"backbone.layers.8.mixer.shared_experts.up_proj.weight": "model-00016-of-00016.safetensors",
|
| 401 |
+
"backbone.layers.8.norm.weight": "model-00016-of-00016.safetensors",
|
| 402 |
+
"backbone.layers.9.mixer.A_log": "model-00016-of-00016.safetensors",
|
| 403 |
+
"backbone.layers.9.mixer.D": "model-00016-of-00016.safetensors",
|
| 404 |
+
"backbone.layers.9.mixer.conv1d.bias": "model-00016-of-00016.safetensors",
|
| 405 |
+
"backbone.layers.9.mixer.conv1d.weight": "model-00016-of-00016.safetensors",
|
| 406 |
+
"backbone.layers.9.mixer.dt_bias": "model-00016-of-00016.safetensors",
|
| 407 |
+
"backbone.layers.9.mixer.in_proj.weight": "model-00016-of-00016.safetensors",
|
| 408 |
+
"backbone.layers.9.mixer.norm.weight": "model-00016-of-00016.safetensors",
|
| 409 |
+
"backbone.layers.9.mixer.out_proj.weight": "model-00016-of-00016.safetensors",
|
| 410 |
+
"backbone.layers.9.norm.weight": "model-00016-of-00016.safetensors",
|
| 411 |
+
"backbone.norm_f.weight": "model-00016-of-00016.safetensors",
|
| 412 |
+
"lm_head.weight": "model-00016-of-00016.safetensors"
|
| 413 |
+
}
|
| 414 |
+
}
|
modeling_superlinear_exp.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
moe.py
ADDED
|
@@ -0,0 +1,890 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import functools
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
try: # pragma: no cover
|
| 12 |
+
import triton
|
| 13 |
+
import triton.language as tl
|
| 14 |
+
except Exception: # pragma: no cover
|
| 15 |
+
triton = None
|
| 16 |
+
tl = None
|
| 17 |
+
|
| 18 |
+
# Eagerly import vllm._moe_C to ensure fused MoE kernels are available
|
| 19 |
+
# This must happen before shared_fused_moe_is_available() is called
|
| 20 |
+
try: # pragma: no cover
|
| 21 |
+
import vllm._moe_C # noqa: F401
|
| 22 |
+
except Exception: # pragma: no cover
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _cdiv(a: int, b: int) -> int:
|
| 27 |
+
return (a + b - 1) // b
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _round_up(a: int, b: int) -> int:
|
| 31 |
+
return _cdiv(a, b) * b
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@functools.lru_cache(maxsize=1)
|
| 35 |
+
def _ensure_moe_kernels_loaded() -> bool:
|
| 36 |
+
if hasattr(torch.ops, "_moe_C"):
|
| 37 |
+
return True
|
| 38 |
+
try: # pragma: no cover
|
| 39 |
+
# vLLM installs the MoE kernels under this module name.
|
| 40 |
+
import vllm._moe_C # noqa: F401
|
| 41 |
+
return hasattr(torch.ops, "_moe_C")
|
| 42 |
+
except Exception:
|
| 43 |
+
return False
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@functools.lru_cache(maxsize=1)
|
| 47 |
+
def shared_fused_moe_is_available() -> bool:
|
| 48 |
+
if triton is None or tl is None:
|
| 49 |
+
return False
|
| 50 |
+
if not _ensure_moe_kernels_loaded():
|
| 51 |
+
return False
|
| 52 |
+
return all(
|
| 53 |
+
hasattr(torch.ops._moe_C, attr)
|
| 54 |
+
for attr in ("moe_align_block_size", "moe_sum")
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _moe_align_block_size(
|
| 59 |
+
topk_ids: torch.Tensor,
|
| 60 |
+
block_size: int,
|
| 61 |
+
num_experts: int,
|
| 62 |
+
*,
|
| 63 |
+
pad_sorted_ids: bool = False,
|
| 64 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 65 |
+
"""Compute (sorted_token_ids, expert_ids, num_tokens_post_padded) using vLLM's `_moe_C` kernels."""
|
| 66 |
+
if not shared_fused_moe_is_available():
|
| 67 |
+
raise RuntimeError(
|
| 68 |
+
"Shared fused MoE is not available (missing triton and/or vllm._moe_C)."
|
| 69 |
+
)
|
| 70 |
+
if topk_ids.dtype != torch.int32:
|
| 71 |
+
topk_ids = topk_ids.to(torch.int32)
|
| 72 |
+
|
| 73 |
+
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
| 74 |
+
if pad_sorted_ids:
|
| 75 |
+
max_num_tokens_padded = _round_up(max_num_tokens_padded, block_size)
|
| 76 |
+
if topk_ids.numel() < num_experts:
|
| 77 |
+
max_num_tokens_padded = min(topk_ids.numel() * block_size, max_num_tokens_padded)
|
| 78 |
+
|
| 79 |
+
sorted_token_ids = torch.empty(
|
| 80 |
+
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
| 81 |
+
)
|
| 82 |
+
max_num_m_blocks = _cdiv(max_num_tokens_padded, block_size)
|
| 83 |
+
expert_ids = torch.empty(
|
| 84 |
+
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
| 85 |
+
)
|
| 86 |
+
num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device=topk_ids.device)
|
| 87 |
+
|
| 88 |
+
torch.ops._moe_C.moe_align_block_size(
|
| 89 |
+
topk_ids,
|
| 90 |
+
num_experts,
|
| 91 |
+
block_size,
|
| 92 |
+
sorted_token_ids,
|
| 93 |
+
expert_ids,
|
| 94 |
+
num_tokens_post_pad,
|
| 95 |
+
None, # maybe_expert_map (added in newer vllm versions)
|
| 96 |
+
)
|
| 97 |
+
return sorted_token_ids, expert_ids, num_tokens_post_pad
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _get_default_config(
|
| 101 |
+
M: int,
|
| 102 |
+
E: int,
|
| 103 |
+
N: int,
|
| 104 |
+
K: int,
|
| 105 |
+
topk: int,
|
| 106 |
+
) -> dict[str, int]:
|
| 107 |
+
# Heuristic default configs adapted from vLLM.
|
| 108 |
+
if M <= E:
|
| 109 |
+
return {
|
| 110 |
+
"BLOCK_SIZE_M": 16,
|
| 111 |
+
"BLOCK_SIZE_N": 32,
|
| 112 |
+
"BLOCK_SIZE_K": 64,
|
| 113 |
+
"GROUP_SIZE_M": 1,
|
| 114 |
+
}
|
| 115 |
+
return {
|
| 116 |
+
"BLOCK_SIZE_M": 64,
|
| 117 |
+
"BLOCK_SIZE_N": 64,
|
| 118 |
+
"BLOCK_SIZE_K": 32,
|
| 119 |
+
"GROUP_SIZE_M": 8,
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
if triton is not None and tl is not None:
|
| 124 |
+
|
| 125 |
+
@triton.jit
|
| 126 |
+
def _write_zeros_to_output(
|
| 127 |
+
c_ptr,
|
| 128 |
+
stride_cm,
|
| 129 |
+
stride_cn,
|
| 130 |
+
pid_n,
|
| 131 |
+
N,
|
| 132 |
+
offs_token,
|
| 133 |
+
token_mask,
|
| 134 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 135 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 136 |
+
compute_type: tl.constexpr,
|
| 137 |
+
):
|
| 138 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
|
| 139 |
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 140 |
+
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
| 141 |
+
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
| 142 |
+
tl.store(c_ptrs, accumulator, mask=c_mask)
|
| 143 |
+
|
| 144 |
+
@triton.jit
|
| 145 |
+
def _fused_moe_kernel(
|
| 146 |
+
# Pointers to matrices
|
| 147 |
+
a_ptr,
|
| 148 |
+
b_ptr,
|
| 149 |
+
c_ptr,
|
| 150 |
+
topk_weights_ptr,
|
| 151 |
+
sorted_token_ids_ptr,
|
| 152 |
+
expert_ids_ptr,
|
| 153 |
+
num_tokens_post_padded_ptr,
|
| 154 |
+
# Matrix dimensions
|
| 155 |
+
N,
|
| 156 |
+
K,
|
| 157 |
+
EM,
|
| 158 |
+
num_valid_tokens,
|
| 159 |
+
# Strides
|
| 160 |
+
stride_am,
|
| 161 |
+
stride_ak,
|
| 162 |
+
stride_be,
|
| 163 |
+
stride_bk,
|
| 164 |
+
stride_bn,
|
| 165 |
+
stride_cm,
|
| 166 |
+
stride_cn,
|
| 167 |
+
# Meta-parameters
|
| 168 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 169 |
+
BLOCK_SIZE_N: tl.constexpr,
|
| 170 |
+
BLOCK_SIZE_K: tl.constexpr,
|
| 171 |
+
GROUP_SIZE_M: tl.constexpr,
|
| 172 |
+
MUL_ROUTED_WEIGHT: tl.constexpr,
|
| 173 |
+
top_k: tl.constexpr,
|
| 174 |
+
compute_type: tl.constexpr,
|
| 175 |
+
):
|
| 176 |
+
# Grouped ordering to promote L2 data reuse.
|
| 177 |
+
pid = tl.program_id(axis=0)
|
| 178 |
+
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
| 179 |
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
| 180 |
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
| 181 |
+
group_id = pid // num_pid_in_group
|
| 182 |
+
first_pid_m = group_id * GROUP_SIZE_M
|
| 183 |
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
| 184 |
+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
| 185 |
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
| 186 |
+
|
| 187 |
+
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
| 188 |
+
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
| 189 |
+
return
|
| 190 |
+
|
| 191 |
+
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
| 192 |
+
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
| 193 |
+
token_mask = offs_token < num_valid_tokens
|
| 194 |
+
|
| 195 |
+
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
| 196 |
+
if off_experts == -1:
|
| 197 |
+
_write_zeros_to_output(
|
| 198 |
+
c_ptr,
|
| 199 |
+
stride_cm,
|
| 200 |
+
stride_cn,
|
| 201 |
+
pid_n,
|
| 202 |
+
N,
|
| 203 |
+
offs_token,
|
| 204 |
+
token_mask,
|
| 205 |
+
BLOCK_SIZE_M,
|
| 206 |
+
BLOCK_SIZE_N,
|
| 207 |
+
compute_type,
|
| 208 |
+
)
|
| 209 |
+
return
|
| 210 |
+
|
| 211 |
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
| 212 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 213 |
+
a_ptrs = a_ptr + (
|
| 214 |
+
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
| 215 |
+
)
|
| 216 |
+
b_ptrs = (
|
| 217 |
+
b_ptr
|
| 218 |
+
+ off_experts * stride_be
|
| 219 |
+
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
| 223 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 224 |
+
a = tl.load(
|
| 225 |
+
a_ptrs,
|
| 226 |
+
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
| 227 |
+
other=0.0,
|
| 228 |
+
)
|
| 229 |
+
b = tl.load(
|
| 230 |
+
b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0
|
| 231 |
+
)
|
| 232 |
+
# Disable TF32 for numerical precision (matches PyTorch's default behavior)
|
| 233 |
+
accumulator += tl.dot(a, b, allow_tf32=False)
|
| 234 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
| 235 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
| 236 |
+
|
| 237 |
+
if MUL_ROUTED_WEIGHT:
|
| 238 |
+
moe_weight = tl.load(
|
| 239 |
+
topk_weights_ptr + offs_token, mask=token_mask, other=0
|
| 240 |
+
)
|
| 241 |
+
accumulator = accumulator * moe_weight[:, None]
|
| 242 |
+
accumulator = accumulator.to(compute_type)
|
| 243 |
+
|
| 244 |
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
| 245 |
+
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
| 246 |
+
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
| 247 |
+
tl.store(c_ptrs, accumulator, mask=c_mask)
|
| 248 |
+
|
| 249 |
+
else: # pragma: no cover
|
| 250 |
+
|
| 251 |
+
def _fused_moe_kernel(*args, **kwargs): # type: ignore[no-redef]
|
| 252 |
+
raise RuntimeError("Triton is not available; cannot use fused MoE.")
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def _invoke_fused_moe_kernel(
|
| 256 |
+
*,
|
| 257 |
+
A: torch.Tensor,
|
| 258 |
+
B: torch.Tensor,
|
| 259 |
+
C: torch.Tensor,
|
| 260 |
+
topk_weights: torch.Tensor | None,
|
| 261 |
+
sorted_token_ids: torch.Tensor,
|
| 262 |
+
expert_ids: torch.Tensor,
|
| 263 |
+
num_tokens_post_padded: torch.Tensor,
|
| 264 |
+
mul_routed_weight: bool,
|
| 265 |
+
top_k: int,
|
| 266 |
+
config: dict[str, Any],
|
| 267 |
+
compute_type: Any,
|
| 268 |
+
) -> None:
|
| 269 |
+
assert triton is not None and tl is not None
|
| 270 |
+
assert topk_weights is not None or not mul_routed_weight
|
| 271 |
+
assert sorted_token_ids.stride(0) == 1
|
| 272 |
+
|
| 273 |
+
M = A.size(0)
|
| 274 |
+
num_tokens = M * top_k
|
| 275 |
+
EM = sorted_token_ids.size(0)
|
| 276 |
+
grid = lambda META: (
|
| 277 |
+
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
config = dict(config)
|
| 281 |
+
_fused_moe_kernel[grid](
|
| 282 |
+
A,
|
| 283 |
+
B,
|
| 284 |
+
C,
|
| 285 |
+
topk_weights,
|
| 286 |
+
sorted_token_ids,
|
| 287 |
+
expert_ids,
|
| 288 |
+
num_tokens_post_padded,
|
| 289 |
+
B.size(1),
|
| 290 |
+
B.size(2),
|
| 291 |
+
EM,
|
| 292 |
+
num_tokens,
|
| 293 |
+
A.stride(0),
|
| 294 |
+
A.stride(1),
|
| 295 |
+
B.stride(0),
|
| 296 |
+
B.stride(2),
|
| 297 |
+
B.stride(1),
|
| 298 |
+
C.stride(1),
|
| 299 |
+
C.stride(2),
|
| 300 |
+
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
| 301 |
+
top_k=top_k,
|
| 302 |
+
compute_type=compute_type,
|
| 303 |
+
**config,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def fused_experts_moe(
|
| 308 |
+
hidden_states: torch.Tensor,
|
| 309 |
+
w1: torch.Tensor,
|
| 310 |
+
w2: torch.Tensor,
|
| 311 |
+
topk_weights: torch.Tensor,
|
| 312 |
+
topk_ids: torch.Tensor,
|
| 313 |
+
*,
|
| 314 |
+
activation: str,
|
| 315 |
+
inplace: bool = False,
|
| 316 |
+
apply_router_weight_on_input: bool = False,
|
| 317 |
+
) -> torch.Tensor:
|
| 318 |
+
"""
|
| 319 |
+
Fused MoE expert compute (2-layer MLP) using Triton grouped GEMMs + vLLM's `_moe_C` align/sum kernels.
|
| 320 |
+
|
| 321 |
+
This function is intentionally minimal: it supports *non-gated* activations
|
| 322 |
+
(`*_no_mul`), which is the SuperlinearExp MoE case here.
|
| 323 |
+
"""
|
| 324 |
+
if torch.is_grad_enabled() and any(
|
| 325 |
+
t.requires_grad for t in (hidden_states, w1, w2, topk_weights)
|
| 326 |
+
):
|
| 327 |
+
return _FusedExpertsMoE.apply(
|
| 328 |
+
hidden_states,
|
| 329 |
+
w1,
|
| 330 |
+
w2,
|
| 331 |
+
topk_weights,
|
| 332 |
+
topk_ids,
|
| 333 |
+
activation,
|
| 334 |
+
apply_router_weight_on_input,
|
| 335 |
+
)
|
| 336 |
+
if hidden_states.numel() == 0:
|
| 337 |
+
return hidden_states
|
| 338 |
+
if hidden_states.dim() != 2:
|
| 339 |
+
raise ValueError(f"Expected [tokens, hidden], got {tuple(hidden_states.shape)}")
|
| 340 |
+
return _fused_experts_moe_forward(
|
| 341 |
+
hidden_states,
|
| 342 |
+
w1,
|
| 343 |
+
w2,
|
| 344 |
+
topk_weights,
|
| 345 |
+
topk_ids,
|
| 346 |
+
activation=activation,
|
| 347 |
+
inplace=inplace,
|
| 348 |
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def _activation_forward(x: torch.Tensor, activation: str) -> torch.Tensor:
|
| 353 |
+
if activation == "relu2_no_mul":
|
| 354 |
+
return torch.square(F.relu(x))
|
| 355 |
+
if activation == "silu_no_mul":
|
| 356 |
+
return F.silu(x)
|
| 357 |
+
if activation == "gelu_no_mul":
|
| 358 |
+
return F.gelu(x)
|
| 359 |
+
raise ValueError(f"Unsupported fused MoE activation: {activation}")
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def _activation_backward(x_fp32: torch.Tensor, activation: str) -> torch.Tensor:
|
| 363 |
+
if activation == "relu2_no_mul":
|
| 364 |
+
return (x_fp32 > 0).to(x_fp32.dtype) * (2.0 * x_fp32)
|
| 365 |
+
if activation == "silu_no_mul":
|
| 366 |
+
sig = torch.sigmoid(x_fp32)
|
| 367 |
+
return sig * (1.0 + x_fp32 * (1.0 - sig))
|
| 368 |
+
if activation == "gelu_no_mul":
|
| 369 |
+
inv_sqrt2 = 1.0 / math.sqrt(2.0)
|
| 370 |
+
inv_sqrt2pi = 1.0 / math.sqrt(2.0 * math.pi)
|
| 371 |
+
cdf = 0.5 * (1.0 + torch.erf(x_fp32 * inv_sqrt2))
|
| 372 |
+
pdf = torch.exp(-0.5 * x_fp32 * x_fp32) * inv_sqrt2pi
|
| 373 |
+
return cdf + x_fp32 * pdf
|
| 374 |
+
raise ValueError(f"Unsupported fused MoE activation: {activation}")
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def _eager_experts_moe_forward(
|
| 378 |
+
hidden_states: torch.Tensor,
|
| 379 |
+
w1: torch.Tensor,
|
| 380 |
+
w2: torch.Tensor,
|
| 381 |
+
topk_weights: torch.Tensor,
|
| 382 |
+
topk_ids: torch.Tensor,
|
| 383 |
+
*,
|
| 384 |
+
activation: str,
|
| 385 |
+
apply_router_weight_on_input: bool,
|
| 386 |
+
) -> torch.Tensor:
|
| 387 |
+
if hidden_states.numel() == 0:
|
| 388 |
+
return hidden_states
|
| 389 |
+
if hidden_states.dim() != 2:
|
| 390 |
+
raise ValueError(f"Expected [tokens, hidden], got {tuple(hidden_states.shape)}")
|
| 391 |
+
|
| 392 |
+
num_tokens, hidden_size = hidden_states.shape
|
| 393 |
+
num_experts, intermediate_size, hidden_size_w1 = w1.shape
|
| 394 |
+
num_experts_w2, hidden_size_w2, intermediate_size_w2 = w2.shape
|
| 395 |
+
if hidden_size_w1 != hidden_size:
|
| 396 |
+
raise ValueError(f"Hidden size mismatch: {hidden_size} != {hidden_size_w1} (w1 in_features)")
|
| 397 |
+
if num_experts_w2 != num_experts:
|
| 398 |
+
raise ValueError(f"Expert count mismatch: {num_experts} != {num_experts_w2} (w2)")
|
| 399 |
+
if hidden_size_w2 != hidden_size:
|
| 400 |
+
raise ValueError(f"Hidden size mismatch: {hidden_size} != {hidden_size_w2} (w2 out_features)")
|
| 401 |
+
if intermediate_size_w2 != intermediate_size:
|
| 402 |
+
raise ValueError(f"Intermediate size mismatch: {intermediate_size} != {intermediate_size_w2} (w2 in_features)")
|
| 403 |
+
if topk_ids.shape != topk_weights.shape:
|
| 404 |
+
raise ValueError("topk_ids/topk_weights shape mismatch")
|
| 405 |
+
|
| 406 |
+
topk = topk_ids.size(1)
|
| 407 |
+
out = torch.zeros((num_tokens, hidden_size), device=hidden_states.device, dtype=torch.float32)
|
| 408 |
+
|
| 409 |
+
CHUNK_SIZE = int(os.getenv("FUSED_MOE_CHUNK_SIZE", str(16 * 1024)))
|
| 410 |
+
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
| 411 |
+
begin = chunk * CHUNK_SIZE
|
| 412 |
+
end = min((chunk + 1) * CHUNK_SIZE, num_tokens)
|
| 413 |
+
x = hidden_states[begin:end]
|
| 414 |
+
if x.numel() == 0:
|
| 415 |
+
break
|
| 416 |
+
|
| 417 |
+
m = x.size(0)
|
| 418 |
+
topk_ids_chunk = topk_ids[begin:end].reshape(-1).to(torch.long)
|
| 419 |
+
topk_weights_chunk = topk_weights[begin:end].reshape(-1).to(torch.float32)
|
| 420 |
+
|
| 421 |
+
token_ids = (
|
| 422 |
+
torch.arange(m, device=x.device, dtype=torch.long)
|
| 423 |
+
.repeat_interleave(topk)
|
| 424 |
+
)
|
| 425 |
+
k_ids = torch.arange(topk, device=x.device, dtype=torch.long).repeat(m)
|
| 426 |
+
|
| 427 |
+
sort_order = torch.argsort(topk_ids_chunk)
|
| 428 |
+
expert_ids_sorted = topk_ids_chunk[sort_order]
|
| 429 |
+
token_ids_sorted = token_ids[sort_order]
|
| 430 |
+
k_ids_sorted = k_ids[sort_order]
|
| 431 |
+
weights_sorted = topk_weights_chunk[sort_order]
|
| 432 |
+
|
| 433 |
+
unique_experts, counts = torch.unique_consecutive(expert_ids_sorted, return_counts=True)
|
| 434 |
+
out_chunk = torch.zeros((m, hidden_size), device=x.device, dtype=torch.float32)
|
| 435 |
+
|
| 436 |
+
offset = 0
|
| 437 |
+
for expert_idx, count in zip(unique_experts.tolist(), counts.tolist()):
|
| 438 |
+
tokens = token_ids_sorted[offset : offset + count]
|
| 439 |
+
weights = weights_sorted[offset : offset + count]
|
| 440 |
+
offset += count
|
| 441 |
+
if count == 0:
|
| 442 |
+
continue
|
| 443 |
+
|
| 444 |
+
x_e = x.index_select(0, tokens)
|
| 445 |
+
u0 = F.linear(x_e, w1[expert_idx])
|
| 446 |
+
if apply_router_weight_on_input:
|
| 447 |
+
u = u0 * weights.to(u0.dtype).unsqueeze(-1)
|
| 448 |
+
else:
|
| 449 |
+
u = u0
|
| 450 |
+
a = _activation_forward(u, activation)
|
| 451 |
+
v = F.linear(a, w2[expert_idx])
|
| 452 |
+
if not apply_router_weight_on_input:
|
| 453 |
+
v = v * weights.to(v.dtype).unsqueeze(-1)
|
| 454 |
+
out_chunk.index_add_(0, tokens, v.to(torch.float32))
|
| 455 |
+
|
| 456 |
+
out[begin:end] = out_chunk
|
| 457 |
+
|
| 458 |
+
return out.to(hidden_states.dtype)
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def _fused_experts_moe_forward(
|
| 462 |
+
hidden_states: torch.Tensor,
|
| 463 |
+
w1: torch.Tensor,
|
| 464 |
+
w2: torch.Tensor,
|
| 465 |
+
topk_weights: torch.Tensor,
|
| 466 |
+
topk_ids: torch.Tensor,
|
| 467 |
+
*,
|
| 468 |
+
activation: str,
|
| 469 |
+
inplace: bool = False,
|
| 470 |
+
apply_router_weight_on_input: bool = False,
|
| 471 |
+
) -> torch.Tensor:
|
| 472 |
+
if not shared_fused_moe_is_available():
|
| 473 |
+
return _eager_experts_moe_forward(
|
| 474 |
+
hidden_states,
|
| 475 |
+
w1,
|
| 476 |
+
w2,
|
| 477 |
+
topk_weights,
|
| 478 |
+
topk_ids,
|
| 479 |
+
activation=activation,
|
| 480 |
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
| 481 |
+
)
|
| 482 |
+
if not hidden_states.is_cuda:
|
| 483 |
+
return _eager_experts_moe_forward(
|
| 484 |
+
hidden_states,
|
| 485 |
+
w1,
|
| 486 |
+
w2,
|
| 487 |
+
topk_weights,
|
| 488 |
+
topk_ids,
|
| 489 |
+
activation=activation,
|
| 490 |
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
# Constraints similar to vLLM's fused kernels.
|
| 494 |
+
if not hidden_states.is_contiguous():
|
| 495 |
+
hidden_states = hidden_states.contiguous()
|
| 496 |
+
if w1.stride(-1) != 1 or w2.stride(-1) != 1:
|
| 497 |
+
raise ValueError("Expert weights must be contiguous in the last dimension.")
|
| 498 |
+
|
| 499 |
+
# Shapes.
|
| 500 |
+
num_tokens = hidden_states.size(0)
|
| 501 |
+
num_experts, n1, k1 = w1.size()
|
| 502 |
+
_, k2, n2 = w2.size()
|
| 503 |
+
if hidden_states.size(1) != k1:
|
| 504 |
+
raise ValueError(
|
| 505 |
+
f"Hidden size mismatch: {hidden_states.size(1)} != {k1} (w1 in_features)"
|
| 506 |
+
)
|
| 507 |
+
if n2 != n1:
|
| 508 |
+
raise ValueError(f"Intermediate size mismatch: {n2} != {n1}")
|
| 509 |
+
if topk_ids.shape != topk_weights.shape:
|
| 510 |
+
raise ValueError("topk_ids/topk_weights shape mismatch")
|
| 511 |
+
|
| 512 |
+
topk = topk_ids.size(1)
|
| 513 |
+
CHUNK_SIZE = int(os.getenv("FUSED_MOE_CHUNK_SIZE", str(16 * 1024)))
|
| 514 |
+
M = min(num_tokens, CHUNK_SIZE)
|
| 515 |
+
|
| 516 |
+
config = _get_default_config(M=M, E=num_experts, N=n1, K=k1, topk=topk)
|
| 517 |
+
|
| 518 |
+
if hidden_states.dtype == torch.bfloat16:
|
| 519 |
+
compute_type = tl.bfloat16
|
| 520 |
+
elif hidden_states.dtype == torch.float16:
|
| 521 |
+
compute_type = tl.float16
|
| 522 |
+
elif hidden_states.dtype == torch.float32:
|
| 523 |
+
compute_type = tl.float32
|
| 524 |
+
else:
|
| 525 |
+
raise ValueError(f"Unsupported dtype: {hidden_states.dtype}")
|
| 526 |
+
|
| 527 |
+
# Accumulate in float32 for numerical precision (matches eager path behavior).
|
| 528 |
+
# The output will be converted back to the original dtype at the end.
|
| 529 |
+
original_dtype = hidden_states.dtype
|
| 530 |
+
out = torch.zeros(
|
| 531 |
+
(num_tokens, hidden_states.size(1)),
|
| 532 |
+
device=hidden_states.device,
|
| 533 |
+
dtype=torch.float32,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
# Cache buffers sized to the largest chunk.
|
| 537 |
+
# IMPORTANT: up_out and down_out must NOT overlap in memory!
|
| 538 |
+
# The down projection kernel reads from up_out while writing to down_out.
|
| 539 |
+
# If they share memory, the kernel will corrupt its input as it writes output.
|
| 540 |
+
up_out = torch.empty(
|
| 541 |
+
(M, topk, n1), device=hidden_states.device, dtype=hidden_states.dtype
|
| 542 |
+
)
|
| 543 |
+
down_out = torch.empty(
|
| 544 |
+
(M, topk, k2), device=hidden_states.device, dtype=hidden_states.dtype
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
| 548 |
+
begin = chunk * CHUNK_SIZE
|
| 549 |
+
end = min((chunk + 1) * CHUNK_SIZE, num_tokens)
|
| 550 |
+
curr_hidden = hidden_states[begin:end]
|
| 551 |
+
tokens_in_chunk = curr_hidden.size(0)
|
| 552 |
+
if tokens_in_chunk == 0:
|
| 553 |
+
break
|
| 554 |
+
|
| 555 |
+
if tokens_in_chunk != M:
|
| 556 |
+
up_out = up_out[:tokens_in_chunk]
|
| 557 |
+
down_out = down_out[:tokens_in_chunk]
|
| 558 |
+
config = _get_default_config(M=tokens_in_chunk, E=num_experts, N=n1, K=k1, topk=topk)
|
| 559 |
+
|
| 560 |
+
curr_topk_ids = topk_ids[begin:end].to(torch.int32).contiguous()
|
| 561 |
+
curr_topk_weights = topk_weights[begin:end].to(torch.float32).contiguous()
|
| 562 |
+
|
| 563 |
+
sorted_token_ids, expert_ids, num_tokens_post_padded = _moe_align_block_size(
|
| 564 |
+
curr_topk_ids,
|
| 565 |
+
config["BLOCK_SIZE_M"],
|
| 566 |
+
num_experts,
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
# 1) Up projection: [tokens, hidden] -> [tokens * topk, intermediate]
|
| 570 |
+
_invoke_fused_moe_kernel(
|
| 571 |
+
A=curr_hidden,
|
| 572 |
+
B=w1,
|
| 573 |
+
C=up_out,
|
| 574 |
+
topk_weights=curr_topk_weights if apply_router_weight_on_input else None,
|
| 575 |
+
sorted_token_ids=sorted_token_ids,
|
| 576 |
+
expert_ids=expert_ids,
|
| 577 |
+
num_tokens_post_padded=num_tokens_post_padded,
|
| 578 |
+
mul_routed_weight=apply_router_weight_on_input,
|
| 579 |
+
top_k=topk,
|
| 580 |
+
config=config,
|
| 581 |
+
compute_type=compute_type,
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
# 2) Activation (in-place on up_out to avoid allocating an extra buffer).
|
| 585 |
+
if activation == "relu2_no_mul":
|
| 586 |
+
x = up_out.view(-1, n1)
|
| 587 |
+
x.relu_()
|
| 588 |
+
x.square_()
|
| 589 |
+
elif activation == "silu_no_mul":
|
| 590 |
+
x = up_out.view(-1, n1)
|
| 591 |
+
x.copy_(F.silu(x))
|
| 592 |
+
elif activation == "gelu_no_mul":
|
| 593 |
+
x = up_out.view(-1, n1)
|
| 594 |
+
x.copy_(F.gelu(x))
|
| 595 |
+
else:
|
| 596 |
+
raise ValueError(f"Unsupported fused MoE activation: {activation}")
|
| 597 |
+
|
| 598 |
+
# 3) Down projection: [tokens * topk, intermediate] -> [tokens * topk, hidden]
|
| 599 |
+
_invoke_fused_moe_kernel(
|
| 600 |
+
A=up_out.view(-1, n1),
|
| 601 |
+
B=w2,
|
| 602 |
+
C=down_out,
|
| 603 |
+
topk_weights=None if apply_router_weight_on_input else curr_topk_weights,
|
| 604 |
+
sorted_token_ids=sorted_token_ids,
|
| 605 |
+
expert_ids=expert_ids,
|
| 606 |
+
num_tokens_post_padded=num_tokens_post_padded,
|
| 607 |
+
mul_routed_weight=not apply_router_weight_on_input,
|
| 608 |
+
top_k=1,
|
| 609 |
+
config=config,
|
| 610 |
+
compute_type=compute_type,
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
# Convert down_out to float32 to match out's dtype for moe_sum
|
| 614 |
+
torch.ops._moe_C.moe_sum(
|
| 615 |
+
down_out.view(*down_out.size()).to(torch.float32),
|
| 616 |
+
out[begin:end],
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
# Convert back to original dtype after accumulation in float32
|
| 620 |
+
return out.to(original_dtype)
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
class _FusedExpertsMoE(torch.autograd.Function):
|
| 624 |
+
@staticmethod
|
| 625 |
+
def forward(
|
| 626 |
+
ctx,
|
| 627 |
+
hidden_states: torch.Tensor,
|
| 628 |
+
w1: torch.Tensor,
|
| 629 |
+
w2: torch.Tensor,
|
| 630 |
+
topk_weights: torch.Tensor,
|
| 631 |
+
topk_ids: torch.Tensor,
|
| 632 |
+
activation: str,
|
| 633 |
+
apply_router_weight_on_input: bool,
|
| 634 |
+
) -> torch.Tensor:
|
| 635 |
+
ctx.activation = activation
|
| 636 |
+
ctx.apply_router_weight_on_input = apply_router_weight_on_input
|
| 637 |
+
ctx.save_for_backward(hidden_states, w1, w2, topk_weights, topk_ids)
|
| 638 |
+
return _fused_experts_moe_forward(
|
| 639 |
+
hidden_states,
|
| 640 |
+
w1,
|
| 641 |
+
w2,
|
| 642 |
+
topk_weights,
|
| 643 |
+
topk_ids,
|
| 644 |
+
activation=activation,
|
| 645 |
+
inplace=False,
|
| 646 |
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
@staticmethod
|
| 650 |
+
def backward(ctx, grad_out: torch.Tensor):
|
| 651 |
+
(
|
| 652 |
+
hidden_states,
|
| 653 |
+
w1,
|
| 654 |
+
w2,
|
| 655 |
+
topk_weights,
|
| 656 |
+
topk_ids,
|
| 657 |
+
) = ctx.saved_tensors
|
| 658 |
+
activation: str = ctx.activation
|
| 659 |
+
apply_router_weight_on_input: bool = ctx.apply_router_weight_on_input
|
| 660 |
+
|
| 661 |
+
need_hidden, need_w1, need_w2, need_topk_w = ctx.needs_input_grad[:4]
|
| 662 |
+
|
| 663 |
+
grad_hidden = torch.zeros_like(hidden_states) if need_hidden else None
|
| 664 |
+
grad_w1 = torch.zeros_like(w1) if need_w1 else None
|
| 665 |
+
grad_w2 = torch.zeros_like(w2) if need_w2 else None
|
| 666 |
+
grad_topk_weights = torch.zeros_like(topk_weights) if need_topk_w else None
|
| 667 |
+
|
| 668 |
+
if hidden_states.numel() == 0:
|
| 669 |
+
return grad_hidden, grad_w1, grad_w2, grad_topk_weights, None, None, None
|
| 670 |
+
|
| 671 |
+
num_tokens = hidden_states.size(0)
|
| 672 |
+
topk = topk_ids.size(1)
|
| 673 |
+
num_experts = w1.size(0)
|
| 674 |
+
|
| 675 |
+
CHUNK_SIZE = int(os.getenv("FUSED_MOE_CHUNK_SIZE", str(16 * 1024)))
|
| 676 |
+
max_padded_tokens_per_expert = int(
|
| 677 |
+
os.getenv("FUSED_MOE_BACKWARD_MAX_PADDED_TOKENS_PER_EXPERT", "2048")
|
| 678 |
+
)
|
| 679 |
+
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
| 680 |
+
begin = chunk * CHUNK_SIZE
|
| 681 |
+
end = min((chunk + 1) * CHUNK_SIZE, num_tokens)
|
| 682 |
+
x = hidden_states[begin:end]
|
| 683 |
+
if x.numel() == 0:
|
| 684 |
+
break
|
| 685 |
+
|
| 686 |
+
m = x.size(0)
|
| 687 |
+
token_ids = (
|
| 688 |
+
torch.arange(m, device=x.device, dtype=torch.long)
|
| 689 |
+
.repeat_interleave(topk)
|
| 690 |
+
)
|
| 691 |
+
k_ids = torch.arange(topk, device=x.device, dtype=torch.long).repeat(m)
|
| 692 |
+
expert_ids = topk_ids[begin:end].reshape(-1).to(torch.long)
|
| 693 |
+
weights_fp32 = topk_weights[begin:end].reshape(-1).to(torch.float32)
|
| 694 |
+
|
| 695 |
+
sort_order = torch.argsort(expert_ids)
|
| 696 |
+
expert_ids_sorted = expert_ids[sort_order]
|
| 697 |
+
token_ids_sorted = token_ids[sort_order]
|
| 698 |
+
k_ids_sorted = k_ids[sort_order]
|
| 699 |
+
weights_sorted = weights_fp32[sort_order]
|
| 700 |
+
|
| 701 |
+
counts_per_expert = torch.bincount(expert_ids_sorted, minlength=num_experts)
|
| 702 |
+
max_count = int(counts_per_expert.max().item())
|
| 703 |
+
|
| 704 |
+
use_vectorized = (
|
| 705 |
+
max_count > 0 and max_count <= max_padded_tokens_per_expert
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
if use_vectorized:
|
| 709 |
+
hidden_size = x.size(1)
|
| 710 |
+
offsets = torch.cumsum(counts_per_expert, 0) - counts_per_expert
|
| 711 |
+
pos_in_expert = torch.arange(
|
| 712 |
+
expert_ids_sorted.numel(), device=x.device, dtype=torch.long
|
| 713 |
+
) - offsets[expert_ids_sorted]
|
| 714 |
+
flat = expert_ids_sorted * max_count + pos_in_expert
|
| 715 |
+
|
| 716 |
+
x_pad = torch.zeros(
|
| 717 |
+
(num_experts, max_count, hidden_size),
|
| 718 |
+
device=x.device,
|
| 719 |
+
dtype=x.dtype,
|
| 720 |
+
)
|
| 721 |
+
x_pad.view(num_experts * max_count, hidden_size)[flat] = x[
|
| 722 |
+
token_ids_sorted
|
| 723 |
+
]
|
| 724 |
+
|
| 725 |
+
gy_pad = torch.zeros(
|
| 726 |
+
(num_experts, max_count, hidden_size),
|
| 727 |
+
device=x.device,
|
| 728 |
+
dtype=torch.float32,
|
| 729 |
+
)
|
| 730 |
+
gy_pad.view(num_experts * max_count, hidden_size)[flat] = grad_out[
|
| 731 |
+
begin:end
|
| 732 |
+
][token_ids_sorted].to(torch.float32)
|
| 733 |
+
|
| 734 |
+
w_pad = torch.zeros(
|
| 735 |
+
(num_experts, max_count),
|
| 736 |
+
device=x.device,
|
| 737 |
+
dtype=torch.float32,
|
| 738 |
+
)
|
| 739 |
+
w_pad.view(num_experts * max_count)[flat] = weights_sorted
|
| 740 |
+
|
| 741 |
+
u0 = torch.einsum("emh,eih->emi", x_pad, w1)
|
| 742 |
+
if apply_router_weight_on_input:
|
| 743 |
+
u = u0 * w_pad.to(u0.dtype).unsqueeze(-1)
|
| 744 |
+
else:
|
| 745 |
+
u = u0
|
| 746 |
+
a = _activation_forward(u, activation)
|
| 747 |
+
|
| 748 |
+
tmp = torch.einsum("emh,ehi->emi", gy_pad.to(a.dtype), w2)
|
| 749 |
+
tmp_fp32 = tmp.to(torch.float32)
|
| 750 |
+
|
| 751 |
+
if need_w2:
|
| 752 |
+
if apply_router_weight_on_input:
|
| 753 |
+
grad_v = gy_pad.to(a.dtype)
|
| 754 |
+
else:
|
| 755 |
+
grad_v = (gy_pad * w_pad.unsqueeze(-1)).to(a.dtype)
|
| 756 |
+
grad_w2_chunk = torch.einsum("emh,emi->ehi", grad_v, a)
|
| 757 |
+
assert grad_w2 is not None
|
| 758 |
+
grad_w2.add_(grad_w2_chunk.to(grad_w2.dtype))
|
| 759 |
+
|
| 760 |
+
gA_fp32 = tmp_fp32
|
| 761 |
+
if not apply_router_weight_on_input:
|
| 762 |
+
gA_fp32 = gA_fp32 * w_pad.unsqueeze(-1)
|
| 763 |
+
|
| 764 |
+
du_fp32 = _activation_backward(u.to(torch.float32), activation)
|
| 765 |
+
gU_fp32 = gA_fp32 * du_fp32
|
| 766 |
+
|
| 767 |
+
if apply_router_weight_on_input:
|
| 768 |
+
if need_topk_w:
|
| 769 |
+
grad_w_fp32 = torch.sum(
|
| 770 |
+
gU_fp32 * u0.to(torch.float32),
|
| 771 |
+
dim=-1,
|
| 772 |
+
)
|
| 773 |
+
gU0_fp32 = gU_fp32 * w_pad.unsqueeze(-1)
|
| 774 |
+
else:
|
| 775 |
+
if need_topk_w:
|
| 776 |
+
grad_w_fp32 = torch.sum(
|
| 777 |
+
a.to(torch.float32) * tmp_fp32,
|
| 778 |
+
dim=-1,
|
| 779 |
+
)
|
| 780 |
+
gU0_fp32 = gU_fp32
|
| 781 |
+
|
| 782 |
+
gU0 = gU0_fp32.to(x.dtype)
|
| 783 |
+
|
| 784 |
+
if need_w1:
|
| 785 |
+
grad_w1_chunk = torch.einsum("emi,emh->eih", gU0, x_pad)
|
| 786 |
+
assert grad_w1 is not None
|
| 787 |
+
grad_w1.add_(grad_w1_chunk.to(grad_w1.dtype))
|
| 788 |
+
|
| 789 |
+
if need_hidden:
|
| 790 |
+
grad_x_pad = torch.einsum("emi,eih->emh", gU0, w1)
|
| 791 |
+
grad_x_assign = grad_x_pad.view(
|
| 792 |
+
num_experts * max_count, hidden_size
|
| 793 |
+
)[flat]
|
| 794 |
+
grad_x_chunk = torch.zeros(
|
| 795 |
+
(m, hidden_size), device=x.device, dtype=x.dtype
|
| 796 |
+
)
|
| 797 |
+
grad_x_chunk.index_add_(0, token_ids_sorted, grad_x_assign)
|
| 798 |
+
assert grad_hidden is not None
|
| 799 |
+
grad_hidden[begin:end].copy_(grad_x_chunk)
|
| 800 |
+
|
| 801 |
+
if need_topk_w:
|
| 802 |
+
grad_w_assign = grad_w_fp32.view(num_experts * max_count)[flat]
|
| 803 |
+
grad_topk_chunk = torch.zeros(
|
| 804 |
+
(m, topk), device=x.device, dtype=topk_weights.dtype
|
| 805 |
+
)
|
| 806 |
+
grad_topk_chunk[token_ids_sorted, k_ids_sorted] = grad_w_assign.to(
|
| 807 |
+
grad_topk_chunk.dtype
|
| 808 |
+
)
|
| 809 |
+
assert grad_topk_weights is not None
|
| 810 |
+
grad_topk_weights[begin:end].copy_(grad_topk_chunk)
|
| 811 |
+
|
| 812 |
+
continue
|
| 813 |
+
|
| 814 |
+
unique_experts, counts = torch.unique_consecutive(
|
| 815 |
+
expert_ids_sorted, return_counts=True
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
grad_x_chunk = torch.zeros((m, x.size(1)), device=x.device, dtype=x.dtype) if need_hidden else None
|
| 819 |
+
grad_topk_chunk = torch.zeros((m, topk), device=x.device, dtype=topk_weights.dtype) if need_topk_w else None
|
| 820 |
+
|
| 821 |
+
offset = 0
|
| 822 |
+
for expert_idx, count in zip(unique_experts.tolist(), counts.tolist()):
|
| 823 |
+
tokens = token_ids_sorted[offset : offset + count]
|
| 824 |
+
ks = k_ids_sorted[offset : offset + count]
|
| 825 |
+
w = weights_sorted[offset : offset + count]
|
| 826 |
+
offset += count
|
| 827 |
+
if count == 0:
|
| 828 |
+
continue
|
| 829 |
+
|
| 830 |
+
x_e = x.index_select(0, tokens)
|
| 831 |
+
w1_e = w1[expert_idx]
|
| 832 |
+
w2_e = w2[expert_idx]
|
| 833 |
+
|
| 834 |
+
u0 = F.linear(x_e, w1_e)
|
| 835 |
+
if apply_router_weight_on_input:
|
| 836 |
+
u = u0 * w.to(u0.dtype).unsqueeze(-1)
|
| 837 |
+
else:
|
| 838 |
+
u = u0
|
| 839 |
+
a = _activation_forward(u, activation)
|
| 840 |
+
|
| 841 |
+
grad_y_fp32 = grad_out[begin:end].index_select(0, tokens).to(torch.float32)
|
| 842 |
+
if apply_router_weight_on_input:
|
| 843 |
+
grad_v_fp32 = grad_y_fp32
|
| 844 |
+
else:
|
| 845 |
+
grad_v_fp32 = grad_y_fp32 * w.unsqueeze(-1)
|
| 846 |
+
|
| 847 |
+
grad_v = grad_v_fp32.to(a.dtype)
|
| 848 |
+
|
| 849 |
+
if need_w2:
|
| 850 |
+
grad_w2_e = torch.matmul(grad_v.transpose(0, 1), a)
|
| 851 |
+
assert grad_w2 is not None
|
| 852 |
+
grad_w2[expert_idx].add_(grad_w2_e.to(grad_w2.dtype))
|
| 853 |
+
|
| 854 |
+
gA = torch.matmul(grad_v, w2_e)
|
| 855 |
+
du_fp32 = _activation_backward(u.to(torch.float32), activation)
|
| 856 |
+
gU_fp32 = gA.to(torch.float32) * du_fp32
|
| 857 |
+
|
| 858 |
+
if apply_router_weight_on_input:
|
| 859 |
+
grad_w_fp32 = torch.sum(gU_fp32 * u0.to(torch.float32), dim=-1)
|
| 860 |
+
gU0_fp32 = gU_fp32 * w.unsqueeze(-1)
|
| 861 |
+
else:
|
| 862 |
+
gy_for_w = grad_y_fp32.to(a.dtype)
|
| 863 |
+
tmp = torch.matmul(gy_for_w, w2_e).to(torch.float32)
|
| 864 |
+
grad_w_fp32 = torch.sum(a.to(torch.float32) * tmp, dim=-1)
|
| 865 |
+
gU0_fp32 = gU_fp32
|
| 866 |
+
|
| 867 |
+
if need_topk_w:
|
| 868 |
+
assert grad_topk_chunk is not None
|
| 869 |
+
grad_topk_chunk[tokens, ks] = grad_w_fp32.to(grad_topk_chunk.dtype)
|
| 870 |
+
|
| 871 |
+
gU0 = gU0_fp32.to(x_e.dtype)
|
| 872 |
+
|
| 873 |
+
if need_w1:
|
| 874 |
+
grad_w1_e = torch.matmul(gU0.transpose(0, 1), x_e)
|
| 875 |
+
assert grad_w1 is not None
|
| 876 |
+
grad_w1[expert_idx].add_(grad_w1_e.to(grad_w1.dtype))
|
| 877 |
+
|
| 878 |
+
if need_hidden:
|
| 879 |
+
assert grad_x_chunk is not None
|
| 880 |
+
grad_x_e = torch.matmul(gU0, w1_e)
|
| 881 |
+
grad_x_chunk.index_add_(0, tokens, grad_x_e)
|
| 882 |
+
|
| 883 |
+
if need_hidden:
|
| 884 |
+
assert grad_hidden is not None and grad_x_chunk is not None
|
| 885 |
+
grad_hidden[begin:end].copy_(grad_x_chunk)
|
| 886 |
+
if need_topk_w:
|
| 887 |
+
assert grad_topk_weights is not None and grad_topk_chunk is not None
|
| 888 |
+
grad_topk_weights[begin:end].copy_(grad_topk_chunk)
|
| 889 |
+
|
| 890 |
+
return grad_hidden, grad_w1, grad_w2, grad_topk_weights, None, None, None
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "<|im_end|>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "<|im_end|>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"unk_token": {
|
| 24 |
+
"content": "<unk>",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
}
|
| 30 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:623c34567aebb18582765289fbe23d901c62704d6518d71866e0e58db892b5b7
|
| 3 |
+
size 17077484
|
tokenizer_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|