Upload folder using huggingface_hub

#1
by exlaw - opened
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
LICENSE.txt ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tencent is pleased to support the community by making WeDLM-8B-Instruct available.
2
+
3
+ Copyright (C) 2025 Tencent. All rights reserved. WeDLM-8B-Instruct IS NOT INTENDED FOR USE WITHIN THE EUROPEAN UNION.
4
+
5
+ WeDLM-8B-Instruct is licensed under the License Terms of WeDLM-8B-Instruct except for the third-party components listed below, which is licensed under different terms. WeDLM-8B-Instruct does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
6
+
7
+ For avoidance of doubts, WeDLM-8B-Instruct refers to the inference enabling code, parameters and weights made publicly available by Tencent in accordance with the License Terms of WeDLM-8B-Instruct in this repository.
8
+
9
+ Terms of the License Terms of WeDLM-8B-Instruct:
10
+ --------------------------------------------------------------------
11
+
12
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
13
+ 0. Additional Territorial Limitation
14
+
15
+ *WeDLM-8B-Instruct IS NOT INTENDED FOR USE WITHIN THE EUROPEAN UNION.*
16
+ IN THE EVENT OF ANY CONFLICT, THIS CLAUSE SHALL PREVAIL.
17
+
18
+ 1. Definitions.
19
+
20
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
21
+
22
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
23
+
24
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
25
+
26
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
31
+
32
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
33
+
34
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
35
+
36
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
37
+
38
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
39
+
40
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
41
+
42
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
43
+
44
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
45
+
46
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
47
+
48
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
49
+
50
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
51
+
52
+ If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
53
+
54
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
55
+
56
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
57
+
58
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
59
+
60
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
61
+
62
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
63
+
64
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
65
+
66
+ END OF TERMS AND CONDITIONS
67
+
68
+ The WeDLM-8B-Instruct model was fine-tuned with the assistance of the following Open Models by Tencent.
69
+
70
+ Open Models Licensed under the Apache-2.0 License:
71
+ --------------------------------------------------------------------
72
+ 1.Qwen/Qwen3-8B
73
+ Copyright 2024 Alibaba Cloud
74
+ --------------------------------------------------------------------
75
+ Terms of the Apache-2.0 License:
76
+ --------------------------------------------------------------------
77
+
78
+ Apache License
79
+
80
+ Version 2.0, January 2004
81
+
82
+ http://www.apache.org/licenses/
83
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
84
+
85
+ 1. Definitions.
86
+
87
+ "License" shall mean the terms and conditions for use, reproduction,
88
+ and distribution as defined by Sections 1 through 9 of this document.
89
+
90
+ "Licensor" shall mean the copyright owner or entity authorized by
91
+ the copyright owner that is granting the License.
92
+
93
+ "Legal Entity" shall mean the union of the acting entity and all
94
+ other entities that control, are controlled by, or are under common
95
+ control with that entity. For the purposes of this definition,
96
+ "control" means (i) the power, direct or indirect, to cause the
97
+ direction or management of such entity, whether by contract or
98
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
99
+ outstanding shares, or (iii) beneficial ownership of such entity.
100
+
101
+ "You" (or "Your") shall mean an individual or Legal Entity
102
+ exercising permissions granted by this License.
103
+
104
+ "Source" form shall mean the preferred form for making modifications,
105
+ including but not limited to software source code, documentation
106
+ source, and configuration files.
107
+
108
+ "Object" form shall mean any form resulting from mechanical
109
+ transformation or translation of a Source form, including but
110
+ not limited to compiled object code, generated documentation,
111
+ and conversions to other media types.
112
+
113
+ "Work" shall mean the work of authorship, whether in Source or
114
+ Object form, made available under the License, as indicated by a
115
+ copyright notice that is included in or attached to the work
116
+ (an example is provided in the Appendix below).
117
+
118
+ "Derivative Works" shall mean any work, whether in Source or Object
119
+ form, that is based on (or derived from) the Work and for which the
120
+ editorial revisions, annotations, elaborations, or other modifications
121
+ represent, as a whole, an original work of authorship. For the purposes
122
+ of this License, Derivative Works shall not include works that remain
123
+ separable from, or merely link (or bind by name) to the interfaces of,
124
+ the Work and Derivative Works thereof.
125
+
126
+ "Contribution" shall mean any work of authorship, including
127
+ the original version of the Work and any modifications or additions
128
+ to that Work or Derivative Works thereof, that is intentionally
129
+ submitted to Licensor for inclusion in the Work by the copyright owner
130
+ or by an individual or Legal Entity authorized to submit on behalf of
131
+ the copyright owner. For the purposes of this definition, "submitted"
132
+ means any form of electronic, verbal, or written communication sent
133
+ to the Licensor or its representatives, including but not limited to
134
+ communication on electronic mailing lists, source code control systems,
135
+ and issue tracking systems that are managed by, or on behalf of, the
136
+ Licensor for the purpose of discussing and improving the Work, but
137
+ excluding communication that is conspicuously marked or otherwise
138
+ designated in writing by the copyright owner as "Not a Contribution."
139
+
140
+ "Contributor" shall mean Licensor and any individual or Legal Entity
141
+ on behalf of whom a Contribution has been received by Licensor and
142
+ subsequently incorporated within the Work.
143
+
144
+ 2. Grant of Copyright License. Subject to the terms and conditions of
145
+ this License, each Contributor hereby grants to You a perpetual,
146
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
147
+ copyright license to reproduce, prepare Derivative Works of,
148
+ publicly display, publicly perform, sublicense, and distribute the
149
+ Work and such Derivative Works in Source or Object form.
150
+
151
+ 3. Grant of Patent License. Subject to the terms and conditions of
152
+ this License, each Contributor hereby grants to You a perpetual,
153
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
154
+ (except as stated in this section) patent license to make, have made,
155
+ use, offer to sell, sell, import, and otherwise transfer the Work,
156
+ where such license applies only to those patent claims licensable
157
+ by such Contributor that are necessarily infringed by their
158
+ Contribution(s) alone or by combination of their Contribution(s)
159
+ with the Work to which such Contribution(s) was submitted. If You
160
+ institute patent litigation against any entity (including a
161
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
162
+ or a Contribution incorporated within the Work constitutes direct
163
+ or contributory patent infringement, then any patent licenses
164
+ granted to You under this License for that Work shall terminate
165
+ as of the date such litigation is filed.
166
+
167
+ 4. Redistribution. You may reproduce and distribute copies of the
168
+ Work or Derivative Works thereof in any medium, with or without
169
+ modifications, and in Source or Object form, provided that You
170
+ meet the following conditions:
171
+
172
+ (a) You must give any other recipients of the Work or
173
+ Derivative Works a copy of this License; and
174
+
175
+ (b) You must cause any modified files to carry prominent notices
176
+ stating that You changed the files; and
177
+
178
+ (c) You must retain, in the Source form of any Derivative Works
179
+ that You distribute, all copyright, patent, trademark, and
180
+ attribution notices from the Source form of the Work,
181
+ excluding those notices that do not pertain to any part of
182
+ the Derivative Works; and
183
+
184
+ (d) If the Work includes a "NOTICE" text file as part of its
185
+ distribution, then any Derivative Works that You distribute must
186
+ include a readable copy of the attribution notices contained
187
+ within such NOTICE file, excluding those notices that do not
188
+ pertain to any part of the Derivative Works, in at least one
189
+ of the following places: within a NOTICE text file distributed
190
+ as part of the Derivative Works; within the Source form or
191
+ documentation, if provided along with the Derivative Works; or,
192
+ within a display generated by the Derivative Works, if and
193
+ wherever such third-party notices normally appear. The contents
194
+ of the NOTICE file are for informational purposes only and
195
+ do not modify the License. You may add Your own attribution
196
+ notices within Derivative Works that You distribute, alongside
197
+ or as an addendum to the NOTICE text from the Work, provided
198
+ that such additional attribution notices cannot be construed
199
+ as modifying the License.
200
+
201
+ You may add Your own copyright statement to Your modifications and
202
+ may provide additional or different license terms and conditions
203
+ for use, reproduction, or distribution of Your modifications, or
204
+ for any such Derivative Works as a whole, provided Your use,
205
+ reproduction, and distribution of the Work otherwise complies with
206
+ the conditions stated in this License.
207
+
208
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
209
+ any Contribution intentionally submitted for inclusion in the Work
210
+ by You to the Licensor shall be under the terms and conditions of
211
+ this License, without any additional terms or conditions.
212
+ Notwithstanding the above, nothing herein shall supersede or modify
213
+ the terms of any separate license agreement you may have executed
214
+ with Licensor regarding such Contributions.
215
+
216
+ 6. Trademarks. This License does not grant permission to use the trade
217
+ names, trademarks, service marks, or product names of the Licensor,
218
+ except as required for reasonable and customary use in describing the
219
+ origin of the Work and reproducing the content of the NOTICE file.
220
+
221
+ 7. Disclaimer of Warranty. Unless required by applicable law or
222
+ agreed to in writing, Licensor provides the Work (and each
223
+ Contributor provides its Contributions) on an "AS IS" BASIS,
224
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
225
+ implied, including, without limitation, any warranties or conditions
226
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
227
+ PARTICULAR PURPOSE. You are solely responsible for determining the
228
+ appropriateness of using or redistributing the Work and assume any
229
+ risks associated with Your exercise of permissions under this License.
230
+
231
+ 8. Limitation of Liability. In no event and under no legal theory,
232
+ whether in tort (including negligence), contract, or otherwise,
233
+ unless required by applicable law (such as deliberate and grossly
234
+ negligent acts) or agreed to in writing, shall any Contributor be
235
+ liable to You for damages, including any direct, indirect, special,
236
+ incidental, or consequential damages of any character arising as a
237
+ result of this License or out of the use or inability to use the
238
+ Work (including but not limited to damages for loss of goodwill,
239
+ work stoppage, computer failure or malfunction, or any and all
240
+ other commercial damages or losses), even if such Contributor
241
+ has been advised of the possibility of such damages.
242
+
243
+ 9. Accepting Warranty or Additional Liability. While redistributing
244
+ the Work or Derivative Works thereof, You may choose to offer,
245
+ and charge a fee for, acceptance of support, warranty, indemnity,
246
+ or other liability obligations and/or rights consistent with this
247
+ License. However, in accepting such obligations, You may act only
248
+ on Your own behalf and on Your sole responsibility, not on behalf
249
+ of any other Contributor, and only if You agree to indemnify,
250
+ defend, and hold each Contributor harmless for any liability
251
+ incurred by, or claims asserted against, such Contributor by reason
252
+ of your accepting any such warranty or additional liability.
253
+
254
+ END OF TERMS AND CONDITIONS
255
+
256
+ The Code of this project is built on and with the aid of the following open source projects. Credits are given to these projects.
257
+
258
+ Open Source Software Licensed under the MIT:
259
+ --------------------------------------------------------------------
260
+ 1. nano-vllm
261
+ Copyright (c) 2025 Xingkai Yu
262
+
263
+ Terms of the MIT:
264
+ --------------------------------------------------------------------
265
+ Permission is hereby granted, free of charge, to any person obtaining a copy
266
+ of this software and associated documentation files (the "Software"), to deal
267
+ in the Software without restriction, including without limitation the rights
268
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
269
+ copies of the Software, and to permit persons to whom the Software is
270
+ furnished to do so, subject to the following conditions:
271
+
272
+ The above copyright notice and this permission notice shall be included in all
273
+ copies or substantial portions of the Software.
274
+
275
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
276
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
277
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
278
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
279
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
280
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
281
+ SOFTWARE.
282
+ ==================================================
283
+ End of the Attribution Notice of this project.
Readme.md ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ - zh
6
+ base_model: tencent/WeDLM-8B
7
+ pipeline_tag: text-generation
8
+ tags:
9
+ - language model
10
+ - parallel-decoding
11
+ ---
12
+
13
+ # WeDLM-8B-Instruct ⭐
14
+
15
+ **WeDLM-8B-Instruct** is our flagship instruction-tuned diffusion language model that performs parallel decoding under standard causal attention, fine-tuned from [WeDLM-8B](https://huggingface.co/tencent/WeDLM-8B).
16
+
17
+ **Highlights:**
18
+ - 🚀 3-6× faster than vLLM-optimized Qwen3-8B on math reasoning tasks
19
+ - 📈 Outperforms base Qwen3-8B-Instruct on most benchmarks
20
+ - ✅ Native KV cache compatible (FlashAttention, PagedAttention, CUDA Graphs)
21
+
22
+ For the base (pretrained) version, see [WeDLM-8B](https://huggingface.co/tencent/WeDLM-8B).
23
+
24
+ 📄 Paper (Coming Soon) | 🌐 [Project Page](https://wedlm.github.io) | 💻 [GitHub](https://github.com/tencent/WeDLM)
25
+
26
+ ## Model Details
27
+
28
+ | Attribute | Value |
29
+ |:----------|:------|
30
+ | Base Model | [WeDLM-8B](https://huggingface.co/tencent/WeDLM-8B) |
31
+ | Parameters | 8B |
32
+ | Context Length | 32,768 |
33
+
34
+ ## Quick Start (Recommended)
35
+
36
+ For **fast inference**, use the `wedlm` engine:
37
+
38
+ ```bash
39
+ pip install git+https://github.com/tencent/WeDLM.git
40
+ ```
41
+
42
+ ```python
43
+ from transformers import AutoTokenizer
44
+ from wedlm import LLM, SamplingParams
45
+
46
+ llm = LLM(model="tencent/WeDLM-8B-Instruct")
47
+ tokenizer = AutoTokenizer.from_pretrained("tencent/WeDLM-8B-Instruct", trust_remote_code=True)
48
+
49
+ prompt = "Solve step by step: A store sells apples for $2 each and oranges for $3 each. Tom bought 5 apples and 4 oranges. How much did he spend?"
50
+ messages = [{"role": "user", "content": prompt}]
51
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
52
+
53
+ outputs = llm.generate([text], SamplingParams(temperature=0.0, max_tokens=512))
54
+ print(outputs[0]["text"])
55
+ ```
56
+
57
+ ### Multi-turn Conversation
58
+
59
+ ```python
60
+ messages = [
61
+ {"role": "user", "content": "What is the derivative of x^2?"},
62
+ {"role": "assistant", "content": "The derivative of x² is 2x."},
63
+ {"role": "user", "content": "What about x^3?"}
64
+ ]
65
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
66
+ outputs = llm.generate([text], SamplingParams(temperature=0.0, max_tokens=256))
67
+ ```
68
+
69
+ ### Batch Inference
70
+
71
+ ```python
72
+ prompts = [
73
+ "Explain quantum entanglement simply.",
74
+ "Write a Python function to check if a number is prime.",
75
+ "What are the main causes of climate change?"
76
+ ]
77
+ messages_batch = [[{"role": "user", "content": p}] for p in prompts]
78
+ texts = [tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages_batch]
79
+
80
+ outputs = llm.generate(texts, SamplingParams(temperature=0.3, max_tokens=512))
81
+ for i, output in enumerate(outputs):
82
+ print(f"=== Response {i+1} ===\n{output['text']}\n")
83
+ ```
84
+
85
+ ## HuggingFace Transformers
86
+
87
+ For **training** or simple forward passes:
88
+
89
+ ```python
90
+ from transformers import AutoTokenizer, AutoModelForCausalLM
91
+
92
+ tokenizer = AutoTokenizer.from_pretrained("tencent/WeDLM-8B-Instruct", trust_remote_code=True)
93
+ model = AutoModelForCausalLM.from_pretrained(
94
+ "tencent/WeDLM-8B-Instruct",
95
+ trust_remote_code=True,
96
+ torch_dtype="auto",
97
+ device_map="auto"
98
+ )
99
+
100
+ messages = [{"role": "user", "content": "Hello!"}]
101
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
102
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
103
+ outputs = model(**inputs)
104
+ ```
105
+
106
+ > ⚠️ **Note:** The HuggingFace interface is for training/forward pass convenience. For optimized inference throughput, use the `wedlm` engine above.
107
+
108
+ ## Performance
109
+
110
+ ### Generation Quality
111
+
112
+ | Benchmark | Qwen3-8B-Instruct | WeDLM-8B-Instruct |
113
+ |:----------|:-----------------:|:-----------------:|
114
+ | ARC-C (0-shot) | 91.47 | **92.92** |
115
+ | GSM8K (3-shot) | 89.91 | **92.27** |
116
+ | MATH (4-shot) | **69.60** | 64.80 |
117
+ | HumanEval (4-shot) | 71.95 | **80.49** |
118
+ | MMLU (5-shot) | 71.52 | **75.14** |
119
+ | GPQA-Diamond (5-shot) | 41.41 | **44.95** |
120
+ | **Average** | 75.12 | **77.53** |
121
+
122
+ ### Inference Speed
123
+
124
+ Speedup varies by task characteristics (measured against vLLM-optimized Qwen3-8B-Instruct):
125
+
126
+ | Scenario | Speedup | Notes |
127
+ |:---------|:-------:|:------|
128
+ | Math Reasoning (GSM8K) | 3-6× | Structured, predictable output |
129
+ | Code Generation | 2-3× | Deterministic syntax |
130
+ | Open-ended QA | 1.5-2× | Higher entropy limits parallelism |
131
+
132
+ ## Citation (Coming soon)
133
+
134
+
135
+ ## License
136
+
137
+ Apache 2.0
__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The WeDLM Team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_wedlm import WeDLMConfig
16
+ from .modeling_wedlm import WeDLMForCausalLM, WeDLMModel, WeDLMPreTrainedModel
17
+
18
+ __all__ = [
19
+ "WeDLMConfig",
20
+ "WeDLMPreTrainedModel",
21
+ "WeDLMModel",
22
+ "WeDLMForCausalLM",
23
+ ]
added_tokens.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</tool_call>": 151658,
3
+ "<tool_call>": 151657,
4
+ "<|box_end|>": 151649,
5
+ "<|box_start|>": 151648,
6
+ "<|endoftext|>": 151643,
7
+ "<|file_sep|>": 151664,
8
+ "<|fim_middle|>": 151660,
9
+ "<|fim_pad|>": 151662,
10
+ "<|fim_prefix|>": 151659,
11
+ "<|fim_suffix|>": 151661,
12
+ "<|im_end|>": 151645,
13
+ "<|im_start|>": 151644,
14
+ "<|image_pad|>": 151655,
15
+ "<|object_ref_end|>": 151647,
16
+ "<|object_ref_start|>": 151646,
17
+ "<|quad_end|>": 151651,
18
+ "<|quad_start|>": 151650,
19
+ "<|repo_name|>": 151663,
20
+ "<|video_pad|>": 151656,
21
+ "<|vision_end|>": 151653,
22
+ "<|vision_pad|>": 151654,
23
+ "<|vision_start|>": 151652
24
+ }
chat_template.jinja ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0]['role'] == 'system' %}
4
+ {{- messages[0]['content'] }}
5
+ {%- else %}
6
+ {{- 'You are a helpful assistant.' }}
7
+ {%- endif %}
8
+ {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
9
+ {%- for tool in tools %}
10
+ {{- "\n" }}
11
+ {{- tool | tojson }}
12
+ {%- endfor %}
13
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
14
+ {%- else %}
15
+ {%- if messages[0]['role'] == 'system' %}
16
+ {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
17
+ {%- else %}
18
+ {{- '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}
19
+ {%- endif %}
20
+ {%- endif %}
21
+ {%- for message in messages %}
22
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
23
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
24
+ {%- elif message.role == "assistant" %}
25
+ {{- '<|im_start|>' + message.role }}
26
+ {%- if message.content %}
27
+ {{- '\n' + message.content }}
28
+ {%- endif %}
29
+ {%- for tool_call in message.tool_calls %}
30
+ {%- if tool_call.function is defined %}
31
+ {%- set tool_call = tool_call.function %}
32
+ {%- endif %}
33
+ {{- '\n<tool_call>\n{"name": "' }}
34
+ {{- tool_call.name }}
35
+ {{- '", "arguments": ' }}
36
+ {{- tool_call.arguments | tojson }}
37
+ {{- '}\n</tool_call>' }}
38
+ {%- endfor %}
39
+ {{- '<|im_end|>\n' }}
40
+ {%- elif message.role == "tool" %}
41
+ {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
42
+ {{- '<|im_start|>user' }}
43
+ {%- endif %}
44
+ {{- '\n<tool_response>\n' }}
45
+ {{- message.content }}
46
+ {{- '\n</tool_response>' }}
47
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
48
+ {{- '<|im_end|>\n' }}
49
+ {%- endif %}
50
+ {%- endif %}
51
+ {%- endfor %}
52
+ {%- if add_generation_prompt %}
53
+ {{- '<|im_start|>assistant\n' }}
54
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "WeDLMForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_wedlm.WeDLMConfig",
9
+ "AutoModelForCausalLM": "modeling_wedlm.WeDLMForCausalLM"
10
+ },
11
+ "dtype": "bfloat16",
12
+ "eos_token_id": 151643,
13
+ "head_dim": 128,
14
+ "hidden_act": "silu",
15
+ "hidden_size": 4096,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 12288,
18
+ "layer_types": [
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention",
44
+ "full_attention",
45
+ "full_attention",
46
+ "full_attention",
47
+ "full_attention",
48
+ "full_attention",
49
+ "full_attention",
50
+ "full_attention",
51
+ "full_attention",
52
+ "full_attention",
53
+ "full_attention",
54
+ "full_attention"
55
+ ],
56
+ "mask_token_id": null,
57
+ "max_position_embeddings": 16384,
58
+ "max_window_layers": 36,
59
+ "model_type": "wedlm",
60
+ "num_attention_heads": 32,
61
+ "num_hidden_layers": 36,
62
+ "num_key_value_heads": 8,
63
+ "pad_token_id": 151643,
64
+ "qk_norm": true,
65
+ "rms_norm_eps": 1e-06,
66
+ "rope_scaling": null,
67
+ "rope_theta": 1000000.0,
68
+ "sliding_window": null,
69
+ "tie_word_embeddings": false,
70
+ "transformers_version": "4.56.1",
71
+ "use_cache": true,
72
+ "use_sliding_window": false,
73
+ "vocab_size": 151936
74
+ }
configuration_wedlm.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The WeDLM team and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """WeDLM model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.modeling_rope_utils import rope_config_validation
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class WeDLMConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`WeDLMModel`]. It is used to instantiate an
28
+ WeDLM model according to the specified arguments, defining the model architecture.
29
+
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information.
32
+
33
+ Args:
34
+ vocab_size (`int`, *optional*, defaults to 151936):
35
+ Vocabulary size of the WeDLM model. Defines the number of different tokens that can be represented by the
36
+ `inputs_ids` passed when calling [`WeDLMModel`]
37
+ hidden_size (`int`, *optional*, defaults to 4096):
38
+ Dimension of the hidden representations.
39
+ intermediate_size (`int`, *optional*, defaults to 22016):
40
+ Dimension of the MLP representations.
41
+ num_hidden_layers (`int`, *optional*, defaults to 32):
42
+ Number of hidden layers in the Transformer encoder.
43
+ num_attention_heads (`int`, *optional*, defaults to 32):
44
+ Number of attention heads for each attention layer in the Transformer encoder.
45
+ num_key_value_heads (`int`, *optional*, defaults to 32):
46
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
47
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
48
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used.
49
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
50
+ The non-linear activation function (function or string) in the decoder.
51
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
52
+ The maximum sequence length that this model might ever be used with.
53
+ initializer_range (`float`, *optional*, defaults to 0.02):
54
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
55
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
56
+ The epsilon used by the rms normalization layers.
57
+ use_cache (`bool`, *optional*, defaults to `True`):
58
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
59
+ relevant if `config.is_decoder=True`.
60
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
61
+ Whether the model's input and output word embeddings should be tied.
62
+ rope_theta (`float`, *optional*, defaults to 10000.0):
63
+ The base period of the RoPE embeddings.
64
+ rope_scaling (`Dict`, *optional*):
65
+ Dictionary containing the scaling configuration for the RoPE embeddings.
66
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
67
+ Whether to use sliding window attention.
68
+ sliding_window (`int`, *optional*, defaults to 4096):
69
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
70
+ max_window_layers (`int`, *optional*, defaults to 28):
71
+ The number of layers using full attention.
72
+ attention_dropout (`float`, *optional*, defaults to 0.0):
73
+ The dropout ratio for the attention probabilities.
74
+ attention_bias (`bool`, *optional*, defaults to `True`):
75
+ Whether to use bias in QKV projections. Set to `True` for Qwen2.5 compatibility,
76
+ `False` for Qwen3 compatibility.
77
+ qk_norm (`bool`, *optional*, defaults to `False`):
78
+ Whether to use QK normalization. Set to `True` for Qwen3 compatibility.
79
+ head_dim (`int`, *optional*):
80
+ The dimension of each attention head. If not specified, defaults to hidden_size // num_attention_heads.
81
+ """
82
+
83
+ model_type = "wedlm"
84
+ keys_to_ignore_at_inference = ["past_key_values"]
85
+
86
+ def __init__(
87
+ self,
88
+ vocab_size=151936,
89
+ hidden_size=4096,
90
+ intermediate_size=22016,
91
+ num_hidden_layers=32,
92
+ num_attention_heads=32,
93
+ num_key_value_heads=32,
94
+ hidden_act="silu",
95
+ max_position_embeddings=32768,
96
+ initializer_range=0.02,
97
+ rms_norm_eps=1e-6,
98
+ use_cache=True,
99
+ tie_word_embeddings=False,
100
+ rope_theta=10000.0,
101
+ rope_scaling=None,
102
+ use_sliding_window=False,
103
+ sliding_window=4096,
104
+ max_window_layers=28,
105
+ attention_dropout=0.0,
106
+ attention_bias=True,
107
+ qk_norm=False,
108
+ head_dim=None,
109
+ mask_token_id=None,
110
+ **kwargs,
111
+ ):
112
+ self.vocab_size = vocab_size
113
+ self.max_position_embeddings = max_position_embeddings
114
+ self.hidden_size = hidden_size
115
+ self.intermediate_size = intermediate_size
116
+ self.num_hidden_layers = num_hidden_layers
117
+ self.num_attention_heads = num_attention_heads
118
+ self.use_sliding_window = use_sliding_window
119
+ self.sliding_window = sliding_window if self.use_sliding_window else None
120
+ self.max_window_layers = max_window_layers
121
+
122
+ # for backward compatibility
123
+ if num_key_value_heads is None:
124
+ num_key_value_heads = num_attention_heads
125
+
126
+ self.num_key_value_heads = num_key_value_heads
127
+ self.hidden_act = hidden_act
128
+ self.initializer_range = initializer_range
129
+ self.rms_norm_eps = rms_norm_eps
130
+ self.use_cache = use_cache
131
+ self.rope_theta = rope_theta
132
+ self.rope_scaling = rope_scaling
133
+ self.attention_dropout = attention_dropout
134
+ self.attention_bias = attention_bias
135
+ self.qk_norm = qk_norm
136
+ self.mask_token_id = mask_token_id
137
+
138
+ if head_dim is None:
139
+ self.head_dim = hidden_size // num_attention_heads
140
+ else:
141
+ self.head_dim = head_dim
142
+
143
+ # Validate the correctness of rotary position embeddings parameters
144
+ # BC: if there is a 'type' field, move it to 'rope_type'.
145
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
146
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
147
+ rope_config_validation(self)
148
+
149
+ # Generate layer_types based on sliding window configuration
150
+ self.layer_types = [
151
+ "sliding_attention"
152
+ if self.sliding_window is not None and i >= self.max_window_layers
153
+ else "full_attention"
154
+ for i in range(self.num_hidden_layers)
155
+ ]
156
+
157
+ super().__init__(
158
+ tie_word_embeddings=tie_word_embeddings,
159
+ **kwargs,
160
+ )
161
+
162
+
163
+ __all__ = ["WeDLMConfig"]
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "eos_token_id": 151643,
4
+ "max_new_tokens": 2048,
5
+ "transformers_version": "4.56.1",
6
+ "trust_remote_code": true
7
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e492ba4572ceb74c7a81602912d123dedd47606cd239ece9cbd435fba54c60da
3
+ size 4902257696
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5aad5b5fe4fe669d1ff02b92c7393fa6f8c410c2ac5d6b1eb4dad60596383b4
3
+ size 4915960368
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ef815cb007a0c482fdae4cadbb6a8ca23f28c686254a29b4b49aa5be3dbb656
3
+ size 4983068496
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1f0f4b4341e80e8822d696d4b0b7c9dec34881454dd067adf143ca3867e6c86
3
+ size 1580230264
model.safetensors.index.json ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_parameters": 8190735360,
4
+ "total_size": 16381470720
5
+ },
6
+ "weight_map": {
7
+ "lm_head.weight": "model-00004-of-00004.safetensors",
8
+ "model.embed_tokens.weight": "model-00001-of-00004.safetensors",
9
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
10
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
11
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
12
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
13
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
14
+ "model.layers.0.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
15
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
16
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
17
+ "model.layers.0.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
18
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
19
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
20
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
21
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
22
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
23
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
24
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
25
+ "model.layers.1.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
26
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
27
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
28
+ "model.layers.1.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
29
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
30
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
31
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
32
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
33
+ "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
34
+ "model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
35
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
36
+ "model.layers.10.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
37
+ "model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
38
+ "model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
39
+ "model.layers.10.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
40
+ "model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
41
+ "model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
42
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
43
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
44
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
45
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
46
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
47
+ "model.layers.11.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
48
+ "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
49
+ "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
50
+ "model.layers.11.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
51
+ "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
52
+ "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
53
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
54
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
55
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
56
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
57
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
58
+ "model.layers.12.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
59
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
60
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
61
+ "model.layers.12.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
62
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
63
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
64
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
65
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
66
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
67
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
68
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
69
+ "model.layers.13.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
70
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
71
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
72
+ "model.layers.13.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
73
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
74
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
75
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
76
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
77
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
78
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
79
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
80
+ "model.layers.14.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
81
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
82
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
83
+ "model.layers.14.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
84
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
85
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
86
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
87
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
88
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
89
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
90
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
91
+ "model.layers.15.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
92
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
93
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
94
+ "model.layers.15.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
95
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
96
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
97
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
98
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
99
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
100
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
101
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
102
+ "model.layers.16.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
103
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
104
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
105
+ "model.layers.16.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
106
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
107
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
108
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
109
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
110
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
111
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
112
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
113
+ "model.layers.17.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
114
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
115
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
116
+ "model.layers.17.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
117
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
118
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
119
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
120
+ "model.layers.18.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
121
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
122
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
123
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
124
+ "model.layers.18.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
125
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
126
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
127
+ "model.layers.18.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
128
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
129
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
130
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00004.safetensors",
131
+ "model.layers.19.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
132
+ "model.layers.19.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
133
+ "model.layers.19.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
134
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
135
+ "model.layers.19.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
136
+ "model.layers.19.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
137
+ "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
138
+ "model.layers.19.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
139
+ "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
140
+ "model.layers.19.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
141
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
142
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
143
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
144
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
145
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
146
+ "model.layers.2.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
147
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
148
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
149
+ "model.layers.2.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
150
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
151
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
152
+ "model.layers.20.input_layernorm.weight": "model-00002-of-00004.safetensors",
153
+ "model.layers.20.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
154
+ "model.layers.20.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
155
+ "model.layers.20.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
156
+ "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
157
+ "model.layers.20.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
158
+ "model.layers.20.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
159
+ "model.layers.20.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
160
+ "model.layers.20.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
161
+ "model.layers.20.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
162
+ "model.layers.20.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
163
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00004.safetensors",
164
+ "model.layers.21.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
165
+ "model.layers.21.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
166
+ "model.layers.21.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
167
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
168
+ "model.layers.21.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
169
+ "model.layers.21.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
170
+ "model.layers.21.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
171
+ "model.layers.21.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
172
+ "model.layers.21.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
173
+ "model.layers.21.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
174
+ "model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
175
+ "model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
176
+ "model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
177
+ "model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
178
+ "model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
179
+ "model.layers.22.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
180
+ "model.layers.22.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
181
+ "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
182
+ "model.layers.22.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
183
+ "model.layers.22.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
184
+ "model.layers.22.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
185
+ "model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
186
+ "model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
187
+ "model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
188
+ "model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
189
+ "model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
190
+ "model.layers.23.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
191
+ "model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
192
+ "model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
193
+ "model.layers.23.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
194
+ "model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
195
+ "model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
196
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
197
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
198
+ "model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
199
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
200
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
201
+ "model.layers.24.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
202
+ "model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
203
+ "model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
204
+ "model.layers.24.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
205
+ "model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
206
+ "model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
207
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
208
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
209
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
210
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
211
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
212
+ "model.layers.25.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
213
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
214
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
215
+ "model.layers.25.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
216
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
217
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
218
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
219
+ "model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
220
+ "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
221
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
222
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
223
+ "model.layers.26.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
224
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
225
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
226
+ "model.layers.26.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
227
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
228
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
229
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
230
+ "model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
231
+ "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
232
+ "model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
233
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
234
+ "model.layers.27.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
235
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
236
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
237
+ "model.layers.27.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
238
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
239
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
240
+ "model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
241
+ "model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
242
+ "model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
243
+ "model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
244
+ "model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
245
+ "model.layers.28.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
246
+ "model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
247
+ "model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
248
+ "model.layers.28.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
249
+ "model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
250
+ "model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
251
+ "model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
252
+ "model.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
253
+ "model.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
254
+ "model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
255
+ "model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
256
+ "model.layers.29.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
257
+ "model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
258
+ "model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
259
+ "model.layers.29.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
260
+ "model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
261
+ "model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
262
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
263
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
264
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
265
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
266
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
267
+ "model.layers.3.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
268
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
269
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
270
+ "model.layers.3.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
271
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
272
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
273
+ "model.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors",
274
+ "model.layers.30.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
275
+ "model.layers.30.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
276
+ "model.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
277
+ "model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
278
+ "model.layers.30.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
279
+ "model.layers.30.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
280
+ "model.layers.30.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
281
+ "model.layers.30.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
282
+ "model.layers.30.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
283
+ "model.layers.30.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
284
+ "model.layers.31.input_layernorm.weight": "model-00003-of-00004.safetensors",
285
+ "model.layers.31.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
286
+ "model.layers.31.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
287
+ "model.layers.31.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
288
+ "model.layers.31.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
289
+ "model.layers.31.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
290
+ "model.layers.31.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
291
+ "model.layers.31.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
292
+ "model.layers.31.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
293
+ "model.layers.31.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
294
+ "model.layers.31.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
295
+ "model.layers.32.input_layernorm.weight": "model-00003-of-00004.safetensors",
296
+ "model.layers.32.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
297
+ "model.layers.32.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
298
+ "model.layers.32.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
299
+ "model.layers.32.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
300
+ "model.layers.32.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
301
+ "model.layers.32.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
302
+ "model.layers.32.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
303
+ "model.layers.32.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
304
+ "model.layers.32.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
305
+ "model.layers.32.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
306
+ "model.layers.33.input_layernorm.weight": "model-00003-of-00004.safetensors",
307
+ "model.layers.33.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
308
+ "model.layers.33.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
309
+ "model.layers.33.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
310
+ "model.layers.33.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
311
+ "model.layers.33.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
312
+ "model.layers.33.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
313
+ "model.layers.33.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
314
+ "model.layers.33.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
315
+ "model.layers.33.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
316
+ "model.layers.33.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
317
+ "model.layers.34.input_layernorm.weight": "model-00003-of-00004.safetensors",
318
+ "model.layers.34.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
319
+ "model.layers.34.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
320
+ "model.layers.34.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
321
+ "model.layers.34.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
322
+ "model.layers.34.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
323
+ "model.layers.34.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
324
+ "model.layers.34.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
325
+ "model.layers.34.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
326
+ "model.layers.34.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
327
+ "model.layers.34.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
328
+ "model.layers.35.input_layernorm.weight": "model-00004-of-00004.safetensors",
329
+ "model.layers.35.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
330
+ "model.layers.35.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
331
+ "model.layers.35.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
332
+ "model.layers.35.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
333
+ "model.layers.35.self_attn.k_norm.weight": "model-00004-of-00004.safetensors",
334
+ "model.layers.35.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
335
+ "model.layers.35.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
336
+ "model.layers.35.self_attn.q_norm.weight": "model-00004-of-00004.safetensors",
337
+ "model.layers.35.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
338
+ "model.layers.35.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
339
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
340
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
341
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
342
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
343
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
344
+ "model.layers.4.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
345
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
346
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
347
+ "model.layers.4.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
348
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
349
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
350
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
351
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
352
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
353
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
354
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
355
+ "model.layers.5.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
356
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
357
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
358
+ "model.layers.5.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
359
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
360
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
361
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
362
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
363
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
364
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
365
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
366
+ "model.layers.6.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
367
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
368
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
369
+ "model.layers.6.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
370
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
371
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
372
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
373
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
374
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
375
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
376
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
377
+ "model.layers.7.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
378
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
379
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
380
+ "model.layers.7.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
381
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
382
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
383
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00004.safetensors",
384
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
385
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
386
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
387
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
388
+ "model.layers.8.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
389
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
390
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
391
+ "model.layers.8.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
392
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
393
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
394
+ "model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
395
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
396
+ "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
397
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
398
+ "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
399
+ "model.layers.9.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
400
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
401
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
402
+ "model.layers.9.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
403
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
404
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
405
+ "model.norm.weight": "model-00004-of-00004.safetensors"
406
+ }
407
+ }
modeling_wedlm.py ADDED
@@ -0,0 +1,1004 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The WeDLM team and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch WeDLM model."""
16
+
17
+ from typing import Optional, Tuple, Union, Dict, List, Callable
18
+
19
+ import torch
20
+ from torch import nn
21
+ import torch.nn.functional as F
22
+
23
+ from transformers import PreTrainedModel, GenerationMixin
24
+ from transformers.activations import ACT2FN
25
+ from transformers.cache_utils import Cache, DynamicCache
26
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
27
+ from transformers.processing_utils import Unpack
28
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
29
+ from transformers.utils.generic import check_model_inputs
30
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
31
+ from transformers.modeling_layers import GradientCheckpointingLayer
32
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
33
+
34
+ # Import attention-related utilities
35
+ try:
36
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
37
+ except ImportError:
38
+ FlashAttentionKwargs = dict
39
+
40
+ try:
41
+ from transformers.integrations.flash_attention import ALL_ATTENTION_FUNCTIONS
42
+ except ImportError:
43
+ try:
44
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
45
+ except ImportError:
46
+ ALL_ATTENTION_FUNCTIONS = {}
47
+
48
+ from .configuration_wedlm import WeDLMConfig
49
+
50
+ import logging
51
+
52
+ logger = logging.getLogger(__name__)
53
+ logger.setLevel(logging.DEBUG)
54
+
55
+
56
+ # ============================================================================
57
+ # Core Components (self-contained, no Qwen2 dependency)
58
+ # ============================================================================
59
+
60
+ class WeDLMMLP(nn.Module):
61
+ """WeDLM MLP module with SwiGLU activation."""
62
+
63
+ def __init__(self, config: WeDLMConfig):
64
+ super().__init__()
65
+ self.config = config
66
+ self.hidden_size = config.hidden_size
67
+ self.intermediate_size = config.intermediate_size
68
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
69
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
70
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
71
+ self.act_fn = ACT2FN[config.hidden_act]
72
+
73
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
74
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
75
+ return down_proj
76
+
77
+
78
+ class WeDLMRMSNorm(nn.Module):
79
+ """WeDLM RMSNorm, equivalent to T5LayerNorm."""
80
+
81
+ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
82
+ super().__init__()
83
+ self.weight = nn.Parameter(torch.ones(hidden_size))
84
+ self.variance_epsilon = eps
85
+
86
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
87
+ input_dtype = hidden_states.dtype
88
+ hidden_states = hidden_states.to(torch.float32)
89
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
90
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
91
+ return self.weight * hidden_states.to(input_dtype)
92
+
93
+ def extra_repr(self) -> str:
94
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
95
+
96
+
97
+ class WeDLMRotaryEmbedding(nn.Module):
98
+ """WeDLM Rotary Position Embedding."""
99
+
100
+ def __init__(self, config: WeDLMConfig, device=None):
101
+ super().__init__()
102
+ # Determine rope_type from config
103
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
104
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type", "default"))
105
+ else:
106
+ self.rope_type = "default"
107
+
108
+ self.max_seq_len_cached = config.max_position_embeddings
109
+ self.original_max_seq_len = config.max_position_embeddings
110
+ self.config = config
111
+
112
+ # Get initialization function
113
+ if self.rope_type == "default":
114
+ inv_freq, self.attention_scaling = self._compute_default_rope_parameters(config, device)
115
+ else:
116
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
117
+ inv_freq, self.attention_scaling = rope_init_fn(config, device)
118
+
119
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
120
+ self.original_inv_freq = self.inv_freq
121
+
122
+ @staticmethod
123
+ def _compute_default_rope_parameters(
124
+ config: WeDLMConfig,
125
+ device: Optional[torch.device] = None,
126
+ ) -> Tuple[torch.Tensor, float]:
127
+ """
128
+ Computes the inverse frequencies for default RoPE.
129
+
130
+ Args:
131
+ config: Model configuration
132
+ device: Device to place the tensors on
133
+
134
+ Returns:
135
+ Tuple of (inv_freq tensor, attention_scaling factor)
136
+ """
137
+ base = config.rope_theta
138
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
139
+
140
+ # Compute the inverse frequencies
141
+ inv_freq = 1.0 / (
142
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
143
+ )
144
+ attention_factor = 1.0
145
+ return inv_freq, attention_factor
146
+
147
+ @torch.no_grad()
148
+ def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
149
+ """
150
+ Compute rotary position embeddings.
151
+
152
+ Args:
153
+ x: Input tensor, used for dtype and device
154
+ position_ids: Position indices
155
+
156
+ Returns:
157
+ Tuple of (cos, sin) tensors
158
+ """
159
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
160
+ position_ids_expanded = position_ids[:, None, :].float()
161
+
162
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
163
+
164
+ # Force float32 computation for numerical stability
165
+ with torch.amp.autocast(device_type=device_type, enabled=False):
166
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
167
+ emb = torch.cat((freqs, freqs), dim=-1)
168
+ cos = emb.cos() * self.attention_scaling
169
+ sin = emb.sin() * self.attention_scaling
170
+
171
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
172
+
173
+
174
+ # ============================================================================
175
+ # Attention Utilities
176
+ # ============================================================================
177
+
178
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
179
+ """Rotates half the hidden dims of the input."""
180
+ x1 = x[..., : x.shape[-1] // 2]
181
+ x2 = x[..., x.shape[-1] // 2 :]
182
+ return torch.cat((-x2, x1), dim=-1)
183
+
184
+
185
+ def apply_rotary_pos_emb(
186
+ q: torch.Tensor,
187
+ k: torch.Tensor,
188
+ cos: torch.Tensor,
189
+ sin: torch.Tensor,
190
+ position_ids: Optional[torch.Tensor] = None,
191
+ unsqueeze_dim: int = 1
192
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
193
+ """Applies Rotary Position Embedding to the query and key tensors."""
194
+ cos = cos.unsqueeze(unsqueeze_dim)
195
+ sin = sin.unsqueeze(unsqueeze_dim)
196
+ q_embed = (q * cos) + (rotate_half(q) * sin)
197
+ k_embed = (k * cos) + (rotate_half(k) * sin)
198
+ return q_embed, k_embed
199
+
200
+
201
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
202
+ """
203
+ Repeats key/value heads to match the number of query heads (for GQA).
204
+
205
+ Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
206
+ """
207
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
208
+ if n_rep == 1:
209
+ return hidden_states
210
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
211
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
212
+
213
+
214
+ def eager_attention_forward(
215
+ module: nn.Module,
216
+ query: torch.Tensor,
217
+ key: torch.Tensor,
218
+ value: torch.Tensor,
219
+ attention_mask: Optional[torch.Tensor],
220
+ scaling: float,
221
+ dropout: float = 0.0,
222
+ **kwargs,
223
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
224
+ """Eager (standard) attention implementation."""
225
+ key_states = repeat_kv(key, module.num_key_value_groups)
226
+ value_states = repeat_kv(value, module.num_key_value_groups)
227
+
228
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
229
+
230
+ if attention_mask is not None:
231
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
232
+ attn_weights = attn_weights + causal_mask
233
+
234
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
235
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
236
+ attn_output = torch.matmul(attn_weights, value_states)
237
+ attn_output = attn_output.transpose(1, 2).contiguous()
238
+
239
+ return attn_output, attn_weights
240
+
241
+
242
+ # ============================================================================
243
+ # Attention Layer
244
+ # ============================================================================
245
+
246
+ class WeDLMAttention(nn.Module):
247
+ """
248
+ WeDLM Attention module.
249
+
250
+ Supports both:
251
+ - Qwen2.5 style: with QKV bias, no QK Norm
252
+ - Qwen3 style: configurable QKV bias, with QK Norm
253
+ """
254
+
255
+ def __init__(self, config: WeDLMConfig, layer_idx: int):
256
+ super().__init__()
257
+ self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
258
+ self.config = config
259
+ self.layer_idx = layer_idx
260
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
261
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
262
+ self.scaling = self.head_dim ** -0.5
263
+ self.attention_dropout = config.attention_dropout
264
+ self.is_causal = True
265
+
266
+ # Support configurable attention_bias (Qwen2.5: True, Qwen3: False by default)
267
+ attention_bias = getattr(config, "attention_bias", True)
268
+
269
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=attention_bias)
270
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=attention_bias)
271
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=attention_bias)
272
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
273
+
274
+ # Support optional QK Norm (Qwen3 feature)
275
+ self.qk_norm = getattr(config, "qk_norm", False)
276
+ if self.qk_norm:
277
+ self.q_norm = WeDLMRMSNorm(self.head_dim, eps=config.rms_norm_eps)
278
+ self.k_norm = WeDLMRMSNorm(self.head_dim, eps=config.rms_norm_eps)
279
+
280
+ self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
281
+
282
+ def forward(
283
+ self,
284
+ hidden_states: torch.Tensor,
285
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
286
+ attention_mask: Optional[torch.Tensor],
287
+ past_key_values: Optional[Cache] = None,
288
+ cache_position: Optional[torch.LongTensor] = None,
289
+ **kwargs,
290
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
291
+ input_shape = hidden_states.shape[:-1]
292
+ hidden_shape = (*input_shape, -1, self.head_dim)
293
+
294
+ if self.qk_norm:
295
+ # Qwen3 style: apply norm after projection, before transpose
296
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
297
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
298
+ else:
299
+ # Qwen2 style: no norm
300
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
301
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
302
+
303
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
304
+
305
+ cos, sin = position_embeddings
306
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
307
+
308
+ if past_key_values is not None:
309
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
310
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
311
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
312
+
313
+ # Select attention implementation
314
+ attention_interface: Callable = eager_attention_forward
315
+ if self.config._attn_implementation != "eager" and self.config._attn_implementation in ALL_ATTENTION_FUNCTIONS:
316
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
317
+
318
+ attn_output, attn_weights = attention_interface(
319
+ self,
320
+ query_states,
321
+ key_states,
322
+ value_states,
323
+ attention_mask,
324
+ dropout=0.0 if not self.training else self.attention_dropout,
325
+ scaling=self.scaling,
326
+ sliding_window=self.sliding_window,
327
+ **kwargs,
328
+ )
329
+
330
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
331
+ attn_output = self.o_proj(attn_output)
332
+ return attn_output, attn_weights
333
+
334
+
335
+ # ============================================================================
336
+ # Decoder Layer
337
+ # ============================================================================
338
+
339
+ class WeDLMDecoderLayer(GradientCheckpointingLayer):
340
+ """WeDLM Decoder Layer with pre-norm architecture."""
341
+
342
+ def __init__(self, config: WeDLMConfig, layer_idx: int):
343
+ super().__init__()
344
+ self.hidden_size = config.hidden_size
345
+
346
+ self.self_attn = WeDLMAttention(config=config, layer_idx=layer_idx)
347
+ self.mlp = WeDLMMLP(config)
348
+ self.input_layernorm = WeDLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
349
+ self.post_attention_layernorm = WeDLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
350
+ self.attention_type = config.layer_types[layer_idx]
351
+
352
+ def forward(
353
+ self,
354
+ hidden_states: torch.Tensor,
355
+ attention_mask: Optional[torch.Tensor] = None,
356
+ position_ids: Optional[torch.LongTensor] = None,
357
+ past_key_values: Optional[Cache] = None,
358
+ output_attentions: Optional[bool] = False,
359
+ use_cache: Optional[bool] = False,
360
+ cache_position: Optional[torch.LongTensor] = None,
361
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
362
+ **kwargs: Unpack[TransformersKwargs],
363
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
364
+ """
365
+ Args:
366
+ hidden_states: Input tensor of shape `(batch, seq_len, embed_dim)`
367
+ attention_mask: Attention mask of size `(batch, sequence_length)`
368
+ position_ids: Position indices
369
+ past_key_values: Cached past key and value projection states
370
+ output_attentions: Whether to return attention weights
371
+ use_cache: Whether to use KV cache
372
+ cache_position: Position in the cache
373
+ position_embeddings: Tuple of (cos, sin) for rotary embeddings
374
+ """
375
+ residual = hidden_states
376
+ hidden_states = self.input_layernorm(hidden_states)
377
+
378
+ # Self Attention
379
+ hidden_states, self_attn_weights = self.self_attn(
380
+ hidden_states=hidden_states,
381
+ position_embeddings=position_embeddings,
382
+ attention_mask=attention_mask,
383
+ past_key_values=past_key_values,
384
+ cache_position=cache_position,
385
+ **kwargs,
386
+ )
387
+ hidden_states = residual + hidden_states
388
+
389
+ # Feed Forward
390
+ residual = hidden_states
391
+ hidden_states = self.post_attention_layernorm(hidden_states)
392
+ hidden_states = self.mlp(hidden_states)
393
+ hidden_states = residual + hidden_states
394
+
395
+ outputs = (hidden_states,)
396
+
397
+ if output_attentions:
398
+ outputs += (self_attn_weights,)
399
+
400
+ return outputs
401
+
402
+
403
+ # ============================================================================
404
+ # Model Classes
405
+ # ============================================================================
406
+
407
+ @auto_docstring
408
+ class WeDLMPreTrainedModel(PreTrainedModel):
409
+ """Base class for WeDLM models."""
410
+
411
+ config_class = WeDLMConfig
412
+ base_model_prefix = "model"
413
+ supports_gradient_checkpointing = True
414
+ _no_split_modules = ["WeDLMDecoderLayer"]
415
+ _skip_keys_device_placement = ["past_key_values"]
416
+ _supports_flash_attn = True
417
+ _supports_sdpa = True
418
+ _supports_flex_attn = True
419
+ _can_compile_fullgraph = True
420
+ _supports_attention_backend = True
421
+ _can_record_outputs = {
422
+ "hidden_states": WeDLMDecoderLayer,
423
+ "attentions": WeDLMAttention,
424
+ }
425
+
426
+
427
+ @auto_docstring
428
+ class WeDLMModel(WeDLMPreTrainedModel):
429
+ """
430
+ WeDLM base model outputting raw hidden states.
431
+ """
432
+
433
+ def __init__(self, config: WeDLMConfig):
434
+ super().__init__(config)
435
+ self.padding_idx = config.pad_token_id
436
+ self.vocab_size = config.vocab_size
437
+
438
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
439
+ self.layers = nn.ModuleList(
440
+ [WeDLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
441
+ )
442
+ self.norm = WeDLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
443
+ self.rotary_emb = WeDLMRotaryEmbedding(config=config)
444
+ self.gradient_checkpointing = False
445
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
446
+
447
+ # Initialize weights and apply final processing
448
+ self.post_init()
449
+
450
+ def get_input_embeddings(self):
451
+ return self.embed_tokens
452
+
453
+ def set_input_embeddings(self, value):
454
+ self.embed_tokens = value
455
+
456
+ @check_model_inputs
457
+ @auto_docstring
458
+ def forward(
459
+ self,
460
+ input_ids: Optional[torch.LongTensor] = None,
461
+ attention_mask: Optional[torch.Tensor] = None,
462
+ position_ids: Optional[torch.LongTensor] = None,
463
+ past_key_values: Optional[Cache] = None,
464
+ inputs_embeds: Optional[torch.FloatTensor] = None,
465
+ use_cache: Optional[bool] = None,
466
+ output_attentions: Optional[bool] = None,
467
+ output_hidden_states: Optional[bool] = None,
468
+ return_dict: Optional[bool] = None,
469
+ cache_position: Optional[torch.LongTensor] = None,
470
+ **kwargs: Unpack[TransformersKwargs],
471
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
472
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
473
+ output_hidden_states = (
474
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
475
+ )
476
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
477
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
478
+
479
+ if (input_ids is None) ^ (inputs_embeds is not None):
480
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
481
+
482
+ if inputs_embeds is None:
483
+ inputs_embeds = self.embed_tokens(input_ids)
484
+
485
+ if use_cache and past_key_values is None:
486
+ past_key_values = DynamicCache(config=self.config)
487
+
488
+ if cache_position is None:
489
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
490
+ cache_position = torch.arange(
491
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
492
+ )
493
+
494
+ if position_ids is None:
495
+ position_ids = cache_position.unsqueeze(0)
496
+
497
+ # Prepare attention masks
498
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
499
+ mask_kwargs = {
500
+ "config": self.config,
501
+ "input_embeds": inputs_embeds,
502
+ "attention_mask": attention_mask,
503
+ "cache_position": cache_position,
504
+ "past_key_values": past_key_values,
505
+ "position_ids": position_ids,
506
+ }
507
+ causal_mask_mapping = {
508
+ "full_attention": create_causal_mask(**mask_kwargs),
509
+ }
510
+ if self.has_sliding_layers:
511
+ causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
512
+
513
+ hidden_states = inputs_embeds
514
+
515
+ # Create position embeddings to be shared across the decoder layers
516
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
517
+
518
+ # Decoder layers
519
+ all_hidden_states = () if output_hidden_states else None
520
+ all_self_attns = () if output_attentions else None
521
+
522
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
523
+ if output_hidden_states:
524
+ all_hidden_states += (hidden_states,)
525
+
526
+ layer_outputs = decoder_layer(
527
+ hidden_states,
528
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
529
+ position_ids=position_ids,
530
+ past_key_values=past_key_values,
531
+ output_attentions=output_attentions,
532
+ use_cache=use_cache,
533
+ cache_position=cache_position,
534
+ position_embeddings=position_embeddings,
535
+ **kwargs,
536
+ )
537
+
538
+ hidden_states = layer_outputs[0]
539
+
540
+ if output_attentions:
541
+ all_self_attns += (layer_outputs[1],)
542
+
543
+ hidden_states = self.norm(hidden_states)
544
+
545
+ if output_hidden_states:
546
+ all_hidden_states += (hidden_states,)
547
+
548
+ if not return_dict:
549
+ return tuple(v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None)
550
+
551
+ return BaseModelOutputWithPast(
552
+ last_hidden_state=hidden_states,
553
+ past_key_values=past_key_values if use_cache else None,
554
+ hidden_states=all_hidden_states,
555
+ attentions=all_self_attns,
556
+ )
557
+
558
+
559
+ @auto_docstring
560
+ class WeDLMForCausalLM(WeDLMPreTrainedModel, GenerationMixin):
561
+ """
562
+ WeDLM Model for Causal Language Modeling with WeDLM block decoding support.
563
+ """
564
+ _tied_weights_keys = ["lm_head.weight"]
565
+
566
+ def __init__(self, config: WeDLMConfig):
567
+ super().__init__(config)
568
+ self.model = WeDLMModel(config)
569
+ self.vocab_size = config.vocab_size
570
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
571
+
572
+ # Initialize weights and apply final processing
573
+ self.post_init()
574
+
575
+ def get_input_embeddings(self):
576
+ return self.model.embed_tokens
577
+
578
+ def set_input_embeddings(self, value):
579
+ self.model.embed_tokens = value
580
+
581
+ def get_output_embeddings(self):
582
+ return self.lm_head
583
+
584
+ def set_output_embeddings(self, new_embeddings):
585
+ self.lm_head = new_embeddings
586
+
587
+ def set_decoder(self, decoder):
588
+ self.model = decoder
589
+
590
+ def get_decoder(self):
591
+ return self.model
592
+
593
+ def _efficient_reorder_sequence(
594
+ self,
595
+ tokens: torch.Tensor,
596
+ mask_indices: torch.Tensor,
597
+ position_ids: torch.Tensor
598
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
599
+ """
600
+ Helper function to reorder sequence by moving MASK parts to the end.
601
+ """
602
+ reordered_tokens = torch.cat((tokens[~mask_indices], tokens[mask_indices]))
603
+ reordered_position_ids = torch.cat((position_ids[~mask_indices], position_ids[mask_indices]))
604
+ return reordered_tokens, reordered_position_ids
605
+
606
+ @torch.no_grad()
607
+ def _generate_one_block(
608
+ self,
609
+ prefix_ids: torch.Tensor,
610
+ prefix_position_ids: torch.Tensor,
611
+ block_size: int,
612
+ mask_token_id: int,
613
+ confidence_threshold: float = 0.0,
614
+ temperature: float = 1.0,
615
+ top_p: float = 1.0,
616
+ top_k: int = 0,
617
+ ) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
618
+ """
619
+ Generate one block of content based on the given prefix.
620
+
621
+ Args:
622
+ prefix_ids: Current sequence token IDs
623
+ prefix_position_ids: Position IDs for current sequence
624
+ block_size: Number of tokens to generate in this block
625
+ mask_token_id: Token ID for MASK token
626
+ confidence_threshold: Minimum confidence to accept a prediction
627
+ temperature: Sampling temperature
628
+ top_p: Nucleus sampling parameter (unused currently)
629
+ top_k: Top-k sampling parameter (unused currently)
630
+
631
+ Returns:
632
+ Tuple of (updated_ids, updated_position_ids, block_statistics)
633
+ """
634
+ device = prefix_ids.device
635
+
636
+ # 1. Append a block of MASK tokens after the current prefix
637
+ mask_tensor = torch.full((block_size,), mask_token_id, dtype=torch.long, device=device)
638
+ current_ids = torch.cat([prefix_ids, mask_tensor])
639
+
640
+ # Create position encodings for the newly added MASKs
641
+ start_pos = prefix_position_ids[-1].item() + 1 if len(prefix_position_ids) > 0 else 0
642
+ mask_position_ids = torch.arange(start_pos, start_pos + block_size, dtype=torch.long, device=device)
643
+ original_position_ids = torch.cat([prefix_position_ids, mask_position_ids])
644
+
645
+ # Mark which positions are MASK
646
+ is_mask = (current_ids == mask_token_id)
647
+
648
+ # Statistics
649
+ block_stats = {
650
+ 'steps': 0,
651
+ 'tokens_generated': 0,
652
+ 'tokens_per_step': [],
653
+ 'max_confidences': [],
654
+ }
655
+
656
+ # 2. WeDLM iteration within the block
657
+ for step in range(block_size):
658
+ if not is_mask.any():
659
+ break
660
+
661
+ block_stats['steps'] += 1
662
+
663
+ # 2.1 Reorder sequence
664
+ reordered_ids, reordered_position_ids = self._efficient_reorder_sequence(
665
+ current_ids, is_mask, original_position_ids
666
+ )
667
+
668
+ # 2.2 Prepare input
669
+ input_ids = reordered_ids.unsqueeze(0)
670
+ position_ids = reordered_position_ids.unsqueeze(0)
671
+
672
+ seq_len = input_ids.shape[1]
673
+ attention_mask = torch.ones((1, seq_len), dtype=torch.long, device=device)
674
+
675
+ # 2.3 Model forward pass
676
+ outputs = self.model(
677
+ input_ids=input_ids,
678
+ attention_mask=attention_mask,
679
+ position_ids=position_ids,
680
+ use_cache=False,
681
+ return_dict=True,
682
+ )
683
+
684
+ hidden_states = outputs.last_hidden_state
685
+ logits = self.lm_head(hidden_states)
686
+
687
+ # 2.4 Get logits for MASK positions
688
+ num_non_mask = (~is_mask).sum().item()
689
+ mask_logits = logits[0, num_non_mask:]
690
+
691
+ if mask_logits.size(0) == 0:
692
+ break
693
+
694
+ mask_logits = mask_logits / temperature
695
+ probs = F.softmax(mask_logits, dim=-1)
696
+ max_probs, predicted_ids = probs.max(dim=-1)
697
+
698
+ block_stats['max_confidences'].append(max_probs.max().item())
699
+
700
+ # 2.5 Select positions to fill
701
+ if confidence_threshold > 0.0:
702
+ above_threshold_mask = max_probs >= confidence_threshold
703
+
704
+ if above_threshold_mask.any():
705
+ indices_to_fill = above_threshold_mask.nonzero(as_tuple=True)[0]
706
+ num_tokens_this_step = len(indices_to_fill)
707
+ else:
708
+ best_idx = max_probs.argmax()
709
+ indices_to_fill = best_idx.unsqueeze(0)
710
+ num_tokens_this_step = 1
711
+ else:
712
+ best_idx = max_probs.argmax()
713
+ indices_to_fill = best_idx.unsqueeze(0)
714
+ num_tokens_this_step = 1
715
+
716
+ block_stats['tokens_per_step'].append(num_tokens_this_step)
717
+ block_stats['tokens_generated'] += num_tokens_this_step
718
+
719
+ # 2.6 Update all selected positions
720
+ for idx in indices_to_fill:
721
+ idx_item = idx.item()
722
+ best_token_id = predicted_ids[idx_item].item()
723
+
724
+ best_pos_in_reordered = num_non_mask + idx_item
725
+ original_pos_value = reordered_position_ids[best_pos_in_reordered].item()
726
+ original_pos_in_seq = (original_position_ids == original_pos_value).nonzero(as_tuple=True)[0].item()
727
+
728
+ current_ids[original_pos_in_seq] = best_token_id
729
+ is_mask[original_pos_in_seq] = False
730
+
731
+ return current_ids, original_position_ids, block_stats
732
+
733
+ @torch.no_grad()
734
+ def generate_wedlm(
735
+ self,
736
+ input_ids: torch.LongTensor,
737
+ max_new_tokens: int,
738
+ block_size: int,
739
+ mask_token_id: Optional[int] = None,
740
+ confidence_threshold: float = 0.0,
741
+ temperature: float = 1.0,
742
+ top_p: float = 1.0,
743
+ top_k: int = 0,
744
+ pad_token_id: Optional[int] = None,
745
+ return_stats: bool = True,
746
+ **kwargs
747
+ ) -> Union[torch.LongTensor, Dict]:
748
+ """
749
+ Generate text using WeDLM block decoding mode.
750
+
751
+ Args:
752
+ input_ids: Input token IDs of shape (batch_size, seq_len)
753
+ max_new_tokens: Maximum number of new tokens to generate
754
+ block_size: Number of tokens to generate per block
755
+ mask_token_id: Token ID for MASK token
756
+ confidence_threshold: Minimum confidence to accept predictions (0.0-1.0)
757
+ temperature: Sampling temperature
758
+ top_p: Nucleus sampling parameter
759
+ top_k: Top-k sampling parameter
760
+ pad_token_id: Token ID for padding
761
+ return_stats: Whether to return generation statistics
762
+
763
+ Returns:
764
+ If return_stats=False: Generated token sequences
765
+ If return_stats=True: Dict with 'sequences' and 'stats'
766
+ """
767
+ if mask_token_id is None:
768
+ mask_token_id = getattr(self.config, "mask_token_id", None)
769
+ if mask_token_id is None:
770
+ raise ValueError("mask_token_id must be provided or set in config")
771
+
772
+ if pad_token_id is None:
773
+ pad_token_id = self.config.pad_token_id
774
+
775
+ if not 0.0 <= confidence_threshold <= 1.0:
776
+ raise ValueError(f"confidence_threshold must be between 0 and 1, got {confidence_threshold}")
777
+
778
+ batch_size = input_ids.shape[0]
779
+ device = input_ids.device
780
+
781
+ num_blocks = (max_new_tokens + block_size - 1) // block_size
782
+
783
+ logger.info(
784
+ f"Starting WeDLM generation: max_new_tokens={max_new_tokens}, block_size={block_size}, "
785
+ f"confidence_threshold={confidence_threshold}, num_blocks={num_blocks}"
786
+ )
787
+
788
+ all_generated = []
789
+ all_sample_stats = []
790
+
791
+ for batch_idx in range(batch_size):
792
+ sample_ids = input_ids[batch_idx]
793
+ if pad_token_id is not None:
794
+ pad_mask = (sample_ids != pad_token_id)
795
+ if pad_mask.any():
796
+ valid_length = pad_mask.sum().item()
797
+ prefix_ids = sample_ids[:valid_length]
798
+ else:
799
+ prefix_ids = sample_ids
800
+ else:
801
+ prefix_ids = sample_ids
802
+
803
+ prefix_length = prefix_ids.shape[0]
804
+ current_position_ids = torch.arange(prefix_length, dtype=torch.long, device=device)
805
+
806
+ current_ids = prefix_ids.clone()
807
+
808
+ sample_stats = {
809
+ 'input_length': prefix_length,
810
+ 'total_steps': 0,
811
+ 'total_tokens_generated': 0,
812
+ 'blocks': [],
813
+ }
814
+
815
+ for block_idx in range(num_blocks):
816
+ remaining_tokens = max_new_tokens - block_idx * block_size
817
+ current_block_size = min(block_size, remaining_tokens)
818
+
819
+ logger.debug(
820
+ f"Batch {batch_idx}, Block {block_idx}/{num_blocks}: "
821
+ f"generating {current_block_size} tokens"
822
+ )
823
+
824
+ current_ids, current_position_ids, block_stats = self._generate_one_block(
825
+ prefix_ids=current_ids,
826
+ prefix_position_ids=current_position_ids,
827
+ block_size=current_block_size,
828
+ mask_token_id=mask_token_id,
829
+ confidence_threshold=confidence_threshold,
830
+ temperature=temperature,
831
+ top_p=top_p,
832
+ top_k=top_k,
833
+ )
834
+
835
+ sample_stats['total_steps'] += block_stats['steps']
836
+ sample_stats['total_tokens_generated'] += block_stats['tokens_generated']
837
+ sample_stats['blocks'].append(block_stats)
838
+
839
+ sample_stats['actual_tokens_generated'] = len(current_ids) - prefix_length
840
+ sample_stats['output_length'] = len(current_ids)
841
+
842
+ all_generated.append(current_ids)
843
+ all_sample_stats.append(sample_stats)
844
+
845
+ max_length = max(seq.shape[0] for seq in all_generated)
846
+ padded_sequences = []
847
+
848
+ for seq in all_generated:
849
+ if seq.shape[0] < max_length:
850
+ padding = torch.full(
851
+ (max_length - seq.shape[0],),
852
+ pad_token_id if pad_token_id is not None else 0,
853
+ dtype=torch.long,
854
+ device=device
855
+ )
856
+ seq = torch.cat([seq, padding])
857
+ padded_sequences.append(seq)
858
+
859
+ result_sequences = torch.stack(padded_sequences, dim=0)
860
+
861
+ total_steps = sum(s['total_steps'] for s in all_sample_stats)
862
+ total_tokens = sum(s['total_tokens_generated'] for s in all_sample_stats)
863
+ avg_tokens_per_step = total_tokens / total_steps if total_steps > 0 else 0
864
+
865
+ logger.info(
866
+ f"WeDLM generation completed: "
867
+ f"total_steps={total_steps}, "
868
+ f"total_tokens_generated={total_tokens}, "
869
+ f"avg_tokens_per_step={avg_tokens_per_step:.2f}"
870
+ )
871
+
872
+ if not return_stats:
873
+ return result_sequences
874
+
875
+ return {
876
+ 'sequences': result_sequences,
877
+ 'stats': {
878
+ 'total_steps': total_steps,
879
+ 'total_tokens_generated': total_tokens,
880
+ 'average_tokens_per_step': avg_tokens_per_step,
881
+ 'efficiency_ratio': total_tokens / total_steps if total_steps > 0 else 0,
882
+ 'per_sample_stats': all_sample_stats,
883
+ 'config': {
884
+ 'batch_size': batch_size,
885
+ 'max_new_tokens': max_new_tokens,
886
+ 'block_size': block_size,
887
+ 'confidence_threshold': confidence_threshold,
888
+ 'temperature': temperature,
889
+ }
890
+ }
891
+ }
892
+
893
+ @can_return_tuple
894
+ @auto_docstring
895
+ def forward(
896
+ self,
897
+ input_ids: Optional[torch.LongTensor] = None,
898
+ attention_mask: Optional[torch.Tensor] = None,
899
+ position_ids: Optional[torch.LongTensor] = None,
900
+ past_key_values: Optional[Cache] = None,
901
+ inputs_embeds: Optional[torch.FloatTensor] = None,
902
+ labels: Optional[torch.LongTensor] = None,
903
+ use_cache: Optional[bool] = None,
904
+ output_attentions: Optional[bool] = None,
905
+ output_hidden_states: Optional[bool] = None,
906
+ return_dict: Optional[bool] = None,
907
+ cache_position: Optional[torch.LongTensor] = None,
908
+ logits_to_keep: Union[int, torch.Tensor] = 0,
909
+ **kwargs: Unpack[TransformersKwargs],
910
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
911
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
912
+ output_hidden_states = (
913
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
914
+ )
915
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
916
+
917
+ outputs = self.model(
918
+ input_ids=input_ids,
919
+ attention_mask=attention_mask,
920
+ position_ids=position_ids,
921
+ past_key_values=past_key_values,
922
+ inputs_embeds=inputs_embeds,
923
+ use_cache=use_cache,
924
+ output_attentions=output_attentions,
925
+ output_hidden_states=output_hidden_states,
926
+ return_dict=return_dict,
927
+ cache_position=cache_position,
928
+ **kwargs,
929
+ )
930
+
931
+ hidden_states = outputs[0]
932
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
933
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
934
+
935
+ loss = None
936
+ if labels is not None:
937
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
938
+
939
+ if not return_dict:
940
+ output = (logits,) + outputs[1:]
941
+ return (loss,) + output if loss is not None else output
942
+
943
+ return CausalLMOutputWithPast(
944
+ loss=loss,
945
+ logits=logits,
946
+ past_key_values=outputs.past_key_values,
947
+ hidden_states=outputs.hidden_states,
948
+ attentions=outputs.attentions,
949
+ )
950
+
951
+ def prepare_inputs_for_generation(
952
+ self,
953
+ input_ids,
954
+ past_key_values=None,
955
+ attention_mask=None,
956
+ inputs_embeds=None,
957
+ cache_position=None,
958
+ position_ids=None,
959
+ use_cache=True,
960
+ **kwargs
961
+ ):
962
+ if past_key_values is not None:
963
+ if inputs_embeds is not None:
964
+ input_ids = input_ids[:, -cache_position.shape[0]:]
965
+ elif input_ids.shape[1] != cache_position.shape[0]:
966
+ input_ids = input_ids[:, cache_position]
967
+
968
+ if attention_mask is not None and position_ids is None:
969
+ position_ids = attention_mask.long().cumsum(-1) - 1
970
+ position_ids.masked_fill_(attention_mask == 0, 1)
971
+ if past_key_values:
972
+ position_ids = position_ids[:, -input_ids.shape[1]:]
973
+
974
+ if inputs_embeds is not None and cache_position[0] == 0:
975
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
976
+ else:
977
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
978
+
979
+ if isinstance(past_key_values, DynamicCache) and attention_mask.ndim == 2:
980
+ model_inputs["cache_position"] = cache_position
981
+ model_inputs["past_key_values"] = past_key_values
982
+ model_inputs["use_cache"] = use_cache
983
+ model_inputs["position_ids"] = position_ids
984
+ model_inputs["attention_mask"] = attention_mask
985
+ return model_inputs
986
+
987
+ model_inputs.update(
988
+ {
989
+ "position_ids": position_ids,
990
+ "cache_position": cache_position,
991
+ "past_key_values": past_key_values,
992
+ "use_cache": use_cache,
993
+ "attention_mask": attention_mask,
994
+ }
995
+ )
996
+ return model_inputs
997
+
998
+
999
+ __all__ = [
1000
+ "WeDLMConfig",
1001
+ "WeDLMPreTrainedModel",
1002
+ "WeDLMModel",
1003
+ "WeDLMForCausalLM",
1004
+ ]
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c5ae00e602b8860cbd784ba82a8aa14e8feecec692e7076590d014d7b7fdafa
3
+ size 11421896
tokenizer_config.json ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ }
181
+ },
182
+ "additional_special_tokens": [
183
+ "<|im_start|>",
184
+ "<|im_end|>",
185
+ "<|object_ref_start|>",
186
+ "<|object_ref_end|>",
187
+ "<|box_start|>",
188
+ "<|box_end|>",
189
+ "<|quad_start|>",
190
+ "<|quad_end|>",
191
+ "<|vision_start|>",
192
+ "<|vision_end|>",
193
+ "<|vision_pad|>",
194
+ "<|image_pad|>",
195
+ "<|video_pad|>"
196
+ ],
197
+ "bos_token": null,
198
+ "clean_up_tokenization_spaces": false,
199
+ "eos_token": "<|endoftext|>",
200
+ "errors": "replace",
201
+ "extra_special_tokens": {},
202
+ "model_max_length": 131072,
203
+ "pad_token": "<|endoftext|>",
204
+ "split_special_tokens": false,
205
+ "tokenizer_class": "Qwen2Tokenizer",
206
+ "unk_token": null
207
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff