expert usage stats
Browse files- myolmoe/modeling_myolmoe.py +23 -11
myolmoe/modeling_myolmoe.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import math
|
| 2 |
from typing import List, Optional, Tuple, Union
|
|
|
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
| 5 |
import torch.utils.checkpoint
|
|
@@ -558,20 +559,17 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 558 |
self.top_k = config.num_experts_per_tok
|
| 559 |
self.norm_topk_prob = config.norm_topk_prob
|
| 560 |
|
| 561 |
-
# Determine if this block is in the second half
|
| 562 |
in_second_half = layer_idx >= self.total_layers // 2
|
| 563 |
|
| 564 |
-
# Determine small expert count for this layer
|
| 565 |
if in_second_half:
|
| 566 |
second_half_idx = layer_idx - (self.total_layers // 2)
|
| 567 |
num_second_half_blocks = self.total_layers - (self.total_layers // 2)
|
| 568 |
-
|
| 569 |
if config.small_expert_strategy == "constant":
|
| 570 |
self.num_small_experts = config.max_small_expert_count // num_second_half_blocks
|
| 571 |
elif config.small_expert_strategy == "increment":
|
| 572 |
-
# Linearly scale small experts from 1 to max_small_expert_count
|
| 573 |
self.num_small_experts = (
|
| 574 |
-
(second_half_idx + 1) * config.max_small_expert_count //
|
|
|
|
| 575 |
)
|
| 576 |
else:
|
| 577 |
raise ValueError(f"Unknown strategy: {config.small_expert_strategy}")
|
|
@@ -584,20 +582,19 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 584 |
]) if self.num_small_experts > 0 else None
|
| 585 |
|
| 586 |
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
self.small_gate = nn.Linear(config.hidden_size, self.num_small_experts, bias=False)
|
| 590 |
-
else:
|
| 591 |
-
self.small_gate = None
|
| 592 |
|
| 593 |
self.small_expert_sparsity_coef = config.small_expert_sparsity_coef
|
| 594 |
|
|
|
|
|
|
|
|
|
|
| 595 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 596 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 597 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 598 |
|
| 599 |
router_logits = self.gate(hidden_states)
|
| 600 |
-
|
| 601 |
if self.num_small_experts > 0:
|
| 602 |
small_router_logits = self.small_gate(hidden_states)
|
| 603 |
combined_logits = torch.cat([router_logits, small_router_logits], dim=-1)
|
|
@@ -607,6 +604,12 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 607 |
routing_probs = F.softmax(combined_logits, dim=1, dtype=torch.float)
|
| 608 |
routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
|
| 609 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
if self.norm_topk_prob:
|
| 611 |
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
| 612 |
|
|
@@ -632,6 +635,15 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 632 |
|
| 633 |
return final_hidden_states.view(batch_size, sequence_length, hidden_dim), combined_logits
|
| 634 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
|
| 636 |
class OlmoeDecoderLayer(nn.Module):
|
| 637 |
def __init__(self, config: OlmoeConfig, layer_idx: int):
|
|
|
|
| 1 |
import math
|
| 2 |
from typing import List, Optional, Tuple, Union
|
| 3 |
+
from collections import defaultdict
|
| 4 |
import torch
|
| 5 |
import torch.nn.functional as F
|
| 6 |
import torch.utils.checkpoint
|
|
|
|
| 559 |
self.top_k = config.num_experts_per_tok
|
| 560 |
self.norm_topk_prob = config.norm_topk_prob
|
| 561 |
|
|
|
|
| 562 |
in_second_half = layer_idx >= self.total_layers // 2
|
| 563 |
|
|
|
|
| 564 |
if in_second_half:
|
| 565 |
second_half_idx = layer_idx - (self.total_layers // 2)
|
| 566 |
num_second_half_blocks = self.total_layers - (self.total_layers // 2)
|
|
|
|
| 567 |
if config.small_expert_strategy == "constant":
|
| 568 |
self.num_small_experts = config.max_small_expert_count // num_second_half_blocks
|
| 569 |
elif config.small_expert_strategy == "increment":
|
|
|
|
| 570 |
self.num_small_experts = (
|
| 571 |
+
(second_half_idx + 1) * config.max_small_expert_count //
|
| 572 |
+
((num_second_half_blocks * (num_second_half_blocks + 1)) // 2)
|
| 573 |
)
|
| 574 |
else:
|
| 575 |
raise ValueError(f"Unknown strategy: {config.small_expert_strategy}")
|
|
|
|
| 582 |
]) if self.num_small_experts > 0 else None
|
| 583 |
|
| 584 |
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
| 585 |
+
self.small_gate = nn.Linear(config.hidden_size, self.num_small_experts, bias=False) \
|
| 586 |
+
if self.num_small_experts > 0 else None
|
|
|
|
|
|
|
|
|
|
| 587 |
|
| 588 |
self.small_expert_sparsity_coef = config.small_expert_sparsity_coef
|
| 589 |
|
| 590 |
+
# Usage tracking (not a buffer, no gradient)
|
| 591 |
+
self.expert_usage = defaultdict(int)
|
| 592 |
+
|
| 593 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 594 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 595 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 596 |
|
| 597 |
router_logits = self.gate(hidden_states)
|
|
|
|
| 598 |
if self.num_small_experts > 0:
|
| 599 |
small_router_logits = self.small_gate(hidden_states)
|
| 600 |
combined_logits = torch.cat([router_logits, small_router_logits], dim=-1)
|
|
|
|
| 604 |
routing_probs = F.softmax(combined_logits, dim=1, dtype=torch.float)
|
| 605 |
routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
|
| 606 |
|
| 607 |
+
# Track expert usage
|
| 608 |
+
for i in range(selected_experts.size(0)):
|
| 609 |
+
for j in range(self.top_k):
|
| 610 |
+
expert_id = selected_experts[i, j].item()
|
| 611 |
+
self.expert_usage[expert_id] += 1
|
| 612 |
+
|
| 613 |
if self.norm_topk_prob:
|
| 614 |
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
| 615 |
|
|
|
|
| 635 |
|
| 636 |
return final_hidden_states.view(batch_size, sequence_length, hidden_dim), combined_logits
|
| 637 |
|
| 638 |
+
def __del__(self):
|
| 639 |
+
if self.expert_usage:
|
| 640 |
+
print(f"\n[Expert Usage Report for Layer {self.layer_idx}]")
|
| 641 |
+
total = sum(self.expert_usage.values())
|
| 642 |
+
for expert_id in sorted(self.expert_usage):
|
| 643 |
+
count = self.expert_usage[expert_id]
|
| 644 |
+
percent = 100.0 * count / total if total > 0 else 0.0
|
| 645 |
+
print(f" Expert {expert_id:2d}: {count} times ({percent:.2f}%)")
|
| 646 |
+
|
| 647 |
|
| 648 |
class OlmoeDecoderLayer(nn.Module):
|
| 649 |
def __init__(self, config: OlmoeConfig, layer_idx: int):
|