import os import cv2 import PIL import torch import subprocess import numpy as np import torch.nn.functional as F import torchvision.transforms as transforms from PIL import Image from pathlib import Path from omegaconf import OmegaConf from utils.common_utils import tensor2im, tensor2im_no_tfm, MaskerCantFindFaceError from datasets.transforms import transforms_registry from runners.inference_runners import FSEInferenceRunner def extract_mask(image_path, save_dir_path, trash=0.995): try: from models.farl.farl import Masker except ImportError: print("Warning: facer module not available, skipping background mask extraction") # Return a dummy mask path save_dir_path = Path(save_dir_path) image_path = Path(image_path) mask_path = save_dir_path / (image_path.stem + "_mask.jpg") # Create a simple white mask (no masking) mask = Image.new("1", (1024, 1024), 1) mask.save(mask_path) return mask_path save_dir_path = Path(save_dir_path) image_path = Path(image_path) orig_img = Image.open(image_path).convert("RGB") transform = transforms.ToTensor() orig_img_tensor = transform(orig_img) orig_img_tensor = (orig_img_tensor.unsqueeze(0) * 255).long().cuda() with torch.inference_mode(): # try to find trashhlod for detecting face for detector_trash in [0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.01]: masker = Masker(trash=detector_trash) faces = masker.face_detector(orig_img_tensor) if len(faces['image_ids']) != 0: break if len(faces['image_ids']) == 0: raise MaskerCantFindFaceError("Masker's face detector can't find face in your image 😢") faces = masker.face_parser(orig_img_tensor, faces) background_mask = F.sigmoid(faces['seg']['logits'][:, 0]) background_mask = background_mask[0].unsqueeze(0) background_mask = (background_mask >= trash).cpu() mask_path = save_dir_path / (image_path.stem + "_mask.jpg") to_save = (background_mask[0] * 255).long().numpy() mask = Image.fromarray(to_save.astype(np.uint8)).convert("1") mask.save(mask_path) backfround_tens = orig_img_tensor[0].cpu() / 255 * background_mask.float().repeat(3, 1, 1) background = tensor2im_no_tfm(backfround_tens) back_path = save_dir_path / (image_path.stem + "_back.jpg") background.save(back_path) face_tens = orig_img_tensor[0].cpu() / 255 * (1 - background_mask.float()).repeat(3, 1, 1) face = tensor2im_no_tfm(face_tens) face_path = save_dir_path / (image_path.stem + "_face.jpg") face.save(face_path) return mask_path def run_alignment(image_path): import dlib from scripts.align_all_parallel import align_face print(f"Loading dlib shape predictor from: pretrained_models/shape_predictor_68_face_landmarks.dat") predictor = dlib.shape_predictor("pretrained_models/shape_predictor_68_face_landmarks.dat") print(f"Running face alignment on: {image_path}") aligned_image, unalign_dict = align_face(filepath=image_path, predictor=predictor) print(f"Face alignment completed successfully") return aligned_image, unalign_dict def unalign(edited_image, unalign_dict, orig_img_pth, unaligned_path): quad = unalign_dict["quad"] source_quad = [(0, 0), (1024, 0), (1024, 1024), (0, 1024)] dest_quad = np.array([quad[3], quad[0], quad[1], quad[2]]) M = cv2.getPerspectiveTransform(dest_quad.astype(np.float32), np.array(source_quad).astype(np.float32)) unaligned = edited_image.transpose(PIL.Image.FLIP_LEFT_RIGHT).transform(unalign_dict["pretrans_size"], PIL.Image.PERSPECTIVE, M.reshape(-1), PIL.Image.BILINEAR) mask = np.asarray(unaligned) > 0 mask = np.stack([mask[:,:,0] | mask[:,:,1] | mask[:,:,2]] * 3, axis=-1) if "blur1" in unalign_dict: unaligned -= unalign_dict["blur2"] unaligned -= unalign_dict["blur1"] pad = unalign_dict["pad"] unaligned = PIL.Image.fromarray(np.uint8(np.clip(np.rint(unaligned), 0, 255)), 'RGB').crop([pad[1], pad[0], unaligned.shape[1] - pad[3], unaligned.shape[0] - pad[2]]) mask = mask[pad[0]:mask.shape[0]-pad[1], pad[2]:mask.shape[1]-pad[3]] img_orig = PIL.Image.open(orig_img_pth).convert("RGB") if "crop" in unalign_dict: crop = unalign_dict["crop"] unaligned = np.pad(np.float32(unaligned), ((crop[1], img_orig.size[1] - crop[3]), (crop[0], img_orig.size[0] - crop[2]), (0, 0))) mask = np.pad(np.float32(mask), ((crop[1], img_orig.size[1] - crop[3]), (crop[0], img_orig.size[0] - crop[2]), (0, 0))) unaligned = PIL.Image.fromarray(np.uint8(np.clip(np.rint(unaligned), 0, 255)), 'RGB') if "shrink" in unalign_dict: unaligned = unaligned.resize(unalign_dict["shrink"]) mask = mask.resize(unalign_dict["shrink"]) unaligned = np.asarray(img_orig) * (1 - mask / mask.max()) + np.asarray(unaligned) * mask / mask.max() PIL.Image.fromarray(unaligned.astype('uint8'), 'RGB').save("edited.png") PIL.Image.fromarray(np.uint8(np.clip(np.rint((1 - mask) * 255), 0, 255)), 'RGB').save("mask.jpg") try: subprocess.run( ["fpie", "-s", orig_img_pth, "-m", "mask.jpg", "-t", "edited.png", "-o", unaligned_path, "-n", "5000", "-b", "taichi-gpu", "-g", "src"], check=True ) except FileNotFoundError: print("Warning: fpie command not available, skipping unalign step") # Just copy the edited image as the final result PIL.Image.open("edited.png").save(unaligned_path) class SimpleRunner: def __init__( self, editor_ckpt_pth: str, simple_config_pth: str = "configs/simple_inference.yaml" ): print(f"Initializing SimpleRunner with checkpoint: {editor_ckpt_pth}") try: config = OmegaConf.load(simple_config_pth) config.model.checkpoint_path = editor_ckpt_pth config.methods_args.fse_full = {} print("Configuration loaded successfully") self.inference_runner = FSEInferenceRunner(config) print("FSEInferenceRunner created") self.inference_runner.setup() print("Inference runner setup completed") self.inference_runner.method.eval() print("Model set to evaluation mode") self.inference_runner.method.decoder = self.inference_runner.method.decoder.float() print("Decoder converted to float precision") print("SimpleRunner initialization completed successfully") except Exception as e: print(f"Error during SimpleRunner initialization: {e}") raise def edit( self, orig_img_pth: str, editing_name: str, edited_power: float, save_pth: str, align: bool = False, use_mask: bool = False, mask_trashold=0.995, mask_path: str = None, save_e4e=False, save_inversion=False ): try: print(f"Starting edit: {editing_name} with power {edited_power}") print(f"Input image: {orig_img_pth}") print(f"Output path: {save_pth}") print(f"Face alignment: {align}") print(f"Use mask: {use_mask}") save_pth = Path(save_pth) save_pth_dir = save_pth.parents[0] save_pth_dir.mkdir(parents=True, exist_ok=True) aligned_image_pth = orig_img_pth if align: print(f"Running face alignment on {orig_img_pth}") try: aligned_image, unalign_dict = run_alignment(orig_img_pth) save_align_pth = save_pth.parents[0] / (save_pth.stem + "_aligned.jpg") print(f"Save aligned image to {save_align_pth}") aligned_image.convert('RGB').save(save_align_pth) aligned_image_pth = save_align_pth print(f"Face alignment completed. Using aligned image: {aligned_image_pth}") except Exception as e: print(f"Face alignment failed: {e}") print("Continuing without alignment...") align = False if use_mask and mask_path is None: print("Preparing mask") try: mask_path = extract_mask(aligned_image_pth, save_pth.parents[0], trash=mask_trashold) print("Mask extraction completed") except Exception as e: print(f"Mask extraction failed: {e}") print("Continuing without mask...") use_mask = False if use_mask and mask_path is not None: print(f"Using mask from {mask_path}") mask = Image.open(mask_path).convert("RGB") transform = transforms.ToTensor() mask = transform(mask).unsqueeze(0).to(self.inference_runner.device) else: mask = None print("Loading and preprocessing image") orig_img = Image.open(aligned_image_pth).convert("RGB") transform_dict = transforms_registry["face_1024"]().get_transforms() orig_img = transform_dict["test"](orig_img).unsqueeze(0) device = self.inference_runner.device print(f"Using device: {device}") print("Running image inversion") inv_images, inversion_results = self.inference_runner._run_on_batch(orig_img.to(device)) print("Image inversion completed") print(f"Running editing: {editing_name}") edited_image = self.inference_runner._run_editing_on_batch( method_res_batch=inversion_results, editing_name=editing_name, editing_degrees=[edited_power], mask=mask, return_e4e=save_e4e ) print("Editing completed") if save_inversion: save_inv_pth = save_pth.parents[0] / (save_pth.stem + "_inversion.jpg") inv_image = tensor2im(inv_images[0].cpu()) inv_image.save(save_inv_pth) if save_e4e: edited_image, e4e_inv, e4e_edit = edited_image save_e4e_inv_pth = save_pth.parents[0] / (save_pth.stem + "_e4e_inversion.jpg") e4e_inv_image = tensor2im(e4e_inv[0].cpu()) e4e_inv_image.save(save_e4e_inv_pth) save_e4e_edit_pth = save_pth.parents[0] / (save_pth.stem + "_e4e_edit.jpg") e4e_edit_image = tensor2im(e4e_edit[0].cpu()) e4e_edit_image.save(save_e4e_edit_pth) print("Converting and saving final result") edited_image = tensor2im(edited_image[0][0].cpu()) edited_image.save(save_pth) print(f"Final result saved to: {save_pth}") if align: try: unaligned_path = save_pth.parents[0] / (save_pth.stem + "_unaligned.jpg") unalign(edited_image, unalign_dict, orig_img_pth, unaligned_path) print("Unalign completed") except Exception as e: print(f"Unalign failed: {e}") print("Using aligned result as final output") print("Edit process completed successfully") return edited_image except Exception as e: print(f"Error during edit process: {e}") import traceback traceback.print_exc() raise def available_editings(self): edits_types = [] for field in dir(self.inference_runner.latent_editor): if "directions" in field.split("_"): edits_types.append(field) print("This code handles the following editing directions for following methods:") available_directions = {} for edit_type in edits_types: print(edit_type + ":") edit_type_directions = getattr(self.inference_runner.latent_editor, edit_type, None).keys() for direction in edit_type_directions: print("\t" + direction) print(GLOBAL_DIRECTIONS_DESC) GLOBAL_DIRECTIONS_DESC =""" You can alse use directions from text prompts via StyleClip Global Mapper (https://arxiv.org/abs/2103.17249). Such directions look as follows: "styleclip_global_{neutral prompt}_{target prompt}_{disentanglement}" where neutral prompt -- some neutral description of the original image (e.g. "a face") target prompt -- text that contains the desired edit (e.g. "a smilling face") disentanglement -- positive number, the more this attribute - the more related attributes will also be changed (e.g. for grey hair editing, wrinkle, skin colour and glasses may also be edited) Example: "styleclip_global_face with hair_face with black hair_0.18" More information about the purpose of directions and their approximate power range can be found in available_directions.txt. """