File size: 2,120 Bytes
05a82cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import timm.models.vision_transformer as vit

def patch_timm_vit_return_attn_scores():
    _orig_attn_forward = vit.Attention.forward

    def attn_forward_patched(self, x, return_attn_scores = False):
        if not return_attn_scores:
            return _orig_attn_forward(self, x)

        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        q = q * self.scale
        attn_scores = q @ k.transpose(-2, -1)
        attn = attn_scores.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = attn @ v
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return (x, attn_scores)

    vit.Attention.forward = attn_forward_patched

    # Patch Block.forward
    _orig_block_forward = vit.Block.forward

    def block_forward_patched(self, x, return_attn_scores= False):
        if not return_attn_scores:
            return _orig_block_forward(self, x)

        out, attn_scores = self.attn(self.norm1(x), return_attn_scores=True)
        x = x + self.drop_path1(self.ls1(out))
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return (x, attn_scores)

    vit.Block.forward = block_forward_patched

    def get_attn_scores(self, x, pre_logits: bool = False):
        x = self.patch_embed(x)
        x = self._pos_embed(x)
        x = self.patch_drop(x)
        x = self.norm_pre(x)
        depth = len(self.blocks)
        for i, blk in enumerate(self.blocks):
            if i == (depth - 1):
                x, attn_scores = blk(x, return_attn_scores=True)
            else:
                x = blk(x)
        x = self.norm(x)
        if self.global_pool:
            x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
        x = self.fc_norm(x)
        x = self.head_drop(x)

        if not pre_logits:
            x = self.head(x)

        return (x, attn_scores)

    vit.VisionTransformer.get_attn_scores = get_attn_scores