Jiqing commited on
Commit
01eb671
·
verified ·
1 Parent(s): e96ac65

Update modeling_protst.py

Browse files
Files changed (1) hide show
  1. modeling_protst.py +62 -58
modeling_protst.py CHANGED
@@ -25,7 +25,7 @@ class BertTextRepresentationOutput(ModelOutput):
25
 
26
 
27
  @dataclass
28
- class EsmProteinClassificationOutput(ModelOutput):
29
 
30
  loss: Optional[torch.FloatTensor] = None
31
  logits: torch.FloatTensor = None
@@ -125,7 +125,7 @@ class EsmForProteinRepresentation(EsmPreTrainedModel):
125
  output_attentions: Optional[bool] = None,
126
  output_hidden_states: Optional[bool] = None,
127
  return_dict: Optional[bool] = None,
128
- ) -> Union[Tuple, EsmProteinClassificationOutput]:
129
 
130
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
131
 
@@ -159,61 +159,6 @@ class EsmForProteinRepresentation(EsmPreTrainedModel):
159
  )
160
 
161
 
162
- class EsmForProteinPropertyPrediction(EsmPreTrainedModel):
163
- def __init__(self, config):
164
- super().__init__(config)
165
- self.model = EsmForProteinRepresentation(config)
166
- self.classifier = ProtSTHead(config, out_dim=config.num_labels)
167
-
168
- self.post_init() # NOTE
169
-
170
- def forward(
171
- self,
172
- input_ids: Optional[torch.LongTensor] = None,
173
- attention_mask: Optional[torch.Tensor] = None,
174
- position_ids: Optional[torch.LongTensor] = None,
175
- head_mask: Optional[torch.Tensor] = None,
176
- inputs_embeds: Optional[torch.FloatTensor] = None,
177
- labels: Optional[torch.LongTensor] = None,
178
- output_attentions: Optional[bool] = None,
179
- output_hidden_states: Optional[bool] = None,
180
- return_dict: Optional[bool] = None,
181
- ) -> Union[Tuple, EsmProteinClassificationOutput]:
182
- r"""
183
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
184
- Labels for computing the protein classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
185
- Returns:
186
- Examples:
187
- """
188
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
189
-
190
- outputs = self.model(
191
- input_ids,
192
- attention_mask=attention_mask,
193
- position_ids=position_ids,
194
- head_mask=head_mask,
195
- inputs_embeds=inputs_embeds,
196
- output_attentions=output_attentions,
197
- output_hidden_states=output_hidden_states,
198
- return_dict=return_dict,
199
- )
200
-
201
- logits = self.classifier(outputs.protein_feature) # [bsz, xxx] -> [bsz, num_labels]
202
-
203
- loss = None
204
- if labels is not None:
205
- loss_fct = nn.CrossEntropyLoss()
206
-
207
- labels = labels.to(logits.device)
208
- loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
209
-
210
- if not return_dict:
211
- output = (logits,)
212
- return ((loss,) + output) if loss is not None else output
213
-
214
- return EsmProteinClassificationOutput(loss=loss, logits=logits)
215
-
216
-
217
  class ProtSTPreTrainedModel(PreTrainedModel):
218
  config_class = ProtSTConfig
219
 
@@ -265,7 +210,7 @@ class ProtSTModel(ProtSTPreTrainedModel):
265
  self.protein_model = EsmForProteinRepresentation(config.protein_config)
266
  self.text_model = BertForPubMed(config.text_config)
267
  self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
268
-
269
  self.post_init() # NOTE
270
 
271
  def forward(self,
@@ -280,3 +225,62 @@ class ProtSTModel(ProtSTPreTrainedModel):
280
  ):
281
  # Not implement yet
282
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  @dataclass
28
+ class ProtSTClassificationOutput(ModelOutput):
29
 
30
  loss: Optional[torch.FloatTensor] = None
31
  logits: torch.FloatTensor = None
 
125
  output_attentions: Optional[bool] = None,
126
  output_hidden_states: Optional[bool] = None,
127
  return_dict: Optional[bool] = None,
128
+ ) -> Union[Tuple, EsmProteinRepresentationOutput]:
129
 
130
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
131
 
 
159
  )
160
 
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  class ProtSTPreTrainedModel(PreTrainedModel):
163
  config_class = ProtSTConfig
164
 
 
210
  self.protein_model = EsmForProteinRepresentation(config.protein_config)
211
  self.text_model = BertForPubMed(config.text_config)
212
  self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
213
+
214
  self.post_init() # NOTE
215
 
216
  def forward(self,
 
225
  ):
226
  # Not implement yet
227
  return None
228
+
229
+
230
+ class ProtSTForProteinPropertyPrediction(ProtSTPreTrainedModel):
231
+ def __init__(self, config):
232
+ super().__init__(config)
233
+
234
+ self.config = config
235
+ self.protein_model = EsmForProteinRepresentation(config.protein_config)
236
+ self.text_model = BertForPubMed(config.text_config)
237
+ self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
238
+ self.classifier = ProtSTHead(config, out_dim=config.num_labels)
239
+
240
+ self.post_init() # NOTE
241
+
242
+ def forward(
243
+ self,
244
+ input_ids: Optional[torch.LongTensor] = None,
245
+ attention_mask: Optional[torch.Tensor] = None,
246
+ position_ids: Optional[torch.LongTensor] = None,
247
+ head_mask: Optional[torch.Tensor] = None,
248
+ inputs_embeds: Optional[torch.FloatTensor] = None,
249
+ labels: Optional[torch.LongTensor] = None,
250
+ output_attentions: Optional[bool] = None,
251
+ output_hidden_states: Optional[bool] = None,
252
+ return_dict: Optional[bool] = None,
253
+ ) -> Union[Tuple, ProtSTClassificationOutput]:
254
+ r"""
255
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
256
+ Labels for computing the protein classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
257
+ Returns:
258
+ Examples:
259
+ """
260
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
261
+
262
+ outputs = self.protein_model(
263
+ input_ids,
264
+ attention_mask=attention_mask,
265
+ position_ids=position_ids,
266
+ head_mask=head_mask,
267
+ inputs_embeds=inputs_embeds,
268
+ output_attentions=output_attentions,
269
+ output_hidden_states=output_hidden_states,
270
+ return_dict=return_dict,
271
+ )
272
+
273
+ logits = self.classifier(outputs.protein_feature) # [bsz, xxx] -> [bsz, num_labels]
274
+
275
+ loss = None
276
+ if labels is not None:
277
+ loss_fct = nn.CrossEntropyLoss()
278
+
279
+ labels = labels.to(logits.device)
280
+ loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
281
+
282
+ if not return_dict:
283
+ output = (logits,)
284
+ return ((loss,) + output) if loss is not None else output
285
+
286
+ return ProtSTClassificationOutput(loss=loss, logits=logits)