Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +8 -18
modeling_gemmoe.py
CHANGED
|
@@ -670,16 +670,11 @@ class GemmoeBlockSparseTop2MLP(nn.Module):
|
|
| 670 |
self.act_fn = approx_gelu
|
| 671 |
|
| 672 |
def forward(self, hidden_states):
|
|
|
|
| 673 |
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
| 674 |
-
current_hidden_states = self.w2(current_hidden_states)
|
| 675 |
return current_hidden_states
|
| 676 |
|
| 677 |
-
class GemmoeBlockSparseTop2MLP(GemmoeBlockSparseTop2MLP):
|
| 678 |
-
def __init__(self, *args, **kwargs):
|
| 679 |
-
logger.warning_once(
|
| 680 |
-
"GemmoeBLockSparseTop2MLP is deprecated by GemmoeBlockSparseTop2MLP and will be removed in v4.40."
|
| 681 |
-
)
|
| 682 |
-
super().__init__(*args, **kwargs)
|
| 683 |
|
| 684 |
class GemmoeSparseMoeBlock(nn.Module):
|
| 685 |
def __init__(self, config):
|
|
@@ -699,8 +694,9 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
| 699 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 700 |
|
| 701 |
# router_logits: (batch * sequence_length, n_experts)
|
| 702 |
-
|
| 703 |
-
|
|
|
|
| 704 |
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
| 705 |
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
| 706 |
|
|
@@ -715,7 +711,7 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
| 715 |
for i in range(self.num_experts):
|
| 716 |
expert = self.experts[i]
|
| 717 |
expert_output = expert(hidden_states[flat_topk_idx == i])
|
| 718 |
-
y[flat_topk_idx == i] = expert_output
|
| 719 |
|
| 720 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
| 721 |
|
|
@@ -983,7 +979,6 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
| 983 |
self.embed_tokens = value
|
| 984 |
|
| 985 |
@add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
|
| 986 |
-
# Ignore copy
|
| 987 |
def forward(
|
| 988 |
self,
|
| 989 |
input_ids: torch.LongTensor = None,
|
|
@@ -994,7 +989,7 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
| 994 |
use_cache: Optional[bool] = None,
|
| 995 |
output_attentions: Optional[bool] = None,
|
| 996 |
output_hidden_states: Optional[bool] = None,
|
| 997 |
-
output_router_logits: Optional[bool] = None,
|
| 998 |
return_dict: Optional[bool] = None,
|
| 999 |
cache_position: Optional[torch.LongTensor] = None,
|
| 1000 |
) -> Union[Tuple, MoeModelOutputWithPast]:
|
|
@@ -1023,7 +1018,6 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
| 1023 |
# Fix for precision issue when casting to bfloat16
|
| 1024 |
hidden_size_sqrt = math.sqrt(self.config.hidden_size)
|
| 1025 |
if inputs_embeds.dtype == torch.bfloat16:
|
| 1026 |
-
|
| 1027 |
pass
|
| 1028 |
|
| 1029 |
hidden_states = inputs_embeds * hidden_size_sqrt
|
|
@@ -1110,10 +1104,6 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
| 1110 |
attentions=all_self_attns,
|
| 1111 |
)
|
| 1112 |
|
| 1113 |
-
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
| 1114 |
-
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
| 1115 |
-
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
| 1116 |
-
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
| 1117 |
def _update_causal_mask(self, attention_mask, input_tensor):
|
| 1118 |
if self.config._attn_implementation == "flash_attention_2":
|
| 1119 |
if attention_mask is not None and 0.0 in attention_mask:
|
|
@@ -1135,7 +1125,7 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
| 1135 |
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
|
| 1136 |
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
|
| 1137 |
if attention_mask is not None:
|
| 1138 |
-
causal_mask = causal_mask.clone()
|
| 1139 |
if attention_mask.dim() == 2:
|
| 1140 |
mask_length = attention_mask.shape[-1]
|
| 1141 |
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
|
|
|
| 670 |
self.act_fn = approx_gelu
|
| 671 |
|
| 672 |
def forward(self, hidden_states):
|
| 673 |
+
hidden_states = hidden_states.to(torch.float32) # Cast to float32
|
| 674 |
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
| 675 |
+
current_hidden_states = self.w2(current_hidden_states.to(hidden_states.dtype)) # Cast back to original dtype
|
| 676 |
return current_hidden_states
|
| 677 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 678 |
|
| 679 |
class GemmoeSparseMoeBlock(nn.Module):
|
| 680 |
def __init__(self, config):
|
|
|
|
| 694 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 695 |
|
| 696 |
# router_logits: (batch * sequence_length, n_experts)
|
| 697 |
+
hidden_states_float = hidden_states.float() # Cast to float32
|
| 698 |
+
router_logits = self.gate(hidden_states_float)
|
| 699 |
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
|
| 700 |
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
| 701 |
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
| 702 |
|
|
|
|
| 711 |
for i in range(self.num_experts):
|
| 712 |
expert = self.experts[i]
|
| 713 |
expert_output = expert(hidden_states[flat_topk_idx == i])
|
| 714 |
+
y[flat_topk_idx == i] = expert_output
|
| 715 |
|
| 716 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
| 717 |
|
|
|
|
| 979 |
self.embed_tokens = value
|
| 980 |
|
| 981 |
@add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
|
|
|
|
| 982 |
def forward(
|
| 983 |
self,
|
| 984 |
input_ids: torch.LongTensor = None,
|
|
|
|
| 989 |
use_cache: Optional[bool] = None,
|
| 990 |
output_attentions: Optional[bool] = None,
|
| 991 |
output_hidden_states: Optional[bool] = None,
|
| 992 |
+
output_router_logits: Optional[bool] = None,
|
| 993 |
return_dict: Optional[bool] = None,
|
| 994 |
cache_position: Optional[torch.LongTensor] = None,
|
| 995 |
) -> Union[Tuple, MoeModelOutputWithPast]:
|
|
|
|
| 1018 |
# Fix for precision issue when casting to bfloat16
|
| 1019 |
hidden_size_sqrt = math.sqrt(self.config.hidden_size)
|
| 1020 |
if inputs_embeds.dtype == torch.bfloat16:
|
|
|
|
| 1021 |
pass
|
| 1022 |
|
| 1023 |
hidden_states = inputs_embeds * hidden_size_sqrt
|
|
|
|
| 1104 |
attentions=all_self_attns,
|
| 1105 |
)
|
| 1106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1107 |
def _update_causal_mask(self, attention_mask, input_tensor):
|
| 1108 |
if self.config._attn_implementation == "flash_attention_2":
|
| 1109 |
if attention_mask is not None and 0.0 in attention_mask:
|
|
|
|
| 1125 |
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
|
| 1126 |
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
|
| 1127 |
if attention_mask is not None:
|
| 1128 |
+
causal_mask = causal_mask.clone()
|
| 1129 |
if attention_mask.dim() == 2:
|
| 1130 |
mask_length = attention_mask.shape[-1]
|
| 1131 |
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|