import torch from editings.styleclip.mapper import latent_mappers from torch import nn def get_keys(d, name): if "state_dict" in d: d = d["state_dict"] d_filt = {k[len(name) + 1 :]: v for k, v in d.items() if k[: len(name)] == name} return d_filt class StyleCLIPMapper(nn.Module): def __init__(self, opts): super(StyleCLIPMapper, self).__init__() self.opts = opts # Define architecture self.mapper = self.set_mapper() self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) # Load weights if needed self.load_weights() def set_mapper(self): if self.opts.mapper_type == "SingleMapper": mapper = latent_mappers.SingleMapper(self.opts) elif self.opts.mapper_type == "LevelsMapper": mapper = latent_mappers.LevelsMapper(self.opts) else: raise Exception("{} is not a valid mapper".format(self.opts.mapper_type)) return mapper def load_weights(self): if self.opts.checkpoint_path is not None: ckpt = torch.load(self.opts.checkpoint_path, map_location="cpu") self.mapper.load_state_dict(get_keys(ckpt, "mapper"), strict=True)