Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +1 -2
modeling_esm_plusplus.py
CHANGED
|
@@ -35,7 +35,6 @@ except ImportError:
|
|
| 35 |
from embedding_mixin import EmbeddingMixin, Pooler
|
| 36 |
|
| 37 |
|
| 38 |
-
|
| 39 |
def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
|
| 40 |
assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
|
| 41 |
token_valid = attention_mask_2d.bool()
|
|
@@ -823,7 +822,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
|
|
| 823 |
if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0:
|
| 824 |
pooling_types = kwargs['pooling_types']
|
| 825 |
else:
|
| 826 |
-
pooling_types = ['
|
| 827 |
self.pooler = Pooler(pooling_types)
|
| 828 |
self.init_weights()
|
| 829 |
|
|
|
|
| 35 |
from embedding_mixin import EmbeddingMixin, Pooler
|
| 36 |
|
| 37 |
|
|
|
|
| 38 |
def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
|
| 39 |
assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
|
| 40 |
token_valid = attention_mask_2d.bool()
|
|
|
|
| 822 |
if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0:
|
| 823 |
pooling_types = kwargs['pooling_types']
|
| 824 |
else:
|
| 825 |
+
pooling_types = ['mean', 'var']
|
| 826 |
self.pooler = Pooler(pooling_types)
|
| 827 |
self.init_weights()
|
| 828 |
|