face / app.py
Diggz10's picture
Update app.py
3d49fbf verified
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import pickle
import lpips
from facenet_pytorch import MTCNN
import legacy
# --- THEME DEFINITION ---
# This theme uses the colors and fonts from your website's CSS.
theme = gr.themes.Base(
primary_hue=gr.themes.colors.blue, # A base blue hue
font=[gr.themes.GoogleFont("Poppins"), "sans-serif"],
).set(
button_primary_background_fill="#4A90E2", # Your --primary-color
button_primary_background_fill_hover="#357ABD", # Your --secondary-color
slider_color="#4A90E2", # Your --slider-thumb color
slider_color_dark="#4A90E2"
)
# -------------------------
# --- Load All Models ---
print("Loading all models...")
device = torch.device("cpu")
with open("ffhq.pkl", "rb") as f:
G = pickle.load(f)['G_ema'].to(device)
print("StyleGAN2 model loaded.")
mtcnn = MTCNN(keep_all=False, device=device)
print("Face detector model loaded.")
gender_direction = np.load("stylegan2directions/gender.npy")
gender_direction = torch.from_numpy(gender_direction).to(torch.float32).to(device)
print("All models and vectors loaded successfully.")
lpips_loss = lpips.LPIPS(net='vgg').to(device).eval()
print("LPIPS model loaded.")
print("Pre-calculating the 'average face' (w_avg)... This may take a moment.")
w_avg_samples = 10000
z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None)
w_avg = torch.mean(w_samples, 0, keepdim=True)
print("w_avg pre-calculation complete.")
# -----------------------------------
def project_face(G, target, w_avg, *, num_steps=100, initial_learning_rate=0.1, progress):
noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
w_opt = w_avg.clone().detach().requires_grad_(True)
optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
for step in progress.tqdm(range(num_steps), desc="Projecting Face"):
synth_images = G.synthesis(w_opt, noise_mode='const')
dist = lpips_loss(synth_images, target)
loss = torch.sum(dist)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
return w_opt.detach()
def edit_uploaded_face(uploaded_image, strength, progress=gr.Progress(track_tqdm=True)):
if uploaded_image is None:
raise gr.Error("No image uploaded. Please upload an image containing a face.")
progress(0, desc="Detecting Face")
input_image = uploaded_image.convert("RGB")
boxes, _ = mtcnn.detect(input_image)
if boxes is None:
raise gr.Error("Could not detect a face. Please try a clearer picture.")
face_box = boxes[0]
padding_x = (face_box[2] - face_box[0]) * 0.2
padding_y = (face_box[3] - face_box[1]) * 0.2
face_box[0] = max(0, face_box[0] - padding_x)
face_box[1] = max(0, face_box[1] - padding_y)
face_box[2] = min(input_image.width, face_box[2] + padding_x)
face_box[3] = min(input_image.height, face_box[3] + padding_y)
cropped_face = input_image.crop(face_box)
target_pil = cropped_face.resize((G.img_resolution, G.img_resolution), Image.LANCZOS)
target_np = np.array(target_pil, dtype=np.uint8)
target_tensor = torch.tensor(target_np.transpose([2, 0, 1]), device=device)
target_tensor = target_tensor.to(torch.float32) / 127.5 - 1
target_tensor = target_tensor.unsqueeze(0)
projected_w = project_face(G, target_tensor, w_avg=w_avg, num_steps=100, progress=progress)
progress(1, desc="Synthesizing Final Image")
w_edited = projected_w + gender_direction * strength
img_out = G.synthesis(w_edited, noise_mode='const')
img_out = (img_out.clamp(-1, 1) + 1) * 127.5
img_out = img_out.permute(0, 2, 3, 1)[0].cpu().numpy().astype(np.uint8)
print("Processing complete.")
return img_out
# Create the Gradio Interface with the custom theme
gr.Interface(
fn=edit_uploaded_face,
inputs=[
gr.Image(label="Upload Image With Face", type="pil"),
gr.Slider(-5, 5, step=0.1, value=0, label="Gender Strength (← Feminine | Masculine →)")
],
outputs=gr.Image(label="Edited Face"),
theme=theme, # <--- APPLYING YOUR CUSTOM THEME
title="Face Editor Backend",
description="This engine detects a face in the uploaded image, then edits its gender expression.",
).queue().launch()