Charlie81 commited on
Commit
7050cb6
·
1 Parent(s): 356573e

reorder small experts

Browse files
myolmoe/config.json CHANGED
@@ -32,6 +32,6 @@
32
  "use_cache": true,
33
  "vocab_size": 50304,
34
  "small_expert_intermediate_ratio": 16,
35
- "small_expert_frequency": 4,
36
  "small_expert_load_balancing_coef": 0.1
37
  }
 
32
  "use_cache": true,
33
  "vocab_size": 50304,
34
  "small_expert_intermediate_ratio": 16,
35
+ "small_expert_count": 64,
36
  "small_expert_load_balancing_coef": 0.1
37
  }
myolmoe/modeling_myolmoe.py CHANGED
@@ -27,7 +27,7 @@ class OlmoeConfig(PretrainedConfig):
27
  [Previous args remain the same...]
28
  small_expert_intermediate_ratio (`float`, *optional*, defaults to 0.5):
29
  Ratio of intermediate size for small experts compared to regular experts.
30
- small_expert_frequency (`int`, *optional*, defaults to 4):
31
  Frequency of small experts - every Nth expert will be small.
32
  small_expert_load_balancing_coef (`float`, *optional*, defaults to 0.1):
33
  Coefficient for small expert load balancing loss.
@@ -63,7 +63,7 @@ class OlmoeConfig(PretrainedConfig):
63
  router_aux_loss_coef=0.01,
64
  norm_topk_prob=False,
65
  small_expert_intermediate_ratio=64,
66
- small_expert_frequency=4,
67
  small_expert_load_balancing_coef=0.1,
68
  **kwargs,
69
  ):
@@ -96,7 +96,7 @@ class OlmoeConfig(PretrainedConfig):
96
 
97
  # Small expert parameters
98
  self.small_expert_intermediate_ratio = small_expert_intermediate_ratio
99
- self.small_expert_frequency = small_expert_frequency
100
  self.small_expert_load_balancing_coef = small_expert_load_balancing_coef
101
 
102
  # Validate the correctness of rotary position embeddings parameters
@@ -558,13 +558,12 @@ class OlmoeSparseMoeBlock(nn.Module):
558
  self.n_step = getattr(config, "nth_step", 2)
559
 
560
  # Track which experts are small
561
- self.small_expert_indices = []
562
  self.experts = nn.ModuleList()
563
 
564
  for i in range(self.num_experts):
565
- is_small = (i % config.small_expert_frequency == 0)
566
- if is_small:
567
- self.small_expert_indices.append(i)
568
  self.experts.append(OlmoeMLP(config, is_small=is_small))
569
 
570
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
 
27
  [Previous args remain the same...]
28
  small_expert_intermediate_ratio (`float`, *optional*, defaults to 0.5):
29
  Ratio of intermediate size for small experts compared to regular experts.
30
+ small_expert_count (`int`, *optional*, defaults to 64):
31
  Frequency of small experts - every Nth expert will be small.
32
  small_expert_load_balancing_coef (`float`, *optional*, defaults to 0.1):
33
  Coefficient for small expert load balancing loss.
 
63
  router_aux_loss_coef=0.01,
64
  norm_topk_prob=False,
65
  small_expert_intermediate_ratio=64,
66
+ small_expert_count=64,
67
  small_expert_load_balancing_coef=0.1,
68
  **kwargs,
69
  ):
 
96
 
97
  # Small expert parameters
98
  self.small_expert_intermediate_ratio = small_expert_intermediate_ratio
99
+ self.small_expert_count = small_expert_count
100
  self.small_expert_load_balancing_coef = small_expert_load_balancing_coef
101
 
102
  # Validate the correctness of rotary position embeddings parameters
 
558
  self.n_step = getattr(config, "nth_step", 2)
559
 
560
  # Track which experts are small
561
+ self.small_expert_indices = list(range(config.num_experts - config.small_expert_count, config.num_experts))
562
  self.experts = nn.ModuleList()
563
 
564
  for i in range(self.num_experts):
565
+ # Small experts are now at the end indices
566
+ is_small = i in self.small_expert_indices
 
567
  self.experts.append(OlmoeMLP(config, is_small=is_small))
568
 
569
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
scripts/train.py CHANGED
@@ -95,7 +95,7 @@ def main():
95
  # Unfreeze only the small experts and their gating networks
96
  for name, param in model.named_parameters():
97
  # Unfreeze small expert layers
98
- if "mlp.experts" in name and any(f"mlp.experts.{i}." in name for i in range(0, config.num_experts, config.small_expert_frequency)):
99
  param.requires_grad = True
100
  print(f"Unfreezing small expert parameter: {name}")
101
 
@@ -103,7 +103,7 @@ def main():
103
  if "mlp.gate" in name:
104
  param.requires_grad = True
105
  print(f"Unfreezing gating network parameter: {name}")
106
-
107
  # Trainer
108
  trainer = Trainer(
109
  model=model,
 
95
  # Unfreeze only the small experts and their gating networks
96
  for name, param in model.named_parameters():
97
  # Unfreeze small expert layers
98
+ if "mlp.experts" in name and any(f"mlp.experts.{i}." in name for i in range(0, config.num_experts, config.small_expert_count)):
99
  param.requires_grad = True
100
  print(f"Unfreezing small expert parameter: {name}")
101
 
 
103
  if "mlp.gate" in name:
104
  param.requires_grad = True
105
  print(f"Unfreezing gating network parameter: {name}")
106
+
107
  # Trainer
108
  trainer = Trainer(
109
  model=model,