| import torch | |
| chunk0 = torch.load('./checkpoints_00335001_model_chunk0.pth') | |
| chunk1 = torch.load('./checkpoints_00335001_model_chunk1.pth') | |
| others = torch.load('./checkpoints_00335001_others.pth') | |
| model = {**chunk0, **chunk1} | |
| res = {'model': model, **others} | |
| torch.save(res, './checkpoints_00335001.pth') |