| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from mmdet.registry import MODELS |
| from .language_model import LanguageEncoder |
| from .transformer_blocks import (MLP, Conv2d, CrossAttentionLayer, FFNLayer, |
| PositionEmbeddingSine, SelfAttentionLayer) |
| from .utils import is_lower_torch_version |
|
|
|
|
| def vl_similarity(image_feat, text_feat, temperature=1): |
| logits = torch.matmul(image_feat, text_feat.t()) |
| logits = temperature.exp().clamp(max=100) * logits |
| return logits |
|
|
|
|
| @MODELS.register_module() |
| class XDecoderTransformerDecoder(nn.Module): |
|
|
| def __init__( |
| self, |
| in_channels=512, |
| hidden_dim: int = 512, |
| dim_proj: int = 512, |
| num_queries: int = 101, |
| max_token_num: int = 77, |
| nheads: int = 8, |
| dim_feedforward: int = 2048, |
| decoder_layers: int = 9, |
| pre_norm: bool = False, |
| mask_dim: int = 512, |
| task: str = 'semseg', |
| captioning_step: int = 50, |
| ): |
| super().__init__() |
|
|
| |
| self.pe_layer = PositionEmbeddingSine(hidden_dim // 2, normalize=True) |
|
|
| |
| self.num_heads = nheads |
| self.num_layers = decoder_layers |
| self.max_token_num = max_token_num |
| self.transformer_self_attention_layers = nn.ModuleList() |
| self.transformer_cross_attention_layers = nn.ModuleList() |
| self.transformer_ffn_layers = nn.ModuleList() |
|
|
| for _ in range(self.num_layers): |
| self.transformer_self_attention_layers.append( |
| SelfAttentionLayer( |
| d_model=hidden_dim, |
| nhead=nheads, |
| dropout=0.0, |
| normalize_before=pre_norm, |
| )) |
|
|
| self.transformer_cross_attention_layers.append( |
| CrossAttentionLayer( |
| d_model=hidden_dim, |
| nhead=nheads, |
| dropout=0.0, |
| normalize_before=pre_norm, |
| )) |
|
|
| self.transformer_ffn_layers.append( |
| FFNLayer( |
| d_model=hidden_dim, |
| dim_feedforward=dim_feedforward, |
| dropout=0.0, |
| normalize_before=pre_norm, |
| )) |
|
|
| self.decoder_norm = nn.LayerNorm(hidden_dim) |
|
|
| self.num_queries = num_queries |
| |
| self.query_feat = nn.Embedding(num_queries, hidden_dim) |
| |
| self.query_embed = nn.Embedding(num_queries, hidden_dim) |
|
|
| |
| self.num_feature_levels = 3 |
| self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) |
| self.input_proj = nn.ModuleList() |
|
|
| for _ in range(self.num_feature_levels): |
| if in_channels != hidden_dim: |
| self.input_proj.append( |
| Conv2d(in_channels, hidden_dim, kernel_size=1)) |
| else: |
| self.input_proj.append(nn.Sequential()) |
|
|
| self.task = task |
|
|
| |
| self.lang_encoder = LanguageEncoder() |
|
|
| self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) |
| self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) |
|
|
| |
| self.caping_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) |
| self.pos_embed_caping = nn.Embedding(max_token_num, hidden_dim) |
| self.captioning_step = captioning_step |
|
|
| |
| |
| |
| self_attn_mask = torch.zeros((1, num_queries + max_token_num, |
| num_queries + max_token_num)).bool() |
| |
| self_attn_mask[:, :num_queries, num_queries:] = True |
| |
| self_attn_mask[:, num_queries:, num_queries:] = torch.triu( |
| torch.ones((1, max_token_num, max_token_num)), diagonal=1).bool() |
| |
| self_attn_mask[:, :num_queries - 1, num_queries - 1:num_queries] = True |
| |
| self_attn_mask[:, num_queries - 1:num_queries, :num_queries - 1] = True |
| self.register_buffer('self_attn_mask', self_attn_mask) |
|
|
| def forward(self, x, mask_features, extra=None): |
| if self.task == 'caption': |
| return self.forward_caption(x, mask_features, extra) |
|
|
| assert len(x) == self.num_feature_levels |
| src = [] |
| pos = [] |
| size_list = [] |
|
|
| for i in range(self.num_feature_levels): |
| size_list.append(x[i].shape[-2:]) |
| pos.append(self.pe_layer(x[i], None).flatten(2)) |
| src.append(self.input_proj[i](x[i]).flatten(2) + |
| self.level_embed.weight[i][None, :, None]) |
|
|
| |
| pos[-1] = pos[-1].permute(2, 0, 1) |
| src[-1] = src[-1].permute(2, 0, 1) |
|
|
| _, bs, _ = src[0].shape |
|
|
| query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) |
| output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) |
|
|
| predictions_mask = [] |
| predictions_class_embed = [] |
|
|
| if self.task == 'ref-seg': |
| self_tgt_mask = self.self_attn_mask[:, :self.num_queries, :self. |
| num_queries].repeat( |
| output.shape[1] * |
| self.num_heads, 1, 1) |
| grounding_tokens = extra['grounding_tokens'] |
| _grounding_tokens = grounding_tokens.detach().clone() |
| |
| pad_tgt_mask = torch.ones( |
| (1, self.num_queries + (self.num_queries - 1) + |
| len(grounding_tokens), self.num_queries + |
| (self.num_queries - 1) + len(grounding_tokens)), |
| device=self_tgt_mask.device).bool().repeat( |
| output.shape[1] * self.num_heads, 1, 1) |
| pad_tgt_mask[:, :self.num_queries, :self. |
| num_queries] = self_tgt_mask |
| |
| pad_tgt_mask[:, self.num_queries:, self.num_queries:] = False |
| self_tgt_mask = pad_tgt_mask |
| output = torch.cat((output, output[:-1]), dim=0) |
| |
| query_embed = torch.cat((query_embed, query_embed[:-1]), dim=0) |
| else: |
| self_tgt_mask = self.self_attn_mask[:, :self.num_queries, :self. |
| num_queries].repeat( |
| output.shape[1] * |
| self.num_heads, 1, 1) |
|
|
| results = self.forward_prediction_heads( |
| output, mask_features, attn_mask_target_size=size_list[0]) |
| attn_mask = results['attn_mask'] |
| predictions_class_embed.append(results['class_embed']) |
| predictions_mask.append(results['outputs_mask']) |
|
|
| for i in range(self.num_layers): |
| level_index = i % self.num_feature_levels |
| attn_mask[torch.where( |
| attn_mask.sum(-1) == attn_mask.shape[-1])] = False |
|
|
| |
| output, avg_attn = self.transformer_cross_attention_layers[i]( |
| output, |
| src[level_index], |
| memory_mask=attn_mask, |
| |
| memory_key_padding_mask=None, |
| pos=pos[level_index], |
| query_pos=query_embed) |
|
|
| if self.task == 'ref-seg': |
| output = torch.cat((output, _grounding_tokens), dim=0) |
| query_embed = torch.cat((query_embed, grounding_tokens), dim=0) |
|
|
| output = self.transformer_self_attention_layers[i]( |
| output, |
| tgt_mask=self_tgt_mask, |
| tgt_key_padding_mask=None, |
| query_pos=query_embed) |
|
|
| output = self.transformer_ffn_layers[i](output) |
|
|
| if self.task == 'ref-seg': |
| _grounding_tokens = output[-len(_grounding_tokens):] |
| output = output[:-len(_grounding_tokens)] |
| query_embed = query_embed[:-len(_grounding_tokens)] |
|
|
| results = self.forward_prediction_heads( |
| output, |
| mask_features, |
| attn_mask_target_size=size_list[(i + 1) % |
| self.num_feature_levels]) |
| attn_mask = results['attn_mask'] |
| predictions_mask.append(results['outputs_mask']) |
| predictions_class_embed.append(results['class_embed']) |
|
|
| out = { |
| 'pred_masks': predictions_mask[-1], |
| 'pred_class_embed': predictions_class_embed[-1], |
| } |
|
|
| if self.task == 'ref-seg': |
| mask_pred_results = [] |
| outputs_class = [] |
| for idx in range(mask_features.shape[0]): |
| pred_gmasks = out['pred_masks'][idx, self.num_queries:2 * |
| self.num_queries - 1] |
| v_emb = predictions_class_embed[-1][idx, self.num_queries:2 * |
| self.num_queries - 1] |
| t_emb = extra['class_emb'] |
|
|
| t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7) |
| v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) |
|
|
| temperature = self.lang_encoder.logit_scale |
| out_prob = vl_similarity(v_emb, t_emb, temperature=temperature) |
|
|
| matched_id = out_prob.max(0)[1] |
| mask_pred_results += [pred_gmasks[matched_id, :, :]] |
| outputs_class += [out_prob[matched_id, :]] |
| out['pred_masks'] = mask_pred_results |
| out['pred_logits'] = outputs_class |
| elif self.task == 'retrieval': |
| t_emb = extra['class_emb'] |
| temperature = self.lang_encoder.logit_scale |
| v_emb = out['pred_class_embed'][:, -1, :] |
| v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) |
| logits = vl_similarity(v_emb, t_emb, temperature) |
| out['pred_logits'] = logits |
| elif self.task in ['semseg', 'instance', 'panoptic']: |
| outputs_class = self.lang_encoder.compute_similarity( |
| out['pred_class_embed']) |
| out['pred_logits'] = outputs_class |
| return out |
|
|
| def forward_caption(self, x, mask_features, extra=None): |
| assert len(x) == self.num_feature_levels |
| src = [] |
| pos = [] |
| size_list = [] |
|
|
| for i in range(self.num_feature_levels): |
| size_list.append(x[i].shape[-2:]) |
| pos.append(self.pe_layer(x[i], None).flatten(2)) |
| src.append(self.input_proj[i](x[i]).flatten(2) + |
| self.level_embed.weight[i][None, :, None]) |
|
|
| |
| pos[-1] = pos[-1].permute(2, 0, 1) |
| src[-1] = src[-1].permute(2, 0, 1) |
|
|
| _, bs, _ = src[0].shape |
|
|
| |
| query_embed_ = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) |
| query_feat = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) |
| lang_token = extra['start_token'].repeat(bs, 1) |
| pos_embed = self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1) |
|
|
| |
| token_embs = self.lang_encoder.lang_encoder.token_embedding.weight |
|
|
| for cap_idx in range(0, self.captioning_step): |
| lang_embed = self.lang_encoder.forward_language( |
| (lang_token, ), with_cls_embed=False)[1].transpose(0, 1) |
| |
| output = torch.cat((query_feat, lang_embed), dim=0) |
| lang_embed += pos_embed |
| query_embed = torch.cat((query_embed_, lang_embed), dim=0) |
|
|
| |
| results = self.forward_prediction_heads( |
| output, mask_features, attn_mask_target_size=size_list[0]) |
| attn_mask = results['attn_mask'] |
|
|
| for i in range(self.num_layers): |
| level_index = i % self.num_feature_levels |
| attn_mask[torch.where( |
| attn_mask.sum(-1) == attn_mask.shape[-1])] = False |
| attn_mask = torch.cat( |
| (attn_mask, |
| torch.zeros_like(attn_mask[:, :self.max_token_num, :])), |
| dim=1) |
| self_tgt_mask = self.self_attn_mask.repeat( |
| output.shape[1] * self.num_heads, 1, 1) |
|
|
| if 'grounding_mask' in extra: |
| bs, nq, wh = attn_mask.shape |
| assert bs == self.num_heads, 'Only support single ' \ |
| 'image referring captioning.' |
| grounding_mask = extra['grounding_mask'] |
| attn_mask = attn_mask.reshape(bs, nq, size_list[i % 3][0], |
| size_list[i % 3][1]) |
| grounding_mask = F.interpolate( |
| grounding_mask.float(), |
| size_list[i % 3], |
| mode='nearest').bool()[0, 0] |
| attn_mask[:, self.num_queries:, grounding_mask] = True |
| attn_mask = attn_mask.reshape(bs, nq, wh) |
|
|
| |
| output, avg_attn = self.transformer_cross_attention_layers[i]( |
| output, |
| src[level_index], |
| memory_mask=attn_mask, |
| |
| memory_key_padding_mask=None, |
| pos=pos[level_index], |
| query_pos=query_embed) |
|
|
| output = self.transformer_self_attention_layers[i]( |
| output, |
| tgt_mask=self_tgt_mask, |
| tgt_key_padding_mask=None, |
| query_pos=query_embed) |
|
|
| output = self.transformer_ffn_layers[i](output) |
|
|
| results = self.forward_prediction_heads( |
| output, |
| mask_features, |
| attn_mask_target_size=size_list[(i + 1) % |
| self.num_feature_levels]) |
| attn_mask = results['attn_mask'] |
|
|
| pred_captions = results['outputs_caption'] |
| pred_captions = pred_captions @ token_embs.t() |
| lang_token[:, cap_idx + 1] = pred_captions[:, cap_idx].max(-1)[1] |
|
|
| texts = self.lang_encoder.tokenizer.batch_decode( |
| lang_token, skip_special_tokens=False) |
| texts_new = [] |
|
|
| for x in texts: |
| x = x.split('<|endoftext|>')[0] |
| x = x.replace('<|endoftext|>', '') |
| x = x.replace('<|startoftext|>', '') |
| x = x.strip() |
| texts_new.append(x) |
|
|
| out = {'pred_caption': texts_new} |
| return out |
|
|
| def forward_prediction_heads(self, output, mask_features, |
| attn_mask_target_size): |
| decoder_output = self.decoder_norm(output) |
| decoder_output = decoder_output.transpose(0, 1) |
|
|
| if self.task == 'caption': |
| outputs_caption = decoder_output[:, self. |
| num_queries:] @ self.caping_embed |
|
|
| |
| norm_decoder_output = decoder_output / ( |
| decoder_output.norm(dim=-1, keepdim=True) + 1e-7) |
| obj_token = norm_decoder_output[:, :self.num_queries - 1] |
| cls_token = norm_decoder_output[:, |
| self.num_queries - 1:self.num_queries] |
|
|
| sim = (cls_token @ obj_token.transpose(1, 2)).softmax(-1)[:, 0, :, |
| None] |
| cls_token = (sim * decoder_output[:, :self.num_queries - 1]).sum( |
| dim=1, keepdim=True) |
|
|
| if self.task == 'ref-seg': |
| decoder_output = torch.cat( |
| (decoder_output[:, :self.num_queries - 1], cls_token, |
| decoder_output[:, self.num_queries:2 * self.num_queries - 1]), |
| dim=1) |
| else: |
| decoder_output = torch.cat( |
| (decoder_output[:, :self.num_queries - 1], cls_token), dim=1) |
|
|
| mask_embed = self.mask_embed(decoder_output) |
| outputs_mask = torch.einsum('bqc,bchw->bqhw', mask_embed, |
| mask_features) |
|
|
| if is_lower_torch_version(): |
| attn_mask = F.interpolate( |
| outputs_mask, |
| size=attn_mask_target_size, |
| mode='bicubic', |
| align_corners=False) |
| else: |
| attn_mask = F.interpolate( |
| outputs_mask, |
| size=attn_mask_target_size, |
| mode='bicubic', |
| align_corners=False, |
| antialias=True) |
|
|
| attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat( |
| 1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool() |
| attn_mask = attn_mask.detach() |
|
|
| attn_mask[:, self.num_queries:self.num_queries + 1].fill_(False) |
|
|
| if self.task == 'caption': |
| results = { |
| 'attn_mask': attn_mask, |
| 'outputs_caption': outputs_caption, |
| } |
| return results |
| else: |
| class_embed = decoder_output @ self.class_embed |
| results = { |
| 'outputs_mask': outputs_mask, |
| 'attn_mask': attn_mask, |
| 'class_embed': class_embed, |
| } |
| return results |
|
|