JoyboyGo commited on
Commit
1bd19d1
·
verified ·
1 Parent(s): 7944479

Upload folder using huggingface_hub

Browse files
YuLan-Mini-Nanbeige-Distill/config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3NextForCausalLM"
4
+ ],
5
+ "attention_bias": true,
6
+ "attention_dropout": 0.0,
7
+ "attn_output_gate": false,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_qwen3_next.Qwen3NextConfig",
10
+ "AutoModel": "modeling_qwen3_next.Qwen3NextForCausalLM",
11
+ "AutoModelForCausalLM": "modeling_qwen3_next.Qwen3NextForCausalLM"
12
+ },
13
+ "bos_token_id": 1,
14
+ "decoder_sparse_step": 1,
15
+ "dtype": "float32",
16
+ "enable_qk_norm": false,
17
+ "eos_token_id": 2,
18
+ "full_attention_interval": 0,
19
+ "head_dim": 64,
20
+ "hidden_act": "silu",
21
+ "hidden_size": 1920,
22
+ "initializer_range": 0.02,
23
+ "intermediate_size": 4800,
24
+ "layer_types": ["linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "full_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "full_attention", "full_attention", "linear_attention", "full_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "full_attention", "linear_attention", "full_attention", "full_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention", "linear_attention"],
25
+ "linear_conv_kernel_dim": 4,
26
+ "linear_key_head_dim": 64,
27
+ "linear_num_key_heads": 8,
28
+ "linear_num_value_heads": 32,
29
+ "linear_value_head_dim": 64,
30
+ "max_position_embeddings": 32768,
31
+ "mlp_only_layers": [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55],
32
+ "num_experts_per_tok": 2,
33
+ "num_experts": 0,
34
+ "model_type": "qwen3_next",
35
+ "moe_intermediate_size": 0,
36
+ "norm_topk_prob": true,
37
+ "num_attention_heads": 30,
38
+ "num_hidden_layers": 56,
39
+ "num_key_value_heads": 6,
40
+ "output_router_logits": false,
41
+ "partial_rotary_factor": 1.0,
42
+ "rms_norm_eps": 1e-06,
43
+ "rope_scaling": null,
44
+ "rope_theta": 490000,
45
+ "router_aux_loss_coef": 0.001,
46
+ "router_bias": false,
47
+ "moe_router_score_function": "softmax",
48
+ "shared_expert_intermediate_size": 0,
49
+ "use_shared_expert_gate": true,
50
+ "tie_word_embeddings": false,
51
+ "transformers_version": "4.57.1",
52
+ "use_cache": true,
53
+ "use_sliding_window": false,
54
+ "ffn_token_shift": null,
55
+ "ffn_intermediate_token_shift": null,
56
+ "attn_token_shift": null,
57
+ "attn_q_token_shift": null,
58
+ "attn_k_token_shift": null,
59
+ "attn_v_token_shift": null,
60
+ "token_shift_conv_size": 4,
61
+ "token_shift_conv_init": "default",
62
+ "attn_position_embedding_type": "rope",
63
+ "rnn_position_embedding_type": "nope",
64
+ "attn_logits_scaling": null,
65
+ "vocab_size": 99000
66
+ }
YuLan-Mini-Nanbeige-Distill/configuration_qwen3_next.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Qwen3-Next model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig, layer_type_validation
18
+ from transformers.modeling_rope_utils import rope_config_validation
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class Qwen3NextConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a
28
+ Qwen3-Next model according to the specified arguments, defining the model architecture.
29
+ Instantiating a configuration with the defaults will yield a similar configuration to that of
30
+ Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct).
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 151936):
38
+ Vocabulary size of the model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids`.
40
+ hidden_size (`int`, *optional*, defaults to 2048):
41
+ Dimension of the hidden representations.
42
+ intermediate_size (`int`, *optional*, defaults to 5632):
43
+ Dimension of the MLP representations.
44
+ num_hidden_layers (`int`, *optional*, defaults to 48):
45
+ Number of hidden layers in the Transformer encoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 16):
47
+ Number of attention heads for each attention layer in the Transformer encoder.
48
+ num_key_value_heads (`int`, *optional*, defaults to 2):
49
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
50
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
51
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
52
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
53
+ by meanpooling all the original heads within that group. For more details checkout [this
54
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
55
+ hidden_act (`str`, *optional*, defaults to `"silu"`):
56
+ The non-linear activation function in the decoder.
57
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
58
+ The maximum sequence length that this model might ever be used with.
59
+ initializer_range (`float`, *optional*, defaults to 0.02):
60
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
61
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
62
+ The epsilon used by the rms normalization layers.
63
+ use_cache (`bool`, *optional*, defaults to `True`):
64
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
65
+ relevant if `config.is_decoder=True`.
66
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
67
+ Whether the model's input and output word embeddings should be tied.
68
+ rope_theta (`float`, *optional*, defaults to 10000.0):
69
+ The base period of the RoPE embeddings.
70
+ rope_scaling (`Dict`, *optional*):
71
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
72
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
73
+ accordingly.
74
+ Expected contents:
75
+ `rope_type` (`str`):
76
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
77
+ 'llama3'], with 'default' being the original RoPE implementation.
78
+ `factor` (`float`, *optional*):
79
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
80
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
81
+ original maximum pre-trained length.
82
+ `original_max_position_embeddings` (`int`, *optional*):
83
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
84
+ pretraining.
85
+ `attention_factor` (`float`, *optional*):
86
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
87
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
88
+ `factor` field to infer the suggested value.
89
+ `beta_fast` (`float`, *optional*):
90
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
91
+ ramp function. If unspecified, it defaults to 32.
92
+ `beta_slow` (`float`, *optional*):
93
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
94
+ ramp function. If unspecified, it defaults to 1.
95
+ `short_factor` (`List[float]`, *optional*):
96
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
97
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
98
+ size divided by the number of attention heads divided by 2
99
+ `long_factor` (`List[float]`, *optional*):
100
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
101
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
102
+ size divided by the number of attention heads divided by 2
103
+ `low_freq_factor` (`float`, *optional*):
104
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
105
+ `high_freq_factor` (`float`, *optional*):
106
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
107
+ partial_rotary_factor (`float`, *optional*, defaults to 0.25):
108
+ Percentage of the query and keys which will have rotary embedding.
109
+ attention_bias (`bool`, *optional*, defaults to `False`):
110
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
111
+ attention_dropout (`float`, *optional*, defaults to 0.0):
112
+ The dropout ratio for the attention probabilities.
113
+ head_dim (`int`, *optional*, defaults to 256):
114
+ Projection weights dimension in multi-head attention.
115
+ linear_conv_kernel_dim (`int`, *optional*, defaults to 4):
116
+ Kernel size of the convolution used in linear attention layers.
117
+ linear_key_head_dim (`int`, *optional*, defaults to 128):
118
+ Dimension of each key head in linear attention.
119
+ linear_value_head_dim (`int`, *optional*, defaults to 128):
120
+ Dimension of each value head in linear attention.
121
+ linear_num_key_heads (`int`, *optional*, defaults to 16):
122
+ Number of key heads used in linear attention layers.
123
+ linear_num_value_heads (`int`, *optional*, defaults to 32):
124
+ Number of value heads used in linear attention layers.
125
+ decoder_sparse_step (`int`, *optional*, defaults to 1):
126
+ The frequency of the MoE layer.
127
+ moe_intermediate_size (`int`, *optional*, defaults to 512):
128
+ Intermediate size of the routed expert.
129
+ shared_expert_intermediate_size (`int`, *optional*, defaults to 512):
130
+ Intermediate size of the shared expert.
131
+ num_experts_per_tok (`int`, *optional*, defaults to 10):
132
+ Number of selected experts.
133
+ num_experts (`int`, *optional*, defaults to 512):
134
+ Number of routed experts.
135
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
136
+ Whether to normalize the topk probabilities.
137
+ output_router_logits (`bool`, *optional*, defaults to `False`):
138
+ Whether or not the router logits should be returned by the model. Enabling this will also
139
+ allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
140
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
141
+ The aux loss factor for the total loss.
142
+ mlp_only_layers (`list[int]`, *optional*, defaults to `[]`):
143
+ Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock
144
+ The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
145
+ If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
146
+ layer_types (`list[str]`, *optional*):
147
+ Types of each layer (attention or linear).
148
+ enable_qk_norm (`bool`, *optional*, defaults to `False`):
149
+ Whether to apply L2 normalization to the query and key embeddings.
150
+ router_bias (`bool`, *optional*, defaults to `False`):
151
+ Whether to use a bias in the router logits.
152
+ moe_router_score_function (`str`, *optional*, defaults to `"softmax"`):
153
+ The score function used in the MoE router.
154
+ ffn_token_shift (`str`, *optional*): Token shift before FFN/MoE. `None`, `"cat"`, or `"conv"`.
155
+ ffn_intermediate_token_shift (`str`, *optional*): Token shift in MLP before down_proj. `None`, `"cat"`, or `"conv"`.
156
+ attn_token_shift (`str`, *optional*): Token shift before attention. `None`, `"cat"`, or `"conv"`.
157
+ attn_q_token_shift (`str`, *optional*): Token shift on query after projection. `None`, `"cat"`, or `"conv"`.
158
+ attn_k_token_shift (`str`, *optional*): Token shift on key after projection. `None`, `"cat"`, or `"conv"`.
159
+ attn_v_token_shift (`str`, *optional*): Token shift on value after projection. `None`, `"cat"`, or `"conv"`.
160
+ token_shift_conv_size (`int`, *optional*, defaults to 4): Kernel size for token-shift Conv1d.
161
+ token_shift_conv_init (`str`, *optional*, defaults to `"default"`): Init for token-shift Conv1d; `"identity"` for causal identity.
162
+ ```python
163
+ >>> from transformers import Qwen3NextModel, Qwen3NextConfig
164
+
165
+ >>> # Initializing a Qwen3Next style configuration
166
+ >>> configuration = Qwen3NextConfig()
167
+
168
+ >>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration
169
+ >>> model = Qwen3NextModel(configuration)
170
+
171
+ >>> # Accessing the model configuration
172
+ >>> configuration = model.config
173
+ ```
174
+ """
175
+
176
+ model_type = "qwen3_next"
177
+ keys_to_ignore_at_inference = ["past_key_values"]
178
+
179
+ base_model_tp_plan = {
180
+ "layers.*.self_attn.q_proj": "colwise",
181
+ "layers.*.self_attn.k_proj": "colwise",
182
+ "layers.*.self_attn.v_proj": "colwise",
183
+ "layers.*.self_attn.o_proj": "rowwise",
184
+ "layers.*.mlp.experts.*.gate_proj": "colwise",
185
+ "layers.*.mlp.experts.*.up_proj": "colwise",
186
+ "layers.*.mlp.experts.*.down_proj": "rowwise",
187
+ "layers.*.mlp.shared_experts.gate_proj": "colwise",
188
+ "layers.*.mlp.shared_experts.up_proj": "colwise",
189
+ "layers.*.mlp.shared_experts.down_proj": "rowwise",
190
+ "layers.*.mlp.gate_proj": "colwise",
191
+ "layers.*.mlp.up_proj": "colwise",
192
+ "layers.*.mlp.down_proj": "rowwise",
193
+ }
194
+ base_model_pp_plan = {
195
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
196
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
197
+ "norm": (["hidden_states"], ["hidden_states"]),
198
+ }
199
+
200
+ def __init__(
201
+ self,
202
+ vocab_size=151936,
203
+ hidden_size=2048,
204
+ intermediate_size=5632,
205
+ num_hidden_layers=48,
206
+ num_attention_heads=16,
207
+ num_key_value_heads=2,
208
+ hidden_act="silu",
209
+ max_position_embeddings=32768,
210
+ initializer_range=0.02,
211
+ rms_norm_eps=1e-6,
212
+ use_cache=True,
213
+ tie_word_embeddings=False,
214
+ rope_theta=10000.0,
215
+ rope_scaling=None,
216
+ partial_rotary_factor=0.25,
217
+ attention_bias=False,
218
+ attention_dropout=0.0,
219
+ head_dim=256,
220
+ linear_conv_kernel_dim=4,
221
+ linear_key_head_dim=128,
222
+ linear_value_head_dim=128,
223
+ linear_num_key_heads=16,
224
+ linear_num_value_heads=32,
225
+ decoder_sparse_step=1,
226
+ moe_intermediate_size=512,
227
+ shared_expert_intermediate_size=512,
228
+ num_experts_per_tok=10,
229
+ num_experts=512,
230
+ norm_topk_prob=True,
231
+ output_router_logits=False,
232
+ router_aux_loss_coef=0.001,
233
+ mlp_only_layers=[],
234
+ layer_types=None,
235
+ enable_qk_norm=False, # @o2iginal
236
+ router_bias=False, # @o2iginal
237
+ attn_output_gate=False, # @o2iginal
238
+ moe_router_score_function="softmax", # @xcx
239
+ # Cannon layer / token shifting (align with Megatron) @o2iginal
240
+ ffn_token_shift=None,
241
+ ffn_intermediate_token_shift=None,
242
+ attn_token_shift=None,
243
+ attn_q_token_shift=None,
244
+ attn_k_token_shift=None,
245
+ attn_v_token_shift=None,
246
+ token_shift_conv_size=4,
247
+ token_shift_conv_init="default",
248
+ # Separate RoPE for attention vs linear/RNN (GDN): "rope" or "nope"
249
+ attn_position_embedding_type="rope",
250
+ rnn_position_embedding_type="nope",
251
+ # Optional logits scaling for length extrapolation (attention only): None, float, or "log" / "log <a>"
252
+ attn_logits_scaling=None,
253
+ **kwargs,
254
+ ):
255
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
256
+ assert attn_position_embedding_type in ("rope", "nope"), (
257
+ f"attn_position_embedding_type must be 'rope' or 'nope', got {attn_position_embedding_type}"
258
+ )
259
+ assert rnn_position_embedding_type in ("rope", "nope"), (
260
+ f"rnn_position_embedding_type must be 'rope' or 'nope', got {rnn_position_embedding_type}"
261
+ )
262
+ self.attn_position_embedding_type = attn_position_embedding_type
263
+ self.rnn_position_embedding_type = rnn_position_embedding_type
264
+ self.attn_logits_scaling = attn_logits_scaling
265
+ self.vocab_size = vocab_size
266
+ self.max_position_embeddings = max_position_embeddings
267
+ self.hidden_size = hidden_size
268
+ self.intermediate_size = intermediate_size
269
+ self.num_hidden_layers = num_hidden_layers
270
+ self.num_attention_heads = num_attention_heads
271
+ self.num_key_value_heads = num_key_value_heads
272
+ self.hidden_act = hidden_act
273
+ self.initializer_range = initializer_range
274
+ self.rms_norm_eps = rms_norm_eps
275
+ self.use_cache = use_cache
276
+ self.rope_theta = rope_theta
277
+ self.rope_scaling = rope_scaling
278
+ self.partial_rotary_factor = partial_rotary_factor
279
+ self.attention_bias = attention_bias
280
+ self.attention_dropout = attention_dropout
281
+ self.head_dim = head_dim
282
+ rope_config_validation(self)
283
+
284
+ self.layer_types = layer_types
285
+ if self.layer_types is None:
286
+ interval_pattern = kwargs.get("full_attention_interval", 4)
287
+ self.layer_types = [
288
+ "linear_attention" if bool((i + 1) % interval_pattern) else "full_attention"
289
+ for i in range(self.num_hidden_layers)
290
+ ]
291
+ layer_type_validation(self.layer_types)
292
+
293
+ # linear attention part
294
+ self.linear_conv_kernel_dim = linear_conv_kernel_dim
295
+ self.linear_key_head_dim = linear_key_head_dim
296
+ self.linear_value_head_dim = linear_value_head_dim
297
+ self.linear_num_key_heads = linear_num_key_heads
298
+ self.linear_num_value_heads = linear_num_value_heads
299
+
300
+ # MoE arguments
301
+ self.decoder_sparse_step = decoder_sparse_step
302
+ self.moe_intermediate_size = moe_intermediate_size
303
+ self.shared_expert_intermediate_size = shared_expert_intermediate_size
304
+ self.num_experts_per_tok = num_experts_per_tok
305
+ self.num_experts = num_experts
306
+ self.norm_topk_prob = norm_topk_prob
307
+ self.output_router_logits = output_router_logits
308
+ self.router_aux_loss_coef = router_aux_loss_coef
309
+ self.mlp_only_layers = mlp_only_layers
310
+ self.enable_qk_norm = enable_qk_norm
311
+ self.router_bias = router_bias
312
+ self.attn_output_gate = attn_output_gate
313
+ self.moe_router_score_function = moe_router_score_function
314
+ self.enable_qk_norm = enable_qk_norm
315
+ self.router_bias = router_bias
316
+ self.attn_output_gate = attn_output_gate
317
+ self.moe_router_score_function = moe_router_score_function
318
+
319
+ # Token shifting (cannon layer): None | "cat" | "conv"
320
+ self.ffn_token_shift = ffn_token_shift
321
+ self.ffn_intermediate_token_shift = ffn_intermediate_token_shift
322
+ self.attn_token_shift = attn_token_shift
323
+ self.attn_q_token_shift = attn_q_token_shift
324
+ self.attn_k_token_shift = attn_k_token_shift
325
+ self.attn_v_token_shift = attn_v_token_shift
326
+ self.token_shift_conv_size = token_shift_conv_size
327
+ self.token_shift_conv_init = token_shift_conv_init
328
+
329
+ __all__ = ["Qwen3NextConfig"]
YuLan-Mini-Nanbeige-Distill/hf2mcore.log ADDED
@@ -0,0 +1,1291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torchrun --nproc_per_node 1 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 41921 /workspace/lvzhihao/PostTrain/YuLan-Pretrain/scripts/distributed_checkpoints_convertor/impl/convert.py --tokenizer-type HuggingFaceTokenizer --tokenizer-model /tmp/tmp.FZZhIF5Vmh --hf-dir /tmp/tmp.FZZhIF5Vmh --mcore2hf --use-gpu --bf16 --normalization RMSNorm --swiglu --disable-bias-linear --seq-length 1 --max-position-embeddings 490000 --attention-backend auto --position-embedding-type rope --kv-channels 64 --group-query-attention --add-qkv-bias --num-layers 56 --hidden-size 1920 --ffn-hidden-size 4800 --num-attention-heads 30 --untie-embeddings-and-output-weights --rotary-base 490000 --rotary-percent 1.00 --num-query-groups 6 --normalization RMSNorm --norm-epsilon 1e-6 --linear-attention-type gated_delta_net --linear-attention-freq [1,1,1,1,1,1,1,1,1,1,1,1,0,1,1,1,1,1,1,1,0,0,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,1,0,0,1,1,1,1,1,1] --linear-conv-kernel-dim 4 --linear-key-head-dim 64 --linear-value-head-dim 64 --linear-num-key-heads 8 --linear-num-value-heads 32 --micro-batch-size 1 --global-batch-size 1024 --train-iters 500000 --weight-decay 0.1 --adam-beta1 0.9 --adam-beta2 0.95 --init-method-std 0.006 --clip-grad 1.0 --lr 2.0e-5 --lr-decay-style cosine --min-lr 6.0e-6 --lr-warmup-fraction .001 --lr-decay-iters 430000 --bf16 --tensor-model-parallel-size 1 --pipeline-model-parallel-size 1 --expert-tensor-parallel-size 1 --expert-model-parallel-size 1 --log-interval 100 --save-interval 10000 --eval-interval 1000 --eval-iters 10 --model-type GPT --load-dir /capacity/userdata/vc0e4b0o65t5/lvzhihao/PostTrain/YuLan-Pretrain/outputs/yulan_mini_sft/run_sl16384_tp1_pp1_cp2/checkpoint/yulan-gdn-sft-1b-sl16384-lr1e-5-gbs64-mb1-tp1-pp1-cp2 --save-dir /capacity/userdata/vc0e4b0o65t5/lvzhihao/PostTrain/YuLan-Pretrain/outputs/yulan_mini_sft/run_sl16384_tp1_pp1_cp2/checkpoint/yulan-gdn-sft-1b-sl16384-lr1e-5-gbs64-mb1-tp1-pp1-cp2/iter_2340-hf --dist-ckpt-optim-fully-reshardable --skip-train --use-cpu-initialization --padded-vocab-size 99000 --no-load-optim --no-load-rng --logging-level 1 --attention-backend auto --synchronizer mcore_gdn_moe --pretrain-script mcore_gdn_moe.model_provider --debug --max-shard-size 20GB
2
+ W0416 17:58:28.424000 71892 .venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
3
+ W0416 17:58:28.424000 71892 .venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
4
+ fused_indices_to_multihot has reached end of life. Please migrate to a non-experimental function.
5
+ /workspace/lvzhihao/PostTrain/YuLan-Pretrain/.venv/lib/python3.11/site-packages/modelopt/torch/utils/import_utils.py:32: UserWarning: Failed to import diffusers plugin due to: ImportError('Requires Flash-Attention version >=2.7.1,<=2.8.2 but got 2.8.3.'). You may ignore this warning if you do not need this plugin.
6
+ warnings.warn(
7
+ INFO 04-16 17:58:48 [__init__.py:216] Automatically detected platform cuda.
8
+ /workspace/lvzhihao/PostTrain/YuLan-Pretrain/.venv/lib/python3.11/site-packages/modelopt/torch/__init__.py:36: UserWarning: transformers version 4.57.1 is incompatible with nvidia-modelopt and may cause issues. Please install recommended version with `pip install nvidia-modelopt[hf]` if working with HF models.
9
+ _warnings.warn(
10
+ Warning: Pai-Megatron-Patch arguments not available, some arguments may not be recognized
11
+ using world size: 1, data-parallel size: 1, context-parallel size: 1, hierarchical context-parallel sizes: None, tensor-model-parallel size: 1, pipeline-model-parallel size: 1
12
+ Number of virtual stages per pipeline stage: None
13
+ accumulate and all-reduce gradients in fp32 for bfloat16 data type.
14
+ using torch.bfloat16 for parameters ...
15
+ ------------------------ arguments ------------------------
16
+ account_for_embedding_in_pipeline_split ......... False
17
+ account_for_loss_in_pipeline_split .............. False
18
+ accumulate_allreduce_grads_in_fp32 .............. True
19
+ activation_func_clamp_value ..................... None
20
+ adam_beta1 ...................................... 0.9
21
+ adam_beta2 ...................................... 0.95
22
+ adam_eps ........................................ 1e-08
23
+ adamw_lr_mup_scaler ............................. False
24
+ add_bias_linear ................................. False
25
+ add_position_embedding .......................... True
26
+ add_qkv_bias .................................... True
27
+ adlr_autoresume ................................. False
28
+ adlr_autoresume_interval ........................ 1000
29
+ align_grad_reduce ............................... True
30
+ align_param_gather .............................. False
31
+ allow_ambiguous_pad_tokens ...................... False
32
+ app_tag_run_name ................................ None
33
+ app_tag_run_version ............................. 0.0.0
34
+ apply_layernorm_1p .............................. False
35
+ apply_query_key_layer_scaling ................... False
36
+ apply_residual_connection_post_layernorm ........ False
37
+ apply_rope_fusion ............................... True
38
+ async_save ...................................... None
39
+ async_tensor_model_parallel_allreduce ........... True
40
+ attention_backend ............................... AttnBackend.auto
41
+ attention_dropout ............................... 0.1
42
+ attention_output_gate ........................... False
43
+ attention_softmax_in_fp32 ....................... False
44
+ attn_k_token_shift .............................. None
45
+ attn_output_gate ................................ None
46
+ attn_output_gate_rand_init ...................... False
47
+ attn_q_token_shift .............................. None
48
+ attn_token_shift ................................ None
49
+ attn_v_token_shift .............................. None
50
+ auto_detect_ckpt_format ......................... False
51
+ auto_generate_cu_seqlens ........................ False
52
+ auto_model ...................................... AutoModelForCausalLM
53
+ barrier_with_L1_time ............................ True
54
+ benchmark_eval .................................. False
55
+ benchmark_global_batch .......................... None
56
+ benchmark_interval .............................. None
57
+ benchmark_micro_batch ........................... None
58
+ benchmark_sequence_length ....................... None
59
+ benchmark_tasks ................................. None
60
+ bert_binary_head ................................ True
61
+ bert_embedder_type .............................. megatron
62
+ bert_load ....................................... None
63
+ bf16 ............................................ True
64
+ bias_dropout_fusion ............................. True
65
+ bias_gelu_fusion ................................ False
66
+ bias_swiglu_fusion .............................. True
67
+ biencoder_projection_dim ........................ 0
68
+ biencoder_shared_query_context_model ............ False
69
+ block_data_path ................................. None
70
+ cache_mla_latents ............................... False
71
+ calc_ft_timeouts ................................ False
72
+ calculate_per_token_loss ........................ False
73
+ check_for_large_grads ........................... False
74
+ check_for_nan_in_loss_and_grad .................. True
75
+ check_for_spiky_loss ............................ False
76
+ check_weight_hash_across_dp_replicas_interval ... None
77
+ ckpt_assume_constant_structure .................. False
78
+ ckpt_convert_format ............................. None
79
+ ckpt_convert_save ............................... None
80
+ ckpt_convert_update_legacy_dist_opt_format ...... False
81
+ ckpt_format ..................................... torch_dist
82
+ ckpt_fully_parallel_load ........................ False
83
+ ckpt_fully_parallel_save ........................ True
84
+ ckpt_fully_parallel_save_deprecated ............. False
85
+ ckpt_step ....................................... None
86
+ classes_fraction ................................ 1.0
87
+ clip_grad ....................................... 1.0
88
+ clone_scatter_output_in_embedding ............... True
89
+ config_logger_dir ...............................
90
+ consumed_train_samples .......................... 0
91
+ consumed_valid_samples .......................... 0
92
+ context_parallel_size ........................... 1
93
+ cp_comm_type .................................... ['p2p']
94
+ create_attention_mask_in_dataloader ............. True
95
+ cross_entropy_fusion_impl ....................... native
96
+ cross_entropy_loss_fusion ....................... False
97
+ cuda_graph_impl ................................. none
98
+ cuda_graph_scope ................................ []
99
+ cuda_graph_warmup_steps ......................... 3
100
+ data_args_path .................................. None
101
+ data_cache_path ................................. None
102
+ data_parallel_random_init ....................... False
103
+ data_parallel_sharding_strategy ................. no_shard
104
+ data_parallel_size .............................. 1
105
+ data_path ....................................... None
106
+ data_per_class_fraction ......................... 1.0
107
+ data_sharding ................................... True
108
+ dataloader_type ................................. single
109
+ ddp_average_in_collective ....................... False
110
+ ddp_bucket_size ................................. None
111
+ ddp_num_buckets ................................. None
112
+ ddp_pad_buckets_for_high_nccl_busbw ............. False
113
+ debug ........................................... True
114
+ decode_only_cuda_graphs ......................... False
115
+ decoder_first_pipeline_num_layers ............... None
116
+ decoder_last_pipeline_num_layers ................ None
117
+ decoder_num_layers .............................. None
118
+ decoder_seq_length .............................. None
119
+ decoupled_lr .................................... None
120
+ decoupled_min_lr ................................ None
121
+ decrease_batch_size_if_needed ................... False
122
+ defer_embedding_wgrad_compute ................... False
123
+ delay_wgrad_compute ............................. False
124
+ deprecated_use_mcore_models ..................... False
125
+ deterministic_mode .............................. False
126
+ dino_bottleneck_size ............................ 256
127
+ dino_freeze_last_layer .......................... 1
128
+ dino_head_hidden_size ........................... 2048
129
+ dino_local_crops_number ......................... 10
130
+ dino_local_img_size ............................. 96
131
+ dino_norm_last_layer ............................ False
132
+ dino_teacher_temp ............................... 0.07
133
+ dino_warmup_teacher_temp ........................ 0.04
134
+ dino_warmup_teacher_temp_epochs ................. 30
135
+ disable_attn_output_gate ........................ False
136
+ disable_bf16_reduced_precision_matmul ........... False
137
+ disable_chunked_prefill ......................... False
138
+ disable_explicit_attention_mask ................. False
139
+ disable_mamba_mem_eff_path ...................... False
140
+ disable_straggler_on_startup .................... False
141
+ disable_symmetric_registration .................. False
142
+ dist_ckpt_format_deprecated ..................... None
143
+ dist_ckpt_optim_fully_reshardable ............... True
144
+ dist_ckpt_save_pre_mcore_014 .................... False
145
+ dist_ckpt_strictness ............................ assume_ok_unexpected
146
+ distrib_optim_fully_reshardable_mem_efficient ... False
147
+ distribute_saved_activations .................... False
148
+ distributed_backend ............................. nccl
149
+ distributed_timeout_minutes ..................... 10
150
+ distributed_timeout_seconds_after_init .......... None
151
+ document_packing_algorithm ...................... random
152
+ dryrun .......................................... False
153
+ dump_param_to_param_group_map ................... None
154
+ emb_deviation_loss_coeff ........................ 0
155
+ emb_deviation_type .............................. None
156
+ embedding_init_method_std ....................... None
157
+ embedding_path .................................. None
158
+ empty_unused_memory_level ....................... 0
159
+ enable_cuda_graph ............................... False
160
+ enable_debug_logging ............................ False
161
+ enable_experimental ............................. False
162
+ enable_ft_package ............................... False
163
+ enable_full_sharding_in_hsdp .................... False
164
+ enable_gloo_process_groups ...................... True
165
+ enable_msc ...................................... True
166
+ enable_one_logger ............................... True
167
+ encoder_num_layers .............................. 56
168
+ encoder_seq_length .............................. 1
169
+ end_weight_decay ................................ 0.1
170
+ eod_mask_loss ................................... False
171
+ error_injection_rate ............................ 0
172
+ error_injection_type ............................ transient_error
173
+ eval_interval ................................... 1000
174
+ eval_iters ...................................... 10
175
+ evidence_data_path .............................. None
176
+ exit_duration_in_mins ........................... None
177
+ exit_interval ................................... None
178
+ exit_on_missing_checkpoint ...................... False
179
+ exit_signal_handler ............................. False
180
+ exp_avg_dtype ................................... torch.float32
181
+ exp_avg_sq_dtype ................................ torch.float32
182
+ expert_model_parallel_size ...................... 1
183
+ expert_tensor_parallel_size ..................... 1
184
+ external_cuda_graph ............................. False
185
+ ffn_hidden_size ................................. 4800
186
+ ffn_intermediate_token_shift .................... None
187
+ ffn_token_shift ................................. None
188
+ fine_grained_activation_offloading .............. False
189
+ finetune ........................................ False
190
+ first_last_layers_bf16 .......................... False
191
+ flash_decode .................................... False
192
+ fp16 ............................................ False
193
+ fp16_lm_cross_entropy ........................... False
194
+ fp32_residual_connection ........................ False
195
+ fp4 ............................................. None
196
+ fp4_param ....................................... False
197
+ fp4_recipe ...................................... nvfp4
198
+ fp8 ............................................. None
199
+ fp8_amax_compute_algo ........................... most_recent
200
+ fp8_amax_history_len ............................ 1
201
+ fp8_interval .................................... 1
202
+ fp8_margin ...................................... 0
203
+ fp8_param_gather ................................ False
204
+ fp8_recipe ...................................... delayed
205
+ fp8_wgrad ....................................... True
206
+ freeze_layernorm_weight ......................... False
207
+ freeze_non_mamba ................................ False
208
+ fsdp_double_buffer .............................. False
209
+ full_validation ................................. False
210
+ gdn_cp_impl ..................................... cp2hp
211
+ geglu ........................................... False
212
+ global_batch_size ............................... 1024
213
+ glu_linear_offset ............................... 0.0
214
+ grad_reduce_in_bf16 ............................. False
215
+ gradient_accumulation_fusion .................... True
216
+ gradient_reduce_div_fusion ...................... True
217
+ group_query_attention ........................... True
218
+ grpo_clamp_eps_lower ............................ 0.01
219
+ grpo_clamp_eps_upper ............................ 0.01
220
+ grpo_default_temperature ........................ 1.0
221
+ grpo_default_top_p .............................. 0
222
+ grpo_entropy_term_weight ........................ 0.0
223
+ grpo_filter_groups_with_same_reward ............. False
224
+ grpo_group_size ................................. 2
225
+ grpo_iterations ................................. 2
226
+ grpo_kl_beta .................................... 0.001
227
+ grpo_prompts_per_step ........................... 32
228
+ head_lr_mult .................................... 1.0
229
+ heterogeneous_layers_config_encoded_json ........ None
230
+ heterogeneous_layers_config_path ................ None
231
+ hf_dir .......................................... /tmp/tmp.FZZhIF5Vmh
232
+ hidden_dropout .................................. 0.1
233
+ hidden_size ..................................... 1920
234
+ hierarchical_context_parallel_sizes ............. None
235
+ high_priority_stream_groups ..................... []
236
+ hybrid_attention_ratio .......................... 0.0
237
+ hybrid_context_parallel ......................... False
238
+ hybrid_mlp_ratio ................................ 0.0
239
+ hybrid_override_pattern ......................... None
240
+ hysteresis ...................................... 2
241
+ ict_head_size ................................... None
242
+ ict_load ........................................ None
243
+ img_h ........................................... 224
244
+ img_w ........................................... 224
245
+ increase_log_level_interval ..................... 1000
246
+ increase_log_level_iters ........................ 5
247
+ indexer_batch_size .............................. 128
248
+ indexer_log_interval ............................ 1000
249
+ inference_batch_times_seqlen_threshold .......... -1
250
+ inference_dynamic_batching ...................... False
251
+ inference_dynamic_batching_block_size ........... 256
252
+ inference_dynamic_batching_buffer_guaranteed_fraction 0.2
253
+ inference_dynamic_batching_buffer_overflow_factor None
254
+ inference_dynamic_batching_buffer_size_gb ....... 40.0
255
+ inference_dynamic_batching_max_requests_override None
256
+ inference_dynamic_batching_max_tokens_override .. None
257
+ inference_dynamic_batching_num_cuda_graphs ...... 16
258
+ inference_dynamic_batching_track_paused_request_events False
259
+ inference_dynamic_batching_unified_memory_level . 0
260
+ inference_max_batch_size ........................ 8
261
+ inference_max_seq_length ........................ 2560
262
+ inference_rng_tracker ........................... False
263
+ init_method_std ................................. 0.006
264
+ init_method_xavier_uniform ...................... False
265
+ init_model_with_meta_device ..................... False
266
+ initial_loss_scale .............................. 4294967296
267
+ inprocess_active_world_size ..................... 1
268
+ inprocess_barrier_timeout ....................... 120
269
+ inprocess_completion_timeout .................... 120
270
+ inprocess_empty_cuda_cache ...................... False
271
+ inprocess_granularity ........................... node
272
+ inprocess_hard_timeout .......................... 90
273
+ inprocess_heartbeat_interval .................... 30
274
+ inprocess_heartbeat_timeout ..................... 60
275
+ inprocess_last_call_wait ........................ 1
276
+ inprocess_max_iterations ........................ None
277
+ inprocess_monitor_process_interval .............. 1.0
278
+ inprocess_monitor_thread_interval ............... 1.0
279
+ inprocess_progress_watchdog_interval ............ 1.0
280
+ inprocess_restart ............................... False
281
+ inprocess_soft_timeout .......................... 60
282
+ inprocess_termination_grace_time ................ 1
283
+ is_hybrid_model ................................. False
284
+ iter_per_epoch .................................. 1250
285
+ iterations_to_skip .............................. []
286
+ keep_fp8_transpose_cache ........................ False
287
+ kitchen_config_file ............................. None
288
+ kitchen_recipe_number ........................... None
289
+ kv_channels ..................................... 64
290
+ kv_lora_rank .................................... 32
291
+ langrl_env_config ............................... None
292
+ langrl_external_server .......................... False
293
+ langrl_inference_server_conversation_template ... None
294
+ langrl_inference_server_type .................... inplace_megatron
295
+ lazy_mpu_init ................................... None
296
+ legacy_tokenizer ................................ False
297
+ linear_attention_freq ........................... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1]
298
+ linear_attention_type ........................... gated_delta_net
299
+ linear_conv_kernel_dim .......................... 4
300
+ linear_key_head_dim ............................. 64
301
+ linear_num_key_heads ............................ 8
302
+ linear_num_value_heads .......................... 32
303
+ linear_value_head_dim ........................... 64
304
+ load ............................................ None
305
+ load_complemental_dataset ....................... None
306
+ load_dir ........................................ /capacity/userdata/vc0e4b0o65t5/lvzhihao/PostTrain/YuLan-Pretrain/outputs/yulan_mini_sft/run_sl16384_tp1_pp1_cp2/checkpoint/yulan-gdn-sft-1b-sl16384-lr1e-5-gbs64-mb1-tp1-pp1-cp2
307
+ load_main_params_from_ckpt ...................... None
308
+ local_rank ...................................... 0
309
+ log_energy ...................................... False
310
+ log_hidden_states ............................... []
311
+ log_interval .................................... 100
312
+ log_loss_scale_to_tensorboard ................... True
313
+ log_memory_to_tensorboard ....................... False
314
+ log_num_zeros_in_grad ........................... False
315
+ log_params ...................................... []
316
+ log_params_norm ................................. False
317
+ log_per_module_grad_rms ......................... False
318
+ log_per_module_update_rms ....................... False
319
+ log_progress .................................... False
320
+ log_straggler ................................... False
321
+ log_throughput .................................. False
322
+ log_timers_to_tensorboard ....................... False
323
+ log_validation_ppl_to_tensorboard ............... False
324
+ log_world_size_to_tensorboard ................... False
325
+ logging_level ................................... 1
326
+ loss_scale ...................................... None
327
+ loss_scale_window ............................... 1000
328
+ lr .............................................. 2e-05
329
+ lr_decay_iters .................................. 430000
330
+ lr_decay_samples ................................ None
331
+ lr_decay_style .................................. cosine
332
+ lr_warmup_fraction .............................. 0.001
333
+ lr_warmup_init .................................. 0.0
334
+ lr_warmup_iters ................................. 0
335
+ lr_warmup_samples ............................... 0
336
+ lr_wsd_decay_iters .............................. None
337
+ lr_wsd_decay_samples ............................ None
338
+ lr_wsd_decay_style .............................. exponential
339
+ main_grads_dtype ................................ torch.float32
340
+ main_params_dtype ............................... torch.float32
341
+ make_vocab_size_divisible_by .................... 128
342
+ mamba_disable_cp ................................ False
343
+ mamba_expand .................................... 2
344
+ mamba_head_dim .................................. 64
345
+ mamba_num_groups ................................ 8
346
+ mamba_num_heads ................................. None
347
+ mamba_state_dim ................................. 128
348
+ manual_gc ....................................... False
349
+ manual_gc_eval .................................. True
350
+ manual_gc_interval .............................. 0
351
+ mask_factor ..................................... 1.0
352
+ mask_prob ....................................... 0.15
353
+ mask_type ....................................... random
354
+ masked_softmax_fusion ........................... True
355
+ max_position_embeddings ......................... 490000
356
+ max_seqlen_per_cp_rank .......................... None
357
+ max_shard_size .................................. 20GB
358
+ max_tokens_to_oom ............................... 12000
359
+ mcore2hf ........................................ True
360
+ memory_snapshot_path ............................ None
361
+ merge_file ...................................... None
362
+ micro_batch_size ................................ 1
363
+ microbatch_group_size_per_vp_stage .............. None
364
+ mid_level_dataset_surplus ....................... 0.005
365
+ min_loss_scale .................................. 1.0
366
+ min_lr .......................................... 6e-06
367
+ min_offloaded_tensor_size ....................... 1048576
368
+ mlp_chunks_for_prefill .......................... 1
369
+ mmap_bin_files .................................. True
370
+ mock_data ....................................... False
371
+ model_type ...................................... GPT
372
+ moe_apply_probs_on_input ........................ False
373
+ moe_aux_loss_coeff .............................. 0.0
374
+ moe_deepep_num_sms .............................. 20
375
+ moe_enable_deepep ............................... False
376
+ moe_expert_capacity_factor ...................... None
377
+ moe_extended_tp ................................. False
378
+ moe_ffn_hidden_size ............................. None
379
+ moe_flex_dispatcher_backend ..................... deepep
380
+ moe_grouped_gemm ................................ False
381
+ moe_hybridep_num_sms ............................ 16
382
+ moe_input_jitter_eps ............................ None
383
+ moe_layer_freq .................................. 1
384
+ moe_layer_recompute ............................. False
385
+ moe_pad_expert_input_to_capacity ................ False
386
+ moe_pad_experts_for_cuda_graph_inference ........ False
387
+ moe_per_layer_logging ........................... False
388
+ moe_permute_fusion .............................. False
389
+ moe_router_bias_update_method ................... sign
390
+ moe_router_bias_update_rate ..................... 0.001
391
+ moe_router_dtype ................................ None
392
+ moe_router_enable_expert_bias ................... False
393
+ moe_router_force_load_balancing ................. False
394
+ moe_router_fusion ............................... False
395
+ moe_router_group_topk ........................... None
396
+ moe_router_load_balancing_type .................. aux_loss
397
+ moe_router_num_groups ........................... None
398
+ moe_router_padding_for_fp8 ...................... False
399
+ moe_router_padding_for_quantization ............. False
400
+ moe_router_pre_softmax .......................... False
401
+ moe_router_score_function ....................... softmax
402
+ moe_router_topk ................................. 2
403
+ moe_router_topk_scaling_factor .................. None
404
+ moe_shared_expert_gate .......................... False
405
+ moe_shared_expert_intermediate_size ............. None
406
+ moe_shared_expert_overlap ....................... False
407
+ moe_token_dispatcher_type ....................... allgather
408
+ moe_token_drop_policy ........................... probs
409
+ moe_upcycling_granularity ....................... 1
410
+ moe_use_legacy_grouped_gemm ..................... False
411
+ moe_use_upcycling ............................... False
412
+ moe_z_loss_coeff ................................ None
413
+ mrope_section ................................... None
414
+ mscale .......................................... 1.0
415
+ mscale_all_dim .................................. 0.0
416
+ mtp_linear_attention_type ....................... None
417
+ mtp_loss_scaling_factor ......................... 0.1
418
+ mtp_num_layers .................................. None
419
+ multi_latent_attention .......................... False
420
+ multiple_validation_sets ........................ False
421
+ muon_ball_momentum .............................. 0.9
422
+ muon_ball_msign_steps ........................... 5
423
+ muon_ball_power_iteration_steps ................. 10
424
+ muon_ball_qkv_split_mode ........................ component
425
+ muon_ball_radius_mode ........................... spectral_mup
426
+ muon_ball_retract_alpha ......................... 0.05
427
+ muon_ball_retract_mode .......................... hard
428
+ muon_ball_scale_mode ............................ spectral_mup
429
+ muon_ball_split_fc1 ............................. True
430
+ muon_ball_split_moe_experts ..................... True
431
+ muon_ball_split_qkv ............................. True
432
+ muon_ball_use_nesterov .......................... True
433
+ muon_extra_scale_factor ......................... 1.0
434
+ muon_fp32_matmul_prec ........................... medium
435
+ muon_momentum ................................... 0.9
436
+ muon_num_ns_steps ............................... 5
437
+ muon_qkv_split_mode ............................. component
438
+ muon_scale_mode ................................. spectral_mup
439
+ muon_scale_vectorized_mode ...................... full
440
+ muon_split_fc1 .................................. True
441
+ muon_split_moe_experts .......................... True
442
+ muon_split_qkv .................................. True
443
+ muon_tp_mode .................................... blockwise
444
+ muon_use_nesterov ............................... False
445
+ muon_vectorize .................................. []
446
+ muon_vectorize_attn_dim ......................... hidden_size
447
+ nccl_all_reduce_for_prefill ..................... False
448
+ nccl_communicator_config_path ................... None
449
+ nccl_ub ......................................... False
450
+ no_load_optim ................................... True
451
+ no_load_rng ..................................... True
452
+ no_load_scheduler ............................... None
453
+ no_persist_layer_norm ........................... False
454
+ no_rope_freq .................................... None
455
+ no_save_optim ................................... None
456
+ no_save_rng ..................................... None
457
+ no_save_step_one ................................ None
458
+ no_weight_decay_cond_type ....................... None
459
+ non_persistent_ckpt_type ........................ None
460
+ non_persistent_global_ckpt_dir .................. None
461
+ non_persistent_local_ckpt_algo .................. fully_parallel
462
+ non_persistent_local_ckpt_dir ................... None
463
+ non_persistent_save_interval .................... None
464
+ norm_epsilon .................................... 1e-06
465
+ normalization ................................... RMSNorm
466
+ num_attention_heads ............................. 30
467
+ num_channels .................................... 3
468
+ num_classes ..................................... 1000
469
+ num_dataset_builder_threads ..................... 1
470
+ num_distributed_optimizer_instances ............. 1
471
+ num_experts ..................................... None
472
+ num_hf_saver .................................... None
473
+ num_layers ...................................... 56
474
+ num_layers_at_end_in_bf16 ....................... 1
475
+ num_layers_at_start_in_bf16 ..................... 1
476
+ num_layers_per_virtual_pipeline_stage ........... None
477
+ num_query_groups ................................ 6
478
+ num_virtual_stages_per_pipeline_rank ............ None
479
+ num_workers ..................................... 2
480
+ object_storage_cache_path ....................... None
481
+ offload_modules ................................. []
482
+ one_logger_async ................................ False
483
+ one_logger_project .............................. megatron-lm
484
+ one_logger_run_name ............................. None
485
+ onnx_safe ....................................... None
486
+ openai_gelu ..................................... False
487
+ optimizer ....................................... adam
488
+ optimizer_cpu_offload ........................... False
489
+ optimizer_offload_fraction ...................... 1.0
490
+ output_bert_embeddings .......................... False
491
+ overlap_cpu_optimizer_d2h_h2d ................... False
492
+ overlap_grad_reduce ............................. False
493
+ overlap_moe_expert_parallel_comm ................ False
494
+ overlap_p2p_comm ................................ False
495
+ overlap_p2p_comm_warmup_flush ................... False
496
+ overlap_param_gather ............................ False
497
+ overlap_param_gather_with_optimizer_step ........ False
498
+ override_hf_eod_token_id ........................ None
499
+ override_opt_param_scheduler .................... False
500
+ padded_vocab_size ............................... 99000
501
+ params_dtype .................................... torch.bfloat16
502
+ patch_dim ....................................... 16
503
+ per_split_data_args_path ........................ None
504
+ perform_initialization .......................... True
505
+ perform_rl_step ................................. False
506
+ pin_cpu_grads ................................... True
507
+ pin_cpu_params .................................. True
508
+ pipeline_model_parallel_comm_backend ............ None
509
+ pipeline_model_parallel_layout .................. None
510
+ pipeline_model_parallel_size .................... 1
511
+ position_embedding_type ......................... rope
512
+ pretrain_script ................................. mcore_gdn_moe.model_provider
513
+ pretrained_checkpoint ........................... None
514
+ profile ......................................... False
515
+ profile_ranks ................................... [0]
516
+ profile_step_end ................................ 12
517
+ profile_step_start .............................. 10
518
+ q_lora_rank ..................................... None
519
+ qk_head_dim ..................................... 128
520
+ qk_l2_norm ...................................... False
521
+ qk_layernorm .................................... False
522
+ qk_pos_emb_head_dim ............................. 64
523
+ query_in_block_prob ............................. 0.1
524
+ quick_geglu ..................................... False
525
+ rampup_batch_size ............................... None
526
+ rank ............................................ 0
527
+ recompute_granularity ........................... None
528
+ recompute_method ................................ None
529
+ recompute_modules ............................... None
530
+ recompute_num_layers ............................ None
531
+ record_memory_history ........................... False
532
+ relative_attention_max_distance ................. 128
533
+ relative_attention_num_buckets .................. 32
534
+ reparam_checkpoint .............................. None
535
+ reparam_fallback_value .......................... None
536
+ reparam_keys .................................... None
537
+ replication ..................................... False
538
+ replication_factor .............................. 2
539
+ replication_jump ................................ None
540
+ rerun_mode ...................................... validate_results
541
+ reset_attention_mask ............................ False
542
+ reset_iteration_on_load ......................... False
543
+ reset_iteration_one_to_zero ..................... False
544
+ reset_position_ids .............................. False
545
+ reset_scheduler_steps_on_load ................... False
546
+ result_rejected_tracker_filename ................ None
547
+ retriever_report_topk_accuracies ................ []
548
+ retriever_score_scaling ......................... False
549
+ retriever_seq_length ............................ 256
550
+ retro_add_retriever ............................. False
551
+ retro_attention_gate ............................ 1
552
+ retro_cyclic_train_iters ........................ None
553
+ retro_encoder_attention_dropout ................. 0.1
554
+ retro_encoder_hidden_dropout .................... 0.1
555
+ retro_encoder_layers ............................ 2
556
+ retro_num_neighbors ............................. 2
557
+ retro_num_retrieved_chunks ...................... 2
558
+ retro_project_dir ............................... None
559
+ retro_verify_neighbor_count ..................... True
560
+ reuse_grad_buf_for_mxfp8_param_ag ............... False
561
+ reweight_loss_by_sample ......................... False
562
+ rl_calculate_intra_group_similarity ............. False
563
+ rl_importance_sampling_truncation_coef .......... None
564
+ rl_inference_logprobs_is_correction ............. False
565
+ rl_offload_kv_cache_during_training ............. False
566
+ rl_offload_optimizer_during_inference ........... False
567
+ rl_partial_rollouts ............................. False
568
+ rl_prompts_per_eval ............................. 32
569
+ rl_remove_kv_cache_during_training .............. False
570
+ rl_reset_cuda_graphs ............................ False
571
+ rl_sequence_packing_algo ........................ fifo
572
+ rl_sequence_packing_bin_size .................... 8192
573
+ rl_use_sequence_packing ......................... False
574
+ rope_scaling_factor ............................. 8.0
575
+ rope_type ....................................... None
576
+ rotary_base ..................................... 490000
577
+ rotary_interleaved .............................. False
578
+ rotary_percent .................................. 1.0
579
+ rotary_scaling_factor ........................... 1.0
580
+ rotary_seq_len_interpolation_factor ............. None
581
+ run_workload_inspector_server ................... False
582
+ sample_rate ..................................... 1.0
583
+ save ............................................ None
584
+ save_after_load ................................. False
585
+ save_dir ........................................ /capacity/userdata/vc0e4b0o65t5/lvzhihao/PostTrain/YuLan-Pretrain/outputs/yulan_mini_sft/run_sl16384_tp1_pp1_cp2/checkpoint/yulan-gdn-sft-1b-sl16384-lr1e-5-gbs64-mb1-tp1-pp1-cp2/iter_2340-hf
586
+ save_interval ................................... 10000
587
+ save_retain_interval ............................ None
588
+ scatter_gather_tensors_in_pipeline .............. True
589
+ seed ............................................ 1234
590
+ seq_length ...................................... 1
591
+ sequence_parallel ............................... False
592
+ sft ............................................. False
593
+ sft_tokenizer_prompt_format ..................... nemotron-h-aligned
594
+ sgd_momentum .................................... 0.9
595
+ sharp_enabled_group ............................. None
596
+ short_seq_prob .................................. 0.1
597
+ skip_train ...................................... True
598
+ skipped_train_samples ........................... 0
599
+ softmax_type .................................... vanilla
600
+ spec ............................................ None
601
+ spectral_ball_momentum .......................... 0.9
602
+ spectral_ball_msign_steps ....................... 8
603
+ spectral_ball_power_iteration_steps ............. 20
604
+ spectral_ball_qkv_split_mode .................... component
605
+ spectral_ball_radius_mode ....................... spectral_mup
606
+ spectral_ball_retract_alpha ..................... 0.05
607
+ spectral_ball_retract_mode ...................... hard
608
+ spectral_ball_scale_mode ........................ spectral_mup
609
+ spectral_ball_solver ............................ bisection
610
+ spectral_ball_solver_max_iterations ............. 20
611
+ spectral_ball_solver_tolerance_f ................ 1e-08
612
+ spectral_ball_split_fc1 ......................... True
613
+ spectral_ball_split_moe_experts ................. True
614
+ spectral_ball_split_qkv ......................... True
615
+ spectral_ball_use_nesterov ...................... True
616
+ spectral_mup_init ............................... False
617
+ split ........................................... None
618
+ split_expert_init ............................... True
619
+ split_fc1_init .................................. True
620
+ split_qkv_init .................................. True
621
+ split_qkv_init_mode ............................. group
622
+ sqreglu ......................................... False
623
+ squared_relu .................................... False
624
+ start_samples ................................... None
625
+ start_weight_decay .............................. 0.1
626
+ straggler_ctrlr_port ............................ 65535
627
+ straggler_minmax_count .......................... 1
628
+ strict_fsdp_dtensor_load ........................ True
629
+ suggested_communication_unit_size ............... None
630
+ swanlab_exp_name ................................
631
+ swanlab_project .................................
632
+ swanlab_save_dir ................................
633
+ swanlab_workspace ...............................
634
+ swiglu .......................................... True
635
+ swin_backbone_type .............................. tiny
636
+ symmetric_ar_type ............................... None
637
+ synchronizer .................................... mcore_gdn_moe
638
+ target_ckpt_format .............................. torch_dist
639
+ te_rng_tracker .................................. False
640
+ tensor_model_parallel_size ...................... 1
641
+ tensorboard_dir ................................. None
642
+ tensorboard_log_interval ........................ 1
643
+ tensorboard_queue_size .......................... 1000
644
+ test_data_path .................................. None
645
+ test_mode ....................................... False
646
+ tiktoken_num_special_tokens ..................... 1000
647
+ tiktoken_pattern ................................ None
648
+ tiktoken_special_tokens ......................... None
649
+ timing_log_level ................................ 0
650
+ timing_log_option ............................... minmax
651
+ titles_data_path ................................ None
652
+ token_shift_conv_init ........................... default
653
+ token_shift_conv_size ........................... 4
654
+ tokenizer_metadata .............................. None
655
+ tokenizer_model ................................. /tmp/tmp.FZZhIF5Vmh
656
+ tokenizer_type .................................. HuggingFaceTokenizer
657
+ torch_fsdp2_reshard_after_forward ............... True
658
+ tp_comm_bootstrap_backend ....................... nccl
659
+ tp_comm_bulk_dgrad .............................. True
660
+ tp_comm_bulk_wgrad .............................. True
661
+ tp_comm_overlap ................................. False
662
+ tp_comm_overlap_ag .............................. True
663
+ tp_comm_overlap_cfg ............................. None
664
+ tp_comm_overlap_rs .............................. True
665
+ tp_comm_overlap_rs_dgrad ........................ False
666
+ tp_comm_split_ag ................................ True
667
+ tp_comm_split_rs ................................ True
668
+ train_data_path ................................. None
669
+ train_iters ..................................... 500000
670
+ train_samples ................................... None
671
+ train_sync_interval ............................. None
672
+ transformer_impl ................................ transformer_engine
673
+ transformer_pipeline_model_parallel_size ........ 1
674
+ trust_remote_code ............................... False
675
+ untie_embeddings_and_output_weights ............. True
676
+ use_checkpoint_args ............................. False
677
+ use_checkpoint_opt_param_scheduler .............. False
678
+ use_cpu_initialization .......................... True
679
+ use_dist_ckpt ................................... True
680
+ use_dist_ckpt_deprecated ........................ False
681
+ use_distributed_optimizer ....................... False
682
+ use_flash_attn .................................. False
683
+ use_fused_weighted_squared_relu ................. False
684
+ use_gpu ......................................... True
685
+ use_legacy_models ............................... False
686
+ use_megatron_fsdp ............................... False
687
+ use_mp_args_from_checkpoint_args ................ False
688
+ use_one_sent_docs ............................... False
689
+ use_persistent_ckpt_worker ...................... False
690
+ use_precision_aware_optimizer ................... False
691
+ use_pytorch_profiler ............................ False
692
+ use_ring_exchange_p2p ........................... False
693
+ use_rope_scaling ................................ False
694
+ use_rotary_position_embeddings .................. False
695
+ use_sharp ....................................... False
696
+ use_te_activation_func .......................... False
697
+ use_tokenizer_model_from_checkpoint_args ........ True
698
+ use_torch_fsdp2 ................................. False
699
+ use_torch_optimizer_for_cpu_offload ............. False
700
+ use_tp_pp_dp_mapping ............................ False
701
+ v_head_dim ...................................... 128
702
+ valid_data_path ................................. None
703
+ variable_seq_lengths ............................ False
704
+ virtual_pipeline_model_parallel_size ............ None
705
+ vision_backbone_type ............................ vit
706
+ vision_pretraining .............................. False
707
+ vision_pretraining_type ......................... classify
708
+ vocab_extra_ids ................................. 0
709
+ vocab_file ...................................... None
710
+ vocab_size ...................................... None
711
+ wandb_entity ....................................
712
+ wandb_exp_name ..................................
713
+ wandb_project ...................................
714
+ wandb_save_dir ..................................
715
+ weight_decay .................................... 0.1
716
+ weight_decay_incr_style ......................... constant
717
+ wgrad_deferral_limit ............................ 0
718
+ window_attn_skip_freq ........................... None
719
+ window_size ..................................... None
720
+ word_embedding_dropout_prob ..................... 0.0
721
+ world_size ...................................... 1
722
+ yaml_cfg ........................................ None
723
+ -------------------- end of arguments ---------------------
724
+ INFO:megatron.core.num_microbatches_calculator:setting number of microbatches to constant 1024
725
+ > building HuggingFaceTokenizer tokenizer ...
726
+ You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.
727
+ WARNING: one_logger package is required to enable e2e metrics tracking. please go to https://confluence.nvidia.com/display/MLWFO/Package+Repositories for details to install it
728
+ INFO:megatron.training.initialize:Setting logging level to 1
729
+ WARNING:megatron.core.rerun_state_machine:RerunStateMachine initialized in mode RerunMode.VALIDATE_RESULTS
730
+ > initializing torch distributed ...
731
+ [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
732
+ [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
733
+ [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
734
+ > initialized tensor model parallel with size 1
735
+ > initialized pipeline model parallel with size 1
736
+ > setting random seeds to 1234 ...
737
+ > compiling dataset index builder ...
738
+ make: Entering directory '/capacity/userdata/vc0e4b0o65t5/lvzhihao/PostTrain/YuLan-Pretrain/megatron/core/datasets'
739
+ make: Nothing to be done for 'default'.
740
+ make: Leaving directory '/capacity/userdata/vc0e4b0o65t5/lvzhihao/PostTrain/YuLan-Pretrain/megatron/core/datasets'
741
+ >>> done with dataset index builder. Compilation time: 0.063 seconds
742
+ WARNING: constraints for invoking optimized fused softmax kernel are not met. We default back to unfused kernel invocations.
743
+ > compiling and loading fused kernels ...
744
+ /workspace/lvzhihao/PostTrain/YuLan-Pretrain/.venv/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py:4807: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
745
+ warnings.warn( # warn only once
746
+ [rank0]:[W416 17:58:49.023645313 ProcessGroupNCCL.cpp:5023] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can specify device_id in init_process_group() to force use of a particular device.
747
+ >>> done with compiling and loading fused kernels. Compilation time: 0.506 seconds
748
+ WORLD_SIZE: 1, RANK: 0, LOCAL_RANK: 0
749
+ building GPT model ...
750
+ `torch_dtype` is deprecated! Use `dtype` instead!
751
+ `torch_dtype` is deprecated! Use `dtype` instead!
752
+ INFO:transformers_modules.tmp_dot_FZZhIF5Vmh.modeling_qwen3_next:[Qwen3Next custom] attn_position_embedding_type=rope, rnn_position_embedding_type=nope, attn_logits_scaling=None
753
+ Qwen3NextForCausalLM(
754
+ (model): Qwen3NextModel(
755
+ (embed_tokens): Embedding(99000, 1920)
756
+ (layers): ModuleList(
757
+ (0-11): 12 x Qwen3NextDecoderLayer(
758
+ (linear_attn): Qwen3NextGatedDeltaNet(
759
+ (act): SiLUActivation()
760
+ (conv1d): Conv1d(3072, 3072, kernel_size=(4,), stride=(1,), padding=(3,), groups=3072, bias=False)
761
+ (in_proj_qkvz): Linear(in_features=1920, out_features=5120, bias=False)
762
+ (in_proj_ba): Linear(in_features=1920, out_features=64, bias=False)
763
+ (norm): FusedRMSNormGated(64, eps=1e-06, activation=silu)
764
+ (out_proj): Linear(in_features=2048, out_features=1920, bias=False)
765
+ )
766
+ (mlp): Qwen3NextMLP(
767
+ (gate_proj): Linear(in_features=1920, out_features=4800, bias=False)
768
+ (up_proj): Linear(in_features=1920, out_features=4800, bias=False)
769
+ (down_proj): Linear(in_features=4800, out_features=1920, bias=False)
770
+ (act_fn): SiLUActivation()
771
+ )
772
+ (input_layernorm): LlamaRMSNorm()
773
+ (post_attention_layernorm): LlamaRMSNorm()
774
+ )
775
+ (12): Qwen3NextDecoderLayer(
776
+ (self_attn): Qwen3NextAttention(
777
+ (q_proj): Linear(in_features=1920, out_features=1920, bias=True)
778
+ (k_proj): Linear(in_features=1920, out_features=384, bias=True)
779
+ (v_proj): Linear(in_features=1920, out_features=384, bias=True)
780
+ (o_proj): Linear(in_features=1920, out_features=1920, bias=False)
781
+ )
782
+ (mlp): Qwen3NextMLP(
783
+ (gate_proj): Linear(in_features=1920, out_features=4800, bias=False)
784
+ (up_proj): Linear(in_features=1920, out_features=4800, bias=False)
785
+ (down_proj): Linear(in_features=4800, out_features=1920, bias=False)
786
+ (act_fn): SiLUActivation()
787
+ )
788
+ (input_layernorm): LlamaRMSNorm()
789
+ (post_attention_layernorm): LlamaRMSNorm()
790
+ )
791
+ (13-19): 7 x Qwen3NextDecoderLayer(
792
+ (linear_attn): Qwen3NextGatedDeltaNet(
793
+ (act): SiLUActivation()
794
+ (conv1d): Conv1d(3072, 3072, kernel_size=(4,), stride=(1,), padding=(3,), groups=3072, bias=False)
795
+ (in_proj_qkvz): Linear(in_features=1920, out_features=5120, bias=False)
796
+ (in_proj_ba): Linear(in_features=1920, out_features=64, bias=False)
797
+ (norm): FusedRMSNormGated(64, eps=1e-06, activation=silu)
798
+ (out_proj): Linear(in_features=2048, out_features=1920, bias=False)
799
+ )
800
+ (mlp): Qwen3NextMLP(
801
+ (gate_proj): Linear(in_features=1920, out_features=4800, bias=False)
802
+ (up_proj): Linear(in_features=1920, out_features=4800, bias=False)
803
+ (down_proj): Linear(in_features=4800, out_features=1920, bias=False)
804
+ (act_fn): SiLUActivation()
805
+ )
806
+ (input_layernorm): LlamaRMSNorm()
807
+ (post_attention_layernorm): LlamaRMSNorm()
808
+ )
809
+ (20-21): 2 x Qwen3NextDecoderLayer(
810
+ (self_attn): Qwen3NextAttention(
811
+ (q_proj): Linear(in_features=1920, out_features=1920, bias=True)
812
+ (k_proj): Linear(in_features=1920, out_features=384, bias=True)
813
+ (v_proj): Linear(in_features=1920, out_features=384, bias=True)
814
+ (o_proj): Linear(in_features=1920, out_features=1920, bias=False)
815
+ )
816
+ (mlp): Qwen3NextMLP(
817
+ (gate_proj): Linear(in_features=1920, out_features=4800, bias=False)
818
+ (up_proj): Linear(in_features=1920, out_features=4800, bias=False)
819
+ (down_proj): Linear(in_features=4800, out_features=1920, bias=False)
820
+ (act_fn): SiLUActivation()
821
+ )
822
+ (input_layernorm): LlamaRMSNorm()
823
+ (post_attention_layernorm): LlamaRMSNorm()
824
+ )
825
+ (22): Qwen3NextDecoderLayer(
826
+ (linear_attn): Qwen3NextGatedDeltaNet(
827
+ (act): SiLUActivation()
828
+ (conv1d): Conv1d(3072, 3072, kernel_size=(4,), stride=(1,), padding=(3,), groups=3072, bias=False)
829
+ (in_proj_qkvz): Linear(in_features=1920, out_features=5120, bias=False)
830
+ (in_proj_ba): Linear(in_features=1920, out_features=64, bias=False)
831
+ (norm): FusedRMSNormGated(64, eps=1e-06, activation=silu)
832
+ (out_proj): Linear(in_features=2048, out_features=1920, bias=False)
833
+ )
834
+ (mlp): Qwen3NextMLP(
835
+ (gate_proj): Linear(in_features=1920, out_features=4800, bias=False)
836
+ (up_proj): Linear(in_features=1920, out_features=4800, bias=False)
837
+ (down_proj): Linear(in_features=4800, out_features=1920, bias=False)
838
+ (act_fn): SiLUActivation()
839
+ )
840
+ (input_layernorm): LlamaRMSNorm()
841
+ (post_attention_layernorm): LlamaRMSNorm()
842
+ )
843
+ (23): Qwen3NextDecoderLayer(
844
+ (self_attn): Qwen3NextAttention(
845
+ (q_proj): Linear(in_features=1920, out_features=1920, bias=True)
846
+ (k_proj): Linear(in_features=1920, out_features=384, bias=True)
847
+ (v_proj): Linear(in_features=1920, out_features=384, bias=True)
848
+ (o_proj): Linear(in_features=1920, out_features=1920, bias=False)
849
+ )
850
+ (mlp): Qwen3NextMLP(
851
+ (gate_proj): Linear(in_features=1920, out_features=4800, bias=False)
852
+ (up_proj): Linear(in_features=1920, out_features=4800, bias=False)
853
+ (down_proj): Linear(in_features=4800, out_features=1920, bias=False)
854
+ (act_fn): SiLUActivation()
855
+ )
856
+ (input_layernorm): LlamaRMSNorm()
857
+ (post_attention_layernorm): LlamaRMSNorm()
858
+ )
859
+ (24-45): 22 x Qwen3NextDecoderLayer(
860
+ (linear_attn): Qwen3NextGatedDeltaNet(
861
+ (act): SiLUActivation()
862
+ (conv1d): Conv1d(3072, 3072, kernel_size=(4,), stride=(1,), padding=(3,), groups=3072, bias=False)
863
+ (in_proj_qkvz): Linear(in_features=1920, out_features=5120, bias=False)
864
+ (in_proj_ba): Linear(in_features=1920, out_features=64, bias=False)
865
+ (norm): FusedRMSNormGated(64, eps=1e-06, activation=silu)
866
+ (out_proj): Linear(in_features=2048, out_features=1920, bias=False)
867
+ )
868
+ (mlp): Qwen3NextMLP(
869
+ (gate_proj): Linear(in_features=1920, out_features=4800, bias=False)
870
+ (up_proj): Linear(in_features=1920, out_features=4800, bias=False)
871
+ (down_proj): Linear(in_features=4800, out_features=1920, bias=False)
872
+ (act_fn): SiLUActivation()
873
+ )
874
+ (input_layernorm): LlamaRMSNorm()
875
+ (post_attention_layernorm): LlamaRMSNorm()
876
+ )
877
+ (46): Qwen3NextDecoderLayer(
878
+ (self_attn): Qwen3NextAttention(
879
+ (q_proj): Linear(in_features=1920, out_features=1920, bias=True)
880
+ (k_proj): Linear(in_features=1920, out_features=384, bias=True)
881
+ (v_proj): Linear(in_features=1920, out_features=384, bias=True)
882
+ (o_proj): Linear(in_features=1920, out_features=1920, bias=False)
883
+ )
884
+ (mlp): Qwen3NextMLP(
885
+ (gate_proj): Linear(in_features=1920, out_features=4800, bias=False)
886
+ (up_proj): Linear(in_features=1920, out_features=4800, bias=False)
887
+ (down_proj): Linear(in_features=4800, out_features=1920, bias=False)
888
+ (act_fn): SiLUActivation()
889
+ )
890
+ (input_layernorm): LlamaRMSNorm()
891
+ (post_attention_layernorm): LlamaRMSNorm()
892
+ )
893
+ (47): Qwen3NextDecoderLayer(
894
+ (linear_attn): Qwen3NextGatedDeltaNet(
895
+ (act): SiLUActivation()
896
+ (conv1d): Conv1d(3072, 3072, kernel_size=(4,), stride=(1,), padding=(3,), groups=3072, bias=False)
897
+ (in_proj_qkvz): Linear(in_features=1920, out_features=5120, bias=False)
898
+ (in_proj_ba): Linear(in_features=1920, out_features=64, bias=False)
899
+ (norm): FusedRMSNormGated(64, eps=1e-06, activation=silu)
900
+ (out_proj): Linear(in_features=2048, out_features=1920, bias=False)
901
+ )
902
+ (mlp): Qwen3NextMLP(
903
+ (gate_proj): Linear(in_features=1920, out_features=4800, bias=False)
904
+ (up_proj): Linear(in_features=1920, out_features=4800, bias=False)
905
+ (down_proj): Linear(in_features=4800, out_features=1920, bias=False)
906
+ (act_fn): SiLUActivation()
907
+ )
908
+ (input_layernorm): LlamaRMSNorm()
909
+ (post_attention_layernorm): LlamaRMSNorm()
910
+ )
911
+ (48-49): 2 x Qwen3NextDecoderLayer(
912
+ (self_attn): Qwen3NextAttention(
913
+ (q_proj): Linear(in_features=1920, out_features=1920, bias=True)
914
+ (k_proj): Linear(in_features=1920, out_features=384, bias=True)
915
+ (v_proj): Linear(in_features=1920, out_features=384, bias=True)
916
+ (o_proj): Linear(in_features=1920, out_features=1920, bias=False)
917
+ )
918
+ (mlp): Qwen3NextMLP(
919
+ (gate_proj): Linear(in_features=1920, out_features=4800, bias=False)
920
+ (up_proj): Linear(in_features=1920, out_features=4800, bias=False)
921
+ (down_proj): Linear(in_features=4800, out_features=1920, bias=False)
922
+ (act_fn): SiLUActivation()
923
+ )
924
+ (input_layernorm): LlamaRMSNorm()
925
+ (post_attention_layernorm): LlamaRMSNorm()
926
+ )
927
+ (50-55): 6 x Qwen3NextDecoderLayer(
928
+ (linear_attn): Qwen3NextGatedDeltaNet(
929
+ (act): SiLUActivation()
930
+ (conv1d): Conv1d(3072, 3072, kernel_size=(4,), stride=(1,), padding=(3,), groups=3072, bias=False)
931
+ (in_proj_qkvz): Linear(in_features=1920, out_features=5120, bias=False)
932
+ (in_proj_ba): Linear(in_features=1920, out_features=64, bias=False)
933
+ (norm): FusedRMSNormGated(64, eps=1e-06, activation=silu)
934
+ (out_proj): Linear(in_features=2048, out_features=1920, bias=False)
935
+ )
936
+ (mlp): Qwen3NextMLP(
937
+ (gate_proj): Linear(in_features=1920, out_features=4800, bias=False)
938
+ (up_proj): Linear(in_features=1920, out_features=4800, bias=False)
939
+ (down_proj): Linear(in_features=4800, out_features=1920, bias=False)
940
+ (act_fn): SiLUActivation()
941
+ )
942
+ (input_layernorm): LlamaRMSNorm()
943
+ (post_attention_layernorm): LlamaRMSNorm()
944
+ )
945
+ )
946
+ (norm): LlamaRMSNorm()
947
+ (rotary_emb): Qwen3NextRotaryEmbedding()
948
+ )
949
+ (lm_head): Linear(in_features=1920, out_features=99000, bias=False)
950
+ )
951
+ GPTModel(
952
+ (embedding): LanguageModelEmbedding(
953
+ (word_embeddings): VocabParallelEmbedding()
954
+ (embedding_dropout): Dropout(p=0.1, inplace=False)
955
+ )
956
+ (rotary_pos_emb): RotaryEmbedding()
957
+ (decoder): TransformerBlock(
958
+ (layers): ModuleList(
959
+ (0-11): 12 x TransformerLayer(
960
+ (input_layernorm): IdentityOp()
961
+ (self_attention): GatedDeltaNet(
962
+ (in_proj): TELayerNormColumnParallelLinear(in_features=1920, out_features=5184, bias=False, TP=1)
963
+ (conv1d): Conv1d(3072, 3072, kernel_size=(4,), stride=(1,), padding=(3,), groups=3072, bias=False)
964
+ (out_norm): RMSNorm()
965
+ (out_proj): TERowParallelLinear(in_features=2048, out_features=1920, bias=False, TP=1)
966
+ )
967
+ (pre_cross_attn_layernorm): IdentityOp()
968
+ (cross_attention): IdentityOp()
969
+ (cross_attn_bda): IdentityFuncOp()
970
+ (pre_mlp_layernorm): IdentityOp()
971
+ (mlp): MLP(
972
+ (linear_fc1): TELayerNormColumnParallelLinear(in_features=1920, out_features=9600, bias=False, TP=1)
973
+ (linear_fc2): TERowParallelLinear(in_features=4800, out_features=1920, bias=False, TP=1)
974
+ )
975
+ )
976
+ (12): TransformerLayer(
977
+ (input_layernorm): IdentityOp()
978
+ (self_attention): SelfAttention(
979
+ (core_attention): TEDotProductAttention(
980
+ (flash_attention): FlashAttention()
981
+ (fused_attention): FusedAttention()
982
+ (unfused_attention): UnfusedDotProductAttention(
983
+ (scale_mask_softmax): FusedScaleMaskSoftmax()
984
+ (attention_dropout): Dropout(p=0.1, inplace=False)
985
+ )
986
+ )
987
+ (linear_proj): TERowParallelLinear(in_features=1920, out_features=1920, bias=False, TP=1)
988
+ (linear_qkv): TELayerNormColumnParallelLinear(in_features=1920, out_features=2688, bias=True, TP=1)
989
+ (q_layernorm): IdentityOp()
990
+ (k_layernorm): IdentityOp()
991
+ )
992
+ (pre_cross_attn_layernorm): IdentityOp()
993
+ (cross_attention): IdentityOp()
994
+ (cross_attn_bda): IdentityFuncOp()
995
+ (pre_mlp_layernorm): IdentityOp()
996
+ (mlp): MLP(
997
+ (linear_fc1): TELayerNormColumnParallelLinear(in_features=1920, out_features=9600, bias=False, TP=1)
998
+ (linear_fc2): TERowParallelLinear(in_features=4800, out_features=1920, bias=False, TP=1)
999
+ )
1000
+ )
1001
+ (13-19): 7 x TransformerLayer(
1002
+ (input_layernorm): IdentityOp()
1003
+ (self_attention): GatedDeltaNet(
1004
+ (in_proj): TELayerNormColumnParallelLinear(in_features=1920, out_features=5184, bias=False, TP=1)
1005
+ (conv1d): Conv1d(3072, 3072, kernel_size=(4,), stride=(1,), padding=(3,), groups=3072, bias=False)
1006
+ (out_norm): RMSNorm()
1007
+ (out_proj): TERowParallelLinear(in_features=2048, out_features=1920, bias=False, TP=1)
1008
+ )
1009
+ (pre_cross_attn_layernorm): IdentityOp()
1010
+ (cross_attention): IdentityOp()
1011
+ (cross_attn_bda): IdentityFuncOp()
1012
+ (pre_mlp_layernorm): IdentityOp()
1013
+ (mlp): MLP(
1014
+ (linear_fc1): TELayerNormColumnParallelLinear(in_features=1920, out_features=9600, bias=False, TP=1)
1015
+ (linear_fc2): TERowParallelLinear(in_features=4800, out_features=1920, bias=False, TP=1)
1016
+ )
1017
+ )
1018
+ (20-21): 2 x TransformerLayer(
1019
+ (input_layernorm): IdentityOp()
1020
+ (self_attention): SelfAttention(
1021
+ (core_attention): TEDotProductAttention(
1022
+ (flash_attention): FlashAttention()
1023
+ (fused_attention): FusedAttention()
1024
+ (unfused_attention): UnfusedDotProductAttention(
1025
+ (scale_mask_softmax): FusedScaleMaskSoftmax()
1026
+ (attention_dropout): Dropout(p=0.1, inplace=False)
1027
+ )
1028
+ )
1029
+ (linear_proj): TERowParallelLinear(in_features=1920, out_features=1920, bias=False, TP=1)
1030
+ (linear_qkv): TELayerNormColumnParallelLinear(in_features=1920, out_features=2688, bias=True, TP=1)
1031
+ (q_layernorm): IdentityOp()
1032
+ (k_layernorm): IdentityOp()
1033
+ )
1034
+ (pre_cross_attn_layernorm): IdentityOp()
1035
+ (cross_attention): IdentityOp()
1036
+ (cross_attn_bda): IdentityFuncOp()
1037
+ (pre_mlp_layernorm): IdentityOp()
1038
+ (mlp): MLP(
1039
+ (linear_fc1): TELayerNormColumnParallelLinear(in_features=1920, out_features=9600, bias=False, TP=1)
1040
+ (linear_fc2): TERowParallelLinear(in_features=4800, out_features=1920, bias=False, TP=1)
1041
+ )
1042
+ )
1043
+ (22): TransformerLayer(
1044
+ (input_layernorm): IdentityOp()
1045
+ (self_attention): GatedDeltaNet(
1046
+ (in_proj): TELayerNormColumnParallelLinear(in_features=1920, out_features=5184, bias=False, TP=1)
1047
+ (conv1d): Conv1d(3072, 3072, kernel_size=(4,), stride=(1,), padding=(3,), groups=3072, bias=False)
1048
+ (out_norm): RMSNorm()
1049
+ (out_proj): TERowParallelLinear(in_features=2048, out_features=1920, bias=False, TP=1)
1050
+ )
1051
+ (pre_cross_attn_layernorm): IdentityOp()
1052
+ (cross_attention): IdentityOp()
1053
+ (cross_attn_bda): IdentityFuncOp()
1054
+ (pre_mlp_layernorm): IdentityOp()
1055
+ (mlp): MLP(
1056
+ (linear_fc1): TELayerNormColumnParallelLinear(in_features=1920, out_features=9600, bias=False, TP=1)
1057
+ (linear_fc2): TERowParallelLinear(in_features=4800, out_features=1920, bias=False, TP=1)
1058
+ )
1059
+ )
1060
+ (23): TransformerLayer(
1061
+ (input_layernorm): IdentityOp()
1062
+ (self_attention): SelfAttention(
1063
+ (core_attention): TEDotProductAttention(
1064
+ (flash_attention): FlashAttention()
1065
+ (fused_attention): FusedAttention()
1066
+ (unfused_attention): UnfusedDotProductAttention(
1067
+ (scale_mask_softmax): FusedScaleMaskSoftmax()
1068
+ (attention_dropout): Dropout(p=0.1, inplace=False)
1069
+ )
1070
+ )
1071
+ (linear_proj): TERowParallelLinear(in_features=1920, out_features=1920, bias=False, TP=1)
1072
+ (linear_qkv): TELayerNormColumnParallelLinear(in_features=1920, out_features=2688, bias=True, TP=1)
1073
+ (q_layernorm): IdentityOp()
1074
+ (k_layernorm): IdentityOp()
1075
+ )
1076
+ (pre_cross_attn_layernorm): IdentityOp()
1077
+ (cross_attention): IdentityOp()
1078
+ (cross_attn_bda): IdentityFuncOp()
1079
+ (pre_mlp_layernorm): IdentityOp()
1080
+ (mlp): MLP(
1081
+ (linear_fc1): TELayerNormColumnParallelLinear(in_features=1920, out_features=9600, bias=False, TP=1)
1082
+ (linear_fc2): TERowParallelLinear(in_features=4800, out_features=1920, bias=False, TP=1)
1083
+ )
1084
+ )
1085
+ (24-45): 22 x TransformerLayer(
1086
+ (input_layernorm): IdentityOp()
1087
+ (self_attention): GatedDeltaNet(
1088
+ (in_proj): TELayerNormColumnParallelLinear(in_features=1920, out_features=5184, bias=False, TP=1)
1089
+ (conv1d): Conv1d(3072, 3072, kernel_size=(4,), stride=(1,), padding=(3,), groups=3072, bias=False)
1090
+ (out_norm): RMSNorm()
1091
+ (out_proj): TERowParallelLinear(in_features=2048, out_features=1920, bias=False, TP=1)
1092
+ )
1093
+ (pre_cross_attn_layernorm): IdentityOp()
1094
+ (cross_attention): IdentityOp()
1095
+ (cross_attn_bda): IdentityFuncOp()
1096
+ (pre_mlp_layernorm): IdentityOp()
1097
+ (mlp): MLP(
1098
+ (linear_fc1): TELayerNormColumnParallelLinear(in_features=1920, out_features=9600, bias=False, TP=1)
1099
+ (linear_fc2): TERowParallelLinear(in_features=4800, out_features=1920, bias=False, TP=1)
1100
+ )
1101
+ )
1102
+ (46): TransformerLayer(
1103
+ (input_layernorm): IdentityOp()
1104
+ (self_attention): SelfAttention(
1105
+ (core_attention): TEDotProductAttention(
1106
+ (flash_attention): FlashAttention()
1107
+ (fused_attention): FusedAttention()
1108
+ (unfused_attention): UnfusedDotProductAttention(
1109
+ (scale_mask_softmax): FusedScaleMaskSoftmax()
1110
+ (attention_dropout): Dropout(p=0.1, inplace=False)
1111
+ )
1112
+ )
1113
+ (linear_proj): TERowParallelLinear(in_features=1920, out_features=1920, bias=False, TP=1)
1114
+ (linear_qkv): TELayerNormColumnParallelLinear(in_features=1920, out_features=2688, bias=True, TP=1)
1115
+ (q_layernorm): IdentityOp()
1116
+ (k_layernorm): IdentityOp()
1117
+ )
1118
+ (pre_cross_attn_layernorm): IdentityOp()
1119
+ (cross_attention): IdentityOp()
1120
+ (cross_attn_bda): IdentityFuncOp()
1121
+ (pre_mlp_layernorm): IdentityOp()
1122
+ (mlp): MLP(
1123
+ (linear_fc1): TELayerNormColumnParallelLinear(in_features=1920, out_features=9600, bias=False, TP=1)
1124
+ (linear_fc2): TERowParallelLinear(in_features=4800, out_features=1920, bias=False, TP=1)
1125
+ )
1126
+ )
1127
+ (47): TransformerLayer(
1128
+ (input_layernorm): IdentityOp()
1129
+ (self_attention): GatedDeltaNet(
1130
+ (in_proj): TELayerNormColumnParallelLinear(in_features=1920, out_features=5184, bias=False, TP=1)
1131
+ (conv1d): Conv1d(3072, 3072, kernel_size=(4,), stride=(1,), padding=(3,), groups=3072, bias=False)
1132
+ (out_norm): RMSNorm()
1133
+ (out_proj): TERowParallelLinear(in_features=2048, out_features=1920, bias=False, TP=1)
1134
+ )
1135
+ (pre_cross_attn_layernorm): IdentityOp()
1136
+ (cross_attention): IdentityOp()
1137
+ (cross_attn_bda): IdentityFuncOp()
1138
+ (pre_mlp_layernorm): IdentityOp()
1139
+ (mlp): MLP(
1140
+ (linear_fc1): TELayerNormColumnParallelLinear(in_features=1920, out_features=9600, bias=False, TP=1)
1141
+ (linear_fc2): TERowParallelLinear(in_features=4800, out_features=1920, bias=False, TP=1)
1142
+ )
1143
+ )
1144
+ (48-49): 2 x TransformerLayer(
1145
+ (input_layernorm): IdentityOp()
1146
+ (self_attention): SelfAttention(
1147
+ (core_attention): TEDotProductAttention(
1148
+ (flash_attention): FlashAttention()
1149
+ (fused_attention): FusedAttention()
1150
+ (unfused_attention): UnfusedDotProductAttention(
1151
+ (scale_mask_softmax): FusedScaleMaskSoftmax()
1152
+ (attention_dropout): Dropout(p=0.1, inplace=False)
1153
+ )
1154
+ )
1155
+ (linear_proj): TERowParallelLinear(in_features=1920, out_features=1920, bias=False, TP=1)
1156
+ (linear_qkv): TELayerNormColumnParallelLinear(in_features=1920, out_features=2688, bias=True, TP=1)
1157
+ (q_layernorm): IdentityOp()
1158
+ (k_layernorm): IdentityOp()
1159
+ )
1160
+ (pre_cross_attn_layernorm): IdentityOp()
1161
+ (cross_attention): IdentityOp()
1162
+ (cross_attn_bda): IdentityFuncOp()
1163
+ (pre_mlp_layernorm): IdentityOp()
1164
+ (mlp): MLP(
1165
+ (linear_fc1): TELayerNormColumnParallelLinear(in_features=1920, out_features=9600, bias=False, TP=1)
1166
+ (linear_fc2): TERowParallelLinear(in_features=4800, out_features=1920, bias=False, TP=1)
1167
+ )
1168
+ )
1169
+ (50-55): 6 x TransformerLayer(
1170
+ (input_layernorm): IdentityOp()
1171
+ (self_attention): GatedDeltaNet(
1172
+ (in_proj): TELayerNormColumnParallelLinear(in_features=1920, out_features=5184, bias=False, TP=1)
1173
+ (conv1d): Conv1d(3072, 3072, kernel_size=(4,), stride=(1,), padding=(3,), groups=3072, bias=False)
1174
+ (out_norm): RMSNorm()
1175
+ (out_proj): TERowParallelLinear(in_features=2048, out_features=1920, bias=False, TP=1)
1176
+ )
1177
+ (pre_cross_attn_layernorm): IdentityOp()
1178
+ (cross_attention): IdentityOp()
1179
+ (cross_attn_bda): IdentityFuncOp()
1180
+ (pre_mlp_layernorm): IdentityOp()
1181
+ (mlp): MLP(
1182
+ (linear_fc1): TELayerNormColumnParallelLinear(in_features=1920, out_features=9600, bias=False, TP=1)
1183
+ (linear_fc2): TERowParallelLinear(in_features=4800, out_features=1920, bias=False, TP=1)
1184
+ )
1185
+ )
1186
+ )
1187
+ (final_layernorm): RMSNorm()
1188
+ )
1189
+ (output_layer): ColumnParallelLinear(in_features=1920, out_features=99000, bias=False, TP=1)
1190
+ )
1191
+ /workspace/lvzhihao/PostTrain/YuLan-Pretrain/megatron/core/dist_checkpointing/strategies/common.py:89: UserWarning: Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.
1192
+ return torch.load(load_path, map_location='cpu')
1193
+ sharded_state_dict metadata loaded from the checkpoint: {'singleton_local_shards': True, 'distrib_optim_sharding_type': 'fully_reshardable', 'distrib_optim_fully_reshardable_mem_efficient': False, 'chained_optim_avoid_prefix': True}
1194
+ Job sharding has changed: Rerun state will be ignored
1195
+ loading distributed checkpoint from /capacity/userdata/vc0e4b0o65t5/lvzhihao/PostTrain/YuLan-Pretrain/outputs/yulan_mini_sft/run_sl16384_tp1_pp1_cp2/checkpoint/yulan-gdn-sft-1b-sl16384-lr1e-5-gbs64-mb1-tp1-pp1-cp2 at iteration 2340
1196
+ /workspace/lvzhihao/PostTrain/YuLan-Pretrain/megatron/core/dist_checkpointing/strategies/torch.py:956: FutureWarning: `load_state_dict` is deprecated and will be removed in future versions. Please use `load` instead.
1197
+ checkpoint.load_state_dict(
1198
+ checkpoint version 3.0
1199
+ successfully loaded checkpoint from /capacity/userdata/vc0e4b0o65t5/lvzhihao/PostTrain/YuLan-Pretrain/outputs/yulan_mini_sft/run_sl16384_tp1_pp1_cp2/checkpoint/yulan-gdn-sft-1b-sl16384-lr1e-5-gbs64-mb1-tp1-pp1-cp2 [ t 1/1, p 1/1 ] at iteration 2340
1200
+ INFO:root:Converting layer 0 is_gdn=True is_not_moe=True
1201
+ INFO:root:Converting layer 1 is_gdn=True is_not_moe=True
1202
+ INFO:root:Converting layer 2 is_gdn=True is_not_moe=True
1203
+ INFO:root:Converting layer 3 is_gdn=True is_not_moe=True
1204
+ INFO:root:Converting layer 4 is_gdn=True is_not_moe=True
1205
+ INFO:root:Converting layer 5 is_gdn=True is_not_moe=True
1206
+ INFO:root:Converting layer 6 is_gdn=True is_not_moe=True
1207
+ INFO:root:Converting layer 7 is_gdn=True is_not_moe=True
1208
+ INFO:root:Converting layer 8 is_gdn=True is_not_moe=True
1209
+ INFO:root:Converting layer 9 is_gdn=True is_not_moe=True
1210
+ INFO:root:Converting layer 10 is_gdn=True is_not_moe=True
1211
+ INFO:root:Converting layer 11 is_gdn=True is_not_moe=True
1212
+ INFO:root:Converting layer 12 is_gdn=False is_not_moe=True
1213
+ INFO:root:[DEBUG] Layer 12: args.attention_output_gate=False
1214
+ INFO:root:[DEBUG] set_gated_selfattn_state: args.attention_output_gate=False
1215
+ INFO:root:[DEBUG] set_gated_selfattn_state: attention_output_gate=False, linear_layer=TELayerNormColumnParallelLinear
1216
+ INFO:root:Converting layer 13 is_gdn=True is_not_moe=True
1217
+ INFO:root:Converting layer 14 is_gdn=True is_not_moe=True
1218
+ INFO:root:Converting layer 15 is_gdn=True is_not_moe=True
1219
+ INFO:root:Converting layer 16 is_gdn=True is_not_moe=True
1220
+ INFO:root:Converting layer 17 is_gdn=True is_not_moe=True
1221
+ INFO:root:Converting layer 18 is_gdn=True is_not_moe=True
1222
+ INFO:root:Converting layer 19 is_gdn=True is_not_moe=True
1223
+ INFO:root:Converting layer 20 is_gdn=False is_not_moe=True
1224
+ INFO:root:[DEBUG] Layer 20: args.attention_output_gate=False
1225
+ INFO:root:[DEBUG] set_gated_selfattn_state: args.attention_output_gate=False
1226
+ INFO:root:[DEBUG] set_gated_selfattn_state: attention_output_gate=False, linear_layer=TELayerNormColumnParallelLinear
1227
+ INFO:root:Converting layer 21 is_gdn=False is_not_moe=True
1228
+ INFO:root:[DEBUG] Layer 21: args.attention_output_gate=False
1229
+ INFO:root:[DEBUG] set_gated_selfattn_state: args.attention_output_gate=False
1230
+ INFO:root:[DEBUG] set_gated_selfattn_state: attention_output_gate=False, linear_layer=TELayerNormColumnParallelLinear
1231
+ INFO:root:Converting layer 22 is_gdn=True is_not_moe=True
1232
+ INFO:root:Converting layer 23 is_gdn=False is_not_moe=True
1233
+ INFO:root:[DEBUG] Layer 23: args.attention_output_gate=False
1234
+ INFO:root:[DEBUG] set_gated_selfattn_state: args.attention_output_gate=False
1235
+ INFO:root:[DEBUG] set_gated_selfattn_state: attention_output_gate=False, linear_layer=TELayerNormColumnParallelLinear
1236
+ INFO:root:Converting layer 24 is_gdn=True is_not_moe=True
1237
+ INFO:root:Converting layer 25 is_gdn=True is_not_moe=True
1238
+ INFO:root:Converting layer 26 is_gdn=True is_not_moe=True
1239
+ INFO:root:Converting layer 27 is_gdn=True is_not_moe=True
1240
+ INFO:root:Converting layer 28 is_gdn=True is_not_moe=True
1241
+ INFO:root:Converting layer 29 is_gdn=True is_not_moe=True
1242
+ INFO:root:Converting layer 30 is_gdn=True is_not_moe=True
1243
+ INFO:root:Converting layer 31 is_gdn=True is_not_moe=True
1244
+ INFO:root:Converting layer 32 is_gdn=True is_not_moe=True
1245
+ INFO:root:Converting layer 33 is_gdn=True is_not_moe=True
1246
+ INFO:root:Converting layer 34 is_gdn=True is_not_moe=True
1247
+ INFO:root:Converting layer 35 is_gdn=True is_not_moe=True
1248
+ INFO:root:Converting layer 36 is_gdn=True is_not_moe=True
1249
+ INFO:root:Converting layer 37 is_gdn=True is_not_moe=True
1250
+ INFO:root:Converting layer 38 is_gdn=True is_not_moe=True
1251
+ INFO:root:Converting layer 39 is_gdn=True is_not_moe=True
1252
+ INFO:root:Converting layer 40 is_gdn=True is_not_moe=True
1253
+ INFO:root:Converting layer 41 is_gdn=True is_not_moe=True
1254
+ INFO:root:Converting layer 42 is_gdn=True is_not_moe=True
1255
+ INFO:root:Converting layer 43 is_gdn=True is_not_moe=True
1256
+ INFO:root:Converting layer 44 is_gdn=True is_not_moe=True
1257
+ INFO:root:Converting layer 45 is_gdn=True is_not_moe=True
1258
+ INFO:root:Converting layer 46 is_gdn=False is_not_moe=True
1259
+ INFO:root:[DEBUG] Layer 46: args.attention_output_gate=False
1260
+ INFO:root:[DEBUG] set_gated_selfattn_state: args.attention_output_gate=False
1261
+ INFO:root:[DEBUG] set_gated_selfattn_state: attention_output_gate=False, linear_layer=TELayerNormColumnParallelLinear
1262
+ INFO:root:Converting layer 47 is_gdn=True is_not_moe=True
1263
+ INFO:root:Converting layer 48 is_gdn=False is_not_moe=True
1264
+ INFO:root:[DEBUG] Layer 48: args.attention_output_gate=False
1265
+ INFO:root:[DEBUG] set_gated_selfattn_state: args.attention_output_gate=False
1266
+ INFO:root:[DEBUG] set_gated_selfattn_state: attention_output_gate=False, linear_layer=TELayerNormColumnParallelLinear
1267
+ INFO:root:Converting layer 49 is_gdn=False is_not_moe=True
1268
+ INFO:root:[DEBUG] Layer 49: args.attention_output_gate=False
1269
+ INFO:root:[DEBUG] set_gated_selfattn_state: args.attention_output_gate=False
1270
+ INFO:root:[DEBUG] set_gated_selfattn_state: attention_output_gate=False, linear_layer=TELayerNormColumnParallelLinear
1271
+ INFO:root:Converting layer 50 is_gdn=True is_not_moe=True
1272
+ INFO:root:Converting layer 51 is_gdn=True is_not_moe=True
1273
+ INFO:root:Converting layer 52 is_gdn=True is_not_moe=True
1274
+ INFO:root:Converting layer 53 is_gdn=True is_not_moe=True
1275
+ INFO:root:Converting layer 54 is_gdn=True is_not_moe=True
1276
+ INFO:root:Converting layer 55 is_gdn=True is_not_moe=True
1277
+ DEBUG:root:[RANK 0] 0 send op & 0 recv op.
1278
+ INFO:root:[Iters 0 RANK 0] starts synchronizing parameters with other ranks...
1279
+ INFO:root:[Iters 0 RANK 0] finishes synchronizing
1280
+ [Iters 0 RANK 0] model.safetensors is saved.
1281
+ DEBUG:root:[Iters 0 RANK 0] joined
1282
+ Conversion finished in 25.8255398273468 seconds.
1283
+ DEBUG:filelock:Attempting to acquire lock 140691634004240 on /root/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
1284
+ DEBUG:filelock:Lock 140691634004240 acquired on /root/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
1285
+ DEBUG:filelock:Attempting to release lock 140691634004240 on /root/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
1286
+ DEBUG:filelock:Lock 140691634004240 released on /root/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
1287
+ DEBUG:filelock:Attempting to acquire lock 140691634004560 on /root/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
1288
+ DEBUG:filelock:Lock 140691634004560 acquired on /root/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
1289
+ DEBUG:filelock:Attempting to release lock 140691634004560 on /root/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
1290
+ DEBUG:filelock:Lock 140691634004560 released on /root/.triton/autotune/Fp16Matmul_4d_kernel.pickle.lock
1291
+ [rank0]:[W416 17:59:16.552855020 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
YuLan-Mini-Nanbeige-Distill/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44938572364bf8fd69d242c80835739972d73d311ef23ba535bef86de5934087
3
+ size 5343303296
YuLan-Mini-Nanbeige-Distill/modeling_qwen3_next.py ADDED
@@ -0,0 +1,1561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/qwen3_next/modular_qwen3_next.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_qwen3_next.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ import math
23
+ from typing import Any, Callable, Optional, Union
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+ from torch import nn
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.cache_utils import Cache
31
+ from transformers.generation import GenerationMixin
32
+ from transformers.masking_utils import create_causal_mask
33
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
34
+ from transformers.modeling_layers import (
35
+ GradientCheckpointingLayer,
36
+ )
37
+ from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
38
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
40
+ from transformers.processing_utils import Unpack
41
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
42
+ from transformers.utils.deprecation import deprecate_kwarg
43
+ from transformers.utils.generic import OutputRecorder, check_model_inputs
44
+ from transformers.utils.import_utils import (
45
+ is_causal_conv1d_available,
46
+ is_flash_linear_attention_available,
47
+ )
48
+ try:
49
+ from configuration_qwen3_next import Qwen3NextConfig
50
+ except ImportError:
51
+ from .configuration_qwen3_next import Qwen3NextConfig
52
+
53
+ if is_causal_conv1d_available():
54
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
55
+ else:
56
+ causal_conv1d_update, causal_conv1d_fn = None, None
57
+
58
+ if is_flash_linear_attention_available():
59
+ from fla.modules import FusedRMSNormGated
60
+ from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
61
+ else:
62
+ chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None
63
+ FusedRMSNormGated = None
64
+
65
+ logger = logging.get_logger(__name__)
66
+
67
+ # ANSI colors for console (custom features log)
68
+ _CG = "\033[92m" # green
69
+ _CY = "\033[93m" # yellow
70
+ _CC = "\033[96m" # cyan
71
+ _CR = "\033[0m" # reset
72
+
73
+
74
+ def _log_custom_features(config: Qwen3NextConfig) -> None:
75
+ """Log attn/rnn RoPE and attn_logits_scaling settings (once per model load)."""
76
+ attn_pe = getattr(config, "attn_position_embedding_type", "rope")
77
+ rnn_pe = getattr(config, "rnn_position_embedding_type", "nope")
78
+ logits_scaling = getattr(config, "attn_logits_scaling", None)
79
+ msg = (
80
+ f"{_CG}[Qwen3Next custom]{_CR} "
81
+ f"{_CC}attn_position_embedding_type{_CR}={_CY}{attn_pe}{_CR}, "
82
+ f"{_CC}rnn_position_embedding_type{_CR}={_CY}{rnn_pe}{_CR}, "
83
+ f"{_CC}attn_logits_scaling{_CR}={_CY}{logits_scaling}{_CR}"
84
+ )
85
+ logger.info(msg)
86
+
87
+
88
+ class Qwen3NextRMSNormGated(nn.Module):
89
+ def __init__(self, hidden_size, eps=1e-6, **kwargs):
90
+ super().__init__()
91
+ self.weight = nn.Parameter(torch.ones(hidden_size))
92
+ self.variance_epsilon = eps
93
+
94
+ def forward(self, hidden_states, gate=None):
95
+ input_dtype = hidden_states.dtype
96
+ hidden_states = hidden_states.to(torch.float32)
97
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
98
+ # Norm before gate
99
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
100
+ hidden_states = self.weight * hidden_states.to(input_dtype)
101
+ hidden_states = hidden_states * F.silu(gate.to(torch.float32))
102
+
103
+ return hidden_states.to(input_dtype)
104
+
105
+
106
+ class Qwen3NextDynamicCache:
107
+ """
108
+ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention
109
+ cache (which has a constant shape regardless of seq_len).
110
+
111
+ This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
112
+ and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
113
+ For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
114
+ while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
115
+ For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
116
+ while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
117
+ and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`.
118
+ """
119
+
120
+ is_compileable = False
121
+
122
+ def __init__(self, config: Qwen3NextConfig):
123
+ super().__init__()
124
+ self.layer_types = config.layer_types
125
+ self.transformer_layers = [
126
+ i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention"
127
+ ]
128
+ self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention")
129
+
130
+ # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference
131
+ self.conv_states = [None for _ in range(config.num_hidden_layers)]
132
+ self.recurrent_states = [None for _ in range(config.num_hidden_layers)]
133
+ self.key_cache = [None for _ in range(config.num_hidden_layers)]
134
+ self.value_cache = [None for _ in range(config.num_hidden_layers)]
135
+
136
+ def __len__(self):
137
+ return len(self.layer_types)
138
+
139
+ def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
140
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
141
+
142
+ def update(
143
+ self,
144
+ key_states: torch.Tensor,
145
+ value_states: torch.Tensor,
146
+ layer_idx: int,
147
+ cache_kwargs: Optional[dict[str, Any]] = None,
148
+ ) -> tuple[torch.Tensor, torch.Tensor]:
149
+ if self.key_cache[layer_idx] is None:
150
+ self.key_cache[layer_idx] = key_states
151
+ self.value_cache[layer_idx] = value_states
152
+ else:
153
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
154
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
155
+
156
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
157
+
158
+ def reorder_cache(self, beam_idx: torch.LongTensor):
159
+ """Reorders the cache for beam search, given the selected beam indices."""
160
+ for layer_idx in range(len(self.key_cache)):
161
+ if self.key_cache[layer_idx] is not None:
162
+ device = self.key_cache[layer_idx].device
163
+ beam_idx = beam_idx.to(device)
164
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx)
165
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx)
166
+
167
+ if self.conv_states[layer_idx] is not None:
168
+ device = self.conv_states[layer_idx].device
169
+ beam_idx = beam_idx.to(device)
170
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx)
171
+ self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx)
172
+
173
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
174
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
175
+ # take any layer that contains cache and not empty tensor
176
+ layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
177
+ if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None:
178
+ return 0
179
+ return self.key_cache[layer_idx].shape[-2]
180
+
181
+ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
182
+ """
183
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
184
+ the given layer at `layer_idx`.
185
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
186
+ """
187
+ kv_offset = 0
188
+ query_length = cache_position.shape[0]
189
+ past_seen_tokens = self.get_seq_length(layer_idx)
190
+ kv_length = query_length + past_seen_tokens
191
+ return kv_length, kv_offset
192
+
193
+ @property
194
+ def has_previous_state(self):
195
+ """We have a previous state if the last linear (conv) layer was already updated."""
196
+ return self.conv_states[self.last_linear_layer] is not None
197
+
198
+
199
+ class Qwen3NextRotaryEmbedding(nn.Module):
200
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
201
+
202
+ def __init__(self, config: Qwen3NextConfig, device=None):
203
+ super().__init__()
204
+ # BC: "rope_type" was originally "type"
205
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
206
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
207
+ else:
208
+ self.rope_type = "default"
209
+ self.max_seq_len_cached = config.max_position_embeddings
210
+ self.original_max_seq_len = config.max_position_embeddings
211
+
212
+ self.config = config
213
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
214
+
215
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
216
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
217
+ self.original_inv_freq = self.inv_freq
218
+
219
+ @torch.no_grad()
220
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
221
+ def forward(self, x, position_ids):
222
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
223
+ position_ids_expanded = position_ids[:, None, :].float()
224
+
225
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
226
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
227
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
228
+ emb = torch.cat((freqs, freqs), dim=-1)
229
+ cos = emb.cos() * self.attention_scaling
230
+ sin = emb.sin() * self.attention_scaling
231
+
232
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
233
+
234
+
235
+ class Qwen3NextRMSNorm(nn.Module):
236
+ def __init__(self, dim: int, eps: float = 1e-6):
237
+ super().__init__()
238
+ self.eps = eps
239
+ self.weight = nn.Parameter(torch.zeros(dim))
240
+
241
+ def _norm(self, x):
242
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
243
+
244
+ def forward(self, x):
245
+ output = self._norm(x.float())
246
+ # Llama does x.to(float16) * w whilst Qwen3Next is (x * w).to(float16)
247
+ # See https://github.com/huggingface/transformers/pull/29402
248
+ output = output * (1.0 + self.weight.float())
249
+ return output.type_as(x)
250
+
251
+ def extra_repr(self):
252
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
253
+
254
+
255
+ class LlamaRMSNorm(nn.Module): # Copy from Llama
256
+ def __init__(self, hidden_size, eps=1e-6):
257
+ super().__init__()
258
+ self.weight = nn.Parameter(torch.ones(hidden_size))
259
+ self.variance_epsilon = eps
260
+
261
+ def forward(self, hidden_states):
262
+ input_dtype = hidden_states.dtype
263
+ hidden_states = hidden_states.to(torch.float32)
264
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
265
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
266
+ return self.weight * hidden_states.to(input_dtype)
267
+
268
+
269
+ def rotate_half(x):
270
+ """Rotates half the hidden dims of the input."""
271
+ x1 = x[..., : x.shape[-1] // 2]
272
+ x2 = x[..., x.shape[-1] // 2 :]
273
+ return torch.cat((-x2, x1), dim=-1)
274
+
275
+
276
+ # Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
277
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
278
+ """Applies Rotary Position Embedding to the query and key tensors.
279
+
280
+ Removes the interleaving of cos and sin from GLM
281
+
282
+ Args:
283
+ q (`torch.Tensor`): The query tensor.
284
+ k (`torch.Tensor`): The key tensor.
285
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
286
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
287
+ position_ids (`torch.Tensor`, *optional*):
288
+ Deprecated and unused.
289
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
290
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
291
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
292
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
293
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
294
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
295
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
296
+ Returns:
297
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
298
+ """
299
+ cos = cos.unsqueeze(unsqueeze_dim)
300
+ sin = sin.unsqueeze(unsqueeze_dim)
301
+
302
+ # Keep half or full tensor for later concatenation
303
+ rotary_dim = cos.shape[-1]
304
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
305
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
306
+
307
+ # Apply rotary embeddings on the first half or full tensor
308
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
309
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
310
+
311
+ # Concatenate back to full shape
312
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
313
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
314
+ return q_embed, k_embed
315
+
316
+
317
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
318
+ """
319
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
320
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
321
+ """
322
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
323
+ if n_rep == 1:
324
+ return hidden_states
325
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
326
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
327
+
328
+
329
+ def eager_attention_forward(
330
+ module: nn.Module,
331
+ query: torch.Tensor,
332
+ key: torch.Tensor,
333
+ value: torch.Tensor,
334
+ attention_mask: Optional[torch.Tensor],
335
+ scaling: float,
336
+ dropout: float = 0.0,
337
+ **kwargs: Unpack[TransformersKwargs],
338
+ ):
339
+ key_states = repeat_kv(key, module.num_key_value_groups)
340
+ value_states = repeat_kv(value, module.num_key_value_groups)
341
+
342
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
343
+ if attention_mask is not None:
344
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
345
+ attn_weights = attn_weights + causal_mask
346
+
347
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
348
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
349
+ attn_output = torch.matmul(attn_weights, value_states)
350
+ attn_output = attn_output.transpose(1, 2).contiguous()
351
+
352
+ return attn_output, attn_weights
353
+
354
+
355
+ class Qwen3NextAttention(nn.Module):
356
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
357
+
358
+ def __init__(self, config: Qwen3NextConfig, layer_idx: int):
359
+ super().__init__()
360
+ self.config = config
361
+ self.layer_idx = layer_idx
362
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
363
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
364
+ self.scaling = self.head_dim**-0.5
365
+ self.attention_dropout = config.attention_dropout
366
+ self.is_causal = True
367
+ self.attn_output_gate = config.attn_output_gate
368
+ self.q_proj = nn.Linear(
369
+ config.hidden_size, config.num_attention_heads * self.head_dim * (1 + self.attn_output_gate), bias=config.attention_bias
370
+ )
371
+ self.k_proj = nn.Linear(
372
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
373
+ )
374
+ self.v_proj = nn.Linear(
375
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
376
+ )
377
+ self.o_proj = nn.Linear(
378
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=False
379
+ )
380
+ self.enable_qk_norm = config.enable_qk_norm
381
+ if self.enable_qk_norm:
382
+ self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
383
+ self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
384
+ else:
385
+ self.q_norm = None
386
+ self.k_norm = None
387
+
388
+ # Separate RoPE for attention: "rope" or "nope" (no RoPE)
389
+ self.attn_position_embedding_type = getattr(config, "attn_position_embedding_type", "rope")
390
+ # Optional logits scaling for length extrapolation: None, float, or "log" / "log <a>"
391
+ self.attn_logits_scaling = getattr(config, "attn_logits_scaling", None)
392
+
393
+ # Token shift on Q/K/V after projection (cannon layer, conv mode); per-head conv (head_dim, 1, kernel_size)
394
+ kernel_size = getattr(config, "token_shift_conv_size", 4)
395
+ self.attn_q_token_shift = getattr(config, "attn_q_token_shift", None)
396
+ self.attn_k_token_shift = getattr(config, "attn_k_token_shift", None)
397
+ self.attn_v_token_shift = getattr(config, "attn_v_token_shift", None)
398
+ if self.attn_q_token_shift == "conv":
399
+ self.q_token_shift_conv = nn.Conv1d(
400
+ self.head_dim,
401
+ self.head_dim,
402
+ kernel_size=kernel_size,
403
+ padding=0,
404
+ groups=self.head_dim,
405
+ bias=False,
406
+ )
407
+ else:
408
+ self.q_token_shift_conv = None
409
+ if self.attn_k_token_shift == "conv":
410
+ self.k_token_shift_conv = nn.Conv1d(
411
+ self.head_dim,
412
+ self.head_dim,
413
+ kernel_size=kernel_size,
414
+ padding=0,
415
+ groups=self.head_dim,
416
+ bias=False,
417
+ )
418
+ else:
419
+ self.k_token_shift_conv = None
420
+ if self.attn_v_token_shift == "conv":
421
+ self.v_token_shift_conv = nn.Conv1d(
422
+ self.head_dim,
423
+ self.head_dim,
424
+ kernel_size=kernel_size,
425
+ padding=0,
426
+ groups=self.head_dim,
427
+ bias=False,
428
+ )
429
+ else:
430
+ self.v_token_shift_conv = None
431
+
432
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
433
+ def forward(
434
+ self,
435
+ hidden_states: torch.Tensor,
436
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
437
+ attention_mask: Optional[torch.Tensor],
438
+ past_key_values: Optional[Cache] = None,
439
+ cache_position: Optional[torch.LongTensor] = None,
440
+ **kwargs: Unpack[FlashAttentionKwargs],
441
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
442
+ input_shape = hidden_states.shape[:-1]
443
+ hidden_shape = (*input_shape, -1, self.head_dim)
444
+
445
+ if self.attn_output_gate:
446
+ query_states, gate = torch.chunk(
447
+ self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
448
+ )
449
+ gate = gate.reshape(*input_shape, -1)
450
+ else:
451
+ query_states = self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim)
452
+ gate = None
453
+
454
+ if self.enable_qk_norm:
455
+ query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
456
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
457
+ else:
458
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
459
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
460
+
461
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
462
+
463
+ # Token shift on Q/K/V (cannon layer, conv mode): [batch, heads, seq, head_dim] -> [batch*heads, head_dim, seq]
464
+ if self.attn_q_token_shift == "conv" and self.q_token_shift_conv is not None:
465
+ b, nh, sq, hn = query_states.shape
466
+ q_bcl = query_states.reshape(b * nh, hn, sq)
467
+ q_bcl = apply_causal_depthwise_conv1d_bcl(q_bcl, self.q_token_shift_conv.weight)
468
+ query_states = q_bcl.reshape(b, nh, hn, sq).permute(0, 1, 3, 2).contiguous()
469
+ if self.attn_k_token_shift == "conv" and self.k_token_shift_conv is not None:
470
+ b, ng, sq, hn = key_states.shape
471
+ k_bcl = key_states.reshape(b * ng, hn, sq)
472
+ k_bcl = apply_causal_depthwise_conv1d_bcl(k_bcl, self.k_token_shift_conv.weight)
473
+ key_states = k_bcl.reshape(b, ng, hn, sq).permute(0, 1, 3, 2).contiguous()
474
+ if self.attn_v_token_shift == "conv" and self.v_token_shift_conv is not None:
475
+ b, ng, sq, hn = value_states.shape
476
+ v_bcl = value_states.reshape(b * ng, hn, sq)
477
+ v_bcl = apply_causal_depthwise_conv1d_bcl(v_bcl, self.v_token_shift_conv.weight)
478
+ value_states = v_bcl.reshape(b, ng, hn, sq).permute(0, 1, 3, 2).contiguous()
479
+
480
+ cos, sin = position_embeddings
481
+ if self.attn_position_embedding_type == "rope":
482
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
483
+ # when "nope", do not apply RoPE
484
+
485
+ if past_key_values is not None:
486
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
487
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
488
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
489
+
490
+ # Optional logits scaling for length extrapolation (before attention, same as FLA)
491
+ if self.attn_logits_scaling is not None:
492
+ if isinstance(self.attn_logits_scaling, (int, float)):
493
+ query_states = query_states * float(self.attn_logits_scaling)
494
+ elif isinstance(self.attn_logits_scaling, str):
495
+ position_ids = kwargs.get("position_ids")
496
+ if position_ids is None:
497
+ batch_size, num_heads, seq_len, _ = query_states.shape
498
+ position_ids = torch.arange(
499
+ seq_len, device=query_states.device, dtype=torch.long
500
+ ).unsqueeze(0).expand(batch_size, -1)
501
+ logger.warning_once(
502
+ "attn_logits_scaling uses position-dependent scaling but position_ids was not passed; "
503
+ "using arange(0, seq_len). Pass position_ids for correct behavior with padding."
504
+ )
505
+ parts = self.attn_logits_scaling.split()
506
+ a = float(parts[1]) if len(parts) > 1 else 362.0
507
+ position_ids_f = position_ids.to(device=query_states.device, dtype=torch.float32)
508
+ scale = (torch.log(position_ids_f + a) / math.log(a)).to(query_states.dtype)
509
+ # query_states: (B, H, T, D); scale: (B, T) -> broadcast
510
+ query_states = query_states * scale.unsqueeze(1).unsqueeze(-1)
511
+ else:
512
+ raise TypeError(
513
+ f"attn_logits_scaling must be float, str or None, got {type(self.attn_logits_scaling)}"
514
+ )
515
+
516
+ attention_interface: Callable = eager_attention_forward
517
+ if self.config._attn_implementation != "eager":
518
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
519
+
520
+ attn_output, attn_weights = attention_interface(
521
+ self,
522
+ query_states,
523
+ key_states,
524
+ value_states,
525
+ attention_mask,
526
+ dropout=0.0 if not self.training else self.attention_dropout,
527
+ scaling=self.scaling,
528
+ **kwargs,
529
+ )
530
+
531
+ if self.attn_output_gate:
532
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
533
+ attn_output = attn_output * torch.sigmoid(gate)
534
+ else:
535
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
536
+
537
+ attn_output = self.o_proj(attn_output)
538
+ return attn_output, attn_weights
539
+
540
+
541
+ def apply_mask_to_padding_states(hidden_states, attention_mask):
542
+ """
543
+ Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
544
+ """
545
+ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
546
+ dtype = hidden_states.dtype
547
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
548
+
549
+ return hidden_states
550
+
551
+
552
+ is_fast_path_available = all(
553
+ (causal_conv1d_fn, causal_conv1d_update, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule)
554
+ )
555
+
556
+
557
+ def torch_causal_conv1d_update(
558
+ hidden_states,
559
+ conv_state,
560
+ weight,
561
+ bias=None,
562
+ activation=None,
563
+ ):
564
+ _, hidden_size, seq_len = hidden_states.shape
565
+ state_len = conv_state.shape[-1]
566
+
567
+ hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
568
+ conv_state.copy_(hidden_states_new[:, :, -state_len:])
569
+ out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
570
+ out = F.silu(out[:, :, -seq_len:])
571
+ out = out.to(hidden_states.dtype)
572
+ return out
573
+
574
+
575
+ def apply_causal_depthwise_conv1d_bcl(
576
+ x_bcl: torch.Tensor,
577
+ weight_c1w: torch.Tensor,
578
+ ) -> torch.Tensor:
579
+ """Apply causal depthwise Conv1d for token shifting. Aligns with Megatron cannon layer.
580
+
581
+ x_bcl: [B, C, L], weight_c1w: [C, 1, W]. Output: [B, C, L].
582
+ """
583
+ if causal_conv1d_fn is not None:
584
+ return causal_conv1d_fn(
585
+ x=x_bcl,
586
+ weight=weight_c1w.squeeze(1),
587
+ bias=None,
588
+ activation=None,
589
+ )
590
+ w = weight_c1w.shape[-1]
591
+ x_pad = F.pad(x_bcl, (w - 1, 0))
592
+ y = F.conv1d(
593
+ x_pad,
594
+ weight_c1w,
595
+ bias=None,
596
+ stride=1,
597
+ padding=0,
598
+ groups=weight_c1w.shape[0],
599
+ )
600
+ return y[..., : x_bcl.shape[-1]]
601
+
602
+
603
+ def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
604
+ """This function is intended to align with the l2norm implementation in the FLA library."""
605
+ inv_norm = 1 / torch.sqrt((x * x).sum(dim=dim, keepdim=True) + eps)
606
+ return x * inv_norm
607
+
608
+
609
+ def torch_chunk_gated_delta_rule(
610
+ query,
611
+ key,
612
+ value,
613
+ g,
614
+ beta,
615
+ chunk_size=64,
616
+ initial_state=None,
617
+ output_final_state=False,
618
+ use_qk_l2norm_in_kernel=False,
619
+ ):
620
+ initial_dtype = query.dtype
621
+ if use_qk_l2norm_in_kernel:
622
+ query = l2norm(query, dim=-1, eps=1e-6)
623
+ key = l2norm(key, dim=-1, eps=1e-6)
624
+ query, key, value, beta, g = [
625
+ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
626
+ ]
627
+
628
+ batch_size, sequence_length, num_heads, k_head_dim = key.shape
629
+ v_head_dim = value.shape[-1]
630
+ pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
631
+ query = F.pad(query, (0, 0, 0, pad_size))
632
+ key = F.pad(key, (0, 0, 0, pad_size))
633
+ value = F.pad(value, (0, 0, 0, pad_size))
634
+ beta = F.pad(beta, (0, pad_size))
635
+ g = F.pad(g, (0, pad_size))
636
+ tot_heads = num_heads + pad_size
637
+ scale = 1 / (query.shape[-1] ** 0.5)
638
+ query = query * scale
639
+
640
+ v_beta = value * beta.unsqueeze(-1)
641
+ k_beta = key * beta.unsqueeze(-1)
642
+ # reshape to chunks
643
+ query, key, value, k_beta, v_beta = [
644
+ x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
645
+ ]
646
+ g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
647
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
648
+
649
+ # chunk decay
650
+ g = g.cumsum(dim=-1)
651
+ decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
652
+ attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
653
+ for i in range(1, chunk_size):
654
+ row = attn[..., i, :i].clone()
655
+ sub = attn[..., :i, :i].clone()
656
+ attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
657
+ attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
658
+ value = attn @ v_beta
659
+ k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
660
+ last_recurrent_state = (
661
+ torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
662
+ if initial_state is None
663
+ else initial_state.to(value)
664
+ )
665
+ core_attn_out = torch.zeros_like(value)
666
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
667
+
668
+ # for each chunk
669
+ for i in range(0, tot_heads // chunk_size):
670
+ q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
671
+ attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
672
+ v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
673
+ v_new = v_i - v_prime
674
+ attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
675
+ core_attn_out[:, :, i] = attn_inter + attn @ v_new
676
+ last_recurrent_state = (
677
+ last_recurrent_state * g[:, :, i, -1, None, None].exp()
678
+ + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
679
+ )
680
+
681
+ if not output_final_state:
682
+ last_recurrent_state = None
683
+ core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
684
+ core_attn_out = core_attn_out[:, :, :num_heads]
685
+ core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
686
+ return core_attn_out, last_recurrent_state
687
+
688
+
689
+ def torch_recurrent_gated_delta_rule(
690
+ query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
691
+ ):
692
+ initial_dtype = query.dtype
693
+ if use_qk_l2norm_in_kernel:
694
+ query = l2norm(query, dim=-1, eps=1e-6)
695
+ key = l2norm(key, dim=-1, eps=1e-6)
696
+ query, key, value, beta, g = [
697
+ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
698
+ ]
699
+
700
+ batch_size, sequence_length, num_heads, k_head_dim = key.shape
701
+ v_head_dim = value.shape[-1]
702
+ scale = 1 / (query.shape[-1] ** 0.5)
703
+ query = query * scale
704
+
705
+ core_attn_out = torch.zeros(batch_size, sequence_length, num_heads, v_head_dim).to(value)
706
+ last_recurrent_state = (
707
+ torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
708
+ if initial_state is None
709
+ else initial_state.to(value)
710
+ )
711
+
712
+ for i in range(num_heads):
713
+ q_t = query[:, :, i]
714
+ k_t = key[:, :, i]
715
+ v_t = value[:, :, i]
716
+ g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
717
+ beta_t = beta[:, :, i].unsqueeze(-1)
718
+
719
+ last_recurrent_state = last_recurrent_state * g_t
720
+ kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
721
+ delta = (v_t - kv_mem) * beta_t
722
+ last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
723
+ core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
724
+
725
+ if not output_final_state:
726
+ last_recurrent_state = None
727
+ core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
728
+ return core_attn_out, last_recurrent_state
729
+
730
+
731
+ class Qwen3NextGatedDeltaNet(nn.Module):
732
+ def __init__(self, config: Qwen3NextConfig, layer_idx: int):
733
+ super().__init__()
734
+ self.hidden_size = config.hidden_size
735
+ self.num_v_heads = config.linear_num_value_heads
736
+ self.num_k_heads = config.linear_num_key_heads
737
+ self.head_k_dim = config.linear_key_head_dim
738
+ self.head_v_dim = config.linear_value_head_dim
739
+ self.key_dim = self.head_k_dim * self.num_k_heads
740
+ self.value_dim = self.head_v_dim * self.num_v_heads
741
+
742
+ self.conv_kernel_size = config.linear_conv_kernel_dim
743
+ self.layer_idx = layer_idx
744
+
745
+ # Optional RoPE for linear/RNN path: "rope" or "nope" (same as FLA rnn_position_embedding_type)
746
+ self.rnn_position_embedding_type = getattr(config, "rnn_position_embedding_type", "nope")
747
+ rope_theta = getattr(config, "rope_theta", 10000.0)
748
+ if self.rnn_position_embedding_type == "rope":
749
+ inv_freq = 1.0 / (
750
+ rope_theta ** (torch.arange(0, self.head_k_dim, 2, dtype=torch.float32) / self.head_k_dim)
751
+ )
752
+ self.register_buffer("_inv_freq", inv_freq, persistent=False)
753
+ else:
754
+ self._inv_freq = None
755
+ self.activation = config.hidden_act
756
+ self.act = ACT2FN[config.hidden_act]
757
+ self.layer_norm_epsilon = config.rms_norm_eps
758
+
759
+ # QKV
760
+ self.conv_dim = self.key_dim * 2 + self.value_dim
761
+ self.conv1d = nn.Conv1d(
762
+ in_channels=self.conv_dim,
763
+ out_channels=self.conv_dim,
764
+ bias=False,
765
+ kernel_size=self.conv_kernel_size,
766
+ groups=self.conv_dim,
767
+ padding=self.conv_kernel_size - 1,
768
+ )
769
+
770
+ # projection of the input hidden states
771
+ projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
772
+ projection_size_ba = self.num_v_heads * 2
773
+ self.in_proj_qkvz = nn.Linear(self.hidden_size, projection_size_qkvz, bias=False)
774
+ self.in_proj_ba = nn.Linear(self.hidden_size, projection_size_ba, bias=False)
775
+
776
+ # time step projection (discretization)
777
+ # instantiate once and copy inv_dt in init_weights of PretrainedModel
778
+ self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))
779
+
780
+ A = torch.empty(self.num_v_heads).uniform_(0, 16)
781
+ self.A_log = nn.Parameter(torch.log(A))
782
+
783
+ # self.norm = FusedRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon)
784
+ self.norm = (
785
+ Qwen3NextRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon)
786
+ if FusedRMSNormGated is None
787
+ else FusedRMSNormGated(
788
+ self.head_v_dim,
789
+ eps=self.layer_norm_epsilon,
790
+ activation=self.activation,
791
+ device=torch.cuda.current_device(),
792
+ dtype=config.dtype if config.dtype is not None else torch.get_current_dtype(),
793
+ )
794
+ )
795
+
796
+ self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
797
+
798
+ self.causal_conv1d_fn = causal_conv1d_fn
799
+ self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update
800
+ self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule
801
+ self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule
802
+
803
+ if not is_fast_path_available:
804
+ logger.warning_once(
805
+ "The fast path is not available because one of the required library is not installed. Falling back to "
806
+ "torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and"
807
+ " https://github.com/Dao-AILab/causal-conv1d"
808
+ )
809
+
810
+ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba):
811
+ """
812
+ Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`.
813
+ """
814
+ # (b, s, d_model) -> (b, s, num_k_heads, 2 * head_k_dim + 2 * head_v_dim * num_v_heads // num_k_heads)
815
+ new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
816
+ self.num_k_heads,
817
+ 2 * self.head_k_dim + 2 * self.head_v_dim * self.num_v_heads // self.num_k_heads,
818
+ )
819
+ new_tensor_shape_ba = mixed_ba.size()[:-1] + (self.num_k_heads, 2 * self.num_v_heads // self.num_k_heads)
820
+
821
+ mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
822
+ mixed_ba = mixed_ba.view(*new_tensor_shape_ba)
823
+ split_arg_list_qkvz = [
824
+ self.head_k_dim,
825
+ self.head_k_dim,
826
+ (self.num_v_heads // self.num_k_heads * self.head_v_dim),
827
+ (self.num_v_heads // self.num_k_heads * self.head_v_dim),
828
+ ]
829
+ split_arg_list_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads]
830
+ query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3)
831
+ b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3)
832
+ # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
833
+ value = value.reshape(value.size(0), value.size(1), -1, self.head_v_dim)
834
+ z = z.reshape(z.size(0), z.size(1), -1, self.head_v_dim)
835
+ b = b.reshape(b.size(0), b.size(1), self.num_v_heads)
836
+ a = a.reshape(a.size(0), a.size(1), self.num_v_heads)
837
+ return query, key, value, z, b, a
838
+
839
+ def forward(
840
+ self,
841
+ hidden_states: torch.Tensor,
842
+ cache_params: Optional[Qwen3NextDynamicCache] = None,
843
+ cache_position: Optional[torch.LongTensor] = None,
844
+ attention_mask: Optional[torch.Tensor] = None,
845
+ position_ids: Optional[torch.LongTensor] = None,
846
+ ):
847
+ hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
848
+
849
+ # Set up dimensions for reshapes later
850
+ batch_size, seq_len, _ = hidden_states.shape
851
+
852
+ use_precomputed_states = (
853
+ cache_params is not None
854
+ and cache_params.has_previous_state
855
+ and seq_len == 1
856
+ and cache_position is not None
857
+ )
858
+
859
+ # getting projected states from cache if it exists
860
+ if cache_params is not None:
861
+ conv_state = cache_params.conv_states[self.layer_idx]
862
+ recurrent_state = cache_params.recurrent_states[self.layer_idx]
863
+
864
+ projected_states_qkvz = self.in_proj_qkvz(hidden_states)
865
+ projected_states_ba = self.in_proj_ba(hidden_states)
866
+ query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
867
+ query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value))
868
+
869
+ mixed_qkv = torch.cat((query, key, value), dim=-1)
870
+ mixed_qkv = mixed_qkv.transpose(1, 2)
871
+
872
+ if use_precomputed_states:
873
+ # 2. Convolution sequence transformation
874
+ # NOTE: the conv state is updated in `causal_conv1d_update`
875
+ mixed_qkv = self.causal_conv1d_update(
876
+ mixed_qkv,
877
+ conv_state,
878
+ self.conv1d.weight.squeeze(1),
879
+ self.conv1d.bias,
880
+ self.activation,
881
+ )
882
+ else:
883
+ if cache_params is not None:
884
+ conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
885
+ cache_params.conv_states[self.layer_idx] = conv_state
886
+ if self.causal_conv1d_fn is not None:
887
+ mixed_qkv = self.causal_conv1d_fn(
888
+ x=mixed_qkv,
889
+ weight=self.conv1d.weight.squeeze(1),
890
+ bias=self.conv1d.bias,
891
+ activation=self.activation,
892
+ seq_idx=None,
893
+ )
894
+ else:
895
+ mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
896
+
897
+ mixed_qkv = mixed_qkv.transpose(1, 2)
898
+ query, key, value = torch.split(
899
+ mixed_qkv,
900
+ [
901
+ self.key_dim,
902
+ self.key_dim,
903
+ self.value_dim,
904
+ ],
905
+ dim=-1,
906
+ )
907
+ query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim)
908
+ key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim)
909
+ value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim)
910
+
911
+ # Optional RoPE for linear/RNN (same as FLA GatedDeltaNet when rnn_position_embedding_type=="rope")
912
+ if self._inv_freq is not None and position_ids is not None:
913
+ # query, key: (batch, seq, num_k_heads, head_k_dim); need cos, sin (batch, seq, head_k_dim)
914
+ inv_freq = self._inv_freq.to(query.device)
915
+ freqs = position_ids[:, :, None].float() * inv_freq[None, None, :]
916
+ emb = torch.cat([freqs, freqs], dim=-1)
917
+ cos = emb.cos().to(query.dtype)
918
+ sin = emb.sin().to(query.dtype)
919
+ query, key = apply_rotary_pos_emb(query, key, cos, sin, unsqueeze_dim=2)
920
+
921
+ beta = b.sigmoid()
922
+ # If the model is loaded in fp16, without the .float() here, A might be -inf
923
+ g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
924
+ if self.num_v_heads // self.num_k_heads > 1:
925
+ query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
926
+ key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
927
+
928
+ if not use_precomputed_states:
929
+ core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
930
+ query,
931
+ key,
932
+ value,
933
+ g=g,
934
+ beta=beta,
935
+ initial_state=None,
936
+ output_final_state=cache_params is not None,
937
+ use_qk_l2norm_in_kernel=True,
938
+ )
939
+
940
+ else:
941
+ core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
942
+ query,
943
+ key,
944
+ value,
945
+ g=g,
946
+ beta=beta,
947
+ initial_state=recurrent_state,
948
+ output_final_state=cache_params is not None,
949
+ use_qk_l2norm_in_kernel=True,
950
+ )
951
+ # Update cache
952
+ if cache_params is not None:
953
+ cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
954
+
955
+ z_shape_og = z.shape
956
+ # reshape input data into 2D tensor
957
+ # core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
958
+ # z = z.reshape(-1, z.shape[-1])
959
+ core_attn_out = core_attn_out
960
+ z = z
961
+ core_attn_out = self.norm(core_attn_out, z)
962
+
963
+ core_attn_out = core_attn_out.reshape(z_shape_og)
964
+ core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1)
965
+
966
+ output = self.out_proj(core_attn_out)
967
+ return output
968
+
969
+
970
+ class Qwen3NextMLP(nn.Module):
971
+ def __init__(self, config, intermediate_size=None):
972
+ super().__init__()
973
+ self.config = config
974
+ self.hidden_size = config.hidden_size
975
+ self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
976
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
977
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
978
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
979
+ self.act_fn = ACT2FN[config.hidden_act]
980
+
981
+ # Token shifting (cannon layer): conv mode adds depthwise Conv1d
982
+ self.ffn_token_shift = getattr(config, "ffn_token_shift", None)
983
+ self.ffn_intermediate_token_shift = getattr(config, "ffn_intermediate_token_shift", None)
984
+ kernel_size = getattr(config, "token_shift_conv_size", 4)
985
+ if self.ffn_token_shift == "conv":
986
+ self.token_shift_conv = nn.Conv1d(
987
+ self.hidden_size,
988
+ self.hidden_size,
989
+ kernel_size=kernel_size,
990
+ padding=0,
991
+ groups=self.hidden_size,
992
+ bias=False,
993
+ )
994
+ else:
995
+ self.token_shift_conv = None
996
+ if self.ffn_intermediate_token_shift == "conv":
997
+ self.intermediate_token_shift_conv = nn.Conv1d(
998
+ self.intermediate_size,
999
+ self.intermediate_size,
1000
+ kernel_size=kernel_size,
1001
+ padding=0,
1002
+ groups=self.intermediate_size,
1003
+ bias=False,
1004
+ )
1005
+ else:
1006
+ self.intermediate_token_shift_conv = None
1007
+
1008
+ def forward(self, x, per_token_scale=None):
1009
+ # Token shift at MLP entry (conv mode; cat mode is stateless, not implemented here for loading conv checkpoints)
1010
+ if self.ffn_token_shift == "conv" and self.token_shift_conv is not None:
1011
+ # x: [batch, seq, hidden] -> [batch, hidden, seq]
1012
+ x_bcl = x.transpose(1, 2).contiguous()
1013
+ x_bcl = apply_causal_depthwise_conv1d_bcl(x_bcl, self.token_shift_conv.weight)
1014
+ x = x_bcl.transpose(1, 2).contiguous()
1015
+
1016
+ # Compute gate and up projections
1017
+ gate = self.gate_proj(x)
1018
+ up = self.up_proj(x)
1019
+ # Apply activation: act_fn(gate) * up
1020
+ intermediate = self.act_fn(gate) * up
1021
+ # Apply per_token_scale if provided (to align with Megatron's behavior)
1022
+ if per_token_scale is not None:
1023
+ intermediate = intermediate * per_token_scale.unsqueeze(-1)
1024
+
1025
+ # Intermediate token shift before down_proj (conv mode)
1026
+ if self.ffn_intermediate_token_shift == "conv" and self.intermediate_token_shift_conv is not None:
1027
+ inter_bcl = intermediate.transpose(1, 2).contiguous()
1028
+ inter_bcl = apply_causal_depthwise_conv1d_bcl(inter_bcl, self.intermediate_token_shift_conv.weight)
1029
+ intermediate = inter_bcl.transpose(1, 2).contiguous()
1030
+
1031
+ # Apply down projection
1032
+ down_proj = self.down_proj(intermediate)
1033
+ return down_proj
1034
+
1035
+
1036
+ class Qwen3NextSparseMoeBlock(nn.Module):
1037
+ def __init__(self, config):
1038
+ super().__init__()
1039
+ self.num_experts = config.num_experts
1040
+ self.top_k = config.num_experts_per_tok
1041
+ self.norm_topk_prob = config.norm_topk_prob
1042
+ self.score_func = config.moe_router_score_function
1043
+
1044
+ # Token shift at MoE entry (cannon layer, reuses ffn_token_shift)
1045
+ self.ffn_token_shift = getattr(config, "ffn_token_shift", None)
1046
+ kernel_size = getattr(config, "token_shift_conv_size", 4)
1047
+ if self.ffn_token_shift == "conv":
1048
+ self.token_shift_conv = nn.Conv1d(
1049
+ config.hidden_size,
1050
+ config.hidden_size,
1051
+ kernel_size=kernel_size,
1052
+ padding=0,
1053
+ groups=config.hidden_size,
1054
+ bias=False,
1055
+ )
1056
+ else:
1057
+ self.token_shift_conv = None
1058
+
1059
+ # gating
1060
+ self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=config.router_bias)
1061
+ self.experts = nn.ModuleList(
1062
+ [Qwen3NextMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
1063
+ )
1064
+
1065
+ if config.shared_expert_intermediate_size > 0:
1066
+ self.shared_expert = Qwen3NextMLP(config, intermediate_size=config.shared_expert_intermediate_size)
1067
+ self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
1068
+ else:
1069
+ self.shared_expert = None
1070
+ self.shared_expert_gate = None
1071
+
1072
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1073
+ """ """
1074
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
1075
+
1076
+ # Token shift at MoE entry (cannon layer, conv mode)
1077
+ if self.ffn_token_shift == "conv" and self.token_shift_conv is not None:
1078
+ h_bcl = hidden_states.transpose(1, 2).contiguous()
1079
+ h_bcl = apply_causal_depthwise_conv1d_bcl(h_bcl, self.token_shift_conv.weight)
1080
+ hidden_states = h_bcl.transpose(1, 2).contiguous()
1081
+
1082
+ hidden_states = hidden_states.view(-1, hidden_dim)
1083
+ # router_logits: (batch * sequence_length, n_experts)
1084
+ router_logits = self.gate(hidden_states)
1085
+
1086
+ if self.score_func == "sigmoid":
1087
+ routing_weights = torch.sigmoid(router_logits.to(torch.float32))
1088
+ elif self.score_func == "softmax":
1089
+ routing_weights = F.softmax(router_logits.to(torch.float32), dim=-1)
1090
+ else:
1091
+ raise NotImplementedError(f"Unknown score function {self.score_func}")
1092
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
1093
+ if self.norm_topk_prob:
1094
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
1095
+ # we cast back to the input dtype
1096
+ routing_weights = routing_weights.to(hidden_states.dtype)
1097
+
1098
+ final_hidden_states = torch.zeros(
1099
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
1100
+ )
1101
+
1102
+ # One hot encode the selected experts to create an expert mask
1103
+ # this will be used to easily index which expert is going to be sollicitated
1104
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
1105
+
1106
+ # Loop over all available experts in the model and perform the computation on each expert
1107
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
1108
+ for expert_idx in expert_hit:
1109
+ expert_layer = self.experts[expert_idx]
1110
+ idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
1111
+
1112
+ # Index the correct hidden states and compute the expert hidden state for
1113
+ # the current expert.
1114
+ # To align with Megatron: apply routing_weights after activation, before down_proj
1115
+ # (instead of after the full expert output)
1116
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
1117
+ current_routing_weights = routing_weights[top_x, idx] # [M] where M is num tokens for this expert
1118
+ # Pass routing_weights as per_token_scale to apply it after activation, before down_proj
1119
+ current_hidden_states = expert_layer(current_state, per_token_scale=current_routing_weights)
1120
+
1121
+ # However `index_add_` only support torch tensors for indexing so we'll use
1122
+ # the `top_x` tensor here.
1123
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
1124
+
1125
+ if self.shared_expert is not None:
1126
+ shared_expert_output = self.shared_expert(hidden_states)
1127
+ if self.shared_expert_gate is not None:
1128
+ shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output
1129
+
1130
+ final_hidden_states = final_hidden_states + shared_expert_output
1131
+
1132
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
1133
+ return final_hidden_states, router_logits
1134
+
1135
+
1136
+ class Qwen3NextDecoderLayer(GradientCheckpointingLayer):
1137
+ def __init__(self, config: Qwen3NextConfig, layer_idx: int):
1138
+ super().__init__()
1139
+ self.hidden_size = config.hidden_size
1140
+
1141
+ # token mixer
1142
+ self.layer_type = config.layer_types[layer_idx]
1143
+ if self.layer_type == "linear_attention":
1144
+ self.linear_attn = Qwen3NextGatedDeltaNet(config, layer_idx)
1145
+ elif self.layer_type == "full_attention":
1146
+ self.self_attn = Qwen3NextAttention(config, layer_idx)
1147
+
1148
+ if (layer_idx not in config.mlp_only_layers) and (
1149
+ config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
1150
+ ):
1151
+ self.mlp = Qwen3NextSparseMoeBlock(config)
1152
+ else:
1153
+ self.mlp = Qwen3NextMLP(config, intermediate_size=config.intermediate_size)
1154
+
1155
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1156
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1157
+
1158
+ # Token shift before attention (cannon layer, conv mode)
1159
+ self.attn_token_shift = getattr(config, "attn_token_shift", None)
1160
+ kernel_size = getattr(config, "token_shift_conv_size", 4)
1161
+ if self.attn_token_shift == "conv":
1162
+ self.attn_token_shift_conv = nn.Conv1d(
1163
+ config.hidden_size,
1164
+ config.hidden_size,
1165
+ kernel_size=kernel_size,
1166
+ padding=0,
1167
+ groups=config.hidden_size,
1168
+ bias=False,
1169
+ )
1170
+ else:
1171
+ self.attn_token_shift_conv = None
1172
+
1173
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
1174
+ def forward(
1175
+ self,
1176
+ hidden_states: torch.Tensor,
1177
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
1178
+ attention_mask: Optional[torch.Tensor] = None,
1179
+ position_ids: Optional[torch.LongTensor] = None,
1180
+ past_key_values: Optional[tuple[torch.Tensor]] = None,
1181
+ cache_position: Optional[torch.LongTensor] = None,
1182
+ **kwargs: Unpack[FlashAttentionKwargs],
1183
+ ) -> torch.FloatTensor:
1184
+ """
1185
+ Args:
1186
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1187
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1188
+ `(batch, sequence_length)` where padding elements are indicated by 0.
1189
+ output_attentions (`bool`, *optional*):
1190
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1191
+ returned tensors for more detail.
1192
+ output_router_logits (`bool`, *optional*):
1193
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss,
1194
+ and should not be returned during inference.
1195
+ use_cache (`bool`, *optional*):
1196
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1197
+ (see `past_key_values`).
1198
+ past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1199
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1200
+ Indices depicting the position of the input sequence tokens in the sequence.
1201
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
1202
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
1203
+ with `head_dim` being the embedding dimension of each attention head.
1204
+ kwargs (`dict`, *optional*):
1205
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
1206
+ into the model
1207
+ """
1208
+ residual = hidden_states
1209
+
1210
+ hidden_states = self.input_layernorm(hidden_states)
1211
+
1212
+ # Token shift before attention (cannon layer, conv mode)
1213
+ if self.attn_token_shift == "conv" and self.attn_token_shift_conv is not None:
1214
+ # [batch, seq, hidden] -> [batch, hidden, seq]
1215
+ h_bcl = hidden_states.transpose(1, 2).contiguous()
1216
+ h_bcl = apply_causal_depthwise_conv1d_bcl(h_bcl, self.attn_token_shift_conv.weight)
1217
+ hidden_states = h_bcl.transpose(1, 2).contiguous()
1218
+
1219
+ # Token Mixer
1220
+ if self.layer_type == "linear_attention":
1221
+ hidden_states = self.linear_attn(
1222
+ hidden_states=hidden_states,
1223
+ cache_params=past_key_values,
1224
+ cache_position=cache_position,
1225
+ attention_mask=attention_mask,
1226
+ position_ids=position_ids,
1227
+ )
1228
+ elif self.layer_type == "full_attention":
1229
+ # Self Attention
1230
+ hidden_states, _ = self.self_attn(
1231
+ hidden_states=hidden_states,
1232
+ attention_mask=attention_mask,
1233
+ position_ids=position_ids,
1234
+ past_key_values=past_key_values,
1235
+ cache_position=cache_position,
1236
+ position_embeddings=position_embeddings,
1237
+ **kwargs,
1238
+ )
1239
+
1240
+ hidden_states = residual + hidden_states
1241
+
1242
+ # Fully Connected
1243
+ residual = hidden_states
1244
+ hidden_states = self.post_attention_layernorm(hidden_states)
1245
+ hidden_states = self.mlp(hidden_states)
1246
+ # For the MoE layers, we need to unpack
1247
+ if isinstance(hidden_states, tuple):
1248
+ hidden_states, _ = hidden_states
1249
+ hidden_states = residual + hidden_states
1250
+
1251
+ return hidden_states
1252
+
1253
+
1254
+ class Qwen3NextPreTrainedModel(PreTrainedModel):
1255
+ config: Qwen3NextConfig
1256
+ base_model_prefix = "model"
1257
+ supports_gradient_checkpointing = True
1258
+ _no_split_modules = ["Qwen3NextDecoderLayer"]
1259
+ _skip_keys_device_placement = "past_key_values"
1260
+ _supports_flash_attn_2 = True
1261
+ _supports_sdpa = True
1262
+ _keys_to_ignore_on_load_unexpected = [r"^mtp.*"]
1263
+ _can_record_outputs = {
1264
+ "router_logits": OutputRecorder(Qwen3NextSparseMoeBlock, index=1),
1265
+ "hidden_states": Qwen3NextDecoderLayer,
1266
+ "attentions": Qwen3NextAttention,
1267
+ }
1268
+ _is_stateful = True
1269
+
1270
+ def _init_weights(self, module):
1271
+ super()._init_weights(module)
1272
+ if isinstance(module, Qwen3NextGatedDeltaNet):
1273
+ module.dt_bias.data.fill_(1.0)
1274
+ module.A_log.data.uniform_(0, 16).log_()
1275
+
1276
+
1277
+ class Qwen3NextModel(Qwen3NextPreTrainedModel):
1278
+ def __init__(self, config: Qwen3NextConfig):
1279
+ super().__init__(config)
1280
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
1281
+ self.layers = nn.ModuleList(
1282
+ [Qwen3NextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1283
+ )
1284
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1285
+ self.rotary_emb = Qwen3NextRotaryEmbedding(config=config)
1286
+ self.gradient_checkpointing = False
1287
+ # Initialize weights and apply final processing
1288
+ self.post_init()
1289
+ _log_custom_features(config)
1290
+
1291
+ # @check_model_inputs
1292
+ @auto_docstring
1293
+ def forward(
1294
+ self,
1295
+ input_ids: Optional[torch.LongTensor] = None,
1296
+ attention_mask: Optional[torch.Tensor] = None,
1297
+ position_ids: Optional[torch.LongTensor] = None,
1298
+ past_key_values: Optional[Cache] = None,
1299
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1300
+ use_cache: Optional[bool] = None,
1301
+ cache_position: Optional[torch.LongTensor] = None,
1302
+ **kwargs: Unpack[TransformersKwargs],
1303
+ ) -> MoeModelOutputWithPast:
1304
+ if (input_ids is None) ^ (inputs_embeds is not None):
1305
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1306
+
1307
+ if inputs_embeds is None:
1308
+ inputs_embeds = self.embed_tokens(input_ids)
1309
+
1310
+ if use_cache and past_key_values is None:
1311
+ past_key_values = Qwen3NextDynamicCache(config=self.config)
1312
+
1313
+ if cache_position is None:
1314
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1315
+ cache_position = torch.arange(
1316
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1317
+ )
1318
+ if position_ids is None:
1319
+ position_ids = cache_position.unsqueeze(0)
1320
+
1321
+ causal_mask = create_causal_mask(
1322
+ config=self.config,
1323
+ input_embeds=inputs_embeds,
1324
+ attention_mask=attention_mask,
1325
+ cache_position=cache_position,
1326
+ past_key_values=past_key_values,
1327
+ position_ids=position_ids,
1328
+ )
1329
+ linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position)
1330
+
1331
+ hidden_states = inputs_embeds
1332
+
1333
+ # create position embeddings to be shared across the decoder layers
1334
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1335
+
1336
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
1337
+ layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask
1338
+
1339
+ hidden_states = decoder_layer(
1340
+ hidden_states,
1341
+ position_embeddings=position_embeddings,
1342
+ attention_mask=layer_mask,
1343
+ position_ids=position_ids,
1344
+ past_key_values=past_key_values,
1345
+ use_cache=use_cache,
1346
+ cache_position=cache_position,
1347
+ **kwargs,
1348
+ )
1349
+
1350
+ hidden_states = self.norm(hidden_states)
1351
+
1352
+ return MoeModelOutputWithPast(
1353
+ last_hidden_state=hidden_states,
1354
+ past_key_values=past_key_values,
1355
+ )
1356
+
1357
+ def _update_linear_attn_mask(self, attention_mask, cache_position):
1358
+ """
1359
+ NOTE: Left-padding is used for linear attention mask.
1360
+ No need for zeroing states when
1361
+ 1. Cached forward
1362
+ 2. Attending to all inputs
1363
+ """
1364
+ linear_attn_mask = attention_mask
1365
+ if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
1366
+ linear_attn_mask = None
1367
+ return linear_attn_mask
1368
+
1369
+
1370
+ def load_balancing_loss_func(
1371
+ gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
1372
+ num_experts: Optional[int] = None,
1373
+ top_k=2,
1374
+ attention_mask: Optional[torch.Tensor] = None,
1375
+ ) -> Union[torch.Tensor, int]:
1376
+ r"""
1377
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
1378
+
1379
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
1380
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
1381
+ experts is too unbalanced.
1382
+
1383
+ Args:
1384
+ gate_logits:
1385
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
1386
+ shape [batch_size X sequence_length, num_experts].
1387
+ num_experts:
1388
+ Number of experts
1389
+ top_k:
1390
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
1391
+ parameter.
1392
+ attention_mask (`torch.Tensor`, *optional*):
1393
+ The attention_mask used in forward function
1394
+ shape [batch_size X sequence_length] if not None.
1395
+
1396
+ Returns:
1397
+ The auxiliary loss.
1398
+ """
1399
+ if gate_logits is None or not isinstance(gate_logits, tuple):
1400
+ return 0
1401
+
1402
+ if isinstance(gate_logits, tuple):
1403
+ compute_device = gate_logits[0].device
1404
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
1405
+
1406
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
1407
+
1408
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
1409
+
1410
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
1411
+
1412
+ if attention_mask is None:
1413
+ # Compute the percentage of tokens routed to each experts
1414
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
1415
+
1416
+ # Compute the average probability of routing to these experts
1417
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
1418
+ else:
1419
+ batch_size, sequence_length = attention_mask.shape
1420
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
1421
+
1422
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
1423
+ expert_attention_mask = (
1424
+ attention_mask[None, :, :, None, None]
1425
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
1426
+ .reshape(-1, top_k, num_experts)
1427
+ .to(compute_device)
1428
+ )
1429
+
1430
+ # Compute the percentage of tokens routed to each experts
1431
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
1432
+ expert_attention_mask, dim=0
1433
+ )
1434
+
1435
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
1436
+ router_per_expert_attention_mask = (
1437
+ attention_mask[None, :, :, None]
1438
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
1439
+ .reshape(-1, num_experts)
1440
+ .to(compute_device)
1441
+ )
1442
+
1443
+ # Compute the average probability of routing to these experts
1444
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
1445
+ router_per_expert_attention_mask, dim=0
1446
+ )
1447
+
1448
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
1449
+ return overall_loss * num_experts
1450
+
1451
+
1452
+ @auto_docstring
1453
+ class Qwen3NextForCausalLM(Qwen3NextPreTrainedModel, GenerationMixin):
1454
+ _tied_weights_keys = ["lm_head.weight"]
1455
+ _tp_plan = {"lm_head": "colwise_rep"}
1456
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
1457
+
1458
+ def __init__(self, config):
1459
+ super().__init__(config)
1460
+ self.model = Qwen3NextModel(config)
1461
+ self.vocab_size = config.vocab_size
1462
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1463
+ self.router_aux_loss_coef = config.router_aux_loss_coef
1464
+ self.num_experts = config.num_experts
1465
+ self.num_experts_per_tok = config.num_experts_per_tok
1466
+
1467
+ # Initialize weights and apply final processing
1468
+ self.post_init()
1469
+
1470
+ @can_return_tuple
1471
+ @auto_docstring
1472
+ def forward(
1473
+ self,
1474
+ input_ids: Optional[torch.LongTensor] = None,
1475
+ attention_mask: Optional[torch.Tensor] = None,
1476
+ position_ids: Optional[torch.LongTensor] = None,
1477
+ past_key_values: Optional[Qwen3NextDynamicCache] = None,
1478
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1479
+ labels: Optional[torch.LongTensor] = None,
1480
+ use_cache: Optional[bool] = None,
1481
+ output_router_logits: Optional[bool] = None,
1482
+ cache_position: Optional[torch.LongTensor] = None,
1483
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1484
+ **kwargs: Unpack[TransformersKwargs],
1485
+ ) -> MoeCausalLMOutputWithPast:
1486
+ r"""
1487
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1488
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1489
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1490
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1491
+
1492
+ Example:
1493
+
1494
+ ```python
1495
+ >>> from transformers import AutoTokenizer, Qwen3NextForCausalLM
1496
+
1497
+ >>> model = Qwen3NextForCausalLM.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct")
1498
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct")
1499
+
1500
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1501
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1502
+
1503
+ >>> # Generate
1504
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1505
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1506
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1507
+ ```"""
1508
+
1509
+ output_router_logits = (
1510
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1511
+ )
1512
+
1513
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1514
+ outputs: MoeModelOutputWithPast = self.model(
1515
+ input_ids=input_ids,
1516
+ attention_mask=attention_mask,
1517
+ position_ids=position_ids,
1518
+ past_key_values=past_key_values,
1519
+ inputs_embeds=inputs_embeds,
1520
+ use_cache=use_cache,
1521
+ output_router_logits=output_router_logits,
1522
+ cache_position=cache_position,
1523
+ **kwargs,
1524
+ )
1525
+
1526
+ hidden_states = outputs.last_hidden_state
1527
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1528
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1529
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1530
+
1531
+ loss = None
1532
+ if labels is not None:
1533
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
1534
+
1535
+ aux_loss = None
1536
+ if output_router_logits:
1537
+ aux_loss = load_balancing_loss_func(
1538
+ outputs.router_logits,
1539
+ self.num_experts,
1540
+ self.num_experts_per_tok,
1541
+ attention_mask,
1542
+ )
1543
+ if labels is not None:
1544
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
1545
+
1546
+ return MoeCausalLMOutputWithPast(
1547
+ loss=loss,
1548
+ aux_loss=aux_loss,
1549
+ logits=logits,
1550
+ past_key_values=outputs.past_key_values,
1551
+ hidden_states=outputs.hidden_states,
1552
+ attentions=outputs.attentions,
1553
+ router_logits=outputs.router_logits,
1554
+ )
1555
+
1556
+
1557
+ __all__ = [
1558
+ "Qwen3NextForCausalLM",
1559
+ "Qwen3NextModel",
1560
+ "Qwen3NextPreTrainedModel",
1561
+ ]
YuLan-Mini-Nanbeige-Distill/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<pad>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
YuLan-Mini-Nanbeige-Distill/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
YuLan-Mini-Nanbeige-Distill/tokenizer_config.json ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "<|endoftext|>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "102": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ },
38
+ "103": {
39
+ "content": "<reasoning_step>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": true
45
+ },
46
+ "104": {
47
+ "content": "<|im_start|>",
48
+ "lstrip": false,
49
+ "normalized": false,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": true
53
+ },
54
+ "105": {
55
+ "content": "<|im_end|>",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": true
61
+ },
62
+ "106": {
63
+ "content": "<|object_ref_start|>",
64
+ "lstrip": false,
65
+ "normalized": false,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": true
69
+ },
70
+ "107": {
71
+ "content": "<|object_ref_end|>",
72
+ "lstrip": false,
73
+ "normalized": false,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": true
77
+ },
78
+ "108": {
79
+ "content": "<|box_start|>",
80
+ "lstrip": false,
81
+ "normalized": false,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": true
85
+ },
86
+ "109": {
87
+ "content": "<|box_end|>",
88
+ "lstrip": false,
89
+ "normalized": false,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": true
93
+ },
94
+ "110": {
95
+ "content": "<|quad_start|>",
96
+ "lstrip": false,
97
+ "normalized": false,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": true
101
+ },
102
+ "111": {
103
+ "content": "<|quad_end|>",
104
+ "lstrip": false,
105
+ "normalized": false,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": true
109
+ },
110
+ "112": {
111
+ "content": "<|vision_start|>",
112
+ "lstrip": false,
113
+ "normalized": false,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": true
117
+ },
118
+ "113": {
119
+ "content": "<|vision_end|>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": true
125
+ },
126
+ "114": {
127
+ "content": "<|vision_pad|>",
128
+ "lstrip": false,
129
+ "normalized": false,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": true
133
+ },
134
+ "115": {
135
+ "content": "<|image_pad|>",
136
+ "lstrip": false,
137
+ "normalized": false,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": true
141
+ },
142
+ "116": {
143
+ "content": "<|video_pad|>",
144
+ "lstrip": false,
145
+ "normalized": false,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": true
149
+ },
150
+ "117": {
151
+ "content": "<tool_call>",
152
+ "lstrip": false,
153
+ "normalized": false,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "118": {
159
+ "content": "</tool_call>",
160
+ "lstrip": false,
161
+ "normalized": false,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "119": {
167
+ "content": "<|fim_prefix|>",
168
+ "lstrip": false,
169
+ "normalized": false,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "120": {
175
+ "content": "<|fim_middle|>",
176
+ "lstrip": false,
177
+ "normalized": false,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "121": {
183
+ "content": "<|fim_suffix|>",
184
+ "lstrip": false,
185
+ "normalized": false,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "122": {
191
+ "content": "<|fim_pad|>",
192
+ "lstrip": false,
193
+ "normalized": false,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "123": {
199
+ "content": "<|repo_name|>",
200
+ "lstrip": false,
201
+ "normalized": false,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ },
206
+ "124": {
207
+ "content": "<|file_sep|>",
208
+ "lstrip": false,
209
+ "normalized": false,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": false
213
+ },
214
+ "125": {
215
+ "content": "<tool_response>",
216
+ "lstrip": false,
217
+ "normalized": false,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": false
221
+ },
222
+ "126": {
223
+ "content": "</tool_response>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": false
229
+ },
230
+ "127": {
231
+ "content": "<think>",
232
+ "lstrip": false,
233
+ "normalized": false,
234
+ "rstrip": false,
235
+ "single_word": false,
236
+ "special": false
237
+ },
238
+ "128": {
239
+ "content": "</think>",
240
+ "lstrip": false,
241
+ "normalized": false,
242
+ "rstrip": false,
243
+ "single_word": false,
244
+ "special": false
245
+ },
246
+ "1071": {
247
+ "content": "<|sequence|>",
248
+ "lstrip": false,
249
+ "normalized": false,
250
+ "rstrip": false,
251
+ "single_word": false,
252
+ "special": true
253
+ },
254
+ "1072": {
255
+ "content": "<|/sequence|>",
256
+ "lstrip": false,
257
+ "normalized": false,
258
+ "rstrip": false,
259
+ "single_word": false,
260
+ "special": true
261
+ },
262
+ "1073": {
263
+ "content": "<|identity|>",
264
+ "lstrip": false,
265
+ "normalized": false,
266
+ "rstrip": false,
267
+ "single_word": false,
268
+ "special": true
269
+ },
270
+ "1074": {
271
+ "content": "<|identity|>",
272
+ "lstrip": false,
273
+ "normalized": false,
274
+ "rstrip": false,
275
+ "single_word": false,
276
+ "special": true
277
+ },
278
+ "1075": {
279
+ "content": "<|tail0|>",
280
+ "lstrip": false,
281
+ "normalized": false,
282
+ "rstrip": false,
283
+ "single_word": false,
284
+ "special": true
285
+ },
286
+ "1076": {
287
+ "content": "<|tail1|>",
288
+ "lstrip": false,
289
+ "normalized": false,
290
+ "rstrip": false,
291
+ "single_word": false,
292
+ "special": true
293
+ },
294
+ "1077": {
295
+ "content": "<|tail2|>",
296
+ "lstrip": false,
297
+ "normalized": false,
298
+ "rstrip": false,
299
+ "single_word": false,
300
+ "special": true
301
+ },
302
+ "1078": {
303
+ "content": "<|tail3|>",
304
+ "lstrip": false,
305
+ "normalized": false,
306
+ "rstrip": false,
307
+ "single_word": false,
308
+ "special": true
309
+ },
310
+ "1079": {
311
+ "content": "<|tail4|>",
312
+ "lstrip": false,
313
+ "normalized": false,
314
+ "rstrip": false,
315
+ "single_word": false,
316
+ "special": true
317
+ },
318
+ "1080": {
319
+ "content": "<|head0|>",
320
+ "lstrip": false,
321
+ "normalized": false,
322
+ "rstrip": false,
323
+ "single_word": false,
324
+ "special": true
325
+ },
326
+ "1081": {
327
+ "content": "<|head1|>",
328
+ "lstrip": false,
329
+ "normalized": false,
330
+ "rstrip": false,
331
+ "single_word": false,
332
+ "special": true
333
+ },
334
+ "1082": {
335
+ "content": "<|head2|>",
336
+ "lstrip": false,
337
+ "normalized": false,
338
+ "rstrip": false,
339
+ "single_word": false,
340
+ "special": true
341
+ },
342
+ "1083": {
343
+ "content": "<|head3|>",
344
+ "lstrip": false,
345
+ "normalized": false,
346
+ "rstrip": false,
347
+ "single_word": false,
348
+ "special": true
349
+ },
350
+ "1084": {
351
+ "content": "<|head4|>",
352
+ "lstrip": false,
353
+ "normalized": false,
354
+ "rstrip": false,
355
+ "single_word": false,
356
+ "special": true
357
+ },
358
+ "1085": {
359
+ "content": "<|chunk_id|>",
360
+ "lstrip": false,
361
+ "normalized": false,
362
+ "rstrip": false,
363
+ "single_word": false,
364
+ "special": true
365
+ },
366
+ "1086": {
367
+ "content": "<|/chunk_id|>",
368
+ "lstrip": false,
369
+ "normalized": false,
370
+ "rstrip": false,
371
+ "single_word": false,
372
+ "special": true
373
+ },
374
+ "1087": {
375
+ "content": "<|last_chunk_id|>",
376
+ "lstrip": false,
377
+ "normalized": false,
378
+ "rstrip": false,
379
+ "single_word": false,
380
+ "special": true
381
+ },
382
+ "1088": {
383
+ "content": "<|/last_chunk_id|>",
384
+ "lstrip": false,
385
+ "normalized": false,
386
+ "rstrip": false,
387
+ "single_word": false,
388
+ "special": true
389
+ }
390
+ },
391
+ "additional_special_tokens": [
392
+ "<|im_start|>",
393
+ "<|im_end|>",
394
+ "<|object_ref_start|>",
395
+ "<|object_ref_end|>",
396
+ "<|box_start|>",
397
+ "<|box_end|>",
398
+ "<|quad_start|>",
399
+ "<|quad_end|>",
400
+ "<|vision_start|>",
401
+ "<|vision_end|>",
402
+ "<|vision_pad|>",
403
+ "<|image_pad|>",
404
+ "<|video_pad|>",
405
+ "<|sequence|>",
406
+ "<|/sequence|>",
407
+ "<|identity|>",
408
+ "<|/identity|>",
409
+ "<|tail0|>",
410
+ "<|tail1|>",
411
+ "<|tail2|>",
412
+ "<|tail3|>",
413
+ "<|tail4|>",
414
+ "<|head0|>",
415
+ "<|head1|>",
416
+ "<|head2|>",
417
+ "<|head3|>",
418
+ "<|head4|>",
419
+ "<|chunk_id|>",
420
+ "<|/chunk_id|>",
421
+ "<|last_chunk_id|>",
422
+ "<|/last_chunk_id|>"
423
+ ],
424
+ "bos_token": null,
425
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n<think>\\n' }}\n{%- endif %}",
426
+ "clean_up_tokenization_spaces": false,
427
+ "eos_token": "<|im_end|>",
428
+ "model_max_length": 32768,
429
+ "pad_token": "<|endoftext|>",
430
+ "padding_side": "right",
431
+ "sp_model_kwargs": {},
432
+ "spaces_between_special_tokens": false,
433
+ "tokenizer_class": "LlamaTokenizerFast",
434
+ "unk_token": "<unk>",
435
+ "use_default_system_prompt": false
436
+ }