Jiqing commited on
Commit
d953af1
·
verified ·
1 Parent(s): dc52ebf

Delete modeling_protst.py

Browse files
Files changed (1) hide show
  1. modeling_protst.py +0 -286
modeling_protst.py DELETED
@@ -1,286 +0,0 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
- from typing import Optional, Tuple, Union
5
- from dataclasses import dataclass
6
- from transformers import PreTrainedModel
7
- from transformers.modeling_outputs import ModelOutput
8
- from transformers.models.esm import EsmPreTrainedModel, EsmModel
9
- from transformers.models.bert import BertPreTrainedModel, BertModel
10
- from .configuration_protst import ProtSTConfig
11
-
12
-
13
- @dataclass
14
- class EsmProteinRepresentationOutput(ModelOutput):
15
-
16
- protein_feature: torch.FloatTensor = None
17
- residue_feature: torch.FloatTensor = None
18
-
19
-
20
- @dataclass
21
- class BertTextRepresentationOutput(ModelOutput):
22
-
23
- text_feature: torch.FloatTensor = None
24
- word_feature: torch.FloatTensor = None
25
-
26
-
27
- @dataclass
28
- class ProtSTClassificationOutput(ModelOutput):
29
-
30
- loss: Optional[torch.FloatTensor] = None
31
- logits: torch.FloatTensor = None
32
-
33
- class ProtSTHead(nn.Module):
34
- def __init__(self, config, out_dim=512):
35
- super().__init__()
36
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
37
- self.out_proj = nn.Linear(config.hidden_size, out_dim)
38
-
39
- def forward(self, x):
40
- x = self.dense(x)
41
- x = nn.functional.relu(x)
42
- x = self.out_proj(x)
43
- return x
44
-
45
-
46
- class BertForPubMed(BertPreTrainedModel):
47
- def __init__(self, config):
48
- super().__init__(config)
49
-
50
- self.pad_token_id = config.pad_token_id
51
- self.cls_token_id = config.cls_token_id
52
- self.sep_token_id = config.sep_token_id
53
-
54
- self.bert = BertModel(config, add_pooling_layer=False)
55
- self.text_mlp = ProtSTHead(config)
56
- self.word_mlp = ProtSTHead(config)
57
-
58
- self.post_init() # NOTE
59
-
60
- def forward(
61
- self,
62
- input_ids: Optional[torch.Tensor] = None,
63
- attention_mask: Optional[torch.Tensor] = None,
64
- token_type_ids: Optional[torch.Tensor] = None,
65
- position_ids: Optional[torch.Tensor] = None,
66
- head_mask: Optional[torch.Tensor] = None,
67
- inputs_embeds: Optional[torch.Tensor] = None,
68
- encoder_hidden_states: Optional[torch.Tensor] = None,
69
- encoder_attention_mask: Optional[torch.Tensor] = None,
70
- output_attentions: Optional[bool] = None,
71
- output_hidden_states: Optional[bool] = None,
72
- return_dict: Optional[bool] = None,
73
- ) -> Union[Tuple[torch.Tensor], ModelOutput]:
74
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
75
-
76
- outputs = self.bert(
77
- input_ids,
78
- attention_mask=attention_mask,
79
- token_type_ids=token_type_ids,
80
- position_ids=position_ids,
81
- head_mask=head_mask,
82
- inputs_embeds=inputs_embeds,
83
- encoder_hidden_states=encoder_hidden_states,
84
- encoder_attention_mask=encoder_attention_mask,
85
- output_attentions=output_attentions,
86
- output_hidden_states=output_hidden_states,
87
- return_dict=return_dict,
88
- )
89
- word_feature = outputs.last_hidden_state
90
- is_special = (input_ids == self.cls_token_id) | (input_ids == self.sep_token_id) | (input_ids == self.pad_token_id)
91
- special_mask = (~is_special).to(torch.int64).unsqueeze(-1)
92
- pooled_feature = ((word_feature * special_mask).sum(1) / (special_mask.sum(1) + 1.0e-6)).to(word_feature.dtype)
93
- pooled_feature = self.text_mlp(pooled_feature)
94
- word_feature = self.word_mlp(word_feature)
95
-
96
- if not return_dict:
97
- return (pooled_feature, word_feature)
98
-
99
- return BertTextRepresentationOutput(text_feature=pooled_feature, word_feature=word_feature)
100
-
101
-
102
-
103
-
104
- class EsmForProteinRepresentation(EsmPreTrainedModel):
105
- def __init__(self, config):
106
- super().__init__(config)
107
-
108
- self.cls_token_id = config.cls_token_id
109
- self.pad_token_id = config.pad_token_id
110
- self.eos_token_id = config.eos_token_id
111
-
112
- self.esm = EsmModel(config, add_pooling_layer=False)
113
- self.protein_mlp = ProtSTHead(config)
114
- self.residue_mlp = ProtSTHead(config)
115
-
116
- self.post_init() # NOTE
117
-
118
- def forward(
119
- self,
120
- input_ids: Optional[torch.LongTensor] = None,
121
- attention_mask: Optional[torch.Tensor] = None,
122
- position_ids: Optional[torch.LongTensor] = None,
123
- head_mask: Optional[torch.Tensor] = None,
124
- inputs_embeds: Optional[torch.FloatTensor] = None,
125
- output_attentions: Optional[bool] = None,
126
- output_hidden_states: Optional[bool] = None,
127
- return_dict: Optional[bool] = None,
128
- ) -> Union[Tuple, EsmProteinRepresentationOutput]:
129
-
130
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
131
-
132
- outputs = self.esm(
133
- input_ids,
134
- attention_mask=attention_mask,
135
- position_ids=position_ids,
136
- head_mask=head_mask,
137
- inputs_embeds=inputs_embeds,
138
- output_attentions=output_attentions,
139
- output_hidden_states=output_hidden_states,
140
- return_dict=return_dict,
141
- )
142
-
143
- residue_feature = outputs.last_hidden_state # [batch_size, seq_len, hidden_dim]
144
-
145
- # mean readout
146
- is_special = (
147
- (input_ids == self.cls_token_id) | (input_ids == self.eos_token_id) | (input_ids == self.pad_token_id)
148
- )
149
- special_mask = (~is_special).to(torch.int64).unsqueeze(-1)
150
- protein_feature = ((residue_feature * special_mask).sum(1) / (special_mask.sum(1) + 1.0e-6)).to(residue_feature.dtype)
151
-
152
- # For ProtST pretrain and zero-shot
153
- protein_feature = self.protein_mlp(protein_feature)
154
- residue_feature = self.residue_mlp(residue_feature)
155
-
156
-
157
- return EsmProteinRepresentationOutput(
158
- protein_feature=protein_feature, residue_feature=residue_feature
159
- )
160
-
161
-
162
- class ProtSTPreTrainedModel(PreTrainedModel):
163
- config_class = ProtSTConfig
164
-
165
- def _compute_protein_feature(self,
166
- protein_input_ids, protein_attention_mask, protein_position_ids,
167
- output_attentions, output_hidden_states
168
- ):
169
-
170
- protein_outputs = self.protein_model(
171
- protein_input_ids,
172
- attention_mask=protein_attention_mask,
173
- position_ids=protein_position_ids,
174
- head_mask=None,
175
- inputs_embeds=None,
176
- encoder_hidden_states=None,
177
- encoder_attention_mask=None,
178
- output_attentions=output_attentions,
179
- output_hidden_states=output_hidden_states,
180
- return_dict=None,
181
- )
182
-
183
- return protein_outputs
184
-
185
- def _compute_text_feature(self,
186
- text_input_ids, text_attention_mask, text_position_ids,
187
- output_attentions, output_hidden_states
188
- ):
189
- text_outputs = self.text_model(
190
- text_input_ids,
191
- attention_mask=text_attention_mask,
192
- position_ids=text_position_ids,
193
- head_mask=None,
194
- inputs_embeds=None,
195
- encoder_hidden_states=None,
196
- encoder_attention_mask=None,
197
- output_attentions=output_attentions,
198
- output_hidden_states=output_hidden_states,
199
- return_dict=None,
200
- )
201
-
202
- return text_outputs
203
-
204
-
205
- class ProtSTModel(ProtSTPreTrainedModel):
206
- def __init__(self, config):
207
- super().__init__(config)
208
-
209
- self.config = config
210
- self.protein_model = EsmForProteinRepresentation(config.protein_config)
211
- self.text_model = BertForPubMed(config.text_config)
212
- self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
213
-
214
- self.post_init() # NOTE
215
-
216
- def forward(self,
217
- protein_input_ids: Optional[torch.LongTensor] = None,
218
- text_input_ids: Optional[torch.LongTensor] = None,
219
- protein_attention_mask: Optional[torch.Tensor] = None,
220
- text_attention_mask: Optional[torch.Tensor] = None,
221
- protein_position_ids: Optional[torch.LongTensor] = None,
222
- text_position_ids: Optional[torch.LongTensor] = None,
223
- output_attentions: Optional[bool] = None,
224
- output_hidden_states: Optional[bool] = None,
225
- ):
226
- # Not implement yet
227
- return None
228
-
229
-
230
- class ProtSTForProteinPropertyPrediction(ProtSTPreTrainedModel):
231
- def __init__(self, config):
232
- super().__init__(config)
233
-
234
- self.config = config
235
- self.protein_model = EsmForProteinRepresentation(config.protein_config)
236
- self.text_model = BertForPubMed(config.text_config)
237
- self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
238
- self.classifier = ProtSTHead(config, out_dim=config.num_labels)
239
-
240
- self.post_init() # NOTE
241
-
242
- def forward(
243
- self,
244
- input_ids: Optional[torch.LongTensor] = None,
245
- attention_mask: Optional[torch.Tensor] = None,
246
- position_ids: Optional[torch.LongTensor] = None,
247
- head_mask: Optional[torch.Tensor] = None,
248
- inputs_embeds: Optional[torch.FloatTensor] = None,
249
- labels: Optional[torch.LongTensor] = None,
250
- output_attentions: Optional[bool] = None,
251
- output_hidden_states: Optional[bool] = None,
252
- return_dict: Optional[bool] = None,
253
- ) -> Union[Tuple, ProtSTClassificationOutput]:
254
- r"""
255
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
256
- Labels for computing the protein classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
257
- Returns:
258
- Examples:
259
- """
260
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
261
-
262
- outputs = self.protein_model(
263
- input_ids,
264
- attention_mask=attention_mask,
265
- position_ids=position_ids,
266
- head_mask=head_mask,
267
- inputs_embeds=inputs_embeds,
268
- output_attentions=output_attentions,
269
- output_hidden_states=output_hidden_states,
270
- return_dict=return_dict,
271
- )
272
-
273
- logits = self.classifier(outputs.protein_feature) # [bsz, xxx] -> [bsz, num_labels]
274
-
275
- loss = None
276
- if labels is not None:
277
- loss_fct = nn.CrossEntropyLoss()
278
-
279
- labels = labels.to(logits.device)
280
- loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
281
-
282
- if not return_dict:
283
- output = (logits,)
284
- return ((loss,) + output) if loss is not None else output
285
-
286
- return ProtSTClassificationOutput(loss=loss, logits=logits)