xuebi
commited on
Commit
·
134ae7c
1
Parent(s):
9ba43d5
update configs
Browse files- .gitattributes +1 -0
- README.md +61 -78
- chat_template.jinja +1 -1
- config.json +113 -0
- configuration_minimax_m2.py +200 -0
- figures/bench_11.png +3 -0
- generation_config.json +9 -0
- modeling_minimax_m2.py +706 -0
.gitattributes
CHANGED
|
@@ -41,3 +41,4 @@ figures/bench_5.png filter=lfs diff=lfs merge=lfs -text
|
|
| 41 |
figures/bench_6.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
figures/bench_8.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
figures/rl_1.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 41 |
figures/bench_6.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
figures/bench_8.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
figures/rl_1.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
figures/bench_11.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
pipeline_tag: text-generation
|
| 3 |
license: other
|
| 4 |
license_name: modified-mit
|
| 5 |
-
license_link: https://github.com/MiniMax-AI/MiniMax-M2.
|
| 6 |
library_name: transformers
|
| 7 |
---
|
| 8 |
|
|
@@ -67,145 +67,101 @@ library_name: transformers
|
|
| 67 |
</a>
|
| 68 |
</div>
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
|
| 74 |
-
|
| 75 |
|
| 76 |
-
|
| 77 |
|
|
|
|
| 78 |
|
|
|
|
| 79 |
|
| 80 |
## Coding
|
| 81 |
-
In programming evaluations, MiniMax M2.5 saw substantial improvements compared to previous models, reaching SOTA levels. The performance of M2.5 in multilingual tasks is especially pronounced.
|
| 82 |
|
| 83 |
-
|
| 84 |
|
| 85 |
-
| Benchmark | M2.5 | M2.1 | Opus 4.5 | Opus 4.6 | Gemini 3 Pro | GPT-5.2 |
|
| 86 |
-
|---|---|---|---|---|---|---|
|
| 87 |
-
| SWE-Bench Verified | 80.2 | 74.0 | 80.9 | 80.8 | 78 | 80.0 |
|
| 88 |
-
| SWE-Bench Pro | 55.4 | 49.7 | 56.9 | 55.4 | 54.1 | 55.6 |
|
| 89 |
-
| Terminal Bench 2 | 51.7 | 47.9 | 53.4 | 55.1 | 54.0 | 54.0 |
|
| 90 |
-
| Multi-SWE-Bench | 51.3 | 47.2 | 50.0 | 50.3 | 42.7 | - |
|
| 91 |
-
| SWE-Bench Multilingual | 74.1 | 71.9 | 77.5 | 77.8 | 65.0 | 72.0 |
|
| 92 |
-
<p align="center">
|
| 93 |
-
<img width="100%" src="figures/bench_1.png">
|
| 94 |
-
</p>
|
| 95 |
<p align="center">
|
| 96 |
<img width="100%" src="figures/bench_2.png">
|
| 97 |
</p>
|
| 98 |
-
A significant improvement from previous generations is M2.5's ability to think and plan like an architect. The Spec-writing tendency of the model emerged during training: before writing any code, M2.5 actively decomposes and plans the features, structure, and UI design of the project from the perspective of an experience software architect.
|
| 99 |
|
| 100 |
-
|
|
|
|
|
|
|
| 101 |
|
| 102 |
To evaluate these capabilities, we also upgraded the VIBE benchmark to a more complex and challenging Pro version, significantly increasing task complexity, domain coverage, and evaluation accuracy. Overall, M2.5 performs on par with Opus 4.5.
|
| 103 |
|
| 104 |
-
| Benchmark | M2.5 | M2.1 | Opus 4.5 | Opus 4.6 | Gemini 3 Pro |
|
| 105 |
-
|---|---|---|---|---|---|
|
| 106 |
-
| VIBE-Pro (AVG) | 54.2 | 42.4 | 55.2 | 55.6 | 36.9 |
|
| 107 |
-
| Web Subset | 36.9 | 31.9 | 37.8 | 40.7 | 28.5 |
|
| 108 |
-
| Simulation Subset | 81.4 | 73.1 | 78.8 | 81.2 | 67.7 |
|
| 109 |
-
| Android Subset | 50.6 | 36.0 | 58.4 | 54.9 | 33.1 |
|
| 110 |
-
| iOS Subset | 47.9 | 28.8 | 45.7 | 45.7 | 18.1 |
|
| 111 |
-
<p align="center">
|
| 112 |
-
<img width="100%" src="figures/bench_3.png">
|
| 113 |
-
</p>
|
| 114 |
<p align="center">
|
| 115 |
<img width="100%" src="figures/bench_4.png">
|
| 116 |
</p>
|
| 117 |
|
| 118 |
-
|
|
|
|
|
|
|
| 119 |
|
| 120 |
## Search and Tool calling
|
| 121 |
|
| 122 |
-
| Benchmark | M2.5 | M2.1 | Opus 4.5 | Opus 4.6 | Gemini 3 Pro | GPT-5.2 |
|
| 123 |
-
|---|---|---|---|---|---|---|
|
| 124 |
-
| BrowseComp w/ ctx | 76.3 | 62.0 | 67.8 | 84.0 | 59.2 | 65.8 |
|
| 125 |
-
| Wide Search | 70.3 | 63.2 | 76.2 | 79.4 | 57.0 | - |
|
| 126 |
-
| RISE | 50.2 | 34.0 | 50.5 | 62.5 | 36.8 | 50.0 |
|
| 127 |
-
| BFCL multi-turn | 76.8 | 37.4 | 68.0 | 63.3 | 61.0 | - |
|
| 128 |
-
| Tau^2 Telecom | 97.8 | 87.0 | 98.2 | 99.3 | 98.0 | 98.7 |
|
| 129 |
-
|
| 130 |
-
<p align="center">
|
| 131 |
-
<img width="100%" src="figures/bench_5.png">
|
| 132 |
-
</p>
|
| 133 |
<p align="center">
|
| 134 |
<img width="100%" src="figures/bench_6.png">
|
| 135 |
</p>
|
| 136 |
|
| 137 |
-
Effective tool calling and search are prerequisites for a model's ability to autonomously handle complex tasks. In evaluations on benchmarks such as BrowseComp and Wide Search, M2.5 achieved industry-leading performance. At the same time, the model's generalization has also improved — M2.5 demonstrates more stable performance when facing unfamiliar scaffolding environments.
|
| 138 |
|
| 139 |
In research tasks performed by professional human experts, using a search engine is only a small part of the process; most of the work involves deep exploration across information-dense webpages. To address this, we built RISE (Realistic Interactive Search Evaluation) to measure a model's search capabilities on real-world professional tasks. The results show that M2.5 excels at expert-level search tasks in real-world settings.
|
| 140 |
|
| 141 |
-
Compared to its predecessors, M2.5 also demonstrates much better decision-making when handling agentic tasks: it has learned to solve problems with more precise search rounds and better token efficiency. For example, across multiple agentic tasks including BrowseComp, Wide Search, and RISE, M2.5 achieved better results with fewer rounds,
|
| 142 |
-
|
| 143 |
-
|
| 144 |
|
| 145 |
## Office work
|
| 146 |
-
M2.5 was trained to produce truly deliverable outputs in office scenarios. To this end, we engaged in thorough collaboration with senior professionals in fields such as finance, law, and social sciences. They designed requirements, provided feedback, participated in defining standards, and directly contributed to data construction, bringing the tacit knowledge of their industries into the model's training pipeline. Based on this foundation, M2.5 has achieved significant capability improvements in high-value workspace scenarios such as Word, PowerPoint, and Excel financial modeling. On the evaluation side, we built an internal Cowork Agent evaluation framework (GDPval-MM) that assesses the delivery quality and trajectory professionalism of models through pairwise comparisons, while also monitoring token costs across the entire workflow to estimate the model's real-world productivity gains. In comparisons against mainstream models, it achieved an average win rate of 59.0%.
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|---|---|---|---|---|---|---|
|
| 150 |
-
| GDPval-MM | 59.0 | 24.6 | 61.1 | 73.5 | 28.1 | 54.5 |
|
| 151 |
-
| MEWC | 74.4 | 55.6 | 82.1 | 89.8 | 78.7 | 41.3 |
|
| 152 |
-
| Finance Modeling | 21.6 | 17.3 | 30.1 | 33.2 | 15.0 | 20.0 |
|
| 153 |
|
| 154 |
-
<p align="center">
|
| 155 |
-
<img width="100%" src="figures/bench_7.png">
|
| 156 |
-
</p>
|
| 157 |
<p align="center">
|
| 158 |
<img width="100%" src="figures/bench_8.png">
|
| 159 |
</p>
|
| 160 |
-
<p align="center">
|
| 161 |
-
<img width="100%" src="figures/bench_9.png">
|
| 162 |
-
</p>
|
| 163 |
-
|
| 164 |
|
| 165 |
## Efficiency
|
|
|
|
| 166 |
Because the real world is full of deadlines and time constraints, task completion speed is a practical necessity. The time it takes a model to complete a task depends on its task decomposition effectiveness, token efficiency, and inference speed. M2.5 is served natively at a rate of 100 tokens per second, which is nearly twice that of other frontier models. Further, our reinforcement learning setup incentivizes the model to reason efficiently and break down tasks optimally. Due to these three factors, M2.5 delivers a significant time savings in complex task completion.
|
| 167 |
|
| 168 |
For example, when running SWE-Bench Verified, M2.5 consumed an average of 3.52 million tokens per task. In comparison, M2.1 consumed 3.72M tokens. Meanwhile, thanks to improvements in capabilities such as parallel tool calling, the end-to-end runtime decreased from an average of 31.3 minutes to 22.8 minutes, representing a 37% speed improvement. This runtime is on par with Claude Opus 4.6's 22.9 minutes, while the total cost per task is only 10% that of Claude Opus 4.6.
|
| 169 |
|
| 170 |
-
|
| 171 |
-
|
| 172 |
## Cost
|
| 173 |
|
| 174 |
Our goal in designing the M2-series of foundation models is to power complex agents without having to worry about cost. We believe that M2.5 is close to realizing this goal. We’re releasing two versions of the model, M2.5 and M2.5-Lightning, that are identical in capability but differ in speed. M2.5-Lightning has a steady throughput of 100 tokens per second, which is two times faster than other frontier models, and costs $0.3 per million input tokens and $2.4 per million output tokens. M2.5, which has a throughput of 50 tokens per second, costs half that. Both model versions support caching. Based on output price, the cost of M2.5 is one-tenth to one-twentieth that of Opus, Gemini 3 Pro, and GPT-5.
|
| 175 |
|
| 176 |
At a rate of 100 output tokens per second, running M2.5 continuously for an hour costs $1. At a rate of 50 TPS, the price drops to $0.3. To put that into perspective, you can have four M2.5 instances running continuously for an entire year for $10,000. We believe that M2.5 provides virtually limitless possibilities for the development and operation of agents in the economy. For the M2-series, the only problem that remains is how to continually push the frontier of model capability.
|
| 177 |
|
| 178 |
-
|
| 179 |
-
|
| 180 |
## Improvement Rate
|
| 181 |
|
| 182 |
-
Over the three and a half months from late October to now, we have successively released M2, M2.1, and M2.5, with the pace of model improvement exceeding our original expectations. For instance, in the highly-regarded SWE-Bench Verified benchmark, the rate of progress of the M2
|
| 183 |
|
| 184 |
<p align="center">
|
| 185 |
<img width="100%" src="figures/bench_10.png">
|
| 186 |
</p>
|
| 187 |
|
| 188 |
-
|
| 189 |
## RL Scaling
|
| 190 |
|
| 191 |
One of the key drivers of the aforementioned developments is the scaling of reinforcement learning. As we train our models, we also benefit from their abilities. Most of the tasks and workspaces that we perform in our company have been made into training environments for RL. To date, there are already hundreds of thousands of such environments. At the same time, we did plenty of work on our agentic RL framework, algorithms, reward signals, and infrastructure engineering to support the continued scaling of our RL training.
|
| 192 |
|
| 193 |
-
|
| 194 |
|
| 195 |
We designed an agent-native RL framework in-house, called Forge, which introduces an intermediary layer that fully decouples the underlying training-inference engine from the agent, supporting the integration of arbitrary agents and enabling us to optimize the model's generalization across agent scaffolds and tools. To improve system throughput, we optimized asynchronous scheduling strategies to balance system throughput against sample off-policyness, and designed a tree-structured merging strategy for training samples, achieving approximately 40x training speedup.
|
| 196 |
|
| 197 |
<p align="center">
|
| 198 |
-
<img width="
|
| 199 |
</p>
|
| 200 |
|
| 201 |
-
|
| 202 |
|
| 203 |
On the algorithm side, we continued using the CISPO algorithm we proposed at the beginning of last year to ensure the stability of MoE models during large-scale training. To address the credit assignment challenge posed by long contexts in agent rollouts, we introduced a process reward mechanism for end-to-end monitoring of generation quality. Furthermore, to deeply align with user experience, we evaluated task completion time through agent trajectories, achieving an optimal trade-off between model intelligence and response speed.
|
|
|
|
| 204 |
<p align="center">
|
| 205 |
-
<img width="
|
| 206 |
</p>
|
| 207 |
-
We will release a more comprehensive introduction to RL scaling soon in a technical blogpost.
|
| 208 |
|
|
|
|
| 209 |
|
| 210 |
## MiniMax Agent: M2.5 as a Professional Employee
|
| 211 |
|
|
@@ -221,23 +177,50 @@ To date, users have built over 10,000 Experts on MiniMax Agent, and this number
|
|
| 221 |
|
| 222 |
MiniMax itself has been among the first to benefit from M2.5's capabilities. Throughout the company's daily operations, 30% of overall tasks are autonomously completed by M2.5, spanning functions including R&D, product, sales, HR, and finance — and the penetration rate continues to rise. Performance in coding scenarios has been particularly notable, with M2.5-generated code accounting for 80% of newly committed code.
|
| 223 |
|
| 224 |
-
|
| 225 |
|
|
|
|
| 226 |
|
|
|
|
| 227 |
|
| 228 |
-
|
| 229 |
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
-
|
| 233 |
|
| 234 |
-
|
| 235 |
|
| 236 |
-
|
| 237 |
|
|
|
|
| 238 |
|
| 239 |
|
| 240 |
-
##
|
| 241 |
|
| 242 |
Further benchmark results of M2.5:
|
| 243 |
|
|
@@ -260,4 +243,4 @@ Evaluation methods:
|
|
| 260 |
> - GDPval-MM: Internal benchmark. Based on the open-source GDPval test set, using a custom agentic evaluation framework where an LLM-as-a-judge performs pairwise win/tie/loss judgments on complete trajectories. Average token cost per task is calculated based on each vendor's official API pricing (without caching).
|
| 261 |
> - MEWC: Internal benchmark. Built on MEWC (Microsoft Excel World Championship), comprising 179 problems from the main and other regional divisions of Excel esports competitions from 2021–2026. It evaluates the model's ability to understand competition Excel spreadsheets and use Excel tools to complete problems. Scores are calculated by comparing output and answer cell values one by one.
|
| 262 |
> - Finance Modeling: Internal benchmark. Primarily contains financial modeling problems constructed by industry experts, involving end-to-end research and analysis tasks performed via Excel tools. Each problem is scored using expert-designed rubrics. Final results are averaged over 3 runs.
|
| 263 |
-
> - AIME25 ~ AA-LCR: Obtained through internal testing based on the public evaluation sets and evaluation methods covered by the Artificial Analysis Intelligence Index leaderboard.
|
|
|
|
| 2 |
pipeline_tag: text-generation
|
| 3 |
license: other
|
| 4 |
license_name: modified-mit
|
| 5 |
+
license_link: https://github.com/MiniMax-AI/MiniMax-M2.5/blob/main/LICENSE
|
| 6 |
library_name: transformers
|
| 7 |
---
|
| 8 |
|
|
|
|
| 67 |
</a>
|
| 68 |
</div>
|
| 69 |
|
| 70 |
+
<p align="center">
|
| 71 |
+
<img width="100%" src="figures/bench_11.png">
|
| 72 |
+
</p>
|
| 73 |
|
| 74 |
+
Today we're introducing our latest model, **MiniMax-M2.5**.
|
| 75 |
|
| 76 |
+
Extensively trained with reinforcement learning in hundreds of thousands of complex real-world environments, M2.5 is **SOTA in coding, agentic tool use and search, office work, and a range of other economically valuable tasks**, boasting scores of **80.2% in SWE-Bench Verified, 51.3% in Multi-SWE-Bench, and 76.3% in BrowseComp** (with context management).
|
| 77 |
|
| 78 |
+
Trained to reason efficiently and decompose tasks optimally, M2.5 exhibits tremendous speed in performing complicated agentic tasks, completing the SWE-Bench Verified evaluation **37% faster** than M2.1, matching the speed of **Claude Opus 4.6**.
|
| 79 |
|
| 80 |
+
M2.5 is the first frontier model where users do not need to worry about cost, delivering on the promise of intelligence too cheap to meter. **It costs just $1 to run the model continuously for an hour at a rate of 100 tokens per second.** At 50 tokens per second, the cost drops to $0.30. We hope that the speed and cost effectiveness of M2.5 enable innovative new agentic applications.
|
| 81 |
|
| 82 |
## Coding
|
|
|
|
| 83 |
|
| 84 |
+
In programming evaluations, MiniMax-M2.5 saw substantial improvements compared to previous generations, reaching SOTA levels. The performance of M2.5 in multilingual tasks is especially pronounced.
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
<p align="center">
|
| 87 |
<img width="100%" src="figures/bench_2.png">
|
| 88 |
</p>
|
|
|
|
| 89 |
|
| 90 |
+
A significant improvement from previous generations is M2.5's ability to think and plan like an architect. The Spec-writing tendency of the model emerged during training: before writing any code, M2.5 actively decomposes and plans the features, structure, and UI design of the project from the perspective of an experienced software architect.
|
| 91 |
+
|
| 92 |
+
M2.5 was trained on over 10 languages (including Go, C, C++, TypeScript, Rust, Kotlin, Python, Java, JavaScript, PHP, Lua, Dart, and Ruby) across more than 200,000 real-world environments. Going far beyond bug-fixing, M2.5 delivers reliable performance across the entire development lifecycle of complex systems: from 0-to-1 system design and environment setup, to 1-to-10 system development, to 10-to-90 feature iteration, and finally 90-to-100 comprehensive code review and system testing. It covers full-stack projects spanning multiple platforms including Web, Android, iOS, and Windows, encompassing server-side APIs, business logic, databases, and more, not just frontend webpage demos.
|
| 93 |
|
| 94 |
To evaluate these capabilities, we also upgraded the VIBE benchmark to a more complex and challenging Pro version, significantly increasing task complexity, domain coverage, and evaluation accuracy. Overall, M2.5 performs on par with Opus 4.5.
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
<p align="center">
|
| 97 |
<img width="100%" src="figures/bench_4.png">
|
| 98 |
</p>
|
| 99 |
|
| 100 |
+
We focused on the model's ability to generalize across out-of-distribution harnesses. We tested performance on the SWE-Bench Verified evaluation set using different coding agent harnesses.
|
| 101 |
+
- On Droid: 79.7(M2.5) > 78.9(Opus 4.6)
|
| 102 |
+
- On OpenCode: 76.1(M2.5) > 75.9(Opus 4.6)
|
| 103 |
|
| 104 |
## Search and Tool calling
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
<p align="center">
|
| 107 |
<img width="100%" src="figures/bench_6.png">
|
| 108 |
</p>
|
| 109 |
|
| 110 |
+
Effective tool calling and search are prerequisites for a model's ability to autonomously handle more complex tasks. In evaluations on benchmarks such as BrowseComp and Wide Search, M2.5 achieved industry-leading performance. At the same time, the model's generalization has also improved — M2.5 demonstrates more stable performance when facing unfamiliar scaffolding environments.
|
| 111 |
|
| 112 |
In research tasks performed by professional human experts, using a search engine is only a small part of the process; most of the work involves deep exploration across information-dense webpages. To address this, we built RISE (Realistic Interactive Search Evaluation) to measure a model's search capabilities on real-world professional tasks. The results show that M2.5 excels at expert-level search tasks in real-world settings.
|
| 113 |
|
| 114 |
+
Compared to its predecessors, M2.5 also demonstrates much better decision-making when handling agentic tasks: it has learned to solve problems with more precise search rounds and better token efficiency. For example, across multiple agentic tasks including BrowseComp, Wide Search, and RISE, M2.5 achieved better results with fewer rounds, using approximately 20% fewer rounds compared to M2.1. This indicates that the model is no longer just getting the answer right, but is also reasoning towards results in more efficient paths.
|
|
|
|
|
|
|
| 115 |
|
| 116 |
## Office work
|
|
|
|
| 117 |
|
| 118 |
+
M2.5 was trained to produce truly deliverable outputs in office scenarios. To this end, we engaged in thorough collaboration with senior professionals in fields such as finance, law, and social sciences. They designed requirements, provided feedback, participated in defining standards, and directly contributed to data construction, bringing the tacit knowledge of their industries into the model's training pipeline. Based on this foundation, M2.5 has achieved significant capability improvements in high-value workspace scenarios such as Word, PowerPoint, and Excel financial modeling. On the evaluation side, we built an internal Cowork Agent evaluation framework (GDPval-MM) that assesses both the quality of the deliverable and the professionalism of the agent's trajectory through pairwise comparisons, while also monitoring token costs across the entire workflow to estimate the model's real-world productivity gains. In comparisons against other mainstream models, it achieved an average win rate of 59.0%.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
|
|
|
|
|
|
|
|
|
| 120 |
<p align="center">
|
| 121 |
<img width="100%" src="figures/bench_8.png">
|
| 122 |
</p>
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
## Efficiency
|
| 125 |
+
|
| 126 |
Because the real world is full of deadlines and time constraints, task completion speed is a practical necessity. The time it takes a model to complete a task depends on its task decomposition effectiveness, token efficiency, and inference speed. M2.5 is served natively at a rate of 100 tokens per second, which is nearly twice that of other frontier models. Further, our reinforcement learning setup incentivizes the model to reason efficiently and break down tasks optimally. Due to these three factors, M2.5 delivers a significant time savings in complex task completion.
|
| 127 |
|
| 128 |
For example, when running SWE-Bench Verified, M2.5 consumed an average of 3.52 million tokens per task. In comparison, M2.1 consumed 3.72M tokens. Meanwhile, thanks to improvements in capabilities such as parallel tool calling, the end-to-end runtime decreased from an average of 31.3 minutes to 22.8 minutes, representing a 37% speed improvement. This runtime is on par with Claude Opus 4.6's 22.9 minutes, while the total cost per task is only 10% that of Claude Opus 4.6.
|
| 129 |
|
|
|
|
|
|
|
| 130 |
## Cost
|
| 131 |
|
| 132 |
Our goal in designing the M2-series of foundation models is to power complex agents without having to worry about cost. We believe that M2.5 is close to realizing this goal. We’re releasing two versions of the model, M2.5 and M2.5-Lightning, that are identical in capability but differ in speed. M2.5-Lightning has a steady throughput of 100 tokens per second, which is two times faster than other frontier models, and costs $0.3 per million input tokens and $2.4 per million output tokens. M2.5, which has a throughput of 50 tokens per second, costs half that. Both model versions support caching. Based on output price, the cost of M2.5 is one-tenth to one-twentieth that of Opus, Gemini 3 Pro, and GPT-5.
|
| 133 |
|
| 134 |
At a rate of 100 output tokens per second, running M2.5 continuously for an hour costs $1. At a rate of 50 TPS, the price drops to $0.3. To put that into perspective, you can have four M2.5 instances running continuously for an entire year for $10,000. We believe that M2.5 provides virtually limitless possibilities for the development and operation of agents in the economy. For the M2-series, the only problem that remains is how to continually push the frontier of model capability.
|
| 135 |
|
|
|
|
|
|
|
| 136 |
## Improvement Rate
|
| 137 |
|
| 138 |
+
Over the three and a half months from late October to now, we have successively released M2, M2.1, and M2.5, with the pace of model improvement exceeding our original expectations. For instance, in the highly-regarded SWE-Bench Verified benchmark, the rate of progress of the M2-series has been significantly faster than that of peers such as the Claude, GPT, and Gemini model families.
|
| 139 |
|
| 140 |
<p align="center">
|
| 141 |
<img width="100%" src="figures/bench_10.png">
|
| 142 |
</p>
|
| 143 |
|
|
|
|
| 144 |
## RL Scaling
|
| 145 |
|
| 146 |
One of the key drivers of the aforementioned developments is the scaling of reinforcement learning. As we train our models, we also benefit from their abilities. Most of the tasks and workspaces that we perform in our company have been made into training environments for RL. To date, there are already hundreds of thousands of such environments. At the same time, we did plenty of work on our agentic RL framework, algorithms, reward signals, and infrastructure engineering to support the continued scaling of our RL training.
|
| 147 |
|
| 148 |
+
### Forge –– Agent-Native RL Framework
|
| 149 |
|
| 150 |
We designed an agent-native RL framework in-house, called Forge, which introduces an intermediary layer that fully decouples the underlying training-inference engine from the agent, supporting the integration of arbitrary agents and enabling us to optimize the model's generalization across agent scaffolds and tools. To improve system throughput, we optimized asynchronous scheduling strategies to balance system throughput against sample off-policyness, and designed a tree-structured merging strategy for training samples, achieving approximately 40x training speedup.
|
| 151 |
|
| 152 |
<p align="center">
|
| 153 |
+
<img width="60%" src="figures/rl_1.png">
|
| 154 |
</p>
|
| 155 |
|
| 156 |
+
### Agentic RL Algorithm and Reward Design
|
| 157 |
|
| 158 |
On the algorithm side, we continued using the CISPO algorithm we proposed at the beginning of last year to ensure the stability of MoE models during large-scale training. To address the credit assignment challenge posed by long contexts in agent rollouts, we introduced a process reward mechanism for end-to-end monitoring of generation quality. Furthermore, to deeply align with user experience, we evaluated task completion time through agent trajectories, achieving an optimal trade-off between model intelligence and response speed.
|
| 159 |
+
|
| 160 |
<p align="center">
|
| 161 |
+
<img width="60%" src="figures/rl_2.png">
|
| 162 |
</p>
|
|
|
|
| 163 |
|
| 164 |
+
We will release a more comprehensive introduction to RL scaling soon in a separate technical blogpost.
|
| 165 |
|
| 166 |
## MiniMax Agent: M2.5 as a Professional Employee
|
| 167 |
|
|
|
|
| 177 |
|
| 178 |
MiniMax itself has been among the first to benefit from M2.5's capabilities. Throughout the company's daily operations, 30% of overall tasks are autonomously completed by M2.5, spanning functions including R&D, product, sales, HR, and finance — and the penetration rate continues to rise. Performance in coding scenarios has been particularly notable, with M2.5-generated code accounting for 80% of newly committed code.
|
| 179 |
|
| 180 |
+
## How to Use
|
| 181 |
|
| 182 |
+
MiniMax Agent: https://agent.minimax.io/
|
| 183 |
|
| 184 |
+
MiniMax API Platform: https://platform.minimax.io/
|
| 185 |
|
| 186 |
+
MiniMax Coding Plan: https://platform.minimax.io/subscribe/coding-plan
|
| 187 |
|
| 188 |
+
## Local Deployment Guide
|
| 189 |
+
|
| 190 |
+
Download the model from HuggingFace repository: https://huggingface.co/MiniMaxAI/MiniMax-M2.5
|
| 191 |
+
|
| 192 |
+
We recommend using the following inference frameworks (listed alphabetically) to serve the model:
|
| 193 |
+
|
| 194 |
+
### SGLang
|
| 195 |
+
|
| 196 |
+
We recommend using [SGLang](https://docs.sglang.io/) to serve MiniMax-M2.5. Please refer to our [SGLang Deployment Guide](./docs/sglang_deploy_guide.md).
|
| 197 |
+
|
| 198 |
+
### vLLM
|
| 199 |
+
|
| 200 |
+
We recommend using [vLLM](https://github.com/vllm-project/vllm) to serve MiniMax-M2.5. Please refer to our [vLLM Deployment Guide](./docs/vllm_deploy_guide.md).
|
| 201 |
+
|
| 202 |
+
### Transformers
|
| 203 |
+
|
| 204 |
+
We recommend using [Transformers](https://github.com/huggingface/transformers) to serve MiniMax-M2.5. Please refer to our [Transformers Deployment Guide](./docs/transformers_deploy_guide.md).
|
| 205 |
+
|
| 206 |
+
### Inference Parameters
|
| 207 |
+
|
| 208 |
+
We recommend using the following parameters for best performance: `temperature=1.0`, `top_p = 0.95`, `top_k = 40`. Default system prompt:
|
| 209 |
+
|
| 210 |
+
```
|
| 211 |
+
You are a helpful assistant. Your name is MiniMax-M2.5 and is built by MiniMax.
|
| 212 |
+
```
|
| 213 |
|
| 214 |
+
## Tool Calling Guide
|
| 215 |
|
| 216 |
+
Please refer to our [Tool Calling Guide](./docs/tool_calling_guide.md).
|
| 217 |
|
| 218 |
+
## Contact Us
|
| 219 |
|
| 220 |
+
Contact us at [model@minimax.io](mailto:model@minimax.io).
|
| 221 |
|
| 222 |
|
| 223 |
+
## Appendix
|
| 224 |
|
| 225 |
Further benchmark results of M2.5:
|
| 226 |
|
|
|
|
| 243 |
> - GDPval-MM: Internal benchmark. Based on the open-source GDPval test set, using a custom agentic evaluation framework where an LLM-as-a-judge performs pairwise win/tie/loss judgments on complete trajectories. Average token cost per task is calculated based on each vendor's official API pricing (without caching).
|
| 244 |
> - MEWC: Internal benchmark. Built on MEWC (Microsoft Excel World Championship), comprising 179 problems from the main and other regional divisions of Excel esports competitions from 2021–2026. It evaluates the model's ability to understand competition Excel spreadsheets and use Excel tools to complete problems. Scores are calculated by comparing output and answer cell values one by one.
|
| 245 |
> - Finance Modeling: Internal benchmark. Primarily contains financial modeling problems constructed by industry experts, involving end-to-end research and analysis tasks performed via Excel tools. Each problem is scored using expert-designed rubrics. Final results are averaged over 3 runs.
|
| 246 |
+
> - AIME25 ~ AA-LCR: Obtained through internal testing based on the public evaluation sets and evaluation methods covered by the Artificial Analysis Intelligence Index leaderboard.
|
chat_template.jinja
CHANGED
|
@@ -28,7 +28,7 @@
|
|
| 28 |
{{- visible_text(system_message.content) }}
|
| 29 |
{%- else -%}
|
| 30 |
{%- if model_identity is not defined -%}
|
| 31 |
-
{%- set model_identity = "You are a helpful assistant." -%}
|
| 32 |
{%- endif -%}
|
| 33 |
{{- model_identity }}
|
| 34 |
{%- endif -%}
|
|
|
|
| 28 |
{{- visible_text(system_message.content) }}
|
| 29 |
{%- else -%}
|
| 30 |
{%- if model_identity is not defined -%}
|
| 31 |
+
{%- set model_identity = "You are a helpful assistant. Your name is MiniMax-M2.5 and is built by MiniMax." -%}
|
| 32 |
{%- endif -%}
|
| 33 |
{{- model_identity }}
|
| 34 |
{%- endif -%}
|
config.json
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"MiniMaxM2ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attn_type_list": [
|
| 6 |
+
1,
|
| 7 |
+
1,
|
| 8 |
+
1,
|
| 9 |
+
1,
|
| 10 |
+
1,
|
| 11 |
+
1,
|
| 12 |
+
1,
|
| 13 |
+
1,
|
| 14 |
+
1,
|
| 15 |
+
1,
|
| 16 |
+
1,
|
| 17 |
+
1,
|
| 18 |
+
1,
|
| 19 |
+
1,
|
| 20 |
+
1,
|
| 21 |
+
1,
|
| 22 |
+
1,
|
| 23 |
+
1,
|
| 24 |
+
1,
|
| 25 |
+
1,
|
| 26 |
+
1,
|
| 27 |
+
1,
|
| 28 |
+
1,
|
| 29 |
+
1,
|
| 30 |
+
1,
|
| 31 |
+
1,
|
| 32 |
+
1,
|
| 33 |
+
1,
|
| 34 |
+
1,
|
| 35 |
+
1,
|
| 36 |
+
1,
|
| 37 |
+
1,
|
| 38 |
+
1,
|
| 39 |
+
1,
|
| 40 |
+
1,
|
| 41 |
+
1,
|
| 42 |
+
1,
|
| 43 |
+
1,
|
| 44 |
+
1,
|
| 45 |
+
1,
|
| 46 |
+
1,
|
| 47 |
+
1,
|
| 48 |
+
1,
|
| 49 |
+
1,
|
| 50 |
+
1,
|
| 51 |
+
1,
|
| 52 |
+
1,
|
| 53 |
+
1,
|
| 54 |
+
1,
|
| 55 |
+
1,
|
| 56 |
+
1,
|
| 57 |
+
1,
|
| 58 |
+
1,
|
| 59 |
+
1,
|
| 60 |
+
1,
|
| 61 |
+
1,
|
| 62 |
+
1,
|
| 63 |
+
1,
|
| 64 |
+
1,
|
| 65 |
+
1,
|
| 66 |
+
1,
|
| 67 |
+
1
|
| 68 |
+
],
|
| 69 |
+
"auto_map": {
|
| 70 |
+
"AutoConfig": "configuration_minimax_m2.MiniMaxM2Config",
|
| 71 |
+
"AutoModelForCausalLM": "modeling_minimax_m2.MiniMaxM2ForCausalLM"
|
| 72 |
+
},
|
| 73 |
+
"head_dim": 128,
|
| 74 |
+
"hidden_act": "silu",
|
| 75 |
+
"hidden_size": 3072,
|
| 76 |
+
"intermediate_size": 1536,
|
| 77 |
+
"max_position_embeddings": 196608,
|
| 78 |
+
"model_type": "minimax_m2",
|
| 79 |
+
"mtp_transformer_layers": 1,
|
| 80 |
+
"num_attention_heads": 48,
|
| 81 |
+
"num_experts_per_tok": 8,
|
| 82 |
+
"num_hidden_layers": 62,
|
| 83 |
+
"num_key_value_heads": 8,
|
| 84 |
+
"num_local_experts": 256,
|
| 85 |
+
"num_mtp_modules": 3,
|
| 86 |
+
"qk_norm_type": "per_layer",
|
| 87 |
+
"quantization_config": {
|
| 88 |
+
"activation_scheme": "dynamic",
|
| 89 |
+
"fmt": "float8_e4m3fn",
|
| 90 |
+
"quant_method": "fp8",
|
| 91 |
+
"weight_block_size": [
|
| 92 |
+
128,
|
| 93 |
+
128
|
| 94 |
+
],
|
| 95 |
+
"modules_to_not_convert": [
|
| 96 |
+
"gate",
|
| 97 |
+
"e_score_correction_bias",
|
| 98 |
+
"lm_head"
|
| 99 |
+
]
|
| 100 |
+
},
|
| 101 |
+
"rms_norm_eps": 1e-06,
|
| 102 |
+
"rope_theta": 5000000,
|
| 103 |
+
"rotary_dim": 64,
|
| 104 |
+
"scoring_func": "sigmoid",
|
| 105 |
+
"shared_intermediate_size": 0,
|
| 106 |
+
"tie_word_embeddings": false,
|
| 107 |
+
"transformers_version": "4.46.1",
|
| 108 |
+
"use_cache": true,
|
| 109 |
+
"use_mtp": true,
|
| 110 |
+
"use_qk_norm": true,
|
| 111 |
+
"use_routing_bias": true,
|
| 112 |
+
"vocab_size": 200064
|
| 113 |
+
}
|
configuration_minimax_m2.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/minimax_m2/modular_minimax_m2.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_minimax_m2.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
# Copyright 2025 the HuggingFace Team. All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class MiniMaxM2Config(PretrainedConfig):
|
| 27 |
+
r"""
|
| 28 |
+
This is the configuration class to store the configuration of a [`MiniMaxM2Model`]. It is used to instantiate an
|
| 29 |
+
MiniMaxM2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 30 |
+
with the defaults will yield a similar configuration to that of the MiniMaxM2-7B-v0.1 or MiniMaxM2-7B-Instruct-v0.1.
|
| 31 |
+
|
| 32 |
+
[minimax_m2ai/MiniMaxM2-8x7B](https://huggingface.co/minimax_m2ai/MiniMaxM2-8x7B)
|
| 33 |
+
[minimax_m2ai/MiniMaxM2-7B-Instruct-v0.1](https://huggingface.co/minimax_m2ai/MiniMaxM2-7B-Instruct-v0.1)
|
| 34 |
+
|
| 35 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 36 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
| 41 |
+
Vocabulary size of the MiniMaxM2 model. Defines the number of different tokens that can be represented by the
|
| 42 |
+
`inputs_ids` passed when calling [`MiniMaxM2Model`]
|
| 43 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 44 |
+
Dimension of the hidden representations.
|
| 45 |
+
intermediate_size (`int`, *optional*, defaults to 14336):
|
| 46 |
+
Dimension of the MLP representations.
|
| 47 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 48 |
+
Number of hidden layers in the Transformer encoder.
|
| 49 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 50 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 51 |
+
num_key_value_heads (`int`, *optional*, defaults to 8):
|
| 52 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 53 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 54 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 55 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 56 |
+
by meanpooling all the original heads within that group. For more details, check out [this
|
| 57 |
+
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.
|
| 58 |
+
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
|
| 59 |
+
The attention head dimension.
|
| 60 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 61 |
+
The non-linear activation function (function or string) in the decoder.
|
| 62 |
+
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
|
| 63 |
+
The maximum sequence length that this model might ever be used with. MiniMaxM2's sliding window attention
|
| 64 |
+
allows sequence of up to 4096*32 tokens.
|
| 65 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 66 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 67 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 68 |
+
The epsilon used by the rms normalization layers.
|
| 69 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 70 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 71 |
+
relevant if `config.is_decoder=True`.
|
| 72 |
+
pad_token_id (`int`, *optional*):
|
| 73 |
+
The id of the padding token.
|
| 74 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
| 75 |
+
The id of the "beginning-of-sequence" token.
|
| 76 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 77 |
+
The id of the "end-of-sequence" token.
|
| 78 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 79 |
+
Whether the model's input and output word embeddings should be tied.
|
| 80 |
+
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
| 81 |
+
The base period of the RoPE embeddings.
|
| 82 |
+
sliding_window (`int`, *optional*):
|
| 83 |
+
Sliding window attention window size. If not specified, will default to `4096`.
|
| 84 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 85 |
+
The dropout ratio for the attention probabilities.
|
| 86 |
+
num_experts_per_tok (`int`, *optional*, defaults to 2):
|
| 87 |
+
The number of experts to route per-token, can be also interpreted as the `top-k` routing
|
| 88 |
+
parameter
|
| 89 |
+
num_local_experts (`int`, *optional*, defaults to 8):
|
| 90 |
+
Number of experts per Sparse MLP layer.
|
| 91 |
+
output_router_logits (`bool`, *optional*, defaults to `False`):
|
| 92 |
+
Whether or not the router logits should be returned by the model. Enabling this will also
|
| 93 |
+
allow the model to output the auxiliary loss. See [here]() for more details
|
| 94 |
+
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
| 95 |
+
The aux loss factor for the total loss.
|
| 96 |
+
router_jitter_noise (`float`, *optional*, defaults to 0.0):
|
| 97 |
+
Amount of noise to add to the router.
|
| 98 |
+
|
| 99 |
+
```python
|
| 100 |
+
>>> from transformers import MiniMaxM2Model, MiniMaxM2Config
|
| 101 |
+
|
| 102 |
+
>>> # Initializing a MiniMaxM2 7B style configuration
|
| 103 |
+
>>> configuration = MiniMaxM2Config()
|
| 104 |
+
|
| 105 |
+
>>> # Initializing a model from the MiniMaxM2 7B style configuration
|
| 106 |
+
>>> model = MiniMaxM2Model(configuration)
|
| 107 |
+
|
| 108 |
+
>>> # Accessing the model configuration
|
| 109 |
+
>>> configuration = model.config
|
| 110 |
+
```"""
|
| 111 |
+
|
| 112 |
+
model_type = "minimax_m2"
|
| 113 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 114 |
+
base_model_tp_plan = {
|
| 115 |
+
"layers.*.self_attn.q_proj": "colwise",
|
| 116 |
+
"layers.*.self_attn.k_proj": "colwise",
|
| 117 |
+
"layers.*.self_attn.v_proj": "colwise",
|
| 118 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 119 |
+
"layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts
|
| 120 |
+
"layers.*.block_sparse_moe.experts.*.w1": "colwise",
|
| 121 |
+
"layers.*.block_sparse_moe.experts.*.w2": "rowwise",
|
| 122 |
+
"layers.*.block_sparse_moe.experts.*.w3": "colwise",
|
| 123 |
+
}
|
| 124 |
+
base_model_pp_plan = {
|
| 125 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 126 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 127 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
vocab_size=32000,
|
| 133 |
+
hidden_size=4096,
|
| 134 |
+
intermediate_size=14336,
|
| 135 |
+
num_hidden_layers=32,
|
| 136 |
+
num_attention_heads=32,
|
| 137 |
+
num_key_value_heads=8,
|
| 138 |
+
head_dim=None,
|
| 139 |
+
hidden_act="silu",
|
| 140 |
+
max_position_embeddings=4096 * 32,
|
| 141 |
+
initializer_range=0.02,
|
| 142 |
+
rms_norm_eps=1e-5,
|
| 143 |
+
use_cache=True,
|
| 144 |
+
pad_token_id=None,
|
| 145 |
+
bos_token_id=1,
|
| 146 |
+
eos_token_id=2,
|
| 147 |
+
tie_word_embeddings=False,
|
| 148 |
+
rope_theta=1e6,
|
| 149 |
+
sliding_window=None,
|
| 150 |
+
attention_dropout=0.0,
|
| 151 |
+
num_experts_per_tok=2,
|
| 152 |
+
num_local_experts=8,
|
| 153 |
+
output_router_logits=False,
|
| 154 |
+
router_aux_loss_coef=0.001,
|
| 155 |
+
router_jitter_noise=0.0,
|
| 156 |
+
**kwargs,
|
| 157 |
+
):
|
| 158 |
+
self.vocab_size = vocab_size
|
| 159 |
+
self.max_position_embeddings = max_position_embeddings
|
| 160 |
+
self.hidden_size = hidden_size
|
| 161 |
+
self.intermediate_size = intermediate_size
|
| 162 |
+
self.num_hidden_layers = num_hidden_layers
|
| 163 |
+
self.num_attention_heads = num_attention_heads
|
| 164 |
+
self.sliding_window = sliding_window
|
| 165 |
+
|
| 166 |
+
# for backward compatibility
|
| 167 |
+
if num_key_value_heads is None:
|
| 168 |
+
num_key_value_heads = num_attention_heads
|
| 169 |
+
|
| 170 |
+
self.num_key_value_heads = num_key_value_heads
|
| 171 |
+
self.hidden_act = hidden_act
|
| 172 |
+
self.initializer_range = initializer_range
|
| 173 |
+
self.rms_norm_eps = rms_norm_eps
|
| 174 |
+
self.use_cache = use_cache
|
| 175 |
+
self.rope_theta = rope_theta
|
| 176 |
+
self.attention_dropout = attention_dropout
|
| 177 |
+
self.head_dim = head_dim
|
| 178 |
+
|
| 179 |
+
self.num_experts_per_tok = num_experts_per_tok
|
| 180 |
+
self.num_local_experts = num_local_experts
|
| 181 |
+
self.output_router_logits = output_router_logits
|
| 182 |
+
self.router_aux_loss_coef = router_aux_loss_coef
|
| 183 |
+
self.router_jitter_noise = router_jitter_noise
|
| 184 |
+
|
| 185 |
+
self.use_qk_norm = kwargs.pop("use_qk_norm", False)
|
| 186 |
+
self.rotary_dim = kwargs.pop("rotary_dim", self.head_dim)
|
| 187 |
+
self.partial_rotary_factor = kwargs.pop("partial_rotary_factor", 1)
|
| 188 |
+
if self.head_dim is not None:
|
| 189 |
+
self.partial_rotary_factor = self.rotary_dim / self.head_dim
|
| 190 |
+
|
| 191 |
+
super().__init__(
|
| 192 |
+
pad_token_id=pad_token_id,
|
| 193 |
+
bos_token_id=bos_token_id,
|
| 194 |
+
eos_token_id=eos_token_id,
|
| 195 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 196 |
+
**kwargs,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
__all__ = ["MiniMaxM2Config"]
|
figures/bench_11.png
ADDED
|
Git LFS Details
|
generation_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 200019,
|
| 3 |
+
"do_sample": true,
|
| 4 |
+
"eos_token_id": 200020,
|
| 5 |
+
"temperature": 1.0,
|
| 6 |
+
"top_p": 0.95,
|
| 7 |
+
"top_k": 40,
|
| 8 |
+
"transformers_version": "4.46.1"
|
| 9 |
+
}
|
modeling_minimax_m2.py
ADDED
|
@@ -0,0 +1,706 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/minimax_m2/modular_minimax_m2.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_minimax_m2.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
# Copyright 2025 the HuggingFace Team. All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
from collections.abc import Callable
|
| 24 |
+
from typing import Optional, Union, Unpack
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
from torch import nn
|
| 28 |
+
|
| 29 |
+
from transformers.activations import ACT2FN
|
| 30 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 31 |
+
from transformers.generation import GenerationMixin
|
| 32 |
+
from transformers.integrations import use_kernel_forward_from_hub
|
| 33 |
+
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
| 34 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 35 |
+
from transformers.modeling_layers import (
|
| 36 |
+
GenericForQuestionAnswering,
|
| 37 |
+
GenericForSequenceClassification,
|
| 38 |
+
GenericForTokenClassification,
|
| 39 |
+
GradientCheckpointingLayer,
|
| 40 |
+
)
|
| 41 |
+
from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
| 42 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 43 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 44 |
+
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
|
| 45 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
| 46 |
+
from transformers.utils.generic import OutputRecorder, check_model_inputs
|
| 47 |
+
from .configuration_minimax_m2 import MiniMaxM2Config
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class MiniMaxM2MLP(nn.Module):
|
| 51 |
+
def __init__(self, config: MiniMaxM2Config):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.ffn_dim = config.intermediate_size
|
| 54 |
+
self.hidden_dim = config.hidden_size
|
| 55 |
+
|
| 56 |
+
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
| 57 |
+
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
| 58 |
+
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
| 59 |
+
|
| 60 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 61 |
+
|
| 62 |
+
def forward(self, hidden_states):
|
| 63 |
+
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
| 64 |
+
current_hidden_states = self.w2(current_hidden_states)
|
| 65 |
+
return current_hidden_states
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class MiniMaxM2Experts(nn.ModuleList):
|
| 69 |
+
"""
|
| 70 |
+
ModuleList of experts.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, config: MiniMaxM2Config):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.top_k = config.num_experts_per_tok
|
| 76 |
+
self.num_experts = config.num_local_experts
|
| 77 |
+
for _ in range(self.num_experts):
|
| 78 |
+
self.append(MiniMaxM2MLP(config))
|
| 79 |
+
|
| 80 |
+
def forward(
|
| 81 |
+
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
| 82 |
+
) -> torch.Tensor:
|
| 83 |
+
"""
|
| 84 |
+
Args:
|
| 85 |
+
hidden_states: (batch_size * sequence_length, hidden_dim)
|
| 86 |
+
selected_experts: (batch_size * sequence_length, top_k)
|
| 87 |
+
routing_weights: (batch_size * sequence_length, top_k)
|
| 88 |
+
Returns:
|
| 89 |
+
(batch_size * sequence_length, hidden_dim)
|
| 90 |
+
"""
|
| 91 |
+
final_hidden_states = torch.zeros_like(hidden_states)
|
| 92 |
+
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
| 93 |
+
|
| 94 |
+
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
| 95 |
+
for expert_idx in expert_hit:
|
| 96 |
+
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
| 97 |
+
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
| 98 |
+
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
|
| 99 |
+
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 100 |
+
return final_hidden_states
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class MiniMaxM2SparseMoeBlock(nn.Module):
|
| 104 |
+
def __init__(self, config):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.top_k = config.num_experts_per_tok
|
| 107 |
+
self.jitter_noise = config.router_jitter_noise
|
| 108 |
+
self.gate = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
|
| 109 |
+
self.experts = MiniMaxM2Experts(config)
|
| 110 |
+
self.register_buffer("e_score_correction_bias", torch.zeros(config.num_local_experts))
|
| 111 |
+
|
| 112 |
+
def route_tokens_to_experts(self, router_logits):
|
| 113 |
+
routing_weights = torch.nn.functional.sigmoid(router_logits.float())
|
| 114 |
+
scores_for_choice = routing_weights + self.e_score_correction_bias
|
| 115 |
+
_, top_k_index = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False)
|
| 116 |
+
top_k_weights = routing_weights.gather(1, top_k_index)
|
| 117 |
+
top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
|
| 118 |
+
return top_k_index, top_k_weights.to(router_logits.dtype)
|
| 119 |
+
|
| 120 |
+
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 121 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 122 |
+
if self.training and self.jitter_noise > 0:
|
| 123 |
+
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
|
| 124 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 125 |
+
router_logits = self.gate(hidden_states)
|
| 126 |
+
top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits)
|
| 127 |
+
hidden_states = self.experts(hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype))
|
| 128 |
+
hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
| 129 |
+
return hidden_states, router_logits
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 133 |
+
class MiniMaxM2RMSNorm(nn.Module):
|
| 134 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 135 |
+
"""
|
| 136 |
+
MiniMaxM2RMSNorm is equivalent to T5LayerNorm
|
| 137 |
+
"""
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 140 |
+
self.variance_epsilon = eps
|
| 141 |
+
|
| 142 |
+
def forward(self, hidden_states):
|
| 143 |
+
input_dtype = hidden_states.dtype
|
| 144 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 145 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 146 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 147 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 148 |
+
|
| 149 |
+
def extra_repr(self):
|
| 150 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 154 |
+
"""
|
| 155 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 156 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 157 |
+
"""
|
| 158 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 159 |
+
if n_rep == 1:
|
| 160 |
+
return hidden_states
|
| 161 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 162 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def eager_attention_forward(
|
| 166 |
+
module: nn.Module,
|
| 167 |
+
query: torch.Tensor,
|
| 168 |
+
key: torch.Tensor,
|
| 169 |
+
value: torch.Tensor,
|
| 170 |
+
attention_mask: Optional[torch.Tensor],
|
| 171 |
+
scaling: float,
|
| 172 |
+
dropout: float = 0.0,
|
| 173 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 174 |
+
):
|
| 175 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 176 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 177 |
+
|
| 178 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 179 |
+
if attention_mask is not None:
|
| 180 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 181 |
+
attn_weights = attn_weights + causal_mask
|
| 182 |
+
|
| 183 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 184 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 185 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 186 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 187 |
+
|
| 188 |
+
return attn_output, attn_weights
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def rotate_half(x):
|
| 192 |
+
"""Rotates half the hidden dims of the input."""
|
| 193 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 194 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 195 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 199 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
q (`torch.Tensor`): The query tensor.
|
| 203 |
+
k (`torch.Tensor`): The key tensor.
|
| 204 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 205 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 206 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 207 |
+
Deprecated and unused.
|
| 208 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 209 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 210 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 211 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 212 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 213 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 214 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 215 |
+
Returns:
|
| 216 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 217 |
+
"""
|
| 218 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 219 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 220 |
+
|
| 221 |
+
# Keep half or full tensor for later concatenation
|
| 222 |
+
rotary_dim = cos.shape[-1]
|
| 223 |
+
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
|
| 224 |
+
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
|
| 225 |
+
|
| 226 |
+
# Apply rotary embeddings on the first half or full tensor
|
| 227 |
+
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
|
| 228 |
+
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
|
| 229 |
+
|
| 230 |
+
# Concatenate back to full shape
|
| 231 |
+
q_embed = torch.cat([q_embed, q_pass], dim=-1)
|
| 232 |
+
k_embed = torch.cat([k_embed, k_pass], dim=-1)
|
| 233 |
+
return q_embed, k_embed
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class MiniMaxM2Attention(nn.Module):
|
| 237 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 238 |
+
|
| 239 |
+
def __init__(self, config: MiniMaxM2Config, layer_idx: int):
|
| 240 |
+
super().__init__()
|
| 241 |
+
self.config = config
|
| 242 |
+
self.layer_idx = layer_idx
|
| 243 |
+
self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 244 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 245 |
+
self.scaling = self.head_dim**-0.5
|
| 246 |
+
self.attention_dropout = config.attention_dropout
|
| 247 |
+
self.is_causal = True
|
| 248 |
+
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
|
| 249 |
+
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
|
| 250 |
+
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
|
| 251 |
+
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
| 252 |
+
|
| 253 |
+
self.use_qk_norm = config.use_qk_norm
|
| 254 |
+
if self.use_qk_norm:
|
| 255 |
+
self.q_norm = MiniMaxM2RMSNorm(self.head_dim * config.num_attention_heads, eps=config.rms_norm_eps)
|
| 256 |
+
self.k_norm = MiniMaxM2RMSNorm(self.head_dim * config.num_key_value_heads, eps=config.rms_norm_eps)
|
| 257 |
+
|
| 258 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 259 |
+
def forward(
|
| 260 |
+
self,
|
| 261 |
+
hidden_states: torch.Tensor,
|
| 262 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 263 |
+
attention_mask: Optional[torch.Tensor],
|
| 264 |
+
past_key_values: Optional[Cache] = None,
|
| 265 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 266 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 267 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 268 |
+
input_shape = hidden_states.shape[:-1]
|
| 269 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 270 |
+
|
| 271 |
+
query_states = self.q_proj(hidden_states)
|
| 272 |
+
key_states = self.k_proj(hidden_states)
|
| 273 |
+
value_states = self.v_proj(hidden_states)
|
| 274 |
+
|
| 275 |
+
if self.use_qk_norm: # main diff from Llama
|
| 276 |
+
query_states = self.q_norm(query_states)
|
| 277 |
+
key_states = self.k_norm(key_states)
|
| 278 |
+
|
| 279 |
+
key_states = key_states.view(hidden_shape)
|
| 280 |
+
query_states = query_states.view(hidden_shape)
|
| 281 |
+
value_states = value_states.view(hidden_shape)
|
| 282 |
+
|
| 283 |
+
query_states = query_states.transpose(1, 2)
|
| 284 |
+
key_states = key_states.transpose(1, 2)
|
| 285 |
+
value_states = value_states.transpose(1, 2)
|
| 286 |
+
|
| 287 |
+
cos, sin = position_embeddings
|
| 288 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 289 |
+
|
| 290 |
+
if past_key_values is not None:
|
| 291 |
+
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
| 292 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 293 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 294 |
+
|
| 295 |
+
attention_interface: Callable = eager_attention_forward
|
| 296 |
+
if self.config._attn_implementation != "eager":
|
| 297 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 298 |
+
|
| 299 |
+
attn_output, attn_weights = attention_interface(
|
| 300 |
+
self,
|
| 301 |
+
query_states,
|
| 302 |
+
key_states,
|
| 303 |
+
value_states,
|
| 304 |
+
attention_mask,
|
| 305 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 306 |
+
scaling=self.scaling,
|
| 307 |
+
**kwargs,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 311 |
+
attn_output = self.o_proj(attn_output)
|
| 312 |
+
return attn_output, attn_weights
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class MiniMaxM2DecoderLayer(GradientCheckpointingLayer):
|
| 316 |
+
def __init__(self, config: MiniMaxM2Config, layer_idx: int):
|
| 317 |
+
super().__init__()
|
| 318 |
+
self.hidden_size = config.hidden_size
|
| 319 |
+
|
| 320 |
+
self.self_attn = MiniMaxM2Attention(config, layer_idx)
|
| 321 |
+
|
| 322 |
+
self.block_sparse_moe = MiniMaxM2SparseMoeBlock(config)
|
| 323 |
+
self.input_layernorm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 324 |
+
self.post_attention_layernorm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 325 |
+
|
| 326 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 327 |
+
def forward(
|
| 328 |
+
self,
|
| 329 |
+
hidden_states: torch.Tensor,
|
| 330 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 331 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 332 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 333 |
+
past_key_values: Optional[Cache] = None,
|
| 334 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 335 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 336 |
+
) -> torch.FloatTensor:
|
| 337 |
+
residual = hidden_states
|
| 338 |
+
|
| 339 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 340 |
+
|
| 341 |
+
# Self Attention
|
| 342 |
+
hidden_states, _ = self.self_attn(
|
| 343 |
+
hidden_states=hidden_states,
|
| 344 |
+
position_embeddings=position_embeddings,
|
| 345 |
+
attention_mask=attention_mask,
|
| 346 |
+
position_ids=position_ids,
|
| 347 |
+
past_key_values=past_key_values,
|
| 348 |
+
cache_position=cache_position,
|
| 349 |
+
**kwargs,
|
| 350 |
+
)
|
| 351 |
+
hidden_states = residual + hidden_states
|
| 352 |
+
|
| 353 |
+
# Fully Connected
|
| 354 |
+
residual = hidden_states
|
| 355 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 356 |
+
hidden_states, _ = self.block_sparse_moe(hidden_states)
|
| 357 |
+
hidden_states = residual + hidden_states
|
| 358 |
+
|
| 359 |
+
return hidden_states
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class MiniMaxM2RotaryEmbedding(nn.Module):
|
| 363 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 364 |
+
|
| 365 |
+
def __init__(self, config: MiniMaxM2Config, device=None):
|
| 366 |
+
super().__init__()
|
| 367 |
+
# BC: "rope_type" was originally "type"
|
| 368 |
+
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 369 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 370 |
+
else:
|
| 371 |
+
self.rope_type = "default"
|
| 372 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 373 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 374 |
+
|
| 375 |
+
self.config = config
|
| 376 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 377 |
+
|
| 378 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 379 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 380 |
+
self.original_inv_freq = self.inv_freq
|
| 381 |
+
|
| 382 |
+
@torch.no_grad()
|
| 383 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 384 |
+
def forward(self, x, position_ids):
|
| 385 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 386 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 387 |
+
|
| 388 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 389 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 390 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 391 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 392 |
+
cos = emb.cos() * self.attention_scaling
|
| 393 |
+
sin = emb.sin() * self.attention_scaling
|
| 394 |
+
|
| 395 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
@auto_docstring
|
| 399 |
+
class MiniMaxM2PreTrainedModel(PreTrainedModel):
|
| 400 |
+
config: MiniMaxM2Config
|
| 401 |
+
base_model_prefix = "model"
|
| 402 |
+
supports_gradient_checkpointing = True
|
| 403 |
+
_no_split_modules = ["MiniMaxM2DecoderLayer"]
|
| 404 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 405 |
+
_supports_flash_attn = True
|
| 406 |
+
_supports_sdpa = True
|
| 407 |
+
_supports_flex_attn = True
|
| 408 |
+
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
| 409 |
+
_supports_attention_backend = True
|
| 410 |
+
_can_record_outputs = {
|
| 411 |
+
"router_logits": OutputRecorder(MiniMaxM2SparseMoeBlock, index=1),
|
| 412 |
+
"hidden_states": MiniMaxM2DecoderLayer,
|
| 413 |
+
"attentions": MiniMaxM2Attention,
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
@auto_docstring
|
| 418 |
+
class MiniMaxM2Model(MiniMaxM2PreTrainedModel):
|
| 419 |
+
def __init__(self, config: MiniMaxM2Config):
|
| 420 |
+
super().__init__(config)
|
| 421 |
+
self.padding_idx = config.pad_token_id
|
| 422 |
+
self.vocab_size = config.vocab_size
|
| 423 |
+
|
| 424 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 425 |
+
self.layers = nn.ModuleList(
|
| 426 |
+
[MiniMaxM2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 427 |
+
)
|
| 428 |
+
self.norm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 429 |
+
self.rotary_emb = MiniMaxM2RotaryEmbedding(config=config)
|
| 430 |
+
self.gradient_checkpointing = False
|
| 431 |
+
|
| 432 |
+
# Initialize weights and apply final processing
|
| 433 |
+
self.post_init()
|
| 434 |
+
|
| 435 |
+
@check_model_inputs
|
| 436 |
+
@auto_docstring
|
| 437 |
+
def forward(
|
| 438 |
+
self,
|
| 439 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 440 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 441 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 442 |
+
past_key_values: Optional[Cache] = None,
|
| 443 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 444 |
+
use_cache: Optional[bool] = None,
|
| 445 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 446 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 447 |
+
) -> MoeModelOutputWithPast:
|
| 448 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 449 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 450 |
+
|
| 451 |
+
if use_cache and past_key_values is None:
|
| 452 |
+
past_key_values = DynamicCache(config=self.config)
|
| 453 |
+
|
| 454 |
+
if inputs_embeds is None:
|
| 455 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 456 |
+
|
| 457 |
+
if cache_position is None:
|
| 458 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 459 |
+
cache_position = torch.arange(
|
| 460 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 461 |
+
)
|
| 462 |
+
if position_ids is None:
|
| 463 |
+
position_ids = cache_position.unsqueeze(0)
|
| 464 |
+
|
| 465 |
+
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
|
| 466 |
+
causal_mask = mask_function(
|
| 467 |
+
config=self.config,
|
| 468 |
+
input_embeds=inputs_embeds,
|
| 469 |
+
attention_mask=attention_mask,
|
| 470 |
+
cache_position=cache_position,
|
| 471 |
+
past_key_values=past_key_values,
|
| 472 |
+
position_ids=position_ids,
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
hidden_states = inputs_embeds
|
| 476 |
+
|
| 477 |
+
# create position embeddings to be shared across the decoder layers
|
| 478 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 479 |
+
|
| 480 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 481 |
+
hidden_states = decoder_layer(
|
| 482 |
+
hidden_states,
|
| 483 |
+
position_embeddings=position_embeddings,
|
| 484 |
+
attention_mask=causal_mask,
|
| 485 |
+
position_ids=position_ids,
|
| 486 |
+
past_key_values=past_key_values,
|
| 487 |
+
use_cache=use_cache,
|
| 488 |
+
cache_position=cache_position,
|
| 489 |
+
**kwargs,
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
hidden_states = self.norm(hidden_states)
|
| 493 |
+
|
| 494 |
+
return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
|
| 495 |
+
last_hidden_state=hidden_states,
|
| 496 |
+
past_key_values=past_key_values,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def load_balancing_loss_func(
|
| 501 |
+
gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
|
| 502 |
+
num_experts: Optional[int] = None,
|
| 503 |
+
top_k=2,
|
| 504 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 505 |
+
) -> Union[torch.Tensor, int]:
|
| 506 |
+
r"""
|
| 507 |
+
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
| 508 |
+
|
| 509 |
+
See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
|
| 510 |
+
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
| 511 |
+
experts is too unbalanced.
|
| 512 |
+
|
| 513 |
+
Args:
|
| 514 |
+
gate_logits:
|
| 515 |
+
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
|
| 516 |
+
shape [batch_size X sequence_length, num_experts].
|
| 517 |
+
num_experts:
|
| 518 |
+
Number of experts
|
| 519 |
+
top_k:
|
| 520 |
+
The number of experts to route per-token, can be also interpreted as the `top-k` routing
|
| 521 |
+
parameter.
|
| 522 |
+
attention_mask (`torch.Tensor`, *optional*):
|
| 523 |
+
The attention_mask used in forward function
|
| 524 |
+
shape [batch_size X sequence_length] if not None.
|
| 525 |
+
|
| 526 |
+
Returns:
|
| 527 |
+
The auxiliary loss.
|
| 528 |
+
"""
|
| 529 |
+
if gate_logits is None or not isinstance(gate_logits, tuple):
|
| 530 |
+
return 0
|
| 531 |
+
|
| 532 |
+
if isinstance(gate_logits, tuple):
|
| 533 |
+
compute_device = gate_logits[0].device
|
| 534 |
+
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
|
| 535 |
+
|
| 536 |
+
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
|
| 537 |
+
|
| 538 |
+
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
| 539 |
+
|
| 540 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
| 541 |
+
|
| 542 |
+
if attention_mask is None:
|
| 543 |
+
# Compute the percentage of tokens routed to each experts
|
| 544 |
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
| 545 |
+
|
| 546 |
+
# Compute the average probability of routing to these experts
|
| 547 |
+
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
| 548 |
+
else:
|
| 549 |
+
batch_size, sequence_length = attention_mask.shape
|
| 550 |
+
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
|
| 551 |
+
|
| 552 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
| 553 |
+
expert_attention_mask = (
|
| 554 |
+
attention_mask[None, :, :, None, None]
|
| 555 |
+
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
|
| 556 |
+
.reshape(-1, top_k, num_experts)
|
| 557 |
+
.to(compute_device)
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
# Compute the percentage of tokens routed to each experts
|
| 561 |
+
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
|
| 562 |
+
expert_attention_mask, dim=0
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
| 566 |
+
router_per_expert_attention_mask = (
|
| 567 |
+
attention_mask[None, :, :, None]
|
| 568 |
+
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
| 569 |
+
.reshape(-1, num_experts)
|
| 570 |
+
.to(compute_device)
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
# Compute the average probability of routing to these experts
|
| 574 |
+
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
|
| 575 |
+
router_per_expert_attention_mask, dim=0
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
| 579 |
+
return overall_loss * num_experts
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
@auto_docstring
|
| 583 |
+
class MiniMaxM2ForCausalLM(MiniMaxM2PreTrainedModel, GenerationMixin):
|
| 584 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 585 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 586 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 587 |
+
|
| 588 |
+
def __init__(self, config):
|
| 589 |
+
super().__init__(config)
|
| 590 |
+
self.model = MiniMaxM2Model(config)
|
| 591 |
+
self.vocab_size = config.vocab_size
|
| 592 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 593 |
+
self.router_aux_loss_coef = config.router_aux_loss_coef
|
| 594 |
+
self.num_experts = config.num_local_experts
|
| 595 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
| 596 |
+
|
| 597 |
+
# Initialize weights and apply final processing
|
| 598 |
+
self.post_init()
|
| 599 |
+
|
| 600 |
+
@can_return_tuple
|
| 601 |
+
@auto_docstring
|
| 602 |
+
def forward(
|
| 603 |
+
self,
|
| 604 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 605 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 606 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 607 |
+
past_key_values: Optional[Cache] = None,
|
| 608 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 609 |
+
labels: Optional[torch.LongTensor] = None,
|
| 610 |
+
use_cache: Optional[bool] = None,
|
| 611 |
+
output_router_logits: Optional[bool] = None,
|
| 612 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 613 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 614 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 615 |
+
) -> MoeCausalLMOutputWithPast:
|
| 616 |
+
r"""
|
| 617 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 618 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 619 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 620 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 621 |
+
|
| 622 |
+
Example:
|
| 623 |
+
|
| 624 |
+
```python
|
| 625 |
+
>>> from transformers import AutoTokenizer, MiniMaxM2ForCausalLM
|
| 626 |
+
|
| 627 |
+
>>> model = MiniMaxM2ForCausalLM.from_pretrained("mistralai/MiniMaxM2-8x7B-v0.1")
|
| 628 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/MiniMaxM2-8x7B-v0.1")
|
| 629 |
+
|
| 630 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 631 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 632 |
+
|
| 633 |
+
>>> # Generate
|
| 634 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 635 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 636 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 637 |
+
```"""
|
| 638 |
+
|
| 639 |
+
output_router_logits = (
|
| 640 |
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 644 |
+
outputs: MoeModelOutputWithPast = self.model(
|
| 645 |
+
input_ids=input_ids,
|
| 646 |
+
attention_mask=attention_mask,
|
| 647 |
+
position_ids=position_ids,
|
| 648 |
+
past_key_values=past_key_values,
|
| 649 |
+
inputs_embeds=inputs_embeds,
|
| 650 |
+
use_cache=use_cache,
|
| 651 |
+
output_router_logits=output_router_logits,
|
| 652 |
+
cache_position=cache_position,
|
| 653 |
+
**kwargs,
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
hidden_states = outputs.last_hidden_state
|
| 657 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 658 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 659 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 660 |
+
|
| 661 |
+
loss = None
|
| 662 |
+
if labels is not None:
|
| 663 |
+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
| 664 |
+
|
| 665 |
+
aux_loss = None
|
| 666 |
+
if output_router_logits:
|
| 667 |
+
aux_loss = load_balancing_loss_func(
|
| 668 |
+
outputs.router_logits,
|
| 669 |
+
self.num_experts,
|
| 670 |
+
self.num_experts_per_tok,
|
| 671 |
+
attention_mask,
|
| 672 |
+
)
|
| 673 |
+
if labels is not None:
|
| 674 |
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
| 675 |
+
|
| 676 |
+
return MoeCausalLMOutputWithPast(
|
| 677 |
+
loss=loss,
|
| 678 |
+
aux_loss=aux_loss,
|
| 679 |
+
logits=logits,
|
| 680 |
+
past_key_values=outputs.past_key_values,
|
| 681 |
+
hidden_states=outputs.hidden_states,
|
| 682 |
+
attentions=outputs.attentions,
|
| 683 |
+
router_logits=outputs.router_logits,
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
class MiniMaxM2ForSequenceClassification(GenericForSequenceClassification, MiniMaxM2PreTrainedModel):
|
| 688 |
+
pass
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
class MiniMaxM2ForTokenClassification(GenericForTokenClassification, MiniMaxM2PreTrainedModel):
|
| 692 |
+
pass
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
class MiniMaxM2ForQuestionAnswering(GenericForQuestionAnswering, MiniMaxM2PreTrainedModel):
|
| 696 |
+
pass
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
__all__ = [
|
| 700 |
+
"MiniMaxM2ForCausalLM",
|
| 701 |
+
"MiniMaxM2ForQuestionAnswering",
|
| 702 |
+
"MiniMaxM2Model",
|
| 703 |
+
"MiniMaxM2PreTrainedModel",
|
| 704 |
+
"MiniMaxM2ForSequenceClassification",
|
| 705 |
+
"MiniMaxM2ForTokenClassification",
|
| 706 |
+
]
|