Upload modeling_dplm2.py with huggingface_hub
Browse files- modeling_dplm2.py +4 -1
modeling_dplm2.py
CHANGED
|
@@ -12,7 +12,7 @@ from torch.nn import functional as F
|
|
| 12 |
from dataclasses import dataclass
|
| 13 |
from typing import Dict, List, Optional, Tuple, Union
|
| 14 |
|
| 15 |
-
from transformers import EsmTokenizer
|
| 16 |
from transformers.modeling_outputs import (
|
| 17 |
BaseModelOutputWithPastAndCrossAttentions,
|
| 18 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
@@ -694,6 +694,9 @@ class DPLM2ForMaskedLM(DPLM2PreTrainedModel, EmbeddingMixin):
|
|
| 694 |
self.loss_fct = nn.CrossEntropyLoss()
|
| 695 |
self.post_init()
|
| 696 |
self.pad_id = config.pad_token_id
|
|
|
|
|
|
|
|
|
|
| 697 |
|
| 698 |
def get_input_embeddings(self) -> nn.Module:
|
| 699 |
return self.esm.embeddings.word_embeddings
|
|
|
|
| 12 |
from dataclasses import dataclass
|
| 13 |
from typing import Dict, List, Optional, Tuple, Union
|
| 14 |
|
| 15 |
+
from transformers import AutoTokenizer, EsmTokenizer
|
| 16 |
from transformers.modeling_outputs import (
|
| 17 |
BaseModelOutputWithPastAndCrossAttentions,
|
| 18 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
|
|
| 694 |
self.loss_fct = nn.CrossEntropyLoss()
|
| 695 |
self.post_init()
|
| 696 |
self.pad_id = config.pad_token_id
|
| 697 |
+
self.tokenizer = self.__class__.tokenizer
|
| 698 |
+
if isinstance(config._name_or_path, str) and len(config._name_or_path) > 0:
|
| 699 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
|
| 700 |
|
| 701 |
def get_input_embeddings(self) -> nn.Module:
|
| 702 |
return self.esm.embeddings.word_embeddings
|