| from typing import Callable, Optional, Union |
|
|
| import torch |
| from torch import nn |
|
|
| import copy |
| from transformers.activations import ACT2FN |
| from transformers.cache_utils import Cache, DynamicCache |
| from transformers.generation import GenerationMixin |
| from transformers.integrations import use_kernel_forward_from_hub |
| from transformers.masking_utils import create_causal_mask |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| from transformers.modeling_layers import ( |
| GenericForQuestionAnswering, |
| GenericForSequenceClassification, |
| GenericForTokenClassification, |
| GradientCheckpointingLayer, |
| ) |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
| from transformers.processing_utils import Unpack |
| from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple |
| from transformers.utils.deprecation import deprecate_kwarg |
| from transformers.utils.generic import check_model_inputs |
| from transformers import Qwen3Config |
|
|
|
|
| def create_block_causal_mask(index: torch.Tensor): |
| """ |
| index: (L) |
| return: (1, 1, L, L) block-wise causal attention mask |
| """ |
| L = index.size(0) |
| idx_i = index.unsqueeze(1).expand(L, L) |
| idx_j = index.unsqueeze(0).expand(L, L) |
|
|
| arange = torch.arange(L, device=index.device) |
| mask = (idx_j == idx_i) | (arange.unsqueeze(0) <= arange.unsqueeze(1)) |
|
|
| return torch.where(mask[None, None, :, :] > 0, torch.tensor(0.0), torch.tensor(float('-inf'))) |
|
|
|
|
| def visualize_mask(mask: torch.Tensor, i: int = 0, j: int = 12): |
| """ |
| mask: (1,1, L, L) |
| """ |
| submask = torch.where(mask[0, 0, :, :] == 0, torch.tensor(1.0), torch.tensor(0.0)) |
| submask = mask[i:j, i:j].int().cpu().numpy() |
| for row in submask: |
| print(" ".join(map(str, row))) |
|
|
|
|
| @use_kernel_forward_from_hub("RMSNorm") |
| class Qwen3RMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps: float = 1e-6) -> None: |
| """ |
| Qwen3RMSNorm is equivalent to T5LayerNorm |
| """ |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
| def extra_repr(self): |
| return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
|
|
| class Qwen3MLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, x): |
| down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
| return down_proj |
|
|
|
|
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| Args: |
| q (`torch.Tensor`): The query tensor. |
| k (`torch.Tensor`): The key tensor. |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| sin (`torch.Tensor`): The sine part of the rotary embedding. |
| position_ids (`torch.Tensor`, *optional*): |
| Deprecated and unused. |
| unsqueeze_dim (`int`, *optional*, defaults to 1): |
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| Returns: |
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| """ |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| def eager_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| scaling: float, |
| dropout: float = 0.0, |
| **kwargs: Unpack[TransformersKwargs], |
| ): |
| key_states = repeat_kv(key, module.num_key_value_groups) |
| value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
| if attention_mask is not None: |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| attn_weights = attn_weights + causal_mask |
|
|
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
| attn_output = torch.matmul(attn_weights, value_states) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
| return attn_output, attn_weights |
|
|
|
|
| class Qwen3RotaryEmbedding(nn.Module): |
| inv_freq: torch.Tensor |
|
|
| def __init__(self, config: Qwen3Config, device=None): |
| super().__init__() |
| |
| if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): |
| self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
| else: |
| self.rope_type = "default" |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
|
|
| self.config = config |
| self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
| inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.original_inv_freq = self.inv_freq |
|
|
| @torch.no_grad() |
| @dynamic_rope_update |
| def forward(self, x, position_ids): |
| inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
| position_ids_expanded = position_ids[:, None, :].float() |
|
|
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() * self.attention_scaling |
| sin = emb.sin() * self.attention_scaling |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
| class Qwen3Attention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__(self, config: Qwen3Config, layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
| self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
| self.scaling = self.head_dim**-0.5 |
| self.attention_dropout = config.attention_dropout |
| self.is_causal = True |
|
|
| self.q_proj = nn.Linear( |
| config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.q_proj_hw = nn.Linear( |
| config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias |
| ) |
|
|
| self.k_proj = nn.Linear( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.k_proj_hw = nn.Linear( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
|
|
| self.v_proj = nn.Linear( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
|
|
| self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| self.q_norm_h = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps) |
| self.q_norm_w = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps) |
|
|
| self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| self.k_norm_h = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps) |
| self.k_norm_w = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps) |
| |
| self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None |
|
|
| self.rotary_emb = Qwen3RotaryEmbedding(config=config) |
|
|
| hw_config = copy.deepcopy(config) |
| hw_config.head_dim = config.head_dim // 2 |
| hw_config.rope_theta = config.rope_theta_hw |
| hw_config.max_position_embeddings = config.max_position_embeddings_hw |
| self.rotary_emb_hw = Qwen3RotaryEmbedding(config=hw_config) |
|
|
| @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| indexes: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor], |
| past_key_values: Optional[Cache] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| |
| assert self.config._attn_implementation == "eager" |
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
| query_states_t = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) |
| query_states_h, query_states_w = self.q_proj_hw(hidden_states).view(hidden_shape).transpose(1, 2).chunk(2, dim=-1) |
| query_states_h, query_states_w = self.q_norm_h(query_states_h), self.q_norm_w(query_states_w) |
|
|
| key_states_t = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) |
| key_states_h, key_states_w = self.k_proj_hw(hidden_states).view(hidden_shape).transpose(1, 2).chunk(2, dim=-1) |
| key_states_h, key_states_w = self.k_norm_h(key_states_h), self.k_norm_w(key_states_w) |
|
|
| value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
| cos_t, sin_t = self.rotary_emb(hidden_states, indexes[0].unsqueeze(0)) |
| query_states_t, key_states_t = apply_rotary_pos_emb(query_states_t, key_states_t, cos_t, sin_t) |
|
|
| cos_h, sin_h = self.rotary_emb_hw(hidden_states, indexes[1].unsqueeze(0)) |
| query_states_h, key_states_h = apply_rotary_pos_emb(query_states_h, key_states_h, cos_h, sin_h) |
|
|
| cos_w, sin_w = self.rotary_emb_hw(hidden_states, indexes[2].unsqueeze(0)) |
| query_states_w, key_states_w = apply_rotary_pos_emb(query_states_w, key_states_w, cos_w, sin_w) |
|
|
| query_states = torch.cat([query_states_t, query_states_h, query_states_w], dim=-1) |
| key_states = torch.cat([key_states_t, key_states_h, key_states_w], dim=-1) |
|
|
|
|
| if past_key_values is not None: |
| |
| |
| |
|
|
| key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs=None) |
|
|
| attention_interface: Callable = eager_attention_forward |
| if self.config._attn_implementation != "eager": |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| sliding_window=self.sliding_window, |
| **kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
| return attn_output, attn_weights |
|
|
|
|
| class Qwen3DecoderLayer(GradientCheckpointingLayer): |
| def __init__(self, config: Qwen3Config, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
|
|
| self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) |
|
|
| self.mlp = Qwen3MLP(config) |
| self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.attention_type = config.layer_types[layer_idx] |
|
|
| @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| indexes: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> torch.Tensor: |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| |
| hidden_states, _ = self.self_attn( |
| hidden_states=hidden_states, |
| indexes=indexes, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
| return hidden_states |
|
|
|
|
| @auto_docstring |
| class Qwen3PreTrainedModel(PreTrainedModel): |
| config: Qwen3Config |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["Qwen3DecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
|
|
| _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": Qwen3DecoderLayer, |
| "attentions": Qwen3Attention, |
| } |
|
|
|
|
| @auto_docstring |
| class Qwen3Model(Qwen3PreTrainedModel): |
| def __init__(self, config: Qwen3Config): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| self.layers = nn.ModuleList( |
| [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| ) |
| self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| |
| self.gradient_checkpointing = False |
| self.has_sliding_layers = "sliding_attention" in self.config.layer_types |
| self.current_index = -1 |
|
|
| |
| self.post_init() |
|
|
| @check_model_inputs |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| indexes: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> BaseModelOutputWithPast: |
| |
| assert position_ids is not None |
| assert cache_position is not None |
| assert past_key_values is not None |
| |
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = DynamicCache(config=self.config) |
|
|
| if cache_position is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| cache_position = torch.arange( |
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
| ) |
|
|
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| |
| if not isinstance(causal_mask_mapping := attention_mask, dict): |
| |
| if input_ids is not None: |
| mask_kwargs = { |
| "config": self.config, |
| "input_embeds": inputs_embeds, |
| "attention_mask": attention_mask, |
| "cache_position": cache_position, |
| "past_key_values": past_key_values, |
| "position_ids": position_ids, |
| } |
| |
| causal_mask_mapping = { |
| "full_attention": create_causal_mask(**mask_kwargs), |
| } |
| self.current_index += 1 |
| indexes = torch.LongTensor([[self.current_index], [0], [0]]).to(input_ids.device) |
| else: |
| causal_mask_mapping = { |
| "full_attention": create_block_causal_mask(indexes[0]), |
| } |
| self.current_index = indexes[0].max() |
| else: |
| raise NotImplementedError('not isinstance(causal_mask_mapping := attention_mask, dict)') |
|
|
| |
| |
| |
|
|
| hidden_states = inputs_embeds |
|
|
| for decoder_layer in self.layers[: self.config.num_hidden_layers]: |
| hidden_states = decoder_layer( |
| hidden_states, |
| indexes=indexes, |
| attention_mask=causal_mask_mapping[decoder_layer.attention_type], |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = self.norm(hidden_states) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values if use_cache else None, |
| ) |
|
|
|
|
| @auto_docstring |
| class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["lm_head.weight"] |
| _tp_plan = {"lm_head": "colwise_rep"} |
| _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = Qwen3Model(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| @can_return_tuple |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| indexes: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> CausalLMOutputWithPast: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, Qwen3ForCausalLM |
| |
| >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B") |
| >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") |
| |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| ```""" |
|
|
| outputs: BaseModelOutputWithPast = self.model( |
| input_ids=input_ids, |
| indexes=indexes, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| class Qwen3ForSequenceClassification(GenericForSequenceClassification, Qwen3PreTrainedModel): |
| pass |
|
|
|
|
| class Qwen3ForTokenClassification(GenericForTokenClassification, Qwen3PreTrainedModel): |
| pass |
|
|
|
|
| class Qwen3ForQuestionAnswering(GenericForQuestionAnswering, Qwen3PreTrainedModel): |
| base_model_prefix = "transformer" |
|
|
|
|
| __all__ = [ |
| "Qwen3ForCausalLM", |
| "Qwen3ForQuestionAnswering", |
| "Qwen3PreTrainedModel", |
| "Qwen3Model", |
| "Qwen3ForSequenceClassification", |
| "Qwen3ForTokenClassification", |
| ] |