Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| from mmengine.fileio import dump, load | |
| from mmengine.logging import print_log | |
| from mmengine.utils import ProgressBar | |
| from pycocotools.coco import COCO | |
| from pycocotools.cocoeval import COCOeval | |
| from mmdet.models.utils import weighted_boxes_fusion | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Fusion image \ | |
| prediction results using Weighted \ | |
| Boxes Fusion from multiple models.') | |
| parser.add_argument( | |
| 'pred-results', | |
| type=str, | |
| nargs='+', | |
| help='files of prediction results \ | |
| from multiple models, json format.') | |
| parser.add_argument('--annotation', type=str, help='annotation file path') | |
| parser.add_argument( | |
| '--weights', | |
| type=float, | |
| nargs='*', | |
| default=None, | |
| help='weights for each model, ' | |
| 'remember to correspond to the above prediction path.') | |
| parser.add_argument( | |
| '--fusion-iou-thr', | |
| type=float, | |
| default=0.55, | |
| help='IoU value for boxes to be a match in wbf.') | |
| parser.add_argument( | |
| '--skip-box-thr', | |
| type=float, | |
| default=0.0, | |
| help='exclude boxes with score lower than this variable in wbf.') | |
| parser.add_argument( | |
| '--conf-type', | |
| type=str, | |
| default='avg', | |
| help='how to calculate confidence in weighted boxes in wbf.') | |
| parser.add_argument( | |
| '--eval-single', | |
| action='store_true', | |
| help='whether evaluate each single model result.') | |
| parser.add_argument( | |
| '--save-fusion-results', | |
| action='store_true', | |
| help='whether save fusion result') | |
| parser.add_argument( | |
| '--out-dir', | |
| type=str, | |
| default='outputs', | |
| help='Output directory of images or prediction results.') | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = parse_args() | |
| assert len(args.models_name) == len(args.pred_results), \ | |
| 'the quantities of model names and prediction results are not equal' | |
| cocoGT = COCO(args.annotation) | |
| predicts_raw = [] | |
| models_name = ['model_' + str(i) for i in range(len(args.pred_results))] | |
| for model_name, path in \ | |
| zip(models_name, args.pred_results): | |
| pred = load(path) | |
| predicts_raw.append(pred) | |
| if args.eval_single: | |
| print_log(f'Evaluate {model_name}...') | |
| cocoDt = cocoGT.loadRes(pred) | |
| coco_eval = COCOeval(cocoGT, cocoDt, iouType='bbox') | |
| coco_eval.evaluate() | |
| coco_eval.accumulate() | |
| coco_eval.summarize() | |
| predict = { | |
| str(image_id): { | |
| 'bboxes_list': [[] for _ in range(len(predicts_raw))], | |
| 'scores_list': [[] for _ in range(len(predicts_raw))], | |
| 'labels_list': [[] for _ in range(len(predicts_raw))] | |
| } | |
| for image_id in cocoGT.getImgIds() | |
| } | |
| for i, pred_single in enumerate(predicts_raw): | |
| for pred in pred_single: | |
| p = predict[str(pred['image_id'])] | |
| p['bboxes_list'][i].append(pred['bbox']) | |
| p['scores_list'][i].append(pred['score']) | |
| p['labels_list'][i].append(pred['category_id']) | |
| result = [] | |
| prog_bar = ProgressBar(len(predict)) | |
| for image_id, res in predict.items(): | |
| bboxes, scores, labels = weighted_boxes_fusion( | |
| res['bboxes_list'], | |
| res['scores_list'], | |
| res['labels_list'], | |
| weights=args.weights, | |
| iou_thr=args.fusion_iou_thr, | |
| skip_box_thr=args.skip_box_thr, | |
| conf_type=args.conf_type) | |
| for bbox, score, label in zip(bboxes, scores, labels): | |
| result.append({ | |
| 'bbox': bbox.numpy().tolist(), | |
| 'category_id': int(label), | |
| 'image_id': int(image_id), | |
| 'score': float(score) | |
| }) | |
| prog_bar.update() | |
| if args.save_fusion_results: | |
| out_file = args.out_dir + '/fusion_results.json' | |
| dump(result, file=out_file) | |
| print_log( | |
| f'Fusion results have been saved to {out_file}.', logger='current') | |
| print_log('Evaluate fusion results using wbf...') | |
| cocoDt = cocoGT.loadRes(result) | |
| coco_eval = COCOeval(cocoGT, cocoDt, iouType='bbox') | |
| coco_eval.evaluate() | |
| coco_eval.accumulate() | |
| coco_eval.summarize() | |
| if __name__ == '__main__': | |
| main() | |