import gradio as gr from transformers import AutoTokenizer, AutoModel import torch import plotly.express as px # Default model DEFAULT_MODEL = "bert-base-uncased" # Function to get attention maps def visualize_attention(text, model_name): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name, output_attentions=True) inputs = tokenizer(text, return_tensors="pt") outputs = model(**inputs) attentions = outputs.attentions # Tuple: (num_layers, batch, num_heads, seq_len, seq_len) # Select last layer, head 0 attention = attentions[-1][0][0].detach().numpy() # shape: (seq_len, seq_len) tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) fig = px.imshow( attention, labels=dict(x="Tokens", y="Tokens", color="Attention"), x=tokens, y=tokens, title=f"{model_name} - Attention Map (Last Layer, Head 0)" ) return fig # Gradio interface def app(text, model): fig = visualize_attention(text, model) return fig interface = gr.Interface( fn=app, inputs=[ gr.Textbox(label="Input Text", placeholder="Enter a sentence"), gr.Dropdown( label="Model", choices=["bert-base-uncased", "distilbert-base-uncased", "roberta-base"], value=DEFAULT_MODEL ) ], outputs=gr.Plot(label="Attention Map"), title="Transformer Attention Visualizer", description=""" Understand how transformer models interpret text through self-attention. 🧠 This tool extracts attention weights from the **last layer** and **first attention head** of popular transformer models. 🔍 The attention map shows how each token focuses on others during processing. 📚 Try different models and sentences to compare how they handle language and context. Ideal for NLP learners, researchers, and anyone curious about how transformers "pay attention". """ ) if __name__ == "__main__": interface.launch()