| from typing import Optional, Tuple | |
| import torch | |
| from torch import nn | |
| from transformers.cache_utils import Cache | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.models.qwen2.modeling_qwen2 import (Qwen2Attention, | |
| Qwen2MLP, | |
| Qwen2RMSNorm) | |
| from .configuration_mimo import MiMoConfig | |
| class MiMoMTPLayers(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.token_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.hidden_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.input_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) | |
| self.final_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.self_attn = Qwen2Attention(config, layer_idx=0) | |
| self.mlp = Qwen2MLP(config) | |
| def forward(self, input_embeds, | |
| hidden_states, | |
| attention_mask, | |
| position_ids, | |
| past_key_values: Optional[Cache]=None, | |
| output_attentions: Optional[bool]=False, | |
| use_cache: Optional[bool]=False, | |
| position_embedding: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| cache_position=None, | |
| **kwargs): | |
| input_embeds = self.token_layernorm(input_embeds) | |
| previous_hidden_states = self.hidden_layernorm(hidden_states) | |
| hidden_states = self.input_proj(torch.cat([previous_hidden_states, input_embeds], dim=-1)) | |
| residual = hidden_states | |
| hidden_states = self.input_layernorm(hidden_states) | |
| hidden_states, _ = self.self_attn(hidden_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| output_attentions=output_attentions, | |
| use_cache=use_cache, | |
| cache_position=cache_position, | |
| position_embedding=position_embedding, | |
| **kwargs) | |
| hidden_states = residual + hidden_states | |
| residual = hidden_states | |
| hidden_states = self.post_attention_layernorm(hidden_states) | |
| hidden_states = self.mlp(hidden_states) | |
| hidden_states = residual + hidden_states | |
| hidden_states = self.final_layernorm(hidden_states) | |
| return hidden_states | |
| class MiMoMTPBlock(PreTrainedModel): | |
| config_class = MiMoConfig | |
| def __init__(self, config: MiMoConfig): | |
| super().__init__(config) | |
| self.mtp_layers = nn.ModuleList( | |
| [nn.Identity()] + \ | |
| [MiMoMTPLayers(config) for _ in range(config.num_nextn_predict_layers - 1)] | |
| ) | |
| class MiMoMTPModel(PreTrainedModel): | |
| config_class = MiMoConfig | |
| def __init__(self, config: MiMoConfig): | |
| super().__init__(config) | |
| self.model = MiMoMTPBlock(config) | |