| import os | |
| import torch | |
| from .comm import get_rank, synchronize | |
| def save_checkpoint(checkpoint, model, optimizer=None, best_metric=None, epoch=None): | |
| if isinstance(model, torch.nn.parallel.DistributedDataParallel): | |
| model = model.module | |
| if get_rank() == 0: | |
| if not os.path.exists(os.path.dirname(checkpoint)): | |
| os.makedirs(os.path.dirname(checkpoint)) | |
| infos = dict() | |
| infos['model_param'] = model.state_dict() | |
| if optimizer is not None: | |
| infos['opt_param'] = optimizer.state_dict() | |
| if best_metric is not None: | |
| infos['best_metric'] = best_metric | |
| if epoch is not None: | |
| infos['epoch'] = epoch | |
| torch.save(infos, checkpoint) | |
| synchronize() | |
| def load_checkpoint(checkpoint, model, optimizer=None): | |
| if isinstance(model, torch.nn.parallel.DistributedDataParallel): | |
| model = model.module | |
| checkpoint = torch.load(checkpoint, map_location='cpu') | |
| model.load_state_dict(checkpoint['model_param'], strict=False) | |
| if (optimizer is not None) and ('opt_param' in checkpoint): | |
| optimizer.load_state_dict(checkpoint['opt_param']) | |
| if 'best_metric' in checkpoint: | |
| best_metric = checkpoint['best_metric'] | |
| else: | |
| best_metric = None | |
| if 'epoch' in checkpoint: | |
| epoch = checkpoint['epoch'] | |
| else: | |
| epoch = None | |
| return best_metric, epoch |