if001 commited on
Commit
f08d3a2
·
verified ·
1 Parent(s): e47b14c
Files changed (1) hide show
  1. modeling_deepseek.py +2 -1
modeling_deepseek.py CHANGED
@@ -435,7 +435,7 @@ class MoEGate(nn.Module):
435
 
436
  ### select top-k experts
437
  if self.topk_method == "noaux_tc":
438
- assert not self.training
439
  scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
440
  group_scores = (
441
  scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1)
@@ -537,6 +537,7 @@ class DeepseekV3MoE(nn.Module):
537
  """
538
  self.ep_size = 1の想定
539
  """
 
540
  cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
541
  cnts.scatter_(1, topk_ids, 1)
542
  tokens_per_expert = cnts.sum(dim=0)
 
435
 
436
  ### select top-k experts
437
  if self.topk_method == "noaux_tc":
438
+ # assert not self.training ## for lora training
439
  scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
440
  group_scores = (
441
  scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1)
 
537
  """
538
  self.ep_size = 1の想定
539
  """
540
+ assert self.ep_size == 1
541
  cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
542
  cnts.scatter_(1, topk_ids, 1)
543
  tokens_per_expert = cnts.sum(dim=0)