fix issues
Browse files
myolmoe/modeling_myolmoe.py
CHANGED
|
@@ -16,9 +16,11 @@ from transformers.modeling_utils import PreTrainedModel
|
|
| 16 |
from transformers.models.olmoe.configuration_olmoe import OlmoeConfig
|
| 17 |
from transformers.utils import logging
|
| 18 |
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
|
| 19 |
|
| 20 |
logger = logging.get_logger(__name__)
|
| 21 |
|
|
|
|
| 22 |
|
| 23 |
def load_balancing_loss_func(
|
| 24 |
gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
|
|
|
|
| 16 |
from transformers.models.olmoe.configuration_olmoe import OlmoeConfig
|
| 17 |
from transformers.utils import logging
|
| 18 |
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 19 |
+
from transformers.generation.utils import GenerationMixin
|
| 20 |
|
| 21 |
logger = logging.get_logger(__name__)
|
| 22 |
|
| 23 |
+
ALL_LAYERNORM_LAYERS = []
|
| 24 |
|
| 25 |
def load_balancing_loss_func(
|
| 26 |
gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
|