Charlie81 commited on
Commit
4d16af6
·
1 Parent(s): 858f5a5

change back to topk

Browse files
Files changed (1) hide show
  1. myolmoe/modeling_myolmoe.py +5 -2
myolmoe/modeling_myolmoe.py CHANGED
@@ -462,8 +462,11 @@ class OlmoeSparseMoeBlock(nn.Module):
462
  hidden_states = hidden_states.view(-1, hidden_dim)
463
  router_logits = self.gate(hidden_states)
464
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
465
- selected_experts = torch.multinomial(routing_weights, self.top_k, replacement=False)
466
- routing_weights = routing_weights.gather(1, selected_experts)
 
 
 
467
 
468
  if self.norm_topk_prob:
469
  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
 
462
  hidden_states = hidden_states.view(-1, hidden_dim)
463
  router_logits = self.gate(hidden_states)
464
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
465
+ routing_weights, selected_experts = torch.topk(
466
+ routing_weights, self.top_k, dim=-1
467
+ )
468
+ # selected_experts = torch.multinomial(routing_weights, self.top_k, replacement=False)
469
+ # routing_weights = routing_weights.gather(1, selected_experts)
470
 
471
  if self.norm_topk_prob:
472
  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)