Charlie81 commited on
Commit
f126bc5
·
1 Parent(s): 61a401c

everything into the expertracking class

Browse files
Files changed (1) hide show
  1. scripts/evalexperts.py +98 -3
scripts/evalexperts.py CHANGED
@@ -77,7 +77,7 @@ class ExpertTrackingHFLM(HFLM):
77
  )
78
 
79
  # Update statistics
80
- self.update_expert_stats( # Changed from _update_expert_stats to update_expert_stats
81
  layer_idx=layer_idx,
82
  topk_experts=topk_experts,
83
  topk_probs=topk_probs,
@@ -89,7 +89,7 @@ class ExpertTrackingHFLM(HFLM):
89
 
90
  return expert_hook
91
 
92
- def update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor, # Renamed from _update_expert_stats
93
  topk_probs: torch.Tensor, num_regular_experts: int,
94
  num_small_experts: int, batch_size: int, seq_len: int):
95
  """Update expert usage statistics."""
@@ -143,7 +143,102 @@ class ExpertTrackingHFLM(HFLM):
143
  if expert_idx not in self.expert_stats['small_expert_usage']:
144
  self.expert_stats['small_expert_usage'][expert_idx] = 0
145
  self.expert_stats['small_expert_usage'][expert_idx] += count
146
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  def _update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
148
  topk_probs: torch.Tensor, num_regular_experts: int,
149
  num_small_experts: int, batch_size: int, seq_len: int):
 
77
  )
78
 
79
  # Update statistics
80
+ self.update_expert_stats(
81
  layer_idx=layer_idx,
82
  topk_experts=topk_experts,
83
  topk_probs=topk_probs,
 
89
 
90
  return expert_hook
91
 
92
+ def update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
93
  topk_probs: torch.Tensor, num_regular_experts: int,
94
  num_small_experts: int, batch_size: int, seq_len: int):
95
  """Update expert usage statistics."""
 
143
  if expert_idx not in self.expert_stats['small_expert_usage']:
144
  self.expert_stats['small_expert_usage'][expert_idx] = 0
145
  self.expert_stats['small_expert_usage'][expert_idx] += count
146
+
147
+ def get_expert_stats(self) -> Dict[str, Any]:
148
+ """Return expert usage statistics in a serializable format."""
149
+ stats = {
150
+ 'total_tokens': self.expert_stats['total_tokens'],
151
+ 'regular_expert_usage': {},
152
+ 'small_expert_usage': {},
153
+ 'layer_stats': {}
154
+ }
155
+
156
+ # Convert regular expert usage
157
+ for expert_idx, count in self.expert_stats['regular_expert_usage'].items():
158
+ stats['regular_expert_usage'][expert_idx] = {
159
+ 'count': count,
160
+ 'percentage': count / (self.expert_stats['total_tokens'] * getattr(self.model.config, 'top_k', 1)) * 100
161
+ }
162
+
163
+ # Convert small expert usage if they exist
164
+ if self.expert_stats['small_expert_usage']:
165
+ for expert_idx, count in self.expert_stats['small_expert_usage'].items():
166
+ stats['small_expert_usage'][expert_idx] = {
167
+ 'count': count,
168
+ 'percentage': count / (self.expert_stats['total_tokens'] * getattr(self.model.config, 'top_k', 1)) * 100
169
+ }
170
+
171
+ # Convert layer stats
172
+ for layer_idx, layer_stat in self.expert_stats['layer_stats'].items():
173
+ stats['layer_stats'][layer_idx] = {
174
+ 'total_tokens': layer_stat['total_tokens'],
175
+ 'regular_expert_counts': layer_stat['regular_expert_counts'],
176
+ 'regular_expert_load': layer_stat['regular_expert_load'],
177
+ 'small_expert_counts': layer_stat['small_expert_counts'],
178
+ 'small_expert_load': layer_stat['small_expert_load']
179
+ }
180
+
181
+ return stats
182
+
183
+ def print_expert_stats(self) -> None:
184
+ """Print expert usage statistics in a human-readable format."""
185
+ if not self.expert_stats['total_tokens']:
186
+ print("No expert usage statistics collected.")
187
+ return
188
+
189
+ total_tokens = self.expert_stats['total_tokens']
190
+ top_k = getattr(self.model.config, 'top_k', 1)
191
+ total_expert_activations = total_tokens * top_k
192
+
193
+ print("\n" + "="*80)
194
+ print("EXPERT USAGE STATISTICS")
195
+ print("="*80)
196
+ print(f"Total tokens processed: {total_tokens:,}")
197
+ print(f"Total expert activations (top-{top_k}): {total_expert_activations:,}")
198
+ print("\nOverall Expert Usage:")
199
+
200
+ # Print regular experts
201
+ if self.expert_stats['regular_expert_usage']:
202
+ print("\nRegular Experts:")
203
+ for expert_idx, count in sorted(self.expert_stats['regular_expert_usage'].items()):
204
+ percentage = count / total_expert_activations * 100
205
+ print(f" Expert {expert_idx}: {count:,} ({percentage:.2f}%)")
206
+
207
+ # Print small experts if they exist
208
+ if self.expert_stats['small_expert_usage']:
209
+ print("\nSmall Experts:")
210
+ for expert_idx, count in sorted(self.expert_stats['small_expert_usage'].items()):
211
+ percentage = count / total_expert_activations * 100
212
+ print(f" Small Expert {expert_idx}: {count:,} ({percentage:.2f}%)")
213
+
214
+ # Print layer-wise statistics
215
+ print("\nLayer-wise Statistics:")
216
+ for layer_idx, layer_stat in self.expert_stats['layer_stats'].items():
217
+ print(f"\nLayer {layer_idx}:")
218
+ print(f" Tokens processed: {layer_stat['total_tokens']:,}")
219
+
220
+ # Regular experts
221
+ print(" Regular Experts:")
222
+ for expert_idx, (count, load) in enumerate(zip(
223
+ layer_stat['regular_expert_counts'],
224
+ layer_stat['regular_expert_load']
225
+ )):
226
+ count_pct = count / (layer_stat['total_tokens'] * top_k) * 100
227
+ load_pct = load / layer_stat['total_tokens'] * 100
228
+ print(f" Expert {expert_idx}: Count={count:,} ({count_pct:.2f}%), Load={load:.2f} ({load_pct:.2f}%)")
229
+
230
+ # Small experts if they exist
231
+ if layer_stat['small_expert_counts'] is not None:
232
+ print(" Small Experts:")
233
+ for expert_idx, (count, load) in enumerate(zip(
234
+ layer_stat['small_expert_counts'],
235
+ layer_stat['small_expert_load']
236
+ )):
237
+ count_pct = count / (layer_stat['total_tokens'] * top_k) * 100
238
+ load_pct = load / layer_stat['total_tokens'] * 100
239
+ print(f" Small Expert {expert_idx}: Count={count:,} ({count_pct:.2f}%), Load={load:.2f} ({load_pct:.2f}%)")
240
+
241
+ print("="*80 + "\n")
242
  def _update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
243
  topk_probs: torch.Tensor, num_regular_experts: int,
244
  num_small_experts: int, batch_size: int, seq_len: int):