lhallee commited on
Commit
d04407b
·
verified ·
1 Parent(s): 1d76746

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +1 -2
modeling_esm_plusplus.py CHANGED
@@ -35,7 +35,6 @@ except ImportError:
35
  from embedding_mixin import EmbeddingMixin, Pooler
36
 
37
 
38
-
39
  def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
40
  assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
41
  token_valid = attention_mask_2d.bool()
@@ -823,7 +822,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
823
  if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0:
824
  pooling_types = kwargs['pooling_types']
825
  else:
826
- pooling_types = ['cls', 'mean']
827
  self.pooler = Pooler(pooling_types)
828
  self.init_weights()
829
 
 
35
  from embedding_mixin import EmbeddingMixin, Pooler
36
 
37
 
 
38
  def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
39
  assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
40
  token_valid = attention_mask_2d.bool()
 
822
  if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0:
823
  pooling_types = kwargs['pooling_types']
824
  else:
825
+ pooling_types = ['mean', 'var']
826
  self.pooler = Pooler(pooling_types)
827
  self.init_weights()
828