Charlie81 commited on
Commit
67ed347
·
1 Parent(s): 4803c83

multinomial8

Browse files
myolmoe/config.json CHANGED
@@ -15,7 +15,7 @@
15
  "norm_topk_prob": false,
16
  "num_attention_heads": 16,
17
  "num_experts": 64,
18
- "num_experts_per_tok": 1,
19
  "num_hidden_layers": 16,
20
  "num_key_value_heads": 16,
21
  "output_router_logits": false,
 
15
  "norm_topk_prob": false,
16
  "num_attention_heads": 16,
17
  "num_experts": 64,
18
+ "num_experts_per_tok": 8,
19
  "num_hidden_layers": 16,
20
  "num_key_value_heads": 16,
21
  "output_router_logits": false,
myolmoe/modeling_myolmoe.py CHANGED
@@ -462,9 +462,9 @@ class OlmoeSparseMoeBlock(nn.Module):
462
  hidden_states = hidden_states.view(-1, hidden_dim)
463
  router_logits = self.gate(hidden_states)
464
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
465
- routing_weights, selected_experts = torch.topk(
466
- routing_weights, self.top_k, dim=-1
467
- )
468
  if self.norm_topk_prob:
469
  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
470
  routing_weights = routing_weights.to(hidden_states.dtype)
 
462
  hidden_states = hidden_states.view(-1, hidden_dim)
463
  router_logits = self.gate(hidden_states)
464
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
465
+ selected_experts = torch.multinomial(routing_weights, self.top_k, replacement=False)
466
+ routing_weights = routing_weights.gather(1, selected_experts)
467
+
468
  if self.norm_topk_prob:
469
  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
470
  routing_weights = routing_weights.to(hidden_states.dtype)