dianecy commited on
Commit
c187b4b
·
verified ·
1 Parent(s): 7e3a804

Upload folder using huggingface_hub

Browse files
ASDA/model/__pycache__/model.cpython-39.pyc ADDED
Binary file (12.6 kB). View file
 
ASDA/model/__pycache__/model_sbert_gref.cpython-39.pyc ADDED
Binary file (13 kB). View file
 
ASDA/model/__pycache__/modules.cpython-39.pyc ADDED
Binary file (8.28 kB). View file
 
ASDA/model/__pycache__/position_encoding.cpython-39.pyc ADDED
Binary file (2 kB). View file
 
ASDA/model/__pycache__/transformer.cpython-39.pyc ADDED
Binary file (8.29 kB). View file
 
ASDA/model/model.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ from .modules import ConvBatchNormReLU, SFA
7
+ from .modules import *
8
+ from .position_encoding import *
9
+
10
+ import clip
11
+ import math
12
+ import sys
13
+
14
+ sys.path.append('../')
15
+ from utils.utils import *
16
+
17
+
18
+ class Simple_fusion(nn.Module):
19
+ def __init__(self, visual_dim=1024, text_dim=768, proj_dim=1024, jemb_drop_out=0.1, leaky=True):
20
+ super(Simple_fusion, self).__init__()
21
+ self.proj_dim = proj_dim
22
+ self.mapping_visu = ConvBatchNormReLU(visual_dim, proj_dim, 1, 1, 0, 1, leaky=leaky)
23
+ self.lang_attn = nn.Sequential(
24
+ nn.Linear(text_dim, text_dim),
25
+ nn.Tanh(),
26
+ nn.Dropout(jemb_drop_out),
27
+ nn.Softmax(dim=1))
28
+
29
+ self.lang_proj = nn.Sequential(
30
+ nn.Linear(text_dim, proj_dim),
31
+ nn.BatchNorm1d(proj_dim),
32
+ nn.LeakyReLU(0.1))
33
+
34
+ self.fusion = nn.Sequential(
35
+ nn.BatchNorm2d(proj_dim),
36
+ nn.LeakyReLU(0.1))
37
+
38
+ def forward(self, visual_feat, lang_feat):
39
+ # visual proj
40
+ visual_feat_proj = self.mapping_visu(visual_feat) # [bt, 1024, 13, 13]
41
+
42
+ """
43
+ # lang attn
44
+ lang_feat_attn = self.lang_attn(lang_feat) #[bt, 15, 768]
45
+ lang_feat_new = lang_feat * lang_feat_attn
46
+ lang_feat_new = lang_feat_new.sum(dim=1) #[bt, 768]
47
+ """
48
+
49
+ lang_feat = lang_feat.squeeze(1)
50
+ # lang proj
51
+ #lang_feat_new = self.lang_proj(lang_feat_new) #[bt, 1024]
52
+ lang_feat_new = self.lang_proj(lang_feat) #[bt, 1024]
53
+
54
+ # fusion
55
+ h, w = visual_feat.shape[-2], visual_feat.shape[-1]
56
+ lang_feat_new_tile = lang_feat_new.view(-1, self.proj_dim, 1, 1).repeat(1, 1, h, w) # [bt, 1024, 13, 13]
57
+ fusion_feat = lang_feat_new_tile * visual_feat_proj
58
+ fusion_feat = self.fusion(fusion_feat)
59
+ return fusion_feat
60
+
61
+ class up_proj_cat_proj(nn.Module):
62
+ def __init__(self, input_1, input_2, do=512, leaky=True):
63
+ super(up_proj_cat_proj, self).__init__()
64
+ self.proj1 = ConvBatchNormReLU(input_2, input_2, 1, 1, 0, 1, leaky=leaky)
65
+ self.proj2 = ConvBatchNormReLU(input_1+input_2, do, 1, 1, 0, 1, leaky=leaky)
66
+
67
+ def forward(self, x, y):
68
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
69
+ y = self.proj1(y)
70
+ out = torch.cat([x,y], dim=1)
71
+ out = self.proj2(out)
72
+ return out
73
+
74
+ class pool_proj_cat_proj(nn.Module):
75
+ def __init__(self, input_1, input_2, do=512, leaky=True):
76
+ super(pool_proj_cat_proj, self).__init__()
77
+ self.downsample = nn.AvgPool2d(2, 2)
78
+ self.proj1 = ConvBatchNormReLU(input_2, do // 2, 1, 1, 0, 1, leaky=leaky)
79
+ self.proj2 = ConvBatchNormReLU(do // 2, do, 3, 1, 1, 1, leaky=leaky)
80
+ self.proj3 = ConvBatchNormReLU(input_1+do, do, 1, 1, 0, 1, leaky=leaky)
81
+
82
+ def forward(self, x, y):
83
+ y = self.downsample(y)
84
+ y = self.proj1(y)
85
+ y = self.proj2(y)
86
+ output = self.proj3(torch.cat([x,y], dim=1))
87
+ return output
88
+
89
+ class proj_cat_proj(nn.Module):
90
+ def __init__(self, input_1, input_2, do=512, leaky=True):
91
+ super(proj_cat_proj, self).__init__()
92
+ self.proj1 = ConvBatchNormReLU(input_2, input_2, 1, 1, 0, 1, leaky=leaky)
93
+ self.proj2 = ConvBatchNormReLU(input_1 + input_2, do, 1, 1, 0, 1, leaky=leaky)
94
+
95
+ def forward(self, x, y):
96
+ y = self.proj1(y)
97
+ out = torch.cat([x, y], dim=1)
98
+ out = self.proj2(out)
99
+ return out
100
+
101
+ class proj_cat(nn.Module):
102
+ def __init__(self, input_1, input_2, do=512, leaky=True):
103
+ super(proj_cat, self).__init__()
104
+ self.proj1 = ConvBatchNormReLU(input_1, do // 2, 1, 1, 0, 1, leaky=leaky)
105
+ self.proj2 = ConvBatchNormReLU(do // 2, do, 3, 1, 1, 1, leaky=leaky)
106
+
107
+ def forward(self, x, y):
108
+ x = self.proj1(x)
109
+ x = self.proj2(x)
110
+ output = torch.cat([x,y], dim=1)
111
+ return output
112
+
113
+ class mask_decoder(nn.Module):
114
+ def __init__(self, input_1, seg_out_stride=2, leaky=True):
115
+ super(mask_decoder, self).__init__()
116
+ self.proj1 = ConvBatchNormReLU(input_1, input_1//2, 3, 1, 1, 1, leaky=leaky)
117
+ self.proj2 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky)
118
+
119
+ self.proj3 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky)
120
+ self.proj4 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky)
121
+ self.proj5 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky)
122
+ #self.proj = nn.Conv2d(input_1, 1, 3, 1, 1, 1)
123
+ self.proj = nn.Conv2d(input_1//2, 32, 3, 1, 1, 1)
124
+
125
+ def forward(self, x, seg_out_stride):
126
+ x = self.proj1(x)
127
+ x = self.proj2(x)
128
+
129
+
130
+ if seg_out_stride <= 8:
131
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
132
+ x = self.proj3(x)
133
+
134
+ if seg_out_stride <= 4:
135
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
136
+ x = self.proj4(x)
137
+
138
+ if seg_out_stride <= 2:
139
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
140
+ x = self.proj5(x)
141
+
142
+ x = self.proj(x)
143
+
144
+ return x
145
+
146
+
147
+ # class FeatureSelector(nn.Module):
148
+ # def __init__(self, img_feature_dim, text_feature_dim, output_dim):
149
+ # super(FeatureSelector, self).__init__()
150
+ # # 使用nn.Sequential来简化MLP的构建
151
+ # self.mlp = nn.Sequential(
152
+ # nn.Linear(img_feature_dim * 3 + text_feature_dim * 3, 1024),
153
+ # nn.ReLU(),
154
+ # nn.Linear(1024, 256),
155
+ # nn.ReLU(),
156
+ # nn.Linear(256, output_dim)
157
+ # )
158
+
159
+ # def forward(self, img_features, text_feature):
160
+ # # 将图像特征和文本特征拼接
161
+ # combined_features = torch.cat(img_features + text_feature, dim=1) #
162
+ # # 通过MLP得到输出得分
163
+ # scores = self.mlp(combined_features)
164
+ # return scores
165
+
166
+
167
+ class QuickGELU(nn.Module):
168
+ def forward(self, x: torch.Tensor):
169
+ return x * torch.sigmoid(1.702 * x)
170
+
171
+ class ResidualAttentionblk(nn.Module):
172
+ def __init__(self, clip_module):
173
+ super().__init__()
174
+
175
+ self.clip_module = clip_module
176
+
177
+ self.selected_tokens = int(676 * 0.8)
178
+
179
+ #self.norm = nn.LayerNorm(768)
180
+
181
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None, lang_tokens=None, index=0):
182
+
183
+
184
+ if lang_tokens is None:
185
+ x = x + self.clip_module.attention(self.clip_module.ln_1(x))
186
+ else:
187
+
188
+ #if index >= 4 and index <= 7:
189
+ # self.selected_tokens = int (676 * 0.8)
190
+ #elif index>=8 and index <=11:
191
+ # self.selected_tokens = int (676 * 0.5)
192
+ #print(index)
193
+ #print(self.selected_tokens)
194
+
195
+ N, B, C = x.shape # N x B x C
196
+ cls_x = x[:1, :, :] # 1 x B x C
197
+ x = x[1:, :, :] # M x B x C
198
+
199
+ ###img_cls text_cls
200
+ #x = torch.mul(x, cls_x)
201
+ #x = self.norm(x.reshape((N-1)*B, C))
202
+ #x = x.reshape(N-1, B, C)
203
+
204
+ ### text eos token
205
+ #score = torch.bmm(x.transpose(0,1), lang_tokens).squeeze(-1)
206
+
207
+ ### text features mean
208
+ score = torch.bmm(x.transpose(0, 1), lang_tokens.permute(1, 2, 0)).mean(dim=-1) # B x N
209
+ score = score.transpose(0, 1) # N x B
210
+
211
+ sorted_scores, sorted_indices = torch.sort(score, descending=True, dim=0)
212
+
213
+ # high_mask = sorted_scores > sorted_scores[self.selected_tokens:self.selected_tokens+1, :]
214
+ high_mask = torch.ones_like(sorted_scores)
215
+ for i in range(B):
216
+ high_mask[sorted_indices[self.selected_tokens:, i], i] = 0
217
+ high_mask = high_mask > 0.5
218
+
219
+ delta_x = x[high_mask].reshape(-1, B, C) # M x B x C
220
+ low_x = x[~high_mask].reshape(-1, B, C) # N-M x B x C
221
+ low_score = score[~high_mask].reshape(-1, B, 1) # N-M x B x 1
222
+
223
+ low_x = low_x * torch.softmax(low_score, dim=0) # N-M x B x C
224
+ low_x = low_x.sum(dim=0, keepdim=True) # 1 x B x C
225
+
226
+ delta_x = torch.cat([cls_x, delta_x, low_x], dim=0) # M+1 x B x C
227
+ delta_x = self.clip_module.attention(self.clip_module.ln_1(delta_x))
228
+
229
+ # for i in range(B):
230
+ # x[high_mask[:, i], i, :] += delta_x[1:-1, i, :]
231
+ # x[~high_mask[:, i], i, :] += delta_x[-1:, i, :]
232
+ # cls_x[:, i] += delta_x[:1, i, :]
233
+ temple = torch.zeros_like(x).type(delta_x.type())
234
+ temple[high_mask] = delta_x[1:-1, :, :].reshape(-1, C)
235
+ temple[~high_mask] = delta_x[-1:, :, :].reshape(-1, 1, C).repeat(1, 676 - self.selected_tokens, 1).reshape(-1, C)
236
+ x = x + temple
237
+ cls_x = cls_x + delta_x[:1, :, :]
238
+
239
+ x = torch.cat([cls_x, x], dim=0)
240
+
241
+ x = x + self.clip_module.mlp(self.clip_module.ln_2(x))
242
+ return x
243
+
244
+ class Model(nn.Module):
245
+ def __init__(self, clip_model='RN50', tunelang=False, fusion_dim=2048, num_query=16, do=512, leaky=True, length=17):
246
+ super(Model, self).__init__()
247
+
248
+ self.tunelang = tunelang
249
+ self.length = length
250
+
251
+ ## Init Encoders
252
+ clip_models = clip.load(clip_model, jit=False, device=torch.device("cpu"))[0].cuda()
253
+
254
+ self.visumodel = clip_models.visual
255
+ self.visu_dim = 768
256
+
257
+ self.cut_list = []
258
+ self.visu_resblocks = nn.ModuleList([ResidualAttentionblk(self.visumodel.transformer.resblocks[i]) for i in range(12)])
259
+ self.visu_proj = nn.ModuleList([nn.Linear(do, self.visu_dim) for _ in range(len(self.cut_list))])
260
+
261
+ self.positional_embedding = nn.Parameter(torch.FloatTensor(1, 26 ** 2 + 1, 768))
262
+ v = self.resize_pos_embed(self.visumodel.positional_embedding.data.unsqueeze(0), self.positional_embedding, 26, 26)
263
+ self.positional_embedding.data.copy_(v)
264
+
265
+ self.textmodel = clip_models.transformer
266
+ self.textmodel_token_embedding = clip_models.token_embedding
267
+ self.textmodel_pos_embed = nn.Parameter(clip_models.positional_embedding[:self.length, :].unsqueeze(0))
268
+ self.textmodel_ln_final = clip_models.ln_final
269
+ self.textdim = self.textmodel_pos_embed.shape[-1]
270
+ for module in self.textmodel.resblocks:
271
+ module.attn_mask = self.build_attention_mask()
272
+
273
+ # vis select
274
+ self.vis_select = nn.Linear(self.visu_dim, do, bias=False)
275
+
276
+ ## Fusion
277
+
278
+ # fusion with x12
279
+ self.fusion = Simple_fusion(visual_dim=self.visu_dim, text_dim=self.textdim, proj_dim=fusion_dim)
280
+
281
+ # fusion with x6
282
+ self.up_proj_cat_proj_1 = proj_cat_proj(input_1=fusion_dim, input_2=self.visu_dim, do=fusion_dim)
283
+ self.pool_proj_cat_proj_2 = proj_cat_proj(input_1=fusion_dim, input_2=self.visu_dim, do=do)
284
+
285
+ # fusion with x9
286
+ self.proj_cat = proj_cat(input_1=fusion_dim, input_2=do, do=do)
287
+ self.up_proj_cat_2 = proj_cat_proj(input_1=fusion_dim, input_2=do * 2, do=do)
288
+ self.proj_0 = ConvBatchNormReLU(do, do, 1, 1, 0, 1, leaky=leaky)
289
+
290
+ self.fpn = SFA(in_channels=self.visu_dim, out_channels=do)
291
+
292
+ ## Align dim
293
+ f_dim = 512
294
+ self.fc_2 = nn.Linear(f_dim, f_dim, bias=False)
295
+ self.norm1 = nn.LayerNorm(f_dim)
296
+ self.norm2 = nn.LayerNorm(f_dim)
297
+
298
+ # visual branch
299
+ self.pos_embedding = PositionEmbeddingSine(f_dim)
300
+ encoder_layer = TransformerEncoderLayer(f_dim, nhead=8, dim_feedforward=f_dim,
301
+ dropout=0.1, activation='relu', normalize_before=False)
302
+ self.encoder = TransformerEncoder(encoder_layer, num_layers=2, norm=nn.LayerNorm(f_dim))
303
+
304
+ ## Decoder
305
+ self.mask_decoder = mask_decoder(f_dim, seg_out_stride=2)
306
+
307
+ # text branch
308
+
309
+ ## coef
310
+ self.lang_tf_enc = lang_tf_enc(do, do, do, head_num=8)
311
+ self.proj1 = ConvBatchNormReLU(do, do, 3, 1, 1, 1, leaky=leaky)
312
+ self.proj2 = ConvBatchNormReLU(do, do, 3, 1, 1, 1, leaky=leaky)
313
+ self.proj3 = nn.Conv2d(do, 32, 3, 1, 1, 1)
314
+ self.projout = nn.Linear(26*26*32, 32, bias=False)
315
+
316
+
317
+ self.feature_selector_l = nn.Linear(do, 1, bias=True)
318
+ self.feature_selector_m = nn.Linear(do, 1, bias=True)
319
+
320
+ def resize_pos_embed(self, posemb, posemb_new, hight, width):
321
+ ntok_new = posemb_new.shape[1]
322
+
323
+ posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:]
324
+ ntok_new -= 1
325
+
326
+ gs_old = int(math.sqrt(len(posemb_grid)))
327
+ print('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape, posemb_new.shape, hight, width))
328
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
329
+ posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear')
330
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1)
331
+ posemb = torch.cat([posemb_token, posemb_grid], dim=1)
332
+ return posemb
333
+
334
+
335
+ def build_attention_mask(self):
336
+ # lazily create causal attention mask, with full attention between the vision tokens
337
+ # pytorch uses additive attention mask; fill with -inf
338
+ mask = torch.empty(self.length, self.length)
339
+ mask.fill_(float("-inf"))
340
+ mask.triu_(1) # zero out the lower diagonal
341
+ return mask
342
+
343
+ def forward(self, image, word_id, word_mask):
344
+ ## Visual Module
345
+
346
+ batch_size = image.size(0)
347
+
348
+ # Extract features from vision
349
+ x = self.visumodel.conv1(image)
350
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
351
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
352
+ x = torch.cat([self.visumodel.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
353
+ x = x + self.positional_embedding.to(x.dtype)
354
+ x = self.visumodel.ln_pre(x)
355
+ x = x.permute(1, 0, 2) # NLD -> LND
356
+
357
+ raw_fword = self.textmodel_token_embedding(word_id).squeeze(1)
358
+ raw_fword = raw_fword + self.textmodel_pos_embed
359
+ raw_fword = raw_fword.permute(1, 0, 2) # NLD -> LND
360
+
361
+ visu_list_l = []
362
+ visu_list_m = []
363
+
364
+ scores_l = []
365
+ scores_m = []
366
+
367
+ for i, [blk_visu, blk_lang] in enumerate(zip(self.visu_resblocks, self.textmodel.resblocks)):
368
+ x = blk_visu(x) # [677, bs, 768]
369
+ raw_fword = blk_lang(raw_fword)
370
+
371
+ img_cls = self.vis_select(x[0, :, :]) # [B, C]
372
+ tex_cls = raw_fword[word_id.argmax(dim=-1).reshape(-1), torch.arange(raw_fword.shape[1]), :] # [B, C]
373
+ score = img_cls * tex_cls # [B, C]
374
+ score = score.unsqueeze(1) # [B, 1, C]
375
+
376
+ if i >=3 and i <= 5:
377
+ visu_list_l.append(x)
378
+ scores_l.append(score)
379
+
380
+ if i>=6 and i <=8:
381
+ visu_list_m.append(x)
382
+ scores_m.append(score)
383
+
384
+
385
+ scores_l = torch.cat(scores_l, dim=1) # [B, 3, C]
386
+ scores_m = torch.cat(scores_m, dim=1) # [B, 3, C]
387
+
388
+ scores_l = self.feature_selector_l(scores_l).squeeze(-1) # [B, 3]
389
+ scores_l = F.softmax(scores_l, dim=-1)
390
+ scores_m = self.feature_selector_m(scores_m).squeeze(-1) # [B, 3]
391
+ scores_m = F.softmax(scores_m, dim=-1)
392
+
393
+ visu_list_l = torch.cat(visu_list_l, dim=0).reshape(len(visu_list_l), -1, batch_size, self.visu_dim).permute(0,2,1,3)
394
+ visu_list_m = torch.cat(visu_list_m, dim=0).reshape(len(visu_list_m), -1, batch_size, self.visu_dim).permute(0,2,1,3)
395
+
396
+
397
+ x6 = visu_list_l[scores_l.argmax(dim=-1).reshape(-1), torch.arange(visu_list_l.shape[1]), :, :].permute(1,0,2)
398
+ x9 = visu_list_m[scores_m.argmax(dim=-1).reshape(-1), torch.arange(visu_list_m.shape[1]), :, :].permute(1,0,2)
399
+
400
+
401
+ x6 = x6.permute(1, 0, 2)[:, 1:, :].reshape(-1, 26, 26, self.visu_dim).permute(0, 3, 1, 2)
402
+ x9 = x9.permute(1, 0, 2)[:, 1:, :].reshape(-1, 26, 26, self.visu_dim).permute(0, 3, 1, 2)
403
+ x12 = x.permute(1, 0, 2)[:, 1:, :]
404
+ x12 = x12.reshape(-1, 26, 26, self.visu_dim).permute(0, 3, 1, 2) # [bs, 768, 26, 26]
405
+
406
+
407
+ raw_fword = raw_fword.permute(1, 0, 2)
408
+ raw_fword = self.textmodel_ln_final(raw_fword)
409
+
410
+ if not self.tunelang:
411
+ raw_fword = raw_fword.detach()
412
+
413
+ eos_token = raw_fword[torch.arange(raw_fword.shape[0]), word_id.argmax(dim=-1).reshape(-1), :]
414
+
415
+ F_g = self.fusion(x12, eos_token)
416
+ F_tf = self.fpn([F_g, x9, x6])
417
+
418
+ # Main body
419
+ b, c, h, w = F_tf.shape
420
+
421
+ flatten_length = h*w
422
+ visu_feat = F_tf.reshape(b, c, flatten_length)
423
+ visu_feat = F.relu(visu_feat)
424
+ lang_feat = F.relu(self.fc_2(raw_fword))
425
+
426
+ visu_feat = visu_feat.permute(0, 2, 1)
427
+ pos_embed = self.pos_embedding(visu_feat)
428
+ visu_feat = visu_feat.transpose(0, 1)
429
+ pos_embed = pos_embed.transpose(0, 1)
430
+ visu_feat = self.encoder(visu_feat, pos=pos_embed)
431
+ #[HW B C]
432
+
433
+ visu_feat_ = visu_feat.permute(1,0,2)
434
+
435
+ # mask decoder
436
+ visu_feat = visu_feat.reshape(h, w, b, c)
437
+ visu_feat = visu_feat.permute(2,3,0,1)
438
+ proto_masks = self.mask_decoder(visu_feat, 2)
439
+
440
+ #[B C H W]
441
+ proto_masks = F.relu(proto_masks)
442
+
443
+ # coef
444
+ coef = self.lang_tf_enc(visu_feat_, lang_feat)
445
+ coef = coef.view(b, h, w, c)
446
+ coef = coef.permute(0, 3, 1, 2)
447
+
448
+ coef = self.proj1(coef)
449
+ coef = self.proj2(coef)
450
+ coef = self.proj3(coef)
451
+ coef = coef.permute(0, 2, 3, 1)
452
+ coef = coef.contiguous().view(b, h*w*32)
453
+ # [b, 1, 32]
454
+ coef = self.projout(coef).unsqueeze(-1)
455
+ coef = F.tanh(coef)
456
+
457
+ # mask assemble
458
+ proto_masks = proto_masks.permute(0, 2, 3, 1)
459
+ proto_masks = proto_masks.view(b, -1, 32)
460
+ #[B HW N] [32 208*208 32]
461
+
462
+ mask_out = torch.bmm(proto_masks, coef, out=None)
463
+ mask_out = mask_out.view(b, 208, 208, 1)
464
+ mask_out = mask_out.permute(0, 3, 1, 2)
465
+ return mask_out
466
+
ASDA/model/model_sbert_gref.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ from .modules import ConvBatchNormReLU, SFA
7
+ from .modules import *
8
+ from .position_encoding import *
9
+
10
+ import clip
11
+ import math
12
+ import sys
13
+
14
+ sys.path.append('../')
15
+ from utils.utils import *
16
+
17
+
18
+ class Simple_fusion(nn.Module):
19
+ def __init__(self, visual_dim=1024, text_dim=768, proj_dim=1024, jemb_drop_out=0.1, leaky=True):
20
+ super(Simple_fusion, self).__init__()
21
+ self.proj_dim = proj_dim
22
+ self.mapping_visu = ConvBatchNormReLU(visual_dim, proj_dim, 1, 1, 0, 1, leaky=leaky)
23
+ self.lang_attn = nn.Sequential(
24
+ nn.Linear(text_dim, text_dim),
25
+ nn.Tanh(),
26
+ nn.Dropout(jemb_drop_out),
27
+ nn.Softmax(dim=1))
28
+
29
+ self.lang_proj = nn.Sequential(
30
+ nn.Linear(text_dim, proj_dim),
31
+ nn.BatchNorm1d(proj_dim),
32
+ nn.LeakyReLU(0.1))
33
+
34
+ self.fusion = nn.Sequential(
35
+ nn.BatchNorm2d(proj_dim),
36
+ nn.LeakyReLU(0.1))
37
+
38
+ def forward(self, visual_feat, lang_feat):
39
+ # visual proj
40
+ visual_feat_proj = self.mapping_visu(visual_feat) # [bt, 1024, 13, 13]
41
+
42
+ """
43
+ # lang attn
44
+ lang_feat_attn = self.lang_attn(lang_feat) #[bt, 15, 768]
45
+ lang_feat_new = lang_feat * lang_feat_attn
46
+ lang_feat_new = lang_feat_new.sum(dim=1) #[bt, 768]
47
+ """
48
+
49
+ lang_feat = lang_feat.squeeze(1)
50
+ # lang proj
51
+ #lang_feat_new = self.lang_proj(lang_feat_new) #[bt, 1024]
52
+ lang_feat_new = self.lang_proj(lang_feat) #[bt, 1024]
53
+
54
+ # fusion
55
+ h, w = visual_feat.shape[-2], visual_feat.shape[-1]
56
+ lang_feat_new_tile = lang_feat_new.view(-1, self.proj_dim, 1, 1).repeat(1, 1, h, w) # [bt, 1024, 13, 13]
57
+ fusion_feat = lang_feat_new_tile * visual_feat_proj
58
+ fusion_feat = self.fusion(fusion_feat)
59
+ return fusion_feat
60
+
61
+ class up_proj_cat_proj(nn.Module):
62
+ def __init__(self, input_1, input_2, do=512, leaky=True):
63
+ super(up_proj_cat_proj, self).__init__()
64
+ self.proj1 = ConvBatchNormReLU(input_2, input_2, 1, 1, 0, 1, leaky=leaky)
65
+ self.proj2 = ConvBatchNormReLU(input_1+input_2, do, 1, 1, 0, 1, leaky=leaky)
66
+
67
+ def forward(self, x, y):
68
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
69
+ y = self.proj1(y)
70
+ out = torch.cat([x,y], dim=1)
71
+ out = self.proj2(out)
72
+ return out
73
+
74
+ class pool_proj_cat_proj(nn.Module):
75
+ def __init__(self, input_1, input_2, do=512, leaky=True):
76
+ super(pool_proj_cat_proj, self).__init__()
77
+ self.downsample = nn.AvgPool2d(2, 2)
78
+ self.proj1 = ConvBatchNormReLU(input_2, do // 2, 1, 1, 0, 1, leaky=leaky)
79
+ self.proj2 = ConvBatchNormReLU(do // 2, do, 3, 1, 1, 1, leaky=leaky)
80
+ self.proj3 = ConvBatchNormReLU(input_1+do, do, 1, 1, 0, 1, leaky=leaky)
81
+
82
+ def forward(self, x, y):
83
+ y = self.downsample(y)
84
+ y = self.proj1(y)
85
+ y = self.proj2(y)
86
+ output = self.proj3(torch.cat([x,y], dim=1))
87
+ return output
88
+
89
+ class proj_cat_proj(nn.Module):
90
+ def __init__(self, input_1, input_2, do=512, leaky=True):
91
+ super(proj_cat_proj, self).__init__()
92
+ self.proj1 = ConvBatchNormReLU(input_2, input_2, 1, 1, 0, 1, leaky=leaky)
93
+ self.proj2 = ConvBatchNormReLU(input_1 + input_2, do, 1, 1, 0, 1, leaky=leaky)
94
+
95
+ def forward(self, x, y):
96
+ y = self.proj1(y)
97
+ out = torch.cat([x, y], dim=1)
98
+ out = self.proj2(out)
99
+ return out
100
+
101
+ class proj_cat(nn.Module):
102
+ def __init__(self, input_1, input_2, do=512, leaky=True):
103
+ super(proj_cat, self).__init__()
104
+ self.proj1 = ConvBatchNormReLU(input_1, do // 2, 1, 1, 0, 1, leaky=leaky)
105
+ self.proj2 = ConvBatchNormReLU(do // 2, do, 3, 1, 1, 1, leaky=leaky)
106
+
107
+ def forward(self, x, y):
108
+ x = self.proj1(x)
109
+ x = self.proj2(x)
110
+ output = torch.cat([x,y], dim=1)
111
+ return output
112
+
113
+ class mask_decoder(nn.Module):
114
+ def __init__(self, input_1, seg_out_stride=2, leaky=True):
115
+ super(mask_decoder, self).__init__()
116
+ self.proj1 = ConvBatchNormReLU(input_1, input_1//2, 3, 1, 1, 1, leaky=leaky)
117
+ self.proj2 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky)
118
+
119
+ self.proj3 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky)
120
+ self.proj4 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky)
121
+ self.proj5 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky)
122
+ #self.proj = nn.Conv2d(input_1, 1, 3, 1, 1, 1)
123
+ self.proj = nn.Conv2d(input_1//2, 32, 3, 1, 1, 1)
124
+
125
+ def forward(self, x, seg_out_stride):
126
+ x = self.proj1(x)
127
+ x = self.proj2(x)
128
+
129
+
130
+ if seg_out_stride <= 8:
131
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
132
+ x = self.proj3(x)
133
+
134
+ if seg_out_stride <= 4:
135
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
136
+ x = self.proj4(x)
137
+
138
+ if seg_out_stride <= 2:
139
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
140
+ x = self.proj5(x)
141
+
142
+ x = self.proj(x)
143
+
144
+ return x
145
+
146
+
147
+ # class FeatureSelector(nn.Module):
148
+ # def __init__(self, img_feature_dim, text_feature_dim, output_dim):
149
+ # super(FeatureSelector, self).__init__()
150
+ # # 使用nn.Sequential来简化MLP的构建
151
+ # self.mlp = nn.Sequential(
152
+ # nn.Linear(img_feature_dim * 3 + text_feature_dim * 3, 1024),
153
+ # nn.ReLU(),
154
+ # nn.Linear(1024, 256),
155
+ # nn.ReLU(),
156
+ # nn.Linear(256, output_dim)
157
+ # )
158
+
159
+ # def forward(self, img_features, text_feature):
160
+ # # 将图像特征和文本特征拼接
161
+ # combined_features = torch.cat(img_features + text_feature, dim=1) #
162
+ # # 通过MLP得到输出得分
163
+ # scores = self.mlp(combined_features)
164
+ # return scores
165
+
166
+
167
+ class QuickGELU(nn.Module):
168
+ def forward(self, x: torch.Tensor):
169
+ return x * torch.sigmoid(1.702 * x)
170
+
171
+ class ResidualAttentionblk(nn.Module):
172
+ def __init__(self, clip_module):
173
+ super().__init__()
174
+
175
+ self.clip_module = clip_module
176
+
177
+ self.selected_tokens = int(676 * 0.8)
178
+
179
+ #self.norm = nn.LayerNorm(768)
180
+
181
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None, lang_tokens=None, index=0):
182
+
183
+
184
+ if lang_tokens is None:
185
+ x = x + self.clip_module.attention(self.clip_module.ln_1(x))
186
+ else:
187
+
188
+ #if index >= 4 and index <= 7:
189
+ # self.selected_tokens = int (676 * 0.8)
190
+ #elif index>=8 and index <=11:
191
+ # self.selected_tokens = int (676 * 0.5)
192
+ #print(index)
193
+ #print(self.selected_tokens)
194
+
195
+ N, B, C = x.shape # N x B x C
196
+ cls_x = x[:1, :, :] # 1 x B x C
197
+ x = x[1:, :, :] # M x B x C
198
+
199
+ ###img_cls text_cls
200
+ #x = torch.mul(x, cls_x)
201
+ #x = self.norm(x.reshape((N-1)*B, C))
202
+ #x = x.reshape(N-1, B, C)
203
+
204
+ ### text eos token
205
+ #score = torch.bmm(x.transpose(0,1), lang_tokens).squeeze(-1)
206
+
207
+ ### text features mean
208
+ score = torch.bmm(x.transpose(0, 1), lang_tokens.permute(1, 2, 0)).mean(dim=-1) # B x N
209
+ score = score.transpose(0, 1) # N x B
210
+
211
+ sorted_scores, sorted_indices = torch.sort(score, descending=True, dim=0)
212
+
213
+ # high_mask = sorted_scores > sorted_scores[self.selected_tokens:self.selected_tokens+1, :]
214
+ high_mask = torch.ones_like(sorted_scores)
215
+ for i in range(B):
216
+ high_mask[sorted_indices[self.selected_tokens:, i], i] = 0
217
+ high_mask = high_mask > 0.5
218
+
219
+ delta_x = x[high_mask].reshape(-1, B, C) # M x B x C
220
+ low_x = x[~high_mask].reshape(-1, B, C) # N-M x B x C
221
+ low_score = score[~high_mask].reshape(-1, B, 1) # N-M x B x 1
222
+
223
+ low_x = low_x * torch.softmax(low_score, dim=0) # N-M x B x C
224
+ low_x = low_x.sum(dim=0, keepdim=True) # 1 x B x C
225
+
226
+ delta_x = torch.cat([cls_x, delta_x, low_x], dim=0) # M+1 x B x C
227
+ delta_x = self.clip_module.attention(self.clip_module.ln_1(delta_x))
228
+
229
+ # for i in range(B):
230
+ # x[high_mask[:, i], i, :] += delta_x[1:-1, i, :]
231
+ # x[~high_mask[:, i], i, :] += delta_x[-1:, i, :]
232
+ # cls_x[:, i] += delta_x[:1, i, :]
233
+ temple = torch.zeros_like(x).type(delta_x.type())
234
+ temple[high_mask] = delta_x[1:-1, :, :].reshape(-1, C)
235
+ temple[~high_mask] = delta_x[-1:, :, :].reshape(-1, 1, C).repeat(1, 676 - self.selected_tokens, 1).reshape(-1, C)
236
+ x = x + temple
237
+ cls_x = cls_x + delta_x[:1, :, :]
238
+
239
+ x = torch.cat([cls_x, x], dim=0)
240
+
241
+ x = x + self.clip_module.mlp(self.clip_module.ln_2(x))
242
+ return x
243
+
244
+ class Model_CL(nn.Module):
245
+ def __init__(self, clip_model='RN50', tunelang=False, fusion_dim=2048, num_query=16, do=512, leaky=True, length=17, fuse_mode='coarse', use_projections=False):
246
+ super(Model_CL, self).__init__()
247
+
248
+ self.tunelang = tunelang
249
+ self.length = length
250
+
251
+ ## Init Encoders
252
+ clip_models = clip.load(clip_model, jit=False, device=torch.device("cpu"))[0].cuda()
253
+
254
+ self.visumodel = clip_models.visual
255
+ self.visu_dim = 768
256
+ self.fuse_mode = fuse_mode
257
+
258
+ self.cut_list = []
259
+ self.visu_resblocks = nn.ModuleList([ResidualAttentionblk(self.visumodel.transformer.resblocks[i]) for i in range(12)])
260
+ self.visu_proj = nn.ModuleList([nn.Linear(do, self.visu_dim) for _ in range(len(self.cut_list))])
261
+
262
+ self.positional_embedding = nn.Parameter(torch.FloatTensor(1, 26 ** 2 + 1, 768))
263
+ v = self.resize_pos_embed(self.visumodel.positional_embedding.data.unsqueeze(0), self.positional_embedding, 26, 26)
264
+ self.positional_embedding.data.copy_(v)
265
+
266
+ self.textmodel = clip_models.transformer
267
+ self.textmodel_token_embedding = clip_models.token_embedding
268
+ self.textmodel_pos_embed = nn.Parameter(clip_models.positional_embedding[:self.length, :].unsqueeze(0))
269
+ self.textmodel_ln_final = clip_models.ln_final
270
+ self.textdim = self.textmodel_pos_embed.shape[-1]
271
+ for module in self.textmodel.resblocks:
272
+ module.attn_mask = self.build_attention_mask()
273
+
274
+ # vis select
275
+ self.vis_select = nn.Linear(self.visu_dim, do, bias=False)
276
+
277
+ ## Fusion
278
+ # fusion with x12
279
+ self.fusion = Simple_fusion(visual_dim=self.visu_dim, text_dim=self.textdim, proj_dim=fusion_dim)
280
+
281
+ # fusion with x6
282
+ self.up_proj_cat_proj_1 = proj_cat_proj(input_1=fusion_dim, input_2=self.visu_dim, do=fusion_dim)
283
+ self.pool_proj_cat_proj_2 = proj_cat_proj(input_1=fusion_dim, input_2=self.visu_dim, do=do)
284
+
285
+ # fusion with x9
286
+ self.proj_cat = proj_cat(input_1=fusion_dim, input_2=do, do=do)
287
+ self.up_proj_cat_2 = proj_cat_proj(input_1=fusion_dim, input_2=do * 2, do=do)
288
+ self.proj_0 = ConvBatchNormReLU(do, do, 1, 1, 0, 1, leaky=leaky)
289
+
290
+ self.fpn = SFA(in_channels=self.visu_dim, out_channels=do)
291
+
292
+
293
+ ## use projections?
294
+ self.use_projections = use_projections
295
+ if self.use_projections :
296
+ self.projection_1 = nn.Linear(512, 512, bias=True)
297
+ else :
298
+ self.projection_1 = None
299
+
300
+
301
+ ## Align dim
302
+ f_dim = 512
303
+ self.fc_2 = nn.Linear(f_dim, f_dim, bias=False)
304
+ self.norm1 = nn.LayerNorm(f_dim)
305
+ self.norm2 = nn.LayerNorm(f_dim)
306
+
307
+ # visual branch
308
+ self.pos_embedding = PositionEmbeddingSine(f_dim)
309
+ encoder_layer = TransformerEncoderLayer(f_dim, nhead=8, dim_feedforward=f_dim,
310
+ dropout=0.1, activation='relu', normalize_before=False)
311
+ self.encoder = TransformerEncoder(encoder_layer, num_layers=2, norm=nn.LayerNorm(f_dim))
312
+
313
+ ## Decoder
314
+ self.mask_decoder = mask_decoder(f_dim, seg_out_stride=2)
315
+
316
+ # text branch
317
+
318
+ ## coef
319
+ self.lang_tf_enc = lang_tf_enc(do, do, do, head_num=8)
320
+ self.proj1 = ConvBatchNormReLU(do, do, 3, 1, 1, 1, leaky=leaky)
321
+ self.proj2 = ConvBatchNormReLU(do, do, 3, 1, 1, 1, leaky=leaky)
322
+ self.proj3 = nn.Conv2d(do, 32, 3, 1, 1, 1)
323
+ self.projout = nn.Linear(26*26*32, 32, bias=False)
324
+
325
+
326
+ self.feature_selector_l = nn.Linear(do, 1, bias=True)
327
+ self.feature_selector_m = nn.Linear(do, 1, bias=True)
328
+
329
+ def resize_pos_embed(self, posemb, posemb_new, hight, width):
330
+ ntok_new = posemb_new.shape[1]
331
+
332
+ posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:]
333
+ ntok_new -= 1
334
+
335
+ gs_old = int(math.sqrt(len(posemb_grid)))
336
+ print('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape, posemb_new.shape, hight, width))
337
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
338
+ posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear')
339
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1)
340
+ posemb = torch.cat([posemb_token, posemb_grid], dim=1)
341
+ return posemb
342
+
343
+
344
+ def build_attention_mask(self):
345
+ # lazily create causal attention mask, with full attention between the vision tokens
346
+ # pytorch uses additive attention mask; fill with -inf
347
+ mask = torch.empty(self.length, self.length)
348
+ mask.fill_(float("-inf"))
349
+ mask.triu_(1) # zero out the lower diagonal
350
+ return mask
351
+
352
+ def forward(self, image, word_id, word_mask):
353
+ ## Visual Module
354
+
355
+ batch_size = image.size(0)
356
+
357
+ # Extract features from vision
358
+ x = self.visumodel.conv1(image)
359
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
360
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
361
+ x = torch.cat([self.visumodel.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
362
+ x = x + self.positional_embedding.to(x.dtype)
363
+ x = self.visumodel.ln_pre(x)
364
+ x = x.permute(1, 0, 2) # NLD -> LND
365
+
366
+ raw_fword = self.textmodel_token_embedding(word_id).squeeze(1)
367
+ raw_fword = raw_fword + self.textmodel_pos_embed
368
+ raw_fword = raw_fword.permute(1, 0, 2) # NLD -> LND
369
+
370
+ visu_list_l = []
371
+ visu_list_m = []
372
+
373
+ scores_l = []
374
+ scores_m = []
375
+
376
+ for i, [blk_visu, blk_lang] in enumerate(zip(self.visu_resblocks, self.textmodel.resblocks)):
377
+ x = blk_visu(x) # [677, bs, 768]
378
+ raw_fword = blk_lang(raw_fword)
379
+
380
+ img_cls = self.vis_select(x[0, :, :]) # [B, C]
381
+ tex_cls = raw_fword[word_id.argmax(dim=-1).reshape(-1), torch.arange(raw_fword.shape[1]), :] # [B, C]
382
+ score = img_cls * tex_cls # [B, C]
383
+ score = score.unsqueeze(1) # [B, 1, C]
384
+
385
+ if i >=3 and i <= 5:
386
+ visu_list_l.append(x)
387
+ scores_l.append(score)
388
+
389
+ if i>=6 and i <=8:
390
+ visu_list_m.append(x)
391
+ scores_m.append(score)
392
+
393
+
394
+ scores_l = torch.cat(scores_l, dim=1) # [B, 3, C]
395
+ scores_m = torch.cat(scores_m, dim=1) # [B, 3, C]
396
+
397
+ scores_l = self.feature_selector_l(scores_l).squeeze(-1) # [B, 3]
398
+ scores_l = F.softmax(scores_l, dim=-1)
399
+ scores_m = self.feature_selector_m(scores_m).squeeze(-1) # [B, 3]
400
+ scores_m = F.softmax(scores_m, dim=-1)
401
+
402
+ visu_list_l = torch.cat(visu_list_l, dim=0).reshape(len(visu_list_l), -1, batch_size, self.visu_dim).permute(0,2,1,3)
403
+ visu_list_m = torch.cat(visu_list_m, dim=0).reshape(len(visu_list_m), -1, batch_size, self.visu_dim).permute(0,2,1,3)
404
+
405
+
406
+ x6 = visu_list_l[scores_l.argmax(dim=-1).reshape(-1), torch.arange(visu_list_l.shape[1]), :, :].permute(1,0,2)
407
+ x9 = visu_list_m[scores_m.argmax(dim=-1).reshape(-1), torch.arange(visu_list_m.shape[1]), :, :].permute(1,0,2)
408
+
409
+
410
+ x6 = x6.permute(1, 0, 2)[:, 1:, :].reshape(-1, 26, 26, self.visu_dim).permute(0, 3, 1, 2)
411
+ x9 = x9.permute(1, 0, 2)[:, 1:, :].reshape(-1, 26, 26, self.visu_dim).permute(0, 3, 1, 2)
412
+ x12 = x.permute(1, 0, 2)[:, 1:, :]
413
+ x12 = x12.reshape(-1, 26, 26, self.visu_dim).permute(0, 3, 1, 2) # [bs, 768, 26, 26]
414
+
415
+
416
+ raw_fword = raw_fword.permute(1, 0, 2)
417
+ raw_fword = self.textmodel_ln_final(raw_fword)
418
+
419
+ if not self.tunelang:
420
+ raw_fword = raw_fword.detach()
421
+
422
+ eos_token = raw_fword[torch.arange(raw_fword.shape[0]), word_id.argmax(dim=-1).reshape(-1), :]
423
+
424
+ F_g = self.fusion(x12, eos_token)
425
+ F_tf = self.fpn([F_g, x9, x6])
426
+
427
+ # Main body
428
+ b, c, h, w = F_tf.shape
429
+
430
+ flatten_length = h*w
431
+ visu_feat = F_tf.reshape(b, c, flatten_length)
432
+ visu_feat = F.relu(visu_feat)
433
+ lang_feat = F.relu(self.fc_2(raw_fword))
434
+
435
+ visu_feat = visu_feat.permute(0, 2, 1)
436
+ pos_embed = self.pos_embedding(visu_feat)
437
+ visu_feat = visu_feat.transpose(0, 1)
438
+ pos_embed = pos_embed.transpose(0, 1)
439
+ visu_feat = self.encoder(visu_feat, pos=pos_embed)
440
+ #[HW B C]
441
+
442
+ visu_feat_ = visu_feat.permute(1,0,2)
443
+
444
+ # mask decoder
445
+ visu_feat = visu_feat.reshape(h, w, b, c)
446
+ visu_feat = visu_feat.permute(2,3,0,1)
447
+ F_coarse_refined = visu_feat
448
+ proto_masks = self.mask_decoder(visu_feat, 2)
449
+
450
+ #[B C H W]
451
+ proto_masks = F.relu(proto_masks)
452
+
453
+ # coef
454
+ coef = self.lang_tf_enc(visu_feat_, lang_feat)
455
+ coef = coef.view(b, h, w, c)
456
+ coef = coef.permute(0, 3, 1, 2)
457
+ F_fine = coef
458
+
459
+ coef = self.proj1(coef)
460
+ coef = self.proj2(coef)
461
+ coef = self.proj3(coef)
462
+ coef = coef.permute(0, 2, 3, 1)
463
+ coef = coef.contiguous().view(b, h*w*32)
464
+ # [b, 1, 32]
465
+ coef = self.projout(coef).unsqueeze(-1)
466
+ coef = F.tanh(coef)
467
+
468
+ # mask assemble
469
+ proto_masks = proto_masks.permute(0, 2, 3, 1)
470
+ proto_masks = proto_masks.view(b, -1, 32)
471
+ #[B HW N] [32 208*208 32]
472
+
473
+ mask_out = torch.bmm(proto_masks, coef, out=None)
474
+ mask_out = mask_out.view(b, 208, 208, 1)
475
+ mask_out = mask_out.permute(0, 3, 1, 2)
476
+
477
+ if self.fuse_mode == 'coarse' :
478
+ metric_tensor = F_tf
479
+ elif self.fuse_mode == 'refined_coarse' :
480
+ metric_tensor = F_coarse_refined
481
+ elif self.fuse_mode == 'fine' :
482
+ metric_tensor = F_fine
483
+
484
+ if self.use_projections :
485
+ metric_tensor = F.adaptive_avg_pool2d(metric_tensor, (1, 1)).view(metric_tensor.size(0), -1)
486
+ metric_tensor = self.projection_1(metric_tensor).unsqueeze(-1).unsqueeze(-1)
487
+
488
+ return mask_out, metric_tensor
ASDA/model/modules.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .transformer import lang_tf_enc, TransformerEncoderLayer, TransformerEncoder
6
+ from .position_encoding import PositionEmbeddingSine
7
+
8
+ class SFA(nn.Module):
9
+ def __init__(self, in_channels, out_channels, scale_factors = [1, 2, 4], fuse_type="sum"):
10
+ super(SFA, self).__init__()
11
+ self.stages = []
12
+ for idx, scale in enumerate(scale_factors):
13
+ out_dim = out_channels
14
+ if scale == 4.0:
15
+ layers = [
16
+ nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2),
17
+ nn.BatchNorm2d(
18
+ num_features=in_channels // 2, eps=1e-5, momentum=0.999, affine=True),
19
+ nn.GELU(),
20
+ nn.ConvTranspose2d(in_channels // 2, in_channels // 4, kernel_size=2, stride=2),
21
+ ]
22
+ out_dim = in_channels // 4
23
+ elif scale == 2.0:
24
+ layers = [nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)]
25
+ out_dim = in_channels // 2
26
+ elif scale == 1.0:
27
+ layers = []
28
+ out_dim = in_channels
29
+ elif scale == 0.5:
30
+ layers = [nn.MaxPool2d(kernel_size=2, stride=2)]
31
+ else:
32
+ raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
33
+
34
+ layers.extend(
35
+ [
36
+ ConvBatchNormReLU(out_dim, out_channels, 1, 1, 0, 1, leaky=True),
37
+ ConvBatchNormReLU(out_channels, out_channels, 3, 1, 1, 1, leaky=True),
38
+ ]
39
+ )
40
+ layers = nn.Sequential(*layers)
41
+ self.stages.append(layers)
42
+
43
+ self.stages = nn.ModuleList(self.stages)
44
+
45
+ # 假设所有输入特征图的通道数相同
46
+ self.lateral_convs = nn.ModuleList([
47
+ ConvBatchNormReLU(out_channels, out_channels, 1, 1, 0, 1, leaky=True) for _ in range(3)
48
+ ])
49
+
50
+ self.output_convs = nn.ModuleList([
51
+ ConvBatchNormReLU(out_channels, out_channels, 3, 1, 1, 1, leaky=True) for _ in range(3)
52
+ ])
53
+
54
+ self._fuse_type = fuse_type # or "avg"
55
+
56
+ self.downsample = nn.MaxPool2d(kernel_size=4, stride=4, padding=0)
57
+
58
+ def forward(self, x):
59
+ '''
60
+ Args:
61
+ x: list[Tensor], T个特征图,每个特征图的尺寸和通道数相同,[x12, x9, x6]
62
+ '''
63
+ # 模拟bottom-up, 获取多尺度特征图
64
+ mutil_scale_features = []
65
+ for idx, stage in enumerate(self.stages):
66
+ mutil_scale_features.append(stage(x[idx]))
67
+
68
+ # top-down
69
+ results = []
70
+ prev_features = self.lateral_convs[0](mutil_scale_features[0])
71
+
72
+ for idx, (lateral_conv, output_conv) in enumerate(
73
+ zip(self.lateral_convs, self.output_convs)
74
+ ):
75
+ # Slicing of ModuleList is not supported https://github.com/pytorch/pytorch/issues/47336
76
+ # Therefore we loop over all modules but skip the first one
77
+ if idx > 0:
78
+ features = mutil_scale_features[idx]
79
+ top_down_features = F.interpolate(prev_features, scale_factor=2.0, mode="nearest")
80
+ lateral_features = lateral_conv(features) # 1x1卷积
81
+ prev_features = lateral_features + top_down_features
82
+ if self._fuse_type == "avg":
83
+ prev_features /= 2
84
+ results.insert(0, output_conv(prev_features))
85
+
86
+ fused_features = self.downsample(results[0]) # 1/4分辨率,需要转换为1/16分辨率
87
+
88
+ return fused_features
89
+
90
+ class ConvBatchNormReLU(nn.Module):
91
+ def __init__(
92
+ self,
93
+ in_channels,
94
+ out_channels,
95
+ kernel_size,
96
+ stride,
97
+ padding,
98
+ dilation,
99
+ leaky=False,
100
+ relu=True,
101
+ instance=False,
102
+ ):
103
+ super(ConvBatchNormReLU, self).__init__()
104
+ self.conv = nn.Conv2d(
105
+ in_channels=in_channels,
106
+ out_channels=out_channels,
107
+ kernel_size=kernel_size,
108
+ stride=stride,
109
+ padding=padding,
110
+ dilation=dilation,
111
+ bias=False)
112
+ # nn.init.kaiming_normal_(self.conv.weight, mode="fan_out", nonlinearity="leaky_relu" if leaky else "relu")
113
+
114
+ if instance:
115
+ self.bn = nn.InstanceNorm2d(num_features=out_channels)
116
+ else:
117
+ self.bn = nn.BatchNorm2d(
118
+ num_features=out_channels, eps=1e-5, momentum=0.999, affine=True
119
+ )
120
+
121
+ if leaky:
122
+ self.relu = nn.LeakyReLU(0.1)
123
+ elif relu:
124
+ self.relu = nn.ReLU()
125
+ def forward(self, x):
126
+ x = self.conv(x)
127
+ x = self.bn(x)
128
+ x = self.relu(x)
129
+ return x
130
+
131
+ # class ConvBatchNormReLU(nn.Sequential):
132
+ # def __init__(
133
+ # self,
134
+ # in_channels,
135
+ # out_channels,
136
+ # kernel_size,
137
+ # stride,
138
+ # padding,
139
+ # dilation,
140
+ # leaky=False,
141
+ # relu=True,
142
+ # instance=False,
143
+ # ):
144
+ # super(ConvBatchNormReLU, self).__init__()
145
+
146
+ # conv = nn.Conv2d(
147
+ # in_channels=in_channels,
148
+ # out_channels=out_channels,
149
+ # kernel_size=kernel_size,
150
+ # stride=stride,
151
+ # padding=padding,
152
+ # dilation=dilation,
153
+ # bias=False,
154
+ # )
155
+ # nn.init.kaiming_normal_(conv.weight, mode="fan_out", nonlinearity="leaky_relu" if leaky else "relu")
156
+
157
+ # self.add_module(
158
+ # "conv", conv
159
+ # )
160
+
161
+ # if instance:
162
+ # self.add_module(
163
+ # "bn",
164
+ # nn.InstanceNorm2d(num_features=out_channels),
165
+ # )
166
+ # else:
167
+ # self.add_module(
168
+ # "bn",
169
+ # nn.BatchNorm2d(
170
+ # num_features=out_channels, eps=1e-5, momentum=0.999, affine=True
171
+ # ),
172
+ # )
173
+
174
+ # if leaky:
175
+ # self.add_module("relu", nn.LeakyReLU(0.1))
176
+ # elif relu:
177
+ # self.add_module("relu", nn.ReLU())
178
+
179
+ # def forward(self, x):
180
+ # return super(ConvBatchNormReLU, self).forward(x)
181
+
182
+
183
+ def concat_coord(x):
184
+ ins_feat = x # [bt, c, h, w] [512, 26, 26]
185
+ batch_size, c, h, w = x.size()
186
+
187
+ float_h = float(h)
188
+ float_w = float(w)
189
+
190
+ y_range = torch.arange(0., float_h, dtype=torch.float32)
191
+ y_range = 2.0 * y_range / (float_h - 1.0) - 1.0
192
+ x_range = torch.arange(0., float_w, dtype=torch.float32)
193
+ x_range = 2.0 * x_range / (float_w - 1.0) - 1.0
194
+ x_range = x_range[None, :]
195
+ y_range = y_range[:, None]
196
+ x = x_range.repeat(h, 1)
197
+ y = y_range.repeat(1, w)
198
+
199
+ x = x[None, None, :, :]
200
+ y = y[None, None, :, :]
201
+ x = x.repeat(batch_size, 1, 1, 1)
202
+ y = y.repeat(batch_size, 1, 1, 1)
203
+ x = x.cuda()
204
+ y = y.cuda()
205
+
206
+ ins_feat_out = torch.cat((ins_feat, x, x, x, y, y, y), 1)
207
+
208
+ return ins_feat_out
209
+
210
+
211
+ class query_generator(nn.Module):
212
+ def __init__(self, input, output, leaky=True):
213
+ super(query_generator, self).__init__()
214
+ self.proj1 = ConvBatchNormReLU(input+6, input+6, 3, 1, 1, 1, leaky=leaky)
215
+ self.proj2 = ConvBatchNormReLU(input+6, input+6, 3, 1, 1, 1, leaky=leaky)
216
+ self.proj3 = ConvBatchNormReLU(input+6, input+6, 3, 1, 1, 1, leaky=leaky)
217
+ self.proj = nn.Conv2d(input+6, output, 1, 1, 0, 1)
218
+
219
+ def forward(self, x):
220
+ x = concat_coord(x)
221
+ x = x + self.proj1(x)
222
+ x = x + self.proj2(x)
223
+ x = x + self.proj3(x)
224
+ x = self.proj(x)
225
+ return x
226
+
227
+
228
+ class KLM(nn.Module):
229
+ def __init__(self, f_dim, feat_dim):
230
+ super(KLM, self).__init__()
231
+ self.lang_tf_enc = lang_tf_enc(f_dim, f_dim, f_dim, head_num=8)
232
+
233
+ self.pos_embedding = PositionEmbeddingSine(f_dim)
234
+ encoder_layer = TransformerEncoderLayer(f_dim, nhead=8, dim_feedforward=f_dim,
235
+ dropout=0.1, activation='relu', normalize_before=False)
236
+ self.encoder = TransformerEncoder(encoder_layer, num_layers=2, norm=nn.LayerNorm(f_dim))
237
+
238
+ # self.catproj = nn.Linear(f_dim * 2, f_dim)
239
+
240
+ self.fc_ker = nn.Linear(f_dim, feat_dim + feat_dim)
241
+ self.fc_vis = nn.Linear(f_dim, feat_dim + feat_dim)
242
+ self.ker_norm = nn.LayerNorm(feat_dim)
243
+ self.vis_norm = nn.LayerNorm(feat_dim)
244
+
245
+ self.channel_fc = nn.Linear(feat_dim, feat_dim)
246
+ self.channel_norm = nn.LayerNorm(feat_dim)
247
+
248
+ self.spatial_fc = nn.Linear(feat_dim, feat_dim)
249
+ self.spatial_norm = nn.LayerNorm(feat_dim)
250
+
251
+ self.out_fc = nn.Linear(feat_dim, f_dim)
252
+ self.out_norm = nn.LayerNorm(f_dim)
253
+
254
+ self.d_model = f_dim
255
+ self.feat_dim = feat_dim
256
+ self.resolution_size = 26
257
+
258
+ def forward(self, kernel, lang_feat, visu_feat):
259
+ # kernel B x N x C
260
+ # lang_feat B x T x C
261
+ # visu_feat B x C x HW
262
+ kernel = self.lang_tf_enc(kernel, lang_feat)
263
+ # B x N x C
264
+ bs, c, hw = visu_feat.shape
265
+ bq, nq, cq = kernel.shape
266
+ bl, ll, cl = lang_feat.shape
267
+
268
+ # Image Attention
269
+ visu_feat = visu_feat.permute(0, 2, 1)
270
+ # B x HW x C
271
+ pos_embed = self.pos_embedding(visu_feat)
272
+ # B x HW x C
273
+
274
+ visu_feat = visu_feat.transpose(0, 1)
275
+ pos_embed = pos_embed.transpose(0, 1)
276
+ visu_feat_ = self.encoder(visu_feat, pos=pos_embed) # HW x B x C
277
+ visu_feat_ = visu_feat_.transpose(0, 1) # B x HW x C
278
+
279
+ # repeat visual feats
280
+ visu_feat = visu_feat_.unsqueeze(dim=1) # B x 1 x HW x C
281
+ kernel = kernel.unsqueeze(dim=2) # B x N x 1 x C
282
+ lang_feat = lang_feat.unsqueeze(dim=2) # B x Q x 1 x C
283
+
284
+ kernel_in = self.fc_ker(kernel)
285
+ kernel_out = kernel_in[:, :, :, self.feat_dim:]
286
+ kernel_in = kernel_in[:, :, :, :self.feat_dim]
287
+
288
+ vis_in = self.fc_vis(visu_feat)
289
+ vis_out = vis_in[:, :, :, self.feat_dim:]
290
+ vis_in = vis_in[:, :, :, :self.feat_dim]
291
+
292
+ gate_feat = self.ker_norm(kernel_in) * self.vis_norm(vis_in)
293
+ #[B N HW 64]
294
+
295
+ channel_gate = self.channel_norm(self.channel_fc(gate_feat))
296
+ channel_gate = channel_gate.mean(2, keepdim=True)
297
+ channel_gate = torch.sigmoid(channel_gate)
298
+ # B x N x 1 x C
299
+
300
+ spatial_gate = self.spatial_norm(self.spatial_fc(gate_feat))
301
+ # spatial_gate = spatial_gate.mean(3, keepdim=True)
302
+ spatial_gate = torch.sigmoid(spatial_gate)
303
+ # B x N x HW x C
304
+
305
+ channel_gate = (1 + channel_gate) * kernel_out # B x N x 1 x C
306
+ channel_gate = channel_gate.squeeze(2) # B x N x C
307
+
308
+ spatial_gate = (1 + spatial_gate) * vis_out # B x N x HW x C
309
+ spatial_gate = spatial_gate.mean(2) # B x N x C
310
+
311
+ gate_feat = (channel_gate + spatial_gate) / 2
312
+ # [B N 64]
313
+ gate_feat = self.out_fc(gate_feat)
314
+ gate_feat = self.out_norm(gate_feat)
315
+ gate_feat = F.relu(gate_feat)
316
+ #[B N C]
317
+
318
+ #visu_feat_.transpose(1, 2) [B C HW]
319
+ return gate_feat, visu_feat_.transpose(1, 2)
320
+
321
+
322
+ class KAM(nn.Module):
323
+ def __init__(self, f_dim, num_query):
324
+ super(KAM, self).__init__()
325
+
326
+ self.k_size = 1
327
+
328
+ self.proj = nn.Linear(26*26, f_dim)
329
+
330
+ self.fc_k = nn.Linear(f_dim, f_dim)
331
+ self.fc_m = nn.Linear(f_dim, f_dim)
332
+ self.fc_fus = nn.Linear(f_dim * 2, f_dim)
333
+ self.fc_out = nn.Linear(f_dim, 1)
334
+
335
+ self.outproj = ConvBatchNormReLU(num_query, f_dim, 3, 1, 1, 1, leaky=True)
336
+ self.maskproj = nn.Conv2d(f_dim, 1, 3, 1, 1, 1)
337
+
338
+ self.bn = nn.BatchNorm2d(f_dim)
339
+
340
+ self.mask_fcs = []
341
+ for _ in range(3):
342
+ self.mask_fcs.append(nn.Linear(f_dim, f_dim, bias=False))
343
+ self.mask_fcs.append(nn.LayerNorm(f_dim))
344
+ self.mask_fcs.append(nn.ReLU())
345
+ self.mask_fcs = nn.Sequential(*self.mask_fcs)
346
+
347
+
348
+ def forward(self, kernel, visu_feat):
349
+ # kernel [B N C]
350
+ # visu_feat [B C HW]
351
+ kernel = self.mask_fcs(kernel)
352
+
353
+ B, N, C = kernel.shape
354
+ kernel_ = kernel
355
+ kernel = kernel.reshape(B, N, -1, C).permute(0, 1, 3, 2) # B x N x C x 1
356
+ kernel = kernel.reshape(B, N, C, self.k_size, self.k_size) # B x N x C x 1 x 1
357
+ #[B N C K K]
358
+ visu_feat_ = visu_feat
359
+ visu_feat = visu_feat.reshape(B, C, 26, 26) # B x C x H x W
360
+
361
+ masks = []
362
+ for i in range(B):
363
+ masks.append(F.conv2d(visu_feat[i: i+1], kernel[i], padding=int(self.k_size // 2))) # 1 x N x H x W
364
+ masks = torch.cat(masks, dim=0) # B x N x H x W
365
+
366
+ feats = masks.reshape(B, N, -1) # B x N x HW
367
+ feats = self.proj(feats) # B x N x C
368
+
369
+ weights_kern = F.relu(self.fc_k(kernel_))
370
+ weights_mask = F.relu(self.fc_m(feats))
371
+
372
+ weights = torch.cat([weights_kern, weights_mask], dim=-1) # B x N x 2C
373
+ weights = F.relu(self.fc_fus(weights)) # B x N x C
374
+ weights = self.fc_out(weights) # B x N x 1
375
+ weights = F.softmax(weights, dim=1) # B x N x 1
376
+
377
+ weights = weights.unsqueeze(-1) # B x N x 1 x 1
378
+
379
+ mask = weights * masks # B x N x H x W
380
+ mask = self.outproj(mask) # B x C x H x W
381
+ mask = self.maskproj(mask)
382
+ mask = F.sigmoid(mask) # B x 1 x H x W
383
+
384
+ visu_feat = visu_feat * mask # B x C x H x W
385
+
386
+ visu_feat = self.bn(visu_feat)
387
+ visu_feat = visu_feat.reshape(B, C, -1) + visu_feat_
388
+ visu_feat = F.relu(visu_feat)
389
+ return visu_feat
390
+
391
+
ASDA/model/position_encoding.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ """
4
+
5
+ Various positional encodings for the transformer.
6
+
7
+ """
8
+
9
+ import math
10
+ import torch
11
+ from torch import nn
12
+
13
+
14
+ class PositionEmbeddingSine(nn.Module):
15
+ """
16
+ This is a more standard version of the position embedding, very similar to the one
17
+ used by the Attention is all you need paper, generalized to work on images.
18
+ """
19
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
20
+ super().__init__()
21
+ self.num_pos_feats = num_pos_feats // 2
22
+ self.temperature = temperature
23
+ self.normalize = normalize
24
+ if scale is not None and normalize is False:
25
+ raise ValueError("normalize should be True if scale is passed")
26
+ if scale is None:
27
+ scale = 2 * math.pi
28
+ self.scale = scale
29
+
30
+ def forward(self, f_s):
31
+ not_mask = torch.ones_like(f_s[:, :, 0].reshape(-1, 26, 26))
32
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
33
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
34
+ if self.normalize:
35
+ eps = 1e-6
36
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
37
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
38
+
39
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=f_s.device)
40
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2,rounding_mode = 'floor')/ self.num_pos_feats)
41
+
42
+ pos_x = x_embed[:, :, :, None] / dim_t
43
+ pos_y = y_embed[:, :, :, None] / dim_t
44
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
45
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
46
+ pos = torch.cat((pos_y, pos_x), dim=3).reshape_as(f_s)
47
+ return pos
ASDA/model/transformer.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ DETR Transformer class.
4
+ Copy-paste from torch.nn.Transformer with modifications:
5
+ * positional encodings are passed in MHattention
6
+ * extra LN at the end of encoder is removed
7
+ * decoder returns a stack of activations from all decoding layers
8
+ """
9
+ import copy
10
+ from typing import Optional
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn, Tensor
14
+ from .position_encoding import *
15
+
16
+
17
+ class lang_tf_enc(nn.Module):
18
+
19
+ def __init__(self, input_1, input_2, hidden_dim, head_num, dropout=0.1):
20
+ super(lang_tf_enc, self).__init__()
21
+ self.pos_embedding_1 = PositionEmbeddingSine(input_2, normalize=True)
22
+ self.pos_embedding_2 = PositionEmbeddingSine(input_1, normalize=True)
23
+ self.dense_q = nn.Linear(input_1, hidden_dim)
24
+ self.dense_k = nn.Linear(input_2, hidden_dim)
25
+ self.dense_v = nn.Linear(input_2, hidden_dim)
26
+ self.self_attn = nn.MultiheadAttention(hidden_dim, head_num, dropout=dropout)
27
+
28
+ self.forward_dim = 2048
29
+ self.norm1 = nn.LayerNorm(hidden_dim)
30
+ self.norm2 = nn.LayerNorm(hidden_dim)
31
+ self.linear1 = nn.Linear(hidden_dim, self.forward_dim)
32
+ self.linear2 = nn.Linear(self.forward_dim, hidden_dim)
33
+ self.activation = _get_activation("relu")
34
+ self.dropout = nn.Dropout(dropout)
35
+
36
+ # @get_local("weights")
37
+ def forward(self, vision_input, lang_input):
38
+ decoder_embed_lang = lang_input
39
+ decoder_embed_vis = vision_input
40
+ q_inp = F.relu(self.dense_q(decoder_embed_vis).permute(1, 0, 2))
41
+ k_inp = F.relu(self.dense_k(decoder_embed_lang).permute(1, 0, 2))
42
+ v_inp = F.relu(self.dense_v(decoder_embed_lang).permute(1, 0, 2))
43
+ lang_input = lang_input.permute(1, 0, 2)
44
+ decoded_layer, weights = self.self_attn(q_inp, k_inp, v_inp)
45
+
46
+ decoded_layer = decoded_layer.permute(1, 0, 2)
47
+ add_layer = decoded_layer + vision_input
48
+
49
+ add_layer = self.norm1(add_layer)
50
+ add_layer2 = self.linear2(self.dropout(self.activation(self.linear1(add_layer))))
51
+ add_layer = add_layer + self.dropout(add_layer2)
52
+ add_layer = self.norm2(add_layer)
53
+
54
+ return add_layer
55
+
56
+
57
+ def _get_clones(module, N):
58
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
59
+
60
+ def _get_activation(activation):
61
+
62
+ if activation == "relu":
63
+ return F.relu
64
+ if activation == "gelu":
65
+ return F.gelu
66
+ if activation == "glu":
67
+ return F.glu
68
+ raise RuntimeError(F"activation shuld be relu/gelu, not {activation}.")
69
+
70
+
71
+ class TransformerEncoder(nn.Module):
72
+
73
+ def __init__(self, encoder_layer, num_layers, norm=None):
74
+ super().__init__()
75
+ self.layers = _get_clones(encoder_layer, num_layers)
76
+ self.num_layers = num_layers
77
+ self.norm = norm
78
+
79
+ def forward(self, src, pos: Optional[Tensor] = None):
80
+ output = src
81
+
82
+ for layer in self.layers:
83
+ output = layer(output, pos=pos)
84
+
85
+ if self.norm is not None:
86
+ output = self.norm(output)
87
+
88
+ return output
89
+
90
+
91
+ class TransformerDecoder(nn.Module):
92
+
93
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
94
+ super().__init__()
95
+ self.layers = _get_clones(decoder_layer, num_layers)
96
+ self.num_layers = num_layers
97
+ self.norm = norm
98
+ self.return_intermediate = return_intermediate
99
+
100
+ def forward(self, tgt, memory, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None):
101
+ output = tgt
102
+
103
+ intermediate = []
104
+
105
+ for layer in self.layers:
106
+ output = layer(output, memory, pos=pos, query_pos=query_pos)
107
+ if self.return_intermediate:
108
+ intermediate.append(self.norm(output))
109
+
110
+
111
+ if self.norm is not None:
112
+ output = self.norm(output)
113
+ if self.return_intermediate:
114
+ intermediate.pop()
115
+ intermediate.append(output)
116
+
117
+
118
+ if self.return_intermediate:
119
+ return torch.stack(intermediate)
120
+
121
+ return output
122
+
123
+
124
+ class TransformerEncoderLayer(nn.Module):
125
+
126
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
127
+ activation="relu", normalize_before=False):
128
+ super().__init__()
129
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
130
+ # Implementation of Feedforward model
131
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
132
+ self.dropout = nn.Dropout(dropout)
133
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
134
+
135
+ self.norm1 = nn.LayerNorm(d_model)
136
+ self.norm2 = nn.LayerNorm(d_model)
137
+ self.dropout1 = nn.Dropout(dropout)
138
+ self.dropout2 = nn.Dropout(dropout)
139
+
140
+ self.activation = _get_activation_fn(activation)
141
+ self.normalize_before = normalize_before
142
+
143
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
144
+ return tensor if pos is None else tensor + pos
145
+
146
+ # @get_local("weights")
147
+ def forward_post(self, src, pos: Optional[Tensor] = None):
148
+ q = k = self.with_pos_embed(src, pos)
149
+ src2, weights = self.self_attn(q, k, value=src, need_weights=False)
150
+
151
+ src = src + self.dropout1(src2)
152
+ src = self.norm1(src)
153
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
154
+ src = src + self.dropout2(src2)
155
+ src = self.norm2(src)
156
+ return src
157
+
158
+ def forward_pre(self, src, pos: Optional[Tensor] = None):
159
+ src2 = self.norm1(src)
160
+ q = k = self.with_pos_embed(src2, pos)
161
+ src2, weights = self.self_attn(q, k, value=src2)
162
+ src = src + self.dropout1(src2)
163
+ src2 = self.norm2(src)
164
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
165
+ src = src + self.dropout2(src2)
166
+ return src
167
+
168
+ def forward(self, src, pos: Optional[Tensor] = None):
169
+ if self.normalize_before:
170
+ return self.forward_pre(src, pos)
171
+ return self.forward_post(src, pos)
172
+
173
+
174
+ class TransformerDecoderLayer(nn.Module):
175
+
176
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
177
+ activation="relu", normalize_before=False):
178
+ super().__init__()
179
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
180
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
181
+ # Implementation of Feedforward model
182
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
183
+ self.dropout = nn.Dropout(dropout)
184
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
185
+
186
+ self.norm1 = nn.LayerNorm(d_model)
187
+ self.norm2 = nn.LayerNorm(d_model)
188
+ self.norm3 = nn.LayerNorm(d_model)
189
+ self.dropout1 = nn.Dropout(dropout)
190
+ self.dropout2 = nn.Dropout(dropout)
191
+ self.dropout3 = nn.Dropout(dropout)
192
+
193
+ self.activation = _get_activation_fn(activation)
194
+ self.normalize_before = normalize_before
195
+
196
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
197
+ return tensor if pos is None else tensor + pos
198
+
199
+ def forward_post(self, tgt, memory, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None):
200
+ q = k = self.with_pos_embed(tgt, query_pos)
201
+ tgt2, weights = self.self_attn(q, k, value=tgt)
202
+ tgt = tgt + self.dropout1(tgt2)
203
+ tgt = self.norm1(tgt)
204
+ tgt2, weights = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
205
+ key=self.with_pos_embed(memory, pos),
206
+ value=memory)
207
+ tgt = tgt + self.dropout2(tgt2)
208
+ tgt = self.norm2(tgt)
209
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
210
+ tgt = tgt + self.dropout3(tgt2)
211
+ tgt = self.norm3(tgt)
212
+ return tgt
213
+
214
+ def forward_pre(self, tgt, memory, pos: Optional[Tensor] = None,
215
+ query_pos: Optional[Tensor] = None):
216
+ tgt2 = self.norm1(tgt)
217
+ q = k = self.with_pos_embed(tgt2, query_pos)
218
+ tgt2, weights = self.self_attn(q, k, value=tgt2)
219
+ tgt = tgt + self.dropout1(tgt2)
220
+ tgt2 = self.norm2(tgt)
221
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
222
+ key=self.with_pos_embed(memory, pos),
223
+ value=memory)
224
+ tgt = tgt + self.dropout2(tgt2)
225
+ tgt2 = self.norm3(tgt)
226
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
227
+ tgt = tgt + self.dropout3(tgt2)
228
+ return tgt
229
+
230
+ def forward(self, tgt, memory, pos: Optional[Tensor] = None,
231
+ query_pos: Optional[Tensor] = None):
232
+ if self.normalize_before:
233
+ return self.forward_pre(tgt, memory, pos, query_pos)
234
+ return self.forward_post(tgt, memory, pos, query_pos)
235
+
236
+
237
+ def _get_clones(module, N):
238
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
239
+
240
+
241
+ def _get_activation_fn(activation):
242
+
243
+ """Return an activation function given a string"""
244
+ if activation == "relu":
245
+ return F.relu
246
+ if activation == "gelu":
247
+ return F.gelu
248
+ if activation == "glu":
249
+ return F.glu
250
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
251
+