| | import torch |
| | from torch import nn |
| | from models.StyleCLIP.mapper import latent_mappers |
| | from models.StyleCLIP.models.stylegan2.model import Generator |
| |
|
| |
|
| | def get_keys(d, name): |
| | if 'state_dict' in d: |
| | d = d['state_dict'] |
| | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} |
| | return d_filt |
| |
|
| |
|
| | class StyleCLIPMapper(nn.Module): |
| |
|
| | def __init__(self, opts, run_id): |
| | super(StyleCLIPMapper, self).__init__() |
| | self.opts = opts |
| | |
| | self.mapper = self.set_mapper() |
| | self.run_id = run_id |
| |
|
| | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) |
| | |
| | self.load_weights() |
| |
|
| | def set_mapper(self): |
| | if self.opts.mapper_type == 'SingleMapper': |
| | mapper = latent_mappers.SingleMapper(self.opts) |
| | elif self.opts.mapper_type == 'LevelsMapper': |
| | mapper = latent_mappers.LevelsMapper(self.opts) |
| | else: |
| | raise Exception('{} is not a valid mapper'.format(self.opts.mapper_type)) |
| | return mapper |
| |
|
| | def load_weights(self): |
| | if self.opts.checkpoint_path is not None: |
| | print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path)) |
| | ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') |
| | self.mapper.load_state_dict(get_keys(ckpt, 'mapper'), strict=True) |
| |
|
| | def set_G(self, new_G): |
| | self.decoder = new_G |
| |
|
| | def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, |
| | inject_latent=None, return_latents=False, alpha=None): |
| | if input_code: |
| | codes = x |
| | else: |
| | codes = self.mapper(x) |
| |
|
| | if latent_mask is not None: |
| | for i in latent_mask: |
| | if inject_latent is not None: |
| | if alpha is not None: |
| | codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] |
| | else: |
| | codes[:, i] = inject_latent[:, i] |
| | else: |
| | codes[:, i] = 0 |
| |
|
| | input_is_latent = not input_code |
| | images = self.decoder.synthesis(codes, noise_mode='const') |
| | result_latent = None |
| | |
| | |
| | |
| | |
| |
|
| | if resize: |
| | images = self.face_pool(images) |
| |
|
| | if return_latents: |
| | return images, result_latent |
| | else: |
| | return images |
| |
|