lhallee commited on
Commit
d7e2218
·
verified ·
1 Parent(s): 5a6ced7

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +222 -6
modeling_fastesm.py CHANGED
@@ -4,12 +4,10 @@ import os
4
  import warnings
5
  import networkx as nx
6
  from torch.nn import functional as F
7
- from torch.utils.data import Dataset as TorchDataset
8
- from torch.utils.data import DataLoader as DataLoader
9
- from typing import Optional, Tuple, Union, Callable, List, Dict, Any
10
  from einops import rearrange
11
  from dataclasses import dataclass
12
- from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer, PreTrainedTokenizerBase
13
  from transformers.modeling_outputs import (
14
  ModelOutput,
15
  BaseModelOutputWithPastAndCrossAttentions,
@@ -25,8 +23,8 @@ from transformers.models.esm.modeling_esm import (
25
  EsmSelfOutput,
26
  EsmClassificationHead,
27
  )
28
- from tqdm.auto import tqdm
29
- from embedding_mixin import EmbeddingMixin, Pooler
30
 
31
  try:
32
  from torch.nn.attention.flex_attention import create_block_mask
@@ -586,6 +584,224 @@ class EsmEncoder(nn.Module):
586
  )
587
 
588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
590
  def __init__(self, config, **kwargs):
591
  FastEsmPreTrainedModel.__init__(self, config, **kwargs)
 
4
  import warnings
5
  import networkx as nx
6
  from torch.nn import functional as F
7
+ from typing import Optional, Tuple, Union, Dict, Any
 
 
8
  from einops import rearrange
9
  from dataclasses import dataclass
10
+ from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
11
  from transformers.modeling_outputs import (
12
  ModelOutput,
13
  BaseModelOutputWithPastAndCrossAttentions,
 
23
  EsmSelfOutput,
24
  EsmClassificationHead,
25
  )
26
+
27
+ from .embedding_mixin import EmbeddingMixin
28
 
29
  try:
30
  from torch.nn.attention.flex_attention import create_block_mask
 
584
  )
585
 
586
 
587
+ class FastEsmPreTrainedModel(PreTrainedModel):
588
+ """
589
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
590
+ models.
591
+ """
592
+ config_class = FastEsmConfig
593
+ base_model_prefix = "fastesm"
594
+ supports_gradient_checkpointing = True
595
+ tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
596
+ all_tied_weights_keys = {}
597
+
598
+ def _init_weights(self, module):
599
+ """Initialize the weights"""
600
+ if isinstance(module, nn.Linear):
601
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
602
+ if module.bias is not None:
603
+ module.bias.data.zero_()
604
+ elif isinstance(module, nn.Embedding):
605
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
606
+ if module.padding_idx is not None:
607
+ module.weight.data[module.padding_idx].zero_()
608
+ elif isinstance(module, nn.LayerNorm):
609
+ if module.bias is not None:
610
+ module.bias.data.zero_()
611
+ module.weight.data.fill_(1.0)
612
+
613
+ def get_input_embeddings(self) -> nn.Module:
614
+ try:
615
+ return self.embeddings.word_embeddings
616
+ except AttributeError:
617
+ return self.esm.embeddings.word_embeddings
618
+
619
+
620
+ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
621
+ def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
622
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
623
+ self.config = config
624
+ self.embeddings = EsmEmbeddings(config)
625
+ self.encoder = EsmEncoder(config)
626
+ self.contact_head = EsmContactPredictionHead(
627
+ in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
628
+ )
629
+ # Initialize weights and apply final processing
630
+ self.post_init()
631
+
632
+ def get_input_embeddings(self):
633
+ return self.embeddings.word_embeddings
634
+
635
+ def set_input_embeddings(self, value):
636
+ self.embeddings.word_embeddings = value
637
+
638
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
639
+ token_embedding_output = self.embeddings(input_ids, attention_mask=attention_mask)
640
+ batch_size, seq_length = input_ids.shape
641
+ if attention_mask is not None:
642
+ extended_attention_mask = attention_mask[:, None, None, :].expand(
643
+ batch_size, 1, seq_length, seq_length
644
+ ).bool()
645
+ else:
646
+ extended_attention_mask = None
647
+ encoder_outputs = self.encoder(
648
+ token_embedding_output,
649
+ attention_mask=extended_attention_mask,
650
+ output_hidden_states=False,
651
+ output_attentions=False,
652
+ )
653
+ return encoder_outputs.last_hidden_state
654
+
655
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
656
+ attns = self(input_ids, attention_mask=attention_mask, output_attentions=True).attentions
657
+ attns = torch.stack(attns, dim=1)
658
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
659
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
660
+ return self.contact_head(input_ids, attns)
661
+
662
+ def forward(
663
+ self,
664
+ input_ids: Optional[torch.Tensor] = None,
665
+ attention_mask: Optional[torch.Tensor] = None,
666
+ position_ids: Optional[torch.Tensor] = None,
667
+ inputs_embeds: Optional[torch.Tensor] = None,
668
+ output_attentions: Optional[bool] = None,
669
+ output_hidden_states: Optional[bool] = None,
670
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
671
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
672
+ """Forward pass for base model.
673
+
674
+ Args:
675
+ input_ids: Input token IDs
676
+ attention_mask: Optional attention mask
677
+ position_ids: Optional position IDs
678
+ inputs_embeds: Optional input embeddings
679
+ output_hidden_states: Whether to return all hidden states
680
+ output_attentions: Whether to return attention weights
681
+
682
+ Returns:
683
+ Model outputs including hidden states and optionally attention weights
684
+ """
685
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
686
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
687
+
688
+ if input_ids is not None and inputs_embeds is not None:
689
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
690
+ elif input_ids is not None:
691
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
692
+ input_shape = input_ids.size()
693
+ elif inputs_embeds is not None:
694
+ input_shape = inputs_embeds.size()[:-1]
695
+ else:
696
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
697
+
698
+ batch_size, seq_length = input_shape
699
+ token_embedding_output = self.embeddings(
700
+ input_ids=input_ids,
701
+ position_ids=position_ids,
702
+ attention_mask=attention_mask,
703
+ inputs_embeds=inputs_embeds,
704
+ )
705
+
706
+ if attention_mask is not None:
707
+ extended_attention_mask = attention_mask[:, None, None, :].expand(
708
+ batch_size, 1, seq_length, seq_length
709
+ ).bool()
710
+ else:
711
+ extended_attention_mask = None
712
+
713
+ encoder_outputs = self.encoder(
714
+ token_embedding_output,
715
+ attention_mask=extended_attention_mask,
716
+ output_hidden_states=output_hidden_states,
717
+ output_attentions=output_attentions,
718
+ )
719
+ sequence_output = encoder_outputs.last_hidden_state
720
+
721
+ return BaseModelOutputWithPoolingAndCrossAttentions(
722
+ last_hidden_state=sequence_output,
723
+ hidden_states=encoder_outputs.hidden_states,
724
+ attentions=encoder_outputs.attentions,
725
+ )
726
+
727
+
728
+ class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
729
+ def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
730
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
731
+ self.config = config
732
+ self.esm = FAST_ESM_ENCODER(config)
733
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
734
+ # Initialize weights and apply final processing
735
+ self.post_init()
736
+
737
+ def get_input_embeddings(self):
738
+ return self.embeddings.word_embeddings
739
+
740
+ def set_input_embeddings(self, value):
741
+ self.embeddings.word_embeddings = value
742
+
743
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
744
+ return self.esm._embed(input_ids, attention_mask)
745
+
746
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
747
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
748
+
749
+ def forward(
750
+ self,
751
+ input_ids: Optional[torch.Tensor] = None,
752
+ attention_mask: Optional[torch.Tensor] = None,
753
+ position_ids: Optional[torch.Tensor] = None,
754
+ inputs_embeds: Optional[torch.Tensor] = None,
755
+ output_attentions: Optional[bool] = None,
756
+ output_hidden_states: Optional[bool] = None,
757
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
758
+ **kwargs,
759
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
760
+ """Forward pass for base model.
761
+
762
+ Args:
763
+ input_ids: Input token IDs
764
+ attention_mask: Optional attention mask
765
+ position_ids: Optional position IDs
766
+ inputs_embeds: Optional input embeddings
767
+ output_hidden_states: Whether to return all hidden states
768
+ output_attentions: Whether to return attention weights
769
+
770
+ Returns:
771
+ Model outputs including hidden states and optionally attention weights
772
+ """
773
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
774
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
775
+
776
+ if input_ids is not None and inputs_embeds is not None:
777
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
778
+ elif input_ids is not None:
779
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
780
+ input_shape = input_ids.size()
781
+ elif inputs_embeds is not None:
782
+ input_shape = inputs_embeds.size()[:-1]
783
+ else:
784
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
785
+
786
+ outputs = self.esm(
787
+ input_ids,
788
+ attention_mask=attention_mask,
789
+ position_ids=position_ids,
790
+ inputs_embeds=inputs_embeds,
791
+ output_hidden_states=output_hidden_states,
792
+ output_attentions=output_attentions,
793
+ )
794
+ sequence_output = outputs.last_hidden_state
795
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
796
+
797
+ return BaseModelOutputWithPoolingAndCrossAttentions(
798
+ last_hidden_state=sequence_output,
799
+ pooler_output=pooled_output,
800
+ hidden_states=outputs.hidden_states,
801
+ attentions=outputs.attentions,
802
+ )
803
+
804
+
805
  class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
806
  def __init__(self, config, **kwargs):
807
  FastEsmPreTrainedModel.__init__(self, config, **kwargs)