PraneshJs commited on
Commit
4d08ba2
·
verified ·
1 Parent(s): 0deccad

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +407 -0
app.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==========================================================
2
+ # Vision Transformer (ViT) Visualizer — HF Space, CPU, Gradio 5
3
+ # - Model: google/vit-base-patch16-224
4
+ # - Shows:
5
+ # * Original + patch grid (tokens)
6
+ # * Attention heatmap overlay (CLS -> patches)
7
+ # * PCA of patch embeddings
8
+ # * Top-5 predictions
9
+ # * Simple vs technical explanation
10
+ # - CPU friendly, uses only Gradio v5-safe features
11
+ # ==========================================================
12
+
13
+ import math
14
+ import warnings
15
+ from typing import Dict, Any, Optional, List, Tuple
16
+
17
+ import gradio as gr
18
+ import torch
19
+ import numpy as np
20
+ from PIL import Image, ImageDraw
21
+ from transformers import AutoImageProcessor, ViTForImageClassification
22
+ from sklearn.decomposition import PCA
23
+ import matplotlib.pyplot as plt
24
+
25
+ warnings.filterwarnings("ignore")
26
+
27
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ MODEL_NAME = "google/vit-base-patch16-224"
29
+
30
+ VIT_MODEL = None
31
+ VIT_PROCESSOR = None
32
+
33
+
34
+ # ---------------------- MODEL LOADING ----------------------
35
+
36
+
37
+ def load_vit():
38
+ """Load ViT + image processor once into global cache."""
39
+ global VIT_MODEL, VIT_PROCESSOR
40
+ if VIT_MODEL is not None and VIT_PROCESSOR is not None:
41
+ return VIT_MODEL, VIT_PROCESSOR
42
+
43
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
44
+ model = ViTForImageClassification.from_pretrained(MODEL_NAME)
45
+
46
+ # ensure we get attentions + hidden states
47
+ model.config.output_attentions = True
48
+ model.config.output_hidden_states = True
49
+
50
+ model.to(DEVICE)
51
+ model.eval()
52
+
53
+ VIT_MODEL = model
54
+ VIT_PROCESSOR = processor
55
+ return model, processor
56
+
57
+
58
+ # ---------------------- VISUAL HELPERS ----------------------
59
+
60
+
61
+ def make_patch_grid_image(pil_img: Image.Image, patch_size: int = 16) -> Image.Image:
62
+ """
63
+ Resize to 224x224 and draw a patch grid (ViT splits into 16x16 patches).
64
+ """
65
+ img = pil_img.convert("RGB").resize((224, 224))
66
+ draw = ImageDraw.Draw(img)
67
+ w, h = img.size
68
+ for x in range(0, w, patch_size):
69
+ draw.line((x, 0, x, h), fill=(0, 255, 0), width=1)
70
+ for y in range(0, h, patch_size):
71
+ draw.line((0, y, w, y), fill=(0, 255, 0), width=1)
72
+ return img
73
+
74
+
75
+ def make_attention_overlay(
76
+ base_img: Image.Image, heatmap_grid: np.ndarray
77
+ ) -> Image.Image:
78
+ """
79
+ Overlay a CLS->patch attention heatmap on top of the 224x224 image.
80
+ heatmap_grid: (G, G) attention values.
81
+ """
82
+ base = base_img.convert("RGB").resize((224, 224))
83
+ g = heatmap_grid.astype(np.float32)
84
+
85
+ if not np.any(g):
86
+ g = np.zeros_like(g, dtype=np.float32)
87
+ else:
88
+ g -= g.min()
89
+ maxv = g.max()
90
+ if maxv > 0:
91
+ g /= maxv
92
+
93
+ # upscale to image size
94
+ H, W = g.shape
95
+ heat_img = Image.fromarray((g * 255).astype("uint8"), mode="L")
96
+ heat_img = heat_img.resize((224, 224), Image.BILINEAR)
97
+ heat = np.array(heat_img).astype(np.float32) / 255.0 # 0..1
98
+
99
+ # simple blue->red colormap overlay
100
+ r = heat
101
+ g_c = np.zeros_like(heat)
102
+ b = 1.0 - heat
103
+ cam = np.stack([r, g_c, b], axis=-1) # H,W,3
104
+
105
+ base_np = np.array(base).astype(np.float32) / 255.0
106
+ alpha = 0.45
107
+ blended = (1 - alpha) * base_np + alpha * cam
108
+ blended = np.clip(blended * 255.0, 0, 255).astype("uint8")
109
+ return Image.fromarray(blended)
110
+
111
+
112
+ def make_pca_plot(patch_embeddings: np.ndarray):
113
+ """
114
+ patch_embeddings: (N_patches, hidden_dim)
115
+ Returns a Matplotlib figure showing patches in 2D PCA space.
116
+ """
117
+ if patch_embeddings.shape[0] < 2:
118
+ return None
119
+
120
+ pca = PCA(n_components=2)
121
+ comps = pca.fit_transform(patch_embeddings) # (N,2)
122
+
123
+ fig, ax = plt.subplots(figsize=(4, 4))
124
+ ax.scatter(comps[:, 0], comps[:, 1], s=20, alpha=0.8)
125
+ ax.set_title("Patches in 2D (PCA of embeddings)")
126
+ ax.set_xlabel("PC1")
127
+ ax.set_ylabel("PC2")
128
+ ax.grid(True, alpha=0.3)
129
+ fig.tight_layout()
130
+ return fig
131
+
132
+
133
+ # ---------------------- CORE ANALYSIS ----------------------
134
+
135
+
136
+ def analyze_vit(img: Optional[Image.Image], simple: bool):
137
+ """
138
+ Main function called by gradio button.
139
+ Returns:
140
+ - patch_grid_image
141
+ - attention_overlay (default: last layer, head 0)
142
+ - PCA figure
143
+ - predictions table
144
+ - explanation markdown
145
+ - state dict (for attention slider updates)
146
+ """
147
+ if img is None:
148
+ return (
149
+ None,
150
+ None,
151
+ None,
152
+ [],
153
+ "⬆️ Please upload an image (e.g., a dog, a car, a object).",
154
+ {},
155
+ )
156
+
157
+ model, processor = load_vit()
158
+
159
+ # 1) Preprocess
160
+ img_resized = img.convert("RGB").resize((224, 224))
161
+ patch_grid_img = make_patch_grid_image(img_resized)
162
+
163
+ inputs = processor(images=img_resized, return_tensors="pt")
164
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
165
+
166
+ with torch.no_grad():
167
+ outputs = model(**inputs)
168
+
169
+ # 2) Predictions (top-5)
170
+ logits = outputs.logits[0].cpu().numpy()
171
+ probs = np.exp(logits - logits.max())
172
+ probs = probs / probs.sum()
173
+ topk_idx = probs.argsort()[-5:][::-1]
174
+ id2label = model.config.id2label
175
+ preds_table = [
176
+ [id2label[int(i)], float(probs[int(i)])] for i in topk_idx
177
+ ]
178
+
179
+ # 3) Patch embeddings from last hidden state
180
+ # hidden_states[-1]: (batch, seq_len, hidden)
181
+ hidden_last = outputs.hidden_states[-1][0].cpu().numpy() # (seq, hidden)
182
+ # seq layout: [CLS] + patches
183
+ patch_emb = hidden_last[1:, :] # (N_patches, hidden)
184
+ pca_fig = make_pca_plot(patch_emb)
185
+
186
+ # 4) Attention -> CLS to patches grid per layer/head
187
+ attentions = outputs.attentions # list of (batch, heads, seq, seq)
188
+ num_layers = len(attentions)
189
+ num_heads = attentions[0].shape[1] if num_layers > 0 else 0
190
+
191
+ # ViT-base: 14x14 = 196 patches
192
+ seq_len = attentions[0].shape[-1] # 1 + N_patches
193
+ n_patches = seq_len - 1
194
+ grid_size = int(math.sqrt(n_patches))
195
+ if grid_size * grid_size != n_patches:
196
+ # fallback: approximate
197
+ grid_size = int(round(math.sqrt(n_patches)))
198
+
199
+ cls_to_patch = np.zeros(
200
+ (num_layers, num_heads, grid_size, grid_size), dtype=np.float32
201
+ )
202
+
203
+ for l, att in enumerate(attentions):
204
+ a = att[0].cpu().numpy() # (heads, seq, seq)
205
+ # CLS token index = 0, patches = 1..N
206
+ cls_vec = a[:, 0, 1:] # (heads, N_patches)
207
+ # if shapes mismatch, pad/truncate
208
+ if cls_vec.shape[1] != grid_size * grid_size:
209
+ tmp = np.zeros((num_heads, grid_size * grid_size), dtype=np.float32)
210
+ n_min = min(tmp.shape[1], cls_vec.shape[1])
211
+ tmp[:, :n_min] = cls_vec[:, :n_min]
212
+ cls_vec = tmp
213
+ cls_grid = cls_vec.reshape(num_heads, grid_size, grid_size)
214
+ cls_to_patch[l] = cls_grid
215
+
216
+ # default attention overlay: last layer, head 0
217
+ default_layer = num_layers - 1
218
+ default_head = 0
219
+ att_grid_default = cls_to_patch[default_layer, default_head]
220
+ att_overlay = make_attention_overlay(img_resized, att_grid_default)
221
+
222
+ # 5) Explanation
223
+ explanation = build_explanation(simple, num_layers, num_heads, grid_size)
224
+
225
+ # 6) State for slider updates
226
+ state = {
227
+ "cls_to_patch": cls_to_patch,
228
+ "grid_size": grid_size,
229
+ "num_layers": num_layers,
230
+ "num_heads": num_heads,
231
+ # we also keep a copy of the 224x224 base image in memory
232
+ "base_image": img_resized,
233
+ }
234
+
235
+ return patch_grid_img, att_overlay, pca_fig, preds_table, explanation, state
236
+
237
+
238
+ def build_explanation(
239
+ simple: bool, num_layers: int, num_heads: int, grid_size: int
240
+ ) -> str:
241
+ if simple:
242
+ return f"""
243
+ ### 🧒 How a Vision Transformer (ViT) “sees” this image
244
+
245
+ 1. **Cut into patches** – The image is sliced into **{grid_size}×{grid_size} = {grid_size*grid_size}** small squares.
246
+ 2. **Turn patches into tokens** – Each patch becomes a little vector (like a word in a sentence).
247
+ 3. **Add position info** – The model remembers where each patch came from (top-left, bottom-right, etc.).
248
+ 4. **Look around with attention** – In each of the **{num_layers} layers**, the model lets every patch
249
+ look at other patches using **self-attention** (with {num_heads} attention heads).
250
+ 5. **Understand the whole image** – After many layers, ViT builds a global understanding of the scene
251
+ and predicts what’s in the picture (top-5 shown on the right).
252
+
253
+ The heatmap shows **where the special [CLS] token is looking** in the last layer.
254
+ """
255
+ else:
256
+ return f"""
257
+ ### 🔬 Vision Transformer internals (technical view)
258
+
259
+ - The image is resized to 224×224 and split into **{grid_size}×{grid_size} = {grid_size*grid_size}** patches.
260
+ - Each patch is linearly projected into an embedding and combined with a positional embedding,
261
+ forming a sequence of tokens: `[CLS] + P₁ + P₂ + … + Pₙ`.
262
+
263
+ - The ViT encoder has **{num_layers} transformer layers** with **{num_heads} attention heads** each.
264
+ In every layer, **self-attention** mixes information across all patches, enabling long-range dependencies
265
+ and global context.
266
+
267
+ - The [CLS] token aggregates information across patches and is passed through a classification head to produce
268
+ logits over ImageNet-1k classes (we show the top-5).
269
+
270
+ - The attention heatmap we display is:
271
+ - From **[CLS] → patch tokens**
272
+ - For a selected `(layer, head)`
273
+ - Reshaped into a `{grid_size}×{grid_size}` grid and upsampled to image resolution for overlay.
274
+
275
+ - The PCA plot shows the **final-layer patch embeddings** projected to 2D, giving an intuition of how
276
+ ViT places patches in a semantic space.
277
+
278
+ Use the sliders to explore different layers and heads and see how the attention focus changes.
279
+ """
280
+
281
+
282
+ # ---------------------- ATTENTION SLIDER UPDATE ----------------------
283
+
284
+
285
+ def update_attention_view(
286
+ state: Dict[str, Any], layer_idx: int, head_idx: int
287
+ ):
288
+ """
289
+ Called when user moves the layer/head sliders.
290
+ Returns a new attention overlay image.
291
+ """
292
+ if not state or "cls_to_patch" not in state:
293
+ return None
294
+
295
+ cls_to_patch = state["cls_to_patch"]
296
+ base_img = state["base_image"]
297
+ num_layers = state["num_layers"]
298
+ num_heads = state["num_heads"]
299
+
300
+ # clamp indices safely
301
+ l = max(0, min(int(layer_idx), num_layers - 1))
302
+ h = max(0, min(int(head_idx), num_heads - 1))
303
+
304
+ grid = cls_to_patch[l, h]
305
+ overlay = make_attention_overlay(base_img, grid)
306
+ return overlay
307
+
308
+
309
+ # ---------------------- BUILD UI ----------------------
310
+
311
+
312
+ with gr.Blocks(title="Vision Transformer (ViT) Visualizer") as demo:
313
+ gr.Markdown(
314
+ """
315
+ # 🧠 Vision Transformer (ViT) — How It Sees the World
316
+
317
+ Upload an image and explore how a Vision Transformer:
318
+
319
+ - Cuts it into patches (tokens)
320
+ - Attends to different regions via self-attention
321
+ - Embeds patches into a high-dimensional space
322
+ - Predicts what’s in the image
323
+
324
+ Toggle **simple / technical** explanation and move the sliders to change
325
+ which layer/head's attention you’re seeing.
326
+ """
327
+ )
328
+
329
+ with gr.Row():
330
+ with gr.Column(scale=1):
331
+ img_in = gr.Image(
332
+ label="Upload image",
333
+ type="pil",
334
+ )
335
+ simple_ck = gr.Checkbox(
336
+ label="Simple explanation (for everyone)",
337
+ value=True,
338
+ )
339
+ run_btn = gr.Button("Run ViT Analysis", variant="primary")
340
+
341
+ gr.Markdown(
342
+ "Try images like: animals, objects, scenes. This uses `google/vit-base-patch16-224` (ImageNet-1k)."
343
+ )
344
+
345
+ with gr.Column(scale=1):
346
+ preds_df = gr.Dataframe(
347
+ headers=["Label", "Probability"],
348
+ datatype=["str", "number"],
349
+ interactive=False,
350
+ label="Top-5 predictions",
351
+ )
352
+ explanation_md = gr.Markdown(label="Explanation")
353
+
354
+ gr.Markdown("## 🧩 Tokens & Attention")
355
+
356
+ with gr.Row():
357
+ patch_img = gr.Image(
358
+ label="Patches (16×16) — how ViT tokenizes the image",
359
+ interactive=False,
360
+ )
361
+ attn_img = gr.Image(
362
+ label="Attention heatmap (CLS → patches)",
363
+ interactive=False,
364
+ )
365
+
366
+ with gr.Row():
367
+ layer_slider = gr.Slider(
368
+ minimum=0,
369
+ maximum=11, # ViT-base has 12 layers (0-11)
370
+ step=1,
371
+ value=11,
372
+ label="Layer (0 = shallow, 11 = deepest)",
373
+ )
374
+ head_slider = gr.Slider(
375
+ minimum=0,
376
+ maximum=11, # 12 attention heads
377
+ step=1,
378
+ value=0,
379
+ label="Head index",
380
+ )
381
+
382
+ gr.Markdown("## 🌌 Patch embeddings in 2D (PCA)")
383
+
384
+ pca_plot = gr.Plot(label="Patches in embedding space (last layer)")
385
+
386
+ state = gr.State()
387
+
388
+ # main button: run full analysis
389
+ run_btn.click(
390
+ fn=analyze_vit,
391
+ inputs=[img_in, simple_ck],
392
+ outputs=[patch_img, attn_img, pca_plot, preds_df, explanation_md, state],
393
+ )
394
+
395
+ # when sliders change: update attention overlay only
396
+ layer_slider.change(
397
+ fn=update_attention_view,
398
+ inputs=[state, layer_slider, head_slider],
399
+ outputs=[attn_img],
400
+ )
401
+ head_slider.change(
402
+ fn=update_attention_view,
403
+ inputs=[state, layer_slider, head_slider],
404
+ outputs=[attn_img],
405
+ )
406
+
407
+ demo.launch()