Eunhwan Park commited on
Commit
4293a01
·
verified ·
1 Parent(s): 0ff3917

Update modeling_motif.py

Browse files

Remove MorehFlashAttention

Files changed (1) hide show
  1. modeling_motif.py +2 -6
modeling_motif.py CHANGED
@@ -63,10 +63,8 @@ if is_flash_attn_2_available():
63
 
64
  try:
65
  moreh_ops = torch.ops.moreh
66
- MorehFlashAttention = moreh_ops.flash_attention
67
  logger.warning_once("Using moreh ops")
68
  except AttributeError:
69
- MorehFlashAttention = None
70
  logger.warning_once("Failed to import moreh ops")
71
 
72
  #_CHECKPOINT_FOR_DOC = "moreh/Motif-102B"
@@ -597,7 +595,7 @@ class MotifFlashAttention2(MotifAttention):
597
  # therefore the input hidden states gets silently casted in float32. Hence, we need
598
  # cast them back in float16 just to be sure everything works as expected.
599
  input_dtype = query_states.dtype
600
- if input_dtype == torch.float32 and MorehFlashAttention is None:
601
  if torch.is_autocast_enabled():
602
  target_dtype = torch.get_autocast_gpu_dtype()
603
  # Handle the case where the model is quantized
@@ -624,7 +622,7 @@ class MotifFlashAttention2(MotifAttention):
624
  value_states = value_states.transpose(1, 2)
625
 
626
  if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None
627
- and self.layer_idx >= self.config.max_window_layers and MorehFlashAttention is None):
628
  sliding_window = self.config.sliding_window
629
  else:
630
  sliding_window = None
@@ -1177,8 +1175,6 @@ class MotifModel(MotifPreTrainedModel):
1177
  output_attentions: bool,
1178
  ):
1179
  if self.config._attn_implementation == "flash_attention_2":
1180
- if MorehFlashAttention is not None:
1181
- return attention_mask
1182
  if attention_mask is not None and 0.0 in attention_mask:
1183
  return attention_mask
1184
  return None
 
63
 
64
  try:
65
  moreh_ops = torch.ops.moreh
 
66
  logger.warning_once("Using moreh ops")
67
  except AttributeError:
 
68
  logger.warning_once("Failed to import moreh ops")
69
 
70
  #_CHECKPOINT_FOR_DOC = "moreh/Motif-102B"
 
595
  # therefore the input hidden states gets silently casted in float32. Hence, we need
596
  # cast them back in float16 just to be sure everything works as expected.
597
  input_dtype = query_states.dtype
598
+ if input_dtype == torch.float32:
599
  if torch.is_autocast_enabled():
600
  target_dtype = torch.get_autocast_gpu_dtype()
601
  # Handle the case where the model is quantized
 
622
  value_states = value_states.transpose(1, 2)
623
 
624
  if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None
625
+ and self.layer_idx >= self.config.max_window_layers):
626
  sliding_window = self.config.sliding_window
627
  else:
628
  sliding_window = None
 
1175
  output_attentions: bool,
1176
  ):
1177
  if self.config._attn_implementation == "flash_attention_2":
 
 
1178
  if attention_mask is not None and 0.0 in attention_mask:
1179
  return attention_mask
1180
  return None