Update modeling_deepseek.py
Browse files- modeling_deepseek.py +86 -11
modeling_deepseek.py
CHANGED
|
@@ -521,12 +521,24 @@ class AddAuxiliaryLoss(torch.autograd.Function):
|
|
| 521 |
class DeepseekV2MoE(nn.Module):
|
| 522 |
"""
|
| 523 |
A mixed expert module containing shared experts.
|
|
|
|
| 524 |
"""
|
| 525 |
|
| 526 |
-
def __init__(self, config):
|
| 527 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
self.config = config
|
| 529 |
self.num_experts_per_tok = config.num_experts_per_tok
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
|
| 531 |
if hasattr(config, "ep_size") and config.ep_size > 1:
|
| 532 |
assert config.ep_size == dist.get_world_size()
|
|
@@ -565,24 +577,87 @@ class DeepseekV2MoE(nn.Module):
|
|
| 565 |
config=config, intermediate_size=intermediate_size
|
| 566 |
)
|
| 567 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
def forward(self, hidden_states):
|
| 569 |
identity = hidden_states
|
| 570 |
orig_shape = hidden_states.shape
|
|
|
|
|
|
|
| 571 |
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
| 572 |
-
|
| 573 |
-
|
|
|
|
| 574 |
if self.training:
|
| 575 |
-
|
| 576 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
)
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
y = AddAuxiliaryLoss.apply(y, aux_loss)
|
| 584 |
else:
|
| 585 |
-
|
|
|
|
|
|
|
|
|
|
| 586 |
if self.config.n_shared_experts is not None:
|
| 587 |
y = y + self.shared_experts(identity)
|
| 588 |
return y
|
|
|
|
| 521 |
class DeepseekV2MoE(nn.Module):
|
| 522 |
"""
|
| 523 |
A mixed expert module containing shared experts.
|
| 524 |
+
Modified to use Default MoE for dense backpropagation.
|
| 525 |
"""
|
| 526 |
|
| 527 |
+
def __init__(self, config, beta=0.9):
|
| 528 |
super().__init__()
|
| 529 |
+
print("=" * 80)
|
| 530 |
+
print("初始化 Default MoE 版本的 DeepseekV2MoE")
|
| 531 |
+
print(f" - 路由专家数量: {config.n_routed_experts}")
|
| 532 |
+
print(f" - Top-K: {config.num_experts_per_tok}")
|
| 533 |
+
print(f" - EMA beta: {beta}")
|
| 534 |
+
print("=" * 80)
|
| 535 |
+
|
| 536 |
self.config = config
|
| 537 |
self.num_experts_per_tok = config.num_experts_per_tok
|
| 538 |
+
self.n_routed_experts = config.n_routed_experts
|
| 539 |
+
|
| 540 |
+
# Default MoE: EMA 参数
|
| 541 |
+
self.beta = beta
|
| 542 |
|
| 543 |
if hasattr(config, "ep_size") and config.ep_size > 1:
|
| 544 |
assert config.ep_size == dist.get_world_size()
|
|
|
|
| 577 |
config=config, intermediate_size=intermediate_size
|
| 578 |
)
|
| 579 |
|
| 580 |
+
# Default MoE: 为每个路由专家注册 default vector
|
| 581 |
+
# persistent=False: 不保存到 checkpoint,避免兼容性问题
|
| 582 |
+
for expert_idx in range(config.n_routed_experts):
|
| 583 |
+
self.register_buffer(
|
| 584 |
+
f'default_vector_{expert_idx}',
|
| 585 |
+
torch.zeros(config.hidden_size),
|
| 586 |
+
persistent=False
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
def forward(self, hidden_states):
|
| 590 |
identity = hidden_states
|
| 591 |
orig_shape = hidden_states.shape
|
| 592 |
+
bsz, seq_len, hidden_dim = hidden_states.shape
|
| 593 |
+
|
| 594 |
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
| 595 |
+
flat_hidden = hidden_states.view(-1, hidden_dim)
|
| 596 |
+
N_tokens = flat_hidden.size(0)
|
| 597 |
+
|
| 598 |
if self.training:
|
| 599 |
+
# ========== Default MoE 训练逻辑 ==========
|
| 600 |
+
dtype = hidden_states.dtype
|
| 601 |
+
device = hidden_states.device
|
| 602 |
+
|
| 603 |
+
# 重新计算完整的 routing weights (所有专家的 softmax)
|
| 604 |
+
# 这样未激活的专家也能接收梯度信号
|
| 605 |
+
router_logits = F.linear(
|
| 606 |
+
flat_hidden.type(torch.float32),
|
| 607 |
+
self.gate.weight.type(torch.float32),
|
| 608 |
+
None
|
| 609 |
)
|
| 610 |
+
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
|
| 611 |
+
routing_weights = routing_weights.to(dtype=dtype)
|
| 612 |
+
|
| 613 |
+
final_output = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
|
| 614 |
+
|
| 615 |
+
# 遍历每个专家
|
| 616 |
+
for expert_idx in range(self.n_routed_experts):
|
| 617 |
+
expert_layer = self.experts[expert_idx]
|
| 618 |
+
|
| 619 |
+
# 获取该专家的 default vector
|
| 620 |
+
default_vector = getattr(self, f'default_vector_{expert_idx}').to(dtype=dtype)
|
| 621 |
+
|
| 622 |
+
# 找出哪些 token 激活了这个专家 (在 top-K 中)
|
| 623 |
+
matches = (topk_idx == expert_idx) # (N_tokens, top_k)
|
| 624 |
+
is_activated = matches.any(dim=1) # (N_tokens,)
|
| 625 |
+
|
| 626 |
+
if is_activated.any():
|
| 627 |
+
# ===== 激活的 tokens: 计算真实专家输出 =====
|
| 628 |
+
activated_token_indices = torch.where(is_activated)[0]
|
| 629 |
+
activated_inputs = flat_hidden[activated_token_indices]
|
| 630 |
+
|
| 631 |
+
# 计算真实专家输出
|
| 632 |
+
real_expert_output = expert_layer(activated_inputs)
|
| 633 |
+
real_expert_output = real_expert_output.to(dtype=dtype)
|
| 634 |
+
|
| 635 |
+
# ===== 更新该专家的 EMA (仅训练模式) =====
|
| 636 |
+
mean_output = real_expert_output.mean(dim=0).detach()
|
| 637 |
+
new_default = self.beta * default_vector + (1 - self.beta) * mean_output
|
| 638 |
+
getattr(self, f'default_vector_{expert_idx}').copy_(new_default)
|
| 639 |
+
|
| 640 |
+
# ===== 累加真实输出 (使用归一化的 topk_weight) =====
|
| 641 |
+
token_indices, k_indices = torch.where(matches)
|
| 642 |
+
if len(token_indices) > 0:
|
| 643 |
+
# 使用 topk_weight (已归一化的权重)
|
| 644 |
+
weights = topk_weight[token_indices, k_indices, None]
|
| 645 |
+
weighted_output = real_expert_output * weights
|
| 646 |
+
final_output.index_add_(0, token_indices, weighted_output.to(dtype))
|
| 647 |
+
|
| 648 |
+
# ===== 对未激活的 tokens,累加 default vector (使用原始 softmax 权重) =====
|
| 649 |
+
non_activated_indices = torch.where(~is_activated)[0]
|
| 650 |
+
if len(non_activated_indices) > 0:
|
| 651 |
+
weights_non_activated = routing_weights[non_activated_indices, expert_idx].unsqueeze(-1)
|
| 652 |
+
final_output[non_activated_indices] += weights_non_activated * default_vector
|
| 653 |
+
|
| 654 |
+
y = final_output.view(*orig_shape)
|
| 655 |
y = AddAuxiliaryLoss.apply(y, aux_loss)
|
| 656 |
else:
|
| 657 |
+
# ========== 推理时使用原始的高效实现 ==========
|
| 658 |
+
y = self.moe_infer(flat_hidden, topk_idx, topk_weight).view(*orig_shape)
|
| 659 |
+
|
| 660 |
+
# 添加共享专家输出
|
| 661 |
if self.config.n_shared_experts is not None:
|
| 662 |
y = y + self.shared_experts(identity)
|
| 663 |
return y
|