autoprogrammer commited on
Commit
c9da69f
·
verified ·
1 Parent(s): 4c90484

Update modeling_deepseek.py

Browse files
Files changed (1) hide show
  1. modeling_deepseek.py +86 -11
modeling_deepseek.py CHANGED
@@ -521,12 +521,24 @@ class AddAuxiliaryLoss(torch.autograd.Function):
521
  class DeepseekV2MoE(nn.Module):
522
  """
523
  A mixed expert module containing shared experts.
 
524
  """
525
 
526
- def __init__(self, config):
527
  super().__init__()
 
 
 
 
 
 
 
528
  self.config = config
529
  self.num_experts_per_tok = config.num_experts_per_tok
 
 
 
 
530
 
531
  if hasattr(config, "ep_size") and config.ep_size > 1:
532
  assert config.ep_size == dist.get_world_size()
@@ -565,24 +577,87 @@ class DeepseekV2MoE(nn.Module):
565
  config=config, intermediate_size=intermediate_size
566
  )
567
 
 
 
 
 
 
 
 
 
 
568
  def forward(self, hidden_states):
569
  identity = hidden_states
570
  orig_shape = hidden_states.shape
 
 
571
  topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
572
- hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
573
- flat_topk_idx = topk_idx.view(-1)
 
574
  if self.training:
575
- hidden_states = hidden_states.repeat_interleave(
576
- self.num_experts_per_tok, dim=0
 
 
 
 
 
 
 
 
577
  )
578
- y = torch.empty_like(hidden_states)
579
- for i, expert in enumerate(self.experts):
580
- y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
581
- y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
582
- y = y.to(hidden_states.dtype).view(*orig_shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  y = AddAuxiliaryLoss.apply(y, aux_loss)
584
  else:
585
- y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
 
 
 
586
  if self.config.n_shared_experts is not None:
587
  y = y + self.shared_experts(identity)
588
  return y
 
521
  class DeepseekV2MoE(nn.Module):
522
  """
523
  A mixed expert module containing shared experts.
524
+ Modified to use Default MoE for dense backpropagation.
525
  """
526
 
527
+ def __init__(self, config, beta=0.9):
528
  super().__init__()
529
+ print("=" * 80)
530
+ print("初始化 Default MoE 版本的 DeepseekV2MoE")
531
+ print(f" - 路由专家数量: {config.n_routed_experts}")
532
+ print(f" - Top-K: {config.num_experts_per_tok}")
533
+ print(f" - EMA beta: {beta}")
534
+ print("=" * 80)
535
+
536
  self.config = config
537
  self.num_experts_per_tok = config.num_experts_per_tok
538
+ self.n_routed_experts = config.n_routed_experts
539
+
540
+ # Default MoE: EMA 参数
541
+ self.beta = beta
542
 
543
  if hasattr(config, "ep_size") and config.ep_size > 1:
544
  assert config.ep_size == dist.get_world_size()
 
577
  config=config, intermediate_size=intermediate_size
578
  )
579
 
580
+ # Default MoE: 为每个路由专家注册 default vector
581
+ # persistent=False: 不保存到 checkpoint,避免兼容性问题
582
+ for expert_idx in range(config.n_routed_experts):
583
+ self.register_buffer(
584
+ f'default_vector_{expert_idx}',
585
+ torch.zeros(config.hidden_size),
586
+ persistent=False
587
+ )
588
+
589
  def forward(self, hidden_states):
590
  identity = hidden_states
591
  orig_shape = hidden_states.shape
592
+ bsz, seq_len, hidden_dim = hidden_states.shape
593
+
594
  topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
595
+ flat_hidden = hidden_states.view(-1, hidden_dim)
596
+ N_tokens = flat_hidden.size(0)
597
+
598
  if self.training:
599
+ # ========== Default MoE 训练逻辑 ==========
600
+ dtype = hidden_states.dtype
601
+ device = hidden_states.device
602
+
603
+ # 重新计算完整的 routing weights (所有专家的 softmax)
604
+ # 这样未激活的专家也能接收梯度信号
605
+ router_logits = F.linear(
606
+ flat_hidden.type(torch.float32),
607
+ self.gate.weight.type(torch.float32),
608
+ None
609
  )
610
+ routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
611
+ routing_weights = routing_weights.to(dtype=dtype)
612
+
613
+ final_output = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
614
+
615
+ # 遍历每个专家
616
+ for expert_idx in range(self.n_routed_experts):
617
+ expert_layer = self.experts[expert_idx]
618
+
619
+ # 获取该专家的 default vector
620
+ default_vector = getattr(self, f'default_vector_{expert_idx}').to(dtype=dtype)
621
+
622
+ # 找出哪些 token 激活了这个专家 (在 top-K 中)
623
+ matches = (topk_idx == expert_idx) # (N_tokens, top_k)
624
+ is_activated = matches.any(dim=1) # (N_tokens,)
625
+
626
+ if is_activated.any():
627
+ # ===== 激活的 tokens: 计算真实专家输出 =====
628
+ activated_token_indices = torch.where(is_activated)[0]
629
+ activated_inputs = flat_hidden[activated_token_indices]
630
+
631
+ # 计算真实专家输出
632
+ real_expert_output = expert_layer(activated_inputs)
633
+ real_expert_output = real_expert_output.to(dtype=dtype)
634
+
635
+ # ===== 更新该专家的 EMA (仅训练模式) =====
636
+ mean_output = real_expert_output.mean(dim=0).detach()
637
+ new_default = self.beta * default_vector + (1 - self.beta) * mean_output
638
+ getattr(self, f'default_vector_{expert_idx}').copy_(new_default)
639
+
640
+ # ===== 累加真实输出 (使用归一化的 topk_weight) =====
641
+ token_indices, k_indices = torch.where(matches)
642
+ if len(token_indices) > 0:
643
+ # 使用 topk_weight (已归一化的权重)
644
+ weights = topk_weight[token_indices, k_indices, None]
645
+ weighted_output = real_expert_output * weights
646
+ final_output.index_add_(0, token_indices, weighted_output.to(dtype))
647
+
648
+ # ===== 对未激活的 tokens,累加 default vector (使用原始 softmax 权重) =====
649
+ non_activated_indices = torch.where(~is_activated)[0]
650
+ if len(non_activated_indices) > 0:
651
+ weights_non_activated = routing_weights[non_activated_indices, expert_idx].unsqueeze(-1)
652
+ final_output[non_activated_indices] += weights_non_activated * default_vector
653
+
654
+ y = final_output.view(*orig_shape)
655
  y = AddAuxiliaryLoss.apply(y, aux_loss)
656
  else:
657
+ # ========== 推理时使用原始的高效实现 ==========
658
+ y = self.moe_infer(flat_hidden, topk_idx, topk_weight).view(*orig_shape)
659
+
660
+ # 添加共享专家输出
661
  if self.config.n_shared_experts is not None:
662
  y = y + self.shared_experts(identity)
663
  return y