File size: 2,120 Bytes
05a82cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | import torch
import timm.models.vision_transformer as vit
def patch_timm_vit_return_attn_scores():
_orig_attn_forward = vit.Attention.forward
def attn_forward_patched(self, x, return_attn_scores = False):
if not return_attn_scores:
return _orig_attn_forward(self, x)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
q = q * self.scale
attn_scores = q @ k.transpose(-2, -1)
attn = attn_scores.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return (x, attn_scores)
vit.Attention.forward = attn_forward_patched
# Patch Block.forward
_orig_block_forward = vit.Block.forward
def block_forward_patched(self, x, return_attn_scores= False):
if not return_attn_scores:
return _orig_block_forward(self, x)
out, attn_scores = self.attn(self.norm1(x), return_attn_scores=True)
x = x + self.drop_path1(self.ls1(out))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return (x, attn_scores)
vit.Block.forward = block_forward_patched
def get_attn_scores(self, x, pre_logits: bool = False):
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
depth = len(self.blocks)
for i, blk in enumerate(self.blocks):
if i == (depth - 1):
x, attn_scores = blk(x, return_attn_scores=True)
else:
x = blk(x)
x = self.norm(x)
if self.global_pool:
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
x = self.head_drop(x)
if not pre_logits:
x = self.head(x)
return (x, attn_scores)
vit.VisionTransformer.get_attn_scores = get_attn_scores
|