oweller2 commited on
Commit ·
0953ea5
1
Parent(s): 64c9f71
updates
Browse files- modeling_flexbert.py +7 -9
modeling_flexbert.py
CHANGED
|
@@ -64,14 +64,13 @@ from transformers.modeling_outputs import (
|
|
| 64 |
ModelOutput,
|
| 65 |
MultipleChoiceModelOutput,
|
| 66 |
SequenceClassifierOutput,
|
| 67 |
-
CausalLMOutput,
|
| 68 |
)
|
| 69 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
| 70 |
-
|
| 71 |
from .bert_padding import index_put_first_axis
|
| 72 |
|
| 73 |
-
from .
|
| 74 |
-
from .
|
| 75 |
FlexBertPaddedAttention,
|
| 76 |
FlexBertPaddedParallelAttention,
|
| 77 |
FlexBertPaddedRopeAttention,
|
|
@@ -81,15 +80,15 @@ from .bert_layers.attention import (
|
|
| 81 |
FlexBertUnpadRopeAttention,
|
| 82 |
FlexBertUnpadRopeParallelAttention,
|
| 83 |
)
|
| 84 |
-
from .
|
| 85 |
-
from .
|
| 86 |
BertAlibiEmbeddings,
|
| 87 |
FlexBertAbsoluteEmbeddings,
|
| 88 |
FlexBertCompiledSansPositionEmbeddings,
|
| 89 |
FlexBertSansPositionEmbeddings,
|
| 90 |
get_embedding_layer,
|
| 91 |
)
|
| 92 |
-
from .
|
| 93 |
ModuleType,
|
| 94 |
TileLinear,
|
| 95 |
TileMode,
|
|
@@ -98,7 +97,7 @@ from .bert_layers.initialization import (
|
|
| 98 |
tile_linear,
|
| 99 |
tile_norm,
|
| 100 |
)
|
| 101 |
-
from .
|
| 102 |
BertAlibiEncoder,
|
| 103 |
BertPooler,
|
| 104 |
BertPredictionHeadTransform,
|
|
@@ -113,7 +112,6 @@ from .bert_layers.layers import (
|
|
| 113 |
FlexBertUnpadPreNormLayer,
|
| 114 |
get_encoder_layer,
|
| 115 |
)
|
| 116 |
-
from .bert_layers.loss import get_loss_fn
|
| 117 |
from .mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
|
| 118 |
from .normalization import get_norm_layer
|
| 119 |
from .padding import pad_input, unpad_input
|
|
|
|
| 64 |
ModelOutput,
|
| 65 |
MultipleChoiceModelOutput,
|
| 66 |
SequenceClassifierOutput,
|
|
|
|
| 67 |
)
|
| 68 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
| 69 |
+
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
|
| 70 |
from .bert_padding import index_put_first_axis
|
| 71 |
|
| 72 |
+
from .activation import get_act_fn
|
| 73 |
+
from .attention import (
|
| 74 |
FlexBertPaddedAttention,
|
| 75 |
FlexBertPaddedParallelAttention,
|
| 76 |
FlexBertPaddedRopeAttention,
|
|
|
|
| 80 |
FlexBertUnpadRopeAttention,
|
| 81 |
FlexBertUnpadRopeParallelAttention,
|
| 82 |
)
|
| 83 |
+
from .configuration_bert import FlexBertConfig
|
| 84 |
+
from .embeddings import (
|
| 85 |
BertAlibiEmbeddings,
|
| 86 |
FlexBertAbsoluteEmbeddings,
|
| 87 |
FlexBertCompiledSansPositionEmbeddings,
|
| 88 |
FlexBertSansPositionEmbeddings,
|
| 89 |
get_embedding_layer,
|
| 90 |
)
|
| 91 |
+
from .initialization import (
|
| 92 |
ModuleType,
|
| 93 |
TileLinear,
|
| 94 |
TileMode,
|
|
|
|
| 97 |
tile_linear,
|
| 98 |
tile_norm,
|
| 99 |
)
|
| 100 |
+
from .layers import (
|
| 101 |
BertAlibiEncoder,
|
| 102 |
BertPooler,
|
| 103 |
BertPredictionHeadTransform,
|
|
|
|
| 112 |
FlexBertUnpadPreNormLayer,
|
| 113 |
get_encoder_layer,
|
| 114 |
)
|
|
|
|
| 115 |
from .mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
|
| 116 |
from .normalization import get_norm_layer
|
| 117 |
from .padding import pad_input, unpad_input
|