| | import pickle |
| |
|
| | import os |
| | import copy |
| | import numpy as np |
| | from skimage import io |
| | import torch |
| | import SharedArray |
| | import torch.distributed as dist |
| |
|
| | from ...ops.iou3d_nms import iou3d_nms_utils |
| | from ...utils import box_utils, common_utils |
| |
|
| | class DataBaseSampler(object): |
| | def __init__(self, root_path, sampler_cfg, class_names, logger=None): |
| | self.root_path = root_path |
| | self.class_names = class_names |
| | self.sampler_cfg = sampler_cfg |
| |
|
| | self.img_aug_type = sampler_cfg.get('IMG_AUG_TYPE', None) |
| | self.img_aug_iou_thresh = sampler_cfg.get('IMG_AUG_IOU_THRESH', 0.5) |
| |
|
| | self.logger = logger |
| | self.db_infos = {} |
| | for class_name in class_names: |
| | self.db_infos[class_name] = [] |
| |
|
| | self.use_shared_memory = sampler_cfg.get('USE_SHARED_MEMORY', False) |
| |
|
| | for db_info_path in sampler_cfg.DB_INFO_PATH: |
| | db_info_path = self.root_path.resolve() / db_info_path |
| | if not db_info_path.exists(): |
| | assert len(sampler_cfg.DB_INFO_PATH) == 1 |
| | sampler_cfg.DB_INFO_PATH[0] = sampler_cfg.BACKUP_DB_INFO['DB_INFO_PATH'] |
| | sampler_cfg.DB_DATA_PATH[0] = sampler_cfg.BACKUP_DB_INFO['DB_DATA_PATH'] |
| | db_info_path = self.root_path.resolve() / sampler_cfg.DB_INFO_PATH[0] |
| | sampler_cfg.NUM_POINT_FEATURES = sampler_cfg.BACKUP_DB_INFO['NUM_POINT_FEATURES'] |
| |
|
| | with open(str(db_info_path), 'rb') as f: |
| | infos = pickle.load(f) |
| | [self.db_infos[cur_class].extend(infos[cur_class]) for cur_class in class_names] |
| |
|
| | for func_name, val in sampler_cfg.PREPARE.items(): |
| | self.db_infos = getattr(self, func_name)(self.db_infos, val) |
| |
|
| | self.gt_database_data_key = self.load_db_to_shared_memory() if self.use_shared_memory else None |
| |
|
| | self.sample_groups = {} |
| | self.sample_class_num = {} |
| | self.limit_whole_scene = sampler_cfg.get('LIMIT_WHOLE_SCENE', False) |
| |
|
| | for x in sampler_cfg.SAMPLE_GROUPS: |
| | class_name, sample_num = x.split(':') |
| | if class_name not in class_names: |
| | continue |
| | self.sample_class_num[class_name] = sample_num |
| | self.sample_groups[class_name] = { |
| | 'sample_num': sample_num, |
| | 'pointer': len(self.db_infos[class_name]), |
| | 'indices': np.arange(len(self.db_infos[class_name])) |
| | } |
| |
|
| | def __getstate__(self): |
| | d = dict(self.__dict__) |
| | del d['logger'] |
| | return d |
| |
|
| | def __setstate__(self, d): |
| | self.__dict__.update(d) |
| |
|
| | def __del__(self): |
| | if self.use_shared_memory: |
| | self.logger.info('Deleting GT database from shared memory') |
| | cur_rank, num_gpus = common_utils.get_dist_info() |
| | sa_key = self.sampler_cfg.DB_DATA_PATH[0] |
| | if cur_rank % num_gpus == 0 and os.path.exists(f"/dev/shm/{sa_key}"): |
| | SharedArray.delete(f"shm://{sa_key}") |
| |
|
| | if num_gpus > 1: |
| | dist.barrier() |
| | self.logger.info('GT database has been removed from shared memory') |
| |
|
| | def load_db_to_shared_memory(self): |
| | self.logger.info('Loading GT database to shared memory') |
| | cur_rank, world_size, num_gpus = common_utils.get_dist_info(return_gpu_per_machine=True) |
| |
|
| | assert self.sampler_cfg.DB_DATA_PATH.__len__() == 1, 'Current only support single DB_DATA' |
| | db_data_path = self.root_path.resolve() / self.sampler_cfg.DB_DATA_PATH[0] |
| | sa_key = self.sampler_cfg.DB_DATA_PATH[0] |
| |
|
| | if cur_rank % num_gpus == 0 and not os.path.exists(f"/dev/shm/{sa_key}"): |
| | gt_database_data = np.load(db_data_path) |
| | common_utils.sa_create(f"shm://{sa_key}", gt_database_data) |
| |
|
| | if num_gpus > 1: |
| | dist.barrier() |
| | self.logger.info('GT database has been saved to shared memory') |
| | return sa_key |
| |
|
| | def filter_by_difficulty(self, db_infos, removed_difficulty): |
| | new_db_infos = {} |
| | for key, dinfos in db_infos.items(): |
| | pre_len = len(dinfos) |
| | new_db_infos[key] = [ |
| | info for info in dinfos |
| | if info['difficulty'] not in removed_difficulty |
| | ] |
| | if self.logger is not None: |
| | self.logger.info('Database filter by difficulty %s: %d => %d' % (key, pre_len, len(new_db_infos[key]))) |
| | return new_db_infos |
| |
|
| | def filter_by_min_points(self, db_infos, min_gt_points_list): |
| | for name_num in min_gt_points_list: |
| | name, min_num = name_num.split(':') |
| | min_num = int(min_num) |
| | if min_num > 0 and name in db_infos.keys(): |
| | filtered_infos = [] |
| | for info in db_infos[name]: |
| | if info['num_points_in_gt'] >= min_num: |
| | filtered_infos.append(info) |
| |
|
| | if self.logger is not None: |
| | self.logger.info('Database filter by min points %s: %d => %d' % |
| | (name, len(db_infos[name]), len(filtered_infos))) |
| | db_infos[name] = filtered_infos |
| |
|
| | return db_infos |
| |
|
| | def sample_with_fixed_number(self, class_name, sample_group): |
| | """ |
| | Args: |
| | class_name: |
| | sample_group: |
| | Returns: |
| | |
| | """ |
| | sample_num, pointer, indices = int(sample_group['sample_num']), sample_group['pointer'], sample_group['indices'] |
| | if pointer >= len(self.db_infos[class_name]): |
| | indices = np.random.permutation(len(self.db_infos[class_name])) |
| | pointer = 0 |
| |
|
| | sampled_dict = [self.db_infos[class_name][idx] for idx in indices[pointer: pointer + sample_num]] |
| | pointer += sample_num |
| | sample_group['pointer'] = pointer |
| | sample_group['indices'] = indices |
| | return sampled_dict |
| |
|
| | @staticmethod |
| | def put_boxes_on_road_planes(gt_boxes, road_planes, calib): |
| | """ |
| | Only validate in KITTIDataset |
| | Args: |
| | gt_boxes: (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...] |
| | road_planes: [a, b, c, d] |
| | calib: |
| | |
| | Returns: |
| | """ |
| | a, b, c, d = road_planes |
| | center_cam = calib.lidar_to_rect(gt_boxes[:, 0:3]) |
| | cur_height_cam = (-d - a * center_cam[:, 0] - c * center_cam[:, 2]) / b |
| | center_cam[:, 1] = cur_height_cam |
| | cur_lidar_height = calib.rect_to_lidar(center_cam)[:, 2] |
| | mv_height = gt_boxes[:, 2] - gt_boxes[:, 5] / 2 - cur_lidar_height |
| | gt_boxes[:, 2] -= mv_height |
| | return gt_boxes, mv_height |
| |
|
| | def copy_paste_to_image_kitti(self, data_dict, crop_feat, gt_number, point_idxes=None): |
| | kitti_img_aug_type = 'by_depth' |
| | kitti_img_aug_use_type = 'annotation' |
| |
|
| | image = data_dict['images'] |
| | boxes3d = data_dict['gt_boxes'] |
| | boxes2d = data_dict['gt_boxes2d'] |
| | corners_lidar = box_utils.boxes_to_corners_3d(boxes3d) |
| | if 'depth' in kitti_img_aug_type: |
| | paste_order = boxes3d[:,0].argsort() |
| | paste_order = paste_order[::-1] |
| | else: |
| | paste_order = np.arange(len(boxes3d),dtype=np.int) |
| |
|
| | if 'reverse' in kitti_img_aug_type: |
| | paste_order = paste_order[::-1] |
| |
|
| | paste_mask = -255 * np.ones(image.shape[:2], dtype=np.int) |
| | fg_mask = np.zeros(image.shape[:2], dtype=np.int) |
| | overlap_mask = np.zeros(image.shape[:2], dtype=np.int) |
| | depth_mask = np.zeros((*image.shape[:2], 2), dtype=np.float) |
| | points_2d, depth_2d = data_dict['calib'].lidar_to_img(data_dict['points'][:,:3]) |
| | points_2d[:,0] = np.clip(points_2d[:,0], a_min=0, a_max=image.shape[1]-1) |
| | points_2d[:,1] = np.clip(points_2d[:,1], a_min=0, a_max=image.shape[0]-1) |
| | points_2d = points_2d.astype(np.int) |
| | for _order in paste_order: |
| | _box2d = boxes2d[_order] |
| | image[_box2d[1]:_box2d[3],_box2d[0]:_box2d[2]] = crop_feat[_order] |
| | overlap_mask[_box2d[1]:_box2d[3],_box2d[0]:_box2d[2]] += \ |
| | (paste_mask[_box2d[1]:_box2d[3],_box2d[0]:_box2d[2]] > 0).astype(np.int) |
| | paste_mask[_box2d[1]:_box2d[3],_box2d[0]:_box2d[2]] = _order |
| |
|
| | if 'cover' in kitti_img_aug_use_type: |
| | |
| | depth_mask[_box2d[1]:_box2d[3],_box2d[0]:_box2d[2],0] = corners_lidar[_order,:,0].min() |
| | depth_mask[_box2d[1]:_box2d[3],_box2d[0]:_box2d[2],1] = corners_lidar[_order,:,0].max() |
| |
|
| | |
| | if _order < gt_number: |
| | fg_mask[_box2d[1]:_box2d[3],_box2d[0]:_box2d[2]] = 1 |
| |
|
| | data_dict['images'] = image |
| |
|
| | |
| | |
| |
|
| | new_mask = paste_mask[points_2d[:,1], points_2d[:,0]]==(point_idxes+gt_number) |
| | if False: |
| | raw_mask = (point_idxes == -1) |
| | else: |
| | raw_fg = (fg_mask == 1) & (paste_mask >= 0) & (paste_mask < gt_number) |
| | raw_bg = (fg_mask == 0) & (paste_mask < 0) |
| | raw_mask = raw_fg[points_2d[:,1], points_2d[:,0]] | raw_bg[points_2d[:,1], points_2d[:,0]] |
| | keep_mask = new_mask | raw_mask |
| | data_dict['points_2d'] = points_2d |
| |
|
| | if 'annotation' in kitti_img_aug_use_type: |
| | data_dict['points'] = data_dict['points'][keep_mask] |
| | data_dict['points_2d'] = data_dict['points_2d'][keep_mask] |
| | elif 'projection' in kitti_img_aug_use_type: |
| | overlap_mask[overlap_mask>=1] = 1 |
| | data_dict['overlap_mask'] = overlap_mask |
| | if 'cover' in kitti_img_aug_use_type: |
| | data_dict['depth_mask'] = depth_mask |
| |
|
| | return data_dict |
| |
|
| | def sample_gt_boxes_2d(self, data_dict, sampled_boxes, valid_mask): |
| | mv_height = None |
| |
|
| | if self.img_aug_type == 'kitti': |
| | sampled_boxes2d, mv_height, ret_valid_mask = self.sample_gt_boxes_2d_kitti(data_dict, sampled_boxes, valid_mask) |
| | else: |
| | raise NotImplementedError |
| |
|
| | return sampled_boxes2d, mv_height, ret_valid_mask |
| |
|
| | def initilize_image_aug_dict(self, data_dict, gt_boxes_mask): |
| | img_aug_gt_dict = None |
| | if self.img_aug_type is None: |
| | pass |
| | elif self.img_aug_type == 'kitti': |
| | obj_index_list, crop_boxes2d = [], [] |
| | gt_number = gt_boxes_mask.sum().astype(np.int) |
| | gt_boxes2d = data_dict['gt_boxes2d'][gt_boxes_mask].astype(np.int) |
| | gt_crops2d = [data_dict['images'][_x[1]:_x[3],_x[0]:_x[2]] for _x in gt_boxes2d] |
| |
|
| | img_aug_gt_dict = { |
| | 'obj_index_list': obj_index_list, |
| | 'gt_crops2d': gt_crops2d, |
| | 'gt_boxes2d': gt_boxes2d, |
| | 'gt_number': gt_number, |
| | 'crop_boxes2d': crop_boxes2d |
| | } |
| | else: |
| | raise NotImplementedError |
| |
|
| | return img_aug_gt_dict |
| |
|
| | def collect_image_crops(self, img_aug_gt_dict, info, data_dict, obj_points, sampled_gt_boxes, sampled_gt_boxes2d, idx): |
| | if self.img_aug_type == 'kitti': |
| | new_box, img_crop2d, obj_points, obj_idx = self.collect_image_crops_kitti(info, data_dict, |
| | obj_points, sampled_gt_boxes, sampled_gt_boxes2d, idx) |
| | img_aug_gt_dict['crop_boxes2d'].append(new_box) |
| | img_aug_gt_dict['gt_crops2d'].append(img_crop2d) |
| | img_aug_gt_dict['obj_index_list'].append(obj_idx) |
| | else: |
| | raise NotImplementedError |
| |
|
| | return img_aug_gt_dict, obj_points |
| |
|
| | def copy_paste_to_image(self, img_aug_gt_dict, data_dict, points): |
| | if self.img_aug_type == 'kitti': |
| | obj_points_idx = np.concatenate(img_aug_gt_dict['obj_index_list'], axis=0) |
| | point_idxes = -1 * np.ones(len(points), dtype=np.int) |
| | point_idxes[:obj_points_idx.shape[0]] = obj_points_idx |
| |
|
| | data_dict['gt_boxes2d'] = np.concatenate([img_aug_gt_dict['gt_boxes2d'], np.array(img_aug_gt_dict['crop_boxes2d'])], axis=0) |
| | data_dict = self.copy_paste_to_image_kitti(data_dict, img_aug_gt_dict['gt_crops2d'], img_aug_gt_dict['gt_number'], point_idxes) |
| | if 'road_plane' in data_dict: |
| | data_dict.pop('road_plane') |
| | else: |
| | raise NotImplementedError |
| | return data_dict |
| |
|
| | def add_sampled_boxes_to_scene(self, data_dict, sampled_gt_boxes, total_valid_sampled_dict, mv_height=None, sampled_gt_boxes2d=None): |
| | gt_boxes_mask = data_dict['gt_boxes_mask'] |
| | gt_boxes = data_dict['gt_boxes'][gt_boxes_mask] |
| | gt_names = data_dict['gt_names'][gt_boxes_mask] |
| | points = data_dict['points'] |
| | if self.sampler_cfg.get('USE_ROAD_PLANE', False) and mv_height is None: |
| | sampled_gt_boxes, mv_height = self.put_boxes_on_road_planes( |
| | sampled_gt_boxes, data_dict['road_plane'], data_dict['calib'] |
| | ) |
| | data_dict.pop('calib') |
| | data_dict.pop('road_plane') |
| |
|
| | obj_points_list = [] |
| |
|
| | |
| | img_aug_gt_dict = self.initilize_image_aug_dict(data_dict, gt_boxes_mask) |
| |
|
| | if self.use_shared_memory: |
| | gt_database_data = SharedArray.attach(f"shm://{self.gt_database_data_key}") |
| | gt_database_data.setflags(write=0) |
| | else: |
| | gt_database_data = None |
| |
|
| | for idx, info in enumerate(total_valid_sampled_dict): |
| | if self.use_shared_memory: |
| | start_offset, end_offset = info['global_data_offset'] |
| | obj_points = copy.deepcopy(gt_database_data[start_offset:end_offset]) |
| | else: |
| | file_path = self.root_path / info['path'] |
| |
|
| | obj_points = np.fromfile(str(file_path), dtype=np.float32).reshape( |
| | [-1, self.sampler_cfg.NUM_POINT_FEATURES]) |
| | if obj_points.shape[0] != info['num_points_in_gt']: |
| | obj_points = np.fromfile(str(file_path), dtype=np.float64).reshape(-1, self.sampler_cfg.NUM_POINT_FEATURES) |
| |
|
| | assert obj_points.shape[0] == info['num_points_in_gt'] |
| | obj_points[:, :3] += info['box3d_lidar'][:3].astype(np.float32) |
| |
|
| | if self.sampler_cfg.get('USE_ROAD_PLANE', False): |
| | |
| | obj_points[:, 2] -= mv_height[idx] |
| |
|
| | if self.img_aug_type is not None: |
| | img_aug_gt_dict, obj_points = self.collect_image_crops( |
| | img_aug_gt_dict, info, data_dict, obj_points, sampled_gt_boxes, sampled_gt_boxes2d, idx |
| | ) |
| |
|
| | obj_points_list.append(obj_points) |
| |
|
| | obj_points = np.concatenate(obj_points_list, axis=0) |
| | sampled_gt_names = np.array([x['name'] for x in total_valid_sampled_dict]) |
| |
|
| | if self.sampler_cfg.get('FILTER_OBJ_POINTS_BY_TIMESTAMP', False) or obj_points.shape[-1] != points.shape[-1]: |
| | if self.sampler_cfg.get('FILTER_OBJ_POINTS_BY_TIMESTAMP', False): |
| | min_time = min(self.sampler_cfg.TIME_RANGE[0], self.sampler_cfg.TIME_RANGE[1]) |
| | max_time = max(self.sampler_cfg.TIME_RANGE[0], self.sampler_cfg.TIME_RANGE[1]) |
| | else: |
| | assert obj_points.shape[-1] == points.shape[-1] + 1 |
| | |
| | min_time = max_time = 0.0 |
| |
|
| | time_mask = np.logical_and(obj_points[:, -1] < max_time + 1e-6, obj_points[:, -1] > min_time - 1e-6) |
| | obj_points = obj_points[time_mask] |
| |
|
| | large_sampled_gt_boxes = box_utils.enlarge_box3d( |
| | sampled_gt_boxes[:, 0:7], extra_width=self.sampler_cfg.REMOVE_EXTRA_WIDTH |
| | ) |
| | points = box_utils.remove_points_in_boxes3d(points, large_sampled_gt_boxes) |
| | points = np.concatenate([obj_points[:, :points.shape[-1]], points], axis=0) |
| | gt_names = np.concatenate([gt_names, sampled_gt_names], axis=0) |
| | gt_boxes = np.concatenate([gt_boxes, sampled_gt_boxes], axis=0) |
| | data_dict['gt_boxes'] = gt_boxes |
| | data_dict['gt_names'] = gt_names |
| | data_dict['points'] = points |
| |
|
| | if self.img_aug_type is not None: |
| | data_dict = self.copy_paste_to_image(img_aug_gt_dict, data_dict, points) |
| |
|
| | return data_dict |
| |
|
| | def __call__(self, data_dict): |
| | """ |
| | Args: |
| | data_dict: |
| | gt_boxes: (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...] |
| | |
| | Returns: |
| | |
| | """ |
| | gt_boxes = data_dict['gt_boxes'] |
| | gt_names = data_dict['gt_names'].astype(str) |
| | existed_boxes = gt_boxes |
| | total_valid_sampled_dict = [] |
| | sampled_mv_height = [] |
| | sampled_gt_boxes2d = [] |
| |
|
| | for class_name, sample_group in self.sample_groups.items(): |
| | if self.limit_whole_scene: |
| | num_gt = np.sum(class_name == gt_names) |
| | sample_group['sample_num'] = str(int(self.sample_class_num[class_name]) - num_gt) |
| | if int(sample_group['sample_num']) > 0: |
| | sampled_dict = self.sample_with_fixed_number(class_name, sample_group) |
| |
|
| | sampled_boxes = np.stack([x['box3d_lidar'] for x in sampled_dict], axis=0).astype(np.float32) |
| |
|
| | assert not self.sampler_cfg.get('DATABASE_WITH_FAKELIDAR', False), 'Please use latest codes to generate GT_DATABASE' |
| |
|
| | iou1 = iou3d_nms_utils.boxes_bev_iou_cpu(sampled_boxes[:, 0:7], existed_boxes[:, 0:7]) |
| | iou2 = iou3d_nms_utils.boxes_bev_iou_cpu(sampled_boxes[:, 0:7], sampled_boxes[:, 0:7]) |
| | iou2[range(sampled_boxes.shape[0]), range(sampled_boxes.shape[0])] = 0 |
| | iou1 = iou1 if iou1.shape[1] > 0 else iou2 |
| | valid_mask = ((iou1.max(axis=1) + iou2.max(axis=1)) == 0) |
| |
|
| | if self.img_aug_type is not None: |
| | sampled_boxes2d, mv_height, valid_mask = self.sample_gt_boxes_2d(data_dict, sampled_boxes, valid_mask) |
| | sampled_gt_boxes2d.append(sampled_boxes2d) |
| | if mv_height is not None: |
| | sampled_mv_height.append(mv_height) |
| |
|
| | valid_mask = valid_mask.nonzero()[0] |
| | valid_sampled_dict = [sampled_dict[x] for x in valid_mask] |
| | valid_sampled_boxes = sampled_boxes[valid_mask] |
| |
|
| | existed_boxes = np.concatenate((existed_boxes, valid_sampled_boxes[:, :existed_boxes.shape[-1]]), axis=0) |
| | total_valid_sampled_dict.extend(valid_sampled_dict) |
| |
|
| | sampled_gt_boxes = existed_boxes[gt_boxes.shape[0]:, :] |
| |
|
| | if total_valid_sampled_dict.__len__() > 0: |
| | sampled_gt_boxes2d = np.concatenate(sampled_gt_boxes2d, axis=0) if len(sampled_gt_boxes2d) > 0 else None |
| | sampled_mv_height = np.concatenate(sampled_mv_height, axis=0) if len(sampled_mv_height) > 0 else None |
| |
|
| | data_dict = self.add_sampled_boxes_to_scene( |
| | data_dict, sampled_gt_boxes, total_valid_sampled_dict, sampled_mv_height, sampled_gt_boxes2d |
| | ) |
| |
|
| | data_dict.pop('gt_boxes_mask') |
| | return data_dict |
| |
|