| import torch |
| import os |
| from collections import OrderedDict |
|
|
| def freeze(model): |
| for p in model.parameters(): |
| p.requires_grad=False |
|
|
| def unfreeze(model): |
| for p in model.parameters(): |
| p.requires_grad=True |
|
|
| def is_frozen(model): |
| x = [p.requires_grad for p in model.parameters()] |
| return not all(x) |
|
|
| def save_checkpoint(model_dir, state, session): |
| epoch = state['epoch'] |
| model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session)) |
| torch.save(state, model_out_path) |
|
|
| def load_checkpoint(model, weights): |
| checkpoint = torch.load(weights) |
| try: |
| model.load_state_dict(checkpoint["state_dict"]) |
| except: |
| print(checkpoint.keys()) |
| state_dict = checkpoint["state_dict_G2"] |
| new_state_dict = OrderedDict() |
| for k, v in state_dict.items(): |
| name = k[7:] |
| new_state_dict[name] = v |
| model.load_state_dict(new_state_dict) |
|
|
|
|
| def load_checkpoint_multigpu(model, weights): |
| checkpoint = torch.load(weights) |
| state_dict = checkpoint["state_dict"] |
| new_state_dict = OrderedDict() |
| for k, v in state_dict.items(): |
| name = k[7:] |
| new_state_dict[name] = v |
| model.load_state_dict(new_state_dict) |
|
|
| def load_start_epoch(weights): |
| checkpoint = torch.load(weights) |
| epoch = checkpoint["epoch"] |
| return epoch |
|
|
| def load_optim(optimizer, weights): |
| checkpoint = torch.load(weights) |
| optimizer.load_state_dict(checkpoint['optimizer']) |
| |
| |
| def load_checkpointG1(model, weights, strict=False): |
| """ |
| 加载 checkpoint 到 model(支持 DDP 的 model.module)。 |
| strict=False:只加载匹配的 key,不匹配的跳过(用于旧 ckpt 兼容新模型,如缺 prior_encoder/fuse_conv)。 |
| """ |
| checkpoint = torch.load(weights, map_location="cpu") |
| state_dict = checkpoint["state_dict_G1"] |
| |
| new_state_dict = OrderedDict() |
| for k, v in state_dict.items(): |
| name = k[7:] if k.startswith("module.") else k |
| new_state_dict[name] = v |
| |
| target = model.module if hasattr(model, "module") else model |
| load_ret = target.load_state_dict(new_state_dict, strict=strict) |
| if not strict: |
| if load_ret.missing_keys: |
| print("[load_checkpointG1] 未加载的 key(如 prior_encoder/fuse_conv 等): %d 个" % len(load_ret.missing_keys)) |
| if load_ret.unexpected_keys: |
| print("[load_checkpointG1] ckpt 中多余被忽略的 key: %d 个" % len(load_ret.unexpected_keys)) |
|
|
| def load_checkpointG2(model, weights): |
| checkpoint = torch.load(weights) |
| |
| try: |
| model.load_state_dict(checkpoint["state_dict_G2"]) |
| except: |
| state_dict = checkpoint["state_dict_G2"] |
| new_state_dict = OrderedDict() |
| for k,v in state_dict.items(): |
| name = k[7:] |
| |
| new_state_dict[name] = v |
| model.load_state_dict(new_state_dict) |
|
|
| def load_optimG1(optimizer, weights): |
| checkpoint = torch.load(weights, map_location="cpu") |
| optimizer.load_state_dict(checkpoint["optimizer_G1"]) |
|
|
|
|
| def load_schedulerG1(scheduler, weights): |
| """若 checkpoint 中有 scheduler_G1 则加载,否则不报错。""" |
| checkpoint = torch.load(weights, map_location="cpu") |
| if "scheduler_G1" in checkpoint: |
| scheduler.load_state_dict(checkpoint["scheduler_G1"]) |
| return True |
| return False |
| def load_optimG2(optimizer, weights): |
| checkpoint = torch.load(weights) |
| print(checkpoint.keys()) |
| optimizer.load_state_dict(checkpoint['optimizer_G2']) |