| import gradio as gr |
| import torch |
| import torch.nn.functional as F |
| import plotly.graph_objects as go |
| from transformers import AutoModel, AutoTokenizer |
|
|
| |
| |
| |
|
|
| def activation_plot(): |
| x = torch.linspace(-6, 6, 500) |
| acts = { |
| "GELU": F.gelu(x), |
| "ReLU": F.relu(x), |
| "SiLU": x * torch.sigmoid(x), |
| "Hard-Swish": x * F.relu6(x + 3) / 6, |
| } |
|
|
| fig = go.Figure() |
| for name, y in acts.items(): |
| fig.add_trace(go.Scatter( |
| x=x.numpy(), |
| y=y.numpy(), |
| mode="lines", |
| name=name |
| )) |
|
|
| fig.update_layout( |
| title="Activation Functions", |
| template="plotly_dark", |
| height=450, |
| ) |
| return fig |
|
|
| |
| |
| |
|
|
| def attention_viz(model_name, text): |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModel.from_pretrained(model_name, output_attentions=True) |
|
|
| tokens = tokenizer(text, return_tensors="pt") |
| with torch.no_grad(): |
| outputs = model(**tokens) |
|
|
| attn = outputs.attentions[-1][0] |
| num_heads = attn.shape[0] |
|
|
| figs = [] |
| for i in range(num_heads): |
| head = attn[i].numpy() |
| fig = go.Figure(data=go.Heatmap( |
| z=head, |
| colorscale="Viridis" |
| )) |
| fig.update_layout( |
| title=f"Head {i}", |
| height=400, |
| template="plotly_dark" |
| ) |
| figs.append(fig) |
|
|
| return figs |
|
|
| |
| |
| |
|
|
| with gr.Blocks(title="Transformer Attention Visualizer") as demo: |
|
|
| gr.Markdown("# 🔥 Activation + Transformer Attention Visualizer") |
| gr.Markdown("### 🚀 Fully interactive Plotly (zoom / pan / hover)") |
|
|
| |
| gr.Markdown("## 1️⃣ Activation Functions") |
| gr.Plot(value=activation_plot()) |
|
|
| |
| gr.Markdown("## 2️⃣ Multi-Head Attention") |
|
|
| model_box = gr.Textbox(value="bert-base-uncased", label="Model name") |
| text_box = gr.Textbox(value="Transformers are amazing.", label="Text") |
|
|
| run_btn = gr.Button("Generate Attention") |
|
|
| gallery = gr.Gallery(label="Attention Heads (zoomable)", show_label=True) |
|
|
| run_btn.click(attention_viz, inputs=[model_box, text_box], outputs=[gallery]) |
|
|
| demo.launch() |
|
|