# Copyright (c) OpenMMLab. All rights reserved. from __future__ import division import argparse import os import torch from mmcv import Config, DictAction from mmcv.runner.checkpoint import save_checkpoint from mmdet import __version__ as mmdet_version from mmdet3d import __version__ as mmdet3d_version from mmdet3d.models import build_model try: from mmcv.cnn import get_model_complexity_info except ImportError: raise ImportError('Please upgrade mmcv to >0.6.2') def parse_args(): parser = argparse.ArgumentParser(description='Train a detector') parser.add_argument('config', help='train config file path') parser.add_argument( '--shape', type=int, nargs='+', default=[40000, 4], help='input point cloud size') parser.add_argument( '--modality', type=str, default='point', choices=['point', 'image', 'multi', 'multiview'], help='input data modality') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, piit should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') args = parser.parse_args() return args def main(): args = parse_args() if args.modality == 'point': assert len(args.shape) == 2, 'invalid input shape' input_shape = tuple(args.shape) elif args.modality == 'image': if len(args.shape) == 1: input_shape = (3, args.shape[0], args.shape[0]) elif len(args.shape) == 2: input_shape = (3, ) + tuple(args.shape) else: raise ValueError('invalid input shape') elif args.modality == 'multi': raise NotImplementedError( 'FLOPs counter is currently not supported for models with ' 'multi-modality input') elif args.modality == 'multiview': input_shape = (1, 6, 3, 928, 1600) cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) # import modules from string list. if cfg.get('custom_imports', None): from mmcv.utils import import_modules_from_strings import_modules_from_strings(**cfg['custom_imports']) # import modules from plguin/xx, registry will be updated if hasattr(cfg, 'plugin'): if cfg.plugin: import importlib if hasattr(cfg, 'plugin_dir'): plugin_dir = cfg.plugin_dir _module_dir = os.path.dirname(plugin_dir) _module_dir = _module_dir.split('/') _module_path = _module_dir[0] for m in _module_dir[1:]: _module_path = _module_path + '.' + m # print(_module_path) plg_lib = importlib.import_module(_module_path) else: # import dir is the dirpath for the config file _module_dir = os.path.dirname(args.config) _module_dir = _module_dir.split('/') _module_path = _module_dir[0] for m in _module_dir[1:]: _module_path = _module_path + '.' + m # print(_module_path) plg_lib = importlib.import_module(_module_path) try: from mmdet3d_plugin.uniad.apis.train import custom_train_model except: from mmdet3d_plugin.e2e.apis.train import custom_train_model # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True model = build_model( cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg')) if torch.cuda.is_available(): model.cuda() model.eval() if hasattr(model, 'forward_dummy'): model.forward = model.forward_dummy else: raise NotImplementedError( 'FLOPs counter is currently not supported for {}'.format( model.__class__.__name__)) flops, params = get_model_complexity_info(model, input_shape) split_line = '=' * 30 print(f'{split_line}\nInput shape: {input_shape}\n' f'Flops: {flops}\nParams: {params}\n{split_line}') print('!!!Please be cautious if you use the results in papers. ' 'You may need to check if all ops are supported and verify that the ' 'flops computation is correct.') # save models for profiling # save_path = '/home/xweng/models/cosmos_paradrive.pth' save_path = '/lustre/fsw/portfolios/nvr/users/xweng/tmp/cosmos_paradrive.pth' save_checkpoint(model, save_path) if __name__ == '__main__': main()