ViTP / ckpts /reduct_pth.py
GreatBird's picture
Upload 125 files
af20dda verified
import torch
import glob
# find out all pth files in the directory
pth_files = glob.glob('**/*.pth', recursive=True)
# loop over all pth files and overwirte them without the 'optimizer' key
def overwirte_pth(pth_file):
print(f'Overwriting {pth_file}')
checkpoint = torch.load(pth_file)
# print the keys and values of the checkpoint
print(checkpoint.keys()) # dict_keys(['meta', 'state_dict', 'optimizer'])
if 'optimizer' not in checkpoint.keys():
print('No optimizer found in the checkpoint')
return
# delete the 'optimizer' key
del checkpoint['optimizer']
if 'param_schedulers' in checkpoint.keys(): del checkpoint['param_schedulers']
if 'message_hub' in checkpoint.keys(): del checkpoint['message_hub']
# overwirte the checkpoint without the 'optimizer' key
torch.save(checkpoint, pth_file)
print(f'Overwritten {pth_file} successfully')
for pth_file in pth_files:
overwirte_pth(pth_file)