Commit ·
d5ccf82
0
Parent(s):
Super-squash branch 'main' using huggingface_hub
Browse filesCo-authored-by: MaziyarPanahi <MaziyarPanahi@users.noreply.huggingface.co>
Co-authored-by: Crystalcareai <Crystalcareai@users.noreply.huggingface.co>
- .gitattributes +36 -0
- README.md +311 -0
- __init__.py +4 -0
- chat_template.jinja +159 -0
- config.json +108 -0
- configuration_afmoe.py +133 -0
- generation_config.json +9 -0
- model-00001-of-00031.safetensors +3 -0
- model-00002-of-00031.safetensors +3 -0
- model-00003-of-00031.safetensors +3 -0
- model-00004-of-00031.safetensors +3 -0
- model-00005-of-00031.safetensors +3 -0
- model-00006-of-00031.safetensors +3 -0
- model-00007-of-00031.safetensors +3 -0
- model-00008-of-00031.safetensors +3 -0
- model-00009-of-00031.safetensors +3 -0
- model-00010-of-00031.safetensors +3 -0
- model-00011-of-00031.safetensors +3 -0
- model-00012-of-00031.safetensors +3 -0
- model-00013-of-00031.safetensors +3 -0
- model-00014-of-00031.safetensors +3 -0
- model-00015-of-00031.safetensors +3 -0
- model-00016-of-00031.safetensors +3 -0
- model-00017-of-00031.safetensors +3 -0
- model-00018-of-00031.safetensors +3 -0
- model-00019-of-00031.safetensors +3 -0
- model-00020-of-00031.safetensors +3 -0
- model-00021-of-00031.safetensors +3 -0
- model-00022-of-00031.safetensors +3 -0
- model-00023-of-00031.safetensors +3 -0
- model-00024-of-00031.safetensors +3 -0
- model-00025-of-00031.safetensors +3 -0
- model-00026-of-00031.safetensors +3 -0
- model-00027-of-00031.safetensors +3 -0
- model-00028-of-00031.safetensors +3 -0
- model-00029-of-00031.safetensors +3 -0
- model-00030-of-00031.safetensors +3 -0
- model-00031-of-00031.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_afmoe.py +680 -0
- special_tokens_map.json +23 -0
- tokenizer.json +3 -0
- tokenizer_config.json +271 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
- es
|
| 6 |
+
- fr
|
| 7 |
+
- de
|
| 8 |
+
- it
|
| 9 |
+
- pt
|
| 10 |
+
- ru
|
| 11 |
+
- ar
|
| 12 |
+
- hi
|
| 13 |
+
- ko
|
| 14 |
+
- zh
|
| 15 |
+
library_name: transformers
|
| 16 |
+
base_model:
|
| 17 |
+
- arcee-ai/Trinity-Large-Base
|
| 18 |
+
arxiv:
|
| 19 |
+
- 2602.17004
|
| 20 |
+
tags:
|
| 21 |
+
- reasoning
|
| 22 |
+
- agentic
|
| 23 |
+
- tool-calling
|
| 24 |
+
- thinking
|
| 25 |
+
---
|
| 26 |
+
<!-- markdownlint-disable first-line-h1 -->
|
| 27 |
+
<!-- markdownlint-disable html -->
|
| 28 |
+
<!-- markdownlint-disable no-duplicate-header -->
|
| 29 |
+
|
| 30 |
+
<div align="center">
|
| 31 |
+
<picture>
|
| 32 |
+
<img
|
| 33 |
+
src="https://cdn-uploads.huggingface.co/production/uploads/6435718aaaef013d1aec3b8b/i-v1KyAMOW_mgVGeic9WJ.png"
|
| 34 |
+
alt="Arcee Trinity Large Thinking"
|
| 35 |
+
style="max-width: 100%; height: auto;"
|
| 36 |
+
>
|
| 37 |
+
</picture>
|
| 38 |
+
</div>
|
| 39 |
+
<hr>
|
| 40 |
+
|
| 41 |
+
# Trinity-Large-Thinking
|
| 42 |
+
|
| 43 |
+
## Introduction
|
| 44 |
+
|
| 45 |
+
Trinity-Large-Thinking is a reasoning-optimized variant of Arcee AI's Trinity-Large family — a 398B-parameter sparse Mixture-of-Experts (MoE) model with approximately 13B active parameters per token. Built on Trinity-Large-Base and post-trained with extended chain-of-thought reasoning and agentic RL, Trinity-Large-Thinking delivers state-of-the-art performance on agentic benchmarks while maintaining strong general capabilities.
|
| 46 |
+
|
| 47 |
+
Trinity-Large-Thinking generates explicit reasoning traces wrapped in `<think>...</think>` blocks before producing its final response. This thinking process is critical to the model's performance — **thinking tokens must be kept in context** for multi-turn conversations and agentic loops to function correctly.
|
| 48 |
+
|
| 49 |
+
Try it at [chat.arcee.ai](http://chat.arcee.ai/)
|
| 50 |
+
|
| 51 |
+
More details on the training of Trinity Large are available in the [technical report](https://arxiv.org/abs/2602.17004).
|
| 52 |
+
|
| 53 |
+
## Key Highlights
|
| 54 |
+
|
| 55 |
+
- **Agentic-first design**: Purpose-built for tool calling, multi-step planning, and agent workflows
|
| 56 |
+
- **State-of-the-art agentic performance**: 94.7% on τ²-Bench, 91.9% on PinchBench, 98.2% on LiveCodeBench
|
| 57 |
+
- **Native reasoning traces**: Extended chain-of-thought via `<think>...</think>` blocks
|
| 58 |
+
- **Compatible with major agent frameworks**: Works out of the box with [OpenClaw](https://github.com/openclaw) and [Hermes Agent](https://github.com/NousResearch/hermes-agent)
|
| 59 |
+
- **Ready to use on [OpenRouter](https://openrouter.ai/)**: No setup required — full reasoning and tool calling support via API
|
| 60 |
+
|
| 61 |
+
## Model Variants
|
| 62 |
+
|
| 63 |
+
The Trinity Large family consists of four checkpoints:
|
| 64 |
+
|
| 65 |
+
- **Trinity-Large-Thinking** (this release): Reasoning-optimized, agentic post-training with extended chain-of-thought
|
| 66 |
+
- **[Trinity-Large-Preview](https://huggingface.co/arcee-ai/Trinity-Large-Preview)**: Lightly post-trained, chat-ready instruct model (no reasoning_content).
|
| 67 |
+
- **[Trinity-Large-TrueBase](https://huggingface.co/arcee-ai/Trinity-Large-TrueBase)**: 10T-token pre-anneal pretraining checkpoint
|
| 68 |
+
- **[Trinity-Large-Base](https://huggingface.co/arcee-ai/Trinity-Large-Base)**: Full 17T-token pretrained foundation model with mid-training anneals
|
| 69 |
+
|
| 70 |
+
## Architecture
|
| 71 |
+
|
| 72 |
+
Trinity-Large-Thinking shares the same sparse MoE architecture as Trinity-Large-Preview.
|
| 73 |
+
|
| 74 |
+
| Hyperparameter | Value |
|
| 75 |
+
|:---|:---:|
|
| 76 |
+
| Total parameters | ~398B |
|
| 77 |
+
| Active parameters per token | ~13B |
|
| 78 |
+
| Experts | 256 (1 shared) |
|
| 79 |
+
| Active experts | 4 |
|
| 80 |
+
| Routing strategy | 4-of-256 (1.56% sparsity) |
|
| 81 |
+
| Dense layers | 6 |
|
| 82 |
+
| Pretraining context length | 8,192 |
|
| 83 |
+
| Context length after extension | 512k |
|
| 84 |
+
| Architecture | Sparse MoE (AfmoeForCausalLM) |
|
| 85 |
+
|
| 86 |
+
## Benchmarks
|
| 87 |
+
|
| 88 |
+
| Benchmark | Trinity-Large-Thinking | Opus-4.6 | GLM-5 | MiniMax-M2.7 | Kimi-K2.5 |
|
| 89 |
+
|---|---:|---:|---:|---:|---:|
|
| 90 |
+
| IFBench | 52.3 | 53.1 | 72.3 | **75.7** | 70.2 |
|
| 91 |
+
| GPQA-Diamond | 76.3 | **89.2** | 81.6 | 86.2 | 86.9 |
|
| 92 |
+
| Tau2-Airline | **88.0** | 82.0 | 80.5 | 80.0 | 80.0 |
|
| 93 |
+
| Tau2-Telecom | 94.7 | 92.1 | **98.2** | 84.8 | 95.9 |
|
| 94 |
+
| PinchBench | 91.9 | **93.3** | 86.4 | 89.8 | 84.8 |
|
| 95 |
+
| AIME25 | 96.3 | **99.8** | 93.3 | 80.0 | 96.3 |
|
| 96 |
+
| BCFLv4 | 70.1 | **77.0** | 70.8 | 70.6 | 68.3 |
|
| 97 |
+
| MMLU-Pro | 83.4 | **89.1** | 85.8 | 80.8 | 87.1 |
|
| 98 |
+
| SWE-bench Verified* | 63.2 | **75.6** | 72.8 | 75.4 | 70.8 |
|
| 99 |
+
|
| 100 |
+
*All models evaluated in mini-swe-agent-v2
|
| 101 |
+
|
| 102 |
+
## Thinking-in-Context: Important Usage Note
|
| 103 |
+
|
| 104 |
+
Trinity-Large-Thinking produces reasoning traces inside `<think>...</think>` blocks before generating its final response.
|
| 105 |
+
|
| 106 |
+
This means:
|
| 107 |
+
|
| 108 |
+
1. **Multi-turn conversations**: When building chat applications, include the full assistant response (thinking + answer) in the conversation history for subsequent turns.
|
| 109 |
+
2. **Agentic loops**: When using Trinity-Large-Thinking as the backbone of an agent (OpenClaw, Hermes Agent, or custom), ensure your tool-calling loop preserves `<think>` blocks in the message history between steps.
|
| 110 |
+
3. **Context window management**: The 512k extended context window accommodates long reasoning chains across many agentic steps. If you must truncate history, prefer removing older turns entirely rather than stripping thinking tokens from recent turns.
|
| 111 |
+
|
| 112 |
+
### How thinking works
|
| 113 |
+
|
| 114 |
+
The model reasons internally before producing its response. When served via vLLM, the reasoning is separated into a dedicated `reasoning_content` field in the API response:
|
| 115 |
+
|
| 116 |
+
// API response structure
|
| 117 |
+
{
|
| 118 |
+
"message": {
|
| 119 |
+
"role": "assistant",
|
| 120 |
+
"reasoning_content": "The user wants flight information. I need to determine the date for next Tuesday, search for flights SFO → JFK, and filter by price < $300.",
|
| 121 |
+
"content": "\n",
|
| 122 |
+
"tool_calls": [{
|
| 123 |
+
"function": {
|
| 124 |
+
"name": "search_flights",
|
| 125 |
+
"arguments": "{\"origin\": \"SFO\", \"destination\": \"JFK\", \"date\": \"2026-04-07\", \"max_price\": 300}"
|
| 126 |
+
}
|
| 127 |
+
}]
|
| 128 |
+
}
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
When building multi-turn agentic loops, include the `reasoning_content` back in the conversation history (re-wrapped in `<think>...</think>` tags within the assistant message) so the model retains its prior reasoning chain.
|
| 132 |
+
|
| 133 |
+
## Training Configuration
|
| 134 |
+
|
| 135 |
+
### Pretraining
|
| 136 |
+
|
| 137 |
+
- Training tokens: 17 trillion
|
| 138 |
+
- Data partner: [Datology](https://www.datologyai.com/)
|
| 139 |
+
|
| 140 |
+
### Posttraining
|
| 141 |
+
|
| 142 |
+
- Instruction tuning and agentic RL with extended chain-of-thought
|
| 143 |
+
- Trained on tool-calling trajectories, multi-step agent tasks, and reasoning chains
|
| 144 |
+
|
| 145 |
+
### Infrastructure
|
| 146 |
+
|
| 147 |
+
- Hardware: 2,048 NVIDIA B300 GPUs
|
| 148 |
+
- Parallelism: HSDP + Expert Parallelism
|
| 149 |
+
- Compute partner: [Prime Intellect](https://www.primeintellect.ai/)
|
| 150 |
+
|
| 151 |
+
## Usage
|
| 152 |
+
|
| 153 |
+
### Running our model
|
| 154 |
+
|
| 155 |
+
- [vLLM](#vllm) (recommended for agentic deployments)
|
| 156 |
+
- [Transformers](#transformers)
|
| 157 |
+
- [API](#api)
|
| 158 |
+
|
| 159 |
+
### vLLM
|
| 160 |
+
|
| 161 |
+
Supported in vLLM 0.11.1+. For agentic use with both reasoning and tool calling:
|
| 162 |
+
|
| 163 |
+
vllm serve arcee-ai/Trinity-Large-Thinking \
|
| 164 |
+
--dtype bfloat16 \
|
| 165 |
+
--enable-reasoning \
|
| 166 |
+
--reasoning-parser deepseek_r1 \
|
| 167 |
+
--enable-auto-tool-choice \
|
| 168 |
+
--tool-call-parser qwen3_coder
|
| 169 |
+
|
| 170 |
+
This configuration:
|
| 171 |
+
- `--reasoning-parser deepseek_r1` — Parses `<think>...</think>` reasoning blocks and exposes them via the `reasoning_content` field in the API response
|
| 172 |
+
- `--tool-call-parser qwen3_coder` — Parses structured tool calls from the model output into the OpenAI-compatible `tool_calls` array
|
| 173 |
+
|
| 174 |
+
**Extracting reasoning content from the API response:**
|
| 175 |
+
|
| 176 |
+
```python
|
| 177 |
+
from openai import OpenAI
|
| 178 |
+
|
| 179 |
+
client = OpenAI(api_key="EMPTY", base_url="http://localhost:8000/v1")
|
| 180 |
+
|
| 181 |
+
response = client.chat.completions.create(
|
| 182 |
+
model="arcee-ai/Trinity-Large-Thinking",
|
| 183 |
+
messages=[
|
| 184 |
+
{"role": "user", "content": "What's the weather like in Paris?"}
|
| 185 |
+
],
|
| 186 |
+
tools=[ # your tool definitions here
|
| 187 |
+
{
|
| 188 |
+
"type": "function",
|
| 189 |
+
"function": {
|
| 190 |
+
"name": "get_weather",
|
| 191 |
+
"description": "Get current weather for a location",
|
| 192 |
+
"parameters": {
|
| 193 |
+
"type": "object",
|
| 194 |
+
"properties": {
|
| 195 |
+
"location": {"type": "string"}
|
| 196 |
+
},
|
| 197 |
+
"required": ["location"]
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
],
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Access reasoning (thinking) content
|
| 205 |
+
reasoning = response.choices[0].message.reasoning_content
|
| 206 |
+
|
| 207 |
+
# Access final response or tool calls
|
| 208 |
+
content = response.choices[0].message.content
|
| 209 |
+
tool_calls = response.choices[0].message.tool_calls
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
**Note on thinking-in-context with vLLM**: When building multi-turn agentic loops, include both `reasoning_content` and `content` in the conversation history you send back to the model. The reasoning content should be re-wrapped in `<think>...</think>` tags within the assistant message.
|
| 213 |
+
|
| 214 |
+
### Transformers
|
| 215 |
+
|
| 216 |
+
Use the `main` transformers branch or pass `trust_remote_code=True` with a released version.
|
| 217 |
+
|
| 218 |
+
```python
|
| 219 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 220 |
+
import torch
|
| 221 |
+
|
| 222 |
+
model_id = "arcee-ai/Trinity-Large-Thinking"
|
| 223 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 224 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 225 |
+
model_id,
|
| 226 |
+
torch_dtype=torch.bfloat16,
|
| 227 |
+
device_map="auto",
|
| 228 |
+
trust_remote_code=True
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
messages = [
|
| 232 |
+
{"role": "user", "content": "Who are you?"},
|
| 233 |
+
]
|
| 234 |
+
|
| 235 |
+
input_ids = tokenizer.apply_chat_template(
|
| 236 |
+
messages,
|
| 237 |
+
add_generation_prompt=True,
|
| 238 |
+
return_tensors="pt"
|
| 239 |
+
).to(model.device)
|
| 240 |
+
|
| 241 |
+
outputs = model.generate(
|
| 242 |
+
input_ids,
|
| 243 |
+
max_new_tokens=4096,
|
| 244 |
+
do_sample=True,
|
| 245 |
+
temperature=0.6,
|
| 246 |
+
top_k=50,
|
| 247 |
+
top_p=0.95
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 251 |
+
print(response)
|
| 252 |
+
```
|
| 253 |
+
|
| 254 |
+
### API
|
| 255 |
+
|
| 256 |
+
Available on OpenRouter:
|
| 257 |
+
|
| 258 |
+
curl -X POST "https://openrouter.ai/v1/chat/completions" \
|
| 259 |
+
-H "Authorization: Bearer $OPENROUTER_API_KEY" \
|
| 260 |
+
-H "Content-Type: application/json" \
|
| 261 |
+
-d '{
|
| 262 |
+
"model": "arcee-ai/trinity-large-thinking",
|
| 263 |
+
"messages": [
|
| 264 |
+
{
|
| 265 |
+
"role": "user",
|
| 266 |
+
"content": "What are some fun things to do in New York?"
|
| 267 |
+
}
|
| 268 |
+
]
|
| 269 |
+
}'
|
| 270 |
+
|
| 271 |
+
## Agentic Use Cases
|
| 272 |
+
|
| 273 |
+
Trinity-Large-Thinking is optimized for deployment as the reasoning backbone of AI agent systems. It has been evaluated and performs excellently with:
|
| 274 |
+
|
| 275 |
+
### OpenClaw
|
| 276 |
+
|
| 277 |
+
Trinity-Large-Thinking works as a drop-in brain for OpenClaw agents. Its native tool-calling format is compatible with OpenClaw's execution loop, and the extended reasoning enables reliable multi-step task completion — from email triage to code generation to meeting scheduling. Our 91.9% PinchBench score reflects real-world OpenClaw task performance.
|
| 278 |
+
|
| 279 |
+
### Hermes Agent
|
| 280 |
+
|
| 281 |
+
Compatible with the Hermes Agent framework from Nous Research. Trinity-Large-Thinking's reasoning traces pair naturally with Hermes's skill-learning loop — the model's explicit chain-of-thought makes skill extraction more reliable, and its strong tool-calling capabilities integrate directly via the Hermes tool-use protocol.
|
| 282 |
+
|
| 283 |
+
### Custom Agent Loops
|
| 284 |
+
|
| 285 |
+
For custom implementations, the key integration pattern is:
|
| 286 |
+
|
| 287 |
+
1. Send the user message with tool definitions
|
| 288 |
+
2. Receive the response with `<think>` reasoning + tool calls
|
| 289 |
+
3. Execute the tool calls
|
| 290 |
+
4. Append the **full** assistant response (thinking + content + tool calls) and tool results to the message history
|
| 291 |
+
5. Send the updated history back for the next step
|
| 292 |
+
6. Repeat until the model produces a final response without tool calls
|
| 293 |
+
|
| 294 |
+
## License
|
| 295 |
+
|
| 296 |
+
Trinity-Large-Thinking is released under the Apache License, Version 2.0.
|
| 297 |
+
|
| 298 |
+
## Citation
|
| 299 |
+
|
| 300 |
+
If you use this model, please cite:
|
| 301 |
+
|
| 302 |
+
@misc{singh2026arceetrinity,
|
| 303 |
+
title = {Arcee Trinity Large Technical Report},
|
| 304 |
+
author = {Varun Singh and Lucas Krauss and Sami Jaghouar and Matej Sirovatka and Charles Goddard and Fares Obied and Jack Min Ong and Jannik Straube and Fern and Aria Harley and Conner Stewart and Colin Kealty and Maziyar Panahi and Simon Kirsten and Anushka Deshpande and Anneketh Vij and Arthur Bresnu and Pranav Veldurthi and Raghav Ravishankar and Hardik Bishnoi and DatologyAI Team and Arcee AI Team and Prime Intellect Team and Mark McQuade and Johannes Hagemann and Lucas Atkins},
|
| 305 |
+
year = {2026},
|
| 306 |
+
eprint = {2602.17004},
|
| 307 |
+
archivePrefix= {arXiv},
|
| 308 |
+
primaryClass = {cs.LG},
|
| 309 |
+
doi = {10.48550/arXiv.2602.17004},
|
| 310 |
+
url = {https://arxiv.org/abs/2602.17004}
|
| 311 |
+
}
|
__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .configuration_afmoe import AfmoeConfig
|
| 2 |
+
from .modeling_afmoe import AfmoeForCausalLM
|
| 3 |
+
|
| 4 |
+
__all__ = ["AfmoeConfig", "AfmoeForCausalLM"]
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<|begin_of_text|>{%- macro render_extra_keys(json_dict, handled_keys) -%}
|
| 2 |
+
{%- if json_dict is mapping %}
|
| 3 |
+
{%- for json_key in json_dict if json_key not in handled_keys %}
|
| 4 |
+
{%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %}
|
| 5 |
+
{{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }}
|
| 6 |
+
{%- else %}
|
| 7 |
+
{{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }}
|
| 8 |
+
{%- endif %}
|
| 9 |
+
{%- endfor %}
|
| 10 |
+
{%- endif %}
|
| 11 |
+
{%- endmacro -%}
|
| 12 |
+
|
| 13 |
+
{%- macro render_tool_call(raw_tool_call) -%}
|
| 14 |
+
{%- if raw_tool_call.function is defined and raw_tool_call.function is mapping %}
|
| 15 |
+
{%- set tool_call = raw_tool_call.function %}
|
| 16 |
+
{%- else %}
|
| 17 |
+
{%- set tool_call = raw_tool_call %}
|
| 18 |
+
{%- endif %}
|
| 19 |
+
{{- '<tool_call>\n<function=' + (tool_call.name | default('') | string) + '>\n' }}
|
| 20 |
+
{%- if tool_call.arguments is defined and tool_call.arguments is mapping %}
|
| 21 |
+
{%- for args_name, args_value in tool_call.arguments.items() %}
|
| 22 |
+
{{- '<parameter=' + (args_name | string) + '>\n' }}
|
| 23 |
+
{%- if args_value is mapping or (args_value is sequence and args_value is not string) %}
|
| 24 |
+
{{- args_value | tojson | safe }}
|
| 25 |
+
{%- else %}
|
| 26 |
+
{{- args_value | string }}
|
| 27 |
+
{%- endif %}
|
| 28 |
+
{{- '\n</parameter>\n' }}
|
| 29 |
+
{%- endfor %}
|
| 30 |
+
{%- endif %}
|
| 31 |
+
{{- '</function>\n</tool_call>' }}
|
| 32 |
+
{%- endmacro -%}
|
| 33 |
+
|
| 34 |
+
{%- set system_message = none %}
|
| 35 |
+
{%- if messages and messages[0]["role"] == "system" %}
|
| 36 |
+
{%- set system_message = messages[0]["content"] %}
|
| 37 |
+
{%- set loop_messages = messages[1:] %}
|
| 38 |
+
{%- else %}
|
| 39 |
+
{%- set loop_messages = messages %}
|
| 40 |
+
{%- endif %}
|
| 41 |
+
|
| 42 |
+
{%- if not tools is defined %}
|
| 43 |
+
{%- set tools = [] %}
|
| 44 |
+
{%- endif %}
|
| 45 |
+
{%- set has_tools = tools is iterable and tools is not string and tools | length > 0 %}
|
| 46 |
+
|
| 47 |
+
{%- if system_message is not none or has_tools %}
|
| 48 |
+
{{- '<|im_start|>system\n' }}
|
| 49 |
+
{%- if system_message is not none %}
|
| 50 |
+
{{- system_message }}
|
| 51 |
+
{%- else %}
|
| 52 |
+
{{- "You are Trinity Large, a helpful assistant developed by Arcee AI, that can interact with a computer to solve tasks." }}
|
| 53 |
+
{%- endif %}
|
| 54 |
+
{%- if has_tools %}
|
| 55 |
+
{{- "\n\n# Tools\n\nYou have access to the following functions:\n\n<tools>" }}
|
| 56 |
+
{%- for tool in tools %}
|
| 57 |
+
{%- if tool.function is defined and tool.function is mapping %}
|
| 58 |
+
{%- set tool = tool.function %}
|
| 59 |
+
{%- endif %}
|
| 60 |
+
{{- '\n<function>\n<name>' ~ (tool.name | default('') | string) ~ '</name>' }}
|
| 61 |
+
{%- if tool.description is defined and tool.description is not none %}
|
| 62 |
+
{{- '\n<description>' ~ (tool.description | string | trim) ~ '</description>' }}
|
| 63 |
+
{%- endif %}
|
| 64 |
+
{{- '\n<parameters>' }}
|
| 65 |
+
{%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}
|
| 66 |
+
{%- for param_name, param_fields in tool.parameters.properties.items() %}
|
| 67 |
+
{{- '\n<parameter>\n<name>' ~ (param_name | string) ~ '</name>' }}
|
| 68 |
+
{%- if param_fields is mapping and param_fields.type is defined and param_fields.type is not none %}
|
| 69 |
+
{{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }}
|
| 70 |
+
{%- endif %}
|
| 71 |
+
{%- if param_fields is mapping and param_fields.description is defined and param_fields.description is not none %}
|
| 72 |
+
{{- '\n<description>' ~ (param_fields.description | string | trim) ~ '</description>' }}
|
| 73 |
+
{%- endif %}
|
| 74 |
+
{%- if param_fields is mapping %}
|
| 75 |
+
{%- set handled_keys = ['name', 'type', 'description'] %}
|
| 76 |
+
{{- render_extra_keys(param_fields, handled_keys) }}
|
| 77 |
+
{%- endif %}
|
| 78 |
+
{{- '\n</parameter>' }}
|
| 79 |
+
{%- endfor %}
|
| 80 |
+
{%- endif %}
|
| 81 |
+
{%- if tool.parameters is defined %}
|
| 82 |
+
{%- set handled_keys = ['type', 'properties'] %}
|
| 83 |
+
{{- render_extra_keys(tool.parameters, handled_keys) }}
|
| 84 |
+
{%- endif %}
|
| 85 |
+
{{- '\n</parameters>' }}
|
| 86 |
+
{%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}
|
| 87 |
+
{{- render_extra_keys(tool, handled_keys) }}
|
| 88 |
+
{{- '\n</function>' }}
|
| 89 |
+
{%- endfor %}
|
| 90 |
+
{{- "\n</tools>" }}
|
| 91 |
+
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
|
| 92 |
+
{%- endif %}
|
| 93 |
+
{{- '<|im_end|>\n' }}
|
| 94 |
+
{%- endif %}
|
| 95 |
+
|
| 96 |
+
{%- for message in loop_messages %}
|
| 97 |
+
{%- set role = message.role | default('') %}
|
| 98 |
+
{%- if role == "assistant" %}
|
| 99 |
+
{%- set content_str = '' if message.content is none else (message.content | string) %}
|
| 100 |
+
{%- set trimmed_content = content_str | trim %}
|
| 101 |
+
|
| 102 |
+
{%- set has_reasoning_content = message.reasoning_content is defined %}
|
| 103 |
+
{%- set has_reasoning = has_reasoning_content or (message.reasoning is defined) %}
|
| 104 |
+
|
| 105 |
+
{%- if has_reasoning_content %}
|
| 106 |
+
{%- set reasoning_value = message.reasoning_content %}
|
| 107 |
+
{%- elif message.reasoning is defined %}
|
| 108 |
+
{%- set reasoning_value = message.reasoning %}
|
| 109 |
+
{%- else %}
|
| 110 |
+
{%- set reasoning_value = none %}
|
| 111 |
+
{%- endif %}
|
| 112 |
+
|
| 113 |
+
{%- set has_tool_calls = message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls is not string and message.tool_calls | length > 0 %}
|
| 114 |
+
|
| 115 |
+
{{- '<|im_start|>assistant\n' }}
|
| 116 |
+
{%- if has_reasoning %}
|
| 117 |
+
{%- if reasoning_value %}
|
| 118 |
+
{{- '<think>' + (reasoning_value | string | trim) + '</think>' }}
|
| 119 |
+
{%- else %}
|
| 120 |
+
{{- '<think></think>' }}
|
| 121 |
+
{%- endif %}
|
| 122 |
+
{%- if trimmed_content %}
|
| 123 |
+
{{- '\n' + trimmed_content }}
|
| 124 |
+
{%- endif %}
|
| 125 |
+
{%- elif has_tool_calls %}
|
| 126 |
+
{%- if trimmed_content %}
|
| 127 |
+
{{- trimmed_content }}
|
| 128 |
+
{%- endif %}
|
| 129 |
+
{%- else %}
|
| 130 |
+
{{- content_str }}
|
| 131 |
+
{%- endif %}
|
| 132 |
+
|
| 133 |
+
{%- if has_tool_calls %}
|
| 134 |
+
{%- for tool_call in message.tool_calls %}
|
| 135 |
+
{%- set separator = '\n' if ((loop.first and (has_reasoning or trimmed_content)) or (not loop.first)) else '' -%}
|
| 136 |
+
{{- separator + render_tool_call(tool_call) }}
|
| 137 |
+
{%- endfor %}
|
| 138 |
+
{%- endif %}
|
| 139 |
+
{{- '<|im_end|>\n' }}
|
| 140 |
+
{%- elif role == "tool" or role == "observation" or role == "function" %}
|
| 141 |
+
{%- if loop.first or loop.previtem.role not in ["tool", "observation", "function"] %}
|
| 142 |
+
{{- '<|im_start|>user\n' }}
|
| 143 |
+
{%- endif %}
|
| 144 |
+
{{- '<tool_response>\n' }}
|
| 145 |
+
{{- '' if message.content is none else (message.content | string) }}
|
| 146 |
+
{{- '\n</tool_response>\n' }}
|
| 147 |
+
{%- if loop.last or loop.nextitem.role not in ["tool", "observation", "function"] %}
|
| 148 |
+
{{- '<|im_end|>\n' }}
|
| 149 |
+
{%- endif %}
|
| 150 |
+
{%- else %}
|
| 151 |
+
{{- '<|im_start|>' + (role | string) }}
|
| 152 |
+
{{- '\n' + ('' if message.content is none else (message.content | string)) }}
|
| 153 |
+
{{- '<|im_end|>\n' }}
|
| 154 |
+
{%- endif %}
|
| 155 |
+
{%- endfor %}
|
| 156 |
+
|
| 157 |
+
{%- if add_generation_prompt %}
|
| 158 |
+
{{- '<|im_start|>assistant\n<think>' }}
|
| 159 |
+
{%- endif %}
|
config.json
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"AfmoeForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_dropout": 0.0,
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_afmoe.AfmoeConfig",
|
| 8 |
+
"AutoModel": "modeling_afmoe.AfmoeModel",
|
| 9 |
+
"AutoModelForCausalLM": "modeling_afmoe.AfmoeForCausalLM"
|
| 10 |
+
},
|
| 11 |
+
"dtype": "bfloat16",
|
| 12 |
+
"global_attn_every_n_layers": 4,
|
| 13 |
+
"head_dim": 128,
|
| 14 |
+
"hidden_act": "silu",
|
| 15 |
+
"hidden_size": 3072,
|
| 16 |
+
"initializer_range": 0.02,
|
| 17 |
+
"intermediate_size": 12288,
|
| 18 |
+
"layer_types": [
|
| 19 |
+
"sliding_attention",
|
| 20 |
+
"sliding_attention",
|
| 21 |
+
"sliding_attention",
|
| 22 |
+
"full_attention",
|
| 23 |
+
"sliding_attention",
|
| 24 |
+
"sliding_attention",
|
| 25 |
+
"sliding_attention",
|
| 26 |
+
"full_attention",
|
| 27 |
+
"sliding_attention",
|
| 28 |
+
"sliding_attention",
|
| 29 |
+
"sliding_attention",
|
| 30 |
+
"full_attention",
|
| 31 |
+
"sliding_attention",
|
| 32 |
+
"sliding_attention",
|
| 33 |
+
"sliding_attention",
|
| 34 |
+
"full_attention",
|
| 35 |
+
"sliding_attention",
|
| 36 |
+
"sliding_attention",
|
| 37 |
+
"sliding_attention",
|
| 38 |
+
"full_attention",
|
| 39 |
+
"sliding_attention",
|
| 40 |
+
"sliding_attention",
|
| 41 |
+
"sliding_attention",
|
| 42 |
+
"full_attention",
|
| 43 |
+
"sliding_attention",
|
| 44 |
+
"sliding_attention",
|
| 45 |
+
"sliding_attention",
|
| 46 |
+
"full_attention",
|
| 47 |
+
"sliding_attention",
|
| 48 |
+
"sliding_attention",
|
| 49 |
+
"sliding_attention",
|
| 50 |
+
"full_attention",
|
| 51 |
+
"sliding_attention",
|
| 52 |
+
"sliding_attention",
|
| 53 |
+
"sliding_attention",
|
| 54 |
+
"full_attention",
|
| 55 |
+
"sliding_attention",
|
| 56 |
+
"sliding_attention",
|
| 57 |
+
"sliding_attention",
|
| 58 |
+
"full_attention",
|
| 59 |
+
"sliding_attention",
|
| 60 |
+
"sliding_attention",
|
| 61 |
+
"sliding_attention",
|
| 62 |
+
"full_attention",
|
| 63 |
+
"sliding_attention",
|
| 64 |
+
"sliding_attention",
|
| 65 |
+
"sliding_attention",
|
| 66 |
+
"full_attention",
|
| 67 |
+
"sliding_attention",
|
| 68 |
+
"sliding_attention",
|
| 69 |
+
"sliding_attention",
|
| 70 |
+
"full_attention",
|
| 71 |
+
"sliding_attention",
|
| 72 |
+
"sliding_attention",
|
| 73 |
+
"sliding_attention",
|
| 74 |
+
"full_attention",
|
| 75 |
+
"sliding_attention",
|
| 76 |
+
"sliding_attention",
|
| 77 |
+
"sliding_attention",
|
| 78 |
+
"full_attention"
|
| 79 |
+
],
|
| 80 |
+
"load_balance_coeff": 0.00005,
|
| 81 |
+
"max_position_embeddings": 262144,
|
| 82 |
+
"model_type": "afmoe",
|
| 83 |
+
"moe_intermediate_size": 3072,
|
| 84 |
+
"mup_enabled": true,
|
| 85 |
+
"n_group": 1,
|
| 86 |
+
"num_attention_heads": 48,
|
| 87 |
+
"num_dense_layers": 6,
|
| 88 |
+
"num_expert_groups": 1,
|
| 89 |
+
"num_experts": 256,
|
| 90 |
+
"num_experts_per_tok": 4,
|
| 91 |
+
"num_hidden_layers": 60,
|
| 92 |
+
"num_key_value_heads": 8,
|
| 93 |
+
"num_limited_groups": 1,
|
| 94 |
+
"num_shared_experts": 1,
|
| 95 |
+
"rms_norm_eps": 1e-05,
|
| 96 |
+
"rope_scaling": null,
|
| 97 |
+
"rope_theta": 10000,
|
| 98 |
+
"route_norm": true,
|
| 99 |
+
"route_scale": 2.448,
|
| 100 |
+
"score_func": "sigmoid",
|
| 101 |
+
"sliding_window": 4096,
|
| 102 |
+
"tie_word_embeddings": false,
|
| 103 |
+
"topk_group": 1,
|
| 104 |
+
"transformers_version": "4.57.1",
|
| 105 |
+
"use_cache": true,
|
| 106 |
+
"use_grouped_mm": true,
|
| 107 |
+
"vocab_size": 200192
|
| 108 |
+
}
|
configuration_afmoe.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 16 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 17 |
+
from transformers.configuration_utils import layer_type_validation
|
| 18 |
+
from transformers.utils import logging
|
| 19 |
+
|
| 20 |
+
logger = logging.get_logger(__name__)
|
| 21 |
+
|
| 22 |
+
class AfmoeConfig(PretrainedConfig):
|
| 23 |
+
"""
|
| 24 |
+
n_group (`int`, *optional*, defaults to 1):
|
| 25 |
+
Number of groups for routed experts.
|
| 26 |
+
topk_group (`int`, *optional*, defaults to 1):
|
| 27 |
+
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
|
| 28 |
+
"""
|
| 29 |
+
model_type = "afmoe"
|
| 30 |
+
base_model_pp_plan = {
|
| 31 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 32 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 33 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
num_hidden_layers: int = 32,
|
| 39 |
+
vocab_size: int = 200192,
|
| 40 |
+
hidden_size: int = 2048,
|
| 41 |
+
intermediate_size: int = 6144,
|
| 42 |
+
moe_intermediate_size=1408,
|
| 43 |
+
num_dense_layers=1,
|
| 44 |
+
num_attention_heads=16,
|
| 45 |
+
num_key_value_heads=None,
|
| 46 |
+
head_dim=128,
|
| 47 |
+
hidden_act="silu",
|
| 48 |
+
max_position_embeddings=16384,
|
| 49 |
+
initializer_range=0.02,
|
| 50 |
+
rms_norm_eps=1e-5,
|
| 51 |
+
use_cache=True,
|
| 52 |
+
tie_word_embeddings=False,
|
| 53 |
+
rope_theta=10000.0,
|
| 54 |
+
rope_scaling=None,
|
| 55 |
+
num_experts=64,
|
| 56 |
+
num_experts_per_tok=6,
|
| 57 |
+
num_shared_experts=2,
|
| 58 |
+
num_expert_groups=1,
|
| 59 |
+
num_limited_groups=1,
|
| 60 |
+
score_func="sigmoid",
|
| 61 |
+
route_norm=True,
|
| 62 |
+
route_scale=1.0,
|
| 63 |
+
global_attn_every_n_layers=4,
|
| 64 |
+
sliding_window=1024,
|
| 65 |
+
mup_enabled=False,
|
| 66 |
+
layer_types=None,
|
| 67 |
+
attention_dropout: float = 0.0,
|
| 68 |
+
n_group: int = 1,
|
| 69 |
+
topk_group: int = 1,
|
| 70 |
+
**kwargs,
|
| 71 |
+
):
|
| 72 |
+
self.vocab_size = vocab_size
|
| 73 |
+
self.max_position_embeddings = max_position_embeddings
|
| 74 |
+
self.hidden_size = hidden_size
|
| 75 |
+
self.intermediate_size = intermediate_size
|
| 76 |
+
self.num_hidden_layers = num_hidden_layers
|
| 77 |
+
self.num_dense_layers = num_dense_layers
|
| 78 |
+
self.num_attention_heads = num_attention_heads
|
| 79 |
+
self.head_dim = head_dim
|
| 80 |
+
self.hidden_act = hidden_act
|
| 81 |
+
self.initializer_range = initializer_range
|
| 82 |
+
self.rms_norm_eps = rms_norm_eps
|
| 83 |
+
self.use_cache = use_cache
|
| 84 |
+
self.rope_theta = rope_theta
|
| 85 |
+
self.rope_scaling = rope_scaling
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# MoE specific
|
| 89 |
+
self.moe_intermediate_size = moe_intermediate_size
|
| 90 |
+
self.num_experts_per_tok = num_experts_per_tok
|
| 91 |
+
self.n_group = n_group
|
| 92 |
+
self.topk_group = topk_group
|
| 93 |
+
self.num_experts = num_experts
|
| 94 |
+
self.num_shared_experts = num_shared_experts
|
| 95 |
+
self.num_expert_groups = num_expert_groups
|
| 96 |
+
self.num_limited_groups = num_limited_groups
|
| 97 |
+
self.score_func = score_func
|
| 98 |
+
self.route_norm = route_norm
|
| 99 |
+
self.route_scale = route_scale
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# Attention specific
|
| 103 |
+
self.attention_dropout = attention_dropout
|
| 104 |
+
self.global_attn_every_n_layers = global_attn_every_n_layers
|
| 105 |
+
self.sliding_window = sliding_window
|
| 106 |
+
self.layer_types = layer_types
|
| 107 |
+
if self.layer_types is None:
|
| 108 |
+
self.layer_types = [
|
| 109 |
+
"sliding_attention" if bool((i + 1) % global_attn_every_n_layers) else "full_attention" for i in range(self.num_hidden_layers)
|
| 110 |
+
]
|
| 111 |
+
layer_type_validation(self.layer_types)
|
| 112 |
+
|
| 113 |
+
# muP specific
|
| 114 |
+
self.mup_enabled = mup_enabled
|
| 115 |
+
|
| 116 |
+
if num_key_value_heads is None:
|
| 117 |
+
num_key_value_heads = num_attention_heads
|
| 118 |
+
|
| 119 |
+
self.num_key_value_heads = num_key_value_heads
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# Validate rope configs
|
| 123 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 124 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 125 |
+
rope_config_validation(self)
|
| 126 |
+
|
| 127 |
+
super().__init__(
|
| 128 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 129 |
+
**kwargs,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
__all__ = ["AfmoeConfig"]
|
generation_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 0,
|
| 4 |
+
"eos_token_id": 3,
|
| 5 |
+
"pad_token_id": 12,
|
| 6 |
+
"transformers_version": "4.57.3",
|
| 7 |
+
"temperature": 0.8,
|
| 8 |
+
"top_p": 0.8
|
| 9 |
+
}
|
model-00001-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:853c70c5c46ebef7fff4e96c2eda7b0a31f8415eab7e62f984aa4180aaed2ab9
|
| 3 |
+
size 2459965736
|
model-00002-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:33a07dd8ba9a4065c05de7e5ca2b3fb117b8847e5e926a73dac81f77bcf04493
|
| 3 |
+
size 704696408
|
model-00003-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8c4d625226b66f0f4560b2299b339bb1cab0faab7f269be57526d549fc6287f2
|
| 3 |
+
size 704696408
|
model-00004-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2c47e8a110787e454a6ee52ffead8903ffd9a42804a6bfd65448a8864fcb949e
|
| 3 |
+
size 704696408
|
model-00005-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:290a75fc94da91c359c345e8d760d342e38a0c3b8bdfc138bfd5ff0a6e395bbd
|
| 3 |
+
size 29359329168
|
model-00006-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:29e9b301281696ff948a03c6cf1570fb221daeb0efa4e7793bcb024ef4e5c7f1
|
| 3 |
+
size 29359329168
|
model-00007-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f9e5f8039d277c892b5901bf39b8ec195652fad2ce561142660e70a2118c1de3
|
| 3 |
+
size 29359330736
|
model-00008-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3b3cf1e2b37d0f37efc818753f7a6b5a3a6a34d5d40799f7861d6ae7f2f50cdb
|
| 3 |
+
size 29359330736
|
model-00009-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d5619671ed44cdb8e8e0e014825b0166364c5fda1cdb16e9df8bc30f369592d4
|
| 3 |
+
size 29359330736
|
model-00010-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:60d3ca7fbd2f5c0e9fdc947a22d88ce92e28a6b5a870a68618eab5160673955c
|
| 3 |
+
size 29359330736
|
model-00011-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:604d743dc7de64413191bf367d9724d042b5a760f073e480fe8de07f4cbbe875
|
| 3 |
+
size 29359330736
|
model-00012-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:07bad9551c1a538461a5e47c5660950bfcd1a06fb5763664c10dc295024d186f
|
| 3 |
+
size 29359330736
|
model-00013-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d7347d5299a1af7e96c705d75082c3d2f6912f07019281cc472565040f4ac401
|
| 3 |
+
size 29359330736
|
model-00014-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dcf043b318b569fff0a1bdfa4d4868170cd08b2c0138e48516ef87b4eadb9d2b
|
| 3 |
+
size 29359330736
|
model-00015-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:513c803339e049cd378333a0463d42517b4dc9614d14cf3ad030cc769f4caab0
|
| 3 |
+
size 29359330736
|
model-00016-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:55c3326f636d40910ab9f5d5e3e4295b695ab6b920fafb55833dde0c15fbe218
|
| 3 |
+
size 29359330736
|
model-00017-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ff887d5ed74cb31778e7ffca3c68994c12d58aca5498daf8f73ec11315bd871a
|
| 3 |
+
size 29359330736
|
model-00018-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2ad6ef2e4ce52c22a75a88e2e1701bc9c00381f5193759bd1db48e08865c93f8
|
| 3 |
+
size 29359330736
|
model-00019-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4860ab9597c2cd8b4239762045de948a4022e748eb19af6caf917b919c9635ac
|
| 3 |
+
size 29359330736
|
model-00020-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7bf756f75c191ddf476814e2c61d77640f6fa4d78880200f56e9fc6d8b31ca2f
|
| 3 |
+
size 29359330736
|
model-00021-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f05bb8b1018a485d0081465586ae362127e731b07f452c79c648f87cb2cddfaa
|
| 3 |
+
size 29359330736
|
model-00022-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7d498f7ba5144e18c0685058a077aeb169a56b4bdf9a5bfe8df38650f868c44c
|
| 3 |
+
size 29359330736
|
model-00023-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d7edaa4f5e19ce9941b404dd184a2f2c6c1f7b231b3fed9d92a57d442843ac4a
|
| 3 |
+
size 29359330736
|
model-00024-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:79ce79cf55200e28dd3dd1ba232aab986cae96619e12da880e8da08506e17dea
|
| 3 |
+
size 29359330736
|
model-00025-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4b101d2e8b1d76865e863fa3b8854839e18e3e98c194d4844e8d99a10d36b1a2
|
| 3 |
+
size 29359330736
|
model-00026-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ab10a3fee4a0c6f6da9e697d64b4cd63557018f5a398118d04b02b734b93de6c
|
| 3 |
+
size 29359330736
|
model-00027-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fee6d0d7d3abb8148c5870feab0949be5264e4d6b45a46c1f68e5b195448dcf3
|
| 3 |
+
size 29359330736
|
model-00028-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1c56fffdea546c522d14cb0afd038b6e33f04dffb479fdeef04c3fdaa03c4520
|
| 3 |
+
size 29359330736
|
model-00029-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:476eaed8aa9d5f86ffeac3778ead337513513b45be6cbfbf496074f2743af91f
|
| 3 |
+
size 29359330736
|
model-00030-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:90b37fb9b43e31963a0df04408be4c285e75fe2a4c8aa57a62097a6a2721aee4
|
| 3 |
+
size 29359330736
|
model-00031-of-00031.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c2fa868e56a5ffe5797cddc5db0931f2bb6fae313338935e4cea19f589eb88d1
|
| 3 |
+
size 29359330736
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_afmoe.py
ADDED
|
@@ -0,0 +1,680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from transformers.activations import ACT2FN
|
| 8 |
+
from transformers.generation import GenerationMixin
|
| 9 |
+
from transformers.modeling_outputs import (
|
| 10 |
+
MoeCausalLMOutputWithPast,
|
| 11 |
+
MoeModelOutputWithPast,
|
| 12 |
+
)
|
| 13 |
+
from transformers.modeling_utils import PreTrainedModel, ALL_ATTENTION_FUNCTIONS
|
| 14 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 15 |
+
from transformers.masking_utils import (
|
| 16 |
+
create_causal_mask,
|
| 17 |
+
create_sliding_window_causal_mask,
|
| 18 |
+
)
|
| 19 |
+
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 20 |
+
from transformers.processing_utils import Unpack
|
| 21 |
+
from transformers.utils import TransformersKwargs
|
| 22 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 23 |
+
from transformers.integrations import use_kernel_forward_from_hub
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
from .configuration_afmoe import AfmoeConfig
|
| 28 |
+
except:
|
| 29 |
+
from configuration_afmoe import AfmoeConfig
|
| 30 |
+
|
| 31 |
+
class AfmoeRotaryEmbedding(nn.Module):
|
| 32 |
+
|
| 33 |
+
def __init__(self, config: AfmoeConfig, device=None):
|
| 34 |
+
super().__init__()
|
| 35 |
+
# BC: "rope_type" was originally "type"
|
| 36 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 37 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 38 |
+
else:
|
| 39 |
+
self.rope_type = "default"
|
| 40 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 41 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 42 |
+
|
| 43 |
+
self.config = config
|
| 44 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 45 |
+
|
| 46 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 47 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 48 |
+
self.original_inv_freq = self.inv_freq
|
| 49 |
+
|
| 50 |
+
def _dynamic_frequency_update(self, position_ids, device):
|
| 51 |
+
"""
|
| 52 |
+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
| 53 |
+
1 - growing beyond the cached sequence length (allow scaling)
|
| 54 |
+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
| 55 |
+
"""
|
| 56 |
+
seq_len = torch.max(position_ids) + 1
|
| 57 |
+
if seq_len > self.max_seq_len_cached: # growth
|
| 58 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
| 59 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
| 60 |
+
self.max_seq_len_cached = seq_len
|
| 61 |
+
|
| 62 |
+
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
| 63 |
+
# This .to() is needed if the model has been moved to a device after being initialized (because
|
| 64 |
+
# the buffer is automatically moved, but not the original copy)
|
| 65 |
+
self.original_inv_freq = self.original_inv_freq.to(device)
|
| 66 |
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
| 67 |
+
self.max_seq_len_cached = self.original_max_seq_len
|
| 68 |
+
|
| 69 |
+
@torch.no_grad()
|
| 70 |
+
def forward(self, x, position_ids):
|
| 71 |
+
if "dynamic" in self.rope_type:
|
| 72 |
+
self._dynamic_frequency_update(position_ids, device=x.device)
|
| 73 |
+
|
| 74 |
+
# Core RoPE block
|
| 75 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 76 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 77 |
+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
| 78 |
+
device_type = x.device.type
|
| 79 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
| 80 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 81 |
+
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
| 82 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 83 |
+
cos = emb.cos()
|
| 84 |
+
sin = emb.sin()
|
| 85 |
+
|
| 86 |
+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
| 87 |
+
cos = cos * self.attention_scaling
|
| 88 |
+
sin = sin * self.attention_scaling
|
| 89 |
+
|
| 90 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def rotate_half(x):
|
| 94 |
+
"""Rotates half the hidden dims of the input."""
|
| 95 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 96 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 97 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 101 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
q (`torch.Tensor`): The query tensor.
|
| 105 |
+
k (`torch.Tensor`): The key tensor.
|
| 106 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 107 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 108 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 109 |
+
Deprecated and unused.
|
| 110 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 111 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 112 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 113 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 114 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 115 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 116 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 117 |
+
Returns:
|
| 118 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 119 |
+
"""
|
| 120 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 121 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 122 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 123 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 124 |
+
return q_embed, k_embed
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 128 |
+
"""
|
| 129 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 130 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 131 |
+
"""
|
| 132 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 133 |
+
if n_rep == 1:
|
| 134 |
+
return hidden_states
|
| 135 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
| 136 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
| 137 |
+
)
|
| 138 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 139 |
+
|
| 140 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 141 |
+
class AfmoeRMSNorm(nn.Module):
|
| 142 |
+
def __init__(self, hidden_size: int, eps: float):
|
| 143 |
+
"""
|
| 144 |
+
AfmoeRMSNorm is equivalent to T5LayerNorm
|
| 145 |
+
"""
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 148 |
+
self.variance_epsilon = eps
|
| 149 |
+
|
| 150 |
+
def forward(self, hidden_states):
|
| 151 |
+
input_dtype = hidden_states.dtype
|
| 152 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 153 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 154 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 155 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 156 |
+
|
| 157 |
+
def extra_repr(self):
|
| 158 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def eager_attention_forward(
|
| 163 |
+
module: nn.Module,
|
| 164 |
+
query: torch.Tensor,
|
| 165 |
+
key: torch.Tensor,
|
| 166 |
+
value: torch.Tensor,
|
| 167 |
+
attention_mask: Optional[torch.Tensor],
|
| 168 |
+
scaling: float,
|
| 169 |
+
dropout: float = 0.0,
|
| 170 |
+
**kwargs,
|
| 171 |
+
):
|
| 172 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 173 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 174 |
+
|
| 175 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 176 |
+
if attention_mask is not None:
|
| 177 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 178 |
+
attn_weights = attn_weights + causal_mask
|
| 179 |
+
|
| 180 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
| 181 |
+
query.dtype
|
| 182 |
+
)
|
| 183 |
+
attn_weights = nn.functional.dropout(
|
| 184 |
+
attn_weights, p=dropout, training=module.training
|
| 185 |
+
)
|
| 186 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 187 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 188 |
+
|
| 189 |
+
return attn_output, attn_weights
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class AfmoeMLP(nn.Module):
|
| 193 |
+
def __init__(self, config, intermediate_size=None):
|
| 194 |
+
super().__init__()
|
| 195 |
+
self.config = config
|
| 196 |
+
self.hidden_size = config.hidden_size
|
| 197 |
+
self.intermediate_size = intermediate_size or config.intermediate_size
|
| 198 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 199 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 200 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 201 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 202 |
+
|
| 203 |
+
def forward(self, x):
|
| 204 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class AfmoeTokenChoiceRouter(nn.Module):
|
| 208 |
+
"""Token-choice top-K router for MoE routing."""
|
| 209 |
+
|
| 210 |
+
def __init__(self, config):
|
| 211 |
+
super().__init__()
|
| 212 |
+
self.config = config
|
| 213 |
+
self.top_k = config.num_experts_per_tok
|
| 214 |
+
self.num_experts = config.num_experts
|
| 215 |
+
self.score_func = config.score_func
|
| 216 |
+
self.route_norm = config.route_norm
|
| 217 |
+
self.route_scale = config.route_scale
|
| 218 |
+
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
| 219 |
+
|
| 220 |
+
def forward(self, hidden_states, expert_bias: torch.Tensor | None):
|
| 221 |
+
_, _, hidden_dim = hidden_states.shape
|
| 222 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 223 |
+
|
| 224 |
+
scores = self.gate(hidden_states)
|
| 225 |
+
|
| 226 |
+
# Apply scoring function in float32 for stability
|
| 227 |
+
if self.score_func == "sigmoid":
|
| 228 |
+
scores = torch.sigmoid(scores.to(torch.float32))
|
| 229 |
+
else:
|
| 230 |
+
scores = F.softmax(scores.to(torch.float32), dim=-1)
|
| 231 |
+
|
| 232 |
+
if expert_bias is not None:
|
| 233 |
+
_, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
|
| 234 |
+
top_scores = scores.gather(dim=1, index=selected_experts)
|
| 235 |
+
else:
|
| 236 |
+
top_scores, selected_experts = torch.topk(scores, k=self.top_k, dim=1)
|
| 237 |
+
|
| 238 |
+
# Normalize weights if using sigmoid
|
| 239 |
+
if self.score_func == "sigmoid" and self.route_norm:
|
| 240 |
+
denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20
|
| 241 |
+
top_scores = top_scores / denominator
|
| 242 |
+
|
| 243 |
+
top_scores = top_scores * self.route_scale
|
| 244 |
+
return top_scores, selected_experts
|
| 245 |
+
|
| 246 |
+
class AfmoeMoE(nn.Module):
|
| 247 |
+
def __init__(self, config):
|
| 248 |
+
super().__init__()
|
| 249 |
+
self.config = config
|
| 250 |
+
self.router = AfmoeTokenChoiceRouter(config)
|
| 251 |
+
|
| 252 |
+
self.shared_experts = None
|
| 253 |
+
if config.num_shared_experts > 0:
|
| 254 |
+
self.shared_experts = AfmoeMLP(
|
| 255 |
+
config, config.moe_intermediate_size * config.num_shared_experts
|
| 256 |
+
)
|
| 257 |
+
self.experts = nn.ModuleList(
|
| 258 |
+
[AfmoeMLP(
|
| 259 |
+
config, intermediate_size=config.moe_intermediate_size
|
| 260 |
+
) for _ in range(config.num_experts)]
|
| 261 |
+
)
|
| 262 |
+
self.expert_bias = nn.Parameter(torch.zeros(config.num_experts, dtype=torch.float32), requires_grad=False)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def forward(self, hidden_states):
|
| 266 |
+
batch_size, seq_len, hidden_dim = hidden_states.shape
|
| 267 |
+
hidden_states_flat = hidden_states.view(-1, hidden_dim)
|
| 268 |
+
|
| 269 |
+
# Get routing decisions
|
| 270 |
+
top_scores, selected_experts = self.router(hidden_states, self.expert_bias)
|
| 271 |
+
|
| 272 |
+
# Process through shared experts
|
| 273 |
+
if self.shared_experts is not None:
|
| 274 |
+
shared_output = self.shared_experts(hidden_states_flat)
|
| 275 |
+
else:
|
| 276 |
+
shared_output = torch.zeros_like(hidden_states_flat)
|
| 277 |
+
|
| 278 |
+
# Reorder tokens by expert for efficient processing
|
| 279 |
+
token_indices_sorted = torch.argsort(selected_experts.view(-1), stable=True)
|
| 280 |
+
top_scores_sorted = top_scores.view(-1)[token_indices_sorted]
|
| 281 |
+
token_to_expert = selected_experts.view(-1)[token_indices_sorted]
|
| 282 |
+
token_indices_sorted = token_indices_sorted // self.config.num_experts_per_tok
|
| 283 |
+
|
| 284 |
+
# Gather input tokens
|
| 285 |
+
token_indices_expanded = token_indices_sorted.unsqueeze(-1).expand(
|
| 286 |
+
-1, hidden_dim
|
| 287 |
+
)
|
| 288 |
+
routed_input = torch.gather(
|
| 289 |
+
hidden_states_flat, dim=0, index=token_indices_expanded
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
routed_output = torch.zeros_like(routed_input)
|
| 293 |
+
for expert_id in range(self.config.num_experts):
|
| 294 |
+
mask = token_to_expert == expert_id
|
| 295 |
+
if mask.any():
|
| 296 |
+
expert_input = routed_input[mask]
|
| 297 |
+
expert_out = self.experts[expert_id](expert_input)
|
| 298 |
+
routed_output[mask] = expert_out
|
| 299 |
+
|
| 300 |
+
routed_output = (
|
| 301 |
+
routed_output.to(torch.float32) * top_scores_sorted.unsqueeze(-1)
|
| 302 |
+
).to(hidden_states.dtype)
|
| 303 |
+
|
| 304 |
+
# Scatter back to original positions
|
| 305 |
+
output = shared_output.scatter_add(
|
| 306 |
+
dim=0, index=token_indices_expanded, src=routed_output
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
return output.view(batch_size, seq_len, hidden_dim)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class AfmoeAttention(nn.Module):
|
| 313 |
+
"""Multi-headed attention with local/global pattern and gating."""
|
| 314 |
+
|
| 315 |
+
def __init__(self, config: AfmoeConfig, layer_idx: int):
|
| 316 |
+
super().__init__()
|
| 317 |
+
self.config = config
|
| 318 |
+
self.layer_idx = layer_idx
|
| 319 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 320 |
+
self.num_heads = config.num_attention_heads
|
| 321 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 322 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 323 |
+
|
| 324 |
+
self.scaling = self.head_dim**-0.5
|
| 325 |
+
self.attention_dropout = config.attention_dropout
|
| 326 |
+
self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention"
|
| 327 |
+
self.sliding_window = config.sliding_window if self.is_local_attention else None
|
| 328 |
+
|
| 329 |
+
self.q_proj = nn.Linear(
|
| 330 |
+
config.hidden_size, self.num_heads * self.head_dim, bias=False
|
| 331 |
+
)
|
| 332 |
+
self.k_proj = nn.Linear(
|
| 333 |
+
config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
| 334 |
+
)
|
| 335 |
+
self.v_proj = nn.Linear(
|
| 336 |
+
config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
| 337 |
+
)
|
| 338 |
+
self.o_proj = nn.Linear(
|
| 339 |
+
self.num_heads * self.head_dim, config.hidden_size, bias=False
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
self.q_norm = AfmoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 343 |
+
self.k_norm = AfmoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 344 |
+
|
| 345 |
+
self.gate_proj = nn.Linear(
|
| 346 |
+
config.hidden_size, self.num_heads * self.head_dim, bias=False
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
def forward(
|
| 350 |
+
self,
|
| 351 |
+
hidden_states: torch.Tensor,
|
| 352 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 353 |
+
attention_mask: Optional[torch.Tensor],
|
| 354 |
+
past_key_value: Optional[Cache] = None,
|
| 355 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 356 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 357 |
+
) -> torch.Tensor:
|
| 358 |
+
|
| 359 |
+
input_shape = hidden_states.shape[:-1]
|
| 360 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 361 |
+
|
| 362 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape)
|
| 363 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
| 364 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape)
|
| 365 |
+
gate_states = self.gate_proj(hidden_states)
|
| 366 |
+
|
| 367 |
+
query_states = self.q_norm(query_states)
|
| 368 |
+
key_states = self.k_norm(key_states)
|
| 369 |
+
|
| 370 |
+
query_states = query_states.transpose(1, 2)
|
| 371 |
+
key_states = key_states.transpose(1, 2)
|
| 372 |
+
value_states = value_states.transpose(1, 2)
|
| 373 |
+
|
| 374 |
+
if self.is_local_attention:
|
| 375 |
+
cos, sin = position_embeddings
|
| 376 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 377 |
+
|
| 378 |
+
if past_key_value is not None:
|
| 379 |
+
cache_kwargs = {"cache_position": cache_position}
|
| 380 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 381 |
+
|
| 382 |
+
attention_interface: Callable = eager_attention_forward
|
| 383 |
+
if self.config._attn_implementation != "eager":
|
| 384 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[
|
| 385 |
+
self.config._attn_implementation
|
| 386 |
+
]
|
| 387 |
+
|
| 388 |
+
output, _ = attention_interface(
|
| 389 |
+
self,
|
| 390 |
+
query_states,
|
| 391 |
+
key_states,
|
| 392 |
+
value_states,
|
| 393 |
+
attention_mask=attention_mask,
|
| 394 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 395 |
+
scaling=self.scaling,
|
| 396 |
+
sliding_window=self.sliding_window,
|
| 397 |
+
**kwargs,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
output = output.view(*input_shape, -1).contiguous()
|
| 401 |
+
output = output * F.sigmoid(gate_states)
|
| 402 |
+
return self.o_proj(output)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class AfmoeDecoderLayer(GradientCheckpointingLayer):
|
| 406 |
+
def __init__(self, config: AfmoeConfig, layer_idx: int):
|
| 407 |
+
super().__init__()
|
| 408 |
+
self.hidden_size = config.hidden_size
|
| 409 |
+
self.layer_idx = layer_idx
|
| 410 |
+
|
| 411 |
+
self.self_attn = AfmoeAttention(config=config, layer_idx=layer_idx)
|
| 412 |
+
self.attention_type = config.layer_types[layer_idx]
|
| 413 |
+
|
| 414 |
+
# Dual normalization for attention
|
| 415 |
+
self.input_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 416 |
+
self.post_attention_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 417 |
+
|
| 418 |
+
# Dual normalization for FFN
|
| 419 |
+
self.pre_mlp_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 420 |
+
self.post_mlp_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 421 |
+
|
| 422 |
+
# MoE or dense FFN
|
| 423 |
+
self.moe_enabled = layer_idx >= config.num_dense_layers
|
| 424 |
+
if self.moe_enabled:
|
| 425 |
+
self.mlp = AfmoeMoE(config)
|
| 426 |
+
else:
|
| 427 |
+
self.mlp = AfmoeMLP(config)
|
| 428 |
+
|
| 429 |
+
def forward(
|
| 430 |
+
self,
|
| 431 |
+
hidden_states: torch.Tensor,
|
| 432 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 433 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 434 |
+
past_key_value: Optional[Cache] = None,
|
| 435 |
+
use_cache: Optional[bool] = None,
|
| 436 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 437 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 438 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 439 |
+
) -> torch.FloatTensor:
|
| 440 |
+
residual = hidden_states
|
| 441 |
+
|
| 442 |
+
# Self Attention with dual normalization
|
| 443 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 444 |
+
hidden_states = self.self_attn(
|
| 445 |
+
hidden_states=hidden_states,
|
| 446 |
+
attention_mask=attention_mask,
|
| 447 |
+
position_ids=position_ids,
|
| 448 |
+
past_key_value=past_key_value,
|
| 449 |
+
use_cache=use_cache,
|
| 450 |
+
cache_position=cache_position,
|
| 451 |
+
position_embeddings=position_embeddings,
|
| 452 |
+
**kwargs,
|
| 453 |
+
)
|
| 454 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 455 |
+
hidden_states = residual + hidden_states
|
| 456 |
+
|
| 457 |
+
# FFN with dual normalization
|
| 458 |
+
residual = hidden_states
|
| 459 |
+
hidden_states = self.pre_mlp_layernorm(hidden_states)
|
| 460 |
+
|
| 461 |
+
if self.moe_enabled:
|
| 462 |
+
hidden_states = self.mlp(hidden_states)
|
| 463 |
+
else:
|
| 464 |
+
hidden_states = self.mlp(hidden_states)
|
| 465 |
+
|
| 466 |
+
hidden_states = self.post_mlp_layernorm(hidden_states)
|
| 467 |
+
hidden_states = residual + hidden_states
|
| 468 |
+
return hidden_states
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
class AfmoePreTrainedModel(PreTrainedModel):
|
| 472 |
+
config_class = AfmoeConfig
|
| 473 |
+
base_model_prefix = "model"
|
| 474 |
+
_no_split_modules = ["AfmoeDecoderLayer"]
|
| 475 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 476 |
+
_keep_in_fp32_modules = [
|
| 477 |
+
"input_layernorm",
|
| 478 |
+
"post_attention_layernorm",
|
| 479 |
+
"pre_mlp_layernorm",
|
| 480 |
+
"post_mlp_layernorm",
|
| 481 |
+
"q_norm",
|
| 482 |
+
"k_norm",
|
| 483 |
+
"norm",
|
| 484 |
+
]
|
| 485 |
+
_supports_sdpa = True
|
| 486 |
+
_supports_attention_backend = True
|
| 487 |
+
supports_gradient_checkpointing = True
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
class AfmoeModel(AfmoePreTrainedModel):
|
| 491 |
+
_no_split_modules = ["AfmoeDecoderLayer"]
|
| 492 |
+
|
| 493 |
+
def __init__(self, config: AfmoeConfig):
|
| 494 |
+
super().__init__(config)
|
| 495 |
+
self.padding_idx = config.pad_token_id
|
| 496 |
+
self.vocab_size = config.vocab_size
|
| 497 |
+
|
| 498 |
+
self.embed_tokens = nn.Embedding(
|
| 499 |
+
config.vocab_size, config.hidden_size, self.padding_idx
|
| 500 |
+
)
|
| 501 |
+
self.layers = nn.ModuleList(
|
| 502 |
+
[
|
| 503 |
+
AfmoeDecoderLayer(config, layer_idx)
|
| 504 |
+
for layer_idx in range(config.num_hidden_layers)
|
| 505 |
+
]
|
| 506 |
+
)
|
| 507 |
+
self.norm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 508 |
+
self.rotary_emb = AfmoeRotaryEmbedding(config=config)
|
| 509 |
+
self.gradient_checkpointing = False
|
| 510 |
+
|
| 511 |
+
self.post_init()
|
| 512 |
+
|
| 513 |
+
def get_input_embeddings(self):
|
| 514 |
+
return self.embed_tokens
|
| 515 |
+
|
| 516 |
+
def set_input_embeddings(self, value):
|
| 517 |
+
self.embed_tokens = value
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def forward(
|
| 521 |
+
self,
|
| 522 |
+
input_ids: torch.LongTensor,
|
| 523 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 524 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 525 |
+
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
| 526 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 527 |
+
use_cache: Optional[bool] = None,
|
| 528 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 529 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 530 |
+
) -> MoeModelOutputWithPast:
|
| 531 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 532 |
+
raise ValueError(
|
| 533 |
+
"You must specify exactly one of input_ids or inputs_embeds"
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
if use_cache and past_key_values is None:
|
| 537 |
+
past_key_values = DynamicCache()
|
| 538 |
+
|
| 539 |
+
if inputs_embeds is None:
|
| 540 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 541 |
+
|
| 542 |
+
if cache_position is None:
|
| 543 |
+
past_seen_tokens = (
|
| 544 |
+
past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 545 |
+
)
|
| 546 |
+
cache_position = torch.arange(
|
| 547 |
+
past_seen_tokens,
|
| 548 |
+
past_seen_tokens + inputs_embeds.shape[1],
|
| 549 |
+
device=inputs_embeds.device,
|
| 550 |
+
)
|
| 551 |
+
if position_ids is None:
|
| 552 |
+
position_ids = cache_position.unsqueeze(0)
|
| 553 |
+
|
| 554 |
+
# It may already have been prepared by e.g. `generate`
|
| 555 |
+
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 556 |
+
mask_kwargs = {
|
| 557 |
+
"config": self.config,
|
| 558 |
+
"input_embeds": inputs_embeds,
|
| 559 |
+
"attention_mask": attention_mask,
|
| 560 |
+
"cache_position": cache_position,
|
| 561 |
+
"past_key_values": past_key_values,
|
| 562 |
+
}
|
| 563 |
+
causal_mask_mapping = {
|
| 564 |
+
"full_attention": create_causal_mask(**mask_kwargs),
|
| 565 |
+
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
hidden_states = inputs_embeds
|
| 569 |
+
|
| 570 |
+
# Apply muP input scaling if enabled
|
| 571 |
+
if self.config.mup_enabled:
|
| 572 |
+
hidden_states = hidden_states * (self.config.hidden_size**0.5)
|
| 573 |
+
|
| 574 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 575 |
+
|
| 576 |
+
for decoder_layer in self.layers:
|
| 577 |
+
hidden_states = decoder_layer(
|
| 578 |
+
hidden_states,
|
| 579 |
+
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
| 580 |
+
position_ids=position_ids,
|
| 581 |
+
past_key_value=past_key_values,
|
| 582 |
+
use_cache=use_cache,
|
| 583 |
+
cache_position=cache_position,
|
| 584 |
+
position_embeddings=position_embeddings,
|
| 585 |
+
**kwargs,
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
hidden_states = self.norm(hidden_states)
|
| 589 |
+
return MoeModelOutputWithPast(
|
| 590 |
+
last_hidden_state=hidden_states,
|
| 591 |
+
past_key_values=past_key_values,
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
class AfmoeForCausalLM(AfmoePreTrainedModel, GenerationMixin):
|
| 596 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 597 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 598 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 599 |
+
|
| 600 |
+
def __init__(self, config):
|
| 601 |
+
super().__init__(config)
|
| 602 |
+
self.model = AfmoeModel(config)
|
| 603 |
+
self.vocab_size = config.vocab_size
|
| 604 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 605 |
+
|
| 606 |
+
# Initialize weights and apply final processing
|
| 607 |
+
self.post_init()
|
| 608 |
+
|
| 609 |
+
def get_input_embeddings(self):
|
| 610 |
+
return self.model.embed_tokens
|
| 611 |
+
|
| 612 |
+
def set_input_embeddings(self, value):
|
| 613 |
+
self.model.embed_tokens = value
|
| 614 |
+
|
| 615 |
+
def get_output_embeddings(self):
|
| 616 |
+
return self.lm_head
|
| 617 |
+
|
| 618 |
+
def set_output_embeddings(self, new_embeddings):
|
| 619 |
+
self.lm_head = new_embeddings
|
| 620 |
+
|
| 621 |
+
def set_decoder(self, decoder):
|
| 622 |
+
self.model = decoder
|
| 623 |
+
|
| 624 |
+
def get_decoder(self):
|
| 625 |
+
return self.model
|
| 626 |
+
|
| 627 |
+
def forward(
|
| 628 |
+
self,
|
| 629 |
+
input_ids: torch.LongTensor,
|
| 630 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 631 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 632 |
+
past_key_values: Optional[Cache] = None,
|
| 633 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 634 |
+
labels: Optional[torch.LongTensor] = None,
|
| 635 |
+
use_cache: Optional[bool] = None,
|
| 636 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 637 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 638 |
+
token_type_ids: Optional[torch.Tensor] = None, # will be ignored
|
| 639 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 640 |
+
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
| 641 |
+
outputs: MoeModelOutputWithPast = self.model(
|
| 642 |
+
input_ids=input_ids,
|
| 643 |
+
attention_mask=attention_mask,
|
| 644 |
+
position_ids=position_ids,
|
| 645 |
+
past_key_values=past_key_values,
|
| 646 |
+
inputs_embeds=inputs_embeds,
|
| 647 |
+
use_cache=use_cache,
|
| 648 |
+
cache_position=cache_position,
|
| 649 |
+
**kwargs,
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
hidden_states = outputs.last_hidden_state
|
| 653 |
+
# Only compute necessary logits
|
| 654 |
+
slice_indices = (
|
| 655 |
+
slice(-logits_to_keep, None)
|
| 656 |
+
if isinstance(logits_to_keep, int)
|
| 657 |
+
else logits_to_keep
|
| 658 |
+
)
|
| 659 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 660 |
+
|
| 661 |
+
loss = None
|
| 662 |
+
if labels is not None:
|
| 663 |
+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
return MoeCausalLMOutputWithPast(
|
| 667 |
+
loss=loss,
|
| 668 |
+
logits=logits,
|
| 669 |
+
past_key_values=outputs.past_key_values,
|
| 670 |
+
hidden_states=outputs.hidden_states,
|
| 671 |
+
attentions=outputs.attentions,
|
| 672 |
+
router_logits=outputs.router_logits,
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
__all__ = [
|
| 677 |
+
"AfmoeForCausalLM",
|
| 678 |
+
"AfmoeModel",
|
| 679 |
+
"AfmoePreTrainedModel",
|
| 680 |
+
]
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<|begin_of_text|>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "<|im_end|>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "<|pad|>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
}
|
| 23 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c5a93d847b4d3a1da95e9527c30ec10144f63a823e9feec98570274980754897
|
| 3 |
+
size 14614487
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"add_prefix_space": null,
|
| 5 |
+
"added_tokens_decoder": {
|
| 6 |
+
"0": {
|
| 7 |
+
"content": "<|begin_of_text|>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false,
|
| 12 |
+
"special": true
|
| 13 |
+
},
|
| 14 |
+
"1": {
|
| 15 |
+
"content": "<|end_of_text|>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false,
|
| 20 |
+
"special": true
|
| 21 |
+
},
|
| 22 |
+
"2": {
|
| 23 |
+
"content": "<|im_start|>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"special": true
|
| 29 |
+
},
|
| 30 |
+
"3": {
|
| 31 |
+
"content": "<|im_end|>",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false,
|
| 36 |
+
"special": true
|
| 37 |
+
},
|
| 38 |
+
"4": {
|
| 39 |
+
"content": "<name>",
|
| 40 |
+
"lstrip": false,
|
| 41 |
+
"normalized": false,
|
| 42 |
+
"rstrip": false,
|
| 43 |
+
"single_word": false,
|
| 44 |
+
"special": false
|
| 45 |
+
},
|
| 46 |
+
"5": {
|
| 47 |
+
"content": "</name>",
|
| 48 |
+
"lstrip": false,
|
| 49 |
+
"normalized": false,
|
| 50 |
+
"rstrip": false,
|
| 51 |
+
"single_word": false,
|
| 52 |
+
"special": false
|
| 53 |
+
},
|
| 54 |
+
"6": {
|
| 55 |
+
"content": "<description>",
|
| 56 |
+
"lstrip": false,
|
| 57 |
+
"normalized": false,
|
| 58 |
+
"rstrip": false,
|
| 59 |
+
"single_word": false,
|
| 60 |
+
"special": false
|
| 61 |
+
},
|
| 62 |
+
"7": {
|
| 63 |
+
"content": "</description>",
|
| 64 |
+
"lstrip": false,
|
| 65 |
+
"normalized": false,
|
| 66 |
+
"rstrip": false,
|
| 67 |
+
"single_word": false,
|
| 68 |
+
"special": false
|
| 69 |
+
},
|
| 70 |
+
"8": {
|
| 71 |
+
"content": "<parameters>",
|
| 72 |
+
"lstrip": false,
|
| 73 |
+
"normalized": false,
|
| 74 |
+
"rstrip": false,
|
| 75 |
+
"single_word": false,
|
| 76 |
+
"special": false
|
| 77 |
+
},
|
| 78 |
+
"9": {
|
| 79 |
+
"content": "</parameters>",
|
| 80 |
+
"lstrip": false,
|
| 81 |
+
"normalized": false,
|
| 82 |
+
"rstrip": false,
|
| 83 |
+
"single_word": false,
|
| 84 |
+
"special": false
|
| 85 |
+
},
|
| 86 |
+
"10": {
|
| 87 |
+
"content": "<type>",
|
| 88 |
+
"lstrip": false,
|
| 89 |
+
"normalized": false,
|
| 90 |
+
"rstrip": false,
|
| 91 |
+
"single_word": false,
|
| 92 |
+
"special": false
|
| 93 |
+
},
|
| 94 |
+
"11": {
|
| 95 |
+
"content": "</type>",
|
| 96 |
+
"lstrip": false,
|
| 97 |
+
"normalized": false,
|
| 98 |
+
"rstrip": false,
|
| 99 |
+
"single_word": false,
|
| 100 |
+
"special": false
|
| 101 |
+
},
|
| 102 |
+
"12": {
|
| 103 |
+
"content": "<|pad|>",
|
| 104 |
+
"lstrip": false,
|
| 105 |
+
"normalized": false,
|
| 106 |
+
"rstrip": false,
|
| 107 |
+
"single_word": false,
|
| 108 |
+
"special": true
|
| 109 |
+
},
|
| 110 |
+
"13": {
|
| 111 |
+
"content": "<think>",
|
| 112 |
+
"lstrip": false,
|
| 113 |
+
"normalized": false,
|
| 114 |
+
"rstrip": false,
|
| 115 |
+
"single_word": false,
|
| 116 |
+
"special": false
|
| 117 |
+
},
|
| 118 |
+
"14": {
|
| 119 |
+
"content": "</think>",
|
| 120 |
+
"lstrip": false,
|
| 121 |
+
"normalized": false,
|
| 122 |
+
"rstrip": false,
|
| 123 |
+
"single_word": false,
|
| 124 |
+
"special": false
|
| 125 |
+
},
|
| 126 |
+
"15": {
|
| 127 |
+
"content": "<tools>",
|
| 128 |
+
"lstrip": false,
|
| 129 |
+
"normalized": false,
|
| 130 |
+
"rstrip": false,
|
| 131 |
+
"single_word": false,
|
| 132 |
+
"special": false
|
| 133 |
+
},
|
| 134 |
+
"16": {
|
| 135 |
+
"content": "</tools>",
|
| 136 |
+
"lstrip": false,
|
| 137 |
+
"normalized": false,
|
| 138 |
+
"rstrip": false,
|
| 139 |
+
"single_word": false,
|
| 140 |
+
"special": false
|
| 141 |
+
},
|
| 142 |
+
"17": {
|
| 143 |
+
"content": "<tool_call>",
|
| 144 |
+
"lstrip": false,
|
| 145 |
+
"normalized": false,
|
| 146 |
+
"rstrip": false,
|
| 147 |
+
"single_word": false,
|
| 148 |
+
"special": false
|
| 149 |
+
},
|
| 150 |
+
"18": {
|
| 151 |
+
"content": "</tool_call>",
|
| 152 |
+
"lstrip": false,
|
| 153 |
+
"normalized": false,
|
| 154 |
+
"rstrip": false,
|
| 155 |
+
"single_word": false,
|
| 156 |
+
"special": false
|
| 157 |
+
},
|
| 158 |
+
"19": {
|
| 159 |
+
"content": "<tool_response>",
|
| 160 |
+
"lstrip": false,
|
| 161 |
+
"normalized": false,
|
| 162 |
+
"rstrip": false,
|
| 163 |
+
"single_word": false,
|
| 164 |
+
"special": false
|
| 165 |
+
},
|
| 166 |
+
"20": {
|
| 167 |
+
"content": "</tool_response>",
|
| 168 |
+
"lstrip": false,
|
| 169 |
+
"normalized": false,
|
| 170 |
+
"rstrip": false,
|
| 171 |
+
"single_word": false,
|
| 172 |
+
"special": false
|
| 173 |
+
},
|
| 174 |
+
"21": {
|
| 175 |
+
"content": "<properties>",
|
| 176 |
+
"lstrip": false,
|
| 177 |
+
"normalized": false,
|
| 178 |
+
"rstrip": false,
|
| 179 |
+
"single_word": false,
|
| 180 |
+
"special": false
|
| 181 |
+
},
|
| 182 |
+
"22": {
|
| 183 |
+
"content": "</properties>",
|
| 184 |
+
"lstrip": false,
|
| 185 |
+
"normalized": false,
|
| 186 |
+
"rstrip": false,
|
| 187 |
+
"single_word": false,
|
| 188 |
+
"special": false
|
| 189 |
+
},
|
| 190 |
+
"23": {
|
| 191 |
+
"content": "<required>",
|
| 192 |
+
"lstrip": false,
|
| 193 |
+
"normalized": false,
|
| 194 |
+
"rstrip": false,
|
| 195 |
+
"single_word": false,
|
| 196 |
+
"special": false
|
| 197 |
+
},
|
| 198 |
+
"24": {
|
| 199 |
+
"content": "</required>",
|
| 200 |
+
"lstrip": false,
|
| 201 |
+
"normalized": false,
|
| 202 |
+
"rstrip": false,
|
| 203 |
+
"single_word": false,
|
| 204 |
+
"special": false
|
| 205 |
+
},
|
| 206 |
+
"25": {
|
| 207 |
+
"content": "<parameter>",
|
| 208 |
+
"lstrip": false,
|
| 209 |
+
"normalized": false,
|
| 210 |
+
"rstrip": false,
|
| 211 |
+
"single_word": false,
|
| 212 |
+
"special": false
|
| 213 |
+
},
|
| 214 |
+
"26": {
|
| 215 |
+
"content": "</parameter>",
|
| 216 |
+
"lstrip": false,
|
| 217 |
+
"normalized": false,
|
| 218 |
+
"rstrip": false,
|
| 219 |
+
"single_word": false,
|
| 220 |
+
"special": false
|
| 221 |
+
},
|
| 222 |
+
"27": {
|
| 223 |
+
"content": "<function>",
|
| 224 |
+
"lstrip": false,
|
| 225 |
+
"normalized": false,
|
| 226 |
+
"rstrip": false,
|
| 227 |
+
"single_word": false,
|
| 228 |
+
"special": false
|
| 229 |
+
},
|
| 230 |
+
"28": {
|
| 231 |
+
"content": "</function>",
|
| 232 |
+
"lstrip": false,
|
| 233 |
+
"normalized": false,
|
| 234 |
+
"rstrip": false,
|
| 235 |
+
"single_word": false,
|
| 236 |
+
"special": false
|
| 237 |
+
},
|
| 238 |
+
"29": {
|
| 239 |
+
"content": "<function=",
|
| 240 |
+
"lstrip": false,
|
| 241 |
+
"normalized": false,
|
| 242 |
+
"rstrip": false,
|
| 243 |
+
"single_word": false,
|
| 244 |
+
"special": false
|
| 245 |
+
},
|
| 246 |
+
"30": {
|
| 247 |
+
"content": "<parameter=",
|
| 248 |
+
"lstrip": false,
|
| 249 |
+
"normalized": false,
|
| 250 |
+
"rstrip": false,
|
| 251 |
+
"single_word": false,
|
| 252 |
+
"special": false
|
| 253 |
+
},
|
| 254 |
+
"31": {
|
| 255 |
+
"content": "<|reserved_special_18|>",
|
| 256 |
+
"lstrip": false,
|
| 257 |
+
"normalized": false,
|
| 258 |
+
"rstrip": false,
|
| 259 |
+
"single_word": false,
|
| 260 |
+
"special": true
|
| 261 |
+
}
|
| 262 |
+
},
|
| 263 |
+
"bos_token": "<|begin_of_text|>",
|
| 264 |
+
"clean_up_tokenization_spaces": false,
|
| 265 |
+
"eos_token": "<|im_end|>",
|
| 266 |
+
"extra_special_tokens": {},
|
| 267 |
+
"model_max_length": 65536,
|
| 268 |
+
"pad_token": "<|pad|>",
|
| 269 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
| 270 |
+
"use_default_system_prompt": false
|
| 271 |
+
}
|