Update modeling_prot2text.py
Browse files- modeling_prot2text.py +2 -18
modeling_prot2text.py
CHANGED
|
@@ -323,8 +323,8 @@ class Prot2TextModel(PreTrainedModel):
|
|
| 323 |
tok_ids = self.decoder.generate(input_ids=inputs['decoder_input_ids'],
|
| 324 |
encoder_outputs=encoder_state,
|
| 325 |
use_cache=True,
|
| 326 |
-
output_attentions=
|
| 327 |
-
output_scores=
|
| 328 |
return_dict_in_generate=True,
|
| 329 |
encoder_attention_mask=inputs['attention_mask'],
|
| 330 |
length_penalty=1.0,
|
|
@@ -333,22 +333,6 @@ class Prot2TextModel(PreTrainedModel):
|
|
| 333 |
num_beams=1)
|
| 334 |
|
| 335 |
generated = tokenizer.batch_decode(tok_ids.get('sequences'), skip_special_tokens=True)
|
| 336 |
-
print(tok_ids.get('scores')[0].size())
|
| 337 |
-
m = torch.nn.Softmax()
|
| 338 |
-
att_w = []
|
| 339 |
-
print(len(gpdb.sequence[0]))
|
| 340 |
-
score = 0
|
| 341 |
-
for i in range(len(tok_ids.get('cross_attentions'))):
|
| 342 |
-
att_w.append(torch.mul(tok_ids.get('cross_attentions')[i][-1].squeeze().mean(dim=0), inputs['attention_mask'][-1].squeeze())[:len(gpdb.sequence[0])].tolist())
|
| 343 |
-
score += np.log(torch.max(m(tok_ids.get('scores')[i]).squeeze()).item())
|
| 344 |
-
score = score / len(tok_ids.get('cross_attentions'))
|
| 345 |
-
# print(str(score))
|
| 346 |
-
|
| 347 |
-
# import seaborn as sns
|
| 348 |
-
# import matplotlib.pylab as plt
|
| 349 |
-
# plt.figure().set_figwidth(150)
|
| 350 |
-
# ax = sns.heatmap(att_w, cmap="YlGnBu", robust=True, xticklabels=gpdb.sequence[0])#, yticklabels=generated[0])
|
| 351 |
-
# plt.savefig("seaborn_plot.png")
|
| 352 |
|
| 353 |
os.remove(structure_filename)
|
| 354 |
os.remove(graph_filename)
|
|
|
|
| 323 |
tok_ids = self.decoder.generate(input_ids=inputs['decoder_input_ids'],
|
| 324 |
encoder_outputs=encoder_state,
|
| 325 |
use_cache=True,
|
| 326 |
+
output_attentions=False,
|
| 327 |
+
output_scores=False,
|
| 328 |
return_dict_in_generate=True,
|
| 329 |
encoder_attention_mask=inputs['attention_mask'],
|
| 330 |
length_penalty=1.0,
|
|
|
|
| 333 |
num_beams=1)
|
| 334 |
|
| 335 |
generated = tokenizer.batch_decode(tok_ids.get('sequences'), skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
os.remove(structure_filename)
|
| 338 |
os.remove(graph_filename)
|