|
|
import sys |
|
|
import os |
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image |
|
|
import pickle |
|
|
import lpips |
|
|
|
|
|
from facenet_pytorch import MTCNN |
|
|
import legacy |
|
|
|
|
|
|
|
|
|
|
|
theme = gr.themes.Base( |
|
|
primary_hue=gr.themes.colors.blue, |
|
|
font=[gr.themes.GoogleFont("Poppins"), "sans-serif"], |
|
|
).set( |
|
|
button_primary_background_fill="#4A90E2", |
|
|
button_primary_background_fill_hover="#357ABD", |
|
|
slider_color="#4A90E2", |
|
|
slider_color_dark="#4A90E2" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading all models...") |
|
|
device = torch.device("cpu") |
|
|
|
|
|
with open("ffhq.pkl", "rb") as f: |
|
|
G = pickle.load(f)['G_ema'].to(device) |
|
|
print("StyleGAN2 model loaded.") |
|
|
|
|
|
mtcnn = MTCNN(keep_all=False, device=device) |
|
|
print("Face detector model loaded.") |
|
|
|
|
|
gender_direction = np.load("stylegan2directions/gender.npy") |
|
|
gender_direction = torch.from_numpy(gender_direction).to(torch.float32).to(device) |
|
|
print("All models and vectors loaded successfully.") |
|
|
|
|
|
lpips_loss = lpips.LPIPS(net='vgg').to(device).eval() |
|
|
print("LPIPS model loaded.") |
|
|
|
|
|
print("Pre-calculating the 'average face' (w_avg)... This may take a moment.") |
|
|
w_avg_samples = 10000 |
|
|
z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) |
|
|
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) |
|
|
w_avg = torch.mean(w_samples, 0, keepdim=True) |
|
|
print("w_avg pre-calculation complete.") |
|
|
|
|
|
|
|
|
def project_face(G, target, w_avg, *, num_steps=100, initial_learning_rate=0.1, progress): |
|
|
noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name } |
|
|
w_opt = w_avg.clone().detach().requires_grad_(True) |
|
|
optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate) |
|
|
|
|
|
for step in progress.tqdm(range(num_steps), desc="Projecting Face"): |
|
|
synth_images = G.synthesis(w_opt, noise_mode='const') |
|
|
dist = lpips_loss(synth_images, target) |
|
|
loss = torch.sum(dist) |
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
return w_opt.detach() |
|
|
|
|
|
def edit_uploaded_face(uploaded_image, strength, progress=gr.Progress(track_tqdm=True)): |
|
|
if uploaded_image is None: |
|
|
raise gr.Error("No image uploaded. Please upload an image containing a face.") |
|
|
progress(0, desc="Detecting Face") |
|
|
input_image = uploaded_image.convert("RGB") |
|
|
boxes, _ = mtcnn.detect(input_image) |
|
|
if boxes is None: |
|
|
raise gr.Error("Could not detect a face. Please try a clearer picture.") |
|
|
face_box = boxes[0] |
|
|
padding_x = (face_box[2] - face_box[0]) * 0.2 |
|
|
padding_y = (face_box[3] - face_box[1]) * 0.2 |
|
|
face_box[0] = max(0, face_box[0] - padding_x) |
|
|
face_box[1] = max(0, face_box[1] - padding_y) |
|
|
face_box[2] = min(input_image.width, face_box[2] + padding_x) |
|
|
face_box[3] = min(input_image.height, face_box[3] + padding_y) |
|
|
cropped_face = input_image.crop(face_box) |
|
|
target_pil = cropped_face.resize((G.img_resolution, G.img_resolution), Image.LANCZOS) |
|
|
target_np = np.array(target_pil, dtype=np.uint8) |
|
|
target_tensor = torch.tensor(target_np.transpose([2, 0, 1]), device=device) |
|
|
target_tensor = target_tensor.to(torch.float32) / 127.5 - 1 |
|
|
target_tensor = target_tensor.unsqueeze(0) |
|
|
projected_w = project_face(G, target_tensor, w_avg=w_avg, num_steps=100, progress=progress) |
|
|
progress(1, desc="Synthesizing Final Image") |
|
|
w_edited = projected_w + gender_direction * strength |
|
|
img_out = G.synthesis(w_edited, noise_mode='const') |
|
|
img_out = (img_out.clamp(-1, 1) + 1) * 127.5 |
|
|
img_out = img_out.permute(0, 2, 3, 1)[0].cpu().numpy().astype(np.uint8) |
|
|
print("Processing complete.") |
|
|
return img_out |
|
|
|
|
|
|
|
|
gr.Interface( |
|
|
fn=edit_uploaded_face, |
|
|
inputs=[ |
|
|
gr.Image(label="Upload Image With Face", type="pil"), |
|
|
gr.Slider(-5, 5, step=0.1, value=0, label="Gender Strength (← Feminine | Masculine →)") |
|
|
], |
|
|
outputs=gr.Image(label="Edited Face"), |
|
|
theme=theme, |
|
|
title="Face Editor Backend", |
|
|
description="This engine detects a face in the uploaded image, then edits its gender expression.", |
|
|
).queue().launch() |