| """ PyTorch JetMoE model.""" |
|
|
| from typing import List, Optional, Tuple, Union |
| import warnings, math |
|
|
| import torch |
| import torch.utils.checkpoint |
| from torch import nn |
| from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss |
| from torch.nn import functional as F |
|
|
| from transformers.modeling_outputs import ( |
| BaseModelOutputWithPast, |
| CausalLMOutputWithPast, |
| SequenceClassifierOutputWithPast, |
| dataclass |
| ) |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import ( |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_flash_attn_2_available, |
| is_flash_attn_greater_or_equal_2_10, |
| replace_return_docstrings, |
| logging |
| ) |
| from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa |
| from transformers.cache_utils import Cache, DynamicCache |
| from .configuration_jetmoe import JetMoEConfig |
| import scattermoe |
|
|
| try: |
| if is_flash_attn_2_available(): |
| from flash_attn import flash_attn_func, flash_attn_varlen_func |
| from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input |
| except ImportError: |
| |
| |
| pass |
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CHECKPOINT_FOR_DOC = "jetmoe" |
| _CONFIG_FOR_DOC = "JetMoEConfig" |
|
|
| class top_k_gating(nn.Module): |
| def __init__( |
| self, |
| input_size, |
| num_experts, |
| top_k, |
| ): |
| """ |
| Initialize the top-k gating mechanism. |
| |
| Args: |
| input_size (int): Size of the input. |
| num_experts (int): Number of experts. |
| top_k (int): Number of top experts to select. |
| acc_aux_loss (bool): Whether to accumulate auxiliary loss statistics. |
| dropout (float): Dropout rate for gating network. |
| hidden_size (int): Hidden size of the gating network. |
| sample_topk (int): Number of top-k experts to sample during training. |
| aux_loss (str): Type of auxiliary loss ('mi' or 'switch'). |
| gate_type (str): Type of gating mechanism ('mlp', 'linear', or 'gmm'). |
| """ |
| super().__init__() |
|
|
| self.num_experts = num_experts |
| self.input_size = input_size |
| assert top_k <= num_experts |
| self.top_k = top_k |
|
|
| self.layer = nn.Linear(input_size, num_experts, bias=False) |
|
|
| def extra_repr(self): |
| """ |
| Return extra representation string for the module. |
| """ |
| return 'k={}, num_experts={}'.format( |
| self.top_k, self.num_experts) |
|
|
| def compute_aux_loss(self, probs, logits, gates): |
| """ |
| Calculate and return the auxiliary loss based on the accumulated statistics. |
| |
| Args: |
| eps (float): Small epsilon value for numerical stability. |
| |
| Returns: |
| torch.Tensor: The calculated auxiliary loss. |
| """ |
| count = logits.size(0) |
| probs = probs.sum(0) |
| freq = (gates > 0).float().sum(0) |
| lsesq = (torch.log(torch.exp(logits).sum(dim=-1)) ** 2).sum() |
|
|
| switchloss = self.num_experts * ( |
| F.normalize(probs, p=1, dim=0) * |
| F.normalize(freq, p=1, dim=0) |
| ).sum() |
| zloss = lsesq / count |
| loss = switchloss + 0.1 * zloss |
|
|
| return loss |
|
|
| def forward(self, x): |
| """ |
| Compute the top-k gating for the input. |
| |
| See paper: https://arxiv.org/abs/1701.06538. |
| |
| Args: |
| x (torch.Tensor): Input tensor with shape [batch_size, input_size]. |
| skip_mask (torch.Tensor): Skip mask tensor (binary) with the same shape as `x`. |
| x: input Tensor with shape [batch_size, input_size] |
| train: a boolean - we only add noise at training time. |
| noise_epsilon: a float |
| |
| Returns: |
| torch.Tensor: Top-k indices. |
| torch.Tensor: Top-k gating values. |
| torch.Tensor: Probability values for each expert. |
| gates: a Tensor with shape [batch_size, num_experts] |
| load: a Tensor with shape [num_experts] |
| """ |
|
|
| logits = self.layer(x).float() |
| top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) |
| top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(x) |
|
|
| if self.training: |
| probs = torch.softmax(logits, dim=1) |
| zeros = torch.zeros_like(probs) |
| zeros = zeros.to(top_k_gates.dtype) |
| gates = zeros.scatter(1, top_k_indices, top_k_gates) |
| self.loss = self.compute_aux_loss(probs, logits, gates) |
| else: |
| self.loss = 0 |
|
|
| return top_k_indices, top_k_gates |
|
|
| class MoE(nn.Module): |
| """ |
| A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts. |
| |
| |
| Args: |
| input_size: integer - size of the input |
| head_size: integer - size of the expert's hidden layer |
| num_experts: an integer - number of experts |
| top_k: an integer - how many experts to use for each batch element |
| bias: a boolean - whether to include bias in linear layers |
| activation: an activation function to apply to expert's outputs |
| acc_aux_loss: a boolean - whether to accumulate auxiliary loss |
| hidden_size: an integer - hidden size of the experts |
| gating_dropout: a float - dropout rate for gating network |
| sample_topk: an integer - how many experts to sample during training |
| gating_size: an integer - size of the gating network |
| aux_loss: a string - type of auxiliary loss ('mi' or 'sparse') |
| gate_type: a string - type of gating mechanism ('mlp' or 'topk') |
| """ |
|
|
| def __init__( |
| self, |
| input_size, |
| hidden_size, |
| num_experts, |
| top_k, |
| bias=True, |
| activation=None, |
| glu=True, |
| ): |
| super(MoE, self).__init__() |
|
|
| self.num_experts = num_experts |
| self.input_size = input_size |
| self.glu = glu |
| if bias: |
| self.bias = torch.nn.Parameter(torch.empty(input_size)) |
| torch.nn.init.zeros_(self.bias) |
| else: |
| self.bias = None |
| self.input_linear = scattermoe.parallel_experts.ParallelExperts(num_experts, input_size, hidden_size * 2 if glu else hidden_size) |
| self.output_linear = scattermoe.parallel_experts.ParallelExperts(num_experts, hidden_size, input_size) |
| self.top_k = min(top_k, self.num_experts) |
| self.activation = activation |
|
|
| self.router = top_k_gating( |
| input_size=input_size, |
| num_experts=num_experts, |
| top_k=top_k, |
| ) |
|
|
| def extra_repr(self): |
| return 'k={}, e={}'.format( |
| self.top_k, self.num_experts) |
|
|
| def get_aux_loss_and_clear(self): |
| """ |
| Get the accumulated auxiliary loss and clear it. |
| |
| Returns: |
| float: Accumulated auxiliary loss. |
| """ |
|
|
| return self.gate.get_aux_loss_and_clear() |
| |
| def compute_gate(self, x): |
| top_k_indices, self.top_k_gates = self.router(x) |
|
|
| with torch.no_grad(): |
| self.sorted_expert_idxs, self.sorted_scattered_idxs = scattermoe.kernels.ops.flatten_and_sort(top_k_indices) |
| self.padded_block_idxs, self.expert_offsets = scattermoe.kernels.ops.padded_block_indices(self.sorted_expert_idxs, self.num_experts) |
|
|
| return self.router.loss |
|
|
| def batch_forward(self, x): |
| """ |
| Forward pass of the mixture of experts layer. |
| |
| Args: |
| x (Tensor): Input tensor. |
| |
| Returns: |
| Tensor: Output tensor. |
| """ |
| bsz, length, emb_size = x.size() |
| x = x.reshape(-1, emb_size) |
|
|
| loss = self.compute_gate(x) |
|
|
| h = self.input_linear( |
| x, self.top_k, |
| self.sorted_expert_idxs, self.sorted_scattered_idxs, |
| self.padded_block_idxs, self.expert_offsets, |
| grouped_out=True |
| ) |
|
|
| if self.glu: |
| h, g = h.chunk(2, dim=-1) |
| h = self.activation(h) * g |
| else: |
| h = self.activation(h) |
|
|
| y = self.output_linear( |
| h, 1, |
| self.sorted_expert_idxs, self.sorted_scattered_idxs, |
| self.padded_block_idxs, self.expert_offsets, |
| grouped_in=True, |
| gates=self.top_k_gates, |
| ) |
|
|
| y = y.view(bsz, length, self.input_size) |
| if self.bias is not None: |
| y = y + self.bias |
| return y, loss |
| |
| def single_forward(self, x): |
| bsz, length, emb_size = x.size() |
|
|
| x = x.reshape(1, self.input_size) |
| top_k_indices, top_k_gates = self.router(x) |
| loss = self.router.loss |
|
|
| y_list = [] |
| for i in range(self.top_k): |
| expert_idx = top_k_indices[0,i] |
|
|
| h = F.linear(x, self.input_linear.weight[expert_idx]) |
| if self.glu: |
| h, g = h.chunk(2, dim=-1) |
| h = self.activation(h) * g |
| else: |
| h = self.activation(h) |
| y = F.linear(h, self.output_linear.weight[expert_idx]) * top_k_gates[0,i] |
|
|
| y_list.append(y) |
| |
| y = sum(y_list) |
| y = y.view(bsz, length, self.input_size) |
| if self.bias is not None: |
| y = y + self.bias |
| return y, loss |
| |
| def forward(self, x): |
| """ |
| Forward pass of the mixture of experts layer. |
| |
| Args: |
| x (Tensor): Input tensor. |
| |
| Returns: |
| Tensor: Output tensor. |
| """ |
| bsz, length, emb_size = x.size() |
| if bsz * length ==1: |
| return self.single_forward(x) |
| else: |
| return self.batch_forward(x) |
|
|
| def batch_map(self, x): |
| """ |
| Map input through the mixture of experts layer. |
| |
| Args: |
| x (Tensor): Input tensor. |
| |
| Returns: |
| Tensor: Output tensor. |
| """ |
| bsz, length, emb_size = x.size() |
| x = x.reshape(-1, emb_size) |
| loss = self.compute_gate(x) |
|
|
| y = self.input_linear( |
| x, self.top_k, |
| self.sorted_expert_idxs, self.sorted_scattered_idxs, |
| self.padded_block_idxs, self.expert_offsets, |
| ) |
| y = y.view(bsz, length, self.top_k, -1) |
| return y, loss |
| |
| def single_map(self, x): |
| bsz, length, emb_size = x.size() |
|
|
| x = x.reshape(1, self.input_size) |
| self.top_k_indices, self.top_k_gates = self.router(x) |
| loss = self.router.loss |
|
|
| y_list = [] |
| for i in range(self.top_k): |
| expert_idx = self.top_k_indices[0,i] |
| y = F.linear(x, self.input_linear.weight[expert_idx]) |
| y_list.append(y) |
| y = torch.cat(y_list, dim=0) |
| y = y.view(bsz, length, self.top_k, -1) |
| return y, loss |
| |
| def map(self, x): |
| """ |
| Map input through the mixture of experts layer. |
| |
| Args: |
| x (Tensor): Input tensor. |
| |
| Returns: |
| Tensor: Output tensor. |
| """ |
| bsz, length, emb_size = x.size() |
| if bsz * length ==1: |
| return self.single_map(x) |
| else: |
| return self.batch_map(x) |
|
|
| def batch_reduce(self, x): |
| """ |
| Reduce the mapped output. |
| |
| Args: |
| x (Tensor): Mapped output tensor. |
| |
| Returns: |
| Tensor: Reduced output tensor. |
| """ |
| |
| bsz, length, k, emb_size = x.size() |
| assert k == self.top_k |
| x = x.reshape(-1, emb_size) |
|
|
| y = self.output_linear( |
| x, 1, |
| self.sorted_expert_idxs, self.sorted_scattered_idxs, |
| self.padded_block_idxs, self.expert_offsets, |
| gates=self.top_k_gates, |
| ) |
| y = y.view(bsz, length, self.input_size) |
| return y |
| |
| def single_reduce(self, x): |
| bsz, length, k, emb_size = x.size() |
|
|
| x = x.reshape(k, emb_size) |
|
|
| y_list = [] |
| for i in range(self.top_k): |
| expert_idx = self.top_k_indices[0,i] |
| y = F.linear(x[i], self.output_linear.weight[expert_idx]) * self.top_k_gates[0,i] |
| y_list.append(y) |
| y = sum(y_list) |
| y = y.view(bsz, length, self.input_size) |
| return y |
| |
| def reduce(self, x): |
| """ |
| Reduce the mapped output. |
| |
| Args: |
| x (Tensor): Mapped output tensor. |
| |
| Returns: |
| Tensor: Reduced output tensor. |
| """ |
| bsz, length, k, emb_size = x.size() |
| if bsz * length ==1: |
| return self.single_reduce(x) |
| else: |
| return self.batch_reduce(x) |
|
|
| @dataclass |
| class JetMoEBaseModelOutputWithPast(BaseModelOutputWithPast): |
| """ |
| Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). |
| |
| Args: |
| last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
| Sequence of hidden-states at the output of the last layer of the model. |
| |
| If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, |
| hidden_size)` is output. |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
| `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if |
| `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, |
| encoder_sequence_length, embed_size_per_head)`. |
| |
| Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if |
| `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` |
| input) to speed up sequential decoding. |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
| |
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
| sequence_length)`. |
| |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| heads. |
| """ |
|
|
| last_hidden_state: torch.FloatTensor = None |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
| aux_loss: Optional[torch.FloatTensor] = None |
|
|
|
|
| @dataclass |
| class JetMoECausalLMOutputWithPast(CausalLMOutputWithPast): |
| """ |
| Base class for causal language model (or autoregressive) outputs. |
| |
| Args: |
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
| Language modeling loss (for next-token prediction). |
| logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
| `(batch_size, num_heads, sequence_length, embed_size_per_head)`) |
| |
| Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see |
| `past_key_values` input) to speed up sequential decoding. |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
| |
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
| sequence_length)`. |
| |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| heads. |
| """ |
|
|
| loss: Optional[torch.FloatTensor] = None |
| logits: torch.FloatTensor = None |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
| aux_loss: Optional[torch.FloatTensor] = None |
|
|
|
|
| @dataclass |
| class JetMoESequenceClassifierOutputWithPast(SequenceClassifierOutputWithPast): |
| """ |
| Base class for outputs of sentence classification models. |
| |
| Args: |
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
| Classification (or regression if config.num_labels==1) loss. |
| logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): |
| Classification (or regression if config.num_labels==1) scores (before SoftMax). |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
| `(batch_size, num_heads, sequence_length, embed_size_per_head)`) |
| |
| Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see |
| `past_key_values` input) to speed up sequential decoding. |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
| |
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
| sequence_length)`. |
| |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| heads. |
| """ |
|
|
| loss: Optional[torch.FloatTensor] = None |
| logits: torch.FloatTensor = None |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
| aux_loss: Optional[torch.FloatTensor] = None |
|
|
|
|
| |
| def _get_unpad_data(attention_mask): |
| seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
| max_seqlen_in_batch = seqlens_in_batch.max().item() |
| cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
| return ( |
| indices, |
| cu_seqlens, |
| max_seqlen_in_batch, |
| ) |
|
|
| class JetMoERMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| """ |
| JetMoERMSNorm module |
| """ |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
| |
| class JetMoERotaryEmbedding(nn.Module): |
| def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): |
| super().__init__() |
|
|
| self.dim = dim |
| self.max_position_embeddings = max_position_embeddings |
| self.base = base |
| inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| |
| self._set_cos_sin_cache( |
| seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() |
| ) |
|
|
| def _set_cos_sin_cache(self, seq_len, device, dtype): |
| self.max_seq_len_cached = seq_len |
| t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) |
|
|
| freqs = torch.outer(t, self.inv_freq) |
| |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) |
| self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) |
|
|
| def forward(self, x, seq_len=None): |
| |
| if seq_len > self.max_seq_len_cached: |
| self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) |
|
|
| return ( |
| self.cos_cached[:seq_len].to(dtype=x.dtype), |
| self.sin_cached[:seq_len].to(dtype=x.dtype), |
| ) |
|
|
|
|
| |
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=2): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| Args: |
| q (`torch.Tensor`): The query tensor. |
| k (`torch.Tensor`): The key tensor. |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| sin (`torch.Tensor`): The sine part of the rotary embedding. |
| position_ids (`torch.Tensor`): |
| The position indices of the tokens corresponding to the query and key tensors. For example, this can be |
| used to pass offsetted position ids when working with a KV-cache. |
| unsqueeze_dim (`int`, *optional*, defaults to 1): |
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| Returns: |
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| """ |
| cos = cos[position_ids].unsqueeze(unsqueeze_dim) |
| sin = sin[position_ids].unsqueeze(unsqueeze_dim) |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| class JetMoEAttention(nn.Module): |
| """ |
| Multi-headed attention from 'Attention Is All You Need' paper. |
| """ |
|
|
| def __init__(self, config: JetMoEConfig, layer_idx: Optional[int] = None): |
| """ |
| Initialize the JetMoEAttention module. |
| |
| Args: |
| config: Configuration object with model hyperparameters. |
| """ |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.is_causal = True |
| if layer_idx is None: |
| logger.warning_once( |
| f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " |
| "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " |
| "when creating this class." |
| ) |
|
|
| self.top_k = config.moe_top_k |
|
|
| self.kv_projection_size = config.kv_channels * config.num_attention_heads |
| self.num_key_value_heads = config.num_attention_heads |
| self.num_heads = self.num_key_value_heads * self.top_k |
| self.hidden_size_per_attention_head = config.kv_channels |
|
|
| self.experts = MoE( |
| input_size=config.hidden_size, |
| hidden_size=self.kv_projection_size, |
| num_experts=config.moe_num_experts, |
| top_k=config.moe_top_k, |
| glu=False |
| ) |
|
|
| self.kv_proj = torch.nn.Linear( |
| config.hidden_size, self.kv_projection_size * 2, bias=False |
| ) |
| |
| self.rotary_emb = JetMoERotaryEmbedding( |
| config.kv_channels, |
| max_position_embeddings=config.max_position_embeddings, |
| base=config.rope_theta, |
| ) |
|
|
| |
| |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| **kwargs, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| if "padding_mask" in kwargs: |
| warnings.warn( |
| "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" |
| ) |
| bsz, q_len, _ = hidden_states.size() |
|
|
| query_states, aux_loss = self.experts.map(hidden_states) |
| key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1) |
|
|
| query_states = query_states.view(bsz, q_len, self.num_heads, self.hidden_size_per_attention_head).transpose(1, 2) |
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.hidden_size_per_attention_head).transpose(1, 2) |
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.hidden_size_per_attention_head).transpose(1, 2) |
|
|
| kv_seq_len = key_states.shape[2] |
| if past_key_value is not None: |
| if self.layer_idx is None: |
| raise ValueError( |
| f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " |
| "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " |
| "with a layer index." |
| ) |
| kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) |
| cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids, unsqueeze_dim=1) |
|
|
| if past_key_value is not None: |
| cache_kwargs = {"sin": sin, "cos": cos} |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
| |
| key_states = key_states.repeat(1, self.top_k, 1, 1) |
| value_states = value_states.repeat(1, self.top_k, 1, 1) |
|
|
| attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.hidden_size_per_attention_head) |
|
|
| if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): |
| raise ValueError( |
| f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" |
| f" {attn_weights.size()}" |
| ) |
|
|
| if attention_mask is not None: |
| if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
| raise ValueError( |
| f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
| ) |
|
|
| attn_weights = attn_weights + attention_mask |
|
|
| |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
| |
| attn_output = torch.matmul(attn_weights, value_states) |
|
|
| if attn_output.size() != (bsz, self.num_heads, q_len, self.hidden_size_per_attention_head): |
| raise ValueError( |
| f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.hidden_size_per_attention_head)}, but is" |
| f" {attn_output.size()}" |
| ) |
|
|
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.reshape(bsz, q_len, self.top_k, self.kv_projection_size) |
|
|
| attn_output = self.experts.reduce(attn_output) |
| attn_output = attn_output.view(bsz, q_len, -1) |
|
|
| if not output_attentions: |
| attn_weights = None |
|
|
| return attn_output, attn_weights, past_key_value, aux_loss |
| |
|
|
| |
| class JetMoESdpaAttention(JetMoEAttention): |
| """ |
| JetMoE attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from |
| `JetMoEAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to |
| SDPA API. |
| """ |
|
|
| |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| if output_attentions: |
| |
| logger.warning_once( |
| "JetMoEModel is using JetMoESdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " |
| 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' |
| ) |
| return super().forward( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
|
|
| bsz, q_len, _ = hidden_states.size() |
|
|
| query_states, aux_loss = self.experts.map(hidden_states) |
| key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1) |
|
|
| query_states = query_states.view(bsz, q_len, self.num_heads, self.hidden_size_per_attention_head).transpose(1, 2) |
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.hidden_size_per_attention_head).transpose(1, 2) |
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.hidden_size_per_attention_head).transpose(1, 2) |
|
|
| kv_seq_len = key_states.shape[2] |
| if past_key_value is not None: |
| kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) |
| cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
|
|
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids, unsqueeze_dim=1) |
|
|
| if past_key_value is not None: |
| cache_kwargs = {"sin": sin, "cos": cos} |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
| key_states = key_states.repeat(1, self.top_k, 1, 1) |
| value_states = value_states.repeat(1, self.top_k, 1, 1) |
|
|
| if attention_mask is not None: |
| if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
| raise ValueError( |
| f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
| ) |
|
|
| |
| |
| if query_states.device.type == "cuda" and attention_mask is not None: |
| query_states = query_states.contiguous() |
| key_states = key_states.contiguous() |
| value_states = value_states.contiguous() |
|
|
| attn_output = torch.nn.functional.scaled_dot_product_attention( |
| query_states, |
| key_states, |
| value_states, |
| attn_mask=attention_mask, |
| dropout_p=0.0, |
| |
| is_causal=self.is_causal and attention_mask is None and q_len > 1, |
| ) |
|
|
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.reshape(bsz, q_len, self.top_k, self.kv_projection_size) |
|
|
| attn_output = self.experts.reduce(attn_output) |
| attn_output = attn_output.view(bsz, q_len, -1) |
|
|
| return attn_output, None, past_key_value, aux_loss |
|
|
|
|
| class JetMoEFlashAttention2(JetMoEAttention): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| |
| |
| |
| self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() |
|
|
| def forward( |
| self, |
| hidden_states: Optional[torch.FloatTensor], |
| attention_mask: Optional[torch.FloatTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| use_cache: Optional[bool] = False, |
| output_attentions: Optional[bool] = False, |
| **kwargs, |
| ) -> Union[ |
| Tuple[torch.Tensor, Tuple[torch.Tensor]], |
| Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], |
| ]: |
| """ |
| Forward pass of the JetMoEAttention module. |
| |
| Args: |
| hidden_states (Optional[torch.FloatTensor]): Input hidden states. |
| attention_mask (Optional[torch.FloatTensor]): Attention mask. |
| layer_past (Optional[Tuple[torch.Tensor]]): Past layer state. |
| use_cache (Optional[bool]): Whether to use cached states. |
| output_attentions (Optional[bool]): Whether to output attention weights. |
| |
| Returns: |
| Union[Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[...]]]: Tuple containing outputs. |
| """ |
| |
| assert output_attentions is False, "output_attentions is not supported" |
|
|
| B, T, C = hidden_states.size() |
|
|
| |
| query_layer, aux_loss = self.experts.map(hidden_states) |
| key_layer, value_layer = self.kv_proj(hidden_states).chunk(2, dim=-1) |
|
|
| query_layer = query_layer.view(B, T, self.num_heads, self.hidden_size_per_attention_head) |
| key_layer = key_layer.view(B, T, self.num_key_value_heads, self.hidden_size_per_attention_head) |
| value_layer = value_layer.view(B, T, self.num_key_value_heads, self.hidden_size_per_attention_head) |
|
|
| kv_seq_len = key_layer.shape[1] |
| if past_key_value is not None: |
| if self.layer_idx is None: |
| raise ValueError( |
| f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " |
| "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " |
| "with a layer index." |
| ) |
| kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) |
| cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) |
| query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) |
|
|
| |
| |
| key_layer = key_layer.repeat(1, 1, self.top_k, 1) |
| value_layer = value_layer.repeat(1, 1, self.top_k, 1) |
|
|
| if past_key_value is not None: |
| cache_kwargs = {"sin": sin, "cos": cos} |
| |
| key_layer = key_layer.transpose(1, 2) |
| value_layer = value_layer.transpose(1, 2) |
| key_layer, value_layer = past_key_value.update(key_layer, value_layer, self.layer_idx, cache_kwargs) |
| key_layer = key_layer.transpose(1, 2) |
| value_layer = value_layer.transpose(1, 2) |
|
|
| context_layer = self._flash_attention_forward( |
| query_layer, |
| key_layer, |
| value_layer, |
| attention_mask, |
| T, |
| ) |
|
|
| |
| y = self.experts.reduce(context_layer.reshape(T, B, self.top_k, self.kv_projection_size)) |
| y = y.view(B, T, C) |
|
|
| if not output_attentions: |
| attn_weights = None |
|
|
| return y, attn_weights, past_key_value, aux_loss |
|
|
| def _flash_attention_forward( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| query_length, |
| dropout=0.0, |
| softmax_scale=None, |
| ): |
| """ |
| Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token |
| first unpad the input, then computes the attention scores and pad the final attention scores. |
| |
| Args: |
| query_states (`torch.Tensor`): |
| Input query states to be passed to Flash Attention API |
| key_states (`torch.Tensor`): |
| Input key states to be passed to Flash Attention API |
| value_states (`torch.Tensor`): |
| Input value states to be passed to Flash Attention API |
| attention_mask (`torch.Tensor`): |
| The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the |
| position of padding tokens and 1 for the position of non-padding tokens. |
| dropout (`float`): |
| Attention dropout |
| softmax_scale (`float`, *optional*): |
| The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) |
| """ |
| if not self._flash_attn_uses_top_left_mask: |
| causal = self.is_causal |
| else: |
| |
| causal = self.is_causal and query_length != 1 |
|
|
| |
| if attention_mask is not None: |
| batch_size = query_states.shape[0] |
| query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( |
| query_states, key_states, value_states, attention_mask, query_length |
| ) |
|
|
| cu_seqlens_q, cu_seqlens_k = cu_seq_lens |
| max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens |
|
|
| attn_output_unpad = flash_attn_varlen_func( |
| query_states, |
| key_states, |
| value_states, |
| cu_seqlens_q=cu_seqlens_q, |
| cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_q=max_seqlen_in_batch_q, |
| max_seqlen_k=max_seqlen_in_batch_k, |
| dropout_p=dropout, |
| softmax_scale=softmax_scale, |
| causal=causal, |
| ) |
|
|
| attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) |
| else: |
| attn_output = flash_attn_func( |
| query_states, |
| key_states, |
| value_states, |
| dropout, |
| softmax_scale=softmax_scale, |
| causal=causal |
| ) |
|
|
| return attn_output |
|
|
|
|
| def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): |
| indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) |
| batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape |
|
|
| key_layer = index_first_axis( |
| key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k |
| ) |
| value_layer = index_first_axis( |
| value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k |
| ) |
| if query_length == kv_seq_len: |
| query_layer = index_first_axis( |
| query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k |
| ) |
| cu_seqlens_q = cu_seqlens_k |
| max_seqlen_in_batch_q = max_seqlen_in_batch_k |
| indices_q = indices_k |
| elif query_length == 1: |
| max_seqlen_in_batch_q = 1 |
| cu_seqlens_q = torch.arange( |
| batch_size + 1, dtype=torch.int32, device=query_layer.device |
| ) |
| indices_q = cu_seqlens_q[:-1] |
| query_layer = query_layer.squeeze(1) |
| else: |
| |
| attention_mask = attention_mask[:, -query_length:] |
| query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) |
|
|
| return ( |
| query_layer, |
| key_layer, |
| value_layer, |
| indices_q, |
| (cu_seqlens_q, cu_seqlens_k), |
| (max_seqlen_in_batch_q, max_seqlen_in_batch_k), |
| ) |
|
|
|
|
| JETMOE_ATTENTION_CLASSES = { |
| "eager": JetMoEAttention, |
| "flash_attention_2": JetMoEFlashAttention2, |
| "sdpa": JetMoESdpaAttention, |
| } |
|
|
|
|
| class JetMoEBlock(nn.Module): |
| def __init__(self, config: JetMoEConfig, layer_idx: Optional[int] = None): |
| """ |
| Initialize the JetMoEBlock module. |
| |
| Args: |
| config: Configuration object with model hyperparameters. |
| """ |
| super().__init__() |
| self.input_layernorm = JetMoERMSNorm(config.hidden_size) |
| |
| self.self_attention = JETMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) |
| self.post_attention_layernorm = JetMoERMSNorm(config.hidden_size) |
|
|
| |
| |
| |
| |
| self.mlp = MoE( |
| input_size=config.hidden_size, |
| hidden_size=config.ffn_hidden_size, |
| num_experts=config.moe_num_experts, |
| activation=F.silu, |
| top_k=config.moe_top_k, |
| glu=config.glu |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: Optional[torch.FloatTensor], |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| **kwargs, |
| ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: |
| """ |
| Forward pass of the JetMoEBlock module. |
| |
| Args: |
| hidden_states (Optional[torch.FloatTensor]): Input hidden states. |
| layer_past (Optional[Tuple[torch.Tensor]]): Past layer state. |
| attention_mask (Optional[torch.FloatTensor]): Attention mask. |
| head_mask (Optional[torch.FloatTensor]): Head mask. |
| use_cache (Optional[bool]): Whether to use cached states. |
| output_attentions (Optional[bool]): Whether to output attention weights. |
| |
| Returns: |
| Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: |
| Tuple containing outputs or optional attention weights. |
| """ |
| |
| attn_output, self_attn_weights, present_key_value, att_aux_loss = self.self_attention( |
| hidden_states=self.input_layernorm(hidden_states), |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
|
|
| hidden_states = hidden_states + attn_output |
| x_mlp, mlp_aux_loss = self.mlp(self.post_attention_layernorm(hidden_states)) |
| hidden_states = hidden_states + x_mlp |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (self_attn_weights,) |
|
|
| if use_cache: |
| outputs += (present_key_value,) |
|
|
| outputs += (att_aux_loss + mlp_aux_loss,) |
|
|
| return outputs |
|
|
|
|
|
|
| class JetMoEPreTrainedModel(PreTrainedModel): |
| """ |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| models. |
| """ |
|
|
| config_class = JetMoEConfig |
| base_model_prefix = "transformer" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["JetMoEBlock"] |
| _skip_keys_device_placement = "past_key_values" |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
|
|
| def __init__(self, *inputs, **kwargs): |
| """ |
| Initialize the JetMoEPreTrainedModel. |
| |
| Args: |
| *inputs: Variable length input arguments. |
| **kwargs: Keyword arguments. |
| """ |
| super().__init__(*inputs, **kwargs) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _init_weights(self, module): |
| """Initialize the weights.""" |
| if isinstance(module, (nn.Linear,)): |
| |
| |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| MODULEFORMER_START_DOCSTRING = r""" |
| This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use |
| it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and |
| behavior. |
| |
| Parameters: |
| config ([`JetMoEConfig`]): Model configuration class with all the parameters of the model. |
| Initializing with a config file does not load the weights associated with the model, only the |
| configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| """ |
|
|
| MODULEFORMER_INPUTS_DOCSTRING = r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `({0})`): |
| Indices of input sequence tokens in the vocabulary. |
| |
| Indices can be obtained using [`AutoProcenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| [What are input IDs?](../glossary#input-ids) |
| attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): |
| Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, |
| 1]`: |
| |
| - 0 corresponds to a *sentence A* token, |
| - 1 corresponds to a *sentence B* token. |
| |
| [What are token type IDs?](../glossary#token-type-ids) |
| position_ids (`torch.LongTensor` of shape `({0})`, *optional*): |
| Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
| config.n_positions - 1]`. |
| |
| [What are position IDs?](../glossary#position-ids) |
| head_mask (`torch.FloatTensor` of shape `(num_attention_heads,)` or `(n_layer, num_attention_heads)`, *optional*): |
| Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: |
| |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| |
| inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_dim)`, *optional*): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
| is useful if you want more control over how to convert *input_ids* indices into associated vectors than the |
| model's internal embedding lookup matrix. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The bare JetMoE Model outputting raw hidden-states without any specific head on top.", |
| MODULEFORMER_START_DOCSTRING, |
| ) |
| class JetMoEModel(JetMoEPreTrainedModel): |
| """ |
| Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`JetMoEBlock`] |
| |
| Args: |
| config: JetMoEConfig |
| """ |
|
|
| def __init__(self, config: JetMoEConfig): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| self.layers = nn.ModuleList( |
| [JetMoEBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| ) |
| self._attn_implementation = config._attn_implementation |
| self.norm = JetMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
| self.gradient_checkpointing = False |
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.embed_tokens = value |
|
|
| @add_start_docstrings_to_model_forward(MODULEFORMER_INPUTS_DOCSTRING) |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") |
| elif input_ids is not None: |
| batch_size, seq_length = input_ids.shape |
| elif inputs_embeds is not None: |
| batch_size, seq_length, _ = inputs_embeds.shape |
| else: |
| raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") |
|
|
| if self.gradient_checkpointing and self.training: |
| if use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| ) |
| use_cache = False |
|
|
| past_key_values_length = 0 |
|
|
| if use_cache: |
| use_legacy_cache = not isinstance(past_key_values, Cache) |
| if use_legacy_cache: |
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
| past_key_values_length = past_key_values.get_usable_length(seq_length) |
|
|
| if position_ids is None: |
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
| position_ids = torch.arange( |
| past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device |
| ) |
| position_ids = position_ids.unsqueeze(0).view(-1, seq_length) |
| else: |
| position_ids = position_ids.view(-1, seq_length).long() |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: |
| is_padding_right = attention_mask[:, -1].sum().item() != batch_size |
| if is_padding_right: |
| raise ValueError( |
| "You are attempting to perform batched generation with padding_side='right'" |
| " this may lead to unexpected behaviour for Flash Attention version of JetMoE. Make sure to " |
| " call `tokenizer.padding_side = 'left'` before tokenizing the input. " |
| ) |
|
|
| if self._attn_implementation == "flash_attention_2": |
| |
| attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None |
| elif self._attn_implementation == "sdpa" and not output_attentions: |
| |
| |
| attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
| attention_mask, |
| (batch_size, seq_length), |
| inputs_embeds, |
| past_key_values_length, |
| ) |
| else: |
| |
| attention_mask = _prepare_4d_causal_attention_mask( |
| attention_mask, |
| (batch_size, seq_length), |
| inputs_embeds, |
| past_key_values_length, |
| ) |
|
|
| hidden_states = inputs_embeds |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| next_decoder_cache = None |
|
|
| aux_loss = 0 |
| for decoder_layer in self.layers: |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| |
| decoder_layer, |
| hidden_states, |
| position_ids, |
| past_key_values, |
| attention_mask, |
| output_attentions, |
| use_cache, |
| use_reentrant=False, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if use_cache: |
| next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| aux_loss += layer_outputs[-1] |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| next_cache = None |
| if use_cache: |
| next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache |
|
|
| if not return_dict: |
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
| return JetMoEBaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| aux_loss=aux_loss, |
| ) |
|
|
|
|
| class JetMoEForCausalLM(JetMoEPreTrainedModel): |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = JetMoEModel(config) |
| self.vocab_size = config.vocab_size |
| self.aux_loss_coef = getattr(config, 'aux_loss_coef', 0.01) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.model = decoder |
|
|
| def get_decoder(self): |
| return self.model |
|
|
| @add_start_docstrings_to_model_forward(MODULEFORMER_INPUTS_DOCSTRING) |
| @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| r""" |
| Args: |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| Returns: |
| """ |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| hidden_states = outputs[0] |
| logits = self.lm_head(hidden_states) |
| logits = logits.float() |
|
|
| loss = None |
| if labels is not None: |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| if labels is not None and self.model.training: |
| loss += self.aux_loss_coef * outputs.aux_loss.to(loss.device) |
|
|
| return JetMoECausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| aux_loss=outputs.aux_loss, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
| ): |
| |
| if past_key_values is not None: |
| if isinstance(past_key_values, Cache): |
| cache_length = past_key_values.get_seq_length() |
| past_length = past_key_values.seen_tokens |
| max_cache_length = past_key_values.get_max_length() |
| else: |
| cache_length = past_length = past_key_values[0][0].shape[2] |
| max_cache_length = None |
|
|
| |
| |
| |
| |
| if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: |
| input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] |
| |
| |
| elif past_length < input_ids.shape[1]: |
| input_ids = input_ids[:, past_length:] |
| |
|
|
| |
| if ( |
| max_cache_length is not None |
| and attention_mask is not None |
| and cache_length + input_ids.shape[1] > max_cache_length |
| ): |
| attention_mask = attention_mask[:, -max_cache_length:] |
|
|
| position_ids = kwargs.get("position_ids", None) |
| if attention_mask is not None and position_ids is None: |
| |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| if past_key_values: |
| position_ids = position_ids[:, -input_ids.shape[1] :] |
|
|
| |
| if inputs_embeds is not None and past_key_values is None: |
| model_inputs = {"inputs_embeds": inputs_embeds} |
| else: |
| model_inputs = {"input_ids": input_ids} |
|
|
| model_inputs.update( |
| { |
| "position_ids": position_ids, |
| "past_key_values": past_key_values, |
| "use_cache": kwargs.get("use_cache"), |
| "attention_mask": attention_mask, |
| } |
| ) |
| return model_inputs |
|
|
| @staticmethod |
| def _reorder_cache(past_key_values, beam_idx): |
| reordered_past = () |
| for layer_past in past_key_values: |
| reordered_past += ( |
| tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), |
| ) |
| return reordered_past |
|
|
|
|
| @add_start_docstrings( |
| """ |
| The JetMoE Model transformer with a sequence classification head on top (linear layer). |
| |
| [`JetMoEForSequenceClassification`] uses the last token in order to do the classification, as other causal models |
| (e.g. GPT-2) do. |
| |
| Since it does classification on the last token, it requires to know the position of the last token. If a |
| `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If |
| no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the |
| padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in |
| each row of the batch). |
| """, |
| MODULEFORMER_START_DOCSTRING, |
| ) |
| |
| class JetMoEForSequenceClassification(JetMoEPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.model = JetMoEModel(config) |
| self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| @add_start_docstrings_to_model_forward(MODULEFORMER_INPUTS_DOCSTRING) |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, SequenceClassifierOutputWithPast]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
| config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| transformer_outputs = self.model( |
| input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| hidden_states = transformer_outputs[0] |
| logits = self.score(hidden_states) |
|
|
| if input_ids is not None: |
| batch_size = input_ids.shape[0] |
| else: |
| batch_size = inputs_embeds.shape[0] |
|
|
| if self.config.pad_token_id is None and batch_size != 1: |
| raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") |
| if self.config.pad_token_id is None: |
| sequence_lengths = -1 |
| else: |
| if input_ids is not None: |
| |
| sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 |
| sequence_lengths = sequence_lengths % input_ids.shape[-1] |
| sequence_lengths = sequence_lengths.to(logits.device) |
| else: |
| sequence_lengths = -1 |
|
|
| pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] |
|
|
| loss = None |
| if labels is not None: |
| labels = labels.to(logits.device) |
| if self.config.problem_type is None: |
| if self.num_labels == 1: |
| self.config.problem_type = "regression" |
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
| self.config.problem_type = "single_label_classification" |
| else: |
| self.config.problem_type = "multi_label_classification" |
|
|
| if self.config.problem_type == "regression": |
| loss_fct = MSELoss() |
| if self.num_labels == 1: |
| loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) |
| else: |
| loss = loss_fct(pooled_logits, labels) |
| elif self.config.problem_type == "single_label_classification": |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) |
| elif self.config.problem_type == "multi_label_classification": |
| loss_fct = BCEWithLogitsLoss() |
| loss = loss_fct(pooled_logits, labels) |
| if not return_dict: |
| output = (pooled_logits,) + transformer_outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return JetMoESequenceClassifierOutputWithPast( |
| loss=loss, |
| logits=pooled_logits, |
| past_key_values=transformer_outputs.past_key_values, |
| hidden_states=transformer_outputs.hidden_states, |
| attentions=transformer_outputs.attentions, |
| aux_loss=transformer_outputs.aux_loss, |
| ) |