robust get_expert_stats and remove repeat
Browse files- scripts/evalexperts.py +28 -162
scripts/evalexperts.py
CHANGED
|
@@ -146,158 +146,25 @@ class ExpertTrackingHFLM(HFLM):
|
|
| 146 |
|
| 147 |
def get_expert_stats(self) -> Dict[str, Any]:
|
| 148 |
"""Return expert usage statistics in a serializable format."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
stats = {
|
| 150 |
-
'total_tokens': self.expert_stats['total_tokens'],
|
| 151 |
-
'regular_expert_usage': {},
|
| 152 |
-
'small_expert_usage': {},
|
| 153 |
-
'layer_stats': {}
|
| 154 |
-
}
|
| 155 |
-
|
| 156 |
-
# Convert regular expert usage
|
| 157 |
-
for expert_idx, count in self.expert_stats['regular_expert_usage'].items():
|
| 158 |
-
stats['regular_expert_usage'][expert_idx] = {
|
| 159 |
-
'count': count,
|
| 160 |
-
'percentage': count / (self.expert_stats['total_tokens'] * getattr(self.model.config, 'top_k', 1)) * 100
|
| 161 |
-
}
|
| 162 |
-
|
| 163 |
-
# Convert small expert usage if they exist
|
| 164 |
-
if self.expert_stats['small_expert_usage']:
|
| 165 |
-
for expert_idx, count in self.expert_stats['small_expert_usage'].items():
|
| 166 |
-
stats['small_expert_usage'][expert_idx] = {
|
| 167 |
-
'count': count,
|
| 168 |
-
'percentage': count / (self.expert_stats['total_tokens'] * getattr(self.model.config, 'top_k', 1)) * 100
|
| 169 |
-
}
|
| 170 |
-
|
| 171 |
-
# Convert layer stats
|
| 172 |
-
for layer_idx, layer_stat in self.expert_stats['layer_stats'].items():
|
| 173 |
-
stats['layer_stats'][layer_idx] = {
|
| 174 |
-
'total_tokens': layer_stat['total_tokens'],
|
| 175 |
-
'regular_expert_counts': layer_stat['regular_expert_counts'],
|
| 176 |
-
'regular_expert_load': layer_stat['regular_expert_load'],
|
| 177 |
-
'small_expert_counts': layer_stat['small_expert_counts'],
|
| 178 |
-
'small_expert_load': layer_stat['small_expert_load']
|
| 179 |
-
}
|
| 180 |
-
|
| 181 |
-
return stats
|
| 182 |
-
|
| 183 |
-
def print_expert_stats(self) -> None:
|
| 184 |
-
"""Print expert usage statistics in a human-readable format."""
|
| 185 |
-
if not self.expert_stats['total_tokens']:
|
| 186 |
-
print("No expert usage statistics collected.")
|
| 187 |
-
return
|
| 188 |
-
|
| 189 |
-
total_tokens = self.expert_stats['total_tokens']
|
| 190 |
-
top_k = getattr(self.model.config, 'top_k', 1)
|
| 191 |
-
total_expert_activations = total_tokens * top_k
|
| 192 |
-
|
| 193 |
-
print("\n" + "="*80)
|
| 194 |
-
print("EXPERT USAGE STATISTICS")
|
| 195 |
-
print("="*80)
|
| 196 |
-
print(f"Total tokens processed: {total_tokens:,}")
|
| 197 |
-
print(f"Total expert activations (top-{top_k}): {total_expert_activations:,}")
|
| 198 |
-
print("\nOverall Expert Usage:")
|
| 199 |
-
|
| 200 |
-
# Print regular experts
|
| 201 |
-
if self.expert_stats['regular_expert_usage']:
|
| 202 |
-
print("\nRegular Experts:")
|
| 203 |
-
for expert_idx, count in sorted(self.expert_stats['regular_expert_usage'].items()):
|
| 204 |
-
percentage = count / total_expert_activations * 100
|
| 205 |
-
print(f" Expert {expert_idx}: {count:,} ({percentage:.2f}%)")
|
| 206 |
-
|
| 207 |
-
# Print small experts if they exist
|
| 208 |
-
if self.expert_stats['small_expert_usage']:
|
| 209 |
-
print("\nSmall Experts:")
|
| 210 |
-
for expert_idx, count in sorted(self.expert_stats['small_expert_usage'].items()):
|
| 211 |
-
percentage = count / total_expert_activations * 100
|
| 212 |
-
print(f" Small Expert {expert_idx}: {count:,} ({percentage:.2f}%)")
|
| 213 |
-
|
| 214 |
-
# Print layer-wise statistics
|
| 215 |
-
print("\nLayer-wise Statistics:")
|
| 216 |
-
for layer_idx, layer_stat in self.expert_stats['layer_stats'].items():
|
| 217 |
-
print(f"\nLayer {layer_idx}:")
|
| 218 |
-
print(f" Tokens processed: {layer_stat['total_tokens']:,}")
|
| 219 |
-
|
| 220 |
-
# Regular experts
|
| 221 |
-
print(" Regular Experts:")
|
| 222 |
-
for expert_idx, (count, load) in enumerate(zip(
|
| 223 |
-
layer_stat['regular_expert_counts'],
|
| 224 |
-
layer_stat['regular_expert_load']
|
| 225 |
-
)):
|
| 226 |
-
count_pct = count / (layer_stat['total_tokens'] * top_k) * 100
|
| 227 |
-
load_pct = load / layer_stat['total_tokens'] * 100
|
| 228 |
-
print(f" Expert {expert_idx}: Count={count:,} ({count_pct:.2f}%), Load={load:.2f} ({load_pct:.2f}%)")
|
| 229 |
-
|
| 230 |
-
# Small experts if they exist
|
| 231 |
-
if layer_stat['small_expert_counts'] is not None:
|
| 232 |
-
print(" Small Experts:")
|
| 233 |
-
for expert_idx, (count, load) in enumerate(zip(
|
| 234 |
-
layer_stat['small_expert_counts'],
|
| 235 |
-
layer_stat['small_expert_load']
|
| 236 |
-
)):
|
| 237 |
-
count_pct = count / (layer_stat['total_tokens'] * top_k) * 100
|
| 238 |
-
load_pct = load / layer_stat['total_tokens'] * 100
|
| 239 |
-
print(f" Small Expert {expert_idx}: Count={count:,} ({count_pct:.2f}%), Load={load:.2f} ({load_pct:.2f}%)")
|
| 240 |
-
|
| 241 |
-
print("="*80 + "\n")
|
| 242 |
-
def _update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
|
| 243 |
-
topk_probs: torch.Tensor, num_regular_experts: int,
|
| 244 |
-
num_small_experts: int, batch_size: int, seq_len: int):
|
| 245 |
-
"""Update expert usage statistics with serializable data types."""
|
| 246 |
-
# Flatten the batch and sequence dimensions
|
| 247 |
-
topk_experts_flat = topk_experts.view(-1, topk_experts.size(-1))
|
| 248 |
-
topk_probs_flat = topk_probs.view(-1, topk_probs.size(-1))
|
| 249 |
-
|
| 250 |
-
# Initialize layer stats if not present
|
| 251 |
-
if layer_idx not in self.expert_stats['layer_stats']:
|
| 252 |
-
self.expert_stats['layer_stats'][layer_idx] = {
|
| 253 |
-
'total_tokens': 0,
|
| 254 |
-
'regular_expert_counts': [0] * num_regular_experts, # Use list instead of tensor
|
| 255 |
-
'small_expert_counts': [0] * num_small_experts if num_small_experts > 0 else None,
|
| 256 |
-
'regular_expert_load': [0.0] * num_regular_experts,
|
| 257 |
-
'small_expert_load': [0.0] * num_small_experts if num_small_experts > 0 else None
|
| 258 |
-
}
|
| 259 |
-
|
| 260 |
-
layer_stats = self.expert_stats['layer_stats'][layer_idx]
|
| 261 |
-
num_tokens = topk_experts_flat.size(0)
|
| 262 |
-
|
| 263 |
-
# Update global stats
|
| 264 |
-
self.expert_stats['total_tokens'] += num_tokens
|
| 265 |
-
|
| 266 |
-
# Update layer stats
|
| 267 |
-
layer_stats['total_tokens'] += num_tokens
|
| 268 |
-
|
| 269 |
-
# Track regular experts
|
| 270 |
-
for expert_idx in range(num_regular_experts):
|
| 271 |
-
mask = (topk_experts_flat == expert_idx)
|
| 272 |
-
count = mask.sum().item()
|
| 273 |
-
load = topk_probs_flat[mask].sum().item()
|
| 274 |
-
|
| 275 |
-
layer_stats['regular_expert_counts'][expert_idx] += count
|
| 276 |
-
layer_stats['regular_expert_load'][expert_idx] += load
|
| 277 |
-
|
| 278 |
-
if expert_idx not in self.expert_stats['regular_expert_usage']:
|
| 279 |
-
self.expert_stats['regular_expert_usage'][expert_idx] = 0
|
| 280 |
-
self.expert_stats['regular_expert_usage'][expert_idx] += count
|
| 281 |
-
|
| 282 |
-
# Track small experts if they exist
|
| 283 |
-
if num_small_experts > 0:
|
| 284 |
-
for expert_idx in range(num_small_experts):
|
| 285 |
-
small_expert_num = expert_idx + num_regular_experts
|
| 286 |
-
mask = (topk_experts_flat == small_expert_num)
|
| 287 |
-
count = mask.sum().item()
|
| 288 |
-
load = topk_probs_flat[mask].sum().item()
|
| 289 |
-
|
| 290 |
-
layer_stats['small_expert_counts'][expert_idx] += count
|
| 291 |
-
layer_stats['small_expert_load'][expert_idx] += load
|
| 292 |
-
|
| 293 |
-
if expert_idx not in self.expert_stats['small_expert_usage']:
|
| 294 |
-
self.expert_stats['small_expert_usage'][expert_idx] = 0
|
| 295 |
-
self.expert_stats['small_expert_usage'][expert_idx] += count
|
| 296 |
-
|
| 297 |
-
def get_expert_stats(self) -> Dict[str, Any]:
|
| 298 |
-
"""Return expert usage statistics in a serializable format."""
|
| 299 |
-
stats = {
|
| 300 |
-
'total_tokens': self.expert_stats['total_tokens'],
|
| 301 |
'regular_expert_usage': {},
|
| 302 |
'small_expert_usage': {},
|
| 303 |
'layer_stats': {}
|
|
@@ -306,30 +173,30 @@ def _update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
|
|
| 306 |
# Convert regular expert usage
|
| 307 |
for expert_idx, count in self.expert_stats['regular_expert_usage'].items():
|
| 308 |
stats['regular_expert_usage'][expert_idx] = {
|
| 309 |
-
'count': count,
|
| 310 |
-
'percentage': count / (self.expert_stats['total_tokens'] * self.model.config
|
| 311 |
}
|
| 312 |
|
| 313 |
# Convert small expert usage if they exist
|
| 314 |
if self.expert_stats['small_expert_usage']:
|
| 315 |
for expert_idx, count in self.expert_stats['small_expert_usage'].items():
|
| 316 |
stats['small_expert_usage'][expert_idx] = {
|
| 317 |
-
'count': count,
|
| 318 |
-
'percentage': count / (self.expert_stats['total_tokens'] * self.model.config
|
| 319 |
}
|
| 320 |
|
| 321 |
# Convert layer stats
|
| 322 |
for layer_idx, layer_stat in self.expert_stats['layer_stats'].items():
|
| 323 |
stats['layer_stats'][layer_idx] = {
|
| 324 |
-
'total_tokens': layer_stat['total_tokens'],
|
| 325 |
-
'regular_expert_counts': layer_stat['regular_expert_counts']
|
| 326 |
-
'regular_expert_load': layer_stat['regular_expert_load']
|
| 327 |
-
'small_expert_counts': layer_stat['small_expert_counts']
|
| 328 |
-
'small_expert_load': layer_stat['small_expert_load']
|
| 329 |
}
|
| 330 |
|
| 331 |
return stats
|
| 332 |
-
|
| 333 |
def print_expert_stats(self) -> None:
|
| 334 |
"""Print expert usage statistics in a human-readable format."""
|
| 335 |
if not self.expert_stats['total_tokens']:
|
|
@@ -390,7 +257,6 @@ def _update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor,
|
|
| 390 |
|
| 391 |
print("="*80 + "\n")
|
| 392 |
|
| 393 |
-
|
| 394 |
def parse_args():
|
| 395 |
"""Parse command line arguments."""
|
| 396 |
parser = argparse.ArgumentParser(
|
|
|
|
| 146 |
|
| 147 |
def get_expert_stats(self) -> Dict[str, Any]:
|
| 148 |
"""Return expert usage statistics in a serializable format."""
|
| 149 |
+
def convert(obj):
|
| 150 |
+
"""Recursively convert objects to JSON-serializable formats."""
|
| 151 |
+
if isinstance(obj, (np.integer, np.floating)):
|
| 152 |
+
return int(obj) if isinstance(obj, np.integer) else float(obj)
|
| 153 |
+
elif isinstance(obj, np.ndarray):
|
| 154 |
+
return obj.tolist()
|
| 155 |
+
elif isinstance(obj, torch.Tensor):
|
| 156 |
+
return obj.cpu().numpy().tolist()
|
| 157 |
+
elif isinstance(obj, torch.dtype):
|
| 158 |
+
return str(obj)
|
| 159 |
+
elif isinstance(obj, (dict)):
|
| 160 |
+
return {k: convert(v) for k, v in obj.items()}
|
| 161 |
+
elif isinstance(obj, (list, tuple)):
|
| 162 |
+
return [convert(v) for v in obj]
|
| 163 |
+
else:
|
| 164 |
+
return obj
|
| 165 |
+
|
| 166 |
stats = {
|
| 167 |
+
'total_tokens': convert(self.expert_stats['total_tokens']),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
'regular_expert_usage': {},
|
| 169 |
'small_expert_usage': {},
|
| 170 |
'layer_stats': {}
|
|
|
|
| 173 |
# Convert regular expert usage
|
| 174 |
for expert_idx, count in self.expert_stats['regular_expert_usage'].items():
|
| 175 |
stats['regular_expert_usage'][expert_idx] = {
|
| 176 |
+
'count': convert(count),
|
| 177 |
+
'percentage': convert(count / (self.expert_stats['total_tokens'] * getattr(self.model.config, 'top_k', 1)) * 100)
|
| 178 |
}
|
| 179 |
|
| 180 |
# Convert small expert usage if they exist
|
| 181 |
if self.expert_stats['small_expert_usage']:
|
| 182 |
for expert_idx, count in self.expert_stats['small_expert_usage'].items():
|
| 183 |
stats['small_expert_usage'][expert_idx] = {
|
| 184 |
+
'count': convert(count),
|
| 185 |
+
'percentage': convert(count / (self.expert_stats['total_tokens'] * getattr(self.model.config, 'top_k', 1)) * 100)
|
| 186 |
}
|
| 187 |
|
| 188 |
# Convert layer stats
|
| 189 |
for layer_idx, layer_stat in self.expert_stats['layer_stats'].items():
|
| 190 |
stats['layer_stats'][layer_idx] = {
|
| 191 |
+
'total_tokens': convert(layer_stat['total_tokens']),
|
| 192 |
+
'regular_expert_counts': convert(layer_stat['regular_expert_counts']),
|
| 193 |
+
'regular_expert_load': convert(layer_stat['regular_expert_load']),
|
| 194 |
+
'small_expert_counts': convert(layer_stat['small_expert_counts']),
|
| 195 |
+
'small_expert_load': convert(layer_stat['small_expert_load'])
|
| 196 |
}
|
| 197 |
|
| 198 |
return stats
|
| 199 |
+
|
| 200 |
def print_expert_stats(self) -> None:
|
| 201 |
"""Print expert usage statistics in a human-readable format."""
|
| 202 |
if not self.expert_stats['total_tokens']:
|
|
|
|
| 257 |
|
| 258 |
print("="*80 + "\n")
|
| 259 |
|
|
|
|
| 260 |
def parse_args():
|
| 261 |
"""Parse command line arguments."""
|
| 262 |
parser = argparse.ArgumentParser(
|