Spaces:
Runtime error
Runtime error
| from argparse import Namespace | |
| import time | |
| import torch | |
| import torchvision.transforms as transforms | |
| import dlib | |
| import numpy as np | |
| from PIL import Image | |
| from pixel2style2pixel.utils.common import tensor2im | |
| from pixel2style2pixel.models.psp import pSp | |
| from pixel2style2pixel.scripts.align_all_parallel import align_face | |
| class InversionModel: | |
| def __init__(self, checkpoint_path: str, dlib_path: str) -> None: | |
| self.dlib_path = dlib_path | |
| self.dlib_predictor = dlib.shape_predictor(dlib_path) | |
| self.tranform_image = transforms.Compose( | |
| [ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| ] | |
| ) | |
| ckpt = torch.load(checkpoint_path, map_location="cpu") | |
| opts = ckpt["opts"] | |
| opts["checkpoint_path"] = checkpoint_path | |
| opts["learn_in_w"] = False | |
| opts["output_size"] = 1024 | |
| self.opts = Namespace(**opts) | |
| self.net = pSp(self.opts) | |
| self.net.eval() | |
| self.net.cuda() | |
| print("Model successfully loaded!") | |
| def run_alignment(self, image_path: str): | |
| aligned_image = align_face(filepath=image_path, predictor=self.dlib_predictor) | |
| print("Aligned image has shape: {}".format(aligned_image.size)) | |
| return aligned_image | |
| def inference(self, image_path: str): | |
| input_image = self.run_alignment(image_path) | |
| input_image = input_image.resize((256, 256)) | |
| transformed_image = self.tranform_image(input_image) | |
| with torch.no_grad(): | |
| tic = time.time() | |
| result_image, latents = self.net( | |
| transformed_image.unsqueeze(0).to("cuda").float(), | |
| return_latents=True, | |
| randomize_noise=False, | |
| ) | |
| toc = time.time() | |
| print("Inference took {:.4f} seconds.".format(toc - tic)) | |
| res_image = tensor2im(result_image[0]) | |
| return ( | |
| res_image, | |
| { | |
| "w1": latents.cpu().detach().numpy(), | |
| "w1_initial": latents.cpu().detach().numpy(), | |
| }, | |
| ) | |