File size: 4,458 Bytes
ed1622f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import torch
from torch import nn
from einops import rearrange, repeat
from torch import einsum


class PerceiverAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head=64,
        heads=8
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads  # 512

        self.norm_media = nn.LayerNorm(dim)
        self.norm_learns = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias=False) # 4096×512
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) # 4096×1024
        self.to_out = nn.Linear(inner_dim, dim, bias=False) # 512×4096

    def forward(self, x, learns): # x(b, 256, 4096), learns(b, 3, 4096)
        x = self.norm_media(x)
        learns = self.norm_learns(learns)

        b, n, h = *x.shape[:2], self.heads

        q = self.to_q(learns) # q(b, 3, 512)

        # 注意:在PerceiverResampler中,将输入和learns拼接后进行attention计算
        kv_input = torch.cat((x, learns), dim=-2) # kv_input(b, 259, 4096)
        k, v = self.to_kv(kv_input).chunk(2, dim=-1) # (b, 259, 1024)->k, v(b, 259, 512)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) # q(b, 8, 3, 64)   k, v(b, 8, 259, 64)

        q = q * self.scale

        # attention计算
        sim = einsum('b h i d, b h j d -> b h i j', q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1) # sim, attn(b, 8, 3, 259)

        out = einsum('b h i j, b h j d -> b h i d', attn, v) # out(b, 8, 3, 64)
        out = rearrange(out, 'b h n d -> b n (h d)') # out(b, 3, 512)
        return self.to_out(out) # return(b, 3, 4096)


class PerceiverResampler(nn.Module):
    def __init__(
        self,
        *,
        dim,  
        depth=6,
        dim_head=64,
        heads=8,
        num_learns=3,  
        ff_mult=4,
    ):
        super().__init__()
        self.learns = nn.Parameter(torch.randn(num_learns, dim))

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )

        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): image features
                shape (b, 256, 4096)
        Returns:
            shape (b, 3, 4096) where 3 is self.num_learns
        """
        b, n, d = x.shape 

 
        learns = repeat(self.learns, "n d -> b n d", b=b)

       
        for attn, ff in self.layers:
            
            learns = attn(x, learns) + learns
            learns = ff(learns) + learns

        return self.norm(learns)
    
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_mult=4): # input_dim = 256
        super().__init__()
        self.ff1 = FeedForward_2(input_dim, input_dim, hidden_mult)
        self.ff2 = FeedForward_2(input_dim, 3, hidden_mult)

    def forward(self, x):
  
        x = x.permute(0, 2, 1)
        x = self.ff1(x)
        x = self.ff2(x)
        
        x = x.permute(0, 2, 1)
        return x
    
class MLP_6763(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_mult=2): 
        super().__init__()
        self.ff1 = FeedForward_2(input_dim, output_dim, hidden_mult)
        self.ff2 = FeedForward_2(output_dim, output_dim, hidden_mult)

    def forward(self, x):
        b, n, d = x.shape
        x = x.view(b, -1)
        x = self.ff1(x)
        x = self.ff2(x)
        return x

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Linear(dim * mult, dim),
        )

    def forward(self, x):
        return self.net(x)
    
class FeedForward_2(nn.Module):
    def __init__(self, input_dim, output_dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.Linear(input_dim, input_dim * mult),
            nn.GELU(),
            nn.Linear(input_dim * mult, output_dim),
        )

    def forward(self, x):
        return self.net(x)