Spaces:
Runtime error
Runtime error
| # ------------------------------------------------------------------------ | |
| # DINO | |
| # Copyright (c) 2022 IDEA. All Rights Reserved. | |
| # Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
| # ------------------------------------------------------------------------ | |
| # Modified from DINO https://github.com/IDEA-Research/DINO by Feng Li and Hao Zhang. | |
| # ------------------------------------------------------------------------ | |
| from typing import Optional, List, Union | |
| import torch | |
| from torch import nn, Tensor | |
| from torch.cuda.amp import autocast | |
| from ...utils.utils import MLP, _get_clones, _get_activation_fn, gen_sineembed_for_position, inverse_sigmoid | |
| from ..pixel_decoder.ops.modules import MSDeformAttn | |
| class TransformerDecoder(nn.Module): | |
| def __init__(self, decoder_layer, num_layers, norm=None, | |
| return_intermediate=False, | |
| d_model=256, query_dim=4, | |
| modulate_hw_attn=True, | |
| num_feature_levels=1, | |
| deformable_decoder=True, | |
| decoder_query_perturber=None, | |
| dec_layer_number=None, # number of queries each layer in decoder | |
| rm_dec_query_scale=True, | |
| dec_layer_share=False, | |
| dec_layer_dropout_prob=None, | |
| cross_track_layer = False, | |
| n_levels = None, | |
| n_heads = None, | |
| n_points = None, | |
| ): | |
| super().__init__() | |
| if num_layers > 0: | |
| self.layers = _get_clones(decoder_layer, num_layers, layer_share=dec_layer_share) | |
| else: | |
| self.layers = [] | |
| self.num_layers = num_layers | |
| self.norm = norm | |
| self.return_intermediate = return_intermediate | |
| assert return_intermediate, "support return_intermediate only" | |
| self.query_dim = query_dim | |
| assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim) | |
| self.num_feature_levels = num_feature_levels | |
| self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2) | |
| if not deformable_decoder: | |
| self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2) | |
| else: | |
| self.query_pos_sine_scale = None | |
| if rm_dec_query_scale: | |
| self.query_scale = None | |
| else: | |
| raise NotImplementedError | |
| self.query_scale = MLP(d_model, d_model, d_model, 2) | |
| self.bbox_embed = None | |
| self.class_embed = None | |
| self.d_model = d_model | |
| self.modulate_hw_attn = modulate_hw_attn | |
| self.deformable_decoder = deformable_decoder | |
| if not deformable_decoder and modulate_hw_attn: | |
| self.ref_anchor_head = MLP(d_model, d_model, 2, 2) | |
| else: | |
| self.ref_anchor_head = None | |
| self.decoder_query_perturber = decoder_query_perturber | |
| self.box_pred_damping = None | |
| self.dec_layer_number = dec_layer_number | |
| if dec_layer_number is not None: | |
| assert isinstance(dec_layer_number, list) | |
| assert len(dec_layer_number) == num_layers | |
| # assert dec_layer_number[0] == | |
| self.dec_layer_dropout_prob = dec_layer_dropout_prob | |
| if dec_layer_dropout_prob is not None: | |
| assert isinstance(dec_layer_dropout_prob, list) | |
| assert len(dec_layer_dropout_prob) == num_layers | |
| for i in dec_layer_dropout_prob: | |
| assert 0.0 <= i <= 1.0 | |
| if cross_track_layer: # add a cross-attention-layer before track ffn head | |
| self.cross_track_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) | |
| self.cross_track = True | |
| else: | |
| self.cross_track = False | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| for m in self.modules(): | |
| if isinstance(m, MSDeformAttn): | |
| m._reset_parameters() | |
| def with_pos_embed(tensor, pos): | |
| return tensor if pos is None else tensor + pos | |
| def forward(self, tgt, memory, | |
| tgt_mask: Optional[Tensor] = None, | |
| memory_mask: Optional[Tensor] = None, | |
| tgt_key_padding_mask: Optional[Tensor] = None, | |
| memory_key_padding_mask: Optional[Tensor] = None, | |
| pos: Optional[Tensor] = None, | |
| refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2 | |
| # for memory | |
| level_start_index: Optional[Tensor] = None, # num_levels | |
| spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 | |
| valid_ratios: Optional[Tensor] = None, | |
| task = None, | |
| extra = None, | |
| ): | |
| """ | |
| Input: | |
| - tgt: nq, bs, d_model | |
| - memory: hw, bs, d_model | |
| - pos: hw, bs, d_model | |
| - refpoints_unsigmoid: nq, bs, 2/4 | |
| - valid_ratios/spatial_shapes: bs, nlevel, 2 | |
| """ | |
| output = tgt | |
| device = tgt.device | |
| intermediate = [] | |
| reference_points = refpoints_unsigmoid.sigmoid().to(device) | |
| ref_points = [reference_points] | |
| for layer_id, layer in enumerate(self.layers): | |
| # preprocess ref points | |
| if self.training and self.decoder_query_perturber is not None and layer_id != 0: | |
| reference_points = self.decoder_query_perturber(reference_points) | |
| reference_points_input = reference_points[:, :, None] \ | |
| * torch.cat([valid_ratios, valid_ratios], -1)[None, :] # nq, bs, nlevel, 4 | |
| query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :]) # nq, bs, 256*2 | |
| raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256 | |
| pos_scale = self.query_scale(output) if self.query_scale is not None else 1 | |
| query_pos = pos_scale * raw_query_pos | |
| output = layer( | |
| tgt=output, | |
| tgt_query_pos=query_pos, | |
| tgt_query_sine_embed=query_sine_embed, | |
| tgt_key_padding_mask=tgt_key_padding_mask, | |
| tgt_reference_points=reference_points_input, | |
| memory=memory, | |
| memory_key_padding_mask=memory_key_padding_mask, | |
| memory_level_start_index=level_start_index, | |
| memory_spatial_shapes=spatial_shapes, | |
| memory_pos=pos, | |
| self_attn_mask=tgt_mask, | |
| cross_attn_mask=memory_mask, | |
| task = task, | |
| extra = extra, | |
| layer_id = layer_id, | |
| ) | |
| # iter update | |
| if self.bbox_embed is not None: | |
| reference_before_sigmoid = inverse_sigmoid(reference_points) | |
| delta_unsig = self.bbox_embed[layer_id](output).to(device) | |
| outputs_unsig = delta_unsig + reference_before_sigmoid | |
| new_reference_points = outputs_unsig.sigmoid() | |
| reference_points = new_reference_points.detach() | |
| # if layer_id != self.num_layers - 1: | |
| ref_points.append(new_reference_points) | |
| intermediate.append(self.norm(output)) | |
| if self.cross_track: | |
| tgt_track = self.cross_track_attn(self.with_pos_embed(output, query_pos).transpose(0, 1), | |
| reference_points_input.transpose(0, 1).contiguous(), | |
| memory.transpose(0, 1), spatial_shapes, level_start_index, | |
| memory_key_padding_mask).transpose(0, 1) | |
| tgt_track = tgt_track + output | |
| tgt_track = tgt_track.transpose(0, 1) | |
| else: | |
| tgt_track = None | |
| return [ | |
| [itm_out.transpose(0, 1) for itm_out in intermediate], | |
| [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points], tgt_track | |
| ] | |
| class DeformableTransformerDecoderLayer(nn.Module): | |
| def __init__(self, d_model=256, d_ffn=1024, | |
| dropout=0.1, activation="relu", | |
| n_levels=4, n_heads=8, n_points=4, | |
| use_deformable_box_attn=False, | |
| key_aware_type=None, | |
| ): | |
| super().__init__() | |
| self.n_heads = n_heads | |
| # cross attention | |
| if use_deformable_box_attn: | |
| raise NotImplementedError | |
| else: | |
| self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| # self attention | |
| self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| # ffn | |
| self.linear1 = nn.Linear(d_model, d_ffn) | |
| self.activation = _get_activation_fn(activation) | |
| self.dropout3 = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(d_ffn, d_model) | |
| self.dropout4 = nn.Dropout(dropout) | |
| self.norm3 = nn.LayerNorm(d_model) | |
| self.key_aware_type = key_aware_type | |
| self.key_aware_proj = None | |
| def rm_self_attn_modules(self): | |
| self.self_attn = None | |
| self.dropout2 = None | |
| self.norm2 = None | |
| def with_pos_embed(tensor, pos): | |
| return tensor if pos is None else tensor + pos | |
| def forward_ffn(self, tgt): | |
| tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) | |
| tgt = tgt + self.dropout4(tgt2) | |
| tgt = self.norm3(tgt) | |
| return tgt | |
| def forward(self, | |
| # for tgt | |
| tgt: Optional[Tensor], # nq, bs, d_model | |
| tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos)) | |
| tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos) | |
| tgt_key_padding_mask: Optional[Tensor] = None, | |
| tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4 | |
| # for memory | |
| memory: Optional[Tensor] = None, # hw, bs, d_model | |
| memory_key_padding_mask: Optional[Tensor] = None, | |
| memory_level_start_index: Optional[Tensor] = None, # num_levels | |
| memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 | |
| memory_pos: Optional[Tensor] = None, # pos for memory | |
| # sa | |
| self_attn_mask: Optional[Tensor] = None, # mask used for self-attention | |
| cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention | |
| task = None, | |
| extra = None, | |
| layer_id = None, | |
| ): | |
| """ | |
| Input: | |
| - tgt/tgt_query_pos: nq, bs, d_model | |
| - | |
| """ | |
| # self attention | |
| if task in ['grounding', 'rvos'] or 'visual_prompt_tokens' in extra: | |
| if self_attn_mask is not None: # training with denoising query | |
| if 'visual_prompt_tokens' in extra: # has visual prompt | |
| level_index = layer_id % 3 # src level : self.num_feature_levels | |
| prompt_tokens = extra['visual_prompt_tokens'][level_index] | |
| promot_pos = prompt_tokens.detach().clone() | |
| prompt_mask = extra['visual_prompt_nonzero_mask'][level_index] | |
| else: #grounding | |
| prompt_tokens = extra['grounding_tokens'] | |
| promot_pos = prompt_tokens.detach().clone() | |
| prompt_mask = extra['grounding_nonzero_mask'] | |
| ori_size = tgt.shape[0] | |
| new_mask_size = tgt.shape[0]+prompt_tokens.shape[0] | |
| new_self_attn_mask = torch.zeros((tgt.shape[1], new_mask_size, new_mask_size), dtype=torch.bool, device=tgt.device) | |
| new_self_attn_mask[:,:ori_size,:ori_size] = self_attn_mask.unsqueeze(0).repeat(tgt.shape[1],1,1) #denoising matching keepmask | |
| # prompt to prompt mask set to True if they are not valid | |
| # new_self_attn_mask[:,ori_size:,ori_size:][prompt_mask] = True | |
| # new_self_attn_mask[:,ori_size:,ori_size:].transpose(1,2)[prompt_mask] = True | |
| # prompt2obj and obj2prompt mask set to True | |
| # new_self_attn_mask[:,ori_size-300:ori_size,ori_size:][] = True | |
| new_self_attn_mask[:,:ori_size,ori_size:].transpose(1,2)[prompt_mask] = True | |
| new_self_attn_mask[:,ori_size:,:ori_size][prompt_mask] = True | |
| # new_self_attn_mask[:,ori_size:,ori_size-300:ori_size].transpose(1,2)[] = True | |
| new_self_attn_mask = new_self_attn_mask.repeat_interleave(self.n_heads, dim=0) | |
| else: # with out denoising query | |
| if 'visual_prompt_tokens' in extra: # has visual prompt | |
| level_index = layer_id % 3 # src level : self.num_feature_levels | |
| prompt_tokens = extra['visual_prompt_tokens'][level_index] | |
| promot_pos = prompt_tokens.detach().clone() | |
| prompt_mask = extra['visual_prompt_nonzero_mask'][level_index] | |
| else: #grounding | |
| prompt_tokens = extra['grounding_tokens'] | |
| promot_pos = prompt_tokens.detach().clone() | |
| prompt_mask = extra['grounding_nonzero_mask'] | |
| ori_size = tgt.shape[0] | |
| new_mask_size = tgt.shape[0]+prompt_tokens.shape[0] | |
| new_self_attn_mask = torch.zeros((tgt.shape[1], new_mask_size, new_mask_size), dtype=torch.bool, device=tgt.device) | |
| new_self_attn_mask[:,:ori_size,ori_size:].transpose(1,2)[prompt_mask] = True | |
| new_self_attn_mask[:,ori_size:,:ori_size][prompt_mask] = True | |
| new_self_attn_mask = new_self_attn_mask.repeat_interleave(self.n_heads, dim=0) | |
| if self.self_attn is not None: | |
| tgt = torch.cat([tgt,prompt_tokens],dim=0) | |
| tgt_query_pos = torch.cat([tgt_query_pos,promot_pos],dim=0) | |
| q = k = self.with_pos_embed(tgt, tgt_query_pos) | |
| tgt2 = self.self_attn(q, k, tgt, attn_mask=new_self_attn_mask)[0] | |
| tgt = tgt + self.dropout2(tgt2) | |
| tgt = self.norm2(tgt) | |
| tgt = tgt[:ori_size] | |
| tgt_query_pos = tgt_query_pos[:ori_size] | |
| else: | |
| if self.self_attn is not None: | |
| q = k = self.with_pos_embed(tgt, tgt_query_pos) | |
| tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0] | |
| tgt = tgt + self.dropout2(tgt2) | |
| tgt = self.norm2(tgt) | |
| # cross attention | |
| if self.key_aware_type is not None: | |
| if self.key_aware_type == 'mean': | |
| tgt = tgt + memory.mean(0, keepdim=True) | |
| elif self.key_aware_type == 'proj_mean': | |
| tgt = tgt + self.key_aware_proj(memory).mean(0, keepdim=True) | |
| else: | |
| raise NotImplementedError("Unknown key_aware_type: {}".format(self.key_aware_type)) | |
| tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1), | |
| tgt_reference_points.transpose(0, 1).contiguous(), | |
| memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index, | |
| memory_key_padding_mask).transpose(0, 1) | |
| tgt = tgt + self.dropout1(tgt2) | |
| tgt = self.norm1(tgt) | |
| # ffn | |
| tgt = self.forward_ffn(tgt) | |
| return tgt | |