PraneshJs commited on
Commit
e4963e3
Β·
verified Β·
1 Parent(s): 50cbd61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +348 -270
app.py CHANGED
@@ -1,12 +1,16 @@
1
- # ViT Visualizer β€” Full Interpretability Suite (A + B + C)
 
2
  # Model: google/vit-base-patch16-224
3
- # Gradio 5 compatible, CPU-friendly
 
4
  # Features:
5
- # - Patch grid (16x16)
6
- # - Patch attention (per layer / per head / query token)
7
- # - Attention rollout (layer aggregated)
8
- # - PCA of patch embeddings across layers
9
- # - Top-5 predictions & simple/technical explanations
 
 
10
  # ==========================================================
11
 
12
  import math
@@ -16,14 +20,15 @@ from typing import Any, Dict, List, Optional, Tuple
16
  import gradio as gr
17
  import numpy as np
18
  import torch
19
- from PIL import Image, ImageDraw
 
 
20
  from transformers import (
21
  AutoImageProcessor,
22
  ViTModel,
23
  ViTForImageClassification,
24
  AutoConfig,
25
  )
26
- from sklearn.decomposition import PCA
27
  import plotly.express as px
28
 
29
  warnings.filterwarnings("ignore")
@@ -31,214 +36,261 @@ warnings.filterwarnings("ignore")
31
  MODEL_NAME = "google/vit-base-patch16-224"
32
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
 
34
- # global caches
35
- VIT_BASE = None # ViTModel (encoder with hidden states & attentions)
36
- VIT_CLF = None # ViTForImageClassification (classification head)
37
  PROCESSOR = None
38
 
39
 
40
- # ------------------ model loader with SDPA -> eager fix ------------------
41
  def load_models():
42
- """
43
- Load processor + ViT base + classification head.
44
- Important: create config first, set attn_implementation='eager'
45
- before enabling output_attentions/output_hidden_states, then load models.
46
- """
47
- global VIT_BASE, VIT_CLF, PROCESSOR
48
- if VIT_BASE is not None and VIT_CLF is not None and PROCESSOR is not None:
49
- return VIT_BASE, VIT_CLF, PROCESSOR
50
 
51
  PROCESSOR = AutoImageProcessor.from_pretrained(MODEL_NAME)
52
 
53
- # load config, modify before creating model
54
- cfg = AutoConfig = None
55
- try:
56
- cfg = AutoConfig.from_pretrained(MODEL_NAME)
57
- except Exception:
58
- # fallback: load a default config and set minimal fields
59
- from transformers import ViTConfig
60
- cfg = ViTConfig.from_pretrained(MODEL_NAME)
61
-
62
- # FORCE eager attention backend so we can extract attentions
63
- # (must set attn_implementation before enabling output_attentions)
64
- cfg.attn_implementation = "eager"
65
  cfg.output_attentions = True
66
  cfg.output_hidden_states = True
67
 
68
- # now load the base encoder with the modified config
69
- base = ViTModel.from_pretrained(MODEL_NAME, config=cfg)
70
- base.to(DEVICE)
71
- base.eval()
 
 
 
 
 
72
 
73
- # load classifier separately (we can use default config for classifier)
74
- clf = ViTForImageClassification.from_pretrained(MODEL_NAME)
75
- clf.to(DEVICE)
76
- clf.eval()
77
 
78
- VIT_BASE = base
79
- VIT_CLF = clf
80
- return base, clf, PROCESSOR
 
 
 
 
 
 
 
 
81
 
82
 
83
- # ------------------ helpers: patch grid & overlay ------------------
84
- def make_patch_grid_image(pil: Image.Image, patch_size: int = 16, target_size: int = 224) -> Image.Image:
85
- img = pil.convert("RGB").resize((target_size, target_size))
86
  draw = ImageDraw.Draw(img)
87
  w, h = img.size
88
  for x in range(0, w, patch_size):
89
- draw.line((x, 0, x, h), fill=(0, 200, 0), width=1)
90
  for y in range(0, h, patch_size):
91
- draw.line((0, y, w, y), fill=(0, 200, 0), width=1)
92
  return img
93
 
94
 
95
- def make_attention_overlay(base_img: Image.Image, heat_grid: np.ndarray, cmap_alpha: float = 0.45) -> Image.Image:
96
  """
97
- heat_grid: (G, G) values (any scale) -> normalized then overlaid on base_img (resized to 224x224)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  """
99
- img = base_img.convert("RGB").resize((224, 224))
 
 
 
100
  g = np.array(heat_grid, dtype=np.float32)
101
- # normalize 0..1
102
  if np.any(g):
103
  g = g - g.min()
104
  if g.max() > 0:
105
  g = g / g.max()
106
  else:
107
- g = np.zeros_like(g, dtype=np.float32)
108
-
109
- # upsample to image resolution
110
  heat_img = Image.fromarray((g * 255).astype("uint8"), mode="L").resize((224, 224), Image.BILINEAR)
111
  heat = np.array(heat_img).astype(np.float32) / 255.0
112
-
113
- # simple blue->red colormap
114
- r = heat
115
- gch = np.zeros_like(heat)
116
- b = 1.0 - heat
117
- cam = np.stack([r, gch, b], axis=-1)
118
-
119
- base_np = np.array(img).astype(np.float32) / 255.0
120
- blended = (1 - cmap_alpha) * base_np + cmap_alpha * cam
121
- blended = np.clip(blended * 255.0, 0, 255).astype("uint8")
122
- return Image.fromarray(blended)
 
123
 
124
 
125
- # ------------------ attention rollout (Abnar & Zuidema) ------------------
126
  def compute_attention_rollout(all_attentions: List[torch.Tensor]) -> np.ndarray:
127
- """
128
- all_attentions: list length L of tensors (batch, heads, seq, seq)
129
- We'll average heads per layer -> (seq, seq) and compute rollout:
130
- R = prod_l (A_l_hat) where A_l_hat = A_l + I; rows normalized
131
- Returns rollout matrix (seq, seq)
132
- """
133
  avg_mats = []
134
  for a in all_attentions:
135
- # a: (batch=1, heads, seq, seq)
136
  mat = a[0].mean(dim=0).detach().cpu().numpy() # (seq, seq)
137
  avg_mats.append(mat)
138
-
139
  seq = avg_mats[0].shape[0]
140
  aug = []
141
  for A in avg_mats:
142
  A_hat = A + np.eye(seq)
143
- # normalize rows (sum over last dim)
144
  row_sums = A_hat.sum(axis=-1, keepdims=True)
145
- # avoid division by zero
146
  row_sums[row_sums == 0] = 1.0
147
  A_hat = A_hat / row_sums
148
  aug.append(A_hat)
149
-
150
  R = aug[0]
151
  for A in aug[1:]:
152
  R = A @ R
153
  return R # (seq, seq)
154
 
155
 
156
- # ------------------ PCA projection for multiple layers ------------------
157
- def layers_pca_plot(hidden_states: List[torch.Tensor], layers: List[int]) -> Any:
158
- """
159
- hidden_states: list of tensors (batch, seq, hidden)
160
- layers: list of indices within hidden_states to project
161
- We'll remove CLS token and do PCA for each chosen layer;
162
- plot patches from each layer with different colors on single plot.
163
- """
164
  pts_all = []
165
- layer_labels = []
166
  for li in layers:
167
- hs = hidden_states[li][0].detach().cpu().numpy() # (seq, hidden)
168
- patches = hs[1:, :] # remove CLS -> (N_patches, hidden)
169
  pca = PCA(n_components=2)
170
  pts = pca.fit_transform(patches)
171
  pts_all.append(pts)
172
- layer_labels.append(np.array([li] * pts.shape[0]))
173
-
174
  coords = np.vstack(pts_all)
175
- labels = np.concatenate(layer_labels)
176
- df = {"x": coords[:, 0], "y": coords[:, 1], "layer": labels.astype(str)}
177
  fig = px.scatter(df, x="x", y="y", color="layer", title="Patch embeddings across layers (PCA)")
178
  fig.update_traces(marker=dict(size=6))
179
  fig.update_layout(height=480)
180
  return fig
181
 
182
 
183
- # ------------------ core analyzer ------------------
184
- def analyze_vit_full(img: Optional[Image.Image], simple: bool):
185
  if img is None:
186
- return (None, None, None, None, None, "", {})
 
 
187
 
188
  base, clf, processor = load_models()
189
 
190
- # preprocess to device
191
- img_resized = img.convert("RGB").resize((224, 224))
192
- inputs = processor(images=img_resized, return_tensors="pt").to(DEVICE)
193
 
194
- # forward pass through base model
195
  with torch.no_grad():
196
  outputs = base(**inputs)
197
 
198
- # outputs.attentions: list L tensors (batch=1, heads, seq, seq)
199
- attentions = outputs.attentions
200
  hidden_states = outputs.hidden_states
201
 
202
- L = len(attentions)
 
203
  seq_len = attentions[0].shape[-1]
204
  n_patches = seq_len - 1
205
- grid_size = int(math.sqrt(n_patches))
206
- if grid_size * grid_size != n_patches:
207
- grid_size = int(round(math.sqrt(n_patches)))
208
-
209
- # Build patch grid image
210
- patch_grid = make_patch_grid_image(img.copy(), patch_size=16, target_size=224)
211
-
212
- # default overlay: last layer, head 0, CLS query
213
- last_layer = L - 1
214
- head0 = 0
215
- # attentions[last_layer]: shape (batch=1, heads, seq, seq)
216
- att_np = attentions[last_layer][0].cpu().numpy() # (heads, seq, seq)
217
- cls_to_patches = att_np[head0, 0, 1:] # (n_patches,)
218
- if cls_to_patches.shape[0] != grid_size * grid_size:
219
- tmp = np.zeros(grid_size * grid_size, dtype=np.float32)
220
- nmin = min(cls_to_patches.shape[0], tmp.shape[0])
221
- tmp[:nmin] = cls_to_patches[:nmin]
222
- cls_to_patches = tmp
223
- cls_grid = cls_to_patches.reshape(grid_size, grid_size)
224
- attn_overlay = make_attention_overlay(img, cls_grid)
225
 
226
- # Compute rollout overlay (CLS)
227
- rollout_mat = compute_attention_rollout(attentions) # (seq, seq)
228
- rollout_cls = rollout_mat[0, 1:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  if rollout_cls.shape[0] != grid_size * grid_size:
230
- tmp = np.zeros(grid_size * grid_size, dtype=np.float32)
231
- nmin = min(rollout_cls.shape[0], tmp.shape[0])
232
  tmp[:nmin] = rollout_cls[:nmin]
233
  rollout_cls = tmp
234
  rollout_grid = rollout_cls.reshape(grid_size, grid_size)
235
- rollout_overlay = make_attention_overlay(img, rollout_grid, cmap_alpha=0.55)
236
-
237
- # PCA multi-layer: choose representative layers
238
- layers_to_show = sorted(list({0, max(0, L // 4), max(0, L // 2), max(0, 3 * L // 4), L - 1}))
239
- pca_fig = layers_pca_plot(hidden_states, layers_to_show)
240
 
241
- # Classification top-5
242
  with torch.no_grad():
243
  logits = clf(**inputs).logits[0].cpu().numpy()
244
  probs = np.exp(logits - logits.max())
@@ -247,175 +299,201 @@ def analyze_vit_full(img: Optional[Image.Image], simple: bool):
247
  labels = clf.config.id2label
248
  preds_text = "\n".join([f"{labels[i]} β€” {probs[i]*100:.2f}%" for i in top5])
249
 
250
- # Explanation
251
- if simple:
252
- explain_md = f"""
253
- ### πŸ§’ How ViT Sees the Image (Simple)
254
- 1. Image is cut into {grid_size}Γ—{grid_size} = {grid_size*grid_size} patches (16Γ—16).
255
- 2. Each patch becomes a token. The model learns what each patch "means".
256
- 3. Attention tells each token which other patches matter to it.
257
- 4. Rollout aggregates attention across layers to show the final "focus".
258
- 5. PCA shows how patch features evolve across layers (from raw to object-aware).
259
- """
260
- else:
261
- explain_md = f"""
262
- ### πŸ”¬ Technical Explanation
263
- - Model: {MODEL_NAME}
264
- - Transformer layers: {L}, patch grid: {grid_size}Γ—{grid_size}
265
- - We extract token attentions (heads) and hidden states for PCA.
266
- - Patch attention visualization maps token attention back to the image grid.
267
- - Attention rollout uses Abnar & Zuidema's method to accumulate attention paths across layers.
268
- """
269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  state = {
271
- "attentions": [a.cpu() for a in attentions], # move to CPU for interactive updates
272
  "hidden_states": [h.cpu() for h in hidden_states],
273
  "grid_size": grid_size,
274
- "num_layers": L,
275
  "num_heads": attentions[0].shape[1],
276
  "base_image": img,
277
  }
278
 
279
- return patch_grid, attn_overlay, rollout_overlay, pca_fig, preds_text, explain_md, state
 
 
 
 
280
 
 
 
 
 
 
281
 
282
- # ------------------ update functions for sliders / choices ------------------
283
- def update_layer_head_query(state: Dict[str, Any], layer_idx: int, head_idx: int, query_token: int, mode: str):
284
- if not state:
285
- return None
286
 
287
- base_img = state["base_image"]
288
- grid = state["grid_size"]
289
- L = state["num_layers"]
290
- H = state["num_heads"]
 
 
 
 
 
 
 
 
 
291
 
292
- l = max(0, min(int(layer_idx), L - 1))
293
- h = max(0, min(int(head_idx), H - 1))
294
- q = max(0, min(int(query_token), grid * grid)) # q in 0..n_patches (0==CLS)
295
 
296
- att_tensor = state["attentions"][l] # shape (heads, seq, seq) or (1,heads,seq,seq)
 
 
 
 
 
 
297
  if att_tensor.ndim == 4:
298
  att_tensor = att_tensor[0]
299
  att_np = att_tensor.numpy() # (heads, seq, seq)
300
-
301
- seq = att_np.shape[-1]
302
- if q >= seq:
303
- q = 0
304
-
305
- vec = att_np[h, q, 1:]
306
  if vec.shape[0] != grid * grid:
307
- tmp = np.zeros(grid * grid, dtype=np.float32)
308
  nmin = min(vec.shape[0], tmp.shape[0])
309
  tmp[:nmin] = vec[:nmin]
310
  vec = tmp
311
-
312
  grid_map = vec.reshape(grid, grid)
313
- overlay = make_attention_overlay(base_img, grid_map)
314
- return overlay
315
 
316
 
317
- def get_rollout_overlay(state: Dict[str, Any]):
318
  if not state:
319
  return None
320
- attentions = state["attentions"]
321
- mats = [a.unsqueeze(0) if a.ndim == 3 else a for a in attentions]
322
- R = compute_attention_rollout(mats) # (seq, seq)
323
  grid = state["grid_size"]
324
  rollout_cls = R[0, 1:]
325
  if rollout_cls.shape[0] != grid * grid:
326
- tmp = np.zeros(grid * grid, dtype=np.float32)
327
- nmin = min(rollout_cls.shape[0], tmp.shape[0])
328
  tmp[:nmin] = rollout_cls[:nmin]
329
  rollout_cls = tmp
330
  rollout_grid = rollout_cls.reshape(grid, grid)
331
- return make_attention_overlay(state["base_image"], rollout_grid, cmap_alpha=0.55)
332
 
333
 
334
- def update_pca_layers(state: Dict[str, Any], selected_layers: List[int]):
335
  if not state:
336
  return None
337
- hs = state["hidden_states"]
338
- layers = [max(0, min(int(l), len(hs) - 1)) for l in selected_layers]
339
- fig = layers_pca_plot(hs, layers)
340
- return fig
341
-
342
-
343
- # ------------------ GRADIO UI ------------------
344
- with gr.Blocks(title="ViT Full Interpretability (A+B+C)") as demo:
345
- gr.Markdown("# πŸ” ViT Visualizer β€” Patch Attention, Rollout & Layer PCA\n"
346
- "Model: **google/vit-base-patch16-224** β€” explore patches, heads, layers, rollout and feature evolution.")
347
-
348
- with gr.Row():
349
- with gr.Column(scale=1):
350
- img_in = gr.Image(label="Upload image (object/scene)", type="pil")
351
- simple = gr.Checkbox(label="Simple explanation (kid-friendly)", value=True)
352
- run_btn = gr.Button("Analyze ViT (full)")
353
-
354
- gr.Markdown("**Patch Attention Controls**\nSelect layer, head and query token (0 = CLS, 1.. = patches left→right top→bottom).")
355
- layer_slider = gr.Slider(minimum=0, maximum=11, step=1, value=11, label="Layer")
356
- head_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Head")
357
- query_slider = gr.Slider(minimum=0, maximum=196, step=1, value=0, label="Query token (0=CLS)")
358
-
359
- gr.Markdown("**Attention Rollout & PCA**")
360
- rollout_btn = gr.Button("Refresh Rollout Overlay")
361
- pca_layers_txt = gr.Textbox(label="PCA layers (comma separated indices, e.g. 0,3,6,11)", value="0,3,6,11")
362
-
363
- with gr.Column(scale=1):
364
- gr.Markdown("### Outputs")
365
- patch_grid_out = gr.Image(label="Patch grid (16Γ—16)")
366
- attn_overlay_out = gr.Image(label="Patch Attention Overlay (layer/head/query)")
367
- rollout_overlay_out = gr.Image(label="Attention Rollout Overlay (aggregated)")
368
- pca_out = gr.Plot(label="PCA: patch embeddings across selected layers")
369
- preds_out = gr.Textbox(label="Top-5 predictions", lines=6)
370
- explanation_out = gr.Markdown(label="Explanation")
371
-
372
- state = gr.State()
373
-
374
- # main analysis
375
- run_btn.click(
376
- fn=analyze_vit_full,
377
- inputs=[img_in, simple],
378
- outputs=[patch_grid_out, attn_overlay_out, rollout_overlay_out, pca_out, preds_out, explanation_out, state],
379
- )
380
-
381
- # update attention overlay (layer/head/query)
382
- layer_slider.change(
383
- fn=update_layer_head_query,
384
- inputs=[state, layer_slider, head_slider, query_slider, gr.State("patch_attention")],
385
- outputs=[attn_overlay_out],
386
- )
387
- head_slider.change(
388
- fn=update_layer_head_query,
389
- inputs=[state, layer_slider, head_slider, query_slider, gr.State("patch_attention")],
390
- outputs=[attn_overlay_out],
391
- )
392
- query_slider.change(
393
- fn=update_layer_head_query,
394
- inputs=[state, layer_slider, head_slider, query_slider, gr.State("patch_attention")],
395
- outputs=[attn_overlay_out],
396
- )
397
-
398
- # rollout refresh
399
- rollout_btn.click(
400
- fn=get_rollout_overlay,
401
- inputs=[state],
402
- outputs=[rollout_overlay_out],
403
- )
404
-
405
- # PCA layers (parse input text)
406
- def parse_and_update_pca(state_obj, txt):
407
- if not state_obj:
408
- return None
409
- try:
410
- parts = [int(p.strip()) for p in txt.split(",") if p.strip() != ""]
411
- except:
412
- parts = [0, max(0, state_obj["num_layers"] - 1)]
413
- return update_pca_layers(state_obj, parts)
414
-
415
- pca_layers_txt.submit(
416
- fn=parse_and_update_pca,
417
- inputs=[state, pca_layers_txt],
418
- outputs=[pca_out],
419
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
  demo.launch()
 
1
+ # ==========================================================
2
+ # ViT Visualizer β€” Simple (comic-style) + Advanced Mode
3
  # Model: google/vit-base-patch16-224
4
+ # Gradio 5 compatible; CPU-friendly
5
+ #
6
  # Features:
7
+ # - Simple mode (4-step, non-technical, kid-friendly)
8
+ # Step1: Patch grid
9
+ # Step2: Patch clustering (colored blocks)
10
+ # Step3: Patch-to-patch arrows (simplified attention)
11
+ # Step4: Focus map (rollout) + Top-5 predictions
12
+ # - Advanced mode (attention maps per layer/head, rollout, PCA)
13
+ # - SDPA -> eager fix for attention extraction
14
  # ==========================================================
15
 
16
  import math
 
20
  import gradio as gr
21
  import numpy as np
22
  import torch
23
+ from PIL import Image, ImageDraw, ImageFont
24
+ from sklearn.cluster import KMeans
25
+ from sklearn.decomposition import PCA
26
  from transformers import (
27
  AutoImageProcessor,
28
  ViTModel,
29
  ViTForImageClassification,
30
  AutoConfig,
31
  )
 
32
  import plotly.express as px
33
 
34
  warnings.filterwarnings("ignore")
 
36
  MODEL_NAME = "google/vit-base-patch16-224"
37
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
 
39
+ # Globals
40
+ BASE_MODEL = None
41
+ CLF_MODEL = None
42
  PROCESSOR = None
43
 
44
 
45
+ # ------------------- model loader with SDPA -> eager fix -------------------
46
  def load_models():
47
+ global BASE_MODEL, CLF_MODEL, PROCESSOR
48
+ if BASE_MODEL is not None and CLF_MODEL is not None and PROCESSOR is not None:
49
+ return BASE_MODEL, CLF_MODEL, PROCESSOR
 
 
 
 
 
50
 
51
  PROCESSOR = AutoImageProcessor.from_pretrained(MODEL_NAME)
52
 
53
+ # load config first, set attn_implementation BEFORE enabling attentions
54
+ cfg = AutoConfig.from_pretrained(MODEL_NAME)
55
+ cfg.attn_implementation = "eager" # << must set this first
 
 
 
 
 
 
 
 
 
56
  cfg.output_attentions = True
57
  cfg.output_hidden_states = True
58
 
59
+ # load base encoder with modified config (we'll extract hidden states & attentions)
60
+ BASE_MODEL = ViTModel.from_pretrained(MODEL_NAME, config=cfg)
61
+ BASE_MODEL.to(DEVICE).eval()
62
+
63
+ # classifier head (for top-5 predictions)
64
+ CLF_MODEL = ViTForImageClassification.from_pretrained(MODEL_NAME)
65
+ CLF_MODEL.to(DEVICE).eval()
66
+
67
+ return BASE_MODEL, CLF_MODEL, PROCESSOR
68
 
 
 
 
 
69
 
70
+ # ------------------- utility: patch grid positions -------------------
71
+ def patch_grid_info(image_size: int = 224, patch_size: int = 16):
72
+ grid_size = image_size // patch_size
73
+ positions = []
74
+ for i in range(grid_size):
75
+ for j in range(grid_size):
76
+ # center coordinates of patch (x,y)
77
+ cx = int((j + 0.5) * patch_size)
78
+ cy = int((i + 0.5) * patch_size)
79
+ positions.append((cx, cy))
80
+ return grid_size, positions
81
 
82
 
83
+ # ------------------- visual helpers -------------------
84
+ def draw_patch_grid(img: Image.Image, patch_size: int = 16, outline=(0, 180, 0)) -> Image.Image:
85
+ img = img.convert("RGB").resize((224, 224))
86
  draw = ImageDraw.Draw(img)
87
  w, h = img.size
88
  for x in range(0, w, patch_size):
89
+ draw.line([(x, 0), (x, h)], fill=outline, width=1)
90
  for y in range(0, h, patch_size):
91
+ draw.line([(0, y), (w, y)], fill=outline, width=1)
92
  return img
93
 
94
 
95
+ def draw_cluster_blocks(img: Image.Image, labels: np.ndarray, n_clusters: int = 4, patch_size: int = 16):
96
  """
97
+ labels: (n_patches,) cluster labels assigned to each patch index (left→right, top→bottom)
98
+ """
99
+ img = img.convert("RGB").resize((224, 224))
100
+ draw = ImageDraw.Draw(img, "RGBA")
101
+ grid_size, positions = patch_grid_info()
102
+ colors = [
103
+ (255, 99, 71, 140),
104
+ (60, 179, 113, 140),
105
+ (65, 105, 225, 140),
106
+ (255, 215, 0, 140),
107
+ (199, 21, 133, 140),
108
+ (0, 206, 209, 140),
109
+ ]
110
+ for idx, lab in enumerate(labels):
111
+ i = idx // grid_size
112
+ j = idx % grid_size
113
+ x0 = j * patch_size
114
+ y0 = i * patch_size
115
+ x1 = x0 + patch_size
116
+ y1 = y0 + patch_size
117
+ col = colors[int(lab) % len(colors)]
118
+ draw.rectangle([x0, y0, x1, y1], fill=col)
119
+ return img
120
+
121
+
122
+ def draw_attention_arrows(img: Image.Image, att_matrix: np.ndarray, top_k: int = 3, query_idx: Optional[int] = None):
123
+ """
124
+ att_matrix: (n_patches, n_patches) attention from query->keys (already preprocessed)
125
+ If query_idx is None -> use CLS (not plotted as patch), else 0..n_patches-1
126
+ We'll draw arrows from query patch centers to top-k key patch centers.
127
+ """
128
+ img = img.convert("RGB").resize((224, 224))
129
+ draw = ImageDraw.Draw(img, "RGBA")
130
+ grid_size, positions = patch_grid_info()
131
+ # pick a query: if None, choose center patch
132
+ if query_idx is None:
133
+ query_idx = (grid_size * grid_size) // 2
134
+ qpos = positions[query_idx]
135
+ # find top_k keys attended by this query
136
+ vec = att_matrix[query_idx] # length n_patches
137
+ top_idx = vec.argsort()[-top_k:][::-1]
138
+ for t in top_idx:
139
+ kpos = positions[t]
140
+ # draw line + arrowhead
141
+ draw.line([qpos, kpos], fill=(255, 0, 0, 200), width=3)
142
+ # arrowhead: small triangle
143
+ dx = kpos[0] - qpos[0]
144
+ dy = kpos[1] - qpos[1]
145
+ ang = math.atan2(dy, dx)
146
+ # size proportional
147
+ ah = 8
148
+ p1 = (kpos[0] - ah * math.cos(ang - 0.3), kpos[1] - ah * math.sin(ang - 0.3))
149
+ p2 = (kpos[0] - ah * math.cos(ang + 0.3), kpos[1] - ah * math.sin(ang + 0.3))
150
+ draw.polygon([kpos, p1, p2], fill=(255, 0, 0, 200))
151
+ # highlight query patch with blue circle
152
+ r = 10
153
+ draw.ellipse([qpos[0] - r, qpos[1] - r, qpos[0] + r, qpos[1] + r], outline=(0, 0, 255, 220), width=2)
154
+ return img
155
+
156
+
157
+ def make_focus_overlay(img: Image.Image, heat_grid: np.ndarray, alpha: float = 0.6):
158
  """
159
+ heat_grid: (G,G) float map
160
+ overlay colored transparency on image where heat is high
161
+ """
162
+ img = img.convert("RGB").resize((224, 224))
163
  g = np.array(heat_grid, dtype=np.float32)
 
164
  if np.any(g):
165
  g = g - g.min()
166
  if g.max() > 0:
167
  g = g / g.max()
168
  else:
169
+ g = np.zeros_like(g)
 
 
170
  heat_img = Image.fromarray((g * 255).astype("uint8"), mode="L").resize((224, 224), Image.BILINEAR)
171
  heat = np.array(heat_img).astype(np.float32) / 255.0
172
+ draw = ImageDraw.Draw(img, "RGBA")
173
+ # color mapping simple: yellow -> red
174
+ H, W = heat.shape
175
+ for y in range(H):
176
+ for x in range(W):
177
+ v = heat[y, x]
178
+ if v > 0.05:
179
+ # map to color
180
+ r = int(255 * v)
181
+ gcol = int(200 * (1 - v))
182
+ draw.point((x, y), fill=(r, gcol, 40, int(255 * alpha * v)))
183
+ return img
184
 
185
 
186
+ # ------------------- attention rollout (Abnar & Zuidema) -------------------
187
  def compute_attention_rollout(all_attentions: List[torch.Tensor]) -> np.ndarray:
 
 
 
 
 
 
188
  avg_mats = []
189
  for a in all_attentions:
 
190
  mat = a[0].mean(dim=0).detach().cpu().numpy() # (seq, seq)
191
  avg_mats.append(mat)
 
192
  seq = avg_mats[0].shape[0]
193
  aug = []
194
  for A in avg_mats:
195
  A_hat = A + np.eye(seq)
 
196
  row_sums = A_hat.sum(axis=-1, keepdims=True)
 
197
  row_sums[row_sums == 0] = 1.0
198
  A_hat = A_hat / row_sums
199
  aug.append(A_hat)
 
200
  R = aug[0]
201
  for A in aug[1:]:
202
  R = A @ R
203
  return R # (seq, seq)
204
 
205
 
206
+ # ------------------- PCA helper for advanced -------------------
207
+ def pca_plot_from_hidden(hidden_states: List[torch.Tensor], layers: List[int]):
 
 
 
 
 
 
208
  pts_all = []
209
+ labels = []
210
  for li in layers:
211
+ hs = hidden_states[li][0].detach().cpu().numpy()
212
+ patches = hs[1:, :]
213
  pca = PCA(n_components=2)
214
  pts = pca.fit_transform(patches)
215
  pts_all.append(pts)
216
+ labels.append(np.array([li] * pts.shape[0]))
 
217
  coords = np.vstack(pts_all)
218
+ layer_labels = np.concatenate(labels)
219
+ df = {"x": coords[:, 0], "y": coords[:, 1], "layer": layer_labels.astype(str)}
220
  fig = px.scatter(df, x="x", y="y", color="layer", title="Patch embeddings across layers (PCA)")
221
  fig.update_traces(marker=dict(size=6))
222
  fig.update_layout(height=480)
223
  return fig
224
 
225
 
226
+ # ------------------- main analyzer (both modes) -------------------
227
+ def analyze_all(img: Optional[Image.Image], mode_simple: bool):
228
  if img is None:
229
+ # return placeholders for all outputs
230
+ empty = None
231
+ return empty, empty, empty, empty, "", empty, empty, empty
232
 
233
  base, clf, processor = load_models()
234
 
235
+ # preprocess
236
+ img224 = img.convert("RGB").resize((224, 224))
237
+ inputs = processor(images=img224, return_tensors="pt").to(DEVICE)
238
 
239
+ # forward through base model to get attentions & hidden states
240
  with torch.no_grad():
241
  outputs = base(**inputs)
242
 
243
+ attentions = outputs.attentions # list L of (1, heads, seq, seq)
 
244
  hidden_states = outputs.hidden_states
245
 
246
+ # build grid & info
247
+ grid_size, positions = patch_grid_info()
248
  seq_len = attentions[0].shape[-1]
249
  n_patches = seq_len - 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
+ # Step1: patch grid image
252
+ patch_grid_img = draw_patch_grid(img224.copy())
253
+
254
+ # Step2: cluster patches using last hidden layer embeddings
255
+ last_hidden = hidden_states[-1][0].detach().cpu().numpy() # (seq, hidden)
256
+ patch_embeddings = last_hidden[1:, :] # remove CLS
257
+ # KMeans small number clusters (4)
258
+ n_clusters = 4
259
+ try:
260
+ kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(patch_embeddings)
261
+ cluster_labels = kmeans.labels_
262
+ except Exception:
263
+ # fallback uniform
264
+ cluster_labels = np.zeros(n_patches, dtype=int)
265
+
266
+ cluster_img = draw_cluster_blocks(img224.copy(), cluster_labels, n_clusters=n_clusters)
267
+
268
+ # Step3: simplified arrows using average last-layer attention across heads
269
+ last_att = attentions[-1][0].mean(dim=0).cpu().numpy() # (seq, seq) averaged heads
270
+ # We want patch->patch attention (exclude CLS index in mapping)
271
+ # Map token indices 1.. to patch indices 0..
272
+ # Make an (n_patches, n_patches) matrix where row q corresponds to query patch q
273
+ if last_att.shape[0] >= n_patches + 1:
274
+ patch_to_patch = last_att[1:, 1:] # (n_patches, n_patches)
275
+ else:
276
+ # fallback zeros
277
+ patch_to_patch = np.zeros((n_patches, n_patches))
278
+ # draw arrows for a central query
279
+ arrow_img = draw_attention_arrows(img224.copy(), patch_to_patch, top_k=4, query_idx=(n_patches // 2))
280
+
281
+ # Step4: rollout focus map (CLS rollout)
282
+ rollout = compute_attention_rollout(attentions) # (seq, seq)
283
+ # take CLS row -> keys 1.. = patches
284
+ rollout_cls = rollout[0, 1:]
285
  if rollout_cls.shape[0] != grid_size * grid_size:
286
+ tmp = np.zeros(grid_size * grid_size, dtype=float)
287
+ nmin = min(len(rollout_cls), tmp.shape[0])
288
  tmp[:nmin] = rollout_cls[:nmin]
289
  rollout_cls = tmp
290
  rollout_grid = rollout_cls.reshape(grid_size, grid_size)
291
+ focus_img = make_focus_overlay(img224.copy(), rollout_grid, alpha=0.6)
 
 
 
 
292
 
293
+ # Top-5 predictions from classifier head
294
  with torch.no_grad():
295
  logits = clf(**inputs).logits[0].cpu().numpy()
296
  probs = np.exp(logits - logits.max())
 
299
  labels = clf.config.id2label
300
  preds_text = "\n".join([f"{labels[i]} β€” {probs[i]*100:.2f}%" for i in top5])
301
 
302
+ # Advanced outputs: PCA fig and default attention overlay (last layer head 0 CLS->patch)
303
+ pca_fig = pca_plot_from_hidden(hidden_states, [0, max(0, len(hidden_states) // 2), len(hidden_states) - 1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
+ # Attention overlay for advanced default (last layer head0 CLS->patch)
306
+ att_np = attentions[-1][0].cpu().numpy() # (heads, seq, seq)
307
+ # average heads for simplicity
308
+ cls_to_patches = att_np.mean(axis=0)[0, 1:]
309
+ if cls_to_patches.shape[0] != grid_size * grid_size:
310
+ tmp = np.zeros(grid_size * grid_size, dtype=float)
311
+ nmin = min(len(cls_to_patches), tmp.shape[0])
312
+ tmp[:nmin] = cls_to_patches[:nmin]
313
+ cls_to_patches = tmp
314
+ cls_grid = cls_to_patches.reshape(grid_size, grid_size)
315
+ # create overlay
316
+ from PIL import Image # ensure imported
317
+ focus_overlay_default = make_focus_overlay(img224.copy(), cls_grid, alpha=0.5)
318
+
319
+ # make state for interactive advanced controls (move to CPU to save GPU mem)
320
  state = {
321
+ "attentions": [a.cpu() for a in attentions],
322
  "hidden_states": [h.cpu() for h in hidden_states],
323
  "grid_size": grid_size,
324
+ "num_layers": len(attentions),
325
  "num_heads": attentions[0].shape[1],
326
  "base_image": img,
327
  }
328
 
329
+ # Return values:
330
+ # Simple view images: patch_grid_img, cluster_img, arrow_img, focus_img, preds_text
331
+ # Advanced outputs: focus_overlay_default, pca_fig, preds_text, explain_md, state
332
+ simple_explain = """
333
+ **How ViT Sees β€” Simple Steps**
334
 
335
+ 1) **Chop** β€” The image is chopped into small square tiles (patches) like LEGO pieces.
336
+ 2) **Understand** β€” Each piece gets a code that describes colors/edges. Pieces that look similar are grouped.
337
+ 3) **Talk** β€” Pieces tell each other what they see (we draw arrows to show that).
338
+ 4) **Focus & Guess** β€” The model merges clues and focuses on important areas, then guesses what the image shows.
339
+ """
340
 
341
+ advanced_explain = """
342
+ **Advanced View:** Explore attention per layer/head, the PCA of patch embeddings, and the model's internal focus.
343
+ Use sliders to change layer/head and see how ViT's attention evolves.
344
+ """
345
 
346
+ return (
347
+ patch_grid_img,
348
+ cluster_img,
349
+ arrow_img,
350
+ focus_img,
351
+ preds_text,
352
+ simple_explain,
353
+ focus_overlay_default,
354
+ pca_fig,
355
+ preds_text,
356
+ advanced_explain,
357
+ state,
358
+ )
359
 
 
 
 
360
 
361
+ # ------------------- interactive advanced helpers -------------------
362
+ def advanced_update_attention(state: Dict[str, Any], layer_idx: int, head_idx: int):
363
+ if not state:
364
+ return None
365
+ l = max(0, min(int(layer_idx), state["num_layers"] - 1))
366
+ h = max(0, min(int(head_idx), state["num_heads"] - 1))
367
+ att_tensor = state["attentions"][l] # (1, heads, seq, seq) or (heads, seq, seq)
368
  if att_tensor.ndim == 4:
369
  att_tensor = att_tensor[0]
370
  att_np = att_tensor.numpy() # (heads, seq, seq)
371
+ # take CLS->patchs for selected head
372
+ vec = att_np[h, 0, 1:]
373
+ grid = state["grid_size"]
 
 
 
374
  if vec.shape[0] != grid * grid:
375
+ tmp = np.zeros(grid * grid, dtype=float)
376
  nmin = min(vec.shape[0], tmp.shape[0])
377
  tmp[:nmin] = vec[:nmin]
378
  vec = tmp
 
379
  grid_map = vec.reshape(grid, grid)
380
+ return make_focus_overlay(state["base_image"].convert("RGB"), grid_map, alpha=0.55)
 
381
 
382
 
383
+ def advanced_update_rollout(state: Dict[str, Any]):
384
  if not state:
385
  return None
386
+ mats = [a.unsqueeze(0) if a.ndim == 3 else a for a in state["attentions"]]
387
+ R = compute_attention_rollout(mats)
 
388
  grid = state["grid_size"]
389
  rollout_cls = R[0, 1:]
390
  if rollout_cls.shape[0] != grid * grid:
391
+ tmp = np.zeros(grid * grid, dtype=float)
392
+ nmin = min(len(rollout_cls), tmp.shape[0])
393
  tmp[:nmin] = rollout_cls[:nmin]
394
  rollout_cls = tmp
395
  rollout_grid = rollout_cls.reshape(grid, grid)
396
+ return make_focus_overlay(state["base_image"].convert("RGB"), rollout_grid, alpha=0.6)
397
 
398
 
399
+ def advanced_update_pca(state: Dict[str, Any], txt: str):
400
  if not state:
401
  return None
402
+ try:
403
+ layers = [int(x.strip()) for x in txt.split(",") if x.strip() != ""]
404
+ except Exception:
405
+ layers = [0, max(0, state["num_layers"] - 1)]
406
+ return pca_plot_from_hidden(state["hidden_states"], layers)
407
+
408
+
409
+ # ------------------- GRADIO UI -------------------
410
+ with gr.Blocks(title="ViT Visualizer β€” Simple + Advanced") as demo:
411
+ gr.Markdown("# πŸ‘€ How Vision Transformers (ViT) See Images\n"
412
+ "Simple mode (story-style) + Advanced mode (inspect internals). Model: **google/vit-base-patch16-224**")
413
+
414
+ with gr.Tabs():
415
+ with gr.TabItem("Simple (for everyone)"):
416
+ with gr.Row():
417
+ with gr.Column(scale=1):
418
+ img_input = gr.Image(label="Upload an image (photo / object)", type="pil")
419
+ run_btn = gr.Button("πŸ”Ž Explain simply")
420
+ gr.Markdown("Tip: use clear images of objects, animals, scenes for best examples.")
421
+ with gr.Column(scale=1):
422
+ pass
423
+
424
+ gr.Markdown("### Step 1 β€” Chopped into patches")
425
+ step1 = gr.Image(label="Patch Grid (ViT chops image into 16Γ—16 patches)")
426
+
427
+ gr.Markdown("### Step 2 β€” The model groups similar patches")
428
+ step2 = gr.Image(label="Clustered patches (colored blocks)")
429
+
430
+ gr.Markdown("### Step 3 β€” Patches talk to each other (simplified)")
431
+ step3 = gr.Image(label="Patch-to-Patch arrows")
432
+
433
+ gr.Markdown("### Step 4 β€” Model focus map and guess")
434
+ with gr.Row():
435
+ step4 = gr.Image(label="Focus map (where model looked most)")
436
+ preds_simple = gr.Textbox(label="Model guesses (Top-5)", lines=4)
437
+
438
+ explanation_simple = gr.Markdown()
439
+
440
+ run_btn.click(
441
+ fn=analyze_all,
442
+ inputs=[img_input, gr.State(True)],
443
+ outputs=[step1, step2, step3, step4, preds_simple, explanation_simple,
444
+ gr.State(), gr.Plot(), gr.Textbox(), gr.Markdown(), gr.State()],
445
+ )
446
+
447
+ with gr.TabItem("Advanced (inspect internals)"):
448
+ with gr.Row():
449
+ with gr.Column(scale=1):
450
+ img_adv = gr.Image(label="Upload image for advanced view", type="pil")
451
+ run_adv = gr.Button("Analyze (advanced)")
452
+ gr.Markdown("Use the sliders to explore attention per layer and head.")
453
+ layer_slider = gr.Slider(0, 11, value=11, step=1, label="Layer (0=shallow)")
454
+ head_slider = gr.Slider(0, 11, value=0, step=1, label="Head index")
455
+ rollout_btn = gr.Button("Refresh Rollout Overlay")
456
+ pca_txt = gr.Textbox(label="PCA layers (comma separated)", value="0,6,11")
457
+ pca_btn = gr.Button("Update PCA")
458
+ with gr.Column(scale=1):
459
+ adv_attn = gr.Image(label="Attention overlay (layer/head CLS->patch)")
460
+ adv_rollout = gr.Image(label="Attention rollout overlay (aggregated)")
461
+ adv_pca = gr.Plot(label="PCA of patch embeddings")
462
+ adv_preds = gr.Textbox(label="Top-5 predictions", lines=5)
463
+ adv_explain = gr.Markdown()
464
+
465
+ state_box = gr.State()
466
+
467
+ # run advanced analysis
468
+ run_adv.click(
469
+ fn=analyze_all,
470
+ inputs=[img_adv, gr.State(False)],
471
+ outputs=[gr.Image(), gr.Image(), gr.Image(), gr.Image(), adv_preds, gr.Markdown(),
472
+ adv_attn, adv_pca, adv_preds, adv_explain, state_box],
473
+ )
474
+
475
+ # update attention overlay with sliders
476
+ layer_slider.change(
477
+ fn=advanced_update_attention,
478
+ inputs=[state_box, layer_slider, head_slider],
479
+ outputs=[adv_attn],
480
+ )
481
+ head_slider.change(
482
+ fn=advanced_update_attention,
483
+ inputs=[state_box, layer_slider, head_slider],
484
+ outputs=[adv_attn],
485
+ )
486
+
487
+ rollout_btn.click(
488
+ fn=advanced_update_rollout,
489
+ inputs=[state_box],
490
+ outputs=[adv_rollout],
491
+ )
492
+
493
+ pca_btn.click(
494
+ fn=advanced_update_pca,
495
+ inputs=[state_box, pca_txt],
496
+ outputs=[adv_pca],
497
+ )
498
 
499
  demo.launch()