Update modeling_fastesm.py
Browse files- modeling_fastesm.py +87 -4
modeling_fastesm.py
CHANGED
|
@@ -612,12 +612,95 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
| 612 |
|
| 613 |
return embeddings_dict
|
| 614 |
|
| 615 |
-
|
|
|
|
| 616 |
def __init__(self, config, add_pooling_layer=True):
|
| 617 |
super().__init__(config)
|
| 618 |
self.config = config
|
| 619 |
self.embeddings = EsmEmbeddings(config)
|
| 620 |
self.encoder = EsmEncoder(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
self.pooler = EsmPooler(config) if add_pooling_layer else None
|
| 622 |
# Initialize weights and apply final processing
|
| 623 |
self.post_init()
|
|
@@ -703,7 +786,7 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel):
|
|
| 703 |
|
| 704 |
def __init__(self, config):
|
| 705 |
super().__init__(config)
|
| 706 |
-
self.esm =
|
| 707 |
self.lm_head = EsmLMHead(config)
|
| 708 |
self.loss_fct = nn.CrossEntropyLoss()
|
| 709 |
self.init_weights()
|
|
@@ -757,7 +840,7 @@ class FastEsmForSequenceClassification(FastEsmPreTrainedModel):
|
|
| 757 |
super().__init__(config)
|
| 758 |
self.num_labels = config.num_labels
|
| 759 |
self.config = config
|
| 760 |
-
self.esm =
|
| 761 |
self.classifier = EsmClassificationHead(config)
|
| 762 |
self.mse = nn.MSELoss()
|
| 763 |
self.ce = nn.CrossEntropyLoss()
|
|
@@ -818,7 +901,7 @@ class FastEsmForTokenClassification(FastEsmPreTrainedModel):
|
|
| 818 |
def __init__(self, config):
|
| 819 |
super().__init__(config)
|
| 820 |
self.num_labels = config.num_labels
|
| 821 |
-
self.esm =
|
| 822 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 823 |
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 824 |
self.loss_fct = nn.CrossEntropyLoss()
|
|
|
|
| 612 |
|
| 613 |
return embeddings_dict
|
| 614 |
|
| 615 |
+
|
| 616 |
+
class FAST_ESM_ENCODER(FastEsmPreTrainedModel):
|
| 617 |
def __init__(self, config, add_pooling_layer=True):
|
| 618 |
super().__init__(config)
|
| 619 |
self.config = config
|
| 620 |
self.embeddings = EsmEmbeddings(config)
|
| 621 |
self.encoder = EsmEncoder(config)
|
| 622 |
+
# Initialize weights and apply final processing
|
| 623 |
+
self.post_init()
|
| 624 |
+
|
| 625 |
+
def get_input_embeddings(self):
|
| 626 |
+
return self.embeddings.word_embeddings
|
| 627 |
+
|
| 628 |
+
def set_input_embeddings(self, value):
|
| 629 |
+
self.embeddings.word_embeddings = value
|
| 630 |
+
|
| 631 |
+
def forward(
|
| 632 |
+
self,
|
| 633 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 634 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 635 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 636 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 637 |
+
output_attentions: Optional[bool] = None,
|
| 638 |
+
output_hidden_states: Optional[bool] = None,
|
| 639 |
+
return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
|
| 640 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
| 641 |
+
"""Forward pass for base model.
|
| 642 |
+
|
| 643 |
+
Args:
|
| 644 |
+
input_ids: Input token IDs
|
| 645 |
+
attention_mask: Optional attention mask
|
| 646 |
+
position_ids: Optional position IDs
|
| 647 |
+
inputs_embeds: Optional input embeddings
|
| 648 |
+
output_hidden_states: Whether to return all hidden states
|
| 649 |
+
output_attentions: Whether to return attention weights
|
| 650 |
+
|
| 651 |
+
Returns:
|
| 652 |
+
Model outputs including hidden states and optionally attention weights
|
| 653 |
+
"""
|
| 654 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 655 |
+
output_hidden_states = (
|
| 656 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 660 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 661 |
+
elif input_ids is not None:
|
| 662 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 663 |
+
input_shape = input_ids.size()
|
| 664 |
+
elif inputs_embeds is not None:
|
| 665 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 666 |
+
else:
|
| 667 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 668 |
+
|
| 669 |
+
batch_size, seq_length = input_shape
|
| 670 |
+
embedding_output = self.embeddings(
|
| 671 |
+
input_ids=input_ids,
|
| 672 |
+
position_ids=position_ids,
|
| 673 |
+
attention_mask=attention_mask,
|
| 674 |
+
inputs_embeds=inputs_embeds,
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
if attention_mask is not None:
|
| 678 |
+
extended_attention_mask = attention_mask[:, None, None, :].expand(
|
| 679 |
+
batch_size, 1, seq_length, seq_length
|
| 680 |
+
).bool()
|
| 681 |
+
else:
|
| 682 |
+
extended_attention_mask = None
|
| 683 |
+
|
| 684 |
+
encoder_outputs = self.encoder(
|
| 685 |
+
embedding_output,
|
| 686 |
+
attention_mask=extended_attention_mask,
|
| 687 |
+
output_hidden_states=output_hidden_states,
|
| 688 |
+
output_attentions=output_attentions,
|
| 689 |
+
)
|
| 690 |
+
sequence_output = encoder_outputs.last_hidden_state
|
| 691 |
+
|
| 692 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 693 |
+
last_hidden_state=sequence_output,
|
| 694 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 695 |
+
attentions=encoder_outputs.attentions,
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
class FastEsmModel(FastEsmPreTrainedModel):
|
| 700 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 701 |
+
super().__init__(config)
|
| 702 |
+
self.config = config
|
| 703 |
+
self.esm = FAST_ESM_ENCODER(config)
|
| 704 |
self.pooler = EsmPooler(config) if add_pooling_layer else None
|
| 705 |
# Initialize weights and apply final processing
|
| 706 |
self.post_init()
|
|
|
|
| 786 |
|
| 787 |
def __init__(self, config):
|
| 788 |
super().__init__(config)
|
| 789 |
+
self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
|
| 790 |
self.lm_head = EsmLMHead(config)
|
| 791 |
self.loss_fct = nn.CrossEntropyLoss()
|
| 792 |
self.init_weights()
|
|
|
|
| 840 |
super().__init__(config)
|
| 841 |
self.num_labels = config.num_labels
|
| 842 |
self.config = config
|
| 843 |
+
self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
|
| 844 |
self.classifier = EsmClassificationHead(config)
|
| 845 |
self.mse = nn.MSELoss()
|
| 846 |
self.ce = nn.CrossEntropyLoss()
|
|
|
|
| 901 |
def __init__(self, config):
|
| 902 |
super().__init__(config)
|
| 903 |
self.num_labels = config.num_labels
|
| 904 |
+
self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
|
| 905 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 906 |
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 907 |
self.loss_fct = nn.CrossEntropyLoss()
|