Athagi commited on
Commit
4d077cb
Β·
1 Parent(s): 5b9c4dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +317 -263
app.py CHANGED
@@ -2,8 +2,8 @@ import os
2
  import cv2
3
  import numpy as np
4
  import gradio as gr
5
- from insightface.app import FaceAnalysis
6
- from insightface.model_zoo import get_model
7
  from PIL import Image
8
  import tempfile
9
  import logging
@@ -12,66 +12,71 @@ import onnxruntime
12
  # --- Configuration & Setup ---
13
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
 
15
- SWAPPER_MODEL_PATH = "models/inswapper_128.onnx"
16
- RESTORATION_MODEL_PATH = "models/gfpgan_1.4.onnx" # Path to your GFPGAN or chosen restoration ONNX model
 
 
17
 
18
- FACE_ANALYZER_NAME = 'buffalo_l' # InsightFace detection model
19
- DETECTION_SIZE = (640, 640) # Input size for face detection
20
- EXECUTION_PROVIDERS = ['CPUExecutionProvider'] # Or ['CUDAExecutionProvider', 'CPUExecutionProvider'] for GPU
21
 
22
  # --- Global Variables (Lazy Loaded by initialize_models) ---
23
  face_analyzer = None
24
- swapper = None
25
  face_restorer = None
 
26
 
27
  # --- Initialization Functions ---
28
  def initialize_models():
29
- global face_analyzer, swapper, face_restorer
30
  try:
31
- # Initialize FaceAnalysis model
32
  if face_analyzer is None:
33
  logging.info(f"Initializing FaceAnalysis model: {FACE_ANALYZER_NAME}")
34
  face_analyzer = FaceAnalysis(name=FACE_ANALYZER_NAME, providers=EXECUTION_PROVIDERS)
35
- face_analyzer.prepare(ctx_id=0, det_size=DETECTION_SIZE) # ctx_id=0 for CPU, -1 for auto
36
  logging.info("FaceAnalysis model initialized successfully.")
37
 
38
- # Initialize Swapper model
39
- if swapper is None:
40
  if not os.path.exists(SWAPPER_MODEL_PATH):
41
- logging.error(f"Swapper model FILE NOT FOUND at {SWAPPER_MODEL_PATH}. Swapping will fail.")
42
  else:
43
- logging.info(f"Loading swapper model from: {SWAPPER_MODEL_PATH}")
44
  try:
45
- # Pass providers to get_model if supported
46
- swapper = get_model(SWAPPER_MODEL_PATH, download=False, providers=EXECUTION_PROVIDERS)
47
- except TypeError:
48
- logging.warning(f"Failed to pass 'providers' to swapper model {SWAPPER_MODEL_PATH}. Retrying without 'providers' argument.")
49
- swapper = get_model(SWAPPER_MODEL_PATH, download=False) # Fallback
50
- logging.info("Swapper model loaded successfully.")
51
-
52
- # Initialize Face Restoration Model
53
- if face_restorer is None: # Only attempt to load if not already loaded/failed
54
  if not os.path.exists(RESTORATION_MODEL_PATH):
55
- logging.error(f"Face restoration model FILE NOT FOUND at: {RESTORATION_MODEL_PATH}. Enhancement feature will be disabled.")
56
  else:
 
57
  logging.info(f"Attempting to load face restoration model from: {RESTORATION_MODEL_PATH}")
58
  try:
59
  face_restorer = onnxruntime.InferenceSession(RESTORATION_MODEL_PATH, providers=EXECUTION_PROVIDERS)
60
  logging.info("Face restoration model loaded successfully.")
61
  except Exception as e:
62
  logging.error(f"Error loading face restoration model from {RESTORATION_MODEL_PATH}: {e}. Enhancement feature will be disabled.", exc_info=True)
63
- face_restorer = None # Ensure it's None if loading failed
 
 
 
64
  except Exception as e:
65
  logging.error(f"A critical error occurred during model initialization: {e}", exc_info=True)
66
- # face_analyzer, swapper, or face_restorer might be None. Subsequent checks will handle this.
67
 
68
  # --- Call Initialization Early ---
69
  initialize_models()
70
- # Global flags based on model loading status, used for UI conditional rendering
71
- core_models_loaded_successfully = face_analyzer is not None and swapper is not None
72
  restoration_model_loaded_successfully = face_restorer is not None
 
73
 
74
- # --- Image Utility Functions ---
75
  def convert_pil_to_cv2(pil_image: Image.Image) -> np.ndarray | None:
76
  if pil_image is None: return None
77
  try:
@@ -88,139 +93,132 @@ def convert_cv2_to_pil(cv2_image: np.ndarray) -> Image.Image | None:
88
  logging.error(f"Error converting CV2 to PIL: {e}")
89
  return None
90
 
91
- # --- Core AI & Image Processing Functions ---
 
 
 
 
92
  def get_faces_from_image(img_np: np.ndarray) -> list:
93
- if not core_models_loaded_successfully or face_analyzer is None:
94
- # This condition should ideally be caught before calling this function if core models failed.
95
  logging.error("Face analyzer not available for get_faces_from_image.")
96
  return []
97
  if img_np is None: return []
98
  try:
99
- return face_analyzer.get(img_np)
100
  except Exception as e:
101
  logging.error(f"Error during face detection: {e}", exc_info=True)
102
  return []
103
 
104
- def draw_detected_faces(img_np: np.ndarray, faces: list) -> np.ndarray:
 
105
  img_with_boxes = img_np.copy()
106
  for i, face in enumerate(faces):
107
- box = face.bbox.astype(int) # InsightFace bbox is [x1, y1, x2, y2]
108
  x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
109
  cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), (0, 255, 0), 2)
110
- label_position = (x1, max(0, y1 - 10)) # Ensure label is not drawn outside top
111
  cv2.putText(img_with_boxes, f"Face {i}", label_position,
112
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (36, 255, 12), 2)
113
  return img_with_boxes
114
 
115
- def enhance_cropped_face(face_crop_bgr: np.ndarray) -> np.ndarray:
116
  if not restoration_model_loaded_successfully or face_restorer is None:
117
  logging.warning("Face restorer model not available. Skipping enhancement for crop.")
118
  return face_crop_bgr
119
  if face_crop_bgr is None or face_crop_bgr.shape[0] == 0 or face_crop_bgr.shape[1] == 0:
120
  logging.warning("Received empty or invalid face crop for enhancement.")
121
  return face_crop_bgr if face_crop_bgr is not None else np.array([])
122
-
123
-
124
  logging.info(f"Applying face restoration to crop of size {face_crop_bgr.shape[:2]}...")
125
  crop_height, crop_width = face_crop_bgr.shape[:2]
126
-
127
  try:
128
  img_rgb = cv2.cvtColor(face_crop_bgr, cv2.COLOR_BGR2RGB)
129
-
130
- # GFPGAN typically expects 512x512 input.
131
- restorer_input_size = (512, 512) # Check your specific model's expected input size
132
  img_resized_for_model = cv2.resize(img_rgb, restorer_input_size, interpolation=cv2.INTER_AREA)
133
-
134
- img_normalized = (img_resized_for_model / 255.0).astype(np.float32) # Normalize to [0, 1]
135
- img_chw = np.transpose(img_normalized, (2, 0, 1)) # HWC to CHW
136
- input_tensor = np.expand_dims(img_chw, axis=0) # Add batch dimension
137
-
138
  input_name = face_restorer.get_inputs()[0].name
139
  output_name = face_restorer.get_outputs()[0].name
140
-
141
  restored_output_model_size = face_restorer.run([output_name], {input_name: input_tensor})[0]
142
-
143
  restored_img_chw = np.squeeze(restored_output_model_size, axis=0)
144
- restored_img_hwc_model_size = np.transpose(restored_img_chw, (1, 2, 0)) # CHW to HWC
145
  restored_img_uint8_model_size = np.clip(restored_img_hwc_model_size * 255.0, 0, 255).astype(np.uint8)
146
-
147
- # Resize back to the original crop's dimensions
148
  restored_crop_rgb = cv2.resize(restored_img_uint8_model_size, (crop_width, crop_height), interpolation=cv2.INTER_LANCZOS4)
149
  restored_crop_bgr = cv2.cvtColor(restored_crop_rgb, cv2.COLOR_RGB2BGR)
150
-
151
  logging.info("Cropped face restoration complete.")
152
  return restored_crop_bgr
153
  except Exception as e:
154
  logging.error(f"Error during face restoration for crop: {e}", exc_info=True)
155
- return face_crop_bgr # Return original crop on error
156
 
157
- def histogram_match_channel(source_channel: np.ndarray, target_channel: np.ndarray) -> np.ndarray:
158
- """Matches histogram of a single source channel to a target channel."""
159
  source_shape = source_channel.shape
160
  source_channel_flat = source_channel.flatten()
161
  target_channel_flat = target_channel.flatten()
162
-
163
- # Get the histogram and CDF of the source image
164
- source_hist, bins = np.histogram(source_channel_flat, 256, [0,256])
165
  source_cdf = source_hist.cumsum()
166
-
167
- # Get the histogram and CDF of the target image
168
- target_hist, bins = np.histogram(target_channel_flat, 256, [0,256])
169
  target_cdf = target_hist.cumsum()
170
-
171
- # Normalize CDFs
172
- source_cdf_norm = source_cdf * float(target_hist.max()) / source_hist.max() # Avoid divide by zero if source_hist.max() is 0
173
- target_cdf_norm = target_cdf # No change needed for target if using its own max
174
-
175
- # Create a lookup table
176
  lookup_table = np.zeros(256, dtype='uint8')
177
- gj = 0
178
- for gi in range(256):
179
- while gj < 256 and target_cdf_norm[gj] < source_cdf_norm[gi]:
180
- gj += 1
181
- if gj == 256: # Safety for out of bounds
182
- gj = 255
183
- lookup_table[gi] = gj
184
-
185
- matched_channel_flat = cv2.LUT(source_channel, lookup_table) # Apply lookup table
186
  return matched_channel_flat.reshape(source_shape)
187
 
188
-
189
- def histogram_match_color(source_img: np.ndarray, target_img: np.ndarray) -> np.ndarray:
190
- """
191
- Performs histogram matching on color images (BGR).
192
- """
193
  if source_img is None or target_img is None or source_img.size == 0 or target_img.size == 0:
194
  logging.warning("Cannot perform histogram matching on empty or None images.")
195
  return source_img if source_img is not None else np.array([])
196
-
197
  matched_img = np.zeros_like(source_img)
198
  try:
199
- for i in range(source_img.shape[2]): # Iterate over B, G, R channels
200
  matched_img[:,:,i] = histogram_match_channel(source_img[:,:,i], target_img[:,:,i])
201
  return matched_img
202
  except Exception as e:
203
  logging.error(f"Error during histogram matching: {e}", exc_info=True)
204
- return source_img # Return original on error
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
- # --- Gradio Interface Functions ---
208
  def process_face_swap(source_pil_img: Image.Image, target_pil_img: Image.Image,
209
- target_face_index: int, apply_enhancement: bool, apply_color_correction: bool,
 
 
210
  progress=gr.Progress(track_tqdm=True)):
 
211
  progress(0, desc="Initializing process...")
212
-
213
- if not core_models_loaded_successfully:
214
- gr.Error("CRITICAL: Core models (Face Analyzer or Swapper) not loaded. Cannot proceed.")
215
- return Image.new('RGB', (100,100), color='lightgrey'), None # Placeholder for output
216
 
217
  if source_pil_img is None: raise gr.Error("Source image not provided.")
218
  if target_pil_img is None: raise gr.Error("Target image not provided.")
219
 
220
  progress(0.05, desc="Converting images...")
221
- source_np = convert_pil_to_cv2(source_pil_img)
222
- target_np = convert_pil_to_cv2(target_pil_img)
223
-
224
  if source_np is None or target_np is None:
225
  raise gr.Error("Image conversion failed. Please try different images.")
226
 
@@ -228,98 +226,204 @@ def process_face_swap(source_pil_img: Image.Image, target_pil_img: Image.Image,
228
 
229
  progress(0.15, desc="Detecting face in source image...")
230
  source_faces = get_faces_from_image(source_np)
231
- if not source_faces: raise gr.Error("No face found in the source image. Please use a clear image of a single face.")
232
- source_face = source_faces[0] # Assuming the first detected face is the one to use
233
 
234
  progress(0.25, desc="Detecting faces in target image...")
235
  target_faces = get_faces_from_image(target_np)
236
  if not target_faces: raise gr.Error("No faces found in the target image.")
237
  if not (0 <= target_face_index < len(target_faces)):
238
- # This case should ideally be prevented by the slider's dynamic range
239
- raise gr.Error(f"Selected target face index ({target_face_index}) is out of range. "
240
- f"Detected {len(target_faces)} faces (indices 0 to {len(target_faces)-1}).")
241
-
242
- target_face_to_swap_info = target_faces[int(target_face_index)] # Face object from InsightFace
243
 
244
- swapped_bgr_img = target_np.copy() # Start with a copy of target
245
  try:
246
- progress(0.4, desc="Performing face swap...")
247
- # swapper.get returns the modified target_np with the face pasted back
248
- swapped_bgr_img = swapper.get(target_np, target_face_to_swap_info, source_face, paste_back=True)
249
-
250
- # Define bounding box for post-processing (enhancement and color correction)
251
- # Use the bbox of the target face where the swap occurred.
252
- # InsightFace bbox is [x1, y1, x2, y2]
253
- bbox_coords = target_face_to_swap_info.bbox.astype(int)
254
 
255
- if apply_enhancement:
256
- if restoration_model_loaded_successfully:
257
- progress(0.6, desc="Applying selective face enhancement...")
258
- padding_enh = 20 # Pixels of padding around the bbox for enhancement context
259
-
260
- enh_x1 = max(0, bbox_coords[0] - padding_enh)
261
- enh_y1 = max(0, bbox_coords[1] - padding_enh)
262
- enh_x2 = min(target_w, bbox_coords[2] + padding_enh)
263
- enh_y2 = min(target_h, bbox_coords[3] + padding_enh)
264
-
265
- if enh_x1 < enh_x2 and enh_y1 < enh_y2: # Check if the crop is valid
266
- face_crop_to_enhance = swapped_bgr_img[enh_y1:enh_y2, enh_x1:enh_x2]
267
- enhanced_crop = enhance_cropped_face(face_crop_to_enhance.copy()) # Use .copy()
268
- if enhanced_crop.shape == face_crop_to_enhance.shape: # Ensure enhanced crop is valid and same size
269
- swapped_bgr_img[enh_y1:enh_y2, enh_x1:enh_x2] = enhanced_crop
270
- else:
271
- logging.warning("Enhanced crop size mismatch. Skipping paste-back for enhancement.")
272
- else:
273
- logging.warning("Skipping enhancement, invalid crop dimensions after padding.")
274
- else:
275
- gr.Info("Face restoration model not available, enhancement skipped by process.")
276
- logging.warning("Enhancement requested at runtime but face restorer is not available.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
- if apply_color_correction:
279
- progress(0.75, desc="Applying color correction...")
280
- # For color correction, we use the original target face region from `target_np` as the reference.
281
- # The region to correct is the same bounding box in `swapped_bgr_img`.
282
- cc_x1 = bbox_coords[0]
283
- cc_y1 = bbox_coords[1]
284
- cc_x2 = bbox_coords[2]
285
- cc_y2 = bbox_coords[3]
286
-
287
- if cc_x1 < cc_x2 and cc_y1 < cc_y2: # Valid bbox
288
- target_face_region_for_color = target_np[cc_y1:cc_y2, cc_x1:cc_x2]
289
- swapped_face_region_to_correct = swapped_bgr_img[cc_y1:cc_y2, cc_x1:cc_x2]
290
-
291
- if target_face_region_for_color.size > 0 and swapped_face_region_to_correct.size > 0:
292
- corrected_swapped_region = histogram_match_color(swapped_face_region_to_correct.copy(), target_face_region_for_color.copy())
293
- if corrected_swapped_region.shape == swapped_face_region_to_correct.shape:
294
- swapped_bgr_img[cc_y1:cc_y2, cc_x1:cc_x2] = corrected_swapped_region
295
- else:
296
- logging.warning("Color corrected region size mismatch. Skipping paste-back for color correction.")
297
  else:
298
- logging.warning("Skipping color correction, empty face region(s) for matching.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  else:
300
- logging.warning("Skipping color correction, invalid bounding box for region extraction.")
 
 
 
 
 
301
 
302
 
 
 
303
  except Exception as e:
304
- logging.error(f"Error during face swapping or post-processing: {e}", exc_info=True)
305
- # Return the current state of swapped_bgr_img to avoid losing partial work if possible
306
- swapped_pil_img_on_error = convert_cv2_to_pil(swapped_bgr_img)
307
- if swapped_pil_img_on_error is None: # Fallback if conversion itself fails
 
308
  swapped_pil_img_on_error = Image.new('RGB', (target_w, target_h), color='lightgrey')
309
  raise gr.Error(f"An error occurred: {str(e)}")
310
- # return swapped_pil_img_on_error, None # Alternative for UI stability
311
 
312
  progress(0.9, desc="Finalizing image...")
313
  swapped_pil_img = convert_cv2_to_pil(swapped_bgr_img)
314
- if swapped_pil_img is None: # Handle case where final conversion fails
315
  gr.Error("Failed to convert final image to display format.")
316
- swapped_pil_img = Image.new('RGB', (target_w, target_h), color='lightgrey') # Placeholder
317
 
318
  temp_file_path = None
319
  try:
320
- # Using 'with' ensures the file is closed before Gradio tries to use it
321
  with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False, mode="wb") as tmp_file:
322
- swapped_pil_img.save(tmp_file, format="JPEG") # Specify format for PIL save
323
  temp_file_path = tmp_file.name
324
  logging.info(f"Swapped image saved to temporary file: {temp_file_path}")
325
  except Exception as e:
@@ -329,181 +433,131 @@ def process_face_swap(source_pil_img: Image.Image, target_pil_img: Image.Image,
329
  progress(1.0, desc="Processing complete!")
330
  return swapped_pil_img, temp_file_path
331
 
332
- def preview_target_faces(target_pil_img: Image.Image):
 
333
  if target_pil_img is None:
334
  blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
335
  return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
336
-
337
  target_np = convert_pil_to_cv2(target_pil_img)
338
- if target_np is None: # If conversion failed
339
  blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
340
  return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
341
-
342
  faces = get_faces_from_image(target_np)
343
- if not faces: # If no faces are detected
344
- preview_pil_img = convert_cv2_to_pil(target_np) # Show original image if no faces
345
- if preview_pil_img is None: # Fallback
346
  preview_pil_img = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
347
  return preview_pil_img, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
348
-
349
  preview_np_img = draw_detected_faces(target_np, faces)
350
  preview_pil_img = convert_cv2_to_pil(preview_np_img)
351
- if preview_pil_img is None: # Fallback
352
  preview_pil_img = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
353
-
354
-
355
  num_faces = len(faces)
356
  slider_update = gr.Slider(minimum=0, maximum=max(0, num_faces - 1), value=0, step=1, interactive=(num_faces > 0))
357
  return preview_pil_img, slider_update
358
 
359
-
360
- # --- Gradio UI Definition ---
361
- with gr.Blocks(title="Ultimate Face Swap AI πŸ’Ž v3", theme=gr.themes.Soft()) as demo: # Changed theme for variety
362
  gr.Markdown(
363
  """
364
  <div style="text-align: center;">
365
- <h1>πŸ’Ž Ultimate Face Swap AI 🌠 v3</h1>
366
- <p>Upload a source face, a target image, and let the AI work its magic!</p>
367
- <p>Optionally enhance with face restoration and color correction for stunning realism.</p>
368
  </div>
369
  """
370
  )
371
 
372
- if not core_models_loaded_successfully:
373
- gr.Error("CRITICAL ERROR: Core models (Face Analyzer or Swapper) failed to load. Application will not function correctly. Please check the console logs for details (e.g., model file paths) and restart the application after fixing.")
374
 
375
  with gr.Row():
376
  with gr.Column(scale=1):
377
- source_image_input = gr.Image(label="πŸ‘€ Source Face Image (Clear, single face)", type="pil", sources=["upload", "clipboard"], height=350)
378
  with gr.Column(scale=1):
379
  target_image_input = gr.Image(label="πŸ–ΌοΈ Target Scene Image", type="pil", sources=["upload", "clipboard"], height=350)
380
 
381
  with gr.Row(equal_height=True):
382
  preview_button = gr.Button("πŸ” Preview & Select Target Face", variant="secondary")
383
- face_index_slider = gr.Slider(
384
- label="🎯 Select Target Face (0-indexed)",
385
- minimum=0, maximum=0, step=1, value=0, interactive=False
386
- )
387
 
388
  target_faces_preview_output = gr.Image(label="πŸ‘€ Detected Faces in Target", interactive=False, height=350)
389
-
390
- gr.HTML("<hr style='margin-top: 15px; margin-bottom: 15px;'>") # Styled HR
391
 
392
  with gr.Row():
393
  with gr.Column(scale=1):
394
- enhance_checkbox_label = "✨ Apply Selective Face Restoration"
395
- if not restoration_model_loaded_successfully:
396
- enhance_checkbox_label += " (Model N/A)"
397
- enhance_checkbox = gr.Checkbox(
398
- label=enhance_checkbox_label,
399
- value=restoration_model_loaded_successfully, # Default to True only if available
400
- interactive=restoration_model_loaded_successfully # Disable if not available
401
- )
402
- if not restoration_model_loaded_successfully:
403
- gr.Markdown("<p style='color: orange; font-size:0.8em;'>⚠️ Face restoration model not loaded. Feature disabled. Check logs.</p>")
404
 
405
  with gr.Column(scale=1):
406
- color_correction_checkbox = gr.Checkbox(label="🎨 Apply Color Correction (Histogram Matching)", value=True) # Default to True
407
 
408
-
409
  with gr.Row():
410
- swap_button = gr.Button("πŸš€ GENERATE SWAP!", variant="primary", scale=3, interactive=core_models_loaded_successfully) # Disable if core models failed
411
  clear_button = gr.Button("🧹 Clear All", variant="stop", scale=1)
412
 
413
-
414
  with gr.Row():
415
  swapped_image_output = gr.Image(label="✨ Swapped Result", interactive=False, height=450)
416
  download_output_file = gr.File(label="⬇️ Download Swapped Image")
417
 
418
- # --- Event Handlers ---
419
  def on_target_image_change_or_clear(target_img_pil):
420
- if target_img_pil is None: # Cleared
421
- blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color = 'lightgray')
422
  return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
423
- # If an image is uploaded, preview_target_faces will be called by the button click.
424
- # This function is mainly for handling the clear/reset case or an initial state.
425
- # For safety, return current values if not cleared, or trigger preview.
426
- # Let's just return blank on clear, and let preview button handle new image.
427
- return target_faces_preview_output.value, face_index_slider.value # Keep current state on new image upload until preview
428
-
429
- target_image_input.change(
430
- fn=on_target_image_change_or_clear,
431
- inputs=[target_image_input],
432
- outputs=[target_faces_preview_output, face_index_slider],
433
- queue=False
434
- )
435
-
436
- preview_button.click(
437
- fn=preview_target_faces,
438
- inputs=[target_image_input],
439
- outputs=[target_faces_preview_output, face_index_slider]
440
- )
441
 
442
  swap_button.click(
443
  fn=process_face_swap,
444
- inputs=[source_image_input, target_image_input, face_index_slider, enhance_checkbox, color_correction_checkbox],
 
 
 
 
445
  outputs=[swapped_image_output, download_output_file]
446
  )
447
 
448
  def clear_all_inputs_outputs():
449
  blank_preview = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color = 'lightgray')
450
- # Also reset the slider and preview output properly
451
- return (
452
- None, # source_image_input
453
- None, # target_image_input
454
- blank_preview, # target_faces_preview_output
455
- gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False), # face_index_slider
456
- None, # swapped_image_output
457
- None # download_output_file
458
- )
459
-
460
- clear_button.click(
461
- fn=clear_all_inputs_outputs,
462
- inputs=None, # No inputs for this function
463
- outputs=[
464
- source_image_input, target_image_input,
465
- target_faces_preview_output, face_index_slider,
466
- swapped_image_output, download_output_file
467
- ],
468
- queue=False # No need to queue a simple clear
469
- )
470
 
471
  gr.Examples(
472
  examples=[
473
- ["examples/source_face.jpg", "examples/target_group.jpg", 0, True, True],
474
- ["examples/source_actor.png", "examples/target_scene.png", 1, True, True],
475
- ["examples/source_face.jpg", "examples/target_group.jpg", 0, False, True], # No enhancement, with color correction
476
- ["examples/source_actor.png", "examples/target_scene.png", 0, True, False], # With enhancement, no color correction
477
  ],
478
- inputs=[source_image_input, target_image_input, face_index_slider, enhance_checkbox, color_correction_checkbox],
479
  outputs=[swapped_image_output, download_output_file],
480
- fn=process_face_swap, # This will also show progress for examples
481
- cache_examples=False, # Set to "lazy" or True if processing is slow and examples are static
482
  label="Example Face Swaps (Click to run)"
483
  )
484
 
485
  # --- Main Execution Block ---
486
  if __name__ == "__main__":
487
  os.makedirs("models", exist_ok=True)
488
- os.makedirs("examples", exist_ok=True) # Create if you intend to use example images
489
 
490
- # Console messages about model loading status
491
  print("\n" + "="*60)
492
- print("πŸš€ ULTIMATE FACE SWAP AI - STARTUP STATUS πŸš€")
493
  print("="*60)
494
  if not core_models_loaded_successfully:
495
- print("πŸ”΄ CRITICAL ERROR: Core models (Face Analyzer or Swapper) FAILED to load.")
496
  print(f" - Face Analyzer: '{FACE_ANALYZER_NAME}' (Status: {'Loaded' if face_analyzer else 'Failed'})")
497
- print(f" - Swapper Model: '{SWAPPER_MODEL_PATH}' (Status: {'Loaded' if swapper else 'Failed'})")
498
  print(" The application UI will load but will be NON-FUNCTIONAL.")
499
- print(" Please check model paths and ensure files exist and are not corrupted.")
500
  else:
501
- print("🟒 Core models (Face Analyzer & Swapper) loaded successfully.")
502
 
503
  if not restoration_model_loaded_successfully:
504
- print(f"🟑 INFO: Face Restoration model ('{RESTORATION_MODEL_PATH}') not loaded.")
505
- print(" The 'Face Restoration' feature will be disabled.")
506
- print(" To enable, ensure the model file exists at the specified path and is valid.")
507
  else:
508
  print("🟒 Face Restoration model loaded successfully.")
509
  print("="*60 + "\n")
 
2
  import cv2
3
  import numpy as np
4
  import gradio as gr
5
+ from insightface.app import FaceAnalysis # Still needed for detection and embeddings
6
+ # from insightface.model_zoo import get_model # No longer using this for the swapper
7
  from PIL import Image
8
  import tempfile
9
  import logging
 
12
  # --- Configuration & Setup ---
13
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
 
15
+ # MODIFIED: Path to the new swapper model
16
+ SWAPPER_MODEL_PATH = "models/reswapper_256.onnx" # <--- USING RESWAPPER_256
17
+ RESTORATION_MODEL_PATH = "models/gfpgan_1.4.onnx" # Or your chosen restorer, e.g., restoreformer++.onnx
18
+ # UPSCALER_MODEL_PATH = "models/RealESRGANx4plus.onnx" # Path for AI Upscaler if you were to add it
19
 
20
+ FACE_ANALYZER_NAME = 'buffalo_l' # InsightFace detection and embedding model
21
+ DETECTION_SIZE = (640, 640)
22
+ EXECUTION_PROVIDERS = ['CPUExecutionProvider'] # Or ['CUDAExecutionProvider', 'CPUExecutionProvider']
23
 
24
  # --- Global Variables (Lazy Loaded by initialize_models) ---
25
  face_analyzer = None
26
+ reswapper_session = None # MODIFIED: Will hold the onnxruntime session for reswapper
27
  face_restorer = None
28
+ # face_upscaler = None # If you add AI upscaling
29
 
30
  # --- Initialization Functions ---
31
  def initialize_models():
32
+ global face_analyzer, reswapper_session, face_restorer # MODIFIED: reswapper_session
33
  try:
34
+ # Initialize FaceAnalysis model (for detection and embeddings)
35
  if face_analyzer is None:
36
  logging.info(f"Initializing FaceAnalysis model: {FACE_ANALYZER_NAME}")
37
  face_analyzer = FaceAnalysis(name=FACE_ANALYZER_NAME, providers=EXECUTION_PROVIDERS)
38
+ face_analyzer.prepare(ctx_id=0, det_size=DETECTION_SIZE)
39
  logging.info("FaceAnalysis model initialized successfully.")
40
 
41
+ # MODIFIED: Initialize ReSwapper model directly with onnxruntime
42
+ if reswapper_session is None:
43
  if not os.path.exists(SWAPPER_MODEL_PATH):
44
+ logging.error(f"ReSwapper model FILE NOT FOUND at {SWAPPER_MODEL_PATH}. Swapping will fail.")
45
  else:
46
+ logging.info(f"Loading ReSwapper model from: {SWAPPER_MODEL_PATH}")
47
  try:
48
+ reswapper_session = onnxruntime.InferenceSession(SWAPPER_MODEL_PATH, providers=EXECUTION_PROVIDERS)
49
+ logging.info(f"ReSwapper model ({SWAPPER_MODEL_PATH}) loaded successfully with onnxruntime.")
50
+ except Exception as e:
51
+ logging.error(f"Error loading ReSwapper model {SWAPPER_MODEL_PATH} with onnxruntime: {e}", exc_info=True)
52
+ reswapper_session = None # Ensure it's None if loading failed
53
+
54
+ # Initialize Face Restoration Model (same as before)
55
+ if face_restorer is None:
 
56
  if not os.path.exists(RESTORATION_MODEL_PATH):
57
+ logging.error(f"Face restoration model FILE NOT FOUND at: {RESTORATION_MODEL_PATH}. Enhancement will be disabled.")
58
  else:
59
+ # (loading logic for face_restorer as in previous complete code)
60
  logging.info(f"Attempting to load face restoration model from: {RESTORATION_MODEL_PATH}")
61
  try:
62
  face_restorer = onnxruntime.InferenceSession(RESTORATION_MODEL_PATH, providers=EXECUTION_PROVIDERS)
63
  logging.info("Face restoration model loaded successfully.")
64
  except Exception as e:
65
  logging.error(f"Error loading face restoration model from {RESTORATION_MODEL_PATH}: {e}. Enhancement feature will be disabled.", exc_info=True)
66
+ face_restorer = None
67
+
68
+ # (Initialize AI Face Upscaler Model - if you were adding it)
69
+
70
  except Exception as e:
71
  logging.error(f"A critical error occurred during model initialization: {e}", exc_info=True)
 
72
 
73
  # --- Call Initialization Early ---
74
  initialize_models()
75
+ core_models_loaded_successfully = face_analyzer is not None and reswapper_session is not None # MODIFIED
 
76
  restoration_model_loaded_successfully = face_restorer is not None
77
+ # upscaler_model_loaded_successfully = face_upscaler is not None # If adding upscaler
78
 
79
+ # --- Image Utility Functions (convert_pil_to_cv2, convert_cv2_to_pil - same as before) ---
80
  def convert_pil_to_cv2(pil_image: Image.Image) -> np.ndarray | None:
81
  if pil_image is None: return None
82
  try:
 
93
  logging.error(f"Error converting CV2 to PIL: {e}")
94
  return None
95
 
96
+ # --- Core AI & Image Processing Functions (get_faces_from_image, draw_detected_faces, enhance_cropped_face, histogram_match_color, apply_naturalness_filters - mostly same) ---
97
+ # These functions (except the main swap logic) should largely remain the same.
98
+ # Make sure they are copied from your previous complete version.
99
+ # For brevity here, I'll assume they are present.
100
+ # Example of get_faces_from_image:
101
  def get_faces_from_image(img_np: np.ndarray) -> list:
102
+ if face_analyzer is None: # Check against the global face_analyzer
 
103
  logging.error("Face analyzer not available for get_faces_from_image.")
104
  return []
105
  if img_np is None: return []
106
  try:
107
+ return face_analyzer.get(img_np) # This provides Face objects with embeddings
108
  except Exception as e:
109
  logging.error(f"Error during face detection: {e}", exc_info=True)
110
  return []
111
 
112
+ # ... (draw_detected_faces, enhance_cropped_face, histogram_match_color, apply_naturalness_filters functions from previous full code)
113
+ def draw_detected_faces(img_np: np.ndarray, faces: list) -> np.ndarray: # Copied for completeness
114
  img_with_boxes = img_np.copy()
115
  for i, face in enumerate(faces):
116
+ box = face.bbox.astype(int)
117
  x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
118
  cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), (0, 255, 0), 2)
119
+ label_position = (x1, max(0, y1 - 10))
120
  cv2.putText(img_with_boxes, f"Face {i}", label_position,
121
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (36, 255, 12), 2)
122
  return img_with_boxes
123
 
124
+ def enhance_cropped_face(face_crop_bgr: np.ndarray) -> np.ndarray: # Copied for completeness
125
  if not restoration_model_loaded_successfully or face_restorer is None:
126
  logging.warning("Face restorer model not available. Skipping enhancement for crop.")
127
  return face_crop_bgr
128
  if face_crop_bgr is None or face_crop_bgr.shape[0] == 0 or face_crop_bgr.shape[1] == 0:
129
  logging.warning("Received empty or invalid face crop for enhancement.")
130
  return face_crop_bgr if face_crop_bgr is not None else np.array([])
 
 
131
  logging.info(f"Applying face restoration to crop of size {face_crop_bgr.shape[:2]}...")
132
  crop_height, crop_width = face_crop_bgr.shape[:2]
 
133
  try:
134
  img_rgb = cv2.cvtColor(face_crop_bgr, cv2.COLOR_BGR2RGB)
135
+ restorer_input_size = (512, 512)
 
 
136
  img_resized_for_model = cv2.resize(img_rgb, restorer_input_size, interpolation=cv2.INTER_AREA)
137
+ img_normalized = (img_resized_for_model / 255.0).astype(np.float32)
138
+ img_chw = np.transpose(img_normalized, (2, 0, 1))
139
+ input_tensor = np.expand_dims(img_chw, axis=0)
 
 
140
  input_name = face_restorer.get_inputs()[0].name
141
  output_name = face_restorer.get_outputs()[0].name
 
142
  restored_output_model_size = face_restorer.run([output_name], {input_name: input_tensor})[0]
 
143
  restored_img_chw = np.squeeze(restored_output_model_size, axis=0)
144
+ restored_img_hwc_model_size = np.transpose(restored_img_chw, (1, 2, 0))
145
  restored_img_uint8_model_size = np.clip(restored_img_hwc_model_size * 255.0, 0, 255).astype(np.uint8)
 
 
146
  restored_crop_rgb = cv2.resize(restored_img_uint8_model_size, (crop_width, crop_height), interpolation=cv2.INTER_LANCZOS4)
147
  restored_crop_bgr = cv2.cvtColor(restored_crop_rgb, cv2.COLOR_RGB2BGR)
 
148
  logging.info("Cropped face restoration complete.")
149
  return restored_crop_bgr
150
  except Exception as e:
151
  logging.error(f"Error during face restoration for crop: {e}", exc_info=True)
152
+ return face_crop_bgr
153
 
154
+ def histogram_match_channel(source_channel: np.ndarray, target_channel: np.ndarray) -> np.ndarray: # Copied
 
155
  source_shape = source_channel.shape
156
  source_channel_flat = source_channel.flatten()
157
  target_channel_flat = target_channel.flatten()
158
+ source_hist, bins = np.histogram(source_channel_flat, 256, [0,256], density=True) # Use density for CDF
159
+ target_hist, bins = np.histogram(target_channel_flat, 256, [0,256], density=True)
 
160
  source_cdf = source_hist.cumsum()
 
 
 
161
  target_cdf = target_hist.cumsum()
162
+ # Normalize CDFs to range [0, 255] for lookup table
163
+ source_cdf_norm = (source_cdf * 255 / source_cdf[-1]).astype(np.uint8) # Ensure last val is 255
164
+ target_cdf_norm = (target_cdf * 255 / target_cdf[-1]).astype(np.uint8)
 
 
 
165
  lookup_table = np.zeros(256, dtype='uint8')
166
+ idx = 0
167
+ for i in range(256):
168
+ while idx < 255 and target_cdf_norm[idx] < source_cdf_norm[i]:
169
+ idx += 1
170
+ lookup_table[i] = idx
171
+ matched_channel_flat = cv2.LUT(source_channel, lookup_table)
 
 
 
172
  return matched_channel_flat.reshape(source_shape)
173
 
174
+ def histogram_match_color(source_img: np.ndarray, target_img: np.ndarray) -> np.ndarray: # Copied
 
 
 
 
175
  if source_img is None or target_img is None or source_img.size == 0 or target_img.size == 0:
176
  logging.warning("Cannot perform histogram matching on empty or None images.")
177
  return source_img if source_img is not None else np.array([])
 
178
  matched_img = np.zeros_like(source_img)
179
  try:
180
+ for i in range(source_img.shape[2]):
181
  matched_img[:,:,i] = histogram_match_channel(source_img[:,:,i], target_img[:,:,i])
182
  return matched_img
183
  except Exception as e:
184
  logging.error(f"Error during histogram matching: {e}", exc_info=True)
185
+ return source_img
186
 
187
+ def apply_naturalness_filters(face_region_bgr: np.ndarray, noise_level: int = 3) -> np.ndarray: # Copied
188
+ if face_region_bgr is None or face_region_bgr.size == 0:
189
+ logging.warning("Cannot apply naturalness filters to empty or None region.")
190
+ return face_region_bgr if face_region_bgr is not None else np.array([])
191
+ processed_region = face_region_bgr.copy()
192
+ try:
193
+ processed_region = cv2.medianBlur(processed_region, 3)
194
+ if noise_level > 0:
195
+ gaussian_noise = np.random.normal(0, noise_level, processed_region.shape).astype(np.int16)
196
+ noisy_region_int16 = processed_region.astype(np.int16) + gaussian_noise
197
+ processed_region = np.clip(noisy_region_int16, 0, 255).astype(np.uint8)
198
+ logging.info("Applied naturalness filters to face region.")
199
+ except Exception as e:
200
+ logging.error(f"Error applying naturalness filters: {e}", exc_info=True)
201
+ return face_region_bgr
202
+ return processed_region
203
 
204
+ # --- Main Processing Function ---
205
  def process_face_swap(source_pil_img: Image.Image, target_pil_img: Image.Image,
206
+ target_face_index: int, # apply_ai_upscaling: bool, (Removed for this example to focus on ReSwapper)
207
+ apply_enhancement: bool, apply_color_correction: bool,
208
+ apply_naturalness: bool,
209
  progress=gr.Progress(track_tqdm=True)):
210
+
211
  progress(0, desc="Initializing process...")
212
+ if not core_models_loaded_successfully: # Checks face_analyzer and reswapper_session
213
+ gr.Error("CRITICAL: Core models (Face Analyzer or ReSwapper) not loaded. Cannot proceed.")
214
+ return Image.new('RGB', (100,100), color='lightgrey'), None
 
215
 
216
  if source_pil_img is None: raise gr.Error("Source image not provided.")
217
  if target_pil_img is None: raise gr.Error("Target image not provided.")
218
 
219
  progress(0.05, desc="Converting images...")
220
+ source_np = convert_pil_to_cv2(source_pil_img) # BGR
221
+ target_np = convert_pil_to_cv2(target_pil_img) # BGR
 
222
  if source_np is None or target_np is None:
223
  raise gr.Error("Image conversion failed. Please try different images.")
224
 
 
226
 
227
  progress(0.15, desc="Detecting face in source image...")
228
  source_faces = get_faces_from_image(source_np)
229
+ if not source_faces: raise gr.Error("No face found in the source image.")
230
+ source_face = source_faces[0] # InsightFace Face object
231
 
232
  progress(0.25, desc="Detecting faces in target image...")
233
  target_faces = get_faces_from_image(target_np)
234
  if not target_faces: raise gr.Error("No faces found in the target image.")
235
  if not (0 <= target_face_index < len(target_faces)):
236
+ raise gr.Error(f"Selected target face index ({target_face_index}) is out of range.")
237
+ target_face_to_swap_info = target_faces[int(target_face_index)] # InsightFace Face object
238
+
239
+ swapped_bgr_img = target_np.copy() # Initialize with a copy
 
240
 
 
241
  try:
242
+ progress(0.4, desc="Preparing inputs for ReSwapper...")
 
 
 
 
 
 
 
243
 
244
+ ########## USER VERIFICATION NEEDED FOR RESWAPPER_256.ONNX ##########
245
+ # 1. Source Face Embedding Preparation
246
+ source_embedding = source_face.embedding.astype(np.float32) # (512,)
247
+ # Most ONNX models expect a batch dimension.
248
+ source_embedding_tensor = np.expand_dims(source_embedding, axis=0) # (1, 512)
249
+
250
+ # 2. Target Image Preparation
251
+ # ASSUMPTION: reswapper_256.onnx might expect a fixed size input, e.g., 256x256.
252
+ # If it does, you MUST resize target_np to that size.
253
+ # For this example, let's assume it takes the target image and processes internally
254
+ # or expects a specific format. A common pattern is RGB, normalized, CHW.
255
+ # Let's assume input size of 256x256 for the target image input to the model.
256
+ # We will process the whole target_np, let the model do its work, and it should return the modified whole target_np.
257
+
258
+ target_img_for_onnx = cv2.cvtColor(target_np, cv2.COLOR_BGR2RGB) # Often RGB for ONNX
259
+
260
+ # RESIZE THE TARGET IMAGE IF THE MODEL EXPECTS A FIXED INPUT SIZE
261
+ # E.g., if reswapper_256.onnx expects 256x256 input for the *target image itself*:
262
+ # target_img_for_onnx = cv2.resize(target_img_for_onnx, (256, 256), interpolation=cv2.INTER_AREA)
263
+ # This is a CRITICAL assumption. If it operates on full-res target, no resize needed here.
264
+ # If the "256" in reswapper_256 refers to an internal patch size, it might handle varied input.
265
+ # For now, let's NOT resize the full target image, assuming the model handles it.
266
+
267
+ target_img_normalized = (target_img_for_onnx / 255.0).astype(np.float32) # Normalize [0, 1]
268
+ # Or normalize to [-1, 1]: (target_img_for_onnx / 127.5 - 1.0).astype(np.float32)
269
+
270
+ target_img_chw = np.transpose(target_img_normalized, (2, 0, 1)) # HWC to CHW
271
+ target_image_tensor = np.expand_dims(target_img_chw, axis=0) # Add batch dimension
272
+
273
+ # 3. Target Face Bounding Box (Optional, but many swappers need it)
274
+ # The InsightFace Face object (target_face_to_swap_info) contains .bbox
275
+ # bbox is [x1, y1, x2, y2]. Normalize if model expects normalized bbox.
276
+ target_bbox_tensor = target_face_to_swap_info.bbox.astype(np.float32)
277
+ # Reshape if needed, e.g., (1, 4)
278
+ target_bbox_tensor = np.expand_dims(target_bbox_tensor, axis=0)
279
+
280
+
281
+ # Placeholder names for ONNX model inputs. YOU MUST FIND THE REAL NAMES.
282
+ # Common names could be "target", "source_embed", "target_face_box"
283
+ # Check model's .get_inputs() if loaded, or use Netron app to view ONNX.
284
+ onnx_input_names = [inp.name for inp in reswapper_session.get_inputs()]
285
+ # Crude assumption based on common patterns (highly likely to be WRONG):
286
+ # This is where you absolutely need the model's documentation.
287
+ # Let's assume: input0=target_image, input1=source_embedding, input2=target_bbox (OPTIONAL)
288
+ # This is a GUESS.
289
+
290
+ # For this example, let's try a simpler input signature if possible,
291
+ # like what an insightface swapper's .get() method might abstract.
292
+ # A common pattern for standalone ONNX swappers is:
293
+ # Input 1: Target image tensor (e.g., BCHW, float32, normalized)
294
+ # Input 2: Source embedding tensor (e.g., B x EmbDim, float32)
295
+ # Some might also take target face landmarks or a mask.
296
+
297
+ # Let's assume these input names based on typical conventions.
298
+ # YOU MUST VERIFY THESE using Netron or the model's documentation.
299
+ # input_feed = {
300
+ # "target_image": target_image_tensor, # Placeholder name
301
+ # "source_embedding": source_embedding_tensor # Placeholder name
302
+ # }
303
+ # If target_bbox is also an input:
304
+ # input_feed["target_bbox"] = target_bbox_tensor # Placeholder name
305
+
306
+ # A more robust way is to check input names from the model:
307
+ # input_name_target_img = reswapper_session.get_inputs()[0].name
308
+ # input_name_source_emb = reswapper_session.get_inputs()[1].name
309
+ # input_feed = {
310
+ # input_name_target_img: target_image_tensor,
311
+ # input_name_source_emb: source_embedding_tensor
312
+ # }
313
+ # if len(reswapper_session.get_inputs()) > 2:
314
+ # input_name_target_bbox = reswapper_session.get_inputs()[2].name
315
+ # input_feed[input_name_target_bbox] = target_bbox_tensor
316
+
317
+ # For this illustrative code, I will make up plausible input names.
318
+ # ***** REPLACE "actual_target_input_name", "actual_source_embed_input_name" etc. *****
319
+ # ***** WITH THE REAL TENSOR NAMES FROM reswapper_256.onnx. *****
320
+ # ***** Common names are often like 'input', 'target', 'source', 'embedding'.*****
321
+ # ***** Use Netron app to inspect your reswapper_256.onnx file! *****
322
+ input_feed = {
323
+ # Example, assuming first input is target image, second is source embedding
324
+ reswapper_session.get_inputs()[0].name: target_image_tensor,
325
+ reswapper_session.get_inputs()[1].name: source_embedding_tensor
326
+ }
327
+ # If your model takes a third input like target_bbox (very model specific):
328
+ if len(reswapper_session.get_inputs()) > 2: # Check if there's a third input
329
+ input_feed[reswapper_session.get_inputs()[2].name] = target_bbox_tensor
330
+
331
+
332
+ # Placeholder for output tensor name. YOU MUST FIND THE REAL NAME.
333
+ # output_name_placeholder = "output_image"
334
+ # More robust:
335
+ output_name = reswapper_session.get_outputs()[0].name
336
+ ########## END OF USER VERIFICATION NEEDED ##########
337
+
338
+ progress(0.5, desc="Running ReSwapper model...")
339
+ onnx_outputs = reswapper_session.run([output_name], input_feed)
340
+ swapped_image_tensor = onnx_outputs[0] # Assuming first output is the image
341
+
342
+ progress(0.6, desc="Post-processing ReSwapper output...")
343
+ # Post-process the output tensor back to a BGR image
344
+ # 1. Remove batch dimension if present
345
+ if swapped_image_tensor.ndim == 4:
346
+ swapped_image_tensor = np.squeeze(swapped_image_tensor, axis=0) # (C, H, W)
347
+ # 2. Transpose CHW to HWC
348
+ swapped_image_hwc = np.transpose(swapped_image_tensor, (1, 2, 0)) # (H, W, C)
349
+ # 3. Denormalize (inverse of the normalization applied earlier)
350
+ # If normalized to [0, 1]:
351
+ swapped_image_denormalized = swapped_image_hwc * 255.0
352
+ # If normalized to [-1, 1]:
353
+ # swapped_image_denormalized = (swapped_image_hwc + 1.0) * 127.5
354
+
355
+ swapped_image_uint8 = np.clip(swapped_image_denormalized, 0, 255).astype(np.uint8)
356
+
357
+ # 4. Convert RGB to BGR if model outputs RGB
358
+ swapped_bgr_img = cv2.cvtColor(swapped_image_uint8, cv2.COLOR_RGB2BGR)
359
+
360
+ # IF `reswapper_256.onnx` returned a resized image (e.g. 256x256),
361
+ # you might need to resize `swapped_bgr_img` back to original target_h, target_w.
362
+ if swapped_bgr_img.shape[0] != target_h or swapped_bgr_img.shape[1] != target_w:
363
+ logging.info(f"Resizing ReSwapper output from {swapped_bgr_img.shape[:2]} to {(target_h, target_w)}")
364
+ swapped_bgr_img = cv2.resize(swapped_bgr_img, (target_w, target_h), interpolation=cv2.INTER_LANCZOS4)
365
 
366
+ # --- Post-processing pipeline (Restoration, Color Correction, Naturalness) ---
367
+ # These will now operate on `swapped_bgr_img` which came from reswapper.
368
+ current_processed_image = swapped_bgr_img.copy()
369
+ bbox_coords = target_face_to_swap_info.bbox.astype(int) # [x1, y1, x2, y2] for ROI processing
370
+ roi_x1, roi_y1, roi_x2, roi_y2 = bbox_coords[0], bbox_coords[1], bbox_coords[2], bbox_coords[3]
371
+
372
+ face_roi_for_postprocessing = current_processed_image[roi_y1:roi_y2, roi_x1:roi_x2].copy()
373
+
374
+ if face_roi_for_postprocessing.size == 0:
375
+ logging.warning("ROI for post-processing is empty. Skipping post-processing steps.")
376
+ else:
377
+ if apply_enhancement:
378
+ if restoration_model_loaded_successfully:
379
+ progress(0.7, desc="Applying face restoration...")
380
+ face_roi_for_postprocessing = enhance_cropped_face(face_roi_for_postprocessing)
 
 
 
 
381
  else:
382
+ gr.Info("Face restoration model N/A, enhancement skipped.")
383
+
384
+ if apply_color_correction:
385
+ progress(0.8, desc="Applying color correction...")
386
+ target_face_region_for_color = target_np[roi_y1:roi_y2, roi_x1:roi_x2]
387
+ if target_face_region_for_color.size > 0:
388
+ face_roi_for_postprocessing = histogram_match_color(face_roi_for_postprocessing, target_face_region_for_color.copy())
389
+
390
+ if apply_naturalness:
391
+ progress(0.85, desc="Applying naturalness filters...")
392
+ face_roi_for_postprocessing = apply_naturalness_filters(face_roi_for_postprocessing)
393
+
394
+ # Paste the processed ROI back
395
+ if face_roi_for_postprocessing.shape[0] == (roi_y2 - roi_y1) and face_roi_for_postprocessing.shape[1] == (roi_x2 - roi_x1):
396
+ current_processed_image[roi_y1:roi_y2, roi_x1:roi_x2] = face_roi_for_postprocessing
397
  else:
398
+ logging.warning("Processed ROI size mismatch after post-processing. Attempting resize for paste.")
399
+ try:
400
+ resized_processed_roi = cv2.resize(face_roi_for_postprocessing, (roi_x2-roi_x1, roi_y2-roi_y1), interpolation=cv2.INTER_LANCZOS4)
401
+ current_processed_image[roi_y1:roi_y2, roi_x1:roi_x2] = resized_processed_roi
402
+ except Exception as e_resize:
403
+ logging.error(f"Failed to resize and paste processed ROI: {e_resize}")
404
 
405
 
406
+ swapped_bgr_img = current_processed_image # Final image after all steps
407
+
408
  except Exception as e:
409
+ logging.error(f"Error during main processing pipeline: {e}", exc_info=True)
410
+ # Fallback to display the original target image or last good state if error occurs mid-process
411
+ img_to_show_on_error = swapped_bgr_img # Use current state which might be partially processed
412
+ swapped_pil_img_on_error = convert_cv2_to_pil(img_to_show_on_error)
413
+ if swapped_pil_img_on_error is None:
414
  swapped_pil_img_on_error = Image.new('RGB', (target_w, target_h), color='lightgrey')
415
  raise gr.Error(f"An error occurred: {str(e)}")
 
416
 
417
  progress(0.9, desc="Finalizing image...")
418
  swapped_pil_img = convert_cv2_to_pil(swapped_bgr_img)
419
+ if swapped_pil_img is None:
420
  gr.Error("Failed to convert final image to display format.")
421
+ swapped_pil_img = Image.new('RGB', (target_w, target_h), color='lightgrey')
422
 
423
  temp_file_path = None
424
  try:
 
425
  with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False, mode="wb") as tmp_file:
426
+ swapped_pil_img.save(tmp_file, format="JPEG")
427
  temp_file_path = tmp_file.name
428
  logging.info(f"Swapped image saved to temporary file: {temp_file_path}")
429
  except Exception as e:
 
433
  progress(1.0, desc="Processing complete!")
434
  return swapped_pil_img, temp_file_path
435
 
436
+ # --- Gradio Preview Function (preview_target_faces - same as before) ---
437
+ def preview_target_faces(target_pil_img: Image.Image): # Copied for completeness
438
  if target_pil_img is None:
439
  blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
440
  return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
 
441
  target_np = convert_pil_to_cv2(target_pil_img)
442
+ if target_np is None:
443
  blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
444
  return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
 
445
  faces = get_faces_from_image(target_np)
446
+ if not faces:
447
+ preview_pil_img = convert_cv2_to_pil(target_np)
448
+ if preview_pil_img is None:
449
  preview_pil_img = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
450
  return preview_pil_img, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
 
451
  preview_np_img = draw_detected_faces(target_np, faces)
452
  preview_pil_img = convert_cv2_to_pil(preview_np_img)
453
+ if preview_pil_img is None:
454
  preview_pil_img = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
 
 
455
  num_faces = len(faces)
456
  slider_update = gr.Slider(minimum=0, maximum=max(0, num_faces - 1), value=0, step=1, interactive=(num_faces > 0))
457
  return preview_pil_img, slider_update
458
 
459
+ # --- Gradio UI Definition (Largely same, ensure inputs to process_face_swap are correct) ---
460
+ with gr.Blocks(title="ReSwapper AI Face Swap πŸ’Ž", theme=gr.themes.Soft()) as demo:
 
461
  gr.Markdown(
462
  """
463
  <div style="text-align: center;">
464
+ <h1>πŸ’Ž ReSwapper AI Face Swap 🌠</h1>
465
+ <p>Using ReSwapper_256.onnx. Upload a source face, a target image, and let the AI work!</p>
 
466
  </div>
467
  """
468
  )
469
 
470
+ if not core_models_loaded_successfully: # This now checks face_analyzer and reswapper_session
471
+ gr.Error("CRITICAL ERROR: Core models (Face Analyzer or ReSwapper) failed to load. Application will not function correctly. Please check console logs and restart.")
472
 
473
  with gr.Row():
474
  with gr.Column(scale=1):
475
+ source_image_input = gr.Image(label="πŸ‘€ Source Face Image", type="pil", sources=["upload", "clipboard"], height=350)
476
  with gr.Column(scale=1):
477
  target_image_input = gr.Image(label="πŸ–ΌοΈ Target Scene Image", type="pil", sources=["upload", "clipboard"], height=350)
478
 
479
  with gr.Row(equal_height=True):
480
  preview_button = gr.Button("πŸ” Preview & Select Target Face", variant="secondary")
481
+ face_index_slider = gr.Slider(label="🎯 Select Target Face (0-indexed)", minimum=0, maximum=0, step=1, value=0, interactive=False)
 
 
 
482
 
483
  target_faces_preview_output = gr.Image(label="πŸ‘€ Detected Faces in Target", interactive=False, height=350)
484
+ gr.HTML("<hr style='margin-top: 15px; margin-bottom: 15px;'>")
 
485
 
486
  with gr.Row():
487
  with gr.Column(scale=1):
488
+ enhance_checkbox_label = "✨ Apply Face Restoration"
489
+ if not restoration_model_loaded_successfully: enhance_checkbox_label += " (Model N/A)"
490
+ enhance_checkbox = gr.Checkbox(label=enhance_checkbox_label, value=restoration_model_loaded_successfully, interactive=restoration_model_loaded_successfully)
491
+ if not restoration_model_loaded_successfully: gr.Markdown("<p style='color: orange; font-size:0.8em;'>⚠️ Restoration model N/A.</p>")
492
+
493
+ with gr.Column(scale=1):
494
+ color_correction_checkbox = gr.Checkbox(label="🎨 Apply Color Correction", value=True)
 
 
 
495
 
496
  with gr.Column(scale=1):
497
+ naturalness_checkbox = gr.Checkbox(label="🌿 Apply Subtle Naturalness Filters", value=False)
498
 
 
499
  with gr.Row():
500
+ swap_button = gr.Button("πŸš€ GENERATE SWAP!", variant="primary", scale=3, interactive=core_models_loaded_successfully)
501
  clear_button = gr.Button("🧹 Clear All", variant="stop", scale=1)
502
 
 
503
  with gr.Row():
504
  swapped_image_output = gr.Image(label="✨ Swapped Result", interactive=False, height=450)
505
  download_output_file = gr.File(label="⬇️ Download Swapped Image")
506
 
507
+ # Event Handlers
508
  def on_target_image_change_or_clear(target_img_pil):
509
+ if target_img_pil is None:
510
+ blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
511
  return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
512
+ return target_faces_preview_output.value, face_index_slider.value
513
+
514
+ target_image_input.change(fn=on_target_image_change_or_clear, inputs=[target_image_input], outputs=[target_faces_preview_output, face_index_slider], queue=False)
515
+ preview_button.click(fn=preview_target_faces, inputs=[target_image_input], outputs=[target_faces_preview_output, face_index_slider])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
 
517
  swap_button.click(
518
  fn=process_face_swap,
519
+ inputs=[
520
+ source_image_input, target_image_input, face_index_slider,
521
+ enhance_checkbox, color_correction_checkbox, naturalness_checkbox
522
+ # Note: Removed apply_ai_upscaling from inputs list for this example as we focused on ReSwapper integration
523
+ ],
524
  outputs=[swapped_image_output, download_output_file]
525
  )
526
 
527
  def clear_all_inputs_outputs():
528
  blank_preview = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color = 'lightgray')
529
+ return (None, None, blank_preview, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False), None, None)
530
+
531
+ clear_button.click(fn=clear_all_inputs_outputs, inputs=None, outputs=[source_image_input, target_image_input, target_faces_preview_output, face_index_slider, swapped_image_output, download_output_file], queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
 
533
  gr.Examples(
534
  examples=[
535
+ ["examples/source_face.jpg", "examples/target_group.jpg", 0, True, True, True],
 
 
 
536
  ],
537
+ inputs=[source_image_input, target_image_input, face_index_slider, enhance_checkbox, color_correction_checkbox, naturalness_checkbox],
538
  outputs=[swapped_image_output, download_output_file],
539
+ fn=process_face_swap, cache_examples=False,
 
540
  label="Example Face Swaps (Click to run)"
541
  )
542
 
543
  # --- Main Execution Block ---
544
  if __name__ == "__main__":
545
  os.makedirs("models", exist_ok=True)
546
+ os.makedirs("examples", exist_ok=True)
547
 
 
548
  print("\n" + "="*60)
549
+ print("πŸš€ ReSwapper AI FACE SWAP - STARTUP STATUS πŸš€")
550
  print("="*60)
551
  if not core_models_loaded_successfully:
552
+ print(f"πŸ”΄ CRITICAL ERROR: Core models FAILED to load.")
553
  print(f" - Face Analyzer: '{FACE_ANALYZER_NAME}' (Status: {'Loaded' if face_analyzer else 'Failed'})")
554
+ print(f" - ReSwapper Model: '{SWAPPER_MODEL_PATH}' (Status: {'Loaded' if reswapper_session else 'Failed'})") # MODIFIED
555
  print(" The application UI will load but will be NON-FUNCTIONAL.")
 
556
  else:
557
+ print("🟒 Core models (Face Analyzer & ReSwapper) loaded successfully.")
558
 
559
  if not restoration_model_loaded_successfully:
560
+ print(f"🟑 INFO: Face Restoration model ('{RESTORATION_MODEL_PATH}') not loaded. Enhancement feature disabled.")
 
 
561
  else:
562
  print("🟒 Face Restoration model loaded successfully.")
563
  print("="*60 + "\n")