|
|
import torch |
|
|
import torch.nn as nn |
|
|
from mmcv.cnn.bricks.registry import TRANSFORMER_LAYER_SEQUENCE |
|
|
from mmcv.cnn.bricks.transformer import ( |
|
|
TransformerLayerSequence, |
|
|
build_transformer_layer_sequence, |
|
|
) |
|
|
from mmcv.runner.base_module import BaseModule |
|
|
|
|
|
from mmdet.models.utils.builder import TRANSFORMER |
|
|
|
|
|
|
|
|
def inverse_sigmoid(x, eps=1e-5): |
|
|
"""Inverse function of sigmoid. |
|
|
|
|
|
Args: |
|
|
x (Tensor): The tensor to do the |
|
|
inverse. |
|
|
eps (float): EPS avoid numerical |
|
|
overflow. Defaults 1e-5. |
|
|
Returns: |
|
|
Tensor: The x has passed the inverse |
|
|
function of sigmoid, has same |
|
|
shape with input. |
|
|
""" |
|
|
x = x.clamp(min=0, max=1) |
|
|
x1 = x.clamp(min=eps) |
|
|
x2 = (1 - x).clamp(min=eps) |
|
|
return torch.log(x1 / x2) |
|
|
|
|
|
|
|
|
@TRANSFORMER.register_module() |
|
|
class Detr3DCamTransformerPlus(BaseModule): |
|
|
"""Implements the DeformableDETR transformer. |
|
|
Args: |
|
|
as_two_stage (bool): Generate query from encoder features. |
|
|
Default: False. |
|
|
num_feature_levels (int): Number of feature maps from FPN: |
|
|
Default: 4. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_feature_levels=4, |
|
|
num_cams=6, |
|
|
decoder=None, |
|
|
reference_points_aug=False, |
|
|
**kwargs |
|
|
): |
|
|
super(Detr3DCamTransformerPlus, self).__init__(**kwargs) |
|
|
self.decoder = build_transformer_layer_sequence(decoder) |
|
|
self.embed_dims = self.decoder.embed_dims |
|
|
self.num_feature_levels = num_feature_levels |
|
|
self.num_cams = num_cams |
|
|
self.reference_points_aug = reference_points_aug |
|
|
self.init_layers() |
|
|
|
|
|
def init_layers(self): |
|
|
"""Initialize layers of the DeformableDetrTransformer.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
def init_weights(self): |
|
|
"""Initialize the transformer weights.""" |
|
|
for p in self.parameters(): |
|
|
if p.dim() > 1: |
|
|
nn.init.xavier_uniform_(p) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
self, mlvl_feats, query_embed, reference_points, reg_branches=None, **kwargs |
|
|
): |
|
|
"""Forward function for `Transformer`. |
|
|
Args: |
|
|
mlvl_feats (list(Tensor)): Input queries from |
|
|
different level. Each element has shape |
|
|
[bs, embed_dims, h, w]. |
|
|
query_embed (Tensor): The query embedding for decoder, |
|
|
with shape [num_query, 2*embed_dim], can be splitted into |
|
|
query_feat and query_positional_encoding. |
|
|
reference_points (Tensor): The corresponding 3d ref points |
|
|
for the query with shape (num_query, 3) |
|
|
value is in inverse sigmoid space |
|
|
reg_branches (obj:`nn.ModuleList`): Regression heads for |
|
|
feature maps from each decoder layer. Only would |
|
|
be passed when |
|
|
`with_box_refine` is True. Default to None. |
|
|
|
|
|
Returns: |
|
|
tuple[Tensor]: results of decoder containing the following tensor. |
|
|
- inter_states: Outputs from decoder, has shape \ |
|
|
(num_dec_layers, num_query, bs, embed_dims) |
|
|
- init_reference_out: The initial value of reference \ |
|
|
points, has shape (bs, num_queries, 3). |
|
|
- inter_references_out: The internal value of reference \ |
|
|
points in decoder, has shape \ |
|
|
(num_dec_layers, bs, num_query, 3) |
|
|
|
|
|
""" |
|
|
assert query_embed is not None |
|
|
bs = mlvl_feats[0].size(0) |
|
|
query_pos, query = torch.split(query_embed, self.embed_dims, dim=1) |
|
|
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) |
|
|
query = query.unsqueeze(0).expand(bs, -1, -1) |
|
|
reference_points = reference_points.unsqueeze(dim=0).expand(bs, -1, -1) |
|
|
|
|
|
if self.training and self.reference_points_aug: |
|
|
reference_points = reference_points + torch.randn_like(reference_points) |
|
|
reference_points = reference_points.sigmoid() |
|
|
init_reference_out = reference_points |
|
|
|
|
|
|
|
|
query = query.permute(1, 0, 2) |
|
|
|
|
|
query_pos = query_pos.permute(1, 0, 2) |
|
|
inter_states, inter_references = self.decoder( |
|
|
query=query, |
|
|
key=None, |
|
|
value=mlvl_feats, |
|
|
query_pos=query_pos, |
|
|
reference_points=reference_points, |
|
|
reg_branches=reg_branches, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
inter_references_out = inter_references |
|
|
return inter_states, init_reference_out, inter_references_out |
|
|
|
|
|
|
|
|
@TRANSFORMER.register_module() |
|
|
class Detr3DCamTrackTransformer(BaseModule): |
|
|
"""Implements the DeformableDETR transformer. |
|
|
Specially designed for track: keep xyz trajectory, and |
|
|
kep bbox size(which should be consisten across frames) |
|
|
|
|
|
Args: |
|
|
num_feature_levels (int): Number of feature maps from FPN: |
|
|
Default: 4. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_feature_levels=4, |
|
|
num_cams=6, |
|
|
decoder=None, |
|
|
reference_points_aug=False, |
|
|
**kwargs |
|
|
): |
|
|
super(Detr3DCamTrackTransformer, self).__init__(**kwargs) |
|
|
self.decoder = build_transformer_layer_sequence(decoder) |
|
|
self.embed_dims = self.decoder.embed_dims |
|
|
self.num_feature_levels = num_feature_levels |
|
|
self.num_cams = num_cams |
|
|
self.reference_points_aug = reference_points_aug |
|
|
self.init_layers() |
|
|
|
|
|
def init_layers(self): |
|
|
"""Initialize layers of the DeformableDetrTransformer.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
def init_weights(self): |
|
|
"""Initialize the transformer weights.""" |
|
|
for p in self.parameters(): |
|
|
if p.dim() > 1: |
|
|
nn.init.xavier_uniform_(p) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
mlvl_feats, |
|
|
query_embed, |
|
|
reference_points, |
|
|
ref_size, |
|
|
reg_branches=None, |
|
|
**kwargs |
|
|
): |
|
|
"""Forward function for `Transformer`. |
|
|
Args: |
|
|
mlvl_feats (list(Tensor)): Input queries from |
|
|
different level. Each element has shape |
|
|
[bs, embed_dims, h, w]. |
|
|
query_embed (Tensor): The query embedding for decoder, |
|
|
with shape [num_query, 2*embed_dim], can be splitted into |
|
|
query_feat and query_positional_encoding. |
|
|
reference_points (Tensor): The corresponding 3d ref points |
|
|
for the query with shape (num_query, 3) |
|
|
value is in inverse sigmoid space |
|
|
ref_size (Tensor): the wlh(bbox size) associated with each query |
|
|
shape (num_query, 3) |
|
|
value in log space. |
|
|
reg_branches (obj:`nn.ModuleList`): Regression heads for |
|
|
feature maps from each decoder layer. Only would |
|
|
be passed when |
|
|
|
|
|
Returns: |
|
|
tuple[Tensor]: results of decoder containing the following tensor. |
|
|
- inter_states: Outputs from decoder, has shape \ |
|
|
(num_dec_layers, num_query, bs, embed_dims) |
|
|
- init_reference_out: The initial value of reference \ |
|
|
points, has shape (bs, num_queries, 3). |
|
|
- inter_references_out: The internal value of reference \ |
|
|
points in decoder, has shape \ |
|
|
(num_dec_layers, bs, num_query, 3) |
|
|
|
|
|
""" |
|
|
assert query_embed is not None |
|
|
bs = mlvl_feats[0].size(0) |
|
|
query_pos, query = torch.split(query_embed, self.embed_dims, dim=1) |
|
|
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) |
|
|
query = query.unsqueeze(0).expand(bs, -1, -1) |
|
|
reference_points = reference_points.unsqueeze(dim=0).expand(bs, -1, -1) |
|
|
ref_size = ref_size.unsqueeze(dim=0).expand(bs, -1, -1) |
|
|
|
|
|
|
|
|
if self.training and self.reference_points_aug: |
|
|
reference_points = reference_points + torch.randn_like(reference_points) |
|
|
reference_points = reference_points.sigmoid() |
|
|
|
|
|
|
|
|
query = query.permute(1, 0, 2) |
|
|
|
|
|
query_pos = query_pos.permute(1, 0, 2) |
|
|
inter_states, inter_references, inter_box_sizes = self.decoder( |
|
|
query=query, |
|
|
key=None, |
|
|
value=mlvl_feats, |
|
|
query_pos=query_pos, |
|
|
reference_points=reference_points, |
|
|
reg_branches=reg_branches, |
|
|
ref_size=ref_size, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
return inter_states, inter_references, inter_box_sizes |
|
|
|
|
|
|
|
|
@TRANSFORMER_LAYER_SEQUENCE.register_module() |
|
|
class Detr3DCamTrackPlusTransformerDecoder(TransformerLayerSequence): |
|
|
"""Implements the decoder in DETR transformer. |
|
|
Args: |
|
|
return_intermediate (bool): Whether to return intermediate outputs. |
|
|
coder_norm_cfg (dict): Config of last normalization layer. Default: |
|
|
`LN`. |
|
|
""" |
|
|
|
|
|
def __init__(self, *args, return_intermediate=True, **kwargs): |
|
|
|
|
|
super(Detr3DCamTrackPlusTransformerDecoder, self).__init__(*args, **kwargs) |
|
|
self.return_intermediate = return_intermediate |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
query, |
|
|
*args, |
|
|
reference_points=None, |
|
|
reg_branches=None, |
|
|
ref_size=None, |
|
|
**kwargs |
|
|
): |
|
|
"""Forward function for `TransformerDecoder`. |
|
|
Args: |
|
|
query (Tensor): Input query with shape |
|
|
`(num_query, bs, embed_dims)`. |
|
|
reference_points (Tensor): The 3d reference points |
|
|
associated with each query. shape (num_query, 3). |
|
|
value is in inevrse sigmoid space |
|
|
reg_branch: (obj:`nn.ModuleList`): Used for |
|
|
refining the regression results. Only would |
|
|
be passed when with_box_refine is True, |
|
|
otherwise would be passed a `None`. |
|
|
ref_size (Tensor): the wlh(bbox size) associated with each query |
|
|
shape (bs, num_query, 3) |
|
|
value in log space. |
|
|
Returns: |
|
|
Tensor: Results with shape [1, num_query, bs, embed_dims] when |
|
|
return_intermediate is `False`, otherwise it has shape |
|
|
[num_layers, num_query, bs, embed_dims]. |
|
|
""" |
|
|
output = query |
|
|
intermediate = [] |
|
|
intermediate_reference_points = [] |
|
|
intermediate_box_sizes = [] |
|
|
for lid, layer in enumerate(self.layers): |
|
|
reference_points_input = reference_points |
|
|
output = layer( |
|
|
output, |
|
|
*args, |
|
|
reference_points=reference_points_input, |
|
|
ref_size=ref_size, |
|
|
**kwargs |
|
|
) |
|
|
output = output.permute(1, 0, 2) |
|
|
|
|
|
if reg_branches is not None: |
|
|
tmp = reg_branches[lid](output) |
|
|
|
|
|
ref_pts_update = torch.cat( |
|
|
[ |
|
|
tmp[..., :2], |
|
|
tmp[..., 4:5], |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
ref_size_update = torch.cat([tmp[..., 2:4], tmp[..., 5:6]], dim=-1) |
|
|
assert reference_points.shape[-1] == 3 |
|
|
|
|
|
new_reference_points = ref_pts_update + inverse_sigmoid( |
|
|
reference_points |
|
|
) |
|
|
new_reference_points = new_reference_points.sigmoid() |
|
|
reference_points = new_reference_points.detach() |
|
|
|
|
|
|
|
|
|
|
|
ref_size = ref_size + ref_size_update |
|
|
if lid > 0: |
|
|
ref_size = ref_size.detach() |
|
|
|
|
|
output = output.permute(1, 0, 2) |
|
|
if self.return_intermediate: |
|
|
intermediate.append(output) |
|
|
intermediate_reference_points.append(reference_points) |
|
|
intermediate_box_sizes.append(ref_size) |
|
|
|
|
|
if self.return_intermediate: |
|
|
return ( |
|
|
torch.stack(intermediate), |
|
|
torch.stack(intermediate_reference_points), |
|
|
torch.stack(intermediate_box_sizes), |
|
|
) |
|
|
|
|
|
return output, reference_points, ref_size |
|
|
|