File size: 4,052 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Copyright (c) OpenMMLab. All rights reserved.
import shutil
from argparse import ArgumentParser, Namespace
from pathlib import Path
from tempfile import TemporaryDirectory

from mmengine.config import Config
from mmengine.utils import mkdir_or_exist

try:
    from model_archiver.model_packaging import package_model
    from model_archiver.model_packaging_utils import ModelExportUtils
except ImportError:
    raise ImportError('`torch-model-archiver` is required.'
                      'Try: pip install torch-model-archiver')


def mmaction2torchserve(

    config_file: str,

    checkpoint_file: str,

    output_folder: str,

    model_name: str,

    label_file: str,

    model_version: str = '1.0',

    force: bool = False,

):
    """Converts MMAction2 model (config + checkpoint) to TorchServe `.mar`.



    Args:

        config_file (str): In MMAction2 config format.

        checkpoint_file (str): In MMAction2 checkpoint format.

        output_folder (str): Folder where `{model_name}.mar` will be created.

            The file created will be in TorchServe archive format.

        label_file (str): A txt file which contains the action category names.

        model_name (str | None): If not None, used for naming the

            `{model_name}.mar` file that will be created under `output_folder`.

            If None, `{Path(checkpoint_file).stem}` will be used.

        model_version (str): Model's version.

        force (bool): If True, if there is an existing `{model_name}.mar` file

            under `output_folder` it will be overwritten.

    """
    mkdir_or_exist(output_folder)

    config = Config.fromfile(config_file)

    with TemporaryDirectory() as tmpdir:
        config.dump(f'{tmpdir}/config.py')
        shutil.copy(label_file, f'{tmpdir}/label_map.txt')

        args = Namespace(
            **{
                'model_file': f'{tmpdir}/config.py',
                'serialized_file': checkpoint_file,
                'handler': f'{Path(__file__).parent}/mmaction_handler.py',
                'model_name': model_name or Path(checkpoint_file).stem,
                'version': model_version,
                'export_path': output_folder,
                'force': force,
                'requirements_file': None,
                'extra_files': f'{tmpdir}/label_map.txt',
                'runtime': 'python',
                'archive_format': 'default'
            })
        manifest = ModelExportUtils.generate_manifest_json(args)
        package_model(args, manifest)


def parse_args():
    parser = ArgumentParser(
        description='Convert MMAction2 models to TorchServe `.mar` format.')
    parser.add_argument('config', type=str, help='config file path')
    parser.add_argument('checkpoint', type=str, help='checkpoint file path')
    parser.add_argument(
        '--output-folder',
        type=str,
        required=True,
        help='Folder where `{model_name}.mar` will be created.')
    parser.add_argument(
        '--model-name',
        type=str,
        default=None,
        help='If not None, used for naming the `{model_name}.mar`'
        'file that will be created under `output_folder`.'
        'If None, `{Path(checkpoint_file).stem}` will be used.')
    parser.add_argument(
        '--label-file',
        type=str,
        default=None,
        help='A txt file which contains the action category names. ')
    parser.add_argument(
        '--model-version',
        type=str,
        default='1.0',
        help='Number used for versioning.')
    parser.add_argument(
        '-f',
        '--force',
        action='store_true',
        help='overwrite the existing `{model_name}.mar`')
    args = parser.parse_args()

    return args


if __name__ == '__main__':
    args = parse_args()

    mmaction2torchserve(args.config, args.checkpoint, args.output_folder,
                        args.model_name, args.label_file, args.model_version,
                        args.force)