import torch import torch.nn as nn from transformers import PreTrainedModel, GenerationMixin from transformers.utils import logging from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_slim_moe import SlimMoEConfig from .slim_moe_transformer import SlimMOETransformer logger = logging.get_logger(__name__) # AutoConfig.register('slim_moe', SlimMoEConfig) # CONFIG_MAPPING.register("slim_moe", SlimMoEConfig) class SlimMoEModel(PreTrainedModel): config_class = SlimMoEConfig base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["SlimMoETransformerBlock"] def _init_weights(self, module): std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02 if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=std) elif isinstance(module, nn.LayerNorm): torch.nn.init.zeros_(module.bias) torch.nn.init.ones_(module.weight) # MODEL_MAPPING.register(SlimMoEConfig, SlimMoEModel) class SlimMoEForCausalLM(SlimMoEModel, GenerationMixin): def __init__(self, config): super().__init__(config) self.transformer = SlimMOETransformer( vocab_size=config.vocab_size, dim=config.dim, num_layers=config.num_hidden_layers, num_heads=config.num_heads, hidden_dim=config.hidden_dim, num_experts=config.num_experts, max_seq_len=config.max_seq_len, dropout=config.dropout, adaptive_routing=getattr(config, 'adaptive_routing', True) ) # --- FIX: Define the lm_head at the top level of this model --- self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False) # Initialize weights and apply final processing (including weight tying) self.post_init() self.lm_head.weight = self.transformer.token_embedding.weight self._dynamic_tied_weights_keys = ['lm_head.weight', 'transformer.token_embedding.weight'] # Initialize aux_loss for logging self.aux_loss = 0.0 # Auxiliary loss coefficient (can be modified after initialization) self.aux_loss_coefficient = getattr(config, 'aux_loss_coefficient', 0.01) @classmethod def from_pretrained_with_tokenizer(cls, model_path: str, tokenizer_path: str = None): """ Load model from pretrained and optionally use a custom tokenizer. Args: model_path: Path to the pretrained model tokenizer_path: Path to custom tokenizer (if None, uses default) Returns: model, tokenizer tuple """ from transformers import AutoTokenizer model = cls.from_pretrained(model_path, trust_remote_code=True) if tokenizer_path: tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) # Update vocab size if needed if tokenizer.vocab_size != model.config.vocab_size: print(f"Warning: Tokenizer vocab size ({tokenizer.vocab_size}) != " f"model vocab size ({model.config.vocab_size})") print(" Consider retraining model with matching vocab size") else: tokenizer = AutoTokenizer.from_pretrained(model_path) return model, tokenizer def get_input_embeddings(self): return self.transformer.token_embedding def set_input_embeddings(self, value): self.transformer.token_embedding = value def get_output_embeddings(self): # --- FIX: Return the top-level lm_head --- return self.lm_head def set_output_embeddings(self, new_embeddings): # --- FIX: Set the top-level lm_head --- self.lm_head = new_embeddings def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): # 1. Get hidden states from the base transformer transformer_outputs = self.transformer( input_ids=input_ids, attention_mask=attention_mask ) hidden_states = transformer_outputs['last_hidden_state'] # 2. Project hidden states to logits logits = self.lm_head(hidden_states) # 3. Calculate loss if labels are provided loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # Add auxiliary loss from MOE layers if self.training: aux_loss = transformer_outputs['aux_loss'] # Store aux_loss for logging (accessible via model.aux_loss) self.aux_loss = aux_loss.item() if isinstance(aux_loss, torch.Tensor) else aux_loss loss = loss + self.aux_loss_coefficient * aux_loss else: self.aux_loss = 0.0 return CausalLMOutputWithPast( loss=loss, logits=logits, ) def prepare_inputs_for_generation(self, input_ids, **kwargs): return { "input_ids": input_ids, "attention_mask": kwargs.get("attention_mask"), } # AutoModelForCausalLM.register(SlimMoEConfig, SlimMoEForCausalLM) # MODEL_FOR_CAUSAL_LM_MAPPING.register(SlimMoEConfig, SlimMoEForCausalLM) def create_moe_causal_lm(vocab_size: int = 50257): """ Create a SlimMoEForCausalLM model with approximately 250M parameters. Returns a full CausalLM model (not just the transformer) configured for ~250M params. """ from .configuration_slim_moe import SlimMoEConfig config = SlimMoEConfig.for_300m(vocab_size=vocab_size) model = SlimMoEForCausalLM(config) return model