File size: 6,371 Bytes
e65ee65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import torch
import torch.nn as nn
from transformers import PreTrainedModel, GenerationMixin
from transformers.utils import logging
from transformers.modeling_outputs import CausalLMOutputWithPast

from .configuration_slim_moe import SlimMoEConfig
from .slim_moe_transformer import SlimMOETransformer

logger = logging.get_logger(__name__)

# AutoConfig.register('slim_moe', SlimMoEConfig)
# CONFIG_MAPPING.register("slim_moe", SlimMoEConfig)


class SlimMoEModel(PreTrainedModel):
    config_class = SlimMoEConfig
    base_model_prefix = "transformer"
    supports_gradient_checkpointing = True
    _no_split_modules = ["SlimMoETransformerBlock"]

    def _init_weights(self, module):
        std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

# MODEL_MAPPING.register(SlimMoEConfig, SlimMoEModel)

class SlimMoEForCausalLM(SlimMoEModel, GenerationMixin):
    def __init__(self, config):
        super().__init__(config)

        self.transformer = SlimMOETransformer(
            vocab_size=config.vocab_size,
            dim=config.dim,
            num_layers=config.num_hidden_layers,
            num_heads=config.num_heads,
            hidden_dim=config.hidden_dim,
            num_experts=config.num_experts,
            max_seq_len=config.max_seq_len,
            dropout=config.dropout,
            adaptive_routing=getattr(config, 'adaptive_routing', True)
        )

        # --- FIX: Define the lm_head at the top level of this model ---
        self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)

        # Initialize weights and apply final processing (including weight tying)
        self.post_init()

        self.lm_head.weight = self.transformer.token_embedding.weight

        self._dynamic_tied_weights_keys = ['lm_head.weight', 'transformer.token_embedding.weight']
        
        # Initialize aux_loss for logging
        self.aux_loss = 0.0
        
        # Auxiliary loss coefficient (can be modified after initialization)
        self.aux_loss_coefficient = getattr(config, 'aux_loss_coefficient', 0.01)

    @classmethod
    def from_pretrained_with_tokenizer(cls, model_path: str, tokenizer_path: str = None):
        """

        Load model from pretrained and optionally use a custom tokenizer.

        

        Args:

            model_path: Path to the pretrained model

            tokenizer_path: Path to custom tokenizer (if None, uses default)

        

        Returns:

            model, tokenizer tuple

        """
        from transformers import AutoTokenizer
        
        model = cls.from_pretrained(model_path, trust_remote_code=True)
        
        if tokenizer_path:
            tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
            # Update vocab size if needed
            if tokenizer.vocab_size != model.config.vocab_size:
                print(f"Warning: Tokenizer vocab size ({tokenizer.vocab_size}) != "
                      f"model vocab size ({model.config.vocab_size})")
                print("   Consider retraining model with matching vocab size")
        else:
            tokenizer = AutoTokenizer.from_pretrained(model_path)
        
        return model, tokenizer

    def get_input_embeddings(self):
        return self.transformer.token_embedding

    def set_input_embeddings(self, value):
        self.transformer.token_embedding = value

    def get_output_embeddings(self):
        # --- FIX: Return the top-level lm_head ---
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        # --- FIX: Set the top-level lm_head ---
        self.lm_head = new_embeddings

    def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
        # 1. Get hidden states from the base transformer
        transformer_outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        hidden_states = transformer_outputs['last_hidden_state']

        # 2. Project hidden states to logits
        logits = self.lm_head(hidden_states)

        # 3. Calculate loss if labels are provided
        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
                            shift_labels.view(-1))

            # Add auxiliary loss from MOE layers
            if self.training:
                aux_loss = transformer_outputs['aux_loss']
                # Store aux_loss for logging (accessible via model.aux_loss)
                self.aux_loss = aux_loss.item() if isinstance(aux_loss, torch.Tensor) else aux_loss
                loss = loss + self.aux_loss_coefficient * aux_loss
            else:
                self.aux_loss = 0.0

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
        )

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {
            "input_ids": input_ids,
            "attention_mask": kwargs.get("attention_mask"),
        }

# AutoModelForCausalLM.register(SlimMoEConfig, SlimMoEForCausalLM)

# MODEL_FOR_CAUSAL_LM_MAPPING.register(SlimMoEConfig, SlimMoEForCausalLM)


def create_moe_causal_lm(vocab_size: int = 50257):
    """

    Create a SlimMoEForCausalLM model with approximately 250M parameters.

    

    Returns a full CausalLM model (not just the transformer) configured for ~250M params.

    """
    from .configuration_slim_moe import SlimMoEConfig
    
    config = SlimMoEConfig.for_300m(vocab_size=vocab_size)
    model = SlimMoEForCausalLM(config)
    
    return model