inspect
Browse files- scripts/evalexperts.py +19 -19
- scripts/inspectexperts.py +51 -0
scripts/evalexperts.py
CHANGED
|
@@ -57,30 +57,30 @@ class ExpertTrackingHFLM(HFLM):
|
|
| 57 |
self._make_expert_hook(layer_idx)
|
| 58 |
)
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
|
| 72 |
-
|
| 73 |
-
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
|
| 83 |
-
|
| 84 |
|
| 85 |
def update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
|
| 86 |
topk_probs: torch.Tensor, num_regular_experts: int,
|
|
|
|
| 57 |
self._make_expert_hook(layer_idx)
|
| 58 |
)
|
| 59 |
|
| 60 |
+
def _make_expert_hook(layer_idx, model):
|
| 61 |
+
def hook(module, input, output):
|
| 62 |
+
# Get expert routing data from output
|
| 63 |
+
if isinstance(output, tuple) and len(output) == 2:
|
| 64 |
+
hidden_states, routing_weights = output
|
| 65 |
+
else:
|
| 66 |
+
hidden_states = output
|
| 67 |
+
routing_weights = None
|
| 68 |
|
| 69 |
+
# Always use the config value for num_small_experts
|
| 70 |
+
num_small_experts = getattr(model.config, 'small_expert_count', 0)
|
| 71 |
|
| 72 |
+
expert_stats[layer_idx] = expert_stats.get(layer_idx, {})
|
| 73 |
+
expert_stats[layer_idx]['total'] = expert_stats[layer_idx].get('total', 0) + 1
|
| 74 |
|
| 75 |
+
if routing_weights is not None:
|
| 76 |
+
top_expert = routing_weights.argmax(dim=-1)
|
| 77 |
+
for expert_id in top_expert.view(-1).tolist():
|
| 78 |
+
expert_stats[layer_idx][expert_id] = expert_stats[layer_idx].get(expert_id, 0) + 1
|
| 79 |
|
| 80 |
+
if expert_id < num_small_experts:
|
| 81 |
+
expert_stats[layer_idx]['small'] = expert_stats[layer_idx].get('small', 0) + 1
|
| 82 |
|
| 83 |
+
return hook
|
| 84 |
|
| 85 |
def update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
|
| 86 |
topk_probs: torch.Tensor, num_regular_experts: int,
|
scripts/inspectexperts.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
parser = argparse.ArgumentParser()
|
| 10 |
+
parser.add_argument(
|
| 11 |
+
"--model_path",
|
| 12 |
+
type=str,
|
| 13 |
+
required=True,
|
| 14 |
+
help="Path to the fine-tuned checkpoint directory (e.g., ./checkpoints/checkpoint-16000)",
|
| 15 |
+
)
|
| 16 |
+
parser.add_argument(
|
| 17 |
+
"--custom_model_path",
|
| 18 |
+
type=str,
|
| 19 |
+
required=False,
|
| 20 |
+
help="(Optional) Path to the model implementation source if needed",
|
| 21 |
+
)
|
| 22 |
+
args = parser.parse_args()
|
| 23 |
+
|
| 24 |
+
print(f"Loading config from: {args.model_path}")
|
| 25 |
+
config = AutoConfig.from_pretrained(args.model_path)
|
| 26 |
+
|
| 27 |
+
if hasattr(config, "num_small_experts"):
|
| 28 |
+
num_small_experts = config.num_small_experts
|
| 29 |
+
else:
|
| 30 |
+
raise ValueError("The model config does not contain 'num_small_experts'.")
|
| 31 |
+
|
| 32 |
+
print(f"Number of small experts: {num_small_experts}")
|
| 33 |
+
|
| 34 |
+
print("Loading model...")
|
| 35 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.bfloat16)
|
| 36 |
+
model.eval()
|
| 37 |
+
|
| 38 |
+
print("Inspecting small expert weights...")
|
| 39 |
+
total_params = 0
|
| 40 |
+
matched_params = 0
|
| 41 |
+
for name, param in model.named_parameters():
|
| 42 |
+
total_params += 1
|
| 43 |
+
if f"small_experts." in name:
|
| 44 |
+
matched_params += 1
|
| 45 |
+
print(f"[Matched] {name} - shape: {tuple(param.shape)}")
|
| 46 |
+
print(f"\nMatched {matched_params}/{total_params} parameters containing 'small_experts.'")
|
| 47 |
+
|
| 48 |
+
print("Done.")
|
| 49 |
+
|
| 50 |
+
if __name__ == "__main__":
|
| 51 |
+
main()
|