try patch hook
Browse files- myolmoe/config.json +1 -5
- scripts/evalexperts.py +25 -32
myolmoe/config.json
CHANGED
|
@@ -30,9 +30,5 @@
|
|
| 30 |
"torch_dtype": "float32",
|
| 31 |
"transformers_version": "4.52.4",
|
| 32 |
"use_cache": true,
|
| 33 |
-
"vocab_size": 50304
|
| 34 |
-
"small_expert_intermediate_ratio": 16,
|
| 35 |
-
"small_expert_count": 64,
|
| 36 |
-
"small_expert_sparsity_coef": 0.1,
|
| 37 |
-
"max_small_expert_count": 64
|
| 38 |
}
|
|
|
|
| 30 |
"torch_dtype": "float32",
|
| 31 |
"transformers_version": "4.52.4",
|
| 32 |
"use_cache": true,
|
| 33 |
+
"vocab_size": 50304
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
}
|
scripts/evalexperts.py
CHANGED
|
@@ -57,38 +57,31 @@ class ExpertTrackingHFLM(HFLM):
|
|
| 57 |
self._make_expert_hook(layer_idx)
|
| 58 |
)
|
| 59 |
|
| 60 |
-
def _make_expert_hook(
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
if
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
#
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
dim=-1
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
num_small_experts=module.num_small_experts if hasattr(module, 'num_small_experts') else 0,
|
| 86 |
-
batch_size=batch_size,
|
| 87 |
-
seq_len=seq_len
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
return expert_hook
|
| 91 |
-
|
| 92 |
def update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
|
| 93 |
topk_probs: torch.Tensor, num_regular_experts: int,
|
| 94 |
num_small_experts: int, batch_size: int, seq_len: int):
|
|
|
|
| 57 |
self._make_expert_hook(layer_idx)
|
| 58 |
)
|
| 59 |
|
| 60 |
+
def _make_expert_hook(layer_idx, model):
|
| 61 |
+
def hook(module, input, output):
|
| 62 |
+
# Get expert routing data from output
|
| 63 |
+
if isinstance(output, tuple) and len(output) == 2:
|
| 64 |
+
hidden_states, routing_weights = output
|
| 65 |
+
else:
|
| 66 |
+
hidden_states = output
|
| 67 |
+
routing_weights = None
|
| 68 |
+
|
| 69 |
+
# Always use the config value for num_small_experts
|
| 70 |
+
num_small_experts = getattr(model.config, 'small_expert_count', 0)
|
| 71 |
+
|
| 72 |
+
expert_stats[layer_idx] = expert_stats.get(layer_idx, {})
|
| 73 |
+
expert_stats[layer_idx]['total'] = expert_stats[layer_idx].get('total', 0) + 1
|
| 74 |
+
|
| 75 |
+
if routing_weights is not None:
|
| 76 |
+
top_expert = routing_weights.argmax(dim=-1)
|
| 77 |
+
for expert_id in top_expert.view(-1).tolist():
|
| 78 |
+
expert_stats[layer_idx][expert_id] = expert_stats[layer_idx].get(expert_id, 0) + 1
|
| 79 |
+
|
| 80 |
+
if expert_id < num_small_experts:
|
| 81 |
+
expert_stats[layer_idx]['small'] = expert_stats[layer_idx].get('small', 0) + 1
|
| 82 |
+
|
| 83 |
+
return hook
|
| 84 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
def update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
|
| 86 |
topk_probs: torch.Tensor, num_regular_experts: int,
|
| 87 |
num_small_experts: int, batch_size: int, seq_len: int):
|