Upload modeling_hunyuan.py
Browse files- modeling_hunyuan.py +143 -102
modeling_hunyuan.py
CHANGED
|
@@ -1,16 +1,5 @@
|
|
| 1 |
-
#
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# https://github.com/Tencent/Tencent-Hunyuan-Large/blob/main/License.docx
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
#
|
| 15 |
""" PyTorch HunYuan model."""
|
| 16 |
|
|
@@ -74,8 +63,7 @@ _CONFIG_FOR_DOC = "HunYuanConfig"
|
|
| 74 |
def topkgating(logits: Tensor, topk: int):
|
| 75 |
logits = logits.float()
|
| 76 |
gates = F.softmax(logits, dim=1)
|
| 77 |
-
|
| 78 |
-
expert_capacity = max(topk, topk * gates.shape[0] // gates.shape[1])
|
| 79 |
num_experts = int(gates.shape[1])
|
| 80 |
# Top-k router probability and corresponding expert indices for each token.
|
| 81 |
# Shape: [tokens_per_group, num_selected_experts].
|
|
@@ -265,7 +253,7 @@ class HunYuanRotaryEmbedding(nn.Module):
|
|
| 265 |
self.max_position_embeddings = max_position_embeddings
|
| 266 |
self.base = base
|
| 267 |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 268 |
-
|
| 269 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 270 |
|
| 271 |
# Build here to make `torch.jit.trace` work.
|
|
@@ -277,7 +265,6 @@ class HunYuanRotaryEmbedding(nn.Module):
|
|
| 277 |
self.max_seq_len_cached = seq_len
|
| 278 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
|
| 279 |
|
| 280 |
-
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 281 |
freqs = torch.outer(t, self.inv_freq)
|
| 282 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 283 |
emb = torch.cat((freqs, freqs), dim=-1).float()
|
|
@@ -286,7 +273,7 @@ class HunYuanRotaryEmbedding(nn.Module):
|
|
| 286 |
|
| 287 |
def forward(self, x, seq_len=None):
|
| 288 |
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 289 |
-
if seq_len > self.max_seq_len_cached
|
| 290 |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 291 |
|
| 292 |
return (
|
|
@@ -411,7 +398,7 @@ class HunYuanMLP(nn.Module):
|
|
| 411 |
self.layer_idx = layer_idx
|
| 412 |
self.hidden_size = config.hidden_size
|
| 413 |
if is_shared_mlp:
|
| 414 |
-
self.intermediate_size = config.intermediate_size * config.num_shared_expert
|
| 415 |
else:
|
| 416 |
self.intermediate_size = config.intermediate_size
|
| 417 |
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
@@ -462,66 +449,142 @@ class HunYuanTopKGate(nn.Module):
|
|
| 462 |
if self.moe_topk == 1:
|
| 463 |
gate_output = top1gating(logits, random_routing_dropped_token=self.random_routing_dropped_token)
|
| 464 |
else:
|
| 465 |
-
gate_output = topkgating(logits, self.moe_topk
|
| 466 |
|
| 467 |
return gate_output
|
| 468 |
|
| 469 |
|
| 470 |
class HunYuanMoE(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None):
|
| 472 |
super().__init__()
|
| 473 |
self.config = config
|
| 474 |
self.layer_idx = layer_idx
|
| 475 |
self.moe_topk = config.moe_topk
|
| 476 |
self.num_experts = config.num_experts
|
|
|
|
|
|
|
| 477 |
if config.use_mixed_mlp_moe:
|
| 478 |
self.shared_mlp = HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True)
|
|
|
|
|
|
|
| 479 |
self.gate = HunYuanTopKGate(config, layer_idx=layer_idx)
|
|
|
|
|
|
|
| 480 |
self.experts = nn.ModuleList(
|
| 481 |
[HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(config.num_experts)]
|
| 482 |
)
|
| 483 |
|
| 484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
bsz, seq_len, hidden_size = hidden_states.shape
|
| 486 |
|
|
|
|
| 487 |
if self.config.use_mixed_mlp_moe:
|
| 488 |
hidden_states_mlp = self.shared_mlp(hidden_states)
|
| 489 |
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
|
| 494 |
-
|
|
|
|
| 495 |
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
|
|
|
|
|
|
|
|
|
| 504 |
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
|
|
|
| 509 |
|
| 510 |
-
|
|
|
|
|
|
|
|
|
|
| 511 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 516 |
-
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 517 |
-
"""
|
| 518 |
-
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 519 |
-
if n_rep == 1:
|
| 520 |
-
return hidden_states
|
| 521 |
-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 522 |
-
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 523 |
|
|
|
|
|
|
|
|
|
|
| 524 |
|
|
|
|
| 525 |
class HunYuanAttention(nn.Module):
|
| 526 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 527 |
|
|
@@ -1064,33 +1127,18 @@ class HunYuanDecoderLayer(nn.Module):
|
|
| 1064 |
kv_states: Optional[Tuple[torch.Tensor]] = None,
|
| 1065 |
**kwargs,
|
| 1066 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 1067 |
-
"""
|
| 1068 |
-
|
| 1069 |
-
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 1070 |
-
attention_mask (`torch.FloatTensor`, *optional*):
|
| 1071 |
-
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
| 1072 |
-
query_sequence_length, key_sequence_length)` if default attention is used.
|
| 1073 |
-
output_attentions (`bool`, *optional*):
|
| 1074 |
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 1075 |
-
returned tensors for more detail.
|
| 1076 |
-
use_cache (`bool`, *optional*):
|
| 1077 |
-
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 1078 |
-
(see `past_key_values`).
|
| 1079 |
-
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 1080 |
-
kv_states (`Tuple(torch.FloatTensor)`, *optional*): Used when CLA is enabled,
|
| 1081 |
-
key and value states from past attention blocks
|
| 1082 |
-
"""
|
| 1083 |
if "padding_mask" in kwargs:
|
| 1084 |
warnings.warn(
|
| 1085 |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use "
|
| 1086 |
"`attention_mask` instead.`"
|
| 1087 |
)
|
| 1088 |
-
|
|
|
|
| 1089 |
residual = hidden_states
|
| 1090 |
-
|
| 1091 |
hidden_states = self.input_layernorm(hidden_states)
|
| 1092 |
-
|
| 1093 |
-
# Self Attention
|
| 1094 |
hidden_states, self_attn_weights, present_key_value, kv_states = self.self_attn(
|
| 1095 |
hidden_states=hidden_states,
|
| 1096 |
attention_mask=attention_mask,
|
|
@@ -1102,47 +1150,40 @@ class HunYuanDecoderLayer(nn.Module):
|
|
| 1102 |
**kwargs,
|
| 1103 |
)
|
| 1104 |
hidden_states = residual + hidden_states
|
| 1105 |
-
|
| 1106 |
-
#
|
| 1107 |
residual = hidden_states
|
| 1108 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 1109 |
-
|
| 1110 |
-
|
| 1111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1112 |
outputs = (hidden_states,)
|
| 1113 |
-
|
| 1114 |
if output_attentions:
|
| 1115 |
outputs += (self_attn_weights,)
|
| 1116 |
-
|
| 1117 |
if use_cache:
|
| 1118 |
outputs += (present_key_value,)
|
| 1119 |
-
|
| 1120 |
outputs += (kv_states,)
|
| 1121 |
-
|
| 1122 |
return outputs
|
| 1123 |
-
|
| 1124 |
-
|
| 1125 |
-
HUNYUAN_START_DOCSTRING = r"""
|
| 1126 |
-
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1127 |
-
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 1128 |
-
etc.)
|
| 1129 |
-
|
| 1130 |
-
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 1131 |
-
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 1132 |
-
and behavior.
|
| 1133 |
-
|
| 1134 |
-
Parameters:
|
| 1135 |
-
config ([`HunYuanConfig`]):
|
| 1136 |
-
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 1137 |
-
load the weights associated with the model, only the configuration. Check out the
|
| 1138 |
-
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 1139 |
-
"""
|
| 1140 |
-
|
| 1141 |
-
|
| 1142 |
-
@add_start_docstrings(
|
| 1143 |
-
"The bare HunYuan Model outputting raw hidden-states without any specific head on top.",
|
| 1144 |
-
HUNYUAN_START_DOCSTRING,
|
| 1145 |
-
)
|
| 1146 |
class HunYuanPreTrainedModel(PreTrainedModel):
|
| 1147 |
config_class = HunYuanConfig
|
| 1148 |
base_model_prefix = "model"
|
|
@@ -1417,7 +1458,7 @@ class HunYuanModel(HunYuanPreTrainedModel):
|
|
| 1417 |
)
|
| 1418 |
|
| 1419 |
|
| 1420 |
-
class
|
| 1421 |
_tied_weights_keys = ["lm_head.weight"]
|
| 1422 |
|
| 1423 |
def __init__(self, config: HunYuanConfig):
|
|
@@ -1547,7 +1588,7 @@ class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
|
|
| 1547 |
if isinstance(past_key_values, Cache):
|
| 1548 |
cache_length = past_key_values.get_seq_length()
|
| 1549 |
past_length = past_key_values.seen_tokens
|
| 1550 |
-
max_cache_length = past_key_values.
|
| 1551 |
else:
|
| 1552 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1553 |
max_cache_length = None
|
|
@@ -1725,4 +1766,4 @@ class HunYuanForSequenceClassification(HunYuanPreTrainedModel):
|
|
| 1725 |
past_key_values=transformer_outputs.past_key_values,
|
| 1726 |
hidden_states=transformer_outputs.hidden_states,
|
| 1727 |
attentions=transformer_outputs.attentions,
|
| 1728 |
-
)
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 Tencent Inc. All Rights Reserved.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
#
|
| 4 |
""" PyTorch HunYuan model."""
|
| 5 |
|
|
|
|
| 63 |
def topkgating(logits: Tensor, topk: int):
|
| 64 |
logits = logits.float()
|
| 65 |
gates = F.softmax(logits, dim=1)
|
| 66 |
+
expert_capacity = topk * gates.shape[0]
|
|
|
|
| 67 |
num_experts = int(gates.shape[1])
|
| 68 |
# Top-k router probability and corresponding expert indices for each token.
|
| 69 |
# Shape: [tokens_per_group, num_selected_experts].
|
|
|
|
| 253 |
self.max_position_embeddings = max_position_embeddings
|
| 254 |
self.base = base
|
| 255 |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 256 |
+
inv_freq = inv_freq.bfloat16()
|
| 257 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 258 |
|
| 259 |
# Build here to make `torch.jit.trace` work.
|
|
|
|
| 265 |
self.max_seq_len_cached = seq_len
|
| 266 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
|
| 267 |
|
|
|
|
| 268 |
freqs = torch.outer(t, self.inv_freq)
|
| 269 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 270 |
emb = torch.cat((freqs, freqs), dim=-1).float()
|
|
|
|
| 273 |
|
| 274 |
def forward(self, x, seq_len=None):
|
| 275 |
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 276 |
+
if seq_len > self.max_seq_len_cached:
|
| 277 |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 278 |
|
| 279 |
return (
|
|
|
|
| 398 |
self.layer_idx = layer_idx
|
| 399 |
self.hidden_size = config.hidden_size
|
| 400 |
if is_shared_mlp:
|
| 401 |
+
self.intermediate_size = config.intermediate_size * config.num_shared_expert
|
| 402 |
else:
|
| 403 |
self.intermediate_size = config.intermediate_size
|
| 404 |
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
|
|
| 449 |
if self.moe_topk == 1:
|
| 450 |
gate_output = top1gating(logits, random_routing_dropped_token=self.random_routing_dropped_token)
|
| 451 |
else:
|
| 452 |
+
gate_output = topkgating(logits, self.moe_topk)
|
| 453 |
|
| 454 |
return gate_output
|
| 455 |
|
| 456 |
|
| 457 |
class HunYuanMoE(nn.Module):
|
| 458 |
+
"""Mixture-of-Experts block with vectorized expert execution and Straight‑Through Estimator (STE) utilities.
|
| 459 |
+
|
| 460 |
+
This implementation removes all Python‑side loops over experts. Expert parameters are **stacked** on‑the‑fly and
|
| 461 |
+
all experts are executed in a single batched matmul sequence, which allows efficient tensor‑parallel execution
|
| 462 |
+
(e.g. with DeepSpeed ZeRO‑3) while keeping the state‑dict format unchanged (each expert remains an individual
|
| 463 |
+
sub‑module for full compatibility with existing checkpoints)."""
|
| 464 |
+
|
| 465 |
def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None):
|
| 466 |
super().__init__()
|
| 467 |
self.config = config
|
| 468 |
self.layer_idx = layer_idx
|
| 469 |
self.moe_topk = config.moe_topk
|
| 470 |
self.num_experts = config.num_experts
|
| 471 |
+
|
| 472 |
+
# Optional shared MLP branch (mixed MoE + dense)
|
| 473 |
if config.use_mixed_mlp_moe:
|
| 474 |
self.shared_mlp = HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True)
|
| 475 |
+
|
| 476 |
+
# Router
|
| 477 |
self.gate = HunYuanTopKGate(config, layer_idx=layer_idx)
|
| 478 |
+
|
| 479 |
+
# Experts kept as individual sub‑modules so that load_state_dict / save_pretrained stay identical.
|
| 480 |
self.experts = nn.ModuleList(
|
| 481 |
[HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(config.num_experts)]
|
| 482 |
)
|
| 483 |
|
| 484 |
+
# ---------------------------------------------------------------------
|
| 485 |
+
# Internal helpers
|
| 486 |
+
# ---------------------------------------------------------------------
|
| 487 |
+
def _stack_weights(self):
|
| 488 |
+
"""Return stacked (batched) expert weights for gate/up/down projections.
|
| 489 |
+
|
| 490 |
+
Shapes:
|
| 491 |
+
Wg : (E, I, H)
|
| 492 |
+
Wu : (E, I, H)
|
| 493 |
+
Wd : (E, H, I) – note transposed compared to nn.Linear's weight so that
|
| 494 |
+
torch.matmul(act, Wd.transpose(-2,‑1)) produces (E, C, H)
|
| 495 |
+
"""
|
| 496 |
+
Wg = torch.stack([exp.gate_proj.weight for exp in self.experts], dim=0)
|
| 497 |
+
Wu = torch.stack([exp.up_proj.weight for exp in self.experts], dim=0)
|
| 498 |
+
Wd = torch.stack([exp.down_proj.weight for exp in self.experts], dim=0)
|
| 499 |
+
return Wg, Wu, Wd
|
| 500 |
+
|
| 501 |
+
# ---------------------------------------------------------------------
|
| 502 |
+
# Public API
|
| 503 |
+
# ---------------------------------------------------------------------
|
| 504 |
+
def forward(
|
| 505 |
+
self,
|
| 506 |
+
hidden_states: torch.Tensor,
|
| 507 |
+
*,
|
| 508 |
+
return_router_logits: bool = False,
|
| 509 |
+
):
|
| 510 |
+
"""Sparse Top‑k MoE forward pass ("y_sparse") with optional router logits.
|
| 511 |
+
|
| 512 |
+
This is the route used during both training and inference for the *sparse*
|
| 513 |
+
branch. No Python loops over experts are used."
|
| 514 |
+
"""
|
| 515 |
bsz, seq_len, hidden_size = hidden_states.shape
|
| 516 |
|
| 517 |
+
# Optional dense branch when using mixed MoE
|
| 518 |
if self.config.use_mixed_mlp_moe:
|
| 519 |
hidden_states_mlp = self.shared_mlp(hidden_states)
|
| 520 |
|
| 521 |
+
# ---------------- Routing ----------------
|
| 522 |
+
l_moe, combine_weights, dispatch_mask, _ = self.gate(hidden_states)
|
| 523 |
+
|
| 524 |
+
flat_input = hidden_states.reshape(-1, hidden_size) # (S, H) where S = B*T
|
| 525 |
+
# dispatch tokens → experts
|
| 526 |
+
dispatched_input = torch.einsum("sec,sm->ecm", # (E, C, H)
|
| 527 |
+
dispatch_mask.to(hidden_states.dtype),
|
| 528 |
+
flat_input)
|
| 529 |
+
|
| 530 |
+
# ---------------- Expert computation ----------------
|
| 531 |
+
Wg, Wu, Wd = self._stack_weights()
|
| 532 |
+
Wg = Wg.to(hidden_states.dtype)
|
| 533 |
+
Wu = Wu.to(hidden_states.dtype)
|
| 534 |
+
Wd = Wd.to(hidden_states.dtype)
|
| 535 |
+
|
| 536 |
+
gate_out = torch.einsum("ech,eih->eci", dispatched_input, Wg) # (E, C, I)
|
| 537 |
+
up_out = torch.einsum("ech,eih->eci", dispatched_input, Wu) # (E, C, I)
|
| 538 |
+
act_fn = self.experts[0].act_fn
|
| 539 |
+
interm = act_fn(gate_out) * up_out # (E, C, I)
|
| 540 |
+
expert_output = torch.matmul(interm, Wd.transpose(-2, -1)) # (E, C, H)
|
| 541 |
+
|
| 542 |
+
# ---------------- Combine ----------------
|
| 543 |
+
combined_output = torch.einsum("sec,ecm->sm", # (S, H)
|
| 544 |
+
combine_weights.to(hidden_states.dtype),
|
| 545 |
+
expert_output)
|
| 546 |
+
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
|
| 547 |
|
| 548 |
+
if self.config.use_mixed_mlp_moe:
|
| 549 |
+
combined_output = hidden_states_mlp + combined_output
|
| 550 |
|
| 551 |
+
if return_router_logits:
|
| 552 |
+
router_logits = self.gate.wg(flat_input).view(bsz, seq_len, self.num_experts)
|
| 553 |
+
return combined_output, router_logits
|
| 554 |
+
return combined_output
|
| 555 |
|
| 556 |
+
# ---------------------------------------------------------------------
|
| 557 |
+
# Dense branch – outputs *all* expert activations (no routing)
|
| 558 |
+
# ---------------------------------------------------------------------
|
| 559 |
+
@torch.no_grad()
|
| 560 |
+
def forward_all(self, hidden_states: torch.Tensor):
|
| 561 |
+
"""Compute every expert on every token ("dense" MoE path).
|
| 562 |
|
| 563 |
+
Returns:
|
| 564 |
+
Tensor of shape (B, T, E, H) – per‑expert hidden states.
|
| 565 |
+
"""
|
| 566 |
+
bsz, seq_len, hidden_size = hidden_states.shape
|
| 567 |
+
flat_input = hidden_states.reshape(-1, hidden_size) # (S, H)
|
| 568 |
|
| 569 |
+
Wg, Wu, Wd = self._stack_weights()
|
| 570 |
+
Wg = Wg.to(hidden_states.dtype)
|
| 571 |
+
Wu = Wu.to(hidden_states.dtype)
|
| 572 |
+
Wd = Wd.to(hidden_states.dtype)
|
| 573 |
|
| 574 |
+
gate_out = torch.einsum("sh,eih->esi", flat_input, Wg) # (E, S, I)
|
| 575 |
+
up_out = torch.einsum("sh,eih->esi", flat_input, Wu) # (E, S, I)
|
| 576 |
+
act_fn = self.experts[0].act_fn
|
| 577 |
+
interm = act_fn(gate_out) * up_out # (E, S, I)
|
| 578 |
+
dense_out = torch.matmul(interm, Wd.transpose(-2, -1)) # (E, S, H)
|
| 579 |
|
| 580 |
+
dense_out = dense_out.permute(1, 0, 2).contiguous() # (S, E, H)
|
| 581 |
+
dense_out = dense_out.view(bsz, seq_len, self.num_experts, hidden_size) # (B, T, E, H)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 582 |
|
| 583 |
+
if self.config.use_mixed_mlp_moe:
|
| 584 |
+
hidden_states_mlp = self.shared_mlp(hidden_states) # (B, T, H)
|
| 585 |
+
dense_out = dense_out + hidden_states_mlp.unsqueeze(2)
|
| 586 |
|
| 587 |
+
return dense_out
|
| 588 |
class HunYuanAttention(nn.Module):
|
| 589 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 590 |
|
|
|
|
| 1127 |
kv_states: Optional[Tuple[torch.Tensor]] = None,
|
| 1128 |
**kwargs,
|
| 1129 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 1130 |
+
"""Modified to include Straight‑Through Estimator (STE) training logic."""
|
| 1131 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1132 |
if "padding_mask" in kwargs:
|
| 1133 |
warnings.warn(
|
| 1134 |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use "
|
| 1135 |
"`attention_mask` instead.`"
|
| 1136 |
)
|
| 1137 |
+
|
| 1138 |
+
# ---------------- Self‑Attention ----------------
|
| 1139 |
residual = hidden_states
|
|
|
|
| 1140 |
hidden_states = self.input_layernorm(hidden_states)
|
| 1141 |
+
|
|
|
|
| 1142 |
hidden_states, self_attn_weights, present_key_value, kv_states = self.self_attn(
|
| 1143 |
hidden_states=hidden_states,
|
| 1144 |
attention_mask=attention_mask,
|
|
|
|
| 1150 |
**kwargs,
|
| 1151 |
)
|
| 1152 |
hidden_states = residual + hidden_states
|
| 1153 |
+
|
| 1154 |
+
# ---------------- MLP / MoE (+STE) ----------------
|
| 1155 |
residual = hidden_states
|
| 1156 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 1157 |
+
|
| 1158 |
+
if self.training and isinstance(self.mlp, HunYuanMoE):
|
| 1159 |
+
# Sparse path + router logits
|
| 1160 |
+
y_sparse, router_logits = self.mlp(hidden_states, return_router_logits=True)
|
| 1161 |
+
|
| 1162 |
+
# Dense (all‑experts) path – memory‑efficient, no grad
|
| 1163 |
+
with torch.no_grad():
|
| 1164 |
+
y_all = self.mlp.forward_all(hidden_states) # (B, T, E, H)
|
| 1165 |
+
|
| 1166 |
+
gate = router_logits.softmax(-1).unsqueeze(-1) # (B, T, E, 1)
|
| 1167 |
+
y_dense = (gate * y_all).sum(-2) # (B, T, H)
|
| 1168 |
+
|
| 1169 |
+
mlp_out = y_dense + (y_sparse - y_dense).detach()
|
| 1170 |
+
else:
|
| 1171 |
+
mlp_out = self.mlp(hidden_states)
|
| 1172 |
+
|
| 1173 |
+
hidden_states = residual + mlp_out
|
| 1174 |
+
|
| 1175 |
+
# ---------------- Outputs ----------------
|
| 1176 |
outputs = (hidden_states,)
|
| 1177 |
+
|
| 1178 |
if output_attentions:
|
| 1179 |
outputs += (self_attn_weights,)
|
| 1180 |
+
|
| 1181 |
if use_cache:
|
| 1182 |
outputs += (present_key_value,)
|
| 1183 |
+
|
| 1184 |
outputs += (kv_states,)
|
| 1185 |
+
|
| 1186 |
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1187 |
class HunYuanPreTrainedModel(PreTrainedModel):
|
| 1188 |
config_class = HunYuanConfig
|
| 1189 |
base_model_prefix = "model"
|
|
|
|
| 1458 |
)
|
| 1459 |
|
| 1460 |
|
| 1461 |
+
class HunYuanForCausalLM(HunYuanPreTrainedModel):
|
| 1462 |
_tied_weights_keys = ["lm_head.weight"]
|
| 1463 |
|
| 1464 |
def __init__(self, config: HunYuanConfig):
|
|
|
|
| 1588 |
if isinstance(past_key_values, Cache):
|
| 1589 |
cache_length = past_key_values.get_seq_length()
|
| 1590 |
past_length = past_key_values.seen_tokens
|
| 1591 |
+
max_cache_length = past_key_values.get_max_length()
|
| 1592 |
else:
|
| 1593 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1594 |
max_cache_length = None
|
|
|
|
| 1766 |
past_key_values=transformer_outputs.past_key_values,
|
| 1767 |
hidden_states=transformer_outputs.hidden_states,
|
| 1768 |
attentions=transformer_outputs.attentions,
|
| 1769 |
+
)
|