2 new depth routing types
Browse files
myolmoe/modeling_myolmoe.py
CHANGED
|
@@ -447,8 +447,9 @@ OLMOE_ATTENTION_CLASSES = {
|
|
| 447 |
|
| 448 |
|
| 449 |
class OlmoeSparseMoeBlock(nn.Module):
|
| 450 |
-
def __init__(self, config):
|
| 451 |
super().__init__()
|
|
|
|
| 452 |
self.num_experts = config.num_experts
|
| 453 |
self.top_k = config.num_experts_per_tok
|
| 454 |
self.norm_topk_prob = config.norm_topk_prob
|
|
@@ -481,6 +482,12 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 481 |
sorted_weights, sorted_indices = torch.sort(routing_probs, dim=-1, descending=True)
|
| 482 |
selected_experts = sorted_indices[:, ::self.n_step][:, :self.top_k]
|
| 483 |
routing_weights = routing_probs.gather(1, selected_experts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
else:
|
| 485 |
raise ValueError(f"Unknown routing type: {self.routing_type}")
|
| 486 |
|
|
|
|
| 447 |
|
| 448 |
|
| 449 |
class OlmoeSparseMoeBlock(nn.Module):
|
| 450 |
+
def __init__(self, config, layer_idx: int):
|
| 451 |
super().__init__()
|
| 452 |
+
self.layer_idx = layer_idx
|
| 453 |
self.num_experts = config.num_experts
|
| 454 |
self.top_k = config.num_experts_per_tok
|
| 455 |
self.norm_topk_prob = config.norm_topk_prob
|
|
|
|
| 482 |
sorted_weights, sorted_indices = torch.sort(routing_probs, dim=-1, descending=True)
|
| 483 |
selected_experts = sorted_indices[:, ::self.n_step][:, :self.top_k]
|
| 484 |
routing_weights = routing_probs.gather(1, selected_experts)
|
| 485 |
+
elif self.routing_type == "depthconstant":
|
| 486 |
+
effective_top_k = max(1, self.top_k - (self.layer_idx // 2))
|
| 487 |
+
routing_weights, selected_experts = torch.topk(routing_probs, effective_top_k, dim=-1)
|
| 488 |
+
elif self.routing_type == "depthlatter":
|
| 489 |
+
effective_top_k = self.top_k if self.layer_idx < 8 else max(1, self.top_k + 8 - self.layer_idx)
|
| 490 |
+
routing_weights, selected_experts = torch.topk(routing_probs, effective_top_k, dim=-1)
|
| 491 |
else:
|
| 492 |
raise ValueError(f"Unknown routing type: {self.routing_type}")
|
| 493 |
|