Athagi commited on
Commit
11acff3
·
1 Parent(s): 6f17ff4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +262 -146
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import cv2
2
  import numpy as np
3
  import gradio as gr
@@ -5,169 +6,284 @@ from insightface.app import FaceAnalysis
5
  from insightface.model_zoo import get_model
6
  from PIL import Image
7
  import tempfile
8
- import os
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Initialize models
11
- face_analyzer = FaceAnalysis(
12
- name='buffalo_l',
13
- providers=['CPUExecutionProvider'],
14
- root=os.environ.get('MODEL_DIR', '.')
15
- )
16
- face_analyzer.prepare(ctx_id=0, det_size=(640, 640))
17
-
18
- # Load SimSwap model (256x256 version)
19
- simswap_model = get_model(
20
- 'models/simswap_256.onnx',
21
- download=False,
22
- download_zip=False
23
- )
24
-
25
- def get_faces(img):
26
- """Detect faces in image with enhanced error handling"""
27
  try:
28
- return face_analyzer.get(img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  except Exception as e:
30
- raise gr.Error(f"Face detection failed: {str(e)}") from e
 
 
 
 
31
 
32
- def draw_faces(img, faces):
33
- """Draw face bounding boxes with improved visualization"""
34
- img_with_boxes = img.copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  for i, face in enumerate(faces):
36
  box = face.bbox.astype(int)
37
- color = (0, 255, 0) # Green
38
- thickness = max(1, int(min(box[2]-box[0], box[3]-box[1]) / 150))
39
-
40
- # Draw bounding box
41
- cv2.rectangle(img_with_boxes,
42
- (box[0], box[1]),
43
- (box[2], box[3]),
44
- color, thickness)
45
-
46
- # Draw label
47
- label = f"Face {i}"
48
- font_scale = max(0.5, min(box[2]-box[0], box[3]-box[1]) / 1000)
49
- cv2.putText(img_with_boxes, label,
50
- (box[0], box[1] - 10),
51
- cv2.FONT_HERSHEY_SIMPLEX,
52
- font_scale, color, thickness)
53
  return img_with_boxes
54
 
55
- def process_image(img):
56
- """Convert PIL/numpy to BGR numpy array"""
57
- if isinstance(img, np.ndarray):
58
- return img[:, :, ::-1] if img.shape[2] == 3 else img
59
- return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
60
-
61
- def face_swap(source_img, target_img, source_index, target_index):
62
- """Perform face swap using SimSwap 256 model"""
63
- # Convert inputs
64
- source_np = process_image(source_img)
65
- target_np = process_image(target_img)
66
-
67
- # Detect faces
68
- source_faces = get_faces(source_np)
69
- target_faces = get_faces(target_np)
70
-
71
- # Validate selections
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  if not source_faces:
73
- raise gr.Error("No faces found in source image!")
 
 
 
 
74
  if not target_faces:
75
- raise gr.Error("No faces found in target image!")
76
- if source_index >= len(source_faces):
77
- raise gr.Error(f"Source face index {source_index} out of range (max {len(source_faces)-1})")
78
- if target_index >= len(target_faces):
79
- raise gr.Error(f"Target face index {target_index} out of range (max {len(target_faces)-1})")
80
-
81
- # Perform swap with SimSwap
82
- result = simswap_model.get(
83
- target_np,
84
- target_faces[target_index],
85
- source_faces[source_index],
86
- paste_back=True
87
- )
88
-
89
- # Convert and save
90
- result_rgb = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
91
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
92
- Image.fromarray(result_rgb).save(tmp.name, quality=95)
93
- return result_rgb, tmp.name
94
-
95
- def update_faces_preview(img, image_type):
96
- """Update face preview with selection options"""
97
- if img is None:
98
- return None, 0, gr.Slider(visible=False)
99
-
100
- img_np = process_image(img)
101
- faces = get_faces(img_np)
102
-
103
- if not faces:
104
- return img, 0, gr.Slider(visible=False)
105
-
106
- preview_img = draw_faces(img_np, faces)
107
- preview_rgb = cv2.cvtColor(preview_img, cv2.COLOR_BGR2RGB)
108
-
109
- return (
110
- preview_rgb,
111
- len(faces),
112
- gr.Slider(
113
- maximum=max(0, len(faces)-1),
114
- visible=len(faces) > 0,
115
- label=f"Select {image_type} Face Index"
116
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  )
118
 
119
- with gr.Blocks(title="SimSwap Pro", css=".gradio-container {background-color: #f0f2f6}") as demo:
120
- gr.Markdown("""
121
- # 🔄 SimSwap Face Swapper (256x256)
122
- *Higher quality face swapping using SimSwap 256 model*
123
- """)
124
-
125
  with gr.Row():
126
- with gr.Column():
127
- gr.Markdown("## Source Face")
128
- source_img = gr.Image(label="Source Image", type="pil")
129
- source_preview_btn = gr.Button("Preview Source Faces")
130
- source_preview = gr.Image(label="Detected Faces", interactive=False)
131
- source_index = gr.Slider(visible=False, step=1)
132
-
133
- with gr.Column():
134
- gr.Markdown("## Target Image")
135
- target_img = gr.Image(label="Target Image", type="pil")
136
- target_preview_btn = gr.Button("Preview Target Faces")
137
- target_preview = gr.Image(label="Detected Faces", interactive=False)
138
- target_index = gr.Slider(visible=False, step=1)
139
-
140
- swap_btn = gr.Button("✨ Swap Faces", variant="primary")
141
- result_title = gr.Markdown("## 🔮 Swapped Result", visible=False)
142
-
 
 
 
 
 
143
  with gr.Row():
144
- output_img = gr.Image(label="Output", interactive=False)
145
- download_btn = gr.File(label="Download Result", visible=False)
146
-
147
- # Event handlers
148
- source_preview_btn.click(
149
- fn=lambda x: update_faces_preview(x, "Source"),
150
- inputs=source_img,
151
- outputs=[source_preview, source_index, source_index]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  )
153
-
154
- target_preview_btn.click(
155
- fn=lambda x: update_faces_preview(x, "Target"),
156
- inputs=target_img,
157
- outputs=[target_preview, target_index, target_index]
158
  )
159
-
160
- swap_btn.click(
161
- fn=face_swap,
162
- inputs=[source_img, target_img, source_index, target_index],
163
- outputs=[output_img, download_btn]
164
- ).then(
165
- fn=lambda: [
166
- gr.Markdown(visible=True),
167
- gr.File(visible=True)
168
  ],
169
- outputs=[result_title, download_btn]
 
 
 
 
170
  )
171
 
172
  if __name__ == "__main__":
173
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import cv2
3
  import numpy as np
4
  import gradio as gr
 
6
  from insightface.model_zoo import get_model
7
  from PIL import Image
8
  import tempfile
9
+ import logging
10
+
11
+ # --- Configuration & Setup ---
12
+ # Configure logging for better debugging
13
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
+
15
+ # Constants for model paths and settings
16
+ SWAPPER_MODEL_PATH = "models/inswapper_128.onnx" # Consider making this a script argument or env variable
17
+ FACE_ANALYZER_NAME = 'buffalo_l'
18
+ DETECTION_SIZE = (640, 640)
19
+ CPU_PROVIDERS = ['CPUExecutionProvider'] # Or ['CUDAExecutionProvider', 'CPUExecutionProvider'] if GPU is available
20
 
21
+ # --- Global Variables (Lazy Loaded) ---
22
+ face_analyzer = None
23
+ swapper = None
24
+
25
+ # --- Initialization Functions ---
26
+ def initialize_models():
27
+ """Initialize the face analyzer and swapper models."""
28
+ global face_analyzer, swapper
 
 
 
 
 
 
 
 
 
29
  try:
30
+ if face_analyzer is None:
31
+ logging.info(f"Initializing FaceAnalysis model: {FACE_ANALYZER_NAME}")
32
+ face_analyzer = FaceAnalysis(name=FACE_ANALYZER_NAME, providers=CPU_PROVIDERS)
33
+ face_analyzer.prepare(ctx_id=0, det_size=DETECTION_SIZE)
34
+ logging.info("FaceAnalysis model initialized.")
35
+
36
+ if swapper is None:
37
+ if not os.path.exists(SWAPPER_MODEL_PATH):
38
+ logging.error(f"Swapper model not found at {SWAPPER_MODEL_PATH}. Please download it.")
39
+ # You might want to raise an exception here or try to download it
40
+ # For now, let's assume get_model can handle download if 'download=True' was intended
41
+ # but the original code had 'download=False'
42
+ raise FileNotFoundError(f"Swapper model not found: {SWAPPER_MODEL_PATH}")
43
+ logging.info(f"Loading swapper model from: {SWAPPER_MODEL_PATH}")
44
+ swapper = get_model(SWAPPER_MODEL_PATH, download=False) # Set download=True if you want to auto-download
45
+ logging.info("Swapper model loaded.")
46
  except Exception as e:
47
+ logging.error(f"Error during model initialization: {e}")
48
+ raise # Re-raise the exception to stop the app if models can't load
49
+
50
+ # Call initialization at the start
51
+ initialize_models()
52
 
53
+ # --- Core Functions ---
54
+ def get_faces_from_image(img_np: np.ndarray):
55
+ """
56
+ Detects faces in a NumPy image array (BGR format).
57
+
58
+ Args:
59
+ img_np: NumPy array representing the image (BGR).
60
+
61
+ Returns:
62
+ A list of face objects detected by FaceAnalysis.
63
+ """
64
+ if face_analyzer is None:
65
+ raise gr.Error("Face analyzer not initialized. Please check logs.")
66
+ faces = face_analyzer.get(img_np)
67
+ return faces
68
+
69
+ def draw_detected_faces(img_np: np.ndarray, faces: list):
70
+ """
71
+ Draws bounding boxes and labels on faces in an image.
72
+
73
+ Args:
74
+ img_np: NumPy array representing the image (BGR).
75
+ faces: A list of face objects.
76
+
77
+ Returns:
78
+ NumPy array with faces drawn (BGR).
79
+ """
80
+ img_with_boxes = img_np.copy()
81
  for i, face in enumerate(faces):
82
  box = face.bbox.astype(int)
83
+ # Ensure coordinates are within image bounds
84
+ x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
85
+ cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), (0, 255, 0), 2)
86
+ label_position = (x1, max(0, y1 - 10)) # Ensure label is not drawn outside top
87
+ cv2.putText(img_with_boxes, f"Face {i}", label_position,
88
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (36, 255, 12), 2) # Changed color for visibility
 
 
 
 
 
 
 
 
 
 
89
  return img_with_boxes
90
 
91
+ def convert_pil_to_cv2(pil_image: Image.Image) -> np.ndarray:
92
+ """Converts a PIL Image to an OpenCV NumPy array (BGR)."""
93
+ return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
94
+
95
+ def convert_cv2_to_pil(cv2_image: np.ndarray) -> Image.Image:
96
+ """Converts an OpenCV NumPy array (BGR) to a PIL Image."""
97
+ return Image.fromarray(cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB))
98
+
99
+ # --- Gradio Interface Functions ---
100
+ def process_face_swap(source_pil_img: Image.Image, target_pil_img: Image.Image, target_face_index: int):
101
+ """
102
+ Performs the face swap operation.
103
+
104
+ Args:
105
+ source_pil_img: PIL Image of the source face.
106
+ target_pil_img: PIL Image of the target scene.
107
+ target_face_index: Index of the face in the target image to be swapped.
108
+
109
+ Returns:
110
+ Tuple: (PIL Image of the swapped result, path to temporary file for download)
111
+ or raises gr.Error on failure.
112
+ """
113
+ if swapper is None:
114
+ raise gr.Error("Swapper model not initialized. Please check logs.")
115
+ if source_pil_img is None:
116
+ raise gr.Error("Source image not provided.")
117
+ if target_pil_img is None:
118
+ raise gr.Error("Target image not provided.")
119
+
120
+ source_np = convert_pil_to_cv2(source_pil_img)
121
+ target_np = convert_pil_to_cv2(target_pil_img)
122
+
123
+ # Get face from source image
124
+ source_faces = get_faces_from_image(source_np)
125
  if not source_faces:
126
+ raise gr.Error("No face found in the source image. Please use a clear image of a face.")
127
+ source_face = source_faces[0] # Assuming the first detected face is the one to use
128
+
129
+ # Get faces from target image
130
+ target_faces = get_faces_from_image(target_np)
131
  if not target_faces:
132
+ raise gr.Error("No faces found in the target image.")
133
+ if not (0 <= target_face_index < len(target_faces)):
134
+ # This case should ideally be prevented by the slider's dynamic range
135
+ raise gr.Error(f"Selected face index ({target_face_index}) is out of range. "
136
+ f"Detected {len(target_faces)} faces (indices 0 to {len(target_faces)-1}).")
137
+
138
+ try:
139
+ logging.info(f"Swapping face from source to target face index {target_face_index}.")
140
+ # Ensure target_face_index is an integer for indexing
141
+ target_face_to_swap = target_faces[int(target_face_index)]
142
+ swapped_bgr_img = swapper.get(target_np, target_face_to_swap, source_face, paste_back=True)
143
+ except Exception as e:
144
+ logging.error(f"Error during face swapping: {e}")
145
+ raise gr.Error(f"An error occurred during the swap process: {str(e)}")
146
+
147
+ swapped_pil_img = convert_cv2_to_pil(swapped_bgr_img)
148
+
149
+ # Save to a temporary file for download
150
+ try:
151
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
152
+ swapped_pil_img.save(tmp_file.name)
153
+ temp_file_path = tmp_file.name
154
+ logging.info(f"Swapped image saved to temporary file: {temp_file_path}")
155
+ except Exception as e:
156
+ logging.error(f"Error saving to temporary file: {e}")
157
+ raise gr.Error("Could not save the swapped image for download.")
158
+
159
+ return swapped_pil_img, temp_file_path
160
+
161
+ def preview_target_faces(target_pil_img: Image.Image):
162
+ """
163
+ Updates the preview of detected faces in the target image and adjusts the slider.
164
+
165
+ Args:
166
+ target_pil_img: PIL Image of the target scene.
167
+
168
+ Returns:
169
+ Tuple: (PIL Image with detected faces, Gradio Slider update)
170
+ """
171
+ if target_pil_img is None:
172
+ # Return a blank image and default slider if no image is provided
173
+ blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color = 'lightgray')
174
+ return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
175
+
176
+ target_np = convert_pil_to_cv2(target_pil_img)
177
+ faces = get_faces_from_image(target_np)
178
+
179
+ preview_np_img = draw_detected_faces(target_np, faces)
180
+ preview_pil_img = convert_cv2_to_pil(preview_np_img)
181
+
182
+ num_faces = len(faces)
183
+ if num_faces > 0:
184
+ # Update slider: max index is num_faces - 1
185
+ slider_update = gr.Slider(minimum=0, maximum=num_faces - 1, value=0, step=1, interactive=True)
186
+ else:
187
+ # No faces, disable slider
188
+ slider_update = gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
189
+
190
+ return preview_pil_img, slider_update
191
+
192
+ # --- Gradio UI Definition ---
193
+ with gr.Blocks(title="Face Swap Pro 🔄", theme=gr.themes.Soft()) as demo:
194
+ gr.Markdown(
195
+ """
196
+ # 🎭 Face Swap Pro 🚀
197
+ Upload a source image with the face you want to use, and a target image where you want to swap a face.
198
+ Use the 'Preview Detected Faces' button to see faces in the target image and select which one to replace.
199
+ """
200
  )
201
 
 
 
 
 
 
 
202
  with gr.Row():
203
+ with gr.Column(scale=1):
204
+ source_image_input = gr.Image(label="👤 Source Face Image", type="pil", sources=["upload", "clipboard"])
205
+ with gr.Column(scale=1):
206
+ target_image_input = gr.Image(label="🖼️ Target Scene Image", type="pil", sources=["upload", "clipboard"])
207
+
208
+ with gr.Row():
209
+ preview_button = gr.Button("🔍 Preview Detected Faces in Target", variant="secondary")
210
+ face_index_slider = gr.Slider(
211
+ label="🎯 Select Target Face Index (0-indexed)",
212
+ minimum=0,
213
+ maximum=0, # Will be updated dynamically
214
+ step=1,
215
+ value=0,
216
+ interactive=False # Initially not interactive
217
+ )
218
+
219
+ target_faces_preview_output = gr.Image(label="👀 Detected Faces in Target", interactive=False)
220
+
221
+ gr.HTML("<hr>") # Visual separator
222
+
223
+ swap_button = gr.Button("🔁 SWAP FACES NOW!", variant="primary")
224
+
225
  with gr.Row():
226
+ swapped_image_output = gr.Image(label="✨ Swapped Result", interactive=False)
227
+ download_output_file = gr.File(label="⬇️ Download Swapped Image")
228
+
229
+ # --- Event Handlers ---
230
+ def on_target_image_change(target_img):
231
+ """Called when the target image is uploaded or cleared."""
232
+ if target_img is None:
233
+ # Reset preview and slider if target image is cleared
234
+ blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color = 'lightgray')
235
+ return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
236
+ # If an image is uploaded, the preview button click will handle the update.
237
+ # This function primarily handles the reset case.
238
+ # Or, you could auto-trigger preview here:
239
+ # return preview_target_faces(target_img)
240
+ # For now, let's keep it explicit with the button.
241
+ return target_faces_preview_output.value, face_index_slider.value # Return current values
242
+
243
+
244
+ # Connect target image change to potentially reset/update UI elements
245
+ # (e.g., if user clears the target image)
246
+ target_image_input.change(
247
+ fn=on_target_image_change,
248
+ inputs=[target_image_input],
249
+ outputs=[target_faces_preview_output, face_index_slider]
250
+ )
251
+
252
+ preview_button.click(
253
+ fn=preview_target_faces,
254
+ inputs=[target_image_input],
255
+ outputs=[target_faces_preview_output, face_index_slider]
256
  )
257
+
258
+ swap_button.click(
259
+ fn=process_face_swap,
260
+ inputs=[source_image_input, target_image_input, face_index_slider],
261
+ outputs=[swapped_image_output, download_output_file]
262
  )
263
+
264
+ # --- Examples ---
265
+ gr.Examples(
266
+ examples=[
267
+ ["examples/source_face.jpg", "examples/target_group.jpg", 0], # Create these example files
268
+ ["examples/source_actor.png", "examples/target_scene.png", 1]
 
 
 
269
  ],
270
+ inputs=[source_image_input, target_image_input, face_index_slider],
271
+ outputs=[swapped_image_output, download_output_file],
272
+ fn=process_face_swap,
273
+ cache_examples=False, # Set to True if your examples are static and processing is slow
274
+ label="Example Face Swaps"
275
  )
276
 
277
  if __name__ == "__main__":
278
+ # Ensure 'models' and 'examples' directories exist or handle their absence
279
+ os.makedirs("models", exist_ok=True)
280
+ os.makedirs("examples", exist_ok=True)
281
+ # You'd typically place your 'inswapper_128.onnx' in the 'models' directory
282
+ # and example images in the 'examples' directory.
283
+
284
+ # Check if the swapper model exists before launching
285
+ if not os.path.exists(SWAPPER_MODEL_PATH):
286
+ print(f"ERROR: Swapper model not found at {SWAPPER_MODEL_PATH}")
287
+ print("Please download the 'inswapper_128.onnx' model and place it in the 'models' directory.")
288
+ else:
289
+ demo.launch()