import torch import torch.nn.functional as F import numpy as np import matplotlib.pyplot as plt from torchvision.transforms.functional import resize from transformers.modeling_outputs import BaseModelOutput import cv2 from transformers.models.vit.modeling_vit import ViTModel import torch import torch.nn.functional as F import matplotlib.pyplot as plt class GradCAM: def __init__(self, vision_encoder): self.model = vision_encoder.model self.target_layer = self._find_last_conv_layer() self.activations = None self.gradients = None self.target_layer.register_forward_hook(self._hook_forward) self.target_layer.register_backward_hook(self._hook_backward) def _find_last_conv_layer(self): for module in reversed(list(self.model.modules())): if isinstance(module, torch.nn.Conv2d): return module raise RuntimeError("No Conv2D layer found for Grad-CAM.") def _hook_forward(self, module, inp, out): self.activations = out.detach() def _hook_backward(self, module, grad_in, grad_out): self.gradients = grad_out[0].detach() def generate(self, image_tensor): self.model.zero_grad() out = self.model(image_tensor) # (B, C, H, W) if out.ndim == 4: pooled = out.mean(dim=[2, 3]) # (B, C) elif out.ndim == 3: pooled = out.mean(dim=1) else: pooled = out score = pooled.norm() score.backward() weights = self.gradients.mean(dim=(2, 3), keepdim=True) cam = (weights * self.activations).sum(dim=1).squeeze() cam = F.relu(cam) cam -= cam.min() cam /= cam.max() + 1e-8 return cam.cpu().numpy() def save(self, img_tensor, save_path): cam = self.generate(img_tensor) img_np = img_tensor[0].permute(1, 2, 0).cpu().numpy() img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min()) cam_resized = cv2.resize(cam, (img_np.shape[1], img_np.shape[0])) plt.figure(figsize=(6, 6)) plt.imshow(img_np) plt.imshow(cam_resized, cmap="inferno", alpha=0.45) plt.axis("off") plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches="tight") plt.close() print(f"[GradCAM] Saved to {save_path}") def get_vit_self_attention(model, image_tensor): vision = model.vision_encoder if "Resnet" in type(vision).__name__: return None # Check for CLIP if hasattr(vision, "model"): if hasattr(vision.model, "vision_model"): hf_vit = vision.model.vision_model outputs = hf_vit( pixel_values=image_tensor, output_attentions=True, return_dict=True, ) return outputs.attentions # Check for ViT if isinstance(vision.model, ViTModel): outputs = vision.model( pixel_values=image_tensor, output_attentions=True, return_dict=True, ) return outputs.attentions raise ValueError("Vision encoder does not expose ViT attentions.") # ATTENTION ROLLOUT (across layers) def attention_rollout(attn_mats, discard_ratio=0.0): device = attn_mats[0].device result = torch.eye(attn_mats[0].size(-1), device=device) for attn in attn_mats: attn = attn.mean(dim=0) # average heads if discard_ratio > 0: flat = attn.view(-1) threshold = flat.topk(int(flat.numel() * discard_ratio), largest=False)[0].max() attn = torch.where(attn < threshold, torch.zeros_like(attn), attn) attn = attn / attn.sum(dim=-1, keepdim=True) result = attn @ result return result def rollout_to_image(rollout, image_size): tokens = rollout.size(0) num_patches = int((tokens - 1) ** 0.5) spatial = rollout[0, 1:].reshape(num_patches, num_patches) spatial = (spatial - spatial.min()) / (spatial.max() - spatial.min()) spatial = resize( spatial.unsqueeze(0).unsqueeze(0), (image_size, image_size) ) return spatial.squeeze().detach().cpu().numpy() def plot_attention_overlay(image, heatmap, alpha=0.45): if torch.is_tensor(image): image = image.permute(1,2,0).cpu().numpy() image = (image - image.min()) / (image.max() - image.min()) plt.figure(figsize=(6,6)) plt.imshow(image) plt.imshow(heatmap, cmap='inferno', alpha=alpha) plt.axis("off") plt.show() # GRADIENT MAP def token_gradient_map(model, tokenizer, image_tensor, target_word, device="cuda"): model.eval() image_tensor = image_tensor.to(device) image_tensor.requires_grad_(True) vision_out = model.vision_encoder(image_tensor) img_embeds = vision_out["image_embeds"] if img_embeds.dim() == 2: img_embeds = img_embeds.unsqueeze(1) projected = model.projector(img_embeds) encoder_outputs = BaseModelOutput(last_hidden_state=projected) start = model.t5.config.decoder_start_token_id decoder_input_ids = torch.tensor([[start]], device=device) outputs = model.t5( encoder_outputs=encoder_outputs, decoder_input_ids=decoder_input_ids, return_dict=True, ) logits = outputs.logits[:, -1, :] target_id = tokenizer.convert_tokens_to_ids(target_word) logit = logits[0, target_id] logit.backward() grad = image_tensor.grad.abs().mean(dim=1).squeeze().cpu().numpy() grad = (grad - grad.min()) / (grad.max() - grad.min() + 1e-8) return grad # ATTENTION x GRAD def attngrad(model, tokenizer, image_tensor, target_word, image_size=224, device="cuda"): raw_attns = get_vit_self_attention(model, image_tensor.to(device)) attn_mats = [a[0] for a in raw_attns] rollout = attention_rollout(attn_mats) roll_map = rollout_to_image(rollout, image_size) grad_map = token_gradient_map(model, tokenizer, image_tensor, target_word, device) combined = roll_map * grad_map combined = (combined - combined.min()) / (combined.max() - combined.min()) return combined def token_gradient_map_smooth(model, tokenizer, image_tensor, target_word, sigma=5, device="cuda"): model.eval() image_tensor = image_tensor.to(device) image_tensor.requires_grad_(True) # Vision encoder vision_out = model.vision_encoder(image_tensor) img_embeds = vision_out["image_embeds"] if img_embeds.dim() == 2: img_embeds = img_embeds.unsqueeze(1) projected = model.projector(img_embeds) encoder_outputs = BaseModelOutput(last_hidden_state=projected) start_token = model.t5.config.decoder_start_token_id decoder_input_ids = torch.tensor( [[start_token]], device=device, dtype=torch.long ) attention_mask = torch.tensor([[1]], device=device) outputs = model.t5( encoder_outputs=encoder_outputs, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, output_attentions=False, output_hidden_states=False, return_dict=True, ) vocab_logits = outputs.logits[:, -1, :] target_id = tokenizer.convert_tokens_to_ids(target_word) logit = vocab_logits[0, target_id] logit.backward() grad = image_tensor.grad.data.abs().mean(dim=1).squeeze().cpu().numpy() grad = (grad - grad.min()) / (grad.max() - grad.min() + 1e-8) grad_smooth = smooth_heatmap(grad, sigma=sigma) return grad_smooth def integrated_gradients( model, tokenizer, image_tensor, target_word, steps=30, device="cuda" ): model.eval() device = torch.device(device) image_tensor = image_tensor.to(device) image_tensor.requires_grad_(True) baseline = torch.zeros_like(image_tensor) target_id = tokenizer.convert_tokens_to_ids(target_word) total_grad = torch.zeros_like(image_tensor) for i in range(1, steps + 1): alpha = i / steps img = baseline + alpha * (image_tensor - baseline) img.requires_grad_(True) vision_out = model.vision_encoder(img) img_embeds = vision_out["image_embeds"] if img_embeds.dim() == 2: img_embeds = img_embeds.unsqueeze(1) projected = model.projector(img_embeds) encoder_outputs = BaseModelOutput(last_hidden_state=projected) start_token = model.t5.config.decoder_start_token_id decoder_input_ids = torch.tensor([[start_token]], device=device) attention_mask = torch.tensor([[1]], device=device) outputs = model.t5( encoder_outputs=encoder_outputs, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, return_dict=True, ) vocab_logits = outputs.logits[:, -1, :] logit = vocab_logits[0, target_id] grads = torch.autograd.grad( outputs=logit, inputs=img, retain_graph=True, create_graph=False, allow_unused=True, )[0] if grads is None: raise RuntimeError("Integrated gradients: grad is None — gradient path was broken.") total_grad += grads avg_grad = total_grad / steps heat = avg_grad.abs().mean(dim=1).squeeze().cpu().numpy() heat = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8) return heat def smooth_heatmap(hm, k=21, sigma=6): hm = cv2.GaussianBlur(hm, (k, k), sigma) hm = (hm - hm.min()) / (hm.max() - hm.min() + 1e-8) return hm def get_cross_attention(model, encoder_outputs, decoder_input_ids, device="cuda"): model.eval() with torch.no_grad(): outputs = model.t5( encoder_outputs=encoder_outputs, decoder_input_ids=decoder_input_ids.to(device), output_attentions=True, return_dict=True, ) # outputs.cross_attentions is a tuple of layers (batch, heads, tgt_len, src_len) cross = outputs.cross_attentions attn_layers = [c[0] for c in cross] # use batch 0 return attn_layers """ def cross_attention_to_image(attn, image_size=224): attn = attn.mean(dim=0) # (tgt_len, src_len) attn = attn[-1] # (src_len,) attn = attn[1:] num_patches = int(attn.numel() ** 0.5) heat = attn.reshape(num_patches, num_patches) heat = heat - heat.min() heat = heat / (heat.max() + 1e-8) heat = resize( heat.unsqueeze(0).unsqueeze(0), (image_size, image_size) ).squeeze() return heat.detach().cpu().numpy() """ def cross_attention_to_image(attn): attn = torch.tensor(attn) if not torch.is_tensor(attn) else attn if attn.numel() == 0: return np.zeros((14, 14), dtype=np.float32) if attn.dim() == 2: attn_vec = attn[-1] # use last generated token elif attn.dim() == 1: attn_vec = attn else: raise ValueError(f"Unexpected attn shape: {attn.shape}") # DROP CLS TOKEN (index 0) for CLIP ViT-L/14 197 tokens but 196 spatial patches if attn_vec.size(0) == 197: attn_vec = attn_vec[1:] # now length = 196 src_len = attn_vec.size(0) side = int(src_len**0.5) if side * side != src_len: new_len = side * side padded = torch.zeros(new_len, device=attn_vec.device) padded[:min(new_len, src_len)] = attn_vec[:min(new_len, src_len)] attn_vec = padded attn_vec = attn_vec / (attn_vec.max() + 1e-8) heatmap = attn_vec.reshape(side, side).cpu().numpy() return heatmap def plot_cross_attention_overlay(image_tensor, heatmap, save_path=None, alpha=0.45): img = image_tensor[0].permute(1,2,0).cpu().numpy() img = (img - img.min()) / (img.max() - img.min()) plt.figure(figsize=(6,6)) plt.imshow(img) plt.imshow(heatmap, cmap='inferno', alpha=alpha) plt.axis("off") if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight") plt.close() print(f"[CrossAttention] Saved to {save_path}") else: plt.show() def visualize_cross_attention(model, tokenizer, image_tensor, word, device="cuda"): device = torch.device(device) image_tensor = image_tensor.to(device) vision_out = model.vision_encoder(image_tensor) img_embeds = vision_out["image_embeds"] if img_embeds.dim() == 2: img_embeds = img_embeds.unsqueeze(1) projected = model.projector(img_embeds) encoder_outputs = BaseModelOutput(last_hidden_state=projected) generated = [model.t5.config.decoder_start_token_id] for _ in range(30): decoder_input_ids = torch.tensor([generated], device=device) attn_layers = get_cross_attention( model, encoder_outputs, decoder_input_ids ) logits = model.t5( encoder_outputs=encoder_outputs, decoder_input_ids=decoder_input_ids, return_dict=True ).logits[:, -1, :] next_id = int(logits.argmax()) generated.append(next_id) if next_id == tokenizer.convert_tokens_to_ids(word): break last_attn = attn_layers[-1] # (heads, T, S) heat = cross_attention_to_image(last_attn) plot_cross_attention_overlay(image_tensor, heat)