Commit ·
fe340b5
1
Parent(s): 1805272
use torchtitan moe impl
Browse files- modeling_deepseek.py +15 -1
modeling_deepseek.py
CHANGED
|
@@ -59,6 +59,8 @@ from .configuration_deepseek import DeepseekV3Config
|
|
| 59 |
import torch.distributed as dist
|
| 60 |
import numpy as np
|
| 61 |
|
|
|
|
|
|
|
| 62 |
if is_flash_attn_2_available():
|
| 63 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 64 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
@@ -1150,8 +1152,20 @@ class DeepseekV3DecoderLayer(nn.Module):
|
|
| 1150 |
config=config, layer_idx=layer_idx
|
| 1151 |
)
|
| 1152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1153 |
self.mlp = (
|
| 1154 |
-
|
| 1155 |
if (
|
| 1156 |
config.n_routed_experts is not None
|
| 1157 |
and layer_idx >= config.first_k_dense_replace
|
|
|
|
| 59 |
import torch.distributed as dist
|
| 60 |
import numpy as np
|
| 61 |
|
| 62 |
+
from torchtitan.models.moe import MoE, MoEArgs
|
| 63 |
+
|
| 64 |
if is_flash_attn_2_available():
|
| 65 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 66 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
|
|
| 1152 |
config=config, layer_idx=layer_idx
|
| 1153 |
)
|
| 1154 |
|
| 1155 |
+
moe_args = MoEArgs(
|
| 1156 |
+
num_experts=config.n_routed_experts,
|
| 1157 |
+
num_shared_experts=config.n_shared_experts,
|
| 1158 |
+
score_func=config.scoring_func,
|
| 1159 |
+
route_norm=config.norm_topk_prob,
|
| 1160 |
+
route_scale=config.routed_scaling_factor,
|
| 1161 |
+
score_before_experts=False,
|
| 1162 |
+
top_k=config.num_experts_per_tok,
|
| 1163 |
+
use_grouped_mm=True,
|
| 1164 |
+
load_balance_coeff=1e-3,
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
self.mlp = (
|
| 1168 |
+
MoE(moe_args, dim=config.hidden_size, hidden_dim=config.moe_intermediate_size)
|
| 1169 |
if (
|
| 1170 |
config.n_routed_experts is not None
|
| 1171 |
and layer_idx >= config.first_k_dense_replace
|