## Code modified from https://github.com/megvii-model/MOTR/blob/main/models/qim.py from typing import List, Optional import numpy as np import torch from torch import Tensor, nn from torch.nn import functional as F from mmdet.models.utils.transformer import inverse_sigmoid from .structures import Instances def random_drop_tracks( track_instances: Instances, drop_probability: float ) -> Instances: if drop_probability > 0 and len(track_instances) > 0: keep_idxes = torch.rand_like(track_instances.scores) > drop_probability track_instances = track_instances[keep_idxes] return track_instances class QueryInteractionBase(nn.Module): def __init__(self, args, dim_in, hidden_dim, dim_out): super().__init__() self.args = args self._build_layers(args, dim_in, hidden_dim, dim_out) self._reset_parameters() def _build_layers(self, args, dim_in, hidden_dim, dim_out): raise NotImplementedError() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def _select_active_tracks(self, data: dict) -> Instances: raise NotImplementedError() def _update_track_embedding(self, track_instances): raise NotImplementedError() class FFN(nn.Module): def __init__(self, d_model, d_ffn, dropout=0): super().__init__() self.linear1 = nn.Linear(d_model, d_ffn) self.activation = F.relu self.dropout1 = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ffn, d_model) self.dropout2 = nn.Dropout(dropout) self.norm = nn.LayerNorm(d_model) def forward(self, tgt): tgt2 = self.linear2(self.dropout1(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout2(tgt2) tgt = self.norm(tgt) return tgt class QueryInteractionModule(QueryInteractionBase): def __init__(self, args, dim_in, hidden_dim, dim_out): super().__init__(args, dim_in, hidden_dim, dim_out) self.random_drop = args["random_drop"] self.fp_ratio = args["fp_ratio"] self.update_query_pos = args["update_query_pos"] def _build_layers(self, args, dim_in, hidden_dim, dim_out): dropout = args["merger_dropout"] self.self_attn = nn.MultiheadAttention(dim_in, 8, dropout) self.linear1 = nn.Linear(dim_in, hidden_dim) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(hidden_dim, dim_in) if args["update_query_pos"]: self.linear_pos1 = nn.Linear(dim_in, hidden_dim) self.linear_pos2 = nn.Linear(hidden_dim, dim_in) self.dropout_pos1 = nn.Dropout(dropout) self.dropout_pos2 = nn.Dropout(dropout) self.norm_pos = nn.LayerNorm(dim_in) self.linear_feat1 = nn.Linear(dim_in, hidden_dim) self.linear_feat2 = nn.Linear(hidden_dim, dim_in) self.dropout_feat1 = nn.Dropout(dropout) self.dropout_feat2 = nn.Dropout(dropout) self.norm_feat = nn.LayerNorm(dim_in) self.norm1 = nn.LayerNorm(dim_in) self.norm2 = nn.LayerNorm(dim_in) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = F.relu def _update_track_embedding(self, track_instances: Instances) -> Instances: if len(track_instances) == 0: return track_instances dim = track_instances.query.shape[1] out_embed = track_instances.output_embedding query_pos = track_instances.query[:, : dim // 2] query_feat = track_instances.query[:, dim // 2 :] q = k = query_pos + out_embed # attention tgt = out_embed tgt2 = self.self_attn(q[:, None], k[:, None], value=tgt[:, None])[0][:, 0] tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) # ffn tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) if self.update_query_pos: # ffn: linear_pos2 query_pos2 = self.linear_pos2( self.dropout_pos1(self.activation(self.linear_pos1(tgt))) ) query_pos = query_pos + self.dropout_pos2(query_pos2) query_pos = self.norm_pos(query_pos) track_instances.query[:, : dim // 2] = query_pos query_feat2 = self.linear_feat2( self.dropout_feat1(self.activation(self.linear_feat1(tgt))) ) query_feat = query_feat + self.dropout_feat2(query_feat2) query_feat = self.norm_feat(query_feat) track_instances.query[:, dim // 2 :] = query_feat # track_instances.ref_pts = inverse_sigmoid(track_instances.pred_boxes[:, :2].detach().clone()) # update ref_pts using track_instances.pred_boxes return track_instances def _random_drop_tracks(self, track_instances: Instances) -> Instances: return random_drop_tracks(track_instances, self.random_drop) def _add_fp_tracks( self, track_instances: Instances, active_track_instances: Instances ) -> Instances: """ self.fp_ratio is used to control num(add_fp) / num(active) """ inactive_instances = track_instances[track_instances.obj_idxes < 0] # add fp for each active track in a specific probability. fp_prob = torch.ones_like(active_track_instances.scores) * self.fp_ratio selected_active_track_instances = active_track_instances[ torch.bernoulli(fp_prob).bool() ] num_fp = len(selected_active_track_instances) if len(inactive_instances) > 0 and num_fp > 0: if num_fp >= len(inactive_instances): fp_track_instances = inactive_instances else: # randomly select num_fp from inactive_instances # fp_indexes = np.random.permutation(len(inactive_instances)) # fp_indexes = fp_indexes[:num_fp] # fp_track_instances = inactive_instances[fp_indexes] # v2: select the fps with top scores rather than random selection fp_indexes = torch.argsort(inactive_instances.scores)[-num_fp:] fp_track_instances = inactive_instances[fp_indexes] merged_track_instances = Instances.cat( [active_track_instances, fp_track_instances] ) return merged_track_instances return active_track_instances def _select_active_tracks(self, data: dict) -> Instances: track_instances: Instances = data["track_instances"] if self.training: # select matched track IDs for interaction # active_idxes = (track_instances.obj_idxes >= 0) & (track_instances.iou > 0.5) active_idxes = track_instances.obj_idxes >= 0 active_track_instances = track_instances[active_idxes] # set -2 instead of -1 to ensure that these tracks will not be selected in matching. active_track_instances = self._random_drop_tracks(active_track_instances) if self.fp_ratio > 0: active_track_instances = self._add_fp_tracks( track_instances, active_track_instances ) else: active_track_instances = track_instances[track_instances.obj_idxes >= 0] return active_track_instances def forward(self, data) -> Instances: active_track_instances = self._select_active_tracks(data) active_track_instances = self._update_track_embedding(active_track_instances) init_track_instances: Instances = data["init_track_instances"] merged_track_instances = Instances.cat( [init_track_instances, active_track_instances] ) return merged_track_instances def build_qim(args, dim_in, hidden_dim, dim_out): qim_type = args["qim_type"] interaction_layers = { "QIMBase": QueryInteractionModule, } assert qim_type in interaction_layers, "invalid query interaction layer: {}".format( qim_type ) return interaction_layers[qim_type](args, dim_in, hidden_dim, dim_out)