fix small experts loss calculation for gradient
Browse files- myolmoe/modeling_myolmoe.py +30 -23
myolmoe/modeling_myolmoe.py
CHANGED
|
@@ -593,13 +593,6 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 593 |
small_expert_mask = torch.zeros_like(expert_mask)
|
| 594 |
for idx in self.small_expert_indices:
|
| 595 |
small_expert_mask[idx] = expert_mask[idx]
|
| 596 |
-
|
| 597 |
-
small_expert_loss = load_balancing_loss_func(
|
| 598 |
-
router_logits,
|
| 599 |
-
self.num_experts,
|
| 600 |
-
self.top_k,
|
| 601 |
-
None
|
| 602 |
-
) * self.small_expert_load_balancing_coef
|
| 603 |
|
| 604 |
for expert_idx in range(self.num_experts):
|
| 605 |
expert_layer = self.experts[expert_idx]
|
|
@@ -611,7 +604,7 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 611 |
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 612 |
|
| 613 |
final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
|
| 614 |
-
return final_hidden_states, router_logits
|
| 615 |
|
| 616 |
class OlmoeDecoderLayer(nn.Module):
|
| 617 |
def __init__(self, config: OlmoeConfig, layer_idx: int):
|
|
@@ -657,9 +650,9 @@ class OlmoeDecoderLayer(nn.Module):
|
|
| 657 |
hidden_states = residual + hidden_states
|
| 658 |
residual = hidden_states
|
| 659 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 660 |
-
hidden_states, router_logits
|
| 661 |
hidden_states = residual + hidden_states #
|
| 662 |
-
outputs = (hidden_states,
|
| 663 |
if output_attentions:
|
| 664 |
outputs += (self_attn_weights,)
|
| 665 |
if use_cache:
|
|
@@ -1048,29 +1041,43 @@ class MyOlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
|
|
| 1048 |
loss = None
|
| 1049 |
if labels is not None:
|
| 1050 |
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
|
|
|
| 1051 |
aux_loss = None
|
|
|
|
|
|
|
| 1052 |
if output_router_logits:
|
|
|
|
| 1053 |
aux_loss = load_balancing_loss_func(
|
| 1054 |
outputs.router_logits if return_dict else outputs[-1],
|
| 1055 |
self.num_experts,
|
| 1056 |
self.num_experts_per_tok,
|
| 1057 |
attention_mask,
|
| 1058 |
)
|
| 1059 |
-
|
| 1060 |
-
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
|
| 1064 |
-
|
| 1065 |
-
|
| 1066 |
-
|
| 1067 |
-
|
| 1068 |
-
for layer_output in outputs:
|
| 1069 |
-
if len(layer_output) > 1 and isinstance(layer_output[1], torch.Tensor):
|
| 1070 |
-
total_small_expert_loss += layer_output[1]
|
| 1071 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1072 |
if labels is not None:
|
| 1073 |
-
loss
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1074 |
#
|
| 1075 |
return MoeCausalLMOutputWithPast(
|
| 1076 |
loss=loss,
|
|
|
|
| 593 |
small_expert_mask = torch.zeros_like(expert_mask)
|
| 594 |
for idx in self.small_expert_indices:
|
| 595 |
small_expert_mask[idx] = expert_mask[idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
|
| 597 |
for expert_idx in range(self.num_experts):
|
| 598 |
expert_layer = self.experts[expert_idx]
|
|
|
|
| 604 |
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 605 |
|
| 606 |
final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
|
| 607 |
+
return final_hidden_states, router_logits
|
| 608 |
|
| 609 |
class OlmoeDecoderLayer(nn.Module):
|
| 610 |
def __init__(self, config: OlmoeConfig, layer_idx: int):
|
|
|
|
| 650 |
hidden_states = residual + hidden_states
|
| 651 |
residual = hidden_states
|
| 652 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 653 |
+
hidden_states, router_logits = self.mlp(hidden_states) #
|
| 654 |
hidden_states = residual + hidden_states #
|
| 655 |
+
outputs = (hidden_states,) #
|
| 656 |
if output_attentions:
|
| 657 |
outputs += (self_attn_weights,)
|
| 658 |
if use_cache:
|
|
|
|
| 1041 |
loss = None
|
| 1042 |
if labels is not None:
|
| 1043 |
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
| 1044 |
+
#
|
| 1045 |
aux_loss = None
|
| 1046 |
+
total_small_expert_loss = torch.tensor(0.0, device=logits.device)
|
| 1047 |
+
|
| 1048 |
if output_router_logits:
|
| 1049 |
+
# Calculate regular load balancing loss
|
| 1050 |
aux_loss = load_balancing_loss_func(
|
| 1051 |
outputs.router_logits if return_dict else outputs[-1],
|
| 1052 |
self.num_experts,
|
| 1053 |
self.num_experts_per_tok,
|
| 1054 |
attention_mask,
|
| 1055 |
)
|
| 1056 |
+
|
| 1057 |
+
# Calculate small expert load balancing loss
|
| 1058 |
+
router_logits = outputs.router_logits if return_dict else outputs[-1]
|
| 1059 |
+
if isinstance(router_logits, tuple):
|
| 1060 |
+
small_expert_mask = torch.zeros_like(router_logits[0])
|
| 1061 |
+
# Create mask for small experts
|
| 1062 |
+
for idx in range(self.config.num_experts - self.config.small_expert_count,
|
| 1063 |
+
self.config.num_experts):
|
| 1064 |
+
small_expert_mask = small_expert_mask.scatter(-1, torch.tensor([idx]), 1.0)
|
|
|
|
|
|
|
|
|
|
| 1065 |
|
| 1066 |
+
# Apply mask and calculate loss
|
| 1067 |
+
masked_router_logits = [rl * small_expert_mask for rl in router_logits]
|
| 1068 |
+
total_small_expert_loss = load_balancing_loss_func(
|
| 1069 |
+
tuple(masked_router_logits),
|
| 1070 |
+
self.num_experts,
|
| 1071 |
+
self.num_experts_per_tok,
|
| 1072 |
+
attention_mask,
|
| 1073 |
+
) * self.config.small_expert_load_balancing_coef
|
| 1074 |
+
|
| 1075 |
if labels is not None:
|
| 1076 |
+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
| 1077 |
+
if aux_loss is not None:
|
| 1078 |
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
|
| 1079 |
+
if total_small_expert_loss is not None:
|
| 1080 |
+
loss += total_small_expert_loss.to(loss.device)
|
| 1081 |
#
|
| 1082 |
return MoeCausalLMOutputWithPast(
|
| 1083 |
loss=loss,
|