|
|
|
|
|
|
|
|
from typing import List, Optional |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from torch import Tensor, nn |
|
|
from torch.nn import functional as F |
|
|
|
|
|
from mmdet.models.utils.transformer import inverse_sigmoid |
|
|
from .structures import Instances |
|
|
|
|
|
|
|
|
def random_drop_tracks( |
|
|
track_instances: Instances, drop_probability: float |
|
|
) -> Instances: |
|
|
if drop_probability > 0 and len(track_instances) > 0: |
|
|
keep_idxes = torch.rand_like(track_instances.scores) > drop_probability |
|
|
track_instances = track_instances[keep_idxes] |
|
|
return track_instances |
|
|
|
|
|
|
|
|
class QueryInteractionBase(nn.Module): |
|
|
def __init__(self, args, dim_in, hidden_dim, dim_out): |
|
|
super().__init__() |
|
|
self.args = args |
|
|
self._build_layers(args, dim_in, hidden_dim, dim_out) |
|
|
self._reset_parameters() |
|
|
|
|
|
def _build_layers(self, args, dim_in, hidden_dim, dim_out): |
|
|
raise NotImplementedError() |
|
|
|
|
|
def _reset_parameters(self): |
|
|
for p in self.parameters(): |
|
|
if p.dim() > 1: |
|
|
nn.init.xavier_uniform_(p) |
|
|
|
|
|
def _select_active_tracks(self, data: dict) -> Instances: |
|
|
raise NotImplementedError() |
|
|
|
|
|
def _update_track_embedding(self, track_instances): |
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
class FFN(nn.Module): |
|
|
def __init__(self, d_model, d_ffn, dropout=0): |
|
|
super().__init__() |
|
|
self.linear1 = nn.Linear(d_model, d_ffn) |
|
|
self.activation = F.relu |
|
|
self.dropout1 = nn.Dropout(dropout) |
|
|
self.linear2 = nn.Linear(d_ffn, d_model) |
|
|
self.dropout2 = nn.Dropout(dropout) |
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
|
|
|
def forward(self, tgt): |
|
|
tgt2 = self.linear2(self.dropout1(self.activation(self.linear1(tgt)))) |
|
|
tgt = tgt + self.dropout2(tgt2) |
|
|
tgt = self.norm(tgt) |
|
|
return tgt |
|
|
|
|
|
|
|
|
class QueryInteractionModule(QueryInteractionBase): |
|
|
def __init__(self, args, dim_in, hidden_dim, dim_out): |
|
|
super().__init__(args, dim_in, hidden_dim, dim_out) |
|
|
self.random_drop = args["random_drop"] |
|
|
self.fp_ratio = args["fp_ratio"] |
|
|
self.update_query_pos = args["update_query_pos"] |
|
|
|
|
|
def _build_layers(self, args, dim_in, hidden_dim, dim_out): |
|
|
dropout = args["merger_dropout"] |
|
|
|
|
|
self.self_attn = nn.MultiheadAttention(dim_in, 8, dropout) |
|
|
self.linear1 = nn.Linear(dim_in, hidden_dim) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.linear2 = nn.Linear(hidden_dim, dim_in) |
|
|
|
|
|
if args["update_query_pos"]: |
|
|
self.linear_pos1 = nn.Linear(dim_in, hidden_dim) |
|
|
self.linear_pos2 = nn.Linear(hidden_dim, dim_in) |
|
|
self.dropout_pos1 = nn.Dropout(dropout) |
|
|
self.dropout_pos2 = nn.Dropout(dropout) |
|
|
self.norm_pos = nn.LayerNorm(dim_in) |
|
|
|
|
|
self.linear_feat1 = nn.Linear(dim_in, hidden_dim) |
|
|
self.linear_feat2 = nn.Linear(hidden_dim, dim_in) |
|
|
self.dropout_feat1 = nn.Dropout(dropout) |
|
|
self.dropout_feat2 = nn.Dropout(dropout) |
|
|
self.norm_feat = nn.LayerNorm(dim_in) |
|
|
|
|
|
self.norm1 = nn.LayerNorm(dim_in) |
|
|
self.norm2 = nn.LayerNorm(dim_in) |
|
|
|
|
|
self.dropout1 = nn.Dropout(dropout) |
|
|
self.dropout2 = nn.Dropout(dropout) |
|
|
self.activation = F.relu |
|
|
|
|
|
def _update_track_embedding(self, track_instances: Instances) -> Instances: |
|
|
if len(track_instances) == 0: |
|
|
return track_instances |
|
|
dim = track_instances.query.shape[1] |
|
|
out_embed = track_instances.output_embedding |
|
|
query_pos = track_instances.query[:, : dim // 2] |
|
|
query_feat = track_instances.query[:, dim // 2 :] |
|
|
q = k = query_pos + out_embed |
|
|
|
|
|
|
|
|
tgt = out_embed |
|
|
tgt2 = self.self_attn(q[:, None], k[:, None], value=tgt[:, None])[0][:, 0] |
|
|
tgt = tgt + self.dropout1(tgt2) |
|
|
tgt = self.norm1(tgt) |
|
|
|
|
|
|
|
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) |
|
|
tgt = tgt + self.dropout2(tgt2) |
|
|
tgt = self.norm2(tgt) |
|
|
|
|
|
if self.update_query_pos: |
|
|
|
|
|
query_pos2 = self.linear_pos2( |
|
|
self.dropout_pos1(self.activation(self.linear_pos1(tgt))) |
|
|
) |
|
|
query_pos = query_pos + self.dropout_pos2(query_pos2) |
|
|
query_pos = self.norm_pos(query_pos) |
|
|
track_instances.query[:, : dim // 2] = query_pos |
|
|
|
|
|
query_feat2 = self.linear_feat2( |
|
|
self.dropout_feat1(self.activation(self.linear_feat1(tgt))) |
|
|
) |
|
|
query_feat = query_feat + self.dropout_feat2(query_feat2) |
|
|
query_feat = self.norm_feat(query_feat) |
|
|
track_instances.query[:, dim // 2 :] = query_feat |
|
|
|
|
|
|
|
|
return track_instances |
|
|
|
|
|
def _random_drop_tracks(self, track_instances: Instances) -> Instances: |
|
|
return random_drop_tracks(track_instances, self.random_drop) |
|
|
|
|
|
def _add_fp_tracks( |
|
|
self, track_instances: Instances, active_track_instances: Instances |
|
|
) -> Instances: |
|
|
""" |
|
|
self.fp_ratio is used to control num(add_fp) / num(active) |
|
|
""" |
|
|
inactive_instances = track_instances[track_instances.obj_idxes < 0] |
|
|
|
|
|
|
|
|
fp_prob = torch.ones_like(active_track_instances.scores) * self.fp_ratio |
|
|
selected_active_track_instances = active_track_instances[ |
|
|
torch.bernoulli(fp_prob).bool() |
|
|
] |
|
|
num_fp = len(selected_active_track_instances) |
|
|
|
|
|
if len(inactive_instances) > 0 and num_fp > 0: |
|
|
if num_fp >= len(inactive_instances): |
|
|
fp_track_instances = inactive_instances |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fp_indexes = torch.argsort(inactive_instances.scores)[-num_fp:] |
|
|
fp_track_instances = inactive_instances[fp_indexes] |
|
|
|
|
|
merged_track_instances = Instances.cat( |
|
|
[active_track_instances, fp_track_instances] |
|
|
) |
|
|
return merged_track_instances |
|
|
|
|
|
return active_track_instances |
|
|
|
|
|
def _select_active_tracks(self, data: dict) -> Instances: |
|
|
track_instances: Instances = data["track_instances"] |
|
|
if self.training: |
|
|
|
|
|
|
|
|
|
|
|
active_idxes = track_instances.obj_idxes >= 0 |
|
|
active_track_instances = track_instances[active_idxes] |
|
|
|
|
|
|
|
|
active_track_instances = self._random_drop_tracks(active_track_instances) |
|
|
if self.fp_ratio > 0: |
|
|
active_track_instances = self._add_fp_tracks( |
|
|
track_instances, active_track_instances |
|
|
) |
|
|
else: |
|
|
active_track_instances = track_instances[track_instances.obj_idxes >= 0] |
|
|
|
|
|
return active_track_instances |
|
|
|
|
|
def forward(self, data) -> Instances: |
|
|
active_track_instances = self._select_active_tracks(data) |
|
|
active_track_instances = self._update_track_embedding(active_track_instances) |
|
|
init_track_instances: Instances = data["init_track_instances"] |
|
|
merged_track_instances = Instances.cat( |
|
|
[init_track_instances, active_track_instances] |
|
|
) |
|
|
return merged_track_instances |
|
|
|
|
|
|
|
|
def build_qim(args, dim_in, hidden_dim, dim_out): |
|
|
qim_type = args["qim_type"] |
|
|
interaction_layers = { |
|
|
"QIMBase": QueryInteractionModule, |
|
|
} |
|
|
assert qim_type in interaction_layers, "invalid query interaction layer: {}".format( |
|
|
qim_type |
|
|
) |
|
|
return interaction_layers[qim_type](args, dim_in, hidden_dim, dim_out) |
|
|
|