Spaces:
Sleeping
Sleeping
File size: 2,006 Bytes
6282cd6 f010d8e 6282cd6 f010d8e 6282cd6 f010d8e 6282cd6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | import gradio as gr
from transformers import AutoTokenizer, AutoModel
import torch
import plotly.express as px
# Default model
DEFAULT_MODEL = "bert-base-uncased"
# Function to get attention maps
def visualize_attention(text, model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)
attentions = outputs.attentions # Tuple: (num_layers, batch, num_heads, seq_len, seq_len)
# Select last layer, head 0
attention = attentions[-1][0][0].detach().numpy() # shape: (seq_len, seq_len)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
fig = px.imshow(
attention,
labels=dict(x="Tokens", y="Tokens", color="Attention"),
x=tokens,
y=tokens,
title=f"{model_name} - Attention Map (Last Layer, Head 0)"
)
return fig
# Gradio interface
def app(text, model):
fig = visualize_attention(text, model)
return fig
interface = gr.Interface(
fn=app,
inputs=[
gr.Textbox(label="Input Text", placeholder="Enter a sentence"),
gr.Dropdown(
label="Model",
choices=["bert-base-uncased", "distilbert-base-uncased", "roberta-base"],
value=DEFAULT_MODEL
)
],
outputs=gr.Plot(label="Attention Map"),
title="Transformer Attention Visualizer",
description="""
Understand how transformer models interpret text through self-attention.
๐ง This tool extracts attention weights from the **last layer** and **first attention head** of popular transformer models.
๐ The attention map shows how each token focuses on others during processing.
๐ Try different models and sentences to compare how they handle language and context.
Ideal for NLP learners, researchers, and anyone curious about how transformers "pay attention".
"""
)
if __name__ == "__main__":
interface.launch()
|