File size: 4,570 Bytes
ea56e06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d49fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
ea56e06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d49fbf
ea56e06
 
 
 
 
 
 
3d49fbf
ea56e06
3d49fbf
ea56e06
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import pickle
import lpips

from facenet_pytorch import MTCNN
import legacy

# --- THEME DEFINITION ---
# This theme uses the colors and fonts from your website's CSS.
theme = gr.themes.Base(
    primary_hue=gr.themes.colors.blue, # A base blue hue
    font=[gr.themes.GoogleFont("Poppins"), "sans-serif"],
).set(
    button_primary_background_fill="#4A90E2",      # Your --primary-color
    button_primary_background_fill_hover="#357ABD", # Your --secondary-color
    slider_color="#4A90E2",                        # Your --slider-thumb color
    slider_color_dark="#4A90E2"
)
# -------------------------


# --- Load All Models ---
print("Loading all models...")
device = torch.device("cpu")

with open("ffhq.pkl", "rb") as f:
    G = pickle.load(f)['G_ema'].to(device)
print("StyleGAN2 model loaded.")

mtcnn = MTCNN(keep_all=False, device=device)
print("Face detector model loaded.")

gender_direction = np.load("stylegan2directions/gender.npy")
gender_direction = torch.from_numpy(gender_direction).to(torch.float32).to(device)
print("All models and vectors loaded successfully.")

lpips_loss = lpips.LPIPS(net='vgg').to(device).eval()
print("LPIPS model loaded.")

print("Pre-calculating the 'average face' (w_avg)... This may take a moment.")
w_avg_samples = 10000
z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None)
w_avg = torch.mean(w_samples, 0, keepdim=True)
print("w_avg pre-calculation complete.")
# -----------------------------------

def project_face(G, target, w_avg, *, num_steps=100, initial_learning_rate=0.1, progress):
    noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
    w_opt = w_avg.clone().detach().requires_grad_(True)
    optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)

    for step in progress.tqdm(range(num_steps), desc="Projecting Face"):
        synth_images = G.synthesis(w_opt, noise_mode='const')
        dist = lpips_loss(synth_images, target)
        loss = torch.sum(dist)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
    
    return w_opt.detach()

def edit_uploaded_face(uploaded_image, strength, progress=gr.Progress(track_tqdm=True)):
    if uploaded_image is None:
        raise gr.Error("No image uploaded. Please upload an image containing a face.")
    progress(0, desc="Detecting Face")
    input_image = uploaded_image.convert("RGB")
    boxes, _ = mtcnn.detect(input_image)
    if boxes is None:
        raise gr.Error("Could not detect a face. Please try a clearer picture.")
    face_box = boxes[0]
    padding_x = (face_box[2] - face_box[0]) * 0.2
    padding_y = (face_box[3] - face_box[1]) * 0.2
    face_box[0] = max(0, face_box[0] - padding_x)
    face_box[1] = max(0, face_box[1] - padding_y)
    face_box[2] = min(input_image.width, face_box[2] + padding_x)
    face_box[3] = min(input_image.height, face_box[3] + padding_y)
    cropped_face = input_image.crop(face_box)
    target_pil = cropped_face.resize((G.img_resolution, G.img_resolution), Image.LANCZOS)
    target_np = np.array(target_pil, dtype=np.uint8)
    target_tensor = torch.tensor(target_np.transpose([2, 0, 1]), device=device)
    target_tensor = target_tensor.to(torch.float32) / 127.5 - 1
    target_tensor = target_tensor.unsqueeze(0)
    projected_w = project_face(G, target_tensor, w_avg=w_avg, num_steps=100, progress=progress)
    progress(1, desc="Synthesizing Final Image")
    w_edited = projected_w + gender_direction * strength
    img_out = G.synthesis(w_edited, noise_mode='const')
    img_out = (img_out.clamp(-1, 1) + 1) * 127.5
    img_out = img_out.permute(0, 2, 3, 1)[0].cpu().numpy().astype(np.uint8)
    print("Processing complete.")
    return img_out

# Create the Gradio Interface with the custom theme
gr.Interface(
    fn=edit_uploaded_face,
    inputs=[
        gr.Image(label="Upload Image With Face", type="pil"),
        gr.Slider(-5, 5, step=0.1, value=0, label="Gender Strength (← Feminine | Masculine →)")
    ],
    outputs=gr.Image(label="Edited Face"),
    theme=theme, # <--- APPLYING YOUR CUSTOM THEME
    title="Face Editor Backend",
    description="This engine detects a face in the uploaded image, then edits its gender expression.",
).queue().launch()