Deepfake-Detector / tools /convert /convert_recognizer.py
AZIIIIIIIIZ's picture
Upload 1039 files
d670799 verified
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from mmengine.config import Config
def convert(config_path, output_config_path):
print('start convert')
cfg = Config.fromfile(config_path)
origin_dataset_type = cfg.dataset_type
# dataset
if origin_dataset_type != 'VideoDataset':
cfg.dataset_type = 'VideoDataset'
cfg.data_root = 'data/kinetics400/rawframes_train'
cfg.data_root_val = 'data/kinetics400/rawframes_val'
cfg.ann_file_train = \
'data/kinetics400/kinetics400_train_list_rawframes.txt'
cfg.ann_file_val = \
'data/kinetics400/kinetics400_val_list_rawframes.txt'
cfg.ann_file_test = \
'data/kinetics400/kinetics400_val_list_rawframes.txt'
# model
preprocess_cfg = cfg.img_norm_cfg
formatshape = None
for trans in cfg.train_pipeline:
if trans.type == 'FormatShape':
formatshape = trans.input_format
preprocess_cfg['input_format'] = formatshape
cfg.preprocess_cfg = preprocess_cfg
cfg.model.data_preprocessor = dict(
type='ActionDataPreprocessor', **dict(preprocess_cfg))
cfg.pop('img_norm_cfg')
if (cfg.model.test_cfg is not None) and ('average_clips'
in cfg.model.test_cfg):
cfg.model.cls_head.average_clips = cfg.model.test_cfg.average_clips
cfg.model.test_cfg.pop('average_clips')
if len(cfg.model.test_cfg) == 0:
cfg.model.test_cfg = None
# pipeline
pipelines = [cfg.train_pipeline, cfg.val_pipeline, cfg.test_pipeline]
if origin_dataset_type == 'VideoDataset':
for pipeline in pipelines:
new_pipeline = [
trans for trans in pipeline
if trans.type not in ['Normalize', 'Collect', 'ToTensor']
]
new_pipeline.append(dict(type='PackActionInputs'))
pipeline.clear().extend(new_pipeline)
elif origin_dataset_type == 'RawframeDataset':
for pipeline in pipelines:
new_pipeline = [
trans for trans in pipeline if trans.type not in
['RawFrameDecode', 'Normalize', 'Collect', 'ToTensor']
]
new_pipeline.insert(0, dict(type='DecordInit'))
new_pipeline.insert(2, dict(type='DecordDecode'))
new_pipeline.append(dict(type='PackActionInputs'))
pipeline.clear()
pipeline.extend(new_pipeline)
# dataloader
cfg.data.train.update(
dict(
type=cfg.dataset_type,
ann_file=cfg.ann_file_train,
data_prefix=dict(video=cfg.data_root),
pipeline=cfg.train_pipeline))
cfg.train_dataloader = dict(
batch_size=cfg.data.videos_per_gpu,
num_workers=cfg.data.workers_per_gpu,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=cfg.data.train)
val_batchsize = cfg.data.val_dataloader.videos_per_gpu \
if 'val_dataloader' in cfg.data else cfg.data.videos_per_gpu
cfg.data.val.update(
dict(
type=cfg.dataset_type,
ann_file=cfg.ann_file_val,
data_prefix=dict(video=cfg.data_root_val),
pipeline=cfg.val_pipeline))
cfg.val_dataloader = dict(
batch_size=val_batchsize,
num_workers=cfg.data.workers_per_gpu,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=cfg.data.val)
test_batchsize = cfg.data.test_dataloader.videos_per_gpu \
if 'test_dataloader' in cfg.data else cfg.data.videos_per_gpu
cfg.data.test.update(
dict(
type=cfg.dataset_type,
ann_file=cfg.ann_file_test,
data_prefix=dict(video=cfg.data_root_val),
pipeline=cfg.test_pipeline))
cfg.test_dataloader = dict(
batch_size=test_batchsize,
num_workers=cfg.data.workers_per_gpu,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=cfg.data.test)
cfg.pop('data')
# eval
cfg.val_evaluator = dict(type='AccMetric')
cfg.test_evaluator = cfg.val_evaluator
cfg.val_cfg = dict(interval=cfg.evaluation.interval)
cfg.test_cfg = dict()
cfg.pop('evaluation')
# optimizer
optimizer_wrapper = dict(optimizer=dict())
for k, v in cfg.optimizer.items():
if k not in ['paramwise_cfg', 'constructor']:
optimizer_wrapper['optimizer'].update({k: v})
else:
optimizer_wrapper.update({k: v})
for k, v in cfg.optimizer_config.items():
if k == 'grad_clip':
k = 'clip_grad'
optimizer_wrapper.update({k: v})
cfg.optimizer_wrapper = optimizer_wrapper
cfg.pop('optimizer')
cfg.pop('optimizer_config')
# train_cfg
cfg.train_cfg = dict(by_epoch=True, max_epochs=cfg.total_epochs)
cfg.pop('total_epochs')
# schedule
cfg.param_scheduler = []
warmup_epoch = 0
if 'warmup' in cfg.lr_config:
warmup_ratio = 0.1 \
if 'warmup_ratio' not in cfg.lr_config \
else cfg.lr_config.warmup_ratio
warmup_epoch = cfg.lr_config.warmup_iters
cfg.param_scheduler.append(
dict(
type='LinearLR',
bengin=0,
start_factor=warmup_ratio,
end=cfg.lr_config.warmup_iters,
by_epoch=cfg.lr_config.warmup_by_epoch))
if cfg.lr_config.policy == 'step':
cfg.param_scheduler.append(
dict(
type='MultiStepLR',
milestones=cfg.lr_config.step,
by_epoch=cfg.train_cfg.by_epoch,
begin=0,
end=cfg.train_cfg.max_epochs,
gamma=0.1
if 'gamma' not in cfg.lr_config else cfg.lr_config.gamma))
elif cfg.lr_config.policy == 'CosineAnnealing':
cfg.param_scheduler.append(
dict(
type='CosineAnnealingLR',
eta_min=cfg.lr_config.min_lr,
by_epoch=cfg.train_cfg.by_epoch,
begin=warmup_epoch,
end=cfg.train_cfg.max_epochs,
T_max=cfg.train_cfg.max_epochs - warmup_epoch))
else:
raise ValueError(f'Not support convert {cfg.lr_config.policy}')
cfg.pop('lr_config')
# runtime
cfg.default_scope = 'mmaction'
cfg.default_hooks = dict(
runtime_info=dict(type='RuntimeInfoHook'),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=20),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', **cfg.checkpoint_config),
sampler_seed=dict(type='DistSamplerSeedHook'),
)
cfg.env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(
mp_start_method=cfg.mp_start_method,
opencv_num_threads=cfg.opencv_num_threads),
dist_cfg=dict(**cfg.dist_params),
)
cfg.log_level = 'INFO'
cfg.load_from = None
cfg.resume = False
cfg.pop('workflow')
cfg.pop('mp_start_method')
cfg.pop('opencv_num_threads')
cfg.pop('log_config')
cfg.pop('dist_params')
cfg.pop('checkpoint_config')
cfg.pop('work_dir')
cfg.dump(output_config_path)
print('Successful')
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Convert an action recognizer config file \
from OpenMMLAb framework v1.0 to v2.0')
parser.add_argument('config', help='The config file path')
parser.add_argument('output_config', help='The config file path')
args = parser.parse_args()
convert(args.config, args.output_config)