Update modeling_fastesm.py
Browse files- modeling_fastesm.py +1 -9
modeling_fastesm.py
CHANGED
|
@@ -4,7 +4,7 @@ from torch.nn import functional as F
|
|
| 4 |
from torch.utils.data import Dataset, DataLoader
|
| 5 |
from typing import Optional, Tuple, Union
|
| 6 |
from einops import rearrange
|
| 7 |
-
from transformers import PreTrainedModel, PretrainedConfig
|
| 8 |
from transformers.modeling_outputs import (
|
| 9 |
MaskedLMOutput,
|
| 10 |
BaseModelOutputWithPastAndCrossAttentions,
|
|
@@ -145,13 +145,6 @@ class EsmEmbeddings(nn.Module):
|
|
| 145 |
def forward(
|
| 146 |
self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
| 147 |
):
|
| 148 |
-
if position_ids is None:
|
| 149 |
-
if input_ids is not None:
|
| 150 |
-
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
| 151 |
-
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
|
| 152 |
-
else:
|
| 153 |
-
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
| 154 |
-
|
| 155 |
if inputs_embeds is None:
|
| 156 |
inputs_embeds = self.word_embeddings(input_ids)
|
| 157 |
|
|
@@ -346,7 +339,6 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
| 346 |
config_class = FastEsmConfig
|
| 347 |
base_model_prefix = "fastesm"
|
| 348 |
supports_gradient_checkpointing = True
|
| 349 |
-
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 350 |
def _init_weights(self, module):
|
| 351 |
"""Initialize the weights"""
|
| 352 |
if isinstance(module, nn.Linear):
|
|
|
|
| 4 |
from torch.utils.data import Dataset, DataLoader
|
| 5 |
from typing import Optional, Tuple, Union
|
| 6 |
from einops import rearrange
|
| 7 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 8 |
from transformers.modeling_outputs import (
|
| 9 |
MaskedLMOutput,
|
| 10 |
BaseModelOutputWithPastAndCrossAttentions,
|
|
|
|
| 145 |
def forward(
|
| 146 |
self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
| 147 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
if inputs_embeds is None:
|
| 149 |
inputs_embeds = self.word_embeddings(input_ids)
|
| 150 |
|
|
|
|
| 339 |
config_class = FastEsmConfig
|
| 340 |
base_model_prefix = "fastesm"
|
| 341 |
supports_gradient_checkpointing = True
|
|
|
|
| 342 |
def _init_weights(self, module):
|
| 343 |
"""Initialize the weights"""
|
| 344 |
if isinstance(module, nn.Linear):
|