|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from torchvision.transforms.functional import resize |
|
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
import cv2 |
|
|
from transformers.models.vit.modeling_vit import ViTModel |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
class GradCAM: |
|
|
def __init__(self, vision_encoder): |
|
|
self.model = vision_encoder.model |
|
|
self.target_layer = self._find_last_conv_layer() |
|
|
|
|
|
self.activations = None |
|
|
self.gradients = None |
|
|
|
|
|
self.target_layer.register_forward_hook(self._hook_forward) |
|
|
self.target_layer.register_backward_hook(self._hook_backward) |
|
|
|
|
|
def _find_last_conv_layer(self): |
|
|
for module in reversed(list(self.model.modules())): |
|
|
if isinstance(module, torch.nn.Conv2d): |
|
|
return module |
|
|
raise RuntimeError("No Conv2D layer found for Grad-CAM.") |
|
|
|
|
|
def _hook_forward(self, module, inp, out): |
|
|
self.activations = out.detach() |
|
|
|
|
|
def _hook_backward(self, module, grad_in, grad_out): |
|
|
self.gradients = grad_out[0].detach() |
|
|
|
|
|
def generate(self, image_tensor): |
|
|
self.model.zero_grad() |
|
|
|
|
|
out = self.model(image_tensor) |
|
|
|
|
|
if out.ndim == 4: |
|
|
pooled = out.mean(dim=[2, 3]) |
|
|
elif out.ndim == 3: |
|
|
pooled = out.mean(dim=1) |
|
|
else: |
|
|
pooled = out |
|
|
|
|
|
score = pooled.norm() |
|
|
score.backward() |
|
|
|
|
|
weights = self.gradients.mean(dim=(2, 3), keepdim=True) |
|
|
cam = (weights * self.activations).sum(dim=1).squeeze() |
|
|
|
|
|
cam = F.relu(cam) |
|
|
cam -= cam.min() |
|
|
cam /= cam.max() + 1e-8 |
|
|
|
|
|
return cam.cpu().numpy() |
|
|
|
|
|
def save(self, img_tensor, save_path): |
|
|
cam = self.generate(img_tensor) |
|
|
|
|
|
img_np = img_tensor[0].permute(1, 2, 0).cpu().numpy() |
|
|
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min()) |
|
|
|
|
|
cam_resized = cv2.resize(cam, (img_np.shape[1], img_np.shape[0])) |
|
|
|
|
|
plt.figure(figsize=(6, 6)) |
|
|
plt.imshow(img_np) |
|
|
plt.imshow(cam_resized, cmap="inferno", alpha=0.45) |
|
|
plt.axis("off") |
|
|
plt.tight_layout() |
|
|
plt.savefig(save_path, dpi=300, bbox_inches="tight") |
|
|
plt.close() |
|
|
|
|
|
print(f"[GradCAM] Saved to {save_path}") |
|
|
|
|
|
|
|
|
def get_vit_self_attention(model, image_tensor): |
|
|
vision = model.vision_encoder |
|
|
|
|
|
if "Resnet" in type(vision).__name__: |
|
|
return None |
|
|
|
|
|
|
|
|
if hasattr(vision, "model"): |
|
|
|
|
|
if hasattr(vision.model, "vision_model"): |
|
|
hf_vit = vision.model.vision_model |
|
|
|
|
|
outputs = hf_vit( |
|
|
pixel_values=image_tensor, |
|
|
output_attentions=True, |
|
|
return_dict=True, |
|
|
) |
|
|
return outputs.attentions |
|
|
|
|
|
|
|
|
if isinstance(vision.model, ViTModel): |
|
|
|
|
|
outputs = vision.model( |
|
|
pixel_values=image_tensor, |
|
|
output_attentions=True, |
|
|
return_dict=True, |
|
|
) |
|
|
return outputs.attentions |
|
|
|
|
|
raise ValueError("Vision encoder does not expose ViT attentions.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def attention_rollout(attn_mats, discard_ratio=0.0): |
|
|
|
|
|
device = attn_mats[0].device |
|
|
result = torch.eye(attn_mats[0].size(-1), device=device) |
|
|
|
|
|
for attn in attn_mats: |
|
|
attn = attn.mean(dim=0) |
|
|
|
|
|
if discard_ratio > 0: |
|
|
flat = attn.view(-1) |
|
|
threshold = flat.topk(int(flat.numel() * discard_ratio), largest=False)[0].max() |
|
|
attn = torch.where(attn < threshold, torch.zeros_like(attn), attn) |
|
|
|
|
|
attn = attn / attn.sum(dim=-1, keepdim=True) |
|
|
result = attn @ result |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def rollout_to_image(rollout, image_size): |
|
|
tokens = rollout.size(0) |
|
|
num_patches = int((tokens - 1) ** 0.5) |
|
|
|
|
|
spatial = rollout[0, 1:].reshape(num_patches, num_patches) |
|
|
spatial = (spatial - spatial.min()) / (spatial.max() - spatial.min()) |
|
|
|
|
|
spatial = resize( |
|
|
spatial.unsqueeze(0).unsqueeze(0), |
|
|
(image_size, image_size) |
|
|
) |
|
|
|
|
|
return spatial.squeeze().detach().cpu().numpy() |
|
|
|
|
|
|
|
|
def plot_attention_overlay(image, heatmap, alpha=0.45): |
|
|
if torch.is_tensor(image): |
|
|
image = image.permute(1,2,0).cpu().numpy() |
|
|
|
|
|
image = (image - image.min()) / (image.max() - image.min()) |
|
|
|
|
|
plt.figure(figsize=(6,6)) |
|
|
plt.imshow(image) |
|
|
plt.imshow(heatmap, cmap='inferno', alpha=alpha) |
|
|
plt.axis("off") |
|
|
plt.show() |
|
|
|
|
|
|
|
|
|
|
|
def token_gradient_map(model, tokenizer, image_tensor, target_word, device="cuda"): |
|
|
model.eval() |
|
|
|
|
|
image_tensor = image_tensor.to(device) |
|
|
image_tensor.requires_grad_(True) |
|
|
|
|
|
vision_out = model.vision_encoder(image_tensor) |
|
|
img_embeds = vision_out["image_embeds"] |
|
|
|
|
|
if img_embeds.dim() == 2: |
|
|
img_embeds = img_embeds.unsqueeze(1) |
|
|
|
|
|
projected = model.projector(img_embeds) |
|
|
|
|
|
encoder_outputs = BaseModelOutput(last_hidden_state=projected) |
|
|
|
|
|
start = model.t5.config.decoder_start_token_id |
|
|
decoder_input_ids = torch.tensor([[start]], device=device) |
|
|
|
|
|
outputs = model.t5( |
|
|
encoder_outputs=encoder_outputs, |
|
|
decoder_input_ids=decoder_input_ids, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
logits = outputs.logits[:, -1, :] |
|
|
target_id = tokenizer.convert_tokens_to_ids(target_word) |
|
|
logit = logits[0, target_id] |
|
|
|
|
|
logit.backward() |
|
|
|
|
|
grad = image_tensor.grad.abs().mean(dim=1).squeeze().cpu().numpy() |
|
|
grad = (grad - grad.min()) / (grad.max() - grad.min() + 1e-8) |
|
|
|
|
|
return grad |
|
|
|
|
|
|
|
|
def attngrad(model, tokenizer, image_tensor, target_word, image_size=224, device="cuda"): |
|
|
|
|
|
raw_attns = get_vit_self_attention(model, image_tensor.to(device)) |
|
|
attn_mats = [a[0] for a in raw_attns] |
|
|
|
|
|
rollout = attention_rollout(attn_mats) |
|
|
roll_map = rollout_to_image(rollout, image_size) |
|
|
|
|
|
grad_map = token_gradient_map(model, tokenizer, image_tensor, target_word, device) |
|
|
|
|
|
combined = roll_map * grad_map |
|
|
combined = (combined - combined.min()) / (combined.max() - combined.min()) |
|
|
return combined |
|
|
|
|
|
|
|
|
|
|
|
def token_gradient_map_smooth(model, tokenizer, image_tensor, target_word, sigma=5, device="cuda"): |
|
|
model.eval() |
|
|
|
|
|
image_tensor = image_tensor.to(device) |
|
|
image_tensor.requires_grad_(True) |
|
|
|
|
|
|
|
|
vision_out = model.vision_encoder(image_tensor) |
|
|
img_embeds = vision_out["image_embeds"] |
|
|
if img_embeds.dim() == 2: |
|
|
img_embeds = img_embeds.unsqueeze(1) |
|
|
|
|
|
projected = model.projector(img_embeds) |
|
|
|
|
|
|
|
|
encoder_outputs = BaseModelOutput(last_hidden_state=projected) |
|
|
|
|
|
|
|
|
start_token = model.t5.config.decoder_start_token_id |
|
|
|
|
|
decoder_input_ids = torch.tensor( |
|
|
[[start_token]], device=device, dtype=torch.long |
|
|
) |
|
|
attention_mask = torch.tensor([[1]], device=device) |
|
|
|
|
|
outputs = model.t5( |
|
|
encoder_outputs=encoder_outputs, |
|
|
decoder_input_ids=decoder_input_ids, |
|
|
attention_mask=attention_mask, |
|
|
output_attentions=False, |
|
|
output_hidden_states=False, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
vocab_logits = outputs.logits[:, -1, :] |
|
|
target_id = tokenizer.convert_tokens_to_ids(target_word) |
|
|
logit = vocab_logits[0, target_id] |
|
|
|
|
|
logit.backward() |
|
|
|
|
|
grad = image_tensor.grad.data.abs().mean(dim=1).squeeze().cpu().numpy() |
|
|
grad = (grad - grad.min()) / (grad.max() - grad.min() + 1e-8) |
|
|
|
|
|
grad_smooth = smooth_heatmap(grad, sigma=sigma) |
|
|
return grad_smooth |
|
|
|
|
|
|
|
|
def integrated_gradients( |
|
|
model, |
|
|
tokenizer, |
|
|
image_tensor, |
|
|
target_word, |
|
|
steps=30, |
|
|
device="cuda" |
|
|
): |
|
|
model.eval() |
|
|
device = torch.device(device) |
|
|
|
|
|
image_tensor = image_tensor.to(device) |
|
|
image_tensor.requires_grad_(True) |
|
|
|
|
|
baseline = torch.zeros_like(image_tensor) |
|
|
|
|
|
target_id = tokenizer.convert_tokens_to_ids(target_word) |
|
|
|
|
|
total_grad = torch.zeros_like(image_tensor) |
|
|
|
|
|
for i in range(1, steps + 1): |
|
|
alpha = i / steps |
|
|
|
|
|
img = baseline + alpha * (image_tensor - baseline) |
|
|
img.requires_grad_(True) |
|
|
|
|
|
vision_out = model.vision_encoder(img) |
|
|
img_embeds = vision_out["image_embeds"] |
|
|
if img_embeds.dim() == 2: |
|
|
img_embeds = img_embeds.unsqueeze(1) |
|
|
|
|
|
projected = model.projector(img_embeds) |
|
|
encoder_outputs = BaseModelOutput(last_hidden_state=projected) |
|
|
|
|
|
start_token = model.t5.config.decoder_start_token_id |
|
|
decoder_input_ids = torch.tensor([[start_token]], device=device) |
|
|
attention_mask = torch.tensor([[1]], device=device) |
|
|
|
|
|
outputs = model.t5( |
|
|
encoder_outputs=encoder_outputs, |
|
|
decoder_input_ids=decoder_input_ids, |
|
|
attention_mask=attention_mask, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
vocab_logits = outputs.logits[:, -1, :] |
|
|
logit = vocab_logits[0, target_id] |
|
|
|
|
|
grads = torch.autograd.grad( |
|
|
outputs=logit, |
|
|
inputs=img, |
|
|
retain_graph=True, |
|
|
create_graph=False, |
|
|
allow_unused=True, |
|
|
)[0] |
|
|
|
|
|
if grads is None: |
|
|
raise RuntimeError("Integrated gradients: grad is None — gradient path was broken.") |
|
|
|
|
|
total_grad += grads |
|
|
|
|
|
avg_grad = total_grad / steps |
|
|
|
|
|
heat = avg_grad.abs().mean(dim=1).squeeze().cpu().numpy() |
|
|
heat = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8) |
|
|
|
|
|
return heat |
|
|
|
|
|
|
|
|
|
|
|
def smooth_heatmap(hm, k=21, sigma=6): |
|
|
hm = cv2.GaussianBlur(hm, (k, k), sigma) |
|
|
hm = (hm - hm.min()) / (hm.max() - hm.min() + 1e-8) |
|
|
return hm |
|
|
|
|
|
|
|
|
def get_cross_attention(model, encoder_outputs, decoder_input_ids, device="cuda"): |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
outputs = model.t5( |
|
|
encoder_outputs=encoder_outputs, |
|
|
decoder_input_ids=decoder_input_ids.to(device), |
|
|
output_attentions=True, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
|
|
|
cross = outputs.cross_attentions |
|
|
attn_layers = [c[0] for c in cross] |
|
|
return attn_layers |
|
|
|
|
|
|
|
|
""" |
|
|
def cross_attention_to_image(attn, image_size=224): |
|
|
|
|
|
attn = attn.mean(dim=0) # (tgt_len, src_len) |
|
|
|
|
|
attn = attn[-1] # (src_len,) |
|
|
|
|
|
attn = attn[1:] |
|
|
|
|
|
num_patches = int(attn.numel() ** 0.5) |
|
|
heat = attn.reshape(num_patches, num_patches) |
|
|
|
|
|
heat = heat - heat.min() |
|
|
heat = heat / (heat.max() + 1e-8) |
|
|
|
|
|
heat = resize( |
|
|
heat.unsqueeze(0).unsqueeze(0), |
|
|
(image_size, image_size) |
|
|
).squeeze() |
|
|
|
|
|
return heat.detach().cpu().numpy() |
|
|
""" |
|
|
|
|
|
def cross_attention_to_image(attn): |
|
|
|
|
|
attn = torch.tensor(attn) if not torch.is_tensor(attn) else attn |
|
|
|
|
|
if attn.numel() == 0: |
|
|
return np.zeros((14, 14), dtype=np.float32) |
|
|
|
|
|
if attn.dim() == 2: |
|
|
attn_vec = attn[-1] |
|
|
elif attn.dim() == 1: |
|
|
attn_vec = attn |
|
|
else: |
|
|
raise ValueError(f"Unexpected attn shape: {attn.shape}") |
|
|
|
|
|
|
|
|
if attn_vec.size(0) == 197: |
|
|
attn_vec = attn_vec[1:] |
|
|
|
|
|
src_len = attn_vec.size(0) |
|
|
side = int(src_len**0.5) |
|
|
|
|
|
if side * side != src_len: |
|
|
new_len = side * side |
|
|
padded = torch.zeros(new_len, device=attn_vec.device) |
|
|
padded[:min(new_len, src_len)] = attn_vec[:min(new_len, src_len)] |
|
|
attn_vec = padded |
|
|
|
|
|
attn_vec = attn_vec / (attn_vec.max() + 1e-8) |
|
|
|
|
|
heatmap = attn_vec.reshape(side, side).cpu().numpy() |
|
|
|
|
|
return heatmap |
|
|
|
|
|
|
|
|
|
|
|
def plot_cross_attention_overlay(image_tensor, heatmap, save_path=None, alpha=0.45): |
|
|
img = image_tensor[0].permute(1,2,0).cpu().numpy() |
|
|
img = (img - img.min()) / (img.max() - img.min()) |
|
|
|
|
|
plt.figure(figsize=(6,6)) |
|
|
plt.imshow(img) |
|
|
plt.imshow(heatmap, cmap='inferno', alpha=alpha) |
|
|
plt.axis("off") |
|
|
|
|
|
if save_path: |
|
|
plt.savefig(save_path, dpi=300, bbox_inches="tight") |
|
|
plt.close() |
|
|
print(f"[CrossAttention] Saved to {save_path}") |
|
|
else: |
|
|
plt.show() |
|
|
|
|
|
|
|
|
def visualize_cross_attention(model, tokenizer, image_tensor, word, device="cuda"): |
|
|
device = torch.device(device) |
|
|
image_tensor = image_tensor.to(device) |
|
|
|
|
|
vision_out = model.vision_encoder(image_tensor) |
|
|
img_embeds = vision_out["image_embeds"] |
|
|
if img_embeds.dim() == 2: |
|
|
img_embeds = img_embeds.unsqueeze(1) |
|
|
|
|
|
projected = model.projector(img_embeds) |
|
|
encoder_outputs = BaseModelOutput(last_hidden_state=projected) |
|
|
|
|
|
generated = [model.t5.config.decoder_start_token_id] |
|
|
|
|
|
for _ in range(30): |
|
|
decoder_input_ids = torch.tensor([generated], device=device) |
|
|
attn_layers = get_cross_attention( |
|
|
model, encoder_outputs, decoder_input_ids |
|
|
) |
|
|
|
|
|
logits = model.t5( |
|
|
encoder_outputs=encoder_outputs, |
|
|
decoder_input_ids=decoder_input_ids, |
|
|
return_dict=True |
|
|
).logits[:, -1, :] |
|
|
next_id = int(logits.argmax()) |
|
|
generated.append(next_id) |
|
|
|
|
|
if next_id == tokenizer.convert_tokens_to_ids(word): |
|
|
break |
|
|
|
|
|
last_attn = attn_layers[-1] |
|
|
heat = cross_attention_to_image(last_attn) |
|
|
|
|
|
plot_cross_attention_overlay(image_tensor, heat) |
|
|
|