Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import PIL | |
| import torch | |
| import subprocess | |
| import numpy as np | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from pathlib import Path | |
| from omegaconf import OmegaConf | |
| from utils.common_utils import tensor2im, tensor2im_no_tfm, MaskerCantFindFaceError | |
| from datasets.transforms import transforms_registry | |
| from runners.inference_runners import FSEInferenceRunner | |
| def extract_mask(image_path, save_dir_path, trash=0.995): | |
| try: | |
| from models.farl.farl import Masker | |
| except ImportError: | |
| print("Warning: facer module not available, skipping background mask extraction") | |
| # Return a dummy mask path | |
| save_dir_path = Path(save_dir_path) | |
| image_path = Path(image_path) | |
| mask_path = save_dir_path / (image_path.stem + "_mask.jpg") | |
| # Create a simple white mask (no masking) | |
| mask = Image.new("1", (1024, 1024), 1) | |
| mask.save(mask_path) | |
| return mask_path | |
| save_dir_path = Path(save_dir_path) | |
| image_path = Path(image_path) | |
| orig_img = Image.open(image_path).convert("RGB") | |
| transform = transforms.ToTensor() | |
| orig_img_tensor = transform(orig_img) | |
| orig_img_tensor = (orig_img_tensor.unsqueeze(0) * 255).long().cuda() | |
| with torch.inference_mode(): | |
| # try to find trashhlod for detecting face | |
| for detector_trash in [0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.01]: | |
| masker = Masker(trash=detector_trash) | |
| faces = masker.face_detector(orig_img_tensor) | |
| if len(faces['image_ids']) != 0: | |
| break | |
| if len(faces['image_ids']) == 0: | |
| raise MaskerCantFindFaceError("Masker's face detector can't find face in your image 😢") | |
| faces = masker.face_parser(orig_img_tensor, faces) | |
| background_mask = F.sigmoid(faces['seg']['logits'][:, 0]) | |
| background_mask = background_mask[0].unsqueeze(0) | |
| background_mask = (background_mask >= trash).cpu() | |
| mask_path = save_dir_path / (image_path.stem + "_mask.jpg") | |
| to_save = (background_mask[0] * 255).long().numpy() | |
| mask = Image.fromarray(to_save.astype(np.uint8)).convert("1") | |
| mask.save(mask_path) | |
| backfround_tens = orig_img_tensor[0].cpu() / 255 * background_mask.float().repeat(3, 1, 1) | |
| background = tensor2im_no_tfm(backfround_tens) | |
| back_path = save_dir_path / (image_path.stem + "_back.jpg") | |
| background.save(back_path) | |
| face_tens = orig_img_tensor[0].cpu() / 255 * (1 - background_mask.float()).repeat(3, 1, 1) | |
| face = tensor2im_no_tfm(face_tens) | |
| face_path = save_dir_path / (image_path.stem + "_face.jpg") | |
| face.save(face_path) | |
| return mask_path | |
| def run_alignment(image_path): | |
| import dlib | |
| from scripts.align_all_parallel import align_face | |
| print(f"Loading dlib shape predictor from: pretrained_models/shape_predictor_68_face_landmarks.dat") | |
| predictor = dlib.shape_predictor("pretrained_models/shape_predictor_68_face_landmarks.dat") | |
| print(f"Running face alignment on: {image_path}") | |
| aligned_image, unalign_dict = align_face(filepath=image_path, predictor=predictor) | |
| print(f"Face alignment completed successfully") | |
| return aligned_image, unalign_dict | |
| def unalign(edited_image, unalign_dict, orig_img_pth, unaligned_path): | |
| quad = unalign_dict["quad"] | |
| source_quad = [(0, 0), (1024, 0), (1024, 1024), (0, 1024)] | |
| dest_quad = np.array([quad[3], quad[0], quad[1], quad[2]]) | |
| M = cv2.getPerspectiveTransform(dest_quad.astype(np.float32), np.array(source_quad).astype(np.float32)) | |
| unaligned = edited_image.transpose(PIL.Image.FLIP_LEFT_RIGHT).transform(unalign_dict["pretrans_size"], PIL.Image.PERSPECTIVE, M.reshape(-1), PIL.Image.BILINEAR) | |
| mask = np.asarray(unaligned) > 0 | |
| mask = np.stack([mask[:,:,0] | mask[:,:,1] | mask[:,:,2]] * 3, axis=-1) | |
| if "blur1" in unalign_dict: | |
| unaligned -= unalign_dict["blur2"] | |
| unaligned -= unalign_dict["blur1"] | |
| pad = unalign_dict["pad"] | |
| unaligned = PIL.Image.fromarray(np.uint8(np.clip(np.rint(unaligned), 0, 255)), 'RGB').crop([pad[1], pad[0], unaligned.shape[1] - pad[3], unaligned.shape[0] - pad[2]]) | |
| mask = mask[pad[0]:mask.shape[0]-pad[1], pad[2]:mask.shape[1]-pad[3]] | |
| img_orig = PIL.Image.open(orig_img_pth).convert("RGB") | |
| if "crop" in unalign_dict: | |
| crop = unalign_dict["crop"] | |
| unaligned = np.pad(np.float32(unaligned), ((crop[1], img_orig.size[1] - crop[3]), (crop[0], img_orig.size[0] - crop[2]), (0, 0))) | |
| mask = np.pad(np.float32(mask), ((crop[1], img_orig.size[1] - crop[3]), (crop[0], img_orig.size[0] - crop[2]), (0, 0))) | |
| unaligned = PIL.Image.fromarray(np.uint8(np.clip(np.rint(unaligned), 0, 255)), 'RGB') | |
| if "shrink" in unalign_dict: | |
| unaligned = unaligned.resize(unalign_dict["shrink"]) | |
| mask = mask.resize(unalign_dict["shrink"]) | |
| unaligned = np.asarray(img_orig) * (1 - mask / mask.max()) + np.asarray(unaligned) * mask / mask.max() | |
| PIL.Image.fromarray(unaligned.astype('uint8'), 'RGB').save("edited.png") | |
| PIL.Image.fromarray(np.uint8(np.clip(np.rint((1 - mask) * 255), 0, 255)), 'RGB').save("mask.jpg") | |
| try: | |
| subprocess.run( | |
| ["fpie", "-s", orig_img_pth, "-m", "mask.jpg", "-t", "edited.png", "-o", unaligned_path, "-n", | |
| "5000", "-b", "taichi-gpu", "-g", "src"], | |
| check=True | |
| ) | |
| except FileNotFoundError: | |
| print("Warning: fpie command not available, skipping unalign step") | |
| # Just copy the edited image as the final result | |
| PIL.Image.open("edited.png").save(unaligned_path) | |
| class SimpleRunner: | |
| def __init__( | |
| self, | |
| editor_ckpt_pth: str, | |
| simple_config_pth: str = "configs/simple_inference.yaml" | |
| ): | |
| print(f"Initializing SimpleRunner with checkpoint: {editor_ckpt_pth}") | |
| try: | |
| config = OmegaConf.load(simple_config_pth) | |
| config.model.checkpoint_path = editor_ckpt_pth | |
| config.methods_args.fse_full = {} | |
| print("Configuration loaded successfully") | |
| self.inference_runner = FSEInferenceRunner(config) | |
| print("FSEInferenceRunner created") | |
| self.inference_runner.setup() | |
| print("Inference runner setup completed") | |
| self.inference_runner.method.eval() | |
| print("Model set to evaluation mode") | |
| self.inference_runner.method.decoder = self.inference_runner.method.decoder.float() | |
| print("Decoder converted to float precision") | |
| print("SimpleRunner initialization completed successfully") | |
| except Exception as e: | |
| print(f"Error during SimpleRunner initialization: {e}") | |
| raise | |
| def edit( | |
| self, | |
| orig_img_pth: str, | |
| editing_name: str, | |
| edited_power: float, | |
| save_pth: str, | |
| align: bool = False, | |
| use_mask: bool = False, | |
| mask_trashold=0.995, | |
| mask_path: str = None, | |
| save_e4e=False, | |
| save_inversion=False | |
| ): | |
| try: | |
| print(f"Starting edit: {editing_name} with power {edited_power}") | |
| print(f"Input image: {orig_img_pth}") | |
| print(f"Output path: {save_pth}") | |
| print(f"Face alignment: {align}") | |
| print(f"Use mask: {use_mask}") | |
| save_pth = Path(save_pth) | |
| save_pth_dir = save_pth.parents[0] | |
| save_pth_dir.mkdir(parents=True, exist_ok=True) | |
| aligned_image_pth = orig_img_pth | |
| if align: | |
| print(f"Running face alignment on {orig_img_pth}") | |
| try: | |
| aligned_image, unalign_dict = run_alignment(orig_img_pth) | |
| save_align_pth = save_pth.parents[0] / (save_pth.stem + "_aligned.jpg") | |
| print(f"Save aligned image to {save_align_pth}") | |
| aligned_image.convert('RGB').save(save_align_pth) | |
| aligned_image_pth = save_align_pth | |
| print(f"Face alignment completed. Using aligned image: {aligned_image_pth}") | |
| except Exception as e: | |
| print(f"Face alignment failed: {e}") | |
| print("Continuing without alignment...") | |
| align = False | |
| if use_mask and mask_path is None: | |
| print("Preparing mask") | |
| try: | |
| mask_path = extract_mask(aligned_image_pth, save_pth.parents[0], trash=mask_trashold) | |
| print("Mask extraction completed") | |
| except Exception as e: | |
| print(f"Mask extraction failed: {e}") | |
| print("Continuing without mask...") | |
| use_mask = False | |
| if use_mask and mask_path is not None: | |
| print(f"Using mask from {mask_path}") | |
| mask = Image.open(mask_path).convert("RGB") | |
| transform = transforms.ToTensor() | |
| mask = transform(mask).unsqueeze(0).to(self.inference_runner.device) | |
| else: | |
| mask = None | |
| print("Loading and preprocessing image") | |
| orig_img = Image.open(aligned_image_pth).convert("RGB") | |
| transform_dict = transforms_registry["face_1024"]().get_transforms() | |
| orig_img = transform_dict["test"](orig_img).unsqueeze(0) | |
| device = self.inference_runner.device | |
| print(f"Using device: {device}") | |
| print("Running image inversion") | |
| inv_images, inversion_results = self.inference_runner._run_on_batch(orig_img.to(device)) | |
| print("Image inversion completed") | |
| print(f"Running editing: {editing_name}") | |
| edited_image = self.inference_runner._run_editing_on_batch( | |
| method_res_batch=inversion_results, | |
| editing_name=editing_name, | |
| editing_degrees=[edited_power], | |
| mask=mask, | |
| return_e4e=save_e4e | |
| ) | |
| print("Editing completed") | |
| if save_inversion: | |
| save_inv_pth = save_pth.parents[0] / (save_pth.stem + "_inversion.jpg") | |
| inv_image = tensor2im(inv_images[0].cpu()) | |
| inv_image.save(save_inv_pth) | |
| if save_e4e: | |
| edited_image, e4e_inv, e4e_edit = edited_image | |
| save_e4e_inv_pth = save_pth.parents[0] / (save_pth.stem + "_e4e_inversion.jpg") | |
| e4e_inv_image = tensor2im(e4e_inv[0].cpu()) | |
| e4e_inv_image.save(save_e4e_inv_pth) | |
| save_e4e_edit_pth = save_pth.parents[0] / (save_pth.stem + "_e4e_edit.jpg") | |
| e4e_edit_image = tensor2im(e4e_edit[0].cpu()) | |
| e4e_edit_image.save(save_e4e_edit_pth) | |
| print("Converting and saving final result") | |
| edited_image = tensor2im(edited_image[0][0].cpu()) | |
| edited_image.save(save_pth) | |
| print(f"Final result saved to: {save_pth}") | |
| if align: | |
| try: | |
| unaligned_path = save_pth.parents[0] / (save_pth.stem + "_unaligned.jpg") | |
| unalign(edited_image, unalign_dict, orig_img_pth, unaligned_path) | |
| print("Unalign completed") | |
| except Exception as e: | |
| print(f"Unalign failed: {e}") | |
| print("Using aligned result as final output") | |
| print("Edit process completed successfully") | |
| return edited_image | |
| except Exception as e: | |
| print(f"Error during edit process: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise | |
| def available_editings(self): | |
| edits_types = [] | |
| for field in dir(self.inference_runner.latent_editor): | |
| if "directions" in field.split("_"): | |
| edits_types.append(field) | |
| print("This code handles the following editing directions for following methods:") | |
| available_directions = {} | |
| for edit_type in edits_types: | |
| print(edit_type + ":") | |
| edit_type_directions = getattr(self.inference_runner.latent_editor, edit_type, None).keys() | |
| for direction in edit_type_directions: | |
| print("\t" + direction) | |
| print(GLOBAL_DIRECTIONS_DESC) | |
| GLOBAL_DIRECTIONS_DESC =""" | |
| You can alse use directions from text prompts via StyleClip Global Mapper (https://arxiv.org/abs/2103.17249). | |
| Such directions look as follows: "styleclip_global_{neutral prompt}_{target prompt}_{disentanglement}" where | |
| neutral prompt -- some neutral description of the original image (e.g. "a face") | |
| target prompt -- text that contains the desired edit (e.g. "a smilling face") | |
| disentanglement -- positive number, the more this attribute - the more related attributes will also be changed (e.g. | |
| for grey hair editing, wrinkle, skin colour and glasses may also be edited) | |
| Example: "styleclip_global_face with hair_face with black hair_0.18" | |
| More information about the purpose of directions and their approximate power range can be found in available_directions.txt. | |
| """ | |