Update app.py
Browse files
app.py
CHANGED
|
@@ -5,7 +5,7 @@ import plotly.graph_objects as go
|
|
| 5 |
from transformers import AutoModel, AutoTokenizer
|
| 6 |
|
| 7 |
############################################
|
| 8 |
-
# Activation Plot
|
| 9 |
############################################
|
| 10 |
|
| 11 |
def activation_plot():
|
|
@@ -34,7 +34,7 @@ def activation_plot():
|
|
| 34 |
return fig
|
| 35 |
|
| 36 |
############################################
|
| 37 |
-
#
|
| 38 |
############################################
|
| 39 |
|
| 40 |
def attention_viz(model_name, text):
|
|
@@ -45,7 +45,7 @@ def attention_viz(model_name, text):
|
|
| 45 |
with torch.no_grad():
|
| 46 |
outputs = model(**tokens)
|
| 47 |
|
| 48 |
-
attn = outputs.attentions[-1][0]
|
| 49 |
num_heads = attn.shape[0]
|
| 50 |
|
| 51 |
figs = []
|
|
@@ -65,31 +65,28 @@ def attention_viz(model_name, text):
|
|
| 65 |
return figs
|
| 66 |
|
| 67 |
############################################
|
| 68 |
-
# UI
|
| 69 |
############################################
|
| 70 |
|
| 71 |
with gr.Blocks(title="Transformer Attention Visualizer") as demo:
|
| 72 |
|
| 73 |
gr.Markdown("# 🔥 Activation + Transformer Attention Visualizer")
|
| 74 |
-
gr.Markdown("### 🚀 Fully interactive Plotly (zoom
|
| 75 |
|
| 76 |
-
#
|
| 77 |
gr.Markdown("## 1️⃣ Activation Functions")
|
| 78 |
-
|
| 79 |
|
| 80 |
-
#
|
| 81 |
-
gr.Markdown("## 2️⃣ Multi-Head Attention
|
| 82 |
|
| 83 |
model_box = gr.Textbox(value="bert-base-uncased", label="Model name")
|
| 84 |
-
text_box = gr.Textbox(value="Transformers are amazing.", label="Text
|
| 85 |
|
| 86 |
run_btn = gr.Button("Generate Attention")
|
| 87 |
-
gallery = gr.Gallery(label="Attention Heads (Zoomable)").style(
|
| 88 |
-
columns=2,
|
| 89 |
-
object_fit="contain",
|
| 90 |
-
height="auto"
|
| 91 |
-
)
|
| 92 |
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
|
| 95 |
demo.launch()
|
|
|
|
| 5 |
from transformers import AutoModel, AutoTokenizer
|
| 6 |
|
| 7 |
############################################
|
| 8 |
+
# Activation Plot
|
| 9 |
############################################
|
| 10 |
|
| 11 |
def activation_plot():
|
|
|
|
| 34 |
return fig
|
| 35 |
|
| 36 |
############################################
|
| 37 |
+
# Attention Visualizer
|
| 38 |
############################################
|
| 39 |
|
| 40 |
def attention_viz(model_name, text):
|
|
|
|
| 45 |
with torch.no_grad():
|
| 46 |
outputs = model(**tokens)
|
| 47 |
|
| 48 |
+
attn = outputs.attentions[-1][0] # (heads, seq, seq)
|
| 49 |
num_heads = attn.shape[0]
|
| 50 |
|
| 51 |
figs = []
|
|
|
|
| 65 |
return figs
|
| 66 |
|
| 67 |
############################################
|
| 68 |
+
# UI
|
| 69 |
############################################
|
| 70 |
|
| 71 |
with gr.Blocks(title="Transformer Attention Visualizer") as demo:
|
| 72 |
|
| 73 |
gr.Markdown("# 🔥 Activation + Transformer Attention Visualizer")
|
| 74 |
+
gr.Markdown("### 🚀 Fully interactive Plotly (zoom / pan / hover)")
|
| 75 |
|
| 76 |
+
# Activation section
|
| 77 |
gr.Markdown("## 1️⃣ Activation Functions")
|
| 78 |
+
gr.Plot(value=activation_plot())
|
| 79 |
|
| 80 |
+
# Attention section
|
| 81 |
+
gr.Markdown("## 2️⃣ Multi-Head Attention")
|
| 82 |
|
| 83 |
model_box = gr.Textbox(value="bert-base-uncased", label="Model name")
|
| 84 |
+
text_box = gr.Textbox(value="Transformers are amazing.", label="Text")
|
| 85 |
|
| 86 |
run_btn = gr.Button("Generate Attention")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
+
gallery = gr.Gallery(label="Attention Heads (zoomable)", show_label=True)
|
| 89 |
+
|
| 90 |
+
run_btn.click(attention_viz, inputs=[model_box, text_box], outputs=[gallery])
|
| 91 |
|
| 92 |
demo.launch()
|