PraneshJs commited on
Commit
994e1b3
·
verified ·
1 Parent(s): 4d08ba2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +333 -311
app.py CHANGED
@@ -1,407 +1,429 @@
1
- # ==========================================================
2
- # Vision Transformer (ViT) Visualizer — HF Space, CPU, Gradio 5
3
- # - Model: google/vit-base-patch16-224
4
- # - Shows:
5
- # * Original + patch grid (tokens)
6
- # * Attention heatmap overlay (CLS -> patches)
7
- # * PCA of patch embeddings
8
- # * Top-5 predictions
9
- # * Simple vs technical explanation
10
- # - CPU friendly, uses only Gradio v5-safe features
11
  # ==========================================================
12
 
13
  import math
14
  import warnings
15
- from typing import Dict, Any, Optional, List, Tuple
16
 
17
  import gradio as gr
18
- import torch
19
  import numpy as np
20
- from PIL import Image, ImageDraw
21
- from transformers import AutoImageProcessor, ViTForImageClassification
 
22
  from sklearn.decomposition import PCA
23
- import matplotlib.pyplot as plt
 
24
 
25
  warnings.filterwarnings("ignore")
26
 
27
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
  MODEL_NAME = "google/vit-base-patch16-224"
 
29
 
30
- VIT_MODEL = None
31
- VIT_PROCESSOR = None
32
-
33
-
34
- # ---------------------- MODEL LOADING ----------------------
35
-
36
-
37
- def load_vit():
38
- """Load ViT + image processor once into global cache."""
39
- global VIT_MODEL, VIT_PROCESSOR
40
- if VIT_MODEL is not None and VIT_PROCESSOR is not None:
41
- return VIT_MODEL, VIT_PROCESSOR
42
 
43
- processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
44
- model = ViTForImageClassification.from_pretrained(MODEL_NAME)
45
 
46
- # ensure we get attentions + hidden states
47
- model.config.output_attentions = True
48
- model.config.output_hidden_states = True
 
 
49
 
50
- model.to(DEVICE)
51
- model.eval()
52
 
53
- VIT_MODEL = model
54
- VIT_PROCESSOR = processor
55
- return model, processor
 
 
 
 
 
56
 
 
 
 
 
57
 
58
- # ---------------------- VISUAL HELPERS ----------------------
 
 
59
 
60
 
61
- def make_patch_grid_image(pil_img: Image.Image, patch_size: int = 16) -> Image.Image:
62
- """
63
- Resize to 224x224 and draw a patch grid (ViT splits into 16x16 patches).
64
- """
65
- img = pil_img.convert("RGB").resize((224, 224))
66
  draw = ImageDraw.Draw(img)
67
  w, h = img.size
68
  for x in range(0, w, patch_size):
69
- draw.line((x, 0, x, h), fill=(0, 255, 0), width=1)
70
  for y in range(0, h, patch_size):
71
- draw.line((0, y, w, y), fill=(0, 255, 0), width=1)
72
  return img
73
 
74
 
75
- def make_attention_overlay(
76
- base_img: Image.Image, heatmap_grid: np.ndarray
77
- ) -> Image.Image:
78
  """
79
- Overlay a CLS->patch attention heatmap on top of the 224x224 image.
80
- heatmap_grid: (G, G) attention values.
81
  """
82
- base = base_img.convert("RGB").resize((224, 224))
83
- g = heatmap_grid.astype(np.float32)
84
-
85
- if not np.any(g):
86
- g = np.zeros_like(g, dtype=np.float32)
 
 
87
  else:
88
- g -= g.min()
89
- maxv = g.max()
90
- if maxv > 0:
91
- g /= maxv
92
-
93
- # upscale to image size
94
- H, W = g.shape
95
- heat_img = Image.fromarray((g * 255).astype("uint8"), mode="L")
96
- heat_img = heat_img.resize((224, 224), Image.BILINEAR)
97
- heat = np.array(heat_img).astype(np.float32) / 255.0 # 0..1
98
-
99
- # simple blue->red colormap overlay
100
  r = heat
101
- g_c = np.zeros_like(heat)
102
  b = 1.0 - heat
103
- cam = np.stack([r, g_c, b], axis=-1) # H,W,3
104
 
105
- base_np = np.array(base).astype(np.float32) / 255.0
106
- alpha = 0.45
107
- blended = (1 - alpha) * base_np + alpha * cam
108
  blended = np.clip(blended * 255.0, 0, 255).astype("uint8")
109
  return Image.fromarray(blended)
110
 
111
 
112
- def make_pca_plot(patch_embeddings: np.ndarray):
 
113
  """
114
- patch_embeddings: (N_patches, hidden_dim)
115
- Returns a Matplotlib figure showing patches in 2D PCA space.
 
 
116
  """
117
- if patch_embeddings.shape[0] < 2:
118
- return None
119
-
120
- pca = PCA(n_components=2)
121
- comps = pca.fit_transform(patch_embeddings) # (N,2)
122
-
123
- fig, ax = plt.subplots(figsize=(4, 4))
124
- ax.scatter(comps[:, 0], comps[:, 1], s=20, alpha=0.8)
125
- ax.set_title("Patches in 2D (PCA of embeddings)")
126
- ax.set_xlabel("PC1")
127
- ax.set_ylabel("PC2")
128
- ax.grid(True, alpha=0.3)
129
- fig.tight_layout()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  return fig
131
 
132
 
133
- # ---------------------- CORE ANALYSIS ----------------------
134
-
135
-
136
- def analyze_vit(img: Optional[Image.Image], simple: bool):
137
- """
138
- Main function called by gradio button.
139
- Returns:
140
- - patch_grid_image
141
- - attention_overlay (default: last layer, head 0)
142
- - PCA figure
143
- - predictions table
144
- - explanation markdown
145
- - state dict (for attention slider updates)
146
- """
147
  if img is None:
148
  return (
149
- None,
150
- None,
151
- None,
152
- [],
153
- "⬆️ Please upload an image (e.g., a dog, a car, a object).",
154
- {},
155
  )
156
 
157
- model, processor = load_vit()
158
 
159
- # 1) Preprocess
160
  img_resized = img.convert("RGB").resize((224, 224))
161
- patch_grid_img = make_patch_grid_image(img_resized)
162
-
163
- inputs = processor(images=img_resized, return_tensors="pt")
164
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
165
 
 
166
  with torch.no_grad():
167
- outputs = model(**inputs)
168
 
169
- # 2) Predictions (top-5)
170
- logits = outputs.logits[0].cpu().numpy()
171
- probs = np.exp(logits - logits.max())
172
- probs = probs / probs.sum()
173
- topk_idx = probs.argsort()[-5:][::-1]
174
- id2label = model.config.id2label
175
- preds_table = [
176
- [id2label[int(i)], float(probs[int(i)])] for i in topk_idx
177
- ]
178
-
179
- # 3) Patch embeddings from last hidden state
180
- # hidden_states[-1]: (batch, seq_len, hidden)
181
- hidden_last = outputs.hidden_states[-1][0].cpu().numpy() # (seq, hidden)
182
- # seq layout: [CLS] + patches
183
- patch_emb = hidden_last[1:, :] # (N_patches, hidden)
184
- pca_fig = make_pca_plot(patch_emb)
185
-
186
- # 4) Attention -> CLS to patches grid per layer/head
187
- attentions = outputs.attentions # list of (batch, heads, seq, seq)
188
- num_layers = len(attentions)
189
- num_heads = attentions[0].shape[1] if num_layers > 0 else 0
190
-
191
- # ViT-base: 14x14 = 196 patches
192
- seq_len = attentions[0].shape[-1] # 1 + N_patches
193
  n_patches = seq_len - 1
194
  grid_size = int(math.sqrt(n_patches))
195
  if grid_size * grid_size != n_patches:
196
- # fallback: approximate
197
  grid_size = int(round(math.sqrt(n_patches)))
198
 
199
- cls_to_patch = np.zeros(
200
- (num_layers, num_heads, grid_size, grid_size), dtype=np.float32
201
- )
202
-
203
- for l, att in enumerate(attentions):
204
- a = att[0].cpu().numpy() # (heads, seq, seq)
205
- # CLS token index = 0, patches = 1..N
206
- cls_vec = a[:, 0, 1:] # (heads, N_patches)
207
- # if shapes mismatch, pad/truncate
208
- if cls_vec.shape[1] != grid_size * grid_size:
209
- tmp = np.zeros((num_heads, grid_size * grid_size), dtype=np.float32)
210
- n_min = min(tmp.shape[1], cls_vec.shape[1])
211
- tmp[:, :n_min] = cls_vec[:, :n_min]
212
- cls_vec = tmp
213
- cls_grid = cls_vec.reshape(num_heads, grid_size, grid_size)
214
- cls_to_patch[l] = cls_grid
215
-
216
- # default attention overlay: last layer, head 0
217
- default_layer = num_layers - 1
218
  default_head = 0
219
- att_grid_default = cls_to_patch[default_layer, default_head]
220
- att_overlay = make_attention_overlay(img_resized, att_grid_default)
221
-
222
- # 5) Explanation
223
- explanation = build_explanation(simple, num_layers, num_heads, grid_size)
224
-
225
- # 6) State for slider updates
226
- state = {
227
- "cls_to_patch": cls_to_patch,
228
- "grid_size": grid_size,
229
- "num_layers": num_layers,
230
- "num_heads": num_heads,
231
- # we also keep a copy of the 224x224 base image in memory
232
- "base_image": img_resized,
233
- }
234
-
235
- return patch_grid_img, att_overlay, pca_fig, preds_table, explanation, state
 
 
 
 
 
 
 
236
 
 
 
 
 
 
 
 
 
237
 
238
- def build_explanation(
239
- simple: bool, num_layers: int, num_heads: int, grid_size: int
240
- ) -> str:
241
  if simple:
242
- return f"""
243
- ### 🧒 How a Vision Transformer (ViT) “sees” this image
244
-
245
- 1. **Cut into patches** The image is sliced into **{grid_size}×{grid_size} = {grid_size*grid_size}** small squares.
246
- 2. **Turn patches into tokens** Each patch becomes a little vector (like a word in a sentence).
247
- 3. **Add position info** The model remembers where each patch came from (top-left, bottom-right, etc.).
248
- 4. **Look around with attention** In each of the **{num_layers} layers**, the model lets every patch
249
- look at other patches using **self-attention** (with {num_heads} attention heads).
250
- 5. **Understand the whole image** – After many layers, ViT builds a global understanding of the scene
251
- and predicts what’s in the picture (top-5 shown on the right).
252
-
253
- The heatmap shows **where the special [CLS] token is looking** in the last layer.
254
  """
255
  else:
256
- return f"""
257
- ### 🔬 Vision Transformer internals (technical view)
258
-
259
- - The image is resized to 224×224 and split into **{grid_size}×{grid_size} = {grid_size*grid_size}** patches.
260
- - Each patch is linearly projected into an embedding and combined with a positional embedding,
261
- forming a sequence of tokens: `[CLS] + P₁ + P₂ + … + Pₙ`.
262
-
263
- - The ViT encoder has **{num_layers} transformer layers** with **{num_heads} attention heads** each.
264
- In every layer, **self-attention** mixes information across all patches, enabling long-range dependencies
265
- and global context.
266
-
267
- - The [CLS] token aggregates information across patches and is passed through a classification head to produce
268
- logits over ImageNet-1k classes (we show the top-5).
269
-
270
- - The attention heatmap we display is:
271
- - From **[CLS] → patch tokens**
272
- - For a selected `(layer, head)`
273
- - Reshaped into a `{grid_size}×{grid_size}` grid and upsampled to image resolution for overlay.
274
-
275
- - The PCA plot shows the **final-layer patch embeddings** projected to 2D, giving an intuition of how
276
- ViT places patches in a semantic space.
277
-
278
- Use the sliders to explore different layers and heads and see how the attention focus changes.
279
  """
280
 
 
 
 
 
 
 
 
 
 
281
 
282
- # ---------------------- ATTENTION SLIDER UPDATE ----------------------
 
 
 
 
 
 
 
 
283
 
284
 
285
- def update_attention_view(
286
- state: Dict[str, Any], layer_idx: int, head_idx: int
287
- ):
288
  """
289
- Called when user moves the layer/head sliders.
290
- Returns a new attention overlay image.
 
291
  """
292
- if not state or "cls_to_patch" not in state:
293
  return None
294
 
295
- cls_to_patch = state["cls_to_patch"]
296
  base_img = state["base_image"]
297
- num_layers = state["num_layers"]
298
- num_heads = state["num_heads"]
299
-
300
- # clamp indices safely
301
- l = max(0, min(int(layer_idx), num_layers - 1))
302
- h = max(0, min(int(head_idx), num_heads - 1))
303
-
304
- grid = cls_to_patch[l, h]
305
- overlay = make_attention_overlay(base_img, grid)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  return overlay
307
 
308
 
309
- # ---------------------- BUILD UI ----------------------
310
-
311
-
312
- with gr.Blocks(title="Vision Transformer (ViT) Visualizer") as demo:
313
- gr.Markdown(
314
- """
315
- # 🧠 Vision Transformer (ViT) How It Sees the World
316
-
317
- Upload an image and explore how a Vision Transformer:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
- - Cuts it into patches (tokens)
320
- - Attends to different regions via self-attention
321
- - Embeds patches into a high-dimensional space
322
- - Predicts what’s in the image
323
 
324
- Toggle **simple / technical** explanation and move the sliders to change
325
- which layer/head's attention you’re seeing.
326
- """
327
- )
328
 
329
  with gr.Row():
330
  with gr.Column(scale=1):
331
- img_in = gr.Image(
332
- label="Upload image",
333
- type="pil",
334
- )
335
- simple_ck = gr.Checkbox(
336
- label="Simple explanation (for everyone)",
337
- value=True,
338
- )
339
- run_btn = gr.Button("Run ViT Analysis", variant="primary")
340
-
341
- gr.Markdown(
342
- "Try images like: animals, objects, scenes. This uses `google/vit-base-patch16-224` (ImageNet-1k)."
343
- )
344
-
345
- with gr.Column(scale=1):
346
- preds_df = gr.Dataframe(
347
- headers=["Label", "Probability"],
348
- datatype=["str", "number"],
349
- interactive=False,
350
- label="Top-5 predictions",
351
- )
352
- explanation_md = gr.Markdown(label="Explanation")
353
 
354
- gr.Markdown("## 🧩 Tokens & Attention")
 
 
 
355
 
356
- with gr.Row():
357
- patch_img = gr.Image(
358
- label="Patches (16×16) how ViT tokenizes the image",
359
- interactive=False,
360
- )
361
- attn_img = gr.Image(
362
- label="Attention heatmap (CLS → patches)",
363
- interactive=False,
364
- )
365
 
366
- with gr.Row():
367
- layer_slider = gr.Slider(
368
- minimum=0,
369
- maximum=11, # ViT-base has 12 layers (0-11)
370
- step=1,
371
- value=11,
372
- label="Layer (0 = shallow, 11 = deepest)",
373
- )
374
- head_slider = gr.Slider(
375
- minimum=0,
376
- maximum=11, # 12 attention heads
377
- step=1,
378
- value=0,
379
- label="Head index",
380
- )
381
-
382
- gr.Markdown("## 🌌 Patch embeddings in 2D (PCA)")
383
-
384
- pca_plot = gr.Plot(label="Patches in embedding space (last layer)")
385
 
386
  state = gr.State()
387
 
388
- # main button: run full analysis
389
  run_btn.click(
390
- fn=analyze_vit,
391
- inputs=[img_in, simple_ck],
392
- outputs=[patch_img, attn_img, pca_plot, preds_df, explanation_md, state],
393
  )
394
 
395
- # when sliders change: update attention overlay only
396
  layer_slider.change(
397
- fn=update_attention_view,
398
- inputs=[state, layer_slider, head_slider],
399
- outputs=[attn_img],
400
  )
401
  head_slider.change(
402
- fn=update_attention_view,
403
- inputs=[state, layer_slider, head_slider],
404
- outputs=[attn_img],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  )
406
 
407
  demo.launch()
 
1
+ # ViT Visualizer — Full Interpretability Suite (A + B + C)
2
+ # Model: google/vit-base-patch16-224
3
+ # Gradio 5 compatible, CPU-friendly
4
+ # Features:
5
+ # - Patch grid (16x16)
6
+ # - Patch attention (per layer / per head / query token)
7
+ # - Attention rollout (layer aggregated)
8
+ # - PCA of patch embeddings across selected layers
9
+ # - Top-5 predictions & simple/technical explanations
 
10
  # ==========================================================
11
 
12
  import math
13
  import warnings
14
+ from typing import Any, Dict, List, Optional, Tuple
15
 
16
  import gradio as gr
 
17
  import numpy as np
18
+ import torch
19
+ from PIL import Image, ImageDraw, ImageFont
20
+ from transformers import AutoImageProcessor, ViTModel, ViTForImageClassification
21
  from sklearn.decomposition import PCA
22
+ import plotly.express as px
23
+ import plotly.graph_objects as go
24
 
25
  warnings.filterwarnings("ignore")
26
 
 
27
  MODEL_NAME = "google/vit-base-patch16-224"
28
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
 
30
+ # global caches
31
+ VIT_BASE = None # ViTModel (encoder with hidden states & attentions)
32
+ VIT_CLF = None # ViTForImageClassification (classification head)
33
+ PROCESSOR = None
 
 
 
 
 
 
 
 
34
 
 
 
35
 
36
+ # ------------------ model loader with SDPA fix ------------------
37
+ def load_models():
38
+ global VIT_BASE, VIT_CLF, PROCESSOR
39
+ if VIT_BASE is not None and VIT_CLF is not None and PROCESSOR is not None:
40
+ return VIT_BASE, VIT_CLF, PROCESSOR
41
 
42
+ PROCESSOR = AutoImageProcessor.from_pretrained(MODEL_NAME)
 
43
 
44
+ # base ViT (encoder) - we need hidden_states & attentions
45
+ base = ViTModel.from_pretrained(MODEL_NAME, output_hidden_states=True)
46
+ # fix attn backend so we can access attentions
47
+ base.config.attn_implementation = "eager"
48
+ base.config.output_attentions = True
49
+ base.config.output_hidden_states = True
50
+ base.to(DEVICE)
51
+ base.eval()
52
 
53
+ # classifier head for top-k labels
54
+ clf = ViTForImageClassification.from_pretrained(MODEL_NAME)
55
+ clf.to(DEVICE)
56
+ clf.eval()
57
 
58
+ VIT_BASE = base
59
+ VIT_CLF = clf
60
+ return base, clf, PROCESSOR
61
 
62
 
63
+ # ------------------ helpers: patch grid & overlay ------------------
64
+ def make_patch_grid_image(pil: Image.Image, patch_size: int = 16, target_size: int = 224) -> Image.Image:
65
+ img = pil.convert("RGB").resize((target_size, target_size))
 
 
66
  draw = ImageDraw.Draw(img)
67
  w, h = img.size
68
  for x in range(0, w, patch_size):
69
+ draw.line((x, 0, x, h), fill=(0, 200, 0), width=1)
70
  for y in range(0, h, patch_size):
71
+ draw.line((0, y, w, y), fill=(0, 200, 0), width=1)
72
  return img
73
 
74
 
75
+ def make_attention_overlay(base_img: Image.Image, heat_grid: np.ndarray, cmap_alpha: float = 0.45) -> Image.Image:
 
 
76
  """
77
+ heat_grid: (G, G) values in any scale (we will normalize)
78
+ overlay on base_img (resized to 224x224)
79
  """
80
+ img = base_img.convert("RGB").resize((224, 224))
81
+ g = np.array(heat_grid, dtype=np.float32)
82
+ # normalize 0..1
83
+ if np.any(g):
84
+ g = g - g.min()
85
+ if g.max() > 0:
86
+ g = g / g.max()
87
  else:
88
+ g = np.zeros_like(g, dtype=np.float32)
89
+
90
+ # upsample
91
+ heat_img = Image.fromarray((g * 255).astype("uint8"), mode="L").resize((224, 224), Image.BILINEAR)
92
+ heat = np.array(heat_img).astype(np.float32) / 255.0
93
+
94
+ # simple colormap blue->red
 
 
 
 
 
95
  r = heat
96
+ gch = np.zeros_like(heat)
97
  b = 1.0 - heat
98
+ cam = np.stack([r, gch, b], axis=-1)
99
 
100
+ base_np = np.array(img).astype(np.float32) / 255.0
101
+ blended = (1 - cmap_alpha) * base_np + cmap_alpha * cam
 
102
  blended = np.clip(blended * 255.0, 0, 255).astype("uint8")
103
  return Image.fromarray(blended)
104
 
105
 
106
+ # ------------------ attention rollout (Abnar & Zuidema) ------------------
107
+ def compute_attention_rollout(all_attentions: List[torch.Tensor]) -> np.ndarray:
108
  """
109
+ all_attentions: list length L of tensors (batch, heads, seq, seq)
110
+ We'll average heads per layer -> (seq, seq) and compute rollout:
111
+ R = prod_l (A_l_hat) where A_l_hat = A_l + I; rows normalized
112
+ Returns rollout matrix (seq, seq)
113
  """
114
+ # convert to np arrays averaged over heads
115
+ avg_mats = []
116
+ for a in all_attentions:
117
+ # a: (batch=1, heads, seq, seq)
118
+ mat = a[0].mean(dim=0).detach().cpu().numpy() # (seq, seq)
119
+ avg_mats.append(mat)
120
+
121
+ seq = avg_mats[0].shape[0]
122
+ # add identity & normalize rows
123
+ aug = []
124
+ for A in avg_mats:
125
+ A_hat = A + np.eye(seq)
126
+ A_hat = A_hat / A_hat.sum(axis=-1, keepdims=True)
127
+ aug.append(A_hat)
128
+
129
+ # multiply (matrix product) in order
130
+ R = aug[0]
131
+ for A in aug[1:]:
132
+ R = A @ R
133
+ return R # (seq, seq)
134
+
135
+
136
+ # ------------------ PCA projection for multiple layers ------------------
137
+ def layers_pca_plot(hidden_states: List[torch.Tensor], layers: List[int]) -> Any:
138
+ """
139
+ hidden_states: list of tensors (batch, seq, hidden)
140
+ layers: list of indices within hidden_states to project
141
+ We'll remove CLS token and do PCA for each chosen layer;
142
+ plot patches from each layer with different colors on single plot.
143
+ """
144
+ pts_all = []
145
+ layer_labels = []
146
+ for li in layers:
147
+ hs = hidden_states[li][0].detach().cpu().numpy() # (seq, hidden)
148
+ patches = hs[1:, :] # remove CLS -> (N_patches, hidden)
149
+ # PCA to 2D
150
+ pca = PCA(n_components=2)
151
+ pts = pca.fit_transform(patches)
152
+ pts_all.append(pts)
153
+ layer_labels.append(np.array([li] * pts.shape[0]))
154
+
155
+ # combine
156
+ coords = np.vstack(pts_all)
157
+ labels = np.concatenate(layer_labels)
158
+ df = {"x": coords[:, 0], "y": coords[:, 1], "layer": labels.astype(str)}
159
+ fig = px.scatter(df, x="x", y="y", color="layer", title="Patch embeddings across layers (PCA)")
160
+ fig.update_traces(marker=dict(size=6))
161
+ fig.update_layout(height=480)
162
  return fig
163
 
164
 
165
+ # ------------------ core analyzer ------------------
166
+ def analyze_vit_full(img: Optional[Image.Image], simple: bool):
 
 
 
 
 
 
 
 
 
 
 
 
167
  if img is None:
168
  return (
169
+ None, None, None, None, None, "", {}, {}
 
 
 
 
 
170
  )
171
 
172
+ base, clf, processor = load_models()
173
 
174
+ # preprocess to device
175
  img_resized = img.convert("RGB").resize((224, 224))
176
+ inputs = processor(images=img_resized, return_tensors="pt").to(DEVICE)
 
 
 
177
 
178
+ # forward pass through base model
179
  with torch.no_grad():
180
+ outputs = base(**inputs)
181
 
182
+ # outputs.attentions: list L tensors (batch=1, heads, seq, seq)
183
+ attentions = outputs.attentions # list length L
184
+ hidden_states = outputs.hidden_states # list length L+1 (including embeddings) usually
185
+
186
+ L = len(attentions)
187
+ seq_len = attentions[0].shape[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  n_patches = seq_len - 1
189
  grid_size = int(math.sqrt(n_patches))
190
  if grid_size * grid_size != n_patches:
191
+ # fallback: compute closest integer grid
192
  grid_size = int(round(math.sqrt(n_patches)))
193
 
194
+ # default selections
195
+ default_layer = L - 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  default_head = 0
197
+ # default query token = 0 (CLS)
198
+ default_query = 0
199
+
200
+ # Build patch grid image
201
+ patch_grid = make_patch_grid_image(img.copy(), patch_size=16, target_size=224)
202
+
203
+ # Build per-layer per-head CLS->patch default overlay
204
+ # pick last layer, head 0, CLS query
205
+ att_np = attentions[default_layer][0].cpu().numpy() # (heads, seq, seq)
206
+ cls_to_patches = att_np[default_head, 0, 1:] # (n_patches,)
207
+ cls_grid = cls_to_patches.reshape(grid_size, grid_size)
208
+ attn_overlay = make_attention_overlay(img, cls_grid)
209
+
210
+ # Compute rollout
211
+ rollout_mat = compute_attention_rollout(attentions) # (seq, seq)
212
+ rollout_cls = rollout_mat[0, 1:]
213
+ rollout_grid = rollout_cls.reshape(grid_size, grid_size)
214
+ rollout_overlay = make_attention_overlay(img, rollout_grid, cmap_alpha=0.5)
215
+
216
+ # PCA multi-layer: pick a few representative layers (start, quarter, half, three-quarters, last)
217
+ layers_to_show = sorted(
218
+ list({0, max(0, L // 4), max(0, L // 2), max(0, 3 * L // 4), L - 1})
219
+ )
220
+ pca_fig = layers_pca_plot(hidden_states, layers_to_show)
221
 
222
+ # Classification top-5
223
+ with torch.no_grad():
224
+ logits = clf(**inputs).logits[0].cpu().numpy()
225
+ probs = np.exp(logits - logits.max())
226
+ probs = probs / probs.sum()
227
+ top5 = probs.argsort()[-5:][::-1]
228
+ labels = clf.config.id2label
229
+ preds_text = "\n".join([f"{labels[i]} — {probs[i]*100:.2f}%" for i in top5])
230
 
231
+ # Explanation
 
 
232
  if simple:
233
+ explain_md = f"""
234
+ ### 🧒 How ViT Sees the Image (Simple)
235
+ 1. Image is cut into {grid_size}×{grid_size} = {grid_size*grid_size} patches (16×16).
236
+ 2. Each patch becomes a token. The model learns what each patch "means".
237
+ 3. Attention tells each token which other patches matter to it.
238
+ 4. Rollout aggregates attention across layers to show the final "focus".
239
+ 5. PCA shows how patch features evolve across layers (from raw to object-aware).
 
 
 
 
 
240
  """
241
  else:
242
+ explain_md = f"""
243
+ ### 🔬 Technical Explanation
244
+ - Model: {MODEL_NAME}
245
+ - Transformer layers: {L}, patch grid: {grid_size}×{grid_size}
246
+ - We extract token attentions (heads) and hidden states for PCA.
247
+ - Patch attention visualization maps token attention back to the image grid.
248
+ - Attention rollout uses Abnar & Zuidema's method to accumulate attention paths across layers.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  """
250
 
251
+ # return many things + state necessary for interactive updates (layer/head/query)
252
+ state = {
253
+ "attentions": [a.cpu() for a in attentions], # store on CPU to allow slider updates
254
+ "hidden_states": [h.cpu() for h in hidden_states],
255
+ "grid_size": grid_size,
256
+ "num_layers": L,
257
+ "num_heads": attentions[0].shape[1],
258
+ "base_image": img, # original high-res image (we'll resize to 224 when overlaying)
259
+ }
260
 
261
+ return (
262
+ patch_grid,
263
+ attn_overlay,
264
+ rollout_overlay,
265
+ pca_fig,
266
+ preds_text,
267
+ explain_md,
268
+ state,
269
+ )
270
 
271
 
272
+ # ------------------ update functions for sliders / choices ------------------
273
+ def update_layer_head_query(state: Dict[str, Any], layer_idx: int, head_idx: int, query_token: int, mode: str):
 
274
  """
275
+ mode:
276
+ - "patch_attention": attention of query_token -> patches at (layer, head)
277
+ - "rollout": ignored (we will return rollout overlay)
278
  """
279
+ if not state:
280
  return None
281
 
 
282
  base_img = state["base_image"]
283
+ grid = state["grid_size"]
284
+ L = state["num_layers"]
285
+ H = state["num_heads"]
286
+
287
+ l = max(0, min(int(layer_idx), L - 1))
288
+ h = max(0, min(int(head_idx), H - 1))
289
+ q = max(0, min(int(query_token), grid * grid)) # q in 0..n_patches (0==CLS)
290
+
291
+ # load attention for layer l: it's a CPU tensor (heads, seq, seq) already stored as state
292
+ att_tensor = state["attentions"][l] # shape (heads, seq, seq) because we saved a[0] earlier
293
+ # ensure shape (heads, seq, seq)
294
+ if att_tensor.ndim == 4: # sometimes shape might be (1, heads, seq, seq)
295
+ att_tensor = att_tensor[0]
296
+ att_np = att_tensor.numpy() # (heads, seq, seq)
297
+
298
+ # query q -> keys: if q == 0 it's CLS; keys positions 1..seq-1 are patches
299
+ seq = att_np.shape[-1]
300
+ n_patches = seq - 1
301
+ # column indices for keys: 1..seq-1 map to patches 0..n_patches-1
302
+ if q >= seq:
303
+ q = 0
304
+
305
+ # get attention vector for head h: att[h, q, 1:]
306
+ vec = att_np[h, q, 1:]
307
+ # if vec shorter/longer than grid^2, adjust
308
+ if vec.shape[0] != grid * grid:
309
+ # pad or trim
310
+ tmp = np.zeros(grid * grid, dtype=np.float32)
311
+ nmin = min(vec.shape[0], tmp.shape[0])
312
+ tmp[:nmin] = vec[:nmin]
313
+ vec = tmp
314
+
315
+ grid_map = vec.reshape(grid, grid)
316
+ overlay = make_attention_overlay(base_img, grid_map)
317
  return overlay
318
 
319
 
320
+ def get_rollout_overlay(state: Dict[str, Any]):
321
+ if not state:
322
+ return None
323
+ attentions = state["attentions"]
324
+ # attentions list of tensors (heads, seq, seq)
325
+ # convert to list of (1, heads, seq, seq) for compute_attention_rollout
326
+ mats = [a.unsqueeze(0) if a.ndim == 3 else a for a in attentions]
327
+ R = compute_attention_rollout(mats) # (seq, seq)
328
+ grid = state["grid_size"]
329
+ rollout_cls = R[0, 1:]
330
+ if rollout_cls.shape[0] != grid * grid:
331
+ tmp = np.zeros(grid * grid, dtype=np.float32)
332
+ nmin = min(rollout_cls.shape[0], tmp.shape[0])
333
+ tmp[:nmin] = rollout_cls[:nmin]
334
+ rollout_cls = tmp
335
+ rollout_grid = rollout_cls.reshape(grid, grid)
336
+ return make_attention_overlay(state["base_image"], rollout_grid, cmap_alpha=0.55)
337
+
338
+
339
+ def update_pca_layers(state: Dict[str, Any], selected_layers: List[int]):
340
+ if not state:
341
+ return None
342
+ # hidden_states stored as list of CPU tensors (batch, seq, hidden)
343
+ hs = state["hidden_states"]
344
+ # ensure layers within range
345
+ layers = [max(0, min(int(l), len(hs) - 1)) for l in selected_layers]
346
+ fig = layers_pca_plot(hs, layers)
347
+ return fig
348
 
 
 
 
 
349
 
350
+ # ------------------ GRADIO UI ------------------
351
+ with gr.Blocks(title="ViT Full Interpretability (A+B+C)") as demo:
352
+ gr.Markdown("# 🔍 ViT Visualizer — Patch Attention, Rollout & Layer PCA\n"
353
+ "Model: **google/vit-base-patch16-224** — explore patches, heads, layers, rollout and feature evolution.")
354
 
355
  with gr.Row():
356
  with gr.Column(scale=1):
357
+ img_in = gr.Image(label="Upload image (object/scene)", type="pil")
358
+ simple = gr.Checkbox(label="Simple explanation (kid-friendly)", value=True)
359
+ run_btn = gr.Button("Analyze ViT (full)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
+ gr.Markdown("**Patch Attention Controls**\nSelect layer, head and query token (0 = CLS, 1.. = patches left→right top→bottom).")
362
+ layer_slider = gr.Slider(minimum=0, maximum=11, step=1, value=11, label="Layer")
363
+ head_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Head")
364
+ query_slider = gr.Slider(minimum=0, maximum=196, step=1, value=0, label="Query token (0=CLS)")
365
 
366
+ gr.Markdown("**Attention Rollout & PCA**")
367
+ rollout_btn = gr.Button("Refresh Rollout Overlay")
368
+ # PCA layers selection: simple multi-select text entry allowed (comma separated)
369
+ pca_layers_txt = gr.Textbox(label="PCA layers (comma separated indices, e.g. 0,3,6,11)", value="0,3,6,11,11")
 
 
 
 
 
370
 
371
+ with gr.Column(scale=1):
372
+ gr.Markdown("### Outputs")
373
+ patch_grid_out = gr.Image(label="Patch grid (16×16)")
374
+ attn_overlay_out = gr.Image(label="Patch Attention Overlay (layer/head/query)")
375
+ rollout_overlay_out = gr.Image(label="Attention Rollout Overlay (aggregated)")
376
+ pca_out = gr.Plot(label="PCA: patch embeddings across selected layers")
377
+ preds_out = gr.Textbox(label="Top-5 predictions", lines=6)
378
+ explanation_out = gr.Markdown(label="Explanation")
 
 
 
 
 
 
 
 
 
 
 
379
 
380
  state = gr.State()
381
 
382
+ # main analysis
383
  run_btn.click(
384
+ fn=analyze_vit_full,
385
+ inputs=[img_in, simple],
386
+ outputs=[patch_grid_out, attn_overlay_out, rollout_overlay_out, pca_out, preds_out, explanation_out, state],
387
  )
388
 
389
+ # update attention overlay (layer/head/query)
390
  layer_slider.change(
391
+ fn=update_layer_head_query,
392
+ inputs=[state, layer_slider, head_slider, query_slider, gr.State("patch_attention")],
393
+ outputs=[attn_overlay_out],
394
  )
395
  head_slider.change(
396
+ fn=update_layer_head_query,
397
+ inputs=[state, layer_slider, head_slider, query_slider, gr.State("patch_attention")],
398
+ outputs=[attn_overlay_out],
399
+ )
400
+ query_slider.change(
401
+ fn=update_layer_head_query,
402
+ inputs=[state, layer_slider, head_slider, query_slider, gr.State("patch_attention")],
403
+ outputs=[attn_overlay_out],
404
+ )
405
+
406
+ # rollout refresh
407
+ rollout_btn.click(
408
+ fn=get_rollout_overlay,
409
+ inputs=[state],
410
+ outputs=[rollout_overlay_out],
411
+ )
412
+
413
+ # PCA layers (parse input text)
414
+ def parse_and_update_pca(state_obj, txt):
415
+ if not state_obj:
416
+ return None
417
+ try:
418
+ parts = [int(p.strip()) for p in txt.split(",") if p.strip() != ""]
419
+ except:
420
+ parts = [0, max(0, state_obj["num_layers"] - 1)]
421
+ return update_pca_layers(state_obj, parts)
422
+
423
+ pca_layers_txt.submit(
424
+ fn=parse_and_update_pca,
425
+ inputs=[state, pca_layers_txt],
426
+ outputs=[pca_out],
427
  )
428
 
429
  demo.launch()