unknownuser6666's picture
Upload folder using huggingface_hub
663494c verified
from copy import deepcopy
from os import times
import torch
import torch.nn as nn
from mmcv.runner import auto_fp16, force_fp32
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet3d.core.bbox.coders import build_bbox_coder
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
from mmdet.models import DETECTORS, build_loss
from mmdet3d_plugin.core.bbox.util import denormalize_bbox, normalize_bbox
from .grid_mask import GridMask
from .memory_bank import build_memory_bank
from .qim import build_qim
from .radar_encoder import build_radar_encoder
from .structures import Instances
def inverse_sigmoid(x, eps=1e-5):
"""Inverse function of sigmoid.
Args:
x (Tensor): The tensor to do the
inverse.
eps (float): EPS avoid numerical
overflow. Defaults 1e-5.
Returns:
Tensor: The x has passed the inverse
function of sigmoid, has same
shape with input.
"""
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
# this class is from MOTR
class RuntimeTrackerBase(object):
# code from https://github.com/megvii-model/MOTR/blob/main/models/motr.py#L303
def __init__(self, score_thresh=0.7, filter_score_thresh=0.6, miss_tolerance=5):
self.score_thresh = score_thresh
self.filter_score_thresh = filter_score_thresh
self.miss_tolerance = miss_tolerance
self.max_obj_id = 0
def clear(self):
self.max_obj_id = 0
def update(self, track_instances: Instances):
track_instances.disappear_time[track_instances.scores >= self.score_thresh] = 0
for i in range(len(track_instances)):
if (
track_instances.obj_idxes[i] == -1
and track_instances.scores[i] >= self.score_thresh
):
# new track
# print("track {} has score {}, assign obj_id {}".format(i, track_instances.scores[i], self.max_obj_id))
track_instances.obj_idxes[i] = self.max_obj_id
self.max_obj_id += 1
elif (
track_instances.obj_idxes[i] >= 0
and track_instances.scores[i] < self.filter_score_thresh
):
# sleep time ++
track_instances.disappear_time[i] += 1
if track_instances.disappear_time[i] >= self.miss_tolerance:
# mark deaded tracklets: Set the obj_id to -1.
# TODO: remove it by following functions
# Then this track will be removed by TrackEmbeddingLayer.
track_instances.obj_idxes[i] = -1
def update_fix_label(self, track_instances: Instances, old_class_scores):
track_instances.disappear_time[track_instances.scores >= self.score_thresh] = 0
for i in range(len(track_instances)):
if (
track_instances.obj_idxes[i] == -1
and track_instances.scores[i] >= self.score_thresh
):
# new track
# print("track {} has score {}, assign obj_id {}".format(i, track_instances.scores[i], self.max_obj_id))
track_instances.obj_idxes[i] = self.max_obj_id
self.max_obj_id += 1
elif (
track_instances.obj_idxes[i] >= 0
and track_instances.scores[i] < self.filter_score_thresh
):
# sleep time ++
track_instances.disappear_time[i] += 1
# keep class unchanged!
track_instances.pred_logits[i] = old_class_scores[i]
if track_instances.disappear_time[i] >= self.miss_tolerance:
# mark deaded tracklets: Set the obj_id to -1.
# TODO: remove it by following functions
# Then this track will be removed by TrackEmbeddingLayer.
track_instances.obj_idxes[i] = -1
elif (
track_instances.obj_idxes[i] >= 0
and track_instances.scores[i] >= self.filter_score_thresh
):
# keep class unchanged!
track_instances.pred_logits[i] = old_class_scores[i]
@DETECTORS.register_module()
class MUTRCamTracker(MVXTwoStageDetector):
"""Tracker which support image w, w/o radar."""
def __init__(
self,
embed_dims=256,
num_query=300,
num_classes=7,
bbox_coder=dict(
type="DETRTrack3DCoder",
post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
pc_range=[-51.2, -51.2, -5.0, 51.2, 51.2, 3.0],
max_num=300,
num_classes=7,
),
qim_args=dict(
qim_type="QIMBase",
merger_dropout=0,
update_query_pos=False,
fp_ratio=0.3,
random_drop=0.1,
),
mem_cfg=dict(
memory_bank_type="MemoryBank",
memory_bank_score_thresh=0.0,
memory_bank_len=4,
),
radar_encoder=None,
fix_feats=False,
score_thresh=0.2,
filter_score_thresh=0.1,
use_grid_mask=False,
pts_voxel_layer=None,
pts_voxel_encoder=None,
pts_middle_encoder=None,
pts_fusion_layer=None,
img_backbone=None,
pts_backbone=None,
img_neck=None,
pts_neck=None,
loss_cfg=None,
pts_bbox_head=None,
img_roi_head=None,
img_rpn_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
):
super(MUTRCamTracker, self).__init__(
pts_voxel_layer,
pts_voxel_encoder,
pts_middle_encoder,
pts_fusion_layer,
img_backbone,
pts_backbone,
img_neck,
pts_neck,
pts_bbox_head,
img_roi_head,
img_rpn_head,
train_cfg,
test_cfg,
pretrained,
)
self.grid_mask = GridMask(
True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7
)
self.use_grid_mask = use_grid_mask
self.num_classes = num_classes
self.bbox_coder = build_bbox_coder(bbox_coder)
self.pc_range = self.bbox_coder.pc_range
self.embed_dims = embed_dims # 256
self.num_query = num_query
self.fix_feats = fix_feats
if self.fix_feats:
self.img_backbone.eval()
self.img_neck.eval()
self.reference_points = nn.Linear(self.embed_dims, 3)
self.bbox_size_fc = nn.Linear(self.embed_dims, 3)
self.query_embedding = nn.Embedding(self.num_query, self.embed_dims * 2)
self.mem_bank_len = mem_cfg["memory_bank_len"]
self.memory_bank = None
self.track_base = RuntimeTrackerBase(
score_thresh=score_thresh,
filter_score_thresh=filter_score_thresh,
miss_tolerance=5,
) # hyper-param for removing inactive queries
self.query_interact = build_qim(
qim_args,
dim_in=embed_dims,
hidden_dim=embed_dims,
dim_out=embed_dims,
)
self.memory_bank = build_memory_bank(
args=mem_cfg,
dim_in=embed_dims,
hidden_dim=embed_dims,
dim_out=embed_dims,
)
self.mem_bank_len = (
0 if self.memory_bank is None else self.memory_bank.max_his_length
)
self.criterion = build_loss(loss_cfg)
self.test_track_instances = None
self.l2g_r_mat = None
self.l2g_t = None
self.radar_encoder = build_radar_encoder(radar_encoder)
def velo_update(
self, ref_pts, velocity, l2g_r1, l2g_t1, l2g_r2, l2g_t2, time_delta
):
"""
Args:
ref_pts (Tensor): (num_query, 3). in inevrse sigmoid space
velocity (Tensor): (num_query, 2). m/s
in lidar frame. vx, vy
global2lidar (np.Array) [4,4].
Outs:
ref_pts (Tensor): (num_query, 3). in inevrse sigmoid space
"""
# print(l2g_r1.type(), l2g_t1.type(), ref_pts.type())
time_delta = time_delta.type(torch.float)
num_query = ref_pts.size(0)
velo_pad_ = velocity.new_zeros((num_query, 1))
velo_pad = torch.cat((velocity, velo_pad_), dim=-1)
# unnormalize
reference_points = ref_pts.sigmoid().clone()
pc_range = self.pc_range
reference_points[..., 0:1] = (
reference_points[..., 0:1] * (pc_range[3] - pc_range[0]) + pc_range[0]
)
reference_points[..., 1:2] = (
reference_points[..., 1:2] * (pc_range[4] - pc_range[1]) + pc_range[1]
)
reference_points[..., 2:3] = (
reference_points[..., 2:3] * (pc_range[5] - pc_range[2]) + pc_range[2]
)
# motion model
reference_points = reference_points + velo_pad * time_delta
# coordinate transform
ref_pts = reference_points @ l2g_r1 + l2g_t1 - l2g_t2
g2l_r = torch.linalg.inv(l2g_r2).type(torch.float)
ref_pts = ref_pts @ g2l_r
# unnormalize
ref_pts[..., 0:1] = (ref_pts[..., 0:1] - pc_range[0]) / (
pc_range[3] - pc_range[0]
)
ref_pts[..., 1:2] = (ref_pts[..., 1:2] - pc_range[1]) / (
pc_range[4] - pc_range[1]
)
ref_pts[..., 2:3] = (ref_pts[..., 2:3] - pc_range[2]) / (
pc_range[5] - pc_range[2]
)
ref_pts = inverse_sigmoid(ref_pts)
return ref_pts
def extract_pts_feat(self, pts, img_feats, img_metas):
"""Extract features of points."""
if not self.with_pts_bbox:
return None
voxels, num_points, coors = self.voxelize(pts)
voxel_features = self.pts_voxel_encoder(voxels, num_points, coors)
batch_size = coors[-1, 0] + 1
x = self.pts_middle_encoder(voxel_features, coors, batch_size)
x = self.pts_backbone(x)
if self.with_pts_neck:
x = self.pts_neck(x)
return x
def extract_img_feat(self, img, img_metas):
"""Extract features of images."""
B = img.size(0)
if self.with_img_backbone and img is not None:
input_shape = img.shape[-2:]
# update real input shape of each single img
for img_meta in img_metas:
img_meta.update(input_shape=input_shape)
if img.dim() == 5 and img.size(0) == 1:
img.squeeze_()
elif img.dim() == 5 and img.size(0) > 1:
B, N, C, H, W = img.size()
img = img.view(B * N, C, H, W)
if self.use_grid_mask:
img = self.grid_mask(img)
img_feats = self.img_backbone(img)
else:
return None
if self.with_img_neck:
img_feats = self.img_neck(img_feats)
img_feats_reshaped = []
for img_feat in img_feats:
BN, C, H, W = img_feat.size()
img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W))
return img_feats_reshaped
@auto_fp16(apply_to=("img"), out_fp32=True)
def extract_feat(self, points, img, radar=None, img_metas=None):
"""Extract features from images and lidar points and radars."""
# lidar feature distabled. (param points not used )
radar = None # don't use radar feature
if radar is not None:
radar_feats = self.radar_encoder(radar)
else:
radar_feats = None
if self.fix_feats:
with torch.no_grad():
img_feats = self.extract_img_feat(img, img_metas)
else:
img_feats = self.extract_img_feat(img, img_metas)
return (img_feats, radar_feats, None)
def _targets_to_instances(
self,
gt_bboxes_3d=None,
gt_labels_3d=None,
instance_inds=None,
img_shape=(
1,
1,
),
):
gt_instances = Instances(tuple(img_shape))
gt_instances.boxes = gt_bboxes_3d
gt_instances.labels = gt_labels_3d
gt_instances.obj_ids = instance_inds
return gt_instances
def _generate_empty_tracks(self):
# create class for empty tracks in the first frame
track_instances = Instances((1, 1))
num_queries, dim = self.query_embedding.weight.shape # (N, 256 * 2)
device = self.query_embedding.weight.device
query = self.query_embedding.weight # N x 512
# convert the query embedding to a point in 3D
track_instances.ref_pts = self.reference_points(query[..., : dim // 2]) # N x 3
# init boxes: xy, wl, z, h, sin, cos, vx, vy, vz
box_sizes = self.bbox_size_fc(query[..., : dim // 2]) # N x 3
pred_boxes_init = torch.zeros(
(len(track_instances), 10), dtype=torch.float, device=device
) # N x 10
pred_boxes_init[..., 2:4] = box_sizes[..., 0:2]
pred_boxes_init[..., 5:6] = box_sizes[..., 2:3]
# track instance class, add other properties
track_instances.query = query
track_instances.output_embedding = torch.zeros(
(num_queries, dim >> 1), device=device
) # N x 256
track_instances.obj_idxes = torch.full(
(len(track_instances),), -1, dtype=torch.long, device=device
) # N x 1
track_instances.matched_gt_idxes = torch.full(
(len(track_instances),), -1, dtype=torch.long, device=device
) # N x 1
track_instances.disappear_time = torch.zeros(
(len(track_instances),), dtype=torch.long, device=device
) # N x 1
track_instances.scores = torch.zeros(
(len(track_instances),), dtype=torch.float, device=device
)
track_instances.track_scores = torch.zeros(
(len(track_instances),), dtype=torch.float, device=device
)
# xy, wl, z, h, sin, cos, vx, vy, vz
track_instances.pred_boxes = pred_boxes_init
track_instances.pred_logits = torch.zeros(
(len(track_instances), self.num_classes), dtype=torch.float, device=device
)
mem_bank_len = self.mem_bank_len
track_instances.mem_bank = torch.zeros(
(len(track_instances), mem_bank_len, dim // 2),
dtype=torch.float32,
device=device,
)
track_instances.mem_padding_mask = torch.ones(
(len(track_instances), mem_bank_len), dtype=torch.bool, device=device
)
track_instances.save_period = torch.zeros(
(len(track_instances),), dtype=torch.float32, device=device
)
return track_instances.to(self.query_embedding.weight.device)
def _copy_tracks_for_loss(self, tgt_instances):
device = self.query_embedding.weight.device
track_instances = Instances((1, 1))
track_instances.obj_idxes = deepcopy(tgt_instances.obj_idxes)
track_instances.matched_gt_idxes = deepcopy(tgt_instances.matched_gt_idxes)
track_instances.disappear_time = deepcopy(tgt_instances.disappear_time)
track_instances.scores = torch.zeros(
(len(track_instances),), dtype=torch.float, device=device
)
track_instances.track_scores = torch.zeros(
(len(track_instances),), dtype=torch.float, device=device
)
track_instances.pred_boxes = torch.zeros(
(len(track_instances), 10), dtype=torch.float, device=device
)
track_instances.pred_logits = torch.zeros(
(len(track_instances), self.num_classes), dtype=torch.float, device=device
)
track_instances.save_period = deepcopy(tgt_instances.save_period)
return track_instances.to(self.query_embedding.weight.device)
@force_fp32(apply_to=("img", "points"))
def forward(self, return_loss=True, **kwargs):
"""Calls either forward_train or forward_test depending on whether
return_loss=True.
Note this setting will change the expected inputs. When
`return_loss=True`, img and img_metas are single-nested (i.e.
torch.Tensor and list[dict]), and when `resturn_loss=False`, img and
img_metas should be double nested (i.e. list[torch.Tensor],
list[list[dict]]), with the outer list indicating test time
augmentations.
"""
if return_loss:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
# @auto_fp16(apply_to=('img', 'radar'))
def _forward_single(
self,
points,
img,
radar,
img_metas,
track_instances,
l2g_r1=None,
l2g_t1=None,
l2g_r2=None,
l2g_t2=None,
time_delta=None,
):
"""
Perform forward only on one frame. Called in forward_train
Warnning: Only Support BS=1
Args:
img: shape [B, num_cam, 3, H, W]
if l2g_r2 is None or l2g_t2 is None:
it means this frame is the end of the training clip,
so no need to call velocity update
"""
# retrieve feature in the frame
B, num_cam, _, H, W = img.shape
img_feats, radar_feats, pts_feats = self.extract_feat(
points, img=img, radar=radar, img_metas=img_metas
)
# img_feats/pts_feats: 4 (stages) [1 x T x C x H x W]
ref_box_sizes = torch.cat(
[track_instances.pred_boxes[:, 2:4], track_instances.pred_boxes[:, 5:6]],
dim=1,
)
# query for each reference points and get outputs
# output_classes: L x B x N x 7
# output_coords: L x B x N x 10
# query_feats: B x N x embed_dim
output_classes, output_coords, query_feats, last_ref_pts = self.pts_bbox_head(
img_feats,
radar_feats,
track_instances.query,
track_instances.ref_pts,
ref_box_sizes,
img_metas,
)
# initialize outputs from the last layer
out = {
"pred_logits": output_classes[-1],
"pred_boxes": output_coords[-1],
"ref_pts": last_ref_pts,
}
with torch.no_grad():
track_scores = output_classes[-1, 0, :].sigmoid().max(dim=-1).values
# Step-1 Update track instances with current prediction
# the track id will be assigned by the matcher.
# copy the track instances and initialize the loss for all decoder layers
nb_dec = output_classes.size(0)
track_instances_list = [
self._copy_tracks_for_loss(track_instances) for i in range(nb_dec - 1)
]
track_instances.output_embedding = query_feats[0] # [300, feat_dim]
# update the reference points based on the velocity,
# take it from the last layer of the outputs
velo = output_coords[-1, 0, :, -2:] # [num_query, 3]
if l2g_r2 is not None:
ref_pts = self.velo_update(
last_ref_pts[0],
velo,
l2g_r1,
l2g_t1,
l2g_r2,
l2g_t2,
time_delta=time_delta,
)
else:
ref_pts = last_ref_pts[0]
track_instances.ref_pts = ref_pts
# add the last frame of track instance data
track_instances_list.append(track_instances)
# update results for all decoder layers
for i in range(nb_dec):
track_instances = track_instances_list[i]
# track_scores = output_classes[i, 0, :].sigmoid().max(dim=-1).values
track_instances.scores = track_scores
track_instances.pred_logits = output_classes[i, 0] # [300, num_cls]
track_instances.pred_boxes = output_coords[i, 0] # [300, box_dim]
out["track_instances"] = track_instances
# compute loss for this layer of results
track_instances = self.criterion.match_for_single_frame(
out, i, if_step=(i == (nb_dec - 1))
)
# print('\n\nlayer %d' % i)
# print(track_instances.obj_idxes) 6893
# print(track_instances.matched_gt_idxes) 3
# print(track_instances.scores)
# print(track_instances.pred_boxes)
if self.memory_bank is not None:
track_instances = self.memory_bank(track_instances)
# Step-2 query interaction
tmp = {}
tmp["init_track_instances"] = self._generate_empty_tracks()
tmp["track_instances"] = track_instances
out_track_instances = self.query_interact(tmp)
out["track_instances"] = out_track_instances
return out
def forward_train(
self,
points=None,
img=None,
radar=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
instance_inds=None,
l2g_r_mat=None,
l2g_t=None,
gt_bboxes_ignore=None,
timestamp=None,
):
"""Forward training function.
This function will call _forward_single in a for loop
Args:
points (list(list[torch.Tensor]), optional): B-T-sample
Points of each sample.
Defaults to None.
img (Torch.Tensor) of shape [B, T, num_cam, 3, H, W]
radar (Torch.Tensor) of shape [B, T, num_points, radar_dim]
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
lidar2img = img_metas[bs]['lidar2img'] of shape [3, 6, 4, 4]. list
of list of list of 4x4 array
gt_bboxes_3d (list[list[:obj:`BaseInstance3DBoxes`]], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[list[torch.Tensor]], optional): Ground truth labels
of 3D boxes. Defaults to None.
l2g_r_mat (list[Tensor]). element shape [T, 3, 3]
l2g_t (list[Tensor]). element shape [T, 3]
normally you should call points @ R_Mat.T + T
here, just call points @ R_mat + T
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# [T+1, 3, 3]
l2g_r_mat = l2g_r_mat[0]
# change to [T+1, 1, 3]
l2g_t = l2g_t[0].unsqueeze(dim=1)
# timestamp = timestamp
# bs = img.size(0)
# img: 1 x T x N_cam x 3 x H x W
num_frame = img.size(1) - 1
track_instances = self._generate_empty_tracks()
# init gt instances!
gt_instances_list = []
for i in range(num_frame):
gt_instances = Instances((1, 1))
boxes = gt_bboxes_3d[0][i].tensor.to(img.device) # M x 9
# nuScenes format: xyz, wlh, rot(radian), vx, vy in world coordinate
# normalize gt bboxes
boxes = normalize_bbox(boxes, self.pc_range) # M x 10
gt_instances.boxes = boxes
gt_instances.labels = gt_labels_3d[0][i] # (M, )
gt_instances.obj_ids = instance_inds[0][i] # (M, )
gt_instances_list.append(gt_instances)
# TODO init criterion
self.criterion.initialize_for_single_clip(gt_instances_list)
# for bs 1
lidar2img = img_metas[0]["lidar2img"] # [T, num_cam]
for i in range(num_frame):
points_single = [p_[i] for p_ in points]
img_single = torch.stack([img_[i] for img_ in img], dim=0)
radar_single = torch.stack([radar_[i] for radar_ in radar], dim=0)
img_metas_single = deepcopy(img_metas)
img_metas_single[0]["lidar2img"] = lidar2img[i]
if i == num_frame - 1:
l2g_r2 = None
l2g_t2 = None
time_delta = None
else:
l2g_r2 = l2g_r_mat[i + 1]
l2g_t2 = l2g_t[i + 1]
time_delta = timestamp[i + 1] - timestamp[i]
frame_res = self._forward_single(
points_single,
img_single,
radar_single,
img_metas_single,
track_instances,
l2g_r_mat[i],
l2g_t[i],
l2g_r2,
l2g_t2,
time_delta,
)
track_instances = frame_res["track_instances"]
outputs = self.criterion.losses_dict
return outputs
def _inference_single(
self,
points,
img,
radar,
img_metas,
track_instances,
l2g_r1=None,
l2g_t1=None,
l2g_r2=None,
l2g_t2=None,
time_delta=None,
):
"""
This function will be called at forward_test
Warnning: Only Support BS=1
img: shape [B, num_cam, 3, H, W]
"""
# velo update:
active_inst = track_instances[track_instances.obj_idxes >= 0]
other_inst = track_instances[track_instances.obj_idxes < 0]
if l2g_r2 is not None and len(active_inst) > 0 and l2g_r1 is not None:
ref_pts = active_inst.ref_pts
velo = active_inst.pred_boxes[:, -2:]
ref_pts = self.velo_update(
ref_pts, velo, l2g_r1, l2g_t1, l2g_r2, l2g_t2, time_delta=time_delta
)
active_inst.ref_pts = ref_pts
track_instances = Instances.cat([other_inst, active_inst])
B, num_cam, _, H, W = img.shape
img_feats, radar_feats, pts_feats = self.extract_feat(
points, img=img, radar=radar, img_metas=img_metas
)
img_feats = [a.clone() for a in img_feats]
# output_classes: [num_dec, B, num_query, num_classes]
# query_feats: [B, num_query, embed_dim]
ref_box_sizes = torch.cat(
[track_instances.pred_boxes[:, 2:4], track_instances.pred_boxes[:, 5:6]],
dim=1,
)
output_classes, output_coords, query_feats, last_ref_pts = self.pts_bbox_head(
img_feats,
radar_feats,
track_instances.query,
track_instances.ref_pts,
ref_box_sizes,
img_metas,
)
out = {
"pred_logits": output_classes[-1],
"pred_boxes": output_coords[-1],
"ref_pts": last_ref_pts,
}
# TODO: Why no max?
track_scores = output_classes[-1, 0, :].sigmoid().max(dim=-1).values
# Step-1 Update track instances with current prediction
# [nb_dec, bs, num_query, xxx]
# each track will be assigned an unique global id by the track base.
track_instances.scores = track_scores
# track_instances.track_scores = track_scores # [300]
track_instances.pred_logits = output_classes[-1, 0] # [300, num_cls]
track_instances.pred_boxes = output_coords[-1, 0] # [300, box_dim]
track_instances.output_embedding = query_feats[0] # [300, feat_dim]
track_instances.ref_pts = last_ref_pts[0]
self.track_base.update(track_instances)
if self.memory_bank is not None:
track_instances = self.memory_bank(track_instances)
# Step-2 Update track instances using matcher
tmp = {}
tmp["init_track_instances"] = self._generate_empty_tracks()
tmp["track_instances"] = track_instances
out_track_instances = self.query_interact(tmp)
out["track_instances"] = out_track_instances
return out
def forward_test(
self,
points=None,
img=None,
radar=None,
img_metas=None,
timestamp=1e6,
l2g_r_mat=None,
l2g_t=None,
**kwargs,
):
"""Forward test function.
only support bs=1, single-gpu, num_frame=1 test
Args:
points (list(list[torch.Tensor]), optional): B-T-sample
Points of each sample.
Defaults to None.
img (Torch.Tensor) of shape [B, T, num_cam, 3, H, W]
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
lidar2img = img_metas[bs]['lidar2img'] of shape [3, 6, 4, 4]. list
of list of list of 4x4 array
gt_bboxes_3d (list[list[:obj:`BaseInstance3DBoxes`]], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[list[torch.Tensor]], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# [3, 3]
l2g_r_mat = l2g_r_mat[0][0]
# change to [1, 3]
l2g_t = l2g_t[0].unsqueeze(dim=1)[0]
bs = img.size(0)
num_frame = img.size(1)
timestamp = timestamp[0]
if self.test_track_instances is None:
track_instances = self._generate_empty_tracks()
self.test_track_instances = track_instances
self.timestamp = timestamp[0]
# TODO: use scene tokens?
if timestamp[0] - self.timestamp > 10:
track_instances = self._generate_empty_tracks()
time_delta = None
l2g_r1 = None
l2g_t1 = None
l2g_r2 = None
l2g_t2 = None
else:
track_instances = self.test_track_instances
time_delta = timestamp[0] - self.timestamp
l2g_r1 = self.l2g_r_mat
l2g_t1 = self.l2g_t
l2g_r2 = l2g_r_mat
l2g_t2 = l2g_t
self.timestamp = timestamp[-1]
self.l2g_r_mat = l2g_r_mat
self.l2g_t = l2g_t
# for bs 1;
lidar2img = img_metas[0]["lidar2img"] # [T, num_cam]
for i in range(num_frame):
points_single = [p_[i] for p_ in points]
img_single = torch.stack([img_[i] for img_ in img], dim=0)
radar_single = torch.stack([radar_[i] for radar_ in radar], dim=0)
img_metas_single = deepcopy(img_metas)
img_metas_single[0]["lidar2img"] = lidar2img[i]
frame_res = self._inference_single(
points_single,
img_single,
radar_single,
img_metas_single,
track_instances,
l2g_r1,
l2g_t1,
l2g_r2,
l2g_t2,
time_delta,
)
track_instances = frame_res["track_instances"]
active_instances = self.query_interact._select_active_tracks(
dict(track_instances=track_instances)
)
self.test_track_instances = track_instances
results = self._active_instances2results(active_instances, img_metas)
return results
def _active_instances2results(self, active_instances, img_metas):
"""
Outs:
active_instances. keys:
- 'pred_logits':
- 'pred_boxes': normalized bboxes
- 'scores'
- 'obj_idxes'
out_dict. keys:
- boxes_3d (torch.Tensor): 3D boxes.
- scores (torch.Tensor): Prediction scores.
- labels_3d (torch.Tensor): Box labels.
- attrs_3d (torch.Tensor, optional): Box attributes.
- track_ids
- tracking_score
"""
# filter out sleep querys
active_idxes = active_instances.scores >= self.track_base.filter_score_thresh
active_instances = active_instances[active_idxes]
if active_instances.pred_logits.numel() == 0:
return [None]
bbox_dict = dict(
cls_scores=active_instances.pred_logits,
bbox_preds=active_instances.pred_boxes,
track_scores=active_instances.scores,
obj_idxes=active_instances.obj_idxes,
)
bboxes_dict = self.bbox_coder.decode(bbox_dict)[0]
bboxes = bboxes_dict["bboxes"]
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
bboxes = img_metas[0]["box_type_3d"][0](bboxes, 9)
labels = bboxes_dict["labels"]
scores = bboxes_dict["scores"]
track_scores = bboxes_dict["track_scores"]
obj_idxes = bboxes_dict["obj_idxes"]
result_dict = dict(
boxes_3d=bboxes.to("cpu"),
scores_3d=scores.cpu(),
labels_3d=labels.cpu(),
track_scores=track_scores.cpu(),
track_ids=obj_idxes.cpu(),
)
return [result_dict]