File size: 4,854 Bytes
789eef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import argparse

import numpy as np
import torch
from mmengine.config import DictAction
from mmengine.logging import MMLogger

from mmpose.apis.inference import init_model

try:
    from mmengine.analysis import get_model_complexity_info
    from mmengine.analysis.print_helper import _format_size
except ImportError:
    raise ImportError('Please upgrade mmengine >= 0.6.0')


def parse_args():
    parser = argparse.ArgumentParser(
        description='Get complexity information from a model config')
    parser.add_argument('config', help='train config file path')
    parser.add_argument(
        '--device', default='cpu', help='Device used for model initialization')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        default={},
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. For example, '
        "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
    parser.add_argument(
        '--input-shape',
        type=int,
        nargs='+',
        default=[256, 192],
        help='input image size')
    parser.add_argument(
        '--batch-size',
        '-b',
        type=int,
        default=1,
        help='Input batch size. If specified and greater than 1, it takes a '
        'callable method that generates a batch input. Otherwise, it will '
        'generate a random tensor with input shape to calculate FLOPs.')
    parser.add_argument(
        '--show-arch-info',
        '-s',
        action='store_true',
        help='Whether to show model arch information')
    args = parser.parse_args()
    return args


def batch_constructor(flops_model, batch_size, input_shape):
    """Generate a batch of tensors to the model."""
    batch = {}

    inputs = torch.randn(batch_size, *input_shape).new_empty(
        (batch_size, *input_shape),
        dtype=next(flops_model.parameters()).dtype,
        device=next(flops_model.parameters()).device)

    batch['inputs'] = inputs
    return batch


def inference(args, input_shape, logger):
    model = init_model(
        args.config,
        checkpoint=None,
        device=args.device,
        cfg_options=args.cfg_options)

    if hasattr(model, '_forward'):
        model.forward = model._forward
    else:
        raise NotImplementedError(
            'FLOPs counter is currently not currently supported with {}'.
            format(model.__class__.__name__))

    if args.batch_size > 1:
        outputs = {}
        avg_flops = []
        logger.info('Running get_flops with batch size specified as {}'.format(
            args.batch_size))
        batch = batch_constructor(model, args.batch_size, input_shape)
        for i in range(args.batch_size):
            result = get_model_complexity_info(
                model,
                input_shape,
                inputs=batch['inputs'],
                show_table=True,
                show_arch=args.show_arch_info)
            avg_flops.append(result['flops'])
        mean_flops = _format_size(int(np.average(avg_flops)))
        outputs['flops_str'] = mean_flops
        outputs['params_str'] = result['params_str']
        outputs['out_table'] = result['out_table']
        outputs['out_arch'] = result['out_arch']
    else:
        outputs = get_model_complexity_info(
            model,
            input_shape,
            inputs=None,
            show_table=True,
            show_arch=args.show_arch_info)
    return outputs


def main():
    args = parse_args()
    logger = MMLogger.get_instance(name='MMLogger')

    if len(args.input_shape) == 1:
        input_shape = (3, args.input_shape[0], args.input_shape[0])
    elif len(args.input_shape) == 2:
        input_shape = (3, ) + tuple(args.input_shape)
    else:
        raise ValueError('invalid input shape')

    if args.device == 'cuda:0':
        assert torch.cuda.is_available(
        ), 'No valid cuda device detected, please double check...'

    outputs = inference(args, input_shape, logger)
    flops = outputs['flops_str']
    params = outputs['params_str']
    split_line = '=' * 30
    input_shape = (args.batch_size, ) + input_shape
    print(f'{split_line}\nInput shape: {input_shape}\n'
          f'Flops: {flops}\nParams: {params}\n{split_line}')
    print(outputs['out_table'])
    if args.show_arch_info:
        print(outputs['out_arch'])
    print('!!!Please be cautious if you use the results in papers. '
          'You may need to check if all ops are supported and verify that the '
          'flops computation is correct.')


if __name__ == '__main__':
    main()