Update modeling_motif.py
Browse files- modeling_motif.py +5 -6
modeling_motif.py
CHANGED
|
@@ -608,7 +608,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 608 |
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 609 |
# cast them back in float16 just to be sure everything works as expected.
|
| 610 |
input_dtype = query_states.dtype
|
| 611 |
-
if input_dtype == torch.float32
|
| 612 |
if torch.is_autocast_enabled():
|
| 613 |
target_dtype = torch.get_autocast_gpu_dtype()
|
| 614 |
# Handle the case where the model is quantized
|
|
@@ -635,7 +635,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 635 |
value_states = value_states.transpose(1, 2)
|
| 636 |
|
| 637 |
if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None
|
| 638 |
-
and self.layer_idx >= self.config.max_window_layers
|
| 639 |
sliding_window = self.config.sliding_window
|
| 640 |
else:
|
| 641 |
sliding_window = None
|
|
@@ -789,8 +789,7 @@ class MotifDecoderLayer(nn.Module):
|
|
| 789 |
def __init__(self, config: MotifConfig, layer_idx: int):
|
| 790 |
super().__init__()
|
| 791 |
self.hidden_size = config.hidden_size
|
| 792 |
-
|
| 793 |
-
config._attn_implementation = "flash_attention_2"
|
| 794 |
if config.sliding_window and config._attn_implementation != "flash_attention_2":
|
| 795 |
logger.warning_once(
|
| 796 |
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
|
@@ -801,7 +800,7 @@ class MotifDecoderLayer(nn.Module):
|
|
| 801 |
self.self_attn = MOTIF_ATTENTION_CLASSES["eager"](config, layer_idx)
|
| 802 |
self.mlp = MotifMLP(config)
|
| 803 |
|
| 804 |
-
RMSNorm =
|
| 805 |
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 806 |
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 807 |
|
|
@@ -1055,7 +1054,7 @@ class MotifModel(MotifPreTrainedModel):
|
|
| 1055 |
MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)
|
| 1056 |
])
|
| 1057 |
self._attn_implementation = config._attn_implementation
|
| 1058 |
-
RMSNorm =
|
| 1059 |
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1060 |
self.hidden_size = config.hidden_size
|
| 1061 |
self.num_heads = config.num_attention_heads
|
|
|
|
| 608 |
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 609 |
# cast them back in float16 just to be sure everything works as expected.
|
| 610 |
input_dtype = query_states.dtype
|
| 611 |
+
if input_dtype == torch.float32:
|
| 612 |
if torch.is_autocast_enabled():
|
| 613 |
target_dtype = torch.get_autocast_gpu_dtype()
|
| 614 |
# Handle the case where the model is quantized
|
|
|
|
| 635 |
value_states = value_states.transpose(1, 2)
|
| 636 |
|
| 637 |
if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None
|
| 638 |
+
and self.layer_idx >= self.config.max_window_layers):
|
| 639 |
sliding_window = self.config.sliding_window
|
| 640 |
else:
|
| 641 |
sliding_window = None
|
|
|
|
| 789 |
def __init__(self, config: MotifConfig, layer_idx: int):
|
| 790 |
super().__init__()
|
| 791 |
self.hidden_size = config.hidden_size
|
| 792 |
+
|
|
|
|
| 793 |
if config.sliding_window and config._attn_implementation != "flash_attention_2":
|
| 794 |
logger.warning_once(
|
| 795 |
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
|
|
|
| 800 |
self.self_attn = MOTIF_ATTENTION_CLASSES["eager"](config, layer_idx)
|
| 801 |
self.mlp = MotifMLP(config)
|
| 802 |
|
| 803 |
+
RMSNorm = MotifRMSNorm
|
| 804 |
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 805 |
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 806 |
|
|
|
|
| 1054 |
MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)
|
| 1055 |
])
|
| 1056 |
self._attn_implementation = config._attn_implementation
|
| 1057 |
+
RMSNorm = MotifRMSNorm
|
| 1058 |
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1059 |
self.hidden_size = config.hidden_size
|
| 1060 |
self.num_heads = config.num_attention_heads
|