import sys import os sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import gradio as gr import numpy as np import torch import torch.nn.functional as F from PIL import Image import pickle import lpips from facenet_pytorch import MTCNN import legacy # --- THEME DEFINITION --- # This theme uses the colors and fonts from your website's CSS. theme = gr.themes.Base( primary_hue=gr.themes.colors.blue, # A base blue hue font=[gr.themes.GoogleFont("Poppins"), "sans-serif"], ).set( button_primary_background_fill="#4A90E2", # Your --primary-color button_primary_background_fill_hover="#357ABD", # Your --secondary-color slider_color="#4A90E2", # Your --slider-thumb color slider_color_dark="#4A90E2" ) # ------------------------- # --- Load All Models --- print("Loading all models...") device = torch.device("cpu") with open("ffhq.pkl", "rb") as f: G = pickle.load(f)['G_ema'].to(device) print("StyleGAN2 model loaded.") mtcnn = MTCNN(keep_all=False, device=device) print("Face detector model loaded.") gender_direction = np.load("stylegan2directions/gender.npy") gender_direction = torch.from_numpy(gender_direction).to(torch.float32).to(device) print("All models and vectors loaded successfully.") lpips_loss = lpips.LPIPS(net='vgg').to(device).eval() print("LPIPS model loaded.") print("Pre-calculating the 'average face' (w_avg)... This may take a moment.") w_avg_samples = 10000 z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) w_avg = torch.mean(w_samples, 0, keepdim=True) print("w_avg pre-calculation complete.") # ----------------------------------- def project_face(G, target, w_avg, *, num_steps=100, initial_learning_rate=0.1, progress): noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name } w_opt = w_avg.clone().detach().requires_grad_(True) optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate) for step in progress.tqdm(range(num_steps), desc="Projecting Face"): synth_images = G.synthesis(w_opt, noise_mode='const') dist = lpips_loss(synth_images, target) loss = torch.sum(dist) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() return w_opt.detach() def edit_uploaded_face(uploaded_image, strength, progress=gr.Progress(track_tqdm=True)): if uploaded_image is None: raise gr.Error("No image uploaded. Please upload an image containing a face.") progress(0, desc="Detecting Face") input_image = uploaded_image.convert("RGB") boxes, _ = mtcnn.detect(input_image) if boxes is None: raise gr.Error("Could not detect a face. Please try a clearer picture.") face_box = boxes[0] padding_x = (face_box[2] - face_box[0]) * 0.2 padding_y = (face_box[3] - face_box[1]) * 0.2 face_box[0] = max(0, face_box[0] - padding_x) face_box[1] = max(0, face_box[1] - padding_y) face_box[2] = min(input_image.width, face_box[2] + padding_x) face_box[3] = min(input_image.height, face_box[3] + padding_y) cropped_face = input_image.crop(face_box) target_pil = cropped_face.resize((G.img_resolution, G.img_resolution), Image.LANCZOS) target_np = np.array(target_pil, dtype=np.uint8) target_tensor = torch.tensor(target_np.transpose([2, 0, 1]), device=device) target_tensor = target_tensor.to(torch.float32) / 127.5 - 1 target_tensor = target_tensor.unsqueeze(0) projected_w = project_face(G, target_tensor, w_avg=w_avg, num_steps=100, progress=progress) progress(1, desc="Synthesizing Final Image") w_edited = projected_w + gender_direction * strength img_out = G.synthesis(w_edited, noise_mode='const') img_out = (img_out.clamp(-1, 1) + 1) * 127.5 img_out = img_out.permute(0, 2, 3, 1)[0].cpu().numpy().astype(np.uint8) print("Processing complete.") return img_out # Create the Gradio Interface with the custom theme gr.Interface( fn=edit_uploaded_face, inputs=[ gr.Image(label="Upload Image With Face", type="pil"), gr.Slider(-5, 5, step=0.1, value=0, label="Gender Strength (← Feminine | Masculine →)") ], outputs=gr.Image(label="Edited Face"), theme=theme, # <--- APPLYING YOUR CUSTOM THEME title="Face Editor Backend", description="This engine detects a face in the uploaded image, then edits its gender expression.", ).queue().launch()