Charlie81 commited on
Commit
a83c539
·
1 Parent(s): a875a53

add evalexperts

Browse files
Files changed (1) hide show
  1. scripts/evalexperts.py +434 -0
scripts/evalexperts.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ eval_with_expert_tracking.py - Evaluation script for MyOLMoE models with expert usage tracking
4
+
5
+ This script evaluates a custom MyOLMoE model on benchmark tasks and tracks expert usage per layer.
6
+
7
+ Usage Example:
8
+ python eval_with_expert_tracking.py --model_path allenai/OLMoE-1B-7B-0924 --tasks mmlu hellaswag --num_fewshot 5
9
+ """
10
+
11
+ import argparse
12
+ import json
13
+ import os
14
+ import sys
15
+ import logging
16
+ from typing import Dict, List, Tuple, Any
17
+ import torch
18
+ import numpy as np
19
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
20
+ from lm_eval import evaluator
21
+ from lm_eval.models.huggingface import HFLM
22
+
23
+ # Set up logging
24
+ logging.basicConfig(
25
+ level=logging.INFO,
26
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
27
+ )
28
+ logger = logging.getLogger(__name__)
29
+
30
+ def parse_args():
31
+ """Parse command line arguments."""
32
+ parser = argparse.ArgumentParser(
33
+ description="Evaluate MyOLMoE model on benchmark tasks with expert usage tracking",
34
+ formatter_class=argparse.RawDescriptionHelpFormatter,
35
+ )
36
+
37
+ # Model arguments
38
+ parser.add_argument(
39
+ "--model_path",
40
+ type=str,
41
+ default="allenai/OLMoE-1B-7B-0924",
42
+ help="Path or name of the pretrained MyOLMoE model"
43
+ )
44
+ parser.add_argument(
45
+ "--custom_model_path",
46
+ type=str,
47
+ default="./myolmoe_model",
48
+ help="Path to custom MyOLMoE model code"
49
+ )
50
+ parser.add_argument(
51
+ "--device",
52
+ type=str,
53
+ default="auto",
54
+ help="Device to use ('auto', 'cuda', 'cpu')"
55
+ )
56
+ parser.add_argument(
57
+ "--dtype",
58
+ type=str,
59
+ default="auto",
60
+ choices=["auto", "float16", "bfloat16", "float32"],
61
+ help="Data type for model weights"
62
+ )
63
+ parser.add_argument(
64
+ "--trust_remote_code",
65
+ action="store_true",
66
+ help="Trust remote code when loading model"
67
+ )
68
+
69
+ # Evaluation arguments
70
+ parser.add_argument(
71
+ "--tasks",
72
+ type=str,
73
+ nargs="+",
74
+ default=["mmlu"],
75
+ help="Tasks to evaluate on (e.g., mmlu, hellaswag, arc_easy)"
76
+ )
77
+ parser.add_argument(
78
+ "--num_fewshot",
79
+ type=int,
80
+ default=0,
81
+ help="Number of few-shot examples"
82
+ )
83
+ parser.add_argument(
84
+ "--batch_size",
85
+ type=int,
86
+ default=8,
87
+ help="Batch size for evaluation"
88
+ )
89
+ parser.add_argument(
90
+ "--max_batch_size",
91
+ type=int,
92
+ default=None,
93
+ help="Maximum batch size (auto if None)"
94
+ )
95
+ parser.add_argument(
96
+ "--limit",
97
+ type=int,
98
+ default=None,
99
+ help="Limit number of examples per task (for testing)"
100
+ )
101
+
102
+ # Output arguments
103
+ parser.add_argument(
104
+ "--output_dir",
105
+ type=str,
106
+ default="./eval_results",
107
+ help="Directory to save evaluation results and expert usage"
108
+ )
109
+ parser.add_argument(
110
+ "--output_filename",
111
+ type=str,
112
+ default=None,
113
+ help="Custom filename for results (auto-generated if not provided)"
114
+ )
115
+
116
+ return parser.parse_args()
117
+
118
+ def load_custom_model(args) -> Tuple[AutoModelForCausalLM, AutoTokenizer, HFLM]:
119
+ """
120
+ Load custom MyOLMoE model, tokenizer, and HFLM wrapper.
121
+
122
+ Args:
123
+ args: Parsed command line arguments
124
+
125
+ Returns:
126
+ Tuple of (model, tokenizer, HFLM wrapper)
127
+ """
128
+ logger.info(f"Loading custom MyOLMoE model: {args.model_path}")
129
+
130
+ # Add custom model path to Python path
131
+ if os.path.exists(args.custom_model_path):
132
+ sys.path.insert(0, args.custom_model_path)
133
+ logger.info(f"Added {args.custom_model_path} to Python path")
134
+ else:
135
+ logger.error(f"Custom model path not found: {args.custom_model_path}")
136
+ raise FileNotFoundError(f"Custom model path not found: {args.custom_model_path}")
137
+
138
+ try:
139
+ from modeling_myolmoe import MyOlmoeForCausalLM
140
+ logger.info("Successfully imported MyOlmoeForCausalLM")
141
+ except ImportError as e:
142
+ logger.error(f"Failed to import custom model: {e}")
143
+ raise
144
+
145
+ # Load model configuration
146
+ config = AutoConfig.from_pretrained(
147
+ args.model_path,
148
+ trust_remote_code=args.trust_remote_code
149
+ )
150
+
151
+ # Determine torch dtype
152
+ torch_dtype = args.dtype
153
+ if args.dtype != "auto":
154
+ torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[args.dtype]
155
+
156
+ # Load model and tokenizer
157
+ model = MyOlmoeForCausalLM.from_pretrained(
158
+ args.model_path,
159
+ config=config,
160
+ torch_dtype=torch_dtype,
161
+ device_map="auto" if args.device == "auto" else None,
162
+ trust_remote_code=args.trust_remote_code
163
+ ).eval()
164
+
165
+ tokenizer = AutoTokenizer.from_pretrained(
166
+ args.model_path,
167
+ trust_remote_code=args.trust_remote_code
168
+ )
169
+
170
+ # Create HFLM wrapper for evaluation
171
+ hf_model = HFLM(
172
+ pretrained=model,
173
+ tokenizer=tokenizer,
174
+ device=args.device,
175
+ batch_size=args.batch_size,
176
+ max_batch_size=args.max_batch_size,
177
+ dtype=args.dtype
178
+ )
179
+
180
+ logger.info("Custom model, tokenizer, and HFLM wrapper loaded successfully")
181
+ return model, tokenizer, hf_model
182
+
183
+ def track_expert_usage(model, input_ids: torch.Tensor) -> List[Dict[int, int]]:
184
+ """
185
+ Track expert usage per layer during a single forward pass.
186
+
187
+ Args:
188
+ model: MyOLMoE model
189
+ input_ids: Input token IDs (batched)
190
+
191
+ Returns:
192
+ List of dictionaries, where each dictionary maps expert indices to their usage counts for a layer
193
+ """
194
+ expert_usage = [{} for _ in range(model.config.num_hidden_layers)]
195
+
196
+ def hook_fn(module, input, output, layer_idx):
197
+ # Assuming the module outputs selected expert indices
198
+ if hasattr(module, 'selected_experts'): # Hypothetical attribute
199
+ selected_experts = module.selected_experts # Shape: (batch_size, seq_len, top_k)
200
+ for expert_idx in selected_experts.flatten().tolist():
201
+ expert_usage[layer_idx][expert_idx] = expert_usage[layer_idx].get(expert_idx, 0) + 1
202
+
203
+ # Register hooks for each MoE layer
204
+ hooks = []
205
+ for i, layer in enumerate(model.transformer.layers): # Adjust based on actual model structure
206
+ if hasattr(layer, 'moe'): # Check if layer has MoE component
207
+ hook = layer.moe.register_forward_hook(lambda m, inp, out: hook_fn(m, inp, out, i))
208
+ hooks.append(hook)
209
+
210
+ # Run a forward pass
211
+ with torch.no_grad():
212
+ model(input_ids)
213
+
214
+ # Remove hooks
215
+ for hook in hooks:
216
+ hook.remove()
217
+
218
+ return expert_usage
219
+
220
+ def run_evaluation_with_tracking(model, hf_model, tokenizer, args) -> Tuple[Dict[str, Any], Dict[str, List[Dict[int, int]]]]:
221
+ """
222
+ Run evaluation on benchmark tasks and track expert usage.
223
+
224
+ Args:
225
+ model: MyOLMoE model
226
+ hf_model: HFLM wrapper for evaluation
227
+ tokenizer: Tokenizer
228
+ args: Parsed command line arguments
229
+
230
+ Returns:
231
+ Tuple of (evaluation results, task-wise expert usage)
232
+ """
233
+ logger.info(f"Running evaluation on tasks: {args.tasks}")
234
+ logger.info(f"Few-shot examples: {args.num_fewshot}")
235
+ logger.info(f"Batch size: {args.batch_size}")
236
+
237
+ # Initialize expert usage tracking for each task
238
+ task_expert_usage = {task: [] for task in args.tasks}
239
+
240
+ # Custom evaluation loop to track expert usage
241
+ def custom_forward(model, batch):
242
+ input_ids = batch["input_ids"].to(model.device)
243
+ # Track expert usage for this batch
244
+ batch_expert_usage = track_expert_usage(model, input_ids)
245
+ # Accumulate usage for the task
246
+ task_name = batch.get("task_name", args.tasks[0]) # Fallback to first task
247
+ task_expert_usage[task_name].append(batch_expert_usage)
248
+ return model(input_ids)
249
+
250
+ # Override HFLM's forward method to include expert tracking
251
+ original_forward = hf_model.forward
252
+ hf_model.forward = lambda batch: custom_forward(model, batch)
253
+
254
+ # Run evaluation
255
+ results = evaluator.simple_evaluate(
256
+ model=hf_model,
257
+ tasks=args.tasks,
258
+ num_fewshot=args.num_fewshot,
259
+ limit=args.limit,
260
+ batch_size=args.batch_size,
261
+ max_batch_size=args.max_batch_size,
262
+ )
263
+
264
+ # Restore original forward method
265
+ hf_model.forward = original_forward
266
+
267
+ # Aggregate expert usage per task
268
+ aggregated_usage = {}
269
+ for task in args.tasks:
270
+ if task_expert_usage[task]:
271
+ aggregated_usage[task] = [
272
+ {k: sum(d.get(k, 0) for d in layer_usages) for k in set().union(*layer_usages)}
273
+ for layer_usages in zip(*task_expert_usage[task])
274
+ ]
275
+ else:
276
+ aggregated_usage[task] = [{} for _ in range(model.config.num_hidden_layers)]
277
+
278
+ logger.info("Evaluation and expert usage tracking completed")
279
+ return results, aggregated_usage
280
+
281
+ def make_serializable(obj: Any) -> Any:
282
+ """
283
+ Convert objects to JSON-serializable format.
284
+
285
+ Args:
286
+ obj: Object to convert
287
+
288
+ Returns:
289
+ JSON-serializable version of the object
290
+ """
291
+ if isinstance(obj, dict):
292
+ return {k: make_serializable(v) for k, v in obj.items()}
293
+ elif isinstance(obj, list):
294
+ return [make_serializable(v) for v in obj]
295
+ elif isinstance(obj, tuple):
296
+ return tuple(make_serializable(v) for v in obj)
297
+ elif isinstance(obj, (np.integer, np.floating)):
298
+ return obj.item()
299
+ elif isinstance(obj, np.dtype):
300
+ return str(obj)
301
+ elif isinstance(obj, torch.Tensor):
302
+ return obj.tolist()
303
+ elif isinstance(obj, torch.dtype):
304
+ return str(obj)
305
+ else:
306
+ return obj
307
+
308
+ def save_results(results: Dict[str, Any], expert_usage: Dict[str, List[Dict[int, int]]], args) -> str:
309
+ """
310
+ Save evaluation results and expert usage to file.
311
+
312
+ Args:
313
+ results: Evaluation results
314
+ expert_usage: Expert usage per task and layer
315
+ args: Parsed command line arguments
316
+
317
+ Returns:
318
+ str: Path to saved results file
319
+ """
320
+ os.makedirs(args.output_dir, exist_ok=True)
321
+
322
+ # Generate filename
323
+ if args.output_filename is None:
324
+ model_name = os.path.basename(args.model_path.rstrip('/'))
325
+ tasks_str = "_".join(args.tasks[:3])
326
+ if len(args.tasks) > 3:
327
+ tasks_str += f"_and_{len(args.tasks)-3}_more"
328
+ filename = f"{model_name}_eval_expert_usage.json"
329
+ else:
330
+ filename = args.output_filename
331
+
332
+ if not filename.endswith('.json'):
333
+ filename += '.json'
334
+
335
+ output_path = os.path.join(args.output_dir, filename)
336
+
337
+ # Prepare results
338
+ results_with_metadata = {
339
+ "metadata": {
340
+ "model_path": args.model_path,
341
+ "tasks": args.tasks,
342
+ "num_fewshot": args.num_fewshot,
343
+ "batch_size": args.batch_size,
344
+ "device": args.device,
345
+ "dtype": args.dtype,
346
+ "limit": args.limit,
347
+ "routing_type": "top-k (default)",
348
+ },
349
+ "results": results,
350
+ "expert_usage": {
351
+ task: [{str(k): v for k, v in layer_usage.items()} for layer_usage in task_usage]
352
+ for task, task_usage in expert_usage.items()
353
+ }
354
+ }
355
+
356
+ # Convert to JSON-serializable format
357
+ serializable_results = make_serializable(results_with_metadata)
358
+
359
+ # Save to file
360
+ with open(output_path, 'w') as f:
361
+ json.dump(serializable_results, f, indent=2)
362
+
363
+ logger.info(f"Results saved to {output_path}")
364
+ return output_path
365
+
366
+ def print_summary(results: Dict[str, Any], expert_usage: Dict[str, List[Dict[int, int]]], args) -> None:
367
+ """
368
+ Print a summary of evaluation results and expert usage.
369
+
370
+ Args:
371
+ results: Evaluation results
372
+ expert_usage: Expert usage per task and layer
373
+ args: Parsed command line arguments
374
+ """
375
+ print(f"\n{'='*80}")
376
+ print(f"EVALUATION SUMMARY")
377
+ print(f"Model: {args.model_path}")
378
+ print(f"Tasks: {', '.join(args.tasks)}")
379
+ print(f"{'='*80}")
380
+
381
+ if "results" in results:
382
+ for task, metrics in results["results"].items():
383
+ if isinstance(metrics, dict):
384
+ print(f"\n📊 {task.upper()}:")
385
+ for metric, value in metrics.items():
386
+ if isinstance(value, (int, float)) and not metric.endswith('_stderr'):
387
+ stderr_key = f"{metric}_stderr"
388
+ stderr = metrics.get(stderr_key, 0)
389
+ print(f" {metric:.<20} {value:.4f} (±{stderr:.4f})")
390
+
391
+ print(f"\nEXPERT USAGE PER TASK AND LAYER")
392
+ for task, task_usage in expert_usage.items():
393
+ print(f"\nTask: {task.upper()}")
394
+ for i, layer_usage in enumerate(task_usage):
395
+ print(f" Layer {i}:")
396
+ for expert_idx, count in layer_usage.items():
397
+ print(f" Expert {expert_idx}: {count} times")
398
+
399
+ print(f"\n{'='*80}")
400
+
401
+ def main():
402
+ """Main function for evaluation with expert usage tracking."""
403
+ args = parse_args()
404
+
405
+ try:
406
+ logger.info("="*80)
407
+ logger.info("Starting MyOLMoE Evaluation with Expert Usage Tracking")
408
+ logger.info("="*80)
409
+
410
+ # Load model, tokenizer, and HFLM wrapper
411
+ model, tokenizer, hf_model = load_custom_model(args)
412
+
413
+ # Run evaluation with expert usage tracking
414
+ results, expert_usage = run_evaluation_with_tracking(model, hf_model, tokenizer, args)
415
+
416
+ # Save results
417
+ output_path = save_results(results, expert_usage, args)
418
+
419
+ # Print summary
420
+ print_summary(results, expert_usage, args)
421
+
422
+ logger.info(f"✅ Evaluation completed successfully!")
423
+ logger.info(f"📁 Results saved to: {output_path}")
424
+
425
+ except KeyboardInterrupt:
426
+ logger.info("Evaluation interrupted by user")
427
+ sys.exit(1)
428
+ except Exception as e:
429
+ logger.error(f"❌ Evaluation failed: {e}")
430
+ logger.debug("Full traceback:", exc_info=True)
431
+ sys.exit(1)
432
+
433
+ if __name__ == "__main__":
434
+ main()