File size: 3,827 Bytes
4336727 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | import torch
import os
from collections import OrderedDict
def freeze(model):
for p in model.parameters():
p.requires_grad=False
def unfreeze(model):
for p in model.parameters():
p.requires_grad=True
def is_frozen(model):
x = [p.requires_grad for p in model.parameters()]
return not all(x)
def save_checkpoint(model_dir, state, session):
epoch = state['epoch']
model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session))
torch.save(state, model_out_path)
def load_checkpoint(model, weights):
checkpoint = torch.load(weights)
try:
model.load_state_dict(checkpoint["state_dict"])
except:
print(checkpoint.keys())
state_dict = checkpoint["state_dict_G2"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
def load_checkpoint_multigpu(model, weights):
checkpoint = torch.load(weights)
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
def load_start_epoch(weights):
checkpoint = torch.load(weights)
epoch = checkpoint["epoch"]
return epoch
def load_optim(optimizer, weights):
checkpoint = torch.load(weights)
optimizer.load_state_dict(checkpoint['optimizer'])
# for p in optimizer.param_groups: lr = p['lr']
# return lr
def load_checkpointG1(model, weights, strict=False):
"""
加载 checkpoint 到 model(支持 DDP 的 model.module)。
strict=False:只加载匹配的 key,不匹配的跳过(用于旧 ckpt 兼容新模型,如缺 prior_encoder/fuse_conv)。
"""
checkpoint = torch.load(weights, map_location="cpu")
state_dict = checkpoint["state_dict_G1"]
# 统一成“无 module. 前缀”的 key,便于加载到 model.module 或单卡 model
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] if k.startswith("module.") else k
new_state_dict[name] = v
# DDP 时加载到 model.module,否则加载到 model
target = model.module if hasattr(model, "module") else model
load_ret = target.load_state_dict(new_state_dict, strict=strict)
if not strict:
if load_ret.missing_keys:
print("[load_checkpointG1] 未加载的 key(如 prior_encoder/fuse_conv 等): %d 个" % len(load_ret.missing_keys))
if load_ret.unexpected_keys:
print("[load_checkpointG1] ckpt 中多余被忽略的 key: %d 个" % len(load_ret.unexpected_keys))
def load_checkpointG2(model, weights):
checkpoint = torch.load(weights)
# print(checkpoint.keys())
try:
model.load_state_dict(checkpoint["state_dict_G2"])
except:
state_dict = checkpoint["state_dict_G2"]
new_state_dict = OrderedDict()
for k,v in state_dict.items():
name = k[7:]
# name = k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
def load_optimG1(optimizer, weights):
checkpoint = torch.load(weights, map_location="cpu")
optimizer.load_state_dict(checkpoint["optimizer_G1"])
def load_schedulerG1(scheduler, weights):
"""若 checkpoint 中有 scheduler_G1 则加载,否则不报错。"""
checkpoint = torch.load(weights, map_location="cpu")
if "scheduler_G1" in checkpoint:
scheduler.load_state_dict(checkpoint["scheduler_G1"])
return True
return False
def load_optimG2(optimizer, weights):
checkpoint = torch.load(weights)
print(checkpoint.keys())
optimizer.load_state_dict(checkpoint['optimizer_G2']) |