| import torch | |
| from .DeMoE import DeMoE | |
| def create_model(opt, device): | |
| ''' | |
| Creates the model. | |
| opt: a dictionary from the yaml config key network | |
| ''' | |
| name = opt['name'] | |
| if name == 'DeMoE': | |
| model = DeMoE(img_channel=opt['img_channels'], | |
| width=opt['width'], | |
| middle_blk_num=opt['middle_blk_num'], | |
| enc_blk_nums=opt['enc_blk_nums'], | |
| dec_blk_nums=opt['dec_blk_nums'], | |
| num_exp=opt['num_experts'], | |
| k_used=opt['k_used']) | |
| else: | |
| raise NotImplementedError('This network is not implemented') | |
| model.to(device) | |
| return model | |
| # def load_weights(model, model_weights): | |
| # ''' | |
| # Loads the weights of a pretrained model, picking only the weights that are | |
| # in the new model. | |
| # ''' | |
| # new_weights = model.state_dict() | |
| # new_weights.update({k: v for k, v in model_weights.items() if k in new_weights}) | |
| # new_weights = {key.replace('module.', ''): value for key, value in new_weights.items()} | |
| # print(new_weights.keys()) | |
| # print(model.state_dict().keys()) | |
| # model.load_state_dict(new_weights, strict= True) | |
| # total_checkpoint_keys = len(model_weights) | |
| # total_model_keys = len(new_weights) | |
| # matching_keys = len(set(model_weights.keys()) & set(new_weights.keys())) | |
| # print(f"Total keys in checkpoint: {total_checkpoint_keys}") | |
| # print(f"Total keys in model state dict: {total_model_keys}") | |
| # print(f"Number of matching keys: {matching_keys}") | |
| # return model | |
| def strip_prefixes(sd: dict, prefixes=("module.", "model.", "ema.", "net.", "netG.", "generator.")) -> dict: | |
| out = {} | |
| for k, v in sd.items(): | |
| nk = k | |
| for p in prefixes: | |
| if nk.startswith(p): | |
| nk = nk[len(p):] | |
| break | |
| out[nk] = v | |
| return out | |
| # ===== quita DDP y local_rank, usa un device único ===== | |
| def load_model(model, path_weights: str, device: torch.device): | |
| # siempre carga en CPU y luego mueve | |
| ckpt = torch.load(path_weights, map_location='cpu', weights_only=False) | |
| # intenta varias claves habituales; si no, usa el dict tal cual | |
| sd = ckpt.get("params") or ckpt.get("model_state_dict") or ckpt.get("state_dict") or ckpt | |
| sd = strip_prefixes(sd) | |
| missing, unexpected = model.load_state_dict(sd, strict=False) | |
| print(f"[DeMoE] load_state: missing={len(missing)}, unexpected={len(unexpected)}") | |
| model = model.to(device=device, dtype=torch.float32).eval() | |
| return model | |
| # def resume_model(model, | |
| # path_model, | |
| # device): | |
| # ''' | |
| # Returns the loaded weights of model and optimizer if resume flag is True | |
| # ''' | |
| # checkpoints = torch.load(path_model, map_location=device, weights_only=False) | |
| # weights = checkpoints['params'] | |
| # model = load_weights(model, model_weights=weights) | |
| # return model | |
| __all__ = ['create_model', 'load_model'] | |