if001 commited on
Commit
daf451b
·
1 Parent(s): d799841

add v2 gready

Browse files
Files changed (1) hide show
  1. modeling_deepseek.py +31 -2
modeling_deepseek.py CHANGED
@@ -398,7 +398,8 @@ class MoEGate(nn.Module):
398
  self.n_routed_experts = config.n_routed_experts
399
  self.routed_scaling_factor = config.routed_scaling_factor
400
  self.scoring_func = config.scoring_func
401
- self.topk_method = config.topk_method
 
402
  self.n_group = config.n_group
403
  self.topk_group = config.topk_group
404
 
@@ -459,6 +460,14 @@ class MoEGate(nn.Module):
459
  tmp_scores, k=self.top_k, dim=-1, sorted=False
460
  )
461
  topk_weight = scores.gather(1, topk_idx)
 
 
 
 
 
 
 
 
462
  else:
463
  raise NotImplementedError(
464
  f"insupportable TopK function for MoE gating: {self.topk_method}"
@@ -528,11 +537,31 @@ class DeepseekV3MoE(nn.Module):
528
  if not self.training:
529
  y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
530
  else:
531
- y = self.moe_train(hidden_states, topk_idx, topk_weight).view(*orig_shape)
 
532
  if self.config.n_shared_experts is not None:
533
  y = y + self.shared_experts(identity)
534
  return y
535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  def moe_train(self, x, topk_ids, topk_weight):
537
  """
538
  self.ep_size = 1の想定
 
398
  self.n_routed_experts = config.n_routed_experts
399
  self.routed_scaling_factor = config.routed_scaling_factor
400
  self.scoring_func = config.scoring_func
401
+ # self.topk_method = config.topk_method
402
+ self.topk_method == "gready"
403
  self.n_group = config.n_group
404
  self.topk_group = config.topk_group
405
 
 
460
  tmp_scores, k=self.top_k, dim=-1, sorted=False
461
  )
462
  topk_weight = scores.gather(1, topk_idx)
463
+ elif self.topk_method == "gready":
464
+ """
465
+ impl from deepseek v2
466
+ https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py#L435
467
+ """
468
+ topk_weight, topk_idx = torch.topk(
469
+ scores, k=self.top_k, dim=-1, sorted=False
470
+ )
471
  else:
472
  raise NotImplementedError(
473
  f"insupportable TopK function for MoE gating: {self.topk_method}"
 
537
  if not self.training:
538
  y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
539
  else:
540
+ # y = self.moe_train(hidden_states, topk_idx, topk_weight).view(*orig_shape)
541
+ y = self.moe_train_v2(hidden_states, topk_idx, topk_weight).view(*orig_shape)
542
  if self.config.n_shared_experts is not None:
543
  y = y + self.shared_experts(identity)
544
  return y
545
 
546
+ def moe_train_v2(self, hidden_states, topk_idx, topk_weight):
547
+ """
548
+ impl from deepseek v2
549
+ https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py#L566
550
+ """
551
+ flat_topk_idx = topk_idx.view(-1)
552
+
553
+ hidden_states = hidden_states.repeat_interleave(
554
+ self.num_experts_per_tok, dim=0
555
+ )
556
+ y = torch.empty_like(hidden_states)
557
+ for i, expert in enumerate(self.experts):
558
+ y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
559
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
560
+ y = y.type(hidden_states.dtype)
561
+ # y = y.view(*orig_shape)
562
+ # y = AddAuxiliaryLoss.apply(y, aux_loss)
563
+ return y
564
+
565
  def moe_train(self, x, topk_ids, topk_weight):
566
  """
567
  self.ep_size = 1の想定