Charlie81 commited on
Commit
3e8c5b1
·
1 Parent(s): c37387b

claude to deepseek

Browse files
Files changed (1) hide show
  1. scripts/evalexperts.py +684 -0
scripts/evalexperts.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ eval_with_expert_tracking.py - Evaluation script for OLMoE models with expert usage tracking
4
+
5
+ This script extends the standard evaluation to track:
6
+ 1. Which experts are being used
7
+ 2. Frequency of expert usage
8
+ 3. Distribution across experts
9
+ 4. Small vs regular expert usage
10
+ """
11
+
12
+ import argparse
13
+ import json
14
+ import os
15
+ import sys
16
+ import logging
17
+ from typing import Dict, List, Optional, Any, Tuple
18
+ import numpy as np
19
+ import torch
20
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
21
+
22
+ # lm-eval imports
23
+ from lm_eval import evaluator
24
+ from lm_eval.models.huggingface import HFLM
25
+
26
+ # Set up logging
27
+ logging.basicConfig(
28
+ level=logging.INFO,
29
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
30
+ )
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class ExpertTrackingHFLM(HFLM):
35
+ """Wrapper around HFLM that tracks expert usage statistics."""
36
+
37
+ def __init__(self, *args, **kwargs):
38
+ super().__init__(*args, **kwargs)
39
+ self.expert_stats = {
40
+ 'total_tokens': 0,
41
+ 'regular_expert_usage': {},
42
+ 'small_expert_usage': {},
43
+ 'layer_stats': {}
44
+ }
45
+ self._register_hooks()
46
+
47
+ def _register_hooks(self):
48
+ """Register forward hooks to track expert usage."""
49
+ if not hasattr(self.model, 'model') or not hasattr(self.model.model, 'layers'):
50
+ logger.warning("Model doesn't have expected layer structure - expert tracking disabled")
51
+ return
52
+
53
+ for layer_idx, layer in enumerate(self.model.model.layers):
54
+ if hasattr(layer, 'mlp') and hasattr(layer.mlp, 'experts'):
55
+ # Register hook for this MoE layer
56
+ layer.mlp._expert_hook_handle = layer.mlp.register_forward_hook(
57
+ self._make_expert_hook(layer_idx)
58
+ )
59
+
60
+ def _make_expert_hook(self, layer_idx: int):
61
+ """Create a forward hook for tracking expert usage in a specific layer."""
62
+ def expert_hook(module, input, output):
63
+ if not hasattr(module, 'gate') or not hasattr(module, 'experts'):
64
+ return
65
+
66
+ hidden_states, router_logits = input[0], output[1]
67
+ batch_size, seq_len, hidden_dim = hidden_states.shape
68
+
69
+ # Get routing probabilities
70
+ routing_probs = torch.softmax(router_logits, dim=-1)
71
+
72
+ # Get top-k experts
73
+ topk_probs, topk_experts = torch.topk(
74
+ routing_probs,
75
+ k=module.top_k,
76
+ dim=-1
77
+ )
78
+
79
+ # Update statistics
80
+ self._update_expert_stats(
81
+ layer_idx=layer_idx,
82
+ topk_experts=topk_experts,
83
+ topk_probs=topk_probs,
84
+ num_regular_experts=module.num_experts,
85
+ num_small_experts=module.num_small_experts if hasattr(module, 'num_small_experts') else 0,
86
+ batch_size=batch_size,
87
+ seq_len=seq_len
88
+ )
89
+
90
+ return expert_hook
91
+
92
+ def _update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
93
+ topk_probs: torch.Tensor, num_regular_experts: int,
94
+ num_small_experts: int, batch_size: int, seq_len: int):
95
+ """Update expert usage statistics."""
96
+ # Flatten the batch and sequence dimensions
97
+ topk_experts_flat = topk_experts.view(-1, topk_experts.size(-1))
98
+ topk_probs_flat = topk_probs.view(-1, topk_probs.size(-1))
99
+
100
+ # Initialize layer stats if not present
101
+ if layer_idx not in self.expert_stats['layer_stats']:
102
+ self.expert_stats['layer_stats'][layer_idx] = {
103
+ 'total_tokens': 0,
104
+ 'regular_expert_counts': torch.zeros(num_regular_experts, dtype=torch.long),
105
+ 'small_expert_counts': torch.zeros(num_small_experts, dtype=torch.long) if num_small_experts > 0 else None,
106
+ 'regular_expert_load': torch.zeros(num_regular_experts, dtype=torch.float),
107
+ 'small_expert_load': torch.zeros(num_small_experts, dtype=torch.float) if num_small_experts > 0 else None
108
+ }
109
+
110
+ layer_stats = self.expert_stats['layer_stats'][layer_idx]
111
+ num_tokens = topk_experts_flat.size(0)
112
+
113
+ # Update global stats
114
+ self.expert_stats['total_tokens'] += num_tokens
115
+
116
+ # Update layer stats
117
+ layer_stats['total_tokens'] += num_tokens
118
+
119
+ # Track regular experts
120
+ for expert_idx in range(num_regular_experts):
121
+ mask = (topk_experts_flat == expert_idx)
122
+ count = mask.sum().item()
123
+ load = topk_probs_flat[mask].sum().item()
124
+
125
+ layer_stats['regular_expert_counts'][expert_idx] += count
126
+ layer_stats['regular_expert_load'][expert_idx] += load
127
+
128
+ if expert_idx not in self.expert_stats['regular_expert_usage']:
129
+ self.expert_stats['regular_expert_usage'][expert_idx] = 0
130
+ self.expert_stats['regular_expert_usage'][expert_idx] += count
131
+
132
+ # Track small experts if they exist
133
+ if num_small_experts > 0:
134
+ for expert_idx in range(num_small_experts):
135
+ small_expert_num = expert_idx + num_regular_experts
136
+ mask = (topk_experts_flat == small_expert_num)
137
+ count = mask.sum().item()
138
+ load = topk_probs_flat[mask].sum().item()
139
+
140
+ layer_stats['small_expert_counts'][expert_idx] += count
141
+ layer_stats['small_expert_load'][expert_idx] += load
142
+
143
+ if expert_idx not in self.expert_stats['small_expert_usage']:
144
+ self.expert_stats['small_expert_usage'][expert_idx] = 0
145
+ self.expert_stats['small_expert_usage'][expert_idx] += count
146
+
147
+ def get_expert_stats(self) -> Dict[str, Any]:
148
+ """Return expert usage statistics in a serializable format."""
149
+ stats = {
150
+ 'total_tokens': self.expert_stats['total_tokens'],
151
+ 'regular_expert_usage': {},
152
+ 'small_expert_usage': {},
153
+ 'layer_stats': {}
154
+ }
155
+
156
+ # Convert regular expert usage
157
+ for expert_idx, count in self.expert_stats['regular_expert_usage'].items():
158
+ stats['regular_expert_usage'][expert_idx] = {
159
+ 'count': count,
160
+ 'percentage': count / (self.expert_stats['total_tokens'] * self.model.config.top_k) * 100
161
+ }
162
+
163
+ # Convert small expert usage if they exist
164
+ if self.expert_stats['small_expert_usage']:
165
+ for expert_idx, count in self.expert_stats['small_expert_usage'].items():
166
+ stats['small_expert_usage'][expert_idx] = {
167
+ 'count': count,
168
+ 'percentage': count / (self.expert_stats['total_tokens'] * self.model.config.top_k) * 100
169
+ }
170
+
171
+ # Convert layer stats
172
+ for layer_idx, layer_stat in self.expert_stats['layer_stats'].items():
173
+ stats['layer_stats'][layer_idx] = {
174
+ 'total_tokens': layer_stat['total_tokens'],
175
+ 'regular_expert_counts': layer_stat['regular_expert_counts'].tolist(),
176
+ 'regular_expert_load': layer_stat['regular_expert_load'].tolist(),
177
+ 'small_expert_counts': layer_stat['small_expert_counts'].tolist() if layer_stat['small_expert_counts'] is not None else None,
178
+ 'small_expert_load': layer_stat['small_expert_load'].tolist() if layer_stat['small_expert_load'] is not None else None
179
+ }
180
+
181
+ return stats
182
+
183
+ def print_expert_stats(self) -> None:
184
+ """Print expert usage statistics in a human-readable format."""
185
+ if not self.expert_stats['total_tokens']:
186
+ print("No expert usage statistics collected.")
187
+ return
188
+
189
+ total_tokens = self.expert_stats['total_tokens']
190
+ top_k = getattr(self.model.config, 'top_k', 1)
191
+ total_expert_activations = total_tokens * top_k
192
+
193
+ print("\n" + "="*80)
194
+ print("EXPERT USAGE STATISTICS")
195
+ print("="*80)
196
+ print(f"Total tokens processed: {total_tokens:,}")
197
+ print(f"Total expert activations (top-{top_k}): {total_expert_activations:,}")
198
+ print("\nOverall Expert Usage:")
199
+
200
+ # Print regular experts
201
+ if self.expert_stats['regular_expert_usage']:
202
+ print("\nRegular Experts:")
203
+ for expert_idx, count in sorted(self.expert_stats['regular_expert_usage'].items()):
204
+ percentage = count / total_expert_activations * 100
205
+ print(f" Expert {expert_idx}: {count:,} ({percentage:.2f}%)")
206
+
207
+ # Print small experts if they exist
208
+ if self.expert_stats['small_expert_usage']:
209
+ print("\nSmall Experts:")
210
+ for expert_idx, count in sorted(self.expert_stats['small_expert_usage'].items()):
211
+ percentage = count / total_expert_activations * 100
212
+ print(f" Small Expert {expert_idx}: {count:,} ({percentage:.2f}%)")
213
+
214
+ # Print layer-wise statistics
215
+ print("\nLayer-wise Statistics:")
216
+ for layer_idx, layer_stat in self.expert_stats['layer_stats'].items():
217
+ print(f"\nLayer {layer_idx}:")
218
+ print(f" Tokens processed: {layer_stat['total_tokens']:,}")
219
+
220
+ # Regular experts
221
+ print(" Regular Experts:")
222
+ for expert_idx, (count, load) in enumerate(zip(
223
+ layer_stat['regular_expert_counts'],
224
+ layer_stat['regular_expert_load']
225
+ )):
226
+ count_pct = count / (layer_stat['total_tokens'] * top_k) * 100
227
+ load_pct = load / layer_stat['total_tokens'] * 100
228
+ print(f" Expert {expert_idx}: Count={count:,} ({count_pct:.2f}%), Load={load:.2f} ({load_pct:.2f}%)")
229
+
230
+ # Small experts if they exist
231
+ if layer_stat['small_expert_counts'] is not None:
232
+ print(" Small Experts:")
233
+ for expert_idx, (count, load) in enumerate(zip(
234
+ layer_stat['small_expert_counts'],
235
+ layer_stat['small_expert_load']
236
+ )):
237
+ count_pct = count / (layer_stat['total_tokens'] * top_k) * 100
238
+ load_pct = load / layer_stat['total_tokens'] * 100
239
+ print(f" Small Expert {expert_idx}: Count={count:,} ({count_pct:.2f}%), Load={load:.2f} ({load_pct:.2f}%)")
240
+
241
+ print("="*80 + "\n")
242
+
243
+
244
+ def parse_args():
245
+ """Parse command line arguments."""
246
+ parser = argparse.ArgumentParser(
247
+ description="Evaluate OLMoE models with expert usage tracking",
248
+ formatter_class=argparse.RawDescriptionHelpFormatter,
249
+ epilog="""
250
+ Examples:
251
+ # Standard evaluation with expert tracking
252
+ python eval_with_expert_tracking.py --model_type transformers --tasks mmlu arc_easy
253
+
254
+ # Custom model evaluation with expert tracking
255
+ python eval_with_expert_tracking.py --model_type custom --tasks mmlu hellaswag
256
+ """
257
+ )
258
+
259
+ # Model arguments
260
+ parser.add_argument(
261
+ "--model_path",
262
+ type=str,
263
+ default="allenai/OLMoE-1B-7B-0924",
264
+ help="Path or name of the pretrained model"
265
+ )
266
+ parser.add_argument(
267
+ "--model_type",
268
+ type=str,
269
+ default="transformers",
270
+ choices=["transformers", "custom"],
271
+ help="Model type: 'transformers' for standard OLMoE, 'custom' for MyOLMoE"
272
+ )
273
+ parser.add_argument(
274
+ "--custom_model_path",
275
+ type=str,
276
+ default="./myolmoe_model",
277
+ help="Path to custom MyOLMoE model code (when using --model_type custom)"
278
+ )
279
+
280
+ # Evaluation arguments
281
+ parser.add_argument(
282
+ "--tasks",
283
+ type=str,
284
+ nargs="+",
285
+ default=["mmlu"],
286
+ help="Tasks to evaluate on (e.g., mmlu, hellaswag, arc_easy, gsm8k)"
287
+ )
288
+ parser.add_argument(
289
+ "--num_fewshot",
290
+ type=int,
291
+ default=0,
292
+ help="Number of few-shot examples"
293
+ )
294
+ parser.add_argument(
295
+ "--batch_size",
296
+ type=int,
297
+ default=8,
298
+ help="Batch size for evaluation"
299
+ )
300
+ parser.add_argument(
301
+ "--max_batch_size",
302
+ type=int,
303
+ default=None,
304
+ help="Maximum batch size (auto if None)"
305
+ )
306
+ parser.add_argument(
307
+ "--device",
308
+ type=str,
309
+ default="auto",
310
+ help="Device to use ('auto', 'cuda', 'cpu')"
311
+ )
312
+ parser.add_argument(
313
+ "--dtype",
314
+ type=str,
315
+ default="auto",
316
+ choices=["auto", "float16", "bfloat16", "float32"],
317
+ help="Data type for model weights"
318
+ )
319
+
320
+ # Output arguments
321
+ parser.add_argument(
322
+ "--output_dir",
323
+ type=str,
324
+ default="./eval_results",
325
+ help="Directory to save evaluation results"
326
+ )
327
+ parser.add_argument(
328
+ "--output_filename",
329
+ type=str,
330
+ default=None,
331
+ help="Custom filename for results (auto-generated if not provided)"
332
+ )
333
+
334
+ # Additional arguments
335
+ parser.add_argument(
336
+ "--limit",
337
+ type=int,
338
+ default=None,
339
+ help="Limit number of examples per task (for testing)"
340
+ )
341
+ parser.add_argument(
342
+ "--write_out",
343
+ action="store_true",
344
+ help="Write out individual predictions to files"
345
+ )
346
+ parser.add_argument(
347
+ "--trust_remote_code",
348
+ action="store_true",
349
+ help="Trust remote code when loading model"
350
+ )
351
+ parser.add_argument(
352
+ "--verbosity",
353
+ type=str,
354
+ default="INFO",
355
+ choices=["DEBUG", "INFO", "WARNING", "ERROR"],
356
+ help="Logging verbosity level"
357
+ )
358
+
359
+ return parser.parse_args()
360
+
361
+
362
+ def load_transformers_model(args) -> ExpertTrackingHFLM:
363
+ """
364
+ Load standard Transformers OLMoE model with expert tracking.
365
+
366
+ Args:
367
+ args: Parsed command line arguments
368
+
369
+ Returns:
370
+ ExpertTrackingHFLM: Wrapped model ready for evaluation with expert tracking
371
+ """
372
+ logger.info(f"Loading Transformers OLMoE model with expert tracking: {args.model_path}")
373
+
374
+ # Create ExpertTrackingHFLM model
375
+ model = ExpertTrackingHFLM(
376
+ pretrained=args.model_path,
377
+ device=args.device,
378
+ batch_size=args.batch_size,
379
+ max_batch_size=args.max_batch_size,
380
+ dtype=args.dtype,
381
+ trust_remote_code=args.trust_remote_code
382
+ )
383
+
384
+ logger.info("Transformers model with expert tracking loaded successfully")
385
+ return model
386
+
387
+
388
+ def load_custom_model(args) -> ExpertTrackingHFLM:
389
+ """
390
+ Load custom MyOLMoE model with expert tracking.
391
+
392
+ Args:
393
+ args: Parsed command line arguments
394
+
395
+ Returns:
396
+ ExpertTrackingHFLM: Wrapped model ready for evaluation with expert tracking
397
+ """
398
+ logger.info(f"Loading custom MyOLMoE model with expert tracking: {args.model_path}")
399
+
400
+ # Add custom model path to Python path
401
+ if os.path.exists(args.custom_model_path):
402
+ sys.path.insert(0, args.custom_model_path)
403
+ logger.info(f"Added {args.custom_model_path} to Python path")
404
+ else:
405
+ logger.warning(f"Custom model path not found: {args.custom_model_path}")
406
+
407
+ try:
408
+ # Import custom model class
409
+ from modeling_myolmoe import MyOlmoeForCausalLM
410
+ logger.info("Successfully imported MyOlmoeForCausalLM")
411
+ except ImportError as e:
412
+ logger.error(f"Failed to import custom model: {e}")
413
+ logger.error("Make sure the custom model code is available in the specified path")
414
+ raise
415
+
416
+ # Load model configuration
417
+ config = AutoConfig.from_pretrained(
418
+ args.model_path,
419
+ trust_remote_code=args.trust_remote_code
420
+ )
421
+
422
+ logger.info("Model will use default top-k routing configuration")
423
+
424
+ # Determine torch dtype
425
+ if args.dtype == "auto":
426
+ torch_dtype = "auto"
427
+ else:
428
+ torch_dtype = {
429
+ "float16": torch.float16,
430
+ "bfloat16": torch.bfloat16,
431
+ "float32": torch.float32
432
+ }[args.dtype]
433
+
434
+ # Load the custom model
435
+ hf_model = MyOlmoeForCausalLM.from_pretrained(
436
+ args.model_path,
437
+ config=config,
438
+ torch_dtype=torch_dtype,
439
+ device_map="auto" if args.device == "auto" else None,
440
+ trust_remote_code=args.trust_remote_code
441
+ ).eval()
442
+
443
+ # Wrap in ExpertTrackingHFLM
444
+ model = ExpertTrackingHFLM(
445
+ pretrained=args.model_path,
446
+ device=args.device,
447
+ batch_size=args.batch_size,
448
+ max_batch_size=args.max_batch_size,
449
+ dtype=args.dtype
450
+ )
451
+
452
+ logger.info("Custom model with expert tracking loaded successfully")
453
+ return model
454
+
455
+
456
+ def run_evaluation(args) -> Tuple[Dict[str, Any], Dict[str, Any]]:
457
+ """
458
+ Run evaluation on the specified model and return both task results and expert stats.
459
+
460
+ Args:
461
+ args: Parsed command line arguments
462
+
463
+ Returns:
464
+ Tuple of (evaluation_results, expert_stats)
465
+ """
466
+ logger.info("Starting evaluation with expert tracking...")
467
+
468
+ # Load appropriate model
469
+ if args.model_type == "transformers":
470
+ model = load_transformers_model(args)
471
+ elif args.model_type == "custom":
472
+ model = load_custom_model(args)
473
+ else:
474
+ raise ValueError(f"Unknown model type: {args.model_type}")
475
+
476
+ # Run evaluation
477
+ logger.info(f"Running evaluation on tasks: {args.tasks}")
478
+ logger.info(f"Few-shot examples: {args.num_fewshot}")
479
+ logger.info(f"Batch size: {args.batch_size}")
480
+
481
+ results = evaluator.simple_evaluate(
482
+ model=model,
483
+ tasks=args.tasks,
484
+ num_fewshot=args.num_fewshot,
485
+ limit=args.limit,
486
+ write_out=args.write_out,
487
+ )
488
+
489
+ # Get expert statistics
490
+ expert_stats = model.get_expert_stats()
491
+
492
+ logger.info("Evaluation completed successfully")
493
+ return results, expert_stats
494
+
495
+
496
+ def save_results(results: Dict[str, Any], expert_stats: Dict[str, Any], args) -> str:
497
+ """
498
+ Save evaluation results and expert statistics to file.
499
+
500
+ Args:
501
+ results: Evaluation results
502
+ expert_stats: Expert usage statistics
503
+ args: Parsed command line arguments
504
+
505
+ Returns:
506
+ str: Path to saved results file
507
+ """
508
+ os.makedirs(args.output_dir, exist_ok=True)
509
+
510
+ # Generate filename if not provided
511
+ if args.output_filename is None:
512
+ model_name = os.path.basename(args.model_path.rstrip('/'))
513
+ tasks_str = "_".join(args.tasks[:3])
514
+ if len(args.tasks) > 3:
515
+ tasks_str += f"_and_{len(args.tasks)-3}_more"
516
+
517
+ if args.model_type == "custom":
518
+ filename = f"{model_name}_custom_{tasks_str}_results_with_expert_stats.json"
519
+ else:
520
+ filename = f"{model_name}_transformers_{tasks_str}_results_with_expert_stats.json"
521
+ else:
522
+ filename = args.output_filename
523
+
524
+ if not filename.endswith('.json'):
525
+ filename += '.json'
526
+
527
+ output_path = os.path.join(args.output_dir, filename)
528
+
529
+ # Prepare metadata
530
+ metadata = {
531
+ "model_path": args.model_path,
532
+ "model_type": args.model_type,
533
+ "tasks": args.tasks,
534
+ "num_fewshot": args.num_fewshot,
535
+ "batch_size": args.batch_size,
536
+ "device": args.device,
537
+ "dtype": args.dtype,
538
+ "limit": args.limit,
539
+ }
540
+
541
+ # Add routing info for custom models
542
+ if args.model_type == "custom":
543
+ metadata["routing_type"] = "top-k (default)"
544
+
545
+ combined_results = {
546
+ "metadata": metadata,
547
+ "task_results": results,
548
+ "expert_statistics": expert_stats
549
+ }
550
+
551
+ # Save to file
552
+ with open(output_path, 'w') as f:
553
+ json.dump(combined_results, f, indent=2)
554
+
555
+ logger.info(f"Results saved to {output_path}")
556
+ return output_path
557
+
558
+
559
+ def print_summary(results: Dict[str, Any], expert_stats: Dict[str, Any], args) -> None:
560
+ """
561
+ Print a formatted summary of evaluation results and expert statistics.
562
+
563
+ Args:
564
+ results: Evaluation results
565
+ expert_stats: Expert usage statistics
566
+ args: Parsed command line arguments
567
+ """
568
+ print(f"\n{'='*80}")
569
+ print(f"EVALUATION SUMMARY")
570
+ print(f"Model: {args.model_path}")
571
+ print(f"Type: {args.model_type.upper()}")
572
+ if args.model_type == "custom":
573
+ print(f"Routing: TOP-K (default)")
574
+ print(f"Tasks: {', '.join(args.tasks)}")
575
+ print(f"{'='*80}")
576
+
577
+ # Print task results
578
+ if "results" in results:
579
+ for task, metrics in results["results"].items():
580
+ if isinstance(metrics, dict):
581
+ print(f"\n📊 {task.upper()}:")
582
+ for metric, value in metrics.items():
583
+ if isinstance(value, (int, float)) and not metric.endswith('_stderr'):
584
+ stderr_key = f"{metric}_stderr"
585
+ stderr = metrics.get(stderr_key, 0)
586
+ print(f" {metric:.<20} {value:.4f} (±{stderr:.4f})")
587
+ else:
588
+ print("\n⚠️ No results found in evaluation output")
589
+
590
+ # Print expert statistics
591
+ if expert_stats:
592
+ total_tokens = expert_stats.get('total_tokens', 0)
593
+ if total_tokens > 0:
594
+ top_k = getattr(args, 'top_k', 1) # Default to 1 if not specified
595
+ total_expert_activations = total_tokens * top_k
596
+
597
+ print(f"\n🔍 EXPERT USAGE SUMMARY (Top-{top_k})")
598
+ print(f"Total tokens processed: {total_tokens:,}")
599
+ print(f"Total expert activations: {total_expert_activations:,}")
600
+
601
+ # Regular experts
602
+ if expert_stats.get('regular_expert_usage'):
603
+ print("\nRegular Experts:")
604
+ for expert_idx, stats in sorted(expert_stats['regular_expert_usage'].items()):
605
+ print(f" Expert {expert_idx}: {stats['count']:,} ({stats['percentage']:.2f}%)")
606
+
607
+ # Small experts
608
+ if expert_stats.get('small_expert_usage'):
609
+ print("\nSmall Experts:")
610
+ for expert_idx, stats in sorted(expert_stats['small_expert_usage'].items()):
611
+ print(f" Small Expert {expert_idx}: {stats['count']:,} ({stats['percentage']:.2f}%)")
612
+
613
+ # Layer statistics
614
+ if expert_stats.get('layer_stats'):
615
+ print("\nLayer-wise Statistics (Top 3 most used experts per layer):")
616
+ for layer_idx, layer_stat in expert_stats['layer_stats'].items():
617
+ print(f"\nLayer {layer_idx}:")
618
+ print(f" Tokens processed: {layer_stat['total_tokens']:,}")
619
+
620
+ # Regular experts
621
+ if layer_stat.get('regular_expert_counts'):
622
+ counts = layer_stat['regular_expert_counts']
623
+ top_indices = np.argsort(counts)[-3:][::-1]
624
+ print(" Top Regular Experts:")
625
+ for idx in top_indices:
626
+ count = counts[idx]
627
+ load = layer_stat['regular_expert_load'][idx]
628
+ count_pct = count / (layer_stat['total_tokens'] * top_k) * 100
629
+ load_pct = load / layer_stat['total_tokens'] * 100
630
+ print(f" Expert {idx}: Count={count:,} ({count_pct:.2f}%), Load={load:.2f} ({load_pct:.2f}%)")
631
+
632
+ # Small experts
633
+ if layer_stat.get('small_expert_counts'):
634
+ counts = layer_stat['small_expert_counts']
635
+ top_indices = np.argsort(counts)[-3:][::-1]
636
+ print(" Top Small Experts:")
637
+ for idx in top_indices:
638
+ count = counts[idx]
639
+ load = layer_stat['small_expert_load'][idx]
640
+ count_pct = count / (layer_stat['total_tokens'] * top_k) * 100
641
+ load_pct = load / layer_stat['total_tokens'] * 100
642
+ print(f" Small Expert {idx}: Count={count:,} ({count_pct:.2f}%), Load={load:.2f} ({load_pct:.2f}%)")
643
+
644
+ print(f"\n{'='*80}")
645
+
646
+
647
+ def main():
648
+ """Main evaluation function with expert tracking."""
649
+ args = parse_args()
650
+
651
+ # Set logging level
652
+ numeric_level = getattr(logging, args.verbosity.upper(), None)
653
+ if isinstance(numeric_level, int):
654
+ logging.getLogger().setLevel(numeric_level)
655
+ logger.setLevel(numeric_level)
656
+
657
+ try:
658
+ logger.info("="*80)
659
+ logger.info("Starting OLMoE Model Evaluation with Expert Tracking")
660
+ logger.info("="*80)
661
+
662
+ # Run evaluation
663
+ results, expert_stats = run_evaluation(args)
664
+
665
+ # Save results
666
+ output_path = save_results(results, expert_stats, args)
667
+
668
+ # Print summary
669
+ print_summary(results, expert_stats, args)
670
+
671
+ logger.info(f"✅ Evaluation completed successfully!")
672
+ logger.info(f"📁 Results saved to: {output_path}")
673
+
674
+ except KeyboardInterrupt:
675
+ logger.info("Evaluation interrupted by user")
676
+ sys.exit(1)
677
+ except Exception as e:
678
+ logger.error(f"❌ Evaluation failed: {e}")
679
+ logger.debug("Full traceback:", exc_info=True)
680
+ sys.exit(1)
681
+
682
+
683
+ if __name__ == "__main__":
684
+ main()