set to multinomial top k
Browse files
myolmoe/modeling_myolmoe.py
CHANGED
|
@@ -462,11 +462,11 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 462 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 463 |
router_logits = self.gate(hidden_states)
|
| 464 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
| 465 |
-
routing_weights, selected_experts = torch.topk(
|
| 466 |
-
|
| 467 |
-
)
|
| 468 |
-
|
| 469 |
-
|
| 470 |
|
| 471 |
if self.norm_topk_prob:
|
| 472 |
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
|
|
| 462 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 463 |
router_logits = self.gate(hidden_states)
|
| 464 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
| 465 |
+
# routing_weights, selected_experts = torch.topk(
|
| 466 |
+
# routing_weights, self.top_k, dim=-1
|
| 467 |
+
# )
|
| 468 |
+
selected_experts = torch.multinomial(routing_weights, self.top_k, replacement=False)
|
| 469 |
+
routing_weights = routing_weights.gather(1, selected_experts)
|
| 470 |
|
| 471 |
if self.norm_topk_prob:
|
| 472 |
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|