File size: 615 Bytes
ddaf22b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import matplotlib.pyplot as plt
import seaborn as sns
import torch

def plot_attention_heatmap(model, tokenizer, input_seq):
    model.eval()
    inputs = tokenizer(input_seq, return_tensors="pt")
    with torch.no_grad():
        output = model.base_model(**inputs, output_attentions=True)

    attention = output.attentions[-1][0]
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(attention, xticklabels=tokens, yticklabels=tokens, cmap="viridis", ax=ax)
    plt.xticks(rotation=90)
    plt.title("Attention Heatmap")
    return fig