shoaibfd26's picture
Update app.py
f010d8e verified
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import torch
import plotly.express as px
# Default model
DEFAULT_MODEL = "bert-base-uncased"
# Function to get attention maps
def visualize_attention(text, model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)
attentions = outputs.attentions # Tuple: (num_layers, batch, num_heads, seq_len, seq_len)
# Select last layer, head 0
attention = attentions[-1][0][0].detach().numpy() # shape: (seq_len, seq_len)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
fig = px.imshow(
attention,
labels=dict(x="Tokens", y="Tokens", color="Attention"),
x=tokens,
y=tokens,
title=f"{model_name} - Attention Map (Last Layer, Head 0)"
)
return fig
# Gradio interface
def app(text, model):
fig = visualize_attention(text, model)
return fig
interface = gr.Interface(
fn=app,
inputs=[
gr.Textbox(label="Input Text", placeholder="Enter a sentence"),
gr.Dropdown(
label="Model",
choices=["bert-base-uncased", "distilbert-base-uncased", "roberta-base"],
value=DEFAULT_MODEL
)
],
outputs=gr.Plot(label="Attention Map"),
title="Transformer Attention Visualizer",
description="""
Understand how transformer models interpret text through self-attention.
🧠 This tool extracts attention weights from the **last layer** and **first attention head** of popular transformer models.
πŸ” The attention map shows how each token focuses on others during processing.
πŸ“š Try different models and sentences to compare how they handle language and context.
Ideal for NLP learners, researchers, and anyone curious about how transformers "pay attention".
"""
)
if __name__ == "__main__":
interface.launch()