ITcoder commited on
Commit
7492f69
·
verified ·
1 Parent(s): 50b4557

Upload folder using huggingface_hub

Browse files
Qwen3-0.6B/config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_qwen3.Qwen3Config",
9
+ "AutoModelForCausalLM": "modeling_qwen3.Qwen3ForCausalLM"
10
+ },
11
+ "bos_token_id": 151643,
12
+ "dtype": "bfloat16",
13
+ "elementwise_attn_output_gate": false,
14
+ "eos_token_id": 151645,
15
+ "ffn_output_gate": true,
16
+ "head_dim": 128,
17
+ "headwise_attn_output_gate": false,
18
+ "hidden_act": "silu",
19
+ "hidden_size": 1024,
20
+ "initializer_range": 0.02,
21
+ "intermediate_size": 3072,
22
+ "max_position_embeddings": 40960,
23
+ "max_window_layers": 28,
24
+ "model_type": "qwen3",
25
+ "num_attention_heads": 16,
26
+ "num_hidden_layers": 28,
27
+ "num_key_value_heads": 8,
28
+ "qkv_bias": false,
29
+ "rms_norm_eps": 1e-06,
30
+ "rope_scaling": null,
31
+ "rope_theta": 1000000,
32
+ "sliding_window": null,
33
+ "tie_word_embeddings": true,
34
+ "torch_dtype": "bfloat16",
35
+ "transformers_version": "4.52.3",
36
+ "use_cache": true,
37
+ "use_qk_norm": true,
38
+ "use_sliding_window": false,
39
+ "vocab_size": 151936,
40
+ "pad_token_id": 151645
41
+ }
Qwen3-0.6B/configuration_qwen3.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Lopyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ """Qwen3 model configuration"""
23
+
24
+ from transformers.configuration_utils import PretrainedConfig
25
+ from transformers.modeling_rope_utils import rope_config_validation
26
+ from transformers.utils import logging
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class Qwen3Config(PretrainedConfig):
32
+ r"""
33
+ This is the configuration class to store the configuration of a [`Qwen3Model`]. It is used to instantiate a
34
+ Qwen3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
35
+ with the defaults will yield a similar configuration to that of
36
+ Qwen3-8B-beta [Qwen/Qwen3-8B-beta](https://huggingface.co/Qwen/Qwen3-8B-beta).
37
+
38
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39
+ documentation from [`PretrainedConfig`] for more information.
40
+
41
+
42
+ Args:
43
+ vocab_size (`int`, *optional*, defaults to 151936):
44
+ Vocabulary size of the Qwen3 model. Defines the number of different tokens that can be represented by the
45
+ `inputs_ids` passed when calling [`Qwen3Model`]
46
+ hidden_size (`int`, *optional*, defaults to 4096):
47
+ Dimension of the hidden representations.
48
+ intermediate_size (`int`, *optional*, defaults to 22016):
49
+ Dimension of the MLP representations.
50
+ num_hidden_layers (`int`, *optional*, defaults to 32):
51
+ Number of hidden layers in the Transformer encoder.
52
+ num_attention_heads (`int`, *optional*, defaults to 32):
53
+ Number of attention heads for each attention layer in the Transformer encoder.
54
+ num_key_value_heads (`int`, *optional*, defaults to 32):
55
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
56
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
57
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
58
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
59
+ by meanpooling all the original heads within that group. For more details checkout [this
60
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
61
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
62
+ The non-linear activation function (function or string) in the decoder.
63
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
64
+ The maximum sequence length that this model might ever be used with.
65
+ initializer_range (`float`, *optional*, defaults to 0.02):
66
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
68
+ The epsilon used by the rms normalization layers.
69
+ use_cache (`bool`, *optional*, defaults to `True`):
70
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
71
+ relevant if `config.is_decoder=True`.
72
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
73
+ Whether the model's input and output word embeddings should be tied.
74
+ rope_theta (`float`, *optional*, defaults to 10000.0):
75
+ The base period of the RoPE embeddings.
76
+ rope_scaling (`Dict`, *optional*):
77
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
78
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
79
+ accordingly.
80
+ Expected contents:
81
+ `rope_type` (`str`):
82
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
83
+ 'llama3'], with 'default' being the original RoPE implementation.
84
+ `factor` (`float`, *optional*):
85
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
86
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
87
+ original maximum pre-trained length.
88
+ `original_max_position_embeddings` (`int`, *optional*):
89
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
90
+ pretraining.
91
+ `attention_factor` (`float`, *optional*):
92
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
93
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
94
+ `factor` field to infer the suggested value.
95
+ `beta_fast` (`float`, *optional*):
96
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
97
+ ramp function. If unspecified, it defaults to 32.
98
+ `beta_slow` (`float`, *optional*):
99
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
100
+ ramp function. If unspecified, it defaults to 1.
101
+ `short_factor` (`List[float]`, *optional*):
102
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
103
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
104
+ size divided by the number of attention heads divided by 2
105
+ `long_factor` (`List[float]`, *optional*):
106
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
107
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
108
+ size divided by the number of attention heads divided by 2
109
+ `low_freq_factor` (`float`, *optional*):
110
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
111
+ `high_freq_factor` (`float`, *optional*):
112
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
113
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
114
+ Whether to use sliding window attention.
115
+ sliding_window (`int`, *optional*, defaults to 4096):
116
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
117
+ max_window_layers (`int`, *optional*, defaults to 28):
118
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
119
+ attention_bias (`bool`, *optional*, defaults to `False`):
120
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
121
+ attention_dropout (`float`, *optional*, defaults to 0.0):
122
+ The dropout ratio for the attention probabilities.
123
+ use_qk_norm (`bool`, *optional*, defaults to `False`):
124
+ Whether query and key in attention use norm
125
+ ```python
126
+ >>> from transformers import Qwen3Model, Qwen3Config
127
+
128
+ >>> # Initializing a Qwen3 style configuration
129
+ >>> configuration = Qwen3Config()
130
+
131
+ >>> # Initializing a model from the Qwen3-8B style configuration
132
+ >>> model = Qwen3Model(configuration)
133
+
134
+ >>> # Accessing the model configuration
135
+ >>> configuration = model.config
136
+ ```"""
137
+
138
+ model_type = "qwen3"
139
+ keys_to_ignore_at_inference = ["past_key_values"]
140
+
141
+ # Default tensor parallel plan for base model `Qwen3`
142
+ base_model_tp_plan = {
143
+ "layers.*.self_attn.q_proj": "colwise",
144
+ "layers.*.self_attn.k_proj": "colwise",
145
+ "layers.*.self_attn.v_proj": "colwise",
146
+ "layers.*.self_attn.o_proj": "rowwise",
147
+ "layers.*.mlp.gate_proj": "colwise",
148
+ "layers.*.mlp.up_proj": "colwise",
149
+ "layers.*.mlp.down_proj": "rowwise",
150
+ }
151
+
152
+ def __init__(
153
+ self,
154
+ vocab_size=151936,
155
+ hidden_size=4096,
156
+ intermediate_size=22016,
157
+ num_hidden_layers=32,
158
+ num_attention_heads=32,
159
+ num_key_value_heads=32,
160
+ head_dim=128,
161
+ hidden_act="silu",
162
+ max_position_embeddings=32768,
163
+ initializer_range=0.02,
164
+ rms_norm_eps=1e-6,
165
+ use_cache=True,
166
+ tie_word_embeddings=False,
167
+ rope_theta=10000.0,
168
+ rope_scaling=None,
169
+ use_sliding_window=False,
170
+ sliding_window=4096,
171
+ max_window_layers=28,
172
+ attention_bias=False,
173
+ attention_dropout=0.0,
174
+ use_qk_norm=True,
175
+ elementwise_attn_output_gate=False,
176
+ headwise_attn_output_gate=False,
177
+ ffn_output_gate=False,
178
+ **kwargs,
179
+ ):
180
+ self.vocab_size = vocab_size
181
+ self.max_position_embeddings = max_position_embeddings
182
+ self.hidden_size = hidden_size
183
+ self.intermediate_size = intermediate_size
184
+ self.num_hidden_layers = num_hidden_layers
185
+ self.num_attention_heads = num_attention_heads
186
+ self.head_dim = head_dim
187
+ self.use_sliding_window = use_sliding_window
188
+ self.sliding_window = sliding_window if use_sliding_window else None
189
+ self.max_window_layers = max_window_layers
190
+
191
+ # for backward compatibility
192
+ if num_key_value_heads is None:
193
+ num_key_value_heads = num_attention_heads
194
+
195
+ self.num_key_value_heads = num_key_value_heads
196
+ self.hidden_act = hidden_act
197
+ self.initializer_range = initializer_range
198
+ self.rms_norm_eps = rms_norm_eps
199
+ self.use_cache = use_cache
200
+ self.rope_theta = rope_theta
201
+ self.rope_scaling = rope_scaling
202
+ self.attention_bias = attention_bias
203
+ self.attention_dropout = attention_dropout
204
+ self.use_qk_norm = use_qk_norm
205
+
206
+ self.headwise_attn_output_gate = headwise_attn_output_gate
207
+ self.elementwise_attn_output_gate = elementwise_attn_output_gate
208
+ self.ffn_output_gate = ffn_output_gate
209
+
210
+ # Validate the correctness of rotary position embeddings parameters
211
+ # BC: if there is a 'type' field, move it to 'rope_type'.
212
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
213
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
214
+ rope_config_validation(self)
215
+
216
+ super().__init__(
217
+ tie_word_embeddings=tie_word_embeddings,
218
+ **kwargs,
219
+ )
Qwen3-0.6B/generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "4.52.3"
13
+ }
Qwen3-0.6B/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d110d55212ef518aa4e8f62994ec8849488cb75f72788d1240fd943934c3d11
3
+ size 1192198184
Qwen3-0.6B/modeling_qwen3.py ADDED
@@ -0,0 +1,1572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """PyTorch Qwen3 model."""
21
+
22
+ import math
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import torch
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+
30
+ from transformers.activations import ACT2FN
31
+ from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
32
+ from transformers.generation import GenerationMixin
33
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
34
+ from transformers.modeling_outputs import (
35
+ BaseModelOutputWithPast,
36
+ CausalLMOutputWithPast,
37
+ QuestionAnsweringModelOutput,
38
+ SequenceClassifierOutputWithPast,
39
+ TokenClassifierOutput,
40
+ )
41
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
42
+ from transformers.modeling_utils import PreTrainedModel
43
+ from transformers.utils import (
44
+ add_code_sample_docstrings,
45
+ add_start_docstrings,
46
+ add_start_docstrings_to_model_forward,
47
+ is_flash_attn_2_available,
48
+ is_flash_attn_greater_or_equal_2_10,
49
+ logging,
50
+ replace_return_docstrings,
51
+ )
52
+
53
+ # 支持相对导入和绝对导入
54
+ try:
55
+ from .configuration_qwen3 import Qwen3Config
56
+ except ImportError:
57
+ from configuration_qwen3 import Qwen3Config
58
+
59
+ if is_flash_attn_2_available():
60
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
61
+
62
+ logger = logging.get_logger(__name__)
63
+
64
+ _CHECKPOINT_FOR_DOC = "Qwen/Qwen3-8B"
65
+ _CONFIG_FOR_DOC = "Qwen3Config"
66
+
67
+ #============================基础组件(归一化,位置编码,前馈神经网络)============================
68
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen3
69
+ class Qwen3RMSNorm(nn.Module): # 归一化
70
+ def __init__(self, hidden_size, eps=1e-6):
71
+ """
72
+ Qwen3RMSNorm is equivalent to T5LayerNorm
73
+ """
74
+ super().__init__()
75
+ self.weight = nn.Parameter(torch.ones(hidden_size))
76
+ self.variance_epsilon = eps
77
+
78
+ def forward(self, hidden_states):
79
+ input_dtype = hidden_states.dtype
80
+ hidden_states = hidden_states.to(torch.float32)
81
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
82
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
83
+ return self.weight * hidden_states.to(input_dtype)
84
+
85
+ def extra_repr(self):
86
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
87
+
88
+
89
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen3
90
+ class Qwen3RotaryEmbedding(nn.Module): # 旋转位置编码
91
+ def __init__(
92
+ self,
93
+ dim=None,
94
+ max_position_embeddings=2048,
95
+ base=10000,
96
+ device=None,
97
+ scaling_factor=1.0,
98
+ rope_type="default",
99
+ config: Optional[Qwen3Config] = None,
100
+ ):
101
+ super().__init__()
102
+ # TODO (joao): remove the `if` below, only used for BC
103
+ self.rope_kwargs = {}
104
+ if config is None:
105
+ logger.warning_once(
106
+ "`Qwen3RotaryEmbedding` can now be fully parameterized by passing the model config through the "
107
+ "`config` argument. All other arguments will be removed in v4.46"
108
+ )
109
+ self.rope_kwargs = {
110
+ "rope_type": rope_type,
111
+ "factor": scaling_factor,
112
+ "dim": dim,
113
+ "base": base,
114
+ "max_position_embeddings": max_position_embeddings,
115
+ }
116
+ self.rope_type = rope_type
117
+ self.max_seq_len_cached = max_position_embeddings
118
+ self.original_max_seq_len = max_position_embeddings
119
+ else:
120
+ # BC: "rope_type" was originally "type"
121
+ if config.rope_scaling is not None:
122
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
123
+ else:
124
+ self.rope_type = "default"
125
+ self.max_seq_len_cached = config.max_position_embeddings
126
+ self.original_max_seq_len = config.max_position_embeddings
127
+
128
+ self.config = config
129
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
130
+
131
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
132
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
133
+ self.original_inv_freq = self.inv_freq
134
+
135
+ def _dynamic_frequency_update(self, position_ids, device):
136
+ """
137
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
138
+ 1 - growing beyond the cached sequence length (allow scaling)
139
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
140
+ """
141
+ seq_len = torch.max(position_ids) + 1
142
+ if seq_len > self.max_seq_len_cached: # growth
143
+ inv_freq, self.attention_scaling = self.rope_init_fn(
144
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
145
+ )
146
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
147
+ self.max_seq_len_cached = seq_len
148
+
149
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
150
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
151
+ self.max_seq_len_cached = self.original_max_seq_len
152
+
153
+ @torch.no_grad()
154
+ def forward(self, x, position_ids):
155
+ if "dynamic" in self.rope_type:
156
+ self._dynamic_frequency_update(position_ids, device=x.device)
157
+
158
+ # Core RoPE block
159
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
160
+ position_ids_expanded = position_ids[:, None, :].float()
161
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
162
+ device_type = x.device.type
163
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
164
+ with torch.autocast(device_type=device_type, enabled=False):
165
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
166
+ emb = torch.cat((freqs, freqs), dim=-1)
167
+ cos = emb.cos()
168
+ sin = emb.sin()
169
+
170
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
171
+ cos = cos * self.attention_scaling
172
+ sin = sin * self.attention_scaling
173
+
174
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
175
+
176
+
177
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
178
+ def rotate_half(x):
179
+ """Rotates half the hidden dims of the input."""
180
+ x1 = x[..., : x.shape[-1] // 2]
181
+ x2 = x[..., x.shape[-1] // 2:]
182
+ return torch.cat((-x2, x1), dim=-1)
183
+
184
+
185
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
186
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
187
+ """Applies Rotary Position Embedding to the query and key tensors.
188
+
189
+ Args:
190
+ q (`torch.Tensor`): The query tensor.
191
+ k (`torch.Tensor`): The key tensor.
192
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
193
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
194
+ position_ids (`torch.Tensor`, *optional*):
195
+ Deprecated and unused.
196
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
197
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
198
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
199
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
200
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
201
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
202
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
203
+ Returns:
204
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
205
+ """
206
+ cos = cos.unsqueeze(unsqueeze_dim)
207
+ sin = sin.unsqueeze(unsqueeze_dim)
208
+ q_embed = (q * cos) + (rotate_half(q) * sin)
209
+ k_embed = (k * cos) + (rotate_half(k) * sin)
210
+ return q_embed, k_embed
211
+
212
+ # FFN
213
+ # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen3
214
+ class Qwen3MLP(nn.Module): # 前馈神经网络
215
+ def __init__(self, config):
216
+ super().__init__()
217
+ self.hidden_size = config.hidden_size
218
+ self.intermediate_size = config.intermediate_size
219
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
220
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
221
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
222
+ self.act_fn = ACT2FN[config.hidden_act]
223
+
224
+ def forward(self, hidden_state):
225
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
226
+
227
+
228
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
229
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
230
+ """
231
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
232
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
233
+ """
234
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
235
+ if n_rep == 1:
236
+ return hidden_states
237
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
238
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
239
+
240
+ #============================注意力机制============================
241
+
242
+ class Qwen3Attention(nn.Module):
243
+ """
244
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
245
+ and "Generating Long Sequences with Sparse Transformers".
246
+ """
247
+
248
+ def __init__(self, config: Qwen3Config, layer_idx: Optional[int] = None):
249
+ super().__init__()
250
+ self.config = config
251
+ self.layer_idx = layer_idx
252
+ if layer_idx is None:
253
+ logger.warning_once(
254
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
255
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
256
+ "when creating this class."
257
+ )
258
+
259
+ self.hidden_size = config.hidden_size
260
+ self.num_heads = config.num_attention_heads
261
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
262
+ self.num_key_value_heads = config.num_key_value_heads
263
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
264
+ self.max_position_embeddings = config.max_position_embeddings
265
+ self.rope_theta = config.rope_theta
266
+ self.is_causal = True
267
+ self.attention_dropout = config.attention_dropout
268
+ self.use_qk_norm = config.use_qk_norm
269
+ self.headwise_attn_output_gate = config.headwise_attn_output_gate
270
+ self.elementwise_attn_output_gate = config.elementwise_attn_output_gate
271
+
272
+ # if (self.head_dim * self.num_heads) != self.hidden_size:
273
+ # raise ValueError(
274
+ # f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
275
+ # f" and `num_heads`: {self.num_heads})."
276
+ # )
277
+ if self.headwise_attn_output_gate: # headwise_attention
278
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim + self.num_heads, bias=config.qkv_bias)
279
+ elif self.elementwise_attn_output_gate: # elementwise_attention
280
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim * 2, bias=config.qkv_bias)
281
+ else: # full_attention
282
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.qkv_bias)
283
+
284
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
285
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
286
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.qkv_bias)
287
+ if self.use_qk_norm:
288
+ self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
289
+ self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
290
+
291
+ self.rotary_emb = Qwen3RotaryEmbedding(config=self.config)
292
+
293
+ def forward(
294
+ self,
295
+ hidden_states: torch.Tensor,
296
+ attention_mask: Optional[torch.Tensor] = None,
297
+ position_ids: Optional[torch.LongTensor] = None,
298
+ past_key_value: Optional[Cache] = None,
299
+ output_attentions: bool = False,
300
+ use_cache: bool = False,
301
+ cache_position: Optional[torch.LongTensor] = None,
302
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
303
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
304
+ bsz, q_len, _ = hidden_states.size()
305
+
306
+ query_states = self.q_proj(hidden_states)
307
+ key_states = self.k_proj(hidden_states)
308
+ value_states = self.v_proj(hidden_states)
309
+
310
+ if self.headwise_attn_output_gate: # headwise_attention
311
+ query_states = query_states.view(bsz, q_len, self.num_key_value_heads, -1) # 第一次reshape操作: 将(batch, seq_len, 2064) reshape 成(batch, seq_len, 8, 258)
312
+ query_states, gate_score = torch.split(query_states, [self.head_dim * self.num_key_value_groups, self.num_key_value_groups], dim=-1) # split操作: 将(batch, seq_len, 8, 258) split 成 query_states(batch, seq_len, 8, 256)和gate_score(batch, seq_len, 8, 2)
313
+ gate_score = gate_score.reshape(bsz, q_len, -1, 1)
314
+ query_states = query_states.reshape(bsz, q_len, -1, self.head_dim).transpose(1, 2)
315
+ elif self.elementwise_attn_output_gate:
316
+ query_states = query_states.view(bsz, q_len, self.num_key_value_heads, -1)
317
+ query_states, gate_score = torch.split(query_states, [self.head_dim * self.num_key_value_groups, self.head_dim * self.num_key_value_groups], dim=-1)
318
+ gate_score = gate_score.reshape(bsz, q_len, -1, self.head_dim)
319
+ query_states = query_states.reshape(bsz, q_len, -1, self.head_dim).transpose(1, 2)
320
+ else:
321
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
322
+
323
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
324
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
325
+
326
+
327
+ if self.use_qk_norm:
328
+ query_states = self.q_norm(query_states)
329
+ key_states = self.k_norm(key_states)
330
+
331
+ cos, sin = position_embeddings
332
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
333
+
334
+ if past_key_value is not None:
335
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
336
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
337
+
338
+ # repeat k/v heads if n_kv_heads < n_heads
339
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
340
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
341
+
342
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
343
+ if attention_mask is not None: # no matter the length, we just slice it
344
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
345
+ attn_weights = attn_weights + causal_mask
346
+
347
+ # upcast attention to fp32
348
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
349
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
350
+
351
+ attn_output = torch.matmul(attn_weights, value_states)
352
+
353
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
354
+ raise ValueError(
355
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
356
+ f" {attn_output.size()}"
357
+ )
358
+
359
+ attn_output = attn_output.transpose(1, 2).contiguous()
360
+
361
+ if self.headwise_attn_output_gate or self.elementwise_attn_output_gate:
362
+ attn_output = attn_output * torch.sigmoid(gate_score)
363
+
364
+ attn_output = attn_output.reshape(bsz, q_len, -1)
365
+
366
+ attn_output = self.o_proj(attn_output)
367
+
368
+ if not output_attentions:
369
+ attn_weights = None
370
+
371
+ return attn_output, attn_weights, past_key_value
372
+
373
+
374
+ class Qwen3FlashAttention2(Qwen3Attention): # flash attention
375
+ """
376
+ Qwen3 flash attention module, following Qwen3 attention module. This module inherits from `Qwen3Attention`
377
+ as the weights of the module stays untouched. The only required change would be on the forward pass
378
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
379
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
380
+ config.max_window_layers layers.
381
+ """
382
+
383
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
384
+ def __init__(self, *args, **kwargs):
385
+ super().__init__(*args, **kwargs)
386
+
387
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
388
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
389
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
390
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
391
+
392
+ def forward(
393
+ self,
394
+ hidden_states: torch.Tensor,
395
+ attention_mask: Optional[torch.Tensor] = None,
396
+ position_ids: Optional[torch.LongTensor] = None,
397
+ past_key_value: Optional[Cache] = None,
398
+ output_attentions: bool = False,
399
+ use_cache: bool = False,
400
+ cache_position: Optional[torch.LongTensor] = None,
401
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
402
+ ):
403
+ bsz, q_len, _ = hidden_states.size()
404
+
405
+ query_states = self.q_proj(hidden_states)
406
+ key_states = self.k_proj(hidden_states)
407
+ value_states = self.v_proj(hidden_states)
408
+
409
+ if self.headwise_attn_output_gate:
410
+ query_states = query_states.view(bsz, q_len, self.num_key_value_heads, -1)
411
+ query_states, gate_score = torch.split(query_states, [self.head_dim * self.num_key_value_groups, self.num_key_value_groups], dim=-1)
412
+ gate_score = gate_score.reshape(bsz, q_len, -1, 1)
413
+ query_states = query_states.reshape(bsz, q_len, -1, self.head_dim).transpose(1, 2)
414
+ elif self.elementwise_attn_output_gate:
415
+ query_states = query_states.view(bsz, q_len, self.num_key_value_heads, -1)
416
+ query_states, gate_score = torch.split(query_states, [self.head_dim * self.num_key_value_groups, self.head_dim * self.num_key_value_groups], dim=-1)
417
+ gate_score = gate_score.reshape(bsz, q_len, -1, self.head_dim)
418
+ query_states = query_states.reshape(bsz, q_len, -1, self.head_dim).transpose(1, 2)
419
+ else:
420
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
421
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
422
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
423
+
424
+ if self.use_qk_norm:
425
+ query_states = self.q_norm(query_states)
426
+ key_states = self.k_norm(key_states)
427
+
428
+ cos, sin = position_embeddings
429
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
430
+
431
+ if past_key_value is not None:
432
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
433
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
434
+
435
+ # repeat k/v heads if n_kv_heads < n_heads
436
+ # key_states = repeat_kv(key_states, self.num_key_value_groups)
437
+ # value_states = repeat_kv(value_states, self.num_key_value_groups)
438
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
439
+
440
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
441
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
442
+ # cast them back in float16 just to be sure everything works as expected.
443
+ input_dtype = query_states.dtype
444
+ if input_dtype == torch.float32:
445
+ if torch.is_autocast_enabled():
446
+ target_dtype = torch.get_autocast_gpu_dtype()
447
+ # Handle the case where the model is quantized
448
+ elif hasattr(self.config, "_pre_quantization_dtype"):
449
+ target_dtype = self.config._pre_quantization_dtype
450
+ else:
451
+ target_dtype = self.q_proj.weight.dtype
452
+
453
+ logger.warning_once(
454
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
455
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
456
+ f" {target_dtype}."
457
+ )
458
+
459
+ query_states = query_states.to(target_dtype)
460
+ key_states = key_states.to(target_dtype)
461
+ value_states = value_states.to(target_dtype)
462
+
463
+ # Reashape to the expected shape for Flash Attention
464
+ query_states = query_states.transpose(1, 2)
465
+ key_states = key_states.transpose(1, 2)
466
+ value_states = value_states.transpose(1, 2)
467
+
468
+ if (
469
+ self.config.use_sliding_window
470
+ and getattr(self.config, "sliding_window", None) is not None
471
+ and self.layer_idx >= self.config.max_window_layers
472
+ ):
473
+ sliding_window = self.config.sliding_window
474
+ else:
475
+ sliding_window = None
476
+ attn_output = _flash_attention_forward(
477
+ query_states,
478
+ key_states,
479
+ value_states,
480
+ attention_mask,
481
+ q_len,
482
+ position_ids=position_ids,
483
+ dropout=dropout_rate,
484
+ sliding_window=sliding_window,
485
+ is_causal=self.is_causal,
486
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
487
+ )
488
+
489
+ if self.headwise_attn_output_gate or self.elementwise_attn_output_gate:
490
+ attn_output = attn_output * torch.sigmoid(gate_score)
491
+
492
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
493
+ attn_output = self.o_proj(attn_output)
494
+ if not output_attentions:
495
+ attn_weights = None
496
+
497
+ return attn_output, attn_weights, past_key_value
498
+
499
+
500
+
501
+ class Qwen3SdpaAttention(Qwen3Attention): # sdpa attention
502
+ """
503
+ Qwen3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
504
+ `Qwen3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
505
+ SDPA API.
506
+ """
507
+
508
+ # Adapted from Qwen3Attention.forward
509
+ def forward(
510
+ self,
511
+ hidden_states: torch.Tensor,
512
+ attention_mask: Optional[torch.Tensor] = None,
513
+ position_ids: Optional[torch.LongTensor] = None,
514
+ past_key_value: Optional[Cache] = None,
515
+ output_attentions: bool = False,
516
+ use_cache: bool = False,
517
+ cache_position: Optional[torch.LongTensor] = None,
518
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
519
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
520
+ if output_attentions:
521
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
522
+ logger.warning_once(
523
+ "Qwen3Model is using Qwen3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
524
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
525
+ )
526
+ return super().forward(
527
+ hidden_states=hidden_states,
528
+ attention_mask=attention_mask,
529
+ position_ids=position_ids,
530
+ past_key_value=past_key_value,
531
+ output_attentions=output_attentions,
532
+ use_cache=use_cache,
533
+ cache_position=cache_position,
534
+ position_embeddings=position_embeddings,
535
+ )
536
+
537
+ bsz, q_len, _ = hidden_states.size()
538
+
539
+ query_states = self.q_proj(hidden_states)
540
+ key_states = self.k_proj(hidden_states)
541
+ value_states = self.v_proj(hidden_states)
542
+
543
+ if self.headwise_attn_output_gate:
544
+ query_states = query_states.view(bsz, q_len, self.num_key_value_heads, -1)
545
+ query_states, gate_score = torch.split(query_states, [self.head_dim * self.num_key_value_groups, self.num_key_value_groups], dim=-1)
546
+ gate_score = gate_score.reshape(bsz, q_len, -1, 1)
547
+ query_states = query_states.reshape(bsz, q_len, -1, self.head_dim).transpose(1, 2)
548
+ elif self.elementwise_attn_output_gate:
549
+ query_states = query_states.view(bsz, q_len, self.num_key_value_heads, -1)
550
+ query_states, gate_score = torch.split(query_states, [self.head_dim * self.num_key_value_groups, self.head_dim * self.num_key_value_groups], dim=-1)
551
+ gate_score = gate_score.reshape(bsz, q_len, -1, self.head_dim)
552
+ query_states = query_states.reshape(bsz, q_len, -1, self.head_dim).transpose(1, 2)
553
+ else:
554
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
555
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
556
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
557
+
558
+ if self.use_qk_norm:
559
+ query_states = self.q_norm(query_states)
560
+ key_states = self.k_norm(key_states)
561
+
562
+ cos, sin = position_embeddings
563
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
564
+
565
+ if past_key_value is not None:
566
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
567
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
568
+
569
+ # key_states: bs, head, q_len, head_dim
570
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
571
+
572
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
573
+
574
+ causal_mask = attention_mask
575
+ if attention_mask is not None: # no matter the length, we just slice it
576
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
577
+
578
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
579
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
580
+ if query_states.device.type == "cuda" and attention_mask is not None:
581
+ query_states = query_states.contiguous()
582
+ key_states = key_states.contiguous()
583
+ value_states = value_states.contiguous()
584
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
585
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
586
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
587
+ is_causal = True if causal_mask is None and q_len > 1 else False
588
+
589
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
590
+ query_states,
591
+ key_states,
592
+ value_states,
593
+ attn_mask=causal_mask,
594
+ dropout_p=self.attention_dropout if self.training else 0.0,
595
+ is_causal=is_causal,
596
+ )
597
+
598
+ attn_output = attn_output.transpose(1, 2).contiguous()
599
+
600
+ if self.headwise_attn_output_gate or self.elementwise_attn_output_gate:
601
+ attn_output = attn_output * torch.sigmoid(gate_score)
602
+
603
+ attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim)
604
+
605
+ attn_output = self.o_proj(attn_output)
606
+
607
+ return attn_output, None, past_key_value
608
+
609
+ QWEN3_ATTENTION_CLASSES = {
610
+ "eager": Qwen3Attention,
611
+ "flash_attention_2": Qwen3FlashAttention2,
612
+ "sdpa": Qwen3SdpaAttention,
613
+ }
614
+
615
+ #============================解码器层============================
616
+
617
+ class Qwen3DecoderLayer(nn.Module): # 解码器层
618
+ def __init__(self, config: Qwen3Config, layer_idx: int):
619
+ super().__init__()
620
+ self.hidden_size = config.hidden_size
621
+
622
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
623
+ logger.warning_once(
624
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
625
+ "unexpected results may be encountered."
626
+ )
627
+ # self-attention机制
628
+ self.self_attn = QWEN3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
629
+
630
+ self.mlp = Qwen3MLP(config) # FFN/MLP前馈神经网络
631
+ self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # 归一化
632
+ self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # 归一化
633
+
634
+ # FFN output gate: output = residual + 2*sigmoid(gate(residual)) * FFN(x)
635
+ # 残差不变,gate 控制 FFN 贡献比例,范围 (0,2) 可抑制也可增强
636
+ self.ffn_output_gate = config.ffn_output_gate
637
+ if self.ffn_output_gate:
638
+ self.ffn_gate = nn.Linear(config.hidden_size, 1, bias=True)
639
+ # 标记为 ffn_gate,让 _init_weights 用自定义初始化(weight=0, bias=0.0)
640
+ # 不在这里直接 init:device_map 模式下此时是 meta tensor,init 无效
641
+ self.ffn_gate._is_ffn_gate = True
642
+
643
+ def forward(
644
+ self,
645
+ hidden_states: torch.Tensor,
646
+ attention_mask: Optional[torch.Tensor] = None,
647
+ position_ids: Optional[torch.LongTensor] = None,
648
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
649
+ output_attentions: Optional[bool] = False,
650
+ use_cache: Optional[bool] = False,
651
+ cache_position: Optional[torch.LongTensor] = None,
652
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
653
+ **kwargs,
654
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
655
+ """
656
+ Args:
657
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
658
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
659
+ `(batch, sequence_length)` where padding elements are indicated by 0.
660
+ output_attentions (`bool`, *optional*):
661
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
662
+ returned tensors for more detail.
663
+ use_cache (`bool`, *optional*):
664
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
665
+ (see `past_key_values`).
666
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
667
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
668
+ Indices depicting the position of the input sequence tokens in the sequence.
669
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
670
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
671
+ with `head_dim` being the embedding dimension of each attention head.
672
+ kwargs (`dict`, *optional*):
673
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
674
+ into the model
675
+ """
676
+
677
+ residual = hidden_states
678
+
679
+ hidden_states = self.input_layernorm(hidden_states) # 归一化
680
+
681
+ # Self Attention # 自注意力
682
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
683
+ hidden_states=hidden_states,
684
+ attention_mask=attention_mask,
685
+ position_ids=position_ids,
686
+ past_key_value=past_key_value,
687
+ output_attentions=output_attentions,
688
+ use_cache=use_cache,
689
+ cache_position=cache_position,
690
+ position_embeddings=position_embeddings,
691
+ )
692
+ hidden_states = residual + hidden_states # 残差连接
693
+
694
+ # Fully Connected
695
+ residual = hidden_states
696
+ hidden_states = self.post_attention_layernorm(hidden_states) # 归一化
697
+ hidden_states = self.mlp(hidden_states) # 前馈神经网络(FFN/MLP)
698
+ if self.ffn_output_gate:
699
+ # gate 由残差流计算,保证残差路径畅通不受影响
700
+ # gate shape: (batch, seq_len, 1) → broadcast 到 (batch, seq_len, hidden_size)
701
+ # clamp 到 [0.7, 1.3],与 Exp06 发现的有效范围匹配
702
+ gate = 2.0 * torch.sigmoid(self.ffn_gate(residual))
703
+ gate = gate.clamp(min=0.7, max=1.3)
704
+ hidden_states = residual + gate * hidden_states # 残差连接(gate 控制 FFN 贡献)
705
+ else:
706
+ hidden_states = residual + hidden_states # 残差连接
707
+
708
+ outputs = (hidden_states,)
709
+
710
+ if output_attentions:
711
+ outputs += (self_attn_weights,)
712
+
713
+ if use_cache:
714
+ outputs += (present_key_value,)
715
+
716
+ return outputs
717
+
718
+
719
+ QWEN3_START_DOCSTRING = r"""
720
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
721
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
722
+ etc.)
723
+
724
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
725
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
726
+ and behavior.
727
+
728
+ Parameters:
729
+ config ([`Qwen3Config`]):
730
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
731
+ load the weights associated with the model, only the configuration. Check out the
732
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
733
+ """
734
+
735
+
736
+ @add_start_docstrings(
737
+ "The bare Qwen3 Model outputting raw hidden-states without any specific head on top.",
738
+ QWEN3_START_DOCSTRING,
739
+ )
740
+ class Qwen3PreTrainedModel(PreTrainedModel):
741
+ config_class = Qwen3Config
742
+ base_model_prefix = "model"
743
+ supports_gradient_checkpointing = True
744
+ _no_split_modules = ["Qwen3DecoderLayer"]
745
+ _skip_keys_device_placement = "past_key_values"
746
+ _supports_flash_attn_2 = True
747
+ _supports_sdpa = True
748
+ _supports_cache_class = True
749
+ _supports_quantized_cache = True
750
+ _supports_static_cache = True
751
+
752
+ def _init_weights(self, module):
753
+ std = self.config.initializer_range
754
+ if isinstance(module, nn.Linear):
755
+ if getattr(module, '_is_ffn_gate', False):
756
+ # ffn_gate 专用初始化:weight=0, bias=0.0
757
+ # 2*sigmoid(0.0)=1.0,初始 gate 精确为1,行为与原模型完全一致
758
+ # 梯度为 2*sigmoid'(0)=0.5,是 bias=4.0 方案的28倍,学习更高效
759
+ # 训练后 gate ∈ (0,2),可抑制也可增强 FFN 贡献
760
+ # 必须放在 _init_weights 里,__init__ 中 device_map 场景下是 meta tensor 无法 init
761
+ nn.init.zeros_(module.weight)
762
+ nn.init.constant_(module.bias, 0.0)
763
+ return
764
+ module.weight.data.normal_(mean=0.0, std=std)
765
+ if module.bias is not None:
766
+ module.bias.data.zero_()
767
+ elif isinstance(module, nn.Embedding):
768
+ module.weight.data.normal_(mean=0.0, std=std)
769
+ if module.padding_idx is not None:
770
+ module.weight.data[module.padding_idx].zero_()
771
+
772
+
773
+ QWEN3_INPUTS_DOCSTRING = r"""
774
+ Args:
775
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
776
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
777
+ it.
778
+
779
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
780
+ [`PreTrainedTokenizer.__call__`] for details.
781
+
782
+ [What are input IDs?](../glossary#input-ids)
783
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
784
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
785
+
786
+ - 1 for tokens that are **not masked**,
787
+ - 0 for tokens that are **masked**.
788
+
789
+ [What are attention masks?](../glossary#attention-mask)
790
+
791
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
792
+ [`PreTrainedTokenizer.__call__`] for details.
793
+
794
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
795
+ `past_key_values`).
796
+
797
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
798
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
799
+ information on the default strategy.
800
+
801
+ - 1 indicates the head is **not masked**,
802
+ - 0 indicates the head is **masked**.
803
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
804
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
805
+ config.n_positions - 1]`.
806
+
807
+ [What are position IDs?](../glossary#position-ids)
808
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
809
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
810
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
811
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
812
+
813
+ Two formats are allowed:
814
+ - a [`~cache_utils.Cache`] instance, see our
815
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
816
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
817
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
818
+ cache format.
819
+
820
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
821
+ legacy cache format will be returned.
822
+
823
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
824
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
825
+ of shape `(batch_size, sequence_length)`.
826
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
827
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
828
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
829
+ model's internal embedding lookup matrix.
830
+ use_cache (`bool`, *optional*):
831
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
832
+ `past_key_values`).
833
+ output_attentions (`bool`, *optional*):
834
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
835
+ tensors for more detail.
836
+ output_hidden_states (`bool`, *optional*):
837
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
838
+ more detail.
839
+ return_dict (`bool`, *optional*):
840
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
841
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
842
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
843
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
844
+ the complete sequence length.
845
+ """
846
+
847
+ #============================完整模型============================
848
+ @add_start_docstrings(
849
+ "The bare Qwen3 Model outputting raw hidden-states without any specific head on top.",
850
+ QWEN3_START_DOCSTRING,
851
+ )
852
+ class Qwen3Model(Qwen3PreTrainedModel):
853
+ """
854
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3DecoderLayer`]
855
+
856
+ Args:
857
+ config: Qwen3Config
858
+ """
859
+
860
+ def __init__(self, config: Qwen3Config):
861
+ super().__init__(config)
862
+ self.padding_idx = config.pad_token_id
863
+ self.vocab_size = config.vocab_size
864
+
865
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) # 词嵌入
866
+ self.layers = nn.ModuleList(
867
+ [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] # 解码器层
868
+ )
869
+ self._attn_implementation = config._attn_implementation # 注意力机制实现
870
+ self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # 归一化
871
+ self.rotary_emb = Qwen3RotaryEmbedding(config=config) # 旋转位置编码
872
+
873
+ self.gradient_checkpointing = False
874
+ # Initialize weights and apply final processing
875
+ self.post_init()
876
+
877
+ def get_input_embeddings(self):
878
+ return self.embed_tokens
879
+
880
+ def set_input_embeddings(self, value):
881
+ self.embed_tokens = value
882
+
883
+ @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)
884
+ def forward(
885
+ self,
886
+ input_ids: torch.LongTensor = None,
887
+ attention_mask: Optional[torch.Tensor] = None,
888
+ position_ids: Optional[torch.LongTensor] = None,
889
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
890
+ inputs_embeds: Optional[torch.FloatTensor] = None,
891
+ use_cache: Optional[bool] = None,
892
+ output_attentions: Optional[bool] = None,
893
+ output_hidden_states: Optional[bool] = None,
894
+ return_dict: Optional[bool] = None,
895
+ cache_position: Optional[torch.LongTensor] = None,
896
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
897
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
898
+ output_hidden_states = (
899
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
900
+ )
901
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
902
+
903
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
904
+
905
+ if (input_ids is None) ^ (inputs_embeds is not None):
906
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
907
+
908
+ if self.gradient_checkpointing and self.training:
909
+ if use_cache:
910
+ logger.warning_once(
911
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
912
+ )
913
+ use_cache = False
914
+
915
+ # kept for BC (non `Cache` `past_key_values` inputs)
916
+ return_legacy_cache = False
917
+ if use_cache and not isinstance(past_key_values, Cache):
918
+ return_legacy_cache = True
919
+ if past_key_values is None:
920
+ past_key_values = DynamicCache()
921
+ else:
922
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
923
+ logger.warning_once(
924
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
925
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
926
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
927
+ )
928
+
929
+ if inputs_embeds is None:
930
+ inputs_embeds = self.embed_tokens(input_ids)
931
+
932
+ if cache_position is None:
933
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
934
+ cache_position = torch.arange(
935
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
936
+ )
937
+ if position_ids is None:
938
+ position_ids = cache_position.unsqueeze(0)
939
+
940
+ causal_mask = self._update_causal_mask(
941
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
942
+ )
943
+
944
+ hidden_states = inputs_embeds
945
+
946
+ # create position embeddings to be shared across the decoder layers
947
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
948
+
949
+ # decoder layers
950
+ all_hidden_states = () if output_hidden_states else None
951
+ all_self_attns = () if output_attentions else None
952
+ next_decoder_cache = None
953
+
954
+ for decoder_layer in self.layers:
955
+ if output_hidden_states:
956
+ all_hidden_states += (hidden_states,)
957
+
958
+ if self.gradient_checkpointing and self.training:
959
+ layer_outputs = self._gradient_checkpointing_func(
960
+ decoder_layer.__call__,
961
+ hidden_states,
962
+ causal_mask,
963
+ position_ids,
964
+ past_key_values,
965
+ output_attentions,
966
+ use_cache,
967
+ cache_position,
968
+ position_embeddings,
969
+ )
970
+ else:
971
+ layer_outputs = decoder_layer(
972
+ hidden_states,
973
+ attention_mask=causal_mask,
974
+ position_ids=position_ids,
975
+ past_key_value=past_key_values,
976
+ output_attentions=output_attentions,
977
+ use_cache=use_cache,
978
+ cache_position=cache_position,
979
+ position_embeddings=position_embeddings,
980
+ )
981
+
982
+ hidden_states = layer_outputs[0]
983
+
984
+ if use_cache:
985
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
986
+
987
+ if output_attentions:
988
+ all_self_attns += (layer_outputs[1],)
989
+
990
+ hidden_states = self.norm(hidden_states)
991
+
992
+ # add hidden states from the last decoder layer
993
+ if output_hidden_states:
994
+ all_hidden_states += (hidden_states,)
995
+
996
+ next_cache = next_decoder_cache if use_cache else None
997
+ if return_legacy_cache:
998
+ next_cache = next_cache.to_legacy_cache()
999
+
1000
+ if not return_dict:
1001
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1002
+ return BaseModelOutputWithPast(
1003
+ last_hidden_state=hidden_states,
1004
+ past_key_values=next_cache,
1005
+ hidden_states=all_hidden_states,
1006
+ attentions=all_self_attns,
1007
+ )
1008
+
1009
+ # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask
1010
+ def _update_causal_mask(
1011
+ self,
1012
+ attention_mask: torch.Tensor,
1013
+ input_tensor: torch.Tensor,
1014
+ cache_position: torch.Tensor,
1015
+ past_key_values: Cache,
1016
+ output_attentions: bool,
1017
+ ):
1018
+ if self.config._attn_implementation == "flash_attention_2":
1019
+ if attention_mask is not None and 0.0 in attention_mask:
1020
+ return attention_mask
1021
+ return None
1022
+
1023
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1024
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1025
+ # to infer the attention mask.
1026
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1027
+ using_static_cache = isinstance(past_key_values, StaticCache)
1028
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
1029
+
1030
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1031
+ if (
1032
+ self.config._attn_implementation == "sdpa"
1033
+ and not (using_static_cache or using_sliding_window_cache)
1034
+ and not output_attentions
1035
+ ):
1036
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1037
+ attention_mask,
1038
+ inputs_embeds=input_tensor,
1039
+ past_key_values_length=past_seen_tokens,
1040
+ sliding_window=self.config.sliding_window,
1041
+ is_training=self.training,
1042
+ ):
1043
+ return None
1044
+
1045
+ dtype, device = input_tensor.dtype, input_tensor.device
1046
+ min_dtype = torch.finfo(dtype).min
1047
+ sequence_length = input_tensor.shape[1]
1048
+ # SlidingWindowCache or StaticCache
1049
+ if using_sliding_window_cache or using_static_cache:
1050
+ target_length = past_key_values.get_max_cache_shape()
1051
+ # DynamicCache or no cache
1052
+ else:
1053
+ target_length = (
1054
+ attention_mask.shape[-1]
1055
+ if isinstance(attention_mask, torch.Tensor)
1056
+ else past_seen_tokens + sequence_length + 1
1057
+ )
1058
+
1059
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1060
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1061
+ attention_mask,
1062
+ sequence_length=sequence_length,
1063
+ target_length=target_length,
1064
+ dtype=dtype,
1065
+ device=device,
1066
+ cache_position=cache_position,
1067
+ batch_size=input_tensor.shape[0],
1068
+ config=self.config,
1069
+ past_key_values=past_key_values,
1070
+ )
1071
+
1072
+ if (
1073
+ self.config._attn_implementation == "sdpa"
1074
+ and attention_mask is not None
1075
+ and attention_mask.device.type == "cuda"
1076
+ and not output_attentions
1077
+ ):
1078
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1079
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1080
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1081
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1082
+
1083
+ return causal_mask
1084
+
1085
+ @staticmethod
1086
+ # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen3
1087
+ def _prepare_4d_causal_attention_mask_with_cache_position(
1088
+ attention_mask: torch.Tensor,
1089
+ sequence_length: int,
1090
+ target_length: int,
1091
+ dtype: torch.dtype,
1092
+ device: torch.device,
1093
+ cache_position: torch.Tensor,
1094
+ batch_size: int,
1095
+ config: Qwen3Config,
1096
+ past_key_values: Cache,
1097
+ ):
1098
+ """
1099
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1100
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1101
+
1102
+ Args:
1103
+ attention_mask (`torch.Tensor`):
1104
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
1105
+ sequence_length (`int`):
1106
+ The sequence length being processed.
1107
+ target_length (`int`):
1108
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
1109
+ dtype (`torch.dtype`):
1110
+ The dtype to use for the 4D attention mask.
1111
+ device (`torch.device`):
1112
+ The device to plcae the 4D attention mask on.
1113
+ cache_position (`torch.Tensor`):
1114
+ Indices depicting the position of the input sequence tokens in the sequence.
1115
+ batch_size (`torch.Tensor`):
1116
+ Batch size.
1117
+ config (`Qwen3Config`):
1118
+ The model's configuration class
1119
+ past_key_values (`Cache`):
1120
+ The cache class that is being used currently to generate
1121
+ """
1122
+ if attention_mask is not None and attention_mask.dim() == 4:
1123
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1124
+ causal_mask = attention_mask
1125
+ else:
1126
+ min_dtype = torch.finfo(dtype).min
1127
+ causal_mask = torch.full(
1128
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1129
+ )
1130
+ diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1131
+ if config.sliding_window is not None:
1132
+ # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
1133
+ # the check is needed to verify is current checkpoint was trained with sliding window or not
1134
+ if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
1135
+ sliding_attend_mask = torch.arange(target_length, device=device) <= (
1136
+ cache_position.reshape(-1, 1) - config.sliding_window
1137
+ )
1138
+ diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
1139
+ causal_mask *= diagonal_attend_mask
1140
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1141
+ if attention_mask is not None:
1142
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1143
+ if attention_mask.shape[-1] > target_length:
1144
+ attention_mask = attention_mask[:, :target_length]
1145
+ mask_length = attention_mask.shape[-1]
1146
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
1147
+ padding_mask = padding_mask == 0
1148
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1149
+ padding_mask, min_dtype
1150
+ )
1151
+ return causal_mask
1152
+
1153
+
1154
+ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): # 语言模型
1155
+ _tied_weights_keys = ["lm_head.weight"]
1156
+ _tp_plan = {"lm_head": "colwise_rep"}
1157
+
1158
+ def __init__(self, config):
1159
+ super().__init__(config)
1160
+ self.model = Qwen3Model(config) # 基础模型
1161
+ self.vocab_size = config.vocab_size
1162
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # 输出层(预测下一个token)
1163
+
1164
+ # Initialize weights and apply final processing
1165
+ self.post_init()
1166
+
1167
+ def get_input_embeddings(self):
1168
+ return self.model.embed_tokens
1169
+
1170
+ def set_input_embeddings(self, value):
1171
+ self.model.embed_tokens = value
1172
+
1173
+ def get_output_embeddings(self):
1174
+ return self.lm_head
1175
+
1176
+ def set_output_embeddings(self, new_embeddings):
1177
+ self.lm_head = new_embeddings
1178
+
1179
+ def set_decoder(self, decoder):
1180
+ self.model = decoder
1181
+
1182
+ def get_decoder(self):
1183
+ return self.model
1184
+
1185
+ @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)
1186
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1187
+ def forward(
1188
+ self,
1189
+ input_ids: torch.LongTensor = None,
1190
+ attention_mask: Optional[torch.Tensor] = None,
1191
+ position_ids: Optional[torch.LongTensor] = None,
1192
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1193
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1194
+ labels: Optional[torch.LongTensor] = None,
1195
+ use_cache: Optional[bool] = None,
1196
+ output_attentions: Optional[bool] = None,
1197
+ output_hidden_states: Optional[bool] = None,
1198
+ return_dict: Optional[bool] = None,
1199
+ cache_position: Optional[torch.LongTensor] = None,
1200
+ num_logits_to_keep: int = 0,
1201
+ **loss_kwargs,
1202
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1203
+ r"""
1204
+ Args:
1205
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1206
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1207
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1208
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1209
+
1210
+ num_logits_to_keep (`int`, *optional*):
1211
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1212
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1213
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1214
+
1215
+ Returns:
1216
+
1217
+ Example:
1218
+
1219
+ ```python
1220
+ >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
1221
+
1222
+ >>> model = Qwen3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1223
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1224
+
1225
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1226
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1227
+
1228
+ >>> # Generate
1229
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1230
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1231
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1232
+ ```"""
1233
+
1234
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1235
+ output_hidden_states = (
1236
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1237
+ )
1238
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1239
+
1240
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1241
+ outputs = self.model(
1242
+ input_ids=input_ids,
1243
+ attention_mask=attention_mask,
1244
+ position_ids=position_ids,
1245
+ past_key_values=past_key_values,
1246
+ inputs_embeds=inputs_embeds,
1247
+ use_cache=use_cache,
1248
+ output_attentions=output_attentions,
1249
+ output_hidden_states=output_hidden_states,
1250
+ return_dict=return_dict,
1251
+ cache_position=cache_position,
1252
+ )
1253
+
1254
+ hidden_states = outputs[0]
1255
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1256
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1257
+
1258
+ loss = None
1259
+ if labels is not None:
1260
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
1261
+
1262
+ if not return_dict:
1263
+ output = (logits,) + outputs[1:]
1264
+ return (loss,) + output if loss is not None else output
1265
+
1266
+ return CausalLMOutputWithPast(
1267
+ loss=loss,
1268
+ logits=logits,
1269
+ past_key_values=outputs.past_key_values,
1270
+ hidden_states=outputs.hidden_states,
1271
+ attentions=outputs.attentions,
1272
+ )
1273
+
1274
+ #============================任务特定模型============================
1275
+ @add_start_docstrings(
1276
+ """
1277
+ The Qwen3 Model transformer with a sequence classification head on top (linear layer).
1278
+
1279
+ [`Qwen3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1280
+ (e.g. GPT-2) do.
1281
+
1282
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1283
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1284
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1285
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1286
+ each row of the batch).
1287
+ """,
1288
+ QWEN3_START_DOCSTRING,
1289
+ )
1290
+ class Qwen3ForSequenceClassification(Qwen3PreTrainedModel): # 文本分类
1291
+ def __init__(self, config):
1292
+ super().__init__(config)
1293
+ self.num_labels = config.num_labels
1294
+ self.model = Qwen3Model(config)
1295
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1296
+
1297
+ # Initialize weights and apply final processing
1298
+ self.post_init()
1299
+
1300
+ def get_input_embeddings(self):
1301
+ return self.model.embed_tokens
1302
+
1303
+ def set_input_embeddings(self, value):
1304
+ self.model.embed_tokens = value
1305
+
1306
+ @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)
1307
+ def forward(
1308
+ self,
1309
+ input_ids: torch.LongTensor = None,
1310
+ attention_mask: Optional[torch.Tensor] = None,
1311
+ position_ids: Optional[torch.LongTensor] = None,
1312
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1313
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1314
+ labels: Optional[torch.LongTensor] = None,
1315
+ use_cache: Optional[bool] = None,
1316
+ output_attentions: Optional[bool] = None,
1317
+ output_hidden_states: Optional[bool] = None,
1318
+ return_dict: Optional[bool] = None,
1319
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1320
+ r"""
1321
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1322
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1323
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1324
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1325
+ """
1326
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1327
+
1328
+ transformer_outputs = self.model(
1329
+ input_ids,
1330
+ attention_mask=attention_mask,
1331
+ position_ids=position_ids,
1332
+ past_key_values=past_key_values,
1333
+ inputs_embeds=inputs_embeds,
1334
+ use_cache=use_cache,
1335
+ output_attentions=output_attentions,
1336
+ output_hidden_states=output_hidden_states,
1337
+ return_dict=return_dict,
1338
+ )
1339
+ hidden_states = transformer_outputs[0]
1340
+ logits = self.score(hidden_states)
1341
+
1342
+ if input_ids is not None:
1343
+ batch_size = input_ids.shape[0]
1344
+ else:
1345
+ batch_size = inputs_embeds.shape[0]
1346
+
1347
+ if self.config.pad_token_id is None and batch_size != 1:
1348
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1349
+ if self.config.pad_token_id is None:
1350
+ sequence_lengths = -1
1351
+ else:
1352
+ if input_ids is not None:
1353
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1354
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1355
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1356
+ sequence_lengths = sequence_lengths.to(logits.device)
1357
+ else:
1358
+ sequence_lengths = -1
1359
+
1360
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1361
+
1362
+ loss = None
1363
+ if labels is not None:
1364
+ labels = labels.to(logits.device)
1365
+ if self.config.problem_type is None:
1366
+ if self.num_labels == 1:
1367
+ self.config.problem_type = "regression"
1368
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1369
+ self.config.problem_type = "single_label_classification"
1370
+ else:
1371
+ self.config.problem_type = "multi_label_classification"
1372
+
1373
+ if self.config.problem_type == "regression":
1374
+ loss_fct = MSELoss()
1375
+ if self.num_labels == 1:
1376
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1377
+ else:
1378
+ loss = loss_fct(pooled_logits, labels)
1379
+ elif self.config.problem_type == "single_label_classification":
1380
+ loss_fct = CrossEntropyLoss()
1381
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1382
+ elif self.config.problem_type == "multi_label_classification":
1383
+ loss_fct = BCEWithLogitsLoss()
1384
+ loss = loss_fct(pooled_logits, labels)
1385
+ if not return_dict:
1386
+ output = (pooled_logits,) + transformer_outputs[1:]
1387
+ return ((loss,) + output) if loss is not None else output
1388
+
1389
+ return SequenceClassifierOutputWithPast(
1390
+ loss=loss,
1391
+ logits=pooled_logits,
1392
+ past_key_values=transformer_outputs.past_key_values,
1393
+ hidden_states=transformer_outputs.hidden_states,
1394
+ attentions=transformer_outputs.attentions,
1395
+ )
1396
+
1397
+
1398
+ @add_start_docstrings(
1399
+ """
1400
+ The Qwen3 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
1401
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
1402
+ """,
1403
+ QWEN3_START_DOCSTRING,
1404
+ )
1405
+ # Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen3, LLAMA->QWEN3
1406
+ class Qwen3ForTokenClassification(Qwen3PreTrainedModel): # 序列标注
1407
+ def __init__(self, config):
1408
+ super().__init__(config)
1409
+ self.num_labels = config.num_labels
1410
+ self.model = Qwen3Model(config)
1411
+ if getattr(config, "classifier_dropout", None) is not None:
1412
+ classifier_dropout = config.classifier_dropout
1413
+ elif getattr(config, "hidden_dropout", None) is not None:
1414
+ classifier_dropout = config.hidden_dropout
1415
+ else:
1416
+ classifier_dropout = 0.1
1417
+ self.dropout = nn.Dropout(classifier_dropout)
1418
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1419
+
1420
+ # Initialize weights and apply final processing
1421
+ self.post_init()
1422
+
1423
+ def get_input_embeddings(self):
1424
+ return self.model.embed_tokens
1425
+
1426
+ def set_input_embeddings(self, value):
1427
+ self.model.embed_tokens = value
1428
+
1429
+ @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)
1430
+ @add_code_sample_docstrings(
1431
+ checkpoint=_CHECKPOINT_FOR_DOC,
1432
+ output_type=TokenClassifierOutput,
1433
+ config_class=_CONFIG_FOR_DOC,
1434
+ )
1435
+ def forward(
1436
+ self,
1437
+ input_ids: Optional[torch.LongTensor] = None,
1438
+ attention_mask: Optional[torch.Tensor] = None,
1439
+ position_ids: Optional[torch.LongTensor] = None,
1440
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1441
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1442
+ labels: Optional[torch.LongTensor] = None,
1443
+ use_cache: Optional[bool] = None,
1444
+ output_attentions: Optional[bool] = None,
1445
+ output_hidden_states: Optional[bool] = None,
1446
+ return_dict: Optional[bool] = None,
1447
+ ) -> Union[Tuple, TokenClassifierOutput]:
1448
+ r"""
1449
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1450
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1451
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1452
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1453
+ """
1454
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1455
+
1456
+ outputs = self.model(
1457
+ input_ids,
1458
+ attention_mask=attention_mask,
1459
+ position_ids=position_ids,
1460
+ past_key_values=past_key_values,
1461
+ inputs_embeds=inputs_embeds,
1462
+ use_cache=use_cache,
1463
+ output_attentions=output_attentions,
1464
+ output_hidden_states=output_hidden_states,
1465
+ return_dict=return_dict,
1466
+ )
1467
+ sequence_output = outputs[0]
1468
+ sequence_output = self.dropout(sequence_output)
1469
+ logits = self.score(sequence_output)
1470
+
1471
+ loss = None
1472
+ if labels is not None:
1473
+ loss = self.loss_function(logits, labels, self.config)
1474
+
1475
+ if not return_dict:
1476
+ output = (logits,) + outputs[2:]
1477
+ return ((loss,) + output) if loss is not None else output
1478
+
1479
+ return TokenClassifierOutput(
1480
+ loss=loss,
1481
+ logits=logits,
1482
+ hidden_states=outputs.hidden_states,
1483
+ attentions=outputs.attentions,
1484
+ )
1485
+
1486
+
1487
+ @add_start_docstrings(
1488
+ """
1489
+ The Qwen3 Model transformer with a span classification head on top for extractive question-answering tasks like
1490
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1491
+ """,
1492
+ QWEN3_START_DOCSTRING,
1493
+ )
1494
+ # Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen3, MISTRAL->QWEN3
1495
+ class Qwen3ForQuestionAnswering(Qwen3PreTrainedModel): # 问答任务
1496
+ base_model_prefix = "model"
1497
+
1498
+ # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Qwen3
1499
+ def __init__(self, config):
1500
+ super().__init__(config)
1501
+ self.model = Qwen3Model(config)
1502
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1503
+
1504
+ # Initialize weights and apply final processing
1505
+ self.post_init()
1506
+
1507
+ def get_input_embeddings(self):
1508
+ return self.model.embed_tokens
1509
+
1510
+ def set_input_embeddings(self, value):
1511
+ self.model.embed_tokens = value
1512
+
1513
+ @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)
1514
+ def forward(
1515
+ self,
1516
+ input_ids: Optional[torch.LongTensor] = None,
1517
+ attention_mask: Optional[torch.FloatTensor] = None,
1518
+ position_ids: Optional[torch.LongTensor] = None,
1519
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1520
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1521
+ start_positions: Optional[torch.LongTensor] = None,
1522
+ end_positions: Optional[torch.LongTensor] = None,
1523
+ output_attentions: Optional[bool] = None,
1524
+ output_hidden_states: Optional[bool] = None,
1525
+ return_dict: Optional[bool] = None,
1526
+ **kwargs,
1527
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1528
+ r"""
1529
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1530
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1531
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1532
+ are not taken into account for computing the loss.
1533
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1534
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1535
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1536
+ are not taken into account for computing the loss.
1537
+ """
1538
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1539
+
1540
+ outputs = self.model(
1541
+ input_ids,
1542
+ attention_mask=attention_mask,
1543
+ position_ids=position_ids,
1544
+ past_key_values=past_key_values,
1545
+ inputs_embeds=inputs_embeds,
1546
+ output_attentions=output_attentions,
1547
+ output_hidden_states=output_hidden_states,
1548
+ return_dict=return_dict,
1549
+ )
1550
+
1551
+ sequence_output = outputs[0]
1552
+
1553
+ logits = self.qa_outputs(sequence_output)
1554
+ start_logits, end_logits = logits.split(1, dim=-1)
1555
+ start_logits = start_logits.squeeze(-1).contiguous()
1556
+ end_logits = end_logits.squeeze(-1).contiguous()
1557
+
1558
+ loss = None
1559
+ if start_positions is not None and end_positions is not None:
1560
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
1561
+
1562
+ if not return_dict:
1563
+ output = (start_logits, end_logits) + outputs[2:]
1564
+ return ((loss,) + output) if loss is not None else output
1565
+
1566
+ return QuestionAnsweringModelOutput(
1567
+ loss=loss,
1568
+ start_logits=start_logits,
1569
+ end_logits=end_logits,
1570
+ hidden_states=outputs.hidden_states,
1571
+ attentions=outputs.attentions,
1572
+ )