Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import argparse | |
| import os | |
| import os.path as osp | |
| import sys | |
| import warnings | |
| import mmcv | |
| import numpy as np | |
| import torch | |
| from mmengine import ProgressBar | |
| from mmengine.config import Config, DictAction | |
| from mmengine.dataset import COLLATE_FUNCTIONS | |
| from mmengine.runner.checkpoint import load_checkpoint | |
| from numpy import random | |
| from mmyolo.registry import DATASETS, MODELS | |
| from mmyolo.utils import register_all_modules | |
| from projects.assigner_visualization.dense_heads import (RTMHeadAssigner, | |
| YOLOv5HeadAssigner, | |
| YOLOv7HeadAssigner, | |
| YOLOv8HeadAssigner) | |
| from projects.assigner_visualization.visualization import \ | |
| YOLOAssignerVisualizer | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description='MMYOLO show the positive sample assigning' | |
| ' results.') | |
| parser.add_argument('config', help='config file path') | |
| parser.add_argument('--checkpoint', '-c', type=str, help='checkpoint file') | |
| parser.add_argument( | |
| '--show-number', | |
| '-n', | |
| type=int, | |
| default=sys.maxsize, | |
| help='number of images selected to save, ' | |
| 'must bigger than 0. if the number is bigger than length ' | |
| 'of dataset, show all the images in dataset; ' | |
| 'default "sys.maxsize", show all images in dataset') | |
| parser.add_argument( | |
| '--output-dir', | |
| default='assigned_results', | |
| type=str, | |
| help='The name of the folder where the image is saved.') | |
| parser.add_argument( | |
| '--device', default='cuda:0', help='Device used for inference.') | |
| parser.add_argument( | |
| '--show-prior', | |
| default=False, | |
| action='store_true', | |
| help='Whether to show prior on image.') | |
| parser.add_argument( | |
| '--not-show-label', | |
| default=False, | |
| action='store_true', | |
| help='Whether to show label on image.') | |
| parser.add_argument('--seed', default=-1, type=int, help='random seed') | |
| parser.add_argument( | |
| '--cfg-options', | |
| nargs='+', | |
| action=DictAction, | |
| help='override some settings in the used config, the key-value pair ' | |
| 'in xxx=yyy format will be merged into config file. If the value to ' | |
| 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' | |
| 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' | |
| 'Note that the quotation marks are necessary and that no white space ' | |
| 'is allowed.') | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = parse_args() | |
| register_all_modules() | |
| # set random seed | |
| seed = int(args.seed) | |
| if seed != -1: | |
| print(f'Set the global seed: {seed}') | |
| random.seed(int(args.seed)) | |
| cfg = Config.fromfile(args.config) | |
| if args.cfg_options is not None: | |
| cfg.merge_from_dict(args.cfg_options) | |
| # build model | |
| model = MODELS.build(cfg.model) | |
| if args.checkpoint is not None: | |
| load_checkpoint(model, args.checkpoint) | |
| elif isinstance(model.bbox_head, (YOLOv7HeadAssigner, RTMHeadAssigner)): | |
| warnings.warn( | |
| 'if you use dynamic_assignment methods such as YOLOv7 or ' | |
| 'YOLOv8 or RTMDet assigner, please load the checkpoint.') | |
| assert isinstance(model.bbox_head, (YOLOv5HeadAssigner, | |
| YOLOv7HeadAssigner, | |
| YOLOv8HeadAssigner, | |
| RTMHeadAssigner)), \ | |
| 'Now, this script only support YOLOv5, YOLOv7, YOLOv8 and RTMdet, ' \ | |
| 'and bbox_head must use ' \ | |
| '`YOLOv5HeadAssigner or YOLOv7HeadAssigne or YOLOv8HeadAssigner ' \ | |
| 'or RTMHeadAssigner`. Please use `' \ | |
| 'yolov5_s-v61_syncbn_fast_8xb16-300e_coco_assignervisualization.py' \ | |
| 'or yolov7_tiny_syncbn_fast_8x16b-300e_coco_assignervisualization.py' \ | |
| 'or yolov8_s_syncbn_fast_8xb16-500e_coco_assignervisualization.py' \ | |
| 'or rtmdet_s_syncbn_fast_8xb32-300e_coco_assignervisualization.py' \ | |
| """` as config file.""" | |
| model.eval() | |
| model.to(args.device) | |
| # build dataset | |
| dataset_cfg = cfg.get('train_dataloader').get('dataset') | |
| dataset = DATASETS.build(dataset_cfg) | |
| # get collate_fn | |
| collate_fn_cfg = cfg.get('train_dataloader').pop( | |
| 'collate_fn', dict(type='pseudo_collate')) | |
| collate_fn_type = collate_fn_cfg.pop('type') | |
| collate_fn = COLLATE_FUNCTIONS.get(collate_fn_type) | |
| # init visualizer | |
| visualizer = YOLOAssignerVisualizer( | |
| vis_backends=[{ | |
| 'type': 'LocalVisBackend' | |
| }], name='visualizer') | |
| visualizer.dataset_meta = dataset.metainfo | |
| # need priors size to draw priors | |
| if hasattr(model.bbox_head.prior_generator, 'base_anchors'): | |
| visualizer.priors_size = model.bbox_head.prior_generator.base_anchors | |
| # make output dir | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| print('Results will save to ', args.output_dir) | |
| # init visualization image number | |
| assert args.show_number > 0 | |
| display_number = min(args.show_number, len(dataset)) | |
| progress_bar = ProgressBar(display_number) | |
| for ind_img in range(display_number): | |
| data = dataset.prepare_data(ind_img) | |
| if data is None: | |
| print('Unable to visualize {} due to strong data augmentations'. | |
| format(dataset[ind_img]['data_samples'].img_path)) | |
| continue | |
| # convert data to batch format | |
| batch_data = collate_fn([data]) | |
| with torch.no_grad(): | |
| assign_results = model.assign(batch_data) | |
| img = data['inputs'].cpu().numpy().astype(np.uint8).transpose( | |
| (1, 2, 0)) | |
| # bgr2rgb | |
| img = mmcv.bgr2rgb(img) | |
| gt_instances = data['data_samples'].gt_instances | |
| img_show = visualizer.draw_assign(img, assign_results, gt_instances, | |
| args.show_prior, args.not_show_label) | |
| if hasattr(data['data_samples'], 'img_path'): | |
| filename = osp.basename(data['data_samples'].img_path) | |
| else: | |
| # some dataset have not image path | |
| filename = f'{ind_img}.jpg' | |
| out_file = osp.join(args.output_dir, filename) | |
| # convert rgb 2 bgr and save img | |
| mmcv.imwrite(mmcv.rgb2bgr(img_show), out_file) | |
| progress_bar.update() | |
| if __name__ == '__main__': | |
| main() | |