MiMo-7B-MTPs / modeling_mimo.py
bwshen-mi's picture
Upload folder using huggingface_hub
75dbb4d verified
from typing import Optional, Tuple
import torch
from torch import nn
from transformers.cache_utils import Cache
from transformers.modeling_utils import PreTrainedModel
from transformers.models.qwen2.modeling_qwen2 import (Qwen2Attention,
Qwen2MLP,
Qwen2RMSNorm)
from .configuration_mimo import MiMoConfig
class MiMoMTPLayers(nn.Module):
def __init__(self, config):
super().__init__()
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.token_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hidden_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.final_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.self_attn = Qwen2Attention(config, layer_idx=0)
self.mlp = Qwen2MLP(config)
def forward(self, input_embeds,
hidden_states,
attention_mask,
position_ids,
past_key_values: Optional[Cache]=None,
output_attentions: Optional[bool]=False,
use_cache: Optional[bool]=False,
position_embedding: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
cache_position=None,
**kwargs):
input_embeds = self.token_layernorm(input_embeds)
previous_hidden_states = self.hidden_layernorm(hidden_states)
hidden_states = self.input_proj(torch.cat([previous_hidden_states, input_embeds], dim=-1))
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, _ = self.self_attn(hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embedding=position_embedding,
**kwargs)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class MiMoMTPBlock(PreTrainedModel):
config_class = MiMoConfig
def __init__(self, config: MiMoConfig):
super().__init__(config)
self.mtp_layers = nn.ModuleList(
[nn.Identity()] + \
[MiMoMTPLayers(config) for _ in range(config.num_nextn_predict_layers - 1)]
)
class MiMoMTPModel(PreTrainedModel):
config_class = MiMoConfig
def __init__(self, config: MiMoConfig):
super().__init__(config)
self.model = MiMoMTPBlock(config)