| import cv2 |
| from PIL import Image |
| import numpy as np |
| import torch |
|
|
| import PIL |
|
|
| def overlay_attn(original_image,mask): |
| |
| |
| |
| colormap_attn, alpha_attn = cv2.COLORMAP_VIRIDIS, 1 |
| |
| |
| w, h = original_image.shape[0], original_image.shape[1] |
| mask = cv2.resize(mask / mask.max(), (h, w))[..., np.newaxis] |
| |
| |
| cmap = cv2.applyColorMap(np.uint8(255 * mask), colormap_attn) |
|
|
| |
| |
| |
| |
| alpha_blended = cv2.addWeighted(np.uint8(original_image),0.4, cmap, 0.6, 0) |
|
|
|
|
| |
| |
|
|
| |
| final_im = Image.fromarray(alpha_blended) |
| |
| return final_im |
|
|
|
|
|
|
| class VITAttentionGradRollout: |
| ''' |
| Expects timm ViT transformer model |
| Adapted from https://github.com/samiraabnar/attention_flow |
| ''' |
| def __init__(self, model, head_fusion='min', discard_ratio=0): |
| self.model = model |
| self.head_fusion = head_fusion |
| self.discard_ratio = discard_ratio |
| |
| self.attentions = {} |
| for idx, module in enumerate(list(model.blocks.children())): |
| module.attn.register_forward_hook(self.get_attention(f"attn{idx}")) |
|
|
|
|
| def get_attention(self, name): |
| def hook(module, input, output): |
| with torch.no_grad(): |
| input = input[0] |
| B, N, C = input.shape |
| qkv = ( |
| module.qkv(input) |
| .detach() |
| .reshape(B, N, 3, module.num_heads, C // module.num_heads) |
| .permute(2, 0, 3, 1, 4) |
| ) |
| q, k, _ = ( |
| qkv[0], |
| qkv[1], |
| qkv[2], |
| ) |
| attn = (q @ k.transpose(-2, -1)) * module.scale |
| attn = attn.softmax(dim=-1) |
| self.attentions[name] = attn |
| return hook |
|
|
| def get_attn_mask(self,k=0): |
| attn_key = "attn" + str() |
| result = torch.eye(self.attentions['attn0'].size(-1)).to(self.attentions['attn0'].device) |
|
|
| |
| with torch.no_grad(): |
| |
| for k in range(11, len(self.attentions.keys())): |
| attention = self.attentions[f'attn{k}'] |
| if self.head_fusion == "mean": |
| attention_heads_fused = attention.mean(axis=1) |
| elif self.head_fusion == "max": |
| attention_heads_fused = attention.max(axis=1)[0] |
| elif self.head_fusion == "min": |
| attention_heads_fused = attention.min(axis=1)[0] |
| else: |
| raise "Attention head fusion type Not supported" |
|
|
| |
| |
| flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) |
| _, indices = flat.topk(int(flat.size(-1)*self.discard_ratio), -1, False) |
| indices = indices[indices != 0] |
| flat[0, indices] = 0 |
| I = torch.eye(attention_heads_fused.size(-1)).to(attention_heads_fused.device) |
| a = (attention_heads_fused + 1.0*I)/2 |
| a = a / a.sum(dim=-1).unsqueeze(-1) |
|
|
| result = torch.matmul(a, result) |
|
|
| |
| |
| mask = result[0, 0 , 1 :] |
| |
| width = int(mask.size(-1)**0.5) |
| mask = mask.reshape(width, width).detach().cpu().numpy() |
| mask = mask / np.max(mask) |
| return mask |
|
|