File size: 4,132 Bytes
3dcfb26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import shutil
import torch

def save_checkpoint(state, is_best, args, filename='default'):
    if filename=='default':
        filename = 'mcn_%s_batch%d'%(args.dataset,args.samples_per_gpu)
    print("=> saving checkpoint '{}'".format(filename))
    if not os.path.exists('./saved_models'):
        os.makedirs('./saved_models')
    checkpoint_name = './saved_models/%s_checkpoint.pth.tar'%(filename)
    best_name = './saved_models/%s_model_best.pth.tar'%(filename)
    torch.save(state, checkpoint_name)
    if is_best:
        print("=> saving best model '{}'".format(best_name))
        shutil.copyfile(checkpoint_name, best_name)

def load_pretrain(model, args, logging, rank):
    if os.path.isfile(args.pretrain):
        checkpoint = torch.load(args.pretrain)
        pretrained_dict = checkpoint['state_dict']
        if hasattr(model, 'module'):
            model_dict = model.module.state_dict()
        else:
            model_dict = model.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        assert (len([k for k, v in pretrained_dict.items()])!=0)
        model_dict.update(pretrained_dict)
        if hasattr(model, 'module'):
            model.module.load_state_dict(model_dict)
        else:
            model.load_state_dict(model_dict)
        print("=> loaded pretrain model at {}"
              .format(args.pretrain))
        if rank == 0:
            logging.info("=> loaded pretrain model at {}"
                .format(args.pretrain))
        del checkpoint  # dereference seems crucial
        torch.cuda.empty_cache()
    else:
        print(("=> no pretrained file found at '{}'".format(args.pretrain)))
        if rank == 0:
            logging.info("=> no pretrained file found at '{}'".format(args.pretrain))
    return model

def load_pretrain_ddp(model, args):
    if os.path.isfile(args.pretrain):
        checkpoint = torch.load(args.pretrain)
        pretrained_dict = checkpoint['state_dict']
        model_dict = model.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        assert (len([k for k, v in pretrained_dict.items()])!=0)
        model_dict.update(pretrained_dict)
        if hasattr(model, 'module'):
            state_dict = model.module.state_dict()
            model.module.load_state_dict(model_dict)
        else:
            state_dict = model.state_dict()
            model.load_state_dict(model_dict)
        print("load ")
        print("=> loaded pretrain model at {}"
              .format(args.pretrain))
        del checkpoint  # dereference seems crucial
        torch.cuda.empty_cache()
    else:
        print(("=> no pretrained file found at '{}'".format(args.pretrain)))
    return model


def load_resume(model, optimizer, args, logging, rank):
    if os.path.isfile(args.resume):
        print(("=> loading checkpoint '{}'".format(args.resume)))
        if rank == 0:
            logging.info("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume, map_location='cpu')
        args.start_epoch = checkpoint['epoch']
        print("epoch: ", args.start_epoch)
        args.best_iou = checkpoint['best_iou']
        print("best iou: ", args.best_iou)
        state_dict = checkpoint['state_dict']

        if hasattr(model, 'module'):
            model_dict = model.module.state_dict()
        else:
            model_dict = model.state_dict()
        new_state_dict = {k:v for k,v in state_dict.items() if k in model_dict}
        model_dict.update(new_state_dict)

        
        if hasattr(model, 'module'):
            model.module.load_state_dict(model_dict)
        else:
            model.load_state_dict(model_dict)
        optimizer.load_state_dict(checkpoint['optimizer'])
        del checkpoint  # dereference seems crucial
        torch.cuda.empty_cache()
        print("load successfully!")
    else:
        print(("=> no checkpoint found at '{}'".format(args.resume)))
        if rank == 0:
            logging.info(("=> no checkpoint found at '{}'".format(args.resume)))
    return model