corrected depth routing
Browse files
myolmoe/modeling_myolmoe.py
CHANGED
|
@@ -483,11 +483,18 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 483 |
selected_experts = sorted_indices[:, ::self.n_step][:, :self.top_k]
|
| 484 |
routing_weights = routing_probs.gather(1, selected_experts)
|
| 485 |
elif self.routing_type == "depthconstant":
|
| 486 |
-
|
|
|
|
|
|
|
| 487 |
routing_weights, selected_experts = torch.topk(routing_probs, effective_top_k, dim=-1)
|
| 488 |
elif self.routing_type == "depthlatter":
|
| 489 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
routing_weights, selected_experts = torch.topk(routing_probs, effective_top_k, dim=-1)
|
|
|
|
| 491 |
else:
|
| 492 |
raise ValueError(f"Unknown routing type: {self.routing_type}")
|
| 493 |
|
|
|
|
| 483 |
selected_experts = sorted_indices[:, ::self.n_step][:, :self.top_k]
|
| 484 |
routing_weights = routing_probs.gather(1, selected_experts)
|
| 485 |
elif self.routing_type == "depthconstant":
|
| 486 |
+
# Assumes there are 16 layers
|
| 487 |
+
slope = (self.top_k - 1) / 15
|
| 488 |
+
effective_top_k = max(1, round(self.top_k - self.layer_idx * slope))
|
| 489 |
routing_weights, selected_experts = torch.topk(routing_probs, effective_top_k, dim=-1)
|
| 490 |
elif self.routing_type == "depthlatter":
|
| 491 |
+
if self.layer_idx < 8:
|
| 492 |
+
effective_top_k = self.top_k
|
| 493 |
+
else:
|
| 494 |
+
slope = (self.top_k - 1) / 7
|
| 495 |
+
effective_top_k = max(1, round(self.top_k - (self.layer_idx - 8) * slope))
|
| 496 |
routing_weights, selected_experts = torch.topk(routing_probs, effective_top_k, dim=-1)
|
| 497 |
+
|
| 498 |
else:
|
| 499 |
raise ValueError(f"Unknown routing type: {self.routing_type}")
|
| 500 |
|