Update modeling_bailing_moe.py
Browse files- modeling_bailing_moe.py +13 -8
modeling_bailing_moe.py
CHANGED
|
@@ -117,8 +117,8 @@ class BailingMoeRMSNorm(nn.Module):
|
|
| 117 |
hidden_states = hidden_states.to(torch.float32)
|
| 118 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 119 |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
|
|
| 120 |
|
| 121 |
-
return (self.weight.float() * hidden_states).to(input_dtype)
|
| 122 |
|
| 123 |
ALL_LAYERNORM_LAYERS.append(BailingMoeRMSNorm)
|
| 124 |
|
|
@@ -495,7 +495,7 @@ class BailingMoeAttention(nn.Module):
|
|
| 495 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 496 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 497 |
|
| 498 |
-
attn_weights = torch.matmul(query_states / math.sqrt(self.head_dim), key_states.transpose(2, 3))
|
| 499 |
|
| 500 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
| 501 |
raise ValueError(
|
|
@@ -825,7 +825,6 @@ class BailingMoeSdpaAttention(BailingMoeAttention):
|
|
| 825 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
| 826 |
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
| 827 |
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
| 828 |
-
# enable_gqa=True
|
| 829 |
)
|
| 830 |
|
| 831 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
@@ -847,6 +846,7 @@ class BailingMoeDecoderLayer(nn.Module):
|
|
| 847 |
def __init__(self, config: BailingMoeConfig, layer_idx: int):
|
| 848 |
super().__init__()
|
| 849 |
self.hidden_size = config.hidden_size
|
|
|
|
| 850 |
self.attention = BAILING_MOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
| 851 |
|
| 852 |
self.mlp = (
|
|
@@ -1167,7 +1167,7 @@ class BailingMoeModel(BailingMoePreTrainedModel):
|
|
| 1167 |
all_router_logits = () if output_router_logits else None
|
| 1168 |
next_decoder_cache = None
|
| 1169 |
|
| 1170 |
-
for
|
| 1171 |
if output_hidden_states:
|
| 1172 |
all_hidden_states += (hidden_states,)
|
| 1173 |
|
|
@@ -1332,9 +1332,10 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
|
|
| 1332 |
)
|
| 1333 |
logits = F.linear(hidden_states, norm_weight, None)
|
| 1334 |
else:
|
| 1335 |
-
self.lm_head.weight.data = (
|
| 1336 |
-
|
| 1337 |
-
|
|
|
|
| 1338 |
logits = F.linear(hidden_states, self.lm_head.weight.data, None)
|
| 1339 |
self.norm_head = False
|
| 1340 |
else:
|
|
@@ -1380,7 +1381,11 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
|
|
| 1380 |
if isinstance(past_key_values, Cache):
|
| 1381 |
cache_length = past_key_values.get_seq_length()
|
| 1382 |
past_length = past_key_values.seen_tokens
|
| 1383 |
-
max_cache_length =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1384 |
else:
|
| 1385 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1386 |
max_cache_length = None
|
|
|
|
| 117 |
hidden_states = hidden_states.to(torch.float32)
|
| 118 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 119 |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 120 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 121 |
|
|
|
|
| 122 |
|
| 123 |
ALL_LAYERNORM_LAYERS.append(BailingMoeRMSNorm)
|
| 124 |
|
|
|
|
| 495 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 496 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 497 |
|
| 498 |
+
attn_weights = torch.matmul(query_states / math.sqrt(self.head_dim), key_states.transpose(2, 3))
|
| 499 |
|
| 500 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
| 501 |
raise ValueError(
|
|
|
|
| 825 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
| 826 |
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
| 827 |
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
|
|
|
| 828 |
)
|
| 829 |
|
| 830 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
| 846 |
def __init__(self, config: BailingMoeConfig, layer_idx: int):
|
| 847 |
super().__init__()
|
| 848 |
self.hidden_size = config.hidden_size
|
| 849 |
+
|
| 850 |
self.attention = BAILING_MOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
| 851 |
|
| 852 |
self.mlp = (
|
|
|
|
| 1167 |
all_router_logits = () if output_router_logits else None
|
| 1168 |
next_decoder_cache = None
|
| 1169 |
|
| 1170 |
+
for decoder_layer in self.layers:
|
| 1171 |
if output_hidden_states:
|
| 1172 |
all_hidden_states += (hidden_states,)
|
| 1173 |
|
|
|
|
| 1332 |
)
|
| 1333 |
logits = F.linear(hidden_states, norm_weight, None)
|
| 1334 |
else:
|
| 1335 |
+
self.lm_head.weight.data = (
|
| 1336 |
+
self.lm_head.weight.data.float()
|
| 1337 |
+
/ (torch.norm(self.lm_head.weight.data.float(), p=2, dim=0, keepdim=True) + 1e-7)
|
| 1338 |
+
).to(hidden_states.dtype)
|
| 1339 |
logits = F.linear(hidden_states, self.lm_head.weight.data, None)
|
| 1340 |
self.norm_head = False
|
| 1341 |
else:
|
|
|
|
| 1381 |
if isinstance(past_key_values, Cache):
|
| 1382 |
cache_length = past_key_values.get_seq_length()
|
| 1383 |
past_length = past_key_values.seen_tokens
|
| 1384 |
+
max_cache_length = (
|
| 1385 |
+
past_key_values.get_max_length()
|
| 1386 |
+
if hasattr(past_key_values, "get_max_length")
|
| 1387 |
+
else past_key_values.get_max_cache_shape()
|
| 1388 |
+
)
|
| 1389 |
else:
|
| 1390 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1391 |
max_cache_length = None
|