Gökdeniz Gülmez commited on
Commit
c4ce38b
·
1 Parent(s): 70b9ed8
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.vscode/settings.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "python-envs.defaultEnvManager": "ms-python.python:conda",
3
+ "python-envs.defaultPackageManager": "ms-python.python:conda",
4
+ "python-envs.pythonProjects": []
5
+ }
__pycache__/configuration_minimax.cpython-313.pyc ADDED
Binary file (10.1 kB). View file
 
__pycache__/modular_minimax.cpython-313.pyc ADDED
Binary file (28.5 kB). View file
 
config.json CHANGED
@@ -3,15 +3,15 @@
3
  "MiniMaxText01ForCausalLM"
4
  ],
5
  "attention_dropout": 0.0,
6
- "attn_type_list": [
7
- 0,
8
- 1,
9
- 0,
10
- 1
11
  ],
12
  "auto_map": {
13
- "AutoConfig": "configuration_minimax_text_01.MiniMaxText01Config",
14
- "AutoModelForCausalLM": "modeling_minimax_text_01.MiniMaxText01ForCausalLM"
15
  },
16
  "bos_token_id": null,
17
  "eos_token_id": 200020,
@@ -27,7 +27,7 @@
27
  "layernorm_mlp_alpha": 3.5565588200778455,
28
  "layernorm_mlp_beta": 1.0,
29
  "max_position_embeddings": 1024,
30
- "model_type": "minimax_text_01",
31
  "num_attention_heads": 4,
32
  "num_experts_per_tok": 1,
33
  "num_hidden_layers": 4,
 
3
  "MiniMaxText01ForCausalLM"
4
  ],
5
  "attention_dropout": 0.0,
6
+ "layer_types": [
7
+ "linear_attention",
8
+ "full_attention",
9
+ "linear_attention",
10
+ "full_attention"
11
  ],
12
  "auto_map": {
13
+ "AutoConfig": "configuration_minimax.MiniMaxConfig",
14
+ "AutoModelForCausalLM": "modeling_minimax.MiniMaxForCausalLM"
15
  },
16
  "bos_token_id": null,
17
  "eos_token_id": 200020,
 
27
  "layernorm_mlp_alpha": 3.5565588200778455,
28
  "layernorm_mlp_beta": 1.0,
29
  "max_position_embeddings": 1024,
30
+ "model_type": "minimax",
31
  "num_attention_heads": 4,
32
  "num_experts_per_tok": 1,
33
  "num_hidden_layers": 4,
configuration_minimax_text_01.py → configuration_minimax.py RENAMED
@@ -1,26 +1,48 @@
1
- """ MiniMaxText01 model configuration"""
2
-
3
- from transformers.configuration_utils import PretrainedConfig
4
- from transformers.utils import logging
5
-
6
-
7
- logger = logging.get_logger(__name__)
8
-
9
-
10
- class MiniMaxText01Config(PretrainedConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  r"""
12
- This is the configuration class to store the configuration of a [`MiniMaxText01Model`]. It is used to instantiate an
13
- MiniMaxText01 model according to the specified arguments, defining the model architecture. Instantiating a configuration
14
- with the defaults will yield a similar configuration to that of the MiniMaxText01.
15
 
16
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
17
- documentation from [`PretrainedConfig`] for more information.
 
 
18
 
19
 
20
  Args:
21
  vocab_size (`int`, *optional*, defaults to 32000):
22
- Vocabulary size of the MiniMaxText01 model. Defines the number of different tokens that can be represented by the
23
- `inputs_ids` passed when calling [`MiniMaxText01Model`]
24
  hidden_size (`int`, *optional*, defaults to 4096):
25
  Dimension of the hidden representations.
26
  intermediate_size (`int`, *optional*, defaults to 14336):
@@ -32,14 +54,16 @@ class MiniMaxText01Config(PretrainedConfig):
32
  num_key_value_heads (`int`, *optional*, defaults to 8):
33
  This is the number of key_value heads that should be used to implement Grouped Query Attention. If
34
  `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
35
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
36
  converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
37
- by meanpooling all the original heads within that group. For more details checkout [this
38
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
 
 
39
  hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
40
  The non-linear activation function (function or string) in the decoder.
41
  max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
42
- The maximum sequence length that this model might ever be used with. MiniMaxText01's sliding window attention
43
  allows sequence of up to 4096*32 tokens.
44
  initializer_range (`float`, *optional*, defaults to 0.02):
45
  The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
@@ -68,28 +92,63 @@ class MiniMaxText01Config(PretrainedConfig):
68
  num_local_experts (`int`, *optional*, defaults to 8):
69
  Number of experts per Sparse MLP layer.
70
  output_router_logits (`bool`, *optional*, defaults to `False`):
71
- Whether or not the router logits should be returned by the model. Enabeling this will also
72
  allow the model to output the auxiliary loss. See [here]() for more details
73
  router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
74
  The aux loss factor for the total loss.
75
  router_jitter_noise (`float`, *optional*, defaults to 0.0):
76
  Amount of noise to add to the router.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  ```python
79
- >>> from transformers import MiniMaxText01Model, MiniMaxText01Config
80
 
81
- >>> # Initializing a MiniMaxText01 style configuration
82
- >>> configuration = MiniMaxText01Config()
83
 
84
- >>> # Initializing a model from the MiniMaxText01 style configuration
85
- >>> model = MiniMaxText01Model(configuration)
86
 
87
  >>> # Accessing the model configuration
88
  >>> configuration = model.config
89
  ```"""
90
 
91
- model_type = "MiniMaxText01"
92
  keys_to_ignore_at_inference = ["past_key_values"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  def __init__(
95
  self,
@@ -99,14 +158,15 @@ class MiniMaxText01Config(PretrainedConfig):
99
  num_hidden_layers=32,
100
  num_attention_heads=32,
101
  num_key_value_heads=8,
 
102
  hidden_act="silu",
103
  max_position_embeddings=4096 * 32,
104
  initializer_range=0.02,
105
  rms_norm_eps=1e-5,
106
  use_cache=True,
107
  pad_token_id=None,
108
- bos_token_id=None,
109
- eos_token_id=None,
110
  tie_word_embeddings=False,
111
  rope_theta=1e6,
112
  sliding_window=None,
@@ -116,8 +176,23 @@ class MiniMaxText01Config(PretrainedConfig):
116
  output_router_logits=False,
117
  router_aux_loss_coef=0.001,
118
  router_jitter_noise=0.0,
 
 
 
 
 
 
 
 
119
  **kwargs,
120
  ):
 
 
 
 
 
 
 
121
  self.vocab_size = vocab_size
122
  self.max_position_embeddings = max_position_embeddings
123
  self.hidden_size = hidden_size
@@ -137,16 +212,27 @@ class MiniMaxText01Config(PretrainedConfig):
137
  self.use_cache = use_cache
138
  self.rope_theta = rope_theta
139
  self.attention_dropout = attention_dropout
 
140
 
141
  self.num_experts_per_tok = num_experts_per_tok
142
  self.num_local_experts = num_local_experts
143
  self.output_router_logits = output_router_logits
144
  self.router_aux_loss_coef = router_aux_loss_coef
145
  self.router_jitter_noise = router_jitter_noise
146
- super().__init__(
147
- pad_token_id=pad_token_id,
148
- bos_token_id=bos_token_id,
149
- eos_token_id=eos_token_id,
150
- tie_word_embeddings=tie_word_embeddings,
151
- **kwargs,
152
- )
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/minimax/modular_minimax.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_minimax.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
9
+ #
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ from transformers import PretrainedConfig
23
+ try:
24
+ from transformers.configuration_utils import layer_type_validation
25
+ except ImportError:
26
+ # fallback for new versions where layer_type_validation moved or was removed
27
+ def layer_type_validation(x):
28
+ return x
29
+
30
+ class MiniMaxConfig(PretrainedConfig):
31
  r"""
32
+ This is the configuration class to store the configuration of a [`MiniMaxModel`]. It is used to instantiate an
33
+ MiniMax model according to the specified arguments, defining the model architecture. Instantiating a configuration
34
+ with the defaults will yield a similar configuration to that of the MiniMax.
35
 
36
+ [MiniMaxAI/MiniMax-Text-01-hf](https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf)
37
+
38
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
39
+ documentation from [`PreTrainedConfig`] for more information.
40
 
41
 
42
  Args:
43
  vocab_size (`int`, *optional*, defaults to 32000):
44
+ Vocabulary size of the MiniMax model. Defines the number of different tokens that can be represented by the
45
+ `inputs_ids` passed when calling [`MiniMaxModel`]
46
  hidden_size (`int`, *optional*, defaults to 4096):
47
  Dimension of the hidden representations.
48
  intermediate_size (`int`, *optional*, defaults to 14336):
 
54
  num_key_value_heads (`int`, *optional*, defaults to 8):
55
  This is the number of key_value heads that should be used to implement Grouped Query Attention. If
56
  `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
57
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
58
  converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
59
+ by meanpooling all the original heads within that group. For more details, check out [this
60
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.
61
+ head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
62
+ The attention head dimension.
63
  hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
64
  The non-linear activation function (function or string) in the decoder.
65
  max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
66
+ The maximum sequence length that this model might ever be used with. MiniMax's sliding window attention
67
  allows sequence of up to 4096*32 tokens.
68
  initializer_range (`float`, *optional*, defaults to 0.02):
69
  The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
 
92
  num_local_experts (`int`, *optional*, defaults to 8):
93
  Number of experts per Sparse MLP layer.
94
  output_router_logits (`bool`, *optional*, defaults to `False`):
95
+ Whether or not the router logits should be returned by the model. Enabling this will also
96
  allow the model to output the auxiliary loss. See [here]() for more details
97
  router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
98
  The aux loss factor for the total loss.
99
  router_jitter_noise (`float`, *optional*, defaults to 0.0):
100
  Amount of noise to add to the router.
101
+ layer_types (`list`, *optional*):
102
+ Attention pattern for each layer.
103
+ block_size (`int`, *optional*, defaults to 256):
104
+ The length of each attention block, determining how queries, keys, and values
105
+ are grouped and processed for intra- and inter-block attention.
106
+ full_attn_alpha_factor (`float`, *optional*, defaults to 1):
107
+ Weight for residual value in residual connection after normal attention.
108
+ full_attn_beta_factor (`float`, *optional*, defaults to 1):
109
+ Weight for hidden state value in residual connection after normal attention.
110
+ linear_attn_alpha_factor (`float`, *optional*, defaults to 1):
111
+ Weight for residual value in residual connection after lightning attention.
112
+ linear_attn_beta_factor (`float`, *optional*, defaults to 1):
113
+ Weight for hidden state value in residual connection after lightning attention.
114
+ mlp_alpha_factor (`float`, *optional*, defaults to 1):
115
+ Weight for residual value in residual connection after MLP.
116
+ mlp_beta_factor (`float`, *optional*, defaults to 1):
117
+ Weight for hidden state value in residual connection after MLP.
118
 
119
  ```python
120
+ >>> from transformers import MiniMaxModel, MiniMaxConfig
121
 
122
+ >>> # Initializing a MiniMax style configuration
123
+ >>> configuration = MiniMaxConfig()
124
 
125
+ >>> # Initializing a model from the MiniMax style configuration
126
+ >>> model = MiniMaxModel(configuration)
127
 
128
  >>> # Accessing the model configuration
129
  >>> configuration = model.config
130
  ```"""
131
 
132
+ model_type = "minimax"
133
  keys_to_ignore_at_inference = ["past_key_values"]
134
+ base_model_tp_plan = {
135
+ "layers.*.self_attn.q_proj": "colwise",
136
+ "layers.*.self_attn.k_proj": "colwise",
137
+ "layers.*.self_attn.v_proj": "colwise",
138
+ "layers.*.self_attn.o_proj": "rowwise",
139
+ "layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts
140
+ "layers.*.block_sparse_moe.experts.*.w1": "colwise",
141
+ "layers.*.block_sparse_moe.experts.*.w2": "rowwise",
142
+ "layers.*.block_sparse_moe.experts.*.w3": "colwise",
143
+ }
144
+ base_model_pp_plan = {
145
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
146
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
147
+ "norm": (["hidden_states"], ["hidden_states"]),
148
+ }
149
+ attribute_map = {
150
+ "num_experts": "num_local_experts",
151
+ }
152
 
153
  def __init__(
154
  self,
 
158
  num_hidden_layers=32,
159
  num_attention_heads=32,
160
  num_key_value_heads=8,
161
+ head_dim=None,
162
  hidden_act="silu",
163
  max_position_embeddings=4096 * 32,
164
  initializer_range=0.02,
165
  rms_norm_eps=1e-5,
166
  use_cache=True,
167
  pad_token_id=None,
168
+ bos_token_id=1,
169
+ eos_token_id=2,
170
  tie_word_embeddings=False,
171
  rope_theta=1e6,
172
  sliding_window=None,
 
176
  output_router_logits=False,
177
  router_aux_loss_coef=0.001,
178
  router_jitter_noise=0.0,
179
+ layer_types=None,
180
+ block_size=256,
181
+ full_attn_alpha_factor=1,
182
+ full_attn_beta_factor=1,
183
+ linear_attn_alpha_factor=1,
184
+ linear_attn_beta_factor=1,
185
+ mlp_alpha_factor=1,
186
+ mlp_beta_factor=1,
187
  **kwargs,
188
  ):
189
+ super().__init__(
190
+ pad_token_id=pad_token_id,
191
+ bos_token_id=bos_token_id,
192
+ eos_token_id=eos_token_id,
193
+ tie_word_embeddings=tie_word_embeddings,
194
+ **kwargs,
195
+ )
196
  self.vocab_size = vocab_size
197
  self.max_position_embeddings = max_position_embeddings
198
  self.hidden_size = hidden_size
 
212
  self.use_cache = use_cache
213
  self.rope_theta = rope_theta
214
  self.attention_dropout = attention_dropout
215
+ self.head_dim = head_dim
216
 
217
  self.num_experts_per_tok = num_experts_per_tok
218
  self.num_local_experts = num_local_experts
219
  self.output_router_logits = output_router_logits
220
  self.router_aux_loss_coef = router_aux_loss_coef
221
  self.router_jitter_noise = router_jitter_noise
222
+ self.layer_types = layer_types
223
+ self.block_size = block_size
224
+ self.full_attn_alpha_factor = full_attn_alpha_factor
225
+ self.full_attn_beta_factor = full_attn_beta_factor
226
+ self.linear_attn_alpha_factor = linear_attn_alpha_factor
227
+ self.linear_attn_beta_factor = linear_attn_beta_factor
228
+ self.mlp_alpha_factor = mlp_alpha_factor
229
+ self.mlp_beta_factor = mlp_beta_factor
230
+
231
+ if self.layer_types is None:
232
+ self.layer_types = [
233
+ "full_attention" if bool((i + 1) % 2) else "linear_attention" for i in range(self.num_hidden_layers)
234
+ ]
235
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
236
+
237
+
238
+ __all__ = ["MiniMaxConfig"]
create.py CHANGED
@@ -4,23 +4,23 @@ from safetensors.torch import save_file
4
  import json
5
 
6
  # Add the directory containing your modeling.py and configuration.py to the Python path
7
- model_dir = "/Users/gokdenizgulmez/Desktop/mlx-lm/mlx_lm/MiniMiniMax01Text"
8
  sys.path.append(model_dir)
9
 
10
  # Import your custom model and configuration classes
11
- from modeling_minimax_text_01 import MiniMaxText01ForCausalLM
12
- from configuration_minimax_text_01 import MiniMaxText01Config
13
 
14
  # Load the configuration
15
- config_path = os.path.join(model_dir, "onfig.json")
16
  with open(config_path, 'r') as f:
17
  config_dict = json.load(f)
18
 
19
  # Create the configuration object
20
- config = MiniMaxText01Config(**config_dict)
21
 
22
  # Create the model
23
- small_model = MiniMaxText01ForCausalLM(config)
24
 
25
  # Print parameter count to verify
26
  param_count = sum(p.numel() for p in small_model.parameters())
 
4
  import json
5
 
6
  # Add the directory containing your modeling.py and configuration.py to the Python path
7
+ model_dir = "/Users/Goekdeniz.Guelmez@computacenter.com/Library/CloudStorage/OneDrive-COMPUTACENTER/Desktop/MiniMax01Text-Dev"
8
  sys.path.append(model_dir)
9
 
10
  # Import your custom model and configuration classes
11
+ from modular_minimax import MiniMaxForCausalLM
12
+ from configuration_minimax import MiniMaxConfig
13
 
14
  # Load the configuration
15
+ config_path = os.path.join(model_dir, "config.json")
16
  with open(config_path, 'r') as f:
17
  config_dict = json.load(f)
18
 
19
  # Create the configuration object
20
+ config = MiniMaxConfig(**config_dict)
21
 
22
  # Create the model
23
+ small_model = MiniMaxForCausalLM(config)
24
 
25
  # Print parameter count to verify
26
  param_count = sum(p.numel() for p in small_model.parameters())
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c67af0b37ff2ae7e3fee3c1e8bf04781f7068593773213c2cf9d7856b48d2e7a
3
- size 424436000
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c7797546b0e092d6b3236d1f9826af42bfe293592590e2032aafe77ba8592a4
3
+ size 423910680
modeling_minimax.py ADDED
@@ -0,0 +1,887 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/minimax/modular_minimax.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_minimax.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
9
+ #
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+ from typing import Callable, Optional, Union
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+ from torch import nn
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.cache_utils import Cache, DynamicCache
31
+ from transformers.generation import GenerationMixin
32
+ from transformers.integrations import use_kernel_forward_from_hub
33
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
34
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
35
+ from transformers.modeling_layers import (
36
+ GenericForQuestionAnswering,
37
+ GenericForSequenceClassification,
38
+ GenericForTokenClassification,
39
+ GradientCheckpointingLayer,
40
+ )
41
+ from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
42
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
43
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
44
+ from transformers.processing_utils import Unpack
45
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
46
+ from transformers.utils.generic import OutputRecorder, check_model_inputs
47
+ from .configuration_minimax import MiniMaxConfig
48
+
49
+
50
+ @use_kernel_forward_from_hub("RMSNorm")
51
+ class MiniMaxRMSNorm(nn.Module):
52
+ def __init__(self, hidden_size, eps=1e-6):
53
+ """
54
+ MiniMaxRMSNorm is equivalent to T5LayerNorm
55
+ """
56
+ super().__init__()
57
+ self.weight = nn.Parameter(torch.ones(hidden_size))
58
+ self.variance_epsilon = eps
59
+
60
+ def forward(self, hidden_states):
61
+ input_dtype = hidden_states.dtype
62
+ hidden_states = hidden_states.to(torch.float32)
63
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
64
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
65
+ return self.weight * hidden_states.to(input_dtype)
66
+
67
+ def extra_repr(self):
68
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
69
+
70
+
71
+ class MiniMaxCache(DynamicCache):
72
+ def __init__(self):
73
+ super().__init__()
74
+ self.linear_cache: list[torch.Tensor] = []
75
+
76
+ def set_linear_cache(self, layer_idx, linear_cache):
77
+ # There may be skipped layers, fill them with empty lists
78
+ for _ in range(len(self.linear_cache), layer_idx + 1):
79
+ self.linear_cache.append([])
80
+ self.linear_cache[layer_idx] = linear_cache
81
+
82
+ def get_linear_cache(self, layer_idx: int):
83
+ if layer_idx < len(self):
84
+ return self.linear_cache[layer_idx]
85
+ return None
86
+
87
+ def __len__(self):
88
+ return max(super().__len__(), len(self.linear_cache))
89
+
90
+ def __getitem__(self, layer_idx: int):
91
+ if layer_idx < len(self.linear_cache) and self.linear_cache[layer_idx] != []:
92
+ return (self.linear_cache[layer_idx],)
93
+ return super().__getitem__(layer_idx)
94
+
95
+ def __iter__(self):
96
+ for layer_idx in range(len(self)):
97
+ yield self[layer_idx]
98
+
99
+ def batch_repeat_interleave(self, repeats: int):
100
+ for layer_idx in range(len(self)):
101
+ if self.linear_cache[layer_idx] != []:
102
+ self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0)
103
+ else:
104
+ self.layers[layer_idx].batch_repeat_interleave(repeats)
105
+
106
+ def batch_select_indices(self, indices: torch.Tensor):
107
+ for layer_idx in range(len(self)):
108
+ if self.linear_cache[layer_idx] != []:
109
+ self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...]
110
+ else:
111
+ self.layers[layer_idx].batch_select_indices(indices)
112
+
113
+ def crop(self, max_length: int):
114
+ raise RuntimeError("MiniMaxCache doesnot support `crop` method")
115
+
116
+
117
+ class MiniMaxLightningAttention(nn.Module):
118
+ def __init__(self, config: MiniMaxConfig, layer_idx: int):
119
+ super().__init__()
120
+ self.layer_idx = layer_idx
121
+ self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
122
+ self.num_attention_heads = config.num_attention_heads
123
+ self.num_hidden_layers = config.num_hidden_layers
124
+ self.block_size = config.block_size
125
+
126
+ self.act_fn = ACT2FN[config.hidden_act]
127
+ self.norm = MiniMaxRMSNorm(self.head_dim * self.num_attention_heads)
128
+ self.qkv_proj = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim * 3, bias=False)
129
+ self.out_proj = nn.Linear(self.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
130
+ self.output_gate = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
131
+
132
+ slope_rate = self.get_slope_rate()
133
+ query_decay, key_decay, diagonal_decay = self.decay_factors(slope_rate)
134
+
135
+ self.register_buffer("slope_rate", slope_rate)
136
+ self.register_buffer("query_decay", query_decay)
137
+ self.register_buffer("key_decay", key_decay)
138
+ self.register_buffer("diagonal_decay", diagonal_decay)
139
+
140
+ def get_slope_rate(self):
141
+ base = 1 / (2 ** (8 / self.num_attention_heads))
142
+ exponent = torch.arange(self.num_attention_heads) + 1
143
+ factor = 1 - self.layer_idx / (self.num_hidden_layers - 1 + 1e-5) + 1e-5
144
+
145
+ rate = base**exponent
146
+ rate = rate * factor
147
+ rate = rate[:, None, None]
148
+
149
+ return rate
150
+
151
+ def decay_factors(self, slope_rate):
152
+ block_size_range = torch.arange(self.block_size) + 1
153
+
154
+ query_decay = torch.exp(-slope_rate * block_size_range[:, None])
155
+ key_decay = torch.exp(-slope_rate * (self.block_size - block_size_range[:, None]))
156
+
157
+ diagonal_decay = block_size_range[:, None] - block_size_range[None, :]
158
+ diagonal_decay = diagonal_decay[None, None, :, :]
159
+ diagonal_decay = slope_rate * diagonal_decay
160
+ diagonal_decay = torch.where(diagonal_decay >= 0, -diagonal_decay, float("-inf"))
161
+ diagonal_decay = torch.exp(diagonal_decay)
162
+
163
+ return query_decay, key_decay, diagonal_decay
164
+
165
+ def forward(
166
+ self,
167
+ hidden_states: torch.Tensor,
168
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
169
+ attention_mask: Optional[torch.Tensor],
170
+ past_key_values: Optional[Cache] = None,
171
+ cache_position: Optional[torch.LongTensor] = None,
172
+ **kwargs: Unpack[FlashAttentionKwargs],
173
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
174
+ batch_size, seq_len, hidden_size = hidden_states.shape
175
+ num_blocks = (seq_len + self.block_size - 1) // self.block_size
176
+
177
+ qkv_states = self.act_fn(self.qkv_proj(hidden_states))
178
+ qkv_states = qkv_states.reshape(batch_size, seq_len, self.num_attention_heads, 3 * self.head_dim)
179
+
180
+ query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=3)
181
+
182
+ query_states = query_states.transpose(1, 2)
183
+ key_states = key_states.transpose(1, 2)
184
+ value_states = value_states.transpose(1, 2)
185
+
186
+ # calculated (K.T @ V) and saved as cache
187
+ attn_weights_inter = None
188
+ if past_key_values is not None:
189
+ attn_weights_inter = past_key_values.get_linear_cache(self.layer_idx)
190
+
191
+ if attn_weights_inter is None:
192
+ attn_weights_inter = torch.zeros(batch_size, self.num_attention_heads, self.head_dim, self.head_dim).to(
193
+ value_states
194
+ )
195
+
196
+ # apply attention_mask
197
+ if attention_mask is not None:
198
+ attention_mask = attention_mask.to(dtype=torch.bool) # Ensure it's a boolean tensor
199
+ value_states = value_states.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(-1), 0)
200
+
201
+ attn_output = []
202
+ for i in range(num_blocks):
203
+ start_idx = i * self.block_size
204
+ end_idx = min(start_idx + self.block_size, seq_len)
205
+ current_block_size = end_idx - start_idx
206
+
207
+ current_query_states = query_states[:, :, start_idx:end_idx]
208
+ current_key_states = key_states[:, :, start_idx:end_idx]
209
+ current_value_states = value_states[:, :, start_idx:end_idx]
210
+
211
+ current_query_decay = self.query_decay[:, :current_block_size]
212
+ current_key_decay = self.key_decay[:, -current_block_size:]
213
+ current_diagonal_decay = self.diagonal_decay[:, :, :current_block_size, :current_block_size]
214
+ block_decay = torch.exp(-self.slope_rate * current_block_size)
215
+
216
+ # intra: ( Q @ K.T ) @ V -> QK * V
217
+ attn_weights_intra = torch.matmul(current_query_states, current_key_states.transpose(-1, -2))
218
+ attn_output_intra = torch.matmul(attn_weights_intra * current_diagonal_decay, current_value_states)
219
+
220
+ # inter: Q @ ( K.T @ V ) -> Q * KV
221
+ attn_output_inter = torch.matmul(current_query_states * current_query_decay, attn_weights_inter)
222
+
223
+ # final attention output
224
+ current_attn_output = attn_output_inter + attn_output_intra
225
+ attn_output.append(current_attn_output)
226
+
227
+ # calculate attn_weights_inter for next block or cache
228
+ next_attn_weights_inter = torch.matmul(
229
+ (current_key_states * current_key_decay).transpose(-1, -2), current_value_states
230
+ )
231
+ attn_weights_inter = attn_weights_inter * block_decay + next_attn_weights_inter
232
+
233
+ else:
234
+ ratio = torch.exp(-self.slope_rate)
235
+ attn_output = []
236
+ for i in range(seq_len):
237
+ current_query_states = query_states[:, :, i : i + 1]
238
+ current_key_states = key_states[:, :, i : i + 1]
239
+ current_value_states = value_states[:, :, i : i + 1]
240
+
241
+ current_attn_weights_inter = torch.matmul(current_key_states.transpose(-1, -2), current_value_states)
242
+ attn_weights_inter = ratio * attn_weights_inter + current_attn_weights_inter
243
+ current_attn_output = torch.matmul(current_query_states, attn_weights_inter)
244
+
245
+ attn_output.append(current_attn_output)
246
+
247
+ # concatenate attention outputs over all blocks
248
+ attn_output = torch.cat(attn_output, dim=-2)
249
+
250
+ # final output projection
251
+ attn_output = attn_output.transpose(1, 2)
252
+ attn_output = attn_output.reshape(batch_size, seq_len, self.num_attention_heads * self.head_dim)
253
+ attn_output = self.norm(attn_output)
254
+ attn_output = F.sigmoid(self.output_gate(hidden_states)) * attn_output
255
+ attn_output = self.out_proj(attn_output)
256
+
257
+ # update cache
258
+ if past_key_values is not None:
259
+ past_key_values.set_linear_cache(self.layer_idx, attn_weights_inter)
260
+
261
+ return attn_output, attn_weights_inter
262
+
263
+
264
+ def rotate_half(x):
265
+ """Rotates half the hidden dims of the input."""
266
+ x1 = x[..., : x.shape[-1] // 2]
267
+ x2 = x[..., x.shape[-1] // 2 :]
268
+ return torch.cat((-x2, x1), dim=-1)
269
+
270
+
271
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
272
+ """Applies Rotary Position Embedding to the query and key tensors.
273
+
274
+ Args:
275
+ q (`torch.Tensor`): The query tensor.
276
+ k (`torch.Tensor`): The key tensor.
277
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
278
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
279
+ position_ids (`torch.Tensor`, *optional*):
280
+ Deprecated and unused.
281
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
282
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
283
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
284
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
285
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
286
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
287
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
288
+ Returns:
289
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
290
+ """
291
+ cos = cos.unsqueeze(unsqueeze_dim)
292
+ sin = sin.unsqueeze(unsqueeze_dim)
293
+ q_embed = (q * cos) + (rotate_half(q) * sin)
294
+ k_embed = (k * cos) + (rotate_half(k) * sin)
295
+ return q_embed, k_embed
296
+
297
+
298
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
299
+ """
300
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
301
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
302
+ """
303
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
304
+ if n_rep == 1:
305
+ return hidden_states
306
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
307
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
308
+
309
+
310
+ def eager_attention_forward(
311
+ module: nn.Module,
312
+ query: torch.Tensor,
313
+ key: torch.Tensor,
314
+ value: torch.Tensor,
315
+ attention_mask: Optional[torch.Tensor],
316
+ scaling: float,
317
+ dropout: float = 0.0,
318
+ **kwargs: Unpack[TransformersKwargs],
319
+ ):
320
+ key_states = repeat_kv(key, module.num_key_value_groups)
321
+ value_states = repeat_kv(value, module.num_key_value_groups)
322
+
323
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
324
+ if attention_mask is not None:
325
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
326
+ attn_weights = attn_weights + causal_mask
327
+
328
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
329
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
330
+ attn_output = torch.matmul(attn_weights, value_states)
331
+ attn_output = attn_output.transpose(1, 2).contiguous()
332
+
333
+ return attn_output, attn_weights
334
+
335
+
336
+ class MiniMaxAttention(nn.Module):
337
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
338
+
339
+ def __init__(self, config: MiniMaxConfig, layer_idx: int):
340
+ super().__init__()
341
+ self.config = config
342
+ self.layer_idx = layer_idx
343
+ self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
344
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
345
+ self.scaling = self.head_dim**-0.5
346
+ self.attention_dropout = config.attention_dropout
347
+ self.is_causal = True
348
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
349
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
350
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
351
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
352
+
353
+ def forward(
354
+ self,
355
+ hidden_states: torch.Tensor,
356
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
357
+ attention_mask: Optional[torch.Tensor],
358
+ past_key_values: Optional[Cache] = None,
359
+ cache_position: Optional[torch.LongTensor] = None,
360
+ **kwargs: Unpack[FlashAttentionKwargs],
361
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
362
+ input_shape = hidden_states.shape[:-1]
363
+ hidden_shape = (*input_shape, -1, self.head_dim)
364
+
365
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
366
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
367
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
368
+
369
+ cos, sin = position_embeddings
370
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
371
+
372
+ if past_key_values is not None:
373
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
374
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
375
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
376
+
377
+ attention_interface: Callable = eager_attention_forward
378
+ if self.config._attn_implementation != "eager":
379
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
380
+
381
+ attn_output, attn_weights = attention_interface(
382
+ self,
383
+ query_states,
384
+ key_states,
385
+ value_states,
386
+ attention_mask,
387
+ dropout=0.0 if not self.training else self.attention_dropout,
388
+ scaling=self.scaling,
389
+ sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
390
+ **kwargs,
391
+ )
392
+
393
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
394
+ attn_output = self.o_proj(attn_output)
395
+ return attn_output, attn_weights
396
+
397
+
398
+ class MiniMaxMLP(nn.Module):
399
+ def __init__(self, config: MiniMaxConfig):
400
+ super().__init__()
401
+ self.ffn_dim = config.intermediate_size
402
+ self.hidden_dim = config.hidden_size
403
+
404
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
405
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
406
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
407
+
408
+ self.act_fn = ACT2FN[config.hidden_act]
409
+
410
+ def forward(self, hidden_states):
411
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
412
+ current_hidden_states = self.w2(current_hidden_states)
413
+ return current_hidden_states
414
+
415
+
416
+ class MiniMaxExperts(nn.ModuleList):
417
+ """
418
+ ModuleList of experts.
419
+ """
420
+
421
+ def __init__(self, config: MiniMaxConfig):
422
+ super().__init__()
423
+ self.top_k = config.num_experts_per_tok
424
+ self.num_experts = config.num_local_experts
425
+ for _ in range(self.num_experts):
426
+ self.append(MiniMaxMLP(config))
427
+
428
+ def forward(
429
+ self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
430
+ ) -> torch.Tensor:
431
+ """
432
+ Args:
433
+ hidden_states: (batch_size * sequence_length, hidden_dim)
434
+ selected_experts: (batch_size * sequence_length, top_k)
435
+ routing_weights: (batch_size * sequence_length, top_k)
436
+ Returns:
437
+ (batch_size * sequence_length, hidden_dim)
438
+ """
439
+ final_hidden_states = torch.zeros_like(hidden_states)
440
+ expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
441
+
442
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
443
+ for expert_idx in expert_hit:
444
+ idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
445
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
446
+ current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
447
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
448
+ return final_hidden_states
449
+
450
+
451
+ class MiniMaxSparseMoeBlock(nn.Module):
452
+ def __init__(self, config):
453
+ super().__init__()
454
+ self.top_k = config.num_experts_per_tok
455
+ self.jitter_noise = config.router_jitter_noise
456
+ self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
457
+ self.experts = MiniMaxExperts(config)
458
+
459
+ def route_tokens_to_experts(self, router_logits):
460
+ routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1)
461
+ top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1)
462
+ top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
463
+ return top_k_index, top_k_weights.to(router_logits.dtype)
464
+
465
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
466
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
467
+ if self.training and self.jitter_noise > 0:
468
+ hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
469
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
470
+ router_logits = self.gate(hidden_states)
471
+ top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits)
472
+ hidden_states = self.experts(hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype))
473
+ hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
474
+ return hidden_states
475
+
476
+
477
+ class MiniMaxDecoderLayer(GradientCheckpointingLayer):
478
+ def __init__(self, config: MiniMaxConfig, layer_idx: int):
479
+ super().__init__()
480
+ self.hidden_size = config.hidden_size
481
+
482
+ self.self_attn = MiniMaxAttention(config, layer_idx)
483
+
484
+ self.block_sparse_moe = MiniMaxSparseMoeBlock(config)
485
+ self.input_layernorm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
486
+ self.post_attention_layernorm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
487
+
488
+ self.layer_idx = layer_idx
489
+ self.layer_type = config.layer_types[layer_idx]
490
+ self.mlp_alpha_factor = config.mlp_alpha_factor
491
+ self.mlp_beta_factor = config.mlp_beta_factor
492
+
493
+ if self.layer_type == "linear_attention":
494
+ self.self_attn = MiniMaxLightningAttention(config, layer_idx)
495
+ self.attn_alpha_factor = config.linear_attn_alpha_factor
496
+ self.attn_beta_factor = config.linear_attn_beta_factor
497
+ else:
498
+ self.self_attn = MiniMaxAttention(config, layer_idx)
499
+ self.attn_alpha_factor = config.full_attn_alpha_factor
500
+ self.attn_beta_factor = config.full_attn_beta_factor
501
+
502
+ def forward(
503
+ self,
504
+ hidden_states: torch.Tensor,
505
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
506
+ attention_mask: Optional[torch.Tensor] = None,
507
+ position_ids: Optional[torch.LongTensor] = None,
508
+ past_key_values: Optional[Cache] = None,
509
+ use_cache: Optional[bool] = False,
510
+ cache_position: Optional[torch.LongTensor] = None,
511
+ **kwargs: Unpack[FlashAttentionKwargs],
512
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
513
+ hidden_states = self.input_layernorm(hidden_states)
514
+ residual = hidden_states
515
+ hidden_states, _ = self.self_attn(
516
+ hidden_states=hidden_states,
517
+ position_embeddings=position_embeddings,
518
+ attention_mask=attention_mask,
519
+ position_ids=position_ids,
520
+ past_key_values=past_key_values,
521
+ use_cache=use_cache,
522
+ cache_position=cache_position,
523
+ **kwargs,
524
+ )
525
+ hidden_states = residual * self.attn_alpha_factor + hidden_states * self.attn_beta_factor
526
+ hidden_states = self.post_attention_layernorm(hidden_states)
527
+ residual = hidden_states
528
+ hidden_states = self.block_sparse_moe(hidden_states)
529
+ hidden_states = residual * self.mlp_alpha_factor + hidden_states * self.mlp_beta_factor
530
+
531
+ return hidden_states
532
+
533
+
534
+ @auto_docstring
535
+ class MiniMaxPreTrainedModel(PreTrainedModel):
536
+ config: MiniMaxConfig
537
+ base_model_prefix = "model"
538
+ supports_gradient_checkpointing = True
539
+ _no_split_modules = ["MiniMaxDecoderLayer"]
540
+ _skip_keys_device_placement = ["past_key_values"]
541
+ _supports_flash_attn = True
542
+ _supports_sdpa = True
543
+ _supports_flex_attn = True
544
+ _can_compile_fullgraph = False
545
+ _supports_attention_backend = True
546
+ _can_record_outputs = {
547
+ "router_logits": OutputRecorder(nn.Linear, layer_name="block_sparse_moe.gate", index=0),
548
+ "hidden_states": MiniMaxDecoderLayer,
549
+ "attentions": [MiniMaxAttention, MiniMaxLightningAttention],
550
+ }
551
+
552
+
553
+ class MiniMaxRotaryEmbedding(nn.Module):
554
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
555
+
556
+ def __init__(self, config: MiniMaxConfig, device=None):
557
+ super().__init__()
558
+ # BC: "rope_type" was originally "type"
559
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
560
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
561
+ else:
562
+ self.rope_type = "default"
563
+ self.max_seq_len_cached = config.max_position_embeddings
564
+ self.original_max_seq_len = config.max_position_embeddings
565
+
566
+ self.config = config
567
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
568
+
569
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
570
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
571
+ self.original_inv_freq = self.inv_freq
572
+
573
+ @torch.no_grad()
574
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
575
+ def forward(self, x, position_ids):
576
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
577
+ position_ids_expanded = position_ids[:, None, :].float()
578
+
579
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
580
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
581
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
582
+ emb = torch.cat((freqs, freqs), dim=-1)
583
+ cos = emb.cos() * self.attention_scaling
584
+ sin = emb.sin() * self.attention_scaling
585
+
586
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
587
+
588
+
589
+ @auto_docstring
590
+ class MiniMaxModel(MiniMaxPreTrainedModel):
591
+ def __init__(self, config: MiniMaxConfig):
592
+ super().__init__(config)
593
+ self.padding_idx = config.pad_token_id
594
+ self.vocab_size = config.vocab_size
595
+
596
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
597
+ self.layers = nn.ModuleList(
598
+ [MiniMaxDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
599
+ )
600
+ self.norm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
601
+ self.rotary_emb = MiniMaxRotaryEmbedding(config=config)
602
+ self.gradient_checkpointing = False
603
+
604
+ # Initialize weights and apply final processing
605
+ self.post_init()
606
+
607
+ @check_model_inputs()
608
+ def forward(
609
+ self,
610
+ input_ids: Optional[torch.LongTensor] = None,
611
+ attention_mask: Optional[torch.Tensor] = None,
612
+ position_ids: Optional[torch.LongTensor] = None,
613
+ past_key_values: Optional[MiniMaxCache] = None,
614
+ inputs_embeds: Optional[torch.FloatTensor] = None,
615
+ use_cache: Optional[bool] = None,
616
+ cache_position: Optional[torch.LongTensor] = None,
617
+ **kwargs: Unpack[TransformersKwargs],
618
+ ) -> MoeModelOutputWithPast:
619
+ if (input_ids is None) ^ (inputs_embeds is not None):
620
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
621
+
622
+ if use_cache and past_key_values is None:
623
+ past_key_values = MiniMaxCache()
624
+ elif use_cache and not isinstance(past_key_values, MiniMaxCache):
625
+ raise ValueError(
626
+ f"MiniMax uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}."
627
+ )
628
+
629
+ if inputs_embeds is None:
630
+ inputs_embeds = self.embed_tokens(input_ids)
631
+
632
+ if cache_position is None:
633
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
634
+ cache_position = torch.arange(
635
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
636
+ )
637
+ if position_ids is None:
638
+ position_ids = cache_position.unsqueeze(0)
639
+
640
+ mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
641
+ causal_mask = mask_function(
642
+ config=self.config,
643
+ input_embeds=inputs_embeds,
644
+ attention_mask=attention_mask,
645
+ cache_position=cache_position,
646
+ past_key_values=past_key_values,
647
+ position_ids=position_ids,
648
+ )
649
+
650
+ hidden_states = inputs_embeds
651
+
652
+ # create position embeddings to be shared across the decoder layers
653
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
654
+
655
+ for decoder_layer in self.layers:
656
+ if decoder_layer.layer_type == "full_attention":
657
+ input_attention_mask = causal_mask
658
+ else:
659
+ # lightning attention uses original attention_mask, and uses it only for the first step
660
+ input_attention_mask = attention_mask
661
+
662
+ hidden_states = decoder_layer(
663
+ hidden_states,
664
+ position_embeddings=position_embeddings,
665
+ attention_mask=input_attention_mask,
666
+ position_ids=position_ids,
667
+ past_key_values=past_key_values,
668
+ use_cache=use_cache,
669
+ cache_position=cache_position,
670
+ **kwargs,
671
+ )
672
+
673
+ hidden_states = self.norm(hidden_states)
674
+
675
+ return MoeModelOutputWithPast(
676
+ last_hidden_state=hidden_states,
677
+ past_key_values=past_key_values,
678
+ )
679
+
680
+
681
+ def load_balancing_loss_func(
682
+ gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
683
+ num_experts: Optional[int] = None,
684
+ top_k=2,
685
+ attention_mask: Optional[torch.Tensor] = None,
686
+ ) -> Union[torch.Tensor, int]:
687
+ r"""
688
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
689
+
690
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
691
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
692
+ experts is too unbalanced.
693
+
694
+ Args:
695
+ gate_logits:
696
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
697
+ shape [batch_size X sequence_length, num_experts].
698
+ num_experts:
699
+ Number of experts
700
+ top_k:
701
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
702
+ parameter.
703
+ attention_mask (`torch.Tensor`, *optional*):
704
+ The attention_mask used in forward function
705
+ shape [batch_size X sequence_length] if not None.
706
+
707
+ Returns:
708
+ The auxiliary loss.
709
+ """
710
+ if gate_logits is None or not isinstance(gate_logits, tuple):
711
+ return 0
712
+
713
+ if isinstance(gate_logits, tuple):
714
+ compute_device = gate_logits[0].device
715
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
716
+
717
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
718
+
719
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
720
+
721
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
722
+
723
+ if attention_mask is None:
724
+ # Compute the percentage of tokens routed to each experts
725
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
726
+
727
+ # Compute the average probability of routing to these experts
728
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
729
+ else:
730
+ batch_size, sequence_length = attention_mask.shape
731
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
732
+
733
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
734
+ expert_attention_mask = (
735
+ attention_mask[None, :, :, None, None]
736
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
737
+ .reshape(-1, top_k, num_experts)
738
+ .to(compute_device)
739
+ )
740
+
741
+ # Compute the percentage of tokens routed to each experts
742
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
743
+ expert_attention_mask, dim=0
744
+ )
745
+
746
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
747
+ router_per_expert_attention_mask = (
748
+ attention_mask[None, :, :, None]
749
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
750
+ .reshape(-1, num_experts)
751
+ .to(compute_device)
752
+ )
753
+
754
+ # Compute the average probability of routing to these experts
755
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
756
+ router_per_expert_attention_mask, dim=0
757
+ )
758
+
759
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
760
+ return overall_loss * num_experts
761
+
762
+
763
+ @auto_docstring
764
+ class MiniMaxForCausalLM(MiniMaxPreTrainedModel, GenerationMixin):
765
+ _tied_weights_keys = ["lm_head.weight"]
766
+ _tp_plan = {"lm_head": "colwise_rep"}
767
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
768
+
769
+ def __init__(self, config):
770
+ super().__init__(config)
771
+ self.model = MiniMaxModel(config)
772
+ self.vocab_size = config.vocab_size
773
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
774
+ self.router_aux_loss_coef = config.router_aux_loss_coef
775
+ self.num_experts = config.num_local_experts
776
+ self.num_experts_per_tok = config.num_experts_per_tok
777
+
778
+ # Initialize weights and apply final processing
779
+ self.post_init()
780
+
781
+ @can_return_tuple
782
+ @auto_docstring
783
+ def forward(
784
+ self,
785
+ input_ids: Optional[torch.LongTensor] = None,
786
+ attention_mask: Optional[torch.Tensor] = None,
787
+ position_ids: Optional[torch.LongTensor] = None,
788
+ past_key_values: Optional[Cache] = None,
789
+ inputs_embeds: Optional[torch.FloatTensor] = None,
790
+ labels: Optional[torch.LongTensor] = None,
791
+ use_cache: Optional[bool] = None,
792
+ output_router_logits: Optional[bool] = None,
793
+ cache_position: Optional[torch.LongTensor] = None,
794
+ logits_to_keep: Union[int, torch.Tensor] = 0,
795
+ **kwargs: Unpack[TransformersKwargs],
796
+ ) -> MoeCausalLMOutputWithPast:
797
+ r"""
798
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
799
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
800
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
801
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
802
+
803
+ Example:
804
+
805
+ ```python
806
+ >>> from transformers import AutoTokenizer, MiniMaxForCausalLM
807
+
808
+ >>> model = MiniMaxForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
809
+ >>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
810
+
811
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
812
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
813
+
814
+ >>> # Generate
815
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
816
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
817
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
818
+ ```"""
819
+
820
+ output_router_logits = (
821
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
822
+ )
823
+
824
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
825
+ outputs: MoeModelOutputWithPast = self.model(
826
+ input_ids=input_ids,
827
+ attention_mask=attention_mask,
828
+ position_ids=position_ids,
829
+ past_key_values=past_key_values,
830
+ inputs_embeds=inputs_embeds,
831
+ use_cache=use_cache,
832
+ output_router_logits=output_router_logits,
833
+ cache_position=cache_position,
834
+ **kwargs,
835
+ )
836
+
837
+ hidden_states = outputs.last_hidden_state
838
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
839
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
840
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
841
+
842
+ loss = None
843
+ if labels is not None:
844
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
845
+
846
+ aux_loss = None
847
+ if output_router_logits:
848
+ aux_loss = load_balancing_loss_func(
849
+ outputs.router_logits,
850
+ self.num_experts,
851
+ self.num_experts_per_tok,
852
+ attention_mask,
853
+ )
854
+ if labels is not None:
855
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
856
+
857
+ return MoeCausalLMOutputWithPast(
858
+ loss=loss,
859
+ aux_loss=aux_loss,
860
+ logits=logits,
861
+ past_key_values=outputs.past_key_values,
862
+ hidden_states=outputs.hidden_states,
863
+ attentions=outputs.attentions,
864
+ router_logits=outputs.router_logits,
865
+ )
866
+
867
+
868
+ class MiniMaxForSequenceClassification(GenericForSequenceClassification, MiniMaxPreTrainedModel):
869
+ pass
870
+
871
+
872
+ class MiniMaxForTokenClassification(GenericForTokenClassification, MiniMaxPreTrainedModel):
873
+ pass
874
+
875
+
876
+ class MiniMaxForQuestionAnswering(GenericForQuestionAnswering, MiniMaxPreTrainedModel):
877
+ pass
878
+
879
+
880
+ __all__ = [
881
+ "MiniMaxPreTrainedModel",
882
+ "MiniMaxModel",
883
+ "MiniMaxForCausalLM",
884
+ "MiniMaxForSequenceClassification",
885
+ "MiniMaxForTokenClassification",
886
+ "MiniMaxForQuestionAnswering",
887
+ ]
modeling_minimax_text_01.py DELETED
@@ -1,1701 +0,0 @@
1
- """ PyTorch MiniMaxText01 model."""
2
- import inspect
3
- import math
4
- import warnings
5
- from typing import List, Optional, Tuple, Union
6
- import os
7
- import copy
8
- import torch
9
- import torch.nn.functional as F
10
- import torch.utils.checkpoint
11
- from torch import nn
12
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
13
- from einops import rearrange, repeat
14
- from transformers.activations import ACT2FN
15
- from transformers.cache_utils import Cache, DynamicCache
16
- from transformers.modeling_attn_mask_utils import (
17
- _prepare_4d_causal_attention_mask,
18
- )
19
- from transformers.modeling_outputs import (
20
- MoeCausalLMOutputWithPast,
21
- MoeModelOutputWithPast,
22
- SequenceClassifierOutputWithPast,
23
- )
24
- from transformers.modeling_utils import PreTrainedModel
25
- from transformers.utils import (
26
- add_start_docstrings,
27
- add_start_docstrings_to_model_forward,
28
- is_flash_attn_2_available,
29
- is_flash_attn_greater_or_equal_2_10,
30
- logging,
31
- replace_return_docstrings,
32
- )
33
- from transformers.utils.import_utils import is_torch_fx_available
34
- from .configuration_minimax_text_01 import MiniMaxText01Config
35
-
36
- if is_flash_attn_2_available():
37
- from flash_attn import flash_attn_func, flash_attn_varlen_func
38
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
39
-
40
- _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
41
-
42
- # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
43
- # It means that the function will not be traced through and simply appear as a node in the graph.
44
- if is_torch_fx_available():
45
- _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
46
-
47
- use_triton = eval(os.environ.get("use_triton", default="False"))
48
- debug = eval(os.environ.get("debug", default="False"))
49
- do_eval = eval(os.environ.get("do_eval", default="False"))
50
- eval_and_not_generate = eval(os.environ.get("eval_and_not_generate", default="False"))
51
- BLOCK = 256
52
-
53
- logger = logging.get_logger(__name__)
54
-
55
- _CONFIG_FOR_DOC = "MiniMaxText01Config"
56
-
57
-
58
- def get_activation_fn(activation):
59
- if debug:
60
- logger.info(f"activation: {activation}")
61
- if activation == "gelu":
62
- return F.gelu
63
- elif activation == "relu":
64
- return F.relu
65
- elif activation == "elu":
66
- return F.elu
67
- elif activation == "sigmoid":
68
- return F.sigmoid
69
- elif activation == "exp":
70
-
71
- def f(x):
72
- with torch.no_grad():
73
- x_max = torch.max(x, dim=-1, keepdims=True).values
74
- y = torch.exp(x - x_max)
75
-
76
- return y
77
-
78
- return f
79
- elif activation == "leak":
80
- return F.leaky_relu
81
- elif activation == "1+elu":
82
-
83
- def f(x):
84
- return 1 + F.elu(x)
85
-
86
- return f
87
- elif activation == "2+elu":
88
-
89
- def f(x):
90
- return 2 + F.elu(x)
91
-
92
- return f
93
- elif activation == "silu" or activation == "swish":
94
- return F.silu
95
- elif activation == "sine":
96
- return torch.sin
97
- else:
98
- logger.info(
99
- f"activation: does not support {activation}, use Identity!!!")
100
- return lambda x: x
101
-
102
-
103
- def load_balancing_loss_func(
104
- gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2,
105
- attention_mask: Optional[torch.Tensor] = None
106
- ) -> float:
107
- r"""
108
- Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
109
-
110
- See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
111
- function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
112
- experts is too unbalanced.
113
-
114
- Args:
115
- gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
116
- Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
117
- shape [batch_size X sequence_length, num_experts].
118
- attention_mask (`torch.Tensor`, None):
119
- The attention_mask used in forward function
120
- shape [batch_size X sequence_length] if not None.
121
- num_experts (`int`, *optional*):
122
- Number of experts
123
-
124
- Returns:
125
- The auxiliary loss.
126
- """
127
- if gate_logits is None or not isinstance(gate_logits, tuple):
128
- return 0
129
-
130
- if isinstance(gate_logits, tuple):
131
- compute_device = gate_logits[0].device
132
- concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
133
-
134
- routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
135
-
136
- _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
137
-
138
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
139
-
140
- if attention_mask is None:
141
- # Compute the percentage of tokens routed to each experts
142
- tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
143
-
144
- # Compute the average probability of routing to these experts
145
- router_prob_per_expert = torch.mean(routing_weights, dim=0)
146
- else:
147
- batch_size, sequence_length = attention_mask.shape
148
- num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
149
-
150
- # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
151
- expert_attention_mask = (
152
- attention_mask[None, :, :, None, None]
153
- .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
154
- .reshape(-1, top_k, num_experts)
155
- .to(compute_device)
156
- )
157
-
158
- # Compute the percentage of tokens routed to each experts
159
- tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
160
- expert_attention_mask, dim=0
161
- )
162
-
163
- # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
164
- router_per_expert_attention_mask = (
165
- attention_mask[None, :, :, None]
166
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
167
- .reshape(-1, num_experts)
168
- .to(compute_device)
169
- )
170
-
171
- # Compute the average probability of routing to these experts
172
- router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
173
- router_per_expert_attention_mask, dim=0
174
- )
175
-
176
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
177
- return overall_loss * num_experts
178
-
179
-
180
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
181
- def _get_unpad_data(attention_mask):
182
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
183
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
184
- max_seqlen_in_batch = seqlens_in_batch.max().item()
185
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
186
- return (
187
- indices,
188
- cu_seqlens,
189
- max_seqlen_in_batch,
190
- )
191
-
192
-
193
- class GLU(nn.Module):
194
-
195
- def __init__(self, d1, d2, bias=False):
196
- super().__init__()
197
-
198
- self.l1 = nn.Linear(d1, d2, bias=bias)
199
- self.l2 = nn.Linear(d1, d2, bias=bias)
200
- self.l3 = nn.Linear(d2, d1, bias=bias)
201
-
202
- def forward(self, x):
203
- o1 = self.l1(x)
204
- o2 = self.l2(x)
205
- output = o1 * o2
206
- output = self.l3(output)
207
- return output
208
-
209
-
210
- class MiniMaxText01LightningAttention(nn.Module):
211
- def __init__(self, config: MiniMaxText01Config, layer_idx: Optional[int] = None):
212
- super().__init__()
213
- bias = False
214
- self.hidden_size = config.hidden_size
215
- self.num_heads = config.num_attention_heads
216
- self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
217
-
218
- self.out_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=bias)
219
- self.act = get_activation_fn(config.hidden_act)
220
- self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads)
221
-
222
- self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias)
223
- self.output_gate = nn.Linear(self.hidden_size, self.head_dim * self.num_heads, bias=bias)
224
-
225
- # for inference only
226
- self.offset = 0
227
- self.layer_idx = layer_idx
228
-
229
- def forward(
230
- self,
231
- hidden_states,
232
- attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
233
- output_attentions: bool = False,
234
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
235
- use_cache: bool = False,
236
- slope_rate: Optional[torch.Tensor] = None,
237
- **kwargs
238
- ):
239
- if (not self.training) and (not do_eval):
240
- return self.inference(
241
- hidden_states,
242
- attn_mask,
243
- output_attentions,
244
- past_key_value,
245
- use_cache,
246
- slope_rate,
247
- )
248
-
249
- def inference(
250
- self,
251
- x,
252
- attn_mask: Optional[torch.Tensor] = None, # (b, n)
253
- output_attentions: bool = False,
254
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
255
- use_cache: bool = False,
256
- slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
257
- ):
258
- # x: b n d
259
- b, n, d = x.shape
260
- # linear map
261
- qkv = self.act(self.qkv_proj(x))
262
- new_shape = qkv.size()[:-1] + (self.num_heads, -1)
263
- qkv = qkv.view(*new_shape)
264
- q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3)
265
- q = q.transpose(1, 2)
266
- k = k.transpose(1, 2)
267
- v = v.transpose(1, 2)
268
-
269
- if past_key_value is None:
270
- self.offset = q.shape[-2]
271
- else:
272
- self.offset += 1
273
-
274
- # for align with metaseq
275
- ratio = torch.exp(-slope_rate)
276
-
277
- # only use for the first time
278
- if past_key_value is None:
279
- slope_rate = slope_rate.to(torch.float32)
280
- if attn_mask is not None:
281
- v = v.masked_fill((1 - attn_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0)
282
- NUM_BLOCK = (n + BLOCK - 1) // BLOCK
283
- b, h, n, d = q.shape
284
- e = v.shape[-1]
285
- # other
286
- array = torch.arange(BLOCK).to(q) + 1
287
- q_decay = torch.exp(-slope_rate * array.reshape(-1, 1))
288
- k_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1)))
289
- index = array[:, None] - array[None, :]
290
- s_index = slope_rate * index[
291
- None,
292
- None,
293
- ]
294
- s_index = torch.where(index >= 0, -s_index, float("-inf"))
295
- diag_decay = torch.exp(s_index)
296
-
297
- kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device)
298
- output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
299
- for i in range(NUM_BLOCK):
300
- si = i * BLOCK
301
- ei = min(si + BLOCK, n)
302
- m = ei - si
303
- qi = q[:, :, si:ei].contiguous()
304
- ki = k[:, :, si:ei].contiguous()
305
- vi = v[:, :, si:ei].contiguous()
306
- qkv_none_diag = torch.matmul(qi * q_decay[:, :m], kv).to(torch.float32)
307
-
308
- # diag
309
- qk = torch.matmul(qi, ki.transpose(-1, -2)).to(torch.float32) * diag_decay[:, :, :m, :m]
310
- qkv_diag = torch.matmul(qk, vi.to(torch.float32))
311
- block_decay = torch.exp(-slope_rate * m)
312
- output[:, :, si:ei] = qkv_none_diag + qkv_diag
313
- kv = block_decay * kv + torch.matmul((ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi)
314
-
315
- else:
316
- kv = past_key_value
317
- output = []
318
- for i in range(n):
319
- kv = ratio * kv + torch.einsum(
320
- "... n d, ... n e -> ... d e",
321
- k[:, :, i:i + 1],
322
- v[:, :, i:i + 1],
323
- )
324
- qkv = torch.einsum("... n e, ... e d -> ... n d", q[:, :, i:i + 1], kv.to(q.dtype))
325
- output.append(qkv)
326
- output = torch.concat(output, dim=-2)
327
- # reshape
328
- output = rearrange(output, "b h n d -> b n (h d)")
329
- # normalize
330
- output = self.norm(output)
331
- # gate
332
- output = F.sigmoid(self.output_gate(x)) * output
333
- # outproj
334
- output = self.out_proj(output)
335
-
336
- attn_weights = None
337
-
338
- return output, attn_weights, kv
339
-
340
-
341
- # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxText01
342
- class MiniMaxText01RMSNorm(nn.Module):
343
- def __init__(self, hidden_size, eps=1e-6):
344
- """
345
- MiniMaxText01RMSNorm is equivalent to T5LayerNorm
346
- """
347
- super().__init__()
348
- self.weight = nn.Parameter(torch.ones(hidden_size))
349
- self.variance_epsilon = eps
350
-
351
- def forward(self, hidden_states):
352
- input_dtype = hidden_states.dtype
353
- hidden_states = hidden_states.to(torch.float32)
354
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
355
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
356
- return self.weight * hidden_states.to(input_dtype)
357
-
358
-
359
- # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->MiniMaxText01
360
- class MiniMaxText01RotaryEmbedding(nn.Module):
361
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
362
- super().__init__()
363
-
364
- self.dim = dim
365
- self.max_position_embeddings = max_position_embeddings
366
- self.base = base
367
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
368
- self.register_buffer("inv_freq", inv_freq, persistent=False)
369
-
370
- # Build here to make `torch.jit.trace` work.
371
- self._set_cos_sin_cache(
372
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
373
- )
374
-
375
- def _set_cos_sin_cache(self, seq_len, device, dtype):
376
- self.max_seq_len_cached = seq_len
377
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
378
-
379
- freqs = torch.outer(t, self.inv_freq)
380
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
381
- emb = torch.cat((freqs, freqs), dim=-1)
382
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
383
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
384
-
385
- def forward(self, x, seq_len=None):
386
- # x: [bs, num_attention_heads, seq_len, head_size]
387
- if seq_len > self.max_seq_len_cached:
388
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)
389
-
390
- return (
391
- self.cos_cached[:seq_len].to(dtype=torch.float32),
392
- self.sin_cached[:seq_len].to(dtype=torch.float32),
393
- )
394
-
395
-
396
- # Copied from transformers.models.llama.modeling_llama.rotate_half
397
- def rotate_half(x):
398
- """Rotates half the hidden dims of the input."""
399
- x1 = x[..., : x.shape[-1] // 2]
400
- x2 = x[..., x.shape[-1] // 2:]
401
- return torch.cat((-x2, x1), dim=-1)
402
-
403
-
404
- # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
405
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
406
- """Applies Rotary Position Embedding to the query and key tensors.
407
-
408
- Args:
409
- q (`torch.Tensor`): The query tensor.
410
- k (`torch.Tensor`): The key tensor.
411
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
412
- sin (`torch.Tensor`): The sine part of the rotary embedding.
413
- position_ids (`torch.Tensor`):
414
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
415
- used to pass offsetted position ids when working with a KV-cache.
416
- unsqueeze_dim (`int`, *optional*, defaults to 1):
417
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
418
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
419
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
420
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
421
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
422
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
423
- Returns:
424
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
425
- """
426
- dtype = q.dtype
427
- rot_dim = cos.shape[-1]
428
- q_, q_pass = q[..., :rot_dim], q[..., rot_dim:]
429
- k_, k_pass = k[..., :rot_dim], k[..., rot_dim:]
430
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
431
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
432
- q_embed = (q_ * cos) + (rotate_half(q_) * sin)
433
- k_embed = (k_ * cos) + (rotate_half(k_) * sin)
434
- return torch.cat((q_embed, q_pass), dim=-1).to(dtype), torch.cat((k_embed, k_pass), dim=-1).to(dtype)
435
-
436
-
437
- # Copied from transformers.models.llama.modeling_llama.repeat_kv
438
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
439
- """
440
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
441
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
442
- """
443
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
444
- if n_rep == 1:
445
- return hidden_states
446
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
447
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
448
-
449
-
450
- # Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->MiniMaxText01
451
- class MiniMaxText01Attention(nn.Module):
452
- """
453
- Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
454
- and "Generating Long Sequences with Sparse Transformers".
455
- """
456
-
457
- def __init__(self, config: MiniMaxText01Config, layer_idx: Optional[int] = None):
458
- super().__init__()
459
- self.config = config
460
- self.layer_idx = layer_idx
461
- if layer_idx is None:
462
- logger.warning_once(
463
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
464
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
465
- "when creating this class."
466
- )
467
-
468
- self.hidden_size = config.hidden_size
469
- self.num_heads = config.num_attention_heads
470
- self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
471
- self.num_key_value_heads = config.num_key_value_heads
472
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
473
- self.max_position_embeddings = config.max_position_embeddings
474
- self.rope_theta = config.rope_theta
475
- self.is_causal = True
476
- self.attention_dropout = config.attention_dropout
477
-
478
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
479
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
480
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
481
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
482
- self.rotary_dim = getattr(config, 'rotary_dim', self.head_dim)
483
-
484
- self.rotary_emb = MiniMaxText01RotaryEmbedding(
485
- self.rotary_dim,
486
- max_position_embeddings=self.max_position_embeddings,
487
- base=self.rope_theta,
488
- )
489
-
490
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
491
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
492
-
493
- def forward(
494
- self,
495
- hidden_states: torch.Tensor,
496
- attention_mask: Optional[torch.Tensor] = None,
497
- position_ids: Optional[torch.LongTensor] = None,
498
- past_key_value: Optional[Cache] = None,
499
- output_attentions: bool = False,
500
- use_cache: bool = False,
501
- **kwargs,
502
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
503
- if "padding_mask" in kwargs:
504
- warnings.warn(
505
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
506
- )
507
- bsz, q_len, _ = hidden_states.size()
508
-
509
- query_states = self.q_proj(hidden_states)
510
- key_states = self.k_proj(hidden_states)
511
- value_states = self.v_proj(hidden_states)
512
-
513
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
514
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
515
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
516
-
517
- kv_seq_len = key_states.shape[-2]
518
- if past_key_value is not None:
519
- if self.layer_idx is None:
520
- raise ValueError(
521
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
522
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
523
- "with a layer index."
524
- )
525
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
526
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
527
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
528
-
529
- if past_key_value is not None:
530
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
531
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
532
-
533
- # repeat k/v heads if n_kv_heads < n_heads
534
- key_states = repeat_kv(key_states, self.num_key_value_groups)
535
- value_states = repeat_kv(value_states, self.num_key_value_groups)
536
-
537
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
538
-
539
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
540
- raise ValueError(
541
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
542
- f" {attn_weights.size()}"
543
- )
544
-
545
- if attention_mask is not None:
546
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
547
- raise ValueError(
548
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
549
- )
550
-
551
- attn_weights = attn_weights + attention_mask
552
-
553
- # upcast attention to fp32
554
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
555
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
556
- attn_output = torch.matmul(attn_weights, value_states)
557
-
558
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
559
- raise ValueError(
560
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
561
- f" {attn_output.size()}"
562
- )
563
-
564
- attn_output = attn_output.transpose(1, 2).contiguous()
565
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
566
-
567
- attn_output = self.o_proj(attn_output)
568
-
569
- if not output_attentions:
570
- attn_weights = None
571
-
572
- return attn_output, attn_weights, past_key_value
573
-
574
-
575
- # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->MiniMaxText01
576
- class MiniMaxText01FlashAttention2(MiniMaxText01Attention):
577
- """
578
- MiniMaxText01 flash attention module. This module inherits from `MiniMaxText01Attention` as the weights of the module stays
579
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
580
- flash attention and deal with padding tokens in case the input contains any of them.
581
- """
582
-
583
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
584
- def __init__(self, *args, **kwargs):
585
- super().__init__(*args, **kwargs)
586
-
587
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
588
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
589
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
590
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
591
-
592
- def forward(
593
- self,
594
- hidden_states: torch.Tensor,
595
- attention_mask: Optional[torch.Tensor] = None,
596
- position_ids: Optional[torch.LongTensor] = None,
597
- past_key_value: Optional[Union[Cache, Tuple[torch.Tensor]]] = None,
598
- output_attentions: bool = False,
599
- use_cache: bool = False,
600
- **kwargs,
601
- ):
602
- if "padding_mask" in kwargs:
603
- warnings.warn(
604
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
605
- )
606
-
607
- # overwrite attention_mask with padding_mask
608
- attention_mask = kwargs.pop("padding_mask")
609
- bsz, q_len, _ = hidden_states.size()
610
-
611
- query_states = self.q_proj(hidden_states)
612
- key_states = self.k_proj(hidden_states)
613
- value_states = self.v_proj(hidden_states)
614
-
615
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
616
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
617
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
618
-
619
- kv_seq_len = key_states.shape[-2]
620
- if past_key_value is not None:
621
- kv_seq_len += past_key_value[0].shape[-3]
622
-
623
- # Because the input can be padded, the absolute sequence length depends on the max position id.
624
- rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
625
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
626
-
627
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
628
-
629
- use_sliding_windows = (
630
- _flash_supports_window_size
631
- and getattr(self.config, "sliding_window", None) is not None
632
- and kv_seq_len > self.config.sliding_window
633
- )
634
-
635
- if not _flash_supports_window_size:
636
- logger.warning_once(
637
- "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
638
- " make sure to upgrade flash-attn library."
639
- )
640
-
641
- dropout_rate = 0.0 if not self.training else self.attention_dropout
642
-
643
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
644
- # therefore the input hidden states gets silently casted in float32. Hence, we need
645
- # cast them back in float16 just to be sure everything works as expected.
646
- input_dtype = query_states.dtype
647
- if input_dtype == torch.float32:
648
- if torch.is_autocast_enabled():
649
- target_dtype = torch.get_autocast_gpu_dtype()
650
- # Handle the case where the model is quantized
651
- elif hasattr(self.config, "_pre_quantization_dtype"):
652
- target_dtype = self.config._pre_quantization_dtype
653
- else:
654
- target_dtype = self.q_proj.weight.dtype
655
-
656
- logger.warning_once(
657
- f"The input hidden states seems to be silently casted in float32, this might be related to"
658
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
659
- f" {target_dtype}."
660
- )
661
-
662
- query_states = query_states.to(target_dtype)
663
- key_states = key_states.to(target_dtype)
664
- value_states = value_states.to(target_dtype)
665
-
666
- # Reshape to the expected shape for Flash Attention
667
- query_states = query_states.transpose(1, 2)
668
- key_states = key_states.transpose(1, 2)
669
- value_states = value_states.transpose(1, 2)
670
-
671
- if past_key_value is not None:
672
- # reuse k, v, for evaluation only
673
- key_states = torch.cat([past_key_value[0], key_states], dim=-3)
674
- value_states = torch.cat([past_key_value[1], value_states], dim=-3)
675
-
676
- past_key_value = (key_states, value_states) if use_cache else None
677
-
678
- attn_output = self._flash_attention_forward(
679
- query_states,
680
- key_states,
681
- value_states,
682
- attention_mask,
683
- q_len,
684
- dropout=dropout_rate,
685
- use_sliding_windows=use_sliding_windows,
686
- )
687
-
688
- attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
689
- attn_output = self.o_proj(attn_output)
690
-
691
- if not output_attentions:
692
- attn_weights = None
693
-
694
- return attn_output, attn_weights, past_key_value
695
-
696
- def _flash_attention_forward(
697
- self,
698
- query_states,
699
- key_states,
700
- value_states,
701
- attention_mask,
702
- query_length,
703
- dropout=0.0,
704
- softmax_scale=None,
705
- use_sliding_windows=False,
706
- ):
707
- """
708
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
709
- first unpad the input, then computes the attention scores and pad the final attention scores.
710
-
711
- Args:
712
- query_states (`torch.Tensor`):
713
- Input query states to be passed to Flash Attention API
714
- key_states (`torch.Tensor`):
715
- Input key states to be passed to Flash Attention API
716
- value_states (`torch.Tensor`):
717
- Input value states to be passed to Flash Attention API
718
- attention_mask (`torch.Tensor`):
719
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
720
- position of padding tokens and 1 for the position of non-padding tokens.
721
- dropout (`float`):
722
- Attention dropout
723
- softmax_scale (`float`, *optional*):
724
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
725
- use_sliding_windows (`bool`, *optional*):
726
- Whether to activate sliding window attention.
727
- """
728
- if not self._flash_attn_uses_top_left_mask:
729
- causal = self.is_causal
730
- else:
731
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
732
- causal = self.is_causal and query_length != 1
733
-
734
- # Contains at least one padding token in the sequence
735
- if attention_mask is not None:
736
- batch_size = query_states.shape[0]
737
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
738
- query_states, key_states, value_states, attention_mask, query_length
739
- )
740
-
741
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
742
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
743
-
744
- if not use_sliding_windows:
745
- attn_output_unpad = flash_attn_varlen_func(
746
- query_states,
747
- key_states,
748
- value_states,
749
- cu_seqlens_q=cu_seqlens_q,
750
- cu_seqlens_k=cu_seqlens_k,
751
- max_seqlen_q=max_seqlen_in_batch_q,
752
- max_seqlen_k=max_seqlen_in_batch_k,
753
- dropout_p=dropout,
754
- softmax_scale=softmax_scale,
755
- causal=causal,
756
- )
757
- else:
758
- attn_output_unpad = flash_attn_varlen_func(
759
- query_states,
760
- key_states,
761
- value_states,
762
- cu_seqlens_q=cu_seqlens_q,
763
- cu_seqlens_k=cu_seqlens_k,
764
- max_seqlen_q=max_seqlen_in_batch_q,
765
- max_seqlen_k=max_seqlen_in_batch_k,
766
- dropout_p=dropout,
767
- softmax_scale=softmax_scale,
768
- causal=causal,
769
- window_size=(self.config.sliding_window, self.config.sliding_window),
770
- )
771
-
772
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
773
- else:
774
- if not use_sliding_windows:
775
- attn_output = flash_attn_func(
776
- query_states,
777
- key_states,
778
- value_states,
779
- dropout,
780
- softmax_scale=softmax_scale,
781
- causal=causal,
782
- )
783
- else:
784
- attn_output = flash_attn_func(
785
- query_states,
786
- key_states,
787
- value_states,
788
- dropout,
789
- softmax_scale=softmax_scale,
790
- causal=causal,
791
- window_size=(self.config.sliding_window, self.config.sliding_window),
792
- )
793
-
794
- return attn_output
795
-
796
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
797
- batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
798
-
799
- # On the first iteration we need to properly re-create the padding mask
800
- # by slicing it on the proper place
801
- if kv_seq_len != attention_mask.shape[-1]:
802
- attention_mask_num_tokens = attention_mask.shape[-1]
803
- attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len:]
804
-
805
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
806
-
807
- key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
808
- value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
809
-
810
- if query_length == kv_seq_len:
811
- query_layer = index_first_axis(
812
- query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
813
- )
814
- cu_seqlens_q = cu_seqlens_k
815
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
816
- indices_q = indices_k
817
- elif query_length == 1:
818
- max_seqlen_in_batch_q = 1
819
- cu_seqlens_q = torch.arange(
820
- batch_size + 1, dtype=torch.int32, device=query_layer.device
821
- ) # There is a memcpy here, that is very bad.
822
- indices_q = cu_seqlens_q[:-1]
823
- query_layer = query_layer.squeeze(1)
824
- else:
825
- # The -q_len: slice assumes left padding.
826
- attention_mask = attention_mask[:, -query_length:]
827
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
828
-
829
- return (
830
- query_layer,
831
- key_layer,
832
- value_layer,
833
- indices_q,
834
- (cu_seqlens_q, cu_seqlens_k),
835
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
836
- )
837
-
838
-
839
- class MiniMaxText01MLP(nn.Module):
840
- def __init__(self, config):
841
- super().__init__()
842
- self.config = config
843
- self.hidden_size = config.hidden_size
844
- self.intermediate_size = config.intermediate_size
845
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
846
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
847
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
848
- self.act_fn = ACT2FN[config.hidden_act]
849
-
850
- def forward(self, x):
851
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
852
- return down_proj
853
-
854
-
855
- class MiniMaxText01BlockSparseTop2MLP(nn.Module):
856
- def __init__(self, config: MiniMaxText01Config):
857
- super().__init__()
858
- self.ffn_dim = config.intermediate_size
859
- self.hidden_dim = config.hidden_size
860
-
861
- self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
862
- self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
863
- self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
864
-
865
- self.act_fn = ACT2FN[config.hidden_act]
866
-
867
- def forward(self, hidden_states):
868
- current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
869
- current_hidden_states = self.w2(current_hidden_states)
870
- return current_hidden_states
871
-
872
-
873
- class MiniMaxText01BLockSparseTop2MLP(MiniMaxText01BlockSparseTop2MLP):
874
- def __init__(self, *args, **kwargs):
875
- logger.warning_once(
876
- "MiniMaxText01BLockSparseTop2MLP is deprecated by MiniMaxText01BlockSparseTop2MLP and will be removed in v4.40."
877
- )
878
- super().__init__(*args, **kwargs)
879
-
880
-
881
- class MiniMaxText01SparseMoeBlock(nn.Module):
882
- """
883
- This implementation is
884
- strictly equivalent to standard MoE with full capacity (no
885
- dropped tokens). It's faster since it formulates MoE operations
886
- in terms of block-sparse operations to accomodate imbalanced
887
- assignments of tokens to experts, whereas standard MoE either
888
- (1) drop tokens at the cost of reduced performance or (2) set
889
- capacity factor to number of experts and thus waste computation
890
- and memory on padding.
891
- """
892
-
893
- def __init__(self, config):
894
- super().__init__()
895
- self.hidden_dim = config.hidden_size
896
- self.ffn_dim = config.intermediate_size
897
- self.num_experts = config.num_local_experts
898
- self.top_k = config.num_experts_per_tok
899
-
900
- # gating
901
- self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
902
-
903
- self.experts = nn.ModuleList([MiniMaxText01BlockSparseTop2MLP(config) for _ in range(self.num_experts)])
904
-
905
- # Jitter parameters
906
- self.jitter_noise = config.router_jitter_noise
907
-
908
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
909
- """ """
910
- batch_size, sequence_length, hidden_dim = hidden_states.shape
911
- if self.training and self.jitter_noise > 0:
912
- hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
913
- hidden_states = hidden_states.view(-1, hidden_dim)
914
- # router_logits: (batch * sequence_length, n_experts)
915
- router_logits = self.gate(hidden_states)
916
-
917
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
918
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
919
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
920
- # we cast back to the input dtype
921
- routing_weights = routing_weights.to(hidden_states.dtype)
922
-
923
- final_hidden_states = torch.zeros(
924
- (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
925
- )
926
-
927
- # One hot encode the selected experts to create an expert mask
928
- # this will be used to easily index which expert is going to be sollicitated
929
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
930
-
931
- # Loop over all available experts in the model and perform the computation on each expert
932
- for expert_idx in range(self.num_experts):
933
- expert_layer = self.experts[expert_idx]
934
- idx, top_x = torch.where(expert_mask[expert_idx])
935
-
936
- # Index the correct hidden states and compute the expert hidden state for
937
- # the current expert. We need to make sure to multiply the output hidden
938
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
939
- current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
940
- current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
941
-
942
- # However `index_add_` only support torch tensors for indexing so we'll use
943
- # the `top_x` tensor here.
944
- final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
945
- final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
946
- return final_hidden_states, router_logits
947
-
948
-
949
- class MiniMaxText01DecoderLayer(nn.Module):
950
- def __init__(self, config: MiniMaxText01Config, layer_idx: int):
951
- super().__init__()
952
- self.config = config
953
- self.hidden_size = config.hidden_size
954
-
955
- self.self_attn = self.build_attn(config, layer_idx)
956
-
957
- self.layer_idx = layer_idx
958
-
959
- self.block_sparse_moe = MiniMaxText01SparseMoeBlock(config)
960
- self.input_layernorm = MiniMaxText01RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
961
- self.post_attention_layernorm = MiniMaxText01RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
962
-
963
- self.postnorm = getattr(config, 'postnorm', False)
964
- self.layernorm_attention_alpha = getattr(config, 'layernorm_linear_attention_alpha', 1) \
965
- if config.attention_type == 0 else getattr(config, 'layernorm_full_attention_alpha', 1)
966
- self.layernorm_attention_beta = getattr(config, 'layernorm_linear_attention_beta', 1) \
967
- if config.attention_type == 0 else getattr(config, 'layernorm_full_attention_beta', 1)
968
- self.layernorm_mlp_alpha = getattr(config, 'layernorm_mlp_alpha', 1)
969
- self.layernorm_mlp_beta = getattr(config, 'layernorm_mlp_beta', 1)
970
-
971
- shared_intermediate = getattr(config, 'shared_intermediate_size', 0)
972
- self.shared_moe = False
973
- if shared_intermediate > 0:
974
- self.shared_moe = True
975
- self.shared_mlp = MiniMaxText01MLP(config)
976
- self.coefficient = torch.nn.Linear(self.hidden_size, 1, bias=False)
977
-
978
- def build_attn(self, config, layer_idx):
979
- if config.attention_type == 0:
980
- Attention_module = MiniMaxText01LightningAttention
981
- else:
982
- Attention_module = MiniMaxText01FlashAttention2
983
-
984
- return Attention_module(
985
- config,
986
- layer_idx
987
- )
988
-
989
- def forward(
990
- self,
991
- hidden_states: torch.Tensor,
992
- attention_mask: Optional[torch.Tensor] = None,
993
- position_ids: Optional[torch.LongTensor] = None,
994
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
995
- output_attentions: Optional[bool] = False,
996
- output_router_logits: Optional[bool] = False,
997
- use_cache: Optional[bool] = False,
998
- slope_rate: Optional[float] = None,
999
- **kwargs,
1000
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1001
- if "padding_mask" in kwargs:
1002
- warnings.warn(
1003
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1004
- )
1005
- """
1006
- Args:
1007
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1008
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1009
- `(batch, sequence_length)` where padding elements are indicated by 0.
1010
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1011
- output_attentions (`bool`, *optional*):
1012
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1013
- returned tensors for more detail.
1014
- output_router_logits (`bool`, *optional*):
1015
- Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
1016
- should not be returned during inference.
1017
- use_cache (`bool`, *optional*):
1018
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1019
- (see `past_key_values`).
1020
- """
1021
-
1022
- residual = hidden_states
1023
-
1024
- hidden_states = self.input_layernorm(hidden_states)
1025
- if self.postnorm:
1026
- residual = hidden_states
1027
-
1028
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
1029
- hidden_states=hidden_states,
1030
- position_ids=position_ids,
1031
- attn_mask=attention_mask,
1032
- past_key_value=past_key_value,
1033
- output_attentions=output_attentions,
1034
- use_cache=use_cache,
1035
- slope_rate=slope_rate,
1036
- )
1037
-
1038
- hidden_states = residual * self.layernorm_attention_alpha \
1039
- + hidden_states * self.layernorm_attention_beta
1040
-
1041
- # Fully Connected
1042
- residual = hidden_states
1043
- hidden_states = self.post_attention_layernorm(hidden_states)
1044
- if self.postnorm:
1045
- residual = hidden_states
1046
-
1047
- moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states)
1048
- if self.shared_moe:
1049
- output_mlp = self.shared_mlp(hidden_states)
1050
- weight_fp32 = self.coefficient.weight.float()
1051
- coef = hidden_states.to(torch.float32) @ weight_fp32.T
1052
- coef = torch.nn.functional.sigmoid(coef).to(hidden_states.dtype)
1053
- hidden_states = moe_hidden_states * (1 - coef) + output_mlp * coef
1054
- else:
1055
- hidden_states = moe_hidden_states
1056
-
1057
- hidden_states = residual * self.layernorm_mlp_alpha \
1058
- + hidden_states * self.layernorm_mlp_beta
1059
-
1060
- outputs = (hidden_states,)
1061
-
1062
- if output_attentions:
1063
- outputs += (self_attn_weights,)
1064
-
1065
- if use_cache:
1066
- outputs += (present_key_value,)
1067
-
1068
- if output_router_logits:
1069
- outputs += (router_logits,)
1070
-
1071
- return outputs
1072
-
1073
-
1074
- MIXTRAL_START_DOCSTRING = r"""
1075
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1076
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1077
- etc.)
1078
-
1079
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1080
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1081
- and behavior.
1082
-
1083
- Parameters:
1084
- config ([`MiniMaxText01Config`]):
1085
- Model configuration class with all the parameters of the model. Initializing with a config file does not
1086
- load the weights associated with the model, only the configuration. Check out the
1087
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1088
- """
1089
-
1090
-
1091
- @add_start_docstrings(
1092
- "The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.",
1093
- MIXTRAL_START_DOCSTRING,
1094
- )
1095
- # Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->MiniMaxText01
1096
- class MiniMaxText01PreTrainedModel(PreTrainedModel):
1097
- config_class = MiniMaxText01Config
1098
- base_model_prefix = "model"
1099
- supports_gradient_checkpointing = True
1100
- _no_split_modules = ["MiniMaxText01DecoderLayer"]
1101
- _skip_keys_device_placement = "past_key_values"
1102
- _supports_flash_attn_2 = True
1103
- _supports_sdpa = True
1104
-
1105
- def _init_weights(self, module):
1106
- std = self.config.initializer_range
1107
- if isinstance(module, nn.Linear):
1108
- module.weight.data.normal_(mean=0.0, std=std)
1109
- if module.bias is not None:
1110
- module.bias.data.zero_()
1111
- elif isinstance(module, nn.Embedding):
1112
- module.weight.data.normal_(mean=0.0, std=std)
1113
- if module.padding_idx is not None:
1114
- module.weight.data[module.padding_idx].zero_()
1115
-
1116
-
1117
- MIXTRAL_INPUTS_DOCSTRING = r"""
1118
- Args:
1119
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1120
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1121
- it.
1122
-
1123
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1124
- [`PreTrainedTokenizer.__call__`] for details.
1125
-
1126
- [What are input IDs?](../glossary#input-ids)
1127
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1128
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1129
-
1130
- - 1 for tokens that are **not masked**,
1131
- - 0 for tokens that are **masked**.
1132
-
1133
- [What are attention masks?](../glossary#attention-mask)
1134
-
1135
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1136
- [`PreTrainedTokenizer.__call__`] for details.
1137
-
1138
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
1139
- `past_key_values`).
1140
-
1141
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1142
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1143
- information on the default strategy.
1144
-
1145
- - 1 indicates the head is **not masked**,
1146
- - 0 indicates the head is **masked**.
1147
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1148
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1149
- config.n_positions - 1]`.
1150
-
1151
- [What are position IDs?](../glossary#position-ids)
1152
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1153
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
1154
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
1155
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1156
-
1157
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1158
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1159
-
1160
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1161
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1162
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1163
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1164
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1165
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1166
- model's internal embedding lookup matrix.
1167
- use_cache (`bool`, *optional*):
1168
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1169
- `past_key_values`).
1170
- output_attentions (`bool`, *optional*):
1171
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1172
- tensors for more detail.
1173
- output_hidden_states (`bool`, *optional*):
1174
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1175
- more detail.
1176
- output_router_logits (`bool`, *optional*):
1177
- Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
1178
- should not be returned during inference.
1179
- return_dict (`bool`, *optional*):
1180
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1181
- """
1182
-
1183
-
1184
- @add_start_docstrings(
1185
- "The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.",
1186
- MIXTRAL_START_DOCSTRING,
1187
- )
1188
- # Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->MiniMaxText01
1189
- class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
1190
- """
1191
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniMaxText01DecoderLayer`]
1192
-
1193
- Args:
1194
- config: MiniMaxText01Config
1195
- """
1196
-
1197
- def __init__(self, config: MiniMaxText01Config):
1198
- super().__init__(config)
1199
- self.padding_idx = config.pad_token_id
1200
- self.vocab_size = config.vocab_size
1201
-
1202
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1203
- self.attn_type_list = config.attn_type_list
1204
- config_copy = copy.deepcopy(config)
1205
-
1206
- self.layers = nn.ModuleList([])
1207
- for i in range(config.num_hidden_layers):
1208
- _config = copy.deepcopy(config)
1209
- if self.attn_type_list[i] == 0:
1210
- _config._attn_implementation = 'linear_attention'
1211
- _config.attention_type = 0
1212
- else:
1213
- _config._attn_implementation = config_copy._attn_implementation
1214
- _config.attention_type = 1
1215
- self.layers.append(MiniMaxText01DecoderLayer(_config, i))
1216
-
1217
- self._attn_implementation = config_copy._attn_implementation
1218
- self.norm = MiniMaxText01RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1219
-
1220
- self.gradient_checkpointing = False
1221
- self.slopes = self._build_slope_tensor(config.num_attention_heads)
1222
- # mask
1223
- self._linear_attn_mask = torch.empty(0)
1224
-
1225
- # Initialize weights and apply final processing
1226
- self.post_init()
1227
-
1228
- def get_input_embeddings(self):
1229
- return self.embed_tokens
1230
-
1231
- def set_input_embeddings(self, value):
1232
- self.embed_tokens = value
1233
-
1234
- @staticmethod
1235
- def _build_slope_tensor(n_attention_heads: int):
1236
-
1237
- def get_slopes(n):
1238
-
1239
- def get_slopes_power_of_2(n):
1240
- start = 2 ** (-(2 ** -(math.log2(n) - 3)))
1241
- ratio = start
1242
- return [start * ratio ** i for i in range(n)]
1243
-
1244
- if math.log2(n).is_integer():
1245
- return get_slopes_power_of_2(
1246
- n) # In the paper, we only train models that have 2^a heads for some a. This function has
1247
- else: # some good properties that only occur when the input is a power of 2. To maintain that even
1248
- closest_power_of_2 = 2 ** math.floor(
1249
- math.log2(n)) # when the number of heads is not a power of 2, we use this workaround.
1250
- return (get_slopes_power_of_2(closest_power_of_2)
1251
- + get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
1252
-
1253
- # h, 1, 1
1254
- slopes = torch.tensor(get_slopes(n_attention_heads), dtype=torch.float32).reshape(n_attention_heads, 1, 1)
1255
-
1256
- return slopes
1257
-
1258
- # Ignore copy
1259
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
1260
- def forward(
1261
- self,
1262
- input_ids: torch.LongTensor = None,
1263
- attention_mask: Optional[torch.Tensor] = None,
1264
- position_ids: Optional[torch.LongTensor] = None,
1265
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1266
- inputs_embeds: Optional[torch.FloatTensor] = None,
1267
- use_cache: Optional[bool] = None,
1268
- output_attentions: Optional[bool] = None,
1269
- output_hidden_states: Optional[bool] = None,
1270
- output_router_logits: Optional[bool] = None,
1271
- return_dict: Optional[bool] = None,
1272
- ) -> Union[Tuple, MoeModelOutputWithPast]:
1273
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1274
- output_router_logits = (
1275
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
1276
- )
1277
- output_hidden_states = (
1278
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1279
- )
1280
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1281
-
1282
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1283
-
1284
- # retrieve input_ids and inputs_embeds
1285
- if input_ids is not None and inputs_embeds is not None:
1286
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1287
- elif input_ids is not None:
1288
- batch_size, seq_length = input_ids.shape
1289
- default_device = input_ids.device
1290
- elif inputs_embeds is not None:
1291
- batch_size, seq_length, _ = inputs_embeds.shape
1292
- default_device = inputs_embeds.device
1293
- else:
1294
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1295
-
1296
- past_key_values_length = 0
1297
-
1298
- if self.gradient_checkpointing and self.training:
1299
- if use_cache:
1300
- logger.warning_once(
1301
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1302
- )
1303
- use_cache = False
1304
-
1305
- seq_length_with_past = seq_length
1306
- if past_key_values is not None:
1307
- for idx in range(len(past_key_values)):
1308
- if self.attn_type_list[idx] == 1:
1309
- past_key_values_length = past_key_values[idx][0].shape[-3]
1310
- seq_length_with_past = seq_length_with_past + past_key_values_length
1311
- break
1312
-
1313
- if position_ids is None:
1314
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1315
- position_ids = torch.arange(
1316
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1317
- )
1318
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1319
- else:
1320
- position_ids = position_ids.view(-1, seq_length).long()
1321
-
1322
- if inputs_embeds is None:
1323
- inputs_embeds = self.embed_tokens(input_ids)
1324
-
1325
- if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1326
- is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1327
- if is_padding_right:
1328
- raise ValueError(
1329
- "You are attempting to perform batched generation with padding_side='right'"
1330
- " this may lead to unexpected behaviour for Flash Attention version of MiniMaxText01. Make sure to "
1331
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1332
- )
1333
- slope_rates = [self.slopes.to(default_device) for _ in range(len(self.layers))]
1334
- hidden_states = inputs_embeds
1335
- # decoder layers
1336
- all_hidden_states = () if output_hidden_states else None
1337
- all_self_attns = () if output_attentions else None
1338
- all_router_logits = () if output_router_logits else None
1339
- next_decoder_cache = () if use_cache else None
1340
-
1341
- for idx, decoder_layer in enumerate(self.layers):
1342
- if output_hidden_states:
1343
- all_hidden_states += (hidden_states,)
1344
-
1345
- past_key_value = (past_key_values[idx] if past_key_values is not None else None)
1346
- attn_mask = attention_mask
1347
- slope_rate = slope_rates[idx]
1348
- slope_rate = slope_rate * (1 - idx / (len(self.layers) - 1) + 1e-5)
1349
- if self.gradient_checkpointing and self.training:
1350
- layer_outputs = self._gradient_checkpointing_func(
1351
- decoder_layer.__call__,
1352
- hidden_states,
1353
- attention_mask,
1354
- position_ids,
1355
- past_key_values,
1356
- output_attentions,
1357
- output_router_logits,
1358
- use_cache,
1359
- )
1360
- else:
1361
- layer_outputs = decoder_layer(
1362
- hidden_states,
1363
- attention_mask=attn_mask,
1364
- position_ids=position_ids,
1365
- past_key_value=past_key_value,
1366
- output_attentions=output_attentions,
1367
- output_router_logits=output_router_logits,
1368
- use_cache=use_cache,
1369
- slope_rate=slope_rate
1370
- )
1371
-
1372
- hidden_states = layer_outputs[0]
1373
-
1374
- if use_cache:
1375
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
1376
-
1377
- if output_attentions:
1378
- all_self_attns += (layer_outputs[1],)
1379
-
1380
- if output_router_logits:
1381
- all_router_logits += (layer_outputs[-1],)
1382
-
1383
- hidden_states = self.norm(hidden_states)
1384
-
1385
- # add hidden states from the last decoder layer
1386
- if output_hidden_states:
1387
- all_hidden_states += (hidden_states,)
1388
- next_cache = next_decoder_cache if use_cache else None
1389
- if not return_dict:
1390
- return tuple(
1391
- v
1392
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
1393
- if v is not None
1394
- )
1395
- return MoeModelOutputWithPast(
1396
- last_hidden_state=hidden_states,
1397
- past_key_values=next_cache,
1398
- hidden_states=all_hidden_states,
1399
- attentions=all_self_attns,
1400
- router_logits=all_router_logits,
1401
- )
1402
-
1403
-
1404
- class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel):
1405
- _tied_weights_keys = ["lm_head.weight"]
1406
-
1407
- def __init__(self, config):
1408
- super().__init__(config)
1409
- self.model = MiniMaxText01Model(config)
1410
- self.vocab_size = config.vocab_size
1411
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1412
- self.router_aux_loss_coef = config.router_aux_loss_coef
1413
- self.num_experts = config.num_local_experts
1414
- self.num_experts_per_tok = config.num_experts_per_tok
1415
- # Initialize weights and apply final processing
1416
- self.post_init()
1417
-
1418
- def get_input_embeddings(self):
1419
- return self.model.embed_tokens
1420
-
1421
- def set_input_embeddings(self, value):
1422
- self.model.embed_tokens = value
1423
-
1424
- def get_output_embeddings(self):
1425
- return self.lm_head
1426
-
1427
- def set_output_embeddings(self, new_embeddings):
1428
- self.lm_head = new_embeddings
1429
-
1430
- def set_decoder(self, decoder):
1431
- self.model = decoder
1432
-
1433
- def get_decoder(self):
1434
- return self.model
1435
-
1436
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
1437
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1438
- # Ignore copy
1439
- def forward(
1440
- self,
1441
- input_ids: torch.LongTensor = None,
1442
- attention_mask: Optional[torch.Tensor] = None,
1443
- position_ids: Optional[torch.LongTensor] = None,
1444
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1445
- inputs_embeds: Optional[torch.FloatTensor] = None,
1446
- labels: Optional[torch.LongTensor] = None,
1447
- use_cache: Optional[bool] = None,
1448
- output_attentions: Optional[bool] = None,
1449
- output_hidden_states: Optional[bool] = None,
1450
- output_router_logits: Optional[bool] = None,
1451
- return_dict: Optional[bool] = None,
1452
- ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1453
- r"""
1454
- Args:
1455
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1456
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1457
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1458
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1459
-
1460
- Returns:
1461
-
1462
- Example:
1463
-
1464
- ```python
1465
- >>> from transformers import AutoTokenizer, MiniMaxText01ForCausalLM
1466
-
1467
- >>> model = MiniMaxText01ForCausalLM.from_pretrained(PATH_TO_WEIGHTS)
1468
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_WEIGHTS)
1469
-
1470
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1471
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1472
-
1473
- >>> # Generate
1474
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1475
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1476
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1477
- ```"""
1478
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1479
- output_router_logits = (
1480
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
1481
- )
1482
-
1483
- output_hidden_states = (
1484
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1485
- )
1486
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1487
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1488
- outputs = self.model(
1489
- input_ids=input_ids,
1490
- attention_mask=attention_mask,
1491
- position_ids=position_ids,
1492
- past_key_values=past_key_values,
1493
- inputs_embeds=inputs_embeds,
1494
- use_cache=use_cache,
1495
- output_attentions=output_attentions,
1496
- output_hidden_states=output_hidden_states,
1497
- output_router_logits=output_router_logits,
1498
- return_dict=return_dict,
1499
- )
1500
-
1501
- hidden_states = outputs[0]
1502
- logits = self.lm_head(hidden_states)
1503
- logits = logits.float()
1504
-
1505
- loss = None
1506
- if labels is not None:
1507
- # Shift so that tokens < n predict n
1508
- shift_logits = logits[..., :-1, :].contiguous()
1509
- shift_labels = labels[..., 1:].contiguous()
1510
- # Flatten the tokens
1511
- loss_fct = CrossEntropyLoss()
1512
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1513
- shift_labels = shift_labels.view(-1)
1514
- # Enable model parallelism
1515
- shift_labels = shift_labels.to(shift_logits.device)
1516
- loss = loss_fct(shift_logits, shift_labels)
1517
-
1518
- aux_loss = None
1519
- if output_router_logits:
1520
- aux_loss = load_balancing_loss_func(
1521
- outputs.router_logits if return_dict else outputs[-1],
1522
- self.num_experts,
1523
- self.num_experts_per_tok,
1524
- attention_mask,
1525
- )
1526
- if labels is not None:
1527
- loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
1528
-
1529
- if not return_dict:
1530
- output = (logits,) + outputs[1:]
1531
- if output_router_logits:
1532
- output = (aux_loss,) + output
1533
- return (loss,) + output if loss is not None else output
1534
-
1535
- torch.cuda.empty_cache()
1536
- return MoeCausalLMOutputWithPast(
1537
- loss=loss,
1538
- aux_loss=aux_loss,
1539
- logits=logits,
1540
- past_key_values=outputs.past_key_values,
1541
- hidden_states=outputs.hidden_states,
1542
- attentions=outputs.attentions,
1543
- router_logits=outputs.router_logits,
1544
- )
1545
-
1546
- def prepare_inputs_for_generation(
1547
- self,
1548
- input_ids,
1549
- past_key_values=None,
1550
- attention_mask=None,
1551
- inputs_embeds=None,
1552
- **kwargs,
1553
- ):
1554
- if past_key_values:
1555
- input_ids = input_ids[:, -1:]
1556
-
1557
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1558
- if inputs_embeds is not None and past_key_values is None:
1559
- model_inputs = {"inputs_embeds": inputs_embeds}
1560
- else:
1561
- model_inputs = {"input_ids": input_ids}
1562
-
1563
- model_inputs.update({
1564
- "past_key_values": past_key_values,
1565
- "use_cache": kwargs.get("use_cache"),
1566
- "attention_mask": attention_mask,
1567
- })
1568
- return model_inputs
1569
-
1570
- @staticmethod
1571
- def _reorder_cache(past_key_values, beam_idx):
1572
- reordered_past = ()
1573
- for layer_past in past_key_values:
1574
- reordered_past += (
1575
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1576
- )
1577
- return reordered_past
1578
-
1579
-
1580
- @add_start_docstrings(
1581
- """
1582
- The MiniMaxText01 Model transformer with a sequence classification head on top (linear layer).
1583
-
1584
- [`MiniMaxText01ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1585
- (e.g. GPT-2) do.
1586
-
1587
- Since it does classification on the last token, it requires to know the position of the last token. If a
1588
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1589
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1590
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1591
- each row of the batch).
1592
- """,
1593
- MIXTRAL_START_DOCSTRING,
1594
- )
1595
- # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->MiniMaxText01, LLAMA->MIXTRAL
1596
- class MiniMaxText01ForSequenceClassification(MiniMaxText01PreTrainedModel):
1597
- def __init__(self, config):
1598
- super().__init__(config)
1599
- self.num_labels = config.num_labels
1600
- self.model = MiniMaxText01Model(config)
1601
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1602
-
1603
- # Initialize weights and apply final processing
1604
- self.post_init()
1605
-
1606
- def get_input_embeddings(self):
1607
- return self.model.embed_tokens
1608
-
1609
- def set_input_embeddings(self, value):
1610
- self.model.embed_tokens = value
1611
-
1612
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
1613
- def forward(
1614
- self,
1615
- input_ids: torch.LongTensor = None,
1616
- attention_mask: Optional[torch.Tensor] = None,
1617
- position_ids: Optional[torch.LongTensor] = None,
1618
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1619
- inputs_embeds: Optional[torch.FloatTensor] = None,
1620
- labels: Optional[torch.LongTensor] = None,
1621
- use_cache: Optional[bool] = None,
1622
- output_attentions: Optional[bool] = None,
1623
- output_hidden_states: Optional[bool] = None,
1624
- return_dict: Optional[bool] = None,
1625
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1626
- r"""
1627
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1628
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1629
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1630
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1631
- """
1632
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1633
-
1634
- transformer_outputs = self.model(
1635
- input_ids,
1636
- attention_mask=attention_mask,
1637
- position_ids=position_ids,
1638
- past_key_values=past_key_values,
1639
- inputs_embeds=inputs_embeds,
1640
- use_cache=use_cache,
1641
- output_attentions=output_attentions,
1642
- output_hidden_states=output_hidden_states,
1643
- return_dict=return_dict,
1644
- )
1645
- hidden_states = transformer_outputs[0]
1646
- logits = self.score(hidden_states)
1647
-
1648
- if input_ids is not None:
1649
- batch_size = input_ids.shape[0]
1650
- else:
1651
- batch_size = inputs_embeds.shape[0]
1652
-
1653
- if self.config.pad_token_id is None and batch_size != 1:
1654
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1655
- if self.config.pad_token_id is None:
1656
- sequence_lengths = -1
1657
- else:
1658
- if input_ids is not None:
1659
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1660
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1661
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
1662
- sequence_lengths = sequence_lengths.to(logits.device)
1663
- else:
1664
- sequence_lengths = -1
1665
-
1666
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1667
-
1668
- loss = None
1669
- if labels is not None:
1670
- labels = labels.to(logits.device)
1671
- if self.config.problem_type is None:
1672
- if self.num_labels == 1:
1673
- self.config.problem_type = "regression"
1674
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1675
- self.config.problem_type = "single_label_classification"
1676
- else:
1677
- self.config.problem_type = "multi_label_classification"
1678
-
1679
- if self.config.problem_type == "regression":
1680
- loss_fct = MSELoss()
1681
- if self.num_labels == 1:
1682
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1683
- else:
1684
- loss = loss_fct(pooled_logits, labels)
1685
- elif self.config.problem_type == "single_label_classification":
1686
- loss_fct = CrossEntropyLoss()
1687
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1688
- elif self.config.problem_type == "multi_label_classification":
1689
- loss_fct = BCEWithLogitsLoss()
1690
- loss = loss_fct(pooled_logits, labels)
1691
- if not return_dict:
1692
- output = (pooled_logits,) + transformer_outputs[1:]
1693
- return ((loss,) + output) if loss is not None else output
1694
-
1695
- return SequenceClassifierOutputWithPast(
1696
- loss=loss,
1697
- logits=pooled_logits,
1698
- past_key_values=transformer_outputs.past_key_values,
1699
- hidden_states=transformer_outputs.hidden_states,
1700
- attentions=transformer_outputs.attentions,
1701
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modular_minimax.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
3
+ #
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch MiniMax model."""
17
+
18
+ from typing import Optional
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import nn
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.cache_utils import Cache, DynamicCache
26
+ from transformers.configuration_utils import layer_type_validation
27
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
28
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
29
+ from transformers.modeling_layers import GradientCheckpointingLayer
30
+ from transformers.modeling_outputs import MoeModelOutputWithPast
31
+ from transformers.processing_utils import Unpack
32
+ from transformers.utils import TransformersKwargs, logging
33
+ from transformers.utils.generic import OutputRecorder, check_model_inputs
34
+ from transformers.models.mixtral.configuration_mixtral import MixtralConfig
35
+ from transformers.models.mixtral.modeling_mixtral import (
36
+ MixtralAttention,
37
+ MixtralDecoderLayer,
38
+ MixtralForCausalLM,
39
+ MixtralModel,
40
+ MixtralPreTrainedModel,
41
+ MixtralRMSNorm,
42
+ MixtralSparseMoeBlock,
43
+ )
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+
49
+ class MiniMaxConfig(MixtralConfig):
50
+ r"""
51
+ This is the configuration class to store the configuration of a [`MiniMaxModel`]. It is used to instantiate an
52
+ MiniMax model according to the specified arguments, defining the model architecture. Instantiating a configuration
53
+ with the defaults will yield a similar configuration to that of the MiniMax.
54
+
55
+ [MiniMaxAI/MiniMax-Text-01-hf](https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf)
56
+
57
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
58
+ documentation from [`PreTrainedConfig`] for more information.
59
+
60
+
61
+ Args:
62
+ vocab_size (`int`, *optional*, defaults to 32000):
63
+ Vocabulary size of the MiniMax model. Defines the number of different tokens that can be represented by the
64
+ `inputs_ids` passed when calling [`MiniMaxModel`]
65
+ hidden_size (`int`, *optional*, defaults to 4096):
66
+ Dimension of the hidden representations.
67
+ intermediate_size (`int`, *optional*, defaults to 14336):
68
+ Dimension of the MLP representations.
69
+ num_hidden_layers (`int`, *optional*, defaults to 32):
70
+ Number of hidden layers in the Transformer encoder.
71
+ num_attention_heads (`int`, *optional*, defaults to 32):
72
+ Number of attention heads for each attention layer in the Transformer encoder.
73
+ num_key_value_heads (`int`, *optional*, defaults to 8):
74
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
75
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
76
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
77
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
78
+ by meanpooling all the original heads within that group. For more details, check out [this
79
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.
80
+ head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
81
+ The attention head dimension.
82
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
83
+ The non-linear activation function (function or string) in the decoder.
84
+ max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
85
+ The maximum sequence length that this model might ever be used with. MiniMax's sliding window attention
86
+ allows sequence of up to 4096*32 tokens.
87
+ initializer_range (`float`, *optional*, defaults to 0.02):
88
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
89
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
90
+ The epsilon used by the rms normalization layers.
91
+ use_cache (`bool`, *optional*, defaults to `True`):
92
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
93
+ relevant if `config.is_decoder=True`.
94
+ pad_token_id (`int`, *optional*):
95
+ The id of the padding token.
96
+ bos_token_id (`int`, *optional*, defaults to 1):
97
+ The id of the "beginning-of-sequence" token.
98
+ eos_token_id (`int`, *optional*, defaults to 2):
99
+ The id of the "end-of-sequence" token.
100
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
101
+ Whether the model's input and output word embeddings should be tied.
102
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
103
+ The base period of the RoPE embeddings.
104
+ sliding_window (`int`, *optional*):
105
+ Sliding window attention window size. If not specified, will default to `4096`.
106
+ attention_dropout (`float`, *optional*, defaults to 0.0):
107
+ The dropout ratio for the attention probabilities.
108
+ num_experts_per_tok (`int`, *optional*, defaults to 2):
109
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
110
+ parameter
111
+ num_local_experts (`int`, *optional*, defaults to 8):
112
+ Number of experts per Sparse MLP layer.
113
+ output_router_logits (`bool`, *optional*, defaults to `False`):
114
+ Whether or not the router logits should be returned by the model. Enabling this will also
115
+ allow the model to output the auxiliary loss. See [here]() for more details
116
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
117
+ The aux loss factor for the total loss.
118
+ router_jitter_noise (`float`, *optional*, defaults to 0.0):
119
+ Amount of noise to add to the router.
120
+ layer_types (`list`, *optional*):
121
+ Attention pattern for each layer.
122
+ block_size (`int`, *optional*, defaults to 256):
123
+ The length of each attention block, determining how queries, keys, and values
124
+ are grouped and processed for intra- and inter-block attention.
125
+ full_attn_alpha_factor (`float`, *optional*, defaults to 1):
126
+ Weight for residual value in residual connection after normal attention.
127
+ full_attn_beta_factor (`float`, *optional*, defaults to 1):
128
+ Weight for hidden state value in residual connection after normal attention.
129
+ linear_attn_alpha_factor (`float`, *optional*, defaults to 1):
130
+ Weight for residual value in residual connection after lightning attention.
131
+ linear_attn_beta_factor (`float`, *optional*, defaults to 1):
132
+ Weight for hidden state value in residual connection after lightning attention.
133
+ mlp_alpha_factor (`float`, *optional*, defaults to 1):
134
+ Weight for residual value in residual connection after MLP.
135
+ mlp_beta_factor (`float`, *optional*, defaults to 1):
136
+ Weight for hidden state value in residual connection after MLP.
137
+
138
+ ```python
139
+ >>> from transformers import MiniMaxModel, MiniMaxConfig
140
+
141
+ >>> # Initializing a MiniMax style configuration
142
+ >>> configuration = MiniMaxConfig()
143
+
144
+ >>> # Initializing a model from the MiniMax style configuration
145
+ >>> model = MiniMaxModel(configuration)
146
+
147
+ >>> # Accessing the model configuration
148
+ >>> configuration = model.config
149
+ ```"""
150
+
151
+ def __init__(
152
+ self,
153
+ layer_types=None,
154
+ block_size=256,
155
+ full_attn_alpha_factor=1,
156
+ full_attn_beta_factor=1,
157
+ linear_attn_alpha_factor=1,
158
+ linear_attn_beta_factor=1,
159
+ mlp_alpha_factor=1,
160
+ mlp_beta_factor=1,
161
+ **super_kwargs,
162
+ ):
163
+ super().__init__(**super_kwargs)
164
+ self.layer_types = layer_types
165
+ self.block_size = block_size
166
+ self.full_attn_alpha_factor = full_attn_alpha_factor
167
+ self.full_attn_beta_factor = full_attn_beta_factor
168
+ self.linear_attn_alpha_factor = linear_attn_alpha_factor
169
+ self.linear_attn_beta_factor = linear_attn_beta_factor
170
+ self.mlp_alpha_factor = mlp_alpha_factor
171
+ self.mlp_beta_factor = mlp_beta_factor
172
+
173
+ if self.layer_types is None:
174
+ self.layer_types = [
175
+ "full_attention" if bool((i + 1) % 2) else "linear_attention" for i in range(self.num_hidden_layers)
176
+ ]
177
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
178
+
179
+
180
+ class MiniMaxRMSNorm(MixtralRMSNorm):
181
+ pass
182
+
183
+
184
+ class MiniMaxCache(DynamicCache):
185
+ def __init__(self):
186
+ super().__init__()
187
+ self.linear_cache: list[torch.Tensor] = []
188
+
189
+ def set_linear_cache(self, layer_idx, linear_cache):
190
+ # There may be skipped layers, fill them with empty lists
191
+ for _ in range(len(self.linear_cache), layer_idx + 1):
192
+ self.linear_cache.append([])
193
+ self.linear_cache[layer_idx] = linear_cache
194
+
195
+ def get_linear_cache(self, layer_idx: int):
196
+ if layer_idx < len(self):
197
+ return self.linear_cache[layer_idx]
198
+ return None
199
+
200
+ def __len__(self):
201
+ return max(super().__len__(), len(self.linear_cache))
202
+
203
+ def __getitem__(self, layer_idx: int):
204
+ if layer_idx < len(self.linear_cache) and self.linear_cache[layer_idx] != []:
205
+ return (self.linear_cache[layer_idx],)
206
+ return super().__getitem__(layer_idx)
207
+
208
+ def __iter__(self):
209
+ for layer_idx in range(len(self)):
210
+ yield self[layer_idx]
211
+
212
+ def batch_repeat_interleave(self, repeats: int):
213
+ for layer_idx in range(len(self)):
214
+ if self.linear_cache[layer_idx] != []:
215
+ self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0)
216
+ else:
217
+ self.layers[layer_idx].batch_repeat_interleave(repeats)
218
+
219
+ def batch_select_indices(self, indices: torch.Tensor):
220
+ for layer_idx in range(len(self)):
221
+ if self.linear_cache[layer_idx] != []:
222
+ self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...]
223
+ else:
224
+ self.layers[layer_idx].batch_select_indices(indices)
225
+
226
+ def crop(self, max_length: int):
227
+ raise RuntimeError("MiniMaxCache doesnot support `crop` method")
228
+
229
+
230
+ class MiniMaxLightningAttention(nn.Module):
231
+ def __init__(self, config: MiniMaxConfig, layer_idx: int):
232
+ super().__init__()
233
+ self.layer_idx = layer_idx
234
+ self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
235
+ self.num_attention_heads = config.num_attention_heads
236
+ self.num_hidden_layers = config.num_hidden_layers
237
+ self.block_size = config.block_size
238
+
239
+ self.act_fn = ACT2FN[config.hidden_act]
240
+ self.norm = MiniMaxRMSNorm(self.head_dim * self.num_attention_heads)
241
+ self.qkv_proj = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim * 3, bias=False)
242
+ self.out_proj = nn.Linear(self.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
243
+ self.output_gate = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
244
+
245
+ slope_rate = self.get_slope_rate()
246
+ query_decay, key_decay, diagonal_decay = self.decay_factors(slope_rate)
247
+
248
+ self.register_buffer("slope_rate", slope_rate)
249
+ self.register_buffer("query_decay", query_decay)
250
+ self.register_buffer("key_decay", key_decay)
251
+ self.register_buffer("diagonal_decay", diagonal_decay)
252
+
253
+ def get_slope_rate(self):
254
+ base = 1 / (2 ** (8 / self.num_attention_heads))
255
+ exponent = torch.arange(self.num_attention_heads) + 1
256
+ factor = 1 - self.layer_idx / (self.num_hidden_layers - 1 + 1e-5) + 1e-5
257
+
258
+ rate = base**exponent
259
+ rate = rate * factor
260
+ rate = rate[:, None, None]
261
+
262
+ return rate
263
+
264
+ def decay_factors(self, slope_rate):
265
+ block_size_range = torch.arange(self.block_size) + 1
266
+
267
+ query_decay = torch.exp(-slope_rate * block_size_range[:, None])
268
+ key_decay = torch.exp(-slope_rate * (self.block_size - block_size_range[:, None]))
269
+
270
+ diagonal_decay = block_size_range[:, None] - block_size_range[None, :]
271
+ diagonal_decay = diagonal_decay[None, None, :, :]
272
+ diagonal_decay = slope_rate * diagonal_decay
273
+ diagonal_decay = torch.where(diagonal_decay >= 0, -diagonal_decay, float("-inf"))
274
+ diagonal_decay = torch.exp(diagonal_decay)
275
+
276
+ return query_decay, key_decay, diagonal_decay
277
+
278
+ def forward(
279
+ self,
280
+ hidden_states: torch.Tensor,
281
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
282
+ attention_mask: Optional[torch.Tensor],
283
+ past_key_values: Optional[Cache] = None,
284
+ cache_position: Optional[torch.LongTensor] = None,
285
+ **kwargs: Unpack[FlashAttentionKwargs],
286
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
287
+ batch_size, seq_len, hidden_size = hidden_states.shape
288
+ num_blocks = (seq_len + self.block_size - 1) // self.block_size
289
+
290
+ qkv_states = self.act_fn(self.qkv_proj(hidden_states))
291
+ qkv_states = qkv_states.reshape(batch_size, seq_len, self.num_attention_heads, 3 * self.head_dim)
292
+
293
+ query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=3)
294
+
295
+ query_states = query_states.transpose(1, 2)
296
+ key_states = key_states.transpose(1, 2)
297
+ value_states = value_states.transpose(1, 2)
298
+
299
+ # calculated (K.T @ V) and saved as cache
300
+ attn_weights_inter = None
301
+ if past_key_values is not None:
302
+ attn_weights_inter = past_key_values.get_linear_cache(self.layer_idx)
303
+
304
+ if attn_weights_inter is None:
305
+ attn_weights_inter = torch.zeros(batch_size, self.num_attention_heads, self.head_dim, self.head_dim).to(
306
+ value_states
307
+ )
308
+
309
+ # apply attention_mask
310
+ if attention_mask is not None:
311
+ attention_mask = attention_mask.to(dtype=torch.bool) # Ensure it's a boolean tensor
312
+ value_states = value_states.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(-1), 0)
313
+
314
+ attn_output = []
315
+ for i in range(num_blocks):
316
+ start_idx = i * self.block_size
317
+ end_idx = min(start_idx + self.block_size, seq_len)
318
+ current_block_size = end_idx - start_idx
319
+
320
+ current_query_states = query_states[:, :, start_idx:end_idx]
321
+ current_key_states = key_states[:, :, start_idx:end_idx]
322
+ current_value_states = value_states[:, :, start_idx:end_idx]
323
+
324
+ current_query_decay = self.query_decay[:, :current_block_size]
325
+ current_key_decay = self.key_decay[:, -current_block_size:]
326
+ current_diagonal_decay = self.diagonal_decay[:, :, :current_block_size, :current_block_size]
327
+ block_decay = torch.exp(-self.slope_rate * current_block_size)
328
+
329
+ # intra: ( Q @ K.T ) @ V -> QK * V
330
+ attn_weights_intra = torch.matmul(current_query_states, current_key_states.transpose(-1, -2))
331
+ attn_output_intra = torch.matmul(attn_weights_intra * current_diagonal_decay, current_value_states)
332
+
333
+ # inter: Q @ ( K.T @ V ) -> Q * KV
334
+ attn_output_inter = torch.matmul(current_query_states * current_query_decay, attn_weights_inter)
335
+
336
+ # final attention output
337
+ current_attn_output = attn_output_inter + attn_output_intra
338
+ attn_output.append(current_attn_output)
339
+
340
+ # calculate attn_weights_inter for next block or cache
341
+ next_attn_weights_inter = torch.matmul(
342
+ (current_key_states * current_key_decay).transpose(-1, -2), current_value_states
343
+ )
344
+ attn_weights_inter = attn_weights_inter * block_decay + next_attn_weights_inter
345
+
346
+ else:
347
+ ratio = torch.exp(-self.slope_rate)
348
+ attn_output = []
349
+ for i in range(seq_len):
350
+ current_query_states = query_states[:, :, i : i + 1]
351
+ current_key_states = key_states[:, :, i : i + 1]
352
+ current_value_states = value_states[:, :, i : i + 1]
353
+
354
+ current_attn_weights_inter = torch.matmul(current_key_states.transpose(-1, -2), current_value_states)
355
+ attn_weights_inter = ratio * attn_weights_inter + current_attn_weights_inter
356
+ current_attn_output = torch.matmul(current_query_states, attn_weights_inter)
357
+
358
+ attn_output.append(current_attn_output)
359
+
360
+ # concatenate attention outputs over all blocks
361
+ attn_output = torch.cat(attn_output, dim=-2)
362
+
363
+ # final output projection
364
+ attn_output = attn_output.transpose(1, 2)
365
+ attn_output = attn_output.reshape(batch_size, seq_len, self.num_attention_heads * self.head_dim)
366
+ attn_output = self.norm(attn_output)
367
+ attn_output = F.sigmoid(self.output_gate(hidden_states)) * attn_output
368
+ attn_output = self.out_proj(attn_output)
369
+
370
+ # update cache
371
+ if past_key_values is not None:
372
+ past_key_values.set_linear_cache(self.layer_idx, attn_weights_inter)
373
+
374
+ return attn_output, attn_weights_inter
375
+
376
+
377
+ class MiniMaxAttention(MixtralAttention):
378
+ pass
379
+
380
+
381
+ class MiniMaxSparseMoeBlock(MixtralSparseMoeBlock):
382
+ pass
383
+
384
+
385
+ class MiniMaxDecoderLayer(MixtralDecoderLayer, GradientCheckpointingLayer):
386
+ def __init__(self, config: MiniMaxConfig, layer_idx: int):
387
+ super().__init__(config, layer_idx)
388
+
389
+ self.layer_idx = layer_idx
390
+ self.layer_type = config.layer_types[layer_idx]
391
+ self.mlp_alpha_factor = config.mlp_alpha_factor
392
+ self.mlp_beta_factor = config.mlp_beta_factor
393
+
394
+ if self.layer_type == "linear_attention":
395
+ self.self_attn = MiniMaxLightningAttention(config, layer_idx)
396
+ self.attn_alpha_factor = config.linear_attn_alpha_factor
397
+ self.attn_beta_factor = config.linear_attn_beta_factor
398
+ else:
399
+ self.self_attn = MiniMaxAttention(config, layer_idx)
400
+ self.attn_alpha_factor = config.full_attn_alpha_factor
401
+ self.attn_beta_factor = config.full_attn_beta_factor
402
+
403
+ def forward(
404
+ self,
405
+ hidden_states: torch.Tensor,
406
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
407
+ attention_mask: Optional[torch.Tensor] = None,
408
+ position_ids: Optional[torch.LongTensor] = None,
409
+ past_key_values: Optional[Cache] = None,
410
+ use_cache: Optional[bool] = False,
411
+ cache_position: Optional[torch.LongTensor] = None,
412
+ **kwargs: Unpack[FlashAttentionKwargs],
413
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
414
+ hidden_states = self.input_layernorm(hidden_states)
415
+ residual = hidden_states
416
+ hidden_states, _ = self.self_attn(
417
+ hidden_states=hidden_states,
418
+ position_embeddings=position_embeddings,
419
+ attention_mask=attention_mask,
420
+ position_ids=position_ids,
421
+ past_key_values=past_key_values,
422
+ use_cache=use_cache,
423
+ cache_position=cache_position,
424
+ **kwargs,
425
+ )
426
+ hidden_states = residual * self.attn_alpha_factor + hidden_states * self.attn_beta_factor
427
+ hidden_states = self.post_attention_layernorm(hidden_states)
428
+ residual = hidden_states
429
+ hidden_states = self.block_sparse_moe(hidden_states)
430
+ hidden_states = residual * self.mlp_alpha_factor + hidden_states * self.mlp_beta_factor
431
+
432
+ return hidden_states
433
+
434
+
435
+ class MiniMaxPreTrainedModel(MixtralPreTrainedModel):
436
+ _can_compile_fullgraph = False
437
+ _can_record_outputs = {
438
+ "router_logits": OutputRecorder(nn.Linear, layer_name="block_sparse_moe.gate", index=0),
439
+ "hidden_states": MiniMaxDecoderLayer,
440
+ "attentions": [MiniMaxAttention, MiniMaxLightningAttention],
441
+ }
442
+
443
+
444
+ class MiniMaxModel(MixtralModel):
445
+ @check_model_inputs
446
+ def forward(
447
+ self,
448
+ input_ids: Optional[torch.LongTensor] = None,
449
+ attention_mask: Optional[torch.Tensor] = None,
450
+ position_ids: Optional[torch.LongTensor] = None,
451
+ past_key_values: Optional[MiniMaxCache] = None,
452
+ inputs_embeds: Optional[torch.FloatTensor] = None,
453
+ use_cache: Optional[bool] = None,
454
+ cache_position: Optional[torch.LongTensor] = None,
455
+ **kwargs: Unpack[TransformersKwargs],
456
+ ) -> MoeModelOutputWithPast:
457
+ if (input_ids is None) ^ (inputs_embeds is not None):
458
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
459
+
460
+ if use_cache and past_key_values is None:
461
+ past_key_values = MiniMaxCache()
462
+ elif use_cache and not isinstance(past_key_values, MiniMaxCache):
463
+ raise ValueError(
464
+ f"MiniMax uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}."
465
+ )
466
+
467
+ if inputs_embeds is None:
468
+ inputs_embeds = self.embed_tokens(input_ids)
469
+
470
+ if cache_position is None:
471
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
472
+ cache_position = torch.arange(
473
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
474
+ )
475
+ if position_ids is None:
476
+ position_ids = cache_position.unsqueeze(0)
477
+
478
+ mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
479
+ causal_mask = mask_function(
480
+ config=self.config,
481
+ input_embeds=inputs_embeds,
482
+ attention_mask=attention_mask,
483
+ cache_position=cache_position,
484
+ past_key_values=past_key_values,
485
+ position_ids=position_ids,
486
+ )
487
+
488
+ hidden_states = inputs_embeds
489
+
490
+ # create position embeddings to be shared across the decoder layers
491
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
492
+
493
+ for decoder_layer in self.layers:
494
+ if decoder_layer.layer_type == "full_attention":
495
+ input_attention_mask = causal_mask
496
+ else:
497
+ # lightning attention uses original attention_mask, and uses it only for the first step
498
+ input_attention_mask = attention_mask
499
+
500
+ hidden_states = decoder_layer(
501
+ hidden_states,
502
+ position_embeddings=position_embeddings,
503
+ attention_mask=input_attention_mask,
504
+ position_ids=position_ids,
505
+ past_key_values=past_key_values,
506
+ use_cache=use_cache,
507
+ cache_position=cache_position,
508
+ **kwargs,
509
+ )
510
+
511
+ hidden_states = self.norm(hidden_states)
512
+
513
+ return MoeModelOutputWithPast(
514
+ last_hidden_state=hidden_states,
515
+ past_key_values=past_key_values,
516
+ )
517
+
518
+
519
+ class MiniMaxForCausalLM(MixtralForCausalLM):
520
+ def forward(self, **super_kwargs):
521
+ r"""
522
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
523
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
524
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
525
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
526
+
527
+ Example:
528
+
529
+ ```python
530
+ >>> from transformers import AutoTokenizer, MiniMaxForCausalLM
531
+
532
+ >>> model = MiniMaxForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
533
+ >>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
534
+
535
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
536
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
537
+
538
+ >>> # Generate
539
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
540
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
541
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
542
+ ```"""
543
+ return super().forward(**super_kwargs)
544
+