Update modeling_motif.py
Browse files- modeling_motif.py +2 -6
modeling_motif.py
CHANGED
|
@@ -38,7 +38,6 @@ if is_flash_attn_2_available():
|
|
| 38 |
|
| 39 |
import einops
|
| 40 |
|
| 41 |
-
MorehFlashAttention = None
|
| 42 |
try:
|
| 43 |
kernelRMSNorm = activation.layers.RMSNorm
|
| 44 |
PolyNormKernel = activation.layers.PolyNorm
|
|
@@ -46,7 +45,7 @@ try:
|
|
| 46 |
except AttributeError:
|
| 47 |
kernelRMSNorm = None
|
| 48 |
PolyNormKernel = None
|
| 49 |
-
logger.warning_once("Failed to import
|
| 50 |
|
| 51 |
_CONFIG_FOR_DOC = "MotifConfig"
|
| 52 |
|
|
@@ -617,7 +616,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 617 |
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 618 |
# cast them back in float16 just to be sure everything works as expected.
|
| 619 |
input_dtype = query_states.dtype
|
| 620 |
-
if input_dtype == torch.float32
|
| 621 |
if torch.is_autocast_enabled():
|
| 622 |
target_dtype = torch.get_autocast_gpu_dtype()
|
| 623 |
# Handle the case where the model is quantized
|
|
@@ -648,7 +647,6 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 648 |
self.config.use_sliding_window
|
| 649 |
and getattr(self.config, "sliding_window", None) is not None
|
| 650 |
and self.layer_idx >= self.config.max_window_layers
|
| 651 |
-
and MorehFlashAttention is None
|
| 652 |
):
|
| 653 |
sliding_window = self.config.sliding_window
|
| 654 |
else:
|
|
@@ -1241,8 +1239,6 @@ class MotifModel(MotifPreTrainedModel):
|
|
| 1241 |
output_attentions: bool,
|
| 1242 |
):
|
| 1243 |
if self.config._attn_implementation == "flash_attention_2":
|
| 1244 |
-
if MorehFlashAttention is not None:
|
| 1245 |
-
return attention_mask
|
| 1246 |
if attention_mask is not None and 0.0 in attention_mask:
|
| 1247 |
return attention_mask
|
| 1248 |
return None
|
|
|
|
| 38 |
|
| 39 |
import einops
|
| 40 |
|
|
|
|
| 41 |
try:
|
| 42 |
kernelRMSNorm = activation.layers.RMSNorm
|
| 43 |
PolyNormKernel = activation.layers.PolyNorm
|
|
|
|
| 45 |
except AttributeError:
|
| 46 |
kernelRMSNorm = None
|
| 47 |
PolyNormKernel = None
|
| 48 |
+
logger.warning_once("Failed to import kernel ops")
|
| 49 |
|
| 50 |
_CONFIG_FOR_DOC = "MotifConfig"
|
| 51 |
|
|
|
|
| 616 |
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 617 |
# cast them back in float16 just to be sure everything works as expected.
|
| 618 |
input_dtype = query_states.dtype
|
| 619 |
+
if input_dtype == torch.float32:
|
| 620 |
if torch.is_autocast_enabled():
|
| 621 |
target_dtype = torch.get_autocast_gpu_dtype()
|
| 622 |
# Handle the case where the model is quantized
|
|
|
|
| 647 |
self.config.use_sliding_window
|
| 648 |
and getattr(self.config, "sliding_window", None) is not None
|
| 649 |
and self.layer_idx >= self.config.max_window_layers
|
|
|
|
| 650 |
):
|
| 651 |
sliding_window = self.config.sliding_window
|
| 652 |
else:
|
|
|
|
| 1239 |
output_attentions: bool,
|
| 1240 |
):
|
| 1241 |
if self.config._attn_implementation == "flash_attention_2":
|
|
|
|
|
|
|
| 1242 |
if attention_mask is not None and 0.0 in attention_mask:
|
| 1243 |
return attention_mask
|
| 1244 |
return None
|