multinomial8
Browse files- myolmoe/config.json +1 -1
- myolmoe/modeling_myolmoe.py +3 -3
myolmoe/config.json
CHANGED
|
@@ -15,7 +15,7 @@
|
|
| 15 |
"norm_topk_prob": false,
|
| 16 |
"num_attention_heads": 16,
|
| 17 |
"num_experts": 64,
|
| 18 |
-
"num_experts_per_tok":
|
| 19 |
"num_hidden_layers": 16,
|
| 20 |
"num_key_value_heads": 16,
|
| 21 |
"output_router_logits": false,
|
|
|
|
| 15 |
"norm_topk_prob": false,
|
| 16 |
"num_attention_heads": 16,
|
| 17 |
"num_experts": 64,
|
| 18 |
+
"num_experts_per_tok": 8,
|
| 19 |
"num_hidden_layers": 16,
|
| 20 |
"num_key_value_heads": 16,
|
| 21 |
"output_router_logits": false,
|
myolmoe/modeling_myolmoe.py
CHANGED
|
@@ -462,9 +462,9 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 462 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 463 |
router_logits = self.gate(hidden_states)
|
| 464 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
if self.norm_topk_prob:
|
| 469 |
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
| 470 |
routing_weights = routing_weights.to(hidden_states.dtype)
|
|
|
|
| 462 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 463 |
router_logits = self.gate(hidden_states)
|
| 464 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
| 465 |
+
selected_experts = torch.multinomial(routing_weights, self.top_k, replacement=False)
|
| 466 |
+
routing_weights = routing_weights.gather(1, selected_experts)
|
| 467 |
+
|
| 468 |
if self.norm_topk_prob:
|
| 469 |
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
| 470 |
routing_weights = routing_weights.to(hidden_states.dtype)
|