|
|
|
|
|
import hashlib |
|
|
import logging |
|
|
import os |
|
|
import os.path as osp |
|
|
import warnings |
|
|
from argparse import ArgumentParser |
|
|
|
|
|
import requests |
|
|
from mmengine import Config |
|
|
|
|
|
from mmseg.apis import inference_model, init_model, show_result_pyplot |
|
|
from mmseg.utils import get_root_logger |
|
|
|
|
|
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
def download_checkpoint(checkpoint_name, model_name, config_name, collect_dir): |
|
|
"""Download checkpoint and check if hash code is true.""" |
|
|
url = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{checkpoint_name}' |
|
|
|
|
|
r = requests.get(url) |
|
|
assert r.status_code != 403, f'{url} Access denied.' |
|
|
|
|
|
with open(osp.join(collect_dir, checkpoint_name), 'wb') as code: |
|
|
code.write(r.content) |
|
|
|
|
|
true_hash_code = osp.splitext(checkpoint_name)[0].split('-')[1] |
|
|
|
|
|
|
|
|
with open(osp.join(collect_dir, checkpoint_name), 'rb') as fp: |
|
|
sha256_cal = hashlib.sha256() |
|
|
sha256_cal.update(fp.read()) |
|
|
cur_hash_code = sha256_cal.hexdigest()[:8] |
|
|
|
|
|
assert true_hash_code == cur_hash_code, f'{url} download failed, ' |
|
|
'incomplete downloaded file or url invalid.' |
|
|
|
|
|
if cur_hash_code != true_hash_code: |
|
|
os.remove(osp.join(collect_dir, checkpoint_name)) |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = ArgumentParser() |
|
|
parser.add_argument('config', help='test config file path') |
|
|
parser.add_argument('checkpoint_root', help='Checkpoint file root path') |
|
|
parser.add_argument( |
|
|
'-i', '--img', default='demo/demo.png', help='Image file') |
|
|
parser.add_argument('-a', '--aug', action='store_true', help='aug test') |
|
|
parser.add_argument('-m', '--model-name', help='model name to inference') |
|
|
parser.add_argument( |
|
|
'-s', '--show', action='store_true', help='show results') |
|
|
parser.add_argument( |
|
|
'-d', '--device', default='cuda:0', help='Device used for inference') |
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
|
|
|
def inference(config_name, checkpoint, args, logger=None): |
|
|
cfg = Config.fromfile(config_name) |
|
|
if args.aug: |
|
|
if 'flip' in cfg.data.test.pipeline[ |
|
|
1] and 'img_scale' in cfg.data.test.pipeline[1]: |
|
|
cfg.data.test.pipeline[1].img_ratios = [ |
|
|
0.5, 0.75, 1.0, 1.25, 1.5, 1.75 |
|
|
] |
|
|
cfg.data.test.pipeline[1].flip = True |
|
|
else: |
|
|
if logger is not None: |
|
|
logger.error(f'{config_name}: unable to start aug test') |
|
|
else: |
|
|
print(f'{config_name}: unable to start aug test', flush=True) |
|
|
|
|
|
model = init_model(cfg, checkpoint, device=args.device) |
|
|
|
|
|
result = inference_model(model, args.img) |
|
|
|
|
|
|
|
|
if args.show: |
|
|
show_result_pyplot(model, args.img, result) |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def main(args): |
|
|
config = Config.fromfile(args.config) |
|
|
|
|
|
if not os.path.exists(args.checkpoint_root): |
|
|
os.makedirs(args.checkpoint_root, 0o775) |
|
|
|
|
|
|
|
|
if args.model_name: |
|
|
if args.model_name in config: |
|
|
model_infos = config[args.model_name] |
|
|
if not isinstance(model_infos, list): |
|
|
model_infos = [model_infos] |
|
|
for model_info in model_infos: |
|
|
config_name = model_info['config'].strip() |
|
|
print(f'processing: {config_name}', flush=True) |
|
|
checkpoint = osp.join(args.checkpoint_root, |
|
|
model_info['checkpoint'].strip()) |
|
|
try: |
|
|
|
|
|
inference(config_name, checkpoint, args) |
|
|
except Exception: |
|
|
print(f'{config_name} test failed!') |
|
|
continue |
|
|
return |
|
|
else: |
|
|
raise RuntimeError('model name input error.') |
|
|
|
|
|
|
|
|
logger = get_root_logger( |
|
|
log_file='benchmark_inference_image.log', log_level=logging.ERROR) |
|
|
|
|
|
for model_name in config: |
|
|
model_infos = config[model_name] |
|
|
|
|
|
if not isinstance(model_infos, list): |
|
|
model_infos = [model_infos] |
|
|
for model_info in model_infos: |
|
|
print('processing: ', model_info['config'], flush=True) |
|
|
config_path = model_info['config'].strip() |
|
|
config_name = osp.splitext(osp.basename(config_path))[0] |
|
|
checkpoint_name = model_info['checkpoint'].strip() |
|
|
checkpoint = osp.join(args.checkpoint_root, checkpoint_name) |
|
|
|
|
|
|
|
|
try: |
|
|
if not osp.exists(checkpoint): |
|
|
download_checkpoint(checkpoint_name, model_name, |
|
|
config_name.rstrip('.py'), |
|
|
args.checkpoint_root) |
|
|
except Exception: |
|
|
logger.error(f'{checkpoint_name} download error') |
|
|
continue |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
inference(config_path, checkpoint, args, logger) |
|
|
except Exception as e: |
|
|
logger.error(f'{config_path} " : {repr(e)}') |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
args = parse_args() |
|
|
main(args) |
|
|
|