oweller2 commited on
Commit ·
e44547d
1
Parent(s): 6aca308
add modeling
Browse files- modeling_flexbert.py +12 -12
modeling_flexbert.py
CHANGED
|
@@ -68,10 +68,10 @@ from transformers.modeling_outputs import (
|
|
| 68 |
)
|
| 69 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
| 70 |
|
| 71 |
-
from bert_padding import index_put_first_axis
|
| 72 |
|
| 73 |
-
from
|
| 74 |
-
from
|
| 75 |
FlexBertPaddedAttention,
|
| 76 |
FlexBertPaddedParallelAttention,
|
| 77 |
FlexBertPaddedRopeAttention,
|
|
@@ -81,15 +81,15 @@ from src.bert_layers.attention import (
|
|
| 81 |
FlexBertUnpadRopeAttention,
|
| 82 |
FlexBertUnpadRopeParallelAttention,
|
| 83 |
)
|
| 84 |
-
from
|
| 85 |
-
from
|
| 86 |
BertAlibiEmbeddings,
|
| 87 |
FlexBertAbsoluteEmbeddings,
|
| 88 |
FlexBertCompiledSansPositionEmbeddings,
|
| 89 |
FlexBertSansPositionEmbeddings,
|
| 90 |
get_embedding_layer,
|
| 91 |
)
|
| 92 |
-
from
|
| 93 |
ModuleType,
|
| 94 |
TileLinear,
|
| 95 |
TileMode,
|
|
@@ -98,7 +98,7 @@ from src.bert_layers.initialization import (
|
|
| 98 |
tile_linear,
|
| 99 |
tile_norm,
|
| 100 |
)
|
| 101 |
-
from
|
| 102 |
BertAlibiEncoder,
|
| 103 |
BertPooler,
|
| 104 |
BertPredictionHeadTransform,
|
|
@@ -113,10 +113,10 @@ from src.bert_layers.layers import (
|
|
| 113 |
FlexBertUnpadPreNormLayer,
|
| 114 |
get_encoder_layer,
|
| 115 |
)
|
| 116 |
-
from
|
| 117 |
-
from
|
| 118 |
-
from
|
| 119 |
-
from
|
| 120 |
|
| 121 |
logger = logging.getLogger(__name__)
|
| 122 |
|
|
@@ -868,7 +868,7 @@ class FlexBertPreTrainedModel(BertPreTrainedModel):
|
|
| 868 |
|
| 869 |
def _init_module_weights(self, module: nn.Module):
|
| 870 |
"""
|
| 871 |
-
Custom weight init of modules using
|
| 872 |
Currently only supports init of embedding modules
|
| 873 |
"""
|
| 874 |
assert isinstance(module, nn.Module)
|
|
|
|
| 68 |
)
|
| 69 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
| 70 |
|
| 71 |
+
from .bert_padding import index_put_first_axis
|
| 72 |
|
| 73 |
+
from .bert_layers.activation import get_act_fn
|
| 74 |
+
from .bert_layers.attention import (
|
| 75 |
FlexBertPaddedAttention,
|
| 76 |
FlexBertPaddedParallelAttention,
|
| 77 |
FlexBertPaddedRopeAttention,
|
|
|
|
| 81 |
FlexBertUnpadRopeAttention,
|
| 82 |
FlexBertUnpadRopeParallelAttention,
|
| 83 |
)
|
| 84 |
+
from .bert_layers.configuration_bert import FlexBertConfig
|
| 85 |
+
from .bert_layers.embeddings import (
|
| 86 |
BertAlibiEmbeddings,
|
| 87 |
FlexBertAbsoluteEmbeddings,
|
| 88 |
FlexBertCompiledSansPositionEmbeddings,
|
| 89 |
FlexBertSansPositionEmbeddings,
|
| 90 |
get_embedding_layer,
|
| 91 |
)
|
| 92 |
+
from .bert_layers.initialization import (
|
| 93 |
ModuleType,
|
| 94 |
TileLinear,
|
| 95 |
TileMode,
|
|
|
|
| 98 |
tile_linear,
|
| 99 |
tile_norm,
|
| 100 |
)
|
| 101 |
+
from .bert_layers.layers import (
|
| 102 |
BertAlibiEncoder,
|
| 103 |
BertPooler,
|
| 104 |
BertPredictionHeadTransform,
|
|
|
|
| 113 |
FlexBertUnpadPreNormLayer,
|
| 114 |
get_encoder_layer,
|
| 115 |
)
|
| 116 |
+
from .bert_layers.loss import get_loss_fn
|
| 117 |
+
from .bert_layers.mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
|
| 118 |
+
from .bert_layers.normalization import get_norm_layer
|
| 119 |
+
from .bert_layers.padding import pad_input, unpad_input
|
| 120 |
|
| 121 |
logger = logging.getLogger(__name__)
|
| 122 |
|
|
|
|
| 868 |
|
| 869 |
def _init_module_weights(self, module: nn.Module):
|
| 870 |
"""
|
| 871 |
+
Custom weight init of modules using .bert_layers.initialization.init_weights
|
| 872 |
Currently only supports init of embedding modules
|
| 873 |
"""
|
| 874 |
assert isinstance(module, nn.Module)
|