Charlie81 commited on
Commit
44c43d7
·
1 Parent(s): 7050cb6

fix small experts loss calculation for gradient

Browse files
Files changed (1) hide show
  1. myolmoe/modeling_myolmoe.py +30 -23
myolmoe/modeling_myolmoe.py CHANGED
@@ -593,13 +593,6 @@ class OlmoeSparseMoeBlock(nn.Module):
593
  small_expert_mask = torch.zeros_like(expert_mask)
594
  for idx in self.small_expert_indices:
595
  small_expert_mask[idx] = expert_mask[idx]
596
-
597
- small_expert_loss = load_balancing_loss_func(
598
- router_logits,
599
- self.num_experts,
600
- self.top_k,
601
- None
602
- ) * self.small_expert_load_balancing_coef
603
 
604
  for expert_idx in range(self.num_experts):
605
  expert_layer = self.experts[expert_idx]
@@ -611,7 +604,7 @@ class OlmoeSparseMoeBlock(nn.Module):
611
  final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
612
 
613
  final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
614
- return final_hidden_states, router_logits, small_expert_loss
615
 
616
  class OlmoeDecoderLayer(nn.Module):
617
  def __init__(self, config: OlmoeConfig, layer_idx: int):
@@ -657,9 +650,9 @@ class OlmoeDecoderLayer(nn.Module):
657
  hidden_states = residual + hidden_states
658
  residual = hidden_states
659
  hidden_states = self.post_attention_layernorm(hidden_states)
660
- hidden_states, router_logits, small_expert_loss = self.mlp(hidden_states) #
661
  hidden_states = residual + hidden_states #
662
- outputs = (hidden_states, small_expert_loss) #
663
  if output_attentions:
664
  outputs += (self_attn_weights,)
665
  if use_cache:
@@ -1048,29 +1041,43 @@ class MyOlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
1048
  loss = None
1049
  if labels is not None:
1050
  loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
 
1051
  aux_loss = None
 
 
1052
  if output_router_logits:
 
1053
  aux_loss = load_balancing_loss_func(
1054
  outputs.router_logits if return_dict else outputs[-1],
1055
  self.num_experts,
1056
  self.num_experts_per_tok,
1057
  attention_mask,
1058
  )
1059
- if labels is not None:
1060
- loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
1061
- if not return_dict:
1062
- output = (logits,) + outputs[1:]
1063
- if output_router_logits:
1064
- output = (aux_loss,) + output
1065
- return (loss,) + output if loss is not None else output
1066
- #
1067
- total_small_expert_loss = torch.tensor(0.0, device=logits.device)
1068
- for layer_output in outputs:
1069
- if len(layer_output) > 1 and isinstance(layer_output[1], torch.Tensor):
1070
- total_small_expert_loss += layer_output[1]
1071
 
 
 
 
 
 
 
 
 
 
1072
  if labels is not None:
1073
- loss += total_small_expert_loss.to(loss.device)
 
 
 
 
1074
  #
1075
  return MoeCausalLMOutputWithPast(
1076
  loss=loss,
 
593
  small_expert_mask = torch.zeros_like(expert_mask)
594
  for idx in self.small_expert_indices:
595
  small_expert_mask[idx] = expert_mask[idx]
 
 
 
 
 
 
 
596
 
597
  for expert_idx in range(self.num_experts):
598
  expert_layer = self.experts[expert_idx]
 
604
  final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
605
 
606
  final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
607
+ return final_hidden_states, router_logits
608
 
609
  class OlmoeDecoderLayer(nn.Module):
610
  def __init__(self, config: OlmoeConfig, layer_idx: int):
 
650
  hidden_states = residual + hidden_states
651
  residual = hidden_states
652
  hidden_states = self.post_attention_layernorm(hidden_states)
653
+ hidden_states, router_logits = self.mlp(hidden_states) #
654
  hidden_states = residual + hidden_states #
655
+ outputs = (hidden_states,) #
656
  if output_attentions:
657
  outputs += (self_attn_weights,)
658
  if use_cache:
 
1041
  loss = None
1042
  if labels is not None:
1043
  loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
1044
+ #
1045
  aux_loss = None
1046
+ total_small_expert_loss = torch.tensor(0.0, device=logits.device)
1047
+
1048
  if output_router_logits:
1049
+ # Calculate regular load balancing loss
1050
  aux_loss = load_balancing_loss_func(
1051
  outputs.router_logits if return_dict else outputs[-1],
1052
  self.num_experts,
1053
  self.num_experts_per_tok,
1054
  attention_mask,
1055
  )
1056
+
1057
+ # Calculate small expert load balancing loss
1058
+ router_logits = outputs.router_logits if return_dict else outputs[-1]
1059
+ if isinstance(router_logits, tuple):
1060
+ small_expert_mask = torch.zeros_like(router_logits[0])
1061
+ # Create mask for small experts
1062
+ for idx in range(self.config.num_experts - self.config.small_expert_count,
1063
+ self.config.num_experts):
1064
+ small_expert_mask = small_expert_mask.scatter(-1, torch.tensor([idx]), 1.0)
 
 
 
1065
 
1066
+ # Apply mask and calculate loss
1067
+ masked_router_logits = [rl * small_expert_mask for rl in router_logits]
1068
+ total_small_expert_loss = load_balancing_loss_func(
1069
+ tuple(masked_router_logits),
1070
+ self.num_experts,
1071
+ self.num_experts_per_tok,
1072
+ attention_mask,
1073
+ ) * self.config.small_expert_load_balancing_coef
1074
+
1075
  if labels is not None:
1076
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
1077
+ if aux_loss is not None:
1078
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
1079
+ if total_small_expert_loss is not None:
1080
+ loss += total_small_expert_loss.to(loss.device)
1081
  #
1082
  return MoeCausalLMOutputWithPast(
1083
  loss=loss,