Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +5 -5
modeling_esm_plusplus.py
CHANGED
|
@@ -796,7 +796,7 @@ class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin):
|
|
| 796 |
"""
|
| 797 |
config_class = ESMplusplusConfig
|
| 798 |
def __init__(self, config: ESMplusplusConfig, **kwargs):
|
| 799 |
-
|
| 800 |
self.config = config
|
| 801 |
self.vocab_size = config.vocab_size
|
| 802 |
self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
|
|
@@ -849,7 +849,7 @@ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
|
|
| 849 |
"""
|
| 850 |
config_class = ESMplusplusConfig
|
| 851 |
def __init__(self, config: ESMplusplusConfig, **kwargs):
|
| 852 |
-
|
| 853 |
self.config = config
|
| 854 |
self.vocab_size = config.vocab_size
|
| 855 |
self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
|
|
@@ -923,7 +923,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
|
|
| 923 |
Extends the base ESM++ model with a classification head.
|
| 924 |
"""
|
| 925 |
def __init__(self, config: ESMplusplusConfig, **kwargs):
|
| 926 |
-
|
| 927 |
self.config = config
|
| 928 |
self.num_labels = config.num_labels
|
| 929 |
self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
|
|
@@ -1007,8 +1007,8 @@ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
|
|
| 1007 |
ESM++ model for token classification.
|
| 1008 |
Extends the base ESM++ model with a token classification head.
|
| 1009 |
"""
|
| 1010 |
-
def __init__(self, config: ESMplusplusConfig):
|
| 1011 |
-
|
| 1012 |
self.config = config
|
| 1013 |
self.num_labels = config.num_labels
|
| 1014 |
self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
|
|
|
|
| 796 |
"""
|
| 797 |
config_class = ESMplusplusConfig
|
| 798 |
def __init__(self, config: ESMplusplusConfig, **kwargs):
|
| 799 |
+
PreTrainedESMplusplusModel.__init__(self, config, **kwargs)
|
| 800 |
self.config = config
|
| 801 |
self.vocab_size = config.vocab_size
|
| 802 |
self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
|
|
|
|
| 849 |
"""
|
| 850 |
config_class = ESMplusplusConfig
|
| 851 |
def __init__(self, config: ESMplusplusConfig, **kwargs):
|
| 852 |
+
PreTrainedESMplusplusModel.__init__(self, config, **kwargs)
|
| 853 |
self.config = config
|
| 854 |
self.vocab_size = config.vocab_size
|
| 855 |
self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
|
|
|
|
| 923 |
Extends the base ESM++ model with a classification head.
|
| 924 |
"""
|
| 925 |
def __init__(self, config: ESMplusplusConfig, **kwargs):
|
| 926 |
+
ESMplusplusForMaskedLM.__init__(self, config, **kwargs)
|
| 927 |
self.config = config
|
| 928 |
self.num_labels = config.num_labels
|
| 929 |
self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
|
|
|
|
| 1007 |
ESM++ model for token classification.
|
| 1008 |
Extends the base ESM++ model with a token classification head.
|
| 1009 |
"""
|
| 1010 |
+
def __init__(self, config: ESMplusplusConfig, **kwargs):
|
| 1011 |
+
ESMplusplusForMaskedLM.__init__(self, config, **kwargs)
|
| 1012 |
self.config = config
|
| 1013 |
self.num_labels = config.num_labels
|
| 1014 |
self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
|