| | import torch
|
| | import torch.nn as nn
|
| | from transformers import PreTrainedModel, GenerationMixin
|
| | from transformers.utils import logging
|
| | from transformers.modeling_outputs import CausalLMOutputWithPast
|
| |
|
| | from .configuration_slim_moe import SlimMoEConfig
|
| | from .slim_moe_transformer import SlimMOETransformer
|
| |
|
| | logger = logging.get_logger(__name__)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class SlimMoEModel(PreTrainedModel):
|
| | config_class = SlimMoEConfig
|
| | base_model_prefix = "transformer"
|
| | supports_gradient_checkpointing = True
|
| | _no_split_modules = ["SlimMoETransformerBlock"]
|
| |
|
| | def _init_weights(self, module):
|
| | std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02
|
| | if isinstance(module, nn.Linear):
|
| | torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
| | if module.bias is not None:
|
| | torch.nn.init.zeros_(module.bias)
|
| | elif isinstance(module, nn.Embedding):
|
| | torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
| | elif isinstance(module, nn.LayerNorm):
|
| | torch.nn.init.zeros_(module.bias)
|
| | torch.nn.init.ones_(module.weight)
|
| |
|
| |
|
| |
|
| | class SlimMoEForCausalLM(SlimMoEModel, GenerationMixin):
|
| | def __init__(self, config):
|
| | super().__init__(config)
|
| |
|
| | self.transformer = SlimMOETransformer(
|
| | vocab_size=config.vocab_size,
|
| | dim=config.dim,
|
| | num_layers=config.num_hidden_layers,
|
| | num_heads=config.num_heads,
|
| | hidden_dim=config.hidden_dim,
|
| | num_experts=config.num_experts,
|
| | max_seq_len=config.max_seq_len,
|
| | dropout=config.dropout,
|
| | adaptive_routing=getattr(config, 'adaptive_routing', True)
|
| | )
|
| |
|
| |
|
| | self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
|
| |
|
| |
|
| | self.post_init()
|
| |
|
| | self.lm_head.weight = self.transformer.token_embedding.weight
|
| |
|
| | self._dynamic_tied_weights_keys = ['lm_head.weight', 'transformer.token_embedding.weight']
|
| |
|
| |
|
| | self.aux_loss = 0.0
|
| |
|
| |
|
| | self.aux_loss_coefficient = getattr(config, 'aux_loss_coefficient', 0.01)
|
| |
|
| | @classmethod
|
| | def from_pretrained_with_tokenizer(cls, model_path: str, tokenizer_path: str = None):
|
| | """
|
| | Load model from pretrained and optionally use a custom tokenizer.
|
| |
|
| | Args:
|
| | model_path: Path to the pretrained model
|
| | tokenizer_path: Path to custom tokenizer (if None, uses default)
|
| |
|
| | Returns:
|
| | model, tokenizer tuple
|
| | """
|
| | from transformers import AutoTokenizer
|
| |
|
| | model = cls.from_pretrained(model_path, trust_remote_code=True)
|
| |
|
| | if tokenizer_path:
|
| | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
| |
|
| | if tokenizer.vocab_size != model.config.vocab_size:
|
| | print(f"Warning: Tokenizer vocab size ({tokenizer.vocab_size}) != "
|
| | f"model vocab size ({model.config.vocab_size})")
|
| | print(" Consider retraining model with matching vocab size")
|
| | else:
|
| | tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| |
|
| | return model, tokenizer
|
| |
|
| | def get_input_embeddings(self):
|
| | return self.transformer.token_embedding
|
| |
|
| | def set_input_embeddings(self, value):
|
| | self.transformer.token_embedding = value
|
| |
|
| | def get_output_embeddings(self):
|
| |
|
| | return self.lm_head
|
| |
|
| | def set_output_embeddings(self, new_embeddings):
|
| |
|
| | self.lm_head = new_embeddings
|
| |
|
| | def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
| |
|
| | transformer_outputs = self.transformer(
|
| | input_ids=input_ids,
|
| | attention_mask=attention_mask
|
| | )
|
| | hidden_states = transformer_outputs['last_hidden_state']
|
| |
|
| |
|
| | logits = self.lm_head(hidden_states)
|
| |
|
| |
|
| | loss = None
|
| | if labels is not None:
|
| | shift_logits = logits[..., :-1, :].contiguous()
|
| | shift_labels = labels[..., 1:].contiguous()
|
| |
|
| | loss_fct = nn.CrossEntropyLoss()
|
| | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
| | shift_labels.view(-1))
|
| |
|
| |
|
| | if self.training:
|
| | aux_loss = transformer_outputs['aux_loss']
|
| |
|
| | self.aux_loss = aux_loss.item() if isinstance(aux_loss, torch.Tensor) else aux_loss
|
| | loss = loss + self.aux_loss_coefficient * aux_loss
|
| | else:
|
| | self.aux_loss = 0.0
|
| |
|
| | return CausalLMOutputWithPast(
|
| | loss=loss,
|
| | logits=logits,
|
| | )
|
| |
|
| | def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
| | return {
|
| | "input_ids": input_ids,
|
| | "attention_mask": kwargs.get("attention_mask"),
|
| | }
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def create_moe_causal_lm(vocab_size: int = 50257):
|
| | """
|
| | Create a SlimMoEForCausalLM model with approximately 250M parameters.
|
| |
|
| | Returns a full CausalLM model (not just the transformer) configured for ~250M params.
|
| | """
|
| | from .configuration_slim_moe import SlimMoEConfig
|
| |
|
| | config = SlimMoEConfig.for_300m(vocab_size=vocab_size)
|
| | model = SlimMoEForCausalLM(config)
|
| |
|
| | return model
|
| |
|
| |
|
| |
|