from typing import Optional, Tuple import torch from torch import nn from transformers.cache_utils import Cache from transformers.modeling_utils import PreTrainedModel from transformers.models.qwen2.modeling_qwen2 import (Qwen2Attention, Qwen2MLP, Qwen2RMSNorm) from .configuration_mimo import MiMoConfig class MiMoMTPLayers(nn.Module): def __init__(self, config): super().__init__() self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.token_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hidden_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) self.final_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.self_attn = Qwen2Attention(config, layer_idx=0) self.mlp = Qwen2MLP(config) def forward(self, input_embeds, hidden_states, attention_mask, position_ids, past_key_values: Optional[Cache]=None, output_attentions: Optional[bool]=False, use_cache: Optional[bool]=False, position_embedding: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, cache_position=None, **kwargs): input_embeds = self.token_layernorm(input_embeds) previous_hidden_states = self.hidden_layernorm(hidden_states) hidden_states = self.input_proj(torch.cat([previous_hidden_states, input_embeds], dim=-1)) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, _ = self.self_attn(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embedding=position_embedding, **kwargs) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states hidden_states = self.final_layernorm(hidden_states) return hidden_states class MiMoMTPBlock(PreTrainedModel): config_class = MiMoConfig def __init__(self, config: MiMoConfig): super().__init__(config) self.mtp_layers = nn.ModuleList( [nn.Identity()] + \ [MiMoMTPLayers(config) for _ in range(config.num_nextn_predict_layers - 1)] ) class MiMoMTPModel(PreTrainedModel): config_class = MiMoConfig def __init__(self, config: MiMoConfig): super().__init__(config) self.model = MiMoMTPBlock(config)