| import torch |
| from argparse import Namespace |
| from torchvision.transforms import transforms |
|
|
| from configs import paths_config |
| from models.e4e.psp import pSp |
| from scripts.latent_creators.base_latent_creator import BaseLatentCreator |
| from utils.log_utils import log_image_from_w |
|
|
|
|
| class E4ELatentCreator(BaseLatentCreator): |
|
|
| def __init__(self, use_wandb=False): |
| self.e4e_inversion_pre_process = transforms.Compose([ |
| transforms.Resize((256, 256)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) |
|
|
| super().__init__('e4e', self.e4e_inversion_pre_process, use_wandb=use_wandb) |
|
|
| e4e_model_path = paths_config.e4e |
| ckpt = torch.load(e4e_model_path, map_location='cpu') |
| opts = ckpt['opts'] |
| opts['batch_size'] = 1 |
| opts['checkpoint_path'] = e4e_model_path |
| opts = Namespace(**opts) |
| self.e4e_inversion_net = pSp(opts) |
| self.e4e_inversion_net.eval() |
| self.e4e_inversion_net = self.e4e_inversion_net.cuda() |
|
|
| def run_projection(self, fname, image): |
| _, e4e_image_latent = self.e4e_inversion_net(image, randomize_noise=False, return_latents=True, |
| resize=False, |
| input_code=False) |
|
|
| if self.use_wandb: |
| log_image_from_w(e4e_image_latent, self.old_G, 'First e4e inversion') |
|
|
| return e4e_image_latent |
|
|
|
|
| if __name__ == '__main__': |
| e4e_latent_creator = E4ELatentCreator() |
| e4e_latent_creator.create_latents() |
|
|