| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """PyTorch Qwen3 model with shared expert support.""" |
|
|
| from typing import List, Optional, Union |
|
|
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| from transformers.modeling_outputs import ( |
| MoeCausalLMOutputWithPast, |
| MoeModelOutputWithPast, |
| ) |
| from transformers.activations import ACT2FN |
| from transformers.utils import logging |
| from transformers.models.mixtral.modeling_mixtral import ( |
| load_balancing_loss_func, |
| ) |
| from transformers.models.qwen3_moe.modeling_qwen3_moe import ( |
| Qwen3MoeMLP, |
| Qwen3MoeRMSNorm, |
| Qwen3MoeAttention, |
| Qwen3MoeDecoderLayer, |
| Qwen3MoeModel, |
| Qwen3MoeForCausalLM, |
| ) |
| from .configuration_qwen3_shared_moe import Qwen3SharedMoeConfig |
|
|
| import scattermoe |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class Qwen3SharedMoeSparseMoeBlock(nn.Module): |
| def __init__(self, config: Qwen3SharedMoeConfig): |
| super().__init__() |
| self.config = config |
| self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) |
| if config.shared_expert_intermediate_size is not None: |
| self.shared_expert = Qwen3MoeMLP( |
| config, intermediate_size=config.shared_expert_intermediate_size |
| ) |
| else: |
| self.shared_expert = None |
| self.moe_mlp = scattermoe.mlp.GLUMLP( |
| input_size=self.config.hidden_size, |
| hidden_size=self.config.moe_intermediate_size, |
| num_experts=self.config.num_experts, |
| top_k=self.config.num_experts_per_tok, |
| activation=ACT2FN[config.hidden_act], |
| ) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| |
| batch_size, sequence_length, hidden_dim = hidden_states.shape |
| hidden_states = hidden_states.view(-1, hidden_dim) |
| |
| router_logits = self.gate(hidden_states) |
|
|
| routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) |
| routing_weights, selected_experts = torch.topk( |
| routing_weights, self.config.num_experts_per_tok, dim=-1 |
| ) |
| if self.config.norm_topk_prob: |
| routing_weights /= routing_weights.sum(dim=-1, keepdim=True) |
| |
| routing_weights = routing_weights.to(hidden_states.dtype) |
|
|
| |
| hs_0 = self.moe_mlp(hidden_states, routing_weights, selected_experts) |
|
|
| if self.shared_expert is not None: |
| shared_res = self.shared_expert(hidden_states) |
| res = hs_0 + shared_res |
| else: |
| res = hs_0 |
| res = res.reshape(batch_size, sequence_length, hidden_dim) |
| return res, router_logits |
|
|
|
|
| class Qwen3SharedMoeDecoderLayer(Qwen3MoeDecoderLayer, nn.Module): |
| def __init__(self, config: Qwen3SharedMoeConfig, layer_idx: int): |
| super().__init__(config, layer_idx) |
| self.hidden_size = config.hidden_size |
|
|
| self.self_attn = Qwen3MoeAttention(config, layer_idx) |
|
|
| if (layer_idx not in config.mlp_only_layers) and ( |
| config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 |
| ): |
| self.mlp = Qwen3SharedMoeSparseMoeBlock(config) |
| else: |
| self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size) |
|
|
| self.input_layernorm = Qwen3MoeRMSNorm( |
| config.hidden_size, eps=config.rms_norm_eps |
| ) |
| self.post_attention_layernorm = Qwen3MoeRMSNorm( |
| config.hidden_size, eps=config.rms_norm_eps |
| ) |
|
|
|
|
| class Qwen3SharedMoeModel(Qwen3MoeModel): |
| config_class = Qwen3SharedMoeConfig |
|
|
| def __init__(self, config: Qwen3SharedMoeConfig): |
| super().__init__(config) |
| self.layers = nn.ModuleList( |
| [ |
| Qwen3SharedMoeDecoderLayer(config, layer_idx) |
| for layer_idx in range(config.num_hidden_layers) |
| ] |
| ) |
|
|
|
|
| class Qwen3SharedMoeForCausalLM(Qwen3MoeForCausalLM): |
| config_class = Qwen3SharedMoeConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = Qwen3SharedMoeModel(config) |
| self.num_experts = config.num_experts |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_router_logits: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| **kwargs, |
| ) -> MoeCausalLMOutputWithPast: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| logits_to_keep (`int` or `torch.Tensor`, *optional*): |
| If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all |
| `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that |
| token can save memory, which becomes pretty significant for long sequences or large vocabulary size. |
| If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. |
| This is useful when using packed tensor format (single dimension for batch and sequence length). |
| |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM |
| |
| >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") |
| >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") |
| |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| ```""" |
|
|
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_router_logits = ( |
| output_router_logits |
| if output_router_logits is not None |
| else self.config.output_router_logits |
| ) |
|
|
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
|
|
| |
| outputs: MoeModelOutputWithPast = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| output_router_logits=output_router_logits, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| |
| slice_indices = ( |
| slice(-logits_to_keep, None) |
| if isinstance(logits_to_keep, int) |
| else logits_to_keep |
| ) |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) |
|
|
| aux_loss = None |
| if output_router_logits: |
| aux_loss = load_balancing_loss_func( |
| outputs.router_logits, |
| self.num_experts, |
| self.num_experts_per_tok, |
| attention_mask, |
| ) |
| if labels is not None: |
| loss += self.router_aux_loss_coef * aux_loss.to( |
| loss.device |
| ) |
|
|
| return MoeCausalLMOutputWithPast( |
| loss=loss, |
| aux_loss=aux_loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| router_logits=outputs.router_logits, |
| ) |
|
|
|
|