dongseokmotif commited on
Commit
a03ff9a
·
verified ·
1 Parent(s): ddfcf5f

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +2 -6
modeling_motif.py CHANGED
@@ -38,7 +38,6 @@ if is_flash_attn_2_available():
38
 
39
  import einops
40
 
41
- MorehFlashAttention = None
42
  try:
43
  kernelRMSNorm = activation.layers.RMSNorm
44
  PolyNormKernel = activation.layers.PolyNorm
@@ -46,7 +45,7 @@ try:
46
  except AttributeError:
47
  kernelRMSNorm = None
48
  PolyNormKernel = None
49
- logger.warning_once("Failed to import moreh ops")
50
 
51
  _CONFIG_FOR_DOC = "MotifConfig"
52
 
@@ -617,7 +616,7 @@ class MotifFlashAttention2(MotifAttention):
617
  # therefore the input hidden states gets silently casted in float32. Hence, we need
618
  # cast them back in float16 just to be sure everything works as expected.
619
  input_dtype = query_states.dtype
620
- if input_dtype == torch.float32 and MorehFlashAttention is None:
621
  if torch.is_autocast_enabled():
622
  target_dtype = torch.get_autocast_gpu_dtype()
623
  # Handle the case where the model is quantized
@@ -648,7 +647,6 @@ class MotifFlashAttention2(MotifAttention):
648
  self.config.use_sliding_window
649
  and getattr(self.config, "sliding_window", None) is not None
650
  and self.layer_idx >= self.config.max_window_layers
651
- and MorehFlashAttention is None
652
  ):
653
  sliding_window = self.config.sliding_window
654
  else:
@@ -1241,8 +1239,6 @@ class MotifModel(MotifPreTrainedModel):
1241
  output_attentions: bool,
1242
  ):
1243
  if self.config._attn_implementation == "flash_attention_2":
1244
- if MorehFlashAttention is not None:
1245
- return attention_mask
1246
  if attention_mask is not None and 0.0 in attention_mask:
1247
  return attention_mask
1248
  return None
 
38
 
39
  import einops
40
 
 
41
  try:
42
  kernelRMSNorm = activation.layers.RMSNorm
43
  PolyNormKernel = activation.layers.PolyNorm
 
45
  except AttributeError:
46
  kernelRMSNorm = None
47
  PolyNormKernel = None
48
+ logger.warning_once("Failed to import kernel ops")
49
 
50
  _CONFIG_FOR_DOC = "MotifConfig"
51
 
 
616
  # therefore the input hidden states gets silently casted in float32. Hence, we need
617
  # cast them back in float16 just to be sure everything works as expected.
618
  input_dtype = query_states.dtype
619
+ if input_dtype == torch.float32:
620
  if torch.is_autocast_enabled():
621
  target_dtype = torch.get_autocast_gpu_dtype()
622
  # Handle the case where the model is quantized
 
647
  self.config.use_sliding_window
648
  and getattr(self.config, "sliding_window", None) is not None
649
  and self.layer_idx >= self.config.max_window_layers
 
650
  ):
651
  sliding_window = self.config.sliding_window
652
  else:
 
1239
  output_attentions: bool,
1240
  ):
1241
  if self.config._attn_implementation == "flash_attention_2":
 
 
1242
  if attention_mask is not None and 0.0 in attention_mask:
1243
  return attention_mask
1244
  return None