lhallee commited on
Commit
2f7d618
·
verified ·
1 Parent(s): 496ae89

Upload modeling_dplm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dplm.py +103 -36
modeling_dplm.py CHANGED
@@ -797,46 +797,24 @@ class ModifiedEsmEncoder(EsmEncoder):
797
  )
798
 
799
 
800
- class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin):
801
- config_class = DPLMConfig
802
-
803
- def get_input_embeddings(self) -> nn.Module:
804
- return self.embeddings.word_embeddings
805
 
806
- def __init__(self, config, add_pooling_layer=True):
807
- DPLMPreTrainedModel.__init__(self, config)
808
  self.config = config
809
  self.embeddings = EsmEmbeddings(config)
810
  self.encoder = ModifiedEsmEncoder(config)
811
- self.pooler = EsmPooler(config) if add_pooling_layer else None
812
  self.contact_head = EsmContactPredictionHead(
813
  in_features=config.num_hidden_layers * config.num_attention_heads,
814
  bias=True,
815
  )
816
  self.post_init()
817
 
818
- def _convert_head_mask_to_5d(self, head_mask: torch.Tensor, num_hidden_layers: int) -> torch.Tensor:
819
- if head_mask.dim() == 1:
820
- head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
821
- head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
822
- elif head_mask.dim() == 2:
823
- head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
824
- assert head_mask.dim() == 5, f"head_mask.dim != 5, got {head_mask.dim()}"
825
- head_mask = head_mask.to(dtype=self.dtype)
826
- return head_mask
827
-
828
- def get_head_mask(
829
- self,
830
- head_mask: Optional[torch.Tensor],
831
- num_hidden_layers: int,
832
- is_attention_chunked: bool = False,
833
- ) -> Union[torch.Tensor, List[None]]:
834
- if head_mask is None:
835
- return [None] * num_hidden_layers
836
- head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
837
- if is_attention_chunked:
838
- head_mask = head_mask.unsqueeze(-1)
839
- return head_mask
840
 
841
  def set_input_embeddings(self, value):
842
  self.embeddings.word_embeddings = value
@@ -860,6 +838,29 @@ class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin):
860
  attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
861
  return self.contact_head(input_ids, attns)
862
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863
  def forward(
864
  self,
865
  input_ids: Optional[torch.Tensor] = None,
@@ -953,14 +954,12 @@ class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin):
953
  flex_block_mask=flex_block_mask,
954
  )
955
  sequence_output = encoder_outputs[0]
956
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
957
 
958
  if return_dict is False:
959
- return (sequence_output, pooled_output) + encoder_outputs[1:]
960
 
961
  return BaseModelOutputWithPoolingAndCrossAttentions(
962
  last_hidden_state=sequence_output,
963
- pooler_output=pooled_output,
964
  past_key_values=None,
965
  hidden_states=encoder_outputs.hidden_states,
966
  attentions=encoder_outputs.attentions,
@@ -968,6 +967,74 @@ class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin):
968
  )
969
 
970
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
971
  class DPLMForMaskedLM(DPLMPreTrainedModel, EmbeddingMixin):
972
  config_class = DPLMConfig
973
 
@@ -994,7 +1061,7 @@ class DPLMForMaskedLM(DPLMPreTrainedModel, EmbeddingMixin):
994
  self.contact_head = None
995
 
996
  def get_input_embeddings(self) -> nn.Module:
997
- return self.esm.embeddings.word_embeddings
998
 
999
  def get_output_embeddings(self):
1000
  return self.lm_head.decoder
@@ -1064,7 +1131,7 @@ class DPLMForSequenceClassification(DPLMPreTrainedModel, EmbeddingMixin):
1064
  config_class = DPLMConfig
1065
 
1066
  def get_input_embeddings(self) -> nn.Module:
1067
- return self.esm.embeddings.word_embeddings
1068
 
1069
  def __init__(self, config):
1070
  DPLMPreTrainedModel.__init__(self, config)
@@ -1134,7 +1201,7 @@ class DPLMForTokenClassification(DPLMPreTrainedModel, EmbeddingMixin):
1134
  config_class = DPLMConfig
1135
 
1136
  def get_input_embeddings(self) -> nn.Module:
1137
- return self.esm.embeddings.word_embeddings
1138
 
1139
  def __init__(self, config):
1140
  DPLMPreTrainedModel.__init__(self, config)
 
797
  )
798
 
799
 
800
+ class FAST_DPLM_ENCODER(DPLMPreTrainedModel, EmbeddingMixin):
801
+ """Inner encoder class that holds the actual ESM-style weights (embeddings, encoder,
802
+ contact_head) so that the weight keys are prefixed with 'esm.' in the outer DPLMModel,
803
+ matching pretrained DPLM checkpoints."""
 
804
 
805
+ def __init__(self, config, **kwargs):
806
+ DPLMPreTrainedModel.__init__(self, config, **kwargs)
807
  self.config = config
808
  self.embeddings = EsmEmbeddings(config)
809
  self.encoder = ModifiedEsmEncoder(config)
 
810
  self.contact_head = EsmContactPredictionHead(
811
  in_features=config.num_hidden_layers * config.num_attention_heads,
812
  bias=True,
813
  )
814
  self.post_init()
815
 
816
+ def get_input_embeddings(self) -> nn.Module:
817
+ return self.embeddings.word_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
818
 
819
  def set_input_embeddings(self, value):
820
  self.embeddings.word_embeddings = value
 
838
  attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
839
  return self.contact_head(input_ids, attns)
840
 
841
+ def _convert_head_mask_to_5d(self, head_mask: torch.Tensor, num_hidden_layers: int) -> torch.Tensor:
842
+ if head_mask.dim() == 1:
843
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
844
+ head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
845
+ elif head_mask.dim() == 2:
846
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
847
+ assert head_mask.dim() == 5, f"head_mask.dim != 5, got {head_mask.dim()}"
848
+ head_mask = head_mask.to(dtype=self.dtype)
849
+ return head_mask
850
+
851
+ def get_head_mask(
852
+ self,
853
+ head_mask: Optional[torch.Tensor],
854
+ num_hidden_layers: int,
855
+ is_attention_chunked: bool = False,
856
+ ) -> Union[torch.Tensor, List[None]]:
857
+ if head_mask is None:
858
+ return [None] * num_hidden_layers
859
+ head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
860
+ if is_attention_chunked:
861
+ head_mask = head_mask.unsqueeze(-1)
862
+ return head_mask
863
+
864
  def forward(
865
  self,
866
  input_ids: Optional[torch.Tensor] = None,
 
954
  flex_block_mask=flex_block_mask,
955
  )
956
  sequence_output = encoder_outputs[0]
 
957
 
958
  if return_dict is False:
959
+ return (sequence_output,) + encoder_outputs[1:]
960
 
961
  return BaseModelOutputWithPoolingAndCrossAttentions(
962
  last_hidden_state=sequence_output,
 
963
  past_key_values=None,
964
  hidden_states=encoder_outputs.hidden_states,
965
  attentions=encoder_outputs.attentions,
 
967
  )
968
 
969
 
970
+ class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin):
971
+ config_class = DPLMConfig
972
+
973
+ def __init__(self, config, add_pooling_layer=True):
974
+ DPLMPreTrainedModel.__init__(self, config)
975
+ self.config = config
976
+ self.esm = FAST_DPLM_ENCODER(config)
977
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
978
+ self.post_init()
979
+
980
+ def get_input_embeddings(self) -> nn.Module:
981
+ return self.esm.embeddings.word_embeddings
982
+
983
+ def set_input_embeddings(self, value):
984
+ self.esm.embeddings.word_embeddings = value
985
+
986
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
987
+ return self.esm._embed(input_ids, attention_mask)
988
+
989
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
990
+ return self.esm.predict_contacts(input_ids, attention_mask)
991
+
992
+ def forward(
993
+ self,
994
+ input_ids: Optional[torch.Tensor] = None,
995
+ attention_mask: Optional[torch.Tensor] = None,
996
+ position_ids: Optional[torch.Tensor] = None,
997
+ head_mask: Optional[torch.Tensor] = None,
998
+ inputs_embeds: Optional[torch.Tensor] = None,
999
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1000
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1001
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1002
+ use_cache: Optional[bool] = None,
1003
+ output_attentions: Optional[bool] = None,
1004
+ output_hidden_states: Optional[bool] = None,
1005
+ return_dict: Optional[bool] = None,
1006
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
1007
+ outputs = self.esm(
1008
+ input_ids=input_ids,
1009
+ attention_mask=attention_mask,
1010
+ position_ids=position_ids,
1011
+ head_mask=head_mask,
1012
+ inputs_embeds=inputs_embeds,
1013
+ encoder_hidden_states=encoder_hidden_states,
1014
+ encoder_attention_mask=encoder_attention_mask,
1015
+ past_key_values=past_key_values,
1016
+ use_cache=use_cache,
1017
+ output_attentions=output_attentions,
1018
+ output_hidden_states=output_hidden_states,
1019
+ return_dict=return_dict,
1020
+ )
1021
+ sequence_output = outputs[0]
1022
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1023
+
1024
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1025
+ if return_dict is False:
1026
+ return (sequence_output, pooled_output) + outputs[1:]
1027
+
1028
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1029
+ last_hidden_state=sequence_output,
1030
+ pooler_output=pooled_output,
1031
+ past_key_values=None,
1032
+ hidden_states=outputs.hidden_states,
1033
+ attentions=outputs.attentions,
1034
+ cross_attentions=outputs.cross_attentions,
1035
+ )
1036
+
1037
+
1038
  class DPLMForMaskedLM(DPLMPreTrainedModel, EmbeddingMixin):
1039
  config_class = DPLMConfig
1040
 
 
1061
  self.contact_head = None
1062
 
1063
  def get_input_embeddings(self) -> nn.Module:
1064
+ return self.esm.get_input_embeddings()
1065
 
1066
  def get_output_embeddings(self):
1067
  return self.lm_head.decoder
 
1131
  config_class = DPLMConfig
1132
 
1133
  def get_input_embeddings(self) -> nn.Module:
1134
+ return self.esm.get_input_embeddings()
1135
 
1136
  def __init__(self, config):
1137
  DPLMPreTrainedModel.__init__(self, config)
 
1201
  config_class = DPLMConfig
1202
 
1203
  def get_input_embeddings(self) -> nn.Module:
1204
+ return self.esm.get_input_embeddings()
1205
 
1206
  def __init__(self, config):
1207
  DPLMPreTrainedModel.__init__(self, config)