add v2 gready
Browse files- modeling_deepseek.py +31 -2
modeling_deepseek.py
CHANGED
|
@@ -398,7 +398,8 @@ class MoEGate(nn.Module):
|
|
| 398 |
self.n_routed_experts = config.n_routed_experts
|
| 399 |
self.routed_scaling_factor = config.routed_scaling_factor
|
| 400 |
self.scoring_func = config.scoring_func
|
| 401 |
-
self.topk_method = config.topk_method
|
|
|
|
| 402 |
self.n_group = config.n_group
|
| 403 |
self.topk_group = config.topk_group
|
| 404 |
|
|
@@ -459,6 +460,14 @@ class MoEGate(nn.Module):
|
|
| 459 |
tmp_scores, k=self.top_k, dim=-1, sorted=False
|
| 460 |
)
|
| 461 |
topk_weight = scores.gather(1, topk_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
else:
|
| 463 |
raise NotImplementedError(
|
| 464 |
f"insupportable TopK function for MoE gating: {self.topk_method}"
|
|
@@ -528,11 +537,31 @@ class DeepseekV3MoE(nn.Module):
|
|
| 528 |
if not self.training:
|
| 529 |
y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
|
| 530 |
else:
|
| 531 |
-
y = self.moe_train(hidden_states, topk_idx, topk_weight).view(*orig_shape)
|
|
|
|
| 532 |
if self.config.n_shared_experts is not None:
|
| 533 |
y = y + self.shared_experts(identity)
|
| 534 |
return y
|
| 535 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
def moe_train(self, x, topk_ids, topk_weight):
|
| 537 |
"""
|
| 538 |
self.ep_size = 1の想定
|
|
|
|
| 398 |
self.n_routed_experts = config.n_routed_experts
|
| 399 |
self.routed_scaling_factor = config.routed_scaling_factor
|
| 400 |
self.scoring_func = config.scoring_func
|
| 401 |
+
# self.topk_method = config.topk_method
|
| 402 |
+
self.topk_method == "gready"
|
| 403 |
self.n_group = config.n_group
|
| 404 |
self.topk_group = config.topk_group
|
| 405 |
|
|
|
|
| 460 |
tmp_scores, k=self.top_k, dim=-1, sorted=False
|
| 461 |
)
|
| 462 |
topk_weight = scores.gather(1, topk_idx)
|
| 463 |
+
elif self.topk_method == "gready":
|
| 464 |
+
"""
|
| 465 |
+
impl from deepseek v2
|
| 466 |
+
https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py#L435
|
| 467 |
+
"""
|
| 468 |
+
topk_weight, topk_idx = torch.topk(
|
| 469 |
+
scores, k=self.top_k, dim=-1, sorted=False
|
| 470 |
+
)
|
| 471 |
else:
|
| 472 |
raise NotImplementedError(
|
| 473 |
f"insupportable TopK function for MoE gating: {self.topk_method}"
|
|
|
|
| 537 |
if not self.training:
|
| 538 |
y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
|
| 539 |
else:
|
| 540 |
+
# y = self.moe_train(hidden_states, topk_idx, topk_weight).view(*orig_shape)
|
| 541 |
+
y = self.moe_train_v2(hidden_states, topk_idx, topk_weight).view(*orig_shape)
|
| 542 |
if self.config.n_shared_experts is not None:
|
| 543 |
y = y + self.shared_experts(identity)
|
| 544 |
return y
|
| 545 |
|
| 546 |
+
def moe_train_v2(self, hidden_states, topk_idx, topk_weight):
|
| 547 |
+
"""
|
| 548 |
+
impl from deepseek v2
|
| 549 |
+
https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py#L566
|
| 550 |
+
"""
|
| 551 |
+
flat_topk_idx = topk_idx.view(-1)
|
| 552 |
+
|
| 553 |
+
hidden_states = hidden_states.repeat_interleave(
|
| 554 |
+
self.num_experts_per_tok, dim=0
|
| 555 |
+
)
|
| 556 |
+
y = torch.empty_like(hidden_states)
|
| 557 |
+
for i, expert in enumerate(self.experts):
|
| 558 |
+
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
|
| 559 |
+
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
| 560 |
+
y = y.type(hidden_states.dtype)
|
| 561 |
+
# y = y.view(*orig_shape)
|
| 562 |
+
# y = AddAuxiliaryLoss.apply(y, aux_loss)
|
| 563 |
+
return y
|
| 564 |
+
|
| 565 |
def moe_train(self, x, topk_ids, topk_weight):
|
| 566 |
"""
|
| 567 |
self.ep_size = 1の想定
|