SlimMoE-250M-instruct / modeling_slim_moe.py
SlimFactory's picture
Upload folder using huggingface_hub
e65ee65 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel, GenerationMixin
from transformers.utils import logging
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_slim_moe import SlimMoEConfig
from .slim_moe_transformer import SlimMOETransformer
logger = logging.get_logger(__name__)
# AutoConfig.register('slim_moe', SlimMoEConfig)
# CONFIG_MAPPING.register("slim_moe", SlimMoEConfig)
class SlimMoEModel(PreTrainedModel):
config_class = SlimMoEConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["SlimMoETransformerBlock"]
def _init_weights(self, module):
std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
# MODEL_MAPPING.register(SlimMoEConfig, SlimMoEModel)
class SlimMoEForCausalLM(SlimMoEModel, GenerationMixin):
def __init__(self, config):
super().__init__(config)
self.transformer = SlimMOETransformer(
vocab_size=config.vocab_size,
dim=config.dim,
num_layers=config.num_hidden_layers,
num_heads=config.num_heads,
hidden_dim=config.hidden_dim,
num_experts=config.num_experts,
max_seq_len=config.max_seq_len,
dropout=config.dropout,
adaptive_routing=getattr(config, 'adaptive_routing', True)
)
# --- FIX: Define the lm_head at the top level of this model ---
self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
# Initialize weights and apply final processing (including weight tying)
self.post_init()
self.lm_head.weight = self.transformer.token_embedding.weight
self._dynamic_tied_weights_keys = ['lm_head.weight', 'transformer.token_embedding.weight']
# Initialize aux_loss for logging
self.aux_loss = 0.0
# Auxiliary loss coefficient (can be modified after initialization)
self.aux_loss_coefficient = getattr(config, 'aux_loss_coefficient', 0.01)
@classmethod
def from_pretrained_with_tokenizer(cls, model_path: str, tokenizer_path: str = None):
"""
Load model from pretrained and optionally use a custom tokenizer.
Args:
model_path: Path to the pretrained model
tokenizer_path: Path to custom tokenizer (if None, uses default)
Returns:
model, tokenizer tuple
"""
from transformers import AutoTokenizer
model = cls.from_pretrained(model_path, trust_remote_code=True)
if tokenizer_path:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
# Update vocab size if needed
if tokenizer.vocab_size != model.config.vocab_size:
print(f"Warning: Tokenizer vocab size ({tokenizer.vocab_size}) != "
f"model vocab size ({model.config.vocab_size})")
print(" Consider retraining model with matching vocab size")
else:
tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer
def get_input_embeddings(self):
return self.transformer.token_embedding
def set_input_embeddings(self, value):
self.transformer.token_embedding = value
def get_output_embeddings(self):
# --- FIX: Return the top-level lm_head ---
return self.lm_head
def set_output_embeddings(self, new_embeddings):
# --- FIX: Set the top-level lm_head ---
self.lm_head = new_embeddings
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
# 1. Get hidden states from the base transformer
transformer_outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask
)
hidden_states = transformer_outputs['last_hidden_state']
# 2. Project hidden states to logits
logits = self.lm_head(hidden_states)
# 3. Calculate loss if labels are provided
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
# Add auxiliary loss from MOE layers
if self.training:
aux_loss = transformer_outputs['aux_loss']
# Store aux_loss for logging (accessible via model.aux_loss)
self.aux_loss = aux_loss.item() if isinstance(aux_loss, torch.Tensor) else aux_loss
loss = loss + self.aux_loss_coefficient * aux_loss
else:
self.aux_loss = 0.0
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {
"input_ids": input_ids,
"attention_mask": kwargs.get("attention_mask"),
}
# AutoModelForCausalLM.register(SlimMoEConfig, SlimMoEForCausalLM)
# MODEL_FOR_CAUSAL_LM_MAPPING.register(SlimMoEConfig, SlimMoEForCausalLM)
def create_moe_causal_lm(vocab_size: int = 50257):
"""
Create a SlimMoEForCausalLM model with approximately 250M parameters.
Returns a full CausalLM model (not just the transformer) configured for ~250M params.
"""
from .configuration_slim_moe import SlimMoEConfig
config = SlimMoEConfig.for_300m(vocab_size=vocab_size)
model = SlimMoEForCausalLM(config)
return model