Upload modeling_dplm.py with huggingface_hub
Browse files- modeling_dplm.py +103 -36
modeling_dplm.py
CHANGED
|
@@ -797,46 +797,24 @@ class ModifiedEsmEncoder(EsmEncoder):
|
|
| 797 |
)
|
| 798 |
|
| 799 |
|
| 800 |
-
class
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
return self.embeddings.word_embeddings
|
| 805 |
|
| 806 |
-
def __init__(self, config,
|
| 807 |
-
DPLMPreTrainedModel.__init__(self, config)
|
| 808 |
self.config = config
|
| 809 |
self.embeddings = EsmEmbeddings(config)
|
| 810 |
self.encoder = ModifiedEsmEncoder(config)
|
| 811 |
-
self.pooler = EsmPooler(config) if add_pooling_layer else None
|
| 812 |
self.contact_head = EsmContactPredictionHead(
|
| 813 |
in_features=config.num_hidden_layers * config.num_attention_heads,
|
| 814 |
bias=True,
|
| 815 |
)
|
| 816 |
self.post_init()
|
| 817 |
|
| 818 |
-
def
|
| 819 |
-
|
| 820 |
-
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
| 821 |
-
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
|
| 822 |
-
elif head_mask.dim() == 2:
|
| 823 |
-
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
| 824 |
-
assert head_mask.dim() == 5, f"head_mask.dim != 5, got {head_mask.dim()}"
|
| 825 |
-
head_mask = head_mask.to(dtype=self.dtype)
|
| 826 |
-
return head_mask
|
| 827 |
-
|
| 828 |
-
def get_head_mask(
|
| 829 |
-
self,
|
| 830 |
-
head_mask: Optional[torch.Tensor],
|
| 831 |
-
num_hidden_layers: int,
|
| 832 |
-
is_attention_chunked: bool = False,
|
| 833 |
-
) -> Union[torch.Tensor, List[None]]:
|
| 834 |
-
if head_mask is None:
|
| 835 |
-
return [None] * num_hidden_layers
|
| 836 |
-
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
|
| 837 |
-
if is_attention_chunked:
|
| 838 |
-
head_mask = head_mask.unsqueeze(-1)
|
| 839 |
-
return head_mask
|
| 840 |
|
| 841 |
def set_input_embeddings(self, value):
|
| 842 |
self.embeddings.word_embeddings = value
|
|
@@ -860,6 +838,29 @@ class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin):
|
|
| 860 |
attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
|
| 861 |
return self.contact_head(input_ids, attns)
|
| 862 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 863 |
def forward(
|
| 864 |
self,
|
| 865 |
input_ids: Optional[torch.Tensor] = None,
|
|
@@ -953,14 +954,12 @@ class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin):
|
|
| 953 |
flex_block_mask=flex_block_mask,
|
| 954 |
)
|
| 955 |
sequence_output = encoder_outputs[0]
|
| 956 |
-
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 957 |
|
| 958 |
if return_dict is False:
|
| 959 |
-
return (sequence_output,
|
| 960 |
|
| 961 |
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 962 |
last_hidden_state=sequence_output,
|
| 963 |
-
pooler_output=pooled_output,
|
| 964 |
past_key_values=None,
|
| 965 |
hidden_states=encoder_outputs.hidden_states,
|
| 966 |
attentions=encoder_outputs.attentions,
|
|
@@ -968,6 +967,74 @@ class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin):
|
|
| 968 |
)
|
| 969 |
|
| 970 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 971 |
class DPLMForMaskedLM(DPLMPreTrainedModel, EmbeddingMixin):
|
| 972 |
config_class = DPLMConfig
|
| 973 |
|
|
@@ -994,7 +1061,7 @@ class DPLMForMaskedLM(DPLMPreTrainedModel, EmbeddingMixin):
|
|
| 994 |
self.contact_head = None
|
| 995 |
|
| 996 |
def get_input_embeddings(self) -> nn.Module:
|
| 997 |
-
return self.esm.
|
| 998 |
|
| 999 |
def get_output_embeddings(self):
|
| 1000 |
return self.lm_head.decoder
|
|
@@ -1064,7 +1131,7 @@ class DPLMForSequenceClassification(DPLMPreTrainedModel, EmbeddingMixin):
|
|
| 1064 |
config_class = DPLMConfig
|
| 1065 |
|
| 1066 |
def get_input_embeddings(self) -> nn.Module:
|
| 1067 |
-
return self.esm.
|
| 1068 |
|
| 1069 |
def __init__(self, config):
|
| 1070 |
DPLMPreTrainedModel.__init__(self, config)
|
|
@@ -1134,7 +1201,7 @@ class DPLMForTokenClassification(DPLMPreTrainedModel, EmbeddingMixin):
|
|
| 1134 |
config_class = DPLMConfig
|
| 1135 |
|
| 1136 |
def get_input_embeddings(self) -> nn.Module:
|
| 1137 |
-
return self.esm.
|
| 1138 |
|
| 1139 |
def __init__(self, config):
|
| 1140 |
DPLMPreTrainedModel.__init__(self, config)
|
|
|
|
| 797 |
)
|
| 798 |
|
| 799 |
|
| 800 |
+
class FAST_DPLM_ENCODER(DPLMPreTrainedModel, EmbeddingMixin):
|
| 801 |
+
"""Inner encoder class that holds the actual ESM-style weights (embeddings, encoder,
|
| 802 |
+
contact_head) so that the weight keys are prefixed with 'esm.' in the outer DPLMModel,
|
| 803 |
+
matching pretrained DPLM checkpoints."""
|
|
|
|
| 804 |
|
| 805 |
+
def __init__(self, config, **kwargs):
|
| 806 |
+
DPLMPreTrainedModel.__init__(self, config, **kwargs)
|
| 807 |
self.config = config
|
| 808 |
self.embeddings = EsmEmbeddings(config)
|
| 809 |
self.encoder = ModifiedEsmEncoder(config)
|
|
|
|
| 810 |
self.contact_head = EsmContactPredictionHead(
|
| 811 |
in_features=config.num_hidden_layers * config.num_attention_heads,
|
| 812 |
bias=True,
|
| 813 |
)
|
| 814 |
self.post_init()
|
| 815 |
|
| 816 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 817 |
+
return self.embeddings.word_embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 818 |
|
| 819 |
def set_input_embeddings(self, value):
|
| 820 |
self.embeddings.word_embeddings = value
|
|
|
|
| 838 |
attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
|
| 839 |
return self.contact_head(input_ids, attns)
|
| 840 |
|
| 841 |
+
def _convert_head_mask_to_5d(self, head_mask: torch.Tensor, num_hidden_layers: int) -> torch.Tensor:
|
| 842 |
+
if head_mask.dim() == 1:
|
| 843 |
+
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
| 844 |
+
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
|
| 845 |
+
elif head_mask.dim() == 2:
|
| 846 |
+
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
| 847 |
+
assert head_mask.dim() == 5, f"head_mask.dim != 5, got {head_mask.dim()}"
|
| 848 |
+
head_mask = head_mask.to(dtype=self.dtype)
|
| 849 |
+
return head_mask
|
| 850 |
+
|
| 851 |
+
def get_head_mask(
|
| 852 |
+
self,
|
| 853 |
+
head_mask: Optional[torch.Tensor],
|
| 854 |
+
num_hidden_layers: int,
|
| 855 |
+
is_attention_chunked: bool = False,
|
| 856 |
+
) -> Union[torch.Tensor, List[None]]:
|
| 857 |
+
if head_mask is None:
|
| 858 |
+
return [None] * num_hidden_layers
|
| 859 |
+
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
|
| 860 |
+
if is_attention_chunked:
|
| 861 |
+
head_mask = head_mask.unsqueeze(-1)
|
| 862 |
+
return head_mask
|
| 863 |
+
|
| 864 |
def forward(
|
| 865 |
self,
|
| 866 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
|
| 954 |
flex_block_mask=flex_block_mask,
|
| 955 |
)
|
| 956 |
sequence_output = encoder_outputs[0]
|
|
|
|
| 957 |
|
| 958 |
if return_dict is False:
|
| 959 |
+
return (sequence_output,) + encoder_outputs[1:]
|
| 960 |
|
| 961 |
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 962 |
last_hidden_state=sequence_output,
|
|
|
|
| 963 |
past_key_values=None,
|
| 964 |
hidden_states=encoder_outputs.hidden_states,
|
| 965 |
attentions=encoder_outputs.attentions,
|
|
|
|
| 967 |
)
|
| 968 |
|
| 969 |
|
| 970 |
+
class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin):
|
| 971 |
+
config_class = DPLMConfig
|
| 972 |
+
|
| 973 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 974 |
+
DPLMPreTrainedModel.__init__(self, config)
|
| 975 |
+
self.config = config
|
| 976 |
+
self.esm = FAST_DPLM_ENCODER(config)
|
| 977 |
+
self.pooler = EsmPooler(config) if add_pooling_layer else None
|
| 978 |
+
self.post_init()
|
| 979 |
+
|
| 980 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 981 |
+
return self.esm.embeddings.word_embeddings
|
| 982 |
+
|
| 983 |
+
def set_input_embeddings(self, value):
|
| 984 |
+
self.esm.embeddings.word_embeddings = value
|
| 985 |
+
|
| 986 |
+
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 987 |
+
return self.esm._embed(input_ids, attention_mask)
|
| 988 |
+
|
| 989 |
+
def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| 990 |
+
return self.esm.predict_contacts(input_ids, attention_mask)
|
| 991 |
+
|
| 992 |
+
def forward(
|
| 993 |
+
self,
|
| 994 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 995 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 996 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 997 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 998 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 999 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 1000 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 1001 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1002 |
+
use_cache: Optional[bool] = None,
|
| 1003 |
+
output_attentions: Optional[bool] = None,
|
| 1004 |
+
output_hidden_states: Optional[bool] = None,
|
| 1005 |
+
return_dict: Optional[bool] = None,
|
| 1006 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
| 1007 |
+
outputs = self.esm(
|
| 1008 |
+
input_ids=input_ids,
|
| 1009 |
+
attention_mask=attention_mask,
|
| 1010 |
+
position_ids=position_ids,
|
| 1011 |
+
head_mask=head_mask,
|
| 1012 |
+
inputs_embeds=inputs_embeds,
|
| 1013 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1014 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1015 |
+
past_key_values=past_key_values,
|
| 1016 |
+
use_cache=use_cache,
|
| 1017 |
+
output_attentions=output_attentions,
|
| 1018 |
+
output_hidden_states=output_hidden_states,
|
| 1019 |
+
return_dict=return_dict,
|
| 1020 |
+
)
|
| 1021 |
+
sequence_output = outputs[0]
|
| 1022 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 1023 |
+
|
| 1024 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1025 |
+
if return_dict is False:
|
| 1026 |
+
return (sequence_output, pooled_output) + outputs[1:]
|
| 1027 |
+
|
| 1028 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 1029 |
+
last_hidden_state=sequence_output,
|
| 1030 |
+
pooler_output=pooled_output,
|
| 1031 |
+
past_key_values=None,
|
| 1032 |
+
hidden_states=outputs.hidden_states,
|
| 1033 |
+
attentions=outputs.attentions,
|
| 1034 |
+
cross_attentions=outputs.cross_attentions,
|
| 1035 |
+
)
|
| 1036 |
+
|
| 1037 |
+
|
| 1038 |
class DPLMForMaskedLM(DPLMPreTrainedModel, EmbeddingMixin):
|
| 1039 |
config_class = DPLMConfig
|
| 1040 |
|
|
|
|
| 1061 |
self.contact_head = None
|
| 1062 |
|
| 1063 |
def get_input_embeddings(self) -> nn.Module:
|
| 1064 |
+
return self.esm.get_input_embeddings()
|
| 1065 |
|
| 1066 |
def get_output_embeddings(self):
|
| 1067 |
return self.lm_head.decoder
|
|
|
|
| 1131 |
config_class = DPLMConfig
|
| 1132 |
|
| 1133 |
def get_input_embeddings(self) -> nn.Module:
|
| 1134 |
+
return self.esm.get_input_embeddings()
|
| 1135 |
|
| 1136 |
def __init__(self, config):
|
| 1137 |
DPLMPreTrainedModel.__init__(self, config)
|
|
|
|
| 1201 |
config_class = DPLMConfig
|
| 1202 |
|
| 1203 |
def get_input_embeddings(self) -> nn.Module:
|
| 1204 |
+
return self.esm.get_input_embeddings()
|
| 1205 |
|
| 1206 |
def __init__(self, config):
|
| 1207 |
DPLMPreTrainedModel.__init__(self, config)
|