import torch from .DeMoE import DeMoE def create_model(opt, device): ''' Creates the model. opt: a dictionary from the yaml config key network ''' name = opt['name'] if name == 'DeMoE': model = DeMoE(img_channel=opt['img_channels'], width=opt['width'], middle_blk_num=opt['middle_blk_num'], enc_blk_nums=opt['enc_blk_nums'], dec_blk_nums=opt['dec_blk_nums'], num_exp=opt['num_experts'], k_used=opt['k_used']) else: raise NotImplementedError('This network is not implemented') model.to(device) return model # def load_weights(model, model_weights): # ''' # Loads the weights of a pretrained model, picking only the weights that are # in the new model. # ''' # new_weights = model.state_dict() # new_weights.update({k: v for k, v in model_weights.items() if k in new_weights}) # new_weights = {key.replace('module.', ''): value for key, value in new_weights.items()} # print(new_weights.keys()) # print(model.state_dict().keys()) # model.load_state_dict(new_weights, strict= True) # total_checkpoint_keys = len(model_weights) # total_model_keys = len(new_weights) # matching_keys = len(set(model_weights.keys()) & set(new_weights.keys())) # print(f"Total keys in checkpoint: {total_checkpoint_keys}") # print(f"Total keys in model state dict: {total_model_keys}") # print(f"Number of matching keys: {matching_keys}") # return model def strip_prefixes(sd: dict, prefixes=("module.", "model.", "ema.", "net.", "netG.", "generator.")) -> dict: out = {} for k, v in sd.items(): nk = k for p in prefixes: if nk.startswith(p): nk = nk[len(p):] break out[nk] = v return out # ===== quita DDP y local_rank, usa un device Ășnico ===== def load_model(model, path_weights: str, device: torch.device): # siempre carga en CPU y luego mueve ckpt = torch.load(path_weights, map_location='cpu', weights_only=False) # intenta varias claves habituales; si no, usa el dict tal cual sd = ckpt.get("params") or ckpt.get("model_state_dict") or ckpt.get("state_dict") or ckpt sd = strip_prefixes(sd) missing, unexpected = model.load_state_dict(sd, strict=False) print(f"[DeMoE] load_state: missing={len(missing)}, unexpected={len(unexpected)}") model = model.to(device=device, dtype=torch.float32).eval() return model # def resume_model(model, # path_model, # device): # ''' # Returns the loaded weights of model and optimizer if resume flag is True # ''' # checkpoints = torch.load(path_model, map_location=device, weights_only=False) # weights = checkpoints['params'] # model = load_weights(model, model_weights=weights) # return model __all__ = ['create_model', 'load_model']