| import torch | |
| # Load the checkpoint | |
| checkpoint = torch.load('/workspace/train-wefadoor-master/anydoor/lightning_logs/version_4/checkpoints/300k_u.ckpt', map_location='cpu') | |
| # Extract the state dictionary | |
| state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint | |
| # Save the state dictionary to a new checkpoint | |
| torch.save(state_dict, '/workspace/train-wefadoor-master/anydoor/lightning_logs/version_4/checkpoints/300k_u.ckpt') | |