lhallee commited on
Commit
019fa36
·
verified ·
1 Parent(s): ace7399

Upload modeling_dplm2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dplm2.py +3 -3
modeling_dplm2.py CHANGED
@@ -1139,7 +1139,7 @@ class DPLM2ForMaskedLM(DPLM2PreTrainedModel, EmbeddingMixin):
1139
  self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
1140
 
1141
  def get_input_embeddings(self) -> nn.Module:
1142
- return self.esm.embeddings.word_embeddings
1143
 
1144
  def get_output_embeddings(self):
1145
  return self.lm_head.decoder
@@ -1238,7 +1238,7 @@ class DPLM2ForSequenceClassification(DPLM2PreTrainedModel, EmbeddingMixin):
1238
  self.post_init()
1239
 
1240
  def get_input_embeddings(self) -> nn.Module:
1241
- return self.esm.embeddings.word_embeddings
1242
 
1243
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1244
  return self.esm._embed(input_ids, attention_mask)
@@ -1314,7 +1314,7 @@ class DPLM2ForTokenClassification(DPLM2PreTrainedModel, EmbeddingMixin):
1314
  self.post_init()
1315
 
1316
  def get_input_embeddings(self) -> nn.Module:
1317
- return self.esm.embeddings.word_embeddings
1318
 
1319
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1320
  return self.esm._embed(input_ids, attention_mask)
 
1139
  self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
1140
 
1141
  def get_input_embeddings(self) -> nn.Module:
1142
+ return self.esm.get_input_embeddings()
1143
 
1144
  def get_output_embeddings(self):
1145
  return self.lm_head.decoder
 
1238
  self.post_init()
1239
 
1240
  def get_input_embeddings(self) -> nn.Module:
1241
+ return self.esm.get_input_embeddings()
1242
 
1243
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1244
  return self.esm._embed(input_ids, attention_mask)
 
1314
  self.post_init()
1315
 
1316
  def get_input_embeddings(self) -> nn.Module:
1317
+ return self.esm.get_input_embeddings()
1318
 
1319
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1320
  return self.esm._embed(input_ids, attention_mask)