anneketh-vij MaziyarPanahi Crystalcareai commited on
Commit
d5ccf82
·
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files

Co-authored-by: MaziyarPanahi <MaziyarPanahi@users.noreply.huggingface.co>
Co-authored-by: Crystalcareai <Crystalcareai@users.noreply.huggingface.co>

Files changed (43) hide show
  1. .gitattributes +36 -0
  2. README.md +311 -0
  3. __init__.py +4 -0
  4. chat_template.jinja +159 -0
  5. config.json +108 -0
  6. configuration_afmoe.py +133 -0
  7. generation_config.json +9 -0
  8. model-00001-of-00031.safetensors +3 -0
  9. model-00002-of-00031.safetensors +3 -0
  10. model-00003-of-00031.safetensors +3 -0
  11. model-00004-of-00031.safetensors +3 -0
  12. model-00005-of-00031.safetensors +3 -0
  13. model-00006-of-00031.safetensors +3 -0
  14. model-00007-of-00031.safetensors +3 -0
  15. model-00008-of-00031.safetensors +3 -0
  16. model-00009-of-00031.safetensors +3 -0
  17. model-00010-of-00031.safetensors +3 -0
  18. model-00011-of-00031.safetensors +3 -0
  19. model-00012-of-00031.safetensors +3 -0
  20. model-00013-of-00031.safetensors +3 -0
  21. model-00014-of-00031.safetensors +3 -0
  22. model-00015-of-00031.safetensors +3 -0
  23. model-00016-of-00031.safetensors +3 -0
  24. model-00017-of-00031.safetensors +3 -0
  25. model-00018-of-00031.safetensors +3 -0
  26. model-00019-of-00031.safetensors +3 -0
  27. model-00020-of-00031.safetensors +3 -0
  28. model-00021-of-00031.safetensors +3 -0
  29. model-00022-of-00031.safetensors +3 -0
  30. model-00023-of-00031.safetensors +3 -0
  31. model-00024-of-00031.safetensors +3 -0
  32. model-00025-of-00031.safetensors +3 -0
  33. model-00026-of-00031.safetensors +3 -0
  34. model-00027-of-00031.safetensors +3 -0
  35. model-00028-of-00031.safetensors +3 -0
  36. model-00029-of-00031.safetensors +3 -0
  37. model-00030-of-00031.safetensors +3 -0
  38. model-00031-of-00031.safetensors +3 -0
  39. model.safetensors.index.json +0 -0
  40. modeling_afmoe.py +680 -0
  41. special_tokens_map.json +23 -0
  42. tokenizer.json +3 -0
  43. tokenizer_config.json +271 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ - es
6
+ - fr
7
+ - de
8
+ - it
9
+ - pt
10
+ - ru
11
+ - ar
12
+ - hi
13
+ - ko
14
+ - zh
15
+ library_name: transformers
16
+ base_model:
17
+ - arcee-ai/Trinity-Large-Base
18
+ arxiv:
19
+ - 2602.17004
20
+ tags:
21
+ - reasoning
22
+ - agentic
23
+ - tool-calling
24
+ - thinking
25
+ ---
26
+ <!-- markdownlint-disable first-line-h1 -->
27
+ <!-- markdownlint-disable html -->
28
+ <!-- markdownlint-disable no-duplicate-header -->
29
+
30
+ <div align="center">
31
+ <picture>
32
+ <img
33
+ src="https://cdn-uploads.huggingface.co/production/uploads/6435718aaaef013d1aec3b8b/i-v1KyAMOW_mgVGeic9WJ.png"
34
+ alt="Arcee Trinity Large Thinking"
35
+ style="max-width: 100%; height: auto;"
36
+ >
37
+ </picture>
38
+ </div>
39
+ <hr>
40
+
41
+ # Trinity-Large-Thinking
42
+
43
+ ## Introduction
44
+
45
+ Trinity-Large-Thinking is a reasoning-optimized variant of Arcee AI's Trinity-Large family — a 398B-parameter sparse Mixture-of-Experts (MoE) model with approximately 13B active parameters per token. Built on Trinity-Large-Base and post-trained with extended chain-of-thought reasoning and agentic RL, Trinity-Large-Thinking delivers state-of-the-art performance on agentic benchmarks while maintaining strong general capabilities.
46
+
47
+ Trinity-Large-Thinking generates explicit reasoning traces wrapped in `<think>...</think>` blocks before producing its final response. This thinking process is critical to the model's performance — **thinking tokens must be kept in context** for multi-turn conversations and agentic loops to function correctly.
48
+
49
+ Try it at [chat.arcee.ai](http://chat.arcee.ai/)
50
+
51
+ More details on the training of Trinity Large are available in the [technical report](https://arxiv.org/abs/2602.17004).
52
+
53
+ ## Key Highlights
54
+
55
+ - **Agentic-first design**: Purpose-built for tool calling, multi-step planning, and agent workflows
56
+ - **State-of-the-art agentic performance**: 94.7% on τ²-Bench, 91.9% on PinchBench, 98.2% on LiveCodeBench
57
+ - **Native reasoning traces**: Extended chain-of-thought via `<think>...</think>` blocks
58
+ - **Compatible with major agent frameworks**: Works out of the box with [OpenClaw](https://github.com/openclaw) and [Hermes Agent](https://github.com/NousResearch/hermes-agent)
59
+ - **Ready to use on [OpenRouter](https://openrouter.ai/)**: No setup required — full reasoning and tool calling support via API
60
+
61
+ ## Model Variants
62
+
63
+ The Trinity Large family consists of four checkpoints:
64
+
65
+ - **Trinity-Large-Thinking** (this release): Reasoning-optimized, agentic post-training with extended chain-of-thought
66
+ - **[Trinity-Large-Preview](https://huggingface.co/arcee-ai/Trinity-Large-Preview)**: Lightly post-trained, chat-ready instruct model (no reasoning_content).
67
+ - **[Trinity-Large-TrueBase](https://huggingface.co/arcee-ai/Trinity-Large-TrueBase)**: 10T-token pre-anneal pretraining checkpoint
68
+ - **[Trinity-Large-Base](https://huggingface.co/arcee-ai/Trinity-Large-Base)**: Full 17T-token pretrained foundation model with mid-training anneals
69
+
70
+ ## Architecture
71
+
72
+ Trinity-Large-Thinking shares the same sparse MoE architecture as Trinity-Large-Preview.
73
+
74
+ | Hyperparameter | Value |
75
+ |:---|:---:|
76
+ | Total parameters | ~398B |
77
+ | Active parameters per token | ~13B |
78
+ | Experts | 256 (1 shared) |
79
+ | Active experts | 4 |
80
+ | Routing strategy | 4-of-256 (1.56% sparsity) |
81
+ | Dense layers | 6 |
82
+ | Pretraining context length | 8,192 |
83
+ | Context length after extension | 512k |
84
+ | Architecture | Sparse MoE (AfmoeForCausalLM) |
85
+
86
+ ## Benchmarks
87
+
88
+ | Benchmark | Trinity-Large-Thinking | Opus-4.6 | GLM-5 | MiniMax-M2.7 | Kimi-K2.5 |
89
+ |---|---:|---:|---:|---:|---:|
90
+ | IFBench | 52.3 | 53.1 | 72.3 | **75.7** | 70.2 |
91
+ | GPQA-Diamond | 76.3 | **89.2** | 81.6 | 86.2 | 86.9 |
92
+ | Tau2-Airline | **88.0** | 82.0 | 80.5 | 80.0 | 80.0 |
93
+ | Tau2-Telecom | 94.7 | 92.1 | **98.2** | 84.8 | 95.9 |
94
+ | PinchBench | 91.9 | **93.3** | 86.4 | 89.8 | 84.8 |
95
+ | AIME25 | 96.3 | **99.8** | 93.3 | 80.0 | 96.3 |
96
+ | BCFLv4 | 70.1 | **77.0** | 70.8 | 70.6 | 68.3 |
97
+ | MMLU-Pro | 83.4 | **89.1** | 85.8 | 80.8 | 87.1 |
98
+ | SWE-bench Verified* | 63.2 | **75.6** | 72.8 | 75.4 | 70.8 |
99
+
100
+ *All models evaluated in mini-swe-agent-v2
101
+
102
+ ## Thinking-in-Context: Important Usage Note
103
+
104
+ Trinity-Large-Thinking produces reasoning traces inside `<think>...</think>` blocks before generating its final response.
105
+
106
+ This means:
107
+
108
+ 1. **Multi-turn conversations**: When building chat applications, include the full assistant response (thinking + answer) in the conversation history for subsequent turns.
109
+ 2. **Agentic loops**: When using Trinity-Large-Thinking as the backbone of an agent (OpenClaw, Hermes Agent, or custom), ensure your tool-calling loop preserves `<think>` blocks in the message history between steps.
110
+ 3. **Context window management**: The 512k extended context window accommodates long reasoning chains across many agentic steps. If you must truncate history, prefer removing older turns entirely rather than stripping thinking tokens from recent turns.
111
+
112
+ ### How thinking works
113
+
114
+ The model reasons internally before producing its response. When served via vLLM, the reasoning is separated into a dedicated `reasoning_content` field in the API response:
115
+
116
+ // API response structure
117
+ {
118
+ "message": {
119
+ "role": "assistant",
120
+ "reasoning_content": "The user wants flight information. I need to determine the date for next Tuesday, search for flights SFO → JFK, and filter by price < $300.",
121
+ "content": "\n",
122
+ "tool_calls": [{
123
+ "function": {
124
+ "name": "search_flights",
125
+ "arguments": "{\"origin\": \"SFO\", \"destination\": \"JFK\", \"date\": \"2026-04-07\", \"max_price\": 300}"
126
+ }
127
+ }]
128
+ }
129
+ }
130
+
131
+ When building multi-turn agentic loops, include the `reasoning_content` back in the conversation history (re-wrapped in `<think>...</think>` tags within the assistant message) so the model retains its prior reasoning chain.
132
+
133
+ ## Training Configuration
134
+
135
+ ### Pretraining
136
+
137
+ - Training tokens: 17 trillion
138
+ - Data partner: [Datology](https://www.datologyai.com/)
139
+
140
+ ### Posttraining
141
+
142
+ - Instruction tuning and agentic RL with extended chain-of-thought
143
+ - Trained on tool-calling trajectories, multi-step agent tasks, and reasoning chains
144
+
145
+ ### Infrastructure
146
+
147
+ - Hardware: 2,048 NVIDIA B300 GPUs
148
+ - Parallelism: HSDP + Expert Parallelism
149
+ - Compute partner: [Prime Intellect](https://www.primeintellect.ai/)
150
+
151
+ ## Usage
152
+
153
+ ### Running our model
154
+
155
+ - [vLLM](#vllm) (recommended for agentic deployments)
156
+ - [Transformers](#transformers)
157
+ - [API](#api)
158
+
159
+ ### vLLM
160
+
161
+ Supported in vLLM 0.11.1+. For agentic use with both reasoning and tool calling:
162
+
163
+ vllm serve arcee-ai/Trinity-Large-Thinking \
164
+ --dtype bfloat16 \
165
+ --enable-reasoning \
166
+ --reasoning-parser deepseek_r1 \
167
+ --enable-auto-tool-choice \
168
+ --tool-call-parser qwen3_coder
169
+
170
+ This configuration:
171
+ - `--reasoning-parser deepseek_r1` — Parses `<think>...</think>` reasoning blocks and exposes them via the `reasoning_content` field in the API response
172
+ - `--tool-call-parser qwen3_coder` — Parses structured tool calls from the model output into the OpenAI-compatible `tool_calls` array
173
+
174
+ **Extracting reasoning content from the API response:**
175
+
176
+ ```python
177
+ from openai import OpenAI
178
+
179
+ client = OpenAI(api_key="EMPTY", base_url="http://localhost:8000/v1")
180
+
181
+ response = client.chat.completions.create(
182
+ model="arcee-ai/Trinity-Large-Thinking",
183
+ messages=[
184
+ {"role": "user", "content": "What's the weather like in Paris?"}
185
+ ],
186
+ tools=[ # your tool definitions here
187
+ {
188
+ "type": "function",
189
+ "function": {
190
+ "name": "get_weather",
191
+ "description": "Get current weather for a location",
192
+ "parameters": {
193
+ "type": "object",
194
+ "properties": {
195
+ "location": {"type": "string"}
196
+ },
197
+ "required": ["location"]
198
+ }
199
+ }
200
+ }
201
+ ],
202
+ )
203
+
204
+ # Access reasoning (thinking) content
205
+ reasoning = response.choices[0].message.reasoning_content
206
+
207
+ # Access final response or tool calls
208
+ content = response.choices[0].message.content
209
+ tool_calls = response.choices[0].message.tool_calls
210
+ ```
211
+
212
+ **Note on thinking-in-context with vLLM**: When building multi-turn agentic loops, include both `reasoning_content` and `content` in the conversation history you send back to the model. The reasoning content should be re-wrapped in `<think>...</think>` tags within the assistant message.
213
+
214
+ ### Transformers
215
+
216
+ Use the `main` transformers branch or pass `trust_remote_code=True` with a released version.
217
+
218
+ ```python
219
+ from transformers import AutoTokenizer, AutoModelForCausalLM
220
+ import torch
221
+
222
+ model_id = "arcee-ai/Trinity-Large-Thinking"
223
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
224
+ model = AutoModelForCausalLM.from_pretrained(
225
+ model_id,
226
+ torch_dtype=torch.bfloat16,
227
+ device_map="auto",
228
+ trust_remote_code=True
229
+ )
230
+
231
+ messages = [
232
+ {"role": "user", "content": "Who are you?"},
233
+ ]
234
+
235
+ input_ids = tokenizer.apply_chat_template(
236
+ messages,
237
+ add_generation_prompt=True,
238
+ return_tensors="pt"
239
+ ).to(model.device)
240
+
241
+ outputs = model.generate(
242
+ input_ids,
243
+ max_new_tokens=4096,
244
+ do_sample=True,
245
+ temperature=0.6,
246
+ top_k=50,
247
+ top_p=0.95
248
+ )
249
+
250
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
251
+ print(response)
252
+ ```
253
+
254
+ ### API
255
+
256
+ Available on OpenRouter:
257
+
258
+ curl -X POST "https://openrouter.ai/v1/chat/completions" \
259
+ -H "Authorization: Bearer $OPENROUTER_API_KEY" \
260
+ -H "Content-Type: application/json" \
261
+ -d '{
262
+ "model": "arcee-ai/trinity-large-thinking",
263
+ "messages": [
264
+ {
265
+ "role": "user",
266
+ "content": "What are some fun things to do in New York?"
267
+ }
268
+ ]
269
+ }'
270
+
271
+ ## Agentic Use Cases
272
+
273
+ Trinity-Large-Thinking is optimized for deployment as the reasoning backbone of AI agent systems. It has been evaluated and performs excellently with:
274
+
275
+ ### OpenClaw
276
+
277
+ Trinity-Large-Thinking works as a drop-in brain for OpenClaw agents. Its native tool-calling format is compatible with OpenClaw's execution loop, and the extended reasoning enables reliable multi-step task completion — from email triage to code generation to meeting scheduling. Our 91.9% PinchBench score reflects real-world OpenClaw task performance.
278
+
279
+ ### Hermes Agent
280
+
281
+ Compatible with the Hermes Agent framework from Nous Research. Trinity-Large-Thinking's reasoning traces pair naturally with Hermes's skill-learning loop — the model's explicit chain-of-thought makes skill extraction more reliable, and its strong tool-calling capabilities integrate directly via the Hermes tool-use protocol.
282
+
283
+ ### Custom Agent Loops
284
+
285
+ For custom implementations, the key integration pattern is:
286
+
287
+ 1. Send the user message with tool definitions
288
+ 2. Receive the response with `<think>` reasoning + tool calls
289
+ 3. Execute the tool calls
290
+ 4. Append the **full** assistant response (thinking + content + tool calls) and tool results to the message history
291
+ 5. Send the updated history back for the next step
292
+ 6. Repeat until the model produces a final response without tool calls
293
+
294
+ ## License
295
+
296
+ Trinity-Large-Thinking is released under the Apache License, Version 2.0.
297
+
298
+ ## Citation
299
+
300
+ If you use this model, please cite:
301
+
302
+ @misc{singh2026arceetrinity,
303
+ title = {Arcee Trinity Large Technical Report},
304
+ author = {Varun Singh and Lucas Krauss and Sami Jaghouar and Matej Sirovatka and Charles Goddard and Fares Obied and Jack Min Ong and Jannik Straube and Fern and Aria Harley and Conner Stewart and Colin Kealty and Maziyar Panahi and Simon Kirsten and Anushka Deshpande and Anneketh Vij and Arthur Bresnu and Pranav Veldurthi and Raghav Ravishankar and Hardik Bishnoi and DatologyAI Team and Arcee AI Team and Prime Intellect Team and Mark McQuade and Johannes Hagemann and Lucas Atkins},
305
+ year = {2026},
306
+ eprint = {2602.17004},
307
+ archivePrefix= {arXiv},
308
+ primaryClass = {cs.LG},
309
+ doi = {10.48550/arXiv.2602.17004},
310
+ url = {https://arxiv.org/abs/2602.17004}
311
+ }
__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .configuration_afmoe import AfmoeConfig
2
+ from .modeling_afmoe import AfmoeForCausalLM
3
+
4
+ __all__ = ["AfmoeConfig", "AfmoeForCausalLM"]
chat_template.jinja ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <|begin_of_text|>{%- macro render_extra_keys(json_dict, handled_keys) -%}
2
+ {%- if json_dict is mapping %}
3
+ {%- for json_key in json_dict if json_key not in handled_keys %}
4
+ {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %}
5
+ {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }}
6
+ {%- else %}
7
+ {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }}
8
+ {%- endif %}
9
+ {%- endfor %}
10
+ {%- endif %}
11
+ {%- endmacro -%}
12
+
13
+ {%- macro render_tool_call(raw_tool_call) -%}
14
+ {%- if raw_tool_call.function is defined and raw_tool_call.function is mapping %}
15
+ {%- set tool_call = raw_tool_call.function %}
16
+ {%- else %}
17
+ {%- set tool_call = raw_tool_call %}
18
+ {%- endif %}
19
+ {{- '<tool_call>\n<function=' + (tool_call.name | default('') | string) + '>\n' }}
20
+ {%- if tool_call.arguments is defined and tool_call.arguments is mapping %}
21
+ {%- for args_name, args_value in tool_call.arguments.items() %}
22
+ {{- '<parameter=' + (args_name | string) + '>\n' }}
23
+ {%- if args_value is mapping or (args_value is sequence and args_value is not string) %}
24
+ {{- args_value | tojson | safe }}
25
+ {%- else %}
26
+ {{- args_value | string }}
27
+ {%- endif %}
28
+ {{- '\n</parameter>\n' }}
29
+ {%- endfor %}
30
+ {%- endif %}
31
+ {{- '</function>\n</tool_call>' }}
32
+ {%- endmacro -%}
33
+
34
+ {%- set system_message = none %}
35
+ {%- if messages and messages[0]["role"] == "system" %}
36
+ {%- set system_message = messages[0]["content"] %}
37
+ {%- set loop_messages = messages[1:] %}
38
+ {%- else %}
39
+ {%- set loop_messages = messages %}
40
+ {%- endif %}
41
+
42
+ {%- if not tools is defined %}
43
+ {%- set tools = [] %}
44
+ {%- endif %}
45
+ {%- set has_tools = tools is iterable and tools is not string and tools | length > 0 %}
46
+
47
+ {%- if system_message is not none or has_tools %}
48
+ {{- '<|im_start|>system\n' }}
49
+ {%- if system_message is not none %}
50
+ {{- system_message }}
51
+ {%- else %}
52
+ {{- "You are Trinity Large, a helpful assistant developed by Arcee AI, that can interact with a computer to solve tasks." }}
53
+ {%- endif %}
54
+ {%- if has_tools %}
55
+ {{- "\n\n# Tools\n\nYou have access to the following functions:\n\n<tools>" }}
56
+ {%- for tool in tools %}
57
+ {%- if tool.function is defined and tool.function is mapping %}
58
+ {%- set tool = tool.function %}
59
+ {%- endif %}
60
+ {{- '\n<function>\n<name>' ~ (tool.name | default('') | string) ~ '</name>' }}
61
+ {%- if tool.description is defined and tool.description is not none %}
62
+ {{- '\n<description>' ~ (tool.description | string | trim) ~ '</description>' }}
63
+ {%- endif %}
64
+ {{- '\n<parameters>' }}
65
+ {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}
66
+ {%- for param_name, param_fields in tool.parameters.properties.items() %}
67
+ {{- '\n<parameter>\n<name>' ~ (param_name | string) ~ '</name>' }}
68
+ {%- if param_fields is mapping and param_fields.type is defined and param_fields.type is not none %}
69
+ {{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }}
70
+ {%- endif %}
71
+ {%- if param_fields is mapping and param_fields.description is defined and param_fields.description is not none %}
72
+ {{- '\n<description>' ~ (param_fields.description | string | trim) ~ '</description>' }}
73
+ {%- endif %}
74
+ {%- if param_fields is mapping %}
75
+ {%- set handled_keys = ['name', 'type', 'description'] %}
76
+ {{- render_extra_keys(param_fields, handled_keys) }}
77
+ {%- endif %}
78
+ {{- '\n</parameter>' }}
79
+ {%- endfor %}
80
+ {%- endif %}
81
+ {%- if tool.parameters is defined %}
82
+ {%- set handled_keys = ['type', 'properties'] %}
83
+ {{- render_extra_keys(tool.parameters, handled_keys) }}
84
+ {%- endif %}
85
+ {{- '\n</parameters>' }}
86
+ {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}
87
+ {{- render_extra_keys(tool, handled_keys) }}
88
+ {{- '\n</function>' }}
89
+ {%- endfor %}
90
+ {{- "\n</tools>" }}
91
+ {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
92
+ {%- endif %}
93
+ {{- '<|im_end|>\n' }}
94
+ {%- endif %}
95
+
96
+ {%- for message in loop_messages %}
97
+ {%- set role = message.role | default('') %}
98
+ {%- if role == "assistant" %}
99
+ {%- set content_str = '' if message.content is none else (message.content | string) %}
100
+ {%- set trimmed_content = content_str | trim %}
101
+
102
+ {%- set has_reasoning_content = message.reasoning_content is defined %}
103
+ {%- set has_reasoning = has_reasoning_content or (message.reasoning is defined) %}
104
+
105
+ {%- if has_reasoning_content %}
106
+ {%- set reasoning_value = message.reasoning_content %}
107
+ {%- elif message.reasoning is defined %}
108
+ {%- set reasoning_value = message.reasoning %}
109
+ {%- else %}
110
+ {%- set reasoning_value = none %}
111
+ {%- endif %}
112
+
113
+ {%- set has_tool_calls = message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls is not string and message.tool_calls | length > 0 %}
114
+
115
+ {{- '<|im_start|>assistant\n' }}
116
+ {%- if has_reasoning %}
117
+ {%- if reasoning_value %}
118
+ {{- '<think>' + (reasoning_value | string | trim) + '</think>' }}
119
+ {%- else %}
120
+ {{- '<think></think>' }}
121
+ {%- endif %}
122
+ {%- if trimmed_content %}
123
+ {{- '\n' + trimmed_content }}
124
+ {%- endif %}
125
+ {%- elif has_tool_calls %}
126
+ {%- if trimmed_content %}
127
+ {{- trimmed_content }}
128
+ {%- endif %}
129
+ {%- else %}
130
+ {{- content_str }}
131
+ {%- endif %}
132
+
133
+ {%- if has_tool_calls %}
134
+ {%- for tool_call in message.tool_calls %}
135
+ {%- set separator = '\n' if ((loop.first and (has_reasoning or trimmed_content)) or (not loop.first)) else '' -%}
136
+ {{- separator + render_tool_call(tool_call) }}
137
+ {%- endfor %}
138
+ {%- endif %}
139
+ {{- '<|im_end|>\n' }}
140
+ {%- elif role == "tool" or role == "observation" or role == "function" %}
141
+ {%- if loop.first or loop.previtem.role not in ["tool", "observation", "function"] %}
142
+ {{- '<|im_start|>user\n' }}
143
+ {%- endif %}
144
+ {{- '<tool_response>\n' }}
145
+ {{- '' if message.content is none else (message.content | string) }}
146
+ {{- '\n</tool_response>\n' }}
147
+ {%- if loop.last or loop.nextitem.role not in ["tool", "observation", "function"] %}
148
+ {{- '<|im_end|>\n' }}
149
+ {%- endif %}
150
+ {%- else %}
151
+ {{- '<|im_start|>' + (role | string) }}
152
+ {{- '\n' + ('' if message.content is none else (message.content | string)) }}
153
+ {{- '<|im_end|>\n' }}
154
+ {%- endif %}
155
+ {%- endfor %}
156
+
157
+ {%- if add_generation_prompt %}
158
+ {{- '<|im_start|>assistant\n<think>' }}
159
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "AfmoeForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_afmoe.AfmoeConfig",
8
+ "AutoModel": "modeling_afmoe.AfmoeModel",
9
+ "AutoModelForCausalLM": "modeling_afmoe.AfmoeForCausalLM"
10
+ },
11
+ "dtype": "bfloat16",
12
+ "global_attn_every_n_layers": 4,
13
+ "head_dim": 128,
14
+ "hidden_act": "silu",
15
+ "hidden_size": 3072,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 12288,
18
+ "layer_types": [
19
+ "sliding_attention",
20
+ "sliding_attention",
21
+ "sliding_attention",
22
+ "full_attention",
23
+ "sliding_attention",
24
+ "sliding_attention",
25
+ "sliding_attention",
26
+ "full_attention",
27
+ "sliding_attention",
28
+ "sliding_attention",
29
+ "sliding_attention",
30
+ "full_attention",
31
+ "sliding_attention",
32
+ "sliding_attention",
33
+ "sliding_attention",
34
+ "full_attention",
35
+ "sliding_attention",
36
+ "sliding_attention",
37
+ "sliding_attention",
38
+ "full_attention",
39
+ "sliding_attention",
40
+ "sliding_attention",
41
+ "sliding_attention",
42
+ "full_attention",
43
+ "sliding_attention",
44
+ "sliding_attention",
45
+ "sliding_attention",
46
+ "full_attention",
47
+ "sliding_attention",
48
+ "sliding_attention",
49
+ "sliding_attention",
50
+ "full_attention",
51
+ "sliding_attention",
52
+ "sliding_attention",
53
+ "sliding_attention",
54
+ "full_attention",
55
+ "sliding_attention",
56
+ "sliding_attention",
57
+ "sliding_attention",
58
+ "full_attention",
59
+ "sliding_attention",
60
+ "sliding_attention",
61
+ "sliding_attention",
62
+ "full_attention",
63
+ "sliding_attention",
64
+ "sliding_attention",
65
+ "sliding_attention",
66
+ "full_attention",
67
+ "sliding_attention",
68
+ "sliding_attention",
69
+ "sliding_attention",
70
+ "full_attention",
71
+ "sliding_attention",
72
+ "sliding_attention",
73
+ "sliding_attention",
74
+ "full_attention",
75
+ "sliding_attention",
76
+ "sliding_attention",
77
+ "sliding_attention",
78
+ "full_attention"
79
+ ],
80
+ "load_balance_coeff": 0.00005,
81
+ "max_position_embeddings": 262144,
82
+ "model_type": "afmoe",
83
+ "moe_intermediate_size": 3072,
84
+ "mup_enabled": true,
85
+ "n_group": 1,
86
+ "num_attention_heads": 48,
87
+ "num_dense_layers": 6,
88
+ "num_expert_groups": 1,
89
+ "num_experts": 256,
90
+ "num_experts_per_tok": 4,
91
+ "num_hidden_layers": 60,
92
+ "num_key_value_heads": 8,
93
+ "num_limited_groups": 1,
94
+ "num_shared_experts": 1,
95
+ "rms_norm_eps": 1e-05,
96
+ "rope_scaling": null,
97
+ "rope_theta": 10000,
98
+ "route_norm": true,
99
+ "route_scale": 2.448,
100
+ "score_func": "sigmoid",
101
+ "sliding_window": 4096,
102
+ "tie_word_embeddings": false,
103
+ "topk_group": 1,
104
+ "transformers_version": "4.57.1",
105
+ "use_cache": true,
106
+ "use_grouped_mm": true,
107
+ "vocab_size": 200192
108
+ }
configuration_afmoe.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from transformers.configuration_utils import PretrainedConfig
16
+ from transformers.modeling_rope_utils import rope_config_validation
17
+ from transformers.configuration_utils import layer_type_validation
18
+ from transformers.utils import logging
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+ class AfmoeConfig(PretrainedConfig):
23
+ """
24
+ n_group (`int`, *optional*, defaults to 1):
25
+ Number of groups for routed experts.
26
+ topk_group (`int`, *optional*, defaults to 1):
27
+ Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
28
+ """
29
+ model_type = "afmoe"
30
+ base_model_pp_plan = {
31
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
32
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
33
+ "norm": (["hidden_states"], ["hidden_states"]),
34
+ }
35
+
36
+ def __init__(
37
+ self,
38
+ num_hidden_layers: int = 32,
39
+ vocab_size: int = 200192,
40
+ hidden_size: int = 2048,
41
+ intermediate_size: int = 6144,
42
+ moe_intermediate_size=1408,
43
+ num_dense_layers=1,
44
+ num_attention_heads=16,
45
+ num_key_value_heads=None,
46
+ head_dim=128,
47
+ hidden_act="silu",
48
+ max_position_embeddings=16384,
49
+ initializer_range=0.02,
50
+ rms_norm_eps=1e-5,
51
+ use_cache=True,
52
+ tie_word_embeddings=False,
53
+ rope_theta=10000.0,
54
+ rope_scaling=None,
55
+ num_experts=64,
56
+ num_experts_per_tok=6,
57
+ num_shared_experts=2,
58
+ num_expert_groups=1,
59
+ num_limited_groups=1,
60
+ score_func="sigmoid",
61
+ route_norm=True,
62
+ route_scale=1.0,
63
+ global_attn_every_n_layers=4,
64
+ sliding_window=1024,
65
+ mup_enabled=False,
66
+ layer_types=None,
67
+ attention_dropout: float = 0.0,
68
+ n_group: int = 1,
69
+ topk_group: int = 1,
70
+ **kwargs,
71
+ ):
72
+ self.vocab_size = vocab_size
73
+ self.max_position_embeddings = max_position_embeddings
74
+ self.hidden_size = hidden_size
75
+ self.intermediate_size = intermediate_size
76
+ self.num_hidden_layers = num_hidden_layers
77
+ self.num_dense_layers = num_dense_layers
78
+ self.num_attention_heads = num_attention_heads
79
+ self.head_dim = head_dim
80
+ self.hidden_act = hidden_act
81
+ self.initializer_range = initializer_range
82
+ self.rms_norm_eps = rms_norm_eps
83
+ self.use_cache = use_cache
84
+ self.rope_theta = rope_theta
85
+ self.rope_scaling = rope_scaling
86
+
87
+
88
+ # MoE specific
89
+ self.moe_intermediate_size = moe_intermediate_size
90
+ self.num_experts_per_tok = num_experts_per_tok
91
+ self.n_group = n_group
92
+ self.topk_group = topk_group
93
+ self.num_experts = num_experts
94
+ self.num_shared_experts = num_shared_experts
95
+ self.num_expert_groups = num_expert_groups
96
+ self.num_limited_groups = num_limited_groups
97
+ self.score_func = score_func
98
+ self.route_norm = route_norm
99
+ self.route_scale = route_scale
100
+
101
+
102
+ # Attention specific
103
+ self.attention_dropout = attention_dropout
104
+ self.global_attn_every_n_layers = global_attn_every_n_layers
105
+ self.sliding_window = sliding_window
106
+ self.layer_types = layer_types
107
+ if self.layer_types is None:
108
+ self.layer_types = [
109
+ "sliding_attention" if bool((i + 1) % global_attn_every_n_layers) else "full_attention" for i in range(self.num_hidden_layers)
110
+ ]
111
+ layer_type_validation(self.layer_types)
112
+
113
+ # muP specific
114
+ self.mup_enabled = mup_enabled
115
+
116
+ if num_key_value_heads is None:
117
+ num_key_value_heads = num_attention_heads
118
+
119
+ self.num_key_value_heads = num_key_value_heads
120
+
121
+
122
+ # Validate rope configs
123
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
124
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
125
+ rope_config_validation(self)
126
+
127
+ super().__init__(
128
+ tie_word_embeddings=tie_word_embeddings,
129
+ **kwargs,
130
+ )
131
+
132
+
133
+ __all__ = ["AfmoeConfig"]
generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 3,
5
+ "pad_token_id": 12,
6
+ "transformers_version": "4.57.3",
7
+ "temperature": 0.8,
8
+ "top_p": 0.8
9
+ }
model-00001-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:853c70c5c46ebef7fff4e96c2eda7b0a31f8415eab7e62f984aa4180aaed2ab9
3
+ size 2459965736
model-00002-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33a07dd8ba9a4065c05de7e5ca2b3fb117b8847e5e926a73dac81f77bcf04493
3
+ size 704696408
model-00003-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c4d625226b66f0f4560b2299b339bb1cab0faab7f269be57526d549fc6287f2
3
+ size 704696408
model-00004-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c47e8a110787e454a6ee52ffead8903ffd9a42804a6bfd65448a8864fcb949e
3
+ size 704696408
model-00005-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:290a75fc94da91c359c345e8d760d342e38a0c3b8bdfc138bfd5ff0a6e395bbd
3
+ size 29359329168
model-00006-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29e9b301281696ff948a03c6cf1570fb221daeb0efa4e7793bcb024ef4e5c7f1
3
+ size 29359329168
model-00007-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9e5f8039d277c892b5901bf39b8ec195652fad2ce561142660e70a2118c1de3
3
+ size 29359330736
model-00008-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b3cf1e2b37d0f37efc818753f7a6b5a3a6a34d5d40799f7861d6ae7f2f50cdb
3
+ size 29359330736
model-00009-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5619671ed44cdb8e8e0e014825b0166364c5fda1cdb16e9df8bc30f369592d4
3
+ size 29359330736
model-00010-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60d3ca7fbd2f5c0e9fdc947a22d88ce92e28a6b5a870a68618eab5160673955c
3
+ size 29359330736
model-00011-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:604d743dc7de64413191bf367d9724d042b5a760f073e480fe8de07f4cbbe875
3
+ size 29359330736
model-00012-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07bad9551c1a538461a5e47c5660950bfcd1a06fb5763664c10dc295024d186f
3
+ size 29359330736
model-00013-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7347d5299a1af7e96c705d75082c3d2f6912f07019281cc472565040f4ac401
3
+ size 29359330736
model-00014-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dcf043b318b569fff0a1bdfa4d4868170cd08b2c0138e48516ef87b4eadb9d2b
3
+ size 29359330736
model-00015-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:513c803339e049cd378333a0463d42517b4dc9614d14cf3ad030cc769f4caab0
3
+ size 29359330736
model-00016-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55c3326f636d40910ab9f5d5e3e4295b695ab6b920fafb55833dde0c15fbe218
3
+ size 29359330736
model-00017-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff887d5ed74cb31778e7ffca3c68994c12d58aca5498daf8f73ec11315bd871a
3
+ size 29359330736
model-00018-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ad6ef2e4ce52c22a75a88e2e1701bc9c00381f5193759bd1db48e08865c93f8
3
+ size 29359330736
model-00019-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4860ab9597c2cd8b4239762045de948a4022e748eb19af6caf917b919c9635ac
3
+ size 29359330736
model-00020-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bf756f75c191ddf476814e2c61d77640f6fa4d78880200f56e9fc6d8b31ca2f
3
+ size 29359330736
model-00021-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f05bb8b1018a485d0081465586ae362127e731b07f452c79c648f87cb2cddfaa
3
+ size 29359330736
model-00022-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d498f7ba5144e18c0685058a077aeb169a56b4bdf9a5bfe8df38650f868c44c
3
+ size 29359330736
model-00023-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7edaa4f5e19ce9941b404dd184a2f2c6c1f7b231b3fed9d92a57d442843ac4a
3
+ size 29359330736
model-00024-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79ce79cf55200e28dd3dd1ba232aab986cae96619e12da880e8da08506e17dea
3
+ size 29359330736
model-00025-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b101d2e8b1d76865e863fa3b8854839e18e3e98c194d4844e8d99a10d36b1a2
3
+ size 29359330736
model-00026-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab10a3fee4a0c6f6da9e697d64b4cd63557018f5a398118d04b02b734b93de6c
3
+ size 29359330736
model-00027-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fee6d0d7d3abb8148c5870feab0949be5264e4d6b45a46c1f68e5b195448dcf3
3
+ size 29359330736
model-00028-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c56fffdea546c522d14cb0afd038b6e33f04dffb479fdeef04c3fdaa03c4520
3
+ size 29359330736
model-00029-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:476eaed8aa9d5f86ffeac3778ead337513513b45be6cbfbf496074f2743af91f
3
+ size 29359330736
model-00030-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90b37fb9b43e31963a0df04408be4c285e75fe2a4c8aa57a62097a6a2721aee4
3
+ size 29359330736
model-00031-of-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2fa868e56a5ffe5797cddc5db0931f2bb6fae313338935e4cea19f589eb88d1
3
+ size 29359330736
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_afmoe.py ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+ from transformers.activations import ACT2FN
8
+ from transformers.generation import GenerationMixin
9
+ from transformers.modeling_outputs import (
10
+ MoeCausalLMOutputWithPast,
11
+ MoeModelOutputWithPast,
12
+ )
13
+ from transformers.modeling_utils import PreTrainedModel, ALL_ATTENTION_FUNCTIONS
14
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
15
+ from transformers.masking_utils import (
16
+ create_causal_mask,
17
+ create_sliding_window_causal_mask,
18
+ )
19
+ from transformers.modeling_layers import GradientCheckpointingLayer
20
+ from transformers.processing_utils import Unpack
21
+ from transformers.utils import TransformersKwargs
22
+ from transformers.cache_utils import Cache, DynamicCache
23
+ from transformers.integrations import use_kernel_forward_from_hub
24
+
25
+
26
+ try:
27
+ from .configuration_afmoe import AfmoeConfig
28
+ except:
29
+ from configuration_afmoe import AfmoeConfig
30
+
31
+ class AfmoeRotaryEmbedding(nn.Module):
32
+
33
+ def __init__(self, config: AfmoeConfig, device=None):
34
+ super().__init__()
35
+ # BC: "rope_type" was originally "type"
36
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
37
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
38
+ else:
39
+ self.rope_type = "default"
40
+ self.max_seq_len_cached = config.max_position_embeddings
41
+ self.original_max_seq_len = config.max_position_embeddings
42
+
43
+ self.config = config
44
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
45
+
46
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
47
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
48
+ self.original_inv_freq = self.inv_freq
49
+
50
+ def _dynamic_frequency_update(self, position_ids, device):
51
+ """
52
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
53
+ 1 - growing beyond the cached sequence length (allow scaling)
54
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
55
+ """
56
+ seq_len = torch.max(position_ids) + 1
57
+ if seq_len > self.max_seq_len_cached: # growth
58
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
59
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
60
+ self.max_seq_len_cached = seq_len
61
+
62
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
63
+ # This .to() is needed if the model has been moved to a device after being initialized (because
64
+ # the buffer is automatically moved, but not the original copy)
65
+ self.original_inv_freq = self.original_inv_freq.to(device)
66
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
67
+ self.max_seq_len_cached = self.original_max_seq_len
68
+
69
+ @torch.no_grad()
70
+ def forward(self, x, position_ids):
71
+ if "dynamic" in self.rope_type:
72
+ self._dynamic_frequency_update(position_ids, device=x.device)
73
+
74
+ # Core RoPE block
75
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
76
+ position_ids_expanded = position_ids[:, None, :].float()
77
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
78
+ device_type = x.device.type
79
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
80
+ with torch.autocast(device_type=device_type, enabled=False):
81
+ freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
82
+ emb = torch.cat((freqs, freqs), dim=-1)
83
+ cos = emb.cos()
84
+ sin = emb.sin()
85
+
86
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
87
+ cos = cos * self.attention_scaling
88
+ sin = sin * self.attention_scaling
89
+
90
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
91
+
92
+
93
+ def rotate_half(x):
94
+ """Rotates half the hidden dims of the input."""
95
+ x1 = x[..., : x.shape[-1] // 2]
96
+ x2 = x[..., x.shape[-1] // 2 :]
97
+ return torch.cat((-x2, x1), dim=-1)
98
+
99
+
100
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
101
+ """Applies Rotary Position Embedding to the query and key tensors.
102
+
103
+ Args:
104
+ q (`torch.Tensor`): The query tensor.
105
+ k (`torch.Tensor`): The key tensor.
106
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
107
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
108
+ position_ids (`torch.Tensor`, *optional*):
109
+ Deprecated and unused.
110
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
111
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
112
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
113
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
114
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
115
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
116
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
117
+ Returns:
118
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
119
+ """
120
+ cos = cos.unsqueeze(unsqueeze_dim)
121
+ sin = sin.unsqueeze(unsqueeze_dim)
122
+ q_embed = (q * cos) + (rotate_half(q) * sin)
123
+ k_embed = (k * cos) + (rotate_half(k) * sin)
124
+ return q_embed, k_embed
125
+
126
+
127
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
128
+ """
129
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
130
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
131
+ """
132
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
133
+ if n_rep == 1:
134
+ return hidden_states
135
+ hidden_states = hidden_states[:, :, None, :, :].expand(
136
+ batch, num_key_value_heads, n_rep, slen, head_dim
137
+ )
138
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
139
+
140
+ @use_kernel_forward_from_hub("RMSNorm")
141
+ class AfmoeRMSNorm(nn.Module):
142
+ def __init__(self, hidden_size: int, eps: float):
143
+ """
144
+ AfmoeRMSNorm is equivalent to T5LayerNorm
145
+ """
146
+ super().__init__()
147
+ self.weight = nn.Parameter(torch.ones(hidden_size))
148
+ self.variance_epsilon = eps
149
+
150
+ def forward(self, hidden_states):
151
+ input_dtype = hidden_states.dtype
152
+ hidden_states = hidden_states.to(torch.float32)
153
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
154
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
155
+ return self.weight * hidden_states.to(input_dtype)
156
+
157
+ def extra_repr(self):
158
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
159
+
160
+
161
+
162
+ def eager_attention_forward(
163
+ module: nn.Module,
164
+ query: torch.Tensor,
165
+ key: torch.Tensor,
166
+ value: torch.Tensor,
167
+ attention_mask: Optional[torch.Tensor],
168
+ scaling: float,
169
+ dropout: float = 0.0,
170
+ **kwargs,
171
+ ):
172
+ key_states = repeat_kv(key, module.num_key_value_groups)
173
+ value_states = repeat_kv(value, module.num_key_value_groups)
174
+
175
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
176
+ if attention_mask is not None:
177
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
178
+ attn_weights = attn_weights + causal_mask
179
+
180
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
181
+ query.dtype
182
+ )
183
+ attn_weights = nn.functional.dropout(
184
+ attn_weights, p=dropout, training=module.training
185
+ )
186
+ attn_output = torch.matmul(attn_weights, value_states)
187
+ attn_output = attn_output.transpose(1, 2).contiguous()
188
+
189
+ return attn_output, attn_weights
190
+
191
+
192
+ class AfmoeMLP(nn.Module):
193
+ def __init__(self, config, intermediate_size=None):
194
+ super().__init__()
195
+ self.config = config
196
+ self.hidden_size = config.hidden_size
197
+ self.intermediate_size = intermediate_size or config.intermediate_size
198
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
199
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
200
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
201
+ self.act_fn = ACT2FN[config.hidden_act]
202
+
203
+ def forward(self, x):
204
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
205
+
206
+
207
+ class AfmoeTokenChoiceRouter(nn.Module):
208
+ """Token-choice top-K router for MoE routing."""
209
+
210
+ def __init__(self, config):
211
+ super().__init__()
212
+ self.config = config
213
+ self.top_k = config.num_experts_per_tok
214
+ self.num_experts = config.num_experts
215
+ self.score_func = config.score_func
216
+ self.route_norm = config.route_norm
217
+ self.route_scale = config.route_scale
218
+ self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
219
+
220
+ def forward(self, hidden_states, expert_bias: torch.Tensor | None):
221
+ _, _, hidden_dim = hidden_states.shape
222
+ hidden_states = hidden_states.view(-1, hidden_dim)
223
+
224
+ scores = self.gate(hidden_states)
225
+
226
+ # Apply scoring function in float32 for stability
227
+ if self.score_func == "sigmoid":
228
+ scores = torch.sigmoid(scores.to(torch.float32))
229
+ else:
230
+ scores = F.softmax(scores.to(torch.float32), dim=-1)
231
+
232
+ if expert_bias is not None:
233
+ _, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
234
+ top_scores = scores.gather(dim=1, index=selected_experts)
235
+ else:
236
+ top_scores, selected_experts = torch.topk(scores, k=self.top_k, dim=1)
237
+
238
+ # Normalize weights if using sigmoid
239
+ if self.score_func == "sigmoid" and self.route_norm:
240
+ denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20
241
+ top_scores = top_scores / denominator
242
+
243
+ top_scores = top_scores * self.route_scale
244
+ return top_scores, selected_experts
245
+
246
+ class AfmoeMoE(nn.Module):
247
+ def __init__(self, config):
248
+ super().__init__()
249
+ self.config = config
250
+ self.router = AfmoeTokenChoiceRouter(config)
251
+
252
+ self.shared_experts = None
253
+ if config.num_shared_experts > 0:
254
+ self.shared_experts = AfmoeMLP(
255
+ config, config.moe_intermediate_size * config.num_shared_experts
256
+ )
257
+ self.experts = nn.ModuleList(
258
+ [AfmoeMLP(
259
+ config, intermediate_size=config.moe_intermediate_size
260
+ ) for _ in range(config.num_experts)]
261
+ )
262
+ self.expert_bias = nn.Parameter(torch.zeros(config.num_experts, dtype=torch.float32), requires_grad=False)
263
+
264
+
265
+ def forward(self, hidden_states):
266
+ batch_size, seq_len, hidden_dim = hidden_states.shape
267
+ hidden_states_flat = hidden_states.view(-1, hidden_dim)
268
+
269
+ # Get routing decisions
270
+ top_scores, selected_experts = self.router(hidden_states, self.expert_bias)
271
+
272
+ # Process through shared experts
273
+ if self.shared_experts is not None:
274
+ shared_output = self.shared_experts(hidden_states_flat)
275
+ else:
276
+ shared_output = torch.zeros_like(hidden_states_flat)
277
+
278
+ # Reorder tokens by expert for efficient processing
279
+ token_indices_sorted = torch.argsort(selected_experts.view(-1), stable=True)
280
+ top_scores_sorted = top_scores.view(-1)[token_indices_sorted]
281
+ token_to_expert = selected_experts.view(-1)[token_indices_sorted]
282
+ token_indices_sorted = token_indices_sorted // self.config.num_experts_per_tok
283
+
284
+ # Gather input tokens
285
+ token_indices_expanded = token_indices_sorted.unsqueeze(-1).expand(
286
+ -1, hidden_dim
287
+ )
288
+ routed_input = torch.gather(
289
+ hidden_states_flat, dim=0, index=token_indices_expanded
290
+ )
291
+
292
+ routed_output = torch.zeros_like(routed_input)
293
+ for expert_id in range(self.config.num_experts):
294
+ mask = token_to_expert == expert_id
295
+ if mask.any():
296
+ expert_input = routed_input[mask]
297
+ expert_out = self.experts[expert_id](expert_input)
298
+ routed_output[mask] = expert_out
299
+
300
+ routed_output = (
301
+ routed_output.to(torch.float32) * top_scores_sorted.unsqueeze(-1)
302
+ ).to(hidden_states.dtype)
303
+
304
+ # Scatter back to original positions
305
+ output = shared_output.scatter_add(
306
+ dim=0, index=token_indices_expanded, src=routed_output
307
+ )
308
+
309
+ return output.view(batch_size, seq_len, hidden_dim)
310
+
311
+
312
+ class AfmoeAttention(nn.Module):
313
+ """Multi-headed attention with local/global pattern and gating."""
314
+
315
+ def __init__(self, config: AfmoeConfig, layer_idx: int):
316
+ super().__init__()
317
+ self.config = config
318
+ self.layer_idx = layer_idx
319
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
320
+ self.num_heads = config.num_attention_heads
321
+ self.num_key_value_heads = config.num_key_value_heads
322
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
323
+
324
+ self.scaling = self.head_dim**-0.5
325
+ self.attention_dropout = config.attention_dropout
326
+ self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention"
327
+ self.sliding_window = config.sliding_window if self.is_local_attention else None
328
+
329
+ self.q_proj = nn.Linear(
330
+ config.hidden_size, self.num_heads * self.head_dim, bias=False
331
+ )
332
+ self.k_proj = nn.Linear(
333
+ config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
334
+ )
335
+ self.v_proj = nn.Linear(
336
+ config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
337
+ )
338
+ self.o_proj = nn.Linear(
339
+ self.num_heads * self.head_dim, config.hidden_size, bias=False
340
+ )
341
+
342
+ self.q_norm = AfmoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
343
+ self.k_norm = AfmoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
344
+
345
+ self.gate_proj = nn.Linear(
346
+ config.hidden_size, self.num_heads * self.head_dim, bias=False
347
+ )
348
+
349
+ def forward(
350
+ self,
351
+ hidden_states: torch.Tensor,
352
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
353
+ attention_mask: Optional[torch.Tensor],
354
+ past_key_value: Optional[Cache] = None,
355
+ cache_position: Optional[torch.LongTensor] = None,
356
+ **kwargs: Unpack[TransformersKwargs],
357
+ ) -> torch.Tensor:
358
+
359
+ input_shape = hidden_states.shape[:-1]
360
+ hidden_shape = (*input_shape, -1, self.head_dim)
361
+
362
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
363
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
364
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
365
+ gate_states = self.gate_proj(hidden_states)
366
+
367
+ query_states = self.q_norm(query_states)
368
+ key_states = self.k_norm(key_states)
369
+
370
+ query_states = query_states.transpose(1, 2)
371
+ key_states = key_states.transpose(1, 2)
372
+ value_states = value_states.transpose(1, 2)
373
+
374
+ if self.is_local_attention:
375
+ cos, sin = position_embeddings
376
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
377
+
378
+ if past_key_value is not None:
379
+ cache_kwargs = {"cache_position": cache_position}
380
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
381
+
382
+ attention_interface: Callable = eager_attention_forward
383
+ if self.config._attn_implementation != "eager":
384
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
385
+ self.config._attn_implementation
386
+ ]
387
+
388
+ output, _ = attention_interface(
389
+ self,
390
+ query_states,
391
+ key_states,
392
+ value_states,
393
+ attention_mask=attention_mask,
394
+ dropout=0.0 if not self.training else self.attention_dropout,
395
+ scaling=self.scaling,
396
+ sliding_window=self.sliding_window,
397
+ **kwargs,
398
+ )
399
+
400
+ output = output.view(*input_shape, -1).contiguous()
401
+ output = output * F.sigmoid(gate_states)
402
+ return self.o_proj(output)
403
+
404
+
405
+ class AfmoeDecoderLayer(GradientCheckpointingLayer):
406
+ def __init__(self, config: AfmoeConfig, layer_idx: int):
407
+ super().__init__()
408
+ self.hidden_size = config.hidden_size
409
+ self.layer_idx = layer_idx
410
+
411
+ self.self_attn = AfmoeAttention(config=config, layer_idx=layer_idx)
412
+ self.attention_type = config.layer_types[layer_idx]
413
+
414
+ # Dual normalization for attention
415
+ self.input_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
416
+ self.post_attention_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
417
+
418
+ # Dual normalization for FFN
419
+ self.pre_mlp_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
420
+ self.post_mlp_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
421
+
422
+ # MoE or dense FFN
423
+ self.moe_enabled = layer_idx >= config.num_dense_layers
424
+ if self.moe_enabled:
425
+ self.mlp = AfmoeMoE(config)
426
+ else:
427
+ self.mlp = AfmoeMLP(config)
428
+
429
+ def forward(
430
+ self,
431
+ hidden_states: torch.Tensor,
432
+ attention_mask: Optional[torch.Tensor] = None,
433
+ position_ids: Optional[torch.LongTensor] = None,
434
+ past_key_value: Optional[Cache] = None,
435
+ use_cache: Optional[bool] = None,
436
+ cache_position: Optional[torch.LongTensor] = None,
437
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
438
+ **kwargs: Unpack[TransformersKwargs],
439
+ ) -> torch.FloatTensor:
440
+ residual = hidden_states
441
+
442
+ # Self Attention with dual normalization
443
+ hidden_states = self.input_layernorm(hidden_states)
444
+ hidden_states = self.self_attn(
445
+ hidden_states=hidden_states,
446
+ attention_mask=attention_mask,
447
+ position_ids=position_ids,
448
+ past_key_value=past_key_value,
449
+ use_cache=use_cache,
450
+ cache_position=cache_position,
451
+ position_embeddings=position_embeddings,
452
+ **kwargs,
453
+ )
454
+ hidden_states = self.post_attention_layernorm(hidden_states)
455
+ hidden_states = residual + hidden_states
456
+
457
+ # FFN with dual normalization
458
+ residual = hidden_states
459
+ hidden_states = self.pre_mlp_layernorm(hidden_states)
460
+
461
+ if self.moe_enabled:
462
+ hidden_states = self.mlp(hidden_states)
463
+ else:
464
+ hidden_states = self.mlp(hidden_states)
465
+
466
+ hidden_states = self.post_mlp_layernorm(hidden_states)
467
+ hidden_states = residual + hidden_states
468
+ return hidden_states
469
+
470
+
471
+ class AfmoePreTrainedModel(PreTrainedModel):
472
+ config_class = AfmoeConfig
473
+ base_model_prefix = "model"
474
+ _no_split_modules = ["AfmoeDecoderLayer"]
475
+ _skip_keys_device_placement = ["past_key_values"]
476
+ _keep_in_fp32_modules = [
477
+ "input_layernorm",
478
+ "post_attention_layernorm",
479
+ "pre_mlp_layernorm",
480
+ "post_mlp_layernorm",
481
+ "q_norm",
482
+ "k_norm",
483
+ "norm",
484
+ ]
485
+ _supports_sdpa = True
486
+ _supports_attention_backend = True
487
+ supports_gradient_checkpointing = True
488
+
489
+
490
+ class AfmoeModel(AfmoePreTrainedModel):
491
+ _no_split_modules = ["AfmoeDecoderLayer"]
492
+
493
+ def __init__(self, config: AfmoeConfig):
494
+ super().__init__(config)
495
+ self.padding_idx = config.pad_token_id
496
+ self.vocab_size = config.vocab_size
497
+
498
+ self.embed_tokens = nn.Embedding(
499
+ config.vocab_size, config.hidden_size, self.padding_idx
500
+ )
501
+ self.layers = nn.ModuleList(
502
+ [
503
+ AfmoeDecoderLayer(config, layer_idx)
504
+ for layer_idx in range(config.num_hidden_layers)
505
+ ]
506
+ )
507
+ self.norm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
508
+ self.rotary_emb = AfmoeRotaryEmbedding(config=config)
509
+ self.gradient_checkpointing = False
510
+
511
+ self.post_init()
512
+
513
+ def get_input_embeddings(self):
514
+ return self.embed_tokens
515
+
516
+ def set_input_embeddings(self, value):
517
+ self.embed_tokens = value
518
+
519
+
520
+ def forward(
521
+ self,
522
+ input_ids: torch.LongTensor,
523
+ attention_mask: Optional[torch.Tensor] = None,
524
+ position_ids: Optional[torch.LongTensor] = None,
525
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
526
+ inputs_embeds: Optional[torch.FloatTensor] = None,
527
+ use_cache: Optional[bool] = None,
528
+ cache_position: Optional[torch.LongTensor] = None,
529
+ **kwargs: Unpack[TransformersKwargs],
530
+ ) -> MoeModelOutputWithPast:
531
+ if (input_ids is None) ^ (inputs_embeds is not None):
532
+ raise ValueError(
533
+ "You must specify exactly one of input_ids or inputs_embeds"
534
+ )
535
+
536
+ if use_cache and past_key_values is None:
537
+ past_key_values = DynamicCache()
538
+
539
+ if inputs_embeds is None:
540
+ inputs_embeds = self.embed_tokens(input_ids)
541
+
542
+ if cache_position is None:
543
+ past_seen_tokens = (
544
+ past_key_values.get_seq_length() if past_key_values is not None else 0
545
+ )
546
+ cache_position = torch.arange(
547
+ past_seen_tokens,
548
+ past_seen_tokens + inputs_embeds.shape[1],
549
+ device=inputs_embeds.device,
550
+ )
551
+ if position_ids is None:
552
+ position_ids = cache_position.unsqueeze(0)
553
+
554
+ # It may already have been prepared by e.g. `generate`
555
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
556
+ mask_kwargs = {
557
+ "config": self.config,
558
+ "input_embeds": inputs_embeds,
559
+ "attention_mask": attention_mask,
560
+ "cache_position": cache_position,
561
+ "past_key_values": past_key_values,
562
+ }
563
+ causal_mask_mapping = {
564
+ "full_attention": create_causal_mask(**mask_kwargs),
565
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
566
+ }
567
+
568
+ hidden_states = inputs_embeds
569
+
570
+ # Apply muP input scaling if enabled
571
+ if self.config.mup_enabled:
572
+ hidden_states = hidden_states * (self.config.hidden_size**0.5)
573
+
574
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
575
+
576
+ for decoder_layer in self.layers:
577
+ hidden_states = decoder_layer(
578
+ hidden_states,
579
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
580
+ position_ids=position_ids,
581
+ past_key_value=past_key_values,
582
+ use_cache=use_cache,
583
+ cache_position=cache_position,
584
+ position_embeddings=position_embeddings,
585
+ **kwargs,
586
+ )
587
+
588
+ hidden_states = self.norm(hidden_states)
589
+ return MoeModelOutputWithPast(
590
+ last_hidden_state=hidden_states,
591
+ past_key_values=past_key_values,
592
+ )
593
+
594
+
595
+ class AfmoeForCausalLM(AfmoePreTrainedModel, GenerationMixin):
596
+ _tied_weights_keys = ["lm_head.weight"]
597
+ _tp_plan = {"lm_head": "colwise_rep"}
598
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
599
+
600
+ def __init__(self, config):
601
+ super().__init__(config)
602
+ self.model = AfmoeModel(config)
603
+ self.vocab_size = config.vocab_size
604
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
605
+
606
+ # Initialize weights and apply final processing
607
+ self.post_init()
608
+
609
+ def get_input_embeddings(self):
610
+ return self.model.embed_tokens
611
+
612
+ def set_input_embeddings(self, value):
613
+ self.model.embed_tokens = value
614
+
615
+ def get_output_embeddings(self):
616
+ return self.lm_head
617
+
618
+ def set_output_embeddings(self, new_embeddings):
619
+ self.lm_head = new_embeddings
620
+
621
+ def set_decoder(self, decoder):
622
+ self.model = decoder
623
+
624
+ def get_decoder(self):
625
+ return self.model
626
+
627
+ def forward(
628
+ self,
629
+ input_ids: torch.LongTensor,
630
+ attention_mask: Optional[torch.Tensor] = None,
631
+ position_ids: Optional[torch.LongTensor] = None,
632
+ past_key_values: Optional[Cache] = None,
633
+ inputs_embeds: Optional[torch.FloatTensor] = None,
634
+ labels: Optional[torch.LongTensor] = None,
635
+ use_cache: Optional[bool] = None,
636
+ cache_position: Optional[torch.LongTensor] = None,
637
+ logits_to_keep: Union[int, torch.Tensor] = 0,
638
+ token_type_ids: Optional[torch.Tensor] = None, # will be ignored
639
+ **kwargs: Unpack[TransformersKwargs],
640
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
641
+ outputs: MoeModelOutputWithPast = self.model(
642
+ input_ids=input_ids,
643
+ attention_mask=attention_mask,
644
+ position_ids=position_ids,
645
+ past_key_values=past_key_values,
646
+ inputs_embeds=inputs_embeds,
647
+ use_cache=use_cache,
648
+ cache_position=cache_position,
649
+ **kwargs,
650
+ )
651
+
652
+ hidden_states = outputs.last_hidden_state
653
+ # Only compute necessary logits
654
+ slice_indices = (
655
+ slice(-logits_to_keep, None)
656
+ if isinstance(logits_to_keep, int)
657
+ else logits_to_keep
658
+ )
659
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
660
+
661
+ loss = None
662
+ if labels is not None:
663
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
664
+
665
+
666
+ return MoeCausalLMOutputWithPast(
667
+ loss=loss,
668
+ logits=logits,
669
+ past_key_values=outputs.past_key_values,
670
+ hidden_states=outputs.hidden_states,
671
+ attentions=outputs.attentions,
672
+ router_logits=outputs.router_logits,
673
+ )
674
+
675
+
676
+ __all__ = [
677
+ "AfmoeForCausalLM",
678
+ "AfmoeModel",
679
+ "AfmoePreTrainedModel",
680
+ ]
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|begin_of_text|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|im_end|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|pad|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5a93d847b4d3a1da95e9527c30ec10144f63a823e9feec98570274980754897
3
+ size 14614487
tokenizer_config.json ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<|begin_of_text|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<|end_of_text|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "<|im_start|>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "3": {
31
+ "content": "<|im_end|>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ },
38
+ "4": {
39
+ "content": "<name>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "5": {
47
+ "content": "</name>",
48
+ "lstrip": false,
49
+ "normalized": false,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "6": {
55
+ "content": "<description>",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "7": {
63
+ "content": "</description>",
64
+ "lstrip": false,
65
+ "normalized": false,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "8": {
71
+ "content": "<parameters>",
72
+ "lstrip": false,
73
+ "normalized": false,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "9": {
79
+ "content": "</parameters>",
80
+ "lstrip": false,
81
+ "normalized": false,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "10": {
87
+ "content": "<type>",
88
+ "lstrip": false,
89
+ "normalized": false,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "11": {
95
+ "content": "</type>",
96
+ "lstrip": false,
97
+ "normalized": false,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "12": {
103
+ "content": "<|pad|>",
104
+ "lstrip": false,
105
+ "normalized": false,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": true
109
+ },
110
+ "13": {
111
+ "content": "<think>",
112
+ "lstrip": false,
113
+ "normalized": false,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "14": {
119
+ "content": "</think>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "15": {
127
+ "content": "<tools>",
128
+ "lstrip": false,
129
+ "normalized": false,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "16": {
135
+ "content": "</tools>",
136
+ "lstrip": false,
137
+ "normalized": false,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "17": {
143
+ "content": "<tool_call>",
144
+ "lstrip": false,
145
+ "normalized": false,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "18": {
151
+ "content": "</tool_call>",
152
+ "lstrip": false,
153
+ "normalized": false,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "19": {
159
+ "content": "<tool_response>",
160
+ "lstrip": false,
161
+ "normalized": false,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "20": {
167
+ "content": "</tool_response>",
168
+ "lstrip": false,
169
+ "normalized": false,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "21": {
175
+ "content": "<properties>",
176
+ "lstrip": false,
177
+ "normalized": false,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "22": {
183
+ "content": "</properties>",
184
+ "lstrip": false,
185
+ "normalized": false,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "23": {
191
+ "content": "<required>",
192
+ "lstrip": false,
193
+ "normalized": false,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "24": {
199
+ "content": "</required>",
200
+ "lstrip": false,
201
+ "normalized": false,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ },
206
+ "25": {
207
+ "content": "<parameter>",
208
+ "lstrip": false,
209
+ "normalized": false,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": false
213
+ },
214
+ "26": {
215
+ "content": "</parameter>",
216
+ "lstrip": false,
217
+ "normalized": false,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": false
221
+ },
222
+ "27": {
223
+ "content": "<function>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": false
229
+ },
230
+ "28": {
231
+ "content": "</function>",
232
+ "lstrip": false,
233
+ "normalized": false,
234
+ "rstrip": false,
235
+ "single_word": false,
236
+ "special": false
237
+ },
238
+ "29": {
239
+ "content": "<function=",
240
+ "lstrip": false,
241
+ "normalized": false,
242
+ "rstrip": false,
243
+ "single_word": false,
244
+ "special": false
245
+ },
246
+ "30": {
247
+ "content": "<parameter=",
248
+ "lstrip": false,
249
+ "normalized": false,
250
+ "rstrip": false,
251
+ "single_word": false,
252
+ "special": false
253
+ },
254
+ "31": {
255
+ "content": "<|reserved_special_18|>",
256
+ "lstrip": false,
257
+ "normalized": false,
258
+ "rstrip": false,
259
+ "single_word": false,
260
+ "special": true
261
+ }
262
+ },
263
+ "bos_token": "<|begin_of_text|>",
264
+ "clean_up_tokenization_spaces": false,
265
+ "eos_token": "<|im_end|>",
266
+ "extra_special_tokens": {},
267
+ "model_max_length": 65536,
268
+ "pad_token": "<|pad|>",
269
+ "tokenizer_class": "PreTrainedTokenizerFast",
270
+ "use_default_system_prompt": false
271
+ }