attribute error
Browse files- scripts/evalexperts.py +56 -1
scripts/evalexperts.py
CHANGED
|
@@ -77,7 +77,7 @@ class ExpertTrackingHFLM(HFLM):
|
|
| 77 |
)
|
| 78 |
|
| 79 |
# Update statistics
|
| 80 |
-
self._update_expert_stats
|
| 81 |
layer_idx=layer_idx,
|
| 82 |
topk_experts=topk_experts,
|
| 83 |
topk_probs=topk_probs,
|
|
@@ -89,6 +89,61 @@ class ExpertTrackingHFLM(HFLM):
|
|
| 89 |
|
| 90 |
return expert_hook
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
def _update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
|
| 93 |
topk_probs: torch.Tensor, num_regular_experts: int,
|
| 94 |
num_small_experts: int, batch_size: int, seq_len: int):
|
|
|
|
| 77 |
)
|
| 78 |
|
| 79 |
# Update statistics
|
| 80 |
+
self.update_expert_stats( # Changed from _update_expert_stats to update_expert_stats
|
| 81 |
layer_idx=layer_idx,
|
| 82 |
topk_experts=topk_experts,
|
| 83 |
topk_probs=topk_probs,
|
|
|
|
| 89 |
|
| 90 |
return expert_hook
|
| 91 |
|
| 92 |
+
def update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor, # Renamed from _update_expert_stats
|
| 93 |
+
topk_probs: torch.Tensor, num_regular_experts: int,
|
| 94 |
+
num_small_experts: int, batch_size: int, seq_len: int):
|
| 95 |
+
"""Update expert usage statistics."""
|
| 96 |
+
# Flatten the batch and sequence dimensions
|
| 97 |
+
topk_experts_flat = topk_experts.view(-1, topk_experts.size(-1))
|
| 98 |
+
topk_probs_flat = topk_probs.view(-1, topk_probs.size(-1))
|
| 99 |
+
|
| 100 |
+
# Initialize layer stats if not present
|
| 101 |
+
if layer_idx not in self.expert_stats['layer_stats']:
|
| 102 |
+
self.expert_stats['layer_stats'][layer_idx] = {
|
| 103 |
+
'total_tokens': 0,
|
| 104 |
+
'regular_expert_counts': [0] * num_regular_experts,
|
| 105 |
+
'small_expert_counts': [0] * num_small_experts if num_small_experts > 0 else None,
|
| 106 |
+
'regular_expert_load': [0.0] * num_regular_experts,
|
| 107 |
+
'small_expert_load': [0.0] * num_small_experts if num_small_experts > 0 else None
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
layer_stats = self.expert_stats['layer_stats'][layer_idx]
|
| 111 |
+
num_tokens = topk_experts_flat.size(0)
|
| 112 |
+
|
| 113 |
+
# Update global stats
|
| 114 |
+
self.expert_stats['total_tokens'] += num_tokens
|
| 115 |
+
|
| 116 |
+
# Update layer stats
|
| 117 |
+
layer_stats['total_tokens'] += num_tokens
|
| 118 |
+
|
| 119 |
+
# Track regular experts
|
| 120 |
+
for expert_idx in range(num_regular_experts):
|
| 121 |
+
mask = (topk_experts_flat == expert_idx)
|
| 122 |
+
count = mask.sum().item()
|
| 123 |
+
load = topk_probs_flat[mask].sum().item()
|
| 124 |
+
|
| 125 |
+
layer_stats['regular_expert_counts'][expert_idx] += count
|
| 126 |
+
layer_stats['regular_expert_load'][expert_idx] += load
|
| 127 |
+
|
| 128 |
+
if expert_idx not in self.expert_stats['regular_expert_usage']:
|
| 129 |
+
self.expert_stats['regular_expert_usage'][expert_idx] = 0
|
| 130 |
+
self.expert_stats['regular_expert_usage'][expert_idx] += count
|
| 131 |
+
|
| 132 |
+
# Track small experts if they exist
|
| 133 |
+
if num_small_experts > 0:
|
| 134 |
+
for expert_idx in range(num_small_experts):
|
| 135 |
+
small_expert_num = expert_idx + num_regular_experts
|
| 136 |
+
mask = (topk_experts_flat == small_expert_num)
|
| 137 |
+
count = mask.sum().item()
|
| 138 |
+
load = topk_probs_flat[mask].sum().item()
|
| 139 |
+
|
| 140 |
+
layer_stats['small_expert_counts'][expert_idx] += count
|
| 141 |
+
layer_stats['small_expert_load'][expert_idx] += load
|
| 142 |
+
|
| 143 |
+
if expert_idx not in self.expert_stats['small_expert_usage']:
|
| 144 |
+
self.expert_stats['small_expert_usage'][expert_idx] = 0
|
| 145 |
+
self.expert_stats['small_expert_usage'][expert_idx] += count
|
| 146 |
+
|
| 147 |
def _update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
|
| 148 |
topk_probs: torch.Tensor, num_regular_experts: int,
|
| 149 |
num_small_experts: int, batch_size: int, seq_len: int):
|