File size: 4,854 Bytes
789eef1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import numpy as np
import torch
from mmengine.config import DictAction
from mmengine.logging import MMLogger
from mmpose.apis.inference import init_model
try:
from mmengine.analysis import get_model_complexity_info
from mmengine.analysis.print_helper import _format_size
except ImportError:
raise ImportError('Please upgrade mmengine >= 0.6.0')
def parse_args():
parser = argparse.ArgumentParser(
description='Get complexity information from a model config')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--device', default='cpu', help='Device used for model initialization')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
default={},
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. For example, '
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
parser.add_argument(
'--input-shape',
type=int,
nargs='+',
default=[256, 192],
help='input image size')
parser.add_argument(
'--batch-size',
'-b',
type=int,
default=1,
help='Input batch size. If specified and greater than 1, it takes a '
'callable method that generates a batch input. Otherwise, it will '
'generate a random tensor with input shape to calculate FLOPs.')
parser.add_argument(
'--show-arch-info',
'-s',
action='store_true',
help='Whether to show model arch information')
args = parser.parse_args()
return args
def batch_constructor(flops_model, batch_size, input_shape):
"""Generate a batch of tensors to the model."""
batch = {}
inputs = torch.randn(batch_size, *input_shape).new_empty(
(batch_size, *input_shape),
dtype=next(flops_model.parameters()).dtype,
device=next(flops_model.parameters()).device)
batch['inputs'] = inputs
return batch
def inference(args, input_shape, logger):
model = init_model(
args.config,
checkpoint=None,
device=args.device,
cfg_options=args.cfg_options)
if hasattr(model, '_forward'):
model.forward = model._forward
else:
raise NotImplementedError(
'FLOPs counter is currently not currently supported with {}'.
format(model.__class__.__name__))
if args.batch_size > 1:
outputs = {}
avg_flops = []
logger.info('Running get_flops with batch size specified as {}'.format(
args.batch_size))
batch = batch_constructor(model, args.batch_size, input_shape)
for i in range(args.batch_size):
result = get_model_complexity_info(
model,
input_shape,
inputs=batch['inputs'],
show_table=True,
show_arch=args.show_arch_info)
avg_flops.append(result['flops'])
mean_flops = _format_size(int(np.average(avg_flops)))
outputs['flops_str'] = mean_flops
outputs['params_str'] = result['params_str']
outputs['out_table'] = result['out_table']
outputs['out_arch'] = result['out_arch']
else:
outputs = get_model_complexity_info(
model,
input_shape,
inputs=None,
show_table=True,
show_arch=args.show_arch_info)
return outputs
def main():
args = parse_args()
logger = MMLogger.get_instance(name='MMLogger')
if len(args.input_shape) == 1:
input_shape = (3, args.input_shape[0], args.input_shape[0])
elif len(args.input_shape) == 2:
input_shape = (3, ) + tuple(args.input_shape)
else:
raise ValueError('invalid input shape')
if args.device == 'cuda:0':
assert torch.cuda.is_available(
), 'No valid cuda device detected, please double check...'
outputs = inference(args, input_shape, logger)
flops = outputs['flops_str']
params = outputs['params_str']
split_line = '=' * 30
input_shape = (args.batch_size, ) + input_shape
print(f'{split_line}\nInput shape: {input_shape}\n'
f'Flops: {flops}\nParams: {params}\n{split_line}')
print(outputs['out_table'])
if args.show_arch_info:
print(outputs['out_arch'])
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')
if __name__ == '__main__':
main()
|