lrm / app.py
ritwika96's picture
Update app.py
b9cd478 verified
import gradio as gr
import torch
import torch.nn.functional as F
import plotly.graph_objects as go
from transformers import AutoModel, AutoTokenizer
############################################
# Activation Plot
############################################
def activation_plot():
x = torch.linspace(-6, 6, 500)
acts = {
"GELU": F.gelu(x),
"ReLU": F.relu(x),
"SiLU": x * torch.sigmoid(x),
"Hard-Swish": x * F.relu6(x + 3) / 6,
}
fig = go.Figure()
for name, y in acts.items():
fig.add_trace(go.Scatter(
x=x.numpy(),
y=y.numpy(),
mode="lines",
name=name
))
fig.update_layout(
title="Activation Functions",
template="plotly_dark",
height=450,
)
return fig
############################################
# Attention Visualizer
############################################
def attention_viz(model_name, text):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)
tokens = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**tokens)
attn = outputs.attentions[-1][0] # (heads, seq, seq)
num_heads = attn.shape[0]
figs = []
for i in range(num_heads):
head = attn[i].numpy()
fig = go.Figure(data=go.Heatmap(
z=head,
colorscale="Viridis"
))
fig.update_layout(
title=f"Head {i}",
height=400,
template="plotly_dark"
)
figs.append(fig)
return figs
############################################
# UI
############################################
with gr.Blocks(title="Transformer Attention Visualizer") as demo:
gr.Markdown("# 🔥 Activation + Transformer Attention Visualizer")
gr.Markdown("### 🚀 Fully interactive Plotly (zoom / pan / hover)")
# Activation section
gr.Markdown("## 1️⃣ Activation Functions")
gr.Plot(value=activation_plot())
# Attention section
gr.Markdown("## 2️⃣ Multi-Head Attention")
model_box = gr.Textbox(value="bert-base-uncased", label="Model name")
text_box = gr.Textbox(value="Transformers are amazing.", label="Text")
run_btn = gr.Button("Generate Attention")
gallery = gr.Gallery(label="Attention Heads (zoomable)", show_label=True)
run_btn.click(attention_viz, inputs=[model_box, text_box], outputs=[gallery])
demo.launch()