File size: 3,827 Bytes
4336727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import os
from collections import OrderedDict

def freeze(model):
    for p in model.parameters():
        p.requires_grad=False

def unfreeze(model):
    for p in model.parameters():
        p.requires_grad=True

def is_frozen(model):
    x = [p.requires_grad for p in model.parameters()]
    return not all(x)

def save_checkpoint(model_dir, state, session):
    epoch = state['epoch']
    model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session))
    torch.save(state, model_out_path)

def load_checkpoint(model, weights):
    checkpoint = torch.load(weights)
    try:
        model.load_state_dict(checkpoint["state_dict"])
    except:
        print(checkpoint.keys())
        state_dict = checkpoint["state_dict_G2"]
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] # remove `module.`
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)


def load_checkpoint_multigpu(model, weights):
    checkpoint = torch.load(weights)
    state_dict = checkpoint["state_dict"]
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)

def load_start_epoch(weights):
    checkpoint = torch.load(weights)
    epoch = checkpoint["epoch"]
    return epoch

def load_optim(optimizer, weights):
    checkpoint = torch.load(weights)
    optimizer.load_state_dict(checkpoint['optimizer'])
    # for p in optimizer.param_groups: lr = p['lr']
    # return lr
def load_checkpointG1(model, weights, strict=False):
    """
    加载 checkpoint 到 model(支持 DDP 的 model.module)。
    strict=False:只加载匹配的 key,不匹配的跳过(用于旧 ckpt 兼容新模型,如缺 prior_encoder/fuse_conv)。
    """
    checkpoint = torch.load(weights, map_location="cpu")
    state_dict = checkpoint["state_dict_G1"]
    # 统一成“无 module. 前缀”的 key,便于加载到 model.module 或单卡 model
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] if k.startswith("module.") else k
        new_state_dict[name] = v
    # DDP 时加载到 model.module,否则加载到 model
    target = model.module if hasattr(model, "module") else model
    load_ret = target.load_state_dict(new_state_dict, strict=strict)
    if not strict:
        if load_ret.missing_keys:
            print("[load_checkpointG1] 未加载的 key(如 prior_encoder/fuse_conv 等): %d 个" % len(load_ret.missing_keys))
        if load_ret.unexpected_keys:
            print("[load_checkpointG1] ckpt 中多余被忽略的 key: %d 个" % len(load_ret.unexpected_keys))

def load_checkpointG2(model, weights):
    checkpoint = torch.load(weights)
    # print(checkpoint.keys())
    try:
        model.load_state_dict(checkpoint["state_dict_G2"])
    except:
        state_dict = checkpoint["state_dict_G2"]
        new_state_dict = OrderedDict()
        for k,v in state_dict.items():
            name = k[7:]
            # name = k
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)

def load_optimG1(optimizer, weights):
    checkpoint = torch.load(weights, map_location="cpu")
    optimizer.load_state_dict(checkpoint["optimizer_G1"])


def load_schedulerG1(scheduler, weights):
    """若 checkpoint 中有 scheduler_G1 则加载,否则不报错。"""
    checkpoint = torch.load(weights, map_location="cpu")
    if "scheduler_G1" in checkpoint:
        scheduler.load_state_dict(checkpoint["scheduler_G1"])
        return True
    return False
def load_optimG2(optimizer, weights):
    checkpoint = torch.load(weights)
    print(checkpoint.keys())
    optimizer.load_state_dict(checkpoint['optimizer_G2'])