Update modeling_afmoe_scm_liger.py
Browse files
modeling_afmoe_scm_liger.py
CHANGED
|
@@ -25,7 +25,6 @@ from transformers.integrations import use_kernel_forward_from_hub
|
|
| 25 |
import scattermoe
|
| 26 |
|
| 27 |
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
| 28 |
-
from liger_kernel.transformers.rms_norm import LigerRMSNorm as AfmoeSCMRMSNorm
|
| 29 |
|
| 30 |
try:
|
| 31 |
from .configuration_afmoe_scm import AfmoeSCMConfig
|
|
@@ -141,8 +140,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
| 141 |
)
|
| 142 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 143 |
|
| 144 |
-
|
| 145 |
-
class
|
| 146 |
def __init__(self, hidden_size: int, eps: float):
|
| 147 |
"""
|
| 148 |
AfmoeSCMRMSNorm is equivalent to T5LayerNorm
|
|
|
|
| 25 |
import scattermoe
|
| 26 |
|
| 27 |
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
|
|
| 28 |
|
| 29 |
try:
|
| 30 |
from .configuration_afmoe_scm import AfmoeSCMConfig
|
|
|
|
| 140 |
)
|
| 141 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 142 |
|
| 143 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 144 |
+
class AfmoeSCMRMSNorm(nn.Module):
|
| 145 |
def __init__(self, hidden_size: int, eps: float):
|
| 146 |
"""
|
| 147 |
AfmoeSCMRMSNorm is equivalent to T5LayerNorm
|