| | |
| | from argparse import ArgumentParser, Namespace |
| | from pathlib import Path |
| | from tempfile import TemporaryDirectory |
| |
|
| | from mmengine.config import Config |
| | from mmengine.utils import mkdir_or_exist |
| |
|
| | try: |
| | from model_archiver.model_packaging import package_model |
| | from model_archiver.model_packaging_utils import ModelExportUtils |
| | except ImportError: |
| | package_model = None |
| |
|
| |
|
| | def mmdet2torchserve( |
| | config_file: str, |
| | checkpoint_file: str, |
| | output_folder: str, |
| | model_name: str, |
| | model_version: str = '1.0', |
| | force: bool = False, |
| | ): |
| | """Converts MMDetection model (config + checkpoint) to TorchServe `.mar`. |
| | |
| | Args: |
| | config_file: |
| | In MMDetection config format. |
| | The contents vary for each task repository. |
| | checkpoint_file: |
| | In MMDetection checkpoint format. |
| | The contents vary for each task repository. |
| | output_folder: |
| | Folder where `{model_name}.mar` will be created. |
| | The file created will be in TorchServe archive format. |
| | model_name: |
| | If not None, used for naming the `{model_name}.mar` file |
| | that will be created under `output_folder`. |
| | If None, `{Path(checkpoint_file).stem}` will be used. |
| | model_version: |
| | Model's version. |
| | force: |
| | If True, if there is an existing `{model_name}.mar` |
| | file under `output_folder` it will be overwritten. |
| | """ |
| | mkdir_or_exist(output_folder) |
| |
|
| | config = Config.fromfile(config_file) |
| |
|
| | with TemporaryDirectory() as tmpdir: |
| | config.dump(f'{tmpdir}/config.py') |
| |
|
| | args = Namespace( |
| | **{ |
| | 'model_file': f'{tmpdir}/config.py', |
| | 'serialized_file': checkpoint_file, |
| | 'handler': f'{Path(__file__).parent}/mmdet_handler.py', |
| | 'model_name': model_name or Path(checkpoint_file).stem, |
| | 'version': model_version, |
| | 'export_path': output_folder, |
| | 'force': force, |
| | 'requirements_file': None, |
| | 'extra_files': None, |
| | 'runtime': 'python', |
| | 'archive_format': 'default' |
| | }) |
| | manifest = ModelExportUtils.generate_manifest_json(args) |
| | package_model(args, manifest) |
| |
|
| |
|
| | def parse_args(): |
| | parser = ArgumentParser( |
| | description='Convert MMDetection models to TorchServe `.mar` format.') |
| | parser.add_argument('config', type=str, help='config file path') |
| | parser.add_argument('checkpoint', type=str, help='checkpoint file path') |
| | parser.add_argument( |
| | '--output-folder', |
| | type=str, |
| | required=True, |
| | help='Folder where `{model_name}.mar` will be created.') |
| | parser.add_argument( |
| | '--model-name', |
| | type=str, |
| | default=None, |
| | help='If not None, used for naming the `{model_name}.mar`' |
| | 'file that will be created under `output_folder`.' |
| | 'If None, `{Path(checkpoint_file).stem}` will be used.') |
| | parser.add_argument( |
| | '--model-version', |
| | type=str, |
| | default='1.0', |
| | help='Number used for versioning.') |
| | parser.add_argument( |
| | '-f', |
| | '--force', |
| | action='store_true', |
| | help='overwrite the existing `{model_name}.mar`') |
| | args = parser.parse_args() |
| |
|
| | return args |
| |
|
| |
|
| | if __name__ == '__main__': |
| | args = parse_args() |
| |
|
| | if package_model is None: |
| | raise ImportError('`torch-model-archiver` is required.' |
| | 'Try: pip install torch-model-archiver') |
| |
|
| | mmdet2torchserve(args.config, args.checkpoint, args.output_folder, |
| | args.model_name, args.model_version, args.force) |
| |
|