File size: 5,447 Bytes
ea1014e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# Copyright (c) OpenMMLab. All rights reserved.
import hashlib
import logging
import os
import os.path as osp
import warnings
from argparse import ArgumentParser
import requests
from mmengine import Config
from mmseg.apis import inference_model, init_model, show_result_pyplot
from mmseg.utils import get_root_logger
# ignore warnings when segmentors inference
warnings.filterwarnings('ignore')
def download_checkpoint(checkpoint_name, model_name, config_name, collect_dir):
"""Download checkpoint and check if hash code is true."""
url = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{checkpoint_name}' # noqa
r = requests.get(url)
assert r.status_code != 403, f'{url} Access denied.'
with open(osp.join(collect_dir, checkpoint_name), 'wb') as code:
code.write(r.content)
true_hash_code = osp.splitext(checkpoint_name)[0].split('-')[1]
# check hash code
with open(osp.join(collect_dir, checkpoint_name), 'rb') as fp:
sha256_cal = hashlib.sha256()
sha256_cal.update(fp.read())
cur_hash_code = sha256_cal.hexdigest()[:8]
assert true_hash_code == cur_hash_code, f'{url} download failed, '
'incomplete downloaded file or url invalid.'
if cur_hash_code != true_hash_code:
os.remove(osp.join(collect_dir, checkpoint_name))
def parse_args():
parser = ArgumentParser()
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint_root', help='Checkpoint file root path')
parser.add_argument(
'-i', '--img', default='demo/demo.png', help='Image file')
parser.add_argument('-a', '--aug', action='store_true', help='aug test')
parser.add_argument('-m', '--model-name', help='model name to inference')
parser.add_argument(
'-s', '--show', action='store_true', help='show results')
parser.add_argument(
'-d', '--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
return args
def inference(config_name, checkpoint, args, logger=None):
cfg = Config.fromfile(config_name)
if args.aug:
if 'flip' in cfg.data.test.pipeline[
1] and 'img_scale' in cfg.data.test.pipeline[1]:
cfg.data.test.pipeline[1].img_ratios = [
0.5, 0.75, 1.0, 1.25, 1.5, 1.75
]
cfg.data.test.pipeline[1].flip = True
else:
if logger is not None:
logger.error(f'{config_name}: unable to start aug test')
else:
print(f'{config_name}: unable to start aug test', flush=True)
model = init_model(cfg, checkpoint, device=args.device)
# test a single image
result = inference_model(model, args.img)
# show the results
if args.show:
show_result_pyplot(model, args.img, result)
return result
# Sample test whether the inference code is correct
def main(args):
config = Config.fromfile(args.config)
if not os.path.exists(args.checkpoint_root):
os.makedirs(args.checkpoint_root, 0o775)
# test single model
if args.model_name:
if args.model_name in config:
model_infos = config[args.model_name]
if not isinstance(model_infos, list):
model_infos = [model_infos]
for model_info in model_infos:
config_name = model_info['config'].strip()
print(f'processing: {config_name}', flush=True)
checkpoint = osp.join(args.checkpoint_root,
model_info['checkpoint'].strip())
try:
# build the model from a config file and a checkpoint file
inference(config_name, checkpoint, args)
except Exception:
print(f'{config_name} test failed!')
continue
return
else:
raise RuntimeError('model name input error.')
# test all model
logger = get_root_logger(
log_file='benchmark_inference_image.log', log_level=logging.ERROR)
for model_name in config:
model_infos = config[model_name]
if not isinstance(model_infos, list):
model_infos = [model_infos]
for model_info in model_infos:
print('processing: ', model_info['config'], flush=True)
config_path = model_info['config'].strip()
config_name = osp.splitext(osp.basename(config_path))[0]
checkpoint_name = model_info['checkpoint'].strip()
checkpoint = osp.join(args.checkpoint_root, checkpoint_name)
# ensure checkpoint exists
try:
if not osp.exists(checkpoint):
download_checkpoint(checkpoint_name, model_name,
config_name.rstrip('.py'),
args.checkpoint_root)
except Exception:
logger.error(f'{checkpoint_name} download error')
continue
# test model inference with checkpoint
try:
# build the model from a config file and a checkpoint file
inference(config_path, checkpoint, args, logger)
except Exception as e:
logger.error(f'{config_path} " : {repr(e)}')
if __name__ == '__main__':
args = parse_args()
main(args)
|