File size: 1,442 Bytes
cb0ad2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import os
import torch
from .comm import get_rank, synchronize
def save_checkpoint(checkpoint, model, optimizer=None, best_metric=None, epoch=None):
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model = model.module
if get_rank() == 0:
if not os.path.exists(os.path.dirname(checkpoint)):
os.makedirs(os.path.dirname(checkpoint))
infos = dict()
infos['model_param'] = model.state_dict()
if optimizer is not None:
infos['opt_param'] = optimizer.state_dict()
if best_metric is not None:
infos['best_metric'] = best_metric
if epoch is not None:
infos['epoch'] = epoch
torch.save(infos, checkpoint)
synchronize()
def load_checkpoint(checkpoint, model, optimizer=None):
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model = model.module
checkpoint = torch.load(checkpoint, map_location='cpu')
model.load_state_dict(checkpoint['model_param'], strict=False)
if (optimizer is not None) and ('opt_param' in checkpoint):
optimizer.load_state_dict(checkpoint['opt_param'])
if 'best_metric' in checkpoint:
best_metric = checkpoint['best_metric']
else:
best_metric = None
if 'epoch' in checkpoint:
epoch = checkpoint['epoch']
else:
epoch = None
return best_metric, epoch |