| | import pickle |
| | import time |
| |
|
| | import numpy as np |
| | import torch |
| | import tqdm |
| |
|
| | from pcdet.models import load_data_to_gpu |
| | from pcdet.utils import common_utils |
| |
|
| |
|
| | def statistics_info(cfg, ret_dict, metric, disp_dict): |
| | for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST: |
| | metric['recall_roi_%s' % str(cur_thresh)] += ret_dict.get('roi_%s' % str(cur_thresh), 0) |
| | metric['recall_rcnn_%s' % str(cur_thresh)] += ret_dict.get('rcnn_%s' % str(cur_thresh), 0) |
| | metric['gt_num'] += ret_dict.get('gt', 0) |
| | min_thresh = cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST[0] |
| | disp_dict['recall_%s' % str(min_thresh)] = \ |
| | '(%d, %d) / %d' % (metric['recall_roi_%s' % str(min_thresh)], metric['recall_rcnn_%s' % str(min_thresh)], metric['gt_num']) |
| |
|
| |
|
| | def eval_one_epoch(cfg, args, model, dataloader, epoch_id, logger, dist_test=False, result_dir=None): |
| | result_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | final_output_dir = result_dir / 'final_result' / 'data' |
| | if args.save_to_file: |
| | final_output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | metric = { |
| | 'gt_num': 0, |
| | } |
| | for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST: |
| | metric['recall_roi_%s' % str(cur_thresh)] = 0 |
| | metric['recall_rcnn_%s' % str(cur_thresh)] = 0 |
| |
|
| | dataset = dataloader.dataset |
| | class_names = dataset.class_names |
| | det_annos = [] |
| |
|
| | if getattr(args, 'infer_time', False): |
| | start_iter = int(len(dataloader) * 0.1) |
| | infer_time_meter = common_utils.AverageMeter() |
| |
|
| | logger.info('*************** EPOCH %s EVALUATION *****************' % epoch_id) |
| | if dist_test: |
| | num_gpus = torch.cuda.device_count() |
| | local_rank = cfg.LOCAL_RANK % num_gpus |
| | model = torch.nn.parallel.DistributedDataParallel( |
| | model, |
| | device_ids=[local_rank], |
| | broadcast_buffers=False |
| | ) |
| | model.eval() |
| |
|
| | if cfg.LOCAL_RANK == 0: |
| | progress_bar = tqdm.tqdm(total=len(dataloader), leave=True, desc='eval', dynamic_ncols=True) |
| | start_time = time.time() |
| | for i, batch_dict in enumerate(dataloader): |
| | load_data_to_gpu(batch_dict) |
| |
|
| | if getattr(args, 'infer_time', False): |
| | start_time = time.time() |
| |
|
| | with torch.no_grad(): |
| | pred_dicts, ret_dict = model(batch_dict) |
| |
|
| | disp_dict = {} |
| |
|
| | if getattr(args, 'infer_time', False): |
| | inference_time = time.time() - start_time |
| | infer_time_meter.update(inference_time * 1000) |
| | |
| | disp_dict['infer_time'] = f'{infer_time_meter.val:.2f}({infer_time_meter.avg:.2f})' |
| |
|
| | statistics_info(cfg, ret_dict, metric, disp_dict) |
| | annos = dataset.generate_prediction_dicts( |
| | batch_dict, pred_dicts, class_names, |
| | output_path=final_output_dir if args.save_to_file else None |
| | ) |
| | det_annos += annos |
| | if cfg.LOCAL_RANK == 0: |
| | progress_bar.set_postfix(disp_dict) |
| | progress_bar.update() |
| |
|
| | if cfg.LOCAL_RANK == 0: |
| | progress_bar.close() |
| |
|
| | if dist_test: |
| | rank, world_size = common_utils.get_dist_info() |
| | det_annos = common_utils.merge_results_dist(det_annos, len(dataset), tmpdir=result_dir / 'tmpdir') |
| | metric = common_utils.merge_results_dist([metric], world_size, tmpdir=result_dir / 'tmpdir') |
| |
|
| | logger.info('*************** Performance of EPOCH %s *****************' % epoch_id) |
| | sec_per_example = (time.time() - start_time) / len(dataloader.dataset) |
| | logger.info('Generate label finished(sec_per_example: %.4f second).' % sec_per_example) |
| |
|
| | if cfg.LOCAL_RANK != 0: |
| | return {} |
| |
|
| | ret_dict = {} |
| | if dist_test: |
| | for key, val in metric[0].items(): |
| | for k in range(1, world_size): |
| | metric[0][key] += metric[k][key] |
| | metric = metric[0] |
| |
|
| | gt_num_cnt = metric['gt_num'] |
| | for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST: |
| | cur_roi_recall = metric['recall_roi_%s' % str(cur_thresh)] / max(gt_num_cnt, 1) |
| | cur_rcnn_recall = metric['recall_rcnn_%s' % str(cur_thresh)] / max(gt_num_cnt, 1) |
| | logger.info('recall_roi_%s: %f' % (cur_thresh, cur_roi_recall)) |
| | logger.info('recall_rcnn_%s: %f' % (cur_thresh, cur_rcnn_recall)) |
| | ret_dict['recall/roi_%s' % str(cur_thresh)] = cur_roi_recall |
| | ret_dict['recall/rcnn_%s' % str(cur_thresh)] = cur_rcnn_recall |
| |
|
| | total_pred_objects = 0 |
| | for anno in det_annos: |
| | total_pred_objects += anno['name'].__len__() |
| | logger.info('Average predicted number of objects(%d samples): %.3f' |
| | % (len(det_annos), total_pred_objects / max(1, len(det_annos)))) |
| |
|
| | with open(result_dir / 'result.pkl', 'wb') as f: |
| | pickle.dump(det_annos, f) |
| |
|
| | print(f"length of det_annos: {len(det_annos)}") |
| | print(dataset) |
| | result_str, result_dict = dataset.evaluation( |
| | det_annos, class_names, |
| | eval_metric=cfg.MODEL.POST_PROCESSING.EVAL_METRIC, |
| | output_path=final_output_dir |
| | ) |
| | print(f"result_dict: {result_dict.keys()}") |
| | logger.info(result_str) |
| | ret_dict.update(result_dict) |
| | logger.info('Result is saved to %s' % result_dir) |
| | logger.info('****************Evaluation done.*****************') |
| | return ret_dict |
| |
|
| |
|
| | if __name__ == '__main__': |
| | pass |
| |
|