Files changed (1) hide show
  1. modeling_deepseek.py +4 -0
modeling_deepseek.py CHANGED
@@ -461,6 +461,10 @@ class MoEGate(nn.Module):
461
  tmp_scores, k=self.top_k, dim=-1, sorted=False
462
  )
463
  topk_weight = scores.gather(1, topk_idx)
 
 
 
 
464
  else:
465
  raise NotImplementedError(
466
  f"insupportable TopK function for MoE gating: {self.topk_method}"
 
461
  tmp_scores, k=self.top_k, dim=-1, sorted=False
462
  )
463
  topk_weight = scores.gather(1, topk_idx)
464
+ elif self.topk_method == "greedy":
465
+ topk_weight, topk_idx = torch.topk(
466
+ scores, k=self.top_k, dim=-1, sorted=False
467
+ )
468
  else:
469
  raise NotImplementedError(
470
  f"insupportable TopK function for MoE gating: {self.topk_method}"