| |
|
| | import argparse
|
| | import os
|
| | import os.path as osp
|
| |
|
| | from mmengine import dump, list_from_file, load
|
| | from mmengine.config import Config, DictAction
|
| | from mmengine.runner import Runner
|
| |
|
| |
|
| | def parse_args():
|
| | parser = argparse.ArgumentParser(
|
| | description='MMAction2 feature extraction')
|
| | parser.add_argument('config', help='test config file path')
|
| | parser.add_argument('checkpoint', help='checkpoint file')
|
| | parser.add_argument('output_prefix', type=str, help='output prefix')
|
| | parser.add_argument(
|
| | '--video-list', type=str, default=None, help='video file list')
|
| | parser.add_argument(
|
| | '--video-root', type=str, default=None, help='video root directory')
|
| | parser.add_argument(
|
| | '--spatial-type',
|
| | type=str,
|
| | default='avg',
|
| | choices=['avg', 'max', 'keep'],
|
| | help='Pooling type in spatial dimension')
|
| | parser.add_argument(
|
| | '--temporal-type',
|
| | type=str,
|
| | default='avg',
|
| | choices=['avg', 'max', 'keep'],
|
| | help='Pooling type in temporal dimension')
|
| | parser.add_argument(
|
| | '--long-video-mode',
|
| | action='store_true',
|
| | help='Perform long video inference to get a feature list from a video')
|
| | parser.add_argument(
|
| | '--clip-interval',
|
| | type=int,
|
| | default=None,
|
| | help='Clip interval for Clip interval of adjacent center of sampled '
|
| | 'clips, used for long video inference')
|
| | parser.add_argument(
|
| | '--frame-interval',
|
| | type=int,
|
| | default=None,
|
| | help='Temporal interval of adjacent sampled frames, used for long '
|
| | 'video long video inference')
|
| | parser.add_argument(
|
| | '--multi-view',
|
| | action='store_true',
|
| | help='Perform multi view inference')
|
| | parser.add_argument(
|
| | '--dump-score',
|
| | action='store_true',
|
| | help='Dump predict scores rather than features')
|
| | parser.add_argument(
|
| | '--cfg-options',
|
| | nargs='+',
|
| | action=DictAction,
|
| | help='override some settings in the used config, the key-value pair '
|
| | 'in xxx=yyy format will be merged into config file. If the value to '
|
| | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
| | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
| | 'Note that the quotation marks are necessary and that no white space '
|
| | 'is allowed.')
|
| | parser.add_argument(
|
| | '--launcher',
|
| | choices=['none', 'pytorch', 'slurm', 'mpi'],
|
| | default='none',
|
| | help='job launcher')
|
| | parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
|
| | args = parser.parse_args()
|
| | if 'LOCAL_RANK' not in os.environ:
|
| | os.environ['LOCAL_RANK'] = str(args.local_rank)
|
| | return args
|
| |
|
| |
|
| | def merge_args(cfg, args):
|
| | """Merge CLI arguments to config."""
|
| | test_pipeline = cfg.test_dataloader.dataset.pipeline
|
| |
|
| | if not args.dump_score:
|
| | backbone_type2name = dict(
|
| | ResNet3dSlowFast='slowfast',
|
| | MobileNetV2TSM='tsm',
|
| | ResNetTSM='tsm',
|
| | )
|
| |
|
| | if cfg.model.type == 'RecognizerGCN':
|
| | backbone_name = 'gcn'
|
| | else:
|
| | backbone_name = backbone_type2name.get(cfg.model.backbone.type)
|
| | num_segments = None
|
| | if backbone_name == 'tsm':
|
| | for idx, transform in enumerate(test_pipeline):
|
| | if transform.type == 'UntrimmedSampleFrames':
|
| | clip_len = transform['clip_len']
|
| | continue
|
| | elif transform.type == 'SampleFrames':
|
| | clip_len = transform['num_clips']
|
| | num_segments = cfg.model.backbone.get('num_segments', 8)
|
| | assert num_segments == clip_len, \
|
| | f'num_segments and clip length must same for TSM, but got ' \
|
| | f'num_segments {num_segments} clip_len {clip_len}'
|
| | if cfg.model.test_cfg is not None:
|
| | max_testing_views = cfg.model.test_cfg.get(
|
| | 'max_testing_views', num_segments)
|
| | assert max_testing_views % num_segments == 0, \
|
| | 'tsm needs to infer with batchsize of multiple ' \
|
| | 'of num_segments.'
|
| |
|
| | spatial_type = None if args.spatial_type == 'keep' else \
|
| | args.spatial_type
|
| | temporal_type = None if args.temporal_type == 'keep' else \
|
| | args.temporal_type
|
| | feature_head = dict(
|
| | type='FeatureHead',
|
| | spatial_type=spatial_type,
|
| | temporal_type=temporal_type,
|
| | backbone_name=backbone_name,
|
| | num_segments=num_segments)
|
| | cfg.model.cls_head = feature_head
|
| |
|
| |
|
| | if not args.multi_view:
|
| |
|
| | cfg.model.cls_head['average_clips'] = 'score'
|
| | if cfg.model.type == 'Recognizer3D':
|
| | for idx, transform in enumerate(test_pipeline):
|
| | if transform.type == 'SampleFrames':
|
| | test_pipeline[idx]['num_clips'] = 1
|
| | for idx, transform in enumerate(test_pipeline):
|
| | if transform.type == 'SampleFrames':
|
| | test_pipeline[idx]['twice_sample'] = False
|
| |
|
| | if transform.type == 'TenCrop':
|
| | test_pipeline[idx].type = 'CenterCrop'
|
| |
|
| |
|
| |
|
| | if args.video_list is not None:
|
| | cfg.test_dataloader.dataset.ann_file = args.video_list
|
| | if args.video_root is not None:
|
| | if cfg.test_dataloader.dataset.type == 'VideoDataset':
|
| | cfg.test_dataloader.dataset.data_prefix = dict(
|
| | video=args.video_root)
|
| | elif cfg.test_dataloader.dataset.type == 'RawframeDataset':
|
| | cfg.test_dataloader.dataset.data_prefix = dict(img=args.video_root)
|
| | args.video_list = cfg.test_dataloader.dataset.ann_file
|
| | args.video_root = cfg.test_dataloader.dataset.data_prefix
|
| |
|
| | if args.long_video_mode:
|
| |
|
| | cfg.model.cls_head['average_clips'] = None
|
| | cfg.test_dataloader.batch_size = 1
|
| | is_recognizer2d = (cfg.model.type == 'Recognizer2D')
|
| |
|
| | frame_interval = args.frame_interval
|
| | for idx, transform in enumerate(test_pipeline):
|
| | if transform.type == 'UntrimmedSampleFrames':
|
| | clip_len = transform['clip_len']
|
| | continue
|
| |
|
| | elif transform.type in ['SampleFrames', 'UniformSample']:
|
| | assert args.clip_interval is not None, \
|
| | 'please specify clip interval for long video inference'
|
| | if is_recognizer2d:
|
| |
|
| |
|
| | clip_len = transform['num_clips']
|
| | else:
|
| | clip_len = transform['clip_len']
|
| | if frame_interval is None:
|
| |
|
| | frame_interval = transform.get('frame_interval')
|
| | assert frame_interval is not None, \
|
| | 'please specify frame interval for long video ' \
|
| | 'inference when use UniformSample or 2D Recognizer'
|
| |
|
| | sample_cfgs = dict(
|
| | type='UntrimmedSampleFrames',
|
| | clip_len=clip_len,
|
| | clip_interval=args.clip_interval,
|
| | frame_interval=frame_interval)
|
| | test_pipeline[idx] = sample_cfgs
|
| | continue
|
| |
|
| | if cfg.test_dataloader.dataset.get('modality') == 'Flow':
|
| | clip_len = 1
|
| |
|
| | if is_recognizer2d:
|
| | from mmaction.models import ActionDataPreprocessor
|
| | from mmaction.registry import MODELS
|
| |
|
| | @MODELS.register_module()
|
| | class LongVideoDataPreprocessor(ActionDataPreprocessor):
|
| | """DataPreprocessor for 2D recognizer to infer on long video.
|
| |
|
| | Which would stack the num_clips to batch dimension, to preserve
|
| | feature of each clip (no average among clips)
|
| | """
|
| |
|
| | def __init__(self, num_frames=8, **kwargs) -> None:
|
| | super().__init__(**kwargs)
|
| | self.num_frames = num_frames
|
| |
|
| | def preprocess(self, inputs, data_samples, training=False):
|
| | batch_inputs, data_samples = super().preprocess(
|
| | inputs, data_samples, training)
|
| |
|
| | nclip_batch_inputs = batch_inputs.view(
|
| | (-1, self.num_frames) + batch_inputs.shape[2:])
|
| |
|
| |
|
| | return nclip_batch_inputs, data_samples
|
| |
|
| | preprocessor_cfg = cfg.model.data_preprocessor
|
| | preprocessor_cfg.type = 'LongVideoDataPreprocessor'
|
| | preprocessor_cfg['num_frames'] = clip_len
|
| |
|
| |
|
| | args.dump = osp.join(args.output_prefix, 'total_feats.pkl')
|
| | dump_metric = dict(type='DumpResults', out_file_path=args.dump)
|
| | cfg.test_evaluator = [dump_metric]
|
| | cfg.work_dir = osp.join(args.output_prefix, 'work_dir')
|
| |
|
| | return cfg
|
| |
|
| |
|
| | def split_feats(args):
|
| | total_feats = load(args.dump)
|
| | if args.dump_score:
|
| | total_feats = [sample['pred_scores']['item'] for sample in total_feats]
|
| |
|
| | video_list = list_from_file(args.video_list)
|
| | video_list = [line.split(' ')[0] for line in video_list]
|
| |
|
| | for video_name, feature in zip(video_list, total_feats):
|
| | dump(feature, osp.join(args.output_prefix, video_name + '.pkl'))
|
| | os.remove(args.dump)
|
| |
|
| |
|
| | def main():
|
| | args = parse_args()
|
| |
|
| |
|
| | cfg = Config.fromfile(args.config)
|
| | if args.cfg_options is not None:
|
| | cfg.merge_from_dict(args.cfg_options)
|
| | cfg = merge_args(cfg, args)
|
| | cfg.launcher = args.launcher
|
| |
|
| | cfg.load_from = args.checkpoint
|
| |
|
| |
|
| | runner = Runner.from_cfg(cfg)
|
| |
|
| |
|
| | runner.test()
|
| |
|
| | split_feats(args)
|
| |
|
| |
|
| | if __name__ == '__main__':
|
| | main()
|
| |
|