File size: 8,257 Bytes
663494c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
## Code modified from https://github.com/megvii-model/MOTR/blob/main/models/qim.py
from typing import List, Optional
import numpy as np
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from mmdet.models.utils.transformer import inverse_sigmoid
from .structures import Instances
def random_drop_tracks(
track_instances: Instances, drop_probability: float
) -> Instances:
if drop_probability > 0 and len(track_instances) > 0:
keep_idxes = torch.rand_like(track_instances.scores) > drop_probability
track_instances = track_instances[keep_idxes]
return track_instances
class QueryInteractionBase(nn.Module):
def __init__(self, args, dim_in, hidden_dim, dim_out):
super().__init__()
self.args = args
self._build_layers(args, dim_in, hidden_dim, dim_out)
self._reset_parameters()
def _build_layers(self, args, dim_in, hidden_dim, dim_out):
raise NotImplementedError()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def _select_active_tracks(self, data: dict) -> Instances:
raise NotImplementedError()
def _update_track_embedding(self, track_instances):
raise NotImplementedError()
class FFN(nn.Module):
def __init__(self, d_model, d_ffn, dropout=0):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = F.relu
self.dropout1 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout2 = nn.Dropout(dropout)
self.norm = nn.LayerNorm(d_model)
def forward(self, tgt):
tgt2 = self.linear2(self.dropout1(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm(tgt)
return tgt
class QueryInteractionModule(QueryInteractionBase):
def __init__(self, args, dim_in, hidden_dim, dim_out):
super().__init__(args, dim_in, hidden_dim, dim_out)
self.random_drop = args["random_drop"]
self.fp_ratio = args["fp_ratio"]
self.update_query_pos = args["update_query_pos"]
def _build_layers(self, args, dim_in, hidden_dim, dim_out):
dropout = args["merger_dropout"]
self.self_attn = nn.MultiheadAttention(dim_in, 8, dropout)
self.linear1 = nn.Linear(dim_in, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(hidden_dim, dim_in)
if args["update_query_pos"]:
self.linear_pos1 = nn.Linear(dim_in, hidden_dim)
self.linear_pos2 = nn.Linear(hidden_dim, dim_in)
self.dropout_pos1 = nn.Dropout(dropout)
self.dropout_pos2 = nn.Dropout(dropout)
self.norm_pos = nn.LayerNorm(dim_in)
self.linear_feat1 = nn.Linear(dim_in, hidden_dim)
self.linear_feat2 = nn.Linear(hidden_dim, dim_in)
self.dropout_feat1 = nn.Dropout(dropout)
self.dropout_feat2 = nn.Dropout(dropout)
self.norm_feat = nn.LayerNorm(dim_in)
self.norm1 = nn.LayerNorm(dim_in)
self.norm2 = nn.LayerNorm(dim_in)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = F.relu
def _update_track_embedding(self, track_instances: Instances) -> Instances:
if len(track_instances) == 0:
return track_instances
dim = track_instances.query.shape[1]
out_embed = track_instances.output_embedding
query_pos = track_instances.query[:, : dim // 2]
query_feat = track_instances.query[:, dim // 2 :]
q = k = query_pos + out_embed
# attention
tgt = out_embed
tgt2 = self.self_attn(q[:, None], k[:, None], value=tgt[:, None])[0][:, 0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# ffn
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
if self.update_query_pos:
# ffn: linear_pos2
query_pos2 = self.linear_pos2(
self.dropout_pos1(self.activation(self.linear_pos1(tgt)))
)
query_pos = query_pos + self.dropout_pos2(query_pos2)
query_pos = self.norm_pos(query_pos)
track_instances.query[:, : dim // 2] = query_pos
query_feat2 = self.linear_feat2(
self.dropout_feat1(self.activation(self.linear_feat1(tgt)))
)
query_feat = query_feat + self.dropout_feat2(query_feat2)
query_feat = self.norm_feat(query_feat)
track_instances.query[:, dim // 2 :] = query_feat
# track_instances.ref_pts = inverse_sigmoid(track_instances.pred_boxes[:, :2].detach().clone())
# update ref_pts using track_instances.pred_boxes
return track_instances
def _random_drop_tracks(self, track_instances: Instances) -> Instances:
return random_drop_tracks(track_instances, self.random_drop)
def _add_fp_tracks(
self, track_instances: Instances, active_track_instances: Instances
) -> Instances:
"""
self.fp_ratio is used to control num(add_fp) / num(active)
"""
inactive_instances = track_instances[track_instances.obj_idxes < 0]
# add fp for each active track in a specific probability.
fp_prob = torch.ones_like(active_track_instances.scores) * self.fp_ratio
selected_active_track_instances = active_track_instances[
torch.bernoulli(fp_prob).bool()
]
num_fp = len(selected_active_track_instances)
if len(inactive_instances) > 0 and num_fp > 0:
if num_fp >= len(inactive_instances):
fp_track_instances = inactive_instances
else:
# randomly select num_fp from inactive_instances
# fp_indexes = np.random.permutation(len(inactive_instances))
# fp_indexes = fp_indexes[:num_fp]
# fp_track_instances = inactive_instances[fp_indexes]
# v2: select the fps with top scores rather than random selection
fp_indexes = torch.argsort(inactive_instances.scores)[-num_fp:]
fp_track_instances = inactive_instances[fp_indexes]
merged_track_instances = Instances.cat(
[active_track_instances, fp_track_instances]
)
return merged_track_instances
return active_track_instances
def _select_active_tracks(self, data: dict) -> Instances:
track_instances: Instances = data["track_instances"]
if self.training:
# select matched track IDs for interaction
# active_idxes = (track_instances.obj_idxes >= 0) & (track_instances.iou > 0.5)
active_idxes = track_instances.obj_idxes >= 0
active_track_instances = track_instances[active_idxes]
# set -2 instead of -1 to ensure that these tracks will not be selected in matching.
active_track_instances = self._random_drop_tracks(active_track_instances)
if self.fp_ratio > 0:
active_track_instances = self._add_fp_tracks(
track_instances, active_track_instances
)
else:
active_track_instances = track_instances[track_instances.obj_idxes >= 0]
return active_track_instances
def forward(self, data) -> Instances:
active_track_instances = self._select_active_tracks(data)
active_track_instances = self._update_track_embedding(active_track_instances)
init_track_instances: Instances = data["init_track_instances"]
merged_track_instances = Instances.cat(
[init_track_instances, active_track_instances]
)
return merged_track_instances
def build_qim(args, dim_in, hidden_dim, dim_out):
qim_type = args["qim_type"]
interaction_layers = {
"QIMBase": QueryInteractionModule,
}
assert qim_type in interaction_layers, "invalid query interaction layer: {}".format(
qim_type
)
return interaction_layers[qim_type](args, dim_in, hidden_dim, dim_out)
|