reorder small experts
Browse files- myolmoe/config.json +1 -1
- myolmoe/modeling_myolmoe.py +6 -7
- scripts/train.py +2 -2
myolmoe/config.json
CHANGED
|
@@ -32,6 +32,6 @@
|
|
| 32 |
"use_cache": true,
|
| 33 |
"vocab_size": 50304,
|
| 34 |
"small_expert_intermediate_ratio": 16,
|
| 35 |
-
"
|
| 36 |
"small_expert_load_balancing_coef": 0.1
|
| 37 |
}
|
|
|
|
| 32 |
"use_cache": true,
|
| 33 |
"vocab_size": 50304,
|
| 34 |
"small_expert_intermediate_ratio": 16,
|
| 35 |
+
"small_expert_count": 64,
|
| 36 |
"small_expert_load_balancing_coef": 0.1
|
| 37 |
}
|
myolmoe/modeling_myolmoe.py
CHANGED
|
@@ -27,7 +27,7 @@ class OlmoeConfig(PretrainedConfig):
|
|
| 27 |
[Previous args remain the same...]
|
| 28 |
small_expert_intermediate_ratio (`float`, *optional*, defaults to 0.5):
|
| 29 |
Ratio of intermediate size for small experts compared to regular experts.
|
| 30 |
-
|
| 31 |
Frequency of small experts - every Nth expert will be small.
|
| 32 |
small_expert_load_balancing_coef (`float`, *optional*, defaults to 0.1):
|
| 33 |
Coefficient for small expert load balancing loss.
|
|
@@ -63,7 +63,7 @@ class OlmoeConfig(PretrainedConfig):
|
|
| 63 |
router_aux_loss_coef=0.01,
|
| 64 |
norm_topk_prob=False,
|
| 65 |
small_expert_intermediate_ratio=64,
|
| 66 |
-
|
| 67 |
small_expert_load_balancing_coef=0.1,
|
| 68 |
**kwargs,
|
| 69 |
):
|
|
@@ -96,7 +96,7 @@ class OlmoeConfig(PretrainedConfig):
|
|
| 96 |
|
| 97 |
# Small expert parameters
|
| 98 |
self.small_expert_intermediate_ratio = small_expert_intermediate_ratio
|
| 99 |
-
self.
|
| 100 |
self.small_expert_load_balancing_coef = small_expert_load_balancing_coef
|
| 101 |
|
| 102 |
# Validate the correctness of rotary position embeddings parameters
|
|
@@ -558,13 +558,12 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 558 |
self.n_step = getattr(config, "nth_step", 2)
|
| 559 |
|
| 560 |
# Track which experts are small
|
| 561 |
-
self.small_expert_indices =
|
| 562 |
self.experts = nn.ModuleList()
|
| 563 |
|
| 564 |
for i in range(self.num_experts):
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
self.small_expert_indices.append(i)
|
| 568 |
self.experts.append(OlmoeMLP(config, is_small=is_small))
|
| 569 |
|
| 570 |
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
|
|
|
| 27 |
[Previous args remain the same...]
|
| 28 |
small_expert_intermediate_ratio (`float`, *optional*, defaults to 0.5):
|
| 29 |
Ratio of intermediate size for small experts compared to regular experts.
|
| 30 |
+
small_expert_count (`int`, *optional*, defaults to 64):
|
| 31 |
Frequency of small experts - every Nth expert will be small.
|
| 32 |
small_expert_load_balancing_coef (`float`, *optional*, defaults to 0.1):
|
| 33 |
Coefficient for small expert load balancing loss.
|
|
|
|
| 63 |
router_aux_loss_coef=0.01,
|
| 64 |
norm_topk_prob=False,
|
| 65 |
small_expert_intermediate_ratio=64,
|
| 66 |
+
small_expert_count=64,
|
| 67 |
small_expert_load_balancing_coef=0.1,
|
| 68 |
**kwargs,
|
| 69 |
):
|
|
|
|
| 96 |
|
| 97 |
# Small expert parameters
|
| 98 |
self.small_expert_intermediate_ratio = small_expert_intermediate_ratio
|
| 99 |
+
self.small_expert_count = small_expert_count
|
| 100 |
self.small_expert_load_balancing_coef = small_expert_load_balancing_coef
|
| 101 |
|
| 102 |
# Validate the correctness of rotary position embeddings parameters
|
|
|
|
| 558 |
self.n_step = getattr(config, "nth_step", 2)
|
| 559 |
|
| 560 |
# Track which experts are small
|
| 561 |
+
self.small_expert_indices = list(range(config.num_experts - config.small_expert_count, config.num_experts))
|
| 562 |
self.experts = nn.ModuleList()
|
| 563 |
|
| 564 |
for i in range(self.num_experts):
|
| 565 |
+
# Small experts are now at the end indices
|
| 566 |
+
is_small = i in self.small_expert_indices
|
|
|
|
| 567 |
self.experts.append(OlmoeMLP(config, is_small=is_small))
|
| 568 |
|
| 569 |
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
scripts/train.py
CHANGED
|
@@ -95,7 +95,7 @@ def main():
|
|
| 95 |
# Unfreeze only the small experts and their gating networks
|
| 96 |
for name, param in model.named_parameters():
|
| 97 |
# Unfreeze small expert layers
|
| 98 |
-
if "mlp.experts" in name and any(f"mlp.experts.{i}." in name for i in range(0, config.num_experts, config.
|
| 99 |
param.requires_grad = True
|
| 100 |
print(f"Unfreezing small expert parameter: {name}")
|
| 101 |
|
|
@@ -103,7 +103,7 @@ def main():
|
|
| 103 |
if "mlp.gate" in name:
|
| 104 |
param.requires_grad = True
|
| 105 |
print(f"Unfreezing gating network parameter: {name}")
|
| 106 |
-
|
| 107 |
# Trainer
|
| 108 |
trainer = Trainer(
|
| 109 |
model=model,
|
|
|
|
| 95 |
# Unfreeze only the small experts and their gating networks
|
| 96 |
for name, param in model.named_parameters():
|
| 97 |
# Unfreeze small expert layers
|
| 98 |
+
if "mlp.experts" in name and any(f"mlp.experts.{i}." in name for i in range(0, config.num_experts, config.small_expert_count)):
|
| 99 |
param.requires_grad = True
|
| 100 |
print(f"Unfreezing small expert parameter: {name}")
|
| 101 |
|
|
|
|
| 103 |
if "mlp.gate" in name:
|
| 104 |
param.requires_grad = True
|
| 105 |
print(f"Unfreezing gating network parameter: {name}")
|
| 106 |
+
|
| 107 |
# Trainer
|
| 108 |
trainer = Trainer(
|
| 109 |
model=model,
|