Athagi commited on
Commit
2e326a6
Β·
1 Parent(s): 4d077cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -327
app.py CHANGED
@@ -3,7 +3,6 @@ import cv2
3
  import numpy as np
4
  import gradio as gr
5
  from insightface.app import FaceAnalysis # Still needed for detection and embeddings
6
- # from insightface.model_zoo import get_model # No longer using this for the swapper
7
  from PIL import Image
8
  import tempfile
9
  import logging
@@ -12,33 +11,28 @@ import onnxruntime
12
  # --- Configuration & Setup ---
13
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
 
15
- # MODIFIED: Path to the new swapper model
16
- SWAPPER_MODEL_PATH = "models/reswapper_256.onnx" # <--- USING RESWAPPER_256
17
  RESTORATION_MODEL_PATH = "models/gfpgan_1.4.onnx" # Or your chosen restorer, e.g., restoreformer++.onnx
18
- # UPSCALER_MODEL_PATH = "models/RealESRGANx4plus.onnx" # Path for AI Upscaler if you were to add it
19
 
20
- FACE_ANALYZER_NAME = 'buffalo_l' # InsightFace detection and embedding model
21
  DETECTION_SIZE = (640, 640)
22
- EXECUTION_PROVIDERS = ['CPUExecutionProvider'] # Or ['CUDAExecutionProvider', 'CPUExecutionProvider']
23
 
24
- # --- Global Variables (Lazy Loaded by initialize_models) ---
25
  face_analyzer = None
26
- reswapper_session = None # MODIFIED: Will hold the onnxruntime session for reswapper
27
  face_restorer = None
28
- # face_upscaler = None # If you add AI upscaling
29
 
30
  # --- Initialization Functions ---
31
  def initialize_models():
32
- global face_analyzer, reswapper_session, face_restorer # MODIFIED: reswapper_session
33
  try:
34
- # Initialize FaceAnalysis model (for detection and embeddings)
35
  if face_analyzer is None:
36
  logging.info(f"Initializing FaceAnalysis model: {FACE_ANALYZER_NAME}")
37
  face_analyzer = FaceAnalysis(name=FACE_ANALYZER_NAME, providers=EXECUTION_PROVIDERS)
38
  face_analyzer.prepare(ctx_id=0, det_size=DETECTION_SIZE)
39
  logging.info("FaceAnalysis model initialized successfully.")
40
 
41
- # MODIFIED: Initialize ReSwapper model directly with onnxruntime
42
  if reswapper_session is None:
43
  if not os.path.exists(SWAPPER_MODEL_PATH):
44
  logging.error(f"ReSwapper model FILE NOT FOUND at {SWAPPER_MODEL_PATH}. Swapping will fail.")
@@ -49,14 +43,12 @@ def initialize_models():
49
  logging.info(f"ReSwapper model ({SWAPPER_MODEL_PATH}) loaded successfully with onnxruntime.")
50
  except Exception as e:
51
  logging.error(f"Error loading ReSwapper model {SWAPPER_MODEL_PATH} with onnxruntime: {e}", exc_info=True)
52
- reswapper_session = None # Ensure it's None if loading failed
53
-
54
- # Initialize Face Restoration Model (same as before)
55
  if face_restorer is None:
56
  if not os.path.exists(RESTORATION_MODEL_PATH):
57
- logging.error(f"Face restoration model FILE NOT FOUND at: {RESTORATION_MODEL_PATH}. Enhancement will be disabled.")
58
  else:
59
- # (loading logic for face_restorer as in previous complete code)
60
  logging.info(f"Attempting to load face restoration model from: {RESTORATION_MODEL_PATH}")
61
  try:
62
  face_restorer = onnxruntime.InferenceSession(RESTORATION_MODEL_PATH, providers=EXECUTION_PROVIDERS)
@@ -64,19 +56,14 @@ def initialize_models():
64
  except Exception as e:
65
  logging.error(f"Error loading face restoration model from {RESTORATION_MODEL_PATH}: {e}. Enhancement feature will be disabled.", exc_info=True)
66
  face_restorer = None
67
-
68
- # (Initialize AI Face Upscaler Model - if you were adding it)
69
-
70
  except Exception as e:
71
  logging.error(f"A critical error occurred during model initialization: {e}", exc_info=True)
72
 
73
- # --- Call Initialization Early ---
74
  initialize_models()
75
- core_models_loaded_successfully = face_analyzer is not None and reswapper_session is not None # MODIFIED
76
  restoration_model_loaded_successfully = face_restorer is not None
77
- # upscaler_model_loaded_successfully = face_upscaler is not None # If adding upscaler
78
 
79
- # --- Image Utility Functions (convert_pil_to_cv2, convert_cv2_to_pil - same as before) ---
80
  def convert_pil_to_cv2(pil_image: Image.Image) -> np.ndarray | None:
81
  if pil_image is None: return None
82
  try:
@@ -93,123 +80,109 @@ def convert_cv2_to_pil(cv2_image: np.ndarray) -> Image.Image | None:
93
  logging.error(f"Error converting CV2 to PIL: {e}")
94
  return None
95
 
96
- # --- Core AI & Image Processing Functions (get_faces_from_image, draw_detected_faces, enhance_cropped_face, histogram_match_color, apply_naturalness_filters - mostly same) ---
97
- # These functions (except the main swap logic) should largely remain the same.
98
- # Make sure they are copied from your previous complete version.
99
- # For brevity here, I'll assume they are present.
100
- # Example of get_faces_from_image:
101
  def get_faces_from_image(img_np: np.ndarray) -> list:
102
- if face_analyzer is None: # Check against the global face_analyzer
103
  logging.error("Face analyzer not available for get_faces_from_image.")
104
  return []
105
  if img_np is None: return []
106
  try:
107
- return face_analyzer.get(img_np) # This provides Face objects with embeddings
108
  except Exception as e:
109
  logging.error(f"Error during face detection: {e}", exc_info=True)
110
  return []
111
 
112
- # ... (draw_detected_faces, enhance_cropped_face, histogram_match_color, apply_naturalness_filters functions from previous full code)
113
- def draw_detected_faces(img_np: np.ndarray, faces: list) -> np.ndarray: # Copied for completeness
114
  img_with_boxes = img_np.copy()
115
  for i, face in enumerate(faces):
116
  box = face.bbox.astype(int)
117
  x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
118
  cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), (0, 255, 0), 2)
119
  label_position = (x1, max(0, y1 - 10))
120
- cv2.putText(img_with_boxes, f"Face {i}", label_position,
121
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (36, 255, 12), 2)
122
  return img_with_boxes
123
 
124
- def enhance_cropped_face(face_crop_bgr: np.ndarray) -> np.ndarray: # Copied for completeness
125
  if not restoration_model_loaded_successfully or face_restorer is None:
126
- logging.warning("Face restorer model not available. Skipping enhancement for crop.")
127
  return face_crop_bgr
128
- if face_crop_bgr is None or face_crop_bgr.shape[0] == 0 or face_crop_bgr.shape[1] == 0:
129
- logging.warning("Received empty or invalid face crop for enhancement.")
130
- return face_crop_bgr if face_crop_bgr is not None else np.array([])
131
  logging.info(f"Applying face restoration to crop of size {face_crop_bgr.shape[:2]}...")
132
- crop_height, crop_width = face_crop_bgr.shape[:2]
133
  try:
134
  img_rgb = cv2.cvtColor(face_crop_bgr, cv2.COLOR_BGR2RGB)
135
- restorer_input_size = (512, 512)
136
- img_resized_for_model = cv2.resize(img_rgb, restorer_input_size, interpolation=cv2.INTER_AREA)
137
- img_normalized = (img_resized_for_model / 255.0).astype(np.float32)
138
- img_chw = np.transpose(img_normalized, (2, 0, 1))
139
  input_tensor = np.expand_dims(img_chw, axis=0)
140
  input_name = face_restorer.get_inputs()[0].name
141
  output_name = face_restorer.get_outputs()[0].name
142
- restored_output_model_size = face_restorer.run([output_name], {input_name: input_tensor})[0]
143
- restored_img_chw = np.squeeze(restored_output_model_size, axis=0)
144
- restored_img_hwc_model_size = np.transpose(restored_img_chw, (1, 2, 0))
145
- restored_img_uint8_model_size = np.clip(restored_img_hwc_model_size * 255.0, 0, 255).astype(np.uint8)
146
- restored_crop_rgb = cv2.resize(restored_img_uint8_model_size, (crop_width, crop_height), interpolation=cv2.INTER_LANCZOS4)
147
- restored_crop_bgr = cv2.cvtColor(restored_crop_rgb, cv2.COLOR_RGB2BGR)
148
- logging.info("Cropped face restoration complete.")
149
- return restored_crop_bgr
150
  except Exception as e:
151
  logging.error(f"Error during face restoration for crop: {e}", exc_info=True)
152
  return face_crop_bgr
153
 
154
- def histogram_match_channel(source_channel: np.ndarray, target_channel: np.ndarray) -> np.ndarray: # Copied
155
- source_shape = source_channel.shape
156
- source_channel_flat = source_channel.flatten()
157
- target_channel_flat = target_channel.flatten()
158
- source_hist, bins = np.histogram(source_channel_flat, 256, [0,256], density=True) # Use density for CDF
159
- target_hist, bins = np.histogram(target_channel_flat, 256, [0,256], density=True)
160
- source_cdf = source_hist.cumsum()
161
- target_cdf = target_hist.cumsum()
162
- # Normalize CDFs to range [0, 255] for lookup table
163
- source_cdf_norm = (source_cdf * 255 / source_cdf[-1]).astype(np.uint8) # Ensure last val is 255
164
- target_cdf_norm = (target_cdf * 255 / target_cdf[-1]).astype(np.uint8)
165
- lookup_table = np.zeros(256, dtype='uint8')
166
- idx = 0
167
  for i in range(256):
168
- while idx < 255 and target_cdf_norm[idx] < source_cdf_norm[i]:
169
- idx += 1
170
- lookup_table[i] = idx
171
- matched_channel_flat = cv2.LUT(source_channel, lookup_table)
172
- return matched_channel_flat.reshape(source_shape)
173
 
174
- def histogram_match_color(source_img: np.ndarray, target_img: np.ndarray) -> np.ndarray: # Copied
175
  if source_img is None or target_img is None or source_img.size == 0 or target_img.size == 0:
176
- logging.warning("Cannot perform histogram matching on empty or None images.")
177
  return source_img if source_img is not None else np.array([])
178
- matched_img = np.zeros_like(source_img)
179
  try:
180
  for i in range(source_img.shape[2]):
181
- matched_img[:,:,i] = histogram_match_channel(source_img[:,:,i], target_img[:,:,i])
182
- return matched_img
183
  except Exception as e:
184
  logging.error(f"Error during histogram matching: {e}", exc_info=True)
185
  return source_img
186
 
187
- def apply_naturalness_filters(face_region_bgr: np.ndarray, noise_level: int = 3) -> np.ndarray: # Copied
188
  if face_region_bgr is None or face_region_bgr.size == 0:
189
- logging.warning("Cannot apply naturalness filters to empty or None region.")
190
  return face_region_bgr if face_region_bgr is not None else np.array([])
191
- processed_region = face_region_bgr.copy()
192
  try:
193
- processed_region = cv2.medianBlur(processed_region, 3)
194
  if noise_level > 0:
195
- gaussian_noise = np.random.normal(0, noise_level, processed_region.shape).astype(np.int16)
196
- noisy_region_int16 = processed_region.astype(np.int16) + gaussian_noise
197
- processed_region = np.clip(noisy_region_int16, 0, 255).astype(np.uint8)
198
- logging.info("Applied naturalness filters to face region.")
199
  except Exception as e:
200
  logging.error(f"Error applying naturalness filters: {e}", exc_info=True)
201
  return face_region_bgr
202
- return processed_region
203
 
204
  # --- Main Processing Function ---
205
  def process_face_swap(source_pil_img: Image.Image, target_pil_img: Image.Image,
206
- target_face_index: int, # apply_ai_upscaling: bool, (Removed for this example to focus on ReSwapper)
207
  apply_enhancement: bool, apply_color_correction: bool,
208
  apply_naturalness: bool,
209
  progress=gr.Progress(track_tqdm=True)):
210
 
211
  progress(0, desc="Initializing process...")
212
- if not core_models_loaded_successfully: # Checks face_analyzer and reswapper_session
213
  gr.Error("CRITICAL: Core models (Face Analyzer or ReSwapper) not loaded. Cannot proceed.")
214
  return Image.new('RGB', (100,100), color='lightgrey'), None
215
 
@@ -217,350 +190,265 @@ def process_face_swap(source_pil_img: Image.Image, target_pil_img: Image.Image,
217
  if target_pil_img is None: raise gr.Error("Target image not provided.")
218
 
219
  progress(0.05, desc="Converting images...")
220
- source_np = convert_pil_to_cv2(source_pil_img) # BGR
221
- target_np = convert_pil_to_cv2(target_pil_img) # BGR
222
  if source_np is None or target_np is None:
223
- raise gr.Error("Image conversion failed. Please try different images.")
224
 
225
- target_h, target_w = target_np.shape[:2]
226
 
227
- progress(0.15, desc="Detecting face in source image...")
228
  source_faces = get_faces_from_image(source_np)
229
- if not source_faces: raise gr.Error("No face found in the source image.")
230
- source_face = source_faces[0] # InsightFace Face object
231
 
232
- progress(0.25, desc="Detecting faces in target image...")
233
  target_faces = get_faces_from_image(target_np)
234
- if not target_faces: raise gr.Error("No faces found in the target image.")
235
  if not (0 <= target_face_index < len(target_faces)):
236
- raise gr.Error(f"Selected target face index ({target_face_index}) is out of range.")
237
- target_face_to_swap_info = target_faces[int(target_face_index)] # InsightFace Face object
238
 
239
- swapped_bgr_img = target_np.copy() # Initialize with a copy
240
 
241
  try:
242
- progress(0.4, desc="Preparing inputs for ReSwapper...")
243
 
244
- ########## USER VERIFICATION NEEDED FOR RESWAPPER_256.ONNX ##########
245
- # 1. Source Face Embedding Preparation
246
- source_embedding = source_face.embedding.astype(np.float32) # (512,)
247
- # Most ONNX models expect a batch dimension.
248
  source_embedding_tensor = np.expand_dims(source_embedding, axis=0) # (1, 512)
249
 
250
  # 2. Target Image Preparation
251
- # ASSUMPTION: reswapper_256.onnx might expect a fixed size input, e.g., 256x256.
252
- # If it does, you MUST resize target_np to that size.
253
- # For this example, let's assume it takes the target image and processes internally
254
- # or expects a specific format. A common pattern is RGB, normalized, CHW.
255
- # Let's assume input size of 256x256 for the target image input to the model.
256
- # We will process the whole target_np, let the model do its work, and it should return the modified whole target_np.
257
 
258
- target_img_for_onnx = cv2.cvtColor(target_np, cv2.COLOR_BGR2RGB) # Often RGB for ONNX
 
 
 
 
259
 
260
- # RESIZE THE TARGET IMAGE IF THE MODEL EXPECTS A FIXED INPUT SIZE
261
- # E.g., if reswapper_256.onnx expects 256x256 input for the *target image itself*:
262
- # target_img_for_onnx = cv2.resize(target_img_for_onnx, (256, 256), interpolation=cv2.INTER_AREA)
263
- # This is a CRITICAL assumption. If it operates on full-res target, no resize needed here.
264
- # If the "256" in reswapper_256 refers to an internal patch size, it might handle varied input.
265
- # For now, let's NOT resize the full target image, assuming the model handles it.
266
-
267
- target_img_normalized = (target_img_for_onnx / 255.0).astype(np.float32) # Normalize [0, 1]
268
- # Or normalize to [-1, 1]: (target_img_for_onnx / 127.5 - 1.0).astype(np.float32)
269
 
270
  target_img_chw = np.transpose(target_img_normalized, (2, 0, 1)) # HWC to CHW
271
  target_image_tensor = np.expand_dims(target_img_chw, axis=0) # Add batch dimension
272
 
273
- # 3. Target Face Bounding Box (Optional, but many swappers need it)
274
- # The InsightFace Face object (target_face_to_swap_info) contains .bbox
275
- # bbox is [x1, y1, x2, y2]. Normalize if model expects normalized bbox.
276
- target_bbox_tensor = target_face_to_swap_info.bbox.astype(np.float32)
277
- # Reshape if needed, e.g., (1, 4)
278
- target_bbox_tensor = np.expand_dims(target_bbox_tensor, axis=0)
279
-
280
-
281
- # Placeholder names for ONNX model inputs. YOU MUST FIND THE REAL NAMES.
282
- # Common names could be "target", "source_embed", "target_face_box"
283
- # Check model's .get_inputs() if loaded, or use Netron app to view ONNX.
284
- onnx_input_names = [inp.name for inp in reswapper_session.get_inputs()]
285
- # Crude assumption based on common patterns (highly likely to be WRONG):
286
- # This is where you absolutely need the model's documentation.
287
- # Let's assume: input0=target_image, input1=source_embedding, input2=target_bbox (OPTIONAL)
288
- # This is a GUESS.
 
 
 
 
 
 
 
 
289
 
290
- # For this example, let's try a simpler input signature if possible,
291
- # like what an insightface swapper's .get() method might abstract.
292
- # A common pattern for standalone ONNX swappers is:
293
- # Input 1: Target image tensor (e.g., BCHW, float32, normalized)
294
- # Input 2: Source embedding tensor (e.g., B x EmbDim, float32)
295
- # Some might also take target face landmarks or a mask.
296
-
297
- # Let's assume these input names based on typical conventions.
298
- # YOU MUST VERIFY THESE using Netron or the model's documentation.
299
- # input_feed = {
300
- # "target_image": target_image_tensor, # Placeholder name
301
- # "source_embedding": source_embedding_tensor # Placeholder name
302
- # }
303
- # If target_bbox is also an input:
304
- # input_feed["target_bbox"] = target_bbox_tensor # Placeholder name
305
-
306
- # A more robust way is to check input names from the model:
307
- # input_name_target_img = reswapper_session.get_inputs()[0].name
308
- # input_name_source_emb = reswapper_session.get_inputs()[1].name
309
- # input_feed = {
310
- # input_name_target_img: target_image_tensor,
311
- # input_name_source_emb: source_embedding_tensor
312
- # }
313
- # if len(reswapper_session.get_inputs()) > 2:
314
- # input_name_target_bbox = reswapper_session.get_inputs()[2].name
315
- # input_feed[input_name_target_bbox] = target_bbox_tensor
316
-
317
- # For this illustrative code, I will make up plausible input names.
318
- # ***** REPLACE "actual_target_input_name", "actual_source_embed_input_name" etc. *****
319
- # ***** WITH THE REAL TENSOR NAMES FROM reswapper_256.onnx. *****
320
- # ***** Common names are often like 'input', 'target', 'source', 'embedding'.*****
321
- # ***** Use Netron app to inspect your reswapper_256.onnx file! *****
322
- input_feed = {
323
- # Example, assuming first input is target image, second is source embedding
324
- reswapper_session.get_inputs()[0].name: target_image_tensor,
325
- reswapper_session.get_inputs()[1].name: source_embedding_tensor
326
- }
327
- # If your model takes a third input like target_bbox (very model specific):
328
- if len(reswapper_session.get_inputs()) > 2: # Check if there's a third input
329
- input_feed[reswapper_session.get_inputs()[2].name] = target_bbox_tensor
330
-
331
-
332
- # Placeholder for output tensor name. YOU MUST FIND THE REAL NAME.
333
- # output_name_placeholder = "output_image"
334
- # More robust:
335
- output_name = reswapper_session.get_outputs()[0].name
336
- ########## END OF USER VERIFICATION NEEDED ##########
337
 
338
  progress(0.5, desc="Running ReSwapper model...")
339
  onnx_outputs = reswapper_session.run([output_name], input_feed)
340
- swapped_image_tensor = onnx_outputs[0] # Assuming first output is the image
341
 
342
  progress(0.6, desc="Post-processing ReSwapper output...")
343
- # Post-process the output tensor back to a BGR image
344
- # 1. Remove batch dimension if present
345
- if swapped_image_tensor.ndim == 4:
346
- swapped_image_tensor = np.squeeze(swapped_image_tensor, axis=0) # (C, H, W)
347
- # 2. Transpose CHW to HWC
348
- swapped_image_hwc = np.transpose(swapped_image_tensor, (1, 2, 0)) # (H, W, C)
349
- # 3. Denormalize (inverse of the normalization applied earlier)
350
  # If normalized to [0, 1]:
351
  swapped_image_denormalized = swapped_image_hwc * 255.0
352
  # If normalized to [-1, 1]:
353
  # swapped_image_denormalized = (swapped_image_hwc + 1.0) * 127.5
354
 
355
- swapped_image_uint8 = np.clip(swapped_image_denormalized, 0, 255).astype(np.uint8)
356
 
357
- # 4. Convert RGB to BGR if model outputs RGB
358
- swapped_bgr_img = cv2.cvtColor(swapped_image_uint8, cv2.COLOR_RGB2BGR)
359
-
360
- # IF `reswapper_256.onnx` returned a resized image (e.g. 256x256),
361
- # you might need to resize `swapped_bgr_img` back to original target_h, target_w.
362
- if swapped_bgr_img.shape[0] != target_h or swapped_bgr_img.shape[1] != target_w:
363
- logging.info(f"Resizing ReSwapper output from {swapped_bgr_img.shape[:2]} to {(target_h, target_w)}")
364
- swapped_bgr_img = cv2.resize(swapped_bgr_img, (target_w, target_h), interpolation=cv2.INTER_LANCZOS4)
 
365
 
366
  # --- Post-processing pipeline (Restoration, Color Correction, Naturalness) ---
367
- # These will now operate on `swapped_bgr_img` which came from reswapper.
368
- current_processed_image = swapped_bgr_img.copy()
369
- bbox_coords = target_face_to_swap_info.bbox.astype(int) # [x1, y1, x2, y2] for ROI processing
370
- roi_x1, roi_y1, roi_x2, roi_y2 = bbox_coords[0], bbox_coords[1], bbox_coords[2], bbox_coords[3]
371
 
372
  face_roi_for_postprocessing = current_processed_image[roi_y1:roi_y2, roi_x1:roi_x2].copy()
373
 
374
- if face_roi_for_postprocessing.size == 0:
375
- logging.warning("ROI for post-processing is empty. Skipping post-processing steps.")
376
- else:
377
- if apply_enhancement:
378
- if restoration_model_loaded_successfully:
379
- progress(0.7, desc="Applying face restoration...")
380
- face_roi_for_postprocessing = enhance_cropped_face(face_roi_for_postprocessing)
381
- else:
382
- gr.Info("Face restoration model N/A, enhancement skipped.")
383
 
384
  if apply_color_correction:
385
  progress(0.8, desc="Applying color correction...")
386
- target_face_region_for_color = target_np[roi_y1:roi_y2, roi_x1:roi_x2]
387
- if target_face_region_for_color.size > 0:
388
- face_roi_for_postprocessing = histogram_match_color(face_roi_for_postprocessing, target_face_region_for_color.copy())
389
 
390
  if apply_naturalness:
391
  progress(0.85, desc="Applying naturalness filters...")
392
  face_roi_for_postprocessing = apply_naturalness_filters(face_roi_for_postprocessing)
393
 
394
  # Paste the processed ROI back
395
- if face_roi_for_postprocessing.shape[0] == (roi_y2 - roi_y1) and face_roi_for_postprocessing.shape[1] == (roi_x2 - roi_x1):
 
396
  current_processed_image[roi_y1:roi_y2, roi_x1:roi_x2] = face_roi_for_postprocessing
397
- else:
398
- logging.warning("Processed ROI size mismatch after post-processing. Attempting resize for paste.")
399
  try:
400
- resized_processed_roi = cv2.resize(face_roi_for_postprocessing, (roi_x2-roi_x1, roi_y2-roi_y1), interpolation=cv2.INTER_LANCZOS4)
401
- current_processed_image[roi_y1:roi_y2, roi_x1:roi_x2] = resized_processed_roi
402
- except Exception as e_resize:
403
- logging.error(f"Failed to resize and paste processed ROI: {e_resize}")
404
-
 
405
 
406
- swapped_bgr_img = current_processed_image # Final image after all steps
407
 
408
  except Exception as e:
409
  logging.error(f"Error during main processing pipeline: {e}", exc_info=True)
410
- # Fallback to display the original target image or last good state if error occurs mid-process
411
- img_to_show_on_error = swapped_bgr_img # Use current state which might be partially processed
412
- swapped_pil_img_on_error = convert_cv2_to_pil(img_to_show_on_error)
413
  if swapped_pil_img_on_error is None:
414
- swapped_pil_img_on_error = Image.new('RGB', (target_w, target_h), color='lightgrey')
415
- raise gr.Error(f"An error occurred: {str(e)}")
416
 
417
  progress(0.9, desc="Finalizing image...")
418
- swapped_pil_img = convert_cv2_to_pil(swapped_bgr_img)
419
  if swapped_pil_img is None:
420
- gr.Error("Failed to convert final image to display format.")
421
- swapped_pil_img = Image.new('RGB', (target_w, target_h), color='lightgrey')
422
 
423
  temp_file_path = None
424
  try:
425
  with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False, mode="wb") as tmp_file:
426
  swapped_pil_img.save(tmp_file, format="JPEG")
427
  temp_file_path = tmp_file.name
428
- logging.info(f"Swapped image saved to temporary file: {temp_file_path}")
429
  except Exception as e:
430
- logging.error(f"Error saving to temporary file: {e}", exc_info=True)
431
- gr.Warning("Could not save the swapped image for download.")
432
 
433
  progress(1.0, desc="Processing complete!")
434
  return swapped_pil_img, temp_file_path
435
 
436
- # --- Gradio Preview Function (preview_target_faces - same as before) ---
437
- def preview_target_faces(target_pil_img: Image.Image): # Copied for completeness
438
  if target_pil_img is None:
439
- blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
440
- return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
441
  target_np = convert_pil_to_cv2(target_pil_img)
442
  if target_np is None:
443
- blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
444
- return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
445
  faces = get_faces_from_image(target_np)
446
  if not faces:
447
- preview_pil_img = convert_cv2_to_pil(target_np)
448
- if preview_pil_img is None:
449
- preview_pil_img = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
450
- return preview_pil_img, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
451
- preview_np_img = draw_detected_faces(target_np, faces)
452
- preview_pil_img = convert_cv2_to_pil(preview_np_img)
453
- if preview_pil_img is None:
454
- preview_pil_img = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
455
- num_faces = len(faces)
456
- slider_update = gr.Slider(minimum=0, maximum=max(0, num_faces - 1), value=0, step=1, interactive=(num_faces > 0))
457
- return preview_pil_img, slider_update
458
-
459
- # --- Gradio UI Definition (Largely same, ensure inputs to process_face_swap are correct) ---
460
- with gr.Blocks(title="ReSwapper AI Face Swap πŸ’Ž", theme=gr.themes.Soft()) as demo:
461
  gr.Markdown(
462
  """
463
  <div style="text-align: center;">
464
- <h1>πŸ’Ž ReSwapper AI Face Swap 🌠</h1>
465
- <p>Using ReSwapper_256.onnx. Upload a source face, a target image, and let the AI work!</p>
466
  </div>
467
  """
468
  )
469
-
470
- if not core_models_loaded_successfully: # This now checks face_analyzer and reswapper_session
471
- gr.Error("CRITICAL ERROR: Core models (Face Analyzer or ReSwapper) failed to load. Application will not function correctly. Please check console logs and restart.")
472
 
473
  with gr.Row():
474
  with gr.Column(scale=1):
475
- source_image_input = gr.Image(label="πŸ‘€ Source Face Image", type="pil", sources=["upload", "clipboard"], height=350)
476
  with gr.Column(scale=1):
477
- target_image_input = gr.Image(label="πŸ–ΌοΈ Target Scene Image", type="pil", sources=["upload", "clipboard"], height=350)
478
-
479
  with gr.Row(equal_height=True):
480
  preview_button = gr.Button("πŸ” Preview & Select Target Face", variant="secondary")
481
- face_index_slider = gr.Slider(label="🎯 Select Target Face (0-indexed)", minimum=0, maximum=0, step=1, value=0, interactive=False)
482
-
483
  target_faces_preview_output = gr.Image(label="πŸ‘€ Detected Faces in Target", interactive=False, height=350)
484
- gr.HTML("<hr style='margin-top: 15px; margin-bottom: 15px;'>")
485
-
486
  with gr.Row():
487
  with gr.Column(scale=1):
488
- enhance_checkbox_label = "✨ Apply Face Restoration"
489
- if not restoration_model_loaded_successfully: enhance_checkbox_label += " (Model N/A)"
490
- enhance_checkbox = gr.Checkbox(label=enhance_checkbox_label, value=restoration_model_loaded_successfully, interactive=restoration_model_loaded_successfully)
491
- if not restoration_model_loaded_successfully: gr.Markdown("<p style='color: orange; font-size:0.8em;'>⚠️ Restoration model N/A.</p>")
492
-
493
  with gr.Column(scale=1):
494
- color_correction_checkbox = gr.Checkbox(label="🎨 Apply Color Correction", value=True)
495
-
496
  with gr.Column(scale=1):
497
- naturalness_checkbox = gr.Checkbox(label="🌿 Apply Subtle Naturalness Filters", value=False)
498
-
499
  with gr.Row():
500
  swap_button = gr.Button("πŸš€ GENERATE SWAP!", variant="primary", scale=3, interactive=core_models_loaded_successfully)
501
  clear_button = gr.Button("🧹 Clear All", variant="stop", scale=1)
502
-
503
  with gr.Row():
504
  swapped_image_output = gr.Image(label="✨ Swapped Result", interactive=False, height=450)
505
  download_output_file = gr.File(label="⬇️ Download Swapped Image")
506
 
507
  # Event Handlers
508
- def on_target_image_change_or_clear(target_img_pil):
509
- if target_img_pil is None:
510
- blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
511
- return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
512
- return target_faces_preview_output.value, face_index_slider.value
513
-
514
- target_image_input.change(fn=on_target_image_change_or_clear, inputs=[target_image_input], outputs=[target_faces_preview_output, face_index_slider], queue=False)
515
  preview_button.click(fn=preview_target_faces, inputs=[target_image_input], outputs=[target_faces_preview_output, face_index_slider])
516
-
517
- swap_button.click(
518
- fn=process_face_swap,
519
- inputs=[
520
- source_image_input, target_image_input, face_index_slider,
521
- enhance_checkbox, color_correction_checkbox, naturalness_checkbox
522
- # Note: Removed apply_ai_upscaling from inputs list for this example as we focused on ReSwapper integration
523
- ],
524
- outputs=[swapped_image_output, download_output_file]
525
- )
526
-
527
- def clear_all_inputs_outputs():
528
- blank_preview = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color = 'lightgray')
529
- return (None, None, blank_preview, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False), None, None)
530
-
531
- clear_button.click(fn=clear_all_inputs_outputs, inputs=None, outputs=[source_image_input, target_image_input, target_faces_preview_output, face_index_slider, swapped_image_output, download_output_file], queue=False)
532
-
533
  gr.Examples(
534
- examples=[
535
- ["examples/source_face.jpg", "examples/target_group.jpg", 0, True, True, True],
536
- ],
537
- inputs=[source_image_input, target_image_input, face_index_slider, enhance_checkbox, color_correction_checkbox, naturalness_checkbox],
538
- outputs=[swapped_image_output, download_output_file],
539
- fn=process_face_swap, cache_examples=False,
540
- label="Example Face Swaps (Click to run)"
541
  )
542
 
543
  # --- Main Execution Block ---
544
  if __name__ == "__main__":
545
  os.makedirs("models", exist_ok=True)
546
  os.makedirs("examples", exist_ok=True)
547
-
548
- print("\n" + "="*60)
549
- print("πŸš€ ReSwapper AI FACE SWAP - STARTUP STATUS πŸš€")
550
- print("="*60)
551
  if not core_models_loaded_successfully:
552
- print(f"πŸ”΄ CRITICAL ERROR: Core models FAILED to load.")
553
- print(f" - Face Analyzer: '{FACE_ANALYZER_NAME}' (Status: {'Loaded' if face_analyzer else 'Failed'})")
554
- print(f" - ReSwapper Model: '{SWAPPER_MODEL_PATH}' (Status: {'Loaded' if reswapper_session else 'Failed'})") # MODIFIED
555
- print(" The application UI will load but will be NON-FUNCTIONAL.")
556
- else:
557
- print("🟒 Core models (Face Analyzer & ReSwapper) loaded successfully.")
558
-
559
- if not restoration_model_loaded_successfully:
560
- print(f"🟑 INFO: Face Restoration model ('{RESTORATION_MODEL_PATH}') not loaded. Enhancement feature disabled.")
561
- else:
562
- print("🟒 Face Restoration model loaded successfully.")
563
- print("="*60 + "\n")
564
-
565
- print("Launching Gradio Interface...")
566
  demo.launch()
 
3
  import numpy as np
4
  import gradio as gr
5
  from insightface.app import FaceAnalysis # Still needed for detection and embeddings
 
6
  from PIL import Image
7
  import tempfile
8
  import logging
 
11
  # --- Configuration & Setup ---
12
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13
 
14
+ SWAPPER_MODEL_PATH = "models/reswapper_256.onnx" # Using ReSwapper
 
15
  RESTORATION_MODEL_PATH = "models/gfpgan_1.4.onnx" # Or your chosen restorer, e.g., restoreformer++.onnx
 
16
 
17
+ FACE_ANALYZER_NAME = 'buffalo_l'
18
  DETECTION_SIZE = (640, 640)
19
+ EXECUTION_PROVIDERS = ['CPUExecutionProvider']
20
 
21
+ # --- Global Variables ---
22
  face_analyzer = None
23
+ reswapper_session = None # For reswapper_256.onnx
24
  face_restorer = None
 
25
 
26
  # --- Initialization Functions ---
27
  def initialize_models():
28
+ global face_analyzer, reswapper_session, face_restorer
29
  try:
 
30
  if face_analyzer is None:
31
  logging.info(f"Initializing FaceAnalysis model: {FACE_ANALYZER_NAME}")
32
  face_analyzer = FaceAnalysis(name=FACE_ANALYZER_NAME, providers=EXECUTION_PROVIDERS)
33
  face_analyzer.prepare(ctx_id=0, det_size=DETECTION_SIZE)
34
  logging.info("FaceAnalysis model initialized successfully.")
35
 
 
36
  if reswapper_session is None:
37
  if not os.path.exists(SWAPPER_MODEL_PATH):
38
  logging.error(f"ReSwapper model FILE NOT FOUND at {SWAPPER_MODEL_PATH}. Swapping will fail.")
 
43
  logging.info(f"ReSwapper model ({SWAPPER_MODEL_PATH}) loaded successfully with onnxruntime.")
44
  except Exception as e:
45
  logging.error(f"Error loading ReSwapper model {SWAPPER_MODEL_PATH} with onnxruntime: {e}", exc_info=True)
46
+ reswapper_session = None
47
+
 
48
  if face_restorer is None:
49
  if not os.path.exists(RESTORATION_MODEL_PATH):
50
+ logging.error(f"Face restoration model FILE NOT FOUND at: {RESTORATION_MODEL_PATH}. Enhancement feature will be disabled.")
51
  else:
 
52
  logging.info(f"Attempting to load face restoration model from: {RESTORATION_MODEL_PATH}")
53
  try:
54
  face_restorer = onnxruntime.InferenceSession(RESTORATION_MODEL_PATH, providers=EXECUTION_PROVIDERS)
 
56
  except Exception as e:
57
  logging.error(f"Error loading face restoration model from {RESTORATION_MODEL_PATH}: {e}. Enhancement feature will be disabled.", exc_info=True)
58
  face_restorer = None
 
 
 
59
  except Exception as e:
60
  logging.error(f"A critical error occurred during model initialization: {e}", exc_info=True)
61
 
 
62
  initialize_models()
63
+ core_models_loaded_successfully = face_analyzer is not None and reswapper_session is not None
64
  restoration_model_loaded_successfully = face_restorer is not None
 
65
 
66
+ # --- Image Utility Functions ---
67
  def convert_pil_to_cv2(pil_image: Image.Image) -> np.ndarray | None:
68
  if pil_image is None: return None
69
  try:
 
80
  logging.error(f"Error converting CV2 to PIL: {e}")
81
  return None
82
 
83
+ # --- Core AI & Image Processing Functions ---
 
 
 
 
84
  def get_faces_from_image(img_np: np.ndarray) -> list:
85
+ if face_analyzer is None:
86
  logging.error("Face analyzer not available for get_faces_from_image.")
87
  return []
88
  if img_np is None: return []
89
  try:
90
+ return face_analyzer.get(img_np)
91
  except Exception as e:
92
  logging.error(f"Error during face detection: {e}", exc_info=True)
93
  return []
94
 
95
+ def draw_detected_faces(img_np: np.ndarray, faces: list) -> np.ndarray:
 
96
  img_with_boxes = img_np.copy()
97
  for i, face in enumerate(faces):
98
  box = face.bbox.astype(int)
99
  x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
100
  cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), (0, 255, 0), 2)
101
  label_position = (x1, max(0, y1 - 10))
102
+ cv2.putText(img_with_boxes, f"Face {i}", label_position, cv2.FONT_HERSHEY_SIMPLEX, 0.6, (36, 255, 12), 2)
 
103
  return img_with_boxes
104
 
105
+ def enhance_cropped_face(face_crop_bgr: np.ndarray) -> np.ndarray:
106
  if not restoration_model_loaded_successfully or face_restorer is None:
 
107
  return face_crop_bgr
108
+ if face_crop_bgr is None or face_crop_bgr.size == 0: return face_crop_bgr if face_crop_bgr is not None else np.array([])
 
 
109
  logging.info(f"Applying face restoration to crop of size {face_crop_bgr.shape[:2]}...")
110
+ crop_h, crop_w = face_crop_bgr.shape[:2]
111
  try:
112
  img_rgb = cv2.cvtColor(face_crop_bgr, cv2.COLOR_BGR2RGB)
113
+ input_size = (512, 512) # Common for GFPGAN-like models
114
+ img_resized = cv2.resize(img_rgb, input_size, interpolation=cv2.INTER_AREA)
115
+ img_norm = (img_resized / 255.0).astype(np.float32)
116
+ img_chw = np.transpose(img_norm, (2,0,1))
117
  input_tensor = np.expand_dims(img_chw, axis=0)
118
  input_name = face_restorer.get_inputs()[0].name
119
  output_name = face_restorer.get_outputs()[0].name
120
+ output = face_restorer.run([output_name], {input_name: input_tensor})[0]
121
+ out_chw = np.squeeze(output, axis=0)
122
+ out_hwc = np.transpose(out_chw, (1,2,0))
123
+ out_denorm = np.clip(out_hwc * 255.0, 0, 255).astype(np.uint8)
124
+ out_resized_rgb = cv2.resize(out_denorm, (crop_w, crop_h), interpolation=cv2.INTER_LANCZOS4)
125
+ return cv2.cvtColor(out_resized_rgb, cv2.COLOR_RGB2BGR)
 
 
126
  except Exception as e:
127
  logging.error(f"Error during face restoration for crop: {e}", exc_info=True)
128
  return face_crop_bgr
129
 
130
+ def histogram_match_channel(source_ch: np.ndarray, target_ch: np.ndarray) -> np.ndarray:
131
+ src_shape = source_ch.shape
132
+ src_flat = source_ch.flatten()
133
+ tgt_flat = target_ch.flatten()
134
+ src_hist, _ = np.histogram(src_flat, 256, [0,255])
135
+ tgt_hist, _ = np.histogram(tgt_flat, 256, [0,255])
136
+ src_cdf = src_hist.cumsum()
137
+ tgt_cdf = tgt_hist.cumsum()
138
+ if src_cdf[-1] == 0: return source_ch # Avoid division by zero for blank source
139
+ if tgt_cdf[-1] == 0: return source_ch # Avoid issues with blank target
140
+ src_cdf_norm = src_cdf * float(tgt_cdf[-1]) / src_cdf[-1] # Normalize src CDF to target range
141
+ lut = np.zeros(256, dtype='uint8')
142
+ j = 0
143
  for i in range(256):
144
+ while j < 255 and tgt_cdf[j] < src_cdf_norm[i]:
145
+ j += 1
146
+ lut[i] = j
147
+ return cv2.LUT(source_ch, lut).reshape(src_shape)
 
148
 
149
+ def histogram_match_color(source_img: np.ndarray, target_img: np.ndarray) -> np.ndarray:
150
  if source_img is None or target_img is None or source_img.size == 0 or target_img.size == 0:
 
151
  return source_img if source_img is not None else np.array([])
152
+ matched = np.zeros_like(source_img)
153
  try:
154
  for i in range(source_img.shape[2]):
155
+ matched[:,:,i] = histogram_match_channel(source_img[:,:,i], target_img[:,:,i])
156
+ return matched
157
  except Exception as e:
158
  logging.error(f"Error during histogram matching: {e}", exc_info=True)
159
  return source_img
160
 
161
+ def apply_naturalness_filters(face_region_bgr: np.ndarray, noise_level: int = 3) -> np.ndarray:
162
  if face_region_bgr is None or face_region_bgr.size == 0:
 
163
  return face_region_bgr if face_region_bgr is not None else np.array([])
164
+ proc = face_region_bgr.copy()
165
  try:
166
+ proc = cv2.medianBlur(proc, 3)
167
  if noise_level > 0:
168
+ noise = np.random.normal(0, noise_level, proc.shape).astype(np.int16)
169
+ proc_int16 = proc.astype(np.int16) + noise
170
+ proc = np.clip(proc_int16, 0, 255).astype(np.uint8)
171
+ logging.info("Applied naturalness filters.")
172
  except Exception as e:
173
  logging.error(f"Error applying naturalness filters: {e}", exc_info=True)
174
  return face_region_bgr
175
+ return proc
176
 
177
  # --- Main Processing Function ---
178
  def process_face_swap(source_pil_img: Image.Image, target_pil_img: Image.Image,
179
+ target_face_index: int,
180
  apply_enhancement: bool, apply_color_correction: bool,
181
  apply_naturalness: bool,
182
  progress=gr.Progress(track_tqdm=True)):
183
 
184
  progress(0, desc="Initializing process...")
185
+ if not core_models_loaded_successfully:
186
  gr.Error("CRITICAL: Core models (Face Analyzer or ReSwapper) not loaded. Cannot proceed.")
187
  return Image.new('RGB', (100,100), color='lightgrey'), None
188
 
 
190
  if target_pil_img is None: raise gr.Error("Target image not provided.")
191
 
192
  progress(0.05, desc="Converting images...")
193
+ source_np = convert_pil_to_cv2(source_pil_img)
194
+ target_np = convert_pil_to_cv2(target_pil_img)
195
  if source_np is None or target_np is None:
196
+ raise gr.Error("Image conversion failed.")
197
 
198
+ original_target_h, original_target_w = target_np.shape[:2]
199
 
200
+ progress(0.15, desc="Detecting source face...")
201
  source_faces = get_faces_from_image(source_np)
202
+ if not source_faces: raise gr.Error("No face in source.")
203
+ source_face = source_faces[0]
204
 
205
+ progress(0.25, desc="Detecting target faces...")
206
  target_faces = get_faces_from_image(target_np)
207
+ if not target_faces: raise gr.Error("No faces in target.")
208
  if not (0 <= target_face_index < len(target_faces)):
209
+ raise gr.Error(f"Target face index out of range.")
210
+ target_face_info = target_faces[int(target_face_index)]
211
 
212
+ swapped_bgr_img_final = target_np.copy() # Initialize for fallback
213
 
214
  try:
215
+ progress(0.35, desc="Preparing inputs for ReSwapper...")
216
 
217
+ # 1. Source Face Embedding
218
+ source_embedding = source_face.embedding.astype(np.float32)
 
 
219
  source_embedding_tensor = np.expand_dims(source_embedding, axis=0) # (1, 512)
220
 
221
  # 2. Target Image Preparation
222
+ target_img_rgb = cv2.cvtColor(target_np, cv2.COLOR_BGR2RGB)
223
+
224
+ # ***** FIX APPLIED HERE: Resize target image for ReSwapper *****
225
+ # The "256" in reswapper_256.onnx strongly suggests this expected input size.
226
+ # Your error indicated: Expected 256, Got 225 (or similar).
227
+ reswapper_input_size = (256, 256) # Expected H, W by the model
228
 
229
+ # Store original size of the target image if it's not already stored for final resize
230
+ # original_target_h, original_target_w are already available
231
+
232
+ target_img_resized_rgb = cv2.resize(target_img_rgb, (reswapper_input_size[1], reswapper_input_size[0]), interpolation=cv2.INTER_AREA)
233
+ # cv2.resize takes (width, height)
234
 
235
+ # --- VERIFY NORMALIZATION for reswapper_256.onnx ---
236
+ # Common options:
237
+ # Option A: [0, 1]
238
+ target_img_normalized = (target_img_resized_rgb / 255.0).astype(np.float32)
239
+ # Option B: [-1, 1]
240
+ # target_img_normalized = (target_img_resized_rgb / 127.5 - 1.0).astype(np.float32)
241
+ # Choose the one your reswapper_256.onnx model expects!
 
 
242
 
243
  target_img_chw = np.transpose(target_img_normalized, (2, 0, 1)) # HWC to CHW
244
  target_image_tensor = np.expand_dims(target_img_chw, axis=0) # Add batch dimension
245
 
246
+ # 3. Target Face Bounding Box (Highly Model Specific!)
247
+ # --- VERIFY IF AND HOW reswapper_256.onnx USES THIS ---
248
+ # If it does, it needs to be scaled to the resized_target_image (256x256)
249
+ # Original bbox is target_face_info.bbox relative to original_target_w, original_target_h
250
+ # For now, we'll pass the original bbox, but this might be incorrect if the model
251
+ # expects a bbox relative to the (potentially resized) target_image_tensor.
252
+ # Some models might not need it if they perform their own detection on the target_image_tensor.
253
+ target_bbox_original = target_face_info.bbox.astype(np.float32)
254
+ target_bbox_tensor_input = np.expand_dims(target_bbox_original, axis=0)
255
+
256
+
257
+ # --- VERIFY TENSOR NAMES AND INPUT ORDER for reswapper_256.onnx ---
258
+ # Use Netron to inspect your .onnx file.
259
+ # The following are GUESSES and placeholders.
260
+ onnx_input_list = reswapper_session.get_inputs()
261
+ onnx_output_list = reswapper_session.get_outputs()
262
+
263
+ input_feed = {}
264
+ # Example GUESS:
265
+ # Assumed input order: target_image, source_embedding, (optional) target_bbox
266
+ if len(onnx_input_list) > 0: input_feed[onnx_input_list[0].name] = target_image_tensor
267
+ if len(onnx_input_list) > 1: input_feed[onnx_input_list[1].name] = source_embedding_tensor
268
+ if len(onnx_input_list) > 2: # If model expects a third input (e.g., bbox)
269
+ input_feed[onnx_input_list[2].name] = target_bbox_tensor_input # This bbox might need scaling/normalization!
270
 
271
+ output_name = onnx_output_list[0].name # Assuming first output is the image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  progress(0.5, desc="Running ReSwapper model...")
274
  onnx_outputs = reswapper_session.run([output_name], input_feed)
275
+ swapped_image_tensor_output = onnx_outputs[0]
276
 
277
  progress(0.6, desc="Post-processing ReSwapper output...")
278
+ if swapped_image_tensor_output.ndim == 4:
279
+ swapped_image_tensor_output = np.squeeze(swapped_image_tensor_output, axis=0)
280
+
281
+ swapped_image_hwc = np.transpose(swapped_image_tensor_output, (1, 2, 0))
282
+
283
+ # --- VERIFY DENORMALIZATION (inverse of normalization used above) ---
 
284
  # If normalized to [0, 1]:
285
  swapped_image_denormalized = swapped_image_hwc * 255.0
286
  # If normalized to [-1, 1]:
287
  # swapped_image_denormalized = (swapped_image_hwc + 1.0) * 127.5
288
 
289
+ swapped_image_uint8_rgb = np.clip(swapped_image_denormalized, 0, 255).astype(np.uint8)
290
 
291
+ # The output of reswapper will be 256x256 (if input was 256x256 and model maintains size)
292
+ # Resize this output back to the original target image's dimensions
293
+ swapped_bgr_resized_to_original = cv2.resize(
294
+ swapped_image_uint8_rgb,
295
+ (original_target_w, original_target_h),
296
+ interpolation=cv2.INTER_LANCZOS4
297
+ )
298
+ # Convert to BGR as the rest of the pipeline expects BGR
299
+ swapped_bgr_img_final = cv2.cvtColor(swapped_bgr_resized_to_original, cv2.COLOR_RGB2BGR)
300
 
301
  # --- Post-processing pipeline (Restoration, Color Correction, Naturalness) ---
302
+ current_processed_image = swapped_bgr_img_final.copy()
303
+ # ROI for post-processing is based on original bbox coordinates
304
+ roi_x1, roi_y1, roi_x2, roi_y2 = target_face_info.bbox.astype(int)
 
305
 
306
  face_roi_for_postprocessing = current_processed_image[roi_y1:roi_y2, roi_x1:roi_x2].copy()
307
 
308
+ if face_roi_for_postprocessing.size > 0:
309
+ if apply_enhancement and restoration_model_loaded_successfully:
310
+ progress(0.7, desc="Applying face restoration...")
311
+ face_roi_for_postprocessing = enhance_cropped_face(face_roi_for_postprocessing)
 
 
 
 
 
312
 
313
  if apply_color_correction:
314
  progress(0.8, desc="Applying color correction...")
315
+ original_target_roi_for_color = target_np[roi_y1:roi_y2, roi_x1:roi_x2] # Original target pixels
316
+ if original_target_roi_for_color.size > 0:
317
+ face_roi_for_postprocessing = histogram_match_color(face_roi_for_postprocessing, original_target_roi_for_color.copy())
318
 
319
  if apply_naturalness:
320
  progress(0.85, desc="Applying naturalness filters...")
321
  face_roi_for_postprocessing = apply_naturalness_filters(face_roi_for_postprocessing)
322
 
323
  # Paste the processed ROI back
324
+ if face_roi_for_postprocessing.shape[0] == (roi_y2 - roi_y1) and \
325
+ face_roi_for_postprocessing.shape[1] == (roi_x2 - roi_x1):
326
  current_processed_image[roi_y1:roi_y2, roi_x1:roi_x2] = face_roi_for_postprocessing
327
+ else: # If size changed during post-processing (should not happen with current funcs)
328
+ logging.warning("ROI size mismatch after post-processing. Attempting resize for paste.")
329
  try:
330
+ resized_roi = cv2.resize(face_roi_for_postprocessing, (roi_x2-roi_x1, roi_y2-roi_y1), interpolation=cv2.INTER_LANCZOS4)
331
+ current_processed_image[roi_y1:roi_y2, roi_x1:roi_x2] = resized_roi
332
+ except Exception as e_resize_paste:
333
+ logging.error(f"Failed to resize/paste processed ROI: {e_resize_paste}")
334
+ else:
335
+ logging.warning("ROI for post-processing is empty. Skipping these steps.")
336
 
337
+ swapped_bgr_img_final = current_processed_image
338
 
339
  except Exception as e:
340
  logging.error(f"Error during main processing pipeline: {e}", exc_info=True)
341
+ swapped_pil_img_on_error = convert_cv2_to_pil(swapped_bgr_img_final) # Use last good state
 
 
342
  if swapped_pil_img_on_error is None:
343
+ swapped_pil_img_on_error = Image.new('RGB', (original_target_w, original_target_h), color='lightgrey')
344
+ raise gr.Error(f"An error occurred: {str(e)}. Check console for details, especially ONNX input/output errors.")
345
 
346
  progress(0.9, desc="Finalizing image...")
347
+ swapped_pil_img = convert_cv2_to_pil(swapped_bgr_img_final)
348
  if swapped_pil_img is None:
349
+ gr.Error("Failed to convert final image.")
350
+ swapped_pil_img = Image.new('RGB', (original_target_w, original_target_h), color='lightgrey')
351
 
352
  temp_file_path = None
353
  try:
354
  with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False, mode="wb") as tmp_file:
355
  swapped_pil_img.save(tmp_file, format="JPEG")
356
  temp_file_path = tmp_file.name
 
357
  except Exception as e:
358
+ logging.error(f"Error saving temp file: {e}", exc_info=True)
359
+ gr.Warning("Could not save swapped image for download.")
360
 
361
  progress(1.0, desc="Processing complete!")
362
  return swapped_pil_img, temp_file_path
363
 
364
+ # --- Gradio Preview Function ---
365
+ def preview_target_faces(target_pil_img: Image.Image):
366
  if target_pil_img is None:
367
+ blank_img = Image.new('RGB', DETECTION_SIZE, color='lightgray')
368
+ return blank_img, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
369
  target_np = convert_pil_to_cv2(target_pil_img)
370
  if target_np is None:
371
+ return Image.new('RGB', DETECTION_SIZE, color='lightgray'), gr.Slider(interactive=False)
 
372
  faces = get_faces_from_image(target_np)
373
  if not faces:
374
+ return convert_cv2_to_pil(target_np) or Image.new('RGB', DETECTION_SIZE, color='lightgray'), gr.Slider(interactive=False)
375
+
376
+ preview_np = draw_detected_faces(target_np, faces)
377
+ preview_pil = convert_cv2_to_pil(preview_np) or Image.new('RGB', DETECTION_SIZE, color='lightgray')
378
+ return preview_pil, gr.Slider(minimum=0, maximum=max(0, len(faces)-1), value=0, step=1, interactive=len(faces)>0)
379
+
380
+ # --- Gradio UI Definition ---
381
+ with gr.Blocks(title="ReSwapper AI Face Swap πŸ’Ž v2", theme=gr.themes.Soft()) as demo: # Changed title slightly
 
 
 
 
 
 
382
  gr.Markdown(
383
  """
384
  <div style="text-align: center;">
385
+ <h1>πŸ’Ž ReSwapper AI Face Swap 🌠 v2</h1>
386
+ <p>Utilizing <code>reswapper_256.onnx</code>. Please ensure model inputs/outputs are correctly configured in the code if issues arise.</p>
387
  </div>
388
  """
389
  )
390
+ if not core_models_loaded_successfully:
391
+ gr.Error("CRITICAL ERROR: Core models (Face Analyzer or ReSwapper) failed to load. Check console.")
 
392
 
393
  with gr.Row():
394
  with gr.Column(scale=1):
395
+ source_image_input = gr.Image(label="πŸ‘€ Source Face Image", type="pil", sources=["upload","clipboard"], height=350)
396
  with gr.Column(scale=1):
397
+ target_image_input = gr.Image(label="πŸ–ΌοΈ Target Scene Image", type="pil", sources=["upload","clipboard"], height=350)
 
398
  with gr.Row(equal_height=True):
399
  preview_button = gr.Button("πŸ” Preview & Select Target Face", variant="secondary")
400
+ face_index_slider = gr.Slider(label="🎯 Select Target Face", minimum=0, maximum=0, step=1, value=0, interactive=False)
 
401
  target_faces_preview_output = gr.Image(label="πŸ‘€ Detected Faces in Target", interactive=False, height=350)
402
+ gr.HTML("<hr style='margin:15px 0;'>")
 
403
  with gr.Row():
404
  with gr.Column(scale=1):
405
+ enh_label = "✨ Face Restoration" + (" (Model N/A)" if not restoration_model_loaded_successfully else "")
406
+ enhance_checkbox = gr.Checkbox(label=enh_label, value=restoration_model_loaded_successfully, interactive=restoration_model_loaded_successfully)
407
+ if not restoration_model_loaded_successfully: gr.Markdown("<p style='color:orange;font-size:0.8em;'>⚠️ Restoration N/A</p>")
 
 
408
  with gr.Column(scale=1):
409
+ color_correction_checkbox = gr.Checkbox(label="🎨 Color Correction", value=True)
 
410
  with gr.Column(scale=1):
411
+ naturalness_checkbox = gr.Checkbox(label="🌿 Naturalness Filters", value=False)
 
412
  with gr.Row():
413
  swap_button = gr.Button("πŸš€ GENERATE SWAP!", variant="primary", scale=3, interactive=core_models_loaded_successfully)
414
  clear_button = gr.Button("🧹 Clear All", variant="stop", scale=1)
 
415
  with gr.Row():
416
  swapped_image_output = gr.Image(label="✨ Swapped Result", interactive=False, height=450)
417
  download_output_file = gr.File(label="⬇️ Download Swapped Image")
418
 
419
  # Event Handlers
420
+ target_image_input.change(fn=lambda x: (Image.new('RGB', DETECTION_SIZE, color='lightgray') if x is None else target_faces_preview_output.value,
421
+ gr.Slider(interactive=False) if x is None else face_index_slider.value),
422
+ inputs=[target_image_input],
423
+ outputs=[target_faces_preview_output, face_index_slider],
424
+ queue=False)
 
 
425
  preview_button.click(fn=preview_target_faces, inputs=[target_image_input], outputs=[target_faces_preview_output, face_index_slider])
426
+ swap_button.click(fn=process_face_swap,
427
+ inputs=[source_image_input, target_image_input, face_index_slider,
428
+ enhance_checkbox, color_correction_checkbox, naturalness_checkbox],
429
+ outputs=[swapped_image_output, download_output_file])
430
+ def clear_all():
431
+ return None, None, Image.new('RGB', DETECTION_SIZE, color='lightgray'), gr.Slider(interactive=False), None, None
432
+ clear_button.click(fn=clear_all, inputs=None,
433
+ outputs=[source_image_input, target_image_input, target_faces_preview_output,
434
+ face_index_slider, swapped_image_output, download_output_file],
435
+ queue=False)
 
 
 
 
 
 
 
436
  gr.Examples(
437
+ [["examples/source_face.jpg", "examples/target_group.jpg", 0, True, True, True]],
438
+ [source_image_input, target_image_input, face_index_slider, enhance_checkbox, color_correction_checkbox, naturalness_checkbox],
439
+ [swapped_image_output, download_output_file],
440
+ process_face_swap, cache_examples=False, label="Example"
 
 
 
441
  )
442
 
443
  # --- Main Execution Block ---
444
  if __name__ == "__main__":
445
  os.makedirs("models", exist_ok=True)
446
  os.makedirs("examples", exist_ok=True)
447
+ print("\n" + "="*60 + "\nπŸš€ ReSwapper AI FACE SWAP - STARTUP STATUS πŸš€\n" + "="*60)
 
 
 
448
  if not core_models_loaded_successfully:
449
+ print(f"πŸ”΄ CRITICAL: Core models FAILED. Analyzer: {'OK' if face_analyzer else 'FAIL'}, ReSwapper: {'OK' if reswapper_session else 'FAIL'}")
450
+ else: print("🟒 Core models (Analyzer & ReSwapper) loaded.")
451
+ if not restoration_model_loaded_successfully: print(f"🟑 INFO: Restoration model ('{RESTORATION_MODEL_PATH}') N/A.")
452
+ else: print("🟒 Restoration model loaded.")
453
+ print("="*60 + "\nLaunching Gradio Interface...")
 
 
 
 
 
 
 
 
 
454
  demo.launch()