Spaces:
Sleeping
Sleeping
File size: 13,193 Bytes
95b1715 600df15 95b1715 4a63fe7 95b1715 5fd5d22 95b1715 5fd5d22 95b1715 5fd5d22 95b1715 600df15 95b1715 600df15 95b1715 0154081 95b1715 0154081 95b1715 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 |
import os
import cv2
import PIL
import torch
import subprocess
import numpy as np
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
from pathlib import Path
from omegaconf import OmegaConf
from utils.common_utils import tensor2im, tensor2im_no_tfm, MaskerCantFindFaceError
from datasets.transforms import transforms_registry
from runners.inference_runners import FSEInferenceRunner
def extract_mask(image_path, save_dir_path, trash=0.995):
try:
from models.farl.farl import Masker
except ImportError:
print("Warning: facer module not available, skipping background mask extraction")
# Return a dummy mask path
save_dir_path = Path(save_dir_path)
image_path = Path(image_path)
mask_path = save_dir_path / (image_path.stem + "_mask.jpg")
# Create a simple white mask (no masking)
mask = Image.new("1", (1024, 1024), 1)
mask.save(mask_path)
return mask_path
save_dir_path = Path(save_dir_path)
image_path = Path(image_path)
orig_img = Image.open(image_path).convert("RGB")
transform = transforms.ToTensor()
orig_img_tensor = transform(orig_img)
orig_img_tensor = (orig_img_tensor.unsqueeze(0) * 255).long().cuda()
with torch.inference_mode():
# try to find trashhlod for detecting face
for detector_trash in [0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.01]:
masker = Masker(trash=detector_trash)
faces = masker.face_detector(orig_img_tensor)
if len(faces['image_ids']) != 0:
break
if len(faces['image_ids']) == 0:
raise MaskerCantFindFaceError("Masker's face detector can't find face in your image 😢")
faces = masker.face_parser(orig_img_tensor, faces)
background_mask = F.sigmoid(faces['seg']['logits'][:, 0])
background_mask = background_mask[0].unsqueeze(0)
background_mask = (background_mask >= trash).cpu()
mask_path = save_dir_path / (image_path.stem + "_mask.jpg")
to_save = (background_mask[0] * 255).long().numpy()
mask = Image.fromarray(to_save.astype(np.uint8)).convert("1")
mask.save(mask_path)
backfround_tens = orig_img_tensor[0].cpu() / 255 * background_mask.float().repeat(3, 1, 1)
background = tensor2im_no_tfm(backfround_tens)
back_path = save_dir_path / (image_path.stem + "_back.jpg")
background.save(back_path)
face_tens = orig_img_tensor[0].cpu() / 255 * (1 - background_mask.float()).repeat(3, 1, 1)
face = tensor2im_no_tfm(face_tens)
face_path = save_dir_path / (image_path.stem + "_face.jpg")
face.save(face_path)
return mask_path
def run_alignment(image_path):
import dlib
from scripts.align_all_parallel import align_face
print(f"Loading dlib shape predictor from: pretrained_models/shape_predictor_68_face_landmarks.dat")
predictor = dlib.shape_predictor("pretrained_models/shape_predictor_68_face_landmarks.dat")
print(f"Running face alignment on: {image_path}")
aligned_image, unalign_dict = align_face(filepath=image_path, predictor=predictor)
print(f"Face alignment completed successfully")
return aligned_image, unalign_dict
def unalign(edited_image, unalign_dict, orig_img_pth, unaligned_path):
quad = unalign_dict["quad"]
source_quad = [(0, 0), (1024, 0), (1024, 1024), (0, 1024)]
dest_quad = np.array([quad[3], quad[0], quad[1], quad[2]])
M = cv2.getPerspectiveTransform(dest_quad.astype(np.float32), np.array(source_quad).astype(np.float32))
unaligned = edited_image.transpose(PIL.Image.FLIP_LEFT_RIGHT).transform(unalign_dict["pretrans_size"], PIL.Image.PERSPECTIVE, M.reshape(-1), PIL.Image.BILINEAR)
mask = np.asarray(unaligned) > 0
mask = np.stack([mask[:,:,0] | mask[:,:,1] | mask[:,:,2]] * 3, axis=-1)
if "blur1" in unalign_dict:
unaligned -= unalign_dict["blur2"]
unaligned -= unalign_dict["blur1"]
pad = unalign_dict["pad"]
unaligned = PIL.Image.fromarray(np.uint8(np.clip(np.rint(unaligned), 0, 255)), 'RGB').crop([pad[1], pad[0], unaligned.shape[1] - pad[3], unaligned.shape[0] - pad[2]])
mask = mask[pad[0]:mask.shape[0]-pad[1], pad[2]:mask.shape[1]-pad[3]]
img_orig = PIL.Image.open(orig_img_pth).convert("RGB")
if "crop" in unalign_dict:
crop = unalign_dict["crop"]
unaligned = np.pad(np.float32(unaligned), ((crop[1], img_orig.size[1] - crop[3]), (crop[0], img_orig.size[0] - crop[2]), (0, 0)))
mask = np.pad(np.float32(mask), ((crop[1], img_orig.size[1] - crop[3]), (crop[0], img_orig.size[0] - crop[2]), (0, 0)))
unaligned = PIL.Image.fromarray(np.uint8(np.clip(np.rint(unaligned), 0, 255)), 'RGB')
if "shrink" in unalign_dict:
unaligned = unaligned.resize(unalign_dict["shrink"])
mask = mask.resize(unalign_dict["shrink"])
unaligned = np.asarray(img_orig) * (1 - mask / mask.max()) + np.asarray(unaligned) * mask / mask.max()
PIL.Image.fromarray(unaligned.astype('uint8'), 'RGB').save("edited.png")
PIL.Image.fromarray(np.uint8(np.clip(np.rint((1 - mask) * 255), 0, 255)), 'RGB').save("mask.jpg")
try:
subprocess.run(
["fpie", "-s", orig_img_pth, "-m", "mask.jpg", "-t", "edited.png", "-o", unaligned_path, "-n",
"5000", "-b", "taichi-gpu", "-g", "src"],
check=True
)
except FileNotFoundError:
print("Warning: fpie command not available, skipping unalign step")
# Just copy the edited image as the final result
PIL.Image.open("edited.png").save(unaligned_path)
class SimpleRunner:
def __init__(
self,
editor_ckpt_pth: str,
simple_config_pth: str = "configs/simple_inference.yaml"
):
print(f"Initializing SimpleRunner with checkpoint: {editor_ckpt_pth}")
try:
config = OmegaConf.load(simple_config_pth)
config.model.checkpoint_path = editor_ckpt_pth
config.methods_args.fse_full = {}
print("Configuration loaded successfully")
self.inference_runner = FSEInferenceRunner(config)
print("FSEInferenceRunner created")
self.inference_runner.setup()
print("Inference runner setup completed")
self.inference_runner.method.eval()
print("Model set to evaluation mode")
self.inference_runner.method.decoder = self.inference_runner.method.decoder.float()
print("Decoder converted to float precision")
print("SimpleRunner initialization completed successfully")
except Exception as e:
print(f"Error during SimpleRunner initialization: {e}")
raise
def edit(
self,
orig_img_pth: str,
editing_name: str,
edited_power: float,
save_pth: str,
align: bool = False,
use_mask: bool = False,
mask_trashold=0.995,
mask_path: str = None,
save_e4e=False,
save_inversion=False
):
try:
print(f"Starting edit: {editing_name} with power {edited_power}")
print(f"Input image: {orig_img_pth}")
print(f"Output path: {save_pth}")
print(f"Face alignment: {align}")
print(f"Use mask: {use_mask}")
save_pth = Path(save_pth)
save_pth_dir = save_pth.parents[0]
save_pth_dir.mkdir(parents=True, exist_ok=True)
aligned_image_pth = orig_img_pth
if align:
print(f"Running face alignment on {orig_img_pth}")
try:
aligned_image, unalign_dict = run_alignment(orig_img_pth)
save_align_pth = save_pth.parents[0] / (save_pth.stem + "_aligned.jpg")
print(f"Save aligned image to {save_align_pth}")
aligned_image.convert('RGB').save(save_align_pth)
aligned_image_pth = save_align_pth
print(f"Face alignment completed. Using aligned image: {aligned_image_pth}")
except Exception as e:
print(f"Face alignment failed: {e}")
print("Continuing without alignment...")
align = False
if use_mask and mask_path is None:
print("Preparing mask")
try:
mask_path = extract_mask(aligned_image_pth, save_pth.parents[0], trash=mask_trashold)
print("Mask extraction completed")
except Exception as e:
print(f"Mask extraction failed: {e}")
print("Continuing without mask...")
use_mask = False
if use_mask and mask_path is not None:
print(f"Using mask from {mask_path}")
mask = Image.open(mask_path).convert("RGB")
transform = transforms.ToTensor()
mask = transform(mask).unsqueeze(0).to(self.inference_runner.device)
else:
mask = None
print("Loading and preprocessing image")
orig_img = Image.open(aligned_image_pth).convert("RGB")
transform_dict = transforms_registry["face_1024"]().get_transforms()
orig_img = transform_dict["test"](orig_img).unsqueeze(0)
device = self.inference_runner.device
print(f"Using device: {device}")
print("Running image inversion")
inv_images, inversion_results = self.inference_runner._run_on_batch(orig_img.to(device))
print("Image inversion completed")
print(f"Running editing: {editing_name}")
edited_image = self.inference_runner._run_editing_on_batch(
method_res_batch=inversion_results,
editing_name=editing_name,
editing_degrees=[edited_power],
mask=mask,
return_e4e=save_e4e
)
print("Editing completed")
if save_inversion:
save_inv_pth = save_pth.parents[0] / (save_pth.stem + "_inversion.jpg")
inv_image = tensor2im(inv_images[0].cpu())
inv_image.save(save_inv_pth)
if save_e4e:
edited_image, e4e_inv, e4e_edit = edited_image
save_e4e_inv_pth = save_pth.parents[0] / (save_pth.stem + "_e4e_inversion.jpg")
e4e_inv_image = tensor2im(e4e_inv[0].cpu())
e4e_inv_image.save(save_e4e_inv_pth)
save_e4e_edit_pth = save_pth.parents[0] / (save_pth.stem + "_e4e_edit.jpg")
e4e_edit_image = tensor2im(e4e_edit[0].cpu())
e4e_edit_image.save(save_e4e_edit_pth)
print("Converting and saving final result")
edited_image = tensor2im(edited_image[0][0].cpu())
edited_image.save(save_pth)
print(f"Final result saved to: {save_pth}")
if align:
try:
unaligned_path = save_pth.parents[0] / (save_pth.stem + "_unaligned.jpg")
unalign(edited_image, unalign_dict, orig_img_pth, unaligned_path)
print("Unalign completed")
except Exception as e:
print(f"Unalign failed: {e}")
print("Using aligned result as final output")
print("Edit process completed successfully")
return edited_image
except Exception as e:
print(f"Error during edit process: {e}")
import traceback
traceback.print_exc()
raise
def available_editings(self):
edits_types = []
for field in dir(self.inference_runner.latent_editor):
if "directions" in field.split("_"):
edits_types.append(field)
print("This code handles the following editing directions for following methods:")
available_directions = {}
for edit_type in edits_types:
print(edit_type + ":")
edit_type_directions = getattr(self.inference_runner.latent_editor, edit_type, None).keys()
for direction in edit_type_directions:
print("\t" + direction)
print(GLOBAL_DIRECTIONS_DESC)
GLOBAL_DIRECTIONS_DESC ="""
You can alse use directions from text prompts via StyleClip Global Mapper (https://arxiv.org/abs/2103.17249).
Such directions look as follows: "styleclip_global_{neutral prompt}_{target prompt}_{disentanglement}" where
neutral prompt -- some neutral description of the original image (e.g. "a face")
target prompt -- text that contains the desired edit (e.g. "a smilling face")
disentanglement -- positive number, the more this attribute - the more related attributes will also be changed (e.g.
for grey hair editing, wrinkle, skin colour and glasses may also be edited)
Example: "styleclip_global_face with hair_face with black hair_0.18"
More information about the purpose of directions and their approximate power range can be found in available_directions.txt.
"""
|