| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Callable, Optional, Tuple, Union |
|
|
| import torch |
| from torch import nn |
|
|
| from transformers.activations import ACT2FN |
| from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache |
| from transformers.generation import GenerationMixin |
| from transformers.integrations import use_kernel_forward_from_hub |
| from transformers.modeling_attn_mask_utils import AttentionMaskConverter |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| from transformers.modeling_layers import GradientCheckpointingLayer |
| from transformers.modeling_outputs import ( |
| BaseModelOutputWithPast, |
| CausalLMOutputWithPast, |
| QuestionAnsweringModelOutput, |
| SequenceClassifierOutputWithPast, |
| TokenClassifierOutput, |
| ) |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
| from transformers.processing_utils import Unpack |
| from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging |
| from .configuration_qwen3 import Qwen3Config |
| from .diffusion_cache_utils import DiffusionDynamicCache |
|
|
| if is_torch_flex_attn_available(): |
| from torch.nn.attention.flex_attention import BlockMask |
|
|
| from transformers.integrations.flex_attention import make_flex_block_causal_mask |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| @use_kernel_forward_from_hub("RMSNorm") |
| class Qwen3RMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| """ |
| Qwen3RMSNorm is equivalent to T5LayerNorm |
| """ |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
| def extra_repr(self): |
| return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
|
|
| class Qwen3MLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, x): |
| down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
| return down_proj |
|
|
|
|
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| Args: |
| q (`torch.Tensor`): The query tensor. |
| k (`torch.Tensor`): The key tensor. |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| sin (`torch.Tensor`): The sine part of the rotary embedding. |
| position_ids (`torch.Tensor`, *optional*): |
| Deprecated and unused. |
| unsqueeze_dim (`int`, *optional*, defaults to 1): |
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| Returns: |
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| """ |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| def eager_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| scaling: float, |
| dropout: float = 0.0, |
| **kwargs, |
| ): |
| key_states = repeat_kv(key, module.num_key_value_groups) |
| value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
| if attention_mask is not None: |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| attn_weights = attn_weights + causal_mask |
|
|
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
| attn_output = torch.matmul(attn_weights, value_states) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
| return attn_output, attn_weights |
|
|
|
|
| class Qwen3Attention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__(self, config: Qwen3Config, layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
| self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
| self.scaling = self.head_dim**-0.5 |
| self.attention_dropout = config.attention_dropout |
| self.is_causal = True |
|
|
| self.q_proj = nn.Linear( |
| config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.k_proj = nn.Linear( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.v_proj = nn.Linear( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
| self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| self.sliding_window = config.sliding_window |
| if not ( |
| self.config.use_sliding_window |
| and getattr(self.config, "sliding_window", None) is not None |
| and self.layer_idx >= self.config.max_window_layers |
| ): |
| self.sliding_window = None |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: Tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor], |
| past_key_value: Optional[Cache] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
| query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) |
| key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) |
| value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
| if past_key_value is not None: |
| |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
| attention_interface: Callable = eager_attention_forward |
| if self.config._attn_implementation != "eager": |
| if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): |
| logger.warning_once( |
| "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " |
| 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' |
| ) |
| else: |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| sliding_window=self.sliding_window, |
| **kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
| return attn_output, attn_weights |
|
|
|
|
| class Qwen3DecoderLayer(GradientCheckpointingLayer): |
| def __init__(self, config: Qwen3Config, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) |
| self.mlp = Qwen3MLP(config) |
| self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| if ( |
| config.sliding_window and config._attn_implementation != "flash_attention_2" |
| ): |
| logger.warning_once( |
| f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " |
| "unexpected results may be encountered." |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
|
|
| |
| hidden_states, self_attn_weights = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| outputs = (hidden_states,) |
| if output_attentions: |
| outputs += (self_attn_weights,) |
|
|
| return outputs |
|
|
|
|
| @auto_docstring |
| class Qwen3PreTrainedModel(PreTrainedModel): |
| config_class = Qwen3Config |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["Qwen3DecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| _supports_cache_class = True |
| _supports_quantized_cache = True |
| _supports_static_cache = True |
| _supports_attention_backend = True |
|
|
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| elif isinstance(module, Qwen3RMSNorm): |
| module.weight.data.fill_(1.0) |
|
|
|
|
| class Qwen3RotaryEmbedding(nn.Module): |
| def __init__(self, config: Qwen3Config, device=None): |
| super().__init__() |
| |
| if hasattr(config, "rope_scaling") and config.rope_scaling is not None: |
| self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
| else: |
| self.rope_type = "default" |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
|
|
| self.config = config |
| self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
| inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.original_inv_freq = self.inv_freq |
|
|
| @torch.no_grad() |
| @dynamic_rope_update |
| def forward(self, x, position_ids): |
| inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
| position_ids_expanded = position_ids[:, None, :].float() |
|
|
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() * self.attention_scaling |
| sin = emb.sin() * self.attention_scaling |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
| @auto_docstring |
| class Qwen3Model(Qwen3PreTrainedModel): |
| def __init__(self, config: Qwen3Config): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| self.layers = nn.ModuleList( |
| [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| ) |
| self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.rotary_emb = Qwen3RotaryEmbedding(config=config) |
| self.gradient_checkpointing = False |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.embed_tokens = value |
|
|
| @can_return_tuple |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **flash_attn_kwargs: Unpack[FlashAttentionKwargs], |
| ) -> BaseModelOutputWithPast: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
| if self.gradient_checkpointing and self.training and use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
| ) |
| use_cache = False |
|
|
| |
| if not isinstance(past_key_values, (type(None), Cache)): |
| raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = DiffusionDynamicCache() |
|
|
| if cache_position is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| cache_position = torch.arange( |
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
| ) |
|
|
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| causal_mask = self._update_causal_mask( |
| attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
| ) |
|
|
| hidden_states = inputs_embeds |
|
|
| |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
|
|
| for decoder_layer in self.layers[: self.config.num_hidden_layers]: |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| **flash_attn_kwargs, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values if use_cache else None, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
|
|
| def _update_causal_mask( |
| self, |
| attention_mask: Union[torch.Tensor, "BlockMask"], |
| input_tensor: torch.Tensor, |
| cache_position: torch.Tensor, |
| past_key_values: Cache, |
| output_attentions: bool = False, |
| ): |
| if self.config._attn_implementation == "flash_attention_2": |
| if attention_mask is not None and past_key_values is not None: |
| is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] |
| if is_padding_right: |
| raise ValueError( |
| "You are attempting to perform batched generation with padding_side='right'" |
| " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to " |
| " call `tokenizer.padding_side = 'left'` before tokenizing the input. " |
| ) |
| if attention_mask is not None and 0.0 in attention_mask: |
| return attention_mask |
| return None |
| if self.config._attn_implementation == "flex_attention": |
| if isinstance(attention_mask, torch.Tensor): |
| attention_mask = make_flex_block_causal_mask(attention_mask) |
| return attention_mask |
|
|
| |
| |
| |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| using_static_cache = isinstance(past_key_values, StaticCache) |
| using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) |
|
|
| |
| if ( |
| self.config._attn_implementation == "sdpa" |
| and not (using_static_cache or using_sliding_window_cache) |
| and not output_attentions |
| ): |
| if AttentionMaskConverter._ignore_causal_mask_sdpa( |
| attention_mask, |
| inputs_embeds=input_tensor, |
| past_key_values_length=past_seen_tokens, |
| sliding_window=self.config.sliding_window, |
| is_training=self.training, |
| ): |
| return None |
|
|
| dtype = input_tensor.dtype |
| min_dtype = torch.finfo(dtype).min |
| sequence_length = input_tensor.shape[1] |
| |
| if using_sliding_window_cache or using_static_cache: |
| target_length = past_key_values.get_max_cache_shape() |
| |
| else: |
| target_length = ( |
| attention_mask.shape[-1] |
| if isinstance(attention_mask, torch.Tensor) |
| else past_seen_tokens + sequence_length + 1 |
| ) |
|
|
| |
| causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask, |
| sequence_length=sequence_length, |
| target_length=target_length, |
| dtype=dtype, |
| cache_position=cache_position, |
| batch_size=input_tensor.shape[0], |
| config=self.config, |
| past_key_values=past_key_values, |
| ) |
|
|
| if ( |
| self.config._attn_implementation == "sdpa" |
| and attention_mask is not None |
| and attention_mask.device.type in ["cuda", "xpu", "npu"] |
| and not output_attentions |
| ): |
| |
| |
| |
| causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
|
|
| return causal_mask |
|
|
| @staticmethod |
| def _prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask: torch.Tensor, |
| sequence_length: int, |
| target_length: int, |
| dtype: torch.dtype, |
| cache_position: torch.Tensor, |
| batch_size: int, |
| config: Qwen3Config, |
| past_key_values: Cache, |
| ): |
| """ |
| Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| |
| Args: |
| attention_mask (`torch.Tensor`): |
| A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. |
| sequence_length (`int`): |
| The sequence length being processed. |
| target_length (`int`): |
| The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. |
| dtype (`torch.dtype`): |
| The dtype to use for the 4D attention mask. |
| cache_position (`torch.Tensor`): |
| Indices depicting the position of the input sequence tokens in the sequence. |
| batch_size (`torch.Tensor`): |
| Batch size. |
| config (`Qwen3Config`): |
| The model's configuration class |
| past_key_values (`Cache`): |
| The cache class that is being used currently to generate |
| """ |
| if attention_mask is not None and attention_mask.dim() == 4: |
| |
| causal_mask = attention_mask |
| else: |
| min_dtype = torch.finfo(dtype).min |
| causal_mask = torch.full( |
| (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device |
| ) |
| diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( |
| -1, 1 |
| ) |
| text_config = config.get_text_config() |
| if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: |
| |
| |
| if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: |
| sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( |
| cache_position.reshape(-1, 1) - text_config.sliding_window |
| ) |
| diagonal_attend_mask.bitwise_or_(sliding_attend_mask) |
| causal_mask *= diagonal_attend_mask |
| causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| if attention_mask is not None: |
| causal_mask = causal_mask.clone() |
| if attention_mask.shape[-1] > target_length: |
| attention_mask = attention_mask[:, :target_length] |
| mask_length = attention_mask.shape[-1] |
| padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( |
| causal_mask.device |
| ) |
| padding_mask = padding_mask == 0 |
| causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| padding_mask, min_dtype |
| ) |
| return causal_mask |
|
|
|
|
| class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... |
|
|
|
|
| def forward_process(input_ids, attention_mask, prompt_lengths, mask_token_id, eps=1e-3): |
| import random |
| IGNORE_INDEX = -100 |
| slot_size_set = [4, 8, 16, 32] |
| device = input_ids.device |
|
|
| batch_size, seq_length = input_ids.shape |
|
|
| input_ids = input_ids.tolist() |
| prompt_lengths = prompt_lengths.squeeze(1).tolist() |
| total_lengths = attention_mask.sum(dim=1).tolist() |
|
|
| pro_input_ids = [] |
| pro_labels = [] |
| pro_masked_indices = [] |
| pro_p_masks = [] |
| pro_answer_lengths = [] |
| pro_position_ids = [] |
|
|
| for i in range(batch_size): |
| slot_size = random.choice(slot_size_set) |
| prompt_len = prompt_lengths[i] |
| total_len = int(total_lengths[i]) |
| pad_len = seq_length - total_len |
|
|
| input_id = input_ids[i][:prompt_len] |
| end_id = input_ids[i][prompt_len:total_len] |
| |
| input_label = [IGNORE_INDEX] * len(input_id) |
|
|
| |
| answer_slots = [ |
| end_id[i:i + slot_size] |
| for i in range(0, len(end_id), slot_size) |
| ] |
|
|
| num_answer_slots = len(answer_slots) |
| if num_answer_slots == 0: |
| continue |
|
|
| |
| input_position_id = list(range(len(input_id))) |
| end_position_id = list(range(len(input_id), len(input_id) + len(end_id))) |
| answer_position_slots = [ |
| end_position_id[i:i + slot_size] |
| for i in range(0, len(end_position_id), slot_size) |
| ] |
|
|
| |
| t = random.random() |
| p_mask = (1 - eps) * t + eps |
| slot_mask = [random.random() < p_mask for _ in range(num_answer_slots)] |
|
|
| |
| unmasked_indices = [i for i, masked in enumerate(slot_mask) if not masked] |
| masked_indices = [i for i, masked in enumerate(slot_mask) if masked] |
| |
| |
| random.shuffle(unmasked_indices) |
|
|
| |
| final_end_id = [] |
| final_answer_label = [] |
| final_masked_indice = [] |
| final_position_id = [] |
|
|
| answer_length = len(end_id) |
|
|
| |
| for slot_idx in unmasked_indices: |
| slot_content = answer_slots[slot_idx] |
| final_end_id.extend(slot_content) |
| |
| ar_label = slot_content[1:] + [IGNORE_INDEX] |
| final_answer_label.extend(ar_label) |
|
|
| final_masked_indice.extend([False]*(len(slot_content))) |
|
|
| final_position_id.extend(answer_position_slots[slot_idx]) |
|
|
| |
| for slot_idx in masked_indices: |
| slot_content = answer_slots[slot_idx] |
| final_end_id.extend([mask_token_id] * len(slot_content)) |
| |
| final_answer_label.extend(slot_content) |
|
|
| final_masked_indice.extend([True]*len(slot_content)) |
|
|
| final_position_id.extend(answer_position_slots[slot_idx]) |
| |
| final_input = input_id + final_end_id + input_ids[i][total_len:] |
| final_label = input_label + final_answer_label + [IGNORE_INDEX] * pad_len |
| final_masked_indice = [False] * len(input_id) + final_masked_indice + [False] * pad_len |
| final_position_id = input_position_id + final_position_id + list(range(total_len, seq_length)) |
|
|
| assert len(final_input) == len(final_label), f"{len(final_input)}, {len(final_label)}" |
| assert len(final_input) == len(final_masked_indice), f"{len(final_input)}, {len(final_masked_indice)}" |
| assert len(final_input) == len(final_position_id), f"{len(final_input)}, {len(final_position_id)}" |
|
|
| |
| pro_input_ids.append(torch.tensor(final_input)) |
| pro_labels.append(torch.tensor(final_label)) |
| pro_masked_indices.append(torch.tensor(final_masked_indice)) |
| pro_p_masks.append(torch.tensor(p_mask)) |
| pro_answer_lengths.append(torch.tensor(answer_length)) |
| pro_position_ids.append(torch.tensor(final_position_id)) |
|
|
| pro_p_masks = torch.stack(pro_p_masks).view(-1, 1).repeat(1, seq_length) |
| pro_answer_lengths = torch.stack(pro_answer_lengths).view(-1, 1).repeat(1, seq_length) |
|
|
| return torch.stack(pro_input_ids).to(device), torch.stack(pro_labels).to(device), torch.stack(pro_masked_indices).to(device), pro_p_masks.to(device), pro_answer_lengths.to(device), torch.stack(pro_position_ids).to(device) |
|
|
| @auto_docstring |
| class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["lm_head.weight"] |
| _tp_plan = {"lm_head": "colwise_rep"} |
| _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = Qwen3Model(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| self.mask_token_id = config.mask_token_id |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.model = decoder |
|
|
| def get_decoder(self): |
| return self.model |
|
|
| @can_return_tuple |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| prompt_lengths: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| **kwargs: Unpack[KwargsForCausalLM], |
| ) -> CausalLMOutputWithPast: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, Qwen3ForCausalLM |
| |
| >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B") |
| >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") |
| |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| ```""" |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
|
|
| if labels is not None: |
| input_ids, labels, masked_indices, p_mask, answer_lengths, position_ids = forward_process(input_ids, attention_mask, prompt_lengths, self.mask_token_id) |
|
|
| |
| outputs: BaseModelOutputWithPast = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
| loss = None |
| if labels is not None: |
| com_logits = logits.float() |
| |
| com_logits = com_logits.view(-1, self.config.vocab_size) |
| labels = labels.view(-1) |
| masked_indices = masked_indices.view(-1) |
| p_mask = p_mask.view(-1) |
| answer_lengths = answer_lengths.view(-1) |
| labels = labels.to(com_logits.device) |
|
|
| |
| AR_indices = ~masked_indices |
| AR_loss = nn.functional.cross_entropy(com_logits[AR_indices], labels[AR_indices], ignore_index=-100, reduction="mean") |
|
|
| |
| MDM_token_loss = nn.functional.cross_entropy(com_logits[masked_indices], labels[masked_indices], ignore_index=-100, reduction='none') / p_mask[masked_indices] |
| MDM_loss = torch.sum(MDM_token_loss / answer_lengths[masked_indices]) / input_ids.shape[0] |
|
|
| loss = AR_loss + MDM_loss |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| @auto_docstring( |
| custom_intro=""" |
| The Qwen3 Model transformer with a sequence classification head on top (linear layer). |
| |
| [`Qwen3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models |
| (e.g. GPT-2) do. |
| |
| Since it does classification on the last token, it requires to know the position of the last token. If a |
| `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If |
| no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the |
| padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in |
| each row of the batch). |
| """ |
| ) |
| class Qwen3ForSequenceClassification(Qwen3PreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.model = Qwen3Model(config) |
| self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| @can_return_tuple |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| ) -> SequenceClassifierOutputWithPast: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
| config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| """ |
|
|
| transformer_outputs: BaseModelOutputWithPast = self.model( |
| input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| ) |
| hidden_states = transformer_outputs.last_hidden_state |
| logits = self.score(hidden_states) |
|
|
| if input_ids is not None: |
| batch_size = input_ids.shape[0] |
| else: |
| batch_size = inputs_embeds.shape[0] |
|
|
| if self.config.pad_token_id is None and batch_size != 1: |
| raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") |
| if self.config.pad_token_id is None: |
| last_non_pad_token = -1 |
| elif input_ids is not None: |
| |
| non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) |
| token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) |
| last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) |
| else: |
| last_non_pad_token = -1 |
| logger.warning_once( |
| f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " |
| "unexpected if using padding tokens in conjunction with `inputs_embeds.`" |
| ) |
|
|
| pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) |
|
|
| return SequenceClassifierOutputWithPast( |
| loss=loss, |
| logits=pooled_logits, |
| past_key_values=transformer_outputs.past_key_values, |
| hidden_states=transformer_outputs.hidden_states, |
| attentions=transformer_outputs.attentions, |
| ) |
|
|
|
|
| @auto_docstring |
| class Qwen3ForTokenClassification(Qwen3PreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.model = Qwen3Model(config) |
| if getattr(config, "classifier_dropout", None) is not None: |
| classifier_dropout = config.classifier_dropout |
| elif getattr(config, "hidden_dropout", None) is not None: |
| classifier_dropout = config.hidden_dropout |
| else: |
| classifier_dropout = 0.1 |
| self.dropout = nn.Dropout(classifier_dropout) |
| self.score = nn.Linear(config.hidden_size, config.num_labels) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| @can_return_tuple |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| ) -> TokenClassifierOutput: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
| config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| """ |
|
|
| outputs: BaseModelOutputWithPast = self.model( |
| input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| ) |
| sequence_output = outputs.last_hidden_state |
| sequence_output = self.dropout(sequence_output) |
| logits = self.score(sequence_output) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits, labels, self.config) |
|
|
| return TokenClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| @auto_docstring |
| class Qwen3ForQuestionAnswering(Qwen3PreTrainedModel): |
| base_model_prefix = "transformer" |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.transformer = Qwen3Model(config) |
| self.qa_outputs = nn.Linear(config.hidden_size, 2) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.transformer.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.transformer.embed_tokens = value |
|
|
| @can_return_tuple |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| start_positions: Optional[torch.LongTensor] = None, |
| end_positions: Optional[torch.LongTensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| **kwargs, |
| ) -> QuestionAnsweringModelOutput: |
| outputs: BaseModelOutputWithPast = self.transformer( |
| input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| ) |
|
|
| sequence_output = outputs.last_hidden_state |
|
|
| logits = self.qa_outputs(sequence_output) |
| start_logits, end_logits = logits.split(1, dim=-1) |
| start_logits = start_logits.squeeze(-1).contiguous() |
| end_logits = end_logits.squeeze(-1).contiguous() |
|
|
| loss = None |
| if start_positions is not None and end_positions is not None: |
| loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) |
|
|
| return QuestionAnsweringModelOutput( |
| loss=loss, |
| start_logits=start_logits, |
| end_logits=end_logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| __all__ = [ |
| "Qwen3ForCausalLM", |
| "Qwen3ForQuestionAnswering", |
| "Qwen3Model", |
| "Qwen3PreTrainedModel", |
| "Qwen3ForSequenceClassification", |
| "Qwen3ForTokenClassification", |
| ] |
|
|