|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
|
|
|
from mmdet.registry import MODELS |
|
|
from .language_model import LanguageEncoder |
|
|
from .transformer_blocks import (MLP, Conv2d, CrossAttentionLayer, FFNLayer, |
|
|
PositionEmbeddingSine, SelfAttentionLayer) |
|
|
from .utils import is_lower_torch_version |
|
|
|
|
|
|
|
|
def vl_similarity(image_feat, text_feat, temperature=1): |
|
|
logits = torch.matmul(image_feat, text_feat.t()) |
|
|
logits = temperature.exp().clamp(max=100) * logits |
|
|
return logits |
|
|
|
|
|
|
|
|
@MODELS.register_module() |
|
|
class XDecoderTransformerDecoder(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_channels=512, |
|
|
hidden_dim: int = 512, |
|
|
dim_proj: int = 512, |
|
|
num_queries: int = 101, |
|
|
max_token_num: int = 77, |
|
|
nheads: int = 8, |
|
|
dim_feedforward: int = 2048, |
|
|
decoder_layers: int = 9, |
|
|
pre_norm: bool = False, |
|
|
mask_dim: int = 512, |
|
|
task: str = 'semseg', |
|
|
captioning_step: int = 50, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.pe_layer = PositionEmbeddingSine(hidden_dim // 2, normalize=True) |
|
|
|
|
|
|
|
|
self.num_heads = nheads |
|
|
self.num_layers = decoder_layers |
|
|
self.max_token_num = max_token_num |
|
|
self.transformer_self_attention_layers = nn.ModuleList() |
|
|
self.transformer_cross_attention_layers = nn.ModuleList() |
|
|
self.transformer_ffn_layers = nn.ModuleList() |
|
|
|
|
|
for _ in range(self.num_layers): |
|
|
self.transformer_self_attention_layers.append( |
|
|
SelfAttentionLayer( |
|
|
d_model=hidden_dim, |
|
|
nhead=nheads, |
|
|
dropout=0.0, |
|
|
normalize_before=pre_norm, |
|
|
)) |
|
|
|
|
|
self.transformer_cross_attention_layers.append( |
|
|
CrossAttentionLayer( |
|
|
d_model=hidden_dim, |
|
|
nhead=nheads, |
|
|
dropout=0.0, |
|
|
normalize_before=pre_norm, |
|
|
)) |
|
|
|
|
|
self.transformer_ffn_layers.append( |
|
|
FFNLayer( |
|
|
d_model=hidden_dim, |
|
|
dim_feedforward=dim_feedforward, |
|
|
dropout=0.0, |
|
|
normalize_before=pre_norm, |
|
|
)) |
|
|
|
|
|
self.decoder_norm = nn.LayerNorm(hidden_dim) |
|
|
|
|
|
self.num_queries = num_queries |
|
|
|
|
|
self.query_feat = nn.Embedding(num_queries, hidden_dim) |
|
|
|
|
|
self.query_embed = nn.Embedding(num_queries, hidden_dim) |
|
|
|
|
|
|
|
|
self.num_feature_levels = 3 |
|
|
self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) |
|
|
self.input_proj = nn.ModuleList() |
|
|
|
|
|
for _ in range(self.num_feature_levels): |
|
|
if in_channels != hidden_dim: |
|
|
self.input_proj.append( |
|
|
Conv2d(in_channels, hidden_dim, kernel_size=1)) |
|
|
else: |
|
|
self.input_proj.append(nn.Sequential()) |
|
|
|
|
|
self.task = task |
|
|
|
|
|
|
|
|
self.lang_encoder = LanguageEncoder() |
|
|
|
|
|
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) |
|
|
self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) |
|
|
|
|
|
|
|
|
self.caping_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) |
|
|
self.pos_embed_caping = nn.Embedding(max_token_num, hidden_dim) |
|
|
self.captioning_step = captioning_step |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self_attn_mask = torch.zeros((1, num_queries + max_token_num, |
|
|
num_queries + max_token_num)).bool() |
|
|
|
|
|
self_attn_mask[:, :num_queries, num_queries:] = True |
|
|
|
|
|
self_attn_mask[:, num_queries:, num_queries:] = torch.triu( |
|
|
torch.ones((1, max_token_num, max_token_num)), diagonal=1).bool() |
|
|
|
|
|
self_attn_mask[:, :num_queries - 1, num_queries - 1:num_queries] = True |
|
|
|
|
|
self_attn_mask[:, num_queries - 1:num_queries, :num_queries - 1] = True |
|
|
self.register_buffer('self_attn_mask', self_attn_mask) |
|
|
|
|
|
def forward(self, x, mask_features, extra=None): |
|
|
if self.task == 'caption': |
|
|
return self.forward_caption(x, mask_features, extra) |
|
|
|
|
|
assert len(x) == self.num_feature_levels |
|
|
src = [] |
|
|
pos = [] |
|
|
size_list = [] |
|
|
|
|
|
for i in range(self.num_feature_levels): |
|
|
size_list.append(x[i].shape[-2:]) |
|
|
pos.append(self.pe_layer(x[i], None).flatten(2)) |
|
|
src.append(self.input_proj[i](x[i]).flatten(2) + |
|
|
self.level_embed.weight[i][None, :, None]) |
|
|
|
|
|
|
|
|
pos[-1] = pos[-1].permute(2, 0, 1) |
|
|
src[-1] = src[-1].permute(2, 0, 1) |
|
|
|
|
|
_, bs, _ = src[0].shape |
|
|
|
|
|
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) |
|
|
output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) |
|
|
|
|
|
predictions_mask = [] |
|
|
predictions_class_embed = [] |
|
|
|
|
|
if self.task == 'ref-seg': |
|
|
self_tgt_mask = self.self_attn_mask[:, :self.num_queries, :self. |
|
|
num_queries].repeat( |
|
|
output.shape[1] * |
|
|
self.num_heads, 1, 1) |
|
|
grounding_tokens = extra['grounding_tokens'] |
|
|
_grounding_tokens = grounding_tokens.detach().clone() |
|
|
|
|
|
pad_tgt_mask = torch.ones( |
|
|
(1, self.num_queries + (self.num_queries - 1) + |
|
|
len(grounding_tokens), self.num_queries + |
|
|
(self.num_queries - 1) + len(grounding_tokens)), |
|
|
device=self_tgt_mask.device).bool().repeat( |
|
|
output.shape[1] * self.num_heads, 1, 1) |
|
|
pad_tgt_mask[:, :self.num_queries, :self. |
|
|
num_queries] = self_tgt_mask |
|
|
|
|
|
pad_tgt_mask[:, self.num_queries:, self.num_queries:] = False |
|
|
self_tgt_mask = pad_tgt_mask |
|
|
output = torch.cat((output, output[:-1]), dim=0) |
|
|
|
|
|
query_embed = torch.cat((query_embed, query_embed[:-1]), dim=0) |
|
|
else: |
|
|
self_tgt_mask = self.self_attn_mask[:, :self.num_queries, :self. |
|
|
num_queries].repeat( |
|
|
output.shape[1] * |
|
|
self.num_heads, 1, 1) |
|
|
|
|
|
results = self.forward_prediction_heads( |
|
|
output, mask_features, attn_mask_target_size=size_list[0]) |
|
|
attn_mask = results['attn_mask'] |
|
|
predictions_class_embed.append(results['class_embed']) |
|
|
predictions_mask.append(results['outputs_mask']) |
|
|
|
|
|
for i in range(self.num_layers): |
|
|
level_index = i % self.num_feature_levels |
|
|
attn_mask[torch.where( |
|
|
attn_mask.sum(-1) == attn_mask.shape[-1])] = False |
|
|
|
|
|
|
|
|
output, avg_attn = self.transformer_cross_attention_layers[i]( |
|
|
output, |
|
|
src[level_index], |
|
|
memory_mask=attn_mask, |
|
|
|
|
|
memory_key_padding_mask=None, |
|
|
pos=pos[level_index], |
|
|
query_pos=query_embed) |
|
|
|
|
|
if self.task == 'ref-seg': |
|
|
output = torch.cat((output, _grounding_tokens), dim=0) |
|
|
query_embed = torch.cat((query_embed, grounding_tokens), dim=0) |
|
|
|
|
|
output = self.transformer_self_attention_layers[i]( |
|
|
output, |
|
|
tgt_mask=self_tgt_mask, |
|
|
tgt_key_padding_mask=None, |
|
|
query_pos=query_embed) |
|
|
|
|
|
output = self.transformer_ffn_layers[i](output) |
|
|
|
|
|
if self.task == 'ref-seg': |
|
|
_grounding_tokens = output[-len(_grounding_tokens):] |
|
|
output = output[:-len(_grounding_tokens)] |
|
|
query_embed = query_embed[:-len(_grounding_tokens)] |
|
|
|
|
|
results = self.forward_prediction_heads( |
|
|
output, |
|
|
mask_features, |
|
|
attn_mask_target_size=size_list[(i + 1) % |
|
|
self.num_feature_levels]) |
|
|
attn_mask = results['attn_mask'] |
|
|
predictions_mask.append(results['outputs_mask']) |
|
|
predictions_class_embed.append(results['class_embed']) |
|
|
|
|
|
out = { |
|
|
'pred_masks': predictions_mask[-1], |
|
|
'pred_class_embed': predictions_class_embed[-1], |
|
|
} |
|
|
|
|
|
if self.task == 'ref-seg': |
|
|
mask_pred_results = [] |
|
|
outputs_class = [] |
|
|
for idx in range(mask_features.shape[0]): |
|
|
pred_gmasks = out['pred_masks'][idx, self.num_queries:2 * |
|
|
self.num_queries - 1] |
|
|
v_emb = predictions_class_embed[-1][idx, self.num_queries:2 * |
|
|
self.num_queries - 1] |
|
|
t_emb = extra['class_emb'] |
|
|
|
|
|
t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7) |
|
|
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) |
|
|
|
|
|
temperature = self.lang_encoder.logit_scale |
|
|
out_prob = vl_similarity(v_emb, t_emb, temperature=temperature) |
|
|
|
|
|
matched_id = out_prob.max(0)[1] |
|
|
mask_pred_results += [pred_gmasks[matched_id, :, :]] |
|
|
outputs_class += [out_prob[matched_id, :]] |
|
|
out['pred_masks'] = mask_pred_results |
|
|
out['pred_logits'] = outputs_class |
|
|
elif self.task == 'retrieval': |
|
|
t_emb = extra['class_emb'] |
|
|
temperature = self.lang_encoder.logit_scale |
|
|
v_emb = out['pred_class_embed'][:, -1, :] |
|
|
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) |
|
|
logits = vl_similarity(v_emb, t_emb, temperature) |
|
|
out['pred_logits'] = logits |
|
|
elif self.task in ['semseg', 'instance', 'panoptic']: |
|
|
outputs_class = self.lang_encoder.compute_similarity( |
|
|
out['pred_class_embed']) |
|
|
out['pred_logits'] = outputs_class |
|
|
return out |
|
|
|
|
|
def forward_caption(self, x, mask_features, extra=None): |
|
|
assert len(x) == self.num_feature_levels |
|
|
src = [] |
|
|
pos = [] |
|
|
size_list = [] |
|
|
|
|
|
for i in range(self.num_feature_levels): |
|
|
size_list.append(x[i].shape[-2:]) |
|
|
pos.append(self.pe_layer(x[i], None).flatten(2)) |
|
|
src.append(self.input_proj[i](x[i]).flatten(2) + |
|
|
self.level_embed.weight[i][None, :, None]) |
|
|
|
|
|
|
|
|
pos[-1] = pos[-1].permute(2, 0, 1) |
|
|
src[-1] = src[-1].permute(2, 0, 1) |
|
|
|
|
|
_, bs, _ = src[0].shape |
|
|
|
|
|
|
|
|
query_embed_ = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) |
|
|
query_feat = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) |
|
|
lang_token = extra['start_token'].repeat(bs, 1) |
|
|
pos_embed = self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1) |
|
|
|
|
|
|
|
|
token_embs = self.lang_encoder.lang_encoder.token_embedding.weight |
|
|
|
|
|
for cap_idx in range(0, self.captioning_step): |
|
|
lang_embed = self.lang_encoder.forward_language( |
|
|
(lang_token, ), with_cls_embed=False)[1].transpose(0, 1) |
|
|
|
|
|
output = torch.cat((query_feat, lang_embed), dim=0) |
|
|
lang_embed += pos_embed |
|
|
query_embed = torch.cat((query_embed_, lang_embed), dim=0) |
|
|
|
|
|
|
|
|
results = self.forward_prediction_heads( |
|
|
output, mask_features, attn_mask_target_size=size_list[0]) |
|
|
attn_mask = results['attn_mask'] |
|
|
|
|
|
for i in range(self.num_layers): |
|
|
level_index = i % self.num_feature_levels |
|
|
attn_mask[torch.where( |
|
|
attn_mask.sum(-1) == attn_mask.shape[-1])] = False |
|
|
attn_mask = torch.cat( |
|
|
(attn_mask, |
|
|
torch.zeros_like(attn_mask[:, :self.max_token_num, :])), |
|
|
dim=1) |
|
|
self_tgt_mask = self.self_attn_mask.repeat( |
|
|
output.shape[1] * self.num_heads, 1, 1) |
|
|
|
|
|
if 'grounding_mask' in extra: |
|
|
bs, nq, wh = attn_mask.shape |
|
|
assert bs == self.num_heads, 'Only support single ' \ |
|
|
'image referring captioning.' |
|
|
grounding_mask = extra['grounding_mask'] |
|
|
attn_mask = attn_mask.reshape(bs, nq, size_list[i % 3][0], |
|
|
size_list[i % 3][1]) |
|
|
grounding_mask = F.interpolate( |
|
|
grounding_mask.float(), |
|
|
size_list[i % 3], |
|
|
mode='nearest').bool()[0, 0] |
|
|
attn_mask[:, self.num_queries:, grounding_mask] = True |
|
|
attn_mask = attn_mask.reshape(bs, nq, wh) |
|
|
|
|
|
|
|
|
output, avg_attn = self.transformer_cross_attention_layers[i]( |
|
|
output, |
|
|
src[level_index], |
|
|
memory_mask=attn_mask, |
|
|
|
|
|
memory_key_padding_mask=None, |
|
|
pos=pos[level_index], |
|
|
query_pos=query_embed) |
|
|
|
|
|
output = self.transformer_self_attention_layers[i]( |
|
|
output, |
|
|
tgt_mask=self_tgt_mask, |
|
|
tgt_key_padding_mask=None, |
|
|
query_pos=query_embed) |
|
|
|
|
|
output = self.transformer_ffn_layers[i](output) |
|
|
|
|
|
results = self.forward_prediction_heads( |
|
|
output, |
|
|
mask_features, |
|
|
attn_mask_target_size=size_list[(i + 1) % |
|
|
self.num_feature_levels]) |
|
|
attn_mask = results['attn_mask'] |
|
|
|
|
|
pred_captions = results['outputs_caption'] |
|
|
pred_captions = pred_captions @ token_embs.t() |
|
|
lang_token[:, cap_idx + 1] = pred_captions[:, cap_idx].max(-1)[1] |
|
|
|
|
|
texts = self.lang_encoder.tokenizer.batch_decode( |
|
|
lang_token, skip_special_tokens=False) |
|
|
texts_new = [] |
|
|
|
|
|
for x in texts: |
|
|
x = x.split('<|endoftext|>')[0] |
|
|
x = x.replace('<|endoftext|>', '') |
|
|
x = x.replace('<|startoftext|>', '') |
|
|
x = x.strip() |
|
|
texts_new.append(x) |
|
|
|
|
|
out = {'pred_caption': texts_new} |
|
|
return out |
|
|
|
|
|
def forward_prediction_heads(self, output, mask_features, |
|
|
attn_mask_target_size): |
|
|
decoder_output = self.decoder_norm(output) |
|
|
decoder_output = decoder_output.transpose(0, 1) |
|
|
|
|
|
if self.task == 'caption': |
|
|
outputs_caption = decoder_output[:, self. |
|
|
num_queries:] @ self.caping_embed |
|
|
|
|
|
|
|
|
norm_decoder_output = decoder_output / ( |
|
|
decoder_output.norm(dim=-1, keepdim=True) + 1e-7) |
|
|
obj_token = norm_decoder_output[:, :self.num_queries - 1] |
|
|
cls_token = norm_decoder_output[:, |
|
|
self.num_queries - 1:self.num_queries] |
|
|
|
|
|
sim = (cls_token @ obj_token.transpose(1, 2)).softmax(-1)[:, 0, :, |
|
|
None] |
|
|
cls_token = (sim * decoder_output[:, :self.num_queries - 1]).sum( |
|
|
dim=1, keepdim=True) |
|
|
|
|
|
if self.task == 'ref-seg': |
|
|
decoder_output = torch.cat( |
|
|
(decoder_output[:, :self.num_queries - 1], cls_token, |
|
|
decoder_output[:, self.num_queries:2 * self.num_queries - 1]), |
|
|
dim=1) |
|
|
else: |
|
|
decoder_output = torch.cat( |
|
|
(decoder_output[:, :self.num_queries - 1], cls_token), dim=1) |
|
|
|
|
|
mask_embed = self.mask_embed(decoder_output) |
|
|
outputs_mask = torch.einsum('bqc,bchw->bqhw', mask_embed, |
|
|
mask_features) |
|
|
|
|
|
if is_lower_torch_version(): |
|
|
attn_mask = F.interpolate( |
|
|
outputs_mask, |
|
|
size=attn_mask_target_size, |
|
|
mode='bicubic', |
|
|
align_corners=False) |
|
|
else: |
|
|
attn_mask = F.interpolate( |
|
|
outputs_mask, |
|
|
size=attn_mask_target_size, |
|
|
mode='bicubic', |
|
|
align_corners=False, |
|
|
antialias=True) |
|
|
|
|
|
attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat( |
|
|
1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool() |
|
|
attn_mask = attn_mask.detach() |
|
|
|
|
|
attn_mask[:, self.num_queries:self.num_queries + 1].fill_(False) |
|
|
|
|
|
if self.task == 'caption': |
|
|
results = { |
|
|
'attn_mask': attn_mask, |
|
|
'outputs_caption': outputs_caption, |
|
|
} |
|
|
return results |
|
|
else: |
|
|
class_embed = decoder_output @ self.class_embed |
|
|
results = { |
|
|
'outputs_mask': outputs_mask, |
|
|
'attn_mask': attn_mask, |
|
|
'class_embed': class_embed, |
|
|
} |
|
|
return results |
|
|
|