Charlie81 commited on
Commit
c306fa9
·
1 Parent(s): 44c43d7

huge fixes

Browse files
Files changed (2) hide show
  1. myolmoe/config.json +1 -1
  2. myolmoe/modeling_myolmoe.py +65 -74
myolmoe/config.json CHANGED
@@ -33,5 +33,5 @@
33
  "vocab_size": 50304,
34
  "small_expert_intermediate_ratio": 16,
35
  "small_expert_count": 64,
36
- "small_expert_load_balancing_coef": 0.1
37
  }
 
33
  "vocab_size": 50304,
34
  "small_expert_intermediate_ratio": 16,
35
  "small_expert_count": 64,
36
+ "small_expert_sparsity_coef": 0.1
37
  }
myolmoe/modeling_myolmoe.py CHANGED
@@ -29,7 +29,7 @@ class OlmoeConfig(PretrainedConfig):
29
  Ratio of intermediate size for small experts compared to regular experts.
30
  small_expert_count (`int`, *optional*, defaults to 64):
31
  Frequency of small experts - every Nth expert will be small.
32
- small_expert_load_balancing_coef (`float`, *optional*, defaults to 0.1):
33
  Coefficient for small expert load balancing loss.
34
  """
35
  model_type = "olmoe"
@@ -64,7 +64,7 @@ class OlmoeConfig(PretrainedConfig):
64
  norm_topk_prob=False,
65
  small_expert_intermediate_ratio=64,
66
  small_expert_count=64,
67
- small_expert_load_balancing_coef=0.1,
68
  **kwargs,
69
  ):
70
  self.vocab_size = vocab_size
@@ -97,7 +97,7 @@ class OlmoeConfig(PretrainedConfig):
97
  # Small expert parameters
98
  self.small_expert_intermediate_ratio = small_expert_intermediate_ratio
99
  self.small_expert_count = small_expert_count
100
- self.small_expert_load_balancing_coef = small_expert_load_balancing_coef
101
 
102
  # Validate the correctness of rotary position embeddings parameters
103
  if self.rope_scaling is not None and "type" in self.rope_scaling:
@@ -546,65 +546,60 @@ OLMOE_ATTENTION_CLASSES = {
546
  }
547
 
548
 
549
-
550
  class OlmoeSparseMoeBlock(nn.Module):
551
  def __init__(self, config, layer_idx: int):
552
  super().__init__()
553
  self.layer_idx = layer_idx
554
  self.num_experts = config.num_experts
 
555
  self.top_k = config.num_experts_per_tok
556
  self.norm_topk_prob = config.norm_topk_prob
557
- self.routing_type = getattr(config, "routing_type", "topk")
558
- self.n_step = getattr(config, "nth_step", 2)
559
-
560
- # Track which experts are small
561
- self.small_expert_indices = list(range(config.num_experts - config.small_expert_count, config.num_experts))
562
- self.experts = nn.ModuleList()
563
 
564
- for i in range(self.num_experts):
565
- # Small experts are now at the end indices
566
- is_small = i in self.small_expert_indices
567
- self.experts.append(OlmoeMLP(config, is_small=is_small))
568
-
569
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
570
- self.small_expert_load_balancing_coef = config.small_expert_load_balancing_coef
 
571
 
572
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
573
  batch_size, sequence_length, hidden_dim = hidden_states.shape
574
  hidden_states = hidden_states.view(-1, hidden_dim)
 
 
575
  router_logits = self.gate(hidden_states)
576
- routing_probs = F.softmax(router_logits, dim=1, dtype=torch.float)
577
-
 
 
 
578
  routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
579
 
580
  if self.norm_topk_prob:
581
- routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
582
 
583
- routing_weights = routing_weights.to(hidden_states.dtype)
584
- final_hidden_states = torch.zeros(
585
- (batch_size * sequence_length, hidden_dim),
586
- dtype=hidden_states.dtype,
587
- device=hidden_states.device,
588
- )
589
-
590
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
591
-
592
- # Calculate small expert load balancing loss
593
- small_expert_mask = torch.zeros_like(expert_mask)
594
- for idx in self.small_expert_indices:
595
- small_expert_mask[idx] = expert_mask[idx]
596
 
597
- for expert_idx in range(self.num_experts):
598
- expert_layer = self.experts[expert_idx]
599
  idx, top_x = torch.where(expert_mask[expert_idx])
600
- if top_x.numel() == 0:
601
  continue
602
- current_state = hidden_states[top_x]
603
- current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
604
- final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
 
 
 
 
 
 
605
 
606
- final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
607
- return final_hidden_states, router_logits
608
 
609
  class OlmoeDecoderLayer(nn.Module):
610
  def __init__(self, config: OlmoeConfig, layer_idx: int):
@@ -1042,42 +1037,38 @@ class MyOlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
1042
  if labels is not None:
1043
  loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
1044
  #
1045
- aux_loss = None
1046
- total_small_expert_loss = torch.tensor(0.0, device=logits.device)
1047
-
1048
- if output_router_logits:
1049
- # Calculate regular load balancing loss
1050
- aux_loss = load_balancing_loss_func(
1051
- outputs.router_logits if return_dict else outputs[-1],
1052
- self.num_experts,
1053
- self.num_experts_per_tok,
1054
- attention_mask,
 
 
 
 
1055
  )
 
1056
 
1057
- # Calculate small expert load balancing loss
1058
- router_logits = outputs.router_logits if return_dict else outputs[-1]
1059
- if isinstance(router_logits, tuple):
1060
- small_expert_mask = torch.zeros_like(router_logits[0])
1061
- # Create mask for small experts
1062
- for idx in range(self.config.num_experts - self.config.small_expert_count,
1063
- self.config.num_experts):
1064
- small_expert_mask = small_expert_mask.scatter(-1, torch.tensor([idx]), 1.0)
1065
 
1066
- # Apply mask and calculate loss
1067
- masked_router_logits = [rl * small_expert_mask for rl in router_logits]
1068
- total_small_expert_loss = load_balancing_loss_func(
1069
- tuple(masked_router_logits),
1070
- self.num_experts,
1071
- self.num_experts_per_tok,
1072
- attention_mask,
1073
- ) * self.config.small_expert_load_balancing_coef
1074
-
1075
- if labels is not None:
1076
- loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
1077
- if aux_loss is not None:
1078
- loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
1079
- if total_small_expert_loss is not None:
1080
- loss += total_small_expert_loss.to(loss.device)
1081
  #
1082
  return MoeCausalLMOutputWithPast(
1083
  loss=loss,
 
29
  Ratio of intermediate size for small experts compared to regular experts.
30
  small_expert_count (`int`, *optional*, defaults to 64):
31
  Frequency of small experts - every Nth expert will be small.
32
+ small_expert_sparsity_coef (`float`, *optional*, defaults to 0.1):
33
  Coefficient for small expert load balancing loss.
34
  """
35
  model_type = "olmoe"
 
64
  norm_topk_prob=False,
65
  small_expert_intermediate_ratio=64,
66
  small_expert_count=64,
67
+ small_expert_sparsity_coef=0.1,
68
  **kwargs,
69
  ):
70
  self.vocab_size = vocab_size
 
97
  # Small expert parameters
98
  self.small_expert_intermediate_ratio = small_expert_intermediate_ratio
99
  self.small_expert_count = small_expert_count
100
+ self.small_expert_sparsity_coef = small_expert_sparsity_coef
101
 
102
  # Validate the correctness of rotary position embeddings parameters
103
  if self.rope_scaling is not None and "type" in self.rope_scaling:
 
546
  }
547
 
548
 
 
549
  class OlmoeSparseMoeBlock(nn.Module):
550
  def __init__(self, config, layer_idx: int):
551
  super().__init__()
552
  self.layer_idx = layer_idx
553
  self.num_experts = config.num_experts
554
+ self.num_small_experts = config.small_expert_count
555
  self.top_k = config.num_experts_per_tok
556
  self.norm_topk_prob = config.norm_topk_prob
557
+ self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
558
+ self.small_experts = nn.ModuleList([OlmoeMLP(config, is_small=True) for _ in range(self.num_small_experts)])
 
 
 
 
559
 
560
+ # Gates for both expert types
 
 
 
 
561
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
562
+ self.small_gate = nn.Linear(config.hidden_size, self.num_small_experts, bias=False)
563
+ self.small_expert_sparsity_coef = config.small_expert_sparsity_coef
564
 
565
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
566
  batch_size, sequence_length, hidden_dim = hidden_states.shape
567
  hidden_states = hidden_states.view(-1, hidden_dim)
568
+
569
+ # Get logits for both expert types
570
  router_logits = self.gate(hidden_states)
571
+ small_router_logits = self.small_gate(hidden_states)
572
+
573
+ # Combine logits for routing
574
+ combined_logits = torch.cat([router_logits, small_router_logits], dim=-1)
575
+ routing_probs = F.softmax(combined_logits, dim=1, dtype=torch.float)
576
  routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
577
 
578
  if self.norm_topk_prob:
579
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
580
 
581
+ final_hidden_states = torch.zeros_like(hidden_states)
582
+ expert_mask = torch.nn.functional.one_hot(
583
+ selected_experts,
584
+ num_classes=self.num_experts + self.num_small_experts
585
+ ).permute(2, 1, 0)
 
 
 
 
 
 
 
 
586
 
587
+ # Process all experts (regular + small)
588
+ for expert_idx in range(self.num_experts + self.num_small_experts):
589
  idx, top_x = torch.where(expert_mask[expert_idx])
590
+ if top_x.shape[0] == 0:
591
  continue
592
+
593
+ if expert_idx < self.num_experts:
594
+ expert = self.experts[expert_idx]
595
+ else:
596
+ expert = self.small_experts[expert_idx - self.num_experts]
597
+
598
+ current_states = hidden_states[top_x]
599
+ current_output = expert(current_states) * routing_weights[top_x, idx, None]
600
+ final_hidden_states.index_add_(0, top_x, current_output.to(hidden_states.dtype))
601
 
602
+ return final_hidden_states.view(batch_size, sequence_length, hidden_dim), combined_logits
 
603
 
604
  class OlmoeDecoderLayer(nn.Module):
605
  def __init__(self, config: OlmoeConfig, layer_idx: int):
 
1037
  if labels is not None:
1038
  loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
1039
  #
1040
+ total_aux_loss = 0
1041
+ if output_router_logits and outputs.router_logits is not None:
1042
+ # Regular load balancing loss
1043
+ total_aux_loss += load_balancing_loss_func(
1044
+ outputs.router_logits,
1045
+ num_experts=self.config.num_experts + self.config.small_expert_count,
1046
+ top_k=self.config.num_experts_per_tok,
1047
+ attention_mask=attention_mask
1048
+ )
1049
+
1050
+ # Small expert sparsity loss
1051
+ small_expert_mask = torch.zeros(
1052
+ self.config.num_experts + self.config.small_expert_count,
1053
+ device=outputs.router_logits[0].device
1054
  )
1055
+ small_expert_mask[self.config.num_experts:] = 1.0
1056
 
1057
+ masked_router_logits = []
1058
+ for logits in outputs.router_logits:
1059
+ # Apply mask to emphasize small experts
1060
+ masked_logits = logits * small_expert_mask * self.config.small_expert_sparsity_coef
1061
+ masked_router_logits.append(masked_logits)
 
 
 
1062
 
1063
+ total_aux_loss += load_balancing_loss_func(
1064
+ tuple(masked_router_logits),
1065
+ num_experts=self.config.num_experts + self.config.small_expert_count,
1066
+ top_k=self.config.num_experts_per_tok,
1067
+ attention_mask=attention_mask
1068
+ )
1069
+
1070
+ if loss is not None:
1071
+ loss += self.router_aux_loss_coef * total_aux_loss.to(loss.device)
 
 
 
 
 
 
1072
  #
1073
  return MoeCausalLMOutputWithPast(
1074
  loss=loss,