Smile_Changer / editings /styleclip /mapper /styleclip_mapper.py
LogicGoInfotechSpaces's picture
Bundle StyleFeatureEditor code packages in Space to fix ModuleNotFoundError
95b1715
import torch
from editings.styleclip.mapper import latent_mappers
from torch import nn
def get_keys(d, name):
if "state_dict" in d:
d = d["state_dict"]
d_filt = {k[len(name) + 1 :]: v for k, v in d.items() if k[: len(name)] == name}
return d_filt
class StyleCLIPMapper(nn.Module):
def __init__(self, opts):
super(StyleCLIPMapper, self).__init__()
self.opts = opts
# Define architecture
self.mapper = self.set_mapper()
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
# Load weights if needed
self.load_weights()
def set_mapper(self):
if self.opts.mapper_type == "SingleMapper":
mapper = latent_mappers.SingleMapper(self.opts)
elif self.opts.mapper_type == "LevelsMapper":
mapper = latent_mappers.LevelsMapper(self.opts)
else:
raise Exception("{} is not a valid mapper".format(self.opts.mapper_type))
return mapper
def load_weights(self):
if self.opts.checkpoint_path is not None:
ckpt = torch.load(self.opts.checkpoint_path, map_location="cpu")
self.mapper.load_state_dict(get_keys(ckpt, "mapper"), strict=True)