| from __future__ import annotations |
|
|
| import argparse |
| import os |
| import pathlib |
| import subprocess |
| import sys |
| from typing import Callable, Union |
|
|
| import dlib |
| import huggingface_hub |
| import numpy as np |
| import PIL.Image |
| import torch |
| import torch.nn as nn |
| import torchvision.transforms as T |
|
|
| if os.getenv("SYSTEM") == "spaces" and not torch.cuda.is_available(): |
| with open("patch.e4e") as f: |
| subprocess.run("patch -p1".split(), cwd="encoder4editing", stdin=f) |
| with open("patch.hairclip") as f: |
| subprocess.run("patch -p1".split(), cwd="HairCLIP", stdin=f) |
|
|
| app_dir = pathlib.Path(__file__).parent |
|
|
| e4e_dir = app_dir / "encoder4editing" |
| sys.path.insert(0, e4e_dir.as_posix()) |
|
|
| from models.psp import pSp |
| from utils.alignment import align_face |
|
|
| hairclip_dir = app_dir / "HairCLIP" |
| mapper_dir = hairclip_dir / "mapper" |
| sys.path.insert(0, hairclip_dir.as_posix()) |
| sys.path.insert(0, mapper_dir.as_posix()) |
|
|
| from mapper.datasets.latents_dataset_inference import LatentsDatasetInference |
| from mapper.hairclip_mapper import HairCLIPMapper |
|
|
|
|
| class Model: |
| def __init__(self): |
| self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| self.landmark_model = self._create_dlib_landmark_model() |
| self.e4e = self._load_e4e() |
| self.hairclip = self._load_hairclip() |
| self.transform = self._create_transform() |
|
|
| @staticmethod |
| def _create_dlib_landmark_model(): |
| path = huggingface_hub.hf_hub_download( |
| "public-data/dlib_face_landmark_model", "shape_predictor_68_face_landmarks.dat" |
| ) |
| return dlib.shape_predictor(path) |
|
|
| def _load_e4e(self) -> nn.Module: |
| ckpt_path = huggingface_hub.hf_hub_download("public-data/e4e", "e4e_ffhq_encode.pt") |
| ckpt = torch.load(ckpt_path, map_location="cpu") |
| opts = ckpt["opts"] |
| opts["device"] = self.device.type |
| opts["checkpoint_path"] = ckpt_path |
| opts = argparse.Namespace(**opts) |
| model = pSp(opts) |
| model.to(self.device) |
| model.eval() |
| return model |
|
|
| def _load_hairclip(self) -> nn.Module: |
| ckpt_path = huggingface_hub.hf_hub_download("public-data/HairCLIP", "hairclip.pt") |
| ckpt = torch.load(ckpt_path, map_location="cpu") |
| opts = ckpt["opts"] |
| opts["device"] = self.device.type |
| opts["checkpoint_path"] = ckpt_path |
| opts["editing_type"] = "both" |
| opts["input_type"] = "text" |
| opts["hairstyle_description"] = "HairCLIP/mapper/hairstyle_list.txt" |
| opts["color_description"] = "red" |
| opts = argparse.Namespace(**opts) |
| model = HairCLIPMapper(opts) |
| model.to(self.device) |
| model.eval() |
| return model |
|
|
| @staticmethod |
| def _create_transform() -> Callable: |
| transform = T.Compose( |
| [ |
| T.Resize(256), |
| T.CenterCrop(256), |
| T.ToTensor(), |
| T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| ] |
| ) |
| return transform |
|
|
| def detect_and_align_face(self, image: str) -> PIL.Image.Image: |
| image = align_face(filepath=image, predictor=self.landmark_model) |
| return image |
|
|
| @staticmethod |
| def denormalize(tensor: torch.Tensor) -> torch.Tensor: |
| return torch.clamp((tensor + 1) / 2 * 255, 0, 255).to(torch.uint8) |
|
|
| def postprocess(self, tensor: torch.Tensor) -> np.ndarray: |
| tensor = self.denormalize(tensor) |
| return tensor.cpu().numpy().transpose(1, 2, 0) |
|
|
| @torch.inference_mode() |
| def reconstruct_face(self, image: PIL.Image.Image) -> tuple[np.ndarray, torch.Tensor]: |
| input_data = self.transform(image).unsqueeze(0).to(self.device) |
| reconstructed_images, latents = self.e4e(input_data, randomize_noise=False, return_latents=True) |
| reconstructed = torch.clamp(reconstructed_images[0].detach(), -1, 1) |
| reconstructed = self.postprocess(reconstructed) |
| return reconstructed, latents[0] |
|
|
| @torch.inference_mode() |
| def generate( |
| self, editing_type: str, hairstyle_index: int, color_description: str, latent: torch.Tensor |
| ) -> np.ndarray: |
| |
| dataset = self.prepare_dataset(editing_type, hairstyle_index, color_description, latent) |
| if dataset is None or not dataset: |
| raise ValueError("Dataset không được tạo hoặc là rỗng") |
| if dataset[0] is None: |
| raise ValueError("Phần tử đầu tiên của dataset là None, không thể tiếp tục") |
| |
| w, hairstyle_text_inputs_list, color_text_inputs_list = dataset[0][:3] |
| |
|
|
| opts = self.hairclip.opts |
| opts.editing_type = editing_type |
| opts.color_description = color_description |
|
|
| if editing_type == "color": |
| hairstyle_index = 0 |
|
|
| device = torch.device(opts.device) |
|
|
| dataset = LatentsDatasetInference(latents=latent.unsqueeze(0).cpu(), opts=opts) |
| w, hairstyle_text_inputs_list, color_text_inputs_list = dataset[0][:3] |
|
|
| w = w.unsqueeze(0).to(device) |
| hairstyle_text_inputs = hairstyle_text_inputs_list[hairstyle_index].unsqueeze(0).to(device) |
| color_text_inputs = color_text_inputs_list[0].unsqueeze(0).to(device) |
|
|
| hairstyle_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).to(device) |
| color_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).to(device) |
|
|
| w_hat = w + 0.1 * self.hairclip.mapper( |
| w, |
| hairstyle_text_inputs, |
| color_text_inputs, |
| hairstyle_tensor_hairmasked, |
| color_tensor_hairmasked, |
| ) |
| x_hat, _ = self.hairclip.decoder( |
| [w_hat], |
| input_is_latent=True, |
| return_latents=True, |
| randomize_noise=False, |
| truncation=1, |
| ) |
| res = torch.clamp(x_hat[0].detach(), -1, 1) |
| res = self.postprocess(res) |
| return res |
|
|