Update modeling_motif.py
Browse files- modeling_motif.py +3 -5
modeling_motif.py
CHANGED
|
@@ -63,12 +63,10 @@ if is_flash_attn_2_available():
|
|
| 63 |
|
| 64 |
try:
|
| 65 |
moreh_ops = torch.ops.moreh
|
| 66 |
-
MorehRMSNorm = moreh_ops.T5LayerNorm
|
| 67 |
ScaledDotProductAttention = moreh_ops.scaled_dot_product_attention
|
| 68 |
MorehFlashAttention = moreh_ops.flash_attention
|
| 69 |
logger.warning_once("Using moreh ops")
|
| 70 |
except AttributeError:
|
| 71 |
-
MorehRMSNorm = None
|
| 72 |
ScaledDotProductAttention = None
|
| 73 |
MorehFlashAttention = None
|
| 74 |
logger.warning_once("Failed to import moreh ops")
|
|
@@ -100,7 +98,7 @@ class MotifRMSNorm(nn.Module):
|
|
| 100 |
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 101 |
|
| 102 |
|
| 103 |
-
ALL_LAYERNORM_LAYERS.append(MotifRMSNorm
|
| 104 |
|
| 105 |
|
| 106 |
class MotifRotaryEmbeddingWithCache(nn.Module):
|
|
@@ -813,7 +811,7 @@ class MotifDecoderLayer(nn.Module):
|
|
| 813 |
self.self_attn = MOTIF_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
| 814 |
self.mlp = MotifMLP(config)
|
| 815 |
|
| 816 |
-
RMSNorm = MorehRMSNorm
|
| 817 |
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 818 |
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 819 |
|
|
@@ -1051,7 +1049,7 @@ class MotifModel(MotifPreTrainedModel):
|
|
| 1051 |
num_hidden_layers = config.num_hidden_layers if self.multi_token_heads is None else config.num_hidden_layers - 1
|
| 1052 |
self.layers = nn.ModuleList([MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
|
| 1053 |
self._attn_implementation = config._attn_implementation
|
| 1054 |
-
RMSNorm = MorehRMSNorm
|
| 1055 |
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1056 |
self.hidden_size = config.hidden_size
|
| 1057 |
self.num_heads = config.num_attention_heads
|
|
|
|
| 63 |
|
| 64 |
try:
|
| 65 |
moreh_ops = torch.ops.moreh
|
|
|
|
| 66 |
ScaledDotProductAttention = moreh_ops.scaled_dot_product_attention
|
| 67 |
MorehFlashAttention = moreh_ops.flash_attention
|
| 68 |
logger.warning_once("Using moreh ops")
|
| 69 |
except AttributeError:
|
|
|
|
| 70 |
ScaledDotProductAttention = None
|
| 71 |
MorehFlashAttention = None
|
| 72 |
logger.warning_once("Failed to import moreh ops")
|
|
|
|
| 98 |
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 99 |
|
| 100 |
|
| 101 |
+
ALL_LAYERNORM_LAYERS.append(MotifRMSNorm)
|
| 102 |
|
| 103 |
|
| 104 |
class MotifRotaryEmbeddingWithCache(nn.Module):
|
|
|
|
| 811 |
self.self_attn = MOTIF_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
| 812 |
self.mlp = MotifMLP(config)
|
| 813 |
|
| 814 |
+
RMSNorm = MorehRMSNorm
|
| 815 |
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 816 |
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 817 |
|
|
|
|
| 1049 |
num_hidden_layers = config.num_hidden_layers if self.multi_token_heads is None else config.num_hidden_layers - 1
|
| 1050 |
self.layers = nn.ModuleList([MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
|
| 1051 |
self._attn_implementation = config._attn_implementation
|
| 1052 |
+
RMSNorm = MorehRMSNorm
|
| 1053 |
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1054 |
self.hidden_size = config.hidden_size
|
| 1055 |
self.num_heads = config.num_attention_heads
|