File size: 2,006 Bytes
6282cd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f010d8e
6282cd6
f010d8e
6282cd6
 
 
 
 
f010d8e
 
 
 
 
 
 
 
 
 
 
6282cd6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import torch
import plotly.express as px

# Default model
DEFAULT_MODEL = "bert-base-uncased"

# Function to get attention maps
def visualize_attention(text, model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name, output_attentions=True)

    inputs = tokenizer(text, return_tensors="pt")
    outputs = model(**inputs)
    attentions = outputs.attentions  # Tuple: (num_layers, batch, num_heads, seq_len, seq_len)

    # Select last layer, head 0
    attention = attentions[-1][0][0].detach().numpy()  # shape: (seq_len, seq_len)
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

    fig = px.imshow(
        attention,
        labels=dict(x="Tokens", y="Tokens", color="Attention"),
        x=tokens,
        y=tokens,
        title=f"{model_name} - Attention Map (Last Layer, Head 0)"
    )
    return fig

# Gradio interface
def app(text, model):
    fig = visualize_attention(text, model)
    return fig

interface = gr.Interface(
    fn=app,
    inputs=[
        gr.Textbox(label="Input Text", placeholder="Enter a sentence"),
        gr.Dropdown(
            label="Model",
            choices=["bert-base-uncased", "distilbert-base-uncased", "roberta-base"],
            value=DEFAULT_MODEL
        )
    ],
    outputs=gr.Plot(label="Attention Map"),
    title="Transformer Attention Visualizer",
    description="""
Understand how transformer models interpret text through self-attention.

๐Ÿง  This tool extracts attention weights from the **last layer** and **first attention head** of popular transformer models.

๐Ÿ” The attention map shows how each token focuses on others during processing.

๐Ÿ“š Try different models and sentences to compare how they handle language and context.

Ideal for NLP learners, researchers, and anyone curious about how transformers "pay attention".
"""
)

if __name__ == "__main__":
    interface.launch()