Charlie81 commited on
Commit
c37387b
·
1 Parent(s): 3adfc62

Revert "expert usage stats"

Browse files

This reverts commit a875a536fb5cb362c243ac2a09bbf7e6ec37db66.

Files changed (2) hide show
  1. myolmoe/modeling_myolmoe.py +11 -23
  2. scripts/evalexperts.py +0 -441
myolmoe/modeling_myolmoe.py CHANGED
@@ -1,6 +1,5 @@
1
  import math
2
  from typing import List, Optional, Tuple, Union
3
- from collections import defaultdict
4
  import torch
5
  import torch.nn.functional as F
6
  import torch.utils.checkpoint
@@ -559,17 +558,20 @@ class OlmoeSparseMoeBlock(nn.Module):
559
  self.top_k = config.num_experts_per_tok
560
  self.norm_topk_prob = config.norm_topk_prob
561
 
 
562
  in_second_half = layer_idx >= self.total_layers // 2
563
 
 
564
  if in_second_half:
565
  second_half_idx = layer_idx - (self.total_layers // 2)
566
  num_second_half_blocks = self.total_layers - (self.total_layers // 2)
 
567
  if config.small_expert_strategy == "constant":
568
  self.num_small_experts = config.max_small_expert_count // num_second_half_blocks
569
  elif config.small_expert_strategy == "increment":
 
570
  self.num_small_experts = (
571
- (second_half_idx + 1) * config.max_small_expert_count //
572
- ((num_second_half_blocks * (num_second_half_blocks + 1)) // 2)
573
  )
574
  else:
575
  raise ValueError(f"Unknown strategy: {config.small_expert_strategy}")
@@ -582,19 +584,20 @@ class OlmoeSparseMoeBlock(nn.Module):
582
  ]) if self.num_small_experts > 0 else None
583
 
584
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
585
- self.small_gate = nn.Linear(config.hidden_size, self.num_small_experts, bias=False) \
586
- if self.num_small_experts > 0 else None
587
 
588
- self.small_expert_sparsity_coef = config.small_expert_sparsity_coef
 
 
 
589
 
590
- # Usage tracking (not a buffer, no gradient)
591
- self.expert_usage = defaultdict(int)
592
 
593
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
594
  batch_size, sequence_length, hidden_dim = hidden_states.shape
595
  hidden_states = hidden_states.view(-1, hidden_dim)
596
 
597
  router_logits = self.gate(hidden_states)
 
598
  if self.num_small_experts > 0:
599
  small_router_logits = self.small_gate(hidden_states)
600
  combined_logits = torch.cat([router_logits, small_router_logits], dim=-1)
@@ -604,12 +607,6 @@ class OlmoeSparseMoeBlock(nn.Module):
604
  routing_probs = F.softmax(combined_logits, dim=1, dtype=torch.float)
605
  routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
606
 
607
- # Track expert usage
608
- for i in range(selected_experts.size(0)):
609
- for j in range(self.top_k):
610
- expert_id = selected_experts[i, j].item()
611
- self.expert_usage[expert_id] += 1
612
-
613
  if self.norm_topk_prob:
614
  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
615
 
@@ -635,15 +632,6 @@ class OlmoeSparseMoeBlock(nn.Module):
635
 
636
  return final_hidden_states.view(batch_size, sequence_length, hidden_dim), combined_logits
637
 
638
- def __del__(self):
639
- if self.expert_usage:
640
- print(f"\n[Expert Usage Report for Layer {self.layer_idx}]")
641
- total = sum(self.expert_usage.values())
642
- for expert_id in sorted(self.expert_usage):
643
- count = self.expert_usage[expert_id]
644
- percent = 100.0 * count / total if total > 0 else 0.0
645
- print(f" Expert {expert_id:2d}: {count} times ({percent:.2f}%)")
646
-
647
 
648
  class OlmoeDecoderLayer(nn.Module):
649
  def __init__(self, config: OlmoeConfig, layer_idx: int):
 
1
  import math
2
  from typing import List, Optional, Tuple, Union
 
3
  import torch
4
  import torch.nn.functional as F
5
  import torch.utils.checkpoint
 
558
  self.top_k = config.num_experts_per_tok
559
  self.norm_topk_prob = config.norm_topk_prob
560
 
561
+ # Determine if this block is in the second half
562
  in_second_half = layer_idx >= self.total_layers // 2
563
 
564
+ # Determine small expert count for this layer
565
  if in_second_half:
566
  second_half_idx = layer_idx - (self.total_layers // 2)
567
  num_second_half_blocks = self.total_layers - (self.total_layers // 2)
568
+
569
  if config.small_expert_strategy == "constant":
570
  self.num_small_experts = config.max_small_expert_count // num_second_half_blocks
571
  elif config.small_expert_strategy == "increment":
572
+ # Linearly scale small experts from 1 to max_small_expert_count
573
  self.num_small_experts = (
574
+ (second_half_idx + 1) * config.max_small_expert_count // ((num_second_half_blocks * (num_second_half_blocks + 1)) // 2)
 
575
  )
576
  else:
577
  raise ValueError(f"Unknown strategy: {config.small_expert_strategy}")
 
584
  ]) if self.num_small_experts > 0 else None
585
 
586
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
 
 
587
 
588
+ if self.num_small_experts > 0:
589
+ self.small_gate = nn.Linear(config.hidden_size, self.num_small_experts, bias=False)
590
+ else:
591
+ self.small_gate = None
592
 
593
+ self.small_expert_sparsity_coef = config.small_expert_sparsity_coef
 
594
 
595
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
596
  batch_size, sequence_length, hidden_dim = hidden_states.shape
597
  hidden_states = hidden_states.view(-1, hidden_dim)
598
 
599
  router_logits = self.gate(hidden_states)
600
+
601
  if self.num_small_experts > 0:
602
  small_router_logits = self.small_gate(hidden_states)
603
  combined_logits = torch.cat([router_logits, small_router_logits], dim=-1)
 
607
  routing_probs = F.softmax(combined_logits, dim=1, dtype=torch.float)
608
  routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
609
 
 
 
 
 
 
 
610
  if self.norm_topk_prob:
611
  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
612
 
 
632
 
633
  return final_hidden_states.view(batch_size, sequence_length, hidden_dim), combined_logits
634
 
 
 
 
 
 
 
 
 
 
635
 
636
  class OlmoeDecoderLayer(nn.Module):
637
  def __init__(self, config: OlmoeConfig, layer_idx: int):
scripts/evalexperts.py DELETED
@@ -1,441 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- eval_with_expert_tracking.py - Evaluation script for MyOLMoE models with expert usage tracking
4
-
5
- This script evaluates a custom MyOLMoE model on benchmark tasks and tracks expert usage per layer.
6
-
7
- Usage Example:
8
- python eval_with_expert_tracking.py --model_path allenai/OLMoE-1B-7B-0924 --tasks mmlu hellaswag --num_fewshot 5
9
- """
10
-
11
- import argparse
12
- import json
13
- import os
14
- import sys
15
- import logging
16
- from typing import Dict, List, Tuple, Any
17
- import torch
18
- import numpy as np
19
- from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
20
- from lm_eval import evaluator
21
- from lm_eval.models.huggingface import HFLM
22
-
23
- # Set up logging
24
- logging.basicConfig(
25
- level=logging.INFO,
26
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
27
- )
28
- logger = logging.getLogger(__name__)
29
-
30
- def parse_args():
31
- """Parse command line arguments."""
32
- parser = argparse.ArgumentParser(
33
- description="Evaluate MyOLMoE model on benchmark tasks with expert usage tracking",
34
- formatter_class=argparse.RawDescriptionHelpFormatter,
35
- )
36
-
37
- # Model arguments
38
- parser.add_argument(
39
- "--model_path",
40
- type=str,
41
- default="allenai/OLMoE-1B-7B-0924",
42
- help="Path or name of the pretrained MyOLMoE model"
43
- )
44
- parser.add_argument(
45
- "--custom_model_path",
46
- type=str,
47
- default="./myolmoe_model",
48
- help="Path to custom MyOLMoE model code"
49
- )
50
- parser.add_argument(
51
- "--device",
52
- type=str,
53
- default="auto",
54
- help="Device to use ('auto', 'cuda', 'cpu')"
55
- )
56
- parser.add_argument(
57
- "--dtype",
58
- type=str,
59
- default="auto",
60
- choices=["auto", "float16", "bfloat16", "float32"],
61
- help="Data type for model weights"
62
- )
63
- parser.add_argument(
64
- "--trust_remote_code",
65
- action="store_true",
66
- help="Trust remote code when loading model"
67
- )
68
-
69
- # Evaluation arguments
70
- parser.add_argument(
71
- "--tasks",
72
- type=str,
73
- nargs="+",
74
- default=["mmlu"],
75
- help="Tasks to evaluate on (e.g., mmlu, hellaswag, arc_easy)"
76
- )
77
- parser.add_argument(
78
- "--num_fewshot",
79
- type=int,
80
- default=0,
81
- help="Number of few-shot examples"
82
- )
83
- parser.add_argument(
84
- "--batch_size",
85
- type=int,
86
- default=8,
87
- help="Batch size for evaluation"
88
- )
89
- parser.add_argument(
90
- "--max_batch_size",
91
- type=int,
92
- default=None,
93
- help="Maximum batch size (auto if None)"
94
- )
95
- parser.add_argument(
96
- "--limit",
97
- type=int,
98
- default=None,
99
- help="Limit number of examples per task (for testing)"
100
- )
101
-
102
- # Output arguments
103
- parser.add_argument(
104
- "--output_dir",
105
- type=str,
106
- default="./eval_results",
107
- help="Directory to save evaluation results and expert usage"
108
- )
109
- parser.add_argument(
110
- "--output_filename",
111
- type=str,
112
- default=None,
113
- help="Custom filename for results (auto-generated if not provided)"
114
- )
115
-
116
- return parser.parse_args()
117
-
118
- def load_custom_model(args) -> Tuple[AutoModelForCausalLM, AutoTokenizer, HFLM]:
119
- """
120
- Load custom MyOLMoE model, tokenizer, and HFLM wrapper.
121
-
122
- Args:
123
- args: Parsed command line arguments
124
-
125
- Returns:
126
- Tuple of (model, tokenizer, HFLM wrapper)
127
- """
128
- logger.info(f"Loading custom MyOLMoE model: {args.model_path}")
129
-
130
- # Add custom model path to Python path
131
- if os.path.exists(args.custom_model_path):
132
- sys.path.insert(0, args.custom_model_path)
133
- logger.info(f"Added {args.custom_model_path} to Python path")
134
- else:
135
- logger.error(f"Custom model path not found: {args.custom_model_path}")
136
- raise FileNotFoundError(f"Custom model path not found: {args.custom_model_path}")
137
-
138
- try:
139
- from modeling_myolmoe import MyOlmoeForCausalLM
140
- logger.info("Successfully imported MyOlmoeForCausalLM")
141
- except ImportError as e:
142
- logger.error(f"Failed to import custom model: {e}")
143
- raise
144
-
145
- # Load model configuration
146
- config = AutoConfig.from_pretrained(
147
- args.model_path,
148
- trust_remote_code=args.trust_remote_code
149
- )
150
-
151
- # Determine torch dtype
152
- torch_dtype = args.dtype
153
- if args.dtype != "auto":
154
- torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[args.dtype]
155
-
156
- # Load model and tokenizer
157
- model = MyOlmoeForCausalLM.from_pretrained(
158
- args.model_path,
159
- config=config,
160
- torch_dtype=torch_dtype,
161
- device_map="auto" if args.device == "auto" else None,
162
- trust_remote_code=args.trust_remote_code
163
- ).eval()
164
-
165
- tokenizer = AutoTokenizer.from_pretrained(
166
- args.model_path,
167
- trust_remote_code=args.trust_remote_code
168
- )
169
-
170
- # Create HFLM wrapper for evaluation
171
- hf_model = HFLM(
172
- pretrained=args.model_path, # Pass model path as string
173
- device=args.device,
174
- batch_size=args.batch_size,
175
- max_batch_size=args.max_batch_size,
176
- dtype=args.dtype,
177
- trust_remote_code=args.trust_remote_code
178
- )
179
-
180
- logger.info("Custom model, tokenizer, and HFLM wrapper loaded successfully")
181
- return model, tokenizer, hf_model
182
-
183
- def track_expert_usage(model, input_ids: torch.Tensor) -> List[Dict[int, int]]:
184
- """
185
- Track expert usage per layer during a single forward pass.
186
-
187
- Args:
188
- model: MyOLMoE model
189
- input_ids: Input token IDs (batched)
190
-
191
- Returns:
192
- List of dictionaries, where each dictionary maps expert indices to their usage counts for a layer
193
- """
194
- expert_usage = [{} for _ in range(model.config.num_hidden_layers)]
195
-
196
- def hook_fn(module, input, output, layer_idx):
197
- if hasattr(module, 'selected_experts'): # Hypothetical attribute
198
- selected_experts = module.selected_experts # Shape: (batch_size, seq_len, top_k)
199
- for expert_idx in selected_experts.flatten().tolist():
200
- expert_usage[layer_idx][expert_idx] = expert_usage[layer_idx].get(expert_idx, 0) + 1
201
- elif hasattr(module, 'routing_weights'): # Alternative: use routing weights
202
- weights = module.routing_weights # Shape: (batch_size, seq_len, num_experts)
203
- top_k_indices = torch.topk(weights, k=model.config.top_k, dim=-1).indices
204
- for expert_idx in top_k_indices.flatten().tolist():
205
- expert_usage[layer_idx][expert_idx] = expert_usage[layer_idx].get(expert_idx, 0) + 1
206
-
207
- # Register hooks for each MoE layer
208
- hooks = []
209
- for i, layer in enumerate(model.transformer.layers): # Adjust based on actual model structure
210
- if hasattr(layer, 'moe'):
211
- hook = layer.moe.register_forward_hook(lambda m, inp, out: hook_fn(m, inp, out, i))
212
- hooks.append(hook)
213
-
214
- # Run a forward pass
215
- with torch.no_grad():
216
- model(input_ids)
217
-
218
- # Remove hooks
219
- for hook in hooks:
220
- hook.remove()
221
-
222
- return expert_usage
223
-
224
- def run_evaluation_with_tracking(model, hf_model, tokenizer, args) -> Tuple[Dict[str, Any], Dict[str, List[Dict[int, int]]]]:
225
- """
226
- Run evaluation on benchmark tasks and track expert usage.
227
-
228
- Args:
229
- model: MyOLMoE model
230
- hf_model: HFLM wrapper for evaluation
231
- tokenizer: Tokenizer
232
- args: Parsed command line arguments
233
-
234
- Returns:
235
- Tuple of (evaluation results, task-wise expert usage)
236
- """
237
- logger.info(f"Running evaluation on tasks: {args.tasks}")
238
- logger.info(f"Few-shot examples: {args.num_fewshot}")
239
- logger.info(f"Batch size: {args.batch_size}")
240
-
241
- # Initialize expert usage tracking for each task
242
- task_expert_usage = {task: [] for task in args.tasks}
243
-
244
- # Custom batch processing to track expert usage
245
- def custom_loglikelihood(self, requests):
246
- from lm_eval.api.instance import Instance
247
- res = []
248
- for request in requests:
249
- input_ids = tokenizer(request.arguments[0], return_tensors="pt").input_ids.to(model.device)
250
- # Track expert usage
251
- batch_expert_usage = track_expert_usage(model, input_ids)
252
- task_expert_usage[request.task_name].append(batch_expert_usage)
253
- # Original loglikelihood computation
254
- res.append(self._loglikelihood([request]))
255
- return [item for sublist in res for item in sublist]
256
-
257
- # Override HFLM's loglikelihood method
258
- original_loglikelihood = hf_model.loglikelihood
259
- hf_model.loglikelihood = custom_loglikelihood.__get__(hf_model, HFLM)
260
-
261
- # Run evaluation
262
- results = evaluator.simple_evaluate(
263
- model=hf_model,
264
- tasks=args.tasks,
265
- num_fewshot=args.num_fewshot,
266
- limit=args.limit,
267
- batch_size=args.batch_size,
268
- max_batch_size=args.max_batch_size,
269
- )
270
-
271
- # Restore original method
272
- hf_model.loglikelihood = original_loglikelihood
273
-
274
- # Aggregate expert usage per task
275
- aggregated_usage = {}
276
- for task in args.tasks:
277
- if task_expert_usage[task]:
278
- aggregated_usage[task] = [
279
- {k: sum(d.get(k, 0) for d in layer_usages) for k in set().union(*layer_usages)}
280
- for layer_usages in zip(*task_expert_usage[task])
281
- ]
282
- else:
283
- aggregated_usage[task] = [{} for _ in range(model.config.num_hidden_layers)]
284
-
285
- logger.info("Evaluation and expert usage tracking completed")
286
- return results, aggregated_usage
287
-
288
- def make_serializable(obj: Any) -> Any:
289
- """
290
- Convert objects to JSON-serializable format.
291
-
292
- Args:
293
- obj: Object to convert
294
-
295
- Returns:
296
- JSON-serializable version of the object
297
- """
298
- if isinstance(obj, dict):
299
- return {k: make_serializable(v) for k, v in obj.items()}
300
- elif isinstance(obj, list):
301
- return [make_serializable(v) for v in obj]
302
- elif isinstance(obj, tuple):
303
- return tuple(make_serializable(v) for v in obj)
304
- elif isinstance(obj, (np.integer, np.floating)):
305
- return obj.item()
306
- elif isinstance(obj, np.dtype):
307
- return str(obj)
308
- elif isinstance(obj, torch.Tensor):
309
- return obj.tolist()
310
- elif isinstance(obj, torch.dtype):
311
- return str(obj)
312
- else:
313
- return obj
314
-
315
- def save_results(results: Dict[str, Any], expert_usage: Dict[str, List[Dict[int, int]]], args) -> str:
316
- """
317
- Save evaluation results and expert usage to file.
318
-
319
- Args:
320
- results: Evaluation results
321
- expert_usage: Expert usage per task and layer
322
- args: Parsed command line arguments
323
-
324
- Returns:
325
- str: Path to saved results file
326
- """
327
- os.makedirs(args.output_dir, exist_ok=True)
328
-
329
- # Generate filename
330
- if args.output_filename is None:
331
- model_name = os.path.basename(args.model_path.rstrip('/'))
332
- tasks_str = "_".join(args.tasks[:3])
333
- if len(args.tasks) > 3:
334
- tasks_str += f"_and_{len(args.tasks)-3}_more"
335
- filename = f"{model_name}_eval_expert_usage.json"
336
- else:
337
- filename = args.output_filename
338
-
339
- if not filename.endswith('.json'):
340
- filename += '.json'
341
-
342
- output_path = os.path.join(args.output_dir, filename)
343
-
344
- # Prepare results
345
- results_with_metadata = {
346
- "metadata": {
347
- "model_path": args.model_path,
348
- "tasks": args.tasks,
349
- "num_fewshot": args.num_fewshot,
350
- "batch_size": args.batch_size,
351
- "device": args.device,
352
- "dtype": args.dtype,
353
- "limit": args.limit,
354
- "routing_type": "top-k (default)",
355
- },
356
- "results": results,
357
- "expert_usage": {
358
- task: [{str(k): v for k, v in layer_usage.items()} for layer_usage in task_usage]
359
- for task, task_usage in expert_usage.items()
360
- }
361
- }
362
-
363
- # Convert to JSON-serializable format
364
- serializable_results = make_serializable(results_with_metadata)
365
-
366
- # Save to file
367
- with open(output_path, 'w') as f:
368
- json.dump(serializable_results, f, indent=2)
369
-
370
- logger.info(f"Results saved to {output_path}")
371
- return output_path
372
-
373
- def print_summary(results: Dict[str, Any], expert_usage: Dict[str, List[Dict[int, int]]], args) -> None:
374
- """
375
- Print a summary of evaluation results and expert usage.
376
-
377
- Args:
378
- results: Evaluation results
379
- expert_usage: Expert usage per task and layer
380
- args: Parsed command line arguments
381
- """
382
- print(f"\n{'='*80}")
383
- print(f"EVALUATION SUMMARY")
384
- print(f"Model: {args.model_path}")
385
- print(f"Tasks: {', '.join(args.tasks)}")
386
- print(f"{'='*80}")
387
-
388
- if "results" in results:
389
- for task, metrics in results["results"].items():
390
- if isinstance(metrics, dict):
391
- print(f"\n📊 {task.upper()}:")
392
- for metric, value in metrics.items():
393
- if isinstance(value, (int, float)) and not metric.endswith('_stderr'):
394
- stderr_key = f"{metric}_stderr"
395
- stderr = metrics.get(stderr_key, 0)
396
- print(f" {metric:.<20} {value:.4f} (±{stderr:.4f})")
397
-
398
- print(f"\nEXPERT USAGE PER TASK AND LAYER")
399
- for task, task_usage in expert_usage.items():
400
- print(f"\nTask: {task.upper()}")
401
- for i, layer_usage in enumerate(task_usage):
402
- print(f" Layer {i}:")
403
- for expert_idx, count in layer_usage.items():
404
- print(f" Expert {expert_idx}: {count} times")
405
-
406
- print(f"\n{'='*80}")
407
-
408
- def main():
409
- """Main function for evaluation with expert usage tracking."""
410
- args = parse_args()
411
-
412
- try:
413
- logger.info("="*80)
414
- logger.info("Starting MyOLMoE Evaluation with Expert Usage Tracking")
415
- logger.info("="*80)
416
-
417
- # Load model, tokenizer, and HFLM wrapper
418
- model, tokenizer, hf_model = load_custom_model(args)
419
-
420
- # Run evaluation with expert usage tracking
421
- results, expert_usage = run_evaluation_with_tracking(model, hf_model, tokenizer, args)
422
-
423
- # Save results
424
- output_path = save_results(results, expert_usage, args)
425
-
426
- # Print summary
427
- print_summary(results, expert_usage, args)
428
-
429
- logger.info(f"✅ Evaluation completed successfully!")
430
- logger.info(f"📁 Results saved to: {output_path}")
431
-
432
- except KeyboardInterrupt:
433
- logger.info("Evaluation interrupted by user")
434
- sys.exit(1)
435
- except Exception as e:
436
- logger.error(f"❌ Evaluation failed: {e}")
437
- logger.debug("Full traceback:", exc_info=True)
438
- sys.exit(1)
439
-
440
- if __name__ == "__main__":
441
- main()