| import torch |
| import glob |
|
|
|
|
| |
| pth_files = glob.glob('**/*.pth', recursive=True) |
|
|
| |
|
|
| def overwirte_pth(pth_file): |
| print(f'Overwriting {pth_file}') |
| checkpoint = torch.load(pth_file) |
|
|
| |
| print(checkpoint.keys()) |
|
|
| if 'optimizer' not in checkpoint.keys(): |
| print('No optimizer found in the checkpoint') |
| return |
| |
| del checkpoint['optimizer'] |
| if 'param_schedulers' in checkpoint.keys(): del checkpoint['param_schedulers'] |
| if 'message_hub' in checkpoint.keys(): del checkpoint['message_hub'] |
|
|
| |
| torch.save(checkpoint, pth_file) |
| print(f'Overwritten {pth_file} successfully') |
| |
| for pth_file in pth_files: |
| overwirte_pth(pth_file) |
|
|
|
|