attempt new distribution experts
Browse files- myolmoe/modeling_myolmoe.py +48 -15
- scripts/train.py +6 -2
myolmoe/modeling_myolmoe.py
CHANGED
|
@@ -65,6 +65,8 @@ class OlmoeConfig(PretrainedConfig):
|
|
| 65 |
small_expert_intermediate_ratio=64,
|
| 66 |
small_expert_count=64,
|
| 67 |
small_expert_sparsity_coef=0.1,
|
|
|
|
|
|
|
| 68 |
**kwargs,
|
| 69 |
):
|
| 70 |
self.vocab_size = vocab_size
|
|
@@ -98,6 +100,8 @@ class OlmoeConfig(PretrainedConfig):
|
|
| 98 |
self.small_expert_intermediate_ratio = small_expert_intermediate_ratio
|
| 99 |
self.small_expert_count = small_expert_count
|
| 100 |
self.small_expert_sparsity_coef = small_expert_sparsity_coef
|
|
|
|
|
|
|
| 101 |
|
| 102 |
# Validate the correctness of rotary position embeddings parameters
|
| 103 |
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
|
@@ -550,28 +554,57 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 550 |
def __init__(self, config, layer_idx: int):
|
| 551 |
super().__init__()
|
| 552 |
self.layer_idx = layer_idx
|
|
|
|
| 553 |
self.num_experts = config.num_experts
|
| 554 |
-
self.num_small_experts = config.small_expert_count
|
| 555 |
self.top_k = config.num_experts_per_tok
|
| 556 |
self.norm_topk_prob = config.norm_topk_prob
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 557 |
self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
|
| 558 |
-
self.small_experts = nn.ModuleList([
|
| 559 |
-
|
| 560 |
-
|
|
|
|
| 561 |
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
| 562 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
self.small_expert_sparsity_coef = config.small_expert_sparsity_coef
|
| 564 |
|
| 565 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 566 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 567 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 568 |
-
|
| 569 |
-
# Get logits for both expert types
|
| 570 |
router_logits = self.gate(hidden_states)
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
|
|
|
|
|
|
|
|
|
| 575 |
routing_probs = F.softmax(combined_logits, dim=1, dtype=torch.float)
|
| 576 |
routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
|
| 577 |
|
|
@@ -580,27 +613,27 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 580 |
|
| 581 |
final_hidden_states = torch.zeros_like(hidden_states)
|
| 582 |
expert_mask = torch.nn.functional.one_hot(
|
| 583 |
-
selected_experts,
|
| 584 |
num_classes=self.num_experts + self.num_small_experts
|
| 585 |
).permute(2, 1, 0)
|
| 586 |
|
| 587 |
-
# Process all experts (regular + small)
|
| 588 |
for expert_idx in range(self.num_experts + self.num_small_experts):
|
| 589 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
| 590 |
if top_x.shape[0] == 0:
|
| 591 |
continue
|
| 592 |
-
|
| 593 |
if expert_idx < self.num_experts:
|
| 594 |
expert = self.experts[expert_idx]
|
| 595 |
else:
|
| 596 |
expert = self.small_experts[expert_idx - self.num_experts]
|
| 597 |
-
|
| 598 |
current_states = hidden_states[top_x]
|
| 599 |
current_output = expert(current_states) * routing_weights[top_x, idx, None]
|
| 600 |
final_hidden_states.index_add_(0, top_x, current_output.to(hidden_states.dtype))
|
| 601 |
|
| 602 |
return final_hidden_states.view(batch_size, sequence_length, hidden_dim), combined_logits
|
| 603 |
|
|
|
|
| 604 |
class OlmoeDecoderLayer(nn.Module):
|
| 605 |
def __init__(self, config: OlmoeConfig, layer_idx: int):
|
| 606 |
super().__init__()
|
|
|
|
| 65 |
small_expert_intermediate_ratio=64,
|
| 66 |
small_expert_count=64,
|
| 67 |
small_expert_sparsity_coef=0.1,
|
| 68 |
+
small_expert_strategy="constant", # NEW
|
| 69 |
+
max_small_expert_count=64, # NEW: total possible small experts
|
| 70 |
**kwargs,
|
| 71 |
):
|
| 72 |
self.vocab_size = vocab_size
|
|
|
|
| 100 |
self.small_expert_intermediate_ratio = small_expert_intermediate_ratio
|
| 101 |
self.small_expert_count = small_expert_count
|
| 102 |
self.small_expert_sparsity_coef = small_expert_sparsity_coef
|
| 103 |
+
self.small_expert_strategy = small_expert_strategy
|
| 104 |
+
self.max_small_expert_count = max_small_expert_count
|
| 105 |
|
| 106 |
# Validate the correctness of rotary position embeddings parameters
|
| 107 |
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
|
|
|
| 554 |
def __init__(self, config, layer_idx: int):
|
| 555 |
super().__init__()
|
| 556 |
self.layer_idx = layer_idx
|
| 557 |
+
self.total_layers = config.num_hidden_layers
|
| 558 |
self.num_experts = config.num_experts
|
|
|
|
| 559 |
self.top_k = config.num_experts_per_tok
|
| 560 |
self.norm_topk_prob = config.norm_topk_prob
|
| 561 |
+
|
| 562 |
+
# Determine if this block is in the second half
|
| 563 |
+
in_second_half = layer_idx >= self.total_layers // 2
|
| 564 |
+
|
| 565 |
+
# Determine small expert count for this layer
|
| 566 |
+
if in_second_half:
|
| 567 |
+
second_half_idx = layer_idx - (self.total_layers // 2)
|
| 568 |
+
num_second_half_blocks = self.total_layers - (self.total_layers // 2)
|
| 569 |
+
|
| 570 |
+
if config.small_expert_strategy == "constant":
|
| 571 |
+
self.num_small_experts = config.max_small_expert_count // num_second_half_blocks
|
| 572 |
+
elif config.small_expert_strategy == "increment":
|
| 573 |
+
# Linearly scale small experts from 1 to max_small_expert_count
|
| 574 |
+
self.num_small_experts = (
|
| 575 |
+
(second_half_idx + 1) * config.max_small_expert_count // ((num_second_half_blocks * (num_second_half_blocks + 1)) // 2)
|
| 576 |
+
)
|
| 577 |
+
else:
|
| 578 |
+
raise ValueError(f"Unknown strategy: {config.small_expert_strategy}")
|
| 579 |
+
else:
|
| 580 |
+
self.num_small_experts = 0
|
| 581 |
+
|
| 582 |
self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
|
| 583 |
+
self.small_experts = nn.ModuleList([
|
| 584 |
+
OlmoeMLP(config, is_small=True) for _ in range(self.num_small_experts)
|
| 585 |
+
]) if self.num_small_experts > 0 else None
|
| 586 |
+
|
| 587 |
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
| 588 |
+
|
| 589 |
+
if self.num_small_experts > 0:
|
| 590 |
+
self.small_gate = nn.Linear(config.hidden_size, self.num_small_experts, bias=False)
|
| 591 |
+
else:
|
| 592 |
+
self.small_gate = None
|
| 593 |
+
|
| 594 |
self.small_expert_sparsity_coef = config.small_expert_sparsity_coef
|
| 595 |
|
| 596 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 597 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 598 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 599 |
+
|
|
|
|
| 600 |
router_logits = self.gate(hidden_states)
|
| 601 |
+
|
| 602 |
+
if self.num_small_experts > 0:
|
| 603 |
+
small_router_logits = self.small_gate(hidden_states)
|
| 604 |
+
combined_logits = torch.cat([router_logits, small_router_logits], dim=-1)
|
| 605 |
+
else:
|
| 606 |
+
combined_logits = router_logits
|
| 607 |
+
|
| 608 |
routing_probs = F.softmax(combined_logits, dim=1, dtype=torch.float)
|
| 609 |
routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
|
| 610 |
|
|
|
|
| 613 |
|
| 614 |
final_hidden_states = torch.zeros_like(hidden_states)
|
| 615 |
expert_mask = torch.nn.functional.one_hot(
|
| 616 |
+
selected_experts,
|
| 617 |
num_classes=self.num_experts + self.num_small_experts
|
| 618 |
).permute(2, 1, 0)
|
| 619 |
|
|
|
|
| 620 |
for expert_idx in range(self.num_experts + self.num_small_experts):
|
| 621 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
| 622 |
if top_x.shape[0] == 0:
|
| 623 |
continue
|
| 624 |
+
|
| 625 |
if expert_idx < self.num_experts:
|
| 626 |
expert = self.experts[expert_idx]
|
| 627 |
else:
|
| 628 |
expert = self.small_experts[expert_idx - self.num_experts]
|
| 629 |
+
|
| 630 |
current_states = hidden_states[top_x]
|
| 631 |
current_output = expert(current_states) * routing_weights[top_x, idx, None]
|
| 632 |
final_hidden_states.index_add_(0, top_x, current_output.to(hidden_states.dtype))
|
| 633 |
|
| 634 |
return final_hidden_states.view(batch_size, sequence_length, hidden_dim), combined_logits
|
| 635 |
|
| 636 |
+
|
| 637 |
class OlmoeDecoderLayer(nn.Module):
|
| 638 |
def __init__(self, config: OlmoeConfig, layer_idx: int):
|
| 639 |
super().__init__()
|
scripts/train.py
CHANGED
|
@@ -73,7 +73,7 @@ def main():
|
|
| 73 |
per_device_train_batch_size=2,
|
| 74 |
gradient_accumulation_steps=8,
|
| 75 |
learning_rate=1e-4,
|
| 76 |
-
num_train_epochs=
|
| 77 |
logging_dir="./logs",
|
| 78 |
logging_steps=10,
|
| 79 |
save_steps=1000,
|
|
@@ -94,10 +94,14 @@ def main():
|
|
| 94 |
# Unfreeze only the small experts and their gating networks
|
| 95 |
trainable_params = []
|
| 96 |
for name, param in model.named_parameters():
|
| 97 |
-
if
|
|
|
|
|
|
|
|
|
|
| 98 |
param.requires_grad = True
|
| 99 |
trainable_params.append(name)
|
| 100 |
print(f"Unfreezing parameter: {name}")
|
|
|
|
| 101 |
|
| 102 |
print(f"Total trainable parameters: {len(trainable_params)}")
|
| 103 |
|
|
|
|
| 73 |
per_device_train_batch_size=2,
|
| 74 |
gradient_accumulation_steps=8,
|
| 75 |
learning_rate=1e-4,
|
| 76 |
+
num_train_epochs=0.001,
|
| 77 |
logging_dir="./logs",
|
| 78 |
logging_steps=10,
|
| 79 |
save_steps=1000,
|
|
|
|
| 94 |
# Unfreeze only the small experts and their gating networks
|
| 95 |
trainable_params = []
|
| 96 |
for name, param in model.named_parameters():
|
| 97 |
+
if (
|
| 98 |
+
"small_experts" in name or
|
| 99 |
+
"small_gate" in name
|
| 100 |
+
):
|
| 101 |
param.requires_grad = True
|
| 102 |
trainable_params.append(name)
|
| 103 |
print(f"Unfreezing parameter: {name}")
|
| 104 |
+
|
| 105 |
|
| 106 |
print(f"Total trainable parameters: {len(trainable_params)}")
|
| 107 |
|