maxoul commited on
Commit
1231748
·
verified ·
1 Parent(s): 57bdcb8

Create modeling_qwen3_bidir.py

Browse files
Files changed (1) hide show
  1. modeling_qwen3_bidir.py +960 -0
modeling_qwen3_bidir.py ADDED
@@ -0,0 +1,960 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # Adapted from https://github.com/huggingface/transformers/blob/v4.51.2/src/transformers/models/qwen3/modeling_qwen3.py
3
+ ###
4
+
5
+ from functools import partial
6
+ from typing import Callable, Optional, Tuple, Union
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+ from transformers.activations import ACT2FN
12
+ from transformers.cache_utils import (
13
+ Cache,
14
+ DynamicCache,
15
+ SlidingWindowCache,
16
+ StaticCache,
17
+ )
18
+ from transformers.generation import GenerationMixin
19
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
20
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
21
+ from transformers.modeling_outputs import (
22
+ BaseModelOutputWithPast,
23
+ CausalLMOutputWithPast,
24
+ )
25
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
26
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
27
+ from transformers.processing_utils import Unpack
28
+ from transformers.utils import (
29
+ LossKwargs,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ can_return_tuple,
33
+ logging,
34
+ replace_return_docstrings,
35
+ )
36
+ from transformers.utils.deprecation import deprecate_kwarg
37
+ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ _CHECKPOINT_FOR_DOC = "Qwen/Qwen3-8B"
43
+ _CONFIG_FOR_DOC = "Qwen3Config"
44
+
45
+
46
+ class Qwen3RMSNorm(nn.Module):
47
+ def __init__(self, hidden_size, eps=1e-6):
48
+ """
49
+ Qwen3RMSNorm is equivalent to T5LayerNorm
50
+ """
51
+ super().__init__()
52
+ self.weight = nn.Parameter(torch.ones(hidden_size))
53
+ self.variance_epsilon = eps
54
+
55
+ def forward(self, hidden_states):
56
+ input_dtype = hidden_states.dtype
57
+ hidden_states = hidden_states.to(torch.float32)
58
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
59
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
60
+ return self.weight * hidden_states.to(input_dtype)
61
+
62
+ def extra_repr(self):
63
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
64
+
65
+
66
+ class Qwen3MLP(nn.Module):
67
+ def __init__(self, config):
68
+ super().__init__()
69
+ self.config = config
70
+ self.hidden_size = config.hidden_size
71
+ self.intermediate_size = config.intermediate_size
72
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
73
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
74
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
75
+ self.act_fn = ACT2FN[config.hidden_act]
76
+
77
+ def forward(self, x):
78
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
79
+ return down_proj
80
+
81
+
82
+ def rotate_half(x):
83
+ """Rotates half the hidden dims of the input."""
84
+ x1 = x[..., : x.shape[-1] // 2]
85
+ x2 = x[..., x.shape[-1] // 2 :]
86
+ return torch.cat((-x2, x1), dim=-1)
87
+
88
+
89
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
90
+ """Applies Rotary Position Embedding to the query and key tensors.
91
+ Args:
92
+ q (`torch.Tensor`): The query tensor.
93
+ k (`torch.Tensor`): The key tensor.
94
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
95
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
96
+ position_ids (`torch.Tensor`, *optional*):
97
+ Deprecated and unused.
98
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
99
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
100
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
101
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
102
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
103
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
104
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
105
+ Returns:
106
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
107
+ """
108
+ cos = cos.unsqueeze(unsqueeze_dim)
109
+ sin = sin.unsqueeze(unsqueeze_dim)
110
+ q_embed = (q * cos) + (rotate_half(q) * sin)
111
+ k_embed = (k * cos) + (rotate_half(k) * sin)
112
+ return q_embed, k_embed
113
+
114
+
115
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
116
+ """
117
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
118
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
119
+ """
120
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
121
+ if n_rep == 1:
122
+ return hidden_states
123
+ hidden_states = hidden_states[:, :, None, :, :].expand(
124
+ batch, num_key_value_heads, n_rep, slen, head_dim
125
+ )
126
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
127
+
128
+
129
+ def eager_attention_forward(
130
+ module: nn.Module,
131
+ query: torch.Tensor,
132
+ key: torch.Tensor,
133
+ value: torch.Tensor,
134
+ attention_mask: Optional[torch.Tensor],
135
+ scaling: float,
136
+ dropout: float = 0.0,
137
+ **kwargs,
138
+ ):
139
+ key_states = repeat_kv(key, module.num_key_value_groups)
140
+ value_states = repeat_kv(value, module.num_key_value_groups)
141
+
142
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
143
+ if attention_mask is not None:
144
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
145
+ attn_weights = attn_weights + causal_mask
146
+
147
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
148
+ query.dtype
149
+ )
150
+ attn_weights = nn.functional.dropout(
151
+ attn_weights, p=dropout, training=module.training
152
+ )
153
+ attn_output = torch.matmul(attn_weights, value_states)
154
+ attn_output = attn_output.transpose(1, 2).contiguous()
155
+
156
+ return attn_output, attn_weights
157
+
158
+
159
+ class Qwen3BidirAttention(nn.Module):
160
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
161
+
162
+ def __init__(self, config: Qwen3Config, layer_idx: int):
163
+ super().__init__()
164
+ self.config = config
165
+ self.layer_idx = layer_idx
166
+ self.head_dim = getattr(
167
+ config, "head_dim", config.hidden_size // config.num_attention_heads
168
+ )
169
+ self.num_key_value_groups = (
170
+ config.num_attention_heads // config.num_key_value_heads
171
+ )
172
+ self.scaling = self.head_dim**-0.5
173
+ self.attention_dropout = config.attention_dropout
174
+ self.is_causal = False
175
+
176
+ self.q_proj = nn.Linear(
177
+ config.hidden_size,
178
+ config.num_attention_heads * self.head_dim,
179
+ bias=config.attention_bias,
180
+ )
181
+ self.k_proj = nn.Linear(
182
+ config.hidden_size,
183
+ config.num_key_value_heads * self.head_dim,
184
+ bias=config.attention_bias,
185
+ )
186
+ self.v_proj = nn.Linear(
187
+ config.hidden_size,
188
+ config.num_key_value_heads * self.head_dim,
189
+ bias=config.attention_bias,
190
+ )
191
+ self.o_proj = nn.Linear(
192
+ config.num_attention_heads * self.head_dim,
193
+ config.hidden_size,
194
+ bias=config.attention_bias,
195
+ )
196
+ self.q_norm = Qwen3RMSNorm(
197
+ self.head_dim, eps=config.rms_norm_eps
198
+ ) # unlike olmo, only on the head dim!
199
+ self.k_norm = Qwen3RMSNorm(
200
+ self.head_dim, eps=config.rms_norm_eps
201
+ ) # thus post q_norm does not need reshape
202
+ self.sliding_window = config.sliding_window
203
+ if not (
204
+ self.config.use_sliding_window
205
+ and getattr(self.config, "sliding_window", None) is not None
206
+ and self.layer_idx >= self.config.max_window_layers
207
+ ):
208
+ self.sliding_window = None
209
+
210
+ def forward(
211
+ self,
212
+ hidden_states: torch.Tensor,
213
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
214
+ attention_mask: Optional[torch.Tensor],
215
+ past_key_value: Optional[Cache] = None,
216
+ cache_position: Optional[torch.LongTensor] = None,
217
+ **kwargs: Unpack[FlashAttentionKwargs],
218
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
219
+ input_shape = hidden_states.shape[:-1]
220
+ hidden_shape = (*input_shape, -1, self.head_dim)
221
+
222
+ query_states = self.q_norm(
223
+ self.q_proj(hidden_states).view(hidden_shape)
224
+ ).transpose(1, 2)
225
+ key_states = self.k_norm(
226
+ self.k_proj(hidden_states).view(hidden_shape)
227
+ ).transpose(1, 2)
228
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
229
+
230
+ cos, sin = position_embeddings
231
+ query_states, key_states = apply_rotary_pos_emb(
232
+ query_states, key_states, cos, sin
233
+ )
234
+
235
+ if past_key_value is not None:
236
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
237
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
238
+ key_states, value_states = past_key_value.update(
239
+ key_states, value_states, self.layer_idx, cache_kwargs
240
+ )
241
+
242
+ attention_interface: Callable = eager_attention_forward
243
+ if self.config._attn_implementation != "eager":
244
+ if self.config._attn_implementation == "sdpa" and kwargs.get(
245
+ "output_attentions", False
246
+ ):
247
+ logger.warning_once(
248
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
249
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
250
+ )
251
+ else:
252
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
253
+ self.config._attn_implementation
254
+ ]
255
+
256
+ attn_output, attn_weights = attention_interface(
257
+ self,
258
+ query_states,
259
+ key_states,
260
+ value_states,
261
+ attention_mask,
262
+ dropout=0.0 if not self.training else self.attention_dropout,
263
+ scaling=self.scaling,
264
+ sliding_window=self.sliding_window, # diff with Llama
265
+ **kwargs,
266
+ )
267
+
268
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
269
+ attn_output = self.o_proj(attn_output)
270
+ return attn_output, attn_weights
271
+
272
+
273
+ class Qwen3BidirDecoderLayer(nn.Module):
274
+ def __init__(self, config: Qwen3Config, layer_idx: int):
275
+ super().__init__()
276
+ self.hidden_size = config.hidden_size
277
+ self.self_attn = Qwen3BidirAttention(config=config, layer_idx=layer_idx)
278
+ self.mlp = Qwen3MLP(config)
279
+ self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
280
+ self.post_attention_layernorm = Qwen3RMSNorm(
281
+ config.hidden_size, eps=config.rms_norm_eps
282
+ )
283
+ if (
284
+ config.sliding_window and config._attn_implementation != "flash_attention_2"
285
+ ): # diff with Llama is this warning
286
+ logger.warning_once(
287
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
288
+ "unexpected results may be encountered."
289
+ )
290
+
291
+ def forward(
292
+ self,
293
+ hidden_states: torch.Tensor,
294
+ attention_mask: Optional[torch.Tensor] = None,
295
+ position_ids: Optional[torch.LongTensor] = None,
296
+ past_key_value: Optional[Cache] = None,
297
+ output_attentions: Optional[bool] = False,
298
+ use_cache: Optional[bool] = False,
299
+ cache_position: Optional[torch.LongTensor] = None,
300
+ position_embeddings: Optional[
301
+ Tuple[torch.Tensor, torch.Tensor]
302
+ ] = None, # necessary, but kept here for BC
303
+ **kwargs: Unpack[FlashAttentionKwargs],
304
+ ) -> Tuple[
305
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
306
+ ]:
307
+ residual = hidden_states
308
+
309
+ hidden_states = self.input_layernorm(hidden_states)
310
+
311
+ # Self Attention
312
+ hidden_states, self_attn_weights = self.self_attn(
313
+ hidden_states=hidden_states,
314
+ attention_mask=attention_mask,
315
+ position_ids=position_ids,
316
+ past_key_value=past_key_value,
317
+ output_attentions=output_attentions,
318
+ use_cache=use_cache,
319
+ cache_position=cache_position,
320
+ position_embeddings=position_embeddings,
321
+ **kwargs,
322
+ )
323
+ hidden_states = residual + hidden_states
324
+
325
+ # Fully Connected
326
+ residual = hidden_states
327
+ hidden_states = self.post_attention_layernorm(hidden_states)
328
+ hidden_states = self.mlp(hidden_states)
329
+ hidden_states = residual + hidden_states
330
+
331
+ outputs = (hidden_states,)
332
+ if output_attentions:
333
+ outputs += (self_attn_weights,)
334
+
335
+ return outputs
336
+
337
+
338
+ class Qwen3RotaryEmbedding(nn.Module):
339
+ def __init__(self, config: Qwen3Config, device=None):
340
+ super().__init__()
341
+ # BC: "rope_type" was originally "type"
342
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
343
+ self.rope_type = config.rope_scaling.get(
344
+ "rope_type", config.rope_scaling.get("type")
345
+ )
346
+ else:
347
+ self.rope_type = "default"
348
+ self.max_seq_len_cached = config.max_position_embeddings
349
+ self.original_max_seq_len = config.max_position_embeddings
350
+
351
+ self.config = config
352
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
353
+
354
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
355
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
356
+ self.original_inv_freq = self.inv_freq
357
+
358
+ @torch.no_grad()
359
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
360
+ def forward(self, x, position_ids):
361
+ inv_freq_expanded = (
362
+ self.inv_freq[None, :, None]
363
+ .float()
364
+ .expand(position_ids.shape[0], -1, 1)
365
+ .to(x.device)
366
+ )
367
+ position_ids_expanded = position_ids[:, None, :].float()
368
+
369
+ device_type = (
370
+ x.device.type
371
+ if isinstance(x.device.type, str) and x.device.type != "mps"
372
+ else "cpu"
373
+ )
374
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
375
+ freqs = (
376
+ inv_freq_expanded.float() @ position_ids_expanded.float()
377
+ ).transpose(1, 2)
378
+ emb = torch.cat((freqs, freqs), dim=-1)
379
+ cos = emb.cos() * self.attention_scaling
380
+ sin = emb.sin() * self.attention_scaling
381
+
382
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
383
+
384
+
385
+ QWEN3_START_DOCSTRING = r"""
386
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
387
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
388
+ etc.)
389
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
390
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
391
+ and behavior.
392
+ Parameters:
393
+ config ([`Qwen3Config`]):
394
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
395
+ load the weights associated with the model, only the configuration. Check out the
396
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
397
+ """
398
+
399
+
400
+ @add_start_docstrings(
401
+ "The bare Qwen3 Model outputting raw hidden-states without any specific head on top.",
402
+ QWEN3_START_DOCSTRING,
403
+ )
404
+ class Qwen3PreTrainedModel(PreTrainedModel):
405
+ config_class = Qwen3Config
406
+ base_model_prefix = "model"
407
+ supports_gradient_checkpointing = True
408
+ _no_split_modules = ["Qwen3DecoderLayer"]
409
+ _skip_keys_device_placement = ["past_key_values"]
410
+ _supports_flash_attn_2 = True
411
+ _supports_sdpa = True
412
+ _supports_flex_attn = True
413
+ _supports_cache_class = True
414
+ _supports_quantized_cache = True
415
+ _supports_static_cache = True
416
+ _supports_attention_backend = True
417
+
418
+ def _init_weights(self, module):
419
+ std = self.config.initializer_range
420
+ if isinstance(module, nn.Linear):
421
+ module.weight.data.normal_(mean=0.0, std=std)
422
+ if module.bias is not None:
423
+ module.bias.data.zero_()
424
+ elif isinstance(module, nn.Embedding):
425
+ module.weight.data.normal_(mean=0.0, std=std)
426
+ if module.padding_idx is not None:
427
+ module.weight.data[module.padding_idx].zero_()
428
+
429
+
430
+ QWEN3_INPUTS_DOCSTRING = r"""
431
+ Args:
432
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
433
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
434
+ it.
435
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
436
+ [`PreTrainedTokenizer.__call__`] for details.
437
+ [What are input IDs?](../glossary#input-ids)
438
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
439
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
440
+ - 1 for tokens that are **not masked**,
441
+ - 0 for tokens that are **masked**.
442
+ [What are attention masks?](../glossary#attention-mask)
443
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
444
+ [`PreTrainedTokenizer.__call__`] for details.
445
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
446
+ `past_key_values`).
447
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
448
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
449
+ information on the default strategy.
450
+ - 1 indicates the head is **not masked**,
451
+ - 0 indicates the head is **masked**.
452
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
453
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
454
+ config.n_positions - 1]`.
455
+ [What are position IDs?](../glossary#position-ids)
456
+ past_key_values (`Cache`, *optional*):
457
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
458
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
459
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
460
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
461
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
462
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
463
+ of shape `(batch_size, sequence_length)`.
464
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
465
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
466
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
467
+ model's internal embedding lookup matrix.
468
+ use_cache (`bool`, *optional*):
469
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
470
+ `past_key_values`).
471
+ output_attentions (`bool`, *optional*):
472
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
473
+ tensors for more detail.
474
+ output_hidden_states (`bool`, *optional*):
475
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
476
+ more detail.
477
+ return_dict (`bool`, *optional*):
478
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
479
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
480
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
481
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
482
+ the complete sequence length.
483
+ """
484
+
485
+
486
+ @add_start_docstrings(
487
+ "The bare Qwen3 Model outputting raw hidden-states without any specific head on top.",
488
+ QWEN3_START_DOCSTRING,
489
+ )
490
+ class Qwen3BidirModel(Qwen3PreTrainedModel):
491
+ """
492
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3DecoderLayer`]
493
+ Args:
494
+ config: Qwen3Config
495
+ """
496
+
497
+ def __init__(self, config: Qwen3Config):
498
+ super().__init__(config)
499
+ self.padding_idx = config.pad_token_id
500
+ self.vocab_size = config.vocab_size
501
+
502
+ self.embed_tokens = nn.Embedding(
503
+ config.vocab_size, config.hidden_size, self.padding_idx
504
+ )
505
+ self.layers = nn.ModuleList(
506
+ [
507
+ Qwen3BidirDecoderLayer(config, layer_idx)
508
+ for layer_idx in range(config.num_hidden_layers)
509
+ ]
510
+ )
511
+ self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
512
+ self.rotary_emb = Qwen3RotaryEmbedding(config=config)
513
+ self.gradient_checkpointing = False
514
+
515
+ # Initialize weights and apply final processing
516
+ self.post_init()
517
+
518
+ def get_input_embeddings(self):
519
+ return self.embed_tokens
520
+
521
+ def set_input_embeddings(self, value):
522
+ self.embed_tokens = value
523
+
524
+ @can_return_tuple
525
+ @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)
526
+ def forward(
527
+ self,
528
+ input_ids: Optional[torch.LongTensor] = None,
529
+ attention_mask: Optional[torch.Tensor] = None,
530
+ position_ids: Optional[torch.LongTensor] = None,
531
+ past_key_values: Optional[Cache] = None,
532
+ inputs_embeds: Optional[torch.FloatTensor] = None,
533
+ use_cache: Optional[bool] = None,
534
+ output_attentions: Optional[bool] = None,
535
+ output_hidden_states: Optional[bool] = None,
536
+ cache_position: Optional[torch.LongTensor] = None,
537
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
538
+ ) -> BaseModelOutputWithPast:
539
+ output_attentions = (
540
+ output_attentions
541
+ if output_attentions is not None
542
+ else self.config.output_attentions
543
+ )
544
+ output_hidden_states = (
545
+ output_hidden_states
546
+ if output_hidden_states is not None
547
+ else self.config.output_hidden_states
548
+ )
549
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
550
+
551
+ if (input_ids is None) ^ (inputs_embeds is not None):
552
+ raise ValueError(
553
+ "You must specify exactly one of input_ids or inputs_embeds"
554
+ )
555
+
556
+ if self.gradient_checkpointing and self.training and use_cache:
557
+ logger.warning_once(
558
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
559
+ )
560
+ use_cache = False
561
+
562
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
563
+ if not isinstance(past_key_values, (type(None), Cache)):
564
+ raise ValueError(
565
+ "The `past_key_values` should be either a `Cache` object or `None`."
566
+ )
567
+
568
+ if inputs_embeds is None:
569
+ inputs_embeds = self.embed_tokens(input_ids)
570
+
571
+ if use_cache and past_key_values is None:
572
+ past_key_values = DynamicCache()
573
+
574
+ if cache_position is None:
575
+ past_seen_tokens = (
576
+ past_key_values.get_seq_length() if past_key_values is not None else 0
577
+ )
578
+ cache_position = torch.arange(
579
+ past_seen_tokens,
580
+ past_seen_tokens + inputs_embeds.shape[1],
581
+ device=inputs_embeds.device,
582
+ )
583
+
584
+ if position_ids is None:
585
+ position_ids = cache_position.unsqueeze(0)
586
+
587
+ causal_mask = self._update_causal_mask(
588
+ attention_mask,
589
+ inputs_embeds,
590
+ cache_position,
591
+ past_key_values,
592
+ output_attentions,
593
+ )
594
+
595
+ hidden_states = inputs_embeds
596
+
597
+ # create position embeddings to be shared across the decoder layers
598
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
599
+
600
+ # decoder layers
601
+ all_hidden_states = () if output_hidden_states else None
602
+ all_self_attns = () if output_attentions else None
603
+
604
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
605
+ if output_hidden_states:
606
+ all_hidden_states += (hidden_states,)
607
+
608
+ if self.gradient_checkpointing and self.training:
609
+ layer_outputs = self._gradient_checkpointing_func(
610
+ partial(decoder_layer.__call__, **flash_attn_kwargs),
611
+ hidden_states,
612
+ causal_mask,
613
+ position_ids,
614
+ past_key_values,
615
+ output_attentions,
616
+ use_cache,
617
+ cache_position,
618
+ position_embeddings,
619
+ )
620
+ else:
621
+ layer_outputs = decoder_layer(
622
+ hidden_states,
623
+ attention_mask=causal_mask,
624
+ position_ids=position_ids,
625
+ past_key_value=past_key_values,
626
+ output_attentions=output_attentions,
627
+ use_cache=use_cache,
628
+ cache_position=cache_position,
629
+ position_embeddings=position_embeddings,
630
+ **flash_attn_kwargs,
631
+ )
632
+
633
+ hidden_states = layer_outputs[0]
634
+
635
+ if output_attentions:
636
+ all_self_attns += (layer_outputs[1],)
637
+
638
+ hidden_states = self.norm(hidden_states)
639
+
640
+ # add hidden states from the last decoder layer
641
+ if output_hidden_states:
642
+ all_hidden_states += (hidden_states,)
643
+
644
+ return BaseModelOutputWithPast(
645
+ last_hidden_state=hidden_states,
646
+ past_key_values=past_key_values if use_cache else None,
647
+ hidden_states=all_hidden_states,
648
+ attentions=all_self_attns,
649
+ )
650
+
651
+ def _update_causal_mask(
652
+ self,
653
+ attention_mask: torch.Tensor,
654
+ input_tensor: torch.Tensor,
655
+ cache_position: torch.Tensor,
656
+ past_key_values: Cache,
657
+ output_attentions: bool = False,
658
+ ):
659
+ if self.config._attn_implementation == "flash_attention_2":
660
+ if attention_mask is not None and past_key_values is not None:
661
+ valid_rows = attention_mask.sum(dim=1) > 0
662
+
663
+ if valid_rows.any():
664
+ # Only check right-padding on non-empty rows
665
+ right_padded_rows = attention_mask[valid_rows, -1] == 0
666
+ is_padding_right = right_padded_rows.any().item()
667
+ if is_padding_right:
668
+ raise ValueError(
669
+ "You are attempting to perform batched generation with padding_side='right'. "
670
+ "This may lead to unexpected behaviour for Flash Attention version of Qwen3. "
671
+ "Make sure to call `tokenizer.padding_side = 'left'` before tokenizing the input."
672
+ )
673
+ # is_padding_right = (
674
+ # attention_mask[:, -1].sum().item() != input_tensor.size()[0]
675
+ # )
676
+ # if is_padding_right:
677
+ # raise ValueError(
678
+ # "You are attempting to perform batched generation with padding_side='right'"
679
+ # " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to "
680
+ # " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
681
+ # )
682
+ if attention_mask is not None and 0.0 in attention_mask:
683
+ return attention_mask
684
+ return None
685
+
686
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
687
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
688
+ # to infer the attention mask.
689
+ past_seen_tokens = (
690
+ past_key_values.get_seq_length() if past_key_values is not None else 0
691
+ )
692
+ using_static_cache = isinstance(past_key_values, StaticCache)
693
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
694
+
695
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
696
+ if (
697
+ self.config._attn_implementation == "sdpa"
698
+ and not (using_static_cache or using_sliding_window_cache)
699
+ and not output_attentions
700
+ ):
701
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
702
+ attention_mask,
703
+ inputs_embeds=input_tensor,
704
+ past_key_values_length=past_seen_tokens,
705
+ sliding_window=self.config.sliding_window,
706
+ is_training=self.training,
707
+ ):
708
+ return None
709
+
710
+ dtype, device = input_tensor.dtype, input_tensor.device
711
+ min_dtype = torch.finfo(dtype).min
712
+ sequence_length = input_tensor.shape[1]
713
+ # SlidingWindowCache or StaticCache
714
+ if using_sliding_window_cache or using_static_cache:
715
+ target_length = past_key_values.get_max_cache_shape()
716
+ # DynamicCache or no cache
717
+ else:
718
+ target_length = (
719
+ attention_mask.shape[-1]
720
+ if isinstance(attention_mask, torch.Tensor)
721
+ else past_seen_tokens + sequence_length + 1
722
+ )
723
+
724
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
725
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
726
+ attention_mask,
727
+ sequence_length=sequence_length,
728
+ target_length=target_length,
729
+ dtype=dtype,
730
+ device=device,
731
+ cache_position=cache_position,
732
+ batch_size=input_tensor.shape[0],
733
+ config=self.config,
734
+ past_key_values=past_key_values,
735
+ )
736
+
737
+ if (
738
+ self.config._attn_implementation == "sdpa"
739
+ and attention_mask is not None
740
+ and attention_mask.device.type in ["cuda", "xpu"]
741
+ and not output_attentions
742
+ ):
743
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
744
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
745
+ # Details: https://github.com/pytorch/pytorch/issues/110213
746
+ causal_mask = AttentionMaskConverter._unmask_unattended(
747
+ causal_mask, min_dtype
748
+ )
749
+
750
+ return causal_mask
751
+
752
+ @staticmethod
753
+ def _prepare_4d_causal_attention_mask_with_cache_position(
754
+ attention_mask: torch.Tensor,
755
+ sequence_length: int,
756
+ target_length: int,
757
+ dtype: torch.dtype,
758
+ device: torch.device,
759
+ cache_position: torch.Tensor,
760
+ batch_size: int,
761
+ config: Qwen3Config,
762
+ past_key_values: Cache,
763
+ ):
764
+ """
765
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
766
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
767
+ Args:
768
+ attention_mask (`torch.Tensor`):
769
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
770
+ sequence_length (`int`):
771
+ The sequence length being processed.
772
+ target_length (`int`):
773
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
774
+ dtype (`torch.dtype`):
775
+ The dtype to use for the 4D attention mask.
776
+ device (`torch.device`):
777
+ The device to place the 4D attention mask on.
778
+ cache_position (`torch.Tensor`):
779
+ Indices depicting the position of the input sequence tokens in the sequence.
780
+ batch_size (`torch.Tensor`):
781
+ Batch size.
782
+ config (`Qwen3Config`):
783
+ The model's configuration class
784
+ past_key_values (`Cache`):
785
+ The cache class that is being used currently to generate
786
+ """
787
+ if attention_mask is not None and attention_mask.dim() == 4:
788
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
789
+ causal_mask = attention_mask
790
+ else:
791
+ min_dtype = torch.finfo(dtype).min
792
+ causal_mask = torch.full(
793
+ (sequence_length, target_length),
794
+ fill_value=min_dtype,
795
+ dtype=dtype,
796
+ device=device,
797
+ )
798
+ diagonal_attend_mask = torch.arange(
799
+ target_length, device=device
800
+ ) > cache_position.reshape(-1, 1)
801
+ if config.sliding_window is not None:
802
+ # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
803
+ # the check is needed to verify is current checkpoint was trained with sliding window or not
804
+ if (
805
+ not isinstance(past_key_values, SlidingWindowCache)
806
+ or sequence_length > target_length
807
+ ):
808
+ sliding_attend_mask = torch.arange(
809
+ target_length, device=device
810
+ ) <= (cache_position.reshape(-1, 1) - config.sliding_window)
811
+ diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
812
+ causal_mask *= diagonal_attend_mask
813
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
814
+ if attention_mask is not None:
815
+ causal_mask = (
816
+ causal_mask.clone()
817
+ ) # copy to contiguous memory for in-place edit
818
+ if attention_mask.shape[-1] > target_length:
819
+ attention_mask = attention_mask[:, :target_length]
820
+ mask_length = attention_mask.shape[-1]
821
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
822
+ :, None, None, :
823
+ ].to(causal_mask.device)
824
+ padding_mask = padding_mask == 0
825
+ causal_mask[:, :, :, :mask_length] = causal_mask[
826
+ :, :, :, :mask_length
827
+ ].masked_fill(padding_mask, min_dtype)
828
+ return causal_mask
829
+
830
+
831
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
832
+
833
+
834
+ class Qwen3BidirForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
835
+ _tied_weights_keys = ["lm_head.weight"]
836
+ _tp_plan = {"lm_head": "colwise_rep"}
837
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
838
+
839
+ def __init__(self, config):
840
+ super().__init__(config)
841
+ self.model = Qwen3BidirModel(config)
842
+ self.vocab_size = config.vocab_size
843
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
844
+
845
+ # Initialize weights and apply final processing
846
+ self.post_init()
847
+
848
+ def get_input_embeddings(self):
849
+ return self.model.embed_tokens
850
+
851
+ def set_input_embeddings(self, value):
852
+ self.model.embed_tokens = value
853
+
854
+ def get_output_embeddings(self):
855
+ return self.lm_head
856
+
857
+ def set_output_embeddings(self, new_embeddings):
858
+ self.lm_head = new_embeddings
859
+
860
+ def set_decoder(self, decoder):
861
+ self.model = decoder
862
+
863
+ def get_decoder(self):
864
+ return self.model
865
+
866
+ @can_return_tuple
867
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
868
+ @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)
869
+ @replace_return_docstrings(
870
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
871
+ )
872
+ def forward(
873
+ self,
874
+ input_ids: Optional[torch.LongTensor] = None,
875
+ attention_mask: Optional[torch.Tensor] = None,
876
+ position_ids: Optional[torch.LongTensor] = None,
877
+ past_key_values: Optional[Cache] = None,
878
+ inputs_embeds: Optional[torch.FloatTensor] = None,
879
+ labels: Optional[torch.LongTensor] = None,
880
+ use_cache: Optional[bool] = None,
881
+ output_attentions: Optional[bool] = None,
882
+ output_hidden_states: Optional[bool] = None,
883
+ cache_position: Optional[torch.LongTensor] = None,
884
+ logits_to_keep: Union[int, torch.Tensor] = 0,
885
+ **kwargs: Unpack[KwargsForCausalLM],
886
+ ) -> CausalLMOutputWithPast:
887
+ r"""
888
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
889
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
890
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
891
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
892
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
893
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
894
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
895
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
896
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
897
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
898
+ Returns:
899
+ Example:
900
+ ```python
901
+ >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
902
+ >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
903
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
904
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
905
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
906
+ >>> # Generate
907
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
908
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
909
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
910
+ ```"""
911
+ output_attentions = (
912
+ output_attentions
913
+ if output_attentions is not None
914
+ else self.config.output_attentions
915
+ )
916
+ output_hidden_states = (
917
+ output_hidden_states
918
+ if output_hidden_states is not None
919
+ else self.config.output_hidden_states
920
+ )
921
+
922
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
923
+ outputs: BaseModelOutputWithPast = self.model(
924
+ input_ids=input_ids,
925
+ attention_mask=attention_mask,
926
+ position_ids=position_ids,
927
+ past_key_values=past_key_values,
928
+ inputs_embeds=inputs_embeds,
929
+ use_cache=use_cache,
930
+ output_attentions=output_attentions,
931
+ output_hidden_states=output_hidden_states,
932
+ cache_position=cache_position,
933
+ **kwargs,
934
+ )
935
+
936
+ hidden_states = outputs.last_hidden_state
937
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
938
+ slice_indices = (
939
+ slice(-logits_to_keep, None)
940
+ if isinstance(logits_to_keep, int)
941
+ else logits_to_keep
942
+ )
943
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
944
+
945
+ loss = None
946
+ if labels is not None:
947
+ loss = self.loss_function(
948
+ logits=logits,
949
+ labels=labels,
950
+ vocab_size=self.config.vocab_size,
951
+ **kwargs,
952
+ )
953
+
954
+ return CausalLMOutputWithPast(
955
+ loss=loss,
956
+ logits=logits,
957
+ past_key_values=outputs.past_key_values,
958
+ hidden_states=outputs.hidden_states,
959
+ attentions=outputs.attentions,
960
+ )