| | |
| | import argparse |
| | import subprocess |
| |
|
| | import torch |
| | from mmengine.logging import print_log |
| |
|
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser( |
| | description='Process a checkpoint to be published') |
| | parser.add_argument('in_file', help='input checkpoint filename') |
| | parser.add_argument('out_file', help='output checkpoint filename') |
| | parser.add_argument( |
| | '--save-keys', |
| | nargs='+', |
| | type=str, |
| | default=['meta', 'state_dict'], |
| | help='keys to save in the published checkpoint') |
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | def process_checkpoint(in_file, out_file, save_keys=['meta', 'state_dict']): |
| | checkpoint = torch.load(in_file, map_location='cpu') |
| |
|
| | |
| | ckpt_keys = list(checkpoint.keys()) |
| | for k in ckpt_keys: |
| | if k not in save_keys: |
| | print_log( |
| | f'Key `{k}` will be removed because it is not in ' |
| | f'save_keys. If you want to keep it, ' |
| | f'please set --save-keys.', |
| | logger='current') |
| | checkpoint.pop(k, None) |
| |
|
| | |
| | |
| | if torch.__version__ >= '1.6': |
| | torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False) |
| | else: |
| | torch.save(checkpoint, out_file) |
| | sha = subprocess.check_output(['sha256sum', out_file]).decode() |
| | if out_file.endswith('.pth'): |
| | out_file_name = out_file[:-4] |
| | else: |
| | out_file_name = out_file |
| | final_file = out_file_name + f'-{sha[:8]}.pth' |
| | subprocess.Popen(['mv', out_file, final_file]) |
| | print_log( |
| | f'The published model is saved at {final_file}.', logger='current') |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| | process_checkpoint(args.in_file, args.out_file, args.save_keys) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|