import queue import torch import numpy as np from mmcv.runner import force_fp32, auto_fp16 from mmcv.runner import get_dist_info from mmcv.runner.fp16_utils import cast_tensor_type from mmdet.models import DETECTORS from mmdet3d.core import bbox3d2result from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector from .utils import GridMask, pad_multiple, GpuPhotoMetricDistortion @DETECTORS.register_module() class SparseBEV(MVXTwoStageDetector): def __init__(self, data_aug=None, stop_prev_grad=0, pts_voxel_layer=None, pts_voxel_encoder=None, pts_middle_encoder=None, pts_fusion_layer=None, img_backbone=None, pts_backbone=None, img_neck=None, pts_neck=None, pts_bbox_head=None, img_roi_head=None, img_rpn_head=None, train_cfg=None, test_cfg=None, pretrained=None): super(SparseBEV, self).__init__(pts_voxel_layer, pts_voxel_encoder, pts_middle_encoder, pts_fusion_layer, img_backbone, pts_backbone, img_neck, pts_neck, pts_bbox_head, img_roi_head, img_rpn_head, train_cfg, test_cfg, pretrained) self.data_aug = data_aug self.stop_prev_grad = stop_prev_grad self.color_aug = GpuPhotoMetricDistortion() self.grid_mask = GridMask(ratio=0.5, prob=0.7) self.use_grid_mask = True self.memory = {} self.queue = queue.Queue() @auto_fp16(apply_to=('img'), out_fp32=True) def extract_img_feat(self, img): if self.use_grid_mask: img = self.grid_mask(img) img_feats = self.img_backbone(img) if isinstance(img_feats, dict): img_feats = list(img_feats.values()) if self.with_img_neck: img_feats = self.img_neck(img_feats) return img_feats def extract_feat(self, img, img_metas): if isinstance(img, list): img = torch.stack(img, dim=0) assert img.dim() == 5 B, N, C, H, W = img.size() img = img.view(B * N, C, H, W) img = img.float() # move some augmentations to GPU if self.data_aug is not None: if 'img_color_aug' in self.data_aug and self.data_aug['img_color_aug'] and self.training: img = self.color_aug(img) if 'img_norm_cfg' in self.data_aug: img_norm_cfg = self.data_aug['img_norm_cfg'] norm_mean = torch.tensor(img_norm_cfg['mean'], device=img.device) norm_std = torch.tensor(img_norm_cfg['std'], device=img.device) if img_norm_cfg['to_rgb']: img = img[:, [2, 1, 0], :, :] # BGR to RGB img = img - norm_mean.reshape(1, 3, 1, 1) img = img / norm_std.reshape(1, 3, 1, 1) for b in range(B): img_shape = (img.shape[2], img.shape[3], img.shape[1]) img_metas[b]['img_shape'] = [img_shape for _ in range(N)] img_metas[b]['ori_shape'] = [img_shape for _ in range(N)] if 'img_pad_cfg' in self.data_aug: img_pad_cfg = self.data_aug['img_pad_cfg'] img = pad_multiple(img, img_metas, size_divisor=img_pad_cfg['size_divisor']) input_shape = img.shape[-2:] # update real input shape of each single img for img_meta in img_metas: img_meta.update(input_shape=input_shape) if self.training and self.stop_prev_grad > 0: H, W = input_shape img = img.reshape(B, -1, 6, C, H, W) img_grad = img[:, :self.stop_prev_grad] img_nograd = img[:, self.stop_prev_grad:] all_img_feats = [self.extract_img_feat(img_grad.reshape(-1, C, H, W))] with torch.no_grad(): self.eval() for k in range(img_nograd.shape[1]): all_img_feats.append(self.extract_img_feat(img_nograd[:, k].reshape(-1, C, H, W))) self.train() img_feats = [] for lvl in range(len(all_img_feats[0])): C, H, W = all_img_feats[0][lvl].shape[1:] img_feat = torch.cat([feat[lvl].reshape(B, -1, 6, C, H, W) for feat in all_img_feats], dim=1) img_feat = img_feat.reshape(-1, C, H, W) img_feats.append(img_feat) else: img_feats = self.extract_img_feat(img) img_feats_reshaped = [] for img_feat in img_feats: BN, C, H, W = img_feat.size() img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W)) return img_feats_reshaped def forward_pts_train(self, pts_feats, gt_bboxes_3d, gt_labels_3d, img_metas, gt_bboxes_ignore=None): """Forward function for point cloud branch. Args: pts_feats (list[torch.Tensor]): Features of point cloud branch gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth boxes for each sample. gt_labels_3d (list[torch.Tensor]): Ground truth labels for boxes of each sampole img_metas (list[dict]): Meta information of samples. gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth boxes to be ignored. Defaults to None. Returns: dict: Losses of each branch. """ outs = self.pts_bbox_head(pts_feats, img_metas) loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs] losses = self.pts_bbox_head.loss(*loss_inputs) return losses @force_fp32(apply_to=('img', 'points')) def forward(self, return_loss=True, **kwargs): """Calls either forward_train or forward_test depending on whether return_loss=True. Note this setting will change the expected inputs. When `return_loss=True`, img and img_metas are single-nested (i.e. torch.Tensor and list[dict]), and when `resturn_loss=False`, img and img_metas should be double nested (i.e. list[torch.Tensor], list[list[dict]]), with the outer list indicating test time augmentations. """ if return_loss: return self.forward_train(**kwargs) else: return self.forward_test(**kwargs) def forward_train(self, points=None, img_metas=None, gt_bboxes_3d=None, gt_labels_3d=None, gt_labels=None, gt_bboxes=None, img=None, proposals=None, gt_bboxes_ignore=None, img_depth=None, img_mask=None): """Forward training function. Args: points (list[torch.Tensor], optional): Points of each sample. Defaults to None. img_metas (list[dict], optional): Meta information of each sample. Defaults to None. gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional): Ground truth 3D boxes. Defaults to None. gt_labels_3d (list[torch.Tensor], optional): Ground truth labels of 3D boxes. Defaults to None. gt_labels (list[torch.Tensor], optional): Ground truth labels of 2D boxes in images. Defaults to None. gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in images. Defaults to None. img (torch.Tensor optional): Images of each sample with shape (N, C, H, W). Defaults to None. proposals ([list[torch.Tensor], optional): Predicted proposals used for training Fast RCNN. Defaults to None. gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth 2D boxes in images to be ignored. Defaults to None. Returns: dict: Losses of different branches. """ img_feats = self.extract_feat(img, img_metas) for i in range(len(img_metas)): img_metas[i]['gt_bboxes_3d'] = gt_bboxes_3d[i] img_metas[i]['gt_labels_3d'] = gt_labels_3d[i] losses = self.forward_pts_train(img_feats, gt_bboxes_3d, gt_labels_3d, img_metas, gt_bboxes_ignore) return losses def forward_test(self, img_metas, img=None, **kwargs): for var, name in [(img_metas, 'img_metas')]: if not isinstance(var, list): raise TypeError('{} must be a list, but got {}'.format( name, type(var))) img = [img] if img is None else img return self.simple_test(img_metas[0], img[0], **kwargs) def simple_test_pts(self, x, img_metas, rescale=False): outs = self.pts_bbox_head(x, img_metas) bbox_list = self.pts_bbox_head.get_bboxes(outs, img_metas[0], rescale=rescale) bbox_results = [ bbox3d2result(bboxes, scores, labels) for bboxes, scores, labels in bbox_list ] return bbox_results def simple_test(self, img_metas, img=None, rescale=False): world_size = get_dist_info()[1] if world_size == 1: # online return self.simple_test_online(img_metas, img, rescale) else: # offline return self.simple_test_offline(img_metas, img, rescale) def simple_test_offline(self, img_metas, img=None, rescale=False): img_feats = self.extract_feat(img=img, img_metas=img_metas) bbox_list = [dict() for _ in range(len(img_metas))] bbox_pts = self.simple_test_pts(img_feats, img_metas, rescale=rescale) for result_dict, pts_bbox in zip(bbox_list, bbox_pts): result_dict['pts_bbox'] = pts_bbox return bbox_list def simple_test_online(self, img_metas, img=None, rescale=False): self.fp16_enabled = False assert len(img_metas) == 1 # batch_size = 1 B, N, C, H, W = img.shape img = img.reshape(B, N//6, 6, C, H, W) img_filenames = img_metas[0]['filename'] num_frames = len(img_filenames) // 6 # assert num_frames == img.shape[1] img_shape = (H, W, C) img_metas[0]['img_shape'] = [img_shape for _ in range(len(img_filenames))] img_metas[0]['ori_shape'] = [img_shape for _ in range(len(img_filenames))] img_metas[0]['pad_shape'] = [img_shape for _ in range(len(img_filenames))] img_feats_list, img_metas_list = [], [] # extract feature frame by frame for i in range(num_frames): img_indices = list(np.arange(i * 6, (i + 1) * 6)) img_metas_curr = [{}] for k in img_metas[0].keys(): if isinstance(img_metas[0][k], list): img_metas_curr[0][k] = [img_metas[0][k][i] for i in img_indices] if img_filenames[img_indices[0]] in self.memory: # found in memory img_feats_curr = self.memory[img_filenames[img_indices[0]]] else: # extract feature and put into memory img_feats_curr = self.extract_feat(img[:, i], img_metas_curr) self.memory[img_filenames[img_indices[0]]] = img_feats_curr self.queue.put(img_filenames[img_indices[0]]) while self.queue.qsize() >= 16: # avoid OOM pop_key = self.queue.get() self.memory.pop(pop_key) img_feats_list.append(img_feats_curr) img_metas_list.append(img_metas_curr) # reorganize feat_levels = len(img_feats_list[0]) img_feats_reorganized = [] for j in range(feat_levels): feat_l = torch.cat([img_feats_list[i][j] for i in range(len(img_feats_list))], dim=0) feat_l = feat_l.flatten(0, 1)[None, ...] img_feats_reorganized.append(feat_l) img_metas_reorganized = img_metas_list[0] for i in range(1, len(img_metas_list)): for k, v in img_metas_list[i][0].items(): if isinstance(v, list): img_metas_reorganized[0][k].extend(v) img_feats = img_feats_reorganized img_metas = img_metas_reorganized img_feats = cast_tensor_type(img_feats, torch.half, torch.float32) # run detector bbox_list = [dict() for _ in range(1)] bbox_pts = self.simple_test_pts(img_feats, img_metas, rescale=rescale) for result_dict, pts_bbox in zip(bbox_list, bbox_pts): result_dict['pts_bbox'] = pts_bbox return bbox_list