File size: 3,504 Bytes
997b0b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import torch
from torch import nn
from einops import rearrange, repeat
from torch import einsum

class PerceiverAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head=64,
        heads=8
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads  # 512

        self.norm_media = nn.LayerNorm(dim)
        self.norm_learns = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias=False) # 4096×512
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) # 4096×1024
        self.to_out = nn.Linear(inner_dim, dim, bias=False) # 512×4096

    def forward(self, x, learns): # x(b, 256, 4096), learns(b, 3, 4096)
        x = self.norm_media(x)
        learns = self.norm_learns(learns)

        b, n, h = *x.shape[:2], self.heads

        q = self.to_q(learns) # q(b, 3, 512)

        # 注意:在PerceiverResampler中,将输入和learns拼接后进行attention计算
        kv_input = torch.cat((x, learns), dim=-2) # kv_input(b, 259, 4096)
        k, v = self.to_kv(kv_input).chunk(2, dim=-1) # (b, 259, 1024)->k, v(b, 259, 512)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) # q(b, 8, 3, 64)   k, v(b, 8, 259, 64)

        q = q * self.scale

        # attention计算
        sim = einsum('b h i d, b h j d -> b h i j', q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1) # sim, attn(b, 8, 3, 259)

        out = einsum('b h i j, b h j d -> b h i d', attn, v) # out(b, 8, 3, 64)
        out = rearrange(out, 'b h n d -> b n (h d)') # out(b, 3, 512)
        return self.to_out(out) # return(b, 3, 4096)


class PerceiverResampler(nn.Module):
    def __init__(
        self,
        *,
        dim,  # 4096
        depth=2,
        dim_head=64,
        heads=8,
        num_learns=3,  # 修改为3个learned queries
        ff_mult=4,
    ):
        super().__init__()
        self.learns = nn.Parameter(torch.randn(num_learns, dim)) # 3×4096

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )

        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): image features
                shape (b, 256, 4096)
        Returns:
            shape (b, 3, 4096) where 3 is self.num_learns
        """
        b, n, d = x.shape  # (b, 256, 4096)

        # 将learned queries广播到batch size
        learns = repeat(self.learns, "n d -> b n d", b=b)

        # 通过多层PerceiverAttention和FeedForward模块处理输入
        for attn, ff in self.layers:
            # learns = attn(learns, x) + learns
            learns = attn(x, learns) + learns
            learns = ff(learns) + learns

        return self.norm(learns)

# 用于前向传播的FeedForward模块
class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Linear(dim * mult, dim),
        )

    def forward(self, x):
        return self.net(x)