DeMoE / archs /__init__.py
danifei's picture
fixed appearance and added more images
5d01aa8
import torch
from .DeMoE import DeMoE
def create_model(opt, device):
'''
Creates the model.
opt: a dictionary from the yaml config key network
'''
name = opt['name']
if name == 'DeMoE':
model = DeMoE(img_channel=opt['img_channels'],
width=opt['width'],
middle_blk_num=opt['middle_blk_num'],
enc_blk_nums=opt['enc_blk_nums'],
dec_blk_nums=opt['dec_blk_nums'],
num_exp=opt['num_experts'],
k_used=opt['k_used'])
else:
raise NotImplementedError('This network is not implemented')
model.to(device)
return model
# def load_weights(model, model_weights):
# '''
# Loads the weights of a pretrained model, picking only the weights that are
# in the new model.
# '''
# new_weights = model.state_dict()
# new_weights.update({k: v for k, v in model_weights.items() if k in new_weights})
# new_weights = {key.replace('module.', ''): value for key, value in new_weights.items()}
# print(new_weights.keys())
# print(model.state_dict().keys())
# model.load_state_dict(new_weights, strict= True)
# total_checkpoint_keys = len(model_weights)
# total_model_keys = len(new_weights)
# matching_keys = len(set(model_weights.keys()) & set(new_weights.keys()))
# print(f"Total keys in checkpoint: {total_checkpoint_keys}")
# print(f"Total keys in model state dict: {total_model_keys}")
# print(f"Number of matching keys: {matching_keys}")
# return model
def strip_prefixes(sd: dict, prefixes=("module.", "model.", "ema.", "net.", "netG.", "generator.")) -> dict:
out = {}
for k, v in sd.items():
nk = k
for p in prefixes:
if nk.startswith(p):
nk = nk[len(p):]
break
out[nk] = v
return out
# ===== quita DDP y local_rank, usa un device único =====
def load_model(model, path_weights: str, device: torch.device):
# siempre carga en CPU y luego mueve
ckpt = torch.load(path_weights, map_location='cpu', weights_only=False)
# intenta varias claves habituales; si no, usa el dict tal cual
sd = ckpt.get("params") or ckpt.get("model_state_dict") or ckpt.get("state_dict") or ckpt
sd = strip_prefixes(sd)
missing, unexpected = model.load_state_dict(sd, strict=False)
print(f"[DeMoE] load_state: missing={len(missing)}, unexpected={len(unexpected)}")
model = model.to(device=device, dtype=torch.float32).eval()
return model
# def resume_model(model,
# path_model,
# device):
# '''
# Returns the loaded weights of model and optimizer if resume flag is True
# '''
# checkpoints = torch.load(path_model, map_location=device, weights_only=False)
# weights = checkpoints['params']
# model = load_weights(model, model_weights=weights)
# return model
__all__ = ['create_model', 'load_model']