ritwika96 commited on
Commit
b9cd478
·
verified ·
1 Parent(s): 0d61c73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -16
app.py CHANGED
@@ -5,7 +5,7 @@ import plotly.graph_objects as go
5
  from transformers import AutoModel, AutoTokenizer
6
 
7
  ############################################
8
- # Activation Plot (Interactive Plotly)
9
  ############################################
10
 
11
  def activation_plot():
@@ -34,7 +34,7 @@ def activation_plot():
34
  return fig
35
 
36
  ############################################
37
- # Transformer Attention (one-click)
38
  ############################################
39
 
40
  def attention_viz(model_name, text):
@@ -45,7 +45,7 @@ def attention_viz(model_name, text):
45
  with torch.no_grad():
46
  outputs = model(**tokens)
47
 
48
- attn = outputs.attentions[-1][0] # (heads, seq, seq)
49
  num_heads = attn.shape[0]
50
 
51
  figs = []
@@ -65,31 +65,28 @@ def attention_viz(model_name, text):
65
  return figs
66
 
67
  ############################################
68
- # UI LAYOUT
69
  ############################################
70
 
71
  with gr.Blocks(title="Transformer Attention Visualizer") as demo:
72
 
73
  gr.Markdown("# 🔥 Activation + Transformer Attention Visualizer")
74
- gr.Markdown("### 🚀 Fully interactive Plotly (zoom in/out, hover, drag)")
75
 
76
- # ACTIVATION SECTION
77
  gr.Markdown("## 1️⃣ Activation Functions")
78
- act_plot = gr.Plot(value=activation_plot())
79
 
80
- # ATTENTION SECTION
81
- gr.Markdown("## 2️⃣ Multi-Head Attention (HuggingFace Transformers)")
82
 
83
  model_box = gr.Textbox(value="bert-base-uncased", label="Model name")
84
- text_box = gr.Textbox(value="Transformers are amazing.", label="Text input")
85
 
86
  run_btn = gr.Button("Generate Attention")
87
- gallery = gr.Gallery(label="Attention Heads (Zoomable)").style(
88
- columns=2,
89
- object_fit="contain",
90
- height="auto"
91
- )
92
 
93
- run_btn.click(attention_viz, inputs=[model_box, text_box], outputs=gallery)
 
 
94
 
95
  demo.launch()
 
5
  from transformers import AutoModel, AutoTokenizer
6
 
7
  ############################################
8
+ # Activation Plot
9
  ############################################
10
 
11
  def activation_plot():
 
34
  return fig
35
 
36
  ############################################
37
+ # Attention Visualizer
38
  ############################################
39
 
40
  def attention_viz(model_name, text):
 
45
  with torch.no_grad():
46
  outputs = model(**tokens)
47
 
48
+ attn = outputs.attentions[-1][0] # (heads, seq, seq)
49
  num_heads = attn.shape[0]
50
 
51
  figs = []
 
65
  return figs
66
 
67
  ############################################
68
+ # UI
69
  ############################################
70
 
71
  with gr.Blocks(title="Transformer Attention Visualizer") as demo:
72
 
73
  gr.Markdown("# 🔥 Activation + Transformer Attention Visualizer")
74
+ gr.Markdown("### 🚀 Fully interactive Plotly (zoom / pan / hover)")
75
 
76
+ # Activation section
77
  gr.Markdown("## 1️⃣ Activation Functions")
78
+ gr.Plot(value=activation_plot())
79
 
80
+ # Attention section
81
+ gr.Markdown("## 2️⃣ Multi-Head Attention")
82
 
83
  model_box = gr.Textbox(value="bert-base-uncased", label="Model name")
84
+ text_box = gr.Textbox(value="Transformers are amazing.", label="Text")
85
 
86
  run_btn = gr.Button("Generate Attention")
 
 
 
 
 
87
 
88
+ gallery = gr.Gallery(label="Attention Heads (zoomable)", show_label=True)
89
+
90
+ run_btn.click(attention_viz, inputs=[model_box, text_box], outputs=[gallery])
91
 
92
  demo.launch()