Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +10 -10
modeling_fastesm.py
CHANGED
|
@@ -756,8 +756,8 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
| 756 |
|
| 757 |
|
| 758 |
class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
|
| 759 |
-
def __init__(self, config, add_pooling_layer: Optional[bool] = True):
|
| 760 |
-
|
| 761 |
self.config = config
|
| 762 |
self.embeddings = EsmEmbeddings(config)
|
| 763 |
self.encoder = EsmEncoder(config)
|
|
@@ -864,8 +864,8 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
| 864 |
|
| 865 |
|
| 866 |
class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
|
| 867 |
-
def __init__(self, config, add_pooling_layer: Optional[bool] = True):
|
| 868 |
-
|
| 869 |
self.config = config
|
| 870 |
self.esm = FAST_ESM_ENCODER(config)
|
| 871 |
self.pooler = EsmPooler(config) if add_pooling_layer else None
|
|
@@ -942,8 +942,8 @@ class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
| 942 |
class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
|
| 943 |
_tied_weights_keys = ["lm_head.decoder.weight"]
|
| 944 |
|
| 945 |
-
def __init__(self, config):
|
| 946 |
-
|
| 947 |
self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
|
| 948 |
self.lm_head = EsmLMHead(config)
|
| 949 |
self.loss_fct = nn.CrossEntropyLoss()
|
|
@@ -998,8 +998,8 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
| 998 |
|
| 999 |
|
| 1000 |
class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
|
| 1001 |
-
def __init__(self, config):
|
| 1002 |
-
|
| 1003 |
self.num_labels = config.num_labels
|
| 1004 |
self.config = config
|
| 1005 |
self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
|
|
@@ -1067,8 +1067,8 @@ class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
| 1067 |
|
| 1068 |
|
| 1069 |
class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
|
| 1070 |
-
def __init__(self, config):
|
| 1071 |
-
|
| 1072 |
self.num_labels = config.num_labels
|
| 1073 |
self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
|
| 1074 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
| 756 |
|
| 757 |
|
| 758 |
class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
|
| 759 |
+
def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
|
| 760 |
+
FastEsmPreTrainedModel.__init__(self, config, **kwargs)
|
| 761 |
self.config = config
|
| 762 |
self.embeddings = EsmEmbeddings(config)
|
| 763 |
self.encoder = EsmEncoder(config)
|
|
|
|
| 864 |
|
| 865 |
|
| 866 |
class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
|
| 867 |
+
def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
|
| 868 |
+
FastEsmPreTrainedModel.__init__(self, config, **kwargs)
|
| 869 |
self.config = config
|
| 870 |
self.esm = FAST_ESM_ENCODER(config)
|
| 871 |
self.pooler = EsmPooler(config) if add_pooling_layer else None
|
|
|
|
| 942 |
class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
|
| 943 |
_tied_weights_keys = ["lm_head.decoder.weight"]
|
| 944 |
|
| 945 |
+
def __init__(self, config, **kwargs):
|
| 946 |
+
FastEsmPreTrainedModel.__init__(self, config, **kwargs)
|
| 947 |
self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
|
| 948 |
self.lm_head = EsmLMHead(config)
|
| 949 |
self.loss_fct = nn.CrossEntropyLoss()
|
|
|
|
| 998 |
|
| 999 |
|
| 1000 |
class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
|
| 1001 |
+
def __init__(self, config, **kwargs):
|
| 1002 |
+
FastEsmPreTrainedModel.__init__(self, config, **kwargs)
|
| 1003 |
self.num_labels = config.num_labels
|
| 1004 |
self.config = config
|
| 1005 |
self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
|
|
|
|
| 1067 |
|
| 1068 |
|
| 1069 |
class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
|
| 1070 |
+
def __init__(self, config, **kwargs):
|
| 1071 |
+
FastEsmPreTrainedModel.__init__(self, config, **kwargs)
|
| 1072 |
self.num_labels = config.num_labels
|
| 1073 |
self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
|
| 1074 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|