|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from transformers import PreTrainedModel, GenerationMixin
|
|
|
from transformers.utils import logging
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
|
|
|
from .configuration_slim_moe import SlimMoEConfig
|
|
|
from .slim_moe_transformer import SlimMOETransformer
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SlimMoEModel(PreTrainedModel):
|
|
|
config_class = SlimMoEConfig
|
|
|
base_model_prefix = "transformer"
|
|
|
supports_gradient_checkpointing = True
|
|
|
_no_split_modules = ["SlimMoETransformerBlock"]
|
|
|
|
|
|
def _init_weights(self, module):
|
|
|
std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02
|
|
|
if isinstance(module, nn.Linear):
|
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
|
|
if module.bias is not None:
|
|
|
torch.nn.init.zeros_(module.bias)
|
|
|
elif isinstance(module, nn.Embedding):
|
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
|
|
elif isinstance(module, nn.LayerNorm):
|
|
|
torch.nn.init.zeros_(module.bias)
|
|
|
torch.nn.init.ones_(module.weight)
|
|
|
|
|
|
|
|
|
|
|
|
class SlimMoEForCausalLM(SlimMoEModel, GenerationMixin):
|
|
|
def __init__(self, config):
|
|
|
super().__init__(config)
|
|
|
|
|
|
self.transformer = SlimMOETransformer(
|
|
|
vocab_size=config.vocab_size,
|
|
|
dim=config.dim,
|
|
|
num_layers=config.num_hidden_layers,
|
|
|
num_heads=config.num_heads,
|
|
|
hidden_dim=config.hidden_dim,
|
|
|
num_experts=config.num_experts,
|
|
|
max_seq_len=config.max_seq_len,
|
|
|
dropout=config.dropout,
|
|
|
adaptive_routing=getattr(config, 'adaptive_routing', True)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
self.post_init()
|
|
|
|
|
|
self.lm_head.weight = self.transformer.token_embedding.weight
|
|
|
|
|
|
self._dynamic_tied_weights_keys = ['lm_head.weight', 'transformer.token_embedding.weight']
|
|
|
|
|
|
|
|
|
self.aux_loss = 0.0
|
|
|
|
|
|
|
|
|
self.aux_loss_coefficient = getattr(config, 'aux_loss_coefficient', 0.01)
|
|
|
|
|
|
@classmethod
|
|
|
def from_pretrained_with_tokenizer(cls, model_path: str, tokenizer_path: str = None):
|
|
|
"""
|
|
|
Load model from pretrained and optionally use a custom tokenizer.
|
|
|
|
|
|
Args:
|
|
|
model_path: Path to the pretrained model
|
|
|
tokenizer_path: Path to custom tokenizer (if None, uses default)
|
|
|
|
|
|
Returns:
|
|
|
model, tokenizer tuple
|
|
|
"""
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
model = cls.from_pretrained(model_path, trust_remote_code=True)
|
|
|
|
|
|
if tokenizer_path:
|
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
|
|
|
|
|
if tokenizer.vocab_size != model.config.vocab_size:
|
|
|
print(f"Warning: Tokenizer vocab size ({tokenizer.vocab_size}) != "
|
|
|
f"model vocab size ({model.config.vocab_size})")
|
|
|
print(" Consider retraining model with matching vocab size")
|
|
|
else:
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
def get_input_embeddings(self):
|
|
|
return self.transformer.token_embedding
|
|
|
|
|
|
def set_input_embeddings(self, value):
|
|
|
self.transformer.token_embedding = value
|
|
|
|
|
|
def get_output_embeddings(self):
|
|
|
|
|
|
return self.lm_head
|
|
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
|
|
|
|
self.lm_head = new_embeddings
|
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
|
|
|
|
|
transformer_outputs = self.transformer(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask
|
|
|
)
|
|
|
hidden_states = transformer_outputs['last_hidden_state']
|
|
|
|
|
|
|
|
|
logits = self.lm_head(hidden_states)
|
|
|
|
|
|
|
|
|
loss = None
|
|
|
if labels is not None:
|
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
|
|
shift_labels.view(-1))
|
|
|
|
|
|
|
|
|
if self.training:
|
|
|
aux_loss = transformer_outputs['aux_loss']
|
|
|
|
|
|
self.aux_loss = aux_loss.item() if isinstance(aux_loss, torch.Tensor) else aux_loss
|
|
|
loss = loss + self.aux_loss_coefficient * aux_loss
|
|
|
else:
|
|
|
self.aux_loss = 0.0
|
|
|
|
|
|
return CausalLMOutputWithPast(
|
|
|
loss=loss,
|
|
|
logits=logits,
|
|
|
)
|
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
|
|
return {
|
|
|
"input_ids": input_ids,
|
|
|
"attention_mask": kwargs.get("attention_mask"),
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_moe_causal_lm(vocab_size: int = 50257):
|
|
|
"""
|
|
|
Create a SlimMoEForCausalLM model with approximately 250M parameters.
|
|
|
|
|
|
Returns a full CausalLM model (not just the transformer) configured for ~250M params.
|
|
|
"""
|
|
|
from .configuration_slim_moe import SlimMoEConfig
|
|
|
|
|
|
config = SlimMoEConfig.for_300m(vocab_size=vocab_size)
|
|
|
model = SlimMoEForCausalLM(config)
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|