iljung1106 commited on
Commit
39e77fe
Β·
1 Parent(s): 178daad

Add Grad-CAM visualization.

Browse files
Files changed (2) hide show
  1. app/visualization.py +301 -0
  2. webui_gradio.py +91 -0
app/visualization.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualization utilities for artist embedding model:
3
+ - Grad-CAM heatmaps
4
+ - View attention weights (whole/face/eyes)
5
+ - Branch attention weights (Gram/Cov/Spectrum/Stats)
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Dict, List, Optional, Tuple
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from PIL import Image
16
+
17
+
18
+ @dataclass
19
+ class ViewAnalysis:
20
+ """Analysis results for a single inference."""
21
+ # View attention weights [3] for whole/face/eyes
22
+ view_weights: Dict[str, float]
23
+ # Branch attention weights per view {view_name: {branch_name: weight}}
24
+ branch_weights: Dict[str, Dict[str, float]]
25
+ # Grad-CAM heatmaps per view (PIL Images)
26
+ gradcam_heatmaps: Dict[str, Optional[Image.Image]]
27
+ # Original images for overlay
28
+ original_images: Dict[str, Optional[Image.Image]]
29
+
30
+
31
+ def _get_branch_weights(encoder, x: torch.Tensor) -> Dict[str, float]:
32
+ """
33
+ Extract branch attention weights from a ViewEncoder forward pass.
34
+ Returns dict with keys: gram, cov, spectrum, stats
35
+ """
36
+ # We need to do a partial forward to get the branch gate weights
37
+ with torch.no_grad():
38
+ x_lab = encoder._rgb_to_lab(x)
39
+ f0 = encoder.stem(x_lab)
40
+ f1 = encoder.b1(f0)
41
+ f2 = encoder.b2(f1)
42
+ f3 = encoder.b3(f2)
43
+ f4 = encoder.b4(f3)
44
+
45
+ g3 = encoder.h_gram3(f3)
46
+ c3 = encoder.h_cov3(f3)
47
+ sp3 = encoder.h_sp3(f3)
48
+ st3 = encoder.h_st3(f3)
49
+
50
+ g4 = encoder.h_gram4(f4)
51
+ c4 = encoder.h_cov4(f4)
52
+ sp4 = encoder.h_sp4(f4)
53
+ st4 = encoder.h_st4(f4)
54
+
55
+ b_gram = torch.cat([g3, g4], dim=1)
56
+ b_cov = torch.cat([c3, c4], dim=1)
57
+ b_sp = torch.cat([sp3, sp4], dim=1)
58
+ b_st = torch.cat([st3, st4], dim=1)
59
+
60
+ flat = torch.cat([b_gram, b_cov, b_sp, b_st], dim=1)
61
+ gate_logits = encoder.branch_gate(flat)
62
+ w = torch.softmax(gate_logits, dim=-1)
63
+
64
+ # w is [1, 4] for single image
65
+ w_np = w[0].cpu().numpy()
66
+ return {
67
+ "Gram": float(w_np[0]),
68
+ "Cov": float(w_np[1]),
69
+ "Spectrum": float(w_np[2]),
70
+ "Stats": float(w_np[3]),
71
+ }
72
+
73
+
74
+ def _compute_gradcam(
75
+ encoder,
76
+ x: torch.Tensor,
77
+ target_layer_name: str = "b3",
78
+ ) -> np.ndarray:
79
+ """
80
+ Compute Grad-CAM heatmap for a ViewEncoder.
81
+ Uses gradients of the output w.r.t. an intermediate feature map.
82
+ Returns a heatmap as numpy array [H, W] normalized to [0, 1].
83
+ """
84
+ # Storage for activations and gradients
85
+ activations = {}
86
+ gradients = {}
87
+
88
+ def forward_hook(module, input, output):
89
+ activations["value"] = output.detach()
90
+
91
+ def backward_hook(module, grad_input, grad_output):
92
+ gradients["value"] = grad_output[0].detach()
93
+
94
+ # Get the target layer
95
+ target_layer = getattr(encoder, target_layer_name, None)
96
+ if target_layer is None:
97
+ # Fallback to b2 or b1
98
+ for fallback in ["b2", "b1", "stem"]:
99
+ target_layer = getattr(encoder, fallback, None)
100
+ if target_layer is not None:
101
+ break
102
+
103
+ if target_layer is None:
104
+ return np.zeros((x.shape[2], x.shape[3]), dtype=np.float32)
105
+
106
+ # Register hooks
107
+ fwd_handle = target_layer.register_forward_hook(forward_hook)
108
+ bwd_handle = target_layer.register_full_backward_hook(backward_hook)
109
+
110
+ try:
111
+ # Forward pass
112
+ x.requires_grad_(True)
113
+ output = encoder(x)
114
+
115
+ # Backward pass - use the L2 norm of output as target
116
+ target = output.norm(dim=1).sum()
117
+ encoder.zero_grad()
118
+ target.backward(retain_graph=True)
119
+
120
+ # Get activations and gradients
121
+ acts = activations.get("value")
122
+ grads = gradients.get("value")
123
+
124
+ if acts is None or grads is None:
125
+ return np.zeros((x.shape[2], x.shape[3]), dtype=np.float32)
126
+
127
+ # Compute Grad-CAM weights (global average pooling of gradients)
128
+ weights = grads.mean(dim=(2, 3), keepdim=True) # [B, C, 1, 1]
129
+
130
+ # Weighted combination of activations
131
+ cam = (weights * acts).sum(dim=1, keepdim=True) # [B, 1, H, W]
132
+ cam = F.relu(cam) # Only positive contributions
133
+
134
+ # Normalize
135
+ cam = cam[0, 0].cpu().numpy()
136
+ if cam.max() > 0:
137
+ cam = cam / cam.max()
138
+
139
+ # Resize to input size
140
+ cam_pil = Image.fromarray((cam * 255).astype(np.uint8))
141
+ cam_pil = cam_pil.resize((x.shape[3], x.shape[2]), Image.BILINEAR)
142
+ cam = np.array(cam_pil).astype(np.float32) / 255.0
143
+
144
+ return cam
145
+
146
+ finally:
147
+ fwd_handle.remove()
148
+ bwd_handle.remove()
149
+ x.requires_grad_(False)
150
+
151
+
152
+ def _overlay_heatmap(
153
+ image: Image.Image,
154
+ heatmap: np.ndarray,
155
+ alpha: float = 0.5,
156
+ colormap: str = "jet",
157
+ ) -> Image.Image:
158
+ """Overlay a heatmap on an image."""
159
+ import matplotlib.pyplot as plt
160
+
161
+ # Ensure heatmap is 2D and normalized
162
+ heatmap = np.clip(heatmap, 0, 1)
163
+
164
+ # Get colormap
165
+ cmap = plt.get_cmap(colormap)
166
+ heatmap_colored = cmap(heatmap)[:, :, :3] # RGB only, no alpha
167
+ heatmap_colored = (heatmap_colored * 255).astype(np.uint8)
168
+
169
+ # Resize heatmap to match image
170
+ heatmap_pil = Image.fromarray(heatmap_colored)
171
+ heatmap_pil = heatmap_pil.resize(image.size, Image.BILINEAR)
172
+
173
+ # Blend
174
+ image_rgb = image.convert("RGB")
175
+ blended = Image.blend(image_rgb, heatmap_pil, alpha)
176
+
177
+ return blended
178
+
179
+
180
+ def analyze_views(
181
+ model: torch.nn.Module,
182
+ views: Dict[str, Optional[torch.Tensor]],
183
+ original_images: Dict[str, Optional[Image.Image]],
184
+ device: torch.device,
185
+ ) -> ViewAnalysis:
186
+ """
187
+ Perform full analysis on a set of views.
188
+ Returns view weights, branch weights per view, and Grad-CAM heatmaps.
189
+ """
190
+ model.eval()
191
+
192
+ # Prepare masks
193
+ masks = {}
194
+ view_tensors = {}
195
+ for k in ("whole", "face", "eyes"):
196
+ if views.get(k) is not None:
197
+ view_tensors[k] = views[k].unsqueeze(0).to(device)
198
+ masks[k] = torch.ones(1, dtype=torch.bool, device=device)
199
+ else:
200
+ view_tensors[k] = None
201
+ masks[k] = torch.zeros(1, dtype=torch.bool, device=device)
202
+
203
+ # Get view attention weights from forward pass
204
+ with torch.no_grad():
205
+ _, _, W = model(view_tensors, masks)
206
+
207
+ # W is [1, num_present_views]
208
+ W_np = W[0].cpu().numpy()
209
+
210
+ # Map W back to view names (only present views have weights)
211
+ view_order = ["whole", "face", "eyes"]
212
+ present_views = [k for k in view_order if view_tensors[k] is not None]
213
+
214
+ view_weights = {}
215
+ for i, k in enumerate(present_views):
216
+ view_weights[k] = float(W_np[i])
217
+ for k in view_order:
218
+ if k not in view_weights:
219
+ view_weights[k] = 0.0
220
+
221
+ # Get branch weights and Grad-CAM for each view
222
+ branch_weights = {}
223
+ gradcam_heatmaps = {}
224
+
225
+ # Get encoder (shared or separate)
226
+ enc_whole = model.enc_whole
227
+ enc_face = model.enc_face
228
+ enc_eyes = model.enc_eyes
229
+
230
+ encoders = {"whole": enc_whole, "face": enc_face, "eyes": enc_eyes}
231
+
232
+ for k in view_order:
233
+ if view_tensors[k] is not None:
234
+ enc = encoders[k]
235
+ x = view_tensors[k]
236
+
237
+ # Branch weights
238
+ try:
239
+ branch_weights[k] = _get_branch_weights(enc, x)
240
+ except Exception:
241
+ branch_weights[k] = {"Gram": 0.25, "Cov": 0.25, "Spectrum": 0.25, "Stats": 0.25}
242
+
243
+ # Grad-CAM
244
+ try:
245
+ heatmap = _compute_gradcam(enc, x.clone(), target_layer_name="b3")
246
+ if original_images.get(k) is not None:
247
+ gradcam_heatmaps[k] = _overlay_heatmap(original_images[k], heatmap, alpha=0.5)
248
+ else:
249
+ gradcam_heatmaps[k] = None
250
+ except Exception:
251
+ gradcam_heatmaps[k] = None
252
+ else:
253
+ branch_weights[k] = {}
254
+ gradcam_heatmaps[k] = None
255
+
256
+ return ViewAnalysis(
257
+ view_weights=view_weights,
258
+ branch_weights=branch_weights,
259
+ gradcam_heatmaps=gradcam_heatmaps,
260
+ original_images={k: original_images.get(k) for k in view_order},
261
+ )
262
+
263
+
264
+ def format_analysis_text(analysis: ViewAnalysis) -> str:
265
+ """Format analysis results as markdown text."""
266
+ lines = ["## πŸ“Š View & Branch Analysis\n"]
267
+
268
+ # View weights
269
+ lines.append("### View Attention Weights")
270
+ lines.append("How much each view contributed to the final embedding:\n")
271
+ for k in ("whole", "face", "eyes"):
272
+ w = analysis.view_weights.get(k, 0.0)
273
+ bar_len = int(w * 20)
274
+ bar = "β–ˆ" * bar_len + "β–‘" * (20 - bar_len)
275
+ lines.append(f"- **{k.capitalize()}**: `{bar}` {w:.1%}")
276
+
277
+ lines.append("")
278
+
279
+ # Branch weights per view
280
+ lines.append("### Branch Attention Weights (per view)")
281
+ lines.append("Which style features were most important:\n")
282
+ branch_names = ["Gram", "Cov", "Spectrum", "Stats"]
283
+ branch_desc = {
284
+ "Gram": "texture patterns",
285
+ "Cov": "color correlations",
286
+ "Spectrum": "frequency content",
287
+ "Stats": "mean/variance",
288
+ }
289
+
290
+ for view_name in ("whole", "face", "eyes"):
291
+ bw = analysis.branch_weights.get(view_name, {})
292
+ if bw:
293
+ lines.append(f"\n**{view_name.capitalize()}**:")
294
+ for b in branch_names:
295
+ w = bw.get(b, 0.0)
296
+ bar_len = int(w * 15)
297
+ bar = "β–“" * bar_len + "β–‘" * (15 - bar_len)
298
+ lines.append(f" - {b} ({branch_desc[b]}): `{bar}` {w:.1%}")
299
+
300
+ return "\n".join(lines)
301
+
webui_gradio.py CHANGED
@@ -166,6 +166,7 @@ _patch_gradio_client_bool_jsonschema()
166
  from app.model_io import LoadedModel, embed_triview, load_style_model
167
  from app.proto_db import PrototypeDB, load_prototype_db, topk_predictions_unique_labels
168
  from app.view_extractor import AnimeFaceEyeExtractor, ExtractorCfg
 
169
 
170
 
171
  ROOT = Path(__file__).resolve().parent
@@ -316,6 +317,65 @@ def classify(
316
  return "βœ… OK", rows, (face_pil if "face_pil" in locals() else None), (eyes_pil if "eyes_pil" in locals() else None)
317
 
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  def _gallery_item_to_pil(item) -> Optional[Image.Image]:
320
  """Convert a Gradio gallery item to PIL Image (handles various formats)."""
321
  if item is None:
@@ -520,6 +580,37 @@ def build_ui() -> gr.Blocks:
520
  table = gr.Dataframe(headers=["label", "cosine_sim"], datatype=["str", "number"], interactive=False)
521
  run_btn.click(classify, inputs=[whole, topk], outputs=[out_status, table, face_prev, eyes_prev])
522
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  with gr.Tab("Add prototype (temporary)"):
524
  gr.Markdown(
525
  "### ⚠️ Temporary Prototypes Only\n"
 
166
  from app.model_io import LoadedModel, embed_triview, load_style_model
167
  from app.proto_db import PrototypeDB, load_prototype_db, topk_predictions_unique_labels
168
  from app.view_extractor import AnimeFaceEyeExtractor, ExtractorCfg
169
+ from app.visualization import ViewAnalysis, analyze_views, format_analysis_text
170
 
171
 
172
  ROOT = Path(__file__).resolve().parent
 
317
  return "βœ… OK", rows, (face_pil if "face_pil" in locals() else None), (eyes_pil if "eyes_pil" in locals() else None)
318
 
319
 
320
+ def analyze_image(whole_img):
321
+ """
322
+ Analyze an image showing view weights, branch weights, and Grad-CAM.
323
+ Returns: status, analysis_text, whole_gradcam, face_gradcam, eyes_gradcam, face_preview, eyes_preview
324
+ """
325
+ if APP_STATE.lm is None:
326
+ return "❌ Click **Load** first.", "", None, None, None, None, None
327
+
328
+ lm = APP_STATE.lm
329
+ ex = APP_STATE.extractor
330
+
331
+ def _to_pil(x):
332
+ if x is None:
333
+ return None
334
+ if isinstance(x, Image.Image):
335
+ return x
336
+ return Image.fromarray(x)
337
+
338
+ w = _to_pil(whole_img)
339
+ if w is None:
340
+ return "❌ Provide a whole image.", "", None, None, None, None, None
341
+
342
+ try:
343
+ # Extract face and eyes
344
+ face_pil = None
345
+ eyes_pil = None
346
+ if ex is not None:
347
+ rgb = np.array(w.convert("RGB"))
348
+ face_rgb, eyes_rgb = ex.extract(rgb)
349
+ if face_rgb is not None:
350
+ face_pil = Image.fromarray(face_rgb)
351
+ if eyes_rgb is not None:
352
+ eyes_pil = Image.fromarray(eyes_rgb)
353
+
354
+ # Prepare tensors
355
+ wt = _pil_to_tensor(w, lm.T_w)
356
+ ft = _pil_to_tensor(face_pil, lm.T_f) if face_pil is not None else None
357
+ et = _pil_to_tensor(eyes_pil, lm.T_e) if eyes_pil is not None else None
358
+
359
+ views = {"whole": wt, "face": ft, "eyes": et}
360
+ original_images = {"whole": w, "face": face_pil, "eyes": eyes_pil}
361
+
362
+ # Run analysis
363
+ analysis = analyze_views(lm.model, views, original_images, lm.device)
364
+ analysis_text = format_analysis_text(analysis)
365
+
366
+ return (
367
+ "βœ… Analysis complete",
368
+ analysis_text,
369
+ analysis.gradcam_heatmaps.get("whole"),
370
+ analysis.gradcam_heatmaps.get("face"),
371
+ analysis.gradcam_heatmaps.get("eyes"),
372
+ face_pil,
373
+ eyes_pil,
374
+ )
375
+ except Exception as e:
376
+ return f"❌ Analysis failed: {e}", "", None, None, None, None, None
377
+
378
+
379
  def _gallery_item_to_pil(item) -> Optional[Image.Image]:
380
  """Convert a Gradio gallery item to PIL Image (handles various formats)."""
381
  if item is None:
 
580
  table = gr.Dataframe(headers=["label", "cosine_sim"], datatype=["str", "number"], interactive=False)
581
  run_btn.click(classify, inputs=[whole, topk], outputs=[out_status, table, face_prev, eyes_prev])
582
 
583
+ with gr.Tab("Analyze (Grad-CAM)"):
584
+ gr.Markdown(
585
+ "### πŸ” View & Branch Analysis with Grad-CAM\n"
586
+ "Visualize which parts of the image and which style features the model focuses on.\n"
587
+ "- **View weights**: How much each view (whole/face/eyes) contributed\n"
588
+ "- **Branch weights**: Which style features (Gram/Cov/Spectrum/Stats) were important\n"
589
+ "- **Grad-CAM**: Spatial attention heatmaps showing where the model looked"
590
+ )
591
+ with gr.Row():
592
+ analyze_input = gr.Image(label="Whole image", type="pil")
593
+ analyze_btn = gr.Button("Analyze", variant="primary")
594
+ analyze_status = gr.Markdown("")
595
+ analyze_text = gr.Markdown("")
596
+
597
+ gr.Markdown("### Grad-CAM Heatmaps")
598
+ with gr.Row():
599
+ gcam_whole = gr.Image(label="Whole (Grad-CAM)", type="pil")
600
+ gcam_face = gr.Image(label="Face (Grad-CAM)", type="pil")
601
+ gcam_eyes = gr.Image(label="Eyes (Grad-CAM)", type="pil")
602
+
603
+ gr.Markdown("### Extracted Views")
604
+ with gr.Row():
605
+ analyze_face = gr.Image(label="Extracted Face", type="pil")
606
+ analyze_eyes = gr.Image(label="Extracted Eyes", type="pil")
607
+
608
+ analyze_btn.click(
609
+ analyze_image,
610
+ inputs=[analyze_input],
611
+ outputs=[analyze_status, analyze_text, gcam_whole, gcam_face, gcam_eyes, analyze_face, analyze_eyes],
612
+ )
613
+
614
  with gr.Tab("Add prototype (temporary)"):
615
  gr.Markdown(
616
  "### ⚠️ Temporary Prototypes Only\n"