| |
| |
| import os.path as osp |
| from argparse import ArgumentParser |
|
|
| import mmcv |
| from mmcv.utils import ProgressBar |
|
|
| from mmocr.apis import init_detector, model_inference |
| from mmocr.models import build_detector |
| from mmocr.utils import list_from_file, list_to_file |
|
|
|
|
| def gen_target_path(target_root_path, src_name, suffix): |
| """Gen target file path. |
| |
| Args: |
| target_root_path (str): The target root path. |
| src_name (str): The source file name. |
| suffix (str): The suffix of target file. |
| """ |
| assert isinstance(target_root_path, str) |
| assert isinstance(src_name, str) |
| assert isinstance(suffix, str) |
|
|
| file_name = osp.split(src_name)[-1] |
| name = osp.splitext(file_name)[0] |
| return osp.join(target_root_path, name + suffix) |
|
|
|
|
| def save_results(result, out_dir, img_name, score_thr=0.3): |
| """Save result of detected bounding boxes (quadrangle or polygon) to txt |
| file. |
| |
| Args: |
| result (dict): Text Detection result for one image. |
| img_name (str): Image file name. |
| out_dir (str): Dir of txt files to save detected results. |
| score_thr (float, optional): Score threshold to filter bboxes. |
| """ |
| assert 'boundary_result' in result |
| assert score_thr > 0 and score_thr < 1 |
|
|
| txt_file = gen_target_path(out_dir, img_name, '.txt') |
| valid_boundary_res = [ |
| res for res in result['boundary_result'] if res[-1] > score_thr |
| ] |
| lines = [ |
| ','.join([str(round(x)) for x in row]) for row in valid_boundary_res |
| ] |
| list_to_file(txt_file, lines) |
|
|
|
|
| def main(): |
| parser = ArgumentParser() |
| parser.add_argument('img_root', type=str, help='Image root path') |
| parser.add_argument('img_list', type=str, help='Image path list file') |
| parser.add_argument('config', type=str, help='Config file') |
| parser.add_argument('checkpoint', type=str, help='Checkpoint file') |
| parser.add_argument( |
| '--score-thr', type=float, default=0.5, help='Bbox score threshold') |
| parser.add_argument( |
| '--out-dir', |
| type=str, |
| default='./results', |
| help='Dir to save ' |
| 'visualize images ' |
| 'and bbox') |
| parser.add_argument( |
| '--device', default='cuda:0', help='Device used for inference.') |
| args = parser.parse_args() |
|
|
| assert 0 < args.score_thr < 1 |
|
|
| |
| model = init_detector(args.config, args.checkpoint, device=args.device) |
| if hasattr(model, 'module'): |
| model = model.module |
|
|
| |
| out_vis_dir = osp.join(args.out_dir, 'out_vis_dir') |
| mmcv.mkdir_or_exist(out_vis_dir) |
| out_txt_dir = osp.join(args.out_dir, 'out_txt_dir') |
| mmcv.mkdir_or_exist(out_txt_dir) |
|
|
| lines = list_from_file(args.img_list) |
| progressbar = ProgressBar(task_num=len(lines)) |
| for line in lines: |
| progressbar.update() |
| img_path = osp.join(args.img_root, line.strip()) |
| if not osp.exists(img_path): |
| raise FileNotFoundError(img_path) |
| |
| result = model_inference(model, img_path) |
| img_name = osp.basename(img_path) |
| |
| save_results(result, out_txt_dir, img_name, score_thr=args.score_thr) |
| |
| out_file = osp.join(out_vis_dir, img_name) |
| kwargs_dict = { |
| 'score_thr': args.score_thr, |
| 'show': False, |
| 'out_file': out_file |
| } |
| model.show_result(img_path, result, **kwargs_dict) |
|
|
| print(f'\nInference done, and results saved in {args.out_dir}\n') |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|