Eunhwan Park commited on
Update modeling_motif.py
Browse filesRemove MorehFlashAttention
- modeling_motif.py +2 -6
modeling_motif.py
CHANGED
|
@@ -63,10 +63,8 @@ if is_flash_attn_2_available():
|
|
| 63 |
|
| 64 |
try:
|
| 65 |
moreh_ops = torch.ops.moreh
|
| 66 |
-
MorehFlashAttention = moreh_ops.flash_attention
|
| 67 |
logger.warning_once("Using moreh ops")
|
| 68 |
except AttributeError:
|
| 69 |
-
MorehFlashAttention = None
|
| 70 |
logger.warning_once("Failed to import moreh ops")
|
| 71 |
|
| 72 |
#_CHECKPOINT_FOR_DOC = "moreh/Motif-102B"
|
|
@@ -597,7 +595,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 597 |
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 598 |
# cast them back in float16 just to be sure everything works as expected.
|
| 599 |
input_dtype = query_states.dtype
|
| 600 |
-
if input_dtype == torch.float32
|
| 601 |
if torch.is_autocast_enabled():
|
| 602 |
target_dtype = torch.get_autocast_gpu_dtype()
|
| 603 |
# Handle the case where the model is quantized
|
|
@@ -624,7 +622,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 624 |
value_states = value_states.transpose(1, 2)
|
| 625 |
|
| 626 |
if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None
|
| 627 |
-
and self.layer_idx >= self.config.max_window_layers
|
| 628 |
sliding_window = self.config.sliding_window
|
| 629 |
else:
|
| 630 |
sliding_window = None
|
|
@@ -1177,8 +1175,6 @@ class MotifModel(MotifPreTrainedModel):
|
|
| 1177 |
output_attentions: bool,
|
| 1178 |
):
|
| 1179 |
if self.config._attn_implementation == "flash_attention_2":
|
| 1180 |
-
if MorehFlashAttention is not None:
|
| 1181 |
-
return attention_mask
|
| 1182 |
if attention_mask is not None and 0.0 in attention_mask:
|
| 1183 |
return attention_mask
|
| 1184 |
return None
|
|
|
|
| 63 |
|
| 64 |
try:
|
| 65 |
moreh_ops = torch.ops.moreh
|
|
|
|
| 66 |
logger.warning_once("Using moreh ops")
|
| 67 |
except AttributeError:
|
|
|
|
| 68 |
logger.warning_once("Failed to import moreh ops")
|
| 69 |
|
| 70 |
#_CHECKPOINT_FOR_DOC = "moreh/Motif-102B"
|
|
|
|
| 595 |
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 596 |
# cast them back in float16 just to be sure everything works as expected.
|
| 597 |
input_dtype = query_states.dtype
|
| 598 |
+
if input_dtype == torch.float32:
|
| 599 |
if torch.is_autocast_enabled():
|
| 600 |
target_dtype = torch.get_autocast_gpu_dtype()
|
| 601 |
# Handle the case where the model is quantized
|
|
|
|
| 622 |
value_states = value_states.transpose(1, 2)
|
| 623 |
|
| 624 |
if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None
|
| 625 |
+
and self.layer_idx >= self.config.max_window_layers):
|
| 626 |
sliding_window = self.config.sliding_window
|
| 627 |
else:
|
| 628 |
sliding_window = None
|
|
|
|
| 1175 |
output_attentions: bool,
|
| 1176 |
):
|
| 1177 |
if self.config._attn_implementation == "flash_attention_2":
|
|
|
|
|
|
|
| 1178 |
if attention_mask is not None and 0.0 in attention_mask:
|
| 1179 |
return attention_mask
|
| 1180 |
return None
|