Charlie81 commited on
Commit
97a7f0a
·
1 Parent(s): b1da2be
Files changed (2) hide show
  1. scripts/evalexperts.py +19 -19
  2. scripts/inspectexperts.py +51 -0
scripts/evalexperts.py CHANGED
@@ -57,30 +57,30 @@ class ExpertTrackingHFLM(HFLM):
57
  self._make_expert_hook(layer_idx)
58
  )
59
 
60
- def _make_expert_hook(layer_idx, model):
61
- def hook(module, input, output):
62
- # Get expert routing data from output
63
- if isinstance(output, tuple) and len(output) == 2:
64
- hidden_states, routing_weights = output
65
- else:
66
- hidden_states = output
67
- routing_weights = None
68
 
69
- # Always use the config value for num_small_experts
70
- num_small_experts = getattr(model.config, 'small_expert_count', 0)
71
 
72
- expert_stats[layer_idx] = expert_stats.get(layer_idx, {})
73
- expert_stats[layer_idx]['total'] = expert_stats[layer_idx].get('total', 0) + 1
74
 
75
- if routing_weights is not None:
76
- top_expert = routing_weights.argmax(dim=-1)
77
- for expert_id in top_expert.view(-1).tolist():
78
- expert_stats[layer_idx][expert_id] = expert_stats[layer_idx].get(expert_id, 0) + 1
79
 
80
- if expert_id < num_small_experts:
81
- expert_stats[layer_idx]['small'] = expert_stats[layer_idx].get('small', 0) + 1
82
 
83
- return hook
84
 
85
  def update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
86
  topk_probs: torch.Tensor, num_regular_experts: int,
 
57
  self._make_expert_hook(layer_idx)
58
  )
59
 
60
+ def _make_expert_hook(layer_idx, model):
61
+ def hook(module, input, output):
62
+ # Get expert routing data from output
63
+ if isinstance(output, tuple) and len(output) == 2:
64
+ hidden_states, routing_weights = output
65
+ else:
66
+ hidden_states = output
67
+ routing_weights = None
68
 
69
+ # Always use the config value for num_small_experts
70
+ num_small_experts = getattr(model.config, 'small_expert_count', 0)
71
 
72
+ expert_stats[layer_idx] = expert_stats.get(layer_idx, {})
73
+ expert_stats[layer_idx]['total'] = expert_stats[layer_idx].get('total', 0) + 1
74
 
75
+ if routing_weights is not None:
76
+ top_expert = routing_weights.argmax(dim=-1)
77
+ for expert_id in top_expert.view(-1).tolist():
78
+ expert_stats[layer_idx][expert_id] = expert_stats[layer_idx].get(expert_id, 0) + 1
79
 
80
+ if expert_id < num_small_experts:
81
+ expert_stats[layer_idx]['small'] = expert_stats[layer_idx].get('small', 0) + 1
82
 
83
+ return hook
84
 
85
  def update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
86
  topk_probs: torch.Tensor, num_regular_experts: int,
scripts/inspectexperts.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import argparse
4
+ import os
5
+ import torch
6
+ from transformers import AutoConfig, AutoModelForCausalLM
7
+
8
+ def main():
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument(
11
+ "--model_path",
12
+ type=str,
13
+ required=True,
14
+ help="Path to the fine-tuned checkpoint directory (e.g., ./checkpoints/checkpoint-16000)",
15
+ )
16
+ parser.add_argument(
17
+ "--custom_model_path",
18
+ type=str,
19
+ required=False,
20
+ help="(Optional) Path to the model implementation source if needed",
21
+ )
22
+ args = parser.parse_args()
23
+
24
+ print(f"Loading config from: {args.model_path}")
25
+ config = AutoConfig.from_pretrained(args.model_path)
26
+
27
+ if hasattr(config, "num_small_experts"):
28
+ num_small_experts = config.num_small_experts
29
+ else:
30
+ raise ValueError("The model config does not contain 'num_small_experts'.")
31
+
32
+ print(f"Number of small experts: {num_small_experts}")
33
+
34
+ print("Loading model...")
35
+ model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.bfloat16)
36
+ model.eval()
37
+
38
+ print("Inspecting small expert weights...")
39
+ total_params = 0
40
+ matched_params = 0
41
+ for name, param in model.named_parameters():
42
+ total_params += 1
43
+ if f"small_experts." in name:
44
+ matched_params += 1
45
+ print(f"[Matched] {name} - shape: {tuple(param.shape)}")
46
+ print(f"\nMatched {matched_params}/{total_params} parameters containing 'small_experts.'")
47
+
48
+ print("Done.")
49
+
50
+ if __name__ == "__main__":
51
+ main()