Diggz10 commited on
Commit
ea56e06
·
verified ·
1 Parent(s): 9bf5755

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from PIL import Image
10
+ import pickle
11
+ import lpips
12
+
13
+ from facenet_pytorch import MTCNN
14
+ import legacy
15
+
16
+ # --- Load All Models ---
17
+ print("Loading all models...")
18
+ device = torch.device("cpu")
19
+
20
+ with open("ffhq.pkl", "rb") as f:
21
+ G = pickle.load(f)['G_ema'].to(device)
22
+ print("StyleGAN2 model loaded.")
23
+
24
+ mtcnn = MTCNN(keep_all=False, device=device)
25
+ print("Face detector model loaded.")
26
+
27
+ gender_direction = np.load("stylegan2directions/gender.npy")
28
+ gender_direction = torch.from_numpy(gender_direction).to(torch.float32).to(device)
29
+ print("All models and vectors loaded successfully.")
30
+
31
+ lpips_loss = lpips.LPIPS(net='vgg').to(device).eval()
32
+ print("LPIPS model loaded.")
33
+
34
+ print("Pre-calculating the 'average face' (w_avg)... This may take a moment.")
35
+ w_avg_samples = 10000
36
+ z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
37
+ w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None)
38
+ w_avg = torch.mean(w_samples, 0, keepdim=True)
39
+ print("w_avg pre-calculation complete.")
40
+ # -----------------------------------
41
+
42
+ def project_face(G, target, w_avg, *, num_steps=100, initial_learning_rate=0.1, progress):
43
+ noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
44
+ w_opt = w_avg.clone().detach().requires_grad_(True)
45
+ optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
46
+
47
+ for step in progress.tqdm(range(num_steps), desc="Projecting Face"):
48
+ synth_images = G.synthesis(w_opt, noise_mode='const')
49
+ dist = lpips_loss(synth_images, target)
50
+ loss = torch.sum(dist)
51
+ optimizer.zero_grad(set_to_none=True)
52
+ loss.backward()
53
+ optimizer.step()
54
+
55
+ return w_opt.detach()
56
+
57
+ def edit_uploaded_face(uploaded_image, strength, progress=gr.Progress(track_tqdm=True)):
58
+ if uploaded_image is None:
59
+ raise gr.Error("No image uploaded. Please upload an image containing a face.")
60
+ progress(0, desc="Detecting Face")
61
+ input_image = uploaded_image.convert("RGB")
62
+ boxes, _ = mtcnn.detect(input_image)
63
+ if boxes is None:
64
+ raise gr.Error("Could not detect a face. Please try a clearer picture.")
65
+ face_box = boxes[0]
66
+ padding_x = (face_box[2] - face_box[0]) * 0.2
67
+ padding_y = (face_box[3] - face_box[1]) * 0.2
68
+ face_box[0] = max(0, face_box[0] - padding_x)
69
+ face_box[1] = max(0, face_box[1] - padding_y)
70
+ face_box[2] = min(input_image.width, face_box[2] + padding_x)
71
+ face_box[3] = min(input_image.height, face_box[3] + padding_y)
72
+ cropped_face = input_image.crop(face_box)
73
+ target_pil = cropped_face.resize((G.img_resolution, G.img_resolution), Image.LANCZOS)
74
+ target_np = np.array(target_pil, dtype=np.uint8)
75
+ target_tensor = torch.tensor(target_np.transpose([2, 0, 1]), device=device)
76
+ target_tensor = target_tensor.to(torch.float32) / 127.5 - 1
77
+ target_tensor = target_tensor.unsqueeze(0)
78
+ projected_w = project_face(G, target_tensor, w_avg=w_avg, num_steps=100, progress=progress)
79
+ progress(1, desc="Synthesizing Final Image")
80
+ w_edited = projected_w + gender_direction * strength
81
+ img_out = G.synthesis(w_edited, noise_mode='const')
82
+ img_out = (img_out.clamp(-1, 1) + 1) * 127.5
83
+ img_out = img_out.permute(0, 2, 3, 1)[0].cpu().numpy().astype(np.uint8)
84
+ print("Processing complete.")
85
+ return img_out
86
+
87
+ # Use .queue() and the standard .launch(). This enables the API.
88
+ gr.Interface(
89
+ fn=edit_uploaded_face,
90
+ inputs=[
91
+ gr.Image(label="Upload Image With Face", type="pil"),
92
+ gr.Slider(-5, 5, step=0.1, value=0, label="Gender Strength (← Feminine | Masculine →)")
93
+ ],
94
+ outputs=gr.Image(label="Edited Face"),
95
+ title="Face Editor Backend",
96
+ description="This engine detects a face in the uploaded image, then edits its gender expression. It is ready to be used as an API.",
97
+ flagging_options=None,
98
+ ).queue().launch()