File size: 4,132 Bytes
3dcfb26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import os
import shutil
import torch
def save_checkpoint(state, is_best, args, filename='default'):
if filename=='default':
filename = 'mcn_%s_batch%d'%(args.dataset,args.samples_per_gpu)
print("=> saving checkpoint '{}'".format(filename))
if not os.path.exists('./saved_models'):
os.makedirs('./saved_models')
checkpoint_name = './saved_models/%s_checkpoint.pth.tar'%(filename)
best_name = './saved_models/%s_model_best.pth.tar'%(filename)
torch.save(state, checkpoint_name)
if is_best:
print("=> saving best model '{}'".format(best_name))
shutil.copyfile(checkpoint_name, best_name)
def load_pretrain(model, args, logging, rank):
if os.path.isfile(args.pretrain):
checkpoint = torch.load(args.pretrain)
pretrained_dict = checkpoint['state_dict']
if hasattr(model, 'module'):
model_dict = model.module.state_dict()
else:
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
assert (len([k for k, v in pretrained_dict.items()])!=0)
model_dict.update(pretrained_dict)
if hasattr(model, 'module'):
model.module.load_state_dict(model_dict)
else:
model.load_state_dict(model_dict)
print("=> loaded pretrain model at {}"
.format(args.pretrain))
if rank == 0:
logging.info("=> loaded pretrain model at {}"
.format(args.pretrain))
del checkpoint # dereference seems crucial
torch.cuda.empty_cache()
else:
print(("=> no pretrained file found at '{}'".format(args.pretrain)))
if rank == 0:
logging.info("=> no pretrained file found at '{}'".format(args.pretrain))
return model
def load_pretrain_ddp(model, args):
if os.path.isfile(args.pretrain):
checkpoint = torch.load(args.pretrain)
pretrained_dict = checkpoint['state_dict']
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
assert (len([k for k, v in pretrained_dict.items()])!=0)
model_dict.update(pretrained_dict)
if hasattr(model, 'module'):
state_dict = model.module.state_dict()
model.module.load_state_dict(model_dict)
else:
state_dict = model.state_dict()
model.load_state_dict(model_dict)
print("load ")
print("=> loaded pretrain model at {}"
.format(args.pretrain))
del checkpoint # dereference seems crucial
torch.cuda.empty_cache()
else:
print(("=> no pretrained file found at '{}'".format(args.pretrain)))
return model
def load_resume(model, optimizer, args, logging, rank):
if os.path.isfile(args.resume):
print(("=> loading checkpoint '{}'".format(args.resume)))
if rank == 0:
logging.info("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location='cpu')
args.start_epoch = checkpoint['epoch']
print("epoch: ", args.start_epoch)
args.best_iou = checkpoint['best_iou']
print("best iou: ", args.best_iou)
state_dict = checkpoint['state_dict']
if hasattr(model, 'module'):
model_dict = model.module.state_dict()
else:
model_dict = model.state_dict()
new_state_dict = {k:v for k,v in state_dict.items() if k in model_dict}
model_dict.update(new_state_dict)
if hasattr(model, 'module'):
model.module.load_state_dict(model_dict)
else:
model.load_state_dict(model_dict)
optimizer.load_state_dict(checkpoint['optimizer'])
del checkpoint # dereference seems crucial
torch.cuda.empty_cache()
print("load successfully!")
else:
print(("=> no checkpoint found at '{}'".format(args.resume)))
if rank == 0:
logging.info(("=> no checkpoint found at '{}'".format(args.resume)))
return model
|