Spaces:
Sleeping
Sleeping
File size: 615 Bytes
ddaf22b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import matplotlib.pyplot as plt
import seaborn as sns
import torch
def plot_attention_heatmap(model, tokenizer, input_seq):
model.eval()
inputs = tokenizer(input_seq, return_tensors="pt")
with torch.no_grad():
output = model.base_model(**inputs, output_attentions=True)
attention = output.attentions[-1][0]
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(attention, xticklabels=tokens, yticklabels=tokens, cmap="viridis", ax=ax)
plt.xticks(rotation=90)
plt.title("Attention Heatmap")
return fig
|