Update app.py
Browse files
app.py
CHANGED
|
@@ -2,8 +2,8 @@ import os
|
|
| 2 |
import cv2
|
| 3 |
import numpy as np
|
| 4 |
import gradio as gr
|
| 5 |
-
from insightface.app import FaceAnalysis
|
| 6 |
-
from insightface.model_zoo import get_model
|
| 7 |
from PIL import Image
|
| 8 |
import tempfile
|
| 9 |
import logging
|
|
@@ -12,66 +12,71 @@ import onnxruntime
|
|
| 12 |
# --- Configuration & Setup ---
|
| 13 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
FACE_ANALYZER_NAME = 'buffalo_l' # InsightFace detection model
|
| 19 |
-
DETECTION_SIZE = (640, 640)
|
| 20 |
-
EXECUTION_PROVIDERS = ['CPUExecutionProvider'] # Or ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
| 21 |
|
| 22 |
# --- Global Variables (Lazy Loaded by initialize_models) ---
|
| 23 |
face_analyzer = None
|
| 24 |
-
|
| 25 |
face_restorer = None
|
|
|
|
| 26 |
|
| 27 |
# --- Initialization Functions ---
|
| 28 |
def initialize_models():
|
| 29 |
-
global face_analyzer,
|
| 30 |
try:
|
| 31 |
-
# Initialize FaceAnalysis model
|
| 32 |
if face_analyzer is None:
|
| 33 |
logging.info(f"Initializing FaceAnalysis model: {FACE_ANALYZER_NAME}")
|
| 34 |
face_analyzer = FaceAnalysis(name=FACE_ANALYZER_NAME, providers=EXECUTION_PROVIDERS)
|
| 35 |
-
face_analyzer.prepare(ctx_id=0, det_size=DETECTION_SIZE)
|
| 36 |
logging.info("FaceAnalysis model initialized successfully.")
|
| 37 |
|
| 38 |
-
# Initialize
|
| 39 |
-
if
|
| 40 |
if not os.path.exists(SWAPPER_MODEL_PATH):
|
| 41 |
-
logging.error(f"
|
| 42 |
else:
|
| 43 |
-
logging.info(f"Loading
|
| 44 |
try:
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
except
|
| 48 |
-
logging.
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
if face_restorer is None: # Only attempt to load if not already loaded/failed
|
| 54 |
if not os.path.exists(RESTORATION_MODEL_PATH):
|
| 55 |
-
logging.error(f"Face restoration model FILE NOT FOUND at: {RESTORATION_MODEL_PATH}. Enhancement
|
| 56 |
else:
|
|
|
|
| 57 |
logging.info(f"Attempting to load face restoration model from: {RESTORATION_MODEL_PATH}")
|
| 58 |
try:
|
| 59 |
face_restorer = onnxruntime.InferenceSession(RESTORATION_MODEL_PATH, providers=EXECUTION_PROVIDERS)
|
| 60 |
logging.info("Face restoration model loaded successfully.")
|
| 61 |
except Exception as e:
|
| 62 |
logging.error(f"Error loading face restoration model from {RESTORATION_MODEL_PATH}: {e}. Enhancement feature will be disabled.", exc_info=True)
|
| 63 |
-
face_restorer = None
|
|
|
|
|
|
|
|
|
|
| 64 |
except Exception as e:
|
| 65 |
logging.error(f"A critical error occurred during model initialization: {e}", exc_info=True)
|
| 66 |
-
# face_analyzer, swapper, or face_restorer might be None. Subsequent checks will handle this.
|
| 67 |
|
| 68 |
# --- Call Initialization Early ---
|
| 69 |
initialize_models()
|
| 70 |
-
|
| 71 |
-
core_models_loaded_successfully = face_analyzer is not None and swapper is not None
|
| 72 |
restoration_model_loaded_successfully = face_restorer is not None
|
|
|
|
| 73 |
|
| 74 |
-
# --- Image Utility Functions ---
|
| 75 |
def convert_pil_to_cv2(pil_image: Image.Image) -> np.ndarray | None:
|
| 76 |
if pil_image is None: return None
|
| 77 |
try:
|
|
@@ -88,139 +93,132 @@ def convert_cv2_to_pil(cv2_image: np.ndarray) -> Image.Image | None:
|
|
| 88 |
logging.error(f"Error converting CV2 to PIL: {e}")
|
| 89 |
return None
|
| 90 |
|
| 91 |
-
# --- Core AI & Image Processing Functions ---
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
def get_faces_from_image(img_np: np.ndarray) -> list:
|
| 93 |
-
if
|
| 94 |
-
# This condition should ideally be caught before calling this function if core models failed.
|
| 95 |
logging.error("Face analyzer not available for get_faces_from_image.")
|
| 96 |
return []
|
| 97 |
if img_np is None: return []
|
| 98 |
try:
|
| 99 |
-
return face_analyzer.get(img_np)
|
| 100 |
except Exception as e:
|
| 101 |
logging.error(f"Error during face detection: {e}", exc_info=True)
|
| 102 |
return []
|
| 103 |
|
| 104 |
-
|
|
|
|
| 105 |
img_with_boxes = img_np.copy()
|
| 106 |
for i, face in enumerate(faces):
|
| 107 |
-
box = face.bbox.astype(int)
|
| 108 |
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
|
| 109 |
cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 110 |
-
label_position = (x1, max(0, y1 - 10))
|
| 111 |
cv2.putText(img_with_boxes, f"Face {i}", label_position,
|
| 112 |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (36, 255, 12), 2)
|
| 113 |
return img_with_boxes
|
| 114 |
|
| 115 |
-
def enhance_cropped_face(face_crop_bgr: np.ndarray) -> np.ndarray:
|
| 116 |
if not restoration_model_loaded_successfully or face_restorer is None:
|
| 117 |
logging.warning("Face restorer model not available. Skipping enhancement for crop.")
|
| 118 |
return face_crop_bgr
|
| 119 |
if face_crop_bgr is None or face_crop_bgr.shape[0] == 0 or face_crop_bgr.shape[1] == 0:
|
| 120 |
logging.warning("Received empty or invalid face crop for enhancement.")
|
| 121 |
return face_crop_bgr if face_crop_bgr is not None else np.array([])
|
| 122 |
-
|
| 123 |
-
|
| 124 |
logging.info(f"Applying face restoration to crop of size {face_crop_bgr.shape[:2]}...")
|
| 125 |
crop_height, crop_width = face_crop_bgr.shape[:2]
|
| 126 |
-
|
| 127 |
try:
|
| 128 |
img_rgb = cv2.cvtColor(face_crop_bgr, cv2.COLOR_BGR2RGB)
|
| 129 |
-
|
| 130 |
-
# GFPGAN typically expects 512x512 input.
|
| 131 |
-
restorer_input_size = (512, 512) # Check your specific model's expected input size
|
| 132 |
img_resized_for_model = cv2.resize(img_rgb, restorer_input_size, interpolation=cv2.INTER_AREA)
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
input_tensor = np.expand_dims(img_chw, axis=0) # Add batch dimension
|
| 137 |
-
|
| 138 |
input_name = face_restorer.get_inputs()[0].name
|
| 139 |
output_name = face_restorer.get_outputs()[0].name
|
| 140 |
-
|
| 141 |
restored_output_model_size = face_restorer.run([output_name], {input_name: input_tensor})[0]
|
| 142 |
-
|
| 143 |
restored_img_chw = np.squeeze(restored_output_model_size, axis=0)
|
| 144 |
-
restored_img_hwc_model_size = np.transpose(restored_img_chw, (1, 2, 0))
|
| 145 |
restored_img_uint8_model_size = np.clip(restored_img_hwc_model_size * 255.0, 0, 255).astype(np.uint8)
|
| 146 |
-
|
| 147 |
-
# Resize back to the original crop's dimensions
|
| 148 |
restored_crop_rgb = cv2.resize(restored_img_uint8_model_size, (crop_width, crop_height), interpolation=cv2.INTER_LANCZOS4)
|
| 149 |
restored_crop_bgr = cv2.cvtColor(restored_crop_rgb, cv2.COLOR_RGB2BGR)
|
| 150 |
-
|
| 151 |
logging.info("Cropped face restoration complete.")
|
| 152 |
return restored_crop_bgr
|
| 153 |
except Exception as e:
|
| 154 |
logging.error(f"Error during face restoration for crop: {e}", exc_info=True)
|
| 155 |
-
return face_crop_bgr
|
| 156 |
|
| 157 |
-
def histogram_match_channel(source_channel: np.ndarray, target_channel: np.ndarray) -> np.ndarray:
|
| 158 |
-
"""Matches histogram of a single source channel to a target channel."""
|
| 159 |
source_shape = source_channel.shape
|
| 160 |
source_channel_flat = source_channel.flatten()
|
| 161 |
target_channel_flat = target_channel.flatten()
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
source_hist, bins = np.histogram(source_channel_flat, 256, [0,256])
|
| 165 |
source_cdf = source_hist.cumsum()
|
| 166 |
-
|
| 167 |
-
# Get the histogram and CDF of the target image
|
| 168 |
-
target_hist, bins = np.histogram(target_channel_flat, 256, [0,256])
|
| 169 |
target_cdf = target_hist.cumsum()
|
| 170 |
-
|
| 171 |
-
#
|
| 172 |
-
|
| 173 |
-
target_cdf_norm = target_cdf # No change needed for target if using its own max
|
| 174 |
-
|
| 175 |
-
# Create a lookup table
|
| 176 |
lookup_table = np.zeros(256, dtype='uint8')
|
| 177 |
-
|
| 178 |
-
for
|
| 179 |
-
while
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
lookup_table[gi] = gj
|
| 184 |
-
|
| 185 |
-
matched_channel_flat = cv2.LUT(source_channel, lookup_table) # Apply lookup table
|
| 186 |
return matched_channel_flat.reshape(source_shape)
|
| 187 |
|
| 188 |
-
|
| 189 |
-
def histogram_match_color(source_img: np.ndarray, target_img: np.ndarray) -> np.ndarray:
|
| 190 |
-
"""
|
| 191 |
-
Performs histogram matching on color images (BGR).
|
| 192 |
-
"""
|
| 193 |
if source_img is None or target_img is None or source_img.size == 0 or target_img.size == 0:
|
| 194 |
logging.warning("Cannot perform histogram matching on empty or None images.")
|
| 195 |
return source_img if source_img is not None else np.array([])
|
| 196 |
-
|
| 197 |
matched_img = np.zeros_like(source_img)
|
| 198 |
try:
|
| 199 |
-
for i in range(source_img.shape[2]):
|
| 200 |
matched_img[:,:,i] = histogram_match_channel(source_img[:,:,i], target_img[:,:,i])
|
| 201 |
return matched_img
|
| 202 |
except Exception as e:
|
| 203 |
logging.error(f"Error during histogram matching: {e}", exc_info=True)
|
| 204 |
-
return source_img
|
| 205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
-
# ---
|
| 208 |
def process_face_swap(source_pil_img: Image.Image, target_pil_img: Image.Image,
|
| 209 |
-
target_face_index: int,
|
|
|
|
|
|
|
| 210 |
progress=gr.Progress(track_tqdm=True)):
|
|
|
|
| 211 |
progress(0, desc="Initializing process...")
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
return Image.new('RGB', (100,100), color='lightgrey'), None # Placeholder for output
|
| 216 |
|
| 217 |
if source_pil_img is None: raise gr.Error("Source image not provided.")
|
| 218 |
if target_pil_img is None: raise gr.Error("Target image not provided.")
|
| 219 |
|
| 220 |
progress(0.05, desc="Converting images...")
|
| 221 |
-
source_np = convert_pil_to_cv2(source_pil_img)
|
| 222 |
-
target_np = convert_pil_to_cv2(target_pil_img)
|
| 223 |
-
|
| 224 |
if source_np is None or target_np is None:
|
| 225 |
raise gr.Error("Image conversion failed. Please try different images.")
|
| 226 |
|
|
@@ -228,98 +226,204 @@ def process_face_swap(source_pil_img: Image.Image, target_pil_img: Image.Image,
|
|
| 228 |
|
| 229 |
progress(0.15, desc="Detecting face in source image...")
|
| 230 |
source_faces = get_faces_from_image(source_np)
|
| 231 |
-
if not source_faces: raise gr.Error("No face found in the source image.
|
| 232 |
-
source_face = source_faces[0] #
|
| 233 |
|
| 234 |
progress(0.25, desc="Detecting faces in target image...")
|
| 235 |
target_faces = get_faces_from_image(target_np)
|
| 236 |
if not target_faces: raise gr.Error("No faces found in the target image.")
|
| 237 |
if not (0 <= target_face_index < len(target_faces)):
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
target_face_to_swap_info = target_faces[int(target_face_index)] # Face object from InsightFace
|
| 243 |
|
| 244 |
-
swapped_bgr_img = target_np.copy() # Start with a copy of target
|
| 245 |
try:
|
| 246 |
-
progress(0.4, desc="
|
| 247 |
-
# swapper.get returns the modified target_np with the face pasted back
|
| 248 |
-
swapped_bgr_img = swapper.get(target_np, target_face_to_swap_info, source_face, paste_back=True)
|
| 249 |
-
|
| 250 |
-
# Define bounding box for post-processing (enhancement and color correction)
|
| 251 |
-
# Use the bbox of the target face where the swap occurred.
|
| 252 |
-
# InsightFace bbox is [x1, y1, x2, y2]
|
| 253 |
-
bbox_coords = target_face_to_swap_info.bbox.astype(int)
|
| 254 |
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
if corrected_swapped_region.shape == swapped_face_region_to_correct.shape:
|
| 294 |
-
swapped_bgr_img[cc_y1:cc_y2, cc_x1:cc_x2] = corrected_swapped_region
|
| 295 |
-
else:
|
| 296 |
-
logging.warning("Color corrected region size mismatch. Skipping paste-back for color correction.")
|
| 297 |
else:
|
| 298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
else:
|
| 300 |
-
logging.warning("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
|
|
|
|
|
|
|
| 303 |
except Exception as e:
|
| 304 |
-
logging.error(f"Error during
|
| 305 |
-
#
|
| 306 |
-
|
| 307 |
-
|
|
|
|
| 308 |
swapped_pil_img_on_error = Image.new('RGB', (target_w, target_h), color='lightgrey')
|
| 309 |
raise gr.Error(f"An error occurred: {str(e)}")
|
| 310 |
-
# return swapped_pil_img_on_error, None # Alternative for UI stability
|
| 311 |
|
| 312 |
progress(0.9, desc="Finalizing image...")
|
| 313 |
swapped_pil_img = convert_cv2_to_pil(swapped_bgr_img)
|
| 314 |
-
if swapped_pil_img is None:
|
| 315 |
gr.Error("Failed to convert final image to display format.")
|
| 316 |
-
swapped_pil_img = Image.new('RGB', (target_w, target_h), color='lightgrey')
|
| 317 |
|
| 318 |
temp_file_path = None
|
| 319 |
try:
|
| 320 |
-
# Using 'with' ensures the file is closed before Gradio tries to use it
|
| 321 |
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False, mode="wb") as tmp_file:
|
| 322 |
-
swapped_pil_img.save(tmp_file, format="JPEG")
|
| 323 |
temp_file_path = tmp_file.name
|
| 324 |
logging.info(f"Swapped image saved to temporary file: {temp_file_path}")
|
| 325 |
except Exception as e:
|
|
@@ -329,181 +433,131 @@ def process_face_swap(source_pil_img: Image.Image, target_pil_img: Image.Image,
|
|
| 329 |
progress(1.0, desc="Processing complete!")
|
| 330 |
return swapped_pil_img, temp_file_path
|
| 331 |
|
| 332 |
-
|
|
|
|
| 333 |
if target_pil_img is None:
|
| 334 |
blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
|
| 335 |
return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
|
| 336 |
-
|
| 337 |
target_np = convert_pil_to_cv2(target_pil_img)
|
| 338 |
-
if target_np is None:
|
| 339 |
blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
|
| 340 |
return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
|
| 341 |
-
|
| 342 |
faces = get_faces_from_image(target_np)
|
| 343 |
-
if not faces:
|
| 344 |
-
preview_pil_img = convert_cv2_to_pil(target_np)
|
| 345 |
-
if preview_pil_img is None:
|
| 346 |
preview_pil_img = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
|
| 347 |
return preview_pil_img, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
|
| 348 |
-
|
| 349 |
preview_np_img = draw_detected_faces(target_np, faces)
|
| 350 |
preview_pil_img = convert_cv2_to_pil(preview_np_img)
|
| 351 |
-
if preview_pil_img is None:
|
| 352 |
preview_pil_img = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
|
| 353 |
-
|
| 354 |
-
|
| 355 |
num_faces = len(faces)
|
| 356 |
slider_update = gr.Slider(minimum=0, maximum=max(0, num_faces - 1), value=0, step=1, interactive=(num_faces > 0))
|
| 357 |
return preview_pil_img, slider_update
|
| 358 |
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
with gr.Blocks(title="Ultimate Face Swap AI π v3", theme=gr.themes.Soft()) as demo: # Changed theme for variety
|
| 362 |
gr.Markdown(
|
| 363 |
"""
|
| 364 |
<div style="text-align: center;">
|
| 365 |
-
<h1>π
|
| 366 |
-
<p>Upload a source face, a target image, and let the AI work
|
| 367 |
-
<p>Optionally enhance with face restoration and color correction for stunning realism.</p>
|
| 368 |
</div>
|
| 369 |
"""
|
| 370 |
)
|
| 371 |
|
| 372 |
-
if not core_models_loaded_successfully:
|
| 373 |
-
gr.Error("CRITICAL ERROR: Core models (Face Analyzer or
|
| 374 |
|
| 375 |
with gr.Row():
|
| 376 |
with gr.Column(scale=1):
|
| 377 |
-
source_image_input = gr.Image(label="π€ Source Face Image
|
| 378 |
with gr.Column(scale=1):
|
| 379 |
target_image_input = gr.Image(label="πΌοΈ Target Scene Image", type="pil", sources=["upload", "clipboard"], height=350)
|
| 380 |
|
| 381 |
with gr.Row(equal_height=True):
|
| 382 |
preview_button = gr.Button("π Preview & Select Target Face", variant="secondary")
|
| 383 |
-
face_index_slider = gr.Slider(
|
| 384 |
-
label="π― Select Target Face (0-indexed)",
|
| 385 |
-
minimum=0, maximum=0, step=1, value=0, interactive=False
|
| 386 |
-
)
|
| 387 |
|
| 388 |
target_faces_preview_output = gr.Image(label="π Detected Faces in Target", interactive=False, height=350)
|
| 389 |
-
|
| 390 |
-
gr.HTML("<hr style='margin-top: 15px; margin-bottom: 15px;'>") # Styled HR
|
| 391 |
|
| 392 |
with gr.Row():
|
| 393 |
with gr.Column(scale=1):
|
| 394 |
-
enhance_checkbox_label = "β¨ Apply
|
| 395 |
-
if not restoration_model_loaded_successfully:
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
)
|
| 402 |
-
if not restoration_model_loaded_successfully:
|
| 403 |
-
gr.Markdown("<p style='color: orange; font-size:0.8em;'>β οΈ Face restoration model not loaded. Feature disabled. Check logs.</p>")
|
| 404 |
|
| 405 |
with gr.Column(scale=1):
|
| 406 |
-
|
| 407 |
|
| 408 |
-
|
| 409 |
with gr.Row():
|
| 410 |
-
swap_button = gr.Button("π GENERATE SWAP!", variant="primary", scale=3, interactive=core_models_loaded_successfully)
|
| 411 |
clear_button = gr.Button("π§Ή Clear All", variant="stop", scale=1)
|
| 412 |
|
| 413 |
-
|
| 414 |
with gr.Row():
|
| 415 |
swapped_image_output = gr.Image(label="β¨ Swapped Result", interactive=False, height=450)
|
| 416 |
download_output_file = gr.File(label="β¬οΈ Download Swapped Image")
|
| 417 |
|
| 418 |
-
#
|
| 419 |
def on_target_image_change_or_clear(target_img_pil):
|
| 420 |
-
if target_img_pil is None:
|
| 421 |
-
blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color
|
| 422 |
return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
return target_faces_preview_output.value, face_index_slider.value # Keep current state on new image upload until preview
|
| 428 |
-
|
| 429 |
-
target_image_input.change(
|
| 430 |
-
fn=on_target_image_change_or_clear,
|
| 431 |
-
inputs=[target_image_input],
|
| 432 |
-
outputs=[target_faces_preview_output, face_index_slider],
|
| 433 |
-
queue=False
|
| 434 |
-
)
|
| 435 |
-
|
| 436 |
-
preview_button.click(
|
| 437 |
-
fn=preview_target_faces,
|
| 438 |
-
inputs=[target_image_input],
|
| 439 |
-
outputs=[target_faces_preview_output, face_index_slider]
|
| 440 |
-
)
|
| 441 |
|
| 442 |
swap_button.click(
|
| 443 |
fn=process_face_swap,
|
| 444 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
outputs=[swapped_image_output, download_output_file]
|
| 446 |
)
|
| 447 |
|
| 448 |
def clear_all_inputs_outputs():
|
| 449 |
blank_preview = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color = 'lightgray')
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
None, # target_image_input
|
| 454 |
-
blank_preview, # target_faces_preview_output
|
| 455 |
-
gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False), # face_index_slider
|
| 456 |
-
None, # swapped_image_output
|
| 457 |
-
None # download_output_file
|
| 458 |
-
)
|
| 459 |
-
|
| 460 |
-
clear_button.click(
|
| 461 |
-
fn=clear_all_inputs_outputs,
|
| 462 |
-
inputs=None, # No inputs for this function
|
| 463 |
-
outputs=[
|
| 464 |
-
source_image_input, target_image_input,
|
| 465 |
-
target_faces_preview_output, face_index_slider,
|
| 466 |
-
swapped_image_output, download_output_file
|
| 467 |
-
],
|
| 468 |
-
queue=False # No need to queue a simple clear
|
| 469 |
-
)
|
| 470 |
|
| 471 |
gr.Examples(
|
| 472 |
examples=[
|
| 473 |
-
["examples/source_face.jpg", "examples/target_group.jpg", 0, True, True],
|
| 474 |
-
["examples/source_actor.png", "examples/target_scene.png", 1, True, True],
|
| 475 |
-
["examples/source_face.jpg", "examples/target_group.jpg", 0, False, True], # No enhancement, with color correction
|
| 476 |
-
["examples/source_actor.png", "examples/target_scene.png", 0, True, False], # With enhancement, no color correction
|
| 477 |
],
|
| 478 |
-
inputs=[source_image_input, target_image_input, face_index_slider, enhance_checkbox, color_correction_checkbox],
|
| 479 |
outputs=[swapped_image_output, download_output_file],
|
| 480 |
-
fn=process_face_swap,
|
| 481 |
-
cache_examples=False, # Set to "lazy" or True if processing is slow and examples are static
|
| 482 |
label="Example Face Swaps (Click to run)"
|
| 483 |
)
|
| 484 |
|
| 485 |
# --- Main Execution Block ---
|
| 486 |
if __name__ == "__main__":
|
| 487 |
os.makedirs("models", exist_ok=True)
|
| 488 |
-
os.makedirs("examples", exist_ok=True)
|
| 489 |
|
| 490 |
-
# Console messages about model loading status
|
| 491 |
print("\n" + "="*60)
|
| 492 |
-
print("π
|
| 493 |
print("="*60)
|
| 494 |
if not core_models_loaded_successfully:
|
| 495 |
-
print("π΄ CRITICAL ERROR: Core models
|
| 496 |
print(f" - Face Analyzer: '{FACE_ANALYZER_NAME}' (Status: {'Loaded' if face_analyzer else 'Failed'})")
|
| 497 |
-
print(f" -
|
| 498 |
print(" The application UI will load but will be NON-FUNCTIONAL.")
|
| 499 |
-
print(" Please check model paths and ensure files exist and are not corrupted.")
|
| 500 |
else:
|
| 501 |
-
print("π’ Core models (Face Analyzer &
|
| 502 |
|
| 503 |
if not restoration_model_loaded_successfully:
|
| 504 |
-
print(f"π‘ INFO: Face Restoration model ('{RESTORATION_MODEL_PATH}') not loaded.")
|
| 505 |
-
print(" The 'Face Restoration' feature will be disabled.")
|
| 506 |
-
print(" To enable, ensure the model file exists at the specified path and is valid.")
|
| 507 |
else:
|
| 508 |
print("π’ Face Restoration model loaded successfully.")
|
| 509 |
print("="*60 + "\n")
|
|
|
|
| 2 |
import cv2
|
| 3 |
import numpy as np
|
| 4 |
import gradio as gr
|
| 5 |
+
from insightface.app import FaceAnalysis # Still needed for detection and embeddings
|
| 6 |
+
# from insightface.model_zoo import get_model # No longer using this for the swapper
|
| 7 |
from PIL import Image
|
| 8 |
import tempfile
|
| 9 |
import logging
|
|
|
|
| 12 |
# --- Configuration & Setup ---
|
| 13 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 14 |
|
| 15 |
+
# MODIFIED: Path to the new swapper model
|
| 16 |
+
SWAPPER_MODEL_PATH = "models/reswapper_256.onnx" # <--- USING RESWAPPER_256
|
| 17 |
+
RESTORATION_MODEL_PATH = "models/gfpgan_1.4.onnx" # Or your chosen restorer, e.g., restoreformer++.onnx
|
| 18 |
+
# UPSCALER_MODEL_PATH = "models/RealESRGANx4plus.onnx" # Path for AI Upscaler if you were to add it
|
| 19 |
|
| 20 |
+
FACE_ANALYZER_NAME = 'buffalo_l' # InsightFace detection and embedding model
|
| 21 |
+
DETECTION_SIZE = (640, 640)
|
| 22 |
+
EXECUTION_PROVIDERS = ['CPUExecutionProvider'] # Or ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
| 23 |
|
| 24 |
# --- Global Variables (Lazy Loaded by initialize_models) ---
|
| 25 |
face_analyzer = None
|
| 26 |
+
reswapper_session = None # MODIFIED: Will hold the onnxruntime session for reswapper
|
| 27 |
face_restorer = None
|
| 28 |
+
# face_upscaler = None # If you add AI upscaling
|
| 29 |
|
| 30 |
# --- Initialization Functions ---
|
| 31 |
def initialize_models():
|
| 32 |
+
global face_analyzer, reswapper_session, face_restorer # MODIFIED: reswapper_session
|
| 33 |
try:
|
| 34 |
+
# Initialize FaceAnalysis model (for detection and embeddings)
|
| 35 |
if face_analyzer is None:
|
| 36 |
logging.info(f"Initializing FaceAnalysis model: {FACE_ANALYZER_NAME}")
|
| 37 |
face_analyzer = FaceAnalysis(name=FACE_ANALYZER_NAME, providers=EXECUTION_PROVIDERS)
|
| 38 |
+
face_analyzer.prepare(ctx_id=0, det_size=DETECTION_SIZE)
|
| 39 |
logging.info("FaceAnalysis model initialized successfully.")
|
| 40 |
|
| 41 |
+
# MODIFIED: Initialize ReSwapper model directly with onnxruntime
|
| 42 |
+
if reswapper_session is None:
|
| 43 |
if not os.path.exists(SWAPPER_MODEL_PATH):
|
| 44 |
+
logging.error(f"ReSwapper model FILE NOT FOUND at {SWAPPER_MODEL_PATH}. Swapping will fail.")
|
| 45 |
else:
|
| 46 |
+
logging.info(f"Loading ReSwapper model from: {SWAPPER_MODEL_PATH}")
|
| 47 |
try:
|
| 48 |
+
reswapper_session = onnxruntime.InferenceSession(SWAPPER_MODEL_PATH, providers=EXECUTION_PROVIDERS)
|
| 49 |
+
logging.info(f"ReSwapper model ({SWAPPER_MODEL_PATH}) loaded successfully with onnxruntime.")
|
| 50 |
+
except Exception as e:
|
| 51 |
+
logging.error(f"Error loading ReSwapper model {SWAPPER_MODEL_PATH} with onnxruntime: {e}", exc_info=True)
|
| 52 |
+
reswapper_session = None # Ensure it's None if loading failed
|
| 53 |
+
|
| 54 |
+
# Initialize Face Restoration Model (same as before)
|
| 55 |
+
if face_restorer is None:
|
|
|
|
| 56 |
if not os.path.exists(RESTORATION_MODEL_PATH):
|
| 57 |
+
logging.error(f"Face restoration model FILE NOT FOUND at: {RESTORATION_MODEL_PATH}. Enhancement will be disabled.")
|
| 58 |
else:
|
| 59 |
+
# (loading logic for face_restorer as in previous complete code)
|
| 60 |
logging.info(f"Attempting to load face restoration model from: {RESTORATION_MODEL_PATH}")
|
| 61 |
try:
|
| 62 |
face_restorer = onnxruntime.InferenceSession(RESTORATION_MODEL_PATH, providers=EXECUTION_PROVIDERS)
|
| 63 |
logging.info("Face restoration model loaded successfully.")
|
| 64 |
except Exception as e:
|
| 65 |
logging.error(f"Error loading face restoration model from {RESTORATION_MODEL_PATH}: {e}. Enhancement feature will be disabled.", exc_info=True)
|
| 66 |
+
face_restorer = None
|
| 67 |
+
|
| 68 |
+
# (Initialize AI Face Upscaler Model - if you were adding it)
|
| 69 |
+
|
| 70 |
except Exception as e:
|
| 71 |
logging.error(f"A critical error occurred during model initialization: {e}", exc_info=True)
|
|
|
|
| 72 |
|
| 73 |
# --- Call Initialization Early ---
|
| 74 |
initialize_models()
|
| 75 |
+
core_models_loaded_successfully = face_analyzer is not None and reswapper_session is not None # MODIFIED
|
|
|
|
| 76 |
restoration_model_loaded_successfully = face_restorer is not None
|
| 77 |
+
# upscaler_model_loaded_successfully = face_upscaler is not None # If adding upscaler
|
| 78 |
|
| 79 |
+
# --- Image Utility Functions (convert_pil_to_cv2, convert_cv2_to_pil - same as before) ---
|
| 80 |
def convert_pil_to_cv2(pil_image: Image.Image) -> np.ndarray | None:
|
| 81 |
if pil_image is None: return None
|
| 82 |
try:
|
|
|
|
| 93 |
logging.error(f"Error converting CV2 to PIL: {e}")
|
| 94 |
return None
|
| 95 |
|
| 96 |
+
# --- Core AI & Image Processing Functions (get_faces_from_image, draw_detected_faces, enhance_cropped_face, histogram_match_color, apply_naturalness_filters - mostly same) ---
|
| 97 |
+
# These functions (except the main swap logic) should largely remain the same.
|
| 98 |
+
# Make sure they are copied from your previous complete version.
|
| 99 |
+
# For brevity here, I'll assume they are present.
|
| 100 |
+
# Example of get_faces_from_image:
|
| 101 |
def get_faces_from_image(img_np: np.ndarray) -> list:
|
| 102 |
+
if face_analyzer is None: # Check against the global face_analyzer
|
|
|
|
| 103 |
logging.error("Face analyzer not available for get_faces_from_image.")
|
| 104 |
return []
|
| 105 |
if img_np is None: return []
|
| 106 |
try:
|
| 107 |
+
return face_analyzer.get(img_np) # This provides Face objects with embeddings
|
| 108 |
except Exception as e:
|
| 109 |
logging.error(f"Error during face detection: {e}", exc_info=True)
|
| 110 |
return []
|
| 111 |
|
| 112 |
+
# ... (draw_detected_faces, enhance_cropped_face, histogram_match_color, apply_naturalness_filters functions from previous full code)
|
| 113 |
+
def draw_detected_faces(img_np: np.ndarray, faces: list) -> np.ndarray: # Copied for completeness
|
| 114 |
img_with_boxes = img_np.copy()
|
| 115 |
for i, face in enumerate(faces):
|
| 116 |
+
box = face.bbox.astype(int)
|
| 117 |
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
|
| 118 |
cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 119 |
+
label_position = (x1, max(0, y1 - 10))
|
| 120 |
cv2.putText(img_with_boxes, f"Face {i}", label_position,
|
| 121 |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (36, 255, 12), 2)
|
| 122 |
return img_with_boxes
|
| 123 |
|
| 124 |
+
def enhance_cropped_face(face_crop_bgr: np.ndarray) -> np.ndarray: # Copied for completeness
|
| 125 |
if not restoration_model_loaded_successfully or face_restorer is None:
|
| 126 |
logging.warning("Face restorer model not available. Skipping enhancement for crop.")
|
| 127 |
return face_crop_bgr
|
| 128 |
if face_crop_bgr is None or face_crop_bgr.shape[0] == 0 or face_crop_bgr.shape[1] == 0:
|
| 129 |
logging.warning("Received empty or invalid face crop for enhancement.")
|
| 130 |
return face_crop_bgr if face_crop_bgr is not None else np.array([])
|
|
|
|
|
|
|
| 131 |
logging.info(f"Applying face restoration to crop of size {face_crop_bgr.shape[:2]}...")
|
| 132 |
crop_height, crop_width = face_crop_bgr.shape[:2]
|
|
|
|
| 133 |
try:
|
| 134 |
img_rgb = cv2.cvtColor(face_crop_bgr, cv2.COLOR_BGR2RGB)
|
| 135 |
+
restorer_input_size = (512, 512)
|
|
|
|
|
|
|
| 136 |
img_resized_for_model = cv2.resize(img_rgb, restorer_input_size, interpolation=cv2.INTER_AREA)
|
| 137 |
+
img_normalized = (img_resized_for_model / 255.0).astype(np.float32)
|
| 138 |
+
img_chw = np.transpose(img_normalized, (2, 0, 1))
|
| 139 |
+
input_tensor = np.expand_dims(img_chw, axis=0)
|
|
|
|
|
|
|
| 140 |
input_name = face_restorer.get_inputs()[0].name
|
| 141 |
output_name = face_restorer.get_outputs()[0].name
|
|
|
|
| 142 |
restored_output_model_size = face_restorer.run([output_name], {input_name: input_tensor})[0]
|
|
|
|
| 143 |
restored_img_chw = np.squeeze(restored_output_model_size, axis=0)
|
| 144 |
+
restored_img_hwc_model_size = np.transpose(restored_img_chw, (1, 2, 0))
|
| 145 |
restored_img_uint8_model_size = np.clip(restored_img_hwc_model_size * 255.0, 0, 255).astype(np.uint8)
|
|
|
|
|
|
|
| 146 |
restored_crop_rgb = cv2.resize(restored_img_uint8_model_size, (crop_width, crop_height), interpolation=cv2.INTER_LANCZOS4)
|
| 147 |
restored_crop_bgr = cv2.cvtColor(restored_crop_rgb, cv2.COLOR_RGB2BGR)
|
|
|
|
| 148 |
logging.info("Cropped face restoration complete.")
|
| 149 |
return restored_crop_bgr
|
| 150 |
except Exception as e:
|
| 151 |
logging.error(f"Error during face restoration for crop: {e}", exc_info=True)
|
| 152 |
+
return face_crop_bgr
|
| 153 |
|
| 154 |
+
def histogram_match_channel(source_channel: np.ndarray, target_channel: np.ndarray) -> np.ndarray: # Copied
|
|
|
|
| 155 |
source_shape = source_channel.shape
|
| 156 |
source_channel_flat = source_channel.flatten()
|
| 157 |
target_channel_flat = target_channel.flatten()
|
| 158 |
+
source_hist, bins = np.histogram(source_channel_flat, 256, [0,256], density=True) # Use density for CDF
|
| 159 |
+
target_hist, bins = np.histogram(target_channel_flat, 256, [0,256], density=True)
|
|
|
|
| 160 |
source_cdf = source_hist.cumsum()
|
|
|
|
|
|
|
|
|
|
| 161 |
target_cdf = target_hist.cumsum()
|
| 162 |
+
# Normalize CDFs to range [0, 255] for lookup table
|
| 163 |
+
source_cdf_norm = (source_cdf * 255 / source_cdf[-1]).astype(np.uint8) # Ensure last val is 255
|
| 164 |
+
target_cdf_norm = (target_cdf * 255 / target_cdf[-1]).astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
| 165 |
lookup_table = np.zeros(256, dtype='uint8')
|
| 166 |
+
idx = 0
|
| 167 |
+
for i in range(256):
|
| 168 |
+
while idx < 255 and target_cdf_norm[idx] < source_cdf_norm[i]:
|
| 169 |
+
idx += 1
|
| 170 |
+
lookup_table[i] = idx
|
| 171 |
+
matched_channel_flat = cv2.LUT(source_channel, lookup_table)
|
|
|
|
|
|
|
|
|
|
| 172 |
return matched_channel_flat.reshape(source_shape)
|
| 173 |
|
| 174 |
+
def histogram_match_color(source_img: np.ndarray, target_img: np.ndarray) -> np.ndarray: # Copied
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
if source_img is None or target_img is None or source_img.size == 0 or target_img.size == 0:
|
| 176 |
logging.warning("Cannot perform histogram matching on empty or None images.")
|
| 177 |
return source_img if source_img is not None else np.array([])
|
|
|
|
| 178 |
matched_img = np.zeros_like(source_img)
|
| 179 |
try:
|
| 180 |
+
for i in range(source_img.shape[2]):
|
| 181 |
matched_img[:,:,i] = histogram_match_channel(source_img[:,:,i], target_img[:,:,i])
|
| 182 |
return matched_img
|
| 183 |
except Exception as e:
|
| 184 |
logging.error(f"Error during histogram matching: {e}", exc_info=True)
|
| 185 |
+
return source_img
|
| 186 |
|
| 187 |
+
def apply_naturalness_filters(face_region_bgr: np.ndarray, noise_level: int = 3) -> np.ndarray: # Copied
|
| 188 |
+
if face_region_bgr is None or face_region_bgr.size == 0:
|
| 189 |
+
logging.warning("Cannot apply naturalness filters to empty or None region.")
|
| 190 |
+
return face_region_bgr if face_region_bgr is not None else np.array([])
|
| 191 |
+
processed_region = face_region_bgr.copy()
|
| 192 |
+
try:
|
| 193 |
+
processed_region = cv2.medianBlur(processed_region, 3)
|
| 194 |
+
if noise_level > 0:
|
| 195 |
+
gaussian_noise = np.random.normal(0, noise_level, processed_region.shape).astype(np.int16)
|
| 196 |
+
noisy_region_int16 = processed_region.astype(np.int16) + gaussian_noise
|
| 197 |
+
processed_region = np.clip(noisy_region_int16, 0, 255).astype(np.uint8)
|
| 198 |
+
logging.info("Applied naturalness filters to face region.")
|
| 199 |
+
except Exception as e:
|
| 200 |
+
logging.error(f"Error applying naturalness filters: {e}", exc_info=True)
|
| 201 |
+
return face_region_bgr
|
| 202 |
+
return processed_region
|
| 203 |
|
| 204 |
+
# --- Main Processing Function ---
|
| 205 |
def process_face_swap(source_pil_img: Image.Image, target_pil_img: Image.Image,
|
| 206 |
+
target_face_index: int, # apply_ai_upscaling: bool, (Removed for this example to focus on ReSwapper)
|
| 207 |
+
apply_enhancement: bool, apply_color_correction: bool,
|
| 208 |
+
apply_naturalness: bool,
|
| 209 |
progress=gr.Progress(track_tqdm=True)):
|
| 210 |
+
|
| 211 |
progress(0, desc="Initializing process...")
|
| 212 |
+
if not core_models_loaded_successfully: # Checks face_analyzer and reswapper_session
|
| 213 |
+
gr.Error("CRITICAL: Core models (Face Analyzer or ReSwapper) not loaded. Cannot proceed.")
|
| 214 |
+
return Image.new('RGB', (100,100), color='lightgrey'), None
|
|
|
|
| 215 |
|
| 216 |
if source_pil_img is None: raise gr.Error("Source image not provided.")
|
| 217 |
if target_pil_img is None: raise gr.Error("Target image not provided.")
|
| 218 |
|
| 219 |
progress(0.05, desc="Converting images...")
|
| 220 |
+
source_np = convert_pil_to_cv2(source_pil_img) # BGR
|
| 221 |
+
target_np = convert_pil_to_cv2(target_pil_img) # BGR
|
|
|
|
| 222 |
if source_np is None or target_np is None:
|
| 223 |
raise gr.Error("Image conversion failed. Please try different images.")
|
| 224 |
|
|
|
|
| 226 |
|
| 227 |
progress(0.15, desc="Detecting face in source image...")
|
| 228 |
source_faces = get_faces_from_image(source_np)
|
| 229 |
+
if not source_faces: raise gr.Error("No face found in the source image.")
|
| 230 |
+
source_face = source_faces[0] # InsightFace Face object
|
| 231 |
|
| 232 |
progress(0.25, desc="Detecting faces in target image...")
|
| 233 |
target_faces = get_faces_from_image(target_np)
|
| 234 |
if not target_faces: raise gr.Error("No faces found in the target image.")
|
| 235 |
if not (0 <= target_face_index < len(target_faces)):
|
| 236 |
+
raise gr.Error(f"Selected target face index ({target_face_index}) is out of range.")
|
| 237 |
+
target_face_to_swap_info = target_faces[int(target_face_index)] # InsightFace Face object
|
| 238 |
+
|
| 239 |
+
swapped_bgr_img = target_np.copy() # Initialize with a copy
|
|
|
|
| 240 |
|
|
|
|
| 241 |
try:
|
| 242 |
+
progress(0.4, desc="Preparing inputs for ReSwapper...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
+
########## USER VERIFICATION NEEDED FOR RESWAPPER_256.ONNX ##########
|
| 245 |
+
# 1. Source Face Embedding Preparation
|
| 246 |
+
source_embedding = source_face.embedding.astype(np.float32) # (512,)
|
| 247 |
+
# Most ONNX models expect a batch dimension.
|
| 248 |
+
source_embedding_tensor = np.expand_dims(source_embedding, axis=0) # (1, 512)
|
| 249 |
+
|
| 250 |
+
# 2. Target Image Preparation
|
| 251 |
+
# ASSUMPTION: reswapper_256.onnx might expect a fixed size input, e.g., 256x256.
|
| 252 |
+
# If it does, you MUST resize target_np to that size.
|
| 253 |
+
# For this example, let's assume it takes the target image and processes internally
|
| 254 |
+
# or expects a specific format. A common pattern is RGB, normalized, CHW.
|
| 255 |
+
# Let's assume input size of 256x256 for the target image input to the model.
|
| 256 |
+
# We will process the whole target_np, let the model do its work, and it should return the modified whole target_np.
|
| 257 |
+
|
| 258 |
+
target_img_for_onnx = cv2.cvtColor(target_np, cv2.COLOR_BGR2RGB) # Often RGB for ONNX
|
| 259 |
+
|
| 260 |
+
# RESIZE THE TARGET IMAGE IF THE MODEL EXPECTS A FIXED INPUT SIZE
|
| 261 |
+
# E.g., if reswapper_256.onnx expects 256x256 input for the *target image itself*:
|
| 262 |
+
# target_img_for_onnx = cv2.resize(target_img_for_onnx, (256, 256), interpolation=cv2.INTER_AREA)
|
| 263 |
+
# This is a CRITICAL assumption. If it operates on full-res target, no resize needed here.
|
| 264 |
+
# If the "256" in reswapper_256 refers to an internal patch size, it might handle varied input.
|
| 265 |
+
# For now, let's NOT resize the full target image, assuming the model handles it.
|
| 266 |
+
|
| 267 |
+
target_img_normalized = (target_img_for_onnx / 255.0).astype(np.float32) # Normalize [0, 1]
|
| 268 |
+
# Or normalize to [-1, 1]: (target_img_for_onnx / 127.5 - 1.0).astype(np.float32)
|
| 269 |
+
|
| 270 |
+
target_img_chw = np.transpose(target_img_normalized, (2, 0, 1)) # HWC to CHW
|
| 271 |
+
target_image_tensor = np.expand_dims(target_img_chw, axis=0) # Add batch dimension
|
| 272 |
+
|
| 273 |
+
# 3. Target Face Bounding Box (Optional, but many swappers need it)
|
| 274 |
+
# The InsightFace Face object (target_face_to_swap_info) contains .bbox
|
| 275 |
+
# bbox is [x1, y1, x2, y2]. Normalize if model expects normalized bbox.
|
| 276 |
+
target_bbox_tensor = target_face_to_swap_info.bbox.astype(np.float32)
|
| 277 |
+
# Reshape if needed, e.g., (1, 4)
|
| 278 |
+
target_bbox_tensor = np.expand_dims(target_bbox_tensor, axis=0)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# Placeholder names for ONNX model inputs. YOU MUST FIND THE REAL NAMES.
|
| 282 |
+
# Common names could be "target", "source_embed", "target_face_box"
|
| 283 |
+
# Check model's .get_inputs() if loaded, or use Netron app to view ONNX.
|
| 284 |
+
onnx_input_names = [inp.name for inp in reswapper_session.get_inputs()]
|
| 285 |
+
# Crude assumption based on common patterns (highly likely to be WRONG):
|
| 286 |
+
# This is where you absolutely need the model's documentation.
|
| 287 |
+
# Let's assume: input0=target_image, input1=source_embedding, input2=target_bbox (OPTIONAL)
|
| 288 |
+
# This is a GUESS.
|
| 289 |
+
|
| 290 |
+
# For this example, let's try a simpler input signature if possible,
|
| 291 |
+
# like what an insightface swapper's .get() method might abstract.
|
| 292 |
+
# A common pattern for standalone ONNX swappers is:
|
| 293 |
+
# Input 1: Target image tensor (e.g., BCHW, float32, normalized)
|
| 294 |
+
# Input 2: Source embedding tensor (e.g., B x EmbDim, float32)
|
| 295 |
+
# Some might also take target face landmarks or a mask.
|
| 296 |
+
|
| 297 |
+
# Let's assume these input names based on typical conventions.
|
| 298 |
+
# YOU MUST VERIFY THESE using Netron or the model's documentation.
|
| 299 |
+
# input_feed = {
|
| 300 |
+
# "target_image": target_image_tensor, # Placeholder name
|
| 301 |
+
# "source_embedding": source_embedding_tensor # Placeholder name
|
| 302 |
+
# }
|
| 303 |
+
# If target_bbox is also an input:
|
| 304 |
+
# input_feed["target_bbox"] = target_bbox_tensor # Placeholder name
|
| 305 |
+
|
| 306 |
+
# A more robust way is to check input names from the model:
|
| 307 |
+
# input_name_target_img = reswapper_session.get_inputs()[0].name
|
| 308 |
+
# input_name_source_emb = reswapper_session.get_inputs()[1].name
|
| 309 |
+
# input_feed = {
|
| 310 |
+
# input_name_target_img: target_image_tensor,
|
| 311 |
+
# input_name_source_emb: source_embedding_tensor
|
| 312 |
+
# }
|
| 313 |
+
# if len(reswapper_session.get_inputs()) > 2:
|
| 314 |
+
# input_name_target_bbox = reswapper_session.get_inputs()[2].name
|
| 315 |
+
# input_feed[input_name_target_bbox] = target_bbox_tensor
|
| 316 |
+
|
| 317 |
+
# For this illustrative code, I will make up plausible input names.
|
| 318 |
+
# ***** REPLACE "actual_target_input_name", "actual_source_embed_input_name" etc. *****
|
| 319 |
+
# ***** WITH THE REAL TENSOR NAMES FROM reswapper_256.onnx. *****
|
| 320 |
+
# ***** Common names are often like 'input', 'target', 'source', 'embedding'.*****
|
| 321 |
+
# ***** Use Netron app to inspect your reswapper_256.onnx file! *****
|
| 322 |
+
input_feed = {
|
| 323 |
+
# Example, assuming first input is target image, second is source embedding
|
| 324 |
+
reswapper_session.get_inputs()[0].name: target_image_tensor,
|
| 325 |
+
reswapper_session.get_inputs()[1].name: source_embedding_tensor
|
| 326 |
+
}
|
| 327 |
+
# If your model takes a third input like target_bbox (very model specific):
|
| 328 |
+
if len(reswapper_session.get_inputs()) > 2: # Check if there's a third input
|
| 329 |
+
input_feed[reswapper_session.get_inputs()[2].name] = target_bbox_tensor
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
# Placeholder for output tensor name. YOU MUST FIND THE REAL NAME.
|
| 333 |
+
# output_name_placeholder = "output_image"
|
| 334 |
+
# More robust:
|
| 335 |
+
output_name = reswapper_session.get_outputs()[0].name
|
| 336 |
+
########## END OF USER VERIFICATION NEEDED ##########
|
| 337 |
+
|
| 338 |
+
progress(0.5, desc="Running ReSwapper model...")
|
| 339 |
+
onnx_outputs = reswapper_session.run([output_name], input_feed)
|
| 340 |
+
swapped_image_tensor = onnx_outputs[0] # Assuming first output is the image
|
| 341 |
+
|
| 342 |
+
progress(0.6, desc="Post-processing ReSwapper output...")
|
| 343 |
+
# Post-process the output tensor back to a BGR image
|
| 344 |
+
# 1. Remove batch dimension if present
|
| 345 |
+
if swapped_image_tensor.ndim == 4:
|
| 346 |
+
swapped_image_tensor = np.squeeze(swapped_image_tensor, axis=0) # (C, H, W)
|
| 347 |
+
# 2. Transpose CHW to HWC
|
| 348 |
+
swapped_image_hwc = np.transpose(swapped_image_tensor, (1, 2, 0)) # (H, W, C)
|
| 349 |
+
# 3. Denormalize (inverse of the normalization applied earlier)
|
| 350 |
+
# If normalized to [0, 1]:
|
| 351 |
+
swapped_image_denormalized = swapped_image_hwc * 255.0
|
| 352 |
+
# If normalized to [-1, 1]:
|
| 353 |
+
# swapped_image_denormalized = (swapped_image_hwc + 1.0) * 127.5
|
| 354 |
+
|
| 355 |
+
swapped_image_uint8 = np.clip(swapped_image_denormalized, 0, 255).astype(np.uint8)
|
| 356 |
+
|
| 357 |
+
# 4. Convert RGB to BGR if model outputs RGB
|
| 358 |
+
swapped_bgr_img = cv2.cvtColor(swapped_image_uint8, cv2.COLOR_RGB2BGR)
|
| 359 |
+
|
| 360 |
+
# IF `reswapper_256.onnx` returned a resized image (e.g. 256x256),
|
| 361 |
+
# you might need to resize `swapped_bgr_img` back to original target_h, target_w.
|
| 362 |
+
if swapped_bgr_img.shape[0] != target_h or swapped_bgr_img.shape[1] != target_w:
|
| 363 |
+
logging.info(f"Resizing ReSwapper output from {swapped_bgr_img.shape[:2]} to {(target_h, target_w)}")
|
| 364 |
+
swapped_bgr_img = cv2.resize(swapped_bgr_img, (target_w, target_h), interpolation=cv2.INTER_LANCZOS4)
|
| 365 |
|
| 366 |
+
# --- Post-processing pipeline (Restoration, Color Correction, Naturalness) ---
|
| 367 |
+
# These will now operate on `swapped_bgr_img` which came from reswapper.
|
| 368 |
+
current_processed_image = swapped_bgr_img.copy()
|
| 369 |
+
bbox_coords = target_face_to_swap_info.bbox.astype(int) # [x1, y1, x2, y2] for ROI processing
|
| 370 |
+
roi_x1, roi_y1, roi_x2, roi_y2 = bbox_coords[0], bbox_coords[1], bbox_coords[2], bbox_coords[3]
|
| 371 |
+
|
| 372 |
+
face_roi_for_postprocessing = current_processed_image[roi_y1:roi_y2, roi_x1:roi_x2].copy()
|
| 373 |
+
|
| 374 |
+
if face_roi_for_postprocessing.size == 0:
|
| 375 |
+
logging.warning("ROI for post-processing is empty. Skipping post-processing steps.")
|
| 376 |
+
else:
|
| 377 |
+
if apply_enhancement:
|
| 378 |
+
if restoration_model_loaded_successfully:
|
| 379 |
+
progress(0.7, desc="Applying face restoration...")
|
| 380 |
+
face_roi_for_postprocessing = enhance_cropped_face(face_roi_for_postprocessing)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
else:
|
| 382 |
+
gr.Info("Face restoration model N/A, enhancement skipped.")
|
| 383 |
+
|
| 384 |
+
if apply_color_correction:
|
| 385 |
+
progress(0.8, desc="Applying color correction...")
|
| 386 |
+
target_face_region_for_color = target_np[roi_y1:roi_y2, roi_x1:roi_x2]
|
| 387 |
+
if target_face_region_for_color.size > 0:
|
| 388 |
+
face_roi_for_postprocessing = histogram_match_color(face_roi_for_postprocessing, target_face_region_for_color.copy())
|
| 389 |
+
|
| 390 |
+
if apply_naturalness:
|
| 391 |
+
progress(0.85, desc="Applying naturalness filters...")
|
| 392 |
+
face_roi_for_postprocessing = apply_naturalness_filters(face_roi_for_postprocessing)
|
| 393 |
+
|
| 394 |
+
# Paste the processed ROI back
|
| 395 |
+
if face_roi_for_postprocessing.shape[0] == (roi_y2 - roi_y1) and face_roi_for_postprocessing.shape[1] == (roi_x2 - roi_x1):
|
| 396 |
+
current_processed_image[roi_y1:roi_y2, roi_x1:roi_x2] = face_roi_for_postprocessing
|
| 397 |
else:
|
| 398 |
+
logging.warning("Processed ROI size mismatch after post-processing. Attempting resize for paste.")
|
| 399 |
+
try:
|
| 400 |
+
resized_processed_roi = cv2.resize(face_roi_for_postprocessing, (roi_x2-roi_x1, roi_y2-roi_y1), interpolation=cv2.INTER_LANCZOS4)
|
| 401 |
+
current_processed_image[roi_y1:roi_y2, roi_x1:roi_x2] = resized_processed_roi
|
| 402 |
+
except Exception as e_resize:
|
| 403 |
+
logging.error(f"Failed to resize and paste processed ROI: {e_resize}")
|
| 404 |
|
| 405 |
|
| 406 |
+
swapped_bgr_img = current_processed_image # Final image after all steps
|
| 407 |
+
|
| 408 |
except Exception as e:
|
| 409 |
+
logging.error(f"Error during main processing pipeline: {e}", exc_info=True)
|
| 410 |
+
# Fallback to display the original target image or last good state if error occurs mid-process
|
| 411 |
+
img_to_show_on_error = swapped_bgr_img # Use current state which might be partially processed
|
| 412 |
+
swapped_pil_img_on_error = convert_cv2_to_pil(img_to_show_on_error)
|
| 413 |
+
if swapped_pil_img_on_error is None:
|
| 414 |
swapped_pil_img_on_error = Image.new('RGB', (target_w, target_h), color='lightgrey')
|
| 415 |
raise gr.Error(f"An error occurred: {str(e)}")
|
|
|
|
| 416 |
|
| 417 |
progress(0.9, desc="Finalizing image...")
|
| 418 |
swapped_pil_img = convert_cv2_to_pil(swapped_bgr_img)
|
| 419 |
+
if swapped_pil_img is None:
|
| 420 |
gr.Error("Failed to convert final image to display format.")
|
| 421 |
+
swapped_pil_img = Image.new('RGB', (target_w, target_h), color='lightgrey')
|
| 422 |
|
| 423 |
temp_file_path = None
|
| 424 |
try:
|
|
|
|
| 425 |
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False, mode="wb") as tmp_file:
|
| 426 |
+
swapped_pil_img.save(tmp_file, format="JPEG")
|
| 427 |
temp_file_path = tmp_file.name
|
| 428 |
logging.info(f"Swapped image saved to temporary file: {temp_file_path}")
|
| 429 |
except Exception as e:
|
|
|
|
| 433 |
progress(1.0, desc="Processing complete!")
|
| 434 |
return swapped_pil_img, temp_file_path
|
| 435 |
|
| 436 |
+
# --- Gradio Preview Function (preview_target_faces - same as before) ---
|
| 437 |
+
def preview_target_faces(target_pil_img: Image.Image): # Copied for completeness
|
| 438 |
if target_pil_img is None:
|
| 439 |
blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
|
| 440 |
return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
|
|
|
|
| 441 |
target_np = convert_pil_to_cv2(target_pil_img)
|
| 442 |
+
if target_np is None:
|
| 443 |
blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
|
| 444 |
return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
|
|
|
|
| 445 |
faces = get_faces_from_image(target_np)
|
| 446 |
+
if not faces:
|
| 447 |
+
preview_pil_img = convert_cv2_to_pil(target_np)
|
| 448 |
+
if preview_pil_img is None:
|
| 449 |
preview_pil_img = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
|
| 450 |
return preview_pil_img, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
|
|
|
|
| 451 |
preview_np_img = draw_detected_faces(target_np, faces)
|
| 452 |
preview_pil_img = convert_cv2_to_pil(preview_np_img)
|
| 453 |
+
if preview_pil_img is None:
|
| 454 |
preview_pil_img = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
|
|
|
|
|
|
|
| 455 |
num_faces = len(faces)
|
| 456 |
slider_update = gr.Slider(minimum=0, maximum=max(0, num_faces - 1), value=0, step=1, interactive=(num_faces > 0))
|
| 457 |
return preview_pil_img, slider_update
|
| 458 |
|
| 459 |
+
# --- Gradio UI Definition (Largely same, ensure inputs to process_face_swap are correct) ---
|
| 460 |
+
with gr.Blocks(title="ReSwapper AI Face Swap π", theme=gr.themes.Soft()) as demo:
|
|
|
|
| 461 |
gr.Markdown(
|
| 462 |
"""
|
| 463 |
<div style="text-align: center;">
|
| 464 |
+
<h1>π ReSwapper AI Face Swap π </h1>
|
| 465 |
+
<p>Using ReSwapper_256.onnx. Upload a source face, a target image, and let the AI work!</p>
|
|
|
|
| 466 |
</div>
|
| 467 |
"""
|
| 468 |
)
|
| 469 |
|
| 470 |
+
if not core_models_loaded_successfully: # This now checks face_analyzer and reswapper_session
|
| 471 |
+
gr.Error("CRITICAL ERROR: Core models (Face Analyzer or ReSwapper) failed to load. Application will not function correctly. Please check console logs and restart.")
|
| 472 |
|
| 473 |
with gr.Row():
|
| 474 |
with gr.Column(scale=1):
|
| 475 |
+
source_image_input = gr.Image(label="π€ Source Face Image", type="pil", sources=["upload", "clipboard"], height=350)
|
| 476 |
with gr.Column(scale=1):
|
| 477 |
target_image_input = gr.Image(label="πΌοΈ Target Scene Image", type="pil", sources=["upload", "clipboard"], height=350)
|
| 478 |
|
| 479 |
with gr.Row(equal_height=True):
|
| 480 |
preview_button = gr.Button("π Preview & Select Target Face", variant="secondary")
|
| 481 |
+
face_index_slider = gr.Slider(label="π― Select Target Face (0-indexed)", minimum=0, maximum=0, step=1, value=0, interactive=False)
|
|
|
|
|
|
|
|
|
|
| 482 |
|
| 483 |
target_faces_preview_output = gr.Image(label="π Detected Faces in Target", interactive=False, height=350)
|
| 484 |
+
gr.HTML("<hr style='margin-top: 15px; margin-bottom: 15px;'>")
|
|
|
|
| 485 |
|
| 486 |
with gr.Row():
|
| 487 |
with gr.Column(scale=1):
|
| 488 |
+
enhance_checkbox_label = "β¨ Apply Face Restoration"
|
| 489 |
+
if not restoration_model_loaded_successfully: enhance_checkbox_label += " (Model N/A)"
|
| 490 |
+
enhance_checkbox = gr.Checkbox(label=enhance_checkbox_label, value=restoration_model_loaded_successfully, interactive=restoration_model_loaded_successfully)
|
| 491 |
+
if not restoration_model_loaded_successfully: gr.Markdown("<p style='color: orange; font-size:0.8em;'>β οΈ Restoration model N/A.</p>")
|
| 492 |
+
|
| 493 |
+
with gr.Column(scale=1):
|
| 494 |
+
color_correction_checkbox = gr.Checkbox(label="π¨ Apply Color Correction", value=True)
|
|
|
|
|
|
|
|
|
|
| 495 |
|
| 496 |
with gr.Column(scale=1):
|
| 497 |
+
naturalness_checkbox = gr.Checkbox(label="πΏ Apply Subtle Naturalness Filters", value=False)
|
| 498 |
|
|
|
|
| 499 |
with gr.Row():
|
| 500 |
+
swap_button = gr.Button("π GENERATE SWAP!", variant="primary", scale=3, interactive=core_models_loaded_successfully)
|
| 501 |
clear_button = gr.Button("π§Ή Clear All", variant="stop", scale=1)
|
| 502 |
|
|
|
|
| 503 |
with gr.Row():
|
| 504 |
swapped_image_output = gr.Image(label="β¨ Swapped Result", interactive=False, height=450)
|
| 505 |
download_output_file = gr.File(label="β¬οΈ Download Swapped Image")
|
| 506 |
|
| 507 |
+
# Event Handlers
|
| 508 |
def on_target_image_change_or_clear(target_img_pil):
|
| 509 |
+
if target_img_pil is None:
|
| 510 |
+
blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
|
| 511 |
return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
|
| 512 |
+
return target_faces_preview_output.value, face_index_slider.value
|
| 513 |
+
|
| 514 |
+
target_image_input.change(fn=on_target_image_change_or_clear, inputs=[target_image_input], outputs=[target_faces_preview_output, face_index_slider], queue=False)
|
| 515 |
+
preview_button.click(fn=preview_target_faces, inputs=[target_image_input], outputs=[target_faces_preview_output, face_index_slider])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
|
| 517 |
swap_button.click(
|
| 518 |
fn=process_face_swap,
|
| 519 |
+
inputs=[
|
| 520 |
+
source_image_input, target_image_input, face_index_slider,
|
| 521 |
+
enhance_checkbox, color_correction_checkbox, naturalness_checkbox
|
| 522 |
+
# Note: Removed apply_ai_upscaling from inputs list for this example as we focused on ReSwapper integration
|
| 523 |
+
],
|
| 524 |
outputs=[swapped_image_output, download_output_file]
|
| 525 |
)
|
| 526 |
|
| 527 |
def clear_all_inputs_outputs():
|
| 528 |
blank_preview = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color = 'lightgray')
|
| 529 |
+
return (None, None, blank_preview, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False), None, None)
|
| 530 |
+
|
| 531 |
+
clear_button.click(fn=clear_all_inputs_outputs, inputs=None, outputs=[source_image_input, target_image_input, target_faces_preview_output, face_index_slider, swapped_image_output, download_output_file], queue=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
|
| 533 |
gr.Examples(
|
| 534 |
examples=[
|
| 535 |
+
["examples/source_face.jpg", "examples/target_group.jpg", 0, True, True, True],
|
|
|
|
|
|
|
|
|
|
| 536 |
],
|
| 537 |
+
inputs=[source_image_input, target_image_input, face_index_slider, enhance_checkbox, color_correction_checkbox, naturalness_checkbox],
|
| 538 |
outputs=[swapped_image_output, download_output_file],
|
| 539 |
+
fn=process_face_swap, cache_examples=False,
|
|
|
|
| 540 |
label="Example Face Swaps (Click to run)"
|
| 541 |
)
|
| 542 |
|
| 543 |
# --- Main Execution Block ---
|
| 544 |
if __name__ == "__main__":
|
| 545 |
os.makedirs("models", exist_ok=True)
|
| 546 |
+
os.makedirs("examples", exist_ok=True)
|
| 547 |
|
|
|
|
| 548 |
print("\n" + "="*60)
|
| 549 |
+
print("π ReSwapper AI FACE SWAP - STARTUP STATUS π")
|
| 550 |
print("="*60)
|
| 551 |
if not core_models_loaded_successfully:
|
| 552 |
+
print(f"π΄ CRITICAL ERROR: Core models FAILED to load.")
|
| 553 |
print(f" - Face Analyzer: '{FACE_ANALYZER_NAME}' (Status: {'Loaded' if face_analyzer else 'Failed'})")
|
| 554 |
+
print(f" - ReSwapper Model: '{SWAPPER_MODEL_PATH}' (Status: {'Loaded' if reswapper_session else 'Failed'})") # MODIFIED
|
| 555 |
print(" The application UI will load but will be NON-FUNCTIONAL.")
|
|
|
|
| 556 |
else:
|
| 557 |
+
print("π’ Core models (Face Analyzer & ReSwapper) loaded successfully.")
|
| 558 |
|
| 559 |
if not restoration_model_loaded_successfully:
|
| 560 |
+
print(f"π‘ INFO: Face Restoration model ('{RESTORATION_MODEL_PATH}') not loaded. Enhancement feature disabled.")
|
|
|
|
|
|
|
| 561 |
else:
|
| 562 |
print("π’ Face Restoration model loaded successfully.")
|
| 563 |
print("="*60 + "\n")
|