lhallee commited on
Commit
91d5e9d
·
verified ·
1 Parent(s): 6424fe2

Upload modeling_dplm2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dplm2.py +4 -1
modeling_dplm2.py CHANGED
@@ -12,7 +12,7 @@ from torch.nn import functional as F
12
  from dataclasses import dataclass
13
  from typing import Dict, List, Optional, Tuple, Union
14
 
15
- from transformers import EsmTokenizer
16
  from transformers.modeling_outputs import (
17
  BaseModelOutputWithPastAndCrossAttentions,
18
  BaseModelOutputWithPoolingAndCrossAttentions,
@@ -694,6 +694,9 @@ class DPLM2ForMaskedLM(DPLM2PreTrainedModel, EmbeddingMixin):
694
  self.loss_fct = nn.CrossEntropyLoss()
695
  self.post_init()
696
  self.pad_id = config.pad_token_id
 
 
 
697
 
698
  def get_input_embeddings(self) -> nn.Module:
699
  return self.esm.embeddings.word_embeddings
 
12
  from dataclasses import dataclass
13
  from typing import Dict, List, Optional, Tuple, Union
14
 
15
+ from transformers import AutoTokenizer, EsmTokenizer
16
  from transformers.modeling_outputs import (
17
  BaseModelOutputWithPastAndCrossAttentions,
18
  BaseModelOutputWithPoolingAndCrossAttentions,
 
694
  self.loss_fct = nn.CrossEntropyLoss()
695
  self.post_init()
696
  self.pad_id = config.pad_token_id
697
+ self.tokenizer = self.__class__.tokenizer
698
+ if isinstance(config._name_or_path, str) and len(config._name_or_path) > 0:
699
+ self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
700
 
701
  def get_input_embeddings(self) -> nn.Module:
702
  return self.esm.embeddings.word_embeddings