Fill-Mask
Transformers
Safetensors
ESMplusplus
custom_code
lhallee commited on
Commit
bb43645
·
verified ·
1 Parent(s): 947b9b6

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +5 -1
modeling_esm_plusplus.py CHANGED
@@ -1949,7 +1949,7 @@ class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin):
1949
  s_max=transformer_output.s_max,
1950
  )
1951
 
1952
- class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
1953
  """
1954
  ESM++ model for masked language modeling.
1955
  Implements the base ESM++ architecture with a masked language modeling head.
@@ -1971,6 +1971,7 @@ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
1971
  self.ce_loss = nn.CrossEntropyLoss()
1972
  self.tokenizer = EsmSequenceTokenizer()
1973
  self.init_weights()
 
1974
 
1975
  def get_input_embeddings(self):
1976
  return self.embed
@@ -2006,6 +2007,9 @@ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
2006
  store_all_hidden_states=store_all_hidden_states,
2007
  )
2008
 
 
 
 
2009
  def forward(
2010
  self,
2011
  input_ids: Optional[torch.Tensor] = None,
 
1949
  s_max=transformer_output.s_max,
1950
  )
1951
 
1952
+ class ESMplusplusForMaskedLM(FastPLMTestTimeTrainingMixin, PreTrainedESMplusplusModel, EmbeddingMixin):
1953
  """
1954
  ESM++ model for masked language modeling.
1955
  Implements the base ESM++ architecture with a masked language modeling head.
 
1971
  self.ce_loss = nn.CrossEntropyLoss()
1972
  self.tokenizer = EsmSequenceTokenizer()
1973
  self.init_weights()
1974
+ self.init_ttt({"lora_target_replace_module": "MultiHeadAttention"})
1975
 
1976
  def get_input_embeddings(self):
1977
  return self.embed
 
2007
  store_all_hidden_states=store_all_hidden_states,
2008
  )
2009
 
2010
+ def _ttt_get_trainable_modules(self) -> list[nn.Module]:
2011
+ return [self.transformer]
2012
+
2013
  def forward(
2014
  self,
2015
  input_ids: Optional[torch.Tensor] = None,