Charlie81 commited on
Commit
a62f0f3
·
1 Parent(s): f126bc5

robust get_expert_stats and remove repeat

Browse files
Files changed (1) hide show
  1. scripts/evalexperts.py +28 -162
scripts/evalexperts.py CHANGED
@@ -146,158 +146,25 @@ class ExpertTrackingHFLM(HFLM):
146
 
147
  def get_expert_stats(self) -> Dict[str, Any]:
148
  """Return expert usage statistics in a serializable format."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  stats = {
150
- 'total_tokens': self.expert_stats['total_tokens'],
151
- 'regular_expert_usage': {},
152
- 'small_expert_usage': {},
153
- 'layer_stats': {}
154
- }
155
-
156
- # Convert regular expert usage
157
- for expert_idx, count in self.expert_stats['regular_expert_usage'].items():
158
- stats['regular_expert_usage'][expert_idx] = {
159
- 'count': count,
160
- 'percentage': count / (self.expert_stats['total_tokens'] * getattr(self.model.config, 'top_k', 1)) * 100
161
- }
162
-
163
- # Convert small expert usage if they exist
164
- if self.expert_stats['small_expert_usage']:
165
- for expert_idx, count in self.expert_stats['small_expert_usage'].items():
166
- stats['small_expert_usage'][expert_idx] = {
167
- 'count': count,
168
- 'percentage': count / (self.expert_stats['total_tokens'] * getattr(self.model.config, 'top_k', 1)) * 100
169
- }
170
-
171
- # Convert layer stats
172
- for layer_idx, layer_stat in self.expert_stats['layer_stats'].items():
173
- stats['layer_stats'][layer_idx] = {
174
- 'total_tokens': layer_stat['total_tokens'],
175
- 'regular_expert_counts': layer_stat['regular_expert_counts'],
176
- 'regular_expert_load': layer_stat['regular_expert_load'],
177
- 'small_expert_counts': layer_stat['small_expert_counts'],
178
- 'small_expert_load': layer_stat['small_expert_load']
179
- }
180
-
181
- return stats
182
-
183
- def print_expert_stats(self) -> None:
184
- """Print expert usage statistics in a human-readable format."""
185
- if not self.expert_stats['total_tokens']:
186
- print("No expert usage statistics collected.")
187
- return
188
-
189
- total_tokens = self.expert_stats['total_tokens']
190
- top_k = getattr(self.model.config, 'top_k', 1)
191
- total_expert_activations = total_tokens * top_k
192
-
193
- print("\n" + "="*80)
194
- print("EXPERT USAGE STATISTICS")
195
- print("="*80)
196
- print(f"Total tokens processed: {total_tokens:,}")
197
- print(f"Total expert activations (top-{top_k}): {total_expert_activations:,}")
198
- print("\nOverall Expert Usage:")
199
-
200
- # Print regular experts
201
- if self.expert_stats['regular_expert_usage']:
202
- print("\nRegular Experts:")
203
- for expert_idx, count in sorted(self.expert_stats['regular_expert_usage'].items()):
204
- percentage = count / total_expert_activations * 100
205
- print(f" Expert {expert_idx}: {count:,} ({percentage:.2f}%)")
206
-
207
- # Print small experts if they exist
208
- if self.expert_stats['small_expert_usage']:
209
- print("\nSmall Experts:")
210
- for expert_idx, count in sorted(self.expert_stats['small_expert_usage'].items()):
211
- percentage = count / total_expert_activations * 100
212
- print(f" Small Expert {expert_idx}: {count:,} ({percentage:.2f}%)")
213
-
214
- # Print layer-wise statistics
215
- print("\nLayer-wise Statistics:")
216
- for layer_idx, layer_stat in self.expert_stats['layer_stats'].items():
217
- print(f"\nLayer {layer_idx}:")
218
- print(f" Tokens processed: {layer_stat['total_tokens']:,}")
219
-
220
- # Regular experts
221
- print(" Regular Experts:")
222
- for expert_idx, (count, load) in enumerate(zip(
223
- layer_stat['regular_expert_counts'],
224
- layer_stat['regular_expert_load']
225
- )):
226
- count_pct = count / (layer_stat['total_tokens'] * top_k) * 100
227
- load_pct = load / layer_stat['total_tokens'] * 100
228
- print(f" Expert {expert_idx}: Count={count:,} ({count_pct:.2f}%), Load={load:.2f} ({load_pct:.2f}%)")
229
-
230
- # Small experts if they exist
231
- if layer_stat['small_expert_counts'] is not None:
232
- print(" Small Experts:")
233
- for expert_idx, (count, load) in enumerate(zip(
234
- layer_stat['small_expert_counts'],
235
- layer_stat['small_expert_load']
236
- )):
237
- count_pct = count / (layer_stat['total_tokens'] * top_k) * 100
238
- load_pct = load / layer_stat['total_tokens'] * 100
239
- print(f" Small Expert {expert_idx}: Count={count:,} ({count_pct:.2f}%), Load={load:.2f} ({load_pct:.2f}%)")
240
-
241
- print("="*80 + "\n")
242
- def _update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
243
- topk_probs: torch.Tensor, num_regular_experts: int,
244
- num_small_experts: int, batch_size: int, seq_len: int):
245
- """Update expert usage statistics with serializable data types."""
246
- # Flatten the batch and sequence dimensions
247
- topk_experts_flat = topk_experts.view(-1, topk_experts.size(-1))
248
- topk_probs_flat = topk_probs.view(-1, topk_probs.size(-1))
249
-
250
- # Initialize layer stats if not present
251
- if layer_idx not in self.expert_stats['layer_stats']:
252
- self.expert_stats['layer_stats'][layer_idx] = {
253
- 'total_tokens': 0,
254
- 'regular_expert_counts': [0] * num_regular_experts, # Use list instead of tensor
255
- 'small_expert_counts': [0] * num_small_experts if num_small_experts > 0 else None,
256
- 'regular_expert_load': [0.0] * num_regular_experts,
257
- 'small_expert_load': [0.0] * num_small_experts if num_small_experts > 0 else None
258
- }
259
-
260
- layer_stats = self.expert_stats['layer_stats'][layer_idx]
261
- num_tokens = topk_experts_flat.size(0)
262
-
263
- # Update global stats
264
- self.expert_stats['total_tokens'] += num_tokens
265
-
266
- # Update layer stats
267
- layer_stats['total_tokens'] += num_tokens
268
-
269
- # Track regular experts
270
- for expert_idx in range(num_regular_experts):
271
- mask = (topk_experts_flat == expert_idx)
272
- count = mask.sum().item()
273
- load = topk_probs_flat[mask].sum().item()
274
-
275
- layer_stats['regular_expert_counts'][expert_idx] += count
276
- layer_stats['regular_expert_load'][expert_idx] += load
277
-
278
- if expert_idx not in self.expert_stats['regular_expert_usage']:
279
- self.expert_stats['regular_expert_usage'][expert_idx] = 0
280
- self.expert_stats['regular_expert_usage'][expert_idx] += count
281
-
282
- # Track small experts if they exist
283
- if num_small_experts > 0:
284
- for expert_idx in range(num_small_experts):
285
- small_expert_num = expert_idx + num_regular_experts
286
- mask = (topk_experts_flat == small_expert_num)
287
- count = mask.sum().item()
288
- load = topk_probs_flat[mask].sum().item()
289
-
290
- layer_stats['small_expert_counts'][expert_idx] += count
291
- layer_stats['small_expert_load'][expert_idx] += load
292
-
293
- if expert_idx not in self.expert_stats['small_expert_usage']:
294
- self.expert_stats['small_expert_usage'][expert_idx] = 0
295
- self.expert_stats['small_expert_usage'][expert_idx] += count
296
-
297
- def get_expert_stats(self) -> Dict[str, Any]:
298
- """Return expert usage statistics in a serializable format."""
299
- stats = {
300
- 'total_tokens': self.expert_stats['total_tokens'],
301
  'regular_expert_usage': {},
302
  'small_expert_usage': {},
303
  'layer_stats': {}
@@ -306,30 +173,30 @@ def _update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
306
  # Convert regular expert usage
307
  for expert_idx, count in self.expert_stats['regular_expert_usage'].items():
308
  stats['regular_expert_usage'][expert_idx] = {
309
- 'count': count,
310
- 'percentage': count / (self.expert_stats['total_tokens'] * self.model.config.top_k) * 100
311
  }
312
 
313
  # Convert small expert usage if they exist
314
  if self.expert_stats['small_expert_usage']:
315
  for expert_idx, count in self.expert_stats['small_expert_usage'].items():
316
  stats['small_expert_usage'][expert_idx] = {
317
- 'count': count,
318
- 'percentage': count / (self.expert_stats['total_tokens'] * self.model.config.top_k) * 100
319
  }
320
 
321
  # Convert layer stats
322
  for layer_idx, layer_stat in self.expert_stats['layer_stats'].items():
323
  stats['layer_stats'][layer_idx] = {
324
- 'total_tokens': layer_stat['total_tokens'],
325
- 'regular_expert_counts': layer_stat['regular_expert_counts'].tolist(),
326
- 'regular_expert_load': layer_stat['regular_expert_load'].tolist(),
327
- 'small_expert_counts': layer_stat['small_expert_counts'].tolist() if layer_stat['small_expert_counts'] is not None else None,
328
- 'small_expert_load': layer_stat['small_expert_load'].tolist() if layer_stat['small_expert_load'] is not None else None
329
  }
330
 
331
  return stats
332
-
333
  def print_expert_stats(self) -> None:
334
  """Print expert usage statistics in a human-readable format."""
335
  if not self.expert_stats['total_tokens']:
@@ -390,7 +257,6 @@ def _update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
390
 
391
  print("="*80 + "\n")
392
 
393
-
394
  def parse_args():
395
  """Parse command line arguments."""
396
  parser = argparse.ArgumentParser(
 
146
 
147
  def get_expert_stats(self) -> Dict[str, Any]:
148
  """Return expert usage statistics in a serializable format."""
149
+ def convert(obj):
150
+ """Recursively convert objects to JSON-serializable formats."""
151
+ if isinstance(obj, (np.integer, np.floating)):
152
+ return int(obj) if isinstance(obj, np.integer) else float(obj)
153
+ elif isinstance(obj, np.ndarray):
154
+ return obj.tolist()
155
+ elif isinstance(obj, torch.Tensor):
156
+ return obj.cpu().numpy().tolist()
157
+ elif isinstance(obj, torch.dtype):
158
+ return str(obj)
159
+ elif isinstance(obj, (dict)):
160
+ return {k: convert(v) for k, v in obj.items()}
161
+ elif isinstance(obj, (list, tuple)):
162
+ return [convert(v) for v in obj]
163
+ else:
164
+ return obj
165
+
166
  stats = {
167
+ 'total_tokens': convert(self.expert_stats['total_tokens']),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  'regular_expert_usage': {},
169
  'small_expert_usage': {},
170
  'layer_stats': {}
 
173
  # Convert regular expert usage
174
  for expert_idx, count in self.expert_stats['regular_expert_usage'].items():
175
  stats['regular_expert_usage'][expert_idx] = {
176
+ 'count': convert(count),
177
+ 'percentage': convert(count / (self.expert_stats['total_tokens'] * getattr(self.model.config, 'top_k', 1)) * 100)
178
  }
179
 
180
  # Convert small expert usage if they exist
181
  if self.expert_stats['small_expert_usage']:
182
  for expert_idx, count in self.expert_stats['small_expert_usage'].items():
183
  stats['small_expert_usage'][expert_idx] = {
184
+ 'count': convert(count),
185
+ 'percentage': convert(count / (self.expert_stats['total_tokens'] * getattr(self.model.config, 'top_k', 1)) * 100)
186
  }
187
 
188
  # Convert layer stats
189
  for layer_idx, layer_stat in self.expert_stats['layer_stats'].items():
190
  stats['layer_stats'][layer_idx] = {
191
+ 'total_tokens': convert(layer_stat['total_tokens']),
192
+ 'regular_expert_counts': convert(layer_stat['regular_expert_counts']),
193
+ 'regular_expert_load': convert(layer_stat['regular_expert_load']),
194
+ 'small_expert_counts': convert(layer_stat['small_expert_counts']),
195
+ 'small_expert_load': convert(layer_stat['small_expert_load'])
196
  }
197
 
198
  return stats
199
+
200
  def print_expert_stats(self) -> None:
201
  """Print expert usage statistics in a human-readable format."""
202
  if not self.expert_stats['total_tokens']:
 
257
 
258
  print("="*80 + "\n")
259
 
 
260
  def parse_args():
261
  """Parse command line arguments."""
262
  parser = argparse.ArgumentParser(