# ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ # This file was automatically generated from src/transformers/models/longcat/modular_longcat.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_longcat.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ from functools import partial from typing import Callable, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn from transformers.activations import ACT2FN from transformers.cache_utils import Cache from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.processing_utils import Unpack from transformers.utils import logging from transformers.cache_utils import DynamicCache, StaticCache from transformers.generation import GenerationMixin from transformers.integrations import use_kernel_forward_from_hub from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, is_torch_flex_attn_available, replace_return_docstrings, ) from transformers.utils.deprecation import deprecate_kwarg from .configuration_longcat import LongcatConfig if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask from transformers.integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "LongcatConfig" @use_kernel_forward_from_hub("RMSNorm") class LongcatRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ LongcatRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class LongcatRotaryEmbedding(nn.Module): def __init__(self, config: LongcatConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class LongcatMLP(nn.Module): def __init__(self, config, hidden_size=None, intermediate_size=None): super().__init__() self.config = config self.hidden_size = config.hidden_size if hidden_size is None else hidden_size self.intermediate_size = config.ffn_hidden_size if intermediate_size is None else intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj class LongcatTopkRouter(nn.Module): def __init__(self, config): super().__init__() self.config = config self.top_k = config.moe_topk self.n_routed_experts = config.n_routed_experts if config.zero_expert_num is None else config.n_routed_experts + config.zero_expert_num self.routed_scaling_factor = config.routed_scaling_factor self.norm_topk_prob = config.norm_topk_prob self.router_bias = config.router_bias self.classifier = nn.Linear(config.hidden_size, self.n_routed_experts, bias=self.router_bias) self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts))) @torch.no_grad() def get_topk_indices(self, scores): scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] return topk_indices def forward(self, hidden_states): hidden_states = hidden_states.view( -1, self.config.hidden_size ) # hidden_states: [batchsize*seq_len, hidden_size] router_logits = F.linear(hidden_states.type(torch.float32), self.classifier.weight.type(torch.float32)) scores = router_logits.softmax(dim=-1) topk_indices = self.get_topk_indices(scores) topk_weights = scores.gather(1, topk_indices) if self.norm_topk_prob: denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 topk_weights /= denominator topk_weights = topk_weights * self.routed_scaling_factor return topk_indices, topk_weights class LongcatMoE(nn.Module): """ A mixture of expert module. """ def __init__(self, config): super().__init__() self.config = config self.experts = nn.ModuleList( [ LongcatMLP(config, intermediate_size=config.expert_ffn_hidden_size) for _ in range(config.n_routed_experts) ] ) self.router = LongcatTopkRouter(config) self.zero_expert_num = config.zero_expert_num self.zero_expert_type = config.zero_expert_type def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): r""" CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused to not have to do a loop here (deepseek has 256 experts soooo yeah). """ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) total_experts = len(self.experts) if self.zero_expert_num is None else len(self.experts) + self.zero_expert_num expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=total_experts) # (T, K, E) expert_mask = expert_mask.permute(2, 0, 1) # (E, T, K) for expert_idx in range(total_experts): expert = self.experts[expert_idx] if expert_idx 0: expert_weights = topk_weights[token_indices, weight_indices] expert_input = hidden_states[token_indices] # (T, H) --> (N, H) if self.zero_expert_num is None or expert_idx