Charlie81 commited on
Commit
39baffa
·
1 Parent(s): 2daadcc

total repo overhaul

Browse files
__init__.py DELETED
File without changes
myolmoe/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modeling_olmoe import OLMoEForCausalLM
myolmoe/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "OlmoeForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "clip_qkv": null,
8
+ "eos_token_id": 50279,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 1024,
13
+ "max_position_embeddings": 4096,
14
+ "model_type": "olmoe",
15
+ "norm_topk_prob": false,
16
+ "num_attention_heads": 16,
17
+ "num_experts": 64,
18
+ "num_experts_per_tok": 8,
19
+ "num_hidden_layers": 16,
20
+ "num_key_value_heads": 16,
21
+ "output_router_logits": false,
22
+ "pad_token_id": 1,
23
+ "rms_norm_eps": 1e-05,
24
+ "rope_scaling": null,
25
+ "rope_theta": 10000.0,
26
+ "router_aux_loss_coef": 0.01,
27
+ "tie_word_embeddings": false,
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.52.4",
30
+ "use_cache": true,
31
+ "vocab_size": 50304
32
+ }
myolmoe/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": 50279,
4
+ "pad_token_id": 1,
5
+ "transformers_version": "4.52.4"
6
+ }
myolmoe/model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e3cff7e367794685c241169072c940d200918617d5e2813f1c387dff52d845e
3
+ size 4997744872
myolmoe/model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15ef5c730ee3cfed7199498788cd2faf337203fc74b529625e7502cdd759f4a7
3
+ size 4997235176
myolmoe/model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9abac4ac1b55c9adabac721a02fa39971f103eea9a65c310972b1246de76e04
3
+ size 3843741912
modeling_myolmoe.py → myolmoe/modeling_myolmoe.py RENAMED
@@ -1,26 +1,25 @@
1
- """PyTorch MyOLMoE model with custom routing mechanisms."""
2
-
3
  import math
4
  from typing import List, Optional, Tuple, Union
5
-
6
  import torch
7
  import torch.nn.functional as F
8
  import torch.utils.checkpoint
9
  from torch import nn
10
  from torch.distributions import Categorical
11
-
12
  from transformers.activations import ACT2FN
13
  from transformers.cache_utils import Cache, DynamicCache
14
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
15
- from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
 
 
 
16
  from transformers.modeling_utils import PreTrainedModel
17
  from transformers.models.olmoe.configuration_olmoe import OlmoeConfig
18
  from transformers.utils import logging
19
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
20
 
21
-
22
  logger = logging.get_logger(__name__)
23
 
 
24
  def load_balancing_loss_func(
25
  gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
26
  num_experts: Optional[int] = None,
@@ -29,77 +28,195 @@ def load_balancing_loss_func(
29
  ) -> Union[torch.Tensor, int]:
30
  if gate_logits is None or not isinstance(gate_logits, tuple):
31
  return 0
32
-
33
  if isinstance(gate_logits, tuple):
34
  compute_device = gate_logits[0].device
35
- concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
36
-
 
37
  routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
38
-
39
  _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
40
-
41
  expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
42
-
43
  if attention_mask is None:
44
- # Compute the percentage of tokens routed to each experts
45
  tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
46
-
47
- # Compute the average probability of routing to these experts
48
  router_prob_per_expert = torch.mean(routing_weights, dim=0)
49
  else:
50
  batch_size, sequence_length = attention_mask.shape
51
- num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
52
-
53
- # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
54
  expert_attention_mask = (
55
  attention_mask[None, :, :, None, None]
56
- .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
 
 
57
  .reshape(-1, top_k, num_experts)
58
  .to(compute_device)
59
  )
60
-
61
- # Compute the percentage of tokens routed to each experts
62
- tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
63
- expert_attention_mask, dim=0
64
- )
65
-
66
- # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
67
  router_per_expert_attention_mask = (
68
  attention_mask[None, :, :, None]
69
  .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
70
  .reshape(-1, num_experts)
71
  .to(compute_device)
72
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- # Compute the average probability of routing to these experts
75
- router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
76
- router_per_expert_attention_mask, dim=0
 
 
 
 
 
77
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
80
- return overall_loss * num_experts
81
 
82
  class OlmoeAttention(nn.Module):
83
- """Multi-headed attention from 'Attention Is All You Need' paper"""
84
-
85
  def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None):
86
  super().__init__()
87
  self.config = config
88
  self.layer_idx = layer_idx
 
 
 
 
 
 
 
89
  self.hidden_size = config.hidden_size
90
  self.num_heads = config.num_attention_heads
91
  self.head_dim = self.hidden_size // self.num_heads
92
  self.num_key_value_heads = config.num_key_value_heads
93
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
94
- self.attention_dropout = config.attention_dropout
95
-
96
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
97
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
98
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
99
- self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  self.q_norm = OlmoeRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
101
- self.k_norm = OlmoeRMSNorm((self.hidden_size // self.num_heads) * self.num_key_value_heads, eps=config.rms_norm_eps)
102
-
 
 
 
103
  def forward(
104
  self,
105
  hidden_states: torch.Tensor,
@@ -113,274 +230,283 @@ class OlmoeAttention(nn.Module):
113
  **kwargs,
114
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
115
  bsz, q_len, _ = hidden_states.size()
116
-
117
  query_states = self.q_norm(self.q_proj(hidden_states))
118
  key_states = self.k_norm(self.k_proj(hidden_states))
119
  value_states = self.v_proj(hidden_states)
120
-
121
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
122
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
123
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
124
-
 
 
 
 
 
 
 
 
125
  cos, sin = position_embeddings
126
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
127
-
 
128
  if past_key_value is not None:
129
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
130
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
131
-
 
132
  key_states = repeat_kv(key_states, self.num_key_value_groups)
133
  value_states = repeat_kv(value_states, self.num_key_value_groups)
134
-
135
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
136
-
137
  if attention_mask is not None:
138
- attn_weights = attn_weights + attention_mask
139
-
140
- # upcast attention to fp32
141
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
142
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
 
 
 
143
  attn_output = torch.matmul(attn_weights, value_states)
144
-
145
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
146
  raise ValueError(
147
  f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
148
  f" {attn_output.size()}"
149
  )
150
-
151
  attn_output = attn_output.transpose(1, 2).contiguous()
152
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
153
  attn_output = self.o_proj(attn_output)
154
-
155
  if not output_attentions:
156
  attn_weights = None
 
 
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  return attn_output, attn_weights, past_key_value
159
 
160
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
161
- """Repeat key/value heads for grouped query attention"""
162
- batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape
163
- if n_rep == 1:
164
- return hidden_states
165
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim)
166
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim)
167
 
168
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
169
- """Apply rotary position embedding to query and key tensors"""
170
- cos = cos.unsqueeze(1)
171
- sin = sin.unsqueeze(1)
172
- q_embed = (q * cos) + (rotate_half(q) * sin)
173
- k_embed = (k * cos) + (rotate_half(k) * sin)
174
- return q_embed, k_embed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
- def rotate_half(x):
177
- """Rotates half the hidden dims of the input."""
178
- x1 = x[..., : x.shape[-1] // 2]
179
- x2 = x[..., x.shape[-1] // 2 :]
180
- return torch.cat((-x2, x1), dim=-1)
181
 
182
- # Define the attention classes dictionary
183
  OLMOE_ATTENTION_CLASSES = {
184
  "eager": OlmoeAttention,
 
 
185
  }
186
 
187
- class OlmoeRMSNorm(nn.Module):
188
- """RMSNorm implementation matching the original OLMoE implementation"""
189
- def __init__(self, hidden_size, eps=1e-5):
190
- super().__init__()
191
- self.weight = nn.Parameter(torch.ones(hidden_size))
192
- self.variance_epsilon = eps
193
 
194
- def forward(self, hidden_states):
195
- input_dtype = hidden_states.dtype
196
- hidden_states = hidden_states.to(torch.float32)
197
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
198
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
199
- return self.weight * hidden_states.to(input_dtype)
200
-
201
- class OlmoeMLP(nn.Module):
202
- """Feed-forward network implementation"""
203
  def __init__(self, config):
204
- super().__init__()
205
- self.config = config
206
- self.hidden_size = config.hidden_size
207
- self.intermediate_size = config.intermediate_size
208
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
209
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
210
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
211
- self.act_fn = ACT2FN[config.hidden_act]
212
-
213
- def forward(self, x):
214
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
215
-
216
- class MyOLMoERouting(nn.Module):
217
- """Custom routing mechanism for MyOLMoE with different routing strategies."""
218
-
219
- def __init__(self, config: OlmoeConfig):
220
  super().__init__()
221
  self.num_experts = config.num_experts
222
  self.top_k = config.num_experts_per_tok
223
- self.hidden_size = config.hidden_size
224
- self.routing_type = getattr(config, "routing_type", "sparse")
225
- self.router_temperature = getattr(config, "router_temperature", 1.0)
226
-
227
- # Shared components
228
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
229
-
230
- # For non-deterministic routing
231
- self.gumbel_noise = getattr(config, "gumbel_noise", 0.1)
232
-
233
- def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
234
- batch_size, sequence_length, hidden_dim = hidden_states.shape
235
- hidden_states = hidden_states.view(-1, hidden_dim)
236
- router_logits = self.gate(hidden_states)
237
-
238
- # Always use softmax, even for "dense" routing
239
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
240
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
241
-
242
- if self.norm_topk_prob:
243
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
244
-
245
- routing_weights = routing_weights.to(hidden_states.dtype)
246
-
247
- if self.routing_type == "dense":
248
- # Dense routing - use all experts equally
249
- routing_weights = torch.ones_like(router_logits) / self.num_experts
250
- selected_experts = torch.topk(routing_weights, self.top_k, dim=-1).indices
251
-
252
- elif self.routing_type == "non_deterministic":
253
- # Non-deterministic routing with temperature and Gumbel noise
254
- if self.training:
255
- # Add Gumbel noise during training
256
- noise = torch.rand_like(router_logits) * self.gumbel_noise
257
- router_logits = router_logits + noise
258
-
259
- # Apply temperature scaling
260
- routing_weights = F.softmax(router_logits / self.router_temperature, dim=-1)
261
- selected_experts = torch.multinomial(routing_weights, self.top_k)
262
-
263
- else: # Default sparse routing
264
- # Standard sparse top-k routing
265
- routing_weights = F.softmax(router_logits, dim=-1)
266
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
267
-
268
- return routing_weights, selected_experts, router_logits
269
-
270
- class OlmoeRotaryEmbedding(nn.Module):
271
- def __init__(self, config: OlmoeConfig, device=None):
272
- super().__init__()
273
- # BC: "rope_type" was originally "type"
274
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
275
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
276
- else:
277
- self.rope_type = "default"
278
- self.max_seq_len_cached = config.max_position_embeddings
279
- self.original_max_seq_len = config.max_position_embeddings
280
-
281
- self.config = config
282
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
283
-
284
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
285
- self.register_buffer("inv_freq", inv_freq, persistent=False)
286
- self.original_inv_freq = self.inv_freq
287
-
288
- @torch.no_grad()
289
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
290
- def forward(self, x, position_ids):
291
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
292
- position_ids_expanded = position_ids[:, None, :].float()
293
-
294
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
295
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
296
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
297
- emb = torch.cat((freqs, freqs), dim=-1)
298
- cos = emb.cos() * self.attention_scaling
299
- sin = emb.sin() * self.attention_scaling
300
-
301
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
302
 
303
- class MyOLMoESparseMoeBlock(nn.Module):
304
- """Modified sparse MoE block with custom routing."""
305
-
306
- def __init__(self, config: OlmoeConfig):
307
- super().__init__()
308
- self.num_experts = config.num_experts
309
- self.top_k = config.num_experts_per_tok
310
- self.hidden_size = config.hidden_size
311
- self.intermediate_size = config.intermediate_size
312
- self.norm_topk_prob = config.norm_topk_prob
313
-
314
- # Custom routing mechanism
315
- self.router = MyOLMoERouting(config)
316
-
317
- # Expert networks
318
- self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
319
-
320
- def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
321
- print(f"DEBUG: MoE forward start - hidden_states shape: {hidden_states.shape}")
322
- batch_size, seq_len, hidden_dim = hidden_states.shape
323
- print("absolute precision")
324
  hidden_states = hidden_states.view(-1, hidden_dim)
325
-
326
- # Get routing weights and selected experts
327
- print(f"DEBUG: 123: {self.router(hidden_states).shape}")
328
- routing_weights, selected_experts, router_logits = self.router(hidden_states)
329
  router_logits = self.gate(hidden_states)
330
-
331
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
332
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
 
 
333
  if self.norm_topk_prob:
334
  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
335
- # we cast back to the input dtype
336
  routing_weights = routing_weights.to(hidden_states.dtype)
337
- print(f"DEBUG: MoE forward mid - routing_weights shape: {routing_weights.shape}, selected_experts shape: {selected_experts.shape}")
338
-
339
  final_hidden_states = torch.zeros(
340
- (batch_size * seq_len, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
 
 
341
  )
342
-
343
- # One-hot expert mask
344
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts)
345
-
346
- # Dispatch to experts
347
  for expert_idx in range(self.num_experts):
348
  expert_layer = self.experts[expert_idx]
349
- idx, top_x = torch.where(expert_mask[:, :, expert_idx])
350
-
351
- if idx.shape[0] == 0:
352
- continue
353
-
354
- current_state = hidden_states[None, top_x].reshape(-1, self.hidden_size)
355
- current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
356
-
357
- final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
358
-
359
- final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, self.hidden_size)
360
-
361
- # Return 3 values: (hidden_states, router_logits, aux_loss)
362
- # For now, we'll return None for aux_loss - you can compute it here if needed
363
- aux_loss = None
364
- print(f"DEBUG: MoE forward returning - final_hidden_states shape: {final_hidden_states.shape}, router_logits shape: {router_logits.shape}, aux_loss: {aux_loss}")
365
- return final_hidden_states, router_logits, aux_loss
366
 
367
- class MyOLMoEDecoderLayer(nn.Module):
368
- """Modified decoder layer with custom MoE routing."""
369
-
370
  def __init__(self, config: OlmoeConfig, layer_idx: int):
371
  super().__init__()
372
  self.hidden_size = config.hidden_size
373
-
374
- # Self-attention
375
- self.self_attn = OLMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
376
-
377
- # Custom MoE layer
378
- self.mlp = MyOLMoESparseMoeBlock(config)
379
-
380
- # Layer norms
381
  self.input_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
382
- self.post_attention_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
383
-
 
 
384
  def forward(
385
  self,
386
  hidden_states: torch.Tensor,
@@ -392,12 +518,12 @@ class MyOLMoEDecoderLayer(nn.Module):
392
  use_cache: Optional[bool] = False,
393
  cache_position: Optional[torch.LongTensor] = None,
394
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
395
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache], Optional[torch.Tensor]]:
 
 
 
396
  residual = hidden_states
397
-
398
- # Self-attention
399
  hidden_states = self.input_layernorm(hidden_states)
400
- print(f"DEBUG: Before MoE call - hidden_states shape: {hidden_states.shape}")
401
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
402
  hidden_states=hidden_states,
403
  attention_mask=attention_mask,
@@ -407,32 +533,36 @@ class MyOLMoEDecoderLayer(nn.Module):
407
  use_cache=use_cache,
408
  cache_position=cache_position,
409
  position_embeddings=position_embeddings,
 
410
  )
411
- print(f"DEBUG: After MoE call - hidden_states shape: {hidden_states.shape}")
412
  hidden_states = residual + hidden_states
413
-
414
- # MoE layer
415
  residual = hidden_states
416
  hidden_states = self.post_attention_layernorm(hidden_states)
417
- hidden_states, router_logits, aux_loss = self.mlp(hidden_states)
418
  hidden_states = residual + hidden_states
419
- print(f"DEBUG: End Decoder call - hidden_states shape: {hidden_states.shape}")
420
-
421
- # Always return 4 values in consistent order
422
- return (
423
- hidden_states,
424
- self_attn_weights if output_attentions else None,
425
- present_key_value if use_cache else None,
426
- router_logits if output_router_logits else None
427
- )
428
 
429
- class MyOLMoEPreTrainedModel(PreTrainedModel):
 
 
430
  config_class = OlmoeConfig
431
  base_model_prefix = "model"
432
  supports_gradient_checkpointing = True
433
- _no_split_modules = ["MyOLMoEDecoderLayer"]
434
  _skip_keys_device_placement = ["past_key_values"]
435
-
 
 
 
 
 
436
  def _init_weights(self, module):
437
  std = self.config.initializer_range
438
  if isinstance(module, nn.Linear):
@@ -447,31 +577,33 @@ class MyOLMoEPreTrainedModel(PreTrainedModel):
447
  module.weight.data[module.padding_idx].zero_()
448
 
449
 
450
- class MyOLMoEModel(MyOLMoEPreTrainedModel):
451
- """Modified OLMoE model with custom routing."""
452
-
453
  def __init__(self, config: OlmoeConfig):
454
  super().__init__(config)
455
  self.padding_idx = config.pad_token_id
456
  self.vocab_size = config.vocab_size
457
-
458
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
459
  self.layers = nn.ModuleList(
460
- [MyOLMoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
 
 
 
461
  )
462
  self.norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
463
- self.rotary_emb = OlmoeRotaryEmbedding(config)
464
  self.gradient_checkpointing = False
465
-
466
- # Initialize weights
467
  self.post_init()
468
-
469
  def get_input_embeddings(self):
470
  return self.embed_tokens
471
-
472
  def set_input_embeddings(self, value):
473
  self.embed_tokens = value
474
-
 
475
  def forward(
476
  self,
477
  input_ids: Optional[torch.LongTensor] = None,
@@ -486,88 +618,119 @@ class MyOLMoEModel(MyOLMoEPreTrainedModel):
486
  return_dict: Optional[bool] = None,
487
  cache_position: Optional[torch.LongTensor] = None,
488
  ) -> Union[Tuple, MoeModelOutputWithPast]:
489
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
490
  output_router_logits = (
491
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
 
 
492
  )
493
  output_hidden_states = (
494
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
495
  )
496
  use_cache = use_cache if use_cache is not None else self.config.use_cache
497
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
498
-
 
 
 
 
 
 
 
 
 
 
499
  if inputs_embeds is None:
500
  inputs_embeds = self.embed_tokens(input_ids)
501
-
502
- if past_key_values is None:
503
- past_key_values = DynamicCache()
504
-
 
 
 
 
 
 
 
 
505
  if cache_position is None:
506
- past_seen_tokens = past_key_values.get_seq_length()
 
 
507
  cache_position = torch.arange(
508
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
 
 
509
  )
510
-
511
  if position_ids is None:
512
  position_ids = cache_position.unsqueeze(0)
513
-
514
  causal_mask = self._update_causal_mask(
515
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
 
 
 
 
516
  )
517
-
518
  hidden_states = inputs_embeds
519
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
520
-
521
  all_hidden_states = () if output_hidden_states else None
522
  all_self_attns = () if output_attentions else None
523
  all_router_logits = () if output_router_logits else None
524
  next_decoder_cache = None
525
-
526
- # In MyOLMoEModel.forward(), replace the layer processing loop with:
527
-
528
- for decoder_layer in self.layers:
529
  if output_hidden_states:
530
  all_hidden_states += (hidden_states,)
531
-
532
- layer_outputs = decoder_layer(
533
- hidden_states,
534
- attention_mask=causal_mask,
535
- position_ids=position_ids,
536
- past_key_value=past_key_values,
537
- output_attentions=output_attentions,
538
- output_router_logits=output_router_logits,
539
- use_cache=use_cache,
540
- cache_position=cache_position,
541
- position_embeddings=position_embeddings,
542
- )
543
-
544
- # Unpack the consistent 4-value return
545
- hidden_states, self_attn_weights, present_key_value, router_logits = layer_outputs
546
-
 
 
 
 
 
 
 
 
 
 
547
  if use_cache:
548
- next_decoder_cache = present_key_value
549
-
550
- if output_attentions and self_attn_weights is not None:
551
- all_self_attns += (self_attn_weights,)
552
-
553
- if output_router_logits and router_logits is not None:
554
- all_router_logits += (router_logits,)
555
-
556
  if output_router_logits and layer_outputs[-1] is not None:
557
  all_router_logits += (layer_outputs[-1],)
558
-
559
  hidden_states = self.norm(hidden_states)
560
-
561
  if output_hidden_states:
562
  all_hidden_states += (hidden_states,)
563
-
564
  next_cache = next_decoder_cache if use_cache else None
565
-
 
566
  if not return_dict:
567
  return tuple(
568
- v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] if v is not None
 
 
569
  )
570
-
571
  return MoeModelOutputWithPast(
572
  last_hidden_state=hidden_states,
573
  past_key_values=next_cache,
@@ -575,44 +738,139 @@ class MyOLMoEModel(MyOLMoEPreTrainedModel):
575
  attentions=all_self_attns,
576
  router_logits=all_router_logits,
577
  )
578
-
579
- def _update_causal_mask(self, attention_mask, input_tensor, cache_position, past_key_values, output_attentions):
580
- # Same as original implementation
581
- if attention_mask is not None and 0.0 in attention_mask:
582
- return attention_mask
583
- return None
584
 
585
-
586
- class MyOLMoEForCausalLM(MyOLMoEPreTrainedModel):
587
- """MyOLMoE model for causal language modeling with custom routing."""
588
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  _tied_weights_keys = ["lm_head.weight"]
590
-
591
- def __init__(self, config: OlmoeConfig):
592
  super().__init__(config)
593
- self.model = MyOLMoEModel(config)
594
  self.vocab_size = config.vocab_size
595
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
596
-
597
  self.router_aux_loss_coef = config.router_aux_loss_coef
598
  self.num_experts = config.num_experts
599
  self.num_experts_per_tok = config.num_experts_per_tok
600
-
601
- # Initialize weights
602
  self.post_init()
603
-
604
  def get_input_embeddings(self):
605
  return self.model.embed_tokens
606
-
607
  def set_input_embeddings(self, value):
608
  self.model.embed_tokens = value
609
-
610
  def get_output_embeddings(self):
611
  return self.lm_head
612
-
613
  def set_output_embeddings(self, new_embeddings):
614
  self.lm_head = new_embeddings
615
-
 
 
 
 
 
 
 
616
  def forward(
617
  self,
618
  input_ids: Optional[torch.LongTensor] = None,
@@ -627,17 +885,27 @@ class MyOLMoEForCausalLM(MyOLMoEPreTrainedModel):
627
  output_router_logits: Optional[bool] = None,
628
  return_dict: Optional[bool] = None,
629
  cache_position: Optional[torch.LongTensor] = None,
 
 
630
  ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
631
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
632
  output_router_logits = (
633
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
 
 
634
  )
635
  output_hidden_states = (
636
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
637
  )
638
- use_cache = use_cache if use_cache is not None else self.config.use_cache
639
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
640
-
641
  outputs = self.model(
642
  input_ids=input_ids,
643
  attention_mask=attention_mask,
@@ -651,14 +919,16 @@ class MyOLMoEForCausalLM(MyOLMoEPreTrainedModel):
651
  return_dict=return_dict,
652
  cache_position=cache_position,
653
  )
654
-
655
  hidden_states = outputs[0]
656
- logits = self.lm_head(hidden_states)
657
-
 
 
 
 
658
  loss = None
659
  if labels is not None:
660
- loss = F.cross_entropy(logits.view(-1, self.vocab_size), labels.view(-1))
661
-
662
  aux_loss = None
663
  if output_router_logits:
664
  aux_loss = load_balancing_loss_func(
@@ -669,13 +939,11 @@ class MyOLMoEForCausalLM(MyOLMoEPreTrainedModel):
669
  )
670
  if labels is not None:
671
  loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
672
-
673
  if not return_dict:
674
  output = (logits,) + outputs[1:]
675
  if output_router_logits:
676
  output = (aux_loss,) + output
677
  return (loss,) + output if loss is not None else output
678
-
679
  return MoeCausalLMOutputWithPast(
680
  loss=loss,
681
  aux_loss=aux_loss,
@@ -685,4 +953,3 @@ class MyOLMoEForCausalLM(MyOLMoEPreTrainedModel):
685
  attentions=outputs.attentions,
686
  router_logits=outputs.router_logits,
687
  )
688
-
 
 
 
1
  import math
2
  from typing import List, Optional, Tuple, Union
 
3
  import torch
4
  import torch.nn.functional as F
5
  import torch.utils.checkpoint
6
  from torch import nn
7
  from torch.distributions import Categorical
 
8
  from transformers.activations import ACT2FN
9
  from transformers.cache_utils import Cache, DynamicCache
10
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
11
+ from transformers.modeling_outputs import (
12
+ MoeCausalLMOutputWithPast,
13
+ MoeModelOutputWithPast,
14
+ )
15
  from transformers.modeling_utils import PreTrainedModel
16
  from transformers.models.olmoe.configuration_olmoe import OlmoeConfig
17
  from transformers.utils import logging
18
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
19
 
 
20
  logger = logging.get_logger(__name__)
21
 
22
+
23
  def load_balancing_loss_func(
24
  gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
25
  num_experts: Optional[int] = None,
 
28
  ) -> Union[torch.Tensor, int]:
29
  if gate_logits is None or not isinstance(gate_logits, tuple):
30
  return 0
 
31
  if isinstance(gate_logits, tuple):
32
  compute_device = gate_logits[0].device
33
+ concatenated_gate_logits = torch.cat(
34
+ [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
35
+ )
36
  routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
 
37
  _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
 
38
  expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
 
39
  if attention_mask is None:
 
40
  tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
 
 
41
  router_prob_per_expert = torch.mean(routing_weights, dim=0)
42
  else:
43
  batch_size, sequence_length = attention_mask.shape
44
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (
45
+ batch_size * sequence_length
46
+ )
47
  expert_attention_mask = (
48
  attention_mask[None, :, :, None, None]
49
+ .expand(
50
+ (num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
51
+ )
52
  .reshape(-1, top_k, num_experts)
53
  .to(compute_device)
54
  )
55
+ tokens_per_expert = torch.sum(
56
+ expert_mask.float() * expert_attention_mask, dim=0
57
+ ) / torch.sum(expert_attention_mask, dim=0)
 
 
 
 
58
  router_per_expert_attention_mask = (
59
  attention_mask[None, :, :, None]
60
  .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
61
  .reshape(-1, num_experts)
62
  .to(compute_device)
63
  )
64
+ router_prob_per_expert = torch.sum(
65
+ routing_weights * router_per_expert_attention_mask, dim=0
66
+ ) / torch.sum(router_per_expert_attention_mask, dim=0)
67
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
68
+ return overall_loss * num_experts
69
+
70
+
71
+ class OlmoeRMSNorm(nn.Module):
72
+ def __init__(self, hidden_size, eps=1e-5):
73
+ super().__init__()
74
+ self.weight = nn.Parameter(torch.ones(hidden_size))
75
+ self.variance_epsilon = eps
76
+
77
+ def forward(self, hidden_states):
78
+ input_dtype = hidden_states.dtype
79
+ hidden_states = hidden_states.to(torch.float32)
80
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
81
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
82
+ return self.weight * hidden_states.to(input_dtype)
83
+
84
+ def extra_repr(self):
85
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
86
+
87
+
88
+ ALL_LAYERNORM_LAYERS.append(OlmoeRMSNorm)
89
+
90
+
91
+ class OlmoeRotaryEmbedding(nn.Module):
92
+ def __init__(self, config: OlmoeConfig, device=None):
93
+ super().__init__()
94
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
95
+ self.rope_type = config.rope_scaling.get(
96
+ "rope_type", config.rope_scaling.get("type")
97
+ )
98
+ else:
99
+ self.rope_type = "default"
100
+ self.max_seq_len_cached = config.max_position_embeddings
101
+ self.original_max_seq_len = config.max_position_embeddings
102
+ self.config = config
103
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
104
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
105
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
106
+ self.original_inv_freq = self.inv_freq
107
 
108
+ @torch.no_grad()
109
+ @dynamic_rope_update
110
+ def forward(self, x, position_ids):
111
+ inv_freq_expanded = (
112
+ self.inv_freq[None, :, None]
113
+ .float()
114
+ .expand(position_ids.shape[0], -1, 1)
115
+ .to(x.device)
116
  )
117
+ position_ids_expanded = position_ids[:, None, :].float()
118
+ device_type = (
119
+ x.device.type
120
+ if isinstance(x.device.type, str) and x.device.type != "mps"
121
+ else "cpu"
122
+ )
123
+ with torch.autocast(device_type=device_type, enabled=False):
124
+ freqs = (
125
+ inv_freq_expanded.float() @ position_ids_expanded.float()
126
+ ).transpose(1, 2)
127
+ emb = torch.cat((freqs, freqs), dim=-1)
128
+ cos = emb.cos() * self.attention_scaling
129
+ sin = emb.sin() * self.attention_scaling
130
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
131
+
132
+
133
+ def rotate_half(x):
134
+ x1 = x[..., : x.shape[-1] // 2]
135
+ x2 = x[..., x.shape[-1] // 2 :]
136
+ return torch.cat((-x2, x1), dim=-1)
137
+
138
+
139
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
140
+ cos = cos.unsqueeze(unsqueeze_dim)
141
+ sin = sin.unsqueeze(unsqueeze_dim)
142
+ q_embed = (q * cos) + (rotate_half(q) * sin)
143
+ k_embed = (k * cos) + (rotate_half(k) * sin)
144
+ return q_embed, k_embed
145
+
146
+
147
+ class OlmoeMLP(nn.Module):
148
+ def __init__(self, config):
149
+ super().__init__()
150
+ self.config = config
151
+ self.hidden_size = config.hidden_size
152
+ self.intermediate_size = config.intermediate_size
153
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
154
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
155
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
156
+ self.act_fn = ACT2FN[config.hidden_act]
157
+
158
+ def forward(self, x):
159
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
160
+ return down_proj
161
+
162
+
163
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
164
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
165
+ if n_rep == 1:
166
+ return hidden_states
167
+ hidden_states = hidden_states[:, :, None, :, :].expand(
168
+ batch, num_key_value_heads, n_rep, slen, head_dim
169
+ )
170
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
171
 
 
 
172
 
173
  class OlmoeAttention(nn.Module):
 
 
174
  def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None):
175
  super().__init__()
176
  self.config = config
177
  self.layer_idx = layer_idx
178
+ if layer_idx is None:
179
+ logger.warning_once(
180
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
181
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
182
+ "when creating this class."
183
+ )
184
+ self.attention_dropout = config.attention_dropout
185
  self.hidden_size = config.hidden_size
186
  self.num_heads = config.num_attention_heads
187
  self.head_dim = self.hidden_size // self.num_heads
188
  self.num_key_value_heads = config.num_key_value_heads
189
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
190
+ self.max_position_embeddings = config.max_position_embeddings
191
+ self.rope_theta = config.rope_theta
192
+ self.is_causal = True
193
+ if (self.head_dim * self.num_heads) != self.hidden_size:
194
+ raise ValueError(
195
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
196
+ f" and `num_heads`: {self.num_heads})."
197
+ )
198
+ self.q_proj = nn.Linear(
199
+ self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
200
+ )
201
+ self.k_proj = nn.Linear(
202
+ self.hidden_size,
203
+ self.num_key_value_heads * self.head_dim,
204
+ bias=config.attention_bias,
205
+ )
206
+ self.v_proj = nn.Linear(
207
+ self.hidden_size,
208
+ self.num_key_value_heads * self.head_dim,
209
+ bias=config.attention_bias,
210
+ )
211
+ self.o_proj = nn.Linear(
212
+ self.hidden_size, self.hidden_size, bias=config.attention_bias
213
+ )
214
  self.q_norm = OlmoeRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
215
+ self.k_norm = OlmoeRMSNorm(
216
+ (self.hidden_size // self.num_heads) * self.num_key_value_heads,
217
+ eps=config.rms_norm_eps,
218
+ )
219
+
220
  def forward(
221
  self,
222
  hidden_states: torch.Tensor,
 
230
  **kwargs,
231
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
232
  bsz, q_len, _ = hidden_states.size()
 
233
  query_states = self.q_norm(self.q_proj(hidden_states))
234
  key_states = self.k_norm(self.k_proj(hidden_states))
235
  value_states = self.v_proj(hidden_states)
236
+ if self.config.clip_qkv is not None:
237
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
238
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
239
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
240
+ query_states = query_states.view(
241
+ bsz, q_len, self.num_heads, self.head_dim
242
+ ).transpose(1, 2)
243
+ key_states = key_states.view(
244
+ bsz, q_len, self.num_key_value_heads, self.head_dim
245
+ ).transpose(1, 2)
246
+ value_states = value_states.view(
247
+ bsz, q_len, self.num_key_value_heads, self.head_dim
248
+ ).transpose(1, 2)
249
  cos, sin = position_embeddings
250
+ query_states, key_states = apply_rotary_pos_emb(
251
+ query_states, key_states, cos, sin
252
+ )
253
  if past_key_value is not None:
254
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
255
+ key_states, value_states = past_key_value.update(
256
+ key_states, value_states, self.layer_idx, cache_kwargs
257
+ )
258
  key_states = repeat_kv(key_states, self.num_key_value_groups)
259
  value_states = repeat_kv(value_states, self.num_key_value_groups)
260
+ attn_weights = torch.matmul(
261
+ query_states, key_states.transpose(2, 3)
262
+ ) / math.sqrt(self.head_dim)
263
  if attention_mask is not None:
264
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
265
+ attn_weights = attn_weights + causal_mask
266
+ attn_weights = nn.functional.softmax(
267
+ attn_weights, dim=-1, dtype=torch.float32
268
+ ).to(query_states.dtype)
269
+ attn_weights = nn.functional.dropout(
270
+ attn_weights, p=self.attention_dropout, training=self.training
271
+ )
272
  attn_output = torch.matmul(attn_weights, value_states)
 
273
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
274
  raise ValueError(
275
  f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
276
  f" {attn_output.size()}"
277
  )
 
278
  attn_output = attn_output.transpose(1, 2).contiguous()
279
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
280
  attn_output = self.o_proj(attn_output)
 
281
  if not output_attentions:
282
  attn_weights = None
283
+ return attn_output, attn_weights, past_key_value
284
+
285
 
286
+ class OlmoeFlashAttention2(OlmoeAttention):
287
+ def __init__(self, *args, **kwargs):
288
+ super().__init__(*args, **kwargs)
289
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
290
+
291
+ def forward(
292
+ self,
293
+ hidden_states: torch.Tensor,
294
+ attention_mask: Optional[torch.LongTensor] = None,
295
+ position_ids: Optional[torch.LongTensor] = None,
296
+ past_key_value: Optional[Cache] = None,
297
+ output_attentions: bool = False,
298
+ use_cache: bool = False,
299
+ cache_position: Optional[torch.LongTensor] = None,
300
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
301
+ **kwargs,
302
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
303
+ output_attentions = False
304
+ bsz, q_len, _ = hidden_states.size()
305
+ query_states = self.q_norm(self.q_proj(hidden_states))
306
+ key_states = self.k_norm(self.k_proj(hidden_states))
307
+ value_states = self.v_proj(hidden_states)
308
+ if self.config.clip_qkv is not None:
309
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
310
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
311
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
312
+ query_states = query_states.view(
313
+ bsz, q_len, self.num_heads, self.head_dim
314
+ ).transpose(1, 2)
315
+ key_states = key_states.view(
316
+ bsz, q_len, self.num_key_value_heads, self.head_dim
317
+ ).transpose(1, 2)
318
+ value_states = value_states.view(
319
+ bsz, q_len, self.num_key_value_heads, self.head_dim
320
+ ).transpose(1, 2)
321
+ cos, sin = position_embeddings
322
+ query_states, key_states = apply_rotary_pos_emb(
323
+ query_states, key_states, cos, sin
324
+ )
325
+ if past_key_value is not None:
326
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
327
+ key_states, value_states = past_key_value.update(
328
+ key_states, value_states, self.layer_idx, cache_kwargs
329
+ )
330
+ query_states = query_states.transpose(1, 2)
331
+ key_states = key_states.transpose(1, 2)
332
+ value_states = value_states.transpose(1, 2)
333
+ dropout_rate = self.attention_dropout if self.training else 0.0
334
+ input_dtype = query_states.dtype
335
+ if input_dtype == torch.float32:
336
+ if torch.is_autocast_enabled():
337
+ target_dtype = torch.get_autocast_gpu_dtype()
338
+ elif hasattr(self.config, "_pre_quantization_dtype"):
339
+ target_dtype = self.config._pre_quantization_dtype
340
+ else:
341
+ target_dtype = self.q_proj.weight.dtype
342
+ logger.warning_once(
343
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
344
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
345
+ f" {target_dtype}."
346
+ )
347
+ query_states = query_states.to(target_dtype)
348
+ key_states = key_states.to(target_dtype)
349
+ value_states = value_states.to(target_dtype)
350
+ attn_output = _flash_attention_forward(
351
+ query_states,
352
+ key_states,
353
+ value_states,
354
+ attention_mask,
355
+ q_len,
356
+ dropout=dropout_rate,
357
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
358
+ is_causal=self.is_causal,
359
+ )
360
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
361
+ attn_output = self.o_proj(attn_output)
362
+ if not output_attentions:
363
+ attn_weights = None
364
  return attn_output, attn_weights, past_key_value
365
 
 
 
 
 
 
 
 
366
 
367
+ class OlmoeSdpaAttention(OlmoeAttention):
368
+ def forward(
369
+ self,
370
+ hidden_states: torch.Tensor,
371
+ attention_mask: Optional[torch.Tensor] = None,
372
+ position_ids: Optional[torch.LongTensor] = None,
373
+ past_key_value: Optional[Cache] = None,
374
+ output_attentions: bool = False,
375
+ use_cache: bool = False,
376
+ cache_position: Optional[torch.LongTensor] = None,
377
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
378
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
379
+ if output_attentions:
380
+ logger.warning_once(
381
+ "OlmoeModel is using OlmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
382
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
383
+ )
384
+ return super().forward(
385
+ hidden_states=hidden_states,
386
+ attention_mask=attention_mask,
387
+ position_ids=position_ids,
388
+ past_key_value=past_key_value,
389
+ output_attentions=output_attentions,
390
+ use_cache=use_cache,
391
+ cache_position=cache_position,
392
+ position_embeddings=position_embeddings,
393
+ )
394
+ bsz, q_len, _ = hidden_states.size()
395
+ query_states = self.q_norm(self.q_proj(hidden_states))
396
+ key_states = self.k_norm(self.k_proj(hidden_states))
397
+ value_states = self.v_proj(hidden_states)
398
+ if self.config.clip_qkv is not None:
399
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
400
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
401
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
402
+ query_states = query_states.view(
403
+ bsz, q_len, self.num_heads, self.head_dim
404
+ ).transpose(1, 2)
405
+ key_states = key_states.view(
406
+ bsz, q_len, self.num_key_value_heads, self.head_dim
407
+ ).transpose(1, 2)
408
+ value_states = value_states.view(
409
+ bsz, q_len, self.num_key_value_heads, self.head_dim
410
+ ).transpose(1, 2)
411
+ cos, sin = position_embeddings
412
+ query_states, key_states = apply_rotary_pos_emb(
413
+ query_states, key_states, cos, sin
414
+ )
415
+ if past_key_value is not None:
416
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
417
+ key_states, value_states = past_key_value.update(
418
+ key_states, value_states, self.layer_idx, cache_kwargs
419
+ )
420
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
421
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
422
+ causal_mask = attention_mask
423
+ if attention_mask is not None:
424
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
425
+ if query_states.device.type == "cuda" and causal_mask is not None:
426
+ query_states = query_states.contiguous()
427
+ key_states = key_states.contiguous()
428
+ value_states = value_states.contiguous()
429
+ is_causal = True if causal_mask is None and q_len > 1 else False
430
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
431
+ query_states,
432
+ key_states,
433
+ value_states,
434
+ attn_mask=causal_mask,
435
+ dropout_p=self.attention_dropout if self.training else 0.0,
436
+ is_causal=is_causal,
437
+ )
438
+ attn_output = attn_output.transpose(1, 2).contiguous()
439
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
440
+ attn_output = self.o_proj(attn_output)
441
+ return attn_output, None, past_key_value
442
 
 
 
 
 
 
443
 
 
444
  OLMOE_ATTENTION_CLASSES = {
445
  "eager": OlmoeAttention,
446
+ "flash_attention_2": OlmoeFlashAttention2,
447
+ "sdpa": OlmoeSdpaAttention,
448
  }
449
 
 
 
 
 
 
 
450
 
451
+ class OlmoeSparseMoeBlock(nn.Module):
 
 
 
 
 
 
 
 
452
  def __init__(self, config):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  super().__init__()
454
  self.num_experts = config.num_experts
455
  self.top_k = config.num_experts_per_tok
456
+ self.norm_topk_prob = config.norm_topk_prob
 
 
 
 
457
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
458
+ self.experts = nn.ModuleList(
459
+ [OlmoeMLP(config) for _ in range(self.num_experts)]
460
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
463
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  hidden_states = hidden_states.view(-1, hidden_dim)
 
 
 
 
465
  router_logits = self.gate(hidden_states)
 
466
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
467
+ routing_weights, selected_experts = torch.topk(
468
+ routing_weights, self.top_k, dim=-1
469
+ )
470
  if self.norm_topk_prob:
471
  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
 
472
  routing_weights = routing_weights.to(hidden_states.dtype)
 
 
473
  final_hidden_states = torch.zeros(
474
+ (batch_size * sequence_length, hidden_dim),
475
+ dtype=hidden_states.dtype,
476
+ device=hidden_states.device,
477
  )
478
+ expert_mask = torch.nn.functional.one_hot(
479
+ selected_experts, num_classes=self.num_experts
480
+ ).permute(2, 1, 0)
 
 
481
  for expert_idx in range(self.num_experts):
482
  expert_layer = self.experts[expert_idx]
483
+ idx, top_x = torch.where(expert_mask[expert_idx])
484
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
485
+ current_hidden_states = (
486
+ expert_layer(current_state) * routing_weights[top_x, idx, None]
487
+ )
488
+ final_hidden_states.index_add_(
489
+ 0, top_x, current_hidden_states.to(hidden_states.dtype)
490
+ )
491
+ final_hidden_states = final_hidden_states.reshape(
492
+ batch_size, sequence_length, hidden_dim
493
+ )
494
+ return final_hidden_states, router_logits
 
 
 
 
 
495
 
496
+
497
+ class OlmoeDecoderLayer(nn.Module):
 
498
  def __init__(self, config: OlmoeConfig, layer_idx: int):
499
  super().__init__()
500
  self.hidden_size = config.hidden_size
501
+ self.self_attn = OLMOE_ATTENTION_CLASSES[config._attn_implementation](
502
+ config=config, layer_idx=layer_idx
503
+ )
504
+ self.mlp = OlmoeSparseMoeBlock(config)
 
 
 
 
505
  self.input_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
506
+ self.post_attention_layernorm = OlmoeRMSNorm(
507
+ config.hidden_size, eps=config.rms_norm_eps
508
+ )
509
+
510
  def forward(
511
  self,
512
  hidden_states: torch.Tensor,
 
518
  use_cache: Optional[bool] = False,
519
  cache_position: Optional[torch.LongTensor] = None,
520
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
521
+ **kwargs,
522
+ ) -> Tuple[
523
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
524
+ ]:
525
  residual = hidden_states
 
 
526
  hidden_states = self.input_layernorm(hidden_states)
 
527
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
528
  hidden_states=hidden_states,
529
  attention_mask=attention_mask,
 
533
  use_cache=use_cache,
534
  cache_position=cache_position,
535
  position_embeddings=position_embeddings,
536
+ **kwargs,
537
  )
 
538
  hidden_states = residual + hidden_states
 
 
539
  residual = hidden_states
540
  hidden_states = self.post_attention_layernorm(hidden_states)
541
+ hidden_states, router_logits = self.mlp(hidden_states)
542
  hidden_states = residual + hidden_states
543
+ outputs = (hidden_states,)
544
+ if output_attentions:
545
+ outputs += (self_attn_weights,)
546
+ if use_cache:
547
+ outputs += (present_key_value,)
548
+ if output_router_logits:
549
+ outputs += (router_logits,)
550
+ return outputs
 
551
 
552
+
553
+ @auto_docstring
554
+ class OlmoePreTrainedModel(PreTrainedModel):
555
  config_class = OlmoeConfig
556
  base_model_prefix = "model"
557
  supports_gradient_checkpointing = True
558
+ _no_split_modules = ["OlmoeDecoderLayer"]
559
  _skip_keys_device_placement = ["past_key_values"]
560
+ _supports_flash_attn_2 = True
561
+ _supports_sdpa = True
562
+ _supports_cache_class = True
563
+ _supports_quantized_cache = True
564
+ _supports_static_cache = False
565
+
566
  def _init_weights(self, module):
567
  std = self.config.initializer_range
568
  if isinstance(module, nn.Linear):
 
577
  module.weight.data[module.padding_idx].zero_()
578
 
579
 
580
+ @auto_docstring
581
+ class OlmoeModel(OlmoePreTrainedModel):
 
582
  def __init__(self, config: OlmoeConfig):
583
  super().__init__(config)
584
  self.padding_idx = config.pad_token_id
585
  self.vocab_size = config.vocab_size
586
+ self.embed_tokens = nn.Embedding(
587
+ config.vocab_size, config.hidden_size, self.padding_idx
588
+ )
589
  self.layers = nn.ModuleList(
590
+ [
591
+ OlmoeDecoderLayer(config, layer_idx)
592
+ for layer_idx in range(config.num_hidden_layers)
593
+ ]
594
  )
595
  self.norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
596
+ self.rotary_emb = OlmoeRotaryEmbedding(config=config)
597
  self.gradient_checkpointing = False
 
 
598
  self.post_init()
599
+
600
  def get_input_embeddings(self):
601
  return self.embed_tokens
602
+
603
  def set_input_embeddings(self, value):
604
  self.embed_tokens = value
605
+
606
+ @auto_docstring
607
  def forward(
608
  self,
609
  input_ids: Optional[torch.LongTensor] = None,
 
618
  return_dict: Optional[bool] = None,
619
  cache_position: Optional[torch.LongTensor] = None,
620
  ) -> Union[Tuple, MoeModelOutputWithPast]:
621
+ output_attentions = (
622
+ output_attentions
623
+ if output_attentions is not None
624
+ else self.config.output_attentions
625
+ )
626
  output_router_logits = (
627
+ output_router_logits
628
+ if output_router_logits is not None
629
+ else self.config.output_router_logits
630
  )
631
  output_hidden_states = (
632
+ output_hidden_states
633
+ if output_hidden_states is not None
634
+ else self.config.output_hidden_states
635
  )
636
  use_cache = use_cache if use_cache is not None else self.config.use_cache
637
+ return_dict = (
638
+ return_dict if return_dict is not None else self.config.use_return_dict
639
+ )
640
+ if (input_ids is None) ^ (inputs_embeds is not None):
641
+ raise ValueError(
642
+ "You must specify exactly one of input_ids or inputs_embeds"
643
+ )
644
+ if self.gradient_checkpointing and self.training and use_cache:
645
+ logger.warning_once(
646
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
647
+ )
648
+ use_cache = False
649
  if inputs_embeds is None:
650
  inputs_embeds = self.embed_tokens(input_ids)
651
+ return_legacy_cache = False
652
+ if use_cache and not isinstance(past_key_values, Cache):
653
+ return_legacy_cache = True
654
+ if past_key_values is None:
655
+ past_key_values = DynamicCache()
656
+ else:
657
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
658
+ logger.warning_once(
659
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
660
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
661
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
662
+ )
663
  if cache_position is None:
664
+ past_seen_tokens = (
665
+ past_key_values.get_seq_length() if past_key_values is not None else 0
666
+ )
667
  cache_position = torch.arange(
668
+ past_seen_tokens,
669
+ past_seen_tokens + inputs_embeds.shape[1],
670
+ device=inputs_embeds.device,
671
  )
 
672
  if position_ids is None:
673
  position_ids = cache_position.unsqueeze(0)
 
674
  causal_mask = self._update_causal_mask(
675
+ attention_mask,
676
+ inputs_embeds,
677
+ cache_position,
678
+ past_key_values,
679
+ output_attentions,
680
  )
 
681
  hidden_states = inputs_embeds
682
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
 
683
  all_hidden_states = () if output_hidden_states else None
684
  all_self_attns = () if output_attentions else None
685
  all_router_logits = () if output_router_logits else None
686
  next_decoder_cache = None
687
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
 
 
 
688
  if output_hidden_states:
689
  all_hidden_states += (hidden_states,)
690
+ if self.gradient_checkpointing and self.training:
691
+ layer_outputs = self._gradient_checkpointing_func(
692
+ decoder_layer.__call__,
693
+ hidden_states,
694
+ causal_mask,
695
+ position_ids,
696
+ past_key_values,
697
+ output_attentions,
698
+ output_router_logits,
699
+ use_cache,
700
+ cache_position,
701
+ position_embeddings,
702
+ )
703
+ else:
704
+ layer_outputs = decoder_layer(
705
+ hidden_states,
706
+ attention_mask=causal_mask,
707
+ position_ids=position_ids,
708
+ past_key_value=past_key_values,
709
+ output_attentions=output_attentions,
710
+ output_router_logits=output_router_logits,
711
+ use_cache=use_cache,
712
+ cache_position=cache_position,
713
+ position_embeddings=position_embeddings,
714
+ )
715
+ hidden_states = layer_outputs[0]
716
  if use_cache:
717
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
718
+ if output_attentions:
719
+ all_self_attns += (layer_outputs[1],)
 
 
 
 
 
720
  if output_router_logits and layer_outputs[-1] is not None:
721
  all_router_logits += (layer_outputs[-1],)
 
722
  hidden_states = self.norm(hidden_states)
 
723
  if output_hidden_states:
724
  all_hidden_states += (hidden_states,)
 
725
  next_cache = next_decoder_cache if use_cache else None
726
+ if return_legacy_cache:
727
+ next_cache = next_cache.to_legacy_cache()
728
  if not return_dict:
729
  return tuple(
730
+ v
731
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
732
+ if v is not None
733
  )
 
734
  return MoeModelOutputWithPast(
735
  last_hidden_state=hidden_states,
736
  past_key_values=next_cache,
 
738
  attentions=all_self_attns,
739
  router_logits=all_router_logits,
740
  )
 
 
 
 
 
 
741
 
742
+ def _update_causal_mask(
743
+ self,
744
+ attention_mask: torch.Tensor,
745
+ input_tensor: torch.Tensor,
746
+ cache_position: torch.Tensor,
747
+ past_key_values: Cache,
748
+ output_attentions: bool,
749
+ ):
750
+ if self.config._attn_implementation == "flash_attention_2":
751
+ if attention_mask is not None and 0.0 in attention_mask:
752
+ return attention_mask
753
+ return None
754
+ past_seen_tokens = (
755
+ past_key_values.get_seq_length() if past_key_values is not None else 0
756
+ )
757
+ using_static_cache = isinstance(past_key_values, StaticCache)
758
+ if (
759
+ self.config._attn_implementation == "sdpa"
760
+ and not using_static_cache
761
+ and not output_attentions
762
+ ):
763
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
764
+ attention_mask,
765
+ inputs_embeds=input_tensor,
766
+ past_key_values_length=past_seen_tokens,
767
+ is_training=self.training,
768
+ ):
769
+ return None
770
+ dtype, device = input_tensor.dtype, input_tensor.device
771
+ sequence_length = input_tensor.shape[1]
772
+ if using_static_cache:
773
+ target_length = past_key_values.get_max_cache_shape()
774
+ else:
775
+ target_length = (
776
+ attention_mask.shape[-1]
777
+ if isinstance(attention_mask, torch.Tensor)
778
+ else past_seen_tokens + sequence_length + 1
779
+ )
780
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
781
+ attention_mask,
782
+ sequence_length=sequence_length,
783
+ target_length=target_length,
784
+ dtype=dtype,
785
+ device=device,
786
+ cache_position=cache_position,
787
+ batch_size=input_tensor.shape[0],
788
+ )
789
+ if (
790
+ self.config._attn_implementation == "sdpa"
791
+ and attention_mask is not None
792
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
793
+ and not output_attentions
794
+ ):
795
+ min_dtype = torch.finfo(dtype).min
796
+ causal_mask = AttentionMaskConverter._unmask_unattended(
797
+ causal_mask, min_dtype
798
+ )
799
+ return causal_mask
800
+
801
+ @staticmethod
802
+ def _prepare_4d_causal_attention_mask_with_cache_position(
803
+ attention_mask: torch.Tensor,
804
+ sequence_length: int,
805
+ target_length: int,
806
+ dtype: torch.dtype,
807
+ device: torch.device,
808
+ cache_position: torch.Tensor,
809
+ batch_size: int,
810
+ **kwargs,
811
+ ):
812
+ if attention_mask is not None and attention_mask.dim() == 4:
813
+ causal_mask = attention_mask
814
+ else:
815
+ min_dtype = torch.finfo(dtype).min
816
+ causal_mask = torch.full(
817
+ (sequence_length, target_length),
818
+ fill_value=min_dtype,
819
+ dtype=dtype,
820
+ device=device,
821
+ )
822
+ if sequence_length != 1:
823
+ causal_mask = torch.triu(causal_mask, diagonal=1)
824
+ causal_mask *= torch.arange(
825
+ target_length, device=device
826
+ ) > cache_position.reshape(-1, 1)
827
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
828
+ if attention_mask is not None:
829
+ causal_mask = causal_mask.clone()
830
+ mask_length = attention_mask.shape[-1]
831
+ padding_mask = (
832
+ causal_mask[:, :, :, :mask_length]
833
+ + attention_mask[:, None, None, :]
834
+ )
835
+ padding_mask = padding_mask == 0
836
+ causal_mask[:, :, :, :mask_length] = causal_mask[
837
+ :, :, :, :mask_length
838
+ ].masked_fill(padding_mask, min_dtype)
839
+ return causal_mask
840
+
841
+
842
+ class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
843
  _tied_weights_keys = ["lm_head.weight"]
844
+
845
+ def __init__(self, config):
846
  super().__init__(config)
847
+ self.model = OlmoeModel(config)
848
  self.vocab_size = config.vocab_size
849
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
850
  self.router_aux_loss_coef = config.router_aux_loss_coef
851
  self.num_experts = config.num_experts
852
  self.num_experts_per_tok = config.num_experts_per_tok
 
 
853
  self.post_init()
854
+
855
  def get_input_embeddings(self):
856
  return self.model.embed_tokens
857
+
858
  def set_input_embeddings(self, value):
859
  self.model.embed_tokens = value
860
+
861
  def get_output_embeddings(self):
862
  return self.lm_head
863
+
864
  def set_output_embeddings(self, new_embeddings):
865
  self.lm_head = new_embeddings
866
+
867
+ def set_decoder(self, decoder):
868
+ self.model = decoder
869
+
870
+ def get_decoder(self):
871
+ return self.model
872
+
873
+ @auto_docstring
874
  def forward(
875
  self,
876
  input_ids: Optional[torch.LongTensor] = None,
 
885
  output_router_logits: Optional[bool] = None,
886
  return_dict: Optional[bool] = None,
887
  cache_position: Optional[torch.LongTensor] = None,
888
+ logits_to_keep: Union[int, torch.Tensor] = 0,
889
+ **loss_kwargs,
890
  ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
891
+ output_attentions = (
892
+ output_attentions
893
+ if output_attentions is not None
894
+ else self.config.output_attentions
895
+ )
896
  output_router_logits = (
897
+ output_router_logits
898
+ if output_router_logits is not None
899
+ else self.config.output_router_logits
900
  )
901
  output_hidden_states = (
902
+ output_hidden_states
903
+ if output_hidden_states is not None
904
+ else self.config.output_hidden_states
905
+ )
906
+ return_dict = (
907
+ return_dict if return_dict is not None else self.config.use_return_dict
908
  )
 
 
 
909
  outputs = self.model(
910
  input_ids=input_ids,
911
  attention_mask=attention_mask,
 
919
  return_dict=return_dict,
920
  cache_position=cache_position,
921
  )
 
922
  hidden_states = outputs[0]
923
+ slice_indices = (
924
+ slice(-logits_to_keep, None)
925
+ if isinstance(logits_to_keep, int)
926
+ else logits_to_keep
927
+ )
928
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
929
  loss = None
930
  if labels is not None:
931
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
 
932
  aux_loss = None
933
  if output_router_logits:
934
  aux_loss = load_balancing_loss_func(
 
939
  )
940
  if labels is not None:
941
  loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
 
942
  if not return_dict:
943
  output = (logits,) + outputs[1:]
944
  if output_router_logits:
945
  output = (aux_loss,) + output
946
  return (loss,) + output if loss is not None else output
 
947
  return MoeCausalLMOutputWithPast(
948
  loss=loss,
949
  aux_loss=aux_loss,
 
953
  attentions=outputs.attentions,
954
  router_logits=outputs.router_logits,
955
  )
 
myolmoe/special_tokens_map.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "<|padding|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ }
16
+ }
myolmoe/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
myolmoe/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "|||IP_ADDRESS|||",
8
+ "lstrip": false,
9
+ "normalized": true,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": false
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ },
206
+ "50277": {
207
+ "content": "|||EMAIL_ADDRESS|||",
208
+ "lstrip": false,
209
+ "normalized": true,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": false
213
+ },
214
+ "50278": {
215
+ "content": "|||PHONE_NUMBER|||",
216
+ "lstrip": false,
217
+ "normalized": true,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": false
221
+ },
222
+ "50279": {
223
+ "content": "<|endoftext|>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": true
229
+ }
230
+ },
231
+ "bos_token": null,
232
+ "clean_up_tokenization_spaces": true,
233
+ "eos_token": "<|endoftext|>",
234
+ "extra_special_tokens": {},
235
+ "model_max_length": 1000000000000000019884624838656,
236
+ "pad_token": "<|padding|>",
237
+ "tokenizer_class": "GPTNeoXTokenizer",
238
+ "unk_token": null
239
+ }
olmoe_wrapper.py DELETED
@@ -1,358 +0,0 @@
1
- """
2
- LM Evaluation Harness Wrapper for Modified MyOLMoE
3
- """
4
- import torch
5
- from typing import List, Optional, Union, Dict, Any
6
- from transformers import AutoTokenizer, AutoModelForCausalLM
7
- from lm_eval.api.model import LM
8
- from lm_eval.api.registry import register_model
9
- import numpy as np
10
-
11
-
12
- @register_model("myolmoe")
13
- class MyOLMoELM(LM):
14
- """LM Evaluation Harness wrapper for MYOLMoE model."""
15
-
16
- def __init__(
17
- self,
18
- pretrained: str = None,
19
- device: str = "cuda",
20
- batch_size: int = 1,
21
- max_length: int = 2048,
22
- trust_remote_code: bool = False,
23
- dtype: str = "float16",
24
- parallelize: bool = False,
25
- device_map: Optional[str] = None,
26
- **kwargs
27
- ):
28
- super().__init__()
29
-
30
- # Initialize device and batch size
31
- if device == "cuda" and not torch.cuda.is_available():
32
- device = "cpu"
33
- self._device = torch.device(device)
34
- self._batch_size = batch_size
35
- self._max_length = max_length
36
-
37
- # Set dtype
38
- if dtype == "float16":
39
- self._dtype = torch.float16
40
- elif dtype == "bfloat16":
41
- self._dtype = torch.bfloat16
42
- else:
43
- self._dtype = torch.float32
44
-
45
- # Load tokenizer and model
46
- if pretrained:
47
- self.tokenizer = AutoTokenizer.from_pretrained(
48
- pretrained,
49
- trust_remote_code=trust_remote_code,
50
- padding_side="left"
51
- )
52
-
53
- # Ensure pad token is set
54
- if self.tokenizer.pad_token is None:
55
- if self.tokenizer.eos_token is not None:
56
- self.tokenizer.pad_token = self.tokenizer.eos_token
57
- else:
58
- self.tokenizer.add_special_tokens({'pad_token': '<pad>'})
59
-
60
- self.model = AutoModelForCausalLM.from_pretrained(
61
- pretrained,
62
- torch_dtype=self._dtype,
63
- device_map=device_map if parallelize else None,
64
- trust_remote_code=trust_remote_code,
65
- **kwargs
66
- )
67
-
68
- if not parallelize:
69
- self.model = self.model.to(self._device)
70
-
71
- self.model.eval()
72
- else:
73
- raise ValueError("pretrained model path must be specified")
74
-
75
- @property
76
- def eot_token_id(self):
77
- """End of text token ID."""
78
- return self.tokenizer.eos_token_id
79
-
80
- @property
81
- def max_length(self):
82
- """Maximum sequence length."""
83
- return self._max_length
84
-
85
- @property
86
- def max_gen_toks(self):
87
- """Maximum number of tokens to generate."""
88
- return 256
89
-
90
- @property
91
- def batch_size(self):
92
- """Batch size for evaluation."""
93
- return self._batch_size
94
-
95
- @property
96
- def device(self):
97
- """Device used for evaluation."""
98
- return self._device
99
-
100
- def tok_encode(self, string: str, add_special_tokens=True) -> List[int]:
101
- """Encode a string to token IDs."""
102
- return self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
103
-
104
- def tok_decode(self, tokens: List[int]) -> str:
105
- """Decode token IDs to string."""
106
- return self.tokenizer.decode(tokens, skip_special_tokens=True)
107
-
108
- def loglikelihood(self, requests: List[tuple]) -> List[tuple]:
109
- """
110
- Compute log-likelihood for each request.
111
- Each request is a tuple of (context, continuation).
112
- """
113
- results = []
114
-
115
- # Process requests in batches
116
- for i in range(0, len(requests), self.batch_size):
117
- batch = requests[i:i + self.batch_size]
118
- batch_results = self._loglikelihood_batch(batch)
119
- results.extend(batch_results)
120
-
121
- return results
122
-
123
- def _loglikelihood_batch(self, batch: List[tuple]) -> List[tuple]:
124
- """Process a batch of loglikelihood requests."""
125
- contexts, continuations = zip(*batch)
126
-
127
- # Encode full sequences (context + continuation)
128
- full_sequences = [ctx + cont for ctx, cont in zip(contexts, continuations)]
129
- full_encodings = [self.tok_encode(seq) for seq in full_sequences]
130
-
131
- # Encode contexts only
132
- context_encodings = [self.tok_encode(ctx) for ctx in contexts]
133
-
134
- # Pad sequences to the same length
135
- max_len = min(max(len(seq) for seq in full_encodings), self.max_length)
136
-
137
- input_ids = []
138
- attention_masks = []
139
- continuation_masks = []
140
-
141
- for full_seq, ctx_seq in zip(full_encodings, context_encodings):
142
- # Truncate if necessary (keep the end)
143
- if len(full_seq) > max_len:
144
- full_seq = full_seq[-max_len:]
145
- ctx_len = max(0, len(ctx_seq) - (len(full_encodings[0]) - max_len))
146
- else:
147
- ctx_len = len(ctx_seq)
148
-
149
- # Create padding
150
- pad_length = max_len - len(full_seq)
151
- padded_seq = [self.tokenizer.pad_token_id] * pad_length + full_seq
152
- attention_mask = [0] * pad_length + [1] * len(full_seq)
153
-
154
- # Create mask for continuation tokens only
155
- continuation_mask = [0] * max_len
156
- continuation_start = pad_length + ctx_len
157
- for j in range(continuation_start, max_len):
158
- continuation_mask[j] = 1
159
-
160
- input_ids.append(padded_seq)
161
- attention_masks.append(attention_mask)
162
- continuation_masks.append(continuation_mask)
163
-
164
- # Convert to tensors
165
- input_ids = torch.tensor(input_ids, device=self.device)
166
- attention_masks = torch.tensor(attention_masks, device=self.device)
167
- continuation_masks = torch.tensor(continuation_masks, device=self.device)
168
-
169
- # Forward pass
170
- with torch.no_grad():
171
- outputs = self.model(input_ids=input_ids, attention_mask=attention_masks)
172
- logits = outputs.logits
173
-
174
- # Compute log-likelihoods
175
- results = []
176
- for i in range(len(batch)):
177
- # Get logits for positions where we predict continuation tokens
178
- # Shift logits and tokens for next-token prediction
179
- shifted_logits = logits[i, :-1] # Remove last position
180
- shifted_tokens = input_ids[i, 1:] # Remove first position
181
- shifted_mask = continuation_masks[i][1:] # Remove first position
182
-
183
- # Only consider continuation tokens
184
- valid_positions = shifted_mask.bool()
185
- if valid_positions.sum() == 0:
186
- results.append((float('-inf'), False))
187
- continue
188
-
189
- # Get log probabilities
190
- log_probs = torch.log_softmax(shifted_logits, dim=-1)
191
- token_log_probs = log_probs.gather(1, shifted_tokens.unsqueeze(1)).squeeze(1)
192
-
193
- # Sum only over continuation tokens
194
- valid_log_probs = token_log_probs[valid_positions]
195
- total_log_prob = valid_log_probs.sum().item()
196
-
197
- # For simplicity, assume greedy is True
198
- is_greedy = True
199
-
200
- results.append((total_log_prob, is_greedy))
201
-
202
- return results
203
-
204
- def generate_until(self, requests: List[tuple]) -> List[str]:
205
- """
206
- Generate text until stopping criteria are met.
207
- Each request is a tuple of (context, generation_kwargs).
208
- """
209
- results = []
210
-
211
- # Process requests in batches
212
- for i in range(0, len(requests), self.batch_size):
213
- batch = requests[i:i + self.batch_size]
214
- batch_results = self._generate_until_batch(batch)
215
- results.extend(batch_results)
216
-
217
- return results
218
-
219
- def _generate_until_batch(self, batch: List[tuple]) -> List[str]:
220
- """Process a batch of generation requests."""
221
- contexts = []
222
- gen_kwargs_list = []
223
-
224
- for context, gen_kwargs in batch:
225
- contexts.append(context)
226
- gen_kwargs_list.append(gen_kwargs)
227
-
228
- # Encode contexts
229
- context_encodings = [self.tok_encode(ctx) for ctx in contexts]
230
-
231
- # Pad contexts
232
- max_ctx_len = min(max(len(seq) for seq in context_encodings),
233
- self.max_length - self.max_gen_toks)
234
-
235
- input_ids = []
236
- attention_masks = []
237
-
238
- for ctx_seq in context_encodings:
239
- # Truncate if necessary (keep the end)
240
- if len(ctx_seq) > max_ctx_len:
241
- ctx_seq = ctx_seq[-max_ctx_len:]
242
-
243
- # Pad sequence
244
- pad_length = max_ctx_len - len(ctx_seq)
245
- padded_seq = [self.tokenizer.pad_token_id] * pad_length + ctx_seq
246
- attention_mask = [0] * pad_length + [1] * len(ctx_seq)
247
-
248
- input_ids.append(padded_seq)
249
- attention_masks.append(attention_mask)
250
-
251
- # Convert to tensors
252
- input_ids = torch.tensor(input_ids, device=self.device)
253
- attention_masks = torch.tensor(attention_masks, device=self.device)
254
-
255
- # Generate
256
- with torch.no_grad():
257
- # Use first gen_kwargs for simplicity (can be extended)
258
- gen_kwargs = gen_kwargs_list[0] if gen_kwargs_list else {}
259
-
260
- # Set default generation parameters
261
- generation_kwargs = {
262
- 'max_new_tokens': gen_kwargs.get('max_gen_toks', self.max_gen_toks),
263
- 'do_sample': gen_kwargs.get('do_sample', False),
264
- 'temperature': gen_kwargs.get('temperature', 1.0),
265
- 'top_p': gen_kwargs.get('top_p', 1.0),
266
- 'pad_token_id': self.tokenizer.pad_token_id,
267
- 'eos_token_id': self.tokenizer.eos_token_id,
268
- 'attention_mask': attention_masks,
269
- 'use_cache': True,
270
- }
271
-
272
- generated = self.model.generate(
273
- input_ids=input_ids,
274
- **generation_kwargs
275
- )
276
-
277
- # Decode generated text
278
- results = []
279
- for i, gen_seq in enumerate(generated):
280
- # Get original context length (without padding)
281
- original_ctx_len = len(context_encodings[i])
282
-
283
- # Extract only the newly generated tokens
284
- if len(gen_seq) > len(input_ids[i]):
285
- new_tokens = gen_seq[len(input_ids[i]):].tolist()
286
- else:
287
- new_tokens = []
288
-
289
- # Decode
290
- if new_tokens:
291
- generated_text = self.tok_decode(new_tokens)
292
- else:
293
- generated_text = ""
294
-
295
- # Apply stopping criteria if specified
296
- if 'until' in gen_kwargs_list[i]:
297
- stop_strings = gen_kwargs_list[i]['until']
298
- if isinstance(stop_strings, str):
299
- stop_strings = [stop_strings]
300
-
301
- for stop_str in stop_strings:
302
- if stop_str in generated_text:
303
- generated_text = generated_text[:generated_text.index(stop_str)]
304
- break
305
-
306
- results.append(generated_text)
307
-
308
- return results
309
-
310
- def loglikelihood_rolling(self, requests: List[tuple]) -> List[float]:
311
- """
312
- Compute rolling log-likelihood for each request.
313
- Each request is a tuple containing the text to evaluate.
314
- """
315
- results = []
316
-
317
- for request in requests:
318
- text = request[0] if isinstance(request, tuple) else request
319
- tokens = self.tok_encode(text)
320
-
321
- if len(tokens) <= 1:
322
- results.append(0.0)
323
- continue
324
-
325
- # Compute log-likelihood using sliding window approach
326
- total_log_prob = 0.0
327
- total_tokens = 0
328
-
329
- # Use sliding window for long sequences
330
- window_size = min(self.max_length, len(tokens))
331
-
332
- for i in range(1, len(tokens)):
333
- # Define the window
334
- start_idx = max(0, i - window_size + 1)
335
- end_idx = i + 1
336
-
337
- window_tokens = tokens[start_idx:end_idx]
338
- input_ids = torch.tensor([window_tokens], device=self.device)
339
-
340
- with torch.no_grad():
341
- outputs = self.model(input_ids=input_ids)
342
- logits = outputs.logits
343
-
344
- # Get log probability for the target token
345
- target_pos = len(window_tokens) - 1
346
- target_token = window_tokens[target_pos]
347
-
348
- if target_pos > 0: # Ensure we have a position to predict from
349
- token_logits = logits[0, target_pos - 1]
350
- log_prob = torch.log_softmax(token_logits, dim=-1)[target_token].item()
351
- total_log_prob += log_prob
352
- total_tokens += 1
353
-
354
- # Return mean log-likelihood per token
355
- avg_log_prob = total_log_prob / total_tokens if total_tokens > 0 else 0.0
356
- results.append(avg_log_prob)
357
-
358
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/downloadmodel.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig
2
+ import os
3
+ os.environ["HF_HUB_READ_TIMEOUT"] = "60"
4
+
5
+ config = AutoConfig.from_pretrained("allenai/OLMoE-1B-7B-0924", timeout=60)
6
+
7
+ print(config)
scripts/downloadweights.py DELETED
@@ -1,20 +0,0 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
2
-
3
- model = AutoModelForCausalLM.from_pretrained(
4
- "allenai/OLMoE-7B", # Exact name from Hugging Face
5
- trust_remote_code=True, # Required if they use custom modeling_olmoe.py
6
- use_safetensors=True # Ensures .safetensors file is used
7
- )
8
-
9
- tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-7B")
10
- print(model.config)
11
- print(model.__class__)
12
-
13
- from transformers.utils.hub import cached_file
14
-
15
- # Example: get the path to the config file or model weights index
16
- config_path = cached_file("allenai/OLMoE-7B", "config.json", trust_remote_code=True)
17
- print(config_path)
18
- import os
19
- model_path = os.path.dirname(config_path)
20
- print(model_path)