lhallee commited on
Commit
ace7399
·
verified ·
1 Parent(s): b371243

Upload modeling_dplm2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dplm2.py +94 -28
modeling_dplm2.py CHANGED
@@ -890,17 +890,38 @@ class ModifiedEsmEncoder(EsmEncoder):
890
  )
891
 
892
 
893
- class DPLM2Model(DPLM2PreTrainedModel, EmbeddingMixin):
894
- config_class = DPLM2Config
 
 
895
 
896
- def __init__(self, config, add_pooling_layer=True):
897
- DPLM2PreTrainedModel.__init__(self, config)
898
  self.config = config
899
  self.embeddings = EsmEmbeddings(config)
900
  self.encoder = ModifiedEsmEncoder(config)
901
- self.pooler = EsmPooler(config) if add_pooling_layer else None
902
  self.post_init()
903
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
904
  def _convert_head_mask_to_5d(self, head_mask: torch.Tensor, num_hidden_layers: int) -> torch.Tensor:
905
  if head_mask.dim() == 1:
906
  head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
@@ -924,26 +945,6 @@ class DPLM2Model(DPLM2PreTrainedModel, EmbeddingMixin):
924
  head_mask = head_mask.unsqueeze(-1)
925
  return head_mask
926
 
927
- def get_input_embeddings(self) -> nn.Module:
928
- return self.embeddings.word_embeddings
929
-
930
- def set_input_embeddings(self, value):
931
- self.embeddings.word_embeddings = value
932
-
933
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
934
- if attention_mask is None:
935
- attention_mask = input_ids.ne(self.config.pad_token_id)
936
- type_ids = _infer_modality_type(input_ids, attention_mask)
937
- outputs = self(
938
- input_ids=input_ids,
939
- attention_mask=attention_mask,
940
- type_ids=type_ids,
941
- output_hidden_states=False,
942
- output_attentions=False,
943
- return_dict=True,
944
- )
945
- return outputs.last_hidden_state
946
-
947
  def forward(
948
  self,
949
  input_ids: Optional[torch.Tensor] = None,
@@ -1039,14 +1040,12 @@ class DPLM2Model(DPLM2PreTrainedModel, EmbeddingMixin):
1039
  flex_block_mask=flex_block_mask,
1040
  )
1041
  sequence_output = encoder_outputs[0]
1042
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1043
 
1044
  if return_dict is False:
1045
- return (sequence_output, pooled_output) + encoder_outputs[1:]
1046
 
1047
  return BaseModelOutputWithPoolingAndCrossAttentions(
1048
  last_hidden_state=sequence_output,
1049
- pooler_output=pooled_output,
1050
  past_key_values=None,
1051
  hidden_states=encoder_outputs.hidden_states,
1052
  attentions=encoder_outputs.attentions,
@@ -1054,6 +1053,73 @@ class DPLM2Model(DPLM2PreTrainedModel, EmbeddingMixin):
1054
  )
1055
 
1056
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1057
  class DPLM2ForMaskedLM(DPLM2PreTrainedModel, EmbeddingMixin):
1058
  config_class = DPLM2Config
1059
 
 
890
  )
891
 
892
 
893
+ class FAST_DPLM2_ENCODER(DPLM2PreTrainedModel, EmbeddingMixin):
894
+ """Inner encoder class that holds the actual ESM-style weights (embeddings, encoder)
895
+ so that the weight keys are prefixed with 'esm.' in the outer DPLM2Model,
896
+ matching pretrained DPLM2 checkpoints."""
897
 
898
+ def __init__(self, config, **kwargs):
899
+ DPLM2PreTrainedModel.__init__(self, config, **kwargs)
900
  self.config = config
901
  self.embeddings = EsmEmbeddings(config)
902
  self.encoder = ModifiedEsmEncoder(config)
 
903
  self.post_init()
904
 
905
+ def get_input_embeddings(self) -> nn.Module:
906
+ return self.embeddings.word_embeddings
907
+
908
+ def set_input_embeddings(self, value):
909
+ self.embeddings.word_embeddings = value
910
+
911
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
912
+ if attention_mask is None:
913
+ attention_mask = input_ids.ne(self.config.pad_token_id)
914
+ type_ids = _infer_modality_type(input_ids, attention_mask)
915
+ outputs = self(
916
+ input_ids=input_ids,
917
+ attention_mask=attention_mask,
918
+ type_ids=type_ids,
919
+ output_hidden_states=False,
920
+ output_attentions=False,
921
+ return_dict=True,
922
+ )
923
+ return outputs.last_hidden_state
924
+
925
  def _convert_head_mask_to_5d(self, head_mask: torch.Tensor, num_hidden_layers: int) -> torch.Tensor:
926
  if head_mask.dim() == 1:
927
  head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
 
945
  head_mask = head_mask.unsqueeze(-1)
946
  return head_mask
947
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
948
  def forward(
949
  self,
950
  input_ids: Optional[torch.Tensor] = None,
 
1040
  flex_block_mask=flex_block_mask,
1041
  )
1042
  sequence_output = encoder_outputs[0]
 
1043
 
1044
  if return_dict is False:
1045
+ return (sequence_output,) + encoder_outputs[1:]
1046
 
1047
  return BaseModelOutputWithPoolingAndCrossAttentions(
1048
  last_hidden_state=sequence_output,
 
1049
  past_key_values=None,
1050
  hidden_states=encoder_outputs.hidden_states,
1051
  attentions=encoder_outputs.attentions,
 
1053
  )
1054
 
1055
 
1056
+ class DPLM2Model(DPLM2PreTrainedModel, EmbeddingMixin):
1057
+ config_class = DPLM2Config
1058
+
1059
+ def __init__(self, config, add_pooling_layer=True):
1060
+ DPLM2PreTrainedModel.__init__(self, config)
1061
+ self.config = config
1062
+ self.esm = FAST_DPLM2_ENCODER(config)
1063
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
1064
+ self.post_init()
1065
+
1066
+ def get_input_embeddings(self) -> nn.Module:
1067
+ return self.esm.embeddings.word_embeddings
1068
+
1069
+ def set_input_embeddings(self, value):
1070
+ self.esm.embeddings.word_embeddings = value
1071
+
1072
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1073
+ return self.esm._embed(input_ids, attention_mask)
1074
+
1075
+ def forward(
1076
+ self,
1077
+ input_ids: Optional[torch.Tensor] = None,
1078
+ attention_mask: Optional[torch.Tensor] = None,
1079
+ position_ids: Optional[torch.Tensor] = None,
1080
+ head_mask: Optional[torch.Tensor] = None,
1081
+ inputs_embeds: Optional[torch.Tensor] = None,
1082
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1083
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1084
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1085
+ use_cache: Optional[bool] = None,
1086
+ output_attentions: Optional[bool] = None,
1087
+ output_hidden_states: Optional[bool] = None,
1088
+ return_dict: Optional[bool] = None,
1089
+ type_ids: Optional[torch.Tensor] = None,
1090
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
1091
+ outputs = self.esm(
1092
+ input_ids=input_ids,
1093
+ attention_mask=attention_mask,
1094
+ position_ids=position_ids,
1095
+ head_mask=head_mask,
1096
+ inputs_embeds=inputs_embeds,
1097
+ encoder_hidden_states=encoder_hidden_states,
1098
+ encoder_attention_mask=encoder_attention_mask,
1099
+ past_key_values=past_key_values,
1100
+ use_cache=use_cache,
1101
+ output_attentions=output_attentions,
1102
+ output_hidden_states=output_hidden_states,
1103
+ return_dict=return_dict,
1104
+ type_ids=type_ids,
1105
+ )
1106
+ sequence_output = outputs[0]
1107
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1108
+
1109
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1110
+ if return_dict is False:
1111
+ return (sequence_output, pooled_output) + outputs[1:]
1112
+
1113
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1114
+ last_hidden_state=sequence_output,
1115
+ pooler_output=pooled_output,
1116
+ past_key_values=None,
1117
+ hidden_states=outputs.hidden_states,
1118
+ attentions=outputs.attentions,
1119
+ cross_attentions=outputs.cross_attentions,
1120
+ )
1121
+
1122
+
1123
  class DPLM2ForMaskedLM(DPLM2PreTrainedModel, EmbeddingMixin):
1124
  config_class = DPLM2Config
1125