|
|
""" |
|
|
heatmap.py |
|
|
---------- |
|
|
|
|
|
Grad-ECLIP visual explanations for CLIP/PaintingCLIP models. |
|
|
Generates heatmap overlays showing which image regions contribute to image-text similarity. |
|
|
|
|
|
Based on "Gradient-based Visual Explanation for Transformer-based CLIP" |
|
|
by Zhao et al. (ICML 2024) |
|
|
|
|
|
Public entry point: |
|
|
------------------ |
|
|
generate_heatmap( |
|
|
image, # str | PIL.Image.Image |
|
|
sentence, # caption text |
|
|
model, # CLIPModel or PEFT-wrapped model |
|
|
processor, # CLIPProcessor |
|
|
device, # torch.device |
|
|
*, |
|
|
layer_idx: int = -1, # which visual transformer block to explain |
|
|
alpha: float = 0.45, # overlay opacity |
|
|
colormap: int = cv2.COLORMAP_JET, |
|
|
resize: Optional[Tuple[int, int]] = None, |
|
|
) -> PIL.Image.Image # RGB overlay for display |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Any, Dict, Optional, Tuple, Union |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image |
|
|
from transformers import CLIPModel, CLIPProcessor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _GradECLIPHooks: |
|
|
""" |
|
|
Context manager for forward/backward hooks to capture Grad-ECLIP components. |
|
|
""" |
|
|
|
|
|
def __init__(self, model: CLIPModel, layer_idx: int): |
|
|
self.model = model |
|
|
self.layer_idx = layer_idx |
|
|
self.captures: Dict[str, Any] = {} |
|
|
self.handles = [] |
|
|
|
|
|
def __enter__(self): |
|
|
|
|
|
vision_layers = self.model.vision_model.encoder.layers |
|
|
if self.layer_idx < 0: |
|
|
self.layer_idx = len(vision_layers) + self.layer_idx |
|
|
self.target_layer = vision_layers[self.layer_idx] |
|
|
|
|
|
|
|
|
self._register_forward_hook() |
|
|
self._register_backward_hook() |
|
|
|
|
|
return self |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
|
|
|
for handle in self.handles: |
|
|
handle.remove() |
|
|
self.handles.clear() |
|
|
|
|
|
def _register_forward_hook(self): |
|
|
"""Register forward hook to capture Q, K, V and attention weights.""" |
|
|
|
|
|
def forward_hook(module, input, output): |
|
|
if len(input) > 0: |
|
|
hidden_states = input[0] |
|
|
|
|
|
|
|
|
x = hidden_states |
|
|
if hasattr(module.self_attn, "layer_norm"): |
|
|
x = module.self_attn.layer_norm(x) |
|
|
|
|
|
|
|
|
if hasattr(module.self_attn, "q_proj"): |
|
|
batch_size, seq_len, hidden_dim = x.shape |
|
|
|
|
|
Q = module.self_attn.q_proj(x) |
|
|
K = module.self_attn.k_proj(x) |
|
|
V = module.self_attn.v_proj(x) |
|
|
|
|
|
|
|
|
self.captures["V"] = V |
|
|
self.captures["hidden_states_pre"] = hidden_states |
|
|
|
|
|
|
|
|
head_dim = hidden_dim // module.self_attn.num_heads |
|
|
num_heads = module.self_attn.num_heads |
|
|
|
|
|
|
|
|
Q_heads = Q.view( |
|
|
batch_size, seq_len, num_heads, head_dim |
|
|
).transpose(1, 2) |
|
|
K_heads = K.view( |
|
|
batch_size, seq_len, num_heads, head_dim |
|
|
).transpose(1, 2) |
|
|
|
|
|
|
|
|
scale = head_dim**-0.5 |
|
|
attn_weights = ( |
|
|
torch.matmul(Q_heads, K_heads.transpose(-2, -1)) * scale |
|
|
) |
|
|
attn_weights = torch.softmax(attn_weights, dim=-1) |
|
|
|
|
|
|
|
|
self.captures["Q"] = Q_heads |
|
|
self.captures["K"] = K_heads |
|
|
self.captures["attn_weights"] = attn_weights.mean( |
|
|
dim=1 |
|
|
) |
|
|
|
|
|
handle = self.target_layer.register_forward_hook(forward_hook) |
|
|
self.handles.append(handle) |
|
|
|
|
|
def _register_backward_hook(self): |
|
|
"""Register backward hook to capture gradients.""" |
|
|
|
|
|
def backward_hook(module, grad_input, grad_output): |
|
|
if len(grad_output) > 0: |
|
|
self.captures["grad_attn"] = grad_output[0] |
|
|
|
|
|
handle = self.target_layer.register_full_backward_hook(backward_hook) |
|
|
self.handles.append(handle) |
|
|
|
|
|
def get_captures(self) -> Dict[str, torch.Tensor]: |
|
|
"""Return captured tensors.""" |
|
|
return self.captures |
|
|
|
|
|
|
|
|
def _compute_gradeclip_importance( |
|
|
captures: Dict[str, torch.Tensor], |
|
|
use_k_similarity: bool = True, |
|
|
device: torch.device = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Compute Grad-ECLIP importance scores from captured tensors. |
|
|
|
|
|
Args: |
|
|
captures: Dictionary with captured tensors from hooks |
|
|
use_k_similarity: Whether to use Q-K similarity weighting |
|
|
device: Computation device |
|
|
|
|
|
Returns: |
|
|
Importance scores for each patch (excluding CLS token) |
|
|
""" |
|
|
|
|
|
V = captures.get("V") |
|
|
grad_attn = captures.get("grad_attn") |
|
|
attn_weights = captures.get("attn_weights") |
|
|
|
|
|
if V is None or grad_attn is None: |
|
|
raise ValueError("Missing required captures for Grad-ECLIP computation") |
|
|
|
|
|
|
|
|
grad_cls = grad_attn[0, 0, :] |
|
|
|
|
|
|
|
|
V_patches = V[0, 1:, :] |
|
|
num_patches = V_patches.shape[0] |
|
|
|
|
|
|
|
|
if attn_weights is not None: |
|
|
|
|
|
cls_attn = attn_weights[0, 0, 1 : num_patches + 1] |
|
|
else: |
|
|
|
|
|
cls_attn = torch.ones(num_patches, device=device or V.device) / num_patches |
|
|
|
|
|
|
|
|
if use_k_similarity and "Q" in captures and "K" in captures: |
|
|
Q = captures["Q"] |
|
|
K = captures["K"] |
|
|
|
|
|
|
|
|
q_cls = Q[:, :, 0:1, :].mean(dim=1) |
|
|
k_patches = K[:, :, 1:, :].mean(dim=1) |
|
|
|
|
|
|
|
|
q_cls = F.normalize(q_cls, dim=-1) |
|
|
k_patches = F.normalize(k_patches, dim=-1) |
|
|
|
|
|
k_similarity = torch.matmul(q_cls, k_patches.transpose(-2, -1)).squeeze() |
|
|
|
|
|
k_similarity = (k_similarity - k_similarity.min()) / ( |
|
|
k_similarity.max() - k_similarity.min() + 1e-8 |
|
|
) |
|
|
|
|
|
|
|
|
cls_attn = cls_attn * k_similarity[:num_patches] |
|
|
|
|
|
|
|
|
importance = (grad_cls * V_patches).sum(dim=-1) |
|
|
importance = importance * cls_attn |
|
|
importance = torch.relu(importance) |
|
|
|
|
|
return importance |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_heatmap( |
|
|
image: Union[str, Image.Image], |
|
|
sentence: str, |
|
|
model: CLIPModel, |
|
|
processor: CLIPProcessor, |
|
|
device: torch.device, |
|
|
*, |
|
|
layer_idx: int = -1, |
|
|
alpha: float = 0.45, |
|
|
colormap: int = cv2.COLORMAP_JET, |
|
|
resize: Optional[Tuple[int, int]] = None, |
|
|
) -> Image.Image: |
|
|
""" |
|
|
Generate Grad-ECLIP heatmap overlay for image-text pair. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
image : str or PIL.Image |
|
|
Input image path or PIL Image object |
|
|
sentence : str |
|
|
Text description to explain |
|
|
model : CLIPModel |
|
|
Pre-loaded CLIP model (possibly with LoRA adapter) |
|
|
processor : CLIPProcessor |
|
|
CLIP processor for preprocessing |
|
|
device : torch.device |
|
|
Computation device |
|
|
layer_idx : int, optional |
|
|
Which vision transformer layer to analyze (default: -1 for last layer) |
|
|
alpha : float, optional |
|
|
Heatmap overlay opacity (default: 0.45) |
|
|
colormap : int, optional |
|
|
OpenCV colormap for visualization (default: COLORMAP_JET) |
|
|
resize : tuple, optional |
|
|
Target (width, height) for output image |
|
|
|
|
|
Returns |
|
|
------- |
|
|
PIL.Image |
|
|
RGB image with heatmap overlay |
|
|
""" |
|
|
|
|
|
if isinstance(image, str): |
|
|
pil_image = Image.open(image).convert("RGB") |
|
|
else: |
|
|
pil_image = image.convert("RGB") |
|
|
|
|
|
|
|
|
orig_size = pil_image.size |
|
|
|
|
|
|
|
|
if resize: |
|
|
display_image = pil_image.resize(resize, Image.Resampling.BICUBIC) |
|
|
else: |
|
|
display_image = pil_image |
|
|
|
|
|
|
|
|
inputs = processor( |
|
|
images=pil_image, text=sentence, return_tensors="pt", padding=True |
|
|
) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
model_requires_grad = [p.requires_grad for p in model.parameters()] |
|
|
for param in model.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
try: |
|
|
|
|
|
with torch.set_grad_enabled(True): |
|
|
with _GradECLIPHooks(model, layer_idx) as hooks: |
|
|
|
|
|
outputs = model(**inputs, output_attentions=False) |
|
|
|
|
|
|
|
|
image_embeds = F.normalize(outputs.image_embeds, dim=-1) |
|
|
text_embeds = F.normalize(outputs.text_embeds, dim=-1) |
|
|
|
|
|
|
|
|
similarity = (image_embeds @ text_embeds.T).squeeze() |
|
|
|
|
|
|
|
|
model.zero_grad() |
|
|
similarity.backward(retain_graph=False) |
|
|
|
|
|
|
|
|
captures = hooks.get_captures() |
|
|
|
|
|
|
|
|
importance = _compute_gradeclip_importance( |
|
|
captures, use_k_similarity=True, device=device |
|
|
) |
|
|
|
|
|
|
|
|
num_patches = importance.shape[0] |
|
|
grid_size = int(np.sqrt(num_patches)) |
|
|
importance_map = importance.reshape(grid_size, grid_size) |
|
|
|
|
|
|
|
|
saliency_map = importance_map.detach().cpu().numpy() |
|
|
saliency_map = saliency_map - saliency_map.min() |
|
|
saliency_map = saliency_map / (saliency_map.max() + 1e-8) |
|
|
|
|
|
|
|
|
saliency_resized = cv2.resize( |
|
|
saliency_map, |
|
|
display_image.size, |
|
|
interpolation=cv2.INTER_CUBIC, |
|
|
) |
|
|
|
|
|
|
|
|
heatmap_uint8 = (saliency_resized * 255).astype(np.uint8) |
|
|
heatmap_bgr = cv2.applyColorMap(heatmap_uint8, colormap) |
|
|
heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
img_array = np.array(display_image).astype(np.float32) |
|
|
overlay = (1 - alpha) * img_array + alpha * heatmap_rgb |
|
|
overlay = np.clip(overlay, 0, 255).astype(np.uint8) |
|
|
|
|
|
return Image.fromarray(overlay, mode="RGB") |
|
|
|
|
|
finally: |
|
|
|
|
|
for param, requires_grad in zip(model.parameters(), model_requires_grad): |
|
|
param.requires_grad = requires_grad |
|
|
|
|
|
|
|
|
__all__ = ["generate_heatmap"] |
|
|
|