TransPolymer_Demo_v001 / src /attention_plot.py
zuliani1123's picture
Upload 4 files
ddaf22b verified
raw
history blame contribute delete
615 Bytes
import matplotlib.pyplot as plt
import seaborn as sns
import torch
def plot_attention_heatmap(model, tokenizer, input_seq):
model.eval()
inputs = tokenizer(input_seq, return_tensors="pt")
with torch.no_grad():
output = model.base_model(**inputs, output_attentions=True)
attention = output.attentions[-1][0]
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(attention, xticklabels=tokens, yticklabels=tokens, cmap="viridis", ax=ax)
plt.xticks(rotation=90)
plt.title("Attention Heatmap")
return fig