Update modeling_prot2text.py
Browse files- modeling_prot2text.py +46 -18
modeling_prot2text.py
CHANGED
|
@@ -240,7 +240,15 @@ class Prot2TextModel(PreTrainedModel):
|
|
| 240 |
x: Optional[torch.FloatTensor] = None,
|
| 241 |
edge_type: Optional[torch.LongTensor] = None,
|
| 242 |
tokenizer=None,
|
| 243 |
-
device='cpu'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
):
|
| 245 |
|
| 246 |
if self.config.esm and not self.config.rgcn and protein_sequence==None:
|
|
@@ -326,25 +334,45 @@ class Prot2TextModel(PreTrainedModel):
|
|
| 326 |
encoder_state['attentions'] = inputs['attention_mask']
|
| 327 |
for key in ['edge_index', 'edge_type', 'x', 'encoder_input_ids']:
|
| 328 |
inputs.pop(key)
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
|
|
|
| 340 |
|
| 341 |
-
|
| 342 |
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
else:
|
| 350 |
seq = esmtokenizer([protein_sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt")
|
|
|
|
| 240 |
x: Optional[torch.FloatTensor] = None,
|
| 241 |
edge_type: Optional[torch.LongTensor] = None,
|
| 242 |
tokenizer=None,
|
| 243 |
+
device='cpu',
|
| 244 |
+
streamer=None,
|
| 245 |
+
max_new_tokens=None,
|
| 246 |
+
do_sample=None,
|
| 247 |
+
top_p=None,
|
| 248 |
+
top_k=None,
|
| 249 |
+
temperature=None,
|
| 250 |
+
num_beams=1,
|
| 251 |
+
repetition_penalty=None
|
| 252 |
):
|
| 253 |
|
| 254 |
if self.config.esm and not self.config.rgcn and protein_sequence==None:
|
|
|
|
| 334 |
encoder_state['attentions'] = inputs['attention_mask']
|
| 335 |
for key in ['edge_index', 'edge_type', 'x', 'encoder_input_ids']:
|
| 336 |
inputs.pop(key)
|
| 337 |
+
if streamer is None:
|
| 338 |
+
tok_ids = self.decoder.generate(input_ids=inputs['decoder_input_ids'],
|
| 339 |
+
encoder_outputs=encoder_state,
|
| 340 |
+
use_cache=True,
|
| 341 |
+
output_attentions=False,
|
| 342 |
+
output_scores=False,
|
| 343 |
+
return_dict_in_generate=True,
|
| 344 |
+
encoder_attention_mask=inputs['attention_mask'],
|
| 345 |
+
length_penalty=1.0,
|
| 346 |
+
no_repeat_ngram_size=None,
|
| 347 |
+
early_stopping=False,
|
| 348 |
+
num_beams=1)
|
| 349 |
|
| 350 |
+
generated = tokenizer.batch_decode(tok_ids.get('sequences'), skip_special_tokens=True)
|
| 351 |
|
| 352 |
+
os.remove(structure_filename)
|
| 353 |
+
os.remove(graph_filename)
|
| 354 |
+
os.remove(process_filename)
|
| 355 |
+
|
| 356 |
+
return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '')
|
| 357 |
+
else:
|
| 358 |
+
os.remove(structure_filename)
|
| 359 |
+
os.remove(graph_filename)
|
| 360 |
+
os.remove(process_filename)
|
| 361 |
+
return self.decoder.generate(input_ids=inputs['decoder_input_ids'],
|
| 362 |
+
encoder_outputs=encoder_state,
|
| 363 |
+
use_cache=True,
|
| 364 |
+
encoder_attention_mask=inputs['attention_mask'],
|
| 365 |
+
length_penalty=1.0,
|
| 366 |
+
no_repeat_ngram_size=None,
|
| 367 |
+
early_stopping=False,
|
| 368 |
+
streamer=streamer,
|
| 369 |
+
max_new_tokens=max_new_tokens,
|
| 370 |
+
do_sample=do_sample,
|
| 371 |
+
top_p=top_p,
|
| 372 |
+
top_k=top_k,
|
| 373 |
+
temperature=temperature,
|
| 374 |
+
num_beams=1,
|
| 375 |
+
repetition_penalty=repetition_penalty)
|
| 376 |
|
| 377 |
else:
|
| 378 |
seq = esmtokenizer([protein_sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt")
|