Spaces:
Runtime error
Runtime error
| import os | |
| import tempfile | |
| from cog import BasePredictor, Input, Path | |
| import shutil | |
| from argparse import Namespace | |
| import time | |
| import sys | |
| import pprint | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torchvision.transforms as transforms | |
| import dlib | |
| sys.path.append(".") | |
| sys.path.append("..") | |
| from datasets import augmentations | |
| from utils.common import tensor2im, log_input_image | |
| from models.psp import pSp | |
| from scripts.align_all_parallel import align_face | |
| class Predictor(BasePredictor): | |
| def setup(self): | |
| self.predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") | |
| model_paths = { | |
| "ffhq_frontalize": "pretrained_models/psp_ffhq_frontalization.pt", | |
| "celebs_sketch_to_face": "pretrained_models/psp_celebs_sketch_to_face.pt", | |
| "celebs_super_resolution": "pretrained_models/psp_celebs_super_resolution.pt", | |
| "toonify": "pretrained_models/psp_ffhq_toonify.pt", | |
| } | |
| loaded_models = {} | |
| for key, value in model_paths.items(): | |
| loaded_models[key] = torch.load(value, map_location="cpu") | |
| self.opts = {} | |
| for key, value in loaded_models.items(): | |
| self.opts[key] = value["opts"] | |
| for key in self.opts.keys(): | |
| self.opts[key]["checkpoint_path"] = model_paths[key] | |
| if "learn_in_w" not in self.opts[key]: | |
| self.opts[key]["learn_in_w"] = False | |
| if "output_size" not in self.opts[key]: | |
| self.opts[key]["output_size"] = 1024 | |
| self.transforms = {} | |
| for key in model_paths.keys(): | |
| if key in ["ffhq_frontalize", "toonify"]: | |
| self.transforms[key] = transforms.Compose( | |
| [ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| ] | |
| ) | |
| elif key == "celebs_sketch_to_face": | |
| self.transforms[key] = transforms.Compose( | |
| [transforms.Resize((256, 256)), transforms.ToTensor()] | |
| ) | |
| elif key == "celebs_super_resolution": | |
| self.transforms[key] = transforms.Compose( | |
| [ | |
| transforms.Resize((256, 256)), | |
| augmentations.BilinearResize(factors=[16]), | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| ] | |
| ) | |
| def predict( | |
| self, | |
| image: Path = Input(description="input image"), | |
| model: str = Input( | |
| choices=[ | |
| "celebs_sketch_to_face", | |
| "ffhq_frontalize", | |
| "celebs_super_resolution", | |
| "toonify", | |
| ], | |
| description="choose model type", | |
| ), | |
| ) -> Path: | |
| opts = self.opts[model] | |
| opts = Namespace(**opts) | |
| pprint.pprint(opts) | |
| net = pSp(opts) | |
| net.eval() | |
| net.cuda() | |
| print("Model successfully loaded!") | |
| original_image = Image.open(str(image)) | |
| if opts.label_nc == 0: | |
| original_image = original_image.convert("RGB") | |
| else: | |
| original_image = original_image.convert("L") | |
| original_image.resize( | |
| (self.opts[model]["output_size"], self.opts[model]["output_size"]) | |
| ) | |
| # Align Image | |
| if model not in ["celebs_sketch_to_face", "celebs_seg_to_face"]: | |
| input_image = self.run_alignment(str(image)) | |
| else: | |
| input_image = original_image | |
| img_transforms = self.transforms[model] | |
| transformed_image = img_transforms(input_image) | |
| if model in ["celebs_sketch_to_face", "celebs_seg_to_face"]: | |
| latent_mask = [8, 9, 10, 11, 12, 13, 14, 15, 16, 17] | |
| else: | |
| latent_mask = None | |
| with torch.no_grad(): | |
| result_image = run_on_batch( | |
| transformed_image.unsqueeze(0), net, latent_mask | |
| )[0] | |
| input_vis_image = log_input_image(transformed_image, opts) | |
| output_image = tensor2im(result_image) | |
| if model == "celebs_super_resolution": | |
| res = np.concatenate( | |
| [ | |
| np.array( | |
| input_vis_image.resize( | |
| ( | |
| self.opts[model]["output_size"], | |
| self.opts[model]["output_size"], | |
| ) | |
| ) | |
| ), | |
| np.array( | |
| output_image.resize( | |
| ( | |
| self.opts[model]["output_size"], | |
| self.opts[model]["output_size"], | |
| ) | |
| ) | |
| ), | |
| ], | |
| axis=1, | |
| ) | |
| else: | |
| res = np.array( | |
| output_image.resize( | |
| (self.opts[model]["output_size"], self.opts[model]["output_size"]) | |
| ) | |
| ) | |
| out_path = Path(tempfile.mkdtemp()) / "out.png" | |
| Image.fromarray(np.array(res)).save(str(out_path)) | |
| return out_path | |
| def run_alignment(self, image_path): | |
| aligned_image = align_face(filepath=image_path, predictor=self.predictor) | |
| print("Aligned image has shape: {}".format(aligned_image.size)) | |
| return aligned_image | |
| def run_on_batch(inputs, net, latent_mask=None): | |
| if latent_mask is None: | |
| result_batch = net(inputs.to("cuda").float(), randomize_noise=False) | |
| else: | |
| result_batch = [] | |
| for image_idx, input_image in enumerate(inputs): | |
| # get latent vector to inject into our input image | |
| vec_to_inject = np.random.randn(1, 512).astype("float32") | |
| _, latent_to_inject = net( | |
| torch.from_numpy(vec_to_inject).to("cuda"), | |
| input_code=True, | |
| return_latents=True, | |
| ) | |
| # get output image with injected style vector | |
| res = net( | |
| input_image.unsqueeze(0).to("cuda").float(), | |
| latent_mask=latent_mask, | |
| inject_latent=latent_to_inject, | |
| resize=False, | |
| ) | |
| result_batch.append(res) | |
| result_batch = torch.cat(result_batch, dim=0) | |
| return result_batch | |