Charlie81 commited on
Commit
36acce3
·
1 Parent(s): a82f934

reset modeling file

Browse files
Files changed (2) hide show
  1. myolmoe/modeling_myolmoe.py +7 -52
  2. scripts/train.py +1 -2
myolmoe/modeling_myolmoe.py CHANGED
@@ -1,4 +1,3 @@
1
- # modeling_myolmoe.py
2
  import math
3
  from typing import List, Optional, Tuple, Union
4
  import torch
@@ -157,21 +156,6 @@ class OlmoeMLP(nn.Module):
157
  def forward(self, x):
158
  down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
159
  return down_proj
160
-
161
- class SmallOlmoeMLP(nn.Module):
162
- def __init__(self, config, small_expert_intermediate_size):
163
- super().__init__()
164
- self.config = config
165
- self.hidden_size = config.hidden_size
166
- self.intermediate_size = small_expert_intermediate_size
167
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
168
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
169
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
170
- self.act_fn = ACT2FN[config.hidden_act]
171
-
172
- def forward(self, x):
173
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
174
- return down_proj
175
 
176
 
177
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -462,34 +446,17 @@ OLMOE_ATTENTION_CLASSES = {
462
  }
463
 
464
 
465
-
466
  class OlmoeSparseMoeBlock(nn.Module):
467
  def __init__(self, config, layer_idx: int):
468
  super().__init__()
469
  self.layer_idx = layer_idx
470
  self.num_experts = config.num_experts
471
- self.num_small_experts = getattr(config, "num_small_experts", 0) # Default to 0 if not specified
472
- self.total_experts = self.num_experts + self.num_small_experts
473
  self.top_k = config.num_experts_per_tok
474
  self.norm_topk_prob = config.norm_topk_prob
475
- self.routing_type = getattr(config, "routing_type", "topk")
476
- self.n_step = getattr(config, "nth_step", 2)
477
-
478
- # Gate now needs to handle both regular and small experts
479
- self.gate = nn.Linear(config.hidden_size, self.total_experts, bias=False)
480
-
481
- # Regular experts
482
  self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
483
-
484
- # Small experts (if any)
485
- self.small_experts = nn.ModuleList()
486
- if self.num_small_experts > 0:
487
- small_expert_intermediate_size = getattr(config, "small_expert_intermediate_size",
488
- config.intermediate_size // 2) # Default to half size
489
- self.small_experts = nn.ModuleList([
490
- SmallOlmoeMLP(config, small_expert_intermediate_size)
491
- for _ in range(self.num_small_experts)
492
- ])
493
 
494
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
495
  batch_size, sequence_length, hidden_dim = hidden_states.shape
@@ -497,6 +464,7 @@ class OlmoeSparseMoeBlock(nn.Module):
497
  router_logits = self.gate(hidden_states)
498
  routing_probs = F.softmax(router_logits, dim=1, dtype=torch.float)
499
 
 
500
  routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
501
 
502
  if self.norm_topk_prob:
@@ -509,9 +477,8 @@ class OlmoeSparseMoeBlock(nn.Module):
509
  device=hidden_states.device,
510
  )
511
 
512
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.total_experts).permute(2, 1, 0)
513
 
514
- # Process regular experts
515
  for expert_idx in range(self.num_experts):
516
  expert_layer = self.experts[expert_idx]
517
  idx, top_x = torch.where(expert_mask[expert_idx])
@@ -521,21 +488,9 @@ class OlmoeSparseMoeBlock(nn.Module):
521
  current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
522
  final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
523
 
524
- # Process small experts
525
- for small_expert_idx in range(self.num_small_experts):
526
- expert_layer = self.small_experts[small_expert_idx]
527
- # Offset by num_experts since small experts come after regular ones
528
- global_expert_idx = self.num_experts + small_expert_idx
529
- idx, top_x = torch.where(expert_mask[global_expert_idx])
530
- if top_x.numel() == 0:
531
- continue
532
- current_state = hidden_states[top_x]
533
- current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
534
- final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
535
-
536
  final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
537
  return final_hidden_states, router_logits
538
-
539
 
540
  class OlmoeDecoderLayer(nn.Module):
541
  def __init__(self, config: OlmoeConfig, layer_idx: int):
@@ -997,4 +952,4 @@ class MyOlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
997
  router_logits=outputs.router_logits,
998
  )
999
 
1000
- __all__ = ["MyOlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"]
 
 
1
  import math
2
  from typing import List, Optional, Tuple, Union
3
  import torch
 
156
  def forward(self, x):
157
  down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
158
  return down_proj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
 
161
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 
446
  }
447
 
448
 
 
449
  class OlmoeSparseMoeBlock(nn.Module):
450
  def __init__(self, config, layer_idx: int):
451
  super().__init__()
452
  self.layer_idx = layer_idx
453
  self.num_experts = config.num_experts
 
 
454
  self.top_k = config.num_experts_per_tok
455
  self.norm_topk_prob = config.norm_topk_prob
456
+ self.routing_type = getattr(config, "routing_type", "topk") # default to topk
457
+ self.n_step = getattr(config, "nth_step", 2) # used in nth-descending
458
+ self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
 
 
 
 
459
  self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
 
 
 
 
 
 
 
 
 
 
460
 
461
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
462
  batch_size, sequence_length, hidden_dim = hidden_states.shape
 
464
  router_logits = self.gate(hidden_states)
465
  routing_probs = F.softmax(router_logits, dim=1, dtype=torch.float)
466
 
467
+ # === Routing ===
468
  routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
469
 
470
  if self.norm_topk_prob:
 
477
  device=hidden_states.device,
478
  )
479
 
480
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
481
 
 
482
  for expert_idx in range(self.num_experts):
483
  expert_layer = self.experts[expert_idx]
484
  idx, top_x = torch.where(expert_mask[expert_idx])
 
488
  current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
489
  final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
490
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
492
  return final_hidden_states, router_logits
493
+
494
 
495
  class OlmoeDecoderLayer(nn.Module):
496
  def __init__(self, config: OlmoeConfig, layer_idx: int):
 
952
  router_logits=outputs.router_logits,
953
  )
954
 
955
+ __all__ = ["MyOlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"]
scripts/train.py CHANGED
@@ -41,8 +41,7 @@ def expand_model_with_small_experts(base_model):
41
  print("# DEBUG: Expanding model with small experts...")
42
  config = base_model.config
43
  config.num_small_experts = 64 # Add 64 small experts
44
- # Changed from //16 to //2 for more reasonable size
45
- config.small_expert_intermediate_size = config.intermediate_size // 2
46
  expanded_model = MyOlmoeForCausalLM(config)
47
 
48
  base_state_dict = base_model.state_dict()
 
41
  print("# DEBUG: Expanding model with small experts...")
42
  config = base_model.config
43
  config.num_small_experts = 64 # Add 64 small experts
44
+ config.small_expert_intermediate_size = config.intermediate_size // 32
 
45
  expanded_model = MyOlmoeForCausalLM(config)
46
 
47
  base_state_dict = base_model.state_dict()