coco-demo / src /interpretability.py
evanec's picture
Upload 12 files
1809762 verified
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision.transforms.functional import resize
from transformers.modeling_outputs import BaseModelOutput
import cv2
from transformers.models.vit.modeling_vit import ViTModel
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
class GradCAM:
def __init__(self, vision_encoder):
self.model = vision_encoder.model
self.target_layer = self._find_last_conv_layer()
self.activations = None
self.gradients = None
self.target_layer.register_forward_hook(self._hook_forward)
self.target_layer.register_backward_hook(self._hook_backward)
def _find_last_conv_layer(self):
for module in reversed(list(self.model.modules())):
if isinstance(module, torch.nn.Conv2d):
return module
raise RuntimeError("No Conv2D layer found for Grad-CAM.")
def _hook_forward(self, module, inp, out):
self.activations = out.detach()
def _hook_backward(self, module, grad_in, grad_out):
self.gradients = grad_out[0].detach()
def generate(self, image_tensor):
self.model.zero_grad()
out = self.model(image_tensor) # (B, C, H, W)
if out.ndim == 4:
pooled = out.mean(dim=[2, 3]) # (B, C)
elif out.ndim == 3:
pooled = out.mean(dim=1)
else:
pooled = out
score = pooled.norm()
score.backward()
weights = self.gradients.mean(dim=(2, 3), keepdim=True)
cam = (weights * self.activations).sum(dim=1).squeeze()
cam = F.relu(cam)
cam -= cam.min()
cam /= cam.max() + 1e-8
return cam.cpu().numpy()
def save(self, img_tensor, save_path):
cam = self.generate(img_tensor)
img_np = img_tensor[0].permute(1, 2, 0).cpu().numpy()
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
cam_resized = cv2.resize(cam, (img_np.shape[1], img_np.shape[0]))
plt.figure(figsize=(6, 6))
plt.imshow(img_np)
plt.imshow(cam_resized, cmap="inferno", alpha=0.45)
plt.axis("off")
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close()
print(f"[GradCAM] Saved to {save_path}")
def get_vit_self_attention(model, image_tensor):
vision = model.vision_encoder
if "Resnet" in type(vision).__name__:
return None
# Check for CLIP
if hasattr(vision, "model"):
if hasattr(vision.model, "vision_model"):
hf_vit = vision.model.vision_model
outputs = hf_vit(
pixel_values=image_tensor,
output_attentions=True,
return_dict=True,
)
return outputs.attentions
# Check for ViT
if isinstance(vision.model, ViTModel):
outputs = vision.model(
pixel_values=image_tensor,
output_attentions=True,
return_dict=True,
)
return outputs.attentions
raise ValueError("Vision encoder does not expose ViT attentions.")
# ATTENTION ROLLOUT (across layers)
def attention_rollout(attn_mats, discard_ratio=0.0):
device = attn_mats[0].device
result = torch.eye(attn_mats[0].size(-1), device=device)
for attn in attn_mats:
attn = attn.mean(dim=0) # average heads
if discard_ratio > 0:
flat = attn.view(-1)
threshold = flat.topk(int(flat.numel() * discard_ratio), largest=False)[0].max()
attn = torch.where(attn < threshold, torch.zeros_like(attn), attn)
attn = attn / attn.sum(dim=-1, keepdim=True)
result = attn @ result
return result
def rollout_to_image(rollout, image_size):
tokens = rollout.size(0)
num_patches = int((tokens - 1) ** 0.5)
spatial = rollout[0, 1:].reshape(num_patches, num_patches)
spatial = (spatial - spatial.min()) / (spatial.max() - spatial.min())
spatial = resize(
spatial.unsqueeze(0).unsqueeze(0),
(image_size, image_size)
)
return spatial.squeeze().detach().cpu().numpy()
def plot_attention_overlay(image, heatmap, alpha=0.45):
if torch.is_tensor(image):
image = image.permute(1,2,0).cpu().numpy()
image = (image - image.min()) / (image.max() - image.min())
plt.figure(figsize=(6,6))
plt.imshow(image)
plt.imshow(heatmap, cmap='inferno', alpha=alpha)
plt.axis("off")
plt.show()
# GRADIENT MAP
def token_gradient_map(model, tokenizer, image_tensor, target_word, device="cuda"):
model.eval()
image_tensor = image_tensor.to(device)
image_tensor.requires_grad_(True)
vision_out = model.vision_encoder(image_tensor)
img_embeds = vision_out["image_embeds"]
if img_embeds.dim() == 2:
img_embeds = img_embeds.unsqueeze(1)
projected = model.projector(img_embeds)
encoder_outputs = BaseModelOutput(last_hidden_state=projected)
start = model.t5.config.decoder_start_token_id
decoder_input_ids = torch.tensor([[start]], device=device)
outputs = model.t5(
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
return_dict=True,
)
logits = outputs.logits[:, -1, :]
target_id = tokenizer.convert_tokens_to_ids(target_word)
logit = logits[0, target_id]
logit.backward()
grad = image_tensor.grad.abs().mean(dim=1).squeeze().cpu().numpy()
grad = (grad - grad.min()) / (grad.max() - grad.min() + 1e-8)
return grad
# ATTENTION x GRAD
def attngrad(model, tokenizer, image_tensor, target_word, image_size=224, device="cuda"):
raw_attns = get_vit_self_attention(model, image_tensor.to(device))
attn_mats = [a[0] for a in raw_attns]
rollout = attention_rollout(attn_mats)
roll_map = rollout_to_image(rollout, image_size)
grad_map = token_gradient_map(model, tokenizer, image_tensor, target_word, device)
combined = roll_map * grad_map
combined = (combined - combined.min()) / (combined.max() - combined.min())
return combined
def token_gradient_map_smooth(model, tokenizer, image_tensor, target_word, sigma=5, device="cuda"):
model.eval()
image_tensor = image_tensor.to(device)
image_tensor.requires_grad_(True)
# Vision encoder
vision_out = model.vision_encoder(image_tensor)
img_embeds = vision_out["image_embeds"]
if img_embeds.dim() == 2:
img_embeds = img_embeds.unsqueeze(1)
projected = model.projector(img_embeds)
encoder_outputs = BaseModelOutput(last_hidden_state=projected)
start_token = model.t5.config.decoder_start_token_id
decoder_input_ids = torch.tensor(
[[start_token]], device=device, dtype=torch.long
)
attention_mask = torch.tensor([[1]], device=device)
outputs = model.t5(
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
)
vocab_logits = outputs.logits[:, -1, :]
target_id = tokenizer.convert_tokens_to_ids(target_word)
logit = vocab_logits[0, target_id]
logit.backward()
grad = image_tensor.grad.data.abs().mean(dim=1).squeeze().cpu().numpy()
grad = (grad - grad.min()) / (grad.max() - grad.min() + 1e-8)
grad_smooth = smooth_heatmap(grad, sigma=sigma)
return grad_smooth
def integrated_gradients(
model,
tokenizer,
image_tensor,
target_word,
steps=30,
device="cuda"
):
model.eval()
device = torch.device(device)
image_tensor = image_tensor.to(device)
image_tensor.requires_grad_(True)
baseline = torch.zeros_like(image_tensor)
target_id = tokenizer.convert_tokens_to_ids(target_word)
total_grad = torch.zeros_like(image_tensor)
for i in range(1, steps + 1):
alpha = i / steps
img = baseline + alpha * (image_tensor - baseline)
img.requires_grad_(True)
vision_out = model.vision_encoder(img)
img_embeds = vision_out["image_embeds"]
if img_embeds.dim() == 2:
img_embeds = img_embeds.unsqueeze(1)
projected = model.projector(img_embeds)
encoder_outputs = BaseModelOutput(last_hidden_state=projected)
start_token = model.t5.config.decoder_start_token_id
decoder_input_ids = torch.tensor([[start_token]], device=device)
attention_mask = torch.tensor([[1]], device=device)
outputs = model.t5(
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
return_dict=True,
)
vocab_logits = outputs.logits[:, -1, :]
logit = vocab_logits[0, target_id]
grads = torch.autograd.grad(
outputs=logit,
inputs=img,
retain_graph=True,
create_graph=False,
allow_unused=True,
)[0]
if grads is None:
raise RuntimeError("Integrated gradients: grad is None — gradient path was broken.")
total_grad += grads
avg_grad = total_grad / steps
heat = avg_grad.abs().mean(dim=1).squeeze().cpu().numpy()
heat = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8)
return heat
def smooth_heatmap(hm, k=21, sigma=6):
hm = cv2.GaussianBlur(hm, (k, k), sigma)
hm = (hm - hm.min()) / (hm.max() - hm.min() + 1e-8)
return hm
def get_cross_attention(model, encoder_outputs, decoder_input_ids, device="cuda"):
model.eval()
with torch.no_grad():
outputs = model.t5(
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids.to(device),
output_attentions=True,
return_dict=True,
)
# outputs.cross_attentions is a tuple of layers (batch, heads, tgt_len, src_len)
cross = outputs.cross_attentions
attn_layers = [c[0] for c in cross] # use batch 0
return attn_layers
"""
def cross_attention_to_image(attn, image_size=224):
attn = attn.mean(dim=0) # (tgt_len, src_len)
attn = attn[-1] # (src_len,)
attn = attn[1:]
num_patches = int(attn.numel() ** 0.5)
heat = attn.reshape(num_patches, num_patches)
heat = heat - heat.min()
heat = heat / (heat.max() + 1e-8)
heat = resize(
heat.unsqueeze(0).unsqueeze(0),
(image_size, image_size)
).squeeze()
return heat.detach().cpu().numpy()
"""
def cross_attention_to_image(attn):
attn = torch.tensor(attn) if not torch.is_tensor(attn) else attn
if attn.numel() == 0:
return np.zeros((14, 14), dtype=np.float32)
if attn.dim() == 2:
attn_vec = attn[-1] # use last generated token
elif attn.dim() == 1:
attn_vec = attn
else:
raise ValueError(f"Unexpected attn shape: {attn.shape}")
# DROP CLS TOKEN (index 0) for CLIP ViT-L/14 197 tokens but 196 spatial patches
if attn_vec.size(0) == 197:
attn_vec = attn_vec[1:] # now length = 196
src_len = attn_vec.size(0)
side = int(src_len**0.5)
if side * side != src_len:
new_len = side * side
padded = torch.zeros(new_len, device=attn_vec.device)
padded[:min(new_len, src_len)] = attn_vec[:min(new_len, src_len)]
attn_vec = padded
attn_vec = attn_vec / (attn_vec.max() + 1e-8)
heatmap = attn_vec.reshape(side, side).cpu().numpy()
return heatmap
def plot_cross_attention_overlay(image_tensor, heatmap, save_path=None, alpha=0.45):
img = image_tensor[0].permute(1,2,0).cpu().numpy()
img = (img - img.min()) / (img.max() - img.min())
plt.figure(figsize=(6,6))
plt.imshow(img)
plt.imshow(heatmap, cmap='inferno', alpha=alpha)
plt.axis("off")
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close()
print(f"[CrossAttention] Saved to {save_path}")
else:
plt.show()
def visualize_cross_attention(model, tokenizer, image_tensor, word, device="cuda"):
device = torch.device(device)
image_tensor = image_tensor.to(device)
vision_out = model.vision_encoder(image_tensor)
img_embeds = vision_out["image_embeds"]
if img_embeds.dim() == 2:
img_embeds = img_embeds.unsqueeze(1)
projected = model.projector(img_embeds)
encoder_outputs = BaseModelOutput(last_hidden_state=projected)
generated = [model.t5.config.decoder_start_token_id]
for _ in range(30):
decoder_input_ids = torch.tensor([generated], device=device)
attn_layers = get_cross_attention(
model, encoder_outputs, decoder_input_ids
)
logits = model.t5(
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
return_dict=True
).logits[:, -1, :]
next_id = int(logits.argmax())
generated.append(next_id)
if next_id == tokenizer.convert_tokens_to_ids(word):
break
last_attn = attn_layers[-1] # (heads, T, S)
heat = cross_attention_to_image(last_attn)
plot_cross_attention_overlay(image_tensor, heat)