huge fixes
Browse files- myolmoe/config.json +1 -1
- myolmoe/modeling_myolmoe.py +65 -74
myolmoe/config.json
CHANGED
|
@@ -33,5 +33,5 @@
|
|
| 33 |
"vocab_size": 50304,
|
| 34 |
"small_expert_intermediate_ratio": 16,
|
| 35 |
"small_expert_count": 64,
|
| 36 |
-
"
|
| 37 |
}
|
|
|
|
| 33 |
"vocab_size": 50304,
|
| 34 |
"small_expert_intermediate_ratio": 16,
|
| 35 |
"small_expert_count": 64,
|
| 36 |
+
"small_expert_sparsity_coef": 0.1
|
| 37 |
}
|
myolmoe/modeling_myolmoe.py
CHANGED
|
@@ -29,7 +29,7 @@ class OlmoeConfig(PretrainedConfig):
|
|
| 29 |
Ratio of intermediate size for small experts compared to regular experts.
|
| 30 |
small_expert_count (`int`, *optional*, defaults to 64):
|
| 31 |
Frequency of small experts - every Nth expert will be small.
|
| 32 |
-
|
| 33 |
Coefficient for small expert load balancing loss.
|
| 34 |
"""
|
| 35 |
model_type = "olmoe"
|
|
@@ -64,7 +64,7 @@ class OlmoeConfig(PretrainedConfig):
|
|
| 64 |
norm_topk_prob=False,
|
| 65 |
small_expert_intermediate_ratio=64,
|
| 66 |
small_expert_count=64,
|
| 67 |
-
|
| 68 |
**kwargs,
|
| 69 |
):
|
| 70 |
self.vocab_size = vocab_size
|
|
@@ -97,7 +97,7 @@ class OlmoeConfig(PretrainedConfig):
|
|
| 97 |
# Small expert parameters
|
| 98 |
self.small_expert_intermediate_ratio = small_expert_intermediate_ratio
|
| 99 |
self.small_expert_count = small_expert_count
|
| 100 |
-
self.
|
| 101 |
|
| 102 |
# Validate the correctness of rotary position embeddings parameters
|
| 103 |
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
|
@@ -546,65 +546,60 @@ OLMOE_ATTENTION_CLASSES = {
|
|
| 546 |
}
|
| 547 |
|
| 548 |
|
| 549 |
-
|
| 550 |
class OlmoeSparseMoeBlock(nn.Module):
|
| 551 |
def __init__(self, config, layer_idx: int):
|
| 552 |
super().__init__()
|
| 553 |
self.layer_idx = layer_idx
|
| 554 |
self.num_experts = config.num_experts
|
|
|
|
| 555 |
self.top_k = config.num_experts_per_tok
|
| 556 |
self.norm_topk_prob = config.norm_topk_prob
|
| 557 |
-
self.
|
| 558 |
-
self.
|
| 559 |
-
|
| 560 |
-
# Track which experts are small
|
| 561 |
-
self.small_expert_indices = list(range(config.num_experts - config.small_expert_count, config.num_experts))
|
| 562 |
-
self.experts = nn.ModuleList()
|
| 563 |
|
| 564 |
-
for
|
| 565 |
-
# Small experts are now at the end indices
|
| 566 |
-
is_small = i in self.small_expert_indices
|
| 567 |
-
self.experts.append(OlmoeMLP(config, is_small=is_small))
|
| 568 |
-
|
| 569 |
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
| 570 |
-
self.
|
|
|
|
| 571 |
|
| 572 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 573 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 574 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
|
|
|
|
|
| 575 |
router_logits = self.gate(hidden_states)
|
| 576 |
-
|
| 577 |
-
|
|
|
|
|
|
|
|
|
|
| 578 |
routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
|
| 579 |
|
| 580 |
if self.norm_topk_prob:
|
| 581 |
-
routing_weights
|
| 582 |
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
)
|
| 589 |
-
|
| 590 |
-
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
| 591 |
-
|
| 592 |
-
# Calculate small expert load balancing loss
|
| 593 |
-
small_expert_mask = torch.zeros_like(expert_mask)
|
| 594 |
-
for idx in self.small_expert_indices:
|
| 595 |
-
small_expert_mask[idx] = expert_mask[idx]
|
| 596 |
|
| 597 |
-
|
| 598 |
-
|
| 599 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
| 600 |
-
if top_x.
|
| 601 |
continue
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
|
| 606 |
-
|
| 607 |
-
return final_hidden_states, router_logits
|
| 608 |
|
| 609 |
class OlmoeDecoderLayer(nn.Module):
|
| 610 |
def __init__(self, config: OlmoeConfig, layer_idx: int):
|
|
@@ -1042,42 +1037,38 @@ class MyOlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
|
|
| 1042 |
if labels is not None:
|
| 1043 |
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
| 1044 |
#
|
| 1045 |
-
|
| 1046 |
-
|
| 1047 |
-
|
| 1048 |
-
|
| 1049 |
-
|
| 1050 |
-
|
| 1051 |
-
|
| 1052 |
-
|
| 1053 |
-
|
| 1054 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
)
|
|
|
|
| 1056 |
|
| 1057 |
-
|
| 1058 |
-
|
| 1059 |
-
|
| 1060 |
-
|
| 1061 |
-
|
| 1062 |
-
for idx in range(self.config.num_experts - self.config.small_expert_count,
|
| 1063 |
-
self.config.num_experts):
|
| 1064 |
-
small_expert_mask = small_expert_mask.scatter(-1, torch.tensor([idx]), 1.0)
|
| 1065 |
|
| 1066 |
-
|
| 1067 |
-
masked_router_logits
|
| 1068 |
-
|
| 1069 |
-
|
| 1070 |
-
|
| 1071 |
-
|
| 1072 |
-
|
| 1073 |
-
|
| 1074 |
-
|
| 1075 |
-
if labels is not None:
|
| 1076 |
-
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
| 1077 |
-
if aux_loss is not None:
|
| 1078 |
-
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
|
| 1079 |
-
if total_small_expert_loss is not None:
|
| 1080 |
-
loss += total_small_expert_loss.to(loss.device)
|
| 1081 |
#
|
| 1082 |
return MoeCausalLMOutputWithPast(
|
| 1083 |
loss=loss,
|
|
|
|
| 29 |
Ratio of intermediate size for small experts compared to regular experts.
|
| 30 |
small_expert_count (`int`, *optional*, defaults to 64):
|
| 31 |
Frequency of small experts - every Nth expert will be small.
|
| 32 |
+
small_expert_sparsity_coef (`float`, *optional*, defaults to 0.1):
|
| 33 |
Coefficient for small expert load balancing loss.
|
| 34 |
"""
|
| 35 |
model_type = "olmoe"
|
|
|
|
| 64 |
norm_topk_prob=False,
|
| 65 |
small_expert_intermediate_ratio=64,
|
| 66 |
small_expert_count=64,
|
| 67 |
+
small_expert_sparsity_coef=0.1,
|
| 68 |
**kwargs,
|
| 69 |
):
|
| 70 |
self.vocab_size = vocab_size
|
|
|
|
| 97 |
# Small expert parameters
|
| 98 |
self.small_expert_intermediate_ratio = small_expert_intermediate_ratio
|
| 99 |
self.small_expert_count = small_expert_count
|
| 100 |
+
self.small_expert_sparsity_coef = small_expert_sparsity_coef
|
| 101 |
|
| 102 |
# Validate the correctness of rotary position embeddings parameters
|
| 103 |
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
|
|
|
| 546 |
}
|
| 547 |
|
| 548 |
|
|
|
|
| 549 |
class OlmoeSparseMoeBlock(nn.Module):
|
| 550 |
def __init__(self, config, layer_idx: int):
|
| 551 |
super().__init__()
|
| 552 |
self.layer_idx = layer_idx
|
| 553 |
self.num_experts = config.num_experts
|
| 554 |
+
self.num_small_experts = config.small_expert_count
|
| 555 |
self.top_k = config.num_experts_per_tok
|
| 556 |
self.norm_topk_prob = config.norm_topk_prob
|
| 557 |
+
self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
|
| 558 |
+
self.small_experts = nn.ModuleList([OlmoeMLP(config, is_small=True) for _ in range(self.num_small_experts)])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 559 |
|
| 560 |
+
# Gates for both expert types
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
| 562 |
+
self.small_gate = nn.Linear(config.hidden_size, self.num_small_experts, bias=False)
|
| 563 |
+
self.small_expert_sparsity_coef = config.small_expert_sparsity_coef
|
| 564 |
|
| 565 |
+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 566 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 567 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 568 |
+
|
| 569 |
+
# Get logits for both expert types
|
| 570 |
router_logits = self.gate(hidden_states)
|
| 571 |
+
small_router_logits = self.small_gate(hidden_states)
|
| 572 |
+
|
| 573 |
+
# Combine logits for routing
|
| 574 |
+
combined_logits = torch.cat([router_logits, small_router_logits], dim=-1)
|
| 575 |
+
routing_probs = F.softmax(combined_logits, dim=1, dtype=torch.float)
|
| 576 |
routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
|
| 577 |
|
| 578 |
if self.norm_topk_prob:
|
| 579 |
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
| 580 |
|
| 581 |
+
final_hidden_states = torch.zeros_like(hidden_states)
|
| 582 |
+
expert_mask = torch.nn.functional.one_hot(
|
| 583 |
+
selected_experts,
|
| 584 |
+
num_classes=self.num_experts + self.num_small_experts
|
| 585 |
+
).permute(2, 1, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 586 |
|
| 587 |
+
# Process all experts (regular + small)
|
| 588 |
+
for expert_idx in range(self.num_experts + self.num_small_experts):
|
| 589 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
| 590 |
+
if top_x.shape[0] == 0:
|
| 591 |
continue
|
| 592 |
+
|
| 593 |
+
if expert_idx < self.num_experts:
|
| 594 |
+
expert = self.experts[expert_idx]
|
| 595 |
+
else:
|
| 596 |
+
expert = self.small_experts[expert_idx - self.num_experts]
|
| 597 |
+
|
| 598 |
+
current_states = hidden_states[top_x]
|
| 599 |
+
current_output = expert(current_states) * routing_weights[top_x, idx, None]
|
| 600 |
+
final_hidden_states.index_add_(0, top_x, current_output.to(hidden_states.dtype))
|
| 601 |
|
| 602 |
+
return final_hidden_states.view(batch_size, sequence_length, hidden_dim), combined_logits
|
|
|
|
| 603 |
|
| 604 |
class OlmoeDecoderLayer(nn.Module):
|
| 605 |
def __init__(self, config: OlmoeConfig, layer_idx: int):
|
|
|
|
| 1037 |
if labels is not None:
|
| 1038 |
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
| 1039 |
#
|
| 1040 |
+
total_aux_loss = 0
|
| 1041 |
+
if output_router_logits and outputs.router_logits is not None:
|
| 1042 |
+
# Regular load balancing loss
|
| 1043 |
+
total_aux_loss += load_balancing_loss_func(
|
| 1044 |
+
outputs.router_logits,
|
| 1045 |
+
num_experts=self.config.num_experts + self.config.small_expert_count,
|
| 1046 |
+
top_k=self.config.num_experts_per_tok,
|
| 1047 |
+
attention_mask=attention_mask
|
| 1048 |
+
)
|
| 1049 |
+
|
| 1050 |
+
# Small expert sparsity loss
|
| 1051 |
+
small_expert_mask = torch.zeros(
|
| 1052 |
+
self.config.num_experts + self.config.small_expert_count,
|
| 1053 |
+
device=outputs.router_logits[0].device
|
| 1054 |
)
|
| 1055 |
+
small_expert_mask[self.config.num_experts:] = 1.0
|
| 1056 |
|
| 1057 |
+
masked_router_logits = []
|
| 1058 |
+
for logits in outputs.router_logits:
|
| 1059 |
+
# Apply mask to emphasize small experts
|
| 1060 |
+
masked_logits = logits * small_expert_mask * self.config.small_expert_sparsity_coef
|
| 1061 |
+
masked_router_logits.append(masked_logits)
|
|
|
|
|
|
|
|
|
|
| 1062 |
|
| 1063 |
+
total_aux_loss += load_balancing_loss_func(
|
| 1064 |
+
tuple(masked_router_logits),
|
| 1065 |
+
num_experts=self.config.num_experts + self.config.small_expert_count,
|
| 1066 |
+
top_k=self.config.num_experts_per_tok,
|
| 1067 |
+
attention_mask=attention_mask
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
if loss is not None:
|
| 1071 |
+
loss += self.router_aux_loss_coef * total_aux_loss.to(loss.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1072 |
#
|
| 1073 |
return MoeCausalLMOutputWithPast(
|
| 1074 |
loss=loss,
|