hif1000 commited on
Commit
06ea793
·
verified ·
1 Parent(s): 5519a8a

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ __pycache__/modeling_superlinear_exp.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
37
+ __pycache__/modeling_superlinear_exp.cpython-313.pyc filter=lfs diff=lfs merge=lfs -text
38
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NVIDIA Open Model License Agreement
2
+
3
+ Last Modified: October 24, 2025
4
+
5
+ This NVIDIA Open Model License Agreement (the "Agreement") is a legal agreement between the Legal Entity You represent, or if no entity is identified, You and NVIDIA Corporation and its Affiliates ("NVIDIA") and governs Your use of the Models that NVIDIA provides to You under this Agreement. NVIDIA and You are each a "party" and collectively the "parties."
6
+
7
+ NVIDIA models released under this Agreement are intended to be used permissively and enable the further development of AI technologies. Subject to the terms of this Agreement, NVIDIA confirms that:
8
+
9
+ - Models are commercially usable.
10
+ - You are free to create and distribute Derivative Models.
11
+ - NVIDIA does not claim ownership to any outputs generated using the Models or Derivative Models.
12
+
13
+ By using, reproducing, modifying, distributing, performing or displaying any portion or element of the Model or Derivative Model, or otherwise accepting the terms of this Agreement, you agree to be bound by this Agreement.
14
+
15
+ Definitions. The following definitions apply to this Agreement:
16
+
17
+ "Derivative Model" means all (a) modifications to the Model, (b) works based on the Model, and (c) any other derivative works of the Model. An output is not a Derivative Model.
18
+
19
+ "Legal Entity" means the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (a) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (b) ownership of fifty percent (50%) or more of the outstanding shares, or (c) beneficial ownership of such entity.
20
+
21
+ "Model" means the machine learning model, software, checkpoints, learnt weights, algorithms, parameters, configuration files and documentation shared under this Agreement.
22
+
23
+ "NVIDIA Cosmos Model" means a multimodal Model shared under this Agreement
24
+
25
+ "Special-Purpose Model" means a Model that is only competent in a narrow set of purpose-specific tasks and should not be used for unintended or general-purpose applications
26
+
27
+ "You" or "Your" means an individual or Legal Entity exercising permissions granted by this Agreement.
28
+
29
+ Conditions for Use, License Grant, AI Ethics and IP Ownership.
30
+
31
+ 2.1 Conditions for Use. The Model and any Derivative Model are subject to additional terms as described in Section 2 and Section 3 of this Agreement and govern Your use. If You institute copyright or patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model or a Derivative Model constitutes direct or contributory copyright or patent infringement, then any licenses granted to You under this Agreement for that Model or Derivative Model will terminate as of the date such litigation is filed. If You bypass, disable, reduce the efficacy of, or circumvent any technical limitation, safety guardrail or associated safety guardrail hyperparameter, encryption, security, digital rights management, or authentication mechanism (collectively "Guardrail") contained in the Model without a substantially similar Guardrail appropriate for your use case, your rights under this Agreement will automatically terminate. NVIDIA may indicate in relevant documentation that a Model is a Special-Purpose Model. NVIDIA may update this Agreement to comply with legal and regulatory requirements at any time and You agree to either comply with any updated license or cease Your copying, use, and distribution of the Model and any Derivative Model.
32
+
33
+ 2.2 License Grant. The rights granted herein are explicitly conditioned on Your full compliance with the terms of this Agreement. Subject to the terms and conditions of this Agreement, NVIDIA hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, revocable (as stated in Section 2.1) license to publicly perform, publicly display, reproduce, use, create derivative works of, make, have made, sell, offer for sale, distribute (through multiple tiers of distribution) and import the Model.
34
+
35
+ 2.3 AI Ethics. Use of the Models under the Agreement must be consistent with NVIDIA's Trustworthy AI terms found at https://www.nvidia.com/en-us/agreements/trustworthy-ai/terms/.
36
+
37
+ 2.4 NVIDIA owns the Model and any Derivative Models created by NVIDIA. Subject to NVIDIA's underlying ownership rights in the Model or its Derivative Models, You are and will be the owner of Your Derivative Models. NVIDIA claims no ownership rights in outputs. You are responsible for outputs and their subsequent uses. Except as expressly granted in this Agreement, (a) NVIDIA reserves all rights, interests and remedies in connection with the Model and (b) no other license or right is granted to you by implication, estoppel or otherwise.
38
+
39
+ Redistribution. You may reproduce and distribute copies of the Model or Derivative Models thereof in any medium, with or without modifications, provided that You meet the following conditions:
40
+
41
+ 3.1 If you distribute the Model, You must give any other recipients of the Model a copy of this Agreement and include the following attribution notice within a "Notice" text file with such copies: "Licensed by NVIDIA Corporation under the NVIDIA Open Model License";
42
+
43
+ 3.2 If you distribute or make available a NVIDIA Cosmos Model, or a product or service (including an AI model) that contains or uses a NVIDIA Cosmos Model, use a NVIDIA Cosmos Model to create a Derivative Model, or use a NVIDIA Cosmos Model or its outputs to create, train, fine tune, or otherwise improve an AI model, you will include "Built on NVIDIA Cosmos" on a related website, user interface, blogpost, about page, or product documentation; and
44
+
45
+ 3.3 You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Models as a whole, provided Your use, reproduction, and distribution of the Model otherwise complies with the conditions stated in this Agreement.
46
+
47
+ Separate Components. The Models may include or be distributed with components provided with separate legal notices or terms that accompany the components, such as an Open Source Software License or other third-party license. The components are subject to the applicable other licenses, including any proprietary notices, disclaimers, requirements and extended use rights; except that this Agreement will prevail regarding the use of third-party Open Source Software License, unless a third-party Open Source Software License requires its license terms to prevail. "Open Source Software License" means any software, data or documentation subject to any license identified as an open source license by the Open Source Initiative (https://opensource.org), Free Software Foundation (https://www.fsf.org) or other similar open source organization or listed by the Software Package Data Exchange (SPDX) Workgroup under the Linux Foundation (https://www.spdx.org).
48
+
49
+ Trademarks. This Agreement does not grant permission to use the trade names, trademarks, service marks, or product names of NVIDIA, except as required for reasonable and customary use in describing the origin of the Model and reproducing the content of the "Notice" text file.
50
+
51
+ Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, NVIDIA provides the Model on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for reviewing Model documentation, including any Special-Purpose Model limitations, and determining the appropriateness of using or redistributing the Model, Derivative Models and outputs. You assume any risks associated with Your exercise of permissions under this Agreement.
52
+
53
+ Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, will NVIDIA be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Model, Derivative Models or outputs (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if NVIDIA has been advised of the possibility of such damages.
54
+
55
+ Indemnity. You will indemnify and hold harmless NVIDIA from and against any claim by any third party arising out of or related to your use or distribution of the Model, Derivative Models or outputs.
56
+
57
+ Feedback. NVIDIA appreciates your feedback, and You agree that NVIDIA may use it without restriction or compensation to You.
58
+
59
+ Governing Law. This Agreement will be governed in all respects by the laws of the United States and the laws of the State of Delaware, without regard to conflict of laws principles or the United Nations Convention on Contracts for the International Sale of Goods. The state and federal courts residing in Santa Clara County, California will have exclusive jurisdiction over any dispute or claim arising out of or related to this Agreement, and the parties irrevocably consent to personal jurisdiction and venue in those courts; except that, either party may apply for injunctive remedies or an equivalent type of urgent legal relief in any jurisdiction.
60
+
61
+ Trade and Compliance. You agree to comply with all applicable export, import, trade and economic sanctions laws and regulations, as amended, including without limitation U.S. Export Administration Regulations and Office of Foreign Assets Control regulations. These laws include restrictions on destinations, end-users and end-use.
62
+
63
+ Version Release Date: October 24, 2025
NOTICE ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Licensed by NVIDIA Corporation under the NVIDIA Open Model License
2
+
3
+ This model (superlinear-exp-v0.1) is a Derivative Model based on NVIDIA Nemotron-3-Nano-30B-A3B-BF16.
4
+
5
+ Upstream model: https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
6
+
7
+ Modifications by: Concavity AI (Yufeng Huang)
8
+ - Replaced standard attention layers with Superlinear attention layers
9
+ - Fine-tuned on long-context retrieval tasks
README.md CHANGED
@@ -1,6 +1,243 @@
1
- ---
2
- license: other
3
- license_name: nvidia-open-model-license
4
- license_link: >-
5
- https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license
6
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: nvidia-open-model-license
4
+ license_link: LICENSE
5
+ library_name: transformers
6
+ pipeline_tag: text-generation
7
+ tags:
8
+ - long-context
9
+ - superlinear-attention
10
+ - subquadratic
11
+ - causal-lm
12
+ base_model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
13
+ ---
14
+
15
+ # Superlinear-Exp-v0.1
16
+
17
+ **Superlinear Multi-Step Attention** — a subquadratic attention mechanism that preserves random context access (structural non-exclusion) for extremely long sequences.
18
+
19
+ This is an experimental release demonstrating the Superlinear attention architecture integrated into a modified [nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) hybrid model.
20
+
21
+ > **WARNING (Security):** This model requires `trust_remote_code=True`, which executes Python code from this repository. Review the code before running in sensitive environments.
22
+
23
+ ## Model Description
24
+
25
+ Superlinear attention reformulates causal self-attention as a multi-step search problem:
26
+
27
+ 1. **Accumulation** — Efficiently processes the sequence and produces per-position representatives (via Mamba-2 layers in the hybrid architecture).
28
+ 2. **Span Search** — Scores a sublinear number of candidate spans using learned routing, then selects top-k spans per query.
29
+ 3. **Span Attention** — Computes standard token-level attention within the selected contiguous spans.
30
+ 4. **Combination** — Produces outputs using softmax-weighted gating over span attention outputs.
31
+
32
+ In the baseline N=2 implementation, both span-search and span-attention scale as **O(L^(3/2))**, enabling practical inference at multi-million-token context lengths where dense attention becomes prohibitive.
33
+
34
+ **Key property:** *Random context access* (structural non-exclusion) — any eligible token position can be selected by the content-dependent routing mechanism; no fixed sparsity pattern permanently excludes positions.
35
+
36
+ ## Quickstart
37
+
38
+ ```python
39
+ import torch
40
+ from transformers import AutoTokenizer, AutoModelForCausalLM
41
+
42
+ tokenizer = AutoTokenizer.from_pretrained(
43
+ "concavity-ai/superlinear-exp-v0.1",
44
+ trust_remote_code=True
45
+ )
46
+
47
+ model = AutoModelForCausalLM.from_pretrained(
48
+ "concavity-ai/superlinear-exp-v0.1",
49
+ torch_dtype=torch.float16,
50
+ device_map="cuda",
51
+ trust_remote_code=True,
52
+ )
53
+
54
+ messages = [{"role": "user", "content": "Explain the Transformer architecture."}]
55
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda")
56
+
57
+ output = model.generate(inputs, max_new_tokens=1000, do_sample=True, temperature=0.1, top_p=0.99)
58
+ print(tokenizer.decode(output[0], skip_special_tokens=True))
59
+ ```
60
+
61
+ ## Dependencies
62
+
63
+ This model uses custom Python code (`trust_remote_code=True`) and CUDA extensions.
64
+
65
+ ### Recommended: follow the Superlinear repo install
66
+
67
+ The simplest supported path is to use the installation flow from the Superlinear repo (it pins a known-good CUDA toolchain and builds `mamba-ssm[causal-conv1d]` from source to avoid wheel/ABI mismatches):
68
+
69
+ https://github.com/concavity-ai/superlinear#installation
70
+
71
+ Copy/paste one-liner (from the Superlinear repo root):
72
+
73
+ ```bash
74
+ conda env create -f environment.yml \
75
+ && conda run -n superlinear pip install torch --index-url https://download.pytorch.org/whl/cu128 \
76
+ && conda run -n superlinear pip install -e ".[server,model]" \
77
+ && conda run -n superlinear bash -lc 'CUDA_HOME="$CONDA_PREFIX" pip install "mamba-ssm[causal-conv1d]" --no-build-isolation --no-cache-dir --no-binary :all:'
78
+ ```
79
+
80
+ ### Optional: pip-only (if `mamba-ssm` already works in your env)
81
+
82
+ If `python -c "import mamba_ssm, causal_conv1d"` already succeeds in the environment you’ll run inference in, you already have a working PyTorch/CUDA pairing for the extension in that environment — you should not need to reinstall PyTorch.
83
+
84
+ Install the remaining Python deps + Superlinear:
85
+
86
+ ```bash
87
+ pip install -U "transformers<5" accelerate safetensors
88
+ pip install -U vllm triton
89
+
90
+ # Superlinear kernels (span-attention)
91
+ pip install -U git+https://github.com/concavity-ai/superlinear.git
92
+ ```
93
+
94
+ ### Building `mamba-ssm` from source (only if needed)
95
+
96
+ If you must build `mamba-ssm[causal-conv1d]` yourself, you need a CUDA toolkit with `nvcc` and `CUDA_HOME` pointing at it (example: `/usr/local/cuda`):
97
+
98
+ ```bash
99
+ CUDA_HOME=/usr/local/cuda \
100
+ pip install -U "mamba-ssm[causal-conv1d]" \
101
+ --no-build-isolation --no-cache-dir --no-binary :all:
102
+ ```
103
+
104
+ ## Recommended Inference Settings
105
+
106
+ For long-context inference with the Superlinear attention mechanism, use the following configuration:
107
+
108
+ ```python
109
+ model = AutoModelForCausalLM.from_pretrained(
110
+ "concavity-ai/superlinear-exp-v0.1",
111
+ # Attention implementation
112
+ _attn_implementation='block-span-gqa',
113
+ decode_kernel='staged-gqa',
114
+
115
+ # Performance optimizations
116
+ enable_cuda_graph=True,
117
+ enable_shared_fused_moe=True,
118
+
119
+ # Superlinear attention hyperparameters
120
+ span_attention_sw_index=65, # Local window boundary index
121
+ span_attention_num_spans=3, # Top-k spans per query
122
+ span_attention_backward_factor=3, # Backward span extent multiplier
123
+ span_attention_forward_factor=1, # Forward span extent multiplier
124
+ span_attention_search_power=0.55, # Search exponent (controls anchor budget)
125
+ span_attention_span_power=0.55, # Span exponent (controls span scale)
126
+
127
+ torch_dtype=torch.float16,
128
+ device_map="cuda",
129
+ trust_remote_code=True,
130
+ )
131
+ ```
132
+
133
+ ### Hyperparameter Notes
134
+
135
+ | Parameter | Description | Typical Value |
136
+ |-----------|-------------|---------------|
137
+ | `span_attention_num_spans` | Number of routed spans selected per query (top-k) | 2 or 3 |
138
+ | `span_attention_backward_factor` | Backward extent of each span relative to base scale | 2–4 |
139
+ | `span_attention_forward_factor` | Forward extent of each span relative to base scale | 0–2 |
140
+ | `span_attention_search_power` | Exponent controlling the number of candidate anchors | 0.5–0.667 |
141
+ | `span_attention_span_power` | Exponent controlling span length scaling | 0.5–0.667 |
142
+
143
+ **Sliding window length from `span_attention_sw_index`:** Internally, the kernels compute the sliding-window length as:
144
+
145
+ ```text
146
+ window_len = floor((sw_index + 1) ** (1 / search_power)) - 1
147
+ ```
148
+
149
+ We parameterize the local sliding window using `sw_index` (a stride/stripe index) rather than specifying `window_len` directly. This keeps the sliding-window boundary aligned with the same index space used by span search, so span-search begins immediately after the sliding-window region and avoids gaps between local attention and routed spans.
150
+
151
+ Example: with `search_power=0.55` and `sw_index=65`,
152
+
153
+ ```text
154
+ window_len = floor(66 ** (1 / 0.55)) - 1 = 2032
155
+ ```
156
+
157
+ ## Hardware Requirements
158
+
159
+ - **GPU:** NVIDIA GPU with sufficient VRAM (tested on B200 180GB)
160
+ - **KV Cache:** ~6GB per million tokens (model-dependent)
161
+ - **Precision:** FP16 recommended
162
+
163
+ ### Measured Throughput (Single B200, Batch=1)
164
+
165
+ | Context Length | Prefill (tok/s) | Decode (tok/s) |
166
+ |----------------|-----------------|----------------|
167
+ | 1M tokens | ~20,202 | ~109 |
168
+ | 10M tokens | ~5,576 | ~76 |
169
+
170
+ *Your results may vary depending on hardware, batch size, and configuration.*
171
+
172
+ ## Intended Use
173
+
174
+ This is an **architecture-and-systems feasibility study** release. It demonstrates that:
175
+
176
+ 1. The Superlinear attention mechanism is structurally random-context-access-preserving
177
+ 2. The architecture achieves asymptotically subquadratic attention complexity
178
+ 3. The resulting irregular span pattern can be implemented with practical performance at very long context lengths
179
+
180
+ ### Limitations
181
+
182
+ - **Not a comprehensive quality study:** We do not present extensive ablations or claim state-of-the-art accuracy on benchmarks.
183
+ - **Limited evaluation:** Initial validation focused on NIAH (Needle In A Haystack) retrieval task and throughput measurements.
184
+ - **Experimental:** This release is intended for research and experimentation, not production use.
185
+ - **Memory:** Full KV cache must be retained for random context access; memory usage scales with context length.
186
+
187
+ ## What's in This Repository
188
+
189
+ ```
190
+ ├── config.json # Model configuration
191
+ ├── generation_config.json # Default generation settings
192
+ ├── tokenizer.json # Tokenizer
193
+ ├── tokenizer_config.json
194
+ ├── special_tokens_map.json
195
+ ├── chat_template.jinja # Chat template
196
+ ├── configuration_superlinear_exp.py # Custom config class
197
+ ├── modeling_superlinear_exp.py # Custom model implementation
198
+ ├── moe.py # MoE components
199
+ ├── model-*.safetensors # Model weights (16 shards)
200
+ ├── model.safetensors.index.json # Weight index
201
+ ├── LICENSE # NVIDIA Open Model License
202
+ ├── NOTICE # Required attribution
203
+ └── README.md # This file
204
+ ```
205
+
206
+ ## License
207
+
208
+ ### Model Weights
209
+
210
+ This model is a derivative of [nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) and is distributed under the **NVIDIA Open Model License Agreement**.
211
+
212
+ See [LICENSE](LICENSE) for the full license text.
213
+
214
+ Use of this model must be consistent with [NVIDIA's Trustworthy AI terms](https://www.nvidia.com/en-us/agreements/trustworthy-ai/terms/).
215
+
216
+ ### Code
217
+
218
+ The modeling code in this repository is provided for loading and running the model. For the broader Superlinear project codebase, see [github.com/concavity-ai/superlinear](https://github.com/concavity-ai/superlinear) (Apache-2.0).
219
+
220
+ ## Attribution
221
+
222
+ **Upstream Model:**
223
+ - NVIDIA Nemotron-3-Nano-30B-A3B ([nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16))
224
+
225
+ **Paper:**
226
+ ```bibtex
227
+ @article{huang2026superlinear,
228
+ title={Superlinear Multi-Step Attention},
229
+ author={Huang, Yufeng},
230
+ journal={arXiv preprint arXiv:2601.18401},
231
+ year={2026}
232
+ }
233
+ ```
234
+
235
+ ## Patent Notice
236
+
237
+ Patent applications have been filed related to aspects of the methods described in this work.
238
+
239
+ ## Contact
240
+
241
+ - Author: Yufeng Huang
242
+ - Email: yufeng@concavity.ai
243
+ - Organization: Concavity AI
__pycache__/configuration_superlinear_exp.cpython-312.pyc ADDED
Binary file (14.9 kB). View file
 
__pycache__/configuration_superlinear_exp.cpython-313.pyc ADDED
Binary file (14.9 kB). View file
 
__pycache__/modeling_superlinear_exp.cpython-312.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:468f989a95d55eb2d38d742756830863d7bffb25e39fb9c304a200b4d9c63b2d
3
+ size 164921
__pycache__/modeling_superlinear_exp.cpython-313.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2dc47238f394aca2f6c7176e3fa059bc975b660170d0d7fc6e385e0d6f69750d
3
+ size 168645
__pycache__/moe.cpython-313.pyc ADDED
Binary file (35.1 kB). View file
 
chat_template.jinja ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% macro render_extra_keys(json_dict, handled_keys) %}
2
+ {%- if json_dict is mapping %}
3
+ {%- for json_key in json_dict if json_key not in handled_keys %}
4
+ {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %}
5
+ {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }}
6
+ {%- else %}
7
+ {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }}
8
+ {%- endif %}
9
+ {%- endfor %}
10
+ {%- endif %}
11
+ {% endmacro %}
12
+ {%- set enable_thinking = enable_thinking if enable_thinking is defined else True %}
13
+ {%- set truncate_history_thinking = truncate_history_thinking if truncate_history_thinking is defined else True %}
14
+
15
+ {%- set ns = namespace(last_user_idx = -1) %}
16
+ {%- set loop_messages = messages %}
17
+ {%- for m in loop_messages %}
18
+ {%- if m["role"] == "user" %}
19
+ {%- set ns.last_user_idx = loop.index0 %}
20
+ {%- endif %}
21
+ {%- endfor %}
22
+
23
+ {%- if messages[0]["role"] == "system" %}
24
+ {%- set system_message = messages[0]["content"] %}
25
+ {%- set loop_messages = messages[1:] %}
26
+ {%- else %}
27
+ {%- set system_message = "" %}
28
+ {%- set loop_messages = messages %}
29
+ {%- endif %}
30
+ {%- if not tools is defined %}
31
+ {%- set tools = [] %}
32
+ {%- endif %}
33
+ {# Recompute last_user_idx relative to loop_messages after handling system #}
34
+ {%- set ns = namespace(last_user_idx = -1) %}
35
+ {%- for m in loop_messages %}
36
+ {%- if m["role"] == "user" %}
37
+ {%- set ns.last_user_idx = loop.index0 %}
38
+ {%- endif %}
39
+ {%- endfor %}
40
+ {%- if system_message is defined %}
41
+ {{- "<|im_start|>system\n" + system_message }}
42
+ {%- else %}
43
+ {%- if tools is iterable and tools | length > 0 %}
44
+ {{- "<|im_start|>system\n" }}
45
+ {%- endif %}
46
+ {%- endif %}
47
+ {%- if tools is iterable and tools | length > 0 %}
48
+ {%- if system_message is defined and system_message | length > 0 %}
49
+ {{- "\n\n" }}
50
+ {%- endif %}
51
+ {{- "# Tools\n\nYou have access to the following functions:\n\n" }}
52
+ {{- "<tools>" }}
53
+ {%- for tool in tools %}
54
+ {%- if tool.function is defined %}
55
+ {%- set tool = tool.function %}
56
+ {%- endif %}
57
+ {{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }}
58
+ {%- if tool.description is defined %}
59
+ {{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }}
60
+ {%- endif %}
61
+ {{- '\n<parameters>' }}
62
+ {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}
63
+ {%- for param_name, param_fields in tool.parameters.properties|items %}
64
+ {{- '\n<parameter>' }}
65
+ {{- '\n<name>' ~ param_name ~ '</name>' }}
66
+ {%- if param_fields.type is defined %}
67
+ {{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }}
68
+ {%- endif %}
69
+ {%- if param_fields.description is defined %}
70
+ {{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}
71
+ {%- endif %}
72
+ {%- if param_fields.enum is defined %}
73
+ {{- '\n<enum>' ~ (param_fields.enum | tojson | safe) ~ '</enum>' }}
74
+ {%- endif %}
75
+ {%- set handled_keys = ['name', 'type', 'description', 'enum'] %}
76
+ {{- render_extra_keys(param_fields, handled_keys) }}
77
+ {{- '\n</parameter>' }}
78
+ {%- endfor %}
79
+ {%- endif %}
80
+ {% set handled_keys = ['type', 'properties', 'required'] %}
81
+ {{- render_extra_keys(tool.parameters, handled_keys) }}
82
+ {%- if tool.parameters is defined and tool.parameters.required is defined %}
83
+ {{- '\n<required>' ~ (tool.parameters.required | tojson | safe) ~ '</required>' }}
84
+ {%- endif %}
85
+ {{- '\n</parameters>' }}
86
+ {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}
87
+ {{- render_extra_keys(tool, handled_keys) }}
88
+ {{- '\n</function>' }}
89
+ {%- endfor %}
90
+ {{- "\n</tools>" }}
91
+
92
+ {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
93
+ {%- endif %}
94
+
95
+
96
+ {%- if system_message is defined %}
97
+ {{- '<|im_end|>\n' }}
98
+ {%- else %}
99
+ {%- if tools is iterable and tools | length > 0 %}
100
+ {{- '<|im_end|>\n' }}
101
+ {%- endif %}
102
+ {%- endif %}
103
+
104
+ {%- for message in loop_messages %}
105
+ {%- if message.role == "assistant" %}
106
+ {# Add reasoning content in to content field for unified processing below. #}
107
+ {%- if message.reasoning_content is defined and message.reasoning_content is string and message.reasoning_content | trim | length > 0 %}
108
+ {%- set content = "<think>\n" ~ message.reasoning_content ~ "\n</think>\n" ~ (message.content | default('', true)) %}
109
+ {%- else %}
110
+ {%- set content = message.content | default('', true) %}
111
+ {%- if content is string -%}
112
+ {# Allow downstream logic to to take care of broken thought, only handle coherent reasoning here. #}
113
+ {%- if '<think>' not in content and '</think>' not in content -%}
114
+ {%- set content = "<think></think>" ~ content -%}
115
+ {%- endif -%}
116
+ {%- else -%}
117
+ {%- set content = content -%}
118
+ {%- endif -%}
119
+ {%- endif %}
120
+ {%- if message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %}
121
+ {# Assistant message has tool calls. #}
122
+ {{- '<|im_start|>assistant\n' }}
123
+ {%- set include_content = not (truncate_history_thinking and loop.index0 < ns.last_user_idx) %}
124
+ {%- if content is string and content | trim | length > 0 %}
125
+ {%- if include_content %}
126
+ {{- (content | trim) ~ '\n' -}}
127
+ {%- else %}
128
+ {%- set c = (content | string) %}
129
+ {%- if '</think>' in c %}
130
+ {# Keep only content after the last closing think. Also generation prompt causes this. #}
131
+ {%- set c = c.split('</think>')[-1] %}
132
+ {%- elif '<think>' in c %}
133
+ {# If <think> was opened but never closed, drop the trailing think segment #}
134
+ {%- set c = c.split('<think>')[0] %}
135
+ {%- endif %}
136
+ {%- set c = "<think></think>" ~ c | trim %}
137
+ {%- if c | length > 0 %}
138
+ {{- c ~ '\n' -}}
139
+ {%- endif %}
140
+ {%- endif %}
141
+ {%- else %}
142
+ {{- "<think></think>" -}}
143
+ {%- endif %}
144
+ {%- for tool_call in message.tool_calls %}
145
+ {%- if tool_call.function is defined %}
146
+ {%- set tool_call = tool_call.function %}
147
+ {%- endif %}
148
+ {{- '<tool_call>\n<function=' ~ tool_call.name ~ '>\n' -}}
149
+ {%- if tool_call.arguments is defined %}
150
+ {%- for args_name, args_value in tool_call.arguments|items %}
151
+ {{- '<parameter=' ~ args_name ~ '>\n' -}}
152
+ {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
153
+ {{- args_value ~ '\n</parameter>\n' -}}
154
+ {%- endfor %}
155
+ {%- endif %}
156
+ {{- '</function>\n</tool_call>\n' -}}
157
+ {%- endfor %}
158
+ {{- '<|im_end|>\n' }}
159
+ {%- else %}
160
+ {# Assistant message doesn't have tool calls. #}
161
+ {%- if not (truncate_history_thinking and loop.index0 < ns.last_user_idx) %}
162
+ {{- '<|im_start|>assistant\n' ~ (content | default('', true) | string | trim) ~ '<|im_end|>\n' }}
163
+ {%- else %}
164
+ {%- set c = (content | default('', true) | string) %}
165
+ {%- if '<think>' in c and '</think>' in c %}
166
+ {%- set c = "<think></think>" ~ c.split('</think>')[-1] %}
167
+ {%- endif %}
168
+ {%- set c = c | trim %}
169
+ {%- if c | length > 0 %}
170
+ {{- '<|im_start|>assistant\n' ~ c ~ '<|im_end|>\n' }}
171
+ {%- else %}
172
+ {{- '<|im_start|>assistant\n<|im_end|>\n' }}
173
+ {%- endif %}
174
+ {%- endif %}
175
+ {%- endif %}
176
+ {%- elif message.role == "user" or message.role == "system" %}
177
+ {{- '<|im_start|>' + message.role + '\n' }}
178
+ {%- set content = message.content | string %}
179
+ {{- content }}
180
+ {{- '<|im_end|>\n' }}
181
+ {%- elif message.role == "tool" %}
182
+ {%- if loop.previtem and loop.previtem.role != "tool" %}
183
+ {{- '<|im_start|>user\n' }}
184
+ {%- endif %}
185
+ {{- '<tool_response>\n' }}
186
+ {{- message.content }}
187
+ {{- '\n</tool_response>\n' }}
188
+ {%- if not loop.last and loop.nextitem.role != "tool" %}
189
+ {{- '<|im_end|>\n' }}
190
+ {%- elif loop.last %}
191
+ {{- '<|im_end|>\n' }}
192
+ {%- endif %}
193
+ {%- else %}
194
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }}
195
+ {%- endif %}
196
+ {%- endfor %}
197
+
198
+ {%- if add_generation_prompt %}
199
+ {%- if enable_thinking %}
200
+ {{- '<|im_start|>assistant\n<think>\n' }}
201
+ {%- else %}
202
+ {{- '<|im_start|>assistant\n<think></think>' }}
203
+ {%- endif %}
204
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SuperlinearExpForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_superlinear_exp.SuperlinearExpConfig",
9
+ "AutoModel": "modeling_superlinear_exp.SuperlinearExpForCausalLM",
10
+ "AutoModelForCausalLM": "modeling_superlinear_exp.SuperlinearExpForCausalLM"
11
+ },
12
+ "bos_token_id": 1,
13
+ "chunk_size": 128,
14
+ "conv_kernel": 4,
15
+ "decode_kernel": "staged-gqa",
16
+ "enable_cuda_graph": true,
17
+ "enable_shared_fused_moe": true,
18
+ "eos_token_id": 2,
19
+ "expand": 2,
20
+ "head_dim": 128,
21
+ "hidden_dropout": 0.0,
22
+ "hidden_size": 2688,
23
+ "hybrid_override_pattern": "MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME",
24
+ "initializer_range": 0.02,
25
+ "intermediate_size": 1856,
26
+ "layer_norm_epsilon": 1e-05,
27
+ "mamba_head_dim": 64,
28
+ "mamba_hidden_act": "silu",
29
+ "mamba_num_heads": 64,
30
+ "mamba_proj_bias": false,
31
+ "mamba_ssm_cache_dtype": "float32",
32
+ "max_position_embeddings": 262144,
33
+ "mlp_bias": false,
34
+ "mlp_hidden_act": "relu2",
35
+ "model_type": "superlinear-exp",
36
+ "moe_intermediate_size": 1856,
37
+ "moe_shared_expert_intermediate_size": 3712,
38
+ "n_group": 1,
39
+ "n_groups": 8,
40
+ "n_routed_experts": 128,
41
+ "n_shared_experts": 1,
42
+ "norm_eps": 1e-05,
43
+ "norm_topk_prob": true,
44
+ "num_attention_heads": 32,
45
+ "num_experts_per_tok": 6,
46
+ "num_hidden_layers": 52,
47
+ "num_key_value_heads": 2,
48
+ "num_logits_to_keep": 1,
49
+ "pad_token_id": 0,
50
+ "partial_rotary_factor": 1.0,
51
+ "rescale_prenorm_residual": true,
52
+ "residual_in_fp32": false,
53
+ "rope_theta": 10000,
54
+ "routed_scaling_factor": 2.5,
55
+ "sliding_window": null,
56
+ "span_attention_backward_factor": 3.0,
57
+ "span_attention_forward_factor": 1.0,
58
+ "span_attention_inv_search_power_int": null,
59
+ "span_attention_num_spans": 3,
60
+ "span_attention_search_power": 0.55,
61
+ "span_attention_span_power": 0.55,
62
+ "span_attention_sw_index": 65,
63
+ "ssm_state_size": 128,
64
+ "tie_word_embeddings": false,
65
+ "time_step_floor": 0.0001,
66
+ "time_step_limit": [
67
+ 0.0,
68
+ Infinity
69
+ ],
70
+ "time_step_max": 0.1,
71
+ "time_step_min": 0.001,
72
+ "topk_group": 1,
73
+ "torch_dtype": "bfloat16",
74
+ "transformers_version": "4.55.4",
75
+ "use_bias": false,
76
+ "use_cache": true,
77
+ "use_conv_bias": true,
78
+ "use_mamba_kernels": true,
79
+ "vocab_size": 131072
80
+ }
configuration_superlinear_exp.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
3
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """SuperlinearExp model configuration"""
17
+
18
+ import math
19
+ import re
20
+
21
+ from transformers.configuration_utils import PretrainedConfig
22
+ from transformers.utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class SuperlinearExpConfig(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`SuperlinearExpModel`]. It is used to instantiate a
31
+ SuperlinearExp model according to the specified arguments, defining the model architecture. Instantiating a configuration
32
+ with the defaults will yield a similar configuration to that of the SuperlinearExp-v0.1 model.
33
+ [todo](todo)
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 131072):
41
+ Vocabulary size of the SuperlinearExp model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`SuperlinearExpModel`]
43
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
44
+ Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
45
+ model has a output word embedding layer.
46
+ hidden_size (`int`, *optional*, defaults to 4096):
47
+ Dimension of the hidden representations.
48
+ intermediate_size (`int`, *optional*, defaults to 21504):
49
+ Dimension of the MLP representations.
50
+ num_hidden_layers (`int`, *optional*, defaults to 52):
51
+ Number of hidden layers in the Transformer encoder.
52
+ hybrid_override_pattern (`str`, *optional*, defaults to `"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"`):
53
+ The pattern of the hybrid model. The pattern is a string of characters where each character represents M: Mamba2, *: Attention, -: MLP
54
+ num_attention_heads (`int`, *optional*, defaults to 32):
55
+ Number of attention heads for each attention layer in the Transformer encoder.
56
+ head_dim (`int`, *optional*, defaults to 128):
57
+ Dimension of each attention head.
58
+ num_key_value_heads (`int`, *optional*, defaults to 8):
59
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
60
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
61
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used.
62
+ mlp_hidden_act (`str`, *optional*, defaults to "relu2"):
63
+ The non-linear activation function in the MLP layers.
64
+ attention_bias (`bool`, *optional*, defaults to `False`):
65
+ Whether to use bias in attention layers.
66
+ mlp_bias (`bool`, *optional*, defaults to `False`):
67
+ Whether to use bias in MLP layers.
68
+ use_bias (`bool`, *optional*, defaults to `False`):
69
+ Whether to use bias in the model.
70
+ initializer_range (`float`, *optional*, defaults to 0.02):
71
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
72
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
73
+ The epsilon used by the layer normalization layers.
74
+ residual_in_fp32 (`bool`, *optional*, defaults to `False`):
75
+ Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model.
76
+ use_cache (`bool`, *optional*, defaults to `True`):
77
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
78
+ relevant if `config.is_decoder=True`.
79
+ num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
80
+ Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
81
+ integer value, only last `num_logits_to_keep` logits will be calculated.
82
+ pad_token_id (`int`, *optional*, defaults to 0):
83
+ The id of the padding token.
84
+ bos_token_id (`int`, *optional*, defaults to 1):
85
+ The id of the "beginning-of-sequence" token.
86
+ eos_token_id (`int`, *optional*, defaults to 2):
87
+ The id of the "end-of-sequence" token.
88
+ sliding_window (`int`, *optional*, defaults to None):
89
+ Sliding window attention window size.
90
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
91
+ The maximum sequence length that this model might ever be used with.
92
+ attention_dropout (`float`, *optional*, defaults to 0.0):
93
+ The dropout ratio for the attention probabilities.
94
+ hidden_dropout (`float`, *optional*, defaults to 0.0):
95
+ The dropout ratio for the hidden states.
96
+ use_mamba_kernels (`bool`, *optional*, defaults to `True`):
97
+ Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and
98
+ `causal-conv1d` are installed, and the mamba modules are running on a CUDA device.
99
+ ssm_state_size (`int`, *optional*, defaults to 128):
100
+ The dimension of the mamba state space latents.
101
+ mamba_num_heads (`int`, *optional*, defaults to 128):
102
+ Number of heads in Mamba layers.
103
+ mamba_n_groups (`int`, *optional*, defaults to 8):
104
+ Number of groups in Mamba layers.
105
+ mamba_head_dim (`int`, *optional*, defaults to 64):
106
+ Dimension of each Mamba head.
107
+ mamba_d_conv (`int`, *optional*, defaults to 4):
108
+ The size of the mamba convolution kernel.
109
+ mamba_expand (`int`, *optional*, defaults to 2):
110
+ Expanding factor used to determine the mamba intermediate size.
111
+ mamba_hidden_act (`str`, *optional*, defaults to "silu"):
112
+ The non-linear activation function in the Mamba layers.
113
+ mamba_dt_min (`float`, *optional*, defaults to 0.001):
114
+ Minimum value for the time step in Mamba.
115
+ mamba_dt_max (`float`, *optional*, defaults to 0.1):
116
+ Maximum value for the time step in Mamba.
117
+ mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float("inf"))):
118
+ Limits for the time step in Mamba.
119
+ mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4):
120
+ Floor value for time step initialization in Mamba.
121
+ mamba_conv_bias (`bool`, *optional*, defaults to `True`):
122
+ Whether to use bias in the convolution layer of the mamba mixer block.
123
+ mamba_proj_bias (`bool`, *optional*, defaults to `False`):
124
+ Whether to use bias in the input and output projections of the mamba mixer block.
125
+ mamba_chunk_size (`int`, *optional*, defaults to 256):
126
+ Size of chunks for Mamba processing.
127
+ rescale_prenorm_residual (`bool`, *optional*, defaults to `True`):
128
+ Whether to rescale the pre-normalization residual connections.
129
+ """
130
+
131
+ model_type = "superlinear-exp"
132
+ keys_to_ignore_at_inference = ["past_key_values"]
133
+
134
+ def __init__(
135
+ self,
136
+ vocab_size=131072,
137
+ tie_word_embeddings=False,
138
+ hidden_size=4096,
139
+ intermediate_size=21504,
140
+ num_hidden_layers=52,
141
+ hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-",
142
+ num_attention_heads=32,
143
+ head_dim=128,
144
+ num_key_value_heads=8, # nemo: num_query_groups
145
+ mlp_hidden_act="relu2",
146
+ attention_bias=False,
147
+ mlp_bias=False,
148
+ use_bias=False,
149
+ initializer_range=0.02, # nemo: init_method_std
150
+ layer_norm_epsilon=1e-5, # nemo: layernorm_epsilon
151
+ residual_in_fp32=False, # Megatron Core default value
152
+ use_cache=True,
153
+ num_logits_to_keep=1,
154
+ pad_token_id=0,
155
+ bos_token_id=1,
156
+ eos_token_id=2,
157
+ sliding_window=None,
158
+ max_position_embeddings=4096,
159
+ attention_dropout=0.0,
160
+ hidden_dropout=0.0, # * ADDED
161
+ use_mamba_kernels=True,
162
+ ssm_state_size=128, # mamba_state_size
163
+ mamba_num_heads=128,
164
+ mamba_n_groups=8, # nemo: mamba_ssm_ngroups = num_heads
165
+ mamba_head_dim=64,
166
+ mamba_d_conv=4,
167
+ mamba_expand=2,
168
+ mamba_hidden_act="silu",
169
+ mamba_dt_min=0.001,
170
+ mamba_dt_max=0.1,
171
+ mamba_dt_limit=(0.0, float("inf")),
172
+ mamba_dt_init_floor=1e-4,
173
+ mamba_conv_bias=True,
174
+ mamba_proj_bias=False,
175
+ mamba_chunk_size=128,
176
+ rescale_prenorm_residual=True,
177
+ span_attention_sw_index=65,
178
+ span_attention_num_spans=3,
179
+ span_attention_backward_factor: float = 3.0,
180
+ span_attention_forward_factor: float = 1.0,
181
+ span_attention_span_power: float = 0.55,
182
+ span_attention_search_power: float | None = 0.55,
183
+ span_attention_inv_search_power_int: int | None = None,
184
+ decode_kernel="staged-gqa",
185
+ enable_cuda_graph: bool = True,
186
+ enable_shared_fused_moe: bool = True,
187
+ n_routed_experts=8,
188
+ n_shared_experts=1,
189
+ moe_intermediate_size=7688,
190
+ moe_shared_expert_intermediate_size=7688,
191
+ num_experts_per_tok=2,
192
+ routed_scaling_factor=1.0,
193
+ n_group=1,
194
+ topk_group=1,
195
+ norm_topk_prob=True,
196
+ **kwargs,
197
+ ):
198
+ self.vocab_size = vocab_size
199
+ self.tie_word_embeddings = tie_word_embeddings
200
+ self.hidden_size = hidden_size
201
+ self.intermediate_size = intermediate_size
202
+ self.num_hidden_layers = num_hidden_layers
203
+ self.hybrid_override_pattern = hybrid_override_pattern
204
+ self.num_attention_heads = num_attention_heads
205
+ self.head_dim = head_dim
206
+ self.sliding_window = sliding_window
207
+ self.max_position_embeddings = max_position_embeddings
208
+ self.attention_dropout = attention_dropout
209
+ self.hidden_dropout = hidden_dropout
210
+
211
+ # Span attention configuration
212
+ self.span_attention_sw_index = span_attention_sw_index
213
+ self.span_attention_num_spans = span_attention_num_spans
214
+ self.span_attention_backward_factor = float(span_attention_backward_factor)
215
+ self.span_attention_forward_factor = float(span_attention_forward_factor)
216
+ self.span_attention_span_power = float(span_attention_span_power)
217
+ if not math.isfinite(self.span_attention_backward_factor) or self.span_attention_backward_factor < 0.0:
218
+ raise ValueError(
219
+ "span_attention_backward_factor must be finite and >= 0 "
220
+ f"(got {self.span_attention_backward_factor})"
221
+ )
222
+ if not math.isfinite(self.span_attention_forward_factor) or self.span_attention_forward_factor < 0.0:
223
+ raise ValueError(
224
+ "span_attention_forward_factor must be finite and >= 0 "
225
+ f"(got {self.span_attention_forward_factor})"
226
+ )
227
+ if (self.span_attention_backward_factor + self.span_attention_forward_factor) <= 0.0:
228
+ raise ValueError(
229
+ "span_attention_backward_factor + span_attention_forward_factor must be > 0 "
230
+ f"(got {self.span_attention_backward_factor + self.span_attention_forward_factor})"
231
+ )
232
+ if not math.isfinite(self.span_attention_span_power) or not (0.0 < self.span_attention_span_power < 1.0):
233
+ raise ValueError(
234
+ "span_attention_span_power must be finite and in (0, 1) "
235
+ f"(got {self.span_attention_span_power})"
236
+ )
237
+
238
+ # Stripe power parameters (search stripes + sliding window width).
239
+ if (span_attention_inv_search_power_int is None) == (span_attention_search_power is None):
240
+ raise ValueError(
241
+ "Provide exactly one of span_attention_inv_search_power_int or span_attention_search_power"
242
+ )
243
+ if span_attention_inv_search_power_int is not None:
244
+ inv_n = int(span_attention_inv_search_power_int)
245
+ if inv_n not in (2, 3, 4, 5, 6):
246
+ raise ValueError(
247
+ "span_attention_inv_search_power_int must be one of (2, 3, 4, 5, 6) "
248
+ f"(got {span_attention_inv_search_power_int})"
249
+ )
250
+ self.span_attention_inv_search_power_int = inv_n
251
+ self.span_attention_search_power = None
252
+ derived_p = 1.0 / float(inv_n)
253
+ else:
254
+ p = float(span_attention_search_power)
255
+ if not math.isfinite(p) or not (0.0 < p < 1.0):
256
+ raise ValueError(
257
+ "span_attention_search_power must be finite and in (0, 1) "
258
+ f"(got {span_attention_search_power})"
259
+ )
260
+ self.span_attention_inv_search_power_int = None
261
+ self.span_attention_search_power = p
262
+ derived_p = p
263
+
264
+ # Critical coverage constraint (see 47.1_generalized_stripe_power_design.md):
265
+ # span_power >= 1 - search_power.
266
+ if self.span_attention_span_power + derived_p < 1.0:
267
+ raise ValueError(
268
+ "span_attention_span_power must satisfy span_power >= 1 - search_power "
269
+ f"(got span_power={self.span_attention_span_power}, search_power={derived_p})"
270
+ )
271
+
272
+ if decode_kernel not in (None, "staged", "staged-gqa"):
273
+ raise ValueError(
274
+ f"Invalid decode_kernel={decode_kernel!r}; expected one of None, 'staged', 'staged-gqa'."
275
+ )
276
+ self.decode_kernel = decode_kernel
277
+ self.enable_cuda_graph = enable_cuda_graph
278
+ self.enable_shared_fused_moe = enable_shared_fused_moe
279
+
280
+ # Validate hybrid_override_pattern
281
+ # M: Mamba2, *: Attention, -: MLP
282
+ assert len(self.hybrid_override_pattern) == self.num_hidden_layers, "hybrid_override_pattern must have the same length as num_hidden_layers"
283
+ assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), "hybrid_override_pattern must only contain characters 'M', '*', or '-'"
284
+
285
+ # for backward compatibility
286
+ if num_key_value_heads is None:
287
+ num_key_value_heads = num_attention_heads
288
+
289
+ self.num_key_value_heads = num_key_value_heads
290
+ self.mlp_hidden_act = mlp_hidden_act
291
+ self.attention_bias = attention_bias
292
+ self.mlp_bias = mlp_bias
293
+ self.use_bias = use_bias
294
+ self.initializer_range = initializer_range
295
+ self.layer_norm_epsilon = layer_norm_epsilon
296
+ self.residual_in_fp32 = residual_in_fp32
297
+
298
+ self.use_cache = use_cache
299
+ self.num_logits_to_keep = num_logits_to_keep
300
+
301
+ self.use_mamba_kernels = use_mamba_kernels
302
+ self.n_groups = mamba_n_groups
303
+ self.mamba_head_dim = mamba_head_dim
304
+ self.ssm_state_size = ssm_state_size
305
+ self.mamba_num_heads = mamba_num_heads
306
+ self.conv_kernel = mamba_d_conv
307
+ self.expand = mamba_expand
308
+ self.mamba_hidden_act = mamba_hidden_act
309
+ self.time_step_min = mamba_dt_min
310
+ self.time_step_max = mamba_dt_max
311
+ self.time_step_limit = mamba_dt_limit
312
+ self.time_step_floor = mamba_dt_init_floor
313
+ self.use_conv_bias = mamba_conv_bias
314
+ self.mamba_proj_bias = mamba_proj_bias
315
+ self.chunk_size = mamba_chunk_size
316
+ self.rescale_prenorm_residual = rescale_prenorm_residual
317
+ self.n_routed_experts = n_routed_experts
318
+ self.n_shared_experts = n_shared_experts
319
+ self.moe_intermediate_size = moe_intermediate_size
320
+ self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size
321
+ self.num_experts_per_tok = num_experts_per_tok
322
+ self.routed_scaling_factor = routed_scaling_factor
323
+ self.n_group = n_group
324
+ self.topk_group = topk_group
325
+ self.norm_topk_prob = norm_topk_prob
326
+
327
+ super().__init__(
328
+ pad_token_id=pad_token_id,
329
+ bos_token_id=bos_token_id,
330
+ eos_token_id=eos_token_id,
331
+ tie_word_embeddings=tie_word_embeddings,
332
+ **kwargs,
333
+ )
334
+
335
+ @property
336
+ def layers_block_type(self):
337
+ return [
338
+ "mamba" if self.hybrid_override_pattern[i] == "M" else
339
+ "attention" if self.hybrid_override_pattern[i] == "*" else
340
+ "mlp" if self.hybrid_override_pattern[i] == "-" else "moe"
341
+ for i in range(self.num_hidden_layers)]
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "do_sample": true,
5
+ "eos_token_id": [
6
+ 2,
7
+ 11
8
+ ],
9
+ "pad_token_id": 0,
10
+ "transformers_version": "4.57.3"
11
+ }
model-00001-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97abe5c9503b8f60e502d509898c658fb8c366b32308463e6fc2cd22d9533973
3
+ size 4654236816
model-00002-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:593e019ec60819249448c91e647abff72d9171a212d1664c0957dd9b2cb94bad
3
+ size 4136509104
model-00003-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:595c430f59919f62e1c3ba70596e853f004bb20710cbbe5b679ba8b22ef4c720
3
+ size 3949593672
model-00004-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fb6e9c083f14e67e0a8c6f1c90dd73d77fd0dadbc3a76203ead3f5e3b67cf91
3
+ size 4213999864
model-00005-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40bc0337b6b2f619e9641278ed97058cc099ca001d4d3ff80de0d97a886b5411
3
+ size 3949593672
model-00006-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6104484169ae2ffa3b919187a176e762414dfe48a6d858461db935278441ffcd
3
+ size 4136509104
model-00007-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0bb006d8d803aa869ad584a8c6e5926b398f0c04ddc1740d842d3d6a369f8ee
3
+ size 3872102896
model-00008-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:485cf181874c6b9cb817407bab34b8334e5515329323f9683d7e93b25679c331
3
+ size 4136509096
model-00009-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b49942cf27c63bec56f5484b113c50442735a3c5885890b43f54cb0509ee555
3
+ size 3949593672
model-00010-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4cd5c8dfe3bba15d056458fca0413511f576c3fae10b9f2daae88d247b17db1b
3
+ size 4145181008
model-00011-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71e77335c4a7af5d8859a2108039a4f30bebcaf278ebd4b56f9dd8f403495abb
3
+ size 4018412536
model-00012-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad603897bb6b7c951bd69e9f89009521909163c8a6195c4fdab00b545fd99c03
3
+ size 4067690240
model-00013-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e79a65b74d9ad6261748131aa86a7d84833bdf5d7609816219e4c7d5bd4c6ec0
3
+ size 3949593672
model-00014-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f97e980088587a0b273d4709aad56b9e33b31e6887b483941f67fa06c0a2ee4
3
+ size 4059018320
model-00015-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:684690b0a3770d79e2466e26fc5d75930572dea5e3b1e42e3aeaff49bcbb5566
3
+ size 3949593656
model-00016-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96d2eda64709c79a4d25ebc2d96d8abe118bbda0150f7874add27cf9695db7f2
3
+ size 2099910896
model.safetensors.index.json ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 63288001152
4
+ },
5
+ "weight_map": {
6
+ "backbone.embeddings.weight": "model-00001-of-00016.safetensors",
7
+ "backbone.layers.0.mixer.A_log": "model-00001-of-00016.safetensors",
8
+ "backbone.layers.0.mixer.D": "model-00001-of-00016.safetensors",
9
+ "backbone.layers.0.mixer.conv1d.bias": "model-00001-of-00016.safetensors",
10
+ "backbone.layers.0.mixer.conv1d.weight": "model-00001-of-00016.safetensors",
11
+ "backbone.layers.0.mixer.dt_bias": "model-00001-of-00016.safetensors",
12
+ "backbone.layers.0.mixer.in_proj.weight": "model-00001-of-00016.safetensors",
13
+ "backbone.layers.0.mixer.norm.weight": "model-00001-of-00016.safetensors",
14
+ "backbone.layers.0.mixer.out_proj.weight": "model-00001-of-00016.safetensors",
15
+ "backbone.layers.0.norm.weight": "model-00001-of-00016.safetensors",
16
+ "backbone.layers.1.mixer.experts.down_proj.weight": "model-00001-of-00016.safetensors",
17
+ "backbone.layers.1.mixer.experts.up_proj.weight": "model-00001-of-00016.safetensors",
18
+ "backbone.layers.1.mixer.gate.e_score_correction_bias": "model-00001-of-00016.safetensors",
19
+ "backbone.layers.1.mixer.gate.weight": "model-00001-of-00016.safetensors",
20
+ "backbone.layers.1.mixer.shared_experts.down_proj.weight": "model-00001-of-00016.safetensors",
21
+ "backbone.layers.1.mixer.shared_experts.up_proj.weight": "model-00001-of-00016.safetensors",
22
+ "backbone.layers.1.norm.weight": "model-00001-of-00016.safetensors",
23
+ "backbone.layers.10.mixer.experts.down_proj.weight": "model-00001-of-00016.safetensors",
24
+ "backbone.layers.10.mixer.experts.up_proj.weight": "model-00002-of-00016.safetensors",
25
+ "backbone.layers.10.mixer.gate.e_score_correction_bias": "model-00002-of-00016.safetensors",
26
+ "backbone.layers.10.mixer.gate.weight": "model-00002-of-00016.safetensors",
27
+ "backbone.layers.10.mixer.shared_experts.down_proj.weight": "model-00002-of-00016.safetensors",
28
+ "backbone.layers.10.mixer.shared_experts.up_proj.weight": "model-00002-of-00016.safetensors",
29
+ "backbone.layers.10.norm.weight": "model-00002-of-00016.safetensors",
30
+ "backbone.layers.11.mixer.A_log": "model-00002-of-00016.safetensors",
31
+ "backbone.layers.11.mixer.D": "model-00002-of-00016.safetensors",
32
+ "backbone.layers.11.mixer.conv1d.bias": "model-00002-of-00016.safetensors",
33
+ "backbone.layers.11.mixer.conv1d.weight": "model-00002-of-00016.safetensors",
34
+ "backbone.layers.11.mixer.dt_bias": "model-00002-of-00016.safetensors",
35
+ "backbone.layers.11.mixer.in_proj.weight": "model-00002-of-00016.safetensors",
36
+ "backbone.layers.11.mixer.norm.weight": "model-00002-of-00016.safetensors",
37
+ "backbone.layers.11.mixer.out_proj.weight": "model-00002-of-00016.safetensors",
38
+ "backbone.layers.11.norm.weight": "model-00002-of-00016.safetensors",
39
+ "backbone.layers.12.mixer.k_proj.weight": "model-00002-of-00016.safetensors",
40
+ "backbone.layers.12.mixer.o_proj.weight": "model-00002-of-00016.safetensors",
41
+ "backbone.layers.12.mixer.q_proj.weight": "model-00002-of-00016.safetensors",
42
+ "backbone.layers.12.mixer.s_proj.weight": "model-00002-of-00016.safetensors",
43
+ "backbone.layers.12.mixer.v_proj.weight": "model-00002-of-00016.safetensors",
44
+ "backbone.layers.12.norm.weight": "model-00002-of-00016.safetensors",
45
+ "backbone.layers.13.mixer.experts.down_proj.weight": "model-00002-of-00016.safetensors",
46
+ "backbone.layers.13.mixer.experts.up_proj.weight": "model-00002-of-00016.safetensors",
47
+ "backbone.layers.13.mixer.gate.e_score_correction_bias": "model-00002-of-00016.safetensors",
48
+ "backbone.layers.13.mixer.gate.weight": "model-00002-of-00016.safetensors",
49
+ "backbone.layers.13.mixer.shared_experts.down_proj.weight": "model-00002-of-00016.safetensors",
50
+ "backbone.layers.13.mixer.shared_experts.up_proj.weight": "model-00002-of-00016.safetensors",
51
+ "backbone.layers.13.norm.weight": "model-00002-of-00016.safetensors",
52
+ "backbone.layers.14.mixer.A_log": "model-00002-of-00016.safetensors",
53
+ "backbone.layers.14.mixer.D": "model-00002-of-00016.safetensors",
54
+ "backbone.layers.14.mixer.conv1d.bias": "model-00002-of-00016.safetensors",
55
+ "backbone.layers.14.mixer.conv1d.weight": "model-00002-of-00016.safetensors",
56
+ "backbone.layers.14.mixer.dt_bias": "model-00002-of-00016.safetensors",
57
+ "backbone.layers.14.mixer.in_proj.weight": "model-00002-of-00016.safetensors",
58
+ "backbone.layers.14.mixer.norm.weight": "model-00002-of-00016.safetensors",
59
+ "backbone.layers.14.mixer.out_proj.weight": "model-00002-of-00016.safetensors",
60
+ "backbone.layers.14.norm.weight": "model-00002-of-00016.safetensors",
61
+ "backbone.layers.15.mixer.experts.down_proj.weight": "model-00003-of-00016.safetensors",
62
+ "backbone.layers.15.mixer.experts.up_proj.weight": "model-00003-of-00016.safetensors",
63
+ "backbone.layers.15.mixer.gate.e_score_correction_bias": "model-00003-of-00016.safetensors",
64
+ "backbone.layers.15.mixer.gate.weight": "model-00003-of-00016.safetensors",
65
+ "backbone.layers.15.mixer.shared_experts.down_proj.weight": "model-00003-of-00016.safetensors",
66
+ "backbone.layers.15.mixer.shared_experts.up_proj.weight": "model-00003-of-00016.safetensors",
67
+ "backbone.layers.15.norm.weight": "model-00003-of-00016.safetensors",
68
+ "backbone.layers.16.mixer.A_log": "model-00003-of-00016.safetensors",
69
+ "backbone.layers.16.mixer.D": "model-00003-of-00016.safetensors",
70
+ "backbone.layers.16.mixer.conv1d.bias": "model-00003-of-00016.safetensors",
71
+ "backbone.layers.16.mixer.conv1d.weight": "model-00003-of-00016.safetensors",
72
+ "backbone.layers.16.mixer.dt_bias": "model-00003-of-00016.safetensors",
73
+ "backbone.layers.16.mixer.in_proj.weight": "model-00003-of-00016.safetensors",
74
+ "backbone.layers.16.mixer.norm.weight": "model-00003-of-00016.safetensors",
75
+ "backbone.layers.16.mixer.out_proj.weight": "model-00003-of-00016.safetensors",
76
+ "backbone.layers.16.norm.weight": "model-00003-of-00016.safetensors",
77
+ "backbone.layers.17.mixer.experts.down_proj.weight": "model-00003-of-00016.safetensors",
78
+ "backbone.layers.17.mixer.experts.up_proj.weight": "model-00004-of-00016.safetensors",
79
+ "backbone.layers.17.mixer.gate.e_score_correction_bias": "model-00004-of-00016.safetensors",
80
+ "backbone.layers.17.mixer.gate.weight": "model-00004-of-00016.safetensors",
81
+ "backbone.layers.17.mixer.shared_experts.down_proj.weight": "model-00004-of-00016.safetensors",
82
+ "backbone.layers.17.mixer.shared_experts.up_proj.weight": "model-00004-of-00016.safetensors",
83
+ "backbone.layers.17.norm.weight": "model-00004-of-00016.safetensors",
84
+ "backbone.layers.18.mixer.A_log": "model-00004-of-00016.safetensors",
85
+ "backbone.layers.18.mixer.D": "model-00004-of-00016.safetensors",
86
+ "backbone.layers.18.mixer.conv1d.bias": "model-00004-of-00016.safetensors",
87
+ "backbone.layers.18.mixer.conv1d.weight": "model-00004-of-00016.safetensors",
88
+ "backbone.layers.18.mixer.dt_bias": "model-00004-of-00016.safetensors",
89
+ "backbone.layers.18.mixer.in_proj.weight": "model-00004-of-00016.safetensors",
90
+ "backbone.layers.18.mixer.norm.weight": "model-00004-of-00016.safetensors",
91
+ "backbone.layers.18.mixer.out_proj.weight": "model-00004-of-00016.safetensors",
92
+ "backbone.layers.18.norm.weight": "model-00004-of-00016.safetensors",
93
+ "backbone.layers.19.mixer.k_proj.weight": "model-00004-of-00016.safetensors",
94
+ "backbone.layers.19.mixer.o_proj.weight": "model-00004-of-00016.safetensors",
95
+ "backbone.layers.19.mixer.q_proj.weight": "model-00004-of-00016.safetensors",
96
+ "backbone.layers.19.mixer.s_proj.weight": "model-00004-of-00016.safetensors",
97
+ "backbone.layers.19.mixer.v_proj.weight": "model-00004-of-00016.safetensors",
98
+ "backbone.layers.19.norm.weight": "model-00004-of-00016.safetensors",
99
+ "backbone.layers.2.mixer.A_log": "model-00004-of-00016.safetensors",
100
+ "backbone.layers.2.mixer.D": "model-00004-of-00016.safetensors",
101
+ "backbone.layers.2.mixer.conv1d.bias": "model-00004-of-00016.safetensors",
102
+ "backbone.layers.2.mixer.conv1d.weight": "model-00004-of-00016.safetensors",
103
+ "backbone.layers.2.mixer.dt_bias": "model-00004-of-00016.safetensors",
104
+ "backbone.layers.2.mixer.in_proj.weight": "model-00004-of-00016.safetensors",
105
+ "backbone.layers.2.mixer.norm.weight": "model-00004-of-00016.safetensors",
106
+ "backbone.layers.2.mixer.out_proj.weight": "model-00004-of-00016.safetensors",
107
+ "backbone.layers.2.norm.weight": "model-00004-of-00016.safetensors",
108
+ "backbone.layers.20.mixer.experts.down_proj.weight": "model-00004-of-00016.safetensors",
109
+ "backbone.layers.20.mixer.experts.up_proj.weight": "model-00004-of-00016.safetensors",
110
+ "backbone.layers.20.mixer.gate.e_score_correction_bias": "model-00004-of-00016.safetensors",
111
+ "backbone.layers.20.mixer.gate.weight": "model-00004-of-00016.safetensors",
112
+ "backbone.layers.20.mixer.shared_experts.down_proj.weight": "model-00004-of-00016.safetensors",
113
+ "backbone.layers.20.mixer.shared_experts.up_proj.weight": "model-00004-of-00016.safetensors",
114
+ "backbone.layers.20.norm.weight": "model-00004-of-00016.safetensors",
115
+ "backbone.layers.21.mixer.A_log": "model-00004-of-00016.safetensors",
116
+ "backbone.layers.21.mixer.D": "model-00004-of-00016.safetensors",
117
+ "backbone.layers.21.mixer.conv1d.bias": "model-00004-of-00016.safetensors",
118
+ "backbone.layers.21.mixer.conv1d.weight": "model-00004-of-00016.safetensors",
119
+ "backbone.layers.21.mixer.dt_bias": "model-00004-of-00016.safetensors",
120
+ "backbone.layers.21.mixer.in_proj.weight": "model-00004-of-00016.safetensors",
121
+ "backbone.layers.21.mixer.norm.weight": "model-00004-of-00016.safetensors",
122
+ "backbone.layers.21.mixer.out_proj.weight": "model-00004-of-00016.safetensors",
123
+ "backbone.layers.21.norm.weight": "model-00004-of-00016.safetensors",
124
+ "backbone.layers.22.mixer.experts.down_proj.weight": "model-00005-of-00016.safetensors",
125
+ "backbone.layers.22.mixer.experts.up_proj.weight": "model-00005-of-00016.safetensors",
126
+ "backbone.layers.22.mixer.gate.e_score_correction_bias": "model-00005-of-00016.safetensors",
127
+ "backbone.layers.22.mixer.gate.weight": "model-00005-of-00016.safetensors",
128
+ "backbone.layers.22.mixer.shared_experts.down_proj.weight": "model-00005-of-00016.safetensors",
129
+ "backbone.layers.22.mixer.shared_experts.up_proj.weight": "model-00005-of-00016.safetensors",
130
+ "backbone.layers.22.norm.weight": "model-00005-of-00016.safetensors",
131
+ "backbone.layers.23.mixer.A_log": "model-00005-of-00016.safetensors",
132
+ "backbone.layers.23.mixer.D": "model-00005-of-00016.safetensors",
133
+ "backbone.layers.23.mixer.conv1d.bias": "model-00005-of-00016.safetensors",
134
+ "backbone.layers.23.mixer.conv1d.weight": "model-00005-of-00016.safetensors",
135
+ "backbone.layers.23.mixer.dt_bias": "model-00005-of-00016.safetensors",
136
+ "backbone.layers.23.mixer.in_proj.weight": "model-00005-of-00016.safetensors",
137
+ "backbone.layers.23.mixer.norm.weight": "model-00005-of-00016.safetensors",
138
+ "backbone.layers.23.mixer.out_proj.weight": "model-00005-of-00016.safetensors",
139
+ "backbone.layers.23.norm.weight": "model-00005-of-00016.safetensors",
140
+ "backbone.layers.24.mixer.experts.down_proj.weight": "model-00005-of-00016.safetensors",
141
+ "backbone.layers.24.mixer.experts.up_proj.weight": "model-00006-of-00016.safetensors",
142
+ "backbone.layers.24.mixer.gate.e_score_correction_bias": "model-00006-of-00016.safetensors",
143
+ "backbone.layers.24.mixer.gate.weight": "model-00006-of-00016.safetensors",
144
+ "backbone.layers.24.mixer.shared_experts.down_proj.weight": "model-00006-of-00016.safetensors",
145
+ "backbone.layers.24.mixer.shared_experts.up_proj.weight": "model-00006-of-00016.safetensors",
146
+ "backbone.layers.24.norm.weight": "model-00006-of-00016.safetensors",
147
+ "backbone.layers.25.mixer.A_log": "model-00006-of-00016.safetensors",
148
+ "backbone.layers.25.mixer.D": "model-00006-of-00016.safetensors",
149
+ "backbone.layers.25.mixer.conv1d.bias": "model-00006-of-00016.safetensors",
150
+ "backbone.layers.25.mixer.conv1d.weight": "model-00006-of-00016.safetensors",
151
+ "backbone.layers.25.mixer.dt_bias": "model-00006-of-00016.safetensors",
152
+ "backbone.layers.25.mixer.in_proj.weight": "model-00006-of-00016.safetensors",
153
+ "backbone.layers.25.mixer.norm.weight": "model-00006-of-00016.safetensors",
154
+ "backbone.layers.25.mixer.out_proj.weight": "model-00006-of-00016.safetensors",
155
+ "backbone.layers.25.norm.weight": "model-00006-of-00016.safetensors",
156
+ "backbone.layers.26.mixer.k_proj.weight": "model-00006-of-00016.safetensors",
157
+ "backbone.layers.26.mixer.o_proj.weight": "model-00006-of-00016.safetensors",
158
+ "backbone.layers.26.mixer.q_proj.weight": "model-00006-of-00016.safetensors",
159
+ "backbone.layers.26.mixer.s_proj.weight": "model-00006-of-00016.safetensors",
160
+ "backbone.layers.26.mixer.v_proj.weight": "model-00006-of-00016.safetensors",
161
+ "backbone.layers.26.norm.weight": "model-00006-of-00016.safetensors",
162
+ "backbone.layers.27.mixer.experts.down_proj.weight": "model-00006-of-00016.safetensors",
163
+ "backbone.layers.27.mixer.experts.up_proj.weight": "model-00006-of-00016.safetensors",
164
+ "backbone.layers.27.mixer.gate.e_score_correction_bias": "model-00006-of-00016.safetensors",
165
+ "backbone.layers.27.mixer.gate.weight": "model-00006-of-00016.safetensors",
166
+ "backbone.layers.27.mixer.shared_experts.down_proj.weight": "model-00006-of-00016.safetensors",
167
+ "backbone.layers.27.mixer.shared_experts.up_proj.weight": "model-00006-of-00016.safetensors",
168
+ "backbone.layers.27.norm.weight": "model-00006-of-00016.safetensors",
169
+ "backbone.layers.28.mixer.A_log": "model-00006-of-00016.safetensors",
170
+ "backbone.layers.28.mixer.D": "model-00006-of-00016.safetensors",
171
+ "backbone.layers.28.mixer.conv1d.bias": "model-00006-of-00016.safetensors",
172
+ "backbone.layers.28.mixer.conv1d.weight": "model-00006-of-00016.safetensors",
173
+ "backbone.layers.28.mixer.dt_bias": "model-00006-of-00016.safetensors",
174
+ "backbone.layers.28.mixer.in_proj.weight": "model-00006-of-00016.safetensors",
175
+ "backbone.layers.28.mixer.norm.weight": "model-00006-of-00016.safetensors",
176
+ "backbone.layers.28.mixer.out_proj.weight": "model-00006-of-00016.safetensors",
177
+ "backbone.layers.28.norm.weight": "model-00006-of-00016.safetensors",
178
+ "backbone.layers.29.mixer.experts.down_proj.weight": "model-00007-of-00016.safetensors",
179
+ "backbone.layers.29.mixer.experts.up_proj.weight": "model-00007-of-00016.safetensors",
180
+ "backbone.layers.29.mixer.gate.e_score_correction_bias": "model-00007-of-00016.safetensors",
181
+ "backbone.layers.29.mixer.gate.weight": "model-00007-of-00016.safetensors",
182
+ "backbone.layers.29.mixer.shared_experts.down_proj.weight": "model-00007-of-00016.safetensors",
183
+ "backbone.layers.29.mixer.shared_experts.up_proj.weight": "model-00007-of-00016.safetensors",
184
+ "backbone.layers.29.norm.weight": "model-00007-of-00016.safetensors",
185
+ "backbone.layers.3.mixer.experts.down_proj.weight": "model-00007-of-00016.safetensors",
186
+ "backbone.layers.3.mixer.experts.up_proj.weight": "model-00008-of-00016.safetensors",
187
+ "backbone.layers.3.mixer.gate.e_score_correction_bias": "model-00008-of-00016.safetensors",
188
+ "backbone.layers.3.mixer.gate.weight": "model-00008-of-00016.safetensors",
189
+ "backbone.layers.3.mixer.shared_experts.down_proj.weight": "model-00008-of-00016.safetensors",
190
+ "backbone.layers.3.mixer.shared_experts.up_proj.weight": "model-00008-of-00016.safetensors",
191
+ "backbone.layers.3.norm.weight": "model-00008-of-00016.safetensors",
192
+ "backbone.layers.30.mixer.A_log": "model-00008-of-00016.safetensors",
193
+ "backbone.layers.30.mixer.D": "model-00008-of-00016.safetensors",
194
+ "backbone.layers.30.mixer.conv1d.bias": "model-00008-of-00016.safetensors",
195
+ "backbone.layers.30.mixer.conv1d.weight": "model-00008-of-00016.safetensors",
196
+ "backbone.layers.30.mixer.dt_bias": "model-00008-of-00016.safetensors",
197
+ "backbone.layers.30.mixer.in_proj.weight": "model-00008-of-00016.safetensors",
198
+ "backbone.layers.30.mixer.norm.weight": "model-00008-of-00016.safetensors",
199
+ "backbone.layers.30.mixer.out_proj.weight": "model-00008-of-00016.safetensors",
200
+ "backbone.layers.30.norm.weight": "model-00008-of-00016.safetensors",
201
+ "backbone.layers.31.mixer.experts.down_proj.weight": "model-00008-of-00016.safetensors",
202
+ "backbone.layers.31.mixer.experts.up_proj.weight": "model-00008-of-00016.safetensors",
203
+ "backbone.layers.31.mixer.gate.e_score_correction_bias": "model-00008-of-00016.safetensors",
204
+ "backbone.layers.31.mixer.gate.weight": "model-00008-of-00016.safetensors",
205
+ "backbone.layers.31.mixer.shared_experts.down_proj.weight": "model-00008-of-00016.safetensors",
206
+ "backbone.layers.31.mixer.shared_experts.up_proj.weight": "model-00008-of-00016.safetensors",
207
+ "backbone.layers.31.norm.weight": "model-00008-of-00016.safetensors",
208
+ "backbone.layers.32.mixer.A_log": "model-00008-of-00016.safetensors",
209
+ "backbone.layers.32.mixer.D": "model-00008-of-00016.safetensors",
210
+ "backbone.layers.32.mixer.conv1d.bias": "model-00008-of-00016.safetensors",
211
+ "backbone.layers.32.mixer.conv1d.weight": "model-00008-of-00016.safetensors",
212
+ "backbone.layers.32.mixer.dt_bias": "model-00008-of-00016.safetensors",
213
+ "backbone.layers.32.mixer.in_proj.weight": "model-00008-of-00016.safetensors",
214
+ "backbone.layers.32.mixer.norm.weight": "model-00008-of-00016.safetensors",
215
+ "backbone.layers.32.mixer.out_proj.weight": "model-00008-of-00016.safetensors",
216
+ "backbone.layers.32.norm.weight": "model-00008-of-00016.safetensors",
217
+ "backbone.layers.33.mixer.k_proj.weight": "model-00008-of-00016.safetensors",
218
+ "backbone.layers.33.mixer.o_proj.weight": "model-00008-of-00016.safetensors",
219
+ "backbone.layers.33.mixer.q_proj.weight": "model-00008-of-00016.safetensors",
220
+ "backbone.layers.33.mixer.s_proj.weight": "model-00008-of-00016.safetensors",
221
+ "backbone.layers.33.mixer.v_proj.weight": "model-00008-of-00016.safetensors",
222
+ "backbone.layers.33.norm.weight": "model-00008-of-00016.safetensors",
223
+ "backbone.layers.34.mixer.experts.down_proj.weight": "model-00009-of-00016.safetensors",
224
+ "backbone.layers.34.mixer.experts.up_proj.weight": "model-00009-of-00016.safetensors",
225
+ "backbone.layers.34.mixer.gate.e_score_correction_bias": "model-00009-of-00016.safetensors",
226
+ "backbone.layers.34.mixer.gate.weight": "model-00009-of-00016.safetensors",
227
+ "backbone.layers.34.mixer.shared_experts.down_proj.weight": "model-00009-of-00016.safetensors",
228
+ "backbone.layers.34.mixer.shared_experts.up_proj.weight": "model-00009-of-00016.safetensors",
229
+ "backbone.layers.34.norm.weight": "model-00009-of-00016.safetensors",
230
+ "backbone.layers.35.mixer.A_log": "model-00009-of-00016.safetensors",
231
+ "backbone.layers.35.mixer.D": "model-00009-of-00016.safetensors",
232
+ "backbone.layers.35.mixer.conv1d.bias": "model-00009-of-00016.safetensors",
233
+ "backbone.layers.35.mixer.conv1d.weight": "model-00009-of-00016.safetensors",
234
+ "backbone.layers.35.mixer.dt_bias": "model-00009-of-00016.safetensors",
235
+ "backbone.layers.35.mixer.in_proj.weight": "model-00009-of-00016.safetensors",
236
+ "backbone.layers.35.mixer.norm.weight": "model-00009-of-00016.safetensors",
237
+ "backbone.layers.35.mixer.out_proj.weight": "model-00009-of-00016.safetensors",
238
+ "backbone.layers.35.norm.weight": "model-00009-of-00016.safetensors",
239
+ "backbone.layers.36.mixer.experts.down_proj.weight": "model-00009-of-00016.safetensors",
240
+ "backbone.layers.36.mixer.experts.up_proj.weight": "model-00010-of-00016.safetensors",
241
+ "backbone.layers.36.mixer.gate.e_score_correction_bias": "model-00010-of-00016.safetensors",
242
+ "backbone.layers.36.mixer.gate.weight": "model-00010-of-00016.safetensors",
243
+ "backbone.layers.36.mixer.shared_experts.down_proj.weight": "model-00010-of-00016.safetensors",
244
+ "backbone.layers.36.mixer.shared_experts.up_proj.weight": "model-00010-of-00016.safetensors",
245
+ "backbone.layers.36.norm.weight": "model-00010-of-00016.safetensors",
246
+ "backbone.layers.37.mixer.A_log": "model-00010-of-00016.safetensors",
247
+ "backbone.layers.37.mixer.D": "model-00010-of-00016.safetensors",
248
+ "backbone.layers.37.mixer.conv1d.bias": "model-00010-of-00016.safetensors",
249
+ "backbone.layers.37.mixer.conv1d.weight": "model-00010-of-00016.safetensors",
250
+ "backbone.layers.37.mixer.dt_bias": "model-00010-of-00016.safetensors",
251
+ "backbone.layers.37.mixer.in_proj.weight": "model-00010-of-00016.safetensors",
252
+ "backbone.layers.37.mixer.norm.weight": "model-00010-of-00016.safetensors",
253
+ "backbone.layers.37.mixer.out_proj.weight": "model-00010-of-00016.safetensors",
254
+ "backbone.layers.37.norm.weight": "model-00010-of-00016.safetensors",
255
+ "backbone.layers.38.mixer.experts.down_proj.weight": "model-00010-of-00016.safetensors",
256
+ "backbone.layers.38.mixer.experts.up_proj.weight": "model-00010-of-00016.safetensors",
257
+ "backbone.layers.38.mixer.gate.e_score_correction_bias": "model-00010-of-00016.safetensors",
258
+ "backbone.layers.38.mixer.gate.weight": "model-00010-of-00016.safetensors",
259
+ "backbone.layers.38.mixer.shared_experts.down_proj.weight": "model-00010-of-00016.safetensors",
260
+ "backbone.layers.38.mixer.shared_experts.up_proj.weight": "model-00010-of-00016.safetensors",
261
+ "backbone.layers.38.norm.weight": "model-00010-of-00016.safetensors",
262
+ "backbone.layers.39.mixer.A_log": "model-00010-of-00016.safetensors",
263
+ "backbone.layers.39.mixer.D": "model-00010-of-00016.safetensors",
264
+ "backbone.layers.39.mixer.conv1d.bias": "model-00010-of-00016.safetensors",
265
+ "backbone.layers.39.mixer.conv1d.weight": "model-00010-of-00016.safetensors",
266
+ "backbone.layers.39.mixer.dt_bias": "model-00010-of-00016.safetensors",
267
+ "backbone.layers.39.mixer.in_proj.weight": "model-00010-of-00016.safetensors",
268
+ "backbone.layers.39.mixer.norm.weight": "model-00010-of-00016.safetensors",
269
+ "backbone.layers.39.mixer.out_proj.weight": "model-00010-of-00016.safetensors",
270
+ "backbone.layers.39.norm.weight": "model-00010-of-00016.safetensors",
271
+ "backbone.layers.4.mixer.A_log": "model-00010-of-00016.safetensors",
272
+ "backbone.layers.4.mixer.D": "model-00010-of-00016.safetensors",
273
+ "backbone.layers.4.mixer.conv1d.bias": "model-00010-of-00016.safetensors",
274
+ "backbone.layers.4.mixer.conv1d.weight": "model-00010-of-00016.safetensors",
275
+ "backbone.layers.4.mixer.dt_bias": "model-00010-of-00016.safetensors",
276
+ "backbone.layers.4.mixer.in_proj.weight": "model-00010-of-00016.safetensors",
277
+ "backbone.layers.4.mixer.norm.weight": "model-00010-of-00016.safetensors",
278
+ "backbone.layers.4.mixer.out_proj.weight": "model-00010-of-00016.safetensors",
279
+ "backbone.layers.4.norm.weight": "model-00010-of-00016.safetensors",
280
+ "backbone.layers.40.mixer.experts.down_proj.weight": "model-00011-of-00016.safetensors",
281
+ "backbone.layers.40.mixer.experts.up_proj.weight": "model-00011-of-00016.safetensors",
282
+ "backbone.layers.40.mixer.gate.e_score_correction_bias": "model-00011-of-00016.safetensors",
283
+ "backbone.layers.40.mixer.gate.weight": "model-00011-of-00016.safetensors",
284
+ "backbone.layers.40.mixer.shared_experts.down_proj.weight": "model-00011-of-00016.safetensors",
285
+ "backbone.layers.40.mixer.shared_experts.up_proj.weight": "model-00011-of-00016.safetensors",
286
+ "backbone.layers.40.norm.weight": "model-00011-of-00016.safetensors",
287
+ "backbone.layers.41.mixer.A_log": "model-00011-of-00016.safetensors",
288
+ "backbone.layers.41.mixer.D": "model-00011-of-00016.safetensors",
289
+ "backbone.layers.41.mixer.conv1d.bias": "model-00011-of-00016.safetensors",
290
+ "backbone.layers.41.mixer.conv1d.weight": "model-00011-of-00016.safetensors",
291
+ "backbone.layers.41.mixer.dt_bias": "model-00011-of-00016.safetensors",
292
+ "backbone.layers.41.mixer.in_proj.weight": "model-00011-of-00016.safetensors",
293
+ "backbone.layers.41.mixer.norm.weight": "model-00011-of-00016.safetensors",
294
+ "backbone.layers.41.mixer.out_proj.weight": "model-00011-of-00016.safetensors",
295
+ "backbone.layers.41.norm.weight": "model-00011-of-00016.safetensors",
296
+ "backbone.layers.42.mixer.k_proj.weight": "model-00011-of-00016.safetensors",
297
+ "backbone.layers.42.mixer.o_proj.weight": "model-00011-of-00016.safetensors",
298
+ "backbone.layers.42.mixer.q_proj.weight": "model-00011-of-00016.safetensors",
299
+ "backbone.layers.42.mixer.s_proj.weight": "model-00011-of-00016.safetensors",
300
+ "backbone.layers.42.mixer.v_proj.weight": "model-00011-of-00016.safetensors",
301
+ "backbone.layers.42.norm.weight": "model-00011-of-00016.safetensors",
302
+ "backbone.layers.43.mixer.experts.down_proj.weight": "model-00011-of-00016.safetensors",
303
+ "backbone.layers.43.mixer.experts.up_proj.weight": "model-00012-of-00016.safetensors",
304
+ "backbone.layers.43.mixer.gate.e_score_correction_bias": "model-00012-of-00016.safetensors",
305
+ "backbone.layers.43.mixer.gate.weight": "model-00012-of-00016.safetensors",
306
+ "backbone.layers.43.mixer.shared_experts.down_proj.weight": "model-00012-of-00016.safetensors",
307
+ "backbone.layers.43.mixer.shared_experts.up_proj.weight": "model-00012-of-00016.safetensors",
308
+ "backbone.layers.43.norm.weight": "model-00012-of-00016.safetensors",
309
+ "backbone.layers.44.mixer.A_log": "model-00012-of-00016.safetensors",
310
+ "backbone.layers.44.mixer.D": "model-00012-of-00016.safetensors",
311
+ "backbone.layers.44.mixer.conv1d.bias": "model-00012-of-00016.safetensors",
312
+ "backbone.layers.44.mixer.conv1d.weight": "model-00012-of-00016.safetensors",
313
+ "backbone.layers.44.mixer.dt_bias": "model-00012-of-00016.safetensors",
314
+ "backbone.layers.44.mixer.in_proj.weight": "model-00012-of-00016.safetensors",
315
+ "backbone.layers.44.mixer.norm.weight": "model-00012-of-00016.safetensors",
316
+ "backbone.layers.44.mixer.out_proj.weight": "model-00012-of-00016.safetensors",
317
+ "backbone.layers.44.norm.weight": "model-00012-of-00016.safetensors",
318
+ "backbone.layers.45.mixer.experts.down_proj.weight": "model-00012-of-00016.safetensors",
319
+ "backbone.layers.45.mixer.experts.up_proj.weight": "model-00012-of-00016.safetensors",
320
+ "backbone.layers.45.mixer.gate.e_score_correction_bias": "model-00012-of-00016.safetensors",
321
+ "backbone.layers.45.mixer.gate.weight": "model-00012-of-00016.safetensors",
322
+ "backbone.layers.45.mixer.shared_experts.down_proj.weight": "model-00012-of-00016.safetensors",
323
+ "backbone.layers.45.mixer.shared_experts.up_proj.weight": "model-00012-of-00016.safetensors",
324
+ "backbone.layers.45.norm.weight": "model-00012-of-00016.safetensors",
325
+ "backbone.layers.46.mixer.A_log": "model-00012-of-00016.safetensors",
326
+ "backbone.layers.46.mixer.D": "model-00012-of-00016.safetensors",
327
+ "backbone.layers.46.mixer.conv1d.bias": "model-00012-of-00016.safetensors",
328
+ "backbone.layers.46.mixer.conv1d.weight": "model-00012-of-00016.safetensors",
329
+ "backbone.layers.46.mixer.dt_bias": "model-00012-of-00016.safetensors",
330
+ "backbone.layers.46.mixer.in_proj.weight": "model-00012-of-00016.safetensors",
331
+ "backbone.layers.46.mixer.norm.weight": "model-00012-of-00016.safetensors",
332
+ "backbone.layers.46.mixer.out_proj.weight": "model-00012-of-00016.safetensors",
333
+ "backbone.layers.46.norm.weight": "model-00012-of-00016.safetensors",
334
+ "backbone.layers.47.mixer.experts.down_proj.weight": "model-00013-of-00016.safetensors",
335
+ "backbone.layers.47.mixer.experts.up_proj.weight": "model-00013-of-00016.safetensors",
336
+ "backbone.layers.47.mixer.gate.e_score_correction_bias": "model-00013-of-00016.safetensors",
337
+ "backbone.layers.47.mixer.gate.weight": "model-00013-of-00016.safetensors",
338
+ "backbone.layers.47.mixer.shared_experts.down_proj.weight": "model-00013-of-00016.safetensors",
339
+ "backbone.layers.47.mixer.shared_experts.up_proj.weight": "model-00013-of-00016.safetensors",
340
+ "backbone.layers.47.norm.weight": "model-00013-of-00016.safetensors",
341
+ "backbone.layers.48.mixer.A_log": "model-00013-of-00016.safetensors",
342
+ "backbone.layers.48.mixer.D": "model-00013-of-00016.safetensors",
343
+ "backbone.layers.48.mixer.conv1d.bias": "model-00013-of-00016.safetensors",
344
+ "backbone.layers.48.mixer.conv1d.weight": "model-00013-of-00016.safetensors",
345
+ "backbone.layers.48.mixer.dt_bias": "model-00013-of-00016.safetensors",
346
+ "backbone.layers.48.mixer.in_proj.weight": "model-00013-of-00016.safetensors",
347
+ "backbone.layers.48.mixer.norm.weight": "model-00013-of-00016.safetensors",
348
+ "backbone.layers.48.mixer.out_proj.weight": "model-00013-of-00016.safetensors",
349
+ "backbone.layers.48.norm.weight": "model-00013-of-00016.safetensors",
350
+ "backbone.layers.49.mixer.experts.down_proj.weight": "model-00013-of-00016.safetensors",
351
+ "backbone.layers.49.mixer.experts.up_proj.weight": "model-00014-of-00016.safetensors",
352
+ "backbone.layers.49.mixer.gate.e_score_correction_bias": "model-00014-of-00016.safetensors",
353
+ "backbone.layers.49.mixer.gate.weight": "model-00014-of-00016.safetensors",
354
+ "backbone.layers.49.mixer.shared_experts.down_proj.weight": "model-00014-of-00016.safetensors",
355
+ "backbone.layers.49.mixer.shared_experts.up_proj.weight": "model-00014-of-00016.safetensors",
356
+ "backbone.layers.49.norm.weight": "model-00014-of-00016.safetensors",
357
+ "backbone.layers.5.mixer.k_proj.weight": "model-00014-of-00016.safetensors",
358
+ "backbone.layers.5.mixer.o_proj.weight": "model-00014-of-00016.safetensors",
359
+ "backbone.layers.5.mixer.q_proj.weight": "model-00014-of-00016.safetensors",
360
+ "backbone.layers.5.mixer.s_proj.weight": "model-00014-of-00016.safetensors",
361
+ "backbone.layers.5.mixer.v_proj.weight": "model-00014-of-00016.safetensors",
362
+ "backbone.layers.5.norm.weight": "model-00014-of-00016.safetensors",
363
+ "backbone.layers.50.mixer.A_log": "model-00014-of-00016.safetensors",
364
+ "backbone.layers.50.mixer.D": "model-00014-of-00016.safetensors",
365
+ "backbone.layers.50.mixer.conv1d.bias": "model-00014-of-00016.safetensors",
366
+ "backbone.layers.50.mixer.conv1d.weight": "model-00014-of-00016.safetensors",
367
+ "backbone.layers.50.mixer.dt_bias": "model-00014-of-00016.safetensors",
368
+ "backbone.layers.50.mixer.in_proj.weight": "model-00014-of-00016.safetensors",
369
+ "backbone.layers.50.mixer.norm.weight": "model-00014-of-00016.safetensors",
370
+ "backbone.layers.50.mixer.out_proj.weight": "model-00014-of-00016.safetensors",
371
+ "backbone.layers.50.norm.weight": "model-00014-of-00016.safetensors",
372
+ "backbone.layers.51.mixer.experts.down_proj.weight": "model-00014-of-00016.safetensors",
373
+ "backbone.layers.51.mixer.experts.up_proj.weight": "model-00014-of-00016.safetensors",
374
+ "backbone.layers.51.mixer.gate.e_score_correction_bias": "model-00014-of-00016.safetensors",
375
+ "backbone.layers.51.mixer.gate.weight": "model-00014-of-00016.safetensors",
376
+ "backbone.layers.51.mixer.shared_experts.down_proj.weight": "model-00014-of-00016.safetensors",
377
+ "backbone.layers.51.mixer.shared_experts.up_proj.weight": "model-00014-of-00016.safetensors",
378
+ "backbone.layers.51.norm.weight": "model-00014-of-00016.safetensors",
379
+ "backbone.layers.6.mixer.experts.down_proj.weight": "model-00015-of-00016.safetensors",
380
+ "backbone.layers.6.mixer.experts.up_proj.weight": "model-00015-of-00016.safetensors",
381
+ "backbone.layers.6.mixer.gate.e_score_correction_bias": "model-00015-of-00016.safetensors",
382
+ "backbone.layers.6.mixer.gate.weight": "model-00015-of-00016.safetensors",
383
+ "backbone.layers.6.mixer.shared_experts.down_proj.weight": "model-00015-of-00016.safetensors",
384
+ "backbone.layers.6.mixer.shared_experts.up_proj.weight": "model-00015-of-00016.safetensors",
385
+ "backbone.layers.6.norm.weight": "model-00015-of-00016.safetensors",
386
+ "backbone.layers.7.mixer.A_log": "model-00015-of-00016.safetensors",
387
+ "backbone.layers.7.mixer.D": "model-00015-of-00016.safetensors",
388
+ "backbone.layers.7.mixer.conv1d.bias": "model-00015-of-00016.safetensors",
389
+ "backbone.layers.7.mixer.conv1d.weight": "model-00015-of-00016.safetensors",
390
+ "backbone.layers.7.mixer.dt_bias": "model-00015-of-00016.safetensors",
391
+ "backbone.layers.7.mixer.in_proj.weight": "model-00015-of-00016.safetensors",
392
+ "backbone.layers.7.mixer.norm.weight": "model-00015-of-00016.safetensors",
393
+ "backbone.layers.7.mixer.out_proj.weight": "model-00015-of-00016.safetensors",
394
+ "backbone.layers.7.norm.weight": "model-00015-of-00016.safetensors",
395
+ "backbone.layers.8.mixer.experts.down_proj.weight": "model-00015-of-00016.safetensors",
396
+ "backbone.layers.8.mixer.experts.up_proj.weight": "model-00016-of-00016.safetensors",
397
+ "backbone.layers.8.mixer.gate.e_score_correction_bias": "model-00016-of-00016.safetensors",
398
+ "backbone.layers.8.mixer.gate.weight": "model-00016-of-00016.safetensors",
399
+ "backbone.layers.8.mixer.shared_experts.down_proj.weight": "model-00016-of-00016.safetensors",
400
+ "backbone.layers.8.mixer.shared_experts.up_proj.weight": "model-00016-of-00016.safetensors",
401
+ "backbone.layers.8.norm.weight": "model-00016-of-00016.safetensors",
402
+ "backbone.layers.9.mixer.A_log": "model-00016-of-00016.safetensors",
403
+ "backbone.layers.9.mixer.D": "model-00016-of-00016.safetensors",
404
+ "backbone.layers.9.mixer.conv1d.bias": "model-00016-of-00016.safetensors",
405
+ "backbone.layers.9.mixer.conv1d.weight": "model-00016-of-00016.safetensors",
406
+ "backbone.layers.9.mixer.dt_bias": "model-00016-of-00016.safetensors",
407
+ "backbone.layers.9.mixer.in_proj.weight": "model-00016-of-00016.safetensors",
408
+ "backbone.layers.9.mixer.norm.weight": "model-00016-of-00016.safetensors",
409
+ "backbone.layers.9.mixer.out_proj.weight": "model-00016-of-00016.safetensors",
410
+ "backbone.layers.9.norm.weight": "model-00016-of-00016.safetensors",
411
+ "backbone.norm_f.weight": "model-00016-of-00016.safetensors",
412
+ "lm_head.weight": "model-00016-of-00016.safetensors"
413
+ }
414
+ }
modeling_superlinear_exp.py ADDED
The diff for this file is too large to render. See raw diff
 
moe.py ADDED
@@ -0,0 +1,890 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import math
5
+ import os
6
+ from typing import Any
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ try: # pragma: no cover
12
+ import triton
13
+ import triton.language as tl
14
+ except Exception: # pragma: no cover
15
+ triton = None
16
+ tl = None
17
+
18
+ # Eagerly import vllm._moe_C to ensure fused MoE kernels are available
19
+ # This must happen before shared_fused_moe_is_available() is called
20
+ try: # pragma: no cover
21
+ import vllm._moe_C # noqa: F401
22
+ except Exception: # pragma: no cover
23
+ pass
24
+
25
+
26
+ def _cdiv(a: int, b: int) -> int:
27
+ return (a + b - 1) // b
28
+
29
+
30
+ def _round_up(a: int, b: int) -> int:
31
+ return _cdiv(a, b) * b
32
+
33
+
34
+ @functools.lru_cache(maxsize=1)
35
+ def _ensure_moe_kernels_loaded() -> bool:
36
+ if hasattr(torch.ops, "_moe_C"):
37
+ return True
38
+ try: # pragma: no cover
39
+ # vLLM installs the MoE kernels under this module name.
40
+ import vllm._moe_C # noqa: F401
41
+ return hasattr(torch.ops, "_moe_C")
42
+ except Exception:
43
+ return False
44
+
45
+
46
+ @functools.lru_cache(maxsize=1)
47
+ def shared_fused_moe_is_available() -> bool:
48
+ if triton is None or tl is None:
49
+ return False
50
+ if not _ensure_moe_kernels_loaded():
51
+ return False
52
+ return all(
53
+ hasattr(torch.ops._moe_C, attr)
54
+ for attr in ("moe_align_block_size", "moe_sum")
55
+ )
56
+
57
+
58
+ def _moe_align_block_size(
59
+ topk_ids: torch.Tensor,
60
+ block_size: int,
61
+ num_experts: int,
62
+ *,
63
+ pad_sorted_ids: bool = False,
64
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
65
+ """Compute (sorted_token_ids, expert_ids, num_tokens_post_padded) using vLLM's `_moe_C` kernels."""
66
+ if not shared_fused_moe_is_available():
67
+ raise RuntimeError(
68
+ "Shared fused MoE is not available (missing triton and/or vllm._moe_C)."
69
+ )
70
+ if topk_ids.dtype != torch.int32:
71
+ topk_ids = topk_ids.to(torch.int32)
72
+
73
+ max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
74
+ if pad_sorted_ids:
75
+ max_num_tokens_padded = _round_up(max_num_tokens_padded, block_size)
76
+ if topk_ids.numel() < num_experts:
77
+ max_num_tokens_padded = min(topk_ids.numel() * block_size, max_num_tokens_padded)
78
+
79
+ sorted_token_ids = torch.empty(
80
+ (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
81
+ )
82
+ max_num_m_blocks = _cdiv(max_num_tokens_padded, block_size)
83
+ expert_ids = torch.empty(
84
+ (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
85
+ )
86
+ num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device=topk_ids.device)
87
+
88
+ torch.ops._moe_C.moe_align_block_size(
89
+ topk_ids,
90
+ num_experts,
91
+ block_size,
92
+ sorted_token_ids,
93
+ expert_ids,
94
+ num_tokens_post_pad,
95
+ None, # maybe_expert_map (added in newer vllm versions)
96
+ )
97
+ return sorted_token_ids, expert_ids, num_tokens_post_pad
98
+
99
+
100
+ def _get_default_config(
101
+ M: int,
102
+ E: int,
103
+ N: int,
104
+ K: int,
105
+ topk: int,
106
+ ) -> dict[str, int]:
107
+ # Heuristic default configs adapted from vLLM.
108
+ if M <= E:
109
+ return {
110
+ "BLOCK_SIZE_M": 16,
111
+ "BLOCK_SIZE_N": 32,
112
+ "BLOCK_SIZE_K": 64,
113
+ "GROUP_SIZE_M": 1,
114
+ }
115
+ return {
116
+ "BLOCK_SIZE_M": 64,
117
+ "BLOCK_SIZE_N": 64,
118
+ "BLOCK_SIZE_K": 32,
119
+ "GROUP_SIZE_M": 8,
120
+ }
121
+
122
+
123
+ if triton is not None and tl is not None:
124
+
125
+ @triton.jit
126
+ def _write_zeros_to_output(
127
+ c_ptr,
128
+ stride_cm,
129
+ stride_cn,
130
+ pid_n,
131
+ N,
132
+ offs_token,
133
+ token_mask,
134
+ BLOCK_SIZE_M: tl.constexpr,
135
+ BLOCK_SIZE_N: tl.constexpr,
136
+ compute_type: tl.constexpr,
137
+ ):
138
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
139
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
140
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
141
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
142
+ tl.store(c_ptrs, accumulator, mask=c_mask)
143
+
144
+ @triton.jit
145
+ def _fused_moe_kernel(
146
+ # Pointers to matrices
147
+ a_ptr,
148
+ b_ptr,
149
+ c_ptr,
150
+ topk_weights_ptr,
151
+ sorted_token_ids_ptr,
152
+ expert_ids_ptr,
153
+ num_tokens_post_padded_ptr,
154
+ # Matrix dimensions
155
+ N,
156
+ K,
157
+ EM,
158
+ num_valid_tokens,
159
+ # Strides
160
+ stride_am,
161
+ stride_ak,
162
+ stride_be,
163
+ stride_bk,
164
+ stride_bn,
165
+ stride_cm,
166
+ stride_cn,
167
+ # Meta-parameters
168
+ BLOCK_SIZE_M: tl.constexpr,
169
+ BLOCK_SIZE_N: tl.constexpr,
170
+ BLOCK_SIZE_K: tl.constexpr,
171
+ GROUP_SIZE_M: tl.constexpr,
172
+ MUL_ROUTED_WEIGHT: tl.constexpr,
173
+ top_k: tl.constexpr,
174
+ compute_type: tl.constexpr,
175
+ ):
176
+ # Grouped ordering to promote L2 data reuse.
177
+ pid = tl.program_id(axis=0)
178
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
179
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
180
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
181
+ group_id = pid // num_pid_in_group
182
+ first_pid_m = group_id * GROUP_SIZE_M
183
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
184
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
185
+ pid_n = (pid % num_pid_in_group) // group_size_m
186
+
187
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
188
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
189
+ return
190
+
191
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
192
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
193
+ token_mask = offs_token < num_valid_tokens
194
+
195
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
196
+ if off_experts == -1:
197
+ _write_zeros_to_output(
198
+ c_ptr,
199
+ stride_cm,
200
+ stride_cn,
201
+ pid_n,
202
+ N,
203
+ offs_token,
204
+ token_mask,
205
+ BLOCK_SIZE_M,
206
+ BLOCK_SIZE_N,
207
+ compute_type,
208
+ )
209
+ return
210
+
211
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
212
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
213
+ a_ptrs = a_ptr + (
214
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
215
+ )
216
+ b_ptrs = (
217
+ b_ptr
218
+ + off_experts * stride_be
219
+ + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
220
+ )
221
+
222
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
223
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
224
+ a = tl.load(
225
+ a_ptrs,
226
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
227
+ other=0.0,
228
+ )
229
+ b = tl.load(
230
+ b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0
231
+ )
232
+ # Disable TF32 for numerical precision (matches PyTorch's default behavior)
233
+ accumulator += tl.dot(a, b, allow_tf32=False)
234
+ a_ptrs += BLOCK_SIZE_K * stride_ak
235
+ b_ptrs += BLOCK_SIZE_K * stride_bk
236
+
237
+ if MUL_ROUTED_WEIGHT:
238
+ moe_weight = tl.load(
239
+ topk_weights_ptr + offs_token, mask=token_mask, other=0
240
+ )
241
+ accumulator = accumulator * moe_weight[:, None]
242
+ accumulator = accumulator.to(compute_type)
243
+
244
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
245
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
246
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
247
+ tl.store(c_ptrs, accumulator, mask=c_mask)
248
+
249
+ else: # pragma: no cover
250
+
251
+ def _fused_moe_kernel(*args, **kwargs): # type: ignore[no-redef]
252
+ raise RuntimeError("Triton is not available; cannot use fused MoE.")
253
+
254
+
255
+ def _invoke_fused_moe_kernel(
256
+ *,
257
+ A: torch.Tensor,
258
+ B: torch.Tensor,
259
+ C: torch.Tensor,
260
+ topk_weights: torch.Tensor | None,
261
+ sorted_token_ids: torch.Tensor,
262
+ expert_ids: torch.Tensor,
263
+ num_tokens_post_padded: torch.Tensor,
264
+ mul_routed_weight: bool,
265
+ top_k: int,
266
+ config: dict[str, Any],
267
+ compute_type: Any,
268
+ ) -> None:
269
+ assert triton is not None and tl is not None
270
+ assert topk_weights is not None or not mul_routed_weight
271
+ assert sorted_token_ids.stride(0) == 1
272
+
273
+ M = A.size(0)
274
+ num_tokens = M * top_k
275
+ EM = sorted_token_ids.size(0)
276
+ grid = lambda META: (
277
+ triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
278
+ )
279
+
280
+ config = dict(config)
281
+ _fused_moe_kernel[grid](
282
+ A,
283
+ B,
284
+ C,
285
+ topk_weights,
286
+ sorted_token_ids,
287
+ expert_ids,
288
+ num_tokens_post_padded,
289
+ B.size(1),
290
+ B.size(2),
291
+ EM,
292
+ num_tokens,
293
+ A.stride(0),
294
+ A.stride(1),
295
+ B.stride(0),
296
+ B.stride(2),
297
+ B.stride(1),
298
+ C.stride(1),
299
+ C.stride(2),
300
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
301
+ top_k=top_k,
302
+ compute_type=compute_type,
303
+ **config,
304
+ )
305
+
306
+
307
+ def fused_experts_moe(
308
+ hidden_states: torch.Tensor,
309
+ w1: torch.Tensor,
310
+ w2: torch.Tensor,
311
+ topk_weights: torch.Tensor,
312
+ topk_ids: torch.Tensor,
313
+ *,
314
+ activation: str,
315
+ inplace: bool = False,
316
+ apply_router_weight_on_input: bool = False,
317
+ ) -> torch.Tensor:
318
+ """
319
+ Fused MoE expert compute (2-layer MLP) using Triton grouped GEMMs + vLLM's `_moe_C` align/sum kernels.
320
+
321
+ This function is intentionally minimal: it supports *non-gated* activations
322
+ (`*_no_mul`), which is the SuperlinearExp MoE case here.
323
+ """
324
+ if torch.is_grad_enabled() and any(
325
+ t.requires_grad for t in (hidden_states, w1, w2, topk_weights)
326
+ ):
327
+ return _FusedExpertsMoE.apply(
328
+ hidden_states,
329
+ w1,
330
+ w2,
331
+ topk_weights,
332
+ topk_ids,
333
+ activation,
334
+ apply_router_weight_on_input,
335
+ )
336
+ if hidden_states.numel() == 0:
337
+ return hidden_states
338
+ if hidden_states.dim() != 2:
339
+ raise ValueError(f"Expected [tokens, hidden], got {tuple(hidden_states.shape)}")
340
+ return _fused_experts_moe_forward(
341
+ hidden_states,
342
+ w1,
343
+ w2,
344
+ topk_weights,
345
+ topk_ids,
346
+ activation=activation,
347
+ inplace=inplace,
348
+ apply_router_weight_on_input=apply_router_weight_on_input,
349
+ )
350
+
351
+
352
+ def _activation_forward(x: torch.Tensor, activation: str) -> torch.Tensor:
353
+ if activation == "relu2_no_mul":
354
+ return torch.square(F.relu(x))
355
+ if activation == "silu_no_mul":
356
+ return F.silu(x)
357
+ if activation == "gelu_no_mul":
358
+ return F.gelu(x)
359
+ raise ValueError(f"Unsupported fused MoE activation: {activation}")
360
+
361
+
362
+ def _activation_backward(x_fp32: torch.Tensor, activation: str) -> torch.Tensor:
363
+ if activation == "relu2_no_mul":
364
+ return (x_fp32 > 0).to(x_fp32.dtype) * (2.0 * x_fp32)
365
+ if activation == "silu_no_mul":
366
+ sig = torch.sigmoid(x_fp32)
367
+ return sig * (1.0 + x_fp32 * (1.0 - sig))
368
+ if activation == "gelu_no_mul":
369
+ inv_sqrt2 = 1.0 / math.sqrt(2.0)
370
+ inv_sqrt2pi = 1.0 / math.sqrt(2.0 * math.pi)
371
+ cdf = 0.5 * (1.0 + torch.erf(x_fp32 * inv_sqrt2))
372
+ pdf = torch.exp(-0.5 * x_fp32 * x_fp32) * inv_sqrt2pi
373
+ return cdf + x_fp32 * pdf
374
+ raise ValueError(f"Unsupported fused MoE activation: {activation}")
375
+
376
+
377
+ def _eager_experts_moe_forward(
378
+ hidden_states: torch.Tensor,
379
+ w1: torch.Tensor,
380
+ w2: torch.Tensor,
381
+ topk_weights: torch.Tensor,
382
+ topk_ids: torch.Tensor,
383
+ *,
384
+ activation: str,
385
+ apply_router_weight_on_input: bool,
386
+ ) -> torch.Tensor:
387
+ if hidden_states.numel() == 0:
388
+ return hidden_states
389
+ if hidden_states.dim() != 2:
390
+ raise ValueError(f"Expected [tokens, hidden], got {tuple(hidden_states.shape)}")
391
+
392
+ num_tokens, hidden_size = hidden_states.shape
393
+ num_experts, intermediate_size, hidden_size_w1 = w1.shape
394
+ num_experts_w2, hidden_size_w2, intermediate_size_w2 = w2.shape
395
+ if hidden_size_w1 != hidden_size:
396
+ raise ValueError(f"Hidden size mismatch: {hidden_size} != {hidden_size_w1} (w1 in_features)")
397
+ if num_experts_w2 != num_experts:
398
+ raise ValueError(f"Expert count mismatch: {num_experts} != {num_experts_w2} (w2)")
399
+ if hidden_size_w2 != hidden_size:
400
+ raise ValueError(f"Hidden size mismatch: {hidden_size} != {hidden_size_w2} (w2 out_features)")
401
+ if intermediate_size_w2 != intermediate_size:
402
+ raise ValueError(f"Intermediate size mismatch: {intermediate_size} != {intermediate_size_w2} (w2 in_features)")
403
+ if topk_ids.shape != topk_weights.shape:
404
+ raise ValueError("topk_ids/topk_weights shape mismatch")
405
+
406
+ topk = topk_ids.size(1)
407
+ out = torch.zeros((num_tokens, hidden_size), device=hidden_states.device, dtype=torch.float32)
408
+
409
+ CHUNK_SIZE = int(os.getenv("FUSED_MOE_CHUNK_SIZE", str(16 * 1024)))
410
+ for chunk in range((num_tokens // CHUNK_SIZE) + 1):
411
+ begin = chunk * CHUNK_SIZE
412
+ end = min((chunk + 1) * CHUNK_SIZE, num_tokens)
413
+ x = hidden_states[begin:end]
414
+ if x.numel() == 0:
415
+ break
416
+
417
+ m = x.size(0)
418
+ topk_ids_chunk = topk_ids[begin:end].reshape(-1).to(torch.long)
419
+ topk_weights_chunk = topk_weights[begin:end].reshape(-1).to(torch.float32)
420
+
421
+ token_ids = (
422
+ torch.arange(m, device=x.device, dtype=torch.long)
423
+ .repeat_interleave(topk)
424
+ )
425
+ k_ids = torch.arange(topk, device=x.device, dtype=torch.long).repeat(m)
426
+
427
+ sort_order = torch.argsort(topk_ids_chunk)
428
+ expert_ids_sorted = topk_ids_chunk[sort_order]
429
+ token_ids_sorted = token_ids[sort_order]
430
+ k_ids_sorted = k_ids[sort_order]
431
+ weights_sorted = topk_weights_chunk[sort_order]
432
+
433
+ unique_experts, counts = torch.unique_consecutive(expert_ids_sorted, return_counts=True)
434
+ out_chunk = torch.zeros((m, hidden_size), device=x.device, dtype=torch.float32)
435
+
436
+ offset = 0
437
+ for expert_idx, count in zip(unique_experts.tolist(), counts.tolist()):
438
+ tokens = token_ids_sorted[offset : offset + count]
439
+ weights = weights_sorted[offset : offset + count]
440
+ offset += count
441
+ if count == 0:
442
+ continue
443
+
444
+ x_e = x.index_select(0, tokens)
445
+ u0 = F.linear(x_e, w1[expert_idx])
446
+ if apply_router_weight_on_input:
447
+ u = u0 * weights.to(u0.dtype).unsqueeze(-1)
448
+ else:
449
+ u = u0
450
+ a = _activation_forward(u, activation)
451
+ v = F.linear(a, w2[expert_idx])
452
+ if not apply_router_weight_on_input:
453
+ v = v * weights.to(v.dtype).unsqueeze(-1)
454
+ out_chunk.index_add_(0, tokens, v.to(torch.float32))
455
+
456
+ out[begin:end] = out_chunk
457
+
458
+ return out.to(hidden_states.dtype)
459
+
460
+
461
+ def _fused_experts_moe_forward(
462
+ hidden_states: torch.Tensor,
463
+ w1: torch.Tensor,
464
+ w2: torch.Tensor,
465
+ topk_weights: torch.Tensor,
466
+ topk_ids: torch.Tensor,
467
+ *,
468
+ activation: str,
469
+ inplace: bool = False,
470
+ apply_router_weight_on_input: bool = False,
471
+ ) -> torch.Tensor:
472
+ if not shared_fused_moe_is_available():
473
+ return _eager_experts_moe_forward(
474
+ hidden_states,
475
+ w1,
476
+ w2,
477
+ topk_weights,
478
+ topk_ids,
479
+ activation=activation,
480
+ apply_router_weight_on_input=apply_router_weight_on_input,
481
+ )
482
+ if not hidden_states.is_cuda:
483
+ return _eager_experts_moe_forward(
484
+ hidden_states,
485
+ w1,
486
+ w2,
487
+ topk_weights,
488
+ topk_ids,
489
+ activation=activation,
490
+ apply_router_weight_on_input=apply_router_weight_on_input,
491
+ )
492
+
493
+ # Constraints similar to vLLM's fused kernels.
494
+ if not hidden_states.is_contiguous():
495
+ hidden_states = hidden_states.contiguous()
496
+ if w1.stride(-1) != 1 or w2.stride(-1) != 1:
497
+ raise ValueError("Expert weights must be contiguous in the last dimension.")
498
+
499
+ # Shapes.
500
+ num_tokens = hidden_states.size(0)
501
+ num_experts, n1, k1 = w1.size()
502
+ _, k2, n2 = w2.size()
503
+ if hidden_states.size(1) != k1:
504
+ raise ValueError(
505
+ f"Hidden size mismatch: {hidden_states.size(1)} != {k1} (w1 in_features)"
506
+ )
507
+ if n2 != n1:
508
+ raise ValueError(f"Intermediate size mismatch: {n2} != {n1}")
509
+ if topk_ids.shape != topk_weights.shape:
510
+ raise ValueError("topk_ids/topk_weights shape mismatch")
511
+
512
+ topk = topk_ids.size(1)
513
+ CHUNK_SIZE = int(os.getenv("FUSED_MOE_CHUNK_SIZE", str(16 * 1024)))
514
+ M = min(num_tokens, CHUNK_SIZE)
515
+
516
+ config = _get_default_config(M=M, E=num_experts, N=n1, K=k1, topk=topk)
517
+
518
+ if hidden_states.dtype == torch.bfloat16:
519
+ compute_type = tl.bfloat16
520
+ elif hidden_states.dtype == torch.float16:
521
+ compute_type = tl.float16
522
+ elif hidden_states.dtype == torch.float32:
523
+ compute_type = tl.float32
524
+ else:
525
+ raise ValueError(f"Unsupported dtype: {hidden_states.dtype}")
526
+
527
+ # Accumulate in float32 for numerical precision (matches eager path behavior).
528
+ # The output will be converted back to the original dtype at the end.
529
+ original_dtype = hidden_states.dtype
530
+ out = torch.zeros(
531
+ (num_tokens, hidden_states.size(1)),
532
+ device=hidden_states.device,
533
+ dtype=torch.float32,
534
+ )
535
+
536
+ # Cache buffers sized to the largest chunk.
537
+ # IMPORTANT: up_out and down_out must NOT overlap in memory!
538
+ # The down projection kernel reads from up_out while writing to down_out.
539
+ # If they share memory, the kernel will corrupt its input as it writes output.
540
+ up_out = torch.empty(
541
+ (M, topk, n1), device=hidden_states.device, dtype=hidden_states.dtype
542
+ )
543
+ down_out = torch.empty(
544
+ (M, topk, k2), device=hidden_states.device, dtype=hidden_states.dtype
545
+ )
546
+
547
+ for chunk in range((num_tokens // CHUNK_SIZE) + 1):
548
+ begin = chunk * CHUNK_SIZE
549
+ end = min((chunk + 1) * CHUNK_SIZE, num_tokens)
550
+ curr_hidden = hidden_states[begin:end]
551
+ tokens_in_chunk = curr_hidden.size(0)
552
+ if tokens_in_chunk == 0:
553
+ break
554
+
555
+ if tokens_in_chunk != M:
556
+ up_out = up_out[:tokens_in_chunk]
557
+ down_out = down_out[:tokens_in_chunk]
558
+ config = _get_default_config(M=tokens_in_chunk, E=num_experts, N=n1, K=k1, topk=topk)
559
+
560
+ curr_topk_ids = topk_ids[begin:end].to(torch.int32).contiguous()
561
+ curr_topk_weights = topk_weights[begin:end].to(torch.float32).contiguous()
562
+
563
+ sorted_token_ids, expert_ids, num_tokens_post_padded = _moe_align_block_size(
564
+ curr_topk_ids,
565
+ config["BLOCK_SIZE_M"],
566
+ num_experts,
567
+ )
568
+
569
+ # 1) Up projection: [tokens, hidden] -> [tokens * topk, intermediate]
570
+ _invoke_fused_moe_kernel(
571
+ A=curr_hidden,
572
+ B=w1,
573
+ C=up_out,
574
+ topk_weights=curr_topk_weights if apply_router_weight_on_input else None,
575
+ sorted_token_ids=sorted_token_ids,
576
+ expert_ids=expert_ids,
577
+ num_tokens_post_padded=num_tokens_post_padded,
578
+ mul_routed_weight=apply_router_weight_on_input,
579
+ top_k=topk,
580
+ config=config,
581
+ compute_type=compute_type,
582
+ )
583
+
584
+ # 2) Activation (in-place on up_out to avoid allocating an extra buffer).
585
+ if activation == "relu2_no_mul":
586
+ x = up_out.view(-1, n1)
587
+ x.relu_()
588
+ x.square_()
589
+ elif activation == "silu_no_mul":
590
+ x = up_out.view(-1, n1)
591
+ x.copy_(F.silu(x))
592
+ elif activation == "gelu_no_mul":
593
+ x = up_out.view(-1, n1)
594
+ x.copy_(F.gelu(x))
595
+ else:
596
+ raise ValueError(f"Unsupported fused MoE activation: {activation}")
597
+
598
+ # 3) Down projection: [tokens * topk, intermediate] -> [tokens * topk, hidden]
599
+ _invoke_fused_moe_kernel(
600
+ A=up_out.view(-1, n1),
601
+ B=w2,
602
+ C=down_out,
603
+ topk_weights=None if apply_router_weight_on_input else curr_topk_weights,
604
+ sorted_token_ids=sorted_token_ids,
605
+ expert_ids=expert_ids,
606
+ num_tokens_post_padded=num_tokens_post_padded,
607
+ mul_routed_weight=not apply_router_weight_on_input,
608
+ top_k=1,
609
+ config=config,
610
+ compute_type=compute_type,
611
+ )
612
+
613
+ # Convert down_out to float32 to match out's dtype for moe_sum
614
+ torch.ops._moe_C.moe_sum(
615
+ down_out.view(*down_out.size()).to(torch.float32),
616
+ out[begin:end],
617
+ )
618
+
619
+ # Convert back to original dtype after accumulation in float32
620
+ return out.to(original_dtype)
621
+
622
+
623
+ class _FusedExpertsMoE(torch.autograd.Function):
624
+ @staticmethod
625
+ def forward(
626
+ ctx,
627
+ hidden_states: torch.Tensor,
628
+ w1: torch.Tensor,
629
+ w2: torch.Tensor,
630
+ topk_weights: torch.Tensor,
631
+ topk_ids: torch.Tensor,
632
+ activation: str,
633
+ apply_router_weight_on_input: bool,
634
+ ) -> torch.Tensor:
635
+ ctx.activation = activation
636
+ ctx.apply_router_weight_on_input = apply_router_weight_on_input
637
+ ctx.save_for_backward(hidden_states, w1, w2, topk_weights, topk_ids)
638
+ return _fused_experts_moe_forward(
639
+ hidden_states,
640
+ w1,
641
+ w2,
642
+ topk_weights,
643
+ topk_ids,
644
+ activation=activation,
645
+ inplace=False,
646
+ apply_router_weight_on_input=apply_router_weight_on_input,
647
+ )
648
+
649
+ @staticmethod
650
+ def backward(ctx, grad_out: torch.Tensor):
651
+ (
652
+ hidden_states,
653
+ w1,
654
+ w2,
655
+ topk_weights,
656
+ topk_ids,
657
+ ) = ctx.saved_tensors
658
+ activation: str = ctx.activation
659
+ apply_router_weight_on_input: bool = ctx.apply_router_weight_on_input
660
+
661
+ need_hidden, need_w1, need_w2, need_topk_w = ctx.needs_input_grad[:4]
662
+
663
+ grad_hidden = torch.zeros_like(hidden_states) if need_hidden else None
664
+ grad_w1 = torch.zeros_like(w1) if need_w1 else None
665
+ grad_w2 = torch.zeros_like(w2) if need_w2 else None
666
+ grad_topk_weights = torch.zeros_like(topk_weights) if need_topk_w else None
667
+
668
+ if hidden_states.numel() == 0:
669
+ return grad_hidden, grad_w1, grad_w2, grad_topk_weights, None, None, None
670
+
671
+ num_tokens = hidden_states.size(0)
672
+ topk = topk_ids.size(1)
673
+ num_experts = w1.size(0)
674
+
675
+ CHUNK_SIZE = int(os.getenv("FUSED_MOE_CHUNK_SIZE", str(16 * 1024)))
676
+ max_padded_tokens_per_expert = int(
677
+ os.getenv("FUSED_MOE_BACKWARD_MAX_PADDED_TOKENS_PER_EXPERT", "2048")
678
+ )
679
+ for chunk in range((num_tokens // CHUNK_SIZE) + 1):
680
+ begin = chunk * CHUNK_SIZE
681
+ end = min((chunk + 1) * CHUNK_SIZE, num_tokens)
682
+ x = hidden_states[begin:end]
683
+ if x.numel() == 0:
684
+ break
685
+
686
+ m = x.size(0)
687
+ token_ids = (
688
+ torch.arange(m, device=x.device, dtype=torch.long)
689
+ .repeat_interleave(topk)
690
+ )
691
+ k_ids = torch.arange(topk, device=x.device, dtype=torch.long).repeat(m)
692
+ expert_ids = topk_ids[begin:end].reshape(-1).to(torch.long)
693
+ weights_fp32 = topk_weights[begin:end].reshape(-1).to(torch.float32)
694
+
695
+ sort_order = torch.argsort(expert_ids)
696
+ expert_ids_sorted = expert_ids[sort_order]
697
+ token_ids_sorted = token_ids[sort_order]
698
+ k_ids_sorted = k_ids[sort_order]
699
+ weights_sorted = weights_fp32[sort_order]
700
+
701
+ counts_per_expert = torch.bincount(expert_ids_sorted, minlength=num_experts)
702
+ max_count = int(counts_per_expert.max().item())
703
+
704
+ use_vectorized = (
705
+ max_count > 0 and max_count <= max_padded_tokens_per_expert
706
+ )
707
+
708
+ if use_vectorized:
709
+ hidden_size = x.size(1)
710
+ offsets = torch.cumsum(counts_per_expert, 0) - counts_per_expert
711
+ pos_in_expert = torch.arange(
712
+ expert_ids_sorted.numel(), device=x.device, dtype=torch.long
713
+ ) - offsets[expert_ids_sorted]
714
+ flat = expert_ids_sorted * max_count + pos_in_expert
715
+
716
+ x_pad = torch.zeros(
717
+ (num_experts, max_count, hidden_size),
718
+ device=x.device,
719
+ dtype=x.dtype,
720
+ )
721
+ x_pad.view(num_experts * max_count, hidden_size)[flat] = x[
722
+ token_ids_sorted
723
+ ]
724
+
725
+ gy_pad = torch.zeros(
726
+ (num_experts, max_count, hidden_size),
727
+ device=x.device,
728
+ dtype=torch.float32,
729
+ )
730
+ gy_pad.view(num_experts * max_count, hidden_size)[flat] = grad_out[
731
+ begin:end
732
+ ][token_ids_sorted].to(torch.float32)
733
+
734
+ w_pad = torch.zeros(
735
+ (num_experts, max_count),
736
+ device=x.device,
737
+ dtype=torch.float32,
738
+ )
739
+ w_pad.view(num_experts * max_count)[flat] = weights_sorted
740
+
741
+ u0 = torch.einsum("emh,eih->emi", x_pad, w1)
742
+ if apply_router_weight_on_input:
743
+ u = u0 * w_pad.to(u0.dtype).unsqueeze(-1)
744
+ else:
745
+ u = u0
746
+ a = _activation_forward(u, activation)
747
+
748
+ tmp = torch.einsum("emh,ehi->emi", gy_pad.to(a.dtype), w2)
749
+ tmp_fp32 = tmp.to(torch.float32)
750
+
751
+ if need_w2:
752
+ if apply_router_weight_on_input:
753
+ grad_v = gy_pad.to(a.dtype)
754
+ else:
755
+ grad_v = (gy_pad * w_pad.unsqueeze(-1)).to(a.dtype)
756
+ grad_w2_chunk = torch.einsum("emh,emi->ehi", grad_v, a)
757
+ assert grad_w2 is not None
758
+ grad_w2.add_(grad_w2_chunk.to(grad_w2.dtype))
759
+
760
+ gA_fp32 = tmp_fp32
761
+ if not apply_router_weight_on_input:
762
+ gA_fp32 = gA_fp32 * w_pad.unsqueeze(-1)
763
+
764
+ du_fp32 = _activation_backward(u.to(torch.float32), activation)
765
+ gU_fp32 = gA_fp32 * du_fp32
766
+
767
+ if apply_router_weight_on_input:
768
+ if need_topk_w:
769
+ grad_w_fp32 = torch.sum(
770
+ gU_fp32 * u0.to(torch.float32),
771
+ dim=-1,
772
+ )
773
+ gU0_fp32 = gU_fp32 * w_pad.unsqueeze(-1)
774
+ else:
775
+ if need_topk_w:
776
+ grad_w_fp32 = torch.sum(
777
+ a.to(torch.float32) * tmp_fp32,
778
+ dim=-1,
779
+ )
780
+ gU0_fp32 = gU_fp32
781
+
782
+ gU0 = gU0_fp32.to(x.dtype)
783
+
784
+ if need_w1:
785
+ grad_w1_chunk = torch.einsum("emi,emh->eih", gU0, x_pad)
786
+ assert grad_w1 is not None
787
+ grad_w1.add_(grad_w1_chunk.to(grad_w1.dtype))
788
+
789
+ if need_hidden:
790
+ grad_x_pad = torch.einsum("emi,eih->emh", gU0, w1)
791
+ grad_x_assign = grad_x_pad.view(
792
+ num_experts * max_count, hidden_size
793
+ )[flat]
794
+ grad_x_chunk = torch.zeros(
795
+ (m, hidden_size), device=x.device, dtype=x.dtype
796
+ )
797
+ grad_x_chunk.index_add_(0, token_ids_sorted, grad_x_assign)
798
+ assert grad_hidden is not None
799
+ grad_hidden[begin:end].copy_(grad_x_chunk)
800
+
801
+ if need_topk_w:
802
+ grad_w_assign = grad_w_fp32.view(num_experts * max_count)[flat]
803
+ grad_topk_chunk = torch.zeros(
804
+ (m, topk), device=x.device, dtype=topk_weights.dtype
805
+ )
806
+ grad_topk_chunk[token_ids_sorted, k_ids_sorted] = grad_w_assign.to(
807
+ grad_topk_chunk.dtype
808
+ )
809
+ assert grad_topk_weights is not None
810
+ grad_topk_weights[begin:end].copy_(grad_topk_chunk)
811
+
812
+ continue
813
+
814
+ unique_experts, counts = torch.unique_consecutive(
815
+ expert_ids_sorted, return_counts=True
816
+ )
817
+
818
+ grad_x_chunk = torch.zeros((m, x.size(1)), device=x.device, dtype=x.dtype) if need_hidden else None
819
+ grad_topk_chunk = torch.zeros((m, topk), device=x.device, dtype=topk_weights.dtype) if need_topk_w else None
820
+
821
+ offset = 0
822
+ for expert_idx, count in zip(unique_experts.tolist(), counts.tolist()):
823
+ tokens = token_ids_sorted[offset : offset + count]
824
+ ks = k_ids_sorted[offset : offset + count]
825
+ w = weights_sorted[offset : offset + count]
826
+ offset += count
827
+ if count == 0:
828
+ continue
829
+
830
+ x_e = x.index_select(0, tokens)
831
+ w1_e = w1[expert_idx]
832
+ w2_e = w2[expert_idx]
833
+
834
+ u0 = F.linear(x_e, w1_e)
835
+ if apply_router_weight_on_input:
836
+ u = u0 * w.to(u0.dtype).unsqueeze(-1)
837
+ else:
838
+ u = u0
839
+ a = _activation_forward(u, activation)
840
+
841
+ grad_y_fp32 = grad_out[begin:end].index_select(0, tokens).to(torch.float32)
842
+ if apply_router_weight_on_input:
843
+ grad_v_fp32 = grad_y_fp32
844
+ else:
845
+ grad_v_fp32 = grad_y_fp32 * w.unsqueeze(-1)
846
+
847
+ grad_v = grad_v_fp32.to(a.dtype)
848
+
849
+ if need_w2:
850
+ grad_w2_e = torch.matmul(grad_v.transpose(0, 1), a)
851
+ assert grad_w2 is not None
852
+ grad_w2[expert_idx].add_(grad_w2_e.to(grad_w2.dtype))
853
+
854
+ gA = torch.matmul(grad_v, w2_e)
855
+ du_fp32 = _activation_backward(u.to(torch.float32), activation)
856
+ gU_fp32 = gA.to(torch.float32) * du_fp32
857
+
858
+ if apply_router_weight_on_input:
859
+ grad_w_fp32 = torch.sum(gU_fp32 * u0.to(torch.float32), dim=-1)
860
+ gU0_fp32 = gU_fp32 * w.unsqueeze(-1)
861
+ else:
862
+ gy_for_w = grad_y_fp32.to(a.dtype)
863
+ tmp = torch.matmul(gy_for_w, w2_e).to(torch.float32)
864
+ grad_w_fp32 = torch.sum(a.to(torch.float32) * tmp, dim=-1)
865
+ gU0_fp32 = gU_fp32
866
+
867
+ if need_topk_w:
868
+ assert grad_topk_chunk is not None
869
+ grad_topk_chunk[tokens, ks] = grad_w_fp32.to(grad_topk_chunk.dtype)
870
+
871
+ gU0 = gU0_fp32.to(x_e.dtype)
872
+
873
+ if need_w1:
874
+ grad_w1_e = torch.matmul(gU0.transpose(0, 1), x_e)
875
+ assert grad_w1 is not None
876
+ grad_w1[expert_idx].add_(grad_w1_e.to(grad_w1.dtype))
877
+
878
+ if need_hidden:
879
+ assert grad_x_chunk is not None
880
+ grad_x_e = torch.matmul(gU0, w1_e)
881
+ grad_x_chunk.index_add_(0, tokens, grad_x_e)
882
+
883
+ if need_hidden:
884
+ assert grad_hidden is not None and grad_x_chunk is not None
885
+ grad_hidden[begin:end].copy_(grad_x_chunk)
886
+ if need_topk_w:
887
+ assert grad_topk_weights is not None and grad_topk_chunk is not None
888
+ grad_topk_weights[begin:end].copy_(grad_topk_chunk)
889
+
890
+ return grad_hidden, grad_w1, grad_w2, grad_topk_weights, None, None, None
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|im_end|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|im_end|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:623c34567aebb18582765289fbe23d901c62704d6518d71866e0e58db892b5b7
3
+ size 17077484
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff