PraneshJs commited on
Commit
50cbd61
·
verified ·
1 Parent(s): 61c8281

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -77
app.py CHANGED
@@ -5,7 +5,7 @@
5
  # - Patch grid (16x16)
6
  # - Patch attention (per layer / per head / query token)
7
  # - Attention rollout (layer aggregated)
8
- # - PCA of patch embeddings across selected layers
9
  # - Top-5 predictions & simple/technical explanations
10
  # ==========================================================
11
 
@@ -16,11 +16,15 @@ from typing import Any, Dict, List, Optional, Tuple
16
  import gradio as gr
17
  import numpy as np
18
  import torch
19
- from PIL import Image, ImageDraw, ImageFont
20
- from transformers import AutoImageProcessor, ViTModel, ViTForImageClassification
 
 
 
 
 
21
  from sklearn.decomposition import PCA
22
  import plotly.express as px
23
- import plotly.graph_objects as go
24
 
25
  warnings.filterwarnings("ignore")
26
 
@@ -33,24 +37,40 @@ VIT_CLF = None # ViTForImageClassification (classification head)
33
  PROCESSOR = None
34
 
35
 
36
- # ------------------ model loader with SDPA fix ------------------
37
  def load_models():
 
 
 
 
 
38
  global VIT_BASE, VIT_CLF, PROCESSOR
39
  if VIT_BASE is not None and VIT_CLF is not None and PROCESSOR is not None:
40
  return VIT_BASE, VIT_CLF, PROCESSOR
41
 
42
  PROCESSOR = AutoImageProcessor.from_pretrained(MODEL_NAME)
43
 
44
- # base ViT (encoder) - we need hidden_states & attentions
45
- base = ViTModel.from_pretrained(MODEL_NAME, output_hidden_states=True)
46
- # fix attn backend so we can access attentions
47
- base.config.attn_implementation = "eager"
48
- base.config.output_attentions = True
49
- base.config.output_hidden_states = True
 
 
 
 
 
 
 
 
 
 
 
50
  base.to(DEVICE)
51
  base.eval()
52
 
53
- # classifier head for top-k labels
54
  clf = ViTForImageClassification.from_pretrained(MODEL_NAME)
55
  clf.to(DEVICE)
56
  clf.eval()
@@ -74,8 +94,7 @@ def make_patch_grid_image(pil: Image.Image, patch_size: int = 16, target_size: i
74
 
75
  def make_attention_overlay(base_img: Image.Image, heat_grid: np.ndarray, cmap_alpha: float = 0.45) -> Image.Image:
76
  """
77
- heat_grid: (G, G) values in any scale (we will normalize)
78
- overlay on base_img (resized to 224x224)
79
  """
80
  img = base_img.convert("RGB").resize((224, 224))
81
  g = np.array(heat_grid, dtype=np.float32)
@@ -87,11 +106,11 @@ def make_attention_overlay(base_img: Image.Image, heat_grid: np.ndarray, cmap_al
87
  else:
88
  g = np.zeros_like(g, dtype=np.float32)
89
 
90
- # upsample
91
  heat_img = Image.fromarray((g * 255).astype("uint8"), mode="L").resize((224, 224), Image.BILINEAR)
92
  heat = np.array(heat_img).astype(np.float32) / 255.0
93
 
94
- # simple colormap blue->red
95
  r = heat
96
  gch = np.zeros_like(heat)
97
  b = 1.0 - heat
@@ -111,7 +130,6 @@ def compute_attention_rollout(all_attentions: List[torch.Tensor]) -> np.ndarray:
111
  R = prod_l (A_l_hat) where A_l_hat = A_l + I; rows normalized
112
  Returns rollout matrix (seq, seq)
113
  """
114
- # convert to np arrays averaged over heads
115
  avg_mats = []
116
  for a in all_attentions:
117
  # a: (batch=1, heads, seq, seq)
@@ -119,14 +137,16 @@ def compute_attention_rollout(all_attentions: List[torch.Tensor]) -> np.ndarray:
119
  avg_mats.append(mat)
120
 
121
  seq = avg_mats[0].shape[0]
122
- # add identity & normalize rows
123
  aug = []
124
  for A in avg_mats:
125
  A_hat = A + np.eye(seq)
126
- A_hat = A_hat / A_hat.sum(axis=-1, keepdims=True)
 
 
 
 
127
  aug.append(A_hat)
128
 
129
- # multiply (matrix product) in order
130
  R = aug[0]
131
  for A in aug[1:]:
132
  R = A @ R
@@ -146,13 +166,11 @@ def layers_pca_plot(hidden_states: List[torch.Tensor], layers: List[int]) -> Any
146
  for li in layers:
147
  hs = hidden_states[li][0].detach().cpu().numpy() # (seq, hidden)
148
  patches = hs[1:, :] # remove CLS -> (N_patches, hidden)
149
- # PCA to 2D
150
  pca = PCA(n_components=2)
151
  pts = pca.fit_transform(patches)
152
  pts_all.append(pts)
153
  layer_labels.append(np.array([li] * pts.shape[0]))
154
 
155
- # combine
156
  coords = np.vstack(pts_all)
157
  labels = np.concatenate(layer_labels)
158
  df = {"x": coords[:, 0], "y": coords[:, 1], "layer": labels.astype(str)}
@@ -165,9 +183,7 @@ def layers_pca_plot(hidden_states: List[torch.Tensor], layers: List[int]) -> Any
165
  # ------------------ core analyzer ------------------
166
  def analyze_vit_full(img: Optional[Image.Image], simple: bool):
167
  if img is None:
168
- return (
169
- None, None, None, None, None, "", {}, {}
170
- )
171
 
172
  base, clf, processor = load_models()
173
 
@@ -180,43 +196,46 @@ def analyze_vit_full(img: Optional[Image.Image], simple: bool):
180
  outputs = base(**inputs)
181
 
182
  # outputs.attentions: list L tensors (batch=1, heads, seq, seq)
183
- attentions = outputs.attentions # list length L
184
- hidden_states = outputs.hidden_states # list length L+1 (including embeddings) usually
185
 
186
  L = len(attentions)
187
  seq_len = attentions[0].shape[-1]
188
  n_patches = seq_len - 1
189
  grid_size = int(math.sqrt(n_patches))
190
  if grid_size * grid_size != n_patches:
191
- # fallback: compute closest integer grid
192
  grid_size = int(round(math.sqrt(n_patches)))
193
 
194
- # default selections
195
- default_layer = L - 1
196
- default_head = 0
197
- # default query token = 0 (CLS)
198
- default_query = 0
199
-
200
  # Build patch grid image
201
  patch_grid = make_patch_grid_image(img.copy(), patch_size=16, target_size=224)
202
 
203
- # Build per-layer per-head CLS->patch default overlay
204
- # pick last layer, head 0, CLS query
205
- att_np = attentions[default_layer][0].cpu().numpy() # (heads, seq, seq)
206
- cls_to_patches = att_np[default_head, 0, 1:] # (n_patches,)
 
 
 
 
 
 
 
207
  cls_grid = cls_to_patches.reshape(grid_size, grid_size)
208
  attn_overlay = make_attention_overlay(img, cls_grid)
209
 
210
- # Compute rollout
211
  rollout_mat = compute_attention_rollout(attentions) # (seq, seq)
212
  rollout_cls = rollout_mat[0, 1:]
 
 
 
 
 
213
  rollout_grid = rollout_cls.reshape(grid_size, grid_size)
214
- rollout_overlay = make_attention_overlay(img, rollout_grid, cmap_alpha=0.5)
215
 
216
- # PCA multi-layer: pick a few representative layers (start, quarter, half, three-quarters, last)
217
- layers_to_show = sorted(
218
- list({0, max(0, L // 4), max(0, L // 2), max(0, 3 * L // 4), L - 1})
219
- )
220
  pca_fig = layers_pca_plot(hidden_states, layers_to_show)
221
 
222
  # Classification top-5
@@ -248,34 +267,20 @@ def analyze_vit_full(img: Optional[Image.Image], simple: bool):
248
  - Attention rollout uses Abnar & Zuidema's method to accumulate attention paths across layers.
249
  """
250
 
251
- # return many things + state necessary for interactive updates (layer/head/query)
252
  state = {
253
- "attentions": [a.cpu() for a in attentions], # store on CPU to allow slider updates
254
  "hidden_states": [h.cpu() for h in hidden_states],
255
  "grid_size": grid_size,
256
  "num_layers": L,
257
  "num_heads": attentions[0].shape[1],
258
- "base_image": img, # original high-res image (we'll resize to 224 when overlaying)
259
  }
260
 
261
- return (
262
- patch_grid,
263
- attn_overlay,
264
- rollout_overlay,
265
- pca_fig,
266
- preds_text,
267
- explain_md,
268
- state,
269
- )
270
 
271
 
272
  # ------------------ update functions for sliders / choices ------------------
273
  def update_layer_head_query(state: Dict[str, Any], layer_idx: int, head_idx: int, query_token: int, mode: str):
274
- """
275
- mode:
276
- - "patch_attention": attention of query_token -> patches at (layer, head)
277
- - "rollout": ignored (we will return rollout overlay)
278
- """
279
  if not state:
280
  return None
281
 
@@ -288,25 +293,17 @@ def update_layer_head_query(state: Dict[str, Any], layer_idx: int, head_idx: int
288
  h = max(0, min(int(head_idx), H - 1))
289
  q = max(0, min(int(query_token), grid * grid)) # q in 0..n_patches (0==CLS)
290
 
291
- # load attention for layer l: it's a CPU tensor (heads, seq, seq) already stored as state
292
- att_tensor = state["attentions"][l] # shape (heads, seq, seq) because we saved a[0] earlier
293
- # ensure shape (heads, seq, seq)
294
- if att_tensor.ndim == 4: # sometimes shape might be (1, heads, seq, seq)
295
  att_tensor = att_tensor[0]
296
  att_np = att_tensor.numpy() # (heads, seq, seq)
297
 
298
- # query q -> keys: if q == 0 it's CLS; keys positions 1..seq-1 are patches
299
  seq = att_np.shape[-1]
300
- n_patches = seq - 1
301
- # column indices for keys: 1..seq-1 map to patches 0..n_patches-1
302
  if q >= seq:
303
  q = 0
304
 
305
- # get attention vector for head h: att[h, q, 1:]
306
  vec = att_np[h, q, 1:]
307
- # if vec shorter/longer than grid^2, adjust
308
  if vec.shape[0] != grid * grid:
309
- # pad or trim
310
  tmp = np.zeros(grid * grid, dtype=np.float32)
311
  nmin = min(vec.shape[0], tmp.shape[0])
312
  tmp[:nmin] = vec[:nmin]
@@ -321,8 +318,6 @@ def get_rollout_overlay(state: Dict[str, Any]):
321
  if not state:
322
  return None
323
  attentions = state["attentions"]
324
- # attentions list of tensors (heads, seq, seq)
325
- # convert to list of (1, heads, seq, seq) for compute_attention_rollout
326
  mats = [a.unsqueeze(0) if a.ndim == 3 else a for a in attentions]
327
  R = compute_attention_rollout(mats) # (seq, seq)
328
  grid = state["grid_size"]
@@ -339,9 +334,7 @@ def get_rollout_overlay(state: Dict[str, Any]):
339
  def update_pca_layers(state: Dict[str, Any], selected_layers: List[int]):
340
  if not state:
341
  return None
342
- # hidden_states stored as list of CPU tensors (batch, seq, hidden)
343
  hs = state["hidden_states"]
344
- # ensure layers within range
345
  layers = [max(0, min(int(l), len(hs) - 1)) for l in selected_layers]
346
  fig = layers_pca_plot(hs, layers)
347
  return fig
@@ -365,8 +358,7 @@ with gr.Blocks(title="ViT Full Interpretability (A+B+C)") as demo:
365
 
366
  gr.Markdown("**Attention Rollout & PCA**")
367
  rollout_btn = gr.Button("Refresh Rollout Overlay")
368
- # PCA layers selection: simple multi-select text entry allowed (comma separated)
369
- pca_layers_txt = gr.Textbox(label="PCA layers (comma separated indices, e.g. 0,3,6,11)", value="0,3,6,11,11")
370
 
371
  with gr.Column(scale=1):
372
  gr.Markdown("### Outputs")
 
5
  # - Patch grid (16x16)
6
  # - Patch attention (per layer / per head / query token)
7
  # - Attention rollout (layer aggregated)
8
+ # - PCA of patch embeddings across layers
9
  # - Top-5 predictions & simple/technical explanations
10
  # ==========================================================
11
 
 
16
  import gradio as gr
17
  import numpy as np
18
  import torch
19
+ from PIL import Image, ImageDraw
20
+ from transformers import (
21
+ AutoImageProcessor,
22
+ ViTModel,
23
+ ViTForImageClassification,
24
+ AutoConfig,
25
+ )
26
  from sklearn.decomposition import PCA
27
  import plotly.express as px
 
28
 
29
  warnings.filterwarnings("ignore")
30
 
 
37
  PROCESSOR = None
38
 
39
 
40
+ # ------------------ model loader with SDPA -> eager fix ------------------
41
  def load_models():
42
+ """
43
+ Load processor + ViT base + classification head.
44
+ Important: create config first, set attn_implementation='eager'
45
+ before enabling output_attentions/output_hidden_states, then load models.
46
+ """
47
  global VIT_BASE, VIT_CLF, PROCESSOR
48
  if VIT_BASE is not None and VIT_CLF is not None and PROCESSOR is not None:
49
  return VIT_BASE, VIT_CLF, PROCESSOR
50
 
51
  PROCESSOR = AutoImageProcessor.from_pretrained(MODEL_NAME)
52
 
53
+ # load config, modify before creating model
54
+ cfg = AutoConfig = None
55
+ try:
56
+ cfg = AutoConfig.from_pretrained(MODEL_NAME)
57
+ except Exception:
58
+ # fallback: load a default config and set minimal fields
59
+ from transformers import ViTConfig
60
+ cfg = ViTConfig.from_pretrained(MODEL_NAME)
61
+
62
+ # FORCE eager attention backend so we can extract attentions
63
+ # (must set attn_implementation before enabling output_attentions)
64
+ cfg.attn_implementation = "eager"
65
+ cfg.output_attentions = True
66
+ cfg.output_hidden_states = True
67
+
68
+ # now load the base encoder with the modified config
69
+ base = ViTModel.from_pretrained(MODEL_NAME, config=cfg)
70
  base.to(DEVICE)
71
  base.eval()
72
 
73
+ # load classifier separately (we can use default config for classifier)
74
  clf = ViTForImageClassification.from_pretrained(MODEL_NAME)
75
  clf.to(DEVICE)
76
  clf.eval()
 
94
 
95
  def make_attention_overlay(base_img: Image.Image, heat_grid: np.ndarray, cmap_alpha: float = 0.45) -> Image.Image:
96
  """
97
+ heat_grid: (G, G) values (any scale) -> normalized then overlaid on base_img (resized to 224x224)
 
98
  """
99
  img = base_img.convert("RGB").resize((224, 224))
100
  g = np.array(heat_grid, dtype=np.float32)
 
106
  else:
107
  g = np.zeros_like(g, dtype=np.float32)
108
 
109
+ # upsample to image resolution
110
  heat_img = Image.fromarray((g * 255).astype("uint8"), mode="L").resize((224, 224), Image.BILINEAR)
111
  heat = np.array(heat_img).astype(np.float32) / 255.0
112
 
113
+ # simple blue->red colormap
114
  r = heat
115
  gch = np.zeros_like(heat)
116
  b = 1.0 - heat
 
130
  R = prod_l (A_l_hat) where A_l_hat = A_l + I; rows normalized
131
  Returns rollout matrix (seq, seq)
132
  """
 
133
  avg_mats = []
134
  for a in all_attentions:
135
  # a: (batch=1, heads, seq, seq)
 
137
  avg_mats.append(mat)
138
 
139
  seq = avg_mats[0].shape[0]
 
140
  aug = []
141
  for A in avg_mats:
142
  A_hat = A + np.eye(seq)
143
+ # normalize rows (sum over last dim)
144
+ row_sums = A_hat.sum(axis=-1, keepdims=True)
145
+ # avoid division by zero
146
+ row_sums[row_sums == 0] = 1.0
147
+ A_hat = A_hat / row_sums
148
  aug.append(A_hat)
149
 
 
150
  R = aug[0]
151
  for A in aug[1:]:
152
  R = A @ R
 
166
  for li in layers:
167
  hs = hidden_states[li][0].detach().cpu().numpy() # (seq, hidden)
168
  patches = hs[1:, :] # remove CLS -> (N_patches, hidden)
 
169
  pca = PCA(n_components=2)
170
  pts = pca.fit_transform(patches)
171
  pts_all.append(pts)
172
  layer_labels.append(np.array([li] * pts.shape[0]))
173
 
 
174
  coords = np.vstack(pts_all)
175
  labels = np.concatenate(layer_labels)
176
  df = {"x": coords[:, 0], "y": coords[:, 1], "layer": labels.astype(str)}
 
183
  # ------------------ core analyzer ------------------
184
  def analyze_vit_full(img: Optional[Image.Image], simple: bool):
185
  if img is None:
186
+ return (None, None, None, None, None, "", {})
 
 
187
 
188
  base, clf, processor = load_models()
189
 
 
196
  outputs = base(**inputs)
197
 
198
  # outputs.attentions: list L tensors (batch=1, heads, seq, seq)
199
+ attentions = outputs.attentions
200
+ hidden_states = outputs.hidden_states
201
 
202
  L = len(attentions)
203
  seq_len = attentions[0].shape[-1]
204
  n_patches = seq_len - 1
205
  grid_size = int(math.sqrt(n_patches))
206
  if grid_size * grid_size != n_patches:
 
207
  grid_size = int(round(math.sqrt(n_patches)))
208
 
 
 
 
 
 
 
209
  # Build patch grid image
210
  patch_grid = make_patch_grid_image(img.copy(), patch_size=16, target_size=224)
211
 
212
+ # default overlay: last layer, head 0, CLS query
213
+ last_layer = L - 1
214
+ head0 = 0
215
+ # attentions[last_layer]: shape (batch=1, heads, seq, seq)
216
+ att_np = attentions[last_layer][0].cpu().numpy() # (heads, seq, seq)
217
+ cls_to_patches = att_np[head0, 0, 1:] # (n_patches,)
218
+ if cls_to_patches.shape[0] != grid_size * grid_size:
219
+ tmp = np.zeros(grid_size * grid_size, dtype=np.float32)
220
+ nmin = min(cls_to_patches.shape[0], tmp.shape[0])
221
+ tmp[:nmin] = cls_to_patches[:nmin]
222
+ cls_to_patches = tmp
223
  cls_grid = cls_to_patches.reshape(grid_size, grid_size)
224
  attn_overlay = make_attention_overlay(img, cls_grid)
225
 
226
+ # Compute rollout overlay (CLS)
227
  rollout_mat = compute_attention_rollout(attentions) # (seq, seq)
228
  rollout_cls = rollout_mat[0, 1:]
229
+ if rollout_cls.shape[0] != grid_size * grid_size:
230
+ tmp = np.zeros(grid_size * grid_size, dtype=np.float32)
231
+ nmin = min(rollout_cls.shape[0], tmp.shape[0])
232
+ tmp[:nmin] = rollout_cls[:nmin]
233
+ rollout_cls = tmp
234
  rollout_grid = rollout_cls.reshape(grid_size, grid_size)
235
+ rollout_overlay = make_attention_overlay(img, rollout_grid, cmap_alpha=0.55)
236
 
237
+ # PCA multi-layer: choose representative layers
238
+ layers_to_show = sorted(list({0, max(0, L // 4), max(0, L // 2), max(0, 3 * L // 4), L - 1}))
 
 
239
  pca_fig = layers_pca_plot(hidden_states, layers_to_show)
240
 
241
  # Classification top-5
 
267
  - Attention rollout uses Abnar & Zuidema's method to accumulate attention paths across layers.
268
  """
269
 
 
270
  state = {
271
+ "attentions": [a.cpu() for a in attentions], # move to CPU for interactive updates
272
  "hidden_states": [h.cpu() for h in hidden_states],
273
  "grid_size": grid_size,
274
  "num_layers": L,
275
  "num_heads": attentions[0].shape[1],
276
+ "base_image": img,
277
  }
278
 
279
+ return patch_grid, attn_overlay, rollout_overlay, pca_fig, preds_text, explain_md, state
 
 
 
 
 
 
 
 
280
 
281
 
282
  # ------------------ update functions for sliders / choices ------------------
283
  def update_layer_head_query(state: Dict[str, Any], layer_idx: int, head_idx: int, query_token: int, mode: str):
 
 
 
 
 
284
  if not state:
285
  return None
286
 
 
293
  h = max(0, min(int(head_idx), H - 1))
294
  q = max(0, min(int(query_token), grid * grid)) # q in 0..n_patches (0==CLS)
295
 
296
+ att_tensor = state["attentions"][l] # shape (heads, seq, seq) or (1,heads,seq,seq)
297
+ if att_tensor.ndim == 4:
 
 
298
  att_tensor = att_tensor[0]
299
  att_np = att_tensor.numpy() # (heads, seq, seq)
300
 
 
301
  seq = att_np.shape[-1]
 
 
302
  if q >= seq:
303
  q = 0
304
 
 
305
  vec = att_np[h, q, 1:]
 
306
  if vec.shape[0] != grid * grid:
 
307
  tmp = np.zeros(grid * grid, dtype=np.float32)
308
  nmin = min(vec.shape[0], tmp.shape[0])
309
  tmp[:nmin] = vec[:nmin]
 
318
  if not state:
319
  return None
320
  attentions = state["attentions"]
 
 
321
  mats = [a.unsqueeze(0) if a.ndim == 3 else a for a in attentions]
322
  R = compute_attention_rollout(mats) # (seq, seq)
323
  grid = state["grid_size"]
 
334
  def update_pca_layers(state: Dict[str, Any], selected_layers: List[int]):
335
  if not state:
336
  return None
 
337
  hs = state["hidden_states"]
 
338
  layers = [max(0, min(int(l), len(hs) - 1)) for l in selected_layers]
339
  fig = layers_pca_plot(hs, layers)
340
  return fig
 
358
 
359
  gr.Markdown("**Attention Rollout & PCA**")
360
  rollout_btn = gr.Button("Refresh Rollout Overlay")
361
+ pca_layers_txt = gr.Textbox(label="PCA layers (comma separated indices, e.g. 0,3,6,11)", value="0,3,6,11")
 
362
 
363
  with gr.Column(scale=1):
364
  gr.Markdown("### Outputs")