# -------------------------------------------------------- # RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf) # Github source: https://github.com/DingXiaoH/RepVGG # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import argparse import os import torch import torch.nn.parallel import torch.optim import torch.utils.data import torch.utils.data.distributed from repvggplus import create_RepVGGplus_by_name, repvgg_model_convert parser = argparse.ArgumentParser(description='RepVGG(plus) Conversion') parser.add_argument('load', metavar='LOAD', help='path to the weights file') parser.add_argument('save', metavar='SAVE', help='path to the weights file') parser.add_argument('-a', '--arch', metavar='ARCH', default='RepVGG-A0') def convert(): args = parser.parse_args() train_model = create_RepVGGplus_by_name(args.arch, deploy=False) if os.path.isfile(args.load): print("=> loading checkpoint '{}'".format(args.load)) checkpoint = torch.load(args.load) if 'state_dict' in checkpoint: checkpoint = checkpoint['state_dict'] elif 'model' in checkpoint: checkpoint = checkpoint['model'] ckpt = {k.replace('module.', ''): v for k, v in checkpoint.items()} # strip the names print(ckpt.keys()) train_model.load_state_dict(ckpt) else: print("=> no checkpoint found at '{}'".format(args.load)) if 'plus' in args.arch: train_model.switch_repvggplus_to_deploy() torch.save(train_model.state_dict(), args.save) else: repvgg_model_convert(train_model, save_path=args.save) if __name__ == '__main__': convert()