File size: 6,371 Bytes
e65ee65 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | import torch
import torch.nn as nn
from transformers import PreTrainedModel, GenerationMixin
from transformers.utils import logging
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_slim_moe import SlimMoEConfig
from .slim_moe_transformer import SlimMOETransformer
logger = logging.get_logger(__name__)
# AutoConfig.register('slim_moe', SlimMoEConfig)
# CONFIG_MAPPING.register("slim_moe", SlimMoEConfig)
class SlimMoEModel(PreTrainedModel):
config_class = SlimMoEConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["SlimMoETransformerBlock"]
def _init_weights(self, module):
std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
# MODEL_MAPPING.register(SlimMoEConfig, SlimMoEModel)
class SlimMoEForCausalLM(SlimMoEModel, GenerationMixin):
def __init__(self, config):
super().__init__(config)
self.transformer = SlimMOETransformer(
vocab_size=config.vocab_size,
dim=config.dim,
num_layers=config.num_hidden_layers,
num_heads=config.num_heads,
hidden_dim=config.hidden_dim,
num_experts=config.num_experts,
max_seq_len=config.max_seq_len,
dropout=config.dropout,
adaptive_routing=getattr(config, 'adaptive_routing', True)
)
# --- FIX: Define the lm_head at the top level of this model ---
self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
# Initialize weights and apply final processing (including weight tying)
self.post_init()
self.lm_head.weight = self.transformer.token_embedding.weight
self._dynamic_tied_weights_keys = ['lm_head.weight', 'transformer.token_embedding.weight']
# Initialize aux_loss for logging
self.aux_loss = 0.0
# Auxiliary loss coefficient (can be modified after initialization)
self.aux_loss_coefficient = getattr(config, 'aux_loss_coefficient', 0.01)
@classmethod
def from_pretrained_with_tokenizer(cls, model_path: str, tokenizer_path: str = None):
"""
Load model from pretrained and optionally use a custom tokenizer.
Args:
model_path: Path to the pretrained model
tokenizer_path: Path to custom tokenizer (if None, uses default)
Returns:
model, tokenizer tuple
"""
from transformers import AutoTokenizer
model = cls.from_pretrained(model_path, trust_remote_code=True)
if tokenizer_path:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
# Update vocab size if needed
if tokenizer.vocab_size != model.config.vocab_size:
print(f"Warning: Tokenizer vocab size ({tokenizer.vocab_size}) != "
f"model vocab size ({model.config.vocab_size})")
print(" Consider retraining model with matching vocab size")
else:
tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer
def get_input_embeddings(self):
return self.transformer.token_embedding
def set_input_embeddings(self, value):
self.transformer.token_embedding = value
def get_output_embeddings(self):
# --- FIX: Return the top-level lm_head ---
return self.lm_head
def set_output_embeddings(self, new_embeddings):
# --- FIX: Set the top-level lm_head ---
self.lm_head = new_embeddings
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
# 1. Get hidden states from the base transformer
transformer_outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask
)
hidden_states = transformer_outputs['last_hidden_state']
# 2. Project hidden states to logits
logits = self.lm_head(hidden_states)
# 3. Calculate loss if labels are provided
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
# Add auxiliary loss from MOE layers
if self.training:
aux_loss = transformer_outputs['aux_loss']
# Store aux_loss for logging (accessible via model.aux_loss)
self.aux_loss = aux_loss.item() if isinstance(aux_loss, torch.Tensor) else aux_loss
loss = loss + self.aux_loss_coefficient * aux_loss
else:
self.aux_loss = 0.0
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {
"input_ids": input_ids,
"attention_mask": kwargs.get("attention_mask"),
}
# AutoModelForCausalLM.register(SlimMoEConfig, SlimMoEForCausalLM)
# MODEL_FOR_CAUSAL_LM_MAPPING.register(SlimMoEConfig, SlimMoEForCausalLM)
def create_moe_causal_lm(vocab_size: int = 50257):
"""
Create a SlimMoEForCausalLM model with approximately 250M parameters.
Returns a full CausalLM model (not just the transformer) configured for ~250M params.
"""
from .configuration_slim_moe import SlimMoEConfig
config = SlimMoEConfig.for_300m(vocab_size=vocab_size)
model = SlimMoEForCausalLM(config)
return model
|