Charlie81 commited on
Commit
f9e2c3f
·
1 Parent(s): ee85822
lm-evaluation-harness/lm_eval/models/my_olmoe.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LM Evaluation Harness Wrapper for Modified MyOLMoE
3
+ """
4
+ import torch
5
+ from typing import List, Optional, Union, Dict, Any
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ from lm_eval.api.model import LM
8
+ from lm_eval.api.registry import register_model
9
+ import numpy as np
10
+
11
+
12
+ @register_model("myolmoe")
13
+ class MyOLMoELM(LM):
14
+ """LM Evaluation Harness wrapper for MYOLMoE model."""
15
+
16
+ def __init__(
17
+ self,
18
+ pretrained: str = None,
19
+ device: str = "cuda",
20
+ batch_size: int = 1,
21
+ max_length: int = 2048,
22
+ trust_remote_code: bool = False,
23
+ dtype: str = "float16",
24
+ parallelize: bool = False,
25
+ device_map: Optional[str] = None,
26
+ **kwargs
27
+ ):
28
+ super().__init__()
29
+
30
+ # Initialize device and batch size
31
+ if device == "cuda" and not torch.cuda.is_available():
32
+ device = "cpu"
33
+ self._device = torch.device(device)
34
+ self._batch_size = batch_size
35
+ self._max_length = max_length
36
+
37
+ # Set dtype
38
+ if dtype == "float16":
39
+ self._dtype = torch.float16
40
+ elif dtype == "bfloat16":
41
+ self._dtype = torch.bfloat16
42
+ else:
43
+ self._dtype = torch.float32
44
+
45
+ # Load tokenizer and model
46
+ if pretrained:
47
+ self.tokenizer = AutoTokenizer.from_pretrained(
48
+ pretrained,
49
+ trust_remote_code=trust_remote_code,
50
+ padding_side="left"
51
+ )
52
+
53
+ # Ensure pad token is set
54
+ if self.tokenizer.pad_token is None:
55
+ if self.tokenizer.eos_token is not None:
56
+ self.tokenizer.pad_token = self.tokenizer.eos_token
57
+ else:
58
+ self.tokenizer.add_special_tokens({'pad_token': '<pad>'})
59
+
60
+ self.model = AutoModelForCausalLM.from_pretrained(
61
+ pretrained,
62
+ torch_dtype=self._dtype,
63
+ device_map=device_map if parallelize else None,
64
+ trust_remote_code=trust_remote_code,
65
+ **kwargs
66
+ )
67
+
68
+ if not parallelize:
69
+ self.model = self.model.to(self._device)
70
+
71
+ self.model.eval()
72
+ else:
73
+ raise ValueError("pretrained model path must be specified")
74
+
75
+ @property
76
+ def eot_token_id(self):
77
+ """End of text token ID."""
78
+ return self.tokenizer.eos_token_id
79
+
80
+ @property
81
+ def max_length(self):
82
+ """Maximum sequence length."""
83
+ return self._max_length
84
+
85
+ @property
86
+ def max_gen_toks(self):
87
+ """Maximum number of tokens to generate."""
88
+ return 256
89
+
90
+ @property
91
+ def batch_size(self):
92
+ """Batch size for evaluation."""
93
+ return self._batch_size
94
+
95
+ @property
96
+ def device(self):
97
+ """Device used for evaluation."""
98
+ return self._device
99
+
100
+ def tok_encode(self, string: str, add_special_tokens=True) -> List[int]:
101
+ """Encode a string to token IDs."""
102
+ return self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
103
+
104
+ def tok_decode(self, tokens: List[int]) -> str:
105
+ """Decode token IDs to string."""
106
+ return self.tokenizer.decode(tokens, skip_special_tokens=True)
107
+
108
+ def loglikelihood(self, requests: List[tuple]) -> List[tuple]:
109
+ """
110
+ Compute log-likelihood for each request.
111
+ Each request is a tuple of (context, continuation).
112
+ """
113
+ results = []
114
+
115
+ # Process requests in batches
116
+ for i in range(0, len(requests), self.batch_size):
117
+ batch = requests[i:i + self.batch_size]
118
+ batch_results = self._loglikelihood_batch(batch)
119
+ results.extend(batch_results)
120
+
121
+ return results
122
+
123
+ def _loglikelihood_batch(self, batch: List[tuple]) -> List[tuple]:
124
+ """Process a batch of loglikelihood requests."""
125
+ contexts, continuations = zip(*batch)
126
+
127
+ # Encode full sequences (context + continuation)
128
+ full_sequences = [ctx + cont for ctx, cont in zip(contexts, continuations)]
129
+ full_encodings = [self.tok_encode(seq) for seq in full_sequences]
130
+
131
+ # Encode contexts only
132
+ context_encodings = [self.tok_encode(ctx) for ctx in contexts]
133
+
134
+ # Pad sequences to the same length
135
+ max_len = min(max(len(seq) for seq in full_encodings), self.max_length)
136
+
137
+ input_ids = []
138
+ attention_masks = []
139
+ continuation_masks = []
140
+
141
+ for full_seq, ctx_seq in zip(full_encodings, context_encodings):
142
+ # Truncate if necessary (keep the end)
143
+ if len(full_seq) > max_len:
144
+ full_seq = full_seq[-max_len:]
145
+ ctx_len = max(0, len(ctx_seq) - (len(full_encodings[0]) - max_len))
146
+ else:
147
+ ctx_len = len(ctx_seq)
148
+
149
+ # Create padding
150
+ pad_length = max_len - len(full_seq)
151
+ padded_seq = [self.tokenizer.pad_token_id] * pad_length + full_seq
152
+ attention_mask = [0] * pad_length + [1] * len(full_seq)
153
+
154
+ # Create mask for continuation tokens only
155
+ continuation_mask = [0] * max_len
156
+ continuation_start = pad_length + ctx_len
157
+ for j in range(continuation_start, max_len):
158
+ continuation_mask[j] = 1
159
+
160
+ input_ids.append(padded_seq)
161
+ attention_masks.append(attention_mask)
162
+ continuation_masks.append(continuation_mask)
163
+
164
+ # Convert to tensors
165
+ input_ids = torch.tensor(input_ids, device=self.device)
166
+ attention_masks = torch.tensor(attention_masks, device=self.device)
167
+ continuation_masks = torch.tensor(continuation_masks, device=self.device)
168
+
169
+ # Forward pass
170
+ with torch.no_grad():
171
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_masks)
172
+ logits = outputs.logits
173
+
174
+ # Compute log-likelihoods
175
+ results = []
176
+ for i in range(len(batch)):
177
+ # Get logits for positions where we predict continuation tokens
178
+ # Shift logits and tokens for next-token prediction
179
+ shifted_logits = logits[i, :-1] # Remove last position
180
+ shifted_tokens = input_ids[i, 1:] # Remove first position
181
+ shifted_mask = continuation_masks[i][1:] # Remove first position
182
+
183
+ # Only consider continuation tokens
184
+ valid_positions = shifted_mask.bool()
185
+ if valid_positions.sum() == 0:
186
+ results.append((float('-inf'), False))
187
+ continue
188
+
189
+ # Get log probabilities
190
+ log_probs = torch.log_softmax(shifted_logits, dim=-1)
191
+ token_log_probs = log_probs.gather(1, shifted_tokens.unsqueeze(1)).squeeze(1)
192
+
193
+ # Sum only over continuation tokens
194
+ valid_log_probs = token_log_probs[valid_positions]
195
+ total_log_prob = valid_log_probs.sum().item()
196
+
197
+ # For simplicity, assume greedy is True
198
+ is_greedy = True
199
+
200
+ results.append((total_log_prob, is_greedy))
201
+
202
+ return results
203
+
204
+ def generate_until(self, requests: List[tuple]) -> List[str]:
205
+ """
206
+ Generate text until stopping criteria are met.
207
+ Each request is a tuple of (context, generation_kwargs).
208
+ """
209
+ results = []
210
+
211
+ # Process requests in batches
212
+ for i in range(0, len(requests), self.batch_size):
213
+ batch = requests[i:i + self.batch_size]
214
+ batch_results = self._generate_until_batch(batch)
215
+ results.extend(batch_results)
216
+
217
+ return results
218
+
219
+ def _generate_until_batch(self, batch: List[tuple]) -> List[str]:
220
+ """Process a batch of generation requests."""
221
+ contexts = []
222
+ gen_kwargs_list = []
223
+
224
+ for context, gen_kwargs in batch:
225
+ contexts.append(context)
226
+ gen_kwargs_list.append(gen_kwargs)
227
+
228
+ # Encode contexts
229
+ context_encodings = [self.tok_encode(ctx) for ctx in contexts]
230
+
231
+ # Pad contexts
232
+ max_ctx_len = min(max(len(seq) for seq in context_encodings),
233
+ self.max_length - self.max_gen_toks)
234
+
235
+ input_ids = []
236
+ attention_masks = []
237
+
238
+ for ctx_seq in context_encodings:
239
+ # Truncate if necessary (keep the end)
240
+ if len(ctx_seq) > max_ctx_len:
241
+ ctx_seq = ctx_seq[-max_ctx_len:]
242
+
243
+ # Pad sequence
244
+ pad_length = max_ctx_len - len(ctx_seq)
245
+ padded_seq = [self.tokenizer.pad_token_id] * pad_length + ctx_seq
246
+ attention_mask = [0] * pad_length + [1] * len(ctx_seq)
247
+
248
+ input_ids.append(padded_seq)
249
+ attention_masks.append(attention_mask)
250
+
251
+ # Convert to tensors
252
+ input_ids = torch.tensor(input_ids, device=self.device)
253
+ attention_masks = torch.tensor(attention_masks, device=self.device)
254
+
255
+ # Generate
256
+ with torch.no_grad():
257
+ # Use first gen_kwargs for simplicity (can be extended)
258
+ gen_kwargs = gen_kwargs_list[0] if gen_kwargs_list else {}
259
+
260
+ # Set default generation parameters
261
+ generation_kwargs = {
262
+ 'max_new_tokens': gen_kwargs.get('max_gen_toks', self.max_gen_toks),
263
+ 'do_sample': gen_kwargs.get('do_sample', False),
264
+ 'temperature': gen_kwargs.get('temperature', 1.0),
265
+ 'top_p': gen_kwargs.get('top_p', 1.0),
266
+ 'pad_token_id': self.tokenizer.pad_token_id,
267
+ 'eos_token_id': self.tokenizer.eos_token_id,
268
+ 'attention_mask': attention_masks,
269
+ 'use_cache': True,
270
+ }
271
+
272
+ generated = self.model.generate(
273
+ input_ids=input_ids,
274
+ **generation_kwargs
275
+ )
276
+
277
+ # Decode generated text
278
+ results = []
279
+ for i, gen_seq in enumerate(generated):
280
+ # Get original context length (without padding)
281
+ original_ctx_len = len(context_encodings[i])
282
+
283
+ # Extract only the newly generated tokens
284
+ if len(gen_seq) > len(input_ids[i]):
285
+ new_tokens = gen_seq[len(input_ids[i]):].tolist()
286
+ else:
287
+ new_tokens = []
288
+
289
+ # Decode
290
+ if new_tokens:
291
+ generated_text = self.tok_decode(new_tokens)
292
+ else:
293
+ generated_text = ""
294
+
295
+ # Apply stopping criteria if specified
296
+ if 'until' in gen_kwargs_list[i]:
297
+ stop_strings = gen_kwargs_list[i]['until']
298
+ if isinstance(stop_strings, str):
299
+ stop_strings = [stop_strings]
300
+
301
+ for stop_str in stop_strings:
302
+ if stop_str in generated_text:
303
+ generated_text = generated_text[:generated_text.index(stop_str)]
304
+ break
305
+
306
+ results.append(generated_text)
307
+
308
+ return results
309
+
310
+ def loglikelihood_rolling(self, requests: List[tuple]) -> List[float]:
311
+ """
312
+ Compute rolling log-likelihood for each request.
313
+ Each request is a tuple containing the text to evaluate.
314
+ """
315
+ results = []
316
+
317
+ for request in requests:
318
+ text = request[0] if isinstance(request, tuple) else request
319
+ tokens = self.tok_encode(text)
320
+
321
+ if len(tokens) <= 1:
322
+ results.append(0.0)
323
+ continue
324
+
325
+ # Compute log-likelihood using sliding window approach
326
+ total_log_prob = 0.0
327
+ total_tokens = 0
328
+
329
+ # Use sliding window for long sequences
330
+ window_size = min(self.max_length, len(tokens))
331
+
332
+ for i in range(1, len(tokens)):
333
+ # Define the window
334
+ start_idx = max(0, i - window_size + 1)
335
+ end_idx = i + 1
336
+
337
+ window_tokens = tokens[start_idx:end_idx]
338
+ input_ids = torch.tensor([window_tokens], device=self.device)
339
+
340
+ with torch.no_grad():
341
+ outputs = self.model(input_ids=input_ids)
342
+ logits = outputs.logits
343
+
344
+ # Get log probability for the target token
345
+ target_pos = len(window_tokens) - 1
346
+ target_token = window_tokens[target_pos]
347
+
348
+ if target_pos > 0: # Ensure we have a position to predict from
349
+ token_logits = logits[0, target_pos - 1]
350
+ log_prob = torch.log_softmax(token_logits, dim=-1)[target_token].item()
351
+ total_log_prob += log_prob
352
+ total_tokens += 1
353
+
354
+ # Return mean log-likelihood per token
355
+ avg_log_prob = total_log_prob / total_tokens if total_tokens > 0 else 0.0
356
+ results.append(avg_log_prob)
357
+
358
+ return results
myolmoe/modeling_myolmoe.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified OLMoE Model with Configurable Routing
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from transformers import PreTrainedModel, PretrainedConfig
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+ from typing import Optional, Tuple, Union, List
10
+ import math
11
+
12
+
13
+ class MyOLMoEConfig(PretrainedConfig):
14
+ """Configuration class for OLMoE model with configurable routing."""
15
+
16
+ model_type = "myolmoe"
17
+
18
+ def __init__(
19
+ self,
20
+ vocab_size=50304,
21
+ hidden_size=768,
22
+ intermediate_size=3072,
23
+ num_hidden_layers=12,
24
+ num_attention_heads=12,
25
+ num_key_value_heads=None,
26
+ hidden_act="swish",
27
+ max_position_embeddings=2048,
28
+ initializer_range=0.02,
29
+ rms_norm_eps=1e-5,
30
+ use_cache=True,
31
+ pad_token_id=None,
32
+ bos_token_id=1,
33
+ eos_token_id=2,
34
+ tie_word_embeddings=False,
35
+ rope_theta=10000.0,
36
+ # MoE specific parameters
37
+ num_experts=8,
38
+ num_experts_per_tok=2,
39
+ router_aux_loss_coef=0.001,
40
+ # Routing configuration
41
+ routing_type="dense", # "dense", "sparse", "non_deterministic"
42
+ router_temperature=1.0, # For non-deterministic routing
43
+ **kwargs
44
+ ):
45
+ self.vocab_size = vocab_size
46
+ self.hidden_size = hidden_size
47
+ self.intermediate_size = intermediate_size
48
+ self.num_hidden_layers = num_hidden_layers
49
+ self.num_attention_heads = num_attention_heads
50
+ self.num_key_value_heads = num_key_value_heads or num_attention_heads
51
+ self.hidden_act = hidden_act
52
+ self.max_position_embeddings = max_position_embeddings
53
+ self.initializer_range = initializer_range
54
+ self.rms_norm_eps = rms_norm_eps
55
+ self.use_cache = use_cache
56
+ self.rope_theta = rope_theta
57
+
58
+ # MoE parameters
59
+ self.num_experts = num_experts
60
+ self.num_experts_per_tok = num_experts_per_tok
61
+ self.router_aux_loss_coef = router_aux_loss_coef
62
+
63
+ # Routing configuration
64
+ self.routing_type = routing_type
65
+ self.router_temperature = router_temperature
66
+
67
+ super().__init__(
68
+ pad_token_id=pad_token_id,
69
+ bos_token_id=bos_token_id,
70
+ eos_token_id=eos_token_id,
71
+ tie_word_embeddings=tie_word_embeddings,
72
+ **kwargs
73
+ )
74
+
75
+
76
+ class MyOLMoERouter(nn.Module):
77
+ """Configurable router for OLMoE experts."""
78
+
79
+ def __init__(self, config: MyOLMoEConfig):
80
+ super().__init__()
81
+ self.config = config
82
+ self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
83
+
84
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
85
+ """
86
+ Route tokens to experts based on configuration.
87
+
88
+ Returns:
89
+ router_logits: Routing logits/probabilities
90
+ router_probs: Expert selection probabilities
91
+ """
92
+ batch_size, seq_len, hidden_dim = hidden_states.shape
93
+ hidden_states = hidden_states.view(-1, hidden_dim)
94
+
95
+ # Compute router logits
96
+ router_logits = self.gate(hidden_states)
97
+
98
+ if self.config.routing_type == "dense":
99
+ # Dense routing: use all experts with softmax weights
100
+ router_probs = F.softmax(router_logits, dim=-1)
101
+
102
+ elif self.config.routing_type == "sparse":
103
+ # Sparse routing: select top-k experts
104
+ router_probs = F.softmax(router_logits, dim=-1)
105
+ topk_weights, topk_indices = torch.topk(
106
+ router_probs, self.config.num_experts_per_tok, dim=-1
107
+ )
108
+ # Zero out non-selected experts
109
+ mask = torch.zeros_like(router_probs)
110
+ mask.scatter_(-1, topk_indices, 1.0)
111
+ router_probs = router_probs * mask
112
+ # Renormalize
113
+ router_probs = router_probs / router_probs.sum(dim=-1, keepdim=True)
114
+
115
+ elif self.config.routing_type == "non_deterministic":
116
+ # Only consider first half of experts for top-k selection
117
+ num_first_half = self.config.num_experts // 2
118
+ router_probs = F.softmax(router_logits, dim=-1)
119
+
120
+ # Create mask for first half experts
121
+ mask = torch.zeros_like(router_probs)
122
+ mask[:, :num_first_half] = 1.0
123
+
124
+ # Apply mask and renormalize probabilities
125
+ masked_probs = router_probs * mask
126
+ masked_probs = masked_probs / (masked_probs.sum(dim=-1, keepdim=True) + 1e-8)
127
+
128
+ # Select top-k from first half
129
+ topk_weights, topk_indices = torch.topk(
130
+ masked_probs[:, :num_first_half], # Only look at first half
131
+ min(self.config.num_experts_per_tok, num_first_half), # Don't exceed available experts
132
+ dim=-1
133
+ )
134
+
135
+ # Create final mask
136
+ final_mask = torch.zeros_like(router_probs)
137
+ final_mask.scatter_(-1, topk_indices, 1.0)
138
+ router_probs = router_probs * final_mask
139
+ router_probs = router_probs / (router_probs.sum(dim=-1, keepdim=True) + 1e-8)
140
+
141
+ else:
142
+ raise ValueError(f"Unsupported routing type: {self.config.routing_type}")
143
+
144
+ router_logits = router_logits.view(batch_size, seq_len, -1)
145
+ router_probs = router_probs.view(batch_size, seq_len, -1)
146
+
147
+ return router_logits, router_probs
148
+
149
+
150
+ class MyOLMoEExpert(nn.Module):
151
+ """Individual expert in the MoE layer."""
152
+
153
+ def __init__(self, config: MyOLMoEConfig):
154
+ super().__init__()
155
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
156
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
157
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
158
+ self.act_fn = self._get_activation_fn(config.hidden_act)
159
+
160
+ def _get_activation_fn(self, activation):
161
+ if activation == "swish" or activation == "silu":
162
+ return F.silu
163
+ elif activation == "relu":
164
+ return F.relu
165
+ elif activation == "gelu":
166
+ return F.gelu
167
+ else:
168
+ raise ValueError(f"Unsupported activation: {activation}")
169
+
170
+ def forward(self, x):
171
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
172
+
173
+
174
+ class MyOLMoEMLP(nn.Module):
175
+ """MoE MLP layer with configurable routing."""
176
+
177
+ def __init__(self, config: MyOLMoEConfig):
178
+ super().__init__()
179
+ self.config = config
180
+ self.router = MyOLMoERouter(config)
181
+ self.experts = nn.ModuleList([
182
+ MyOLMoEExpert(config) for _ in range(config.num_experts)
183
+ ])
184
+
185
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
186
+ batch_size, seq_len, hidden_dim = hidden_states.shape
187
+ hidden_states_flat = hidden_states.view(-1, hidden_dim)
188
+
189
+ # Route to experts
190
+ router_logits, router_probs = self.router(hidden_states)
191
+ router_probs_flat = router_probs.view(-1, self.config.num_experts)
192
+
193
+ # Process through experts
194
+ expert_outputs = []
195
+ for i, expert in enumerate(self.experts):
196
+ expert_output = expert(hidden_states_flat)
197
+ expert_outputs.append(expert_output)
198
+
199
+ expert_outputs = torch.stack(expert_outputs, dim=-1) # [batch*seq, hidden, num_experts]
200
+
201
+ # Combine expert outputs
202
+ output = torch.sum(expert_outputs * router_probs_flat.unsqueeze(1), dim=-1)
203
+ output = output.view(batch_size, seq_len, hidden_dim)
204
+
205
+ # Compute auxiliary loss
206
+ aux_loss = self._compute_aux_loss(router_probs_flat, router_logits.view(-1, self.config.num_experts))
207
+
208
+ return output, aux_loss
209
+
210
+ def _compute_aux_loss(self, router_probs, router_logits):
211
+ """Compute auxiliary loss for load balancing."""
212
+ if self.config.router_aux_loss_coef == 0:
213
+ return torch.tensor(0.0, device=router_probs.device)
214
+
215
+ # Load balancing loss
216
+ num_tokens = router_probs.shape[0]
217
+ expert_usage = router_probs.sum(dim=0) / num_tokens # Average usage per expert
218
+ aux_loss = self.config.num_experts * torch.sum(expert_usage * expert_usage)
219
+
220
+ return self.config.router_aux_loss_coef * aux_loss
221
+
222
+
223
+ class MyOLMoEDecoderLayer(nn.Module):
224
+ """Transformer decoder layer with MoE MLP."""
225
+
226
+ def __init__(self, config: MyOLMoEConfig):
227
+ super().__init__()
228
+ self.hidden_size = config.hidden_size
229
+ self.self_attn = nn.MultiheadAttention(
230
+ config.hidden_size,
231
+ config.num_attention_heads,
232
+ dropout=0.0,
233
+ batch_first=True
234
+ )
235
+ self.mlp = MyOLMoEMLP(config)
236
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
237
+ self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
238
+
239
+ def forward(
240
+ self,
241
+ hidden_states: torch.Tensor,
242
+ attention_mask: Optional[torch.Tensor] = None,
243
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
244
+ residual = hidden_states
245
+ hidden_states = self.input_layernorm(hidden_states)
246
+
247
+ # Self attention
248
+ attn_output, _ = self.self_attn(
249
+ hidden_states, hidden_states, hidden_states,
250
+ attn_mask=attention_mask
251
+ )
252
+ hidden_states = residual + attn_output
253
+
254
+ # MLP
255
+ residual = hidden_states
256
+ hidden_states = self.post_attention_layernorm(hidden_states)
257
+ mlp_output, aux_loss = self.mlp(hidden_states)
258
+ hidden_states = residual + mlp_output
259
+
260
+ return hidden_states, aux_loss
261
+
262
+
263
+ class MyOLMoEModel(PreTrainedModel):
264
+ """OLMoE model with configurable routing."""
265
+
266
+ config_class = MyOLMoEConfig
267
+
268
+ def __init__(self, config: MyOLMoEConfig):
269
+ super().__init__(config)
270
+ self.padding_idx = config.pad_token_id
271
+ self.vocab_size = config.vocab_size
272
+
273
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
274
+ self.layers = nn.ModuleList([
275
+ MyOLMoEDecoderLayer(config) for _ in range(config.num_hidden_layers)
276
+ ])
277
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
278
+
279
+ self.gradient_checkpointing = False
280
+ self.post_init()
281
+
282
+ def forward(
283
+ self,
284
+ input_ids: torch.LongTensor = None,
285
+ attention_mask: Optional[torch.Tensor] = None,
286
+ inputs_embeds: Optional[torch.FloatTensor] = None,
287
+ output_hidden_states: Optional[bool] = None,
288
+ return_dict: Optional[bool] = None,
289
+ ):
290
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
291
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
292
+
293
+ if input_ids is not None and inputs_embeds is not None:
294
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
295
+ elif input_ids is not None:
296
+ batch_size, seq_length = input_ids.shape
297
+ elif inputs_embeds is not None:
298
+ batch_size, seq_length, _ = inputs_embeds.shape
299
+ else:
300
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
301
+
302
+ if inputs_embeds is None:
303
+ inputs_embeds = self.embed_tokens(input_ids)
304
+
305
+ hidden_states = inputs_embeds
306
+
307
+ all_hidden_states = () if output_hidden_states else None
308
+ total_aux_loss = 0.0
309
+
310
+ for decoder_layer in self.layers:
311
+ if output_hidden_states:
312
+ all_hidden_states += (hidden_states,)
313
+
314
+ layer_outputs = decoder_layer(
315
+ hidden_states,
316
+ attention_mask=attention_mask,
317
+ )
318
+
319
+ hidden_states = layer_outputs[0]
320
+ total_aux_loss += layer_outputs[1]
321
+
322
+ hidden_states = self.norm(hidden_states)
323
+
324
+ if output_hidden_states:
325
+ all_hidden_states += (hidden_states,)
326
+
327
+ if not return_dict:
328
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
329
+
330
+ return {
331
+ 'last_hidden_state': hidden_states,
332
+ 'hidden_states': all_hidden_states,
333
+ 'aux_loss': total_aux_loss
334
+ }
335
+
336
+
337
+ class MyOLMoEForCausalLM(PreTrainedModel):
338
+ """MyOLMoE model for causal language modeling."""
339
+
340
+ config_class = MyOLMoEConfig
341
+
342
+ def __init__(self, config):
343
+ print("⚡ Using CUSTOM MyOLMoE implementation!") # Will show during loading
344
+ super().__init__(config)
345
+ self.model = MyOLMoEModel(config)
346
+ self.vocab_size = config.vocab_size
347
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
348
+
349
+ self.post_init()
350
+
351
+ def get_input_embeddings(self):
352
+ return self.model.embed_tokens
353
+
354
+ def set_input_embeddings(self, value):
355
+ self.model.embed_tokens = value
356
+
357
+ def get_output_embeddings(self):
358
+ return self.lm_head
359
+
360
+ def set_output_embeddings(self, new_embeddings):
361
+ self.lm_head = new_embeddings
362
+
363
+ def forward(
364
+ self,
365
+ input_ids: torch.LongTensor = None,
366
+ attention_mask: Optional[torch.Tensor] = None,
367
+ labels: Optional[torch.LongTensor] = None,
368
+ inputs_embeds: Optional[torch.FloatTensor] = None,
369
+ output_hidden_states: Optional[bool] = None,
370
+ return_dict: Optional[bool] = None,
371
+ ):
372
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
373
+
374
+ outputs = self.model(
375
+ input_ids=input_ids,
376
+ attention_mask=attention_mask,
377
+ inputs_embeds=inputs_embeds,
378
+ output_hidden_states=output_hidden_states,
379
+ return_dict=True,
380
+ )
381
+
382
+ hidden_states = outputs['last_hidden_state']
383
+ logits = self.lm_head(hidden_states)
384
+
385
+ loss = None
386
+ if labels is not None:
387
+ shift_logits = logits[..., :-1, :].contiguous()
388
+ shift_labels = labels[..., 1:].contiguous()
389
+ loss_fct = nn.CrossEntropyLoss()
390
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
391
+ shift_labels = shift_labels.view(-1)
392
+ shift_labels = shift_labels.to(shift_logits.device)
393
+ loss = loss_fct(shift_logits, shift_labels)
394
+
395
+ # Add auxiliary loss
396
+ if 'aux_loss' in outputs:
397
+ loss += outputs['aux_loss']
398
+
399
+ if not return_dict:
400
+ output = (logits,) + tuple(v for k, v in outputs.items() if k != 'last_hidden_state')
401
+ return (loss,) + output if loss is not None else output
402
+
403
+ return CausalLMOutputWithPast(
404
+ loss=loss,
405
+ logits=logits,
406
+ hidden_states=outputs.get('hidden_states'),
407
+ )
408
+
409
+
410
+ # Register the model
411
+ from transformers import AutoConfig, AutoModelForCausalLM
412
+ AutoConfig.register("myolmoe", MyOLMoEConfig)
413
+ AutoModelForCausalLM.register(MyOLMoEConfig, MyOLMoEForCausalLM)