eunhwanpark-motiftech commited on
Commit
7ec81c6
·
verified ·
1 Parent(s): ba7c576

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +3 -5
modeling_motif.py CHANGED
@@ -63,12 +63,10 @@ if is_flash_attn_2_available():
63
 
64
  try:
65
  moreh_ops = torch.ops.moreh
66
- MorehRMSNorm = moreh_ops.T5LayerNorm
67
  ScaledDotProductAttention = moreh_ops.scaled_dot_product_attention
68
  MorehFlashAttention = moreh_ops.flash_attention
69
  logger.warning_once("Using moreh ops")
70
  except AttributeError:
71
- MorehRMSNorm = None
72
  ScaledDotProductAttention = None
73
  MorehFlashAttention = None
74
  logger.warning_once("Failed to import moreh ops")
@@ -100,7 +98,7 @@ class MotifRMSNorm(nn.Module):
100
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
101
 
102
 
103
- ALL_LAYERNORM_LAYERS.append(MotifRMSNorm if MorehRMSNorm is None else MorehRMSNorm)
104
 
105
 
106
  class MotifRotaryEmbeddingWithCache(nn.Module):
@@ -813,7 +811,7 @@ class MotifDecoderLayer(nn.Module):
813
  self.self_attn = MOTIF_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
814
  self.mlp = MotifMLP(config)
815
 
816
- RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
817
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
818
  self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
819
 
@@ -1051,7 +1049,7 @@ class MotifModel(MotifPreTrainedModel):
1051
  num_hidden_layers = config.num_hidden_layers if self.multi_token_heads is None else config.num_hidden_layers - 1
1052
  self.layers = nn.ModuleList([MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
1053
  self._attn_implementation = config._attn_implementation
1054
- RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
1055
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1056
  self.hidden_size = config.hidden_size
1057
  self.num_heads = config.num_attention_heads
 
63
 
64
  try:
65
  moreh_ops = torch.ops.moreh
 
66
  ScaledDotProductAttention = moreh_ops.scaled_dot_product_attention
67
  MorehFlashAttention = moreh_ops.flash_attention
68
  logger.warning_once("Using moreh ops")
69
  except AttributeError:
 
70
  ScaledDotProductAttention = None
71
  MorehFlashAttention = None
72
  logger.warning_once("Failed to import moreh ops")
 
98
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
99
 
100
 
101
+ ALL_LAYERNORM_LAYERS.append(MotifRMSNorm)
102
 
103
 
104
  class MotifRotaryEmbeddingWithCache(nn.Module):
 
811
  self.self_attn = MOTIF_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
812
  self.mlp = MotifMLP(config)
813
 
814
+ RMSNorm = MorehRMSNorm
815
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
816
  self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
817
 
 
1049
  num_hidden_layers = config.num_hidden_layers if self.multi_token_heads is None else config.num_hidden_layers - 1
1050
  self.layers = nn.ModuleList([MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
1051
  self._attn_implementation = config._attn_implementation
1052
+ RMSNorm = MorehRMSNorm
1053
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1054
  self.hidden_size = config.hidden_size
1055
  self.num_heads = config.num_attention_heads