Spaces:
Sleeping
Sleeping
iljung1106
commited on
Commit
Β·
39e77fe
1
Parent(s):
178daad
Add Grad-CAM visualization.
Browse files- app/visualization.py +301 -0
- webui_gradio.py +91 -0
app/visualization.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Visualization utilities for artist embedding model:
|
| 3 |
+
- Grad-CAM heatmaps
|
| 4 |
+
- View attention weights (whole/face/eyes)
|
| 5 |
+
- Branch attention weights (Gram/Cov/Spectrum/Stats)
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Dict, List, Optional, Tuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class ViewAnalysis:
|
| 20 |
+
"""Analysis results for a single inference."""
|
| 21 |
+
# View attention weights [3] for whole/face/eyes
|
| 22 |
+
view_weights: Dict[str, float]
|
| 23 |
+
# Branch attention weights per view {view_name: {branch_name: weight}}
|
| 24 |
+
branch_weights: Dict[str, Dict[str, float]]
|
| 25 |
+
# Grad-CAM heatmaps per view (PIL Images)
|
| 26 |
+
gradcam_heatmaps: Dict[str, Optional[Image.Image]]
|
| 27 |
+
# Original images for overlay
|
| 28 |
+
original_images: Dict[str, Optional[Image.Image]]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _get_branch_weights(encoder, x: torch.Tensor) -> Dict[str, float]:
|
| 32 |
+
"""
|
| 33 |
+
Extract branch attention weights from a ViewEncoder forward pass.
|
| 34 |
+
Returns dict with keys: gram, cov, spectrum, stats
|
| 35 |
+
"""
|
| 36 |
+
# We need to do a partial forward to get the branch gate weights
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
x_lab = encoder._rgb_to_lab(x)
|
| 39 |
+
f0 = encoder.stem(x_lab)
|
| 40 |
+
f1 = encoder.b1(f0)
|
| 41 |
+
f2 = encoder.b2(f1)
|
| 42 |
+
f3 = encoder.b3(f2)
|
| 43 |
+
f4 = encoder.b4(f3)
|
| 44 |
+
|
| 45 |
+
g3 = encoder.h_gram3(f3)
|
| 46 |
+
c3 = encoder.h_cov3(f3)
|
| 47 |
+
sp3 = encoder.h_sp3(f3)
|
| 48 |
+
st3 = encoder.h_st3(f3)
|
| 49 |
+
|
| 50 |
+
g4 = encoder.h_gram4(f4)
|
| 51 |
+
c4 = encoder.h_cov4(f4)
|
| 52 |
+
sp4 = encoder.h_sp4(f4)
|
| 53 |
+
st4 = encoder.h_st4(f4)
|
| 54 |
+
|
| 55 |
+
b_gram = torch.cat([g3, g4], dim=1)
|
| 56 |
+
b_cov = torch.cat([c3, c4], dim=1)
|
| 57 |
+
b_sp = torch.cat([sp3, sp4], dim=1)
|
| 58 |
+
b_st = torch.cat([st3, st4], dim=1)
|
| 59 |
+
|
| 60 |
+
flat = torch.cat([b_gram, b_cov, b_sp, b_st], dim=1)
|
| 61 |
+
gate_logits = encoder.branch_gate(flat)
|
| 62 |
+
w = torch.softmax(gate_logits, dim=-1)
|
| 63 |
+
|
| 64 |
+
# w is [1, 4] for single image
|
| 65 |
+
w_np = w[0].cpu().numpy()
|
| 66 |
+
return {
|
| 67 |
+
"Gram": float(w_np[0]),
|
| 68 |
+
"Cov": float(w_np[1]),
|
| 69 |
+
"Spectrum": float(w_np[2]),
|
| 70 |
+
"Stats": float(w_np[3]),
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _compute_gradcam(
|
| 75 |
+
encoder,
|
| 76 |
+
x: torch.Tensor,
|
| 77 |
+
target_layer_name: str = "b3",
|
| 78 |
+
) -> np.ndarray:
|
| 79 |
+
"""
|
| 80 |
+
Compute Grad-CAM heatmap for a ViewEncoder.
|
| 81 |
+
Uses gradients of the output w.r.t. an intermediate feature map.
|
| 82 |
+
Returns a heatmap as numpy array [H, W] normalized to [0, 1].
|
| 83 |
+
"""
|
| 84 |
+
# Storage for activations and gradients
|
| 85 |
+
activations = {}
|
| 86 |
+
gradients = {}
|
| 87 |
+
|
| 88 |
+
def forward_hook(module, input, output):
|
| 89 |
+
activations["value"] = output.detach()
|
| 90 |
+
|
| 91 |
+
def backward_hook(module, grad_input, grad_output):
|
| 92 |
+
gradients["value"] = grad_output[0].detach()
|
| 93 |
+
|
| 94 |
+
# Get the target layer
|
| 95 |
+
target_layer = getattr(encoder, target_layer_name, None)
|
| 96 |
+
if target_layer is None:
|
| 97 |
+
# Fallback to b2 or b1
|
| 98 |
+
for fallback in ["b2", "b1", "stem"]:
|
| 99 |
+
target_layer = getattr(encoder, fallback, None)
|
| 100 |
+
if target_layer is not None:
|
| 101 |
+
break
|
| 102 |
+
|
| 103 |
+
if target_layer is None:
|
| 104 |
+
return np.zeros((x.shape[2], x.shape[3]), dtype=np.float32)
|
| 105 |
+
|
| 106 |
+
# Register hooks
|
| 107 |
+
fwd_handle = target_layer.register_forward_hook(forward_hook)
|
| 108 |
+
bwd_handle = target_layer.register_full_backward_hook(backward_hook)
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
# Forward pass
|
| 112 |
+
x.requires_grad_(True)
|
| 113 |
+
output = encoder(x)
|
| 114 |
+
|
| 115 |
+
# Backward pass - use the L2 norm of output as target
|
| 116 |
+
target = output.norm(dim=1).sum()
|
| 117 |
+
encoder.zero_grad()
|
| 118 |
+
target.backward(retain_graph=True)
|
| 119 |
+
|
| 120 |
+
# Get activations and gradients
|
| 121 |
+
acts = activations.get("value")
|
| 122 |
+
grads = gradients.get("value")
|
| 123 |
+
|
| 124 |
+
if acts is None or grads is None:
|
| 125 |
+
return np.zeros((x.shape[2], x.shape[3]), dtype=np.float32)
|
| 126 |
+
|
| 127 |
+
# Compute Grad-CAM weights (global average pooling of gradients)
|
| 128 |
+
weights = grads.mean(dim=(2, 3), keepdim=True) # [B, C, 1, 1]
|
| 129 |
+
|
| 130 |
+
# Weighted combination of activations
|
| 131 |
+
cam = (weights * acts).sum(dim=1, keepdim=True) # [B, 1, H, W]
|
| 132 |
+
cam = F.relu(cam) # Only positive contributions
|
| 133 |
+
|
| 134 |
+
# Normalize
|
| 135 |
+
cam = cam[0, 0].cpu().numpy()
|
| 136 |
+
if cam.max() > 0:
|
| 137 |
+
cam = cam / cam.max()
|
| 138 |
+
|
| 139 |
+
# Resize to input size
|
| 140 |
+
cam_pil = Image.fromarray((cam * 255).astype(np.uint8))
|
| 141 |
+
cam_pil = cam_pil.resize((x.shape[3], x.shape[2]), Image.BILINEAR)
|
| 142 |
+
cam = np.array(cam_pil).astype(np.float32) / 255.0
|
| 143 |
+
|
| 144 |
+
return cam
|
| 145 |
+
|
| 146 |
+
finally:
|
| 147 |
+
fwd_handle.remove()
|
| 148 |
+
bwd_handle.remove()
|
| 149 |
+
x.requires_grad_(False)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _overlay_heatmap(
|
| 153 |
+
image: Image.Image,
|
| 154 |
+
heatmap: np.ndarray,
|
| 155 |
+
alpha: float = 0.5,
|
| 156 |
+
colormap: str = "jet",
|
| 157 |
+
) -> Image.Image:
|
| 158 |
+
"""Overlay a heatmap on an image."""
|
| 159 |
+
import matplotlib.pyplot as plt
|
| 160 |
+
|
| 161 |
+
# Ensure heatmap is 2D and normalized
|
| 162 |
+
heatmap = np.clip(heatmap, 0, 1)
|
| 163 |
+
|
| 164 |
+
# Get colormap
|
| 165 |
+
cmap = plt.get_cmap(colormap)
|
| 166 |
+
heatmap_colored = cmap(heatmap)[:, :, :3] # RGB only, no alpha
|
| 167 |
+
heatmap_colored = (heatmap_colored * 255).astype(np.uint8)
|
| 168 |
+
|
| 169 |
+
# Resize heatmap to match image
|
| 170 |
+
heatmap_pil = Image.fromarray(heatmap_colored)
|
| 171 |
+
heatmap_pil = heatmap_pil.resize(image.size, Image.BILINEAR)
|
| 172 |
+
|
| 173 |
+
# Blend
|
| 174 |
+
image_rgb = image.convert("RGB")
|
| 175 |
+
blended = Image.blend(image_rgb, heatmap_pil, alpha)
|
| 176 |
+
|
| 177 |
+
return blended
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def analyze_views(
|
| 181 |
+
model: torch.nn.Module,
|
| 182 |
+
views: Dict[str, Optional[torch.Tensor]],
|
| 183 |
+
original_images: Dict[str, Optional[Image.Image]],
|
| 184 |
+
device: torch.device,
|
| 185 |
+
) -> ViewAnalysis:
|
| 186 |
+
"""
|
| 187 |
+
Perform full analysis on a set of views.
|
| 188 |
+
Returns view weights, branch weights per view, and Grad-CAM heatmaps.
|
| 189 |
+
"""
|
| 190 |
+
model.eval()
|
| 191 |
+
|
| 192 |
+
# Prepare masks
|
| 193 |
+
masks = {}
|
| 194 |
+
view_tensors = {}
|
| 195 |
+
for k in ("whole", "face", "eyes"):
|
| 196 |
+
if views.get(k) is not None:
|
| 197 |
+
view_tensors[k] = views[k].unsqueeze(0).to(device)
|
| 198 |
+
masks[k] = torch.ones(1, dtype=torch.bool, device=device)
|
| 199 |
+
else:
|
| 200 |
+
view_tensors[k] = None
|
| 201 |
+
masks[k] = torch.zeros(1, dtype=torch.bool, device=device)
|
| 202 |
+
|
| 203 |
+
# Get view attention weights from forward pass
|
| 204 |
+
with torch.no_grad():
|
| 205 |
+
_, _, W = model(view_tensors, masks)
|
| 206 |
+
|
| 207 |
+
# W is [1, num_present_views]
|
| 208 |
+
W_np = W[0].cpu().numpy()
|
| 209 |
+
|
| 210 |
+
# Map W back to view names (only present views have weights)
|
| 211 |
+
view_order = ["whole", "face", "eyes"]
|
| 212 |
+
present_views = [k for k in view_order if view_tensors[k] is not None]
|
| 213 |
+
|
| 214 |
+
view_weights = {}
|
| 215 |
+
for i, k in enumerate(present_views):
|
| 216 |
+
view_weights[k] = float(W_np[i])
|
| 217 |
+
for k in view_order:
|
| 218 |
+
if k not in view_weights:
|
| 219 |
+
view_weights[k] = 0.0
|
| 220 |
+
|
| 221 |
+
# Get branch weights and Grad-CAM for each view
|
| 222 |
+
branch_weights = {}
|
| 223 |
+
gradcam_heatmaps = {}
|
| 224 |
+
|
| 225 |
+
# Get encoder (shared or separate)
|
| 226 |
+
enc_whole = model.enc_whole
|
| 227 |
+
enc_face = model.enc_face
|
| 228 |
+
enc_eyes = model.enc_eyes
|
| 229 |
+
|
| 230 |
+
encoders = {"whole": enc_whole, "face": enc_face, "eyes": enc_eyes}
|
| 231 |
+
|
| 232 |
+
for k in view_order:
|
| 233 |
+
if view_tensors[k] is not None:
|
| 234 |
+
enc = encoders[k]
|
| 235 |
+
x = view_tensors[k]
|
| 236 |
+
|
| 237 |
+
# Branch weights
|
| 238 |
+
try:
|
| 239 |
+
branch_weights[k] = _get_branch_weights(enc, x)
|
| 240 |
+
except Exception:
|
| 241 |
+
branch_weights[k] = {"Gram": 0.25, "Cov": 0.25, "Spectrum": 0.25, "Stats": 0.25}
|
| 242 |
+
|
| 243 |
+
# Grad-CAM
|
| 244 |
+
try:
|
| 245 |
+
heatmap = _compute_gradcam(enc, x.clone(), target_layer_name="b3")
|
| 246 |
+
if original_images.get(k) is not None:
|
| 247 |
+
gradcam_heatmaps[k] = _overlay_heatmap(original_images[k], heatmap, alpha=0.5)
|
| 248 |
+
else:
|
| 249 |
+
gradcam_heatmaps[k] = None
|
| 250 |
+
except Exception:
|
| 251 |
+
gradcam_heatmaps[k] = None
|
| 252 |
+
else:
|
| 253 |
+
branch_weights[k] = {}
|
| 254 |
+
gradcam_heatmaps[k] = None
|
| 255 |
+
|
| 256 |
+
return ViewAnalysis(
|
| 257 |
+
view_weights=view_weights,
|
| 258 |
+
branch_weights=branch_weights,
|
| 259 |
+
gradcam_heatmaps=gradcam_heatmaps,
|
| 260 |
+
original_images={k: original_images.get(k) for k in view_order},
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def format_analysis_text(analysis: ViewAnalysis) -> str:
|
| 265 |
+
"""Format analysis results as markdown text."""
|
| 266 |
+
lines = ["## π View & Branch Analysis\n"]
|
| 267 |
+
|
| 268 |
+
# View weights
|
| 269 |
+
lines.append("### View Attention Weights")
|
| 270 |
+
lines.append("How much each view contributed to the final embedding:\n")
|
| 271 |
+
for k in ("whole", "face", "eyes"):
|
| 272 |
+
w = analysis.view_weights.get(k, 0.0)
|
| 273 |
+
bar_len = int(w * 20)
|
| 274 |
+
bar = "β" * bar_len + "β" * (20 - bar_len)
|
| 275 |
+
lines.append(f"- **{k.capitalize()}**: `{bar}` {w:.1%}")
|
| 276 |
+
|
| 277 |
+
lines.append("")
|
| 278 |
+
|
| 279 |
+
# Branch weights per view
|
| 280 |
+
lines.append("### Branch Attention Weights (per view)")
|
| 281 |
+
lines.append("Which style features were most important:\n")
|
| 282 |
+
branch_names = ["Gram", "Cov", "Spectrum", "Stats"]
|
| 283 |
+
branch_desc = {
|
| 284 |
+
"Gram": "texture patterns",
|
| 285 |
+
"Cov": "color correlations",
|
| 286 |
+
"Spectrum": "frequency content",
|
| 287 |
+
"Stats": "mean/variance",
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
for view_name in ("whole", "face", "eyes"):
|
| 291 |
+
bw = analysis.branch_weights.get(view_name, {})
|
| 292 |
+
if bw:
|
| 293 |
+
lines.append(f"\n**{view_name.capitalize()}**:")
|
| 294 |
+
for b in branch_names:
|
| 295 |
+
w = bw.get(b, 0.0)
|
| 296 |
+
bar_len = int(w * 15)
|
| 297 |
+
bar = "β" * bar_len + "β" * (15 - bar_len)
|
| 298 |
+
lines.append(f" - {b} ({branch_desc[b]}): `{bar}` {w:.1%}")
|
| 299 |
+
|
| 300 |
+
return "\n".join(lines)
|
| 301 |
+
|
webui_gradio.py
CHANGED
|
@@ -166,6 +166,7 @@ _patch_gradio_client_bool_jsonschema()
|
|
| 166 |
from app.model_io import LoadedModel, embed_triview, load_style_model
|
| 167 |
from app.proto_db import PrototypeDB, load_prototype_db, topk_predictions_unique_labels
|
| 168 |
from app.view_extractor import AnimeFaceEyeExtractor, ExtractorCfg
|
|
|
|
| 169 |
|
| 170 |
|
| 171 |
ROOT = Path(__file__).resolve().parent
|
|
@@ -316,6 +317,65 @@ def classify(
|
|
| 316 |
return "β
OK", rows, (face_pil if "face_pil" in locals() else None), (eyes_pil if "eyes_pil" in locals() else None)
|
| 317 |
|
| 318 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
def _gallery_item_to_pil(item) -> Optional[Image.Image]:
|
| 320 |
"""Convert a Gradio gallery item to PIL Image (handles various formats)."""
|
| 321 |
if item is None:
|
|
@@ -520,6 +580,37 @@ def build_ui() -> gr.Blocks:
|
|
| 520 |
table = gr.Dataframe(headers=["label", "cosine_sim"], datatype=["str", "number"], interactive=False)
|
| 521 |
run_btn.click(classify, inputs=[whole, topk], outputs=[out_status, table, face_prev, eyes_prev])
|
| 522 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
with gr.Tab("Add prototype (temporary)"):
|
| 524 |
gr.Markdown(
|
| 525 |
"### β οΈ Temporary Prototypes Only\n"
|
|
|
|
| 166 |
from app.model_io import LoadedModel, embed_triview, load_style_model
|
| 167 |
from app.proto_db import PrototypeDB, load_prototype_db, topk_predictions_unique_labels
|
| 168 |
from app.view_extractor import AnimeFaceEyeExtractor, ExtractorCfg
|
| 169 |
+
from app.visualization import ViewAnalysis, analyze_views, format_analysis_text
|
| 170 |
|
| 171 |
|
| 172 |
ROOT = Path(__file__).resolve().parent
|
|
|
|
| 317 |
return "β
OK", rows, (face_pil if "face_pil" in locals() else None), (eyes_pil if "eyes_pil" in locals() else None)
|
| 318 |
|
| 319 |
|
| 320 |
+
def analyze_image(whole_img):
|
| 321 |
+
"""
|
| 322 |
+
Analyze an image showing view weights, branch weights, and Grad-CAM.
|
| 323 |
+
Returns: status, analysis_text, whole_gradcam, face_gradcam, eyes_gradcam, face_preview, eyes_preview
|
| 324 |
+
"""
|
| 325 |
+
if APP_STATE.lm is None:
|
| 326 |
+
return "β Click **Load** first.", "", None, None, None, None, None
|
| 327 |
+
|
| 328 |
+
lm = APP_STATE.lm
|
| 329 |
+
ex = APP_STATE.extractor
|
| 330 |
+
|
| 331 |
+
def _to_pil(x):
|
| 332 |
+
if x is None:
|
| 333 |
+
return None
|
| 334 |
+
if isinstance(x, Image.Image):
|
| 335 |
+
return x
|
| 336 |
+
return Image.fromarray(x)
|
| 337 |
+
|
| 338 |
+
w = _to_pil(whole_img)
|
| 339 |
+
if w is None:
|
| 340 |
+
return "β Provide a whole image.", "", None, None, None, None, None
|
| 341 |
+
|
| 342 |
+
try:
|
| 343 |
+
# Extract face and eyes
|
| 344 |
+
face_pil = None
|
| 345 |
+
eyes_pil = None
|
| 346 |
+
if ex is not None:
|
| 347 |
+
rgb = np.array(w.convert("RGB"))
|
| 348 |
+
face_rgb, eyes_rgb = ex.extract(rgb)
|
| 349 |
+
if face_rgb is not None:
|
| 350 |
+
face_pil = Image.fromarray(face_rgb)
|
| 351 |
+
if eyes_rgb is not None:
|
| 352 |
+
eyes_pil = Image.fromarray(eyes_rgb)
|
| 353 |
+
|
| 354 |
+
# Prepare tensors
|
| 355 |
+
wt = _pil_to_tensor(w, lm.T_w)
|
| 356 |
+
ft = _pil_to_tensor(face_pil, lm.T_f) if face_pil is not None else None
|
| 357 |
+
et = _pil_to_tensor(eyes_pil, lm.T_e) if eyes_pil is not None else None
|
| 358 |
+
|
| 359 |
+
views = {"whole": wt, "face": ft, "eyes": et}
|
| 360 |
+
original_images = {"whole": w, "face": face_pil, "eyes": eyes_pil}
|
| 361 |
+
|
| 362 |
+
# Run analysis
|
| 363 |
+
analysis = analyze_views(lm.model, views, original_images, lm.device)
|
| 364 |
+
analysis_text = format_analysis_text(analysis)
|
| 365 |
+
|
| 366 |
+
return (
|
| 367 |
+
"β
Analysis complete",
|
| 368 |
+
analysis_text,
|
| 369 |
+
analysis.gradcam_heatmaps.get("whole"),
|
| 370 |
+
analysis.gradcam_heatmaps.get("face"),
|
| 371 |
+
analysis.gradcam_heatmaps.get("eyes"),
|
| 372 |
+
face_pil,
|
| 373 |
+
eyes_pil,
|
| 374 |
+
)
|
| 375 |
+
except Exception as e:
|
| 376 |
+
return f"β Analysis failed: {e}", "", None, None, None, None, None
|
| 377 |
+
|
| 378 |
+
|
| 379 |
def _gallery_item_to_pil(item) -> Optional[Image.Image]:
|
| 380 |
"""Convert a Gradio gallery item to PIL Image (handles various formats)."""
|
| 381 |
if item is None:
|
|
|
|
| 580 |
table = gr.Dataframe(headers=["label", "cosine_sim"], datatype=["str", "number"], interactive=False)
|
| 581 |
run_btn.click(classify, inputs=[whole, topk], outputs=[out_status, table, face_prev, eyes_prev])
|
| 582 |
|
| 583 |
+
with gr.Tab("Analyze (Grad-CAM)"):
|
| 584 |
+
gr.Markdown(
|
| 585 |
+
"### π View & Branch Analysis with Grad-CAM\n"
|
| 586 |
+
"Visualize which parts of the image and which style features the model focuses on.\n"
|
| 587 |
+
"- **View weights**: How much each view (whole/face/eyes) contributed\n"
|
| 588 |
+
"- **Branch weights**: Which style features (Gram/Cov/Spectrum/Stats) were important\n"
|
| 589 |
+
"- **Grad-CAM**: Spatial attention heatmaps showing where the model looked"
|
| 590 |
+
)
|
| 591 |
+
with gr.Row():
|
| 592 |
+
analyze_input = gr.Image(label="Whole image", type="pil")
|
| 593 |
+
analyze_btn = gr.Button("Analyze", variant="primary")
|
| 594 |
+
analyze_status = gr.Markdown("")
|
| 595 |
+
analyze_text = gr.Markdown("")
|
| 596 |
+
|
| 597 |
+
gr.Markdown("### Grad-CAM Heatmaps")
|
| 598 |
+
with gr.Row():
|
| 599 |
+
gcam_whole = gr.Image(label="Whole (Grad-CAM)", type="pil")
|
| 600 |
+
gcam_face = gr.Image(label="Face (Grad-CAM)", type="pil")
|
| 601 |
+
gcam_eyes = gr.Image(label="Eyes (Grad-CAM)", type="pil")
|
| 602 |
+
|
| 603 |
+
gr.Markdown("### Extracted Views")
|
| 604 |
+
with gr.Row():
|
| 605 |
+
analyze_face = gr.Image(label="Extracted Face", type="pil")
|
| 606 |
+
analyze_eyes = gr.Image(label="Extracted Eyes", type="pil")
|
| 607 |
+
|
| 608 |
+
analyze_btn.click(
|
| 609 |
+
analyze_image,
|
| 610 |
+
inputs=[analyze_input],
|
| 611 |
+
outputs=[analyze_status, analyze_text, gcam_whole, gcam_face, gcam_eyes, analyze_face, analyze_eyes],
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
with gr.Tab("Add prototype (temporary)"):
|
| 615 |
gr.Markdown(
|
| 616 |
"### β οΈ Temporary Prototypes Only\n"
|