Upload modeling_hunyuan.py
Browse files- modeling_hunyuan.py +303 -144
modeling_hunyuan.py
CHANGED
|
@@ -1,5 +1,16 @@
|
|
| 1 |
-
#
|
| 2 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
#
|
| 4 |
""" PyTorch HunYuan model."""
|
| 5 |
|
|
@@ -8,7 +19,6 @@ import warnings
|
|
| 8 |
from typing import List, Optional, Tuple, Union
|
| 9 |
|
| 10 |
import torch
|
| 11 |
-
torch.set_default_dtype(torch.float32)
|
| 12 |
from torch import Tensor
|
| 13 |
import torch.nn.functional as F
|
| 14 |
import torch.utils.checkpoint
|
|
@@ -64,7 +74,8 @@ _CONFIG_FOR_DOC = "HunYuanConfig"
|
|
| 64 |
def topkgating(logits: Tensor, topk: int):
|
| 65 |
logits = logits.float()
|
| 66 |
gates = F.softmax(logits, dim=1)
|
| 67 |
-
expert_capacity = topk * gates.shape[0]
|
|
|
|
| 68 |
num_experts = int(gates.shape[1])
|
| 69 |
# Top-k router probability and corresponding expert indices for each token.
|
| 70 |
# Shape: [tokens_per_group, num_selected_experts].
|
|
@@ -254,7 +265,7 @@ class HunYuanRotaryEmbedding(nn.Module):
|
|
| 254 |
self.max_position_embeddings = max_position_embeddings
|
| 255 |
self.base = base
|
| 256 |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 257 |
-
inv_freq = inv_freq.bfloat16()
|
| 258 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 259 |
|
| 260 |
# Build here to make `torch.jit.trace` work.
|
|
@@ -266,6 +277,7 @@ class HunYuanRotaryEmbedding(nn.Module):
|
|
| 266 |
self.max_seq_len_cached = seq_len
|
| 267 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
|
| 268 |
|
|
|
|
| 269 |
freqs = torch.outer(t, self.inv_freq)
|
| 270 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 271 |
emb = torch.cat((freqs, freqs), dim=-1).float()
|
|
@@ -274,7 +286,7 @@ class HunYuanRotaryEmbedding(nn.Module):
|
|
| 274 |
|
| 275 |
def forward(self, x, seq_len=None):
|
| 276 |
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 277 |
-
if seq_len > self.max_seq_len_cached:
|
| 278 |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 279 |
|
| 280 |
return (
|
|
@@ -399,7 +411,7 @@ class HunYuanMLP(nn.Module):
|
|
| 399 |
self.layer_idx = layer_idx
|
| 400 |
self.hidden_size = config.hidden_size
|
| 401 |
if is_shared_mlp:
|
| 402 |
-
self.intermediate_size = config.intermediate_size * config.num_shared_expert
|
| 403 |
else:
|
| 404 |
self.intermediate_size = config.intermediate_size
|
| 405 |
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
@@ -450,142 +462,66 @@ class HunYuanTopKGate(nn.Module):
|
|
| 450 |
if self.moe_topk == 1:
|
| 451 |
gate_output = top1gating(logits, random_routing_dropped_token=self.random_routing_dropped_token)
|
| 452 |
else:
|
| 453 |
-
gate_output = topkgating(logits, self.moe_topk)
|
| 454 |
|
| 455 |
return gate_output
|
| 456 |
|
| 457 |
|
| 458 |
class HunYuanMoE(nn.Module):
|
| 459 |
-
"""Mixture-of-Experts block with vectorized expert execution and Straight‑Through Estimator (STE) utilities.
|
| 460 |
-
|
| 461 |
-
This implementation removes all Python‑side loops over experts. Expert parameters are **stacked** on‑the‑fly and
|
| 462 |
-
all experts are executed in a single batched matmul sequence, which allows efficient tensor‑parallel execution
|
| 463 |
-
(e.g. with DeepSpeed ZeRO‑3) while keeping the state‑dict format unchanged (each expert remains an individual
|
| 464 |
-
sub‑module for full compatibility with existing checkpoints)."""
|
| 465 |
-
|
| 466 |
def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None):
|
| 467 |
super().__init__()
|
| 468 |
self.config = config
|
| 469 |
self.layer_idx = layer_idx
|
| 470 |
self.moe_topk = config.moe_topk
|
| 471 |
self.num_experts = config.num_experts
|
| 472 |
-
|
| 473 |
-
# Optional shared MLP branch (mixed MoE + dense)
|
| 474 |
if config.use_mixed_mlp_moe:
|
| 475 |
self.shared_mlp = HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True)
|
| 476 |
-
|
| 477 |
-
# Router
|
| 478 |
self.gate = HunYuanTopKGate(config, layer_idx=layer_idx)
|
| 479 |
-
|
| 480 |
-
# Experts kept as individual sub‑modules so that load_state_dict / save_pretrained stay identical.
|
| 481 |
self.experts = nn.ModuleList(
|
| 482 |
[HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(config.num_experts)]
|
| 483 |
)
|
| 484 |
|
| 485 |
-
|
| 486 |
-
# Internal helpers
|
| 487 |
-
# ---------------------------------------------------------------------
|
| 488 |
-
def _stack_weights(self):
|
| 489 |
-
"""Return stacked (batched) expert weights for gate/up/down projections.
|
| 490 |
-
|
| 491 |
-
Shapes:
|
| 492 |
-
Wg : (E, I, H)
|
| 493 |
-
Wu : (E, I, H)
|
| 494 |
-
Wd : (E, H, I) – note transposed compared to nn.Linear's weight so that
|
| 495 |
-
torch.matmul(act, Wd.transpose(-2,‑1)) produces (E, C, H)
|
| 496 |
-
"""
|
| 497 |
-
Wg = torch.stack([exp.gate_proj.weight for exp in self.experts], dim=0)
|
| 498 |
-
Wu = torch.stack([exp.up_proj.weight for exp in self.experts], dim=0)
|
| 499 |
-
Wd = torch.stack([exp.down_proj.weight for exp in self.experts], dim=0)
|
| 500 |
-
return Wg, Wu, Wd
|
| 501 |
-
|
| 502 |
-
# ---------------------------------------------------------------------
|
| 503 |
-
# Public API
|
| 504 |
-
# ---------------------------------------------------------------------
|
| 505 |
-
def forward(
|
| 506 |
-
self,
|
| 507 |
-
hidden_states: torch.Tensor,
|
| 508 |
-
*,
|
| 509 |
-
return_router_logits: bool = False,
|
| 510 |
-
):
|
| 511 |
-
"""Sparse Top‑k MoE forward pass ("y_sparse") with optional router logits.
|
| 512 |
-
|
| 513 |
-
This is the route used during both training and inference for the *sparse*
|
| 514 |
-
branch. No Python loops over experts are used."
|
| 515 |
-
"""
|
| 516 |
bsz, seq_len, hidden_size = hidden_states.shape
|
| 517 |
|
| 518 |
-
# Optional dense branch when using mixed MoE
|
| 519 |
if self.config.use_mixed_mlp_moe:
|
| 520 |
hidden_states_mlp = self.shared_mlp(hidden_states)
|
| 521 |
|
| 522 |
-
|
| 523 |
-
l_moe, combine_weights, dispatch_mask, _ = self.gate(hidden_states)
|
| 524 |
-
|
| 525 |
-
flat_input = hidden_states.reshape(-1, hidden_size) # (S, H) where S = B*T
|
| 526 |
-
# dispatch tokens → experts
|
| 527 |
-
dispatched_input = torch.einsum("sec,sm->ecm", # (E, C, H)
|
| 528 |
-
dispatch_mask.to(hidden_states.dtype),
|
| 529 |
-
flat_input)
|
| 530 |
-
|
| 531 |
-
# ---------------- Expert computation ----------------
|
| 532 |
-
Wg, Wu, Wd = self._stack_weights()
|
| 533 |
-
Wg = Wg.to(hidden_states.dtype)
|
| 534 |
-
Wu = Wu.to(hidden_states.dtype)
|
| 535 |
-
Wd = Wd.to(hidden_states.dtype)
|
| 536 |
-
|
| 537 |
-
gate_out = torch.einsum("ech,eih->eci", dispatched_input, Wg) # (E, C, I)
|
| 538 |
-
up_out = torch.einsum("ech,eih->eci", dispatched_input, Wu) # (E, C, I)
|
| 539 |
-
act_fn = self.experts[0].act_fn
|
| 540 |
-
interm = act_fn(gate_out) * up_out # (E, C, I)
|
| 541 |
-
expert_output = torch.matmul(interm, Wd.transpose(-2, -1)) # (E, C, H)
|
| 542 |
-
|
| 543 |
-
# ---------------- Combine ----------------
|
| 544 |
-
combined_output = torch.einsum("sec,ecm->sm", # (S, H)
|
| 545 |
-
combine_weights.to(hidden_states.dtype),
|
| 546 |
-
expert_output)
|
| 547 |
-
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
|
| 548 |
|
| 549 |
-
|
| 550 |
-
combined_output = hidden_states_mlp + combined_output
|
| 551 |
|
| 552 |
-
|
| 553 |
-
router_logits = self.gate.wg(flat_input).view(bsz, seq_len, self.num_experts)
|
| 554 |
-
return combined_output, router_logits
|
| 555 |
-
return combined_output
|
| 556 |
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
def forward_all(self, hidden_states: torch.Tensor):
|
| 562 |
-
"""Compute every expert on every token ("dense" MoE path).
|
| 563 |
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
bsz, seq_len, hidden_size = hidden_states.shape
|
| 568 |
-
flat_input = hidden_states.reshape(-1, hidden_size) # (S, H)
|
| 569 |
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
|
| 575 |
-
|
| 576 |
-
up_out = torch.einsum("sh,eih->esi", flat_input, Wu) # (E, S, I)
|
| 577 |
-
act_fn = self.experts[0].act_fn
|
| 578 |
-
interm = act_fn(gate_out) * up_out # (E, S, I)
|
| 579 |
-
dense_out = torch.matmul(interm, Wd.transpose(-2, -1)) # (E, S, H)
|
| 580 |
|
| 581 |
-
dense_out = dense_out.permute(1, 0, 2).contiguous() # (S, E, H)
|
| 582 |
-
dense_out = dense_out.view(bsz, seq_len, self.num_experts, hidden_size) # (B, T, E, H)
|
| 583 |
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
|
| 588 |
-
return dense_out
|
| 589 |
class HunYuanAttention(nn.Module):
|
| 590 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 591 |
|
|
@@ -1128,18 +1064,33 @@ class HunYuanDecoderLayer(nn.Module):
|
|
| 1128 |
kv_states: Optional[Tuple[torch.Tensor]] = None,
|
| 1129 |
**kwargs,
|
| 1130 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 1131 |
-
"""
|
| 1132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1133 |
if "padding_mask" in kwargs:
|
| 1134 |
warnings.warn(
|
| 1135 |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use "
|
| 1136 |
"`attention_mask` instead.`"
|
| 1137 |
)
|
| 1138 |
-
|
| 1139 |
-
# ---------------- Self‑Attention ----------------
|
| 1140 |
residual = hidden_states
|
|
|
|
| 1141 |
hidden_states = self.input_layernorm(hidden_states)
|
| 1142 |
-
|
|
|
|
| 1143 |
hidden_states, self_attn_weights, present_key_value, kv_states = self.self_attn(
|
| 1144 |
hidden_states=hidden_states,
|
| 1145 |
attention_mask=attention_mask,
|
|
@@ -1151,40 +1102,47 @@ class HunYuanDecoderLayer(nn.Module):
|
|
| 1151 |
**kwargs,
|
| 1152 |
)
|
| 1153 |
hidden_states = residual + hidden_states
|
| 1154 |
-
|
| 1155 |
-
#
|
| 1156 |
residual = hidden_states
|
| 1157 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 1158 |
-
|
| 1159 |
-
|
| 1160 |
-
|
| 1161 |
-
y_sparse, router_logits = self.mlp(hidden_states, return_router_logits=True)
|
| 1162 |
-
|
| 1163 |
-
# Dense (all‑experts) path – memory‑efficient, no grad
|
| 1164 |
-
with torch.no_grad():
|
| 1165 |
-
y_all = self.mlp.forward_all(hidden_states) # (B, T, E, H)
|
| 1166 |
-
|
| 1167 |
-
gate = router_logits.softmax(-1).unsqueeze(-1) # (B, T, E, 1)
|
| 1168 |
-
y_dense = (gate * y_all).sum(-2) # (B, T, H)
|
| 1169 |
-
|
| 1170 |
-
mlp_out = y_dense + (y_sparse - y_dense).detach()
|
| 1171 |
-
else:
|
| 1172 |
-
mlp_out = self.mlp(hidden_states)
|
| 1173 |
-
|
| 1174 |
-
hidden_states = residual + mlp_out
|
| 1175 |
-
|
| 1176 |
-
# ---------------- Outputs ----------------
|
| 1177 |
outputs = (hidden_states,)
|
| 1178 |
-
|
| 1179 |
if output_attentions:
|
| 1180 |
outputs += (self_attn_weights,)
|
| 1181 |
-
|
| 1182 |
if use_cache:
|
| 1183 |
outputs += (present_key_value,)
|
| 1184 |
-
|
| 1185 |
outputs += (kv_states,)
|
| 1186 |
-
|
| 1187 |
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1188 |
class HunYuanPreTrainedModel(PreTrainedModel):
|
| 1189 |
config_class = HunYuanConfig
|
| 1190 |
base_model_prefix = "model"
|
|
@@ -1277,6 +1235,10 @@ HUNYUAN_INPUTS_DOCSTRING = r"""
|
|
| 1277 |
"""
|
| 1278 |
|
| 1279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1280 |
class HunYuanModel(HunYuanPreTrainedModel):
|
| 1281 |
"""
|
| 1282 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HunYuanDecoderLayer`]
|
|
@@ -1455,7 +1417,7 @@ class HunYuanModel(HunYuanPreTrainedModel):
|
|
| 1455 |
)
|
| 1456 |
|
| 1457 |
|
| 1458 |
-
class
|
| 1459 |
_tied_weights_keys = ["lm_head.weight"]
|
| 1460 |
|
| 1461 |
def __init__(self, config: HunYuanConfig):
|
|
@@ -1585,7 +1547,7 @@ class HunYuanForCausalLM(HunYuanPreTrainedModel):
|
|
| 1585 |
if isinstance(past_key_values, Cache):
|
| 1586 |
cache_length = past_key_values.get_seq_length()
|
| 1587 |
past_length = past_key_values.seen_tokens
|
| 1588 |
-
max_cache_length = past_key_values.
|
| 1589 |
else:
|
| 1590 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1591 |
max_cache_length = None
|
|
@@ -1644,6 +1606,21 @@ class HunYuanForCausalLM(HunYuanPreTrainedModel):
|
|
| 1644 |
return reordered_past
|
| 1645 |
|
| 1646 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1647 |
class HunYuanForSequenceClassification(HunYuanPreTrainedModel):
|
| 1648 |
def __init__(self, config):
|
| 1649 |
super().__init__(config)
|
|
@@ -1748,4 +1725,186 @@ class HunYuanForSequenceClassification(HunYuanPreTrainedModel):
|
|
| 1748 |
past_key_values=transformer_outputs.past_key_values,
|
| 1749 |
hidden_states=transformer_outputs.hidden_states,
|
| 1750 |
attentions=transformer_outputs.attentions,
|
| 1751 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://github.com/Tencent/Tencent-Hunyuan-Large/blob/main/License.docx
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
#
|
| 15 |
""" PyTorch HunYuan model."""
|
| 16 |
|
|
|
|
| 19 |
from typing import List, Optional, Tuple, Union
|
| 20 |
|
| 21 |
import torch
|
|
|
|
| 22 |
from torch import Tensor
|
| 23 |
import torch.nn.functional as F
|
| 24 |
import torch.utils.checkpoint
|
|
|
|
| 74 |
def topkgating(logits: Tensor, topk: int):
|
| 75 |
logits = logits.float()
|
| 76 |
gates = F.softmax(logits, dim=1)
|
| 77 |
+
# expert_capacity = topk * gates.shape[0]
|
| 78 |
+
expert_capacity = max(topk, topk * gates.shape[0] // gates.shape[1])
|
| 79 |
num_experts = int(gates.shape[1])
|
| 80 |
# Top-k router probability and corresponding expert indices for each token.
|
| 81 |
# Shape: [tokens_per_group, num_selected_experts].
|
|
|
|
| 265 |
self.max_position_embeddings = max_position_embeddings
|
| 266 |
self.base = base
|
| 267 |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 268 |
+
# inv_freq = inv_freq.bfloat16()
|
| 269 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 270 |
|
| 271 |
# Build here to make `torch.jit.trace` work.
|
|
|
|
| 277 |
self.max_seq_len_cached = seq_len
|
| 278 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
|
| 279 |
|
| 280 |
+
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 281 |
freqs = torch.outer(t, self.inv_freq)
|
| 282 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 283 |
emb = torch.cat((freqs, freqs), dim=-1).float()
|
|
|
|
| 286 |
|
| 287 |
def forward(self, x, seq_len=None):
|
| 288 |
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 289 |
+
if seq_len > self.max_seq_len_cached or self.inv_freq.dtype != torch.float32:
|
| 290 |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 291 |
|
| 292 |
return (
|
|
|
|
| 411 |
self.layer_idx = layer_idx
|
| 412 |
self.hidden_size = config.hidden_size
|
| 413 |
if is_shared_mlp:
|
| 414 |
+
self.intermediate_size = config.intermediate_size * config.num_shared_expert[0]
|
| 415 |
else:
|
| 416 |
self.intermediate_size = config.intermediate_size
|
| 417 |
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
|
|
| 462 |
if self.moe_topk == 1:
|
| 463 |
gate_output = top1gating(logits, random_routing_dropped_token=self.random_routing_dropped_token)
|
| 464 |
else:
|
| 465 |
+
gate_output = topkgating(logits, self.moe_topk[0])
|
| 466 |
|
| 467 |
return gate_output
|
| 468 |
|
| 469 |
|
| 470 |
class HunYuanMoE(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None):
|
| 472 |
super().__init__()
|
| 473 |
self.config = config
|
| 474 |
self.layer_idx = layer_idx
|
| 475 |
self.moe_topk = config.moe_topk
|
| 476 |
self.num_experts = config.num_experts
|
|
|
|
|
|
|
| 477 |
if config.use_mixed_mlp_moe:
|
| 478 |
self.shared_mlp = HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True)
|
|
|
|
|
|
|
| 479 |
self.gate = HunYuanTopKGate(config, layer_idx=layer_idx)
|
|
|
|
|
|
|
| 480 |
self.experts = nn.ModuleList(
|
| 481 |
[HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(config.num_experts)]
|
| 482 |
)
|
| 483 |
|
| 484 |
+
def forward(self, hidden_states):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
bsz, seq_len, hidden_size = hidden_states.shape
|
| 486 |
|
|
|
|
| 487 |
if self.config.use_mixed_mlp_moe:
|
| 488 |
hidden_states_mlp = self.shared_mlp(hidden_states)
|
| 489 |
|
| 490 |
+
l_moe, combine_weights, dispatch_mask, exp_counts = self.gate(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
|
| 492 |
+
reshaped_input = hidden_states.reshape(-1, hidden_size)
|
|
|
|
| 493 |
|
| 494 |
+
dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
|
|
|
|
|
|
|
|
|
|
| 495 |
|
| 496 |
+
chunks = dispatched_input.chunk(self.num_experts, dim=0)
|
| 497 |
+
expert_outputs = []
|
| 498 |
+
for chunk, expert in zip(chunks, self.experts):
|
| 499 |
+
expert_outputs.append(expert(chunk))
|
|
|
|
|
|
|
| 500 |
|
| 501 |
+
expert_output = torch.cat(expert_outputs, dim=0)
|
| 502 |
+
combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output)
|
| 503 |
+
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
|
|
|
|
|
|
|
| 504 |
|
| 505 |
+
if self.config.use_mixed_mlp_moe:
|
| 506 |
+
output = hidden_states_mlp + combined_output
|
| 507 |
+
else:
|
| 508 |
+
output = combined_output
|
| 509 |
|
| 510 |
+
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
|
|
|
|
|
|
|
| 512 |
|
| 513 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 514 |
+
"""
|
| 515 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 516 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 517 |
+
"""
|
| 518 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 519 |
+
if n_rep == 1:
|
| 520 |
+
return hidden_states
|
| 521 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 522 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 523 |
+
|
| 524 |
|
|
|
|
| 525 |
class HunYuanAttention(nn.Module):
|
| 526 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 527 |
|
|
|
|
| 1064 |
kv_states: Optional[Tuple[torch.Tensor]] = None,
|
| 1065 |
**kwargs,
|
| 1066 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 1067 |
+
"""
|
| 1068 |
+
Args:
|
| 1069 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 1070 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
| 1071 |
+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
| 1072 |
+
query_sequence_length, key_sequence_length)` if default attention is used.
|
| 1073 |
+
output_attentions (`bool`, *optional*):
|
| 1074 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 1075 |
+
returned tensors for more detail.
|
| 1076 |
+
use_cache (`bool`, *optional*):
|
| 1077 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 1078 |
+
(see `past_key_values`).
|
| 1079 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 1080 |
+
kv_states (`Tuple(torch.FloatTensor)`, *optional*): Used when CLA is enabled,
|
| 1081 |
+
key and value states from past attention blocks
|
| 1082 |
+
"""
|
| 1083 |
if "padding_mask" in kwargs:
|
| 1084 |
warnings.warn(
|
| 1085 |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use "
|
| 1086 |
"`attention_mask` instead.`"
|
| 1087 |
)
|
| 1088 |
+
|
|
|
|
| 1089 |
residual = hidden_states
|
| 1090 |
+
|
| 1091 |
hidden_states = self.input_layernorm(hidden_states)
|
| 1092 |
+
|
| 1093 |
+
# Self Attention
|
| 1094 |
hidden_states, self_attn_weights, present_key_value, kv_states = self.self_attn(
|
| 1095 |
hidden_states=hidden_states,
|
| 1096 |
attention_mask=attention_mask,
|
|
|
|
| 1102 |
**kwargs,
|
| 1103 |
)
|
| 1104 |
hidden_states = residual + hidden_states
|
| 1105 |
+
|
| 1106 |
+
# Fully Connected
|
| 1107 |
residual = hidden_states
|
| 1108 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 1109 |
+
hidden_states = self.mlp(hidden_states)
|
| 1110 |
+
hidden_states = residual + hidden_states
|
| 1111 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1112 |
outputs = (hidden_states,)
|
| 1113 |
+
|
| 1114 |
if output_attentions:
|
| 1115 |
outputs += (self_attn_weights,)
|
| 1116 |
+
|
| 1117 |
if use_cache:
|
| 1118 |
outputs += (present_key_value,)
|
| 1119 |
+
|
| 1120 |
outputs += (kv_states,)
|
| 1121 |
+
|
| 1122 |
return outputs
|
| 1123 |
+
|
| 1124 |
+
|
| 1125 |
+
HUNYUAN_START_DOCSTRING = r"""
|
| 1126 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1127 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 1128 |
+
etc.)
|
| 1129 |
+
|
| 1130 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 1131 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 1132 |
+
and behavior.
|
| 1133 |
+
|
| 1134 |
+
Parameters:
|
| 1135 |
+
config ([`HunYuanConfig`]):
|
| 1136 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 1137 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 1138 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 1139 |
+
"""
|
| 1140 |
+
|
| 1141 |
+
|
| 1142 |
+
@add_start_docstrings(
|
| 1143 |
+
"The bare HunYuan Model outputting raw hidden-states without any specific head on top.",
|
| 1144 |
+
HUNYUAN_START_DOCSTRING,
|
| 1145 |
+
)
|
| 1146 |
class HunYuanPreTrainedModel(PreTrainedModel):
|
| 1147 |
config_class = HunYuanConfig
|
| 1148 |
base_model_prefix = "model"
|
|
|
|
| 1235 |
"""
|
| 1236 |
|
| 1237 |
|
| 1238 |
+
@add_start_docstrings(
|
| 1239 |
+
"The bare HunYuan Model outputting raw hidden-states without any specific head on top.",
|
| 1240 |
+
HUNYUAN_START_DOCSTRING,
|
| 1241 |
+
)
|
| 1242 |
class HunYuanModel(HunYuanPreTrainedModel):
|
| 1243 |
"""
|
| 1244 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HunYuanDecoderLayer`]
|
|
|
|
| 1417 |
)
|
| 1418 |
|
| 1419 |
|
| 1420 |
+
class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
|
| 1421 |
_tied_weights_keys = ["lm_head.weight"]
|
| 1422 |
|
| 1423 |
def __init__(self, config: HunYuanConfig):
|
|
|
|
| 1547 |
if isinstance(past_key_values, Cache):
|
| 1548 |
cache_length = past_key_values.get_seq_length()
|
| 1549 |
past_length = past_key_values.seen_tokens
|
| 1550 |
+
max_cache_length = past_key_values.get_max_cache_shape()
|
| 1551 |
else:
|
| 1552 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1553 |
max_cache_length = None
|
|
|
|
| 1606 |
return reordered_past
|
| 1607 |
|
| 1608 |
|
| 1609 |
+
@add_start_docstrings(
|
| 1610 |
+
"""
|
| 1611 |
+
The HunYuan Model transformer with a sequence classification head on top (linear layer).
|
| 1612 |
+
|
| 1613 |
+
[`HunYuanForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
| 1614 |
+
(e.g. GPT-2) do.
|
| 1615 |
+
|
| 1616 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
| 1617 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
| 1618 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
| 1619 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
| 1620 |
+
each row of the batch).
|
| 1621 |
+
""",
|
| 1622 |
+
HUNYUAN_START_DOCSTRING,
|
| 1623 |
+
)
|
| 1624 |
class HunYuanForSequenceClassification(HunYuanPreTrainedModel):
|
| 1625 |
def __init__(self, config):
|
| 1626 |
super().__init__(config)
|
|
|
|
| 1725 |
past_key_values=transformer_outputs.past_key_values,
|
| 1726 |
hidden_states=transformer_outputs.hidden_states,
|
| 1727 |
attentions=transformer_outputs.attentions,
|
| 1728 |
+
)
|
| 1729 |
+
|
| 1730 |
+
|
| 1731 |
+
# ================================================================
|
| 1732 |
+
# Dense/Sparse MoE utilities
|
| 1733 |
+
# These enable dense training (all experts active) with sparse inference,
|
| 1734 |
+
# by fusing the per‑expert parameters into single linear layers and
|
| 1735 |
+
# applying a straight‑through estimator (STE) – gradients flow through the
|
| 1736 |
+
# dense path, while the forward pass matches sparse routing behaviour.
|
| 1737 |
+
# ================================================================
|
| 1738 |
+
|
| 1739 |
+
import types
|
| 1740 |
+
|
| 1741 |
+
class HunYuanDenseMoE(nn.Module):
|
| 1742 |
+
"""Dense counterpart of :class:`HunYuanMoE`.
|
| 1743 |
+
|
| 1744 |
+
* The per‑expert linear layers (``gate_proj``, ``up_proj``, ``down_proj``)
|
| 1745 |
+
are concatenated to form three *fused* linear layers.
|
| 1746 |
+
* Forward pass:
|
| 1747 |
+
1. **Dense path** – every expert contributes,
|
| 1748 |
+
weighted by the softmax router probabilities.
|
| 1749 |
+
2. **Sparse path** – only the Top‑K experts (identical to the
|
| 1750 |
+
original sparse MoE) are evaluated.
|
| 1751 |
+
3. **STE** – ``output = dense + (sparse – dense).detach()`` –
|
| 1752 |
+
identical values to sparse inference, dense gradients.
|
| 1753 |
+
"""
|
| 1754 |
+
|
| 1755 |
+
def __init__(self, moe: "HunYuanMoE"):
|
| 1756 |
+
super().__init__()
|
| 1757 |
+
self.config = moe.config
|
| 1758 |
+
self.layer_idx = moe.layer_idx
|
| 1759 |
+
self.num_experts = moe.num_experts
|
| 1760 |
+
# All experts share the same hidden/intermediate sizes
|
| 1761 |
+
self.hidden_size = moe.experts[0].hidden_size
|
| 1762 |
+
self.intermediate_size = moe.experts[0].intermediate_size
|
| 1763 |
+
|
| 1764 |
+
# Router is reused directly
|
| 1765 |
+
self.gate = moe.gate
|
| 1766 |
+
|
| 1767 |
+
# ------------------------------------------------------------------
|
| 1768 |
+
# Fuse per‑expert parameters
|
| 1769 |
+
# ------------------------------------------------------------------
|
| 1770 |
+
with torch.no_grad():
|
| 1771 |
+
fused_gate_w = torch.cat([exp.gate_proj.weight for exp in moe.experts], dim=0).clone()
|
| 1772 |
+
fused_up_w = torch.cat([exp.up_proj.weight for exp in moe.experts], dim=0).clone()
|
| 1773 |
+
# down_proj weights are shaped (hidden, intermediate)
|
| 1774 |
+
fused_down_w = torch.cat([exp.down_proj.weight for exp in moe.experts], dim=1).clone()
|
| 1775 |
+
|
| 1776 |
+
self.fused_gate_proj = nn.Linear(self.hidden_size,
|
| 1777 |
+
self.intermediate_size * self.num_experts,
|
| 1778 |
+
bias=False)
|
| 1779 |
+
self.fused_up_proj = nn.Linear(self.hidden_size,
|
| 1780 |
+
self.intermediate_size * self.num_experts,
|
| 1781 |
+
bias=False)
|
| 1782 |
+
self.fused_down_proj = nn.Linear(self.intermediate_size * self.num_experts,
|
| 1783 |
+
self.hidden_size,
|
| 1784 |
+
bias=False)
|
| 1785 |
+
|
| 1786 |
+
# Load weights
|
| 1787 |
+
self.fused_gate_proj.weight.data.copy_(fused_gate_w)
|
| 1788 |
+
self.fused_up_proj.weight.data.copy_(fused_up_w)
|
| 1789 |
+
self.fused_down_proj.weight.data.copy_(fused_down_w)
|
| 1790 |
+
|
| 1791 |
+
self.act_fn = moe.experts[0].act_fn
|
| 1792 |
+
self.topk = self.gate.moe_topk[0] if isinstance(self.gate.moe_topk, (list, tuple)) else self.gate.moe_topk
|
| 1793 |
+
|
| 1794 |
+
def _dense_path(self, x, probs):
|
| 1795 |
+
"""Compute dense mixture – every expert active."""
|
| 1796 |
+
gate_out = self.fused_gate_proj(x) # (T, I*E)
|
| 1797 |
+
up_out = self.fused_up_proj(x) # (T, I*E)
|
| 1798 |
+
interm = self.act_fn(gate_out) * up_out # (T, I*E)
|
| 1799 |
+
|
| 1800 |
+
# Reshape to (T, E, I)
|
| 1801 |
+
interm = interm.view(-1, self.num_experts, self.intermediate_size)
|
| 1802 |
+
|
| 1803 |
+
# Weight by softmax router probabilities
|
| 1804 |
+
interm_weighted = interm * probs.unsqueeze(-1) # (T, E, I)
|
| 1805 |
+
|
| 1806 |
+
# Collapse experts back to vector and project down
|
| 1807 |
+
dense_flat = interm_weighted.reshape(-1, self.intermediate_size * self.num_experts)
|
| 1808 |
+
return self.fused_down_proj(dense_flat) # (T, H)
|
| 1809 |
+
|
| 1810 |
+
def _sparse_path(self, x, probs):
|
| 1811 |
+
"""Compute sparse Top‑K mixture (matches original inference)."""
|
| 1812 |
+
# Pre‑compute per‑expert activations to avoid repeated forward calls
|
| 1813 |
+
gate_out = self.fused_gate_proj(x) # (T, I*E)
|
| 1814 |
+
up_out = self.fused_up_proj(x) # (T, I*E)
|
| 1815 |
+
interm = self.act_fn(gate_out) * up_out # (T, I*E)
|
| 1816 |
+
interm = interm.view(-1, self.num_experts, self.intermediate_size) # (T,E,I)
|
| 1817 |
+
|
| 1818 |
+
# Top‑K experts per token
|
| 1819 |
+
values, indices = torch.topk(probs, self.topk, dim=1) # (T,K)
|
| 1820 |
+
|
| 1821 |
+
# Gather corresponding intermediate activations
|
| 1822 |
+
gathered = interm.gather(1, indices.unsqueeze(-1).expand(-1, -1, self.intermediate_size)) # (T,K,I)
|
| 1823 |
+
gathered = gathered * values.unsqueeze(-1) # weight by router prob
|
| 1824 |
+
|
| 1825 |
+
# Gather matching down-proj weights
|
| 1826 |
+
# fused_down_proj.weight: (H, I*E) -> reshape to (E,I,H)
|
| 1827 |
+
down_w = self.fused_down_proj.weight.view(self.hidden_size,
|
| 1828 |
+
self.num_experts,
|
| 1829 |
+
self.intermediate_size).permute(1,2,0).contiguous() # (E,I,H)
|
| 1830 |
+
selected_w = down_w.index_select(0, indices.reshape(-1)).view(indices.size(0), self.topk,
|
| 1831 |
+
self.intermediate_size, self.hidden_size)
|
| 1832 |
+
# (T,K,I,H)
|
| 1833 |
+
|
| 1834 |
+
# Compute output: batch matmul over I
|
| 1835 |
+
sparse_out = torch.einsum('t k i, t k i h -> t h', gathered, selected_w)
|
| 1836 |
+
return sparse_out # (T, H)
|
| 1837 |
+
|
| 1838 |
+
def forward(self, hidden_states):
|
| 1839 |
+
bsz, seq_len, _ = hidden_states.shape
|
| 1840 |
+
x = hidden_states.reshape(-1, self.hidden_size) # T x H
|
| 1841 |
+
logits = self.gate.wg(x) # (T,E)
|
| 1842 |
+
probs = torch.softmax(logits, dim=1) # (T,E)
|
| 1843 |
+
|
| 1844 |
+
dense_out = self._dense_path(x, probs)
|
| 1845 |
+
sparse_out = self._sparse_path(x, probs)
|
| 1846 |
+
|
| 1847 |
+
out = dense_out + (sparse_out - dense_out).detach() # STE
|
| 1848 |
+
return out.view(bsz, seq_len, self.hidden_size)
|
| 1849 |
+
|
| 1850 |
+
# -----------------------------------------------------------------------
|
| 1851 |
+
# Helper for module replacement
|
| 1852 |
+
# -----------------------------------------------------------------------
|
| 1853 |
+
def _replace_submodule(root: nn.Module, target: str, new_module: nn.Module):
|
| 1854 |
+
"""Replace a (possibly nested) sub‑module.
|
| 1855 |
+
|
| 1856 |
+
``target`` is the dotted path returned by ``model.named_modules()``.
|
| 1857 |
+
"""
|
| 1858 |
+
parts = target.split('.')
|
| 1859 |
+
parent = root
|
| 1860 |
+
for p in parts[:-1]:
|
| 1861 |
+
parent = getattr(parent, p)
|
| 1862 |
+
setattr(parent, parts[-1], new_module)
|
| 1863 |
+
|
| 1864 |
+
# -----------------------------------------------------------------------
|
| 1865 |
+
# Public APIs
|
| 1866 |
+
# -----------------------------------------------------------------------
|
| 1867 |
+
def densify(model: nn.Module):
|
| 1868 |
+
"""Convert all :class:`HunYuanMoE` modules under *model* to
|
| 1869 |
+
:class:`HunYuanDenseMoE`. Operates **in‑place**."""
|
| 1870 |
+
replacements = []
|
| 1871 |
+
for name, module in model.named_modules():
|
| 1872 |
+
if isinstance(module, HunYuanMoE):
|
| 1873 |
+
replacements.append((name, module))
|
| 1874 |
+
for name, sparse_moe in replacements:
|
| 1875 |
+
dense_moe = HunYuanDenseMoE(sparse_moe).to(next(sparse_moe.parameters()).device)
|
| 1876 |
+
_replace_submodule(model, name, dense_moe)
|
| 1877 |
+
return model
|
| 1878 |
+
|
| 1879 |
+
|
| 1880 |
+
def sparsify(model: nn.Module):
|
| 1881 |
+
"""Rebuild standard sparse :class:`HunYuanMoE` modules from their
|
| 1882 |
+
fused :class:`HunYuanDenseMoE` form. Operates **in‑place**."""
|
| 1883 |
+
replacements = []
|
| 1884 |
+
for name, module in model.named_modules():
|
| 1885 |
+
if isinstance(module, HunYuanDenseMoE):
|
| 1886 |
+
replacements.append((name, module))
|
| 1887 |
+
for name, dense_moe in replacements:
|
| 1888 |
+
cfg = dense_moe.config
|
| 1889 |
+
sparse_moe = HunYuanMoE(cfg, layer_idx=dense_moe.layer_idx).to(next(dense_moe.parameters()).device)
|
| 1890 |
+
|
| 1891 |
+
# Copy router
|
| 1892 |
+
sparse_moe.gate.load_state_dict(dense_moe.gate.state_dict())
|
| 1893 |
+
|
| 1894 |
+
# Slice fused weights back to per‑expert
|
| 1895 |
+
for idx, expert in enumerate(sparse_moe.experts):
|
| 1896 |
+
start = idx * dense_moe.intermediate_size
|
| 1897 |
+
end = (idx + 1) * dense_moe.intermediate_size
|
| 1898 |
+
|
| 1899 |
+
expert.gate_proj.weight.data.copy_(
|
| 1900 |
+
dense_moe.fused_gate_proj.weight.data[start:end]
|
| 1901 |
+
)
|
| 1902 |
+
expert.up_proj.weight.data.copy_(
|
| 1903 |
+
dense_moe.fused_up_proj.weight.data[start:end]
|
| 1904 |
+
)
|
| 1905 |
+
expert.down_proj.weight.data.copy_(
|
| 1906 |
+
dense_moe.fused_down_proj.weight.data[:, start:end]
|
| 1907 |
+
)
|
| 1908 |
+
|
| 1909 |
+
_replace_submodule(model, name, sparse_moe)
|
| 1910 |
+
return model
|