Charlie81 commited on
Commit
b1da2be
·
1 Parent(s): 5dc5166

try patch hook

Browse files
Files changed (2) hide show
  1. myolmoe/config.json +1 -5
  2. scripts/evalexperts.py +25 -32
myolmoe/config.json CHANGED
@@ -30,9 +30,5 @@
30
  "torch_dtype": "float32",
31
  "transformers_version": "4.52.4",
32
  "use_cache": true,
33
- "vocab_size": 50304,
34
- "small_expert_intermediate_ratio": 16,
35
- "small_expert_count": 64,
36
- "small_expert_sparsity_coef": 0.1,
37
- "max_small_expert_count": 64
38
  }
 
30
  "torch_dtype": "float32",
31
  "transformers_version": "4.52.4",
32
  "use_cache": true,
33
+ "vocab_size": 50304
 
 
 
 
34
  }
scripts/evalexperts.py CHANGED
@@ -57,38 +57,31 @@ class ExpertTrackingHFLM(HFLM):
57
  self._make_expert_hook(layer_idx)
58
  )
59
 
60
- def _make_expert_hook(self, layer_idx: int):
61
- """Create a forward hook for tracking expert usage in a specific layer."""
62
- def expert_hook(module, input, output):
63
- if not hasattr(module, 'gate') or not hasattr(module, 'experts'):
64
- return
65
-
66
- hidden_states, router_logits = input[0], output[1]
67
- batch_size, seq_len, hidden_dim = hidden_states.shape
68
-
69
- # Get routing probabilities
70
- routing_probs = torch.softmax(router_logits, dim=-1)
71
-
72
- # Get top-k experts
73
- topk_probs, topk_experts = torch.topk(
74
- routing_probs,
75
- k=module.top_k,
76
- dim=-1
77
- )
78
-
79
- # Update statistics
80
- self.update_expert_stats(
81
- layer_idx=layer_idx,
82
- topk_experts=topk_experts,
83
- topk_probs=topk_probs,
84
- num_regular_experts=module.num_experts,
85
- num_small_experts=module.num_small_experts if hasattr(module, 'num_small_experts') else 0,
86
- batch_size=batch_size,
87
- seq_len=seq_len
88
- )
89
-
90
- return expert_hook
91
-
92
  def update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
93
  topk_probs: torch.Tensor, num_regular_experts: int,
94
  num_small_experts: int, batch_size: int, seq_len: int):
 
57
  self._make_expert_hook(layer_idx)
58
  )
59
 
60
+ def _make_expert_hook(layer_idx, model):
61
+ def hook(module, input, output):
62
+ # Get expert routing data from output
63
+ if isinstance(output, tuple) and len(output) == 2:
64
+ hidden_states, routing_weights = output
65
+ else:
66
+ hidden_states = output
67
+ routing_weights = None
68
+
69
+ # Always use the config value for num_small_experts
70
+ num_small_experts = getattr(model.config, 'small_expert_count', 0)
71
+
72
+ expert_stats[layer_idx] = expert_stats.get(layer_idx, {})
73
+ expert_stats[layer_idx]['total'] = expert_stats[layer_idx].get('total', 0) + 1
74
+
75
+ if routing_weights is not None:
76
+ top_expert = routing_weights.argmax(dim=-1)
77
+ for expert_id in top_expert.view(-1).tolist():
78
+ expert_stats[layer_idx][expert_id] = expert_stats[layer_idx].get(expert_id, 0) + 1
79
+
80
+ if expert_id < num_small_experts:
81
+ expert_stats[layer_idx]['small'] = expert_stats[layer_idx].get('small', 0) + 1
82
+
83
+ return hook
84
+
 
 
 
 
 
 
 
85
  def update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
86
  topk_probs: torch.Tensor, num_regular_experts: int,
87
  num_small_experts: int, batch_size: int, seq_len: int):