Diggz10 commited on
Commit
6426d05
·
verified ·
1 Parent(s): 1bc4a45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -31
app.py CHANGED
@@ -6,54 +6,111 @@ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
6
  import gradio as gr
7
  import numpy as np
8
  import torch
 
 
 
9
  import legacy
10
  import dnnlib
 
 
 
11
 
12
- print("Loading and converting TensorFlow model...")
 
13
  device = torch.device("cpu")
 
 
14
  with open("stylegan2-ffhq-config-f.pkl", "rb") as f:
15
- # This line uses legacy.py to load the TF model and convert it to PyTorch
16
- G = legacy.load_network_pkl(f)["G_ema"].to(device)
17
- print("Model loaded and converted successfully.")
 
 
 
18
 
19
- # Load gender direction vector
20
- print("Loading gender direction vector...")
21
  gender_direction = np.load("stylegan2directions/gender.npy")
22
  gender_direction = torch.from_numpy(gender_direction).to(torch.float32).to(device)
23
- print("Vector loaded successfully.")
 
24
 
25
- # This is the function that will be called by the Gradio interface
26
- def edit_gender(seed, strength):
27
- seed = int(seed) # Ensure seed is an integer
28
-
29
- # Generate latent code from the seed
30
- rnd = np.random.RandomState(seed)
31
- z = torch.from_numpy(rnd.randn(1, G.z_dim)).to(torch.float32).to(device)
 
 
 
 
 
32
 
33
- # Map the latent code to the intermediate space (W)
34
- w = G.mapping(z, None, truncation_psi=0.7)
 
 
 
 
35
 
36
- # Apply the gender direction transformation
37
- w = w + gender_direction * strength
 
38
 
39
- # Synthesize the image from the modified W
40
- img = G.synthesis(w, noise_mode="const")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- # Post-process the image for display
43
- img = (img.clamp(-1, 1) + 1) * 127.5
44
- img = img.permute(0, 2, 3, 1)[0].cpu().numpy().astype(np.uint8)
 
45
 
46
- return img
47
 
48
- # Create and launch the Gradio interface
 
49
  gr.Interface(
50
- fn=edit_gender,
51
  inputs=[
52
- gr.Slider(0, 10000, step=1, value=1234, label="Random Seed"),
53
  gr.Slider(-5, 5, step=0.1, value=0, label="Gender Strength (← Feminine | Masculine →)")
54
  ],
55
- outputs=gr.Image(label="Generated Face"),
56
- title="StyleGAN2 Gender Editor",
57
- description="Move the slider to change the gender expression of the generated face. Change the seed to get a new face.",
58
- allow_flagging="never"
 
 
 
 
 
59
  ).launch()
 
6
  import gradio as gr
7
  import numpy as np
8
  import torch
9
+ from PIL import Image
10
+
11
+ # Backend libraries for StyleGAN and face detection
12
  import legacy
13
  import dnnlib
14
+ from training.networks import SynthesisNetwork
15
+ from projector import project
16
+ from facenet_pytorch import MTCNN
17
 
18
+ # --- Load All Models ---
19
+ print("Loading all models...")
20
  device = torch.device("cpu")
21
+
22
+ # Load StyleGAN2 Generator
23
  with open("stylegan2-ffhq-config-f.pkl", "rb") as f:
24
+ G = legacy.load_network_pkl(f)['G_ema'].to(device)
25
+ print("StyleGAN2 model loaded.")
26
+
27
+ # Load Face Detector (MTCNN)
28
+ mtcnn = MTCNN(keep_all=False, device=device) # keep_all=False finds only the best face
29
+ print("Face detector model loaded.")
30
 
31
+ # Load Gender Direction Vector
 
32
  gender_direction = np.load("stylegan2directions/gender.npy")
33
  gender_direction = torch.from_numpy(gender_direction).to(torch.float32).to(device)
34
+ print("All models and vectors loaded successfully.")
35
+ # -----------------------------------
36
 
37
+
38
+ def edit_uploaded_face(uploaded_image, strength):
39
+ """
40
+ This function detects a face in an uploaded image, projects it, edits it, and returns the result.
41
+ """
42
+ if uploaded_image is None:
43
+ raise gr.Error("No image uploaded. Please upload an image containing a face.")
44
+
45
+ print("Detecting face in the uploaded image...")
46
+ # The 'uploaded_image' from Gradio is a PIL Image, which is what MTCNN needs.
47
+ # We need to convert it to RGB if it has an alpha channel (like PNGs)
48
+ input_image = uploaded_image.convert("RGB")
49
 
50
+ # Detect face and get bounding box
51
+ boxes, _ = mtcnn.detect(input_image)
52
+
53
+ # Handle case where no face is detected
54
+ if boxes is None:
55
+ raise gr.Error("Could not detect a face. Please try a clearer picture or one where the face is more prominent.")
56
 
57
+ # --- Crop the image to the detected face ---
58
+ # boxes[0] contains the coordinates [x1, y1, x2, y2]
59
+ face_box = boxes[0]
60
 
61
+ # Add some padding to the crop to ensure the whole head is included
62
+ padding_x = (face_box[2] - face_box[0]) * 0.2
63
+ padding_y = (face_box[3] - face_box[1]) * 0.2
64
+ face_box[0] = max(0, face_box[0] - padding_x)
65
+ face_box[1] = max(0, face_box[1] - padding_y)
66
+ face_box[2] = min(input_image.width, face_box[2] + padding_x)
67
+ face_box[3] = min(input_image.height, face_box[3] + padding_y)
68
+
69
+ cropped_face = input_image.crop(face_box)
70
+ print("Face detected and cropped.")
71
+
72
+ # --- Run GAN Inversion on the CROPPED face ---
73
+ print("Projecting the face into the model's latent space...")
74
+ # This can be slow, especially on CPU. num_steps=100 is a good compromise for web apps.
75
+ projected_w = project(
76
+ G,
77
+ cropped_face, # Use the cropped face here
78
+ num_steps=100,
79
+ device=device,
80
+ verbose=False # Set to True for more detailed projection logs
81
+ )
82
+ print("Image projected successfully.")
83
+
84
+ # --- Apply Edit and Synthesize New Face ---
85
+ w_to_edit = projected_w[0]
86
+ w_edited = w_to_edit + gender_direction * strength
87
+ w_edited = w_edited.unsqueeze(0)
88
+
89
+ print("Synthesizing new image...")
90
+ img_out = G.synthesis(w_edited, noise_mode='const')
91
 
92
+ # Post-process for display
93
+ img_out = (img_out.clamp(-1, 1) + 1) * 127.5
94
+ img_out = img_out.permute(0, 2, 3, 1)[0].cpu().numpy().astype(np.uint8)
95
+ print("Processing complete.")
96
 
97
+ return img_out
98
 
99
+ # --- Create the Gradio Interface ---
100
+ # This interface now has a robust backend ready to be used as an API.
101
  gr.Interface(
102
+ fn=edit_uploaded_face,
103
  inputs=[
104
+ gr.Image(label="Upload Image With Face", type="pil"),
105
  gr.Slider(-5, 5, step=0.1, value=0, label="Gender Strength (← Feminine | Masculine →)")
106
  ],
107
+ outputs=gr.Image(label="Edited Face"),
108
+ title="Face Editor Backend",
109
+ description="This engine detects a face in the uploaded image, then edits its gender expression. It is ready to be used as an API.",
110
+ allow_flagging="never",
111
+ examples=[
112
+ ["stylegan2directions/obama.jpg", 0],
113
+ ["stylegan2directions/obama.jpg", 3.5],
114
+ ["stylegan2directions/obama.jpg", -3.5],
115
+ ]
116
  ).launch()