import torch import os from collections import OrderedDict def freeze(model): for p in model.parameters(): p.requires_grad=False def unfreeze(model): for p in model.parameters(): p.requires_grad=True def is_frozen(model): x = [p.requires_grad for p in model.parameters()] return not all(x) def save_checkpoint(model_dir, state, session): epoch = state['epoch'] model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session)) torch.save(state, model_out_path) def load_checkpoint(model, weights): checkpoint = torch.load(weights) try: model.load_state_dict(checkpoint["state_dict"]) except: print(checkpoint.keys()) state_dict = checkpoint["state_dict_G2"] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) def load_checkpoint_multigpu(model, weights): checkpoint = torch.load(weights) state_dict = checkpoint["state_dict"] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) def load_start_epoch(weights): checkpoint = torch.load(weights) epoch = checkpoint["epoch"] return epoch def load_optim(optimizer, weights): checkpoint = torch.load(weights) optimizer.load_state_dict(checkpoint['optimizer']) # for p in optimizer.param_groups: lr = p['lr'] # return lr def load_checkpointG1(model, weights, strict=False): """ 加载 checkpoint 到 model(支持 DDP 的 model.module)。 strict=False:只加载匹配的 key,不匹配的跳过(用于旧 ckpt 兼容新模型,如缺 prior_encoder/fuse_conv)。 """ checkpoint = torch.load(weights, map_location="cpu") state_dict = checkpoint["state_dict_G1"] # 统一成“无 module. 前缀”的 key,便于加载到 model.module 或单卡 model new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] if k.startswith("module.") else k new_state_dict[name] = v # DDP 时加载到 model.module,否则加载到 model target = model.module if hasattr(model, "module") else model load_ret = target.load_state_dict(new_state_dict, strict=strict) if not strict: if load_ret.missing_keys: print("[load_checkpointG1] 未加载的 key(如 prior_encoder/fuse_conv 等): %d 个" % len(load_ret.missing_keys)) if load_ret.unexpected_keys: print("[load_checkpointG1] ckpt 中多余被忽略的 key: %d 个" % len(load_ret.unexpected_keys)) def load_checkpointG2(model, weights): checkpoint = torch.load(weights) # print(checkpoint.keys()) try: model.load_state_dict(checkpoint["state_dict_G2"]) except: state_dict = checkpoint["state_dict_G2"] new_state_dict = OrderedDict() for k,v in state_dict.items(): name = k[7:] # name = k new_state_dict[name] = v model.load_state_dict(new_state_dict) def load_optimG1(optimizer, weights): checkpoint = torch.load(weights, map_location="cpu") optimizer.load_state_dict(checkpoint["optimizer_G1"]) def load_schedulerG1(scheduler, weights): """若 checkpoint 中有 scheduler_G1 则加载,否则不报错。""" checkpoint = torch.load(weights, map_location="cpu") if "scheduler_G1" in checkpoint: scheduler.load_state_dict(checkpoint["scheduler_G1"]) return True return False def load_optimG2(optimizer, weights): checkpoint = torch.load(weights) print(checkpoint.keys()) optimizer.load_state_dict(checkpoint['optimizer_G2'])