Charlie81 commited on
Commit
9d66be3
·
1 Parent(s): 08ab65a

init no routing

Browse files
.gitignore ADDED
File without changes
README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
myolmoe/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modeling_olmoe import MyOlmoeForCausalLM
myolmoe/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "OlmoeForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "clip_qkv": null,
8
+ "eos_token_id": 50279,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 1024,
13
+ "max_position_embeddings": 4096,
14
+ "model_type": "olmoe",
15
+ "norm_topk_prob": false,
16
+ "num_attention_heads": 16,
17
+ "num_experts": 64,
18
+ "num_experts_per_tok": 2,
19
+ "num_hidden_layers": 16,
20
+ "num_key_value_heads": 16,
21
+ "output_router_logits": false,
22
+ "pad_token_id": 1,
23
+ "rms_norm_eps": 1e-05,
24
+ "rope_scaling": null,
25
+ "rope_theta": 10000.0,
26
+ "router_aux_loss_coef": 0.01,
27
+ "tie_word_embeddings": false,
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.52.4",
30
+ "use_cache": true,
31
+ "vocab_size": 50304
32
+ }
myolmoe/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": 50279,
4
+ "pad_token_id": 1,
5
+ "transformers_version": "4.52.4"
6
+ }
myolmoe/model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e3cff7e367794685c241169072c940d200918617d5e2813f1c387dff52d845e
3
+ size 4997744872
myolmoe/model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15ef5c730ee3cfed7199498788cd2faf337203fc74b529625e7502cdd759f4a7
3
+ size 4997235176
myolmoe/model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9abac4ac1b55c9adabac721a02fa39971f103eea9a65c310972b1246de76e04
3
+ size 3843741912
myolmoe/model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
myolmoe/modeling_myolmoe.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+ from torch.distributions import Categorical
8
+ from transformers.activations import ACT2FN
9
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
10
+ from transformers.generation.utils import GenerationMixin
11
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
12
+ from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
13
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
16
+ from transformers.utils import logging
17
+ from transformers.models.olmoe.configuration_olmoe import OlmoeConfig
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+ def load_balancing_loss_func(
22
+ gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
23
+ num_experts: Optional[int] = None,
24
+ top_k=2,
25
+ attention_mask: Optional[torch.Tensor] = None,
26
+ ) -> Union[torch.Tensor, int]:
27
+ if gate_logits is None or not isinstance(gate_logits, tuple):
28
+ return 0
29
+ if isinstance(gate_logits, tuple):
30
+ compute_device = gate_logits[0].device
31
+ concatenated_gate_logits = torch.cat(
32
+ [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
33
+ )
34
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
35
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
36
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
37
+ if attention_mask is None:
38
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
39
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
40
+ else:
41
+ batch_size, sequence_length = attention_mask.shape
42
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (
43
+ batch_size * sequence_length
44
+ )
45
+ expert_attention_mask = (
46
+ attention_mask[None, :, :, None, None]
47
+ .expand(
48
+ (num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
49
+ )
50
+ .reshape(-1, top_k, num_experts)
51
+ .to(compute_device)
52
+ )
53
+ tokens_per_expert = torch.sum(
54
+ expert_mask.float() * expert_attention_mask, dim=0
55
+ ) / torch.sum(expert_attention_mask, dim=0)
56
+ router_per_expert_attention_mask = (
57
+ attention_mask[None, :, :, None]
58
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
59
+ .reshape(-1, num_experts)
60
+ .to(compute_device)
61
+ )
62
+ router_prob_per_expert = torch.sum(
63
+ routing_weights * router_per_expert_attention_mask, dim=0
64
+ ) / torch.sum(router_per_expert_attention_mask, dim=0)
65
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
66
+ return overall_loss * num_experts
67
+
68
+
69
+ class OlmoeRMSNorm(nn.Module):
70
+ def __init__(self, hidden_size, eps=1e-5):
71
+ super().__init__()
72
+ self.weight = nn.Parameter(torch.ones(hidden_size))
73
+ self.variance_epsilon = eps
74
+
75
+ def forward(self, hidden_states):
76
+ input_dtype = hidden_states.dtype
77
+ hidden_states = hidden_states.to(torch.float32)
78
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
79
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
80
+ return self.weight * hidden_states.to(input_dtype)
81
+
82
+ def extra_repr(self):
83
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
84
+
85
+
86
+ ALL_LAYERNORM_LAYERS.append(OlmoeRMSNorm)
87
+
88
+
89
+ class OlmoeRotaryEmbedding(nn.Module):
90
+ def __init__(self, config: OlmoeConfig, device=None):
91
+ super().__init__()
92
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
93
+ self.rope_type = config.rope_scaling.get(
94
+ "rope_type", config.rope_scaling.get("type")
95
+ )
96
+ else:
97
+ self.rope_type = "default"
98
+ self.max_seq_len_cached = config.max_position_embeddings
99
+ self.original_max_seq_len = config.max_position_embeddings
100
+ self.config = config
101
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
102
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
103
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
104
+ self.original_inv_freq = self.inv_freq
105
+
106
+ @torch.no_grad()
107
+ @dynamic_rope_update
108
+ def forward(self, x, position_ids):
109
+ inv_freq_expanded = (
110
+ self.inv_freq[None, :, None]
111
+ .float()
112
+ .expand(position_ids.shape[0], -1, 1)
113
+ .to(x.device)
114
+ )
115
+ position_ids_expanded = position_ids[:, None, :].float()
116
+ device_type = (
117
+ x.device.type
118
+ if isinstance(x.device.type, str) and x.device.type != "mps"
119
+ else "cpu"
120
+ )
121
+ with torch.autocast(device_type=device_type, enabled=False):
122
+ freqs = (
123
+ inv_freq_expanded.float() @ position_ids_expanded.float()
124
+ ).transpose(1, 2)
125
+ emb = torch.cat((freqs, freqs), dim=-1)
126
+ cos = emb.cos() * self.attention_scaling
127
+ sin = emb.sin() * self.attention_scaling
128
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
129
+
130
+
131
+ def rotate_half(x):
132
+ x1 = x[..., : x.shape[-1] // 2]
133
+ x2 = x[..., x.shape[-1] // 2 :]
134
+ return torch.cat((-x2, x1), dim=-1)
135
+
136
+
137
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
138
+ cos = cos.unsqueeze(unsqueeze_dim)
139
+ sin = sin.unsqueeze(unsqueeze_dim)
140
+ q_embed = (q * cos) + (rotate_half(q) * sin)
141
+ k_embed = (k * cos) + (rotate_half(k) * sin)
142
+ return q_embed, k_embed
143
+
144
+
145
+ class OlmoeMLP(nn.Module):
146
+ def __init__(self, config):
147
+ super().__init__()
148
+ self.config = config
149
+ self.hidden_size = config.hidden_size
150
+ self.intermediate_size = config.intermediate_size
151
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
152
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
153
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
154
+ self.act_fn = ACT2FN[config.hidden_act]
155
+
156
+ def forward(self, x):
157
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
158
+ return down_proj
159
+
160
+
161
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
162
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
163
+ if n_rep == 1:
164
+ return hidden_states
165
+ hidden_states = hidden_states[:, :, None, :, :].expand(
166
+ batch, num_key_value_heads, n_rep, slen, head_dim
167
+ )
168
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
169
+
170
+
171
+ class OlmoeAttention(nn.Module):
172
+ def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None):
173
+ super().__init__()
174
+ self.config = config
175
+ self.layer_idx = layer_idx
176
+ if layer_idx is None:
177
+ logger.warning_once(
178
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
179
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
180
+ "when creating this class."
181
+ )
182
+ self.attention_dropout = config.attention_dropout
183
+ self.hidden_size = config.hidden_size
184
+ self.num_heads = config.num_attention_heads
185
+ self.head_dim = self.hidden_size // self.num_heads
186
+ self.num_key_value_heads = config.num_key_value_heads
187
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
188
+ self.max_position_embeddings = config.max_position_embeddings
189
+ self.rope_theta = config.rope_theta
190
+ self.is_causal = True
191
+ if (self.head_dim * self.num_heads) != self.hidden_size:
192
+ raise ValueError(
193
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
194
+ f" and `num_heads`: {self.num_heads})."
195
+ )
196
+ self.q_proj = nn.Linear(
197
+ self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
198
+ )
199
+ self.k_proj = nn.Linear(
200
+ self.hidden_size,
201
+ self.num_key_value_heads * self.head_dim,
202
+ bias=config.attention_bias,
203
+ )
204
+ self.v_proj = nn.Linear(
205
+ self.hidden_size,
206
+ self.num_key_value_heads * self.head_dim,
207
+ bias=config.attention_bias,
208
+ )
209
+ self.o_proj = nn.Linear(
210
+ self.hidden_size, self.hidden_size, bias=config.attention_bias
211
+ )
212
+ self.q_norm = OlmoeRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
213
+ self.k_norm = OlmoeRMSNorm(
214
+ (self.hidden_size // self.num_heads) * self.num_key_value_heads,
215
+ eps=config.rms_norm_eps,
216
+ )
217
+
218
+ def forward(
219
+ self,
220
+ hidden_states: torch.Tensor,
221
+ attention_mask: Optional[torch.Tensor] = None,
222
+ position_ids: Optional[torch.LongTensor] = None,
223
+ past_key_value: Optional[Cache] = None,
224
+ output_attentions: bool = False,
225
+ use_cache: bool = False,
226
+ cache_position: Optional[torch.LongTensor] = None,
227
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
228
+ **kwargs,
229
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
230
+ bsz, q_len, _ = hidden_states.size()
231
+ query_states = self.q_norm(self.q_proj(hidden_states))
232
+ key_states = self.k_norm(self.k_proj(hidden_states))
233
+ value_states = self.v_proj(hidden_states)
234
+ if self.config.clip_qkv is not None:
235
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
236
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
237
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
238
+ query_states = query_states.view(
239
+ bsz, q_len, self.num_heads, self.head_dim
240
+ ).transpose(1, 2)
241
+ key_states = key_states.view(
242
+ bsz, q_len, self.num_key_value_heads, self.head_dim
243
+ ).transpose(1, 2)
244
+ value_states = value_states.view(
245
+ bsz, q_len, self.num_key_value_heads, self.head_dim
246
+ ).transpose(1, 2)
247
+ cos, sin = position_embeddings
248
+ query_states, key_states = apply_rotary_pos_emb(
249
+ query_states, key_states, cos, sin
250
+ )
251
+ if past_key_value is not None:
252
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
253
+ key_states, value_states = past_key_value.update(
254
+ key_states, value_states, self.layer_idx, cache_kwargs
255
+ )
256
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
257
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
258
+ attn_weights = torch.matmul(
259
+ query_states, key_states.transpose(2, 3)
260
+ ) / math.sqrt(self.head_dim)
261
+ if attention_mask is not None:
262
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
263
+ attn_weights = attn_weights + causal_mask
264
+ attn_weights = nn.functional.softmax(
265
+ attn_weights, dim=-1, dtype=torch.float32
266
+ ).to(query_states.dtype)
267
+ attn_weights = nn.functional.dropout(
268
+ attn_weights, p=self.attention_dropout, training=self.training
269
+ )
270
+ attn_output = torch.matmul(attn_weights, value_states)
271
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
272
+ raise ValueError(
273
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
274
+ f" {attn_output.size()}"
275
+ )
276
+ attn_output = attn_output.transpose(1, 2).contiguous()
277
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
278
+ attn_output = self.o_proj(attn_output)
279
+ if not output_attentions:
280
+ attn_weights = None
281
+ return attn_output, attn_weights, past_key_value
282
+
283
+
284
+ class OlmoeFlashAttention2(OlmoeAttention):
285
+ def __init__(self, *args, **kwargs):
286
+ super().__init__(*args, **kwargs)
287
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
288
+
289
+ def forward(
290
+ self,
291
+ hidden_states: torch.Tensor,
292
+ attention_mask: Optional[torch.LongTensor] = None,
293
+ position_ids: Optional[torch.LongTensor] = None,
294
+ past_key_value: Optional[Cache] = None,
295
+ output_attentions: bool = False,
296
+ use_cache: bool = False,
297
+ cache_position: Optional[torch.LongTensor] = None,
298
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
299
+ **kwargs,
300
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
301
+ output_attentions = False
302
+ bsz, q_len, _ = hidden_states.size()
303
+ query_states = self.q_norm(self.q_proj(hidden_states))
304
+ key_states = self.k_norm(self.k_proj(hidden_states))
305
+ value_states = self.v_proj(hidden_states)
306
+ if self.config.clip_qkv is not None:
307
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
308
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
309
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
310
+ query_states = query_states.view(
311
+ bsz, q_len, self.num_heads, self.head_dim
312
+ ).transpose(1, 2)
313
+ key_states = key_states.view(
314
+ bsz, q_len, self.num_key_value_heads, self.head_dim
315
+ ).transpose(1, 2)
316
+ value_states = value_states.view(
317
+ bsz, q_len, self.num_key_value_heads, self.head_dim
318
+ ).transpose(1, 2)
319
+ cos, sin = position_embeddings
320
+ query_states, key_states = apply_rotary_pos_emb(
321
+ query_states, key_states, cos, sin
322
+ )
323
+ if past_key_value is not None:
324
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
325
+ key_states, value_states = past_key_value.update(
326
+ key_states, value_states, self.layer_idx, cache_kwargs
327
+ )
328
+ query_states = query_states.transpose(1, 2)
329
+ key_states = key_states.transpose(1, 2)
330
+ value_states = value_states.transpose(1, 2)
331
+ dropout_rate = self.attention_dropout if self.training else 0.0
332
+ input_dtype = query_states.dtype
333
+ if input_dtype == torch.float32:
334
+ if torch.is_autocast_enabled():
335
+ target_dtype = torch.get_autocast_gpu_dtype()
336
+ elif hasattr(self.config, "_pre_quantization_dtype"):
337
+ target_dtype = self.config._pre_quantization_dtype
338
+ else:
339
+ target_dtype = self.q_proj.weight.dtype
340
+ logger.warning_once(
341
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
342
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
343
+ f" {target_dtype}."
344
+ )
345
+ query_states = query_states.to(target_dtype)
346
+ key_states = key_states.to(target_dtype)
347
+ value_states = value_states.to(target_dtype)
348
+ attn_output = _flash_attention_forward(
349
+ query_states,
350
+ key_states,
351
+ value_states,
352
+ attention_mask,
353
+ q_len,
354
+ dropout=dropout_rate,
355
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
356
+ is_causal=self.is_causal,
357
+ )
358
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
359
+ attn_output = self.o_proj(attn_output)
360
+ if not output_attentions:
361
+ attn_weights = None
362
+ return attn_output, attn_weights, past_key_value
363
+
364
+
365
+ class OlmoeSdpaAttention(OlmoeAttention):
366
+ def forward(
367
+ self,
368
+ hidden_states: torch.Tensor,
369
+ attention_mask: Optional[torch.Tensor] = None,
370
+ position_ids: Optional[torch.LongTensor] = None,
371
+ past_key_value: Optional[Cache] = None,
372
+ output_attentions: bool = False,
373
+ use_cache: bool = False,
374
+ cache_position: Optional[torch.LongTensor] = None,
375
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
376
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
377
+ if output_attentions:
378
+ logger.warning_once(
379
+ "OlmoeModel is using OlmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
380
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
381
+ )
382
+ return super().forward(
383
+ hidden_states=hidden_states,
384
+ attention_mask=attention_mask,
385
+ position_ids=position_ids,
386
+ past_key_value=past_key_value,
387
+ output_attentions=output_attentions,
388
+ use_cache=use_cache,
389
+ cache_position=cache_position,
390
+ position_embeddings=position_embeddings,
391
+ )
392
+ bsz, q_len, _ = hidden_states.size()
393
+ query_states = self.q_norm(self.q_proj(hidden_states))
394
+ key_states = self.k_norm(self.k_proj(hidden_states))
395
+ value_states = self.v_proj(hidden_states)
396
+ if self.config.clip_qkv is not None:
397
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
398
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
399
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
400
+ query_states = query_states.view(
401
+ bsz, q_len, self.num_heads, self.head_dim
402
+ ).transpose(1, 2)
403
+ key_states = key_states.view(
404
+ bsz, q_len, self.num_key_value_heads, self.head_dim
405
+ ).transpose(1, 2)
406
+ value_states = value_states.view(
407
+ bsz, q_len, self.num_key_value_heads, self.head_dim
408
+ ).transpose(1, 2)
409
+ cos, sin = position_embeddings
410
+ query_states, key_states = apply_rotary_pos_emb(
411
+ query_states, key_states, cos, sin
412
+ )
413
+ if past_key_value is not None:
414
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
415
+ key_states, value_states = past_key_value.update(
416
+ key_states, value_states, self.layer_idx, cache_kwargs
417
+ )
418
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
419
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
420
+ causal_mask = attention_mask
421
+ if attention_mask is not None:
422
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
423
+ if query_states.device.type == "cuda" and causal_mask is not None:
424
+ query_states = query_states.contiguous()
425
+ key_states = key_states.contiguous()
426
+ value_states = value_states.contiguous()
427
+ is_causal = True if causal_mask is None and q_len > 1 else False
428
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
429
+ query_states,
430
+ key_states,
431
+ value_states,
432
+ attn_mask=causal_mask,
433
+ dropout_p=self.attention_dropout if self.training else 0.0,
434
+ is_causal=is_causal,
435
+ )
436
+ attn_output = attn_output.transpose(1, 2).contiguous()
437
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
438
+ attn_output = self.o_proj(attn_output)
439
+ return attn_output, None, past_key_value
440
+
441
+
442
+ OLMOE_ATTENTION_CLASSES = {
443
+ "eager": OlmoeAttention,
444
+ "flash_attention_2": OlmoeFlashAttention2,
445
+ "sdpa": OlmoeSdpaAttention,
446
+ }
447
+
448
+
449
+ class OlmoeSparseMoeBlock(nn.Module):
450
+ def __init__(self, config, layer_idx: int):
451
+ super().__init__()
452
+ self.layer_idx = layer_idx
453
+ self.num_experts = config.num_experts
454
+ self.top_k = config.num_experts_per_tok
455
+ self.norm_topk_prob = config.norm_topk_prob
456
+ self.routing_type = getattr(config, "routing_type", "topk") # default to topk
457
+ self.n_step = getattr(config, "nth_step", 2) # used in nth-descending
458
+ self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
459
+ self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
460
+
461
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
462
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
463
+ hidden_states = hidden_states.view(-1, hidden_dim)
464
+ router_logits = self.gate(hidden_states)
465
+ routing_probs = F.softmax(router_logits, dim=1, dtype=torch.float)
466
+
467
+ # === Routing ===
468
+ routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
469
+
470
+ if self.norm_topk_prob:
471
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
472
+
473
+ routing_weights = routing_weights.to(hidden_states.dtype)
474
+ final_hidden_states = torch.zeros(
475
+ (batch_size * sequence_length, hidden_dim),
476
+ dtype=hidden_states.dtype,
477
+ device=hidden_states.device,
478
+ )
479
+
480
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
481
+
482
+ for expert_idx in range(self.num_experts):
483
+ expert_layer = self.experts[expert_idx]
484
+ idx, top_x = torch.where(expert_mask[expert_idx])
485
+ if top_x.numel() == 0:
486
+ continue
487
+ current_state = hidden_states[top_x]
488
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
489
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
490
+
491
+ final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
492
+ return final_hidden_states, router_logits
493
+
494
+
495
+ class OlmoeDecoderLayer(nn.Module):
496
+ def __init__(self, config: OlmoeConfig, layer_idx: int):
497
+ super().__init__()
498
+ self.hidden_size = config.hidden_size
499
+ self.self_attn = OLMOE_ATTENTION_CLASSES[config._attn_implementation](
500
+ config=config, layer_idx=layer_idx
501
+ )
502
+ self.mlp = OlmoeSparseMoeBlock(config, layer_idx)
503
+ self.input_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
504
+ self.post_attention_layernorm = OlmoeRMSNorm(
505
+ config.hidden_size, eps=config.rms_norm_eps
506
+ )
507
+
508
+ def forward(
509
+ self,
510
+ hidden_states: torch.Tensor,
511
+ attention_mask: Optional[torch.Tensor] = None,
512
+ position_ids: Optional[torch.LongTensor] = None,
513
+ past_key_value: Optional[Cache] = None,
514
+ output_attentions: Optional[bool] = False,
515
+ output_router_logits: Optional[bool] = False,
516
+ use_cache: Optional[bool] = False,
517
+ cache_position: Optional[torch.LongTensor] = None,
518
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
519
+ **kwargs,
520
+ ) -> Tuple[
521
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
522
+ ]:
523
+ residual = hidden_states
524
+ hidden_states = self.input_layernorm(hidden_states)
525
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
526
+ hidden_states=hidden_states,
527
+ attention_mask=attention_mask,
528
+ position_ids=position_ids,
529
+ past_key_value=past_key_value,
530
+ output_attentions=output_attentions,
531
+ use_cache=use_cache,
532
+ cache_position=cache_position,
533
+ position_embeddings=position_embeddings,
534
+ **kwargs,
535
+ )
536
+ hidden_states = residual + hidden_states
537
+ residual = hidden_states
538
+ hidden_states = self.post_attention_layernorm(hidden_states)
539
+ hidden_states, router_logits = self.mlp(hidden_states)
540
+ hidden_states = residual + hidden_states
541
+ outputs = (hidden_states,)
542
+ if output_attentions:
543
+ outputs += (self_attn_weights,)
544
+ if use_cache:
545
+ outputs += (present_key_value,)
546
+ if output_router_logits:
547
+ outputs += (router_logits,)
548
+ return outputs
549
+
550
+
551
+
552
+ class OlmoePreTrainedModel(PreTrainedModel):
553
+ config_class = OlmoeConfig
554
+ base_model_prefix = "model"
555
+ supports_gradient_checkpointing = True
556
+ _no_split_modules = ["OlmoeDecoderLayer"]
557
+ _skip_keys_device_placement = ["past_key_values"]
558
+ _supports_flash_attn_2 = True
559
+ _supports_sdpa = True
560
+ _supports_cache_class = True
561
+ _supports_quantized_cache = True
562
+ _supports_static_cache = False
563
+
564
+ def _init_weights(self, module):
565
+ std = self.config.initializer_range
566
+ if isinstance(module, nn.Linear):
567
+ module.weight.data.normal_(mean=0.0, std=std)
568
+ if module.bias is not None:
569
+ module.bias.data.zero_()
570
+ elif isinstance(module, OlmoeRMSNorm):
571
+ module.weight.data.fill_(1.0)
572
+ elif isinstance(module, nn.Embedding):
573
+ module.weight.data.normal_(mean=0.0, std=std)
574
+ if module.padding_idx is not None:
575
+ module.weight.data[module.padding_idx].zero_()
576
+
577
+
578
+
579
+ class OlmoeModel(OlmoePreTrainedModel):
580
+ def __init__(self, config: OlmoeConfig):
581
+ super().__init__(config)
582
+ self.padding_idx = config.pad_token_id
583
+ self.vocab_size = config.vocab_size
584
+ self.embed_tokens = nn.Embedding(
585
+ config.vocab_size, config.hidden_size, self.padding_idx
586
+ )
587
+ self.layers = nn.ModuleList(
588
+ [
589
+ OlmoeDecoderLayer(config, layer_idx)
590
+ for layer_idx in range(config.num_hidden_layers)
591
+ ]
592
+ )
593
+ self.norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
594
+ self.rotary_emb = OlmoeRotaryEmbedding(config=config)
595
+ self.gradient_checkpointing = False
596
+ self.post_init()
597
+
598
+ def get_input_embeddings(self):
599
+ return self.embed_tokens
600
+
601
+ def set_input_embeddings(self, value):
602
+ self.embed_tokens = value
603
+
604
+
605
+ def forward(
606
+ self,
607
+ input_ids: Optional[torch.LongTensor] = None,
608
+ attention_mask: Optional[torch.Tensor] = None,
609
+ position_ids: Optional[torch.LongTensor] = None,
610
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
611
+ inputs_embeds: Optional[torch.FloatTensor] = None,
612
+ use_cache: Optional[bool] = None,
613
+ output_attentions: Optional[bool] = None,
614
+ output_hidden_states: Optional[bool] = None,
615
+ output_router_logits: Optional[bool] = None,
616
+ return_dict: Optional[bool] = None,
617
+ cache_position: Optional[torch.LongTensor] = None,
618
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
619
+ output_attentions = (
620
+ output_attentions
621
+ if output_attentions is not None
622
+ else self.config.output_attentions
623
+ )
624
+ output_router_logits = (
625
+ output_router_logits
626
+ if output_router_logits is not None
627
+ else self.config.output_router_logits
628
+ )
629
+ output_hidden_states = (
630
+ output_hidden_states
631
+ if output_hidden_states is not None
632
+ else self.config.output_hidden_states
633
+ )
634
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
635
+ return_dict = (
636
+ return_dict if return_dict is not None else self.config.use_return_dict
637
+ )
638
+ if (input_ids is None) ^ (inputs_embeds is not None):
639
+ raise ValueError(
640
+ "You must specify exactly one of input_ids or inputs_embeds"
641
+ )
642
+ if self.gradient_checkpointing and self.training and use_cache:
643
+ logger.warning_once(
644
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
645
+ )
646
+ use_cache = False
647
+ if inputs_embeds is None:
648
+ inputs_embeds = self.embed_tokens(input_ids)
649
+ return_legacy_cache = False
650
+ if use_cache and not isinstance(past_key_values, Cache):
651
+ return_legacy_cache = True
652
+ if past_key_values is None:
653
+ past_key_values = DynamicCache()
654
+ else:
655
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
656
+ logger.warning_once(
657
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
658
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
659
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
660
+ )
661
+ if cache_position is None:
662
+ past_seen_tokens = (
663
+ past_key_values.get_seq_length() if past_key_values is not None else 0
664
+ )
665
+ cache_position = torch.arange(
666
+ past_seen_tokens,
667
+ past_seen_tokens + inputs_embeds.shape[1],
668
+ device=inputs_embeds.device,
669
+ )
670
+ if position_ids is None:
671
+ position_ids = cache_position.unsqueeze(0)
672
+ causal_mask = self._update_causal_mask(
673
+ attention_mask,
674
+ inputs_embeds,
675
+ cache_position,
676
+ past_key_values,
677
+ output_attentions,
678
+ )
679
+ hidden_states = inputs_embeds
680
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
681
+ all_hidden_states = () if output_hidden_states else None
682
+ all_self_attns = () if output_attentions else None
683
+ all_router_logits = () if output_router_logits else None
684
+ next_decoder_cache = None
685
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
686
+ if output_hidden_states:
687
+ all_hidden_states += (hidden_states,)
688
+ if self.gradient_checkpointing and self.training:
689
+ layer_outputs = self._gradient_checkpointing_func(
690
+ decoder_layer.__call__,
691
+ hidden_states,
692
+ causal_mask,
693
+ position_ids,
694
+ past_key_values,
695
+ output_attentions,
696
+ output_router_logits,
697
+ use_cache,
698
+ cache_position,
699
+ position_embeddings,
700
+ )
701
+ else:
702
+ layer_outputs = decoder_layer(
703
+ hidden_states,
704
+ attention_mask=causal_mask,
705
+ position_ids=position_ids,
706
+ past_key_value=past_key_values,
707
+ output_attentions=output_attentions,
708
+ output_router_logits=output_router_logits,
709
+ use_cache=use_cache,
710
+ cache_position=cache_position,
711
+ position_embeddings=position_embeddings,
712
+ )
713
+ hidden_states = layer_outputs[0]
714
+ if use_cache:
715
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
716
+ if output_attentions:
717
+ all_self_attns += (layer_outputs[1],)
718
+ if output_router_logits and layer_outputs[-1] is not None:
719
+ all_router_logits += (layer_outputs[-1],)
720
+ hidden_states = self.norm(hidden_states)
721
+ if output_hidden_states:
722
+ all_hidden_states += (hidden_states,)
723
+ next_cache = next_decoder_cache if use_cache else None
724
+ if return_legacy_cache:
725
+ next_cache = next_cache.to_legacy_cache()
726
+ if not return_dict:
727
+ return tuple(
728
+ v
729
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
730
+ if v is not None
731
+ )
732
+ return MoeModelOutputWithPast(
733
+ last_hidden_state=hidden_states,
734
+ past_key_values=next_cache,
735
+ hidden_states=all_hidden_states,
736
+ attentions=all_self_attns,
737
+ router_logits=all_router_logits,
738
+ )
739
+
740
+ def _update_causal_mask(
741
+ self,
742
+ attention_mask: torch.Tensor,
743
+ input_tensor: torch.Tensor,
744
+ cache_position: torch.Tensor,
745
+ past_key_values: Cache,
746
+ output_attentions: bool,
747
+ ):
748
+ if self.config._attn_implementation == "flash_attention_2":
749
+ if attention_mask is not None and 0.0 in attention_mask:
750
+ return attention_mask
751
+ return None
752
+ past_seen_tokens = (
753
+ past_key_values.get_seq_length() if past_key_values is not None else 0
754
+ )
755
+ using_static_cache = isinstance(past_key_values, StaticCache)
756
+ if (
757
+ self.config._attn_implementation == "sdpa"
758
+ and not using_static_cache
759
+ and not output_attentions
760
+ ):
761
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
762
+ attention_mask,
763
+ inputs_embeds=input_tensor,
764
+ past_key_values_length=past_seen_tokens,
765
+ is_training=self.training,
766
+ ):
767
+ return None
768
+ dtype, device = input_tensor.dtype, input_tensor.device
769
+ sequence_length = input_tensor.shape[1]
770
+ if using_static_cache:
771
+ target_length = past_key_values.get_max_cache_shape()
772
+ else:
773
+ target_length = (
774
+ attention_mask.shape[-1]
775
+ if isinstance(attention_mask, torch.Tensor)
776
+ else past_seen_tokens + sequence_length + 1
777
+ )
778
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
779
+ attention_mask,
780
+ sequence_length=sequence_length,
781
+ target_length=target_length,
782
+ dtype=dtype,
783
+ device=device,
784
+ cache_position=cache_position,
785
+ batch_size=input_tensor.shape[0],
786
+ )
787
+ if (
788
+ self.config._attn_implementation == "sdpa"
789
+ and attention_mask is not None
790
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
791
+ and not output_attentions
792
+ ):
793
+ min_dtype = torch.finfo(dtype).min
794
+ causal_mask = AttentionMaskConverter._unmask_unattended(
795
+ causal_mask, min_dtype
796
+ )
797
+ return causal_mask
798
+
799
+ @staticmethod
800
+ def _prepare_4d_causal_attention_mask_with_cache_position(
801
+ attention_mask: torch.Tensor,
802
+ sequence_length: int,
803
+ target_length: int,
804
+ dtype: torch.dtype,
805
+ device: torch.device,
806
+ cache_position: torch.Tensor,
807
+ batch_size: int,
808
+ **kwargs,
809
+ ):
810
+ if attention_mask is not None and attention_mask.dim() == 4:
811
+ causal_mask = attention_mask
812
+ else:
813
+ min_dtype = torch.finfo(dtype).min
814
+ causal_mask = torch.full(
815
+ (sequence_length, target_length),
816
+ fill_value=min_dtype,
817
+ dtype=dtype,
818
+ device=device,
819
+ )
820
+ if sequence_length != 1:
821
+ causal_mask = torch.triu(causal_mask, diagonal=1)
822
+ causal_mask *= torch.arange(
823
+ target_length, device=device
824
+ ) > cache_position.reshape(-1, 1)
825
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
826
+ if attention_mask is not None:
827
+ causal_mask = causal_mask.clone()
828
+ mask_length = attention_mask.shape[-1]
829
+ padding_mask = (
830
+ causal_mask[:, :, :, :mask_length]
831
+ + attention_mask[:, None, None, :]
832
+ )
833
+ padding_mask = padding_mask == 0
834
+ causal_mask[:, :, :, :mask_length] = causal_mask[
835
+ :, :, :, :mask_length
836
+ ].masked_fill(padding_mask, min_dtype)
837
+ return causal_mask
838
+
839
+
840
+ class MyOlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
841
+ _tied_weights_keys = ["lm_head.weight"]
842
+
843
+ def __init__(self, config):
844
+ super().__init__(config)
845
+ self.model = OlmoeModel(config)
846
+ self.vocab_size = config.vocab_size
847
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
848
+ self.router_aux_loss_coef = config.router_aux_loss_coef
849
+ self.num_experts = config.num_experts
850
+ self.num_experts_per_tok = config.num_experts_per_tok
851
+ self.post_init()
852
+
853
+ def get_input_embeddings(self):
854
+ return self.model.embed_tokens
855
+
856
+ def set_input_embeddings(self, value):
857
+ self.model.embed_tokens = value
858
+
859
+ def get_output_embeddings(self):
860
+ return self.lm_head
861
+
862
+ def set_output_embeddings(self, new_embeddings):
863
+ self.lm_head = new_embeddings
864
+
865
+ def set_decoder(self, decoder):
866
+ self.model = decoder
867
+
868
+ def get_decoder(self):
869
+ return self.model
870
+
871
+
872
+ def forward(
873
+ self,
874
+ input_ids: Optional[torch.LongTensor] = None,
875
+ attention_mask: Optional[torch.Tensor] = None,
876
+ position_ids: Optional[torch.LongTensor] = None,
877
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
878
+ inputs_embeds: Optional[torch.FloatTensor] = None,
879
+ labels: Optional[torch.LongTensor] = None,
880
+ use_cache: Optional[bool] = None,
881
+ output_attentions: Optional[bool] = None,
882
+ output_hidden_states: Optional[bool] = None,
883
+ output_router_logits: Optional[bool] = None,
884
+ return_dict: Optional[bool] = None,
885
+ cache_position: Optional[torch.LongTensor] = None,
886
+ logits_to_keep: Union[int, torch.Tensor] = 0,
887
+ **loss_kwargs,
888
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
889
+ output_attentions = (
890
+ output_attentions
891
+ if output_attentions is not None
892
+ else self.config.output_attentions
893
+ )
894
+ output_router_logits = (
895
+ output_router_logits
896
+ if output_router_logits is not None
897
+ else self.config.output_router_logits
898
+ )
899
+ output_hidden_states = (
900
+ output_hidden_states
901
+ if output_hidden_states is not None
902
+ else self.config.output_hidden_states
903
+ )
904
+ return_dict = (
905
+ return_dict if return_dict is not None else self.config.use_return_dict
906
+ )
907
+ outputs = self.model(
908
+ input_ids=input_ids,
909
+ attention_mask=attention_mask,
910
+ position_ids=position_ids,
911
+ past_key_values=past_key_values,
912
+ inputs_embeds=inputs_embeds,
913
+ use_cache=use_cache,
914
+ output_attentions=output_attentions,
915
+ output_hidden_states=output_hidden_states,
916
+ output_router_logits=output_router_logits,
917
+ return_dict=return_dict,
918
+ cache_position=cache_position,
919
+ )
920
+ hidden_states = outputs[0]
921
+ slice_indices = (
922
+ slice(-logits_to_keep, None)
923
+ if isinstance(logits_to_keep, int)
924
+ else logits_to_keep
925
+ )
926
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
927
+ loss = None
928
+ if labels is not None:
929
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
930
+ aux_loss = None
931
+ if output_router_logits:
932
+ aux_loss = load_balancing_loss_func(
933
+ outputs.router_logits if return_dict else outputs[-1],
934
+ self.num_experts,
935
+ self.num_experts_per_tok,
936
+ attention_mask,
937
+ )
938
+ if labels is not None:
939
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
940
+ if not return_dict:
941
+ output = (logits,) + outputs[1:]
942
+ if output_router_logits:
943
+ output = (aux_loss,) + output
944
+ return (loss,) + output if loss is not None else output
945
+ return MoeCausalLMOutputWithPast(
946
+ loss=loss,
947
+ aux_loss=aux_loss,
948
+ logits=logits,
949
+ past_key_values=outputs.past_key_values,
950
+ hidden_states=outputs.hidden_states,
951
+ attentions=outputs.attentions,
952
+ router_logits=outputs.router_logits,
953
+ )
954
+
955
+ __all__ = ["MyOlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"]
myolmoe/special_tokens_map.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "<|padding|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ }
16
+ }
myolmoe/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
myolmoe/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "|||IP_ADDRESS|||",
8
+ "lstrip": false,
9
+ "normalized": true,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": false
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ },
206
+ "50277": {
207
+ "content": "|||EMAIL_ADDRESS|||",
208
+ "lstrip": false,
209
+ "normalized": true,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": false
213
+ },
214
+ "50278": {
215
+ "content": "|||PHONE_NUMBER|||",
216
+ "lstrip": false,
217
+ "normalized": true,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": false
221
+ },
222
+ "50279": {
223
+ "content": "<|endoftext|>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": true
229
+ }
230
+ },
231
+ "bos_token": null,
232
+ "clean_up_tokenization_spaces": true,
233
+ "eos_token": "<|endoftext|>",
234
+ "extra_special_tokens": {},
235
+ "model_max_length": 1000000000000000019884624838656,
236
+ "pad_token": "<|padding|>",
237
+ "tokenizer_class": "GPTNeoXTokenizer",
238
+ "unk_token": null
239
+ }
oldreqs.txt ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.2.2
2
+ accelerate==1.4.0
3
+ ai2-olmo==0.6.0
4
+ ai2-olmo-core==0.1.0
5
+ aiohappyeyeballs==2.4.6
6
+ aiohttp==3.11.12
7
+ aiosignal==1.3.2
8
+ annotated-types==0.7.0
9
+ antlr4-python3-runtime==4.9.3
10
+ attrs==25.1.0
11
+ bitsandbytes==0.45.5
12
+ boto3==1.38.18
13
+ botocore==1.38.18
14
+ cached_path==1.7.3
15
+ cachetools==5.5.2
16
+ certifi==2025.1.31
17
+ chardet==5.2.0
18
+ charset-normalizer==3.4.1
19
+ click==8.2.1
20
+ colorama==0.4.6
21
+ DataProperty==1.1.0
22
+ datasets==3.3.1
23
+ dill==0.3.8
24
+ docstring_parser==0.16
25
+ einops==0.8.1
26
+ evaluate==0.4.3
27
+ filelock==3.17.0
28
+ flash_attn==2.7.4.post1
29
+ frozenlist==1.5.0
30
+ fsspec==2024.12.0
31
+ google-api-core==2.25.0rc1
32
+ google-auth==2.40.1
33
+ google-cloud-core==2.4.3
34
+ google-cloud-storage==2.19.0
35
+ google-crc32c==1.7.1
36
+ google-resumable-media==2.7.2
37
+ googleapis-common-protos==1.70.0
38
+ grpcio==1.71.0
39
+ huggingface-hub==0.31.2
40
+ idna==3.10
41
+ importlib_resources==6.5.2
42
+ Jinja2==3.1.5
43
+ jmespath==1.0.1
44
+ joblib==1.5.1
45
+ jsonlines==4.0.0
46
+ # Editable install with no version control (lm_eval==0.4.8)
47
+ -e /home/ianwu/SkipMoE/SkipPassMoE/lm-evaluation-harness
48
+ lxml==5.4.0
49
+ Markdown==3.8
50
+ markdown-it-py==3.0.0
51
+ MarkupSafe==3.0.2
52
+ mbstrdecoder==1.1.4
53
+ mdurl==0.1.2
54
+ more-itertools==10.7.0
55
+ mpmath==1.3.0
56
+ multidict==6.1.0
57
+ multiprocess==0.70.16
58
+ networkx==3.4.2
59
+ nltk==3.9.1
60
+ numexpr==2.10.2
61
+ numpy==1.26.4
62
+ nvidia-cublas-cu12==12.4.5.8
63
+ nvidia-cuda-cupti-cu12==12.4.127
64
+ nvidia-cuda-nvrtc-cu12==12.4.127
65
+ nvidia-cuda-runtime-cu12==12.4.127
66
+ nvidia-cudnn-cu12==9.1.0.70
67
+ nvidia-cufft-cu12==11.2.1.3
68
+ nvidia-curand-cu12==10.3.5.147
69
+ nvidia-cusolver-cu12==11.6.1.9
70
+ nvidia-cusparse-cu12==12.3.1.170
71
+ nvidia-cusparselt-cu12==0.6.2
72
+ nvidia-nccl-cu12==2.21.5
73
+ nvidia-nvjitlink-cu12==12.4.127
74
+ nvidia-nvtx-cu12==12.4.127
75
+ omegaconf==2.3.0
76
+ packaging==24.2
77
+ pandas==2.2.3
78
+ pathvalidate==3.2.3
79
+ peft==0.15.2
80
+ portalocker==3.1.1
81
+ propcache==0.2.1
82
+ proto-plus==1.26.1
83
+ protobuf==6.31.0
84
+ psutil==7.0.0
85
+ pyarrow==19.0.1
86
+ pyasn1==0.6.1
87
+ pyasn1_modules==0.4.2
88
+ pybind11==2.13.6
89
+ pydantic==2.11.4
90
+ pydantic_core==2.33.2
91
+ Pygments==2.19.1
92
+ pytablewriter==1.2.1
93
+ python-dateutil==2.9.0.post0
94
+ pytz==2025.1
95
+ PyYAML==6.0.2
96
+ regex==2024.11.6
97
+ requests==2.32.3
98
+ rich==13.9.4
99
+ rouge_score==0.1.2
100
+ rsa==4.9.1
101
+ s3transfer==0.12.0
102
+ sacrebleu==2.5.1
103
+ safetensors==0.5.2
104
+ scikit-learn==1.6.1
105
+ scipy==1.15.3
106
+ sentencepiece==0.2.0
107
+ shtab==1.7.1
108
+ six==1.17.0
109
+ sqlitedict==2.1.0
110
+ sympy==1.13.1
111
+ tabledata==1.3.4
112
+ tabulate==0.9.0
113
+ tcolorpy==0.1.7
114
+ tensorboard==2.19.0
115
+ tensorboard-data-server==0.7.2
116
+ threadpoolctl==3.6.0
117
+ tokenizers==0.21.0
118
+ torch==2.6.0
119
+ tqdm==4.67.1
120
+ tqdm-multiprocess==0.0.11
121
+ transformers==4.51.3
122
+ triton==3.2.0
123
+ trl==0.10.1
124
+ typeguard==4.4.2
125
+ typepy==1.3.4
126
+ typing-inspection==0.4.0
127
+ typing_extensions==4.12.2
128
+ tyro==0.9.15
129
+ tzdata==2025.1
130
+ urllib3==2.3.0
131
+ Werkzeug==3.1.3
132
+ word2number==1.1
133
+ xxhash==3.5.0
134
+ yarl==1.18.3
135
+ zstandard==0.23.0
requirements.txt ADDED
File without changes
scripts/downloadmodel.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig
2
+ import os
3
+ os.environ["HF_HUB_READ_TIMEOUT"] = "60"
4
+
5
+ config = AutoConfig.from_pretrained("allenai/OLMoE-1B-7B-0924", timeout=60)
6
+
7
+ print(config)
scripts/eval.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ eval.py - Evaluation script for OLMoE models using lm-evaluation-harness
4
+
5
+ This script supports evaluation of both:
6
+ 1. Standard Transformers OLMoE models
7
+ 2. Custom MyOLMoE models (uses top-k routing by default)
8
+
9
+ Usage Examples:
10
+ # Evaluate standard OLMoE model
11
+ python eval.py --model_type transformers --tasks mmlu hellaswag
12
+
13
+ # Evaluate custom MyOLMoE model
14
+ python eval.py --model_type custom --tasks mmlu
15
+ """
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+ import sys
21
+ import logging
22
+ from typing import Dict, List, Optional, Any
23
+ import numpy as np
24
+ import torch
25
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
26
+
27
+ # lm-eval imports
28
+ from lm_eval import evaluator
29
+ from lm_eval.models.huggingface import HFLM
30
+
31
+ # Set up logging
32
+ logging.basicConfig(
33
+ level=logging.INFO,
34
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
35
+ )
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ def parse_args():
40
+ """Parse command line arguments."""
41
+ parser = argparse.ArgumentParser(
42
+ description="Evaluate OLMoE models using lm-evaluation-harness",
43
+ formatter_class=argparse.RawDescriptionHelpFormatter,
44
+ epilog="""
45
+ Examples:
46
+ # Standard OLMoE evaluation
47
+ python eval.py --model_type transformers --tasks mmlu arc_easy
48
+
49
+ # Custom MyOLMoE evaluation (uses top-k routing by default)
50
+ python eval.py --model_type custom --tasks mmlu hellaswag
51
+ """
52
+ )
53
+
54
+ # Model arguments
55
+ parser.add_argument(
56
+ "--model_path",
57
+ type=str,
58
+ default="allenai/OLMoE-1B-7B-0924",
59
+ help="Path or name of the pretrained model"
60
+ )
61
+ parser.add_argument(
62
+ "--model_type",
63
+ type=str,
64
+ default="transformers",
65
+ choices=["transformers", "custom"],
66
+ help="Model type: 'transformers' for standard OLMoE, 'custom' for MyOLMoE"
67
+ )
68
+ parser.add_argument(
69
+ "--custom_model_path",
70
+ type=str,
71
+ default="./myolmoe_model",
72
+ help="Path to custom MyOLMoE model code (when using --model_type custom)"
73
+ )
74
+
75
+ # Evaluation arguments
76
+ parser.add_argument(
77
+ "--tasks",
78
+ type=str,
79
+ nargs="+",
80
+ default=["mmlu"],
81
+ help="Tasks to evaluate on (e.g., mmlu, hellaswag, arc_easy, gsm8k)"
82
+ )
83
+ parser.add_argument(
84
+ "--num_fewshot",
85
+ type=int,
86
+ default=0,
87
+ help="Number of few-shot examples"
88
+ )
89
+ parser.add_argument(
90
+ "--batch_size",
91
+ type=int,
92
+ default=8,
93
+ help="Batch size for evaluation"
94
+ )
95
+ parser.add_argument(
96
+ "--max_batch_size",
97
+ type=int,
98
+ default=None,
99
+ help="Maximum batch size (auto if None)"
100
+ )
101
+ parser.add_argument(
102
+ "--device",
103
+ type=str,
104
+ default="auto",
105
+ help="Device to use ('auto', 'cuda', 'cpu')"
106
+ )
107
+ parser.add_argument(
108
+ "--dtype",
109
+ type=str,
110
+ default="auto",
111
+ choices=["auto", "float16", "bfloat16", "float32"],
112
+ help="Data type for model weights"
113
+ )
114
+
115
+ # Output arguments
116
+ parser.add_argument(
117
+ "--output_dir",
118
+ type=str,
119
+ default="./eval_results",
120
+ help="Directory to save evaluation results"
121
+ )
122
+ parser.add_argument(
123
+ "--output_filename",
124
+ type=str,
125
+ default=None,
126
+ help="Custom filename for results (auto-generated if not provided)"
127
+ )
128
+
129
+ # Additional arguments
130
+ parser.add_argument(
131
+ "--limit",
132
+ type=int,
133
+ default=None,
134
+ help="Limit number of examples per task (for testing)"
135
+ )
136
+ parser.add_argument(
137
+ "--write_out",
138
+ action="store_true",
139
+ help="Write out individual predictions to files"
140
+ )
141
+ parser.add_argument(
142
+ "--trust_remote_code",
143
+ action="store_true",
144
+ help="Trust remote code when loading model"
145
+ )
146
+ parser.add_argument(
147
+ "--verbosity",
148
+ type=str,
149
+ default="INFO",
150
+ choices=["DEBUG", "INFO", "WARNING", "ERROR"],
151
+ help="Logging verbosity level"
152
+ )
153
+
154
+ return parser.parse_args()
155
+
156
+
157
+ def load_transformers_model(args) -> HFLM:
158
+ """
159
+ Load standard Transformers OLMoE model.
160
+
161
+ Args:
162
+ args: Parsed command line arguments
163
+
164
+ Returns:
165
+ HFLM: Wrapped model ready for evaluation
166
+ """
167
+ logger.info(f"Loading Transformers OLMoE model: {args.model_path}")
168
+
169
+ # Create HFLM model directly
170
+ model = HFLM(
171
+ pretrained=args.model_path,
172
+ device=args.device,
173
+ batch_size=args.batch_size,
174
+ max_batch_size=args.max_batch_size,
175
+ dtype=args.dtype,
176
+ trust_remote_code=args.trust_remote_code
177
+ )
178
+
179
+ logger.info("Transformers model loaded successfully")
180
+ return model
181
+
182
+
183
+ def load_custom_model(args) -> HFLM:
184
+ """
185
+ Load custom MyOLMoE model (uses top-k routing by default).
186
+
187
+ Args:
188
+ args: Parsed command line arguments
189
+
190
+ Returns:
191
+ HFLM: Wrapped model ready for evaluation
192
+ """
193
+ logger.info(f"Loading custom MyOLMoE model: {args.model_path}")
194
+ logger.info("Using top-k routing (default)")
195
+
196
+ # Add custom model path to Python path
197
+ if os.path.exists(args.custom_model_path):
198
+ sys.path.insert(0, args.custom_model_path)
199
+ logger.info(f"Added {args.custom_model_path} to Python path")
200
+ else:
201
+ logger.warning(f"Custom model path not found: {args.custom_model_path}")
202
+
203
+ try:
204
+ # Import custom model class
205
+ from modeling_myolmoe import MyOlmoeForCausalLM
206
+ logger.info("Successfully imported MyOlmoeForCausalLM")
207
+ except ImportError as e:
208
+ logger.error(f"Failed to import custom model: {e}")
209
+ logger.error("Make sure the custom model code is available in the specified path")
210
+ raise
211
+
212
+ # Load model configuration
213
+ config = AutoConfig.from_pretrained(
214
+ args.model_path,
215
+ trust_remote_code=args.trust_remote_code
216
+ )
217
+
218
+ logger.info("Model will use default top-k routing configuration")
219
+
220
+ # Determine torch dtype
221
+ if args.dtype == "auto":
222
+ torch_dtype = "auto"
223
+ else:
224
+ torch_dtype = {
225
+ "float16": torch.float16,
226
+ "bfloat16": torch.bfloat16,
227
+ "float32": torch.float32
228
+ }[args.dtype]
229
+
230
+ # Load the custom model
231
+ hf_model = MyOlmoeForCausalLM.from_pretrained(
232
+ args.model_path,
233
+ config=config,
234
+ torch_dtype=torch_dtype,
235
+ device_map="auto" if args.device == "auto" else None,
236
+ trust_remote_code=args.trust_remote_code
237
+ ).eval()
238
+
239
+ # Wrap in HFLM
240
+ model = HFLM(
241
+ pretrained=hf_model,
242
+ device=args.device,
243
+ batch_size=args.batch_size,
244
+ max_batch_size=args.max_batch_size,
245
+ dtype=args.dtype
246
+ )
247
+
248
+ logger.info("Custom model loaded successfully")
249
+ return model
250
+
251
+
252
+ def validate_model_config(model_path: str, trust_remote_code: bool = False) -> Dict[str, Any]:
253
+ """
254
+ Validate model configuration and return key information.
255
+
256
+ Args:
257
+ model_path: Path to the model
258
+ trust_remote_code: Whether to trust remote code
259
+
260
+ Returns:
261
+ Dict containing model configuration information
262
+ """
263
+ try:
264
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
265
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=trust_remote_code)
266
+
267
+ model_info = {
268
+ "model_type": getattr(config, "model_type", "unknown"),
269
+ "vocab_size": getattr(config, "vocab_size", "unknown"),
270
+ "hidden_size": getattr(config, "hidden_size", "unknown"),
271
+ "num_layers": getattr(config, "num_hidden_layers", "unknown"),
272
+ "num_experts": getattr(config, "num_experts", "not specified"),
273
+ }
274
+
275
+ logger.info("Model validation successful:")
276
+ for key, value in model_info.items():
277
+ logger.info(f" {key}: {value}")
278
+
279
+ return model_info
280
+
281
+ except Exception as e:
282
+ logger.error(f"Model validation failed: {e}")
283
+ raise
284
+
285
+
286
+ def make_serializable(obj: Any) -> Any:
287
+ """
288
+ Convert objects to JSON-serializable format.
289
+
290
+ Args:
291
+ obj: Object to convert
292
+
293
+ Returns:
294
+ JSON-serializable version of the object
295
+ """
296
+ if isinstance(obj, dict):
297
+ return {k: make_serializable(v) for k, v in obj.items()}
298
+ elif isinstance(obj, list):
299
+ return [make_serializable(v) for v in obj]
300
+ elif isinstance(obj, tuple):
301
+ return tuple(make_serializable(v) for v in obj)
302
+ elif isinstance(obj, (np.integer, np.floating)):
303
+ return obj.item()
304
+ elif isinstance(obj, np.dtype):
305
+ return str(obj)
306
+ elif isinstance(obj, torch.Tensor):
307
+ return obj.tolist()
308
+ elif isinstance(obj, torch.dtype):
309
+ return str(obj)
310
+ else:
311
+ return obj
312
+
313
+
314
+ def run_evaluation(args) -> Dict[str, Any]:
315
+ """
316
+ Run evaluation on the specified model.
317
+
318
+ Args:
319
+ args: Parsed command line arguments
320
+
321
+ Returns:
322
+ Dict containing evaluation results
323
+ """
324
+ logger.info("Starting evaluation...")
325
+
326
+ # Validate model first
327
+ validate_model_config(args.model_path, args.trust_remote_code)
328
+
329
+ # Load appropriate model
330
+ if args.model_type == "transformers":
331
+ model = load_transformers_model(args)
332
+ elif args.model_type == "custom":
333
+ model = load_custom_model(args)
334
+ else:
335
+ raise ValueError(f"Unknown model type: {args.model_type}")
336
+
337
+ # Run evaluation
338
+ logger.info(f"Running evaluation on tasks: {args.tasks}")
339
+ logger.info(f"Few-shot examples: {args.num_fewshot}")
340
+ logger.info(f"Batch size: {args.batch_size}")
341
+
342
+ results = evaluator.simple_evaluate(
343
+ model=model,
344
+ tasks=args.tasks,
345
+ num_fewshot=args.num_fewshot,
346
+ limit=args.limit,
347
+ write_out=args.write_out,
348
+ )
349
+
350
+ logger.info("Evaluation completed successfully")
351
+ return results
352
+
353
+
354
+ def save_results(results: Dict[str, Any], args) -> str:
355
+ """
356
+ Save evaluation results to file.
357
+
358
+ Args:
359
+ results: Evaluation results
360
+ args: Parsed command line arguments
361
+
362
+ Returns:
363
+ str: Path to saved results file
364
+ """
365
+ os.makedirs(args.output_dir, exist_ok=True)
366
+
367
+ # Generate filename if not provided
368
+ if args.output_filename is None:
369
+ model_name = os.path.basename(args.model_path.rstrip('/'))
370
+ tasks_str = "_".join(args.tasks[:3])
371
+ if len(args.tasks) > 3:
372
+ tasks_str += f"_and_{len(args.tasks)-3}_more"
373
+
374
+ if args.model_type == "custom":
375
+ filename = f"{model_name}_custom_{tasks_str}_results.json"
376
+ else:
377
+ filename = f"{model_name}_transformers_{tasks_str}_results.json"
378
+ else:
379
+ filename = args.output_filename
380
+
381
+ if not filename.endswith('.json'):
382
+ filename += '.json'
383
+
384
+ output_path = os.path.join(args.output_dir, filename)
385
+
386
+ # Prepare metadata
387
+ metadata = {
388
+ "model_path": args.model_path,
389
+ "model_type": args.model_type,
390
+ "tasks": args.tasks,
391
+ "num_fewshot": args.num_fewshot,
392
+ "batch_size": args.batch_size,
393
+ "device": args.device,
394
+ "dtype": args.dtype,
395
+ "limit": args.limit,
396
+ }
397
+
398
+ # Add routing info for custom models
399
+ if args.model_type == "custom":
400
+ metadata["routing_type"] = "top-k (default)"
401
+
402
+ results_with_metadata = {
403
+ "metadata": metadata,
404
+ "results": results
405
+ }
406
+
407
+ # Convert to JSON-serializable format
408
+ serializable_results = make_serializable(results_with_metadata)
409
+
410
+ # Save to file
411
+ with open(output_path, 'w') as f:
412
+ json.dump(serializable_results, f, indent=2)
413
+
414
+ logger.info(f"Results saved to {output_path}")
415
+ return output_path
416
+
417
+
418
+ def print_summary(results: Dict[str, Any], args) -> None:
419
+ """
420
+ Print a formatted summary of evaluation results.
421
+
422
+ Args:
423
+ results: Evaluation results
424
+ args: Parsed command line arguments
425
+ """
426
+ print(f"\n{'='*80}")
427
+ print(f"EVALUATION SUMMARY")
428
+ print(f"Model: {args.model_path}")
429
+ print(f"Type: {args.model_type.upper()}")
430
+ if args.model_type == "custom":
431
+ print(f"Routing: TOP-K (default)")
432
+ print(f"Tasks: {', '.join(args.tasks)}")
433
+ print(f"{'='*80}")
434
+
435
+ if "results" in results:
436
+ for task, metrics in results["results"].items():
437
+ if isinstance(metrics, dict):
438
+ print(f"\n📊 {task.upper()}:")
439
+ for metric, value in metrics.items():
440
+ if isinstance(value, (int, float)) and not metric.endswith('_stderr'):
441
+ stderr_key = f"{metric}_stderr"
442
+ stderr = metrics.get(stderr_key, 0)
443
+ print(f" {metric:.<20} {value:.4f} (±{stderr:.4f})")
444
+ else:
445
+ print("\n⚠️ No results found in evaluation output")
446
+
447
+ print(f"\n{'='*80}")
448
+
449
+
450
+ def main():
451
+ """Main evaluation function."""
452
+ args = parse_args()
453
+
454
+ # Set logging level
455
+ numeric_level = getattr(logging, args.verbosity.upper(), None)
456
+ if isinstance(numeric_level, int):
457
+ logging.getLogger().setLevel(numeric_level)
458
+ logger.setLevel(numeric_level)
459
+
460
+ try:
461
+ logger.info("="*80)
462
+ logger.info("Starting OLMoE Model Evaluation")
463
+ logger.info("="*80)
464
+
465
+ # Run evaluation
466
+ results = run_evaluation(args)
467
+
468
+ # Save results
469
+ output_path = save_results(results, args)
470
+
471
+ # Print summary
472
+ print_summary(results, args)
473
+
474
+ logger.info(f"✅ Evaluation completed successfully!")
475
+ logger.info(f"📁 Results saved to: {output_path}")
476
+
477
+ except KeyboardInterrupt:
478
+ logger.info("Evaluation interrupted by user")
479
+ sys.exit(1)
480
+ except Exception as e:
481
+ logger.error(f"❌ Evaluation failed: {e}")
482
+ logger.debug("Full traceback:", exc_info=True)
483
+ sys.exit(1)
484
+
485
+
486
+ if __name__ == "__main__":
487
+ main()
scripts/hello.py ADDED
@@ -0,0 +1 @@
 
 
1
+ print("hello helo")