leejunhyeok commited on
Commit
38eae03
·
verified ·
1 Parent(s): bdd0329

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +5 -6
modeling_motif.py CHANGED
@@ -608,7 +608,7 @@ class MotifFlashAttention2(MotifAttention):
608
  # therefore the input hidden states gets silently casted in float32. Hence, we need
609
  # cast them back in float16 just to be sure everything works as expected.
610
  input_dtype = query_states.dtype
611
- if input_dtype == torch.float32 and MorehFlashAttention is None:
612
  if torch.is_autocast_enabled():
613
  target_dtype = torch.get_autocast_gpu_dtype()
614
  # Handle the case where the model is quantized
@@ -635,7 +635,7 @@ class MotifFlashAttention2(MotifAttention):
635
  value_states = value_states.transpose(1, 2)
636
 
637
  if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None
638
- and self.layer_idx >= self.config.max_window_layers and MorehFlashAttention is None):
639
  sliding_window = self.config.sliding_window
640
  else:
641
  sliding_window = None
@@ -789,8 +789,7 @@ class MotifDecoderLayer(nn.Module):
789
  def __init__(self, config: MotifConfig, layer_idx: int):
790
  super().__init__()
791
  self.hidden_size = config.hidden_size
792
- if config.use_moreh_attention:
793
- config._attn_implementation = "flash_attention_2"
794
  if config.sliding_window and config._attn_implementation != "flash_attention_2":
795
  logger.warning_once(
796
  f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
@@ -801,7 +800,7 @@ class MotifDecoderLayer(nn.Module):
801
  self.self_attn = MOTIF_ATTENTION_CLASSES["eager"](config, layer_idx)
802
  self.mlp = MotifMLP(config)
803
 
804
- RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
805
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
806
  self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
807
 
@@ -1055,7 +1054,7 @@ class MotifModel(MotifPreTrainedModel):
1055
  MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)
1056
  ])
1057
  self._attn_implementation = config._attn_implementation
1058
- RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
1059
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1060
  self.hidden_size = config.hidden_size
1061
  self.num_heads = config.num_attention_heads
 
608
  # therefore the input hidden states gets silently casted in float32. Hence, we need
609
  # cast them back in float16 just to be sure everything works as expected.
610
  input_dtype = query_states.dtype
611
+ if input_dtype == torch.float32:
612
  if torch.is_autocast_enabled():
613
  target_dtype = torch.get_autocast_gpu_dtype()
614
  # Handle the case where the model is quantized
 
635
  value_states = value_states.transpose(1, 2)
636
 
637
  if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None
638
+ and self.layer_idx >= self.config.max_window_layers):
639
  sliding_window = self.config.sliding_window
640
  else:
641
  sliding_window = None
 
789
  def __init__(self, config: MotifConfig, layer_idx: int):
790
  super().__init__()
791
  self.hidden_size = config.hidden_size
792
+
 
793
  if config.sliding_window and config._attn_implementation != "flash_attention_2":
794
  logger.warning_once(
795
  f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
 
800
  self.self_attn = MOTIF_ATTENTION_CLASSES["eager"](config, layer_idx)
801
  self.mlp = MotifMLP(config)
802
 
803
+ RMSNorm = MotifRMSNorm
804
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
805
  self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
806
 
 
1054
  MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)
1055
  ])
1056
  self._attn_implementation = config._attn_implementation
1057
+ RMSNorm = MotifRMSNorm
1058
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1059
  self.hidden_size = config.hidden_size
1060
  self.num_heads = config.num_attention_heads