Charlie81 commited on
Commit
a875a53
·
1 Parent(s): 2b77d15

expert usage stats

Browse files
Files changed (1) hide show
  1. myolmoe/modeling_myolmoe.py +23 -11
myolmoe/modeling_myolmoe.py CHANGED
@@ -1,5 +1,6 @@
1
  import math
2
  from typing import List, Optional, Tuple, Union
 
3
  import torch
4
  import torch.nn.functional as F
5
  import torch.utils.checkpoint
@@ -558,20 +559,17 @@ class OlmoeSparseMoeBlock(nn.Module):
558
  self.top_k = config.num_experts_per_tok
559
  self.norm_topk_prob = config.norm_topk_prob
560
 
561
- # Determine if this block is in the second half
562
  in_second_half = layer_idx >= self.total_layers // 2
563
 
564
- # Determine small expert count for this layer
565
  if in_second_half:
566
  second_half_idx = layer_idx - (self.total_layers // 2)
567
  num_second_half_blocks = self.total_layers - (self.total_layers // 2)
568
-
569
  if config.small_expert_strategy == "constant":
570
  self.num_small_experts = config.max_small_expert_count // num_second_half_blocks
571
  elif config.small_expert_strategy == "increment":
572
- # Linearly scale small experts from 1 to max_small_expert_count
573
  self.num_small_experts = (
574
- (second_half_idx + 1) * config.max_small_expert_count // ((num_second_half_blocks * (num_second_half_blocks + 1)) // 2)
 
575
  )
576
  else:
577
  raise ValueError(f"Unknown strategy: {config.small_expert_strategy}")
@@ -584,20 +582,19 @@ class OlmoeSparseMoeBlock(nn.Module):
584
  ]) if self.num_small_experts > 0 else None
585
 
586
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
587
-
588
- if self.num_small_experts > 0:
589
- self.small_gate = nn.Linear(config.hidden_size, self.num_small_experts, bias=False)
590
- else:
591
- self.small_gate = None
592
 
593
  self.small_expert_sparsity_coef = config.small_expert_sparsity_coef
594
 
 
 
 
595
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
596
  batch_size, sequence_length, hidden_dim = hidden_states.shape
597
  hidden_states = hidden_states.view(-1, hidden_dim)
598
 
599
  router_logits = self.gate(hidden_states)
600
-
601
  if self.num_small_experts > 0:
602
  small_router_logits = self.small_gate(hidden_states)
603
  combined_logits = torch.cat([router_logits, small_router_logits], dim=-1)
@@ -607,6 +604,12 @@ class OlmoeSparseMoeBlock(nn.Module):
607
  routing_probs = F.softmax(combined_logits, dim=1, dtype=torch.float)
608
  routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
609
 
 
 
 
 
 
 
610
  if self.norm_topk_prob:
611
  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
612
 
@@ -632,6 +635,15 @@ class OlmoeSparseMoeBlock(nn.Module):
632
 
633
  return final_hidden_states.view(batch_size, sequence_length, hidden_dim), combined_logits
634
 
 
 
 
 
 
 
 
 
 
635
 
636
  class OlmoeDecoderLayer(nn.Module):
637
  def __init__(self, config: OlmoeConfig, layer_idx: int):
 
1
  import math
2
  from typing import List, Optional, Tuple, Union
3
+ from collections import defaultdict
4
  import torch
5
  import torch.nn.functional as F
6
  import torch.utils.checkpoint
 
559
  self.top_k = config.num_experts_per_tok
560
  self.norm_topk_prob = config.norm_topk_prob
561
 
 
562
  in_second_half = layer_idx >= self.total_layers // 2
563
 
 
564
  if in_second_half:
565
  second_half_idx = layer_idx - (self.total_layers // 2)
566
  num_second_half_blocks = self.total_layers - (self.total_layers // 2)
 
567
  if config.small_expert_strategy == "constant":
568
  self.num_small_experts = config.max_small_expert_count // num_second_half_blocks
569
  elif config.small_expert_strategy == "increment":
 
570
  self.num_small_experts = (
571
+ (second_half_idx + 1) * config.max_small_expert_count //
572
+ ((num_second_half_blocks * (num_second_half_blocks + 1)) // 2)
573
  )
574
  else:
575
  raise ValueError(f"Unknown strategy: {config.small_expert_strategy}")
 
582
  ]) if self.num_small_experts > 0 else None
583
 
584
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
585
+ self.small_gate = nn.Linear(config.hidden_size, self.num_small_experts, bias=False) \
586
+ if self.num_small_experts > 0 else None
 
 
 
587
 
588
  self.small_expert_sparsity_coef = config.small_expert_sparsity_coef
589
 
590
+ # Usage tracking (not a buffer, no gradient)
591
+ self.expert_usage = defaultdict(int)
592
+
593
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
594
  batch_size, sequence_length, hidden_dim = hidden_states.shape
595
  hidden_states = hidden_states.view(-1, hidden_dim)
596
 
597
  router_logits = self.gate(hidden_states)
 
598
  if self.num_small_experts > 0:
599
  small_router_logits = self.small_gate(hidden_states)
600
  combined_logits = torch.cat([router_logits, small_router_logits], dim=-1)
 
604
  routing_probs = F.softmax(combined_logits, dim=1, dtype=torch.float)
605
  routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
606
 
607
+ # Track expert usage
608
+ for i in range(selected_experts.size(0)):
609
+ for j in range(self.top_k):
610
+ expert_id = selected_experts[i, j].item()
611
+ self.expert_usage[expert_id] += 1
612
+
613
  if self.norm_topk_prob:
614
  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
615
 
 
635
 
636
  return final_hidden_states.view(batch_size, sequence_length, hidden_dim), combined_logits
637
 
638
+ def __del__(self):
639
+ if self.expert_usage:
640
+ print(f"\n[Expert Usage Report for Layer {self.layer_idx}]")
641
+ total = sum(self.expert_usage.values())
642
+ for expert_id in sorted(self.expert_usage):
643
+ count = self.expert_usage[expert_id]
644
+ percent = 100.0 * count / total if total > 0 else 0.0
645
+ print(f" Expert {expert_id:2d}: {count} times ({percent:.2f}%)")
646
+
647
 
648
  class OlmoeDecoderLayer(nn.Module):
649
  def __init__(self, config: OlmoeConfig, layer_idx: int):