Update app.py
Browse files
app.py
CHANGED
|
@@ -1,12 +1,16 @@
|
|
| 1 |
-
#
|
|
|
|
| 2 |
# Model: google/vit-base-patch16-224
|
| 3 |
-
# Gradio 5 compatible
|
|
|
|
| 4 |
# Features:
|
| 5 |
-
# -
|
| 6 |
-
#
|
| 7 |
-
#
|
| 8 |
-
#
|
| 9 |
-
#
|
|
|
|
|
|
|
| 10 |
# ==========================================================
|
| 11 |
|
| 12 |
import math
|
|
@@ -16,14 +20,15 @@ from typing import Any, Dict, List, Optional, Tuple
|
|
| 16 |
import gradio as gr
|
| 17 |
import numpy as np
|
| 18 |
import torch
|
| 19 |
-
from PIL import Image, ImageDraw
|
|
|
|
|
|
|
| 20 |
from transformers import (
|
| 21 |
AutoImageProcessor,
|
| 22 |
ViTModel,
|
| 23 |
ViTForImageClassification,
|
| 24 |
AutoConfig,
|
| 25 |
)
|
| 26 |
-
from sklearn.decomposition import PCA
|
| 27 |
import plotly.express as px
|
| 28 |
|
| 29 |
warnings.filterwarnings("ignore")
|
|
@@ -31,214 +36,261 @@ warnings.filterwarnings("ignore")
|
|
| 31 |
MODEL_NAME = "google/vit-base-patch16-224"
|
| 32 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 33 |
|
| 34 |
-
#
|
| 35 |
-
|
| 36 |
-
|
| 37 |
PROCESSOR = None
|
| 38 |
|
| 39 |
|
| 40 |
-
#
|
| 41 |
def load_models():
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
before enabling output_attentions/output_hidden_states, then load models.
|
| 46 |
-
"""
|
| 47 |
-
global VIT_BASE, VIT_CLF, PROCESSOR
|
| 48 |
-
if VIT_BASE is not None and VIT_CLF is not None and PROCESSOR is not None:
|
| 49 |
-
return VIT_BASE, VIT_CLF, PROCESSOR
|
| 50 |
|
| 51 |
PROCESSOR = AutoImageProcessor.from_pretrained(MODEL_NAME)
|
| 52 |
|
| 53 |
-
# load config,
|
| 54 |
-
cfg = AutoConfig
|
| 55 |
-
|
| 56 |
-
cfg = AutoConfig.from_pretrained(MODEL_NAME)
|
| 57 |
-
except Exception:
|
| 58 |
-
# fallback: load a default config and set minimal fields
|
| 59 |
-
from transformers import ViTConfig
|
| 60 |
-
cfg = ViTConfig.from_pretrained(MODEL_NAME)
|
| 61 |
-
|
| 62 |
-
# FORCE eager attention backend so we can extract attentions
|
| 63 |
-
# (must set attn_implementation before enabling output_attentions)
|
| 64 |
-
cfg.attn_implementation = "eager"
|
| 65 |
cfg.output_attentions = True
|
| 66 |
cfg.output_hidden_states = True
|
| 67 |
|
| 68 |
-
#
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
# load classifier separately (we can use default config for classifier)
|
| 74 |
-
clf = ViTForImageClassification.from_pretrained(MODEL_NAME)
|
| 75 |
-
clf.to(DEVICE)
|
| 76 |
-
clf.eval()
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
|
| 83 |
-
#
|
| 84 |
-
def
|
| 85 |
-
img =
|
| 86 |
draw = ImageDraw.Draw(img)
|
| 87 |
w, h = img.size
|
| 88 |
for x in range(0, w, patch_size):
|
| 89 |
-
draw.line((x, 0, x, h), fill=
|
| 90 |
for y in range(0, h, patch_size):
|
| 91 |
-
draw.line((0, y, w, y), fill=
|
| 92 |
return img
|
| 93 |
|
| 94 |
|
| 95 |
-
def
|
| 96 |
"""
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
"""
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
| 100 |
g = np.array(heat_grid, dtype=np.float32)
|
| 101 |
-
# normalize 0..1
|
| 102 |
if np.any(g):
|
| 103 |
g = g - g.min()
|
| 104 |
if g.max() > 0:
|
| 105 |
g = g / g.max()
|
| 106 |
else:
|
| 107 |
-
g = np.zeros_like(g
|
| 108 |
-
|
| 109 |
-
# upsample to image resolution
|
| 110 |
heat_img = Image.fromarray((g * 255).astype("uint8"), mode="L").resize((224, 224), Image.BILINEAR)
|
| 111 |
heat = np.array(heat_img).astype(np.float32) / 255.0
|
| 112 |
-
|
| 113 |
-
# simple
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
|
|
|
| 123 |
|
| 124 |
|
| 125 |
-
#
|
| 126 |
def compute_attention_rollout(all_attentions: List[torch.Tensor]) -> np.ndarray:
|
| 127 |
-
"""
|
| 128 |
-
all_attentions: list length L of tensors (batch, heads, seq, seq)
|
| 129 |
-
We'll average heads per layer -> (seq, seq) and compute rollout:
|
| 130 |
-
R = prod_l (A_l_hat) where A_l_hat = A_l + I; rows normalized
|
| 131 |
-
Returns rollout matrix (seq, seq)
|
| 132 |
-
"""
|
| 133 |
avg_mats = []
|
| 134 |
for a in all_attentions:
|
| 135 |
-
# a: (batch=1, heads, seq, seq)
|
| 136 |
mat = a[0].mean(dim=0).detach().cpu().numpy() # (seq, seq)
|
| 137 |
avg_mats.append(mat)
|
| 138 |
-
|
| 139 |
seq = avg_mats[0].shape[0]
|
| 140 |
aug = []
|
| 141 |
for A in avg_mats:
|
| 142 |
A_hat = A + np.eye(seq)
|
| 143 |
-
# normalize rows (sum over last dim)
|
| 144 |
row_sums = A_hat.sum(axis=-1, keepdims=True)
|
| 145 |
-
# avoid division by zero
|
| 146 |
row_sums[row_sums == 0] = 1.0
|
| 147 |
A_hat = A_hat / row_sums
|
| 148 |
aug.append(A_hat)
|
| 149 |
-
|
| 150 |
R = aug[0]
|
| 151 |
for A in aug[1:]:
|
| 152 |
R = A @ R
|
| 153 |
return R # (seq, seq)
|
| 154 |
|
| 155 |
|
| 156 |
-
#
|
| 157 |
-
def
|
| 158 |
-
"""
|
| 159 |
-
hidden_states: list of tensors (batch, seq, hidden)
|
| 160 |
-
layers: list of indices within hidden_states to project
|
| 161 |
-
We'll remove CLS token and do PCA for each chosen layer;
|
| 162 |
-
plot patches from each layer with different colors on single plot.
|
| 163 |
-
"""
|
| 164 |
pts_all = []
|
| 165 |
-
|
| 166 |
for li in layers:
|
| 167 |
-
hs = hidden_states[li][0].detach().cpu().numpy()
|
| 168 |
-
patches = hs[1:, :]
|
| 169 |
pca = PCA(n_components=2)
|
| 170 |
pts = pca.fit_transform(patches)
|
| 171 |
pts_all.append(pts)
|
| 172 |
-
|
| 173 |
-
|
| 174 |
coords = np.vstack(pts_all)
|
| 175 |
-
|
| 176 |
-
df = {"x": coords[:, 0], "y": coords[:, 1], "layer":
|
| 177 |
fig = px.scatter(df, x="x", y="y", color="layer", title="Patch embeddings across layers (PCA)")
|
| 178 |
fig.update_traces(marker=dict(size=6))
|
| 179 |
fig.update_layout(height=480)
|
| 180 |
return fig
|
| 181 |
|
| 182 |
|
| 183 |
-
#
|
| 184 |
-
def
|
| 185 |
if img is None:
|
| 186 |
-
return
|
|
|
|
|
|
|
| 187 |
|
| 188 |
base, clf, processor = load_models()
|
| 189 |
|
| 190 |
-
# preprocess
|
| 191 |
-
|
| 192 |
-
inputs = processor(images=
|
| 193 |
|
| 194 |
-
# forward
|
| 195 |
with torch.no_grad():
|
| 196 |
outputs = base(**inputs)
|
| 197 |
|
| 198 |
-
|
| 199 |
-
attentions = outputs.attentions
|
| 200 |
hidden_states = outputs.hidden_states
|
| 201 |
|
| 202 |
-
|
|
|
|
| 203 |
seq_len = attentions[0].shape[-1]
|
| 204 |
n_patches = seq_len - 1
|
| 205 |
-
grid_size = int(math.sqrt(n_patches))
|
| 206 |
-
if grid_size * grid_size != n_patches:
|
| 207 |
-
grid_size = int(round(math.sqrt(n_patches)))
|
| 208 |
-
|
| 209 |
-
# Build patch grid image
|
| 210 |
-
patch_grid = make_patch_grid_image(img.copy(), patch_size=16, target_size=224)
|
| 211 |
-
|
| 212 |
-
# default overlay: last layer, head 0, CLS query
|
| 213 |
-
last_layer = L - 1
|
| 214 |
-
head0 = 0
|
| 215 |
-
# attentions[last_layer]: shape (batch=1, heads, seq, seq)
|
| 216 |
-
att_np = attentions[last_layer][0].cpu().numpy() # (heads, seq, seq)
|
| 217 |
-
cls_to_patches = att_np[head0, 0, 1:] # (n_patches,)
|
| 218 |
-
if cls_to_patches.shape[0] != grid_size * grid_size:
|
| 219 |
-
tmp = np.zeros(grid_size * grid_size, dtype=np.float32)
|
| 220 |
-
nmin = min(cls_to_patches.shape[0], tmp.shape[0])
|
| 221 |
-
tmp[:nmin] = cls_to_patches[:nmin]
|
| 222 |
-
cls_to_patches = tmp
|
| 223 |
-
cls_grid = cls_to_patches.reshape(grid_size, grid_size)
|
| 224 |
-
attn_overlay = make_attention_overlay(img, cls_grid)
|
| 225 |
|
| 226 |
-
#
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
if rollout_cls.shape[0] != grid_size * grid_size:
|
| 230 |
-
tmp = np.zeros(grid_size * grid_size, dtype=
|
| 231 |
-
nmin = min(rollout_cls
|
| 232 |
tmp[:nmin] = rollout_cls[:nmin]
|
| 233 |
rollout_cls = tmp
|
| 234 |
rollout_grid = rollout_cls.reshape(grid_size, grid_size)
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
# PCA multi-layer: choose representative layers
|
| 238 |
-
layers_to_show = sorted(list({0, max(0, L // 4), max(0, L // 2), max(0, 3 * L // 4), L - 1}))
|
| 239 |
-
pca_fig = layers_pca_plot(hidden_states, layers_to_show)
|
| 240 |
|
| 241 |
-
#
|
| 242 |
with torch.no_grad():
|
| 243 |
logits = clf(**inputs).logits[0].cpu().numpy()
|
| 244 |
probs = np.exp(logits - logits.max())
|
|
@@ -247,175 +299,201 @@ def analyze_vit_full(img: Optional[Image.Image], simple: bool):
|
|
| 247 |
labels = clf.config.id2label
|
| 248 |
preds_text = "\n".join([f"{labels[i]} β {probs[i]*100:.2f}%" for i in top5])
|
| 249 |
|
| 250 |
-
#
|
| 251 |
-
|
| 252 |
-
explain_md = f"""
|
| 253 |
-
### π§ How ViT Sees the Image (Simple)
|
| 254 |
-
1. Image is cut into {grid_size}Γ{grid_size} = {grid_size*grid_size} patches (16Γ16).
|
| 255 |
-
2. Each patch becomes a token. The model learns what each patch "means".
|
| 256 |
-
3. Attention tells each token which other patches matter to it.
|
| 257 |
-
4. Rollout aggregates attention across layers to show the final "focus".
|
| 258 |
-
5. PCA shows how patch features evolve across layers (from raw to object-aware).
|
| 259 |
-
"""
|
| 260 |
-
else:
|
| 261 |
-
explain_md = f"""
|
| 262 |
-
### π¬ Technical Explanation
|
| 263 |
-
- Model: {MODEL_NAME}
|
| 264 |
-
- Transformer layers: {L}, patch grid: {grid_size}Γ{grid_size}
|
| 265 |
-
- We extract token attentions (heads) and hidden states for PCA.
|
| 266 |
-
- Patch attention visualization maps token attention back to the image grid.
|
| 267 |
-
- Attention rollout uses Abnar & Zuidema's method to accumulate attention paths across layers.
|
| 268 |
-
"""
|
| 269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
state = {
|
| 271 |
-
"attentions": [a.cpu() for a in attentions],
|
| 272 |
"hidden_states": [h.cpu() for h in hidden_states],
|
| 273 |
"grid_size": grid_size,
|
| 274 |
-
"num_layers":
|
| 275 |
"num_heads": attentions[0].shape[1],
|
| 276 |
"base_image": img,
|
| 277 |
}
|
| 278 |
|
| 279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
-
l = max(0, min(int(layer_idx), L - 1))
|
| 293 |
-
h = max(0, min(int(head_idx), H - 1))
|
| 294 |
-
q = max(0, min(int(query_token), grid * grid)) # q in 0..n_patches (0==CLS)
|
| 295 |
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
if att_tensor.ndim == 4:
|
| 298 |
att_tensor = att_tensor[0]
|
| 299 |
att_np = att_tensor.numpy() # (heads, seq, seq)
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
q = 0
|
| 304 |
-
|
| 305 |
-
vec = att_np[h, q, 1:]
|
| 306 |
if vec.shape[0] != grid * grid:
|
| 307 |
-
tmp = np.zeros(grid * grid, dtype=
|
| 308 |
nmin = min(vec.shape[0], tmp.shape[0])
|
| 309 |
tmp[:nmin] = vec[:nmin]
|
| 310 |
vec = tmp
|
| 311 |
-
|
| 312 |
grid_map = vec.reshape(grid, grid)
|
| 313 |
-
|
| 314 |
-
return overlay
|
| 315 |
|
| 316 |
|
| 317 |
-
def
|
| 318 |
if not state:
|
| 319 |
return None
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
R = compute_attention_rollout(mats) # (seq, seq)
|
| 323 |
grid = state["grid_size"]
|
| 324 |
rollout_cls = R[0, 1:]
|
| 325 |
if rollout_cls.shape[0] != grid * grid:
|
| 326 |
-
tmp = np.zeros(grid * grid, dtype=
|
| 327 |
-
nmin = min(rollout_cls
|
| 328 |
tmp[:nmin] = rollout_cls[:nmin]
|
| 329 |
rollout_cls = tmp
|
| 330 |
rollout_grid = rollout_cls.reshape(grid, grid)
|
| 331 |
-
return
|
| 332 |
|
| 333 |
|
| 334 |
-
def
|
| 335 |
if not state:
|
| 336 |
return None
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
gr.Markdown("
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
demo.launch()
|
|
|
|
| 1 |
+
# ==========================================================
|
| 2 |
+
# ViT Visualizer β Simple (comic-style) + Advanced Mode
|
| 3 |
# Model: google/vit-base-patch16-224
|
| 4 |
+
# Gradio 5 compatible; CPU-friendly
|
| 5 |
+
#
|
| 6 |
# Features:
|
| 7 |
+
# - Simple mode (4-step, non-technical, kid-friendly)
|
| 8 |
+
# Step1: Patch grid
|
| 9 |
+
# Step2: Patch clustering (colored blocks)
|
| 10 |
+
# Step3: Patch-to-patch arrows (simplified attention)
|
| 11 |
+
# Step4: Focus map (rollout) + Top-5 predictions
|
| 12 |
+
# - Advanced mode (attention maps per layer/head, rollout, PCA)
|
| 13 |
+
# - SDPA -> eager fix for attention extraction
|
| 14 |
# ==========================================================
|
| 15 |
|
| 16 |
import math
|
|
|
|
| 20 |
import gradio as gr
|
| 21 |
import numpy as np
|
| 22 |
import torch
|
| 23 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 24 |
+
from sklearn.cluster import KMeans
|
| 25 |
+
from sklearn.decomposition import PCA
|
| 26 |
from transformers import (
|
| 27 |
AutoImageProcessor,
|
| 28 |
ViTModel,
|
| 29 |
ViTForImageClassification,
|
| 30 |
AutoConfig,
|
| 31 |
)
|
|
|
|
| 32 |
import plotly.express as px
|
| 33 |
|
| 34 |
warnings.filterwarnings("ignore")
|
|
|
|
| 36 |
MODEL_NAME = "google/vit-base-patch16-224"
|
| 37 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 38 |
|
| 39 |
+
# Globals
|
| 40 |
+
BASE_MODEL = None
|
| 41 |
+
CLF_MODEL = None
|
| 42 |
PROCESSOR = None
|
| 43 |
|
| 44 |
|
| 45 |
+
# ------------------- model loader with SDPA -> eager fix -------------------
|
| 46 |
def load_models():
|
| 47 |
+
global BASE_MODEL, CLF_MODEL, PROCESSOR
|
| 48 |
+
if BASE_MODEL is not None and CLF_MODEL is not None and PROCESSOR is not None:
|
| 49 |
+
return BASE_MODEL, CLF_MODEL, PROCESSOR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
PROCESSOR = AutoImageProcessor.from_pretrained(MODEL_NAME)
|
| 52 |
|
| 53 |
+
# load config first, set attn_implementation BEFORE enabling attentions
|
| 54 |
+
cfg = AutoConfig.from_pretrained(MODEL_NAME)
|
| 55 |
+
cfg.attn_implementation = "eager" # << must set this first
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
cfg.output_attentions = True
|
| 57 |
cfg.output_hidden_states = True
|
| 58 |
|
| 59 |
+
# load base encoder with modified config (we'll extract hidden states & attentions)
|
| 60 |
+
BASE_MODEL = ViTModel.from_pretrained(MODEL_NAME, config=cfg)
|
| 61 |
+
BASE_MODEL.to(DEVICE).eval()
|
| 62 |
+
|
| 63 |
+
# classifier head (for top-5 predictions)
|
| 64 |
+
CLF_MODEL = ViTForImageClassification.from_pretrained(MODEL_NAME)
|
| 65 |
+
CLF_MODEL.to(DEVICE).eval()
|
| 66 |
+
|
| 67 |
+
return BASE_MODEL, CLF_MODEL, PROCESSOR
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
# ------------------- utility: patch grid positions -------------------
|
| 71 |
+
def patch_grid_info(image_size: int = 224, patch_size: int = 16):
|
| 72 |
+
grid_size = image_size // patch_size
|
| 73 |
+
positions = []
|
| 74 |
+
for i in range(grid_size):
|
| 75 |
+
for j in range(grid_size):
|
| 76 |
+
# center coordinates of patch (x,y)
|
| 77 |
+
cx = int((j + 0.5) * patch_size)
|
| 78 |
+
cy = int((i + 0.5) * patch_size)
|
| 79 |
+
positions.append((cx, cy))
|
| 80 |
+
return grid_size, positions
|
| 81 |
|
| 82 |
|
| 83 |
+
# ------------------- visual helpers -------------------
|
| 84 |
+
def draw_patch_grid(img: Image.Image, patch_size: int = 16, outline=(0, 180, 0)) -> Image.Image:
|
| 85 |
+
img = img.convert("RGB").resize((224, 224))
|
| 86 |
draw = ImageDraw.Draw(img)
|
| 87 |
w, h = img.size
|
| 88 |
for x in range(0, w, patch_size):
|
| 89 |
+
draw.line([(x, 0), (x, h)], fill=outline, width=1)
|
| 90 |
for y in range(0, h, patch_size):
|
| 91 |
+
draw.line([(0, y), (w, y)], fill=outline, width=1)
|
| 92 |
return img
|
| 93 |
|
| 94 |
|
| 95 |
+
def draw_cluster_blocks(img: Image.Image, labels: np.ndarray, n_clusters: int = 4, patch_size: int = 16):
|
| 96 |
"""
|
| 97 |
+
labels: (n_patches,) cluster labels assigned to each patch index (leftβright, topβbottom)
|
| 98 |
+
"""
|
| 99 |
+
img = img.convert("RGB").resize((224, 224))
|
| 100 |
+
draw = ImageDraw.Draw(img, "RGBA")
|
| 101 |
+
grid_size, positions = patch_grid_info()
|
| 102 |
+
colors = [
|
| 103 |
+
(255, 99, 71, 140),
|
| 104 |
+
(60, 179, 113, 140),
|
| 105 |
+
(65, 105, 225, 140),
|
| 106 |
+
(255, 215, 0, 140),
|
| 107 |
+
(199, 21, 133, 140),
|
| 108 |
+
(0, 206, 209, 140),
|
| 109 |
+
]
|
| 110 |
+
for idx, lab in enumerate(labels):
|
| 111 |
+
i = idx // grid_size
|
| 112 |
+
j = idx % grid_size
|
| 113 |
+
x0 = j * patch_size
|
| 114 |
+
y0 = i * patch_size
|
| 115 |
+
x1 = x0 + patch_size
|
| 116 |
+
y1 = y0 + patch_size
|
| 117 |
+
col = colors[int(lab) % len(colors)]
|
| 118 |
+
draw.rectangle([x0, y0, x1, y1], fill=col)
|
| 119 |
+
return img
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def draw_attention_arrows(img: Image.Image, att_matrix: np.ndarray, top_k: int = 3, query_idx: Optional[int] = None):
|
| 123 |
+
"""
|
| 124 |
+
att_matrix: (n_patches, n_patches) attention from query->keys (already preprocessed)
|
| 125 |
+
If query_idx is None -> use CLS (not plotted as patch), else 0..n_patches-1
|
| 126 |
+
We'll draw arrows from query patch centers to top-k key patch centers.
|
| 127 |
+
"""
|
| 128 |
+
img = img.convert("RGB").resize((224, 224))
|
| 129 |
+
draw = ImageDraw.Draw(img, "RGBA")
|
| 130 |
+
grid_size, positions = patch_grid_info()
|
| 131 |
+
# pick a query: if None, choose center patch
|
| 132 |
+
if query_idx is None:
|
| 133 |
+
query_idx = (grid_size * grid_size) // 2
|
| 134 |
+
qpos = positions[query_idx]
|
| 135 |
+
# find top_k keys attended by this query
|
| 136 |
+
vec = att_matrix[query_idx] # length n_patches
|
| 137 |
+
top_idx = vec.argsort()[-top_k:][::-1]
|
| 138 |
+
for t in top_idx:
|
| 139 |
+
kpos = positions[t]
|
| 140 |
+
# draw line + arrowhead
|
| 141 |
+
draw.line([qpos, kpos], fill=(255, 0, 0, 200), width=3)
|
| 142 |
+
# arrowhead: small triangle
|
| 143 |
+
dx = kpos[0] - qpos[0]
|
| 144 |
+
dy = kpos[1] - qpos[1]
|
| 145 |
+
ang = math.atan2(dy, dx)
|
| 146 |
+
# size proportional
|
| 147 |
+
ah = 8
|
| 148 |
+
p1 = (kpos[0] - ah * math.cos(ang - 0.3), kpos[1] - ah * math.sin(ang - 0.3))
|
| 149 |
+
p2 = (kpos[0] - ah * math.cos(ang + 0.3), kpos[1] - ah * math.sin(ang + 0.3))
|
| 150 |
+
draw.polygon([kpos, p1, p2], fill=(255, 0, 0, 200))
|
| 151 |
+
# highlight query patch with blue circle
|
| 152 |
+
r = 10
|
| 153 |
+
draw.ellipse([qpos[0] - r, qpos[1] - r, qpos[0] + r, qpos[1] + r], outline=(0, 0, 255, 220), width=2)
|
| 154 |
+
return img
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def make_focus_overlay(img: Image.Image, heat_grid: np.ndarray, alpha: float = 0.6):
|
| 158 |
"""
|
| 159 |
+
heat_grid: (G,G) float map
|
| 160 |
+
overlay colored transparency on image where heat is high
|
| 161 |
+
"""
|
| 162 |
+
img = img.convert("RGB").resize((224, 224))
|
| 163 |
g = np.array(heat_grid, dtype=np.float32)
|
|
|
|
| 164 |
if np.any(g):
|
| 165 |
g = g - g.min()
|
| 166 |
if g.max() > 0:
|
| 167 |
g = g / g.max()
|
| 168 |
else:
|
| 169 |
+
g = np.zeros_like(g)
|
|
|
|
|
|
|
| 170 |
heat_img = Image.fromarray((g * 255).astype("uint8"), mode="L").resize((224, 224), Image.BILINEAR)
|
| 171 |
heat = np.array(heat_img).astype(np.float32) / 255.0
|
| 172 |
+
draw = ImageDraw.Draw(img, "RGBA")
|
| 173 |
+
# color mapping simple: yellow -> red
|
| 174 |
+
H, W = heat.shape
|
| 175 |
+
for y in range(H):
|
| 176 |
+
for x in range(W):
|
| 177 |
+
v = heat[y, x]
|
| 178 |
+
if v > 0.05:
|
| 179 |
+
# map to color
|
| 180 |
+
r = int(255 * v)
|
| 181 |
+
gcol = int(200 * (1 - v))
|
| 182 |
+
draw.point((x, y), fill=(r, gcol, 40, int(255 * alpha * v)))
|
| 183 |
+
return img
|
| 184 |
|
| 185 |
|
| 186 |
+
# ------------------- attention rollout (Abnar & Zuidema) -------------------
|
| 187 |
def compute_attention_rollout(all_attentions: List[torch.Tensor]) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
avg_mats = []
|
| 189 |
for a in all_attentions:
|
|
|
|
| 190 |
mat = a[0].mean(dim=0).detach().cpu().numpy() # (seq, seq)
|
| 191 |
avg_mats.append(mat)
|
|
|
|
| 192 |
seq = avg_mats[0].shape[0]
|
| 193 |
aug = []
|
| 194 |
for A in avg_mats:
|
| 195 |
A_hat = A + np.eye(seq)
|
|
|
|
| 196 |
row_sums = A_hat.sum(axis=-1, keepdims=True)
|
|
|
|
| 197 |
row_sums[row_sums == 0] = 1.0
|
| 198 |
A_hat = A_hat / row_sums
|
| 199 |
aug.append(A_hat)
|
|
|
|
| 200 |
R = aug[0]
|
| 201 |
for A in aug[1:]:
|
| 202 |
R = A @ R
|
| 203 |
return R # (seq, seq)
|
| 204 |
|
| 205 |
|
| 206 |
+
# ------------------- PCA helper for advanced -------------------
|
| 207 |
+
def pca_plot_from_hidden(hidden_states: List[torch.Tensor], layers: List[int]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
pts_all = []
|
| 209 |
+
labels = []
|
| 210 |
for li in layers:
|
| 211 |
+
hs = hidden_states[li][0].detach().cpu().numpy()
|
| 212 |
+
patches = hs[1:, :]
|
| 213 |
pca = PCA(n_components=2)
|
| 214 |
pts = pca.fit_transform(patches)
|
| 215 |
pts_all.append(pts)
|
| 216 |
+
labels.append(np.array([li] * pts.shape[0]))
|
|
|
|
| 217 |
coords = np.vstack(pts_all)
|
| 218 |
+
layer_labels = np.concatenate(labels)
|
| 219 |
+
df = {"x": coords[:, 0], "y": coords[:, 1], "layer": layer_labels.astype(str)}
|
| 220 |
fig = px.scatter(df, x="x", y="y", color="layer", title="Patch embeddings across layers (PCA)")
|
| 221 |
fig.update_traces(marker=dict(size=6))
|
| 222 |
fig.update_layout(height=480)
|
| 223 |
return fig
|
| 224 |
|
| 225 |
|
| 226 |
+
# ------------------- main analyzer (both modes) -------------------
|
| 227 |
+
def analyze_all(img: Optional[Image.Image], mode_simple: bool):
|
| 228 |
if img is None:
|
| 229 |
+
# return placeholders for all outputs
|
| 230 |
+
empty = None
|
| 231 |
+
return empty, empty, empty, empty, "", empty, empty, empty
|
| 232 |
|
| 233 |
base, clf, processor = load_models()
|
| 234 |
|
| 235 |
+
# preprocess
|
| 236 |
+
img224 = img.convert("RGB").resize((224, 224))
|
| 237 |
+
inputs = processor(images=img224, return_tensors="pt").to(DEVICE)
|
| 238 |
|
| 239 |
+
# forward through base model to get attentions & hidden states
|
| 240 |
with torch.no_grad():
|
| 241 |
outputs = base(**inputs)
|
| 242 |
|
| 243 |
+
attentions = outputs.attentions # list L of (1, heads, seq, seq)
|
|
|
|
| 244 |
hidden_states = outputs.hidden_states
|
| 245 |
|
| 246 |
+
# build grid & info
|
| 247 |
+
grid_size, positions = patch_grid_info()
|
| 248 |
seq_len = attentions[0].shape[-1]
|
| 249 |
n_patches = seq_len - 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
+
# Step1: patch grid image
|
| 252 |
+
patch_grid_img = draw_patch_grid(img224.copy())
|
| 253 |
+
|
| 254 |
+
# Step2: cluster patches using last hidden layer embeddings
|
| 255 |
+
last_hidden = hidden_states[-1][0].detach().cpu().numpy() # (seq, hidden)
|
| 256 |
+
patch_embeddings = last_hidden[1:, :] # remove CLS
|
| 257 |
+
# KMeans small number clusters (4)
|
| 258 |
+
n_clusters = 4
|
| 259 |
+
try:
|
| 260 |
+
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(patch_embeddings)
|
| 261 |
+
cluster_labels = kmeans.labels_
|
| 262 |
+
except Exception:
|
| 263 |
+
# fallback uniform
|
| 264 |
+
cluster_labels = np.zeros(n_patches, dtype=int)
|
| 265 |
+
|
| 266 |
+
cluster_img = draw_cluster_blocks(img224.copy(), cluster_labels, n_clusters=n_clusters)
|
| 267 |
+
|
| 268 |
+
# Step3: simplified arrows using average last-layer attention across heads
|
| 269 |
+
last_att = attentions[-1][0].mean(dim=0).cpu().numpy() # (seq, seq) averaged heads
|
| 270 |
+
# We want patch->patch attention (exclude CLS index in mapping)
|
| 271 |
+
# Map token indices 1.. to patch indices 0..
|
| 272 |
+
# Make an (n_patches, n_patches) matrix where row q corresponds to query patch q
|
| 273 |
+
if last_att.shape[0] >= n_patches + 1:
|
| 274 |
+
patch_to_patch = last_att[1:, 1:] # (n_patches, n_patches)
|
| 275 |
+
else:
|
| 276 |
+
# fallback zeros
|
| 277 |
+
patch_to_patch = np.zeros((n_patches, n_patches))
|
| 278 |
+
# draw arrows for a central query
|
| 279 |
+
arrow_img = draw_attention_arrows(img224.copy(), patch_to_patch, top_k=4, query_idx=(n_patches // 2))
|
| 280 |
+
|
| 281 |
+
# Step4: rollout focus map (CLS rollout)
|
| 282 |
+
rollout = compute_attention_rollout(attentions) # (seq, seq)
|
| 283 |
+
# take CLS row -> keys 1.. = patches
|
| 284 |
+
rollout_cls = rollout[0, 1:]
|
| 285 |
if rollout_cls.shape[0] != grid_size * grid_size:
|
| 286 |
+
tmp = np.zeros(grid_size * grid_size, dtype=float)
|
| 287 |
+
nmin = min(len(rollout_cls), tmp.shape[0])
|
| 288 |
tmp[:nmin] = rollout_cls[:nmin]
|
| 289 |
rollout_cls = tmp
|
| 290 |
rollout_grid = rollout_cls.reshape(grid_size, grid_size)
|
| 291 |
+
focus_img = make_focus_overlay(img224.copy(), rollout_grid, alpha=0.6)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
+
# Top-5 predictions from classifier head
|
| 294 |
with torch.no_grad():
|
| 295 |
logits = clf(**inputs).logits[0].cpu().numpy()
|
| 296 |
probs = np.exp(logits - logits.max())
|
|
|
|
| 299 |
labels = clf.config.id2label
|
| 300 |
preds_text = "\n".join([f"{labels[i]} β {probs[i]*100:.2f}%" for i in top5])
|
| 301 |
|
| 302 |
+
# Advanced outputs: PCA fig and default attention overlay (last layer head 0 CLS->patch)
|
| 303 |
+
pca_fig = pca_plot_from_hidden(hidden_states, [0, max(0, len(hidden_states) // 2), len(hidden_states) - 1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
+
# Attention overlay for advanced default (last layer head0 CLS->patch)
|
| 306 |
+
att_np = attentions[-1][0].cpu().numpy() # (heads, seq, seq)
|
| 307 |
+
# average heads for simplicity
|
| 308 |
+
cls_to_patches = att_np.mean(axis=0)[0, 1:]
|
| 309 |
+
if cls_to_patches.shape[0] != grid_size * grid_size:
|
| 310 |
+
tmp = np.zeros(grid_size * grid_size, dtype=float)
|
| 311 |
+
nmin = min(len(cls_to_patches), tmp.shape[0])
|
| 312 |
+
tmp[:nmin] = cls_to_patches[:nmin]
|
| 313 |
+
cls_to_patches = tmp
|
| 314 |
+
cls_grid = cls_to_patches.reshape(grid_size, grid_size)
|
| 315 |
+
# create overlay
|
| 316 |
+
from PIL import Image # ensure imported
|
| 317 |
+
focus_overlay_default = make_focus_overlay(img224.copy(), cls_grid, alpha=0.5)
|
| 318 |
+
|
| 319 |
+
# make state for interactive advanced controls (move to CPU to save GPU mem)
|
| 320 |
state = {
|
| 321 |
+
"attentions": [a.cpu() for a in attentions],
|
| 322 |
"hidden_states": [h.cpu() for h in hidden_states],
|
| 323 |
"grid_size": grid_size,
|
| 324 |
+
"num_layers": len(attentions),
|
| 325 |
"num_heads": attentions[0].shape[1],
|
| 326 |
"base_image": img,
|
| 327 |
}
|
| 328 |
|
| 329 |
+
# Return values:
|
| 330 |
+
# Simple view images: patch_grid_img, cluster_img, arrow_img, focus_img, preds_text
|
| 331 |
+
# Advanced outputs: focus_overlay_default, pca_fig, preds_text, explain_md, state
|
| 332 |
+
simple_explain = """
|
| 333 |
+
**How ViT Sees β Simple Steps**
|
| 334 |
|
| 335 |
+
1) **Chop** β The image is chopped into small square tiles (patches) like LEGO pieces.
|
| 336 |
+
2) **Understand** β Each piece gets a code that describes colors/edges. Pieces that look similar are grouped.
|
| 337 |
+
3) **Talk** β Pieces tell each other what they see (we draw arrows to show that).
|
| 338 |
+
4) **Focus & Guess** β The model merges clues and focuses on important areas, then guesses what the image shows.
|
| 339 |
+
"""
|
| 340 |
|
| 341 |
+
advanced_explain = """
|
| 342 |
+
**Advanced View:** Explore attention per layer/head, the PCA of patch embeddings, and the model's internal focus.
|
| 343 |
+
Use sliders to change layer/head and see how ViT's attention evolves.
|
| 344 |
+
"""
|
| 345 |
|
| 346 |
+
return (
|
| 347 |
+
patch_grid_img,
|
| 348 |
+
cluster_img,
|
| 349 |
+
arrow_img,
|
| 350 |
+
focus_img,
|
| 351 |
+
preds_text,
|
| 352 |
+
simple_explain,
|
| 353 |
+
focus_overlay_default,
|
| 354 |
+
pca_fig,
|
| 355 |
+
preds_text,
|
| 356 |
+
advanced_explain,
|
| 357 |
+
state,
|
| 358 |
+
)
|
| 359 |
|
|
|
|
|
|
|
|
|
|
| 360 |
|
| 361 |
+
# ------------------- interactive advanced helpers -------------------
|
| 362 |
+
def advanced_update_attention(state: Dict[str, Any], layer_idx: int, head_idx: int):
|
| 363 |
+
if not state:
|
| 364 |
+
return None
|
| 365 |
+
l = max(0, min(int(layer_idx), state["num_layers"] - 1))
|
| 366 |
+
h = max(0, min(int(head_idx), state["num_heads"] - 1))
|
| 367 |
+
att_tensor = state["attentions"][l] # (1, heads, seq, seq) or (heads, seq, seq)
|
| 368 |
if att_tensor.ndim == 4:
|
| 369 |
att_tensor = att_tensor[0]
|
| 370 |
att_np = att_tensor.numpy() # (heads, seq, seq)
|
| 371 |
+
# take CLS->patchs for selected head
|
| 372 |
+
vec = att_np[h, 0, 1:]
|
| 373 |
+
grid = state["grid_size"]
|
|
|
|
|
|
|
|
|
|
| 374 |
if vec.shape[0] != grid * grid:
|
| 375 |
+
tmp = np.zeros(grid * grid, dtype=float)
|
| 376 |
nmin = min(vec.shape[0], tmp.shape[0])
|
| 377 |
tmp[:nmin] = vec[:nmin]
|
| 378 |
vec = tmp
|
|
|
|
| 379 |
grid_map = vec.reshape(grid, grid)
|
| 380 |
+
return make_focus_overlay(state["base_image"].convert("RGB"), grid_map, alpha=0.55)
|
|
|
|
| 381 |
|
| 382 |
|
| 383 |
+
def advanced_update_rollout(state: Dict[str, Any]):
|
| 384 |
if not state:
|
| 385 |
return None
|
| 386 |
+
mats = [a.unsqueeze(0) if a.ndim == 3 else a for a in state["attentions"]]
|
| 387 |
+
R = compute_attention_rollout(mats)
|
|
|
|
| 388 |
grid = state["grid_size"]
|
| 389 |
rollout_cls = R[0, 1:]
|
| 390 |
if rollout_cls.shape[0] != grid * grid:
|
| 391 |
+
tmp = np.zeros(grid * grid, dtype=float)
|
| 392 |
+
nmin = min(len(rollout_cls), tmp.shape[0])
|
| 393 |
tmp[:nmin] = rollout_cls[:nmin]
|
| 394 |
rollout_cls = tmp
|
| 395 |
rollout_grid = rollout_cls.reshape(grid, grid)
|
| 396 |
+
return make_focus_overlay(state["base_image"].convert("RGB"), rollout_grid, alpha=0.6)
|
| 397 |
|
| 398 |
|
| 399 |
+
def advanced_update_pca(state: Dict[str, Any], txt: str):
|
| 400 |
if not state:
|
| 401 |
return None
|
| 402 |
+
try:
|
| 403 |
+
layers = [int(x.strip()) for x in txt.split(",") if x.strip() != ""]
|
| 404 |
+
except Exception:
|
| 405 |
+
layers = [0, max(0, state["num_layers"] - 1)]
|
| 406 |
+
return pca_plot_from_hidden(state["hidden_states"], layers)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
# ------------------- GRADIO UI -------------------
|
| 410 |
+
with gr.Blocks(title="ViT Visualizer β Simple + Advanced") as demo:
|
| 411 |
+
gr.Markdown("# π How Vision Transformers (ViT) See Images\n"
|
| 412 |
+
"Simple mode (story-style) + Advanced mode (inspect internals). Model: **google/vit-base-patch16-224**")
|
| 413 |
+
|
| 414 |
+
with gr.Tabs():
|
| 415 |
+
with gr.TabItem("Simple (for everyone)"):
|
| 416 |
+
with gr.Row():
|
| 417 |
+
with gr.Column(scale=1):
|
| 418 |
+
img_input = gr.Image(label="Upload an image (photo / object)", type="pil")
|
| 419 |
+
run_btn = gr.Button("π Explain simply")
|
| 420 |
+
gr.Markdown("Tip: use clear images of objects, animals, scenes for best examples.")
|
| 421 |
+
with gr.Column(scale=1):
|
| 422 |
+
pass
|
| 423 |
+
|
| 424 |
+
gr.Markdown("### Step 1 β Chopped into patches")
|
| 425 |
+
step1 = gr.Image(label="Patch Grid (ViT chops image into 16Γ16 patches)")
|
| 426 |
+
|
| 427 |
+
gr.Markdown("### Step 2 β The model groups similar patches")
|
| 428 |
+
step2 = gr.Image(label="Clustered patches (colored blocks)")
|
| 429 |
+
|
| 430 |
+
gr.Markdown("### Step 3 β Patches talk to each other (simplified)")
|
| 431 |
+
step3 = gr.Image(label="Patch-to-Patch arrows")
|
| 432 |
+
|
| 433 |
+
gr.Markdown("### Step 4 β Model focus map and guess")
|
| 434 |
+
with gr.Row():
|
| 435 |
+
step4 = gr.Image(label="Focus map (where model looked most)")
|
| 436 |
+
preds_simple = gr.Textbox(label="Model guesses (Top-5)", lines=4)
|
| 437 |
+
|
| 438 |
+
explanation_simple = gr.Markdown()
|
| 439 |
+
|
| 440 |
+
run_btn.click(
|
| 441 |
+
fn=analyze_all,
|
| 442 |
+
inputs=[img_input, gr.State(True)],
|
| 443 |
+
outputs=[step1, step2, step3, step4, preds_simple, explanation_simple,
|
| 444 |
+
gr.State(), gr.Plot(), gr.Textbox(), gr.Markdown(), gr.State()],
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
with gr.TabItem("Advanced (inspect internals)"):
|
| 448 |
+
with gr.Row():
|
| 449 |
+
with gr.Column(scale=1):
|
| 450 |
+
img_adv = gr.Image(label="Upload image for advanced view", type="pil")
|
| 451 |
+
run_adv = gr.Button("Analyze (advanced)")
|
| 452 |
+
gr.Markdown("Use the sliders to explore attention per layer and head.")
|
| 453 |
+
layer_slider = gr.Slider(0, 11, value=11, step=1, label="Layer (0=shallow)")
|
| 454 |
+
head_slider = gr.Slider(0, 11, value=0, step=1, label="Head index")
|
| 455 |
+
rollout_btn = gr.Button("Refresh Rollout Overlay")
|
| 456 |
+
pca_txt = gr.Textbox(label="PCA layers (comma separated)", value="0,6,11")
|
| 457 |
+
pca_btn = gr.Button("Update PCA")
|
| 458 |
+
with gr.Column(scale=1):
|
| 459 |
+
adv_attn = gr.Image(label="Attention overlay (layer/head CLS->patch)")
|
| 460 |
+
adv_rollout = gr.Image(label="Attention rollout overlay (aggregated)")
|
| 461 |
+
adv_pca = gr.Plot(label="PCA of patch embeddings")
|
| 462 |
+
adv_preds = gr.Textbox(label="Top-5 predictions", lines=5)
|
| 463 |
+
adv_explain = gr.Markdown()
|
| 464 |
+
|
| 465 |
+
state_box = gr.State()
|
| 466 |
+
|
| 467 |
+
# run advanced analysis
|
| 468 |
+
run_adv.click(
|
| 469 |
+
fn=analyze_all,
|
| 470 |
+
inputs=[img_adv, gr.State(False)],
|
| 471 |
+
outputs=[gr.Image(), gr.Image(), gr.Image(), gr.Image(), adv_preds, gr.Markdown(),
|
| 472 |
+
adv_attn, adv_pca, adv_preds, adv_explain, state_box],
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
# update attention overlay with sliders
|
| 476 |
+
layer_slider.change(
|
| 477 |
+
fn=advanced_update_attention,
|
| 478 |
+
inputs=[state_box, layer_slider, head_slider],
|
| 479 |
+
outputs=[adv_attn],
|
| 480 |
+
)
|
| 481 |
+
head_slider.change(
|
| 482 |
+
fn=advanced_update_attention,
|
| 483 |
+
inputs=[state_box, layer_slider, head_slider],
|
| 484 |
+
outputs=[adv_attn],
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
rollout_btn.click(
|
| 488 |
+
fn=advanced_update_rollout,
|
| 489 |
+
inputs=[state_box],
|
| 490 |
+
outputs=[adv_rollout],
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
pca_btn.click(
|
| 494 |
+
fn=advanced_update_pca,
|
| 495 |
+
inputs=[state_box, pca_txt],
|
| 496 |
+
outputs=[adv_pca],
|
| 497 |
+
)
|
| 498 |
|
| 499 |
demo.launch()
|