Charlie81 commited on
Commit
3aa53b4
·
1 Parent(s): 4d16af6

set to multinomial top k

Browse files
Files changed (1) hide show
  1. myolmoe/modeling_myolmoe.py +5 -5
myolmoe/modeling_myolmoe.py CHANGED
@@ -462,11 +462,11 @@ class OlmoeSparseMoeBlock(nn.Module):
462
  hidden_states = hidden_states.view(-1, hidden_dim)
463
  router_logits = self.gate(hidden_states)
464
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
465
- routing_weights, selected_experts = torch.topk(
466
- routing_weights, self.top_k, dim=-1
467
- )
468
- # selected_experts = torch.multinomial(routing_weights, self.top_k, replacement=False)
469
- # routing_weights = routing_weights.gather(1, selected_experts)
470
 
471
  if self.norm_topk_prob:
472
  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
 
462
  hidden_states = hidden_states.view(-1, hidden_dim)
463
  router_logits = self.gate(hidden_states)
464
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
465
+ # routing_weights, selected_experts = torch.topk(
466
+ # routing_weights, self.top_k, dim=-1
467
+ # )
468
+ selected_experts = torch.multinomial(routing_weights, self.top_k, replacement=False)
469
+ routing_weights = routing_weights.gather(1, selected_experts)
470
 
471
  if self.norm_topk_prob:
472
  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)