Charlie81 commited on
Commit
834ad70
·
1 Parent(s): f03d4f8

attempt new distribution experts

Browse files
Files changed (2) hide show
  1. myolmoe/modeling_myolmoe.py +48 -15
  2. scripts/train.py +6 -2
myolmoe/modeling_myolmoe.py CHANGED
@@ -65,6 +65,8 @@ class OlmoeConfig(PretrainedConfig):
65
  small_expert_intermediate_ratio=64,
66
  small_expert_count=64,
67
  small_expert_sparsity_coef=0.1,
 
 
68
  **kwargs,
69
  ):
70
  self.vocab_size = vocab_size
@@ -98,6 +100,8 @@ class OlmoeConfig(PretrainedConfig):
98
  self.small_expert_intermediate_ratio = small_expert_intermediate_ratio
99
  self.small_expert_count = small_expert_count
100
  self.small_expert_sparsity_coef = small_expert_sparsity_coef
 
 
101
 
102
  # Validate the correctness of rotary position embeddings parameters
103
  if self.rope_scaling is not None and "type" in self.rope_scaling:
@@ -550,28 +554,57 @@ class OlmoeSparseMoeBlock(nn.Module):
550
  def __init__(self, config, layer_idx: int):
551
  super().__init__()
552
  self.layer_idx = layer_idx
 
553
  self.num_experts = config.num_experts
554
- self.num_small_experts = config.small_expert_count
555
  self.top_k = config.num_experts_per_tok
556
  self.norm_topk_prob = config.norm_topk_prob
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
558
- self.small_experts = nn.ModuleList([OlmoeMLP(config, is_small=True) for _ in range(self.num_small_experts)])
559
-
560
- # Gates for both expert types
 
561
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
562
- self.small_gate = nn.Linear(config.hidden_size, self.num_small_experts, bias=False)
 
 
 
 
 
563
  self.small_expert_sparsity_coef = config.small_expert_sparsity_coef
564
 
565
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
566
  batch_size, sequence_length, hidden_dim = hidden_states.shape
567
  hidden_states = hidden_states.view(-1, hidden_dim)
568
-
569
- # Get logits for both expert types
570
  router_logits = self.gate(hidden_states)
571
- small_router_logits = self.small_gate(hidden_states)
572
-
573
- # Combine logits for routing
574
- combined_logits = torch.cat([router_logits, small_router_logits], dim=-1)
 
 
 
575
  routing_probs = F.softmax(combined_logits, dim=1, dtype=torch.float)
576
  routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
577
 
@@ -580,27 +613,27 @@ class OlmoeSparseMoeBlock(nn.Module):
580
 
581
  final_hidden_states = torch.zeros_like(hidden_states)
582
  expert_mask = torch.nn.functional.one_hot(
583
- selected_experts,
584
  num_classes=self.num_experts + self.num_small_experts
585
  ).permute(2, 1, 0)
586
 
587
- # Process all experts (regular + small)
588
  for expert_idx in range(self.num_experts + self.num_small_experts):
589
  idx, top_x = torch.where(expert_mask[expert_idx])
590
  if top_x.shape[0] == 0:
591
  continue
592
-
593
  if expert_idx < self.num_experts:
594
  expert = self.experts[expert_idx]
595
  else:
596
  expert = self.small_experts[expert_idx - self.num_experts]
597
-
598
  current_states = hidden_states[top_x]
599
  current_output = expert(current_states) * routing_weights[top_x, idx, None]
600
  final_hidden_states.index_add_(0, top_x, current_output.to(hidden_states.dtype))
601
 
602
  return final_hidden_states.view(batch_size, sequence_length, hidden_dim), combined_logits
603
 
 
604
  class OlmoeDecoderLayer(nn.Module):
605
  def __init__(self, config: OlmoeConfig, layer_idx: int):
606
  super().__init__()
 
65
  small_expert_intermediate_ratio=64,
66
  small_expert_count=64,
67
  small_expert_sparsity_coef=0.1,
68
+ small_expert_strategy="constant", # NEW
69
+ max_small_expert_count=64, # NEW: total possible small experts
70
  **kwargs,
71
  ):
72
  self.vocab_size = vocab_size
 
100
  self.small_expert_intermediate_ratio = small_expert_intermediate_ratio
101
  self.small_expert_count = small_expert_count
102
  self.small_expert_sparsity_coef = small_expert_sparsity_coef
103
+ self.small_expert_strategy = small_expert_strategy
104
+ self.max_small_expert_count = max_small_expert_count
105
 
106
  # Validate the correctness of rotary position embeddings parameters
107
  if self.rope_scaling is not None and "type" in self.rope_scaling:
 
554
  def __init__(self, config, layer_idx: int):
555
  super().__init__()
556
  self.layer_idx = layer_idx
557
+ self.total_layers = config.num_hidden_layers
558
  self.num_experts = config.num_experts
 
559
  self.top_k = config.num_experts_per_tok
560
  self.norm_topk_prob = config.norm_topk_prob
561
+
562
+ # Determine if this block is in the second half
563
+ in_second_half = layer_idx >= self.total_layers // 2
564
+
565
+ # Determine small expert count for this layer
566
+ if in_second_half:
567
+ second_half_idx = layer_idx - (self.total_layers // 2)
568
+ num_second_half_blocks = self.total_layers - (self.total_layers // 2)
569
+
570
+ if config.small_expert_strategy == "constant":
571
+ self.num_small_experts = config.max_small_expert_count // num_second_half_blocks
572
+ elif config.small_expert_strategy == "increment":
573
+ # Linearly scale small experts from 1 to max_small_expert_count
574
+ self.num_small_experts = (
575
+ (second_half_idx + 1) * config.max_small_expert_count // ((num_second_half_blocks * (num_second_half_blocks + 1)) // 2)
576
+ )
577
+ else:
578
+ raise ValueError(f"Unknown strategy: {config.small_expert_strategy}")
579
+ else:
580
+ self.num_small_experts = 0
581
+
582
  self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
583
+ self.small_experts = nn.ModuleList([
584
+ OlmoeMLP(config, is_small=True) for _ in range(self.num_small_experts)
585
+ ]) if self.num_small_experts > 0 else None
586
+
587
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
588
+
589
+ if self.num_small_experts > 0:
590
+ self.small_gate = nn.Linear(config.hidden_size, self.num_small_experts, bias=False)
591
+ else:
592
+ self.small_gate = None
593
+
594
  self.small_expert_sparsity_coef = config.small_expert_sparsity_coef
595
 
596
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
597
  batch_size, sequence_length, hidden_dim = hidden_states.shape
598
  hidden_states = hidden_states.view(-1, hidden_dim)
599
+
 
600
  router_logits = self.gate(hidden_states)
601
+
602
+ if self.num_small_experts > 0:
603
+ small_router_logits = self.small_gate(hidden_states)
604
+ combined_logits = torch.cat([router_logits, small_router_logits], dim=-1)
605
+ else:
606
+ combined_logits = router_logits
607
+
608
  routing_probs = F.softmax(combined_logits, dim=1, dtype=torch.float)
609
  routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
610
 
 
613
 
614
  final_hidden_states = torch.zeros_like(hidden_states)
615
  expert_mask = torch.nn.functional.one_hot(
616
+ selected_experts,
617
  num_classes=self.num_experts + self.num_small_experts
618
  ).permute(2, 1, 0)
619
 
 
620
  for expert_idx in range(self.num_experts + self.num_small_experts):
621
  idx, top_x = torch.where(expert_mask[expert_idx])
622
  if top_x.shape[0] == 0:
623
  continue
624
+
625
  if expert_idx < self.num_experts:
626
  expert = self.experts[expert_idx]
627
  else:
628
  expert = self.small_experts[expert_idx - self.num_experts]
629
+
630
  current_states = hidden_states[top_x]
631
  current_output = expert(current_states) * routing_weights[top_x, idx, None]
632
  final_hidden_states.index_add_(0, top_x, current_output.to(hidden_states.dtype))
633
 
634
  return final_hidden_states.view(batch_size, sequence_length, hidden_dim), combined_logits
635
 
636
+
637
  class OlmoeDecoderLayer(nn.Module):
638
  def __init__(self, config: OlmoeConfig, layer_idx: int):
639
  super().__init__()
scripts/train.py CHANGED
@@ -73,7 +73,7 @@ def main():
73
  per_device_train_batch_size=2,
74
  gradient_accumulation_steps=8,
75
  learning_rate=1e-4,
76
- num_train_epochs=1,
77
  logging_dir="./logs",
78
  logging_steps=10,
79
  save_steps=1000,
@@ -94,10 +94,14 @@ def main():
94
  # Unfreeze only the small experts and their gating networks
95
  trainable_params = []
96
  for name, param in model.named_parameters():
97
- if "mlp.small_experts" in name or "mlp.small_gate" in name:
 
 
 
98
  param.requires_grad = True
99
  trainable_params.append(name)
100
  print(f"Unfreezing parameter: {name}")
 
101
 
102
  print(f"Total trainable parameters: {len(trainable_params)}")
103
 
 
73
  per_device_train_batch_size=2,
74
  gradient_accumulation_steps=8,
75
  learning_rate=1e-4,
76
+ num_train_epochs=0.001,
77
  logging_dir="./logs",
78
  logging_steps=10,
79
  save_steps=1000,
 
94
  # Unfreeze only the small experts and their gating networks
95
  trainable_params = []
96
  for name, param in model.named_parameters():
97
+ if (
98
+ "small_experts" in name or
99
+ "small_gate" in name
100
+ ):
101
  param.requires_grad = True
102
  trainable_params.append(name)
103
  print(f"Unfreezing parameter: {name}")
104
+
105
 
106
  print(f"Total trainable parameters: {len(trainable_params)}")
107