gtang666 commited on
Commit
997b0b9
·
verified ·
1 Parent(s): c5ea45f

Upload InternVL/perceiver_resampler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. InternVL/perceiver_resampler.py +113 -0
InternVL/perceiver_resampler.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch import nn
4
+ from einops import rearrange, repeat
5
+ from torch import einsum
6
+
7
+ class PerceiverAttention(nn.Module):
8
+ def __init__(
9
+ self,
10
+ *,
11
+ dim,
12
+ dim_head=64,
13
+ heads=8
14
+ ):
15
+ super().__init__()
16
+ self.scale = dim_head ** -0.5
17
+ self.heads = heads
18
+ inner_dim = dim_head * heads # 512
19
+
20
+ self.norm_media = nn.LayerNorm(dim)
21
+ self.norm_learns = nn.LayerNorm(dim)
22
+
23
+ self.to_q = nn.Linear(dim, inner_dim, bias=False) # 4096×512
24
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) # 4096×1024
25
+ self.to_out = nn.Linear(inner_dim, dim, bias=False) # 512×4096
26
+
27
+ def forward(self, x, learns): # x(b, 256, 4096), learns(b, 3, 4096)
28
+ x = self.norm_media(x)
29
+ learns = self.norm_learns(learns)
30
+
31
+ b, n, h = *x.shape[:2], self.heads
32
+
33
+ q = self.to_q(learns) # q(b, 3, 512)
34
+
35
+ # 注意:在PerceiverResampler中,将输入和learns拼接后进行attention计算
36
+ kv_input = torch.cat((x, learns), dim=-2) # kv_input(b, 259, 4096)
37
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1) # (b, 259, 1024)->k, v(b, 259, 512)
38
+
39
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) # q(b, 8, 3, 64) k, v(b, 8, 259, 64)
40
+
41
+ q = q * self.scale
42
+
43
+ # attention计算
44
+ sim = einsum('b h i d, b h j d -> b h i j', q, k)
45
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
46
+ attn = sim.softmax(dim=-1) # sim, attn(b, 8, 3, 259)
47
+
48
+ out = einsum('b h i j, b h j d -> b h i d', attn, v) # out(b, 8, 3, 64)
49
+ out = rearrange(out, 'b h n d -> b n (h d)') # out(b, 3, 512)
50
+ return self.to_out(out) # return(b, 3, 4096)
51
+
52
+
53
+ class PerceiverResampler(nn.Module):
54
+ def __init__(
55
+ self,
56
+ *,
57
+ dim, # 4096
58
+ depth=2,
59
+ dim_head=64,
60
+ heads=8,
61
+ num_learns=3, # 修改为3个learned queries
62
+ ff_mult=4,
63
+ ):
64
+ super().__init__()
65
+ self.learns = nn.Parameter(torch.randn(num_learns, dim)) # 3×4096
66
+
67
+ self.layers = nn.ModuleList([])
68
+ for _ in range(depth):
69
+ self.layers.append(
70
+ nn.ModuleList(
71
+ [
72
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
73
+ FeedForward(dim=dim, mult=ff_mult),
74
+ ]
75
+ )
76
+ )
77
+
78
+ self.norm = nn.LayerNorm(dim)
79
+
80
+ def forward(self, x):
81
+ """
82
+ Args:
83
+ x (torch.Tensor): image features
84
+ shape (b, 256, 4096)
85
+ Returns:
86
+ shape (b, 3, 4096) where 3 is self.num_learns
87
+ """
88
+ b, n, d = x.shape # (b, 256, 4096)
89
+
90
+ # 将learned queries广播到batch size
91
+ learns = repeat(self.learns, "n d -> b n d", b=b)
92
+
93
+ # 通过多层PerceiverAttention和FeedForward模块处理输入
94
+ for attn, ff in self.layers:
95
+ # learns = attn(learns, x) + learns
96
+ learns = attn(x, learns) + learns
97
+ learns = ff(learns) + learns
98
+
99
+ return self.norm(learns)
100
+
101
+ # 用于前向传播的FeedForward模块
102
+ class FeedForward(nn.Module):
103
+ def __init__(self, dim, mult=4):
104
+ super().__init__()
105
+ self.net = nn.Sequential(
106
+ nn.LayerNorm(dim),
107
+ nn.Linear(dim, dim * mult),
108
+ nn.GELU(),
109
+ nn.Linear(dim * mult, dim),
110
+ )
111
+
112
+ def forward(self, x):
113
+ return self.net(x)