|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
from .modules import ConvBatchNormReLU, SFA |
|
|
from .modules import * |
|
|
from .position_encoding import * |
|
|
|
|
|
import clip |
|
|
import math |
|
|
import sys |
|
|
|
|
|
sys.path.append('../') |
|
|
from utils.utils import * |
|
|
|
|
|
|
|
|
class Simple_fusion(nn.Module): |
|
|
def __init__(self, visual_dim=1024, text_dim=768, proj_dim=1024, jemb_drop_out=0.1, leaky=True): |
|
|
super(Simple_fusion, self).__init__() |
|
|
self.proj_dim = proj_dim |
|
|
self.mapping_visu = ConvBatchNormReLU(visual_dim, proj_dim, 1, 1, 0, 1, leaky=leaky) |
|
|
self.lang_attn = nn.Sequential( |
|
|
nn.Linear(text_dim, text_dim), |
|
|
nn.Tanh(), |
|
|
nn.Dropout(jemb_drop_out), |
|
|
nn.Softmax(dim=1)) |
|
|
|
|
|
self.lang_proj = nn.Sequential( |
|
|
nn.Linear(text_dim, proj_dim), |
|
|
nn.BatchNorm1d(proj_dim), |
|
|
nn.LeakyReLU(0.1)) |
|
|
|
|
|
self.fusion = nn.Sequential( |
|
|
nn.BatchNorm2d(proj_dim), |
|
|
nn.LeakyReLU(0.1)) |
|
|
|
|
|
def forward(self, visual_feat, lang_feat): |
|
|
|
|
|
visual_feat_proj = self.mapping_visu(visual_feat) |
|
|
|
|
|
""" |
|
|
# lang attn |
|
|
lang_feat_attn = self.lang_attn(lang_feat) #[bt, 15, 768] |
|
|
lang_feat_new = lang_feat * lang_feat_attn |
|
|
lang_feat_new = lang_feat_new.sum(dim=1) #[bt, 768] |
|
|
""" |
|
|
|
|
|
lang_feat = lang_feat.squeeze(1) |
|
|
|
|
|
|
|
|
lang_feat_new = self.lang_proj(lang_feat) |
|
|
|
|
|
|
|
|
h, w = visual_feat.shape[-2], visual_feat.shape[-1] |
|
|
lang_feat_new_tile = lang_feat_new.view(-1, self.proj_dim, 1, 1).repeat(1, 1, h, w) |
|
|
fusion_feat = lang_feat_new_tile * visual_feat_proj |
|
|
fusion_feat = self.fusion(fusion_feat) |
|
|
return fusion_feat |
|
|
|
|
|
class up_proj_cat_proj(nn.Module): |
|
|
def __init__(self, input_1, input_2, do=512, leaky=True): |
|
|
super(up_proj_cat_proj, self).__init__() |
|
|
self.proj1 = ConvBatchNormReLU(input_2, input_2, 1, 1, 0, 1, leaky=leaky) |
|
|
self.proj2 = ConvBatchNormReLU(input_1+input_2, do, 1, 1, 0, 1, leaky=leaky) |
|
|
|
|
|
def forward(self, x, y): |
|
|
x = F.interpolate(x, scale_factor=2, mode='nearest') |
|
|
y = self.proj1(y) |
|
|
out = torch.cat([x,y], dim=1) |
|
|
out = self.proj2(out) |
|
|
return out |
|
|
|
|
|
class pool_proj_cat_proj(nn.Module): |
|
|
def __init__(self, input_1, input_2, do=512, leaky=True): |
|
|
super(pool_proj_cat_proj, self).__init__() |
|
|
self.downsample = nn.AvgPool2d(2, 2) |
|
|
self.proj1 = ConvBatchNormReLU(input_2, do // 2, 1, 1, 0, 1, leaky=leaky) |
|
|
self.proj2 = ConvBatchNormReLU(do // 2, do, 3, 1, 1, 1, leaky=leaky) |
|
|
self.proj3 = ConvBatchNormReLU(input_1+do, do, 1, 1, 0, 1, leaky=leaky) |
|
|
|
|
|
def forward(self, x, y): |
|
|
y = self.downsample(y) |
|
|
y = self.proj1(y) |
|
|
y = self.proj2(y) |
|
|
output = self.proj3(torch.cat([x,y], dim=1)) |
|
|
return output |
|
|
|
|
|
class proj_cat_proj(nn.Module): |
|
|
def __init__(self, input_1, input_2, do=512, leaky=True): |
|
|
super(proj_cat_proj, self).__init__() |
|
|
self.proj1 = ConvBatchNormReLU(input_2, input_2, 1, 1, 0, 1, leaky=leaky) |
|
|
self.proj2 = ConvBatchNormReLU(input_1 + input_2, do, 1, 1, 0, 1, leaky=leaky) |
|
|
|
|
|
def forward(self, x, y): |
|
|
y = self.proj1(y) |
|
|
out = torch.cat([x, y], dim=1) |
|
|
out = self.proj2(out) |
|
|
return out |
|
|
|
|
|
class proj_cat(nn.Module): |
|
|
def __init__(self, input_1, input_2, do=512, leaky=True): |
|
|
super(proj_cat, self).__init__() |
|
|
self.proj1 = ConvBatchNormReLU(input_1, do // 2, 1, 1, 0, 1, leaky=leaky) |
|
|
self.proj2 = ConvBatchNormReLU(do // 2, do, 3, 1, 1, 1, leaky=leaky) |
|
|
|
|
|
def forward(self, x, y): |
|
|
x = self.proj1(x) |
|
|
x = self.proj2(x) |
|
|
output = torch.cat([x,y], dim=1) |
|
|
return output |
|
|
|
|
|
class mask_decoder(nn.Module): |
|
|
def __init__(self, input_1, seg_out_stride=2, leaky=True): |
|
|
super(mask_decoder, self).__init__() |
|
|
self.proj1 = ConvBatchNormReLU(input_1, input_1//2, 3, 1, 1, 1, leaky=leaky) |
|
|
self.proj2 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky) |
|
|
|
|
|
self.proj3 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky) |
|
|
self.proj4 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky) |
|
|
self.proj5 = ConvBatchNormReLU(input_1//2, input_1//2, 3, 1, 1, 1, leaky=leaky) |
|
|
|
|
|
self.proj = nn.Conv2d(input_1//2, 32, 3, 1, 1, 1) |
|
|
|
|
|
def forward(self, x, seg_out_stride): |
|
|
x = self.proj1(x) |
|
|
x = self.proj2(x) |
|
|
|
|
|
|
|
|
if seg_out_stride <= 8: |
|
|
x = F.interpolate(x, scale_factor=2, mode='nearest') |
|
|
x = self.proj3(x) |
|
|
|
|
|
if seg_out_stride <= 4: |
|
|
x = F.interpolate(x, scale_factor=2, mode='nearest') |
|
|
x = self.proj4(x) |
|
|
|
|
|
if seg_out_stride <= 2: |
|
|
x = F.interpolate(x, scale_factor=2, mode='nearest') |
|
|
x = self.proj5(x) |
|
|
|
|
|
x = self.proj(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QuickGELU(nn.Module): |
|
|
def forward(self, x: torch.Tensor): |
|
|
return x * torch.sigmoid(1.702 * x) |
|
|
|
|
|
class ResidualAttentionblk(nn.Module): |
|
|
def __init__(self, clip_module): |
|
|
super().__init__() |
|
|
|
|
|
self.clip_module = clip_module |
|
|
|
|
|
self.selected_tokens = int(676 * 0.8) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None, lang_tokens=None, index=0): |
|
|
|
|
|
|
|
|
if lang_tokens is None: |
|
|
x = x + self.clip_module.attention(self.clip_module.ln_1(x)) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
N, B, C = x.shape |
|
|
cls_x = x[:1, :, :] |
|
|
x = x[1:, :, :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
score = torch.bmm(x.transpose(0, 1), lang_tokens.permute(1, 2, 0)).mean(dim=-1) |
|
|
score = score.transpose(0, 1) |
|
|
|
|
|
sorted_scores, sorted_indices = torch.sort(score, descending=True, dim=0) |
|
|
|
|
|
|
|
|
high_mask = torch.ones_like(sorted_scores) |
|
|
for i in range(B): |
|
|
high_mask[sorted_indices[self.selected_tokens:, i], i] = 0 |
|
|
high_mask = high_mask > 0.5 |
|
|
|
|
|
delta_x = x[high_mask].reshape(-1, B, C) |
|
|
low_x = x[~high_mask].reshape(-1, B, C) |
|
|
low_score = score[~high_mask].reshape(-1, B, 1) |
|
|
|
|
|
low_x = low_x * torch.softmax(low_score, dim=0) |
|
|
low_x = low_x.sum(dim=0, keepdim=True) |
|
|
|
|
|
delta_x = torch.cat([cls_x, delta_x, low_x], dim=0) |
|
|
delta_x = self.clip_module.attention(self.clip_module.ln_1(delta_x)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
temple = torch.zeros_like(x).type(delta_x.type()) |
|
|
temple[high_mask] = delta_x[1:-1, :, :].reshape(-1, C) |
|
|
temple[~high_mask] = delta_x[-1:, :, :].reshape(-1, 1, C).repeat(1, 676 - self.selected_tokens, 1).reshape(-1, C) |
|
|
x = x + temple |
|
|
cls_x = cls_x + delta_x[:1, :, :] |
|
|
|
|
|
x = torch.cat([cls_x, x], dim=0) |
|
|
|
|
|
x = x + self.clip_module.mlp(self.clip_module.ln_2(x)) |
|
|
return x |
|
|
|
|
|
class Model(nn.Module): |
|
|
def __init__(self, clip_model='RN50', tunelang=False, fusion_dim=2048, num_query=16, do=512, leaky=True, length=17): |
|
|
super(Model, self).__init__() |
|
|
|
|
|
self.tunelang = tunelang |
|
|
self.length = length |
|
|
|
|
|
|
|
|
clip_models = clip.load(clip_model, jit=False, device=torch.device("cpu"))[0].cuda() |
|
|
|
|
|
self.visumodel = clip_models.visual |
|
|
self.visu_dim = 768 |
|
|
|
|
|
self.cut_list = [] |
|
|
self.visu_resblocks = nn.ModuleList([ResidualAttentionblk(self.visumodel.transformer.resblocks[i]) for i in range(12)]) |
|
|
self.visu_proj = nn.ModuleList([nn.Linear(do, self.visu_dim) for _ in range(len(self.cut_list))]) |
|
|
|
|
|
self.positional_embedding = nn.Parameter(torch.FloatTensor(1, 26 ** 2 + 1, 768)) |
|
|
v = self.resize_pos_embed(self.visumodel.positional_embedding.data.unsqueeze(0), self.positional_embedding, 26, 26) |
|
|
self.positional_embedding.data.copy_(v) |
|
|
|
|
|
self.textmodel = clip_models.transformer |
|
|
self.textmodel_token_embedding = clip_models.token_embedding |
|
|
self.textmodel_pos_embed = nn.Parameter(clip_models.positional_embedding[:self.length, :].unsqueeze(0)) |
|
|
self.textmodel_ln_final = clip_models.ln_final |
|
|
self.textdim = self.textmodel_pos_embed.shape[-1] |
|
|
for module in self.textmodel.resblocks: |
|
|
module.attn_mask = self.build_attention_mask() |
|
|
|
|
|
|
|
|
self.vis_select = nn.Linear(self.visu_dim, do, bias=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.fusion = Simple_fusion(visual_dim=self.visu_dim, text_dim=self.textdim, proj_dim=fusion_dim) |
|
|
|
|
|
|
|
|
self.up_proj_cat_proj_1 = proj_cat_proj(input_1=fusion_dim, input_2=self.visu_dim, do=fusion_dim) |
|
|
self.pool_proj_cat_proj_2 = proj_cat_proj(input_1=fusion_dim, input_2=self.visu_dim, do=do) |
|
|
|
|
|
|
|
|
self.proj_cat = proj_cat(input_1=fusion_dim, input_2=do, do=do) |
|
|
self.up_proj_cat_2 = proj_cat_proj(input_1=fusion_dim, input_2=do * 2, do=do) |
|
|
self.proj_0 = ConvBatchNormReLU(do, do, 1, 1, 0, 1, leaky=leaky) |
|
|
|
|
|
self.fpn = SFA(in_channels=self.visu_dim, out_channels=do) |
|
|
|
|
|
|
|
|
f_dim = 512 |
|
|
self.fc_2 = nn.Linear(f_dim, f_dim, bias=False) |
|
|
self.norm1 = nn.LayerNorm(f_dim) |
|
|
self.norm2 = nn.LayerNorm(f_dim) |
|
|
|
|
|
|
|
|
self.pos_embedding = PositionEmbeddingSine(f_dim) |
|
|
encoder_layer = TransformerEncoderLayer(f_dim, nhead=8, dim_feedforward=f_dim, |
|
|
dropout=0.1, activation='relu', normalize_before=False) |
|
|
self.encoder = TransformerEncoder(encoder_layer, num_layers=2, norm=nn.LayerNorm(f_dim)) |
|
|
|
|
|
|
|
|
self.mask_decoder = mask_decoder(f_dim, seg_out_stride=2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.lang_tf_enc = lang_tf_enc(do, do, do, head_num=8) |
|
|
self.proj1 = ConvBatchNormReLU(do, do, 3, 1, 1, 1, leaky=leaky) |
|
|
self.proj2 = ConvBatchNormReLU(do, do, 3, 1, 1, 1, leaky=leaky) |
|
|
self.proj3 = nn.Conv2d(do, 32, 3, 1, 1, 1) |
|
|
self.projout = nn.Linear(26*26*32, 32, bias=False) |
|
|
|
|
|
|
|
|
self.feature_selector_l = nn.Linear(do, 1, bias=True) |
|
|
self.feature_selector_m = nn.Linear(do, 1, bias=True) |
|
|
|
|
|
def resize_pos_embed(self, posemb, posemb_new, hight, width): |
|
|
ntok_new = posemb_new.shape[1] |
|
|
|
|
|
posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:] |
|
|
ntok_new -= 1 |
|
|
|
|
|
gs_old = int(math.sqrt(len(posemb_grid))) |
|
|
print('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape, posemb_new.shape, hight, width)) |
|
|
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) |
|
|
posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear') |
|
|
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1) |
|
|
posemb = torch.cat([posemb_token, posemb_grid], dim=1) |
|
|
return posemb |
|
|
|
|
|
|
|
|
def build_attention_mask(self): |
|
|
|
|
|
|
|
|
mask = torch.empty(self.length, self.length) |
|
|
mask.fill_(float("-inf")) |
|
|
mask.triu_(1) |
|
|
return mask |
|
|
|
|
|
def forward(self, image, word_id, word_mask): |
|
|
|
|
|
|
|
|
batch_size = image.size(0) |
|
|
|
|
|
|
|
|
x = self.visumodel.conv1(image) |
|
|
x = x.reshape(x.shape[0], x.shape[1], -1) |
|
|
x = x.permute(0, 2, 1) |
|
|
x = torch.cat([self.visumodel.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) |
|
|
x = x + self.positional_embedding.to(x.dtype) |
|
|
x = self.visumodel.ln_pre(x) |
|
|
x = x.permute(1, 0, 2) |
|
|
|
|
|
raw_fword = self.textmodel_token_embedding(word_id).squeeze(1) |
|
|
raw_fword = raw_fword + self.textmodel_pos_embed |
|
|
raw_fword = raw_fword.permute(1, 0, 2) |
|
|
|
|
|
visu_list_l = [] |
|
|
visu_list_m = [] |
|
|
|
|
|
scores_l = [] |
|
|
scores_m = [] |
|
|
|
|
|
for i, [blk_visu, blk_lang] in enumerate(zip(self.visu_resblocks, self.textmodel.resblocks)): |
|
|
x = blk_visu(x) |
|
|
raw_fword = blk_lang(raw_fword) |
|
|
|
|
|
img_cls = self.vis_select(x[0, :, :]) |
|
|
tex_cls = raw_fword[word_id.argmax(dim=-1).reshape(-1), torch.arange(raw_fword.shape[1]), :] |
|
|
score = img_cls * tex_cls |
|
|
score = score.unsqueeze(1) |
|
|
|
|
|
if i >=3 and i <= 5: |
|
|
visu_list_l.append(x) |
|
|
scores_l.append(score) |
|
|
|
|
|
if i>=6 and i <=8: |
|
|
visu_list_m.append(x) |
|
|
scores_m.append(score) |
|
|
|
|
|
|
|
|
scores_l = torch.cat(scores_l, dim=1) |
|
|
scores_m = torch.cat(scores_m, dim=1) |
|
|
|
|
|
scores_l = self.feature_selector_l(scores_l).squeeze(-1) |
|
|
scores_l = F.softmax(scores_l, dim=-1) |
|
|
scores_m = self.feature_selector_m(scores_m).squeeze(-1) |
|
|
scores_m = F.softmax(scores_m, dim=-1) |
|
|
|
|
|
visu_list_l = torch.cat(visu_list_l, dim=0).reshape(len(visu_list_l), -1, batch_size, self.visu_dim).permute(0,2,1,3) |
|
|
visu_list_m = torch.cat(visu_list_m, dim=0).reshape(len(visu_list_m), -1, batch_size, self.visu_dim).permute(0,2,1,3) |
|
|
|
|
|
|
|
|
x6 = visu_list_l[scores_l.argmax(dim=-1).reshape(-1), torch.arange(visu_list_l.shape[1]), :, :].permute(1,0,2) |
|
|
x9 = visu_list_m[scores_m.argmax(dim=-1).reshape(-1), torch.arange(visu_list_m.shape[1]), :, :].permute(1,0,2) |
|
|
|
|
|
|
|
|
x6 = x6.permute(1, 0, 2)[:, 1:, :].reshape(-1, 26, 26, self.visu_dim).permute(0, 3, 1, 2) |
|
|
x9 = x9.permute(1, 0, 2)[:, 1:, :].reshape(-1, 26, 26, self.visu_dim).permute(0, 3, 1, 2) |
|
|
x12 = x.permute(1, 0, 2)[:, 1:, :] |
|
|
x12 = x12.reshape(-1, 26, 26, self.visu_dim).permute(0, 3, 1, 2) |
|
|
|
|
|
|
|
|
raw_fword = raw_fword.permute(1, 0, 2) |
|
|
raw_fword = self.textmodel_ln_final(raw_fword) |
|
|
|
|
|
if not self.tunelang: |
|
|
raw_fword = raw_fword.detach() |
|
|
|
|
|
eos_token = raw_fword[torch.arange(raw_fword.shape[0]), word_id.argmax(dim=-1).reshape(-1), :] |
|
|
|
|
|
F_g = self.fusion(x12, eos_token) |
|
|
F_tf = self.fpn([F_g, x9, x6]) |
|
|
|
|
|
|
|
|
b, c, h, w = F_tf.shape |
|
|
|
|
|
flatten_length = h*w |
|
|
visu_feat = F_tf.reshape(b, c, flatten_length) |
|
|
visu_feat = F.relu(visu_feat) |
|
|
lang_feat = F.relu(self.fc_2(raw_fword)) |
|
|
|
|
|
visu_feat = visu_feat.permute(0, 2, 1) |
|
|
pos_embed = self.pos_embedding(visu_feat) |
|
|
visu_feat = visu_feat.transpose(0, 1) |
|
|
pos_embed = pos_embed.transpose(0, 1) |
|
|
visu_feat = self.encoder(visu_feat, pos=pos_embed) |
|
|
|
|
|
|
|
|
visu_feat_ = visu_feat.permute(1,0,2) |
|
|
|
|
|
|
|
|
visu_feat = visu_feat.reshape(h, w, b, c) |
|
|
visu_feat = visu_feat.permute(2,3,0,1) |
|
|
proto_masks = self.mask_decoder(visu_feat, 2) |
|
|
|
|
|
|
|
|
proto_masks = F.relu(proto_masks) |
|
|
|
|
|
|
|
|
coef = self.lang_tf_enc(visu_feat_, lang_feat) |
|
|
coef = coef.view(b, h, w, c) |
|
|
coef = coef.permute(0, 3, 1, 2) |
|
|
|
|
|
coef = self.proj1(coef) |
|
|
coef = self.proj2(coef) |
|
|
coef = self.proj3(coef) |
|
|
coef = coef.permute(0, 2, 3, 1) |
|
|
coef = coef.contiguous().view(b, h*w*32) |
|
|
|
|
|
coef = self.projout(coef).unsqueeze(-1) |
|
|
coef = F.tanh(coef) |
|
|
|
|
|
|
|
|
proto_masks = proto_masks.permute(0, 2, 3, 1) |
|
|
proto_masks = proto_masks.view(b, -1, 32) |
|
|
|
|
|
|
|
|
mask_out = torch.bmm(proto_masks, coef, out=None) |
|
|
mask_out = mask_out.view(b, 208, 208, 1) |
|
|
mask_out = mask_out.permute(0, 3, 1, 2) |
|
|
return mask_out |
|
|
|