Spaces:
Configuration error
Configuration error
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from typing import Iterable, Tuple, Union | |
| from pathlib import Path | |
| from torchvision import transforms | |
| import kornia | |
| from omegaconf import DictConfig | |
| from src.FaceDetector.face_detector import Detection | |
| from src.FaceAlign.face_align import align_face, inverse_transform_batch | |
| from src.PostProcess.utils import SoftErosion | |
| from src.model_loader import get_model | |
| from src.Misc.types import CheckpointType, FaceAlignmentType | |
| from src.Misc.utils import tensor2img | |
| class SimSwap: | |
| def __init__( | |
| self, | |
| config: DictConfig, | |
| id_image: Union[np.ndarray, None] = None, | |
| specific_image: Union[np.ndarray, None] = None, | |
| ): | |
| self.id_image: Union[np.ndarray, None] = id_image | |
| self.id_latent: Union[torch.Tensor, None] = None | |
| self.specific_id_image: Union[np.ndarray, None] = specific_image | |
| self.specific_latent: Union[torch.Tensor, None] = None | |
| self.use_mask: Union[bool, None] = True | |
| self.crop_size: Union[int, None] = None | |
| self.checkpoint_type: Union[CheckpointType, None] = None | |
| self.face_alignment_type: Union[FaceAlignmentType, None] = None | |
| self.smooth_mask_iter: Union[int, None] = None | |
| self.smooth_mask_kernel_size: Union[int, None] = None | |
| self.smooth_mask_threshold: Union[float, None] = None | |
| self.face_detector_threshold: Union[float, None] = None | |
| self.specific_latent_match_threshold: Union[float, None] = None | |
| self.device = torch.device(config.device) | |
| self.set_parameters(config) | |
| # For BiSeNet and for official_224 SimSwap | |
| self.to_tensor_normalize = transforms.Compose( | |
| [ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| # For SimSwap models trained with the updated code | |
| self.to_tensor = transforms.ToTensor() | |
| self.face_detector = get_model( | |
| "face_detector", | |
| device=self.device, | |
| load_state_dice=False, | |
| model_path=Path(config.face_detector_weights), | |
| det_thresh=self.face_detector_threshold, | |
| det_size=(640, 640), | |
| mode="ffhq", | |
| ) | |
| self.face_id_net = get_model( | |
| "arcface", | |
| device=self.device, | |
| load_state_dice=False, | |
| model_path=Path(config.face_id_weights), | |
| ) | |
| self.bise_net = get_model( | |
| "parsing_model", | |
| device=self.device, | |
| load_state_dice=True, | |
| model_path=Path(config.parsing_model_weights), | |
| n_classes=19, | |
| ) | |
| gen_model = "generator_512" if self.crop_size == 512 else "generator_224" | |
| self.simswap_net = get_model( | |
| gen_model, | |
| device=self.device, | |
| load_state_dice=True, | |
| model_path=Path(config.simswap_weights), | |
| input_nc=3, | |
| output_nc=3, | |
| latent_size=512, | |
| n_blocks=9, | |
| deep=True if self.crop_size == 512 else False, | |
| use_last_act=True | |
| if self.checkpoint_type == CheckpointType.OFFICIAL_224 | |
| else False, | |
| ) | |
| self.blend = get_model( | |
| "blend_module", | |
| device=self.device, | |
| load_state_dice=False, | |
| model_path=Path(config.blend_module_weights) | |
| ) | |
| self.enhance_output = config.enhance_output | |
| if config.enhance_output: | |
| self.gfpgan_net = get_model( | |
| "gfpgan", | |
| device=self.device, | |
| load_state_dice=True, | |
| model_path=Path(config.gfpgan_weights) | |
| ) | |
| def set_parameters(self, config) -> None: | |
| self.set_crop_size(config.crop_size) | |
| self.set_checkpoint_type(config.checkpoint_type) | |
| self.set_face_alignment_type(config.face_alignment_type) | |
| self.set_face_detector_threshold(config.face_detector_threshold) | |
| self.set_specific_latent_match_threshold(config.specific_latent_match_threshold) | |
| self.set_smooth_mask_kernel_size(config.smooth_mask_kernel_size) | |
| self.set_smooth_mask_threshold(config.smooth_mask_threshold) | |
| self.set_smooth_mask_iter(config.smooth_mask_iter) | |
| def set_crop_size(self, crop_size: int) -> None: | |
| if crop_size < 0: | |
| raise "Invalid crop_size! Must be a positive value." | |
| self.crop_size = crop_size | |
| def set_checkpoint_type(self, checkpoint_type: str) -> None: | |
| type = CheckpointType(checkpoint_type) | |
| if type not in (CheckpointType.OFFICIAL_224, CheckpointType.UNOFFICIAL): | |
| raise "Invalid checkpoint_type! Must be one of the predefined values." | |
| self.checkpoint_type = type | |
| def set_face_alignment_type(self, face_alignment_type: str) -> None: | |
| type = FaceAlignmentType(face_alignment_type) | |
| if type not in ( | |
| FaceAlignmentType.FFHQ, | |
| FaceAlignmentType.DEFAULT, | |
| ): | |
| raise "Invalid face_alignment_type! Must be one of the predefined values." | |
| self.face_alignment_type = type | |
| def set_face_detector_threshold(self, face_detector_threshold: float) -> None: | |
| if face_detector_threshold < 0.0 or face_detector_threshold > 1.0: | |
| raise "Invalid face_detector_threshold! Must be a positive value in range [0.0...1.0]." | |
| self.face_detector_threshold = face_detector_threshold | |
| def set_specific_latent_match_threshold( | |
| self, specific_latent_match_threshold: float | |
| ) -> None: | |
| if specific_latent_match_threshold < 0.0: | |
| raise "Invalid specific_latent_match_th! Must be a positive value." | |
| self.specific_latent_match_threshold = specific_latent_match_threshold | |
| def re_initialize_soft_mask(self): | |
| self.smooth_mask = SoftErosion(kernel_size=self.smooth_mask_kernel_size, | |
| threshold=self.smooth_mask_threshold, | |
| iterations=self.smooth_mask_iter).to(self.device) | |
| def set_smooth_mask_kernel_size(self, smooth_mask_kernel_size: int) -> None: | |
| if smooth_mask_kernel_size < 0: | |
| raise "Invalid smooth_mask_kernel_size! Must be a positive value." | |
| smooth_mask_kernel_size += 1 if smooth_mask_kernel_size % 2 == 0 else 0 | |
| self.smooth_mask_kernel_size = smooth_mask_kernel_size | |
| self.re_initialize_soft_mask() | |
| def set_smooth_mask_threshold(self, smooth_mask_threshold: int) -> None: | |
| if smooth_mask_threshold < 0 or smooth_mask_threshold > 1.0: | |
| raise "Invalid smooth_mask_threshold! Must be within 0...1 range." | |
| self.smooth_mask_threshold = smooth_mask_threshold | |
| self.re_initialize_soft_mask() | |
| def set_smooth_mask_iter(self, smooth_mask_iter: float) -> None: | |
| if smooth_mask_iter < 0: | |
| raise "Invalid smooth_mask_iter! Must be a positive value.." | |
| self.smooth_mask_iter = smooth_mask_iter | |
| self.re_initialize_soft_mask() | |
| def run_detect_align(self, image: np.ndarray, for_id: bool = False) -> Tuple[Union[Iterable[np.ndarray], None], | |
| Union[Iterable[np.ndarray], None], | |
| np.ndarray]: | |
| detection: Detection = self.face_detector(image) | |
| if detection.bbox is None: | |
| if for_id: | |
| raise "Can't detect a face! Please change the ID image!" | |
| return None, None, detection.score | |
| kps = detection.key_points | |
| if for_id: | |
| max_score_ind = np.argmax(detection.score, axis=0) | |
| kps = detection.key_points[max_score_ind] | |
| kps = kps[None, ...] | |
| align_imgs, transforms = align_face( | |
| image, | |
| kps, | |
| crop_size=self.crop_size, | |
| mode="ffhq" | |
| if self.face_alignment_type == FaceAlignmentType.FFHQ | |
| else "none", | |
| ) | |
| return align_imgs, transforms, detection.score | |
| def __call__(self, att_image: np.ndarray) -> np.ndarray: | |
| if self.id_latent is None: | |
| align_id_imgs, id_transforms, _ = self.run_detect_align( | |
| self.id_image, for_id=True | |
| ) | |
| # normalize=True, because official SimSwap model trained with normalized id_lattent | |
| self.id_latent: torch.Tensor = self.face_id_net( | |
| align_id_imgs, normalize=True | |
| ) | |
| if self.specific_id_image is not None and self.specific_latent is None: | |
| align_specific_imgs, specific_transforms, _ = self.run_detect_align( | |
| self.specific_id_image, for_id=True | |
| ) | |
| self.specific_latent: torch.Tensor = self.face_id_net( | |
| align_specific_imgs, normalize=False | |
| ) | |
| # for_id=False, because we want to get all faces | |
| align_att_imgs, att_transforms, att_detection_score = self.run_detect_align( | |
| att_image, for_id=False | |
| ) | |
| if align_att_imgs is None and att_transforms is None: | |
| return att_image | |
| # Select specific crop from the target image | |
| if self.specific_latent is not None: | |
| att_latent: torch.Tensor = self.face_id_net(align_att_imgs, normalize=False) | |
| latent_dist = torch.mean( | |
| F.mse_loss( | |
| att_latent, | |
| self.specific_latent.repeat(att_latent.shape[0], 1), | |
| reduction="none", | |
| ), | |
| dim=-1, | |
| ) | |
| att_detection_score = torch.tensor( | |
| att_detection_score, device=latent_dist.device | |
| ) | |
| min_index = torch.argmin(latent_dist * att_detection_score) | |
| min_value = latent_dist[min_index] | |
| if min_value < self.specific_latent_match_threshold: | |
| align_att_imgs = [align_att_imgs[min_index]] | |
| att_transforms = [att_transforms[min_index]] | |
| else: | |
| return att_image | |
| swapped_img: torch.Tensor = self.simswap_net(align_att_imgs, self.id_latent) | |
| if self.enhance_output: | |
| swapped_img = self.gfpgan_net.enhance(swapped_img, weight=0.5) | |
| # Put all crops/transformations into a batch | |
| align_att_img_batch_for_parsing_model: torch.Tensor = torch.stack( | |
| [self.to_tensor_normalize(x) for x in align_att_imgs], dim=0 | |
| ) | |
| align_att_img_batch_for_parsing_model = ( | |
| align_att_img_batch_for_parsing_model.to(self.device) | |
| ) | |
| att_transforms: torch.Tensor = torch.stack( | |
| [torch.tensor(x).float() for x in att_transforms], dim=0 | |
| ) | |
| att_transforms = att_transforms.to(self.device, non_blocking=True) | |
| align_att_img_batch: torch.Tensor = torch.stack( | |
| [self.to_tensor(x) for x in align_att_imgs], dim=0 | |
| ) | |
| align_att_img_batch = align_att_img_batch.to(self.device, non_blocking=True) | |
| # Get face masks for the attribute image | |
| face_mask, ignore_mask_ids = self.bise_net.get_mask( | |
| align_att_img_batch_for_parsing_model, self.crop_size | |
| ) | |
| inv_att_transforms: torch.Tensor = inverse_transform_batch(att_transforms) | |
| soft_face_mask, _ = self.smooth_mask(face_mask) | |
| swapped_img[ignore_mask_ids, ...] = align_att_img_batch[ignore_mask_ids, ...] | |
| frame_size = (att_image.shape[0], att_image.shape[1]) | |
| att_image = self.to_tensor(att_image).to(self.device, non_blocking=True).unsqueeze(0) | |
| target_image = kornia.geometry.transform.warp_affine( | |
| swapped_img, | |
| inv_att_transforms, | |
| frame_size, | |
| mode="bilinear", | |
| padding_mode="border", | |
| align_corners=True, | |
| fill_value=torch.zeros(3), | |
| ) | |
| soft_face_mask = kornia.geometry.transform.warp_affine( | |
| soft_face_mask, | |
| inv_att_transforms, | |
| frame_size, | |
| mode="bilinear", | |
| padding_mode="zeros", | |
| align_corners=True, | |
| fill_value=torch.zeros(3), | |
| ) | |
| result = self.blend(target_image, soft_face_mask, att_image) | |
| return tensor2img(result) | |