Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| import plotly.express as px | |
| # Default model | |
| DEFAULT_MODEL = "bert-base-uncased" | |
| # Function to get attention maps | |
| def visualize_attention(text, model_name): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModel.from_pretrained(model_name, output_attentions=True) | |
| inputs = tokenizer(text, return_tensors="pt") | |
| outputs = model(**inputs) | |
| attentions = outputs.attentions # Tuple: (num_layers, batch, num_heads, seq_len, seq_len) | |
| # Select last layer, head 0 | |
| attention = attentions[-1][0][0].detach().numpy() # shape: (seq_len, seq_len) | |
| tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) | |
| fig = px.imshow( | |
| attention, | |
| labels=dict(x="Tokens", y="Tokens", color="Attention"), | |
| x=tokens, | |
| y=tokens, | |
| title=f"{model_name} - Attention Map (Last Layer, Head 0)" | |
| ) | |
| return fig | |
| # Gradio interface | |
| def app(text, model): | |
| fig = visualize_attention(text, model) | |
| return fig | |
| interface = gr.Interface( | |
| fn=app, | |
| inputs=[ | |
| gr.Textbox(label="Input Text", placeholder="Enter a sentence"), | |
| gr.Dropdown( | |
| label="Model", | |
| choices=["bert-base-uncased", "distilbert-base-uncased", "roberta-base"], | |
| value=DEFAULT_MODEL | |
| ) | |
| ], | |
| outputs=gr.Plot(label="Attention Map"), | |
| title="Transformer Attention Visualizer", | |
| description=""" | |
| Understand how transformer models interpret text through self-attention. | |
| π§ This tool extracts attention weights from the **last layer** and **first attention head** of popular transformer models. | |
| π The attention map shows how each token focuses on others during processing. | |
| π Try different models and sentences to compare how they handle language and context. | |
| Ideal for NLP learners, researchers, and anyone curious about how transformers "pay attention". | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() | |