Update app.py
Browse files
app.py
CHANGED
|
@@ -3,7 +3,6 @@ import cv2
|
|
| 3 |
import numpy as np
|
| 4 |
import gradio as gr
|
| 5 |
from insightface.app import FaceAnalysis # Still needed for detection and embeddings
|
| 6 |
-
# from insightface.model_zoo import get_model # No longer using this for the swapper
|
| 7 |
from PIL import Image
|
| 8 |
import tempfile
|
| 9 |
import logging
|
|
@@ -12,33 +11,28 @@ import onnxruntime
|
|
| 12 |
# --- Configuration & Setup ---
|
| 13 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 14 |
|
| 15 |
-
|
| 16 |
-
SWAPPER_MODEL_PATH = "models/reswapper_256.onnx" # <--- USING RESWAPPER_256
|
| 17 |
RESTORATION_MODEL_PATH = "models/gfpgan_1.4.onnx" # Or your chosen restorer, e.g., restoreformer++.onnx
|
| 18 |
-
# UPSCALER_MODEL_PATH = "models/RealESRGANx4plus.onnx" # Path for AI Upscaler if you were to add it
|
| 19 |
|
| 20 |
-
FACE_ANALYZER_NAME = 'buffalo_l'
|
| 21 |
DETECTION_SIZE = (640, 640)
|
| 22 |
-
EXECUTION_PROVIDERS = ['CPUExecutionProvider']
|
| 23 |
|
| 24 |
-
# --- Global Variables
|
| 25 |
face_analyzer = None
|
| 26 |
-
reswapper_session = None #
|
| 27 |
face_restorer = None
|
| 28 |
-
# face_upscaler = None # If you add AI upscaling
|
| 29 |
|
| 30 |
# --- Initialization Functions ---
|
| 31 |
def initialize_models():
|
| 32 |
-
global face_analyzer, reswapper_session, face_restorer
|
| 33 |
try:
|
| 34 |
-
# Initialize FaceAnalysis model (for detection and embeddings)
|
| 35 |
if face_analyzer is None:
|
| 36 |
logging.info(f"Initializing FaceAnalysis model: {FACE_ANALYZER_NAME}")
|
| 37 |
face_analyzer = FaceAnalysis(name=FACE_ANALYZER_NAME, providers=EXECUTION_PROVIDERS)
|
| 38 |
face_analyzer.prepare(ctx_id=0, det_size=DETECTION_SIZE)
|
| 39 |
logging.info("FaceAnalysis model initialized successfully.")
|
| 40 |
|
| 41 |
-
# MODIFIED: Initialize ReSwapper model directly with onnxruntime
|
| 42 |
if reswapper_session is None:
|
| 43 |
if not os.path.exists(SWAPPER_MODEL_PATH):
|
| 44 |
logging.error(f"ReSwapper model FILE NOT FOUND at {SWAPPER_MODEL_PATH}. Swapping will fail.")
|
|
@@ -49,14 +43,12 @@ def initialize_models():
|
|
| 49 |
logging.info(f"ReSwapper model ({SWAPPER_MODEL_PATH}) loaded successfully with onnxruntime.")
|
| 50 |
except Exception as e:
|
| 51 |
logging.error(f"Error loading ReSwapper model {SWAPPER_MODEL_PATH} with onnxruntime: {e}", exc_info=True)
|
| 52 |
-
reswapper_session = None
|
| 53 |
-
|
| 54 |
-
# Initialize Face Restoration Model (same as before)
|
| 55 |
if face_restorer is None:
|
| 56 |
if not os.path.exists(RESTORATION_MODEL_PATH):
|
| 57 |
-
logging.error(f"Face restoration model FILE NOT FOUND at: {RESTORATION_MODEL_PATH}. Enhancement will be disabled.")
|
| 58 |
else:
|
| 59 |
-
# (loading logic for face_restorer as in previous complete code)
|
| 60 |
logging.info(f"Attempting to load face restoration model from: {RESTORATION_MODEL_PATH}")
|
| 61 |
try:
|
| 62 |
face_restorer = onnxruntime.InferenceSession(RESTORATION_MODEL_PATH, providers=EXECUTION_PROVIDERS)
|
|
@@ -64,19 +56,14 @@ def initialize_models():
|
|
| 64 |
except Exception as e:
|
| 65 |
logging.error(f"Error loading face restoration model from {RESTORATION_MODEL_PATH}: {e}. Enhancement feature will be disabled.", exc_info=True)
|
| 66 |
face_restorer = None
|
| 67 |
-
|
| 68 |
-
# (Initialize AI Face Upscaler Model - if you were adding it)
|
| 69 |
-
|
| 70 |
except Exception as e:
|
| 71 |
logging.error(f"A critical error occurred during model initialization: {e}", exc_info=True)
|
| 72 |
|
| 73 |
-
# --- Call Initialization Early ---
|
| 74 |
initialize_models()
|
| 75 |
-
core_models_loaded_successfully = face_analyzer is not None and reswapper_session is not None
|
| 76 |
restoration_model_loaded_successfully = face_restorer is not None
|
| 77 |
-
# upscaler_model_loaded_successfully = face_upscaler is not None # If adding upscaler
|
| 78 |
|
| 79 |
-
# --- Image Utility Functions
|
| 80 |
def convert_pil_to_cv2(pil_image: Image.Image) -> np.ndarray | None:
|
| 81 |
if pil_image is None: return None
|
| 82 |
try:
|
|
@@ -93,123 +80,109 @@ def convert_cv2_to_pil(cv2_image: np.ndarray) -> Image.Image | None:
|
|
| 93 |
logging.error(f"Error converting CV2 to PIL: {e}")
|
| 94 |
return None
|
| 95 |
|
| 96 |
-
# --- Core AI & Image Processing Functions
|
| 97 |
-
# These functions (except the main swap logic) should largely remain the same.
|
| 98 |
-
# Make sure they are copied from your previous complete version.
|
| 99 |
-
# For brevity here, I'll assume they are present.
|
| 100 |
-
# Example of get_faces_from_image:
|
| 101 |
def get_faces_from_image(img_np: np.ndarray) -> list:
|
| 102 |
-
if face_analyzer is None:
|
| 103 |
logging.error("Face analyzer not available for get_faces_from_image.")
|
| 104 |
return []
|
| 105 |
if img_np is None: return []
|
| 106 |
try:
|
| 107 |
-
return face_analyzer.get(img_np)
|
| 108 |
except Exception as e:
|
| 109 |
logging.error(f"Error during face detection: {e}", exc_info=True)
|
| 110 |
return []
|
| 111 |
|
| 112 |
-
|
| 113 |
-
def draw_detected_faces(img_np: np.ndarray, faces: list) -> np.ndarray: # Copied for completeness
|
| 114 |
img_with_boxes = img_np.copy()
|
| 115 |
for i, face in enumerate(faces):
|
| 116 |
box = face.bbox.astype(int)
|
| 117 |
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
|
| 118 |
cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 119 |
label_position = (x1, max(0, y1 - 10))
|
| 120 |
-
cv2.putText(img_with_boxes, f"Face {i}", label_position,
|
| 121 |
-
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (36, 255, 12), 2)
|
| 122 |
return img_with_boxes
|
| 123 |
|
| 124 |
-
def enhance_cropped_face(face_crop_bgr: np.ndarray) -> np.ndarray:
|
| 125 |
if not restoration_model_loaded_successfully or face_restorer is None:
|
| 126 |
-
logging.warning("Face restorer model not available. Skipping enhancement for crop.")
|
| 127 |
return face_crop_bgr
|
| 128 |
-
if face_crop_bgr is None or face_crop_bgr.
|
| 129 |
-
logging.warning("Received empty or invalid face crop for enhancement.")
|
| 130 |
-
return face_crop_bgr if face_crop_bgr is not None else np.array([])
|
| 131 |
logging.info(f"Applying face restoration to crop of size {face_crop_bgr.shape[:2]}...")
|
| 132 |
-
|
| 133 |
try:
|
| 134 |
img_rgb = cv2.cvtColor(face_crop_bgr, cv2.COLOR_BGR2RGB)
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
img_chw = np.transpose(
|
| 139 |
input_tensor = np.expand_dims(img_chw, axis=0)
|
| 140 |
input_name = face_restorer.get_inputs()[0].name
|
| 141 |
output_name = face_restorer.get_outputs()[0].name
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
logging.info("Cropped face restoration complete.")
|
| 149 |
-
return restored_crop_bgr
|
| 150 |
except Exception as e:
|
| 151 |
logging.error(f"Error during face restoration for crop: {e}", exc_info=True)
|
| 152 |
return face_crop_bgr
|
| 153 |
|
| 154 |
-
def histogram_match_channel(
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
for i in range(256):
|
| 168 |
-
while
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
return matched_channel_flat.reshape(source_shape)
|
| 173 |
|
| 174 |
-
def histogram_match_color(source_img: np.ndarray, target_img: np.ndarray) -> np.ndarray:
|
| 175 |
if source_img is None or target_img is None or source_img.size == 0 or target_img.size == 0:
|
| 176 |
-
logging.warning("Cannot perform histogram matching on empty or None images.")
|
| 177 |
return source_img if source_img is not None else np.array([])
|
| 178 |
-
|
| 179 |
try:
|
| 180 |
for i in range(source_img.shape[2]):
|
| 181 |
-
|
| 182 |
-
return
|
| 183 |
except Exception as e:
|
| 184 |
logging.error(f"Error during histogram matching: {e}", exc_info=True)
|
| 185 |
return source_img
|
| 186 |
|
| 187 |
-
def apply_naturalness_filters(face_region_bgr: np.ndarray, noise_level: int = 3) -> np.ndarray:
|
| 188 |
if face_region_bgr is None or face_region_bgr.size == 0:
|
| 189 |
-
logging.warning("Cannot apply naturalness filters to empty or None region.")
|
| 190 |
return face_region_bgr if face_region_bgr is not None else np.array([])
|
| 191 |
-
|
| 192 |
try:
|
| 193 |
-
|
| 194 |
if noise_level > 0:
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
logging.info("Applied naturalness filters
|
| 199 |
except Exception as e:
|
| 200 |
logging.error(f"Error applying naturalness filters: {e}", exc_info=True)
|
| 201 |
return face_region_bgr
|
| 202 |
-
return
|
| 203 |
|
| 204 |
# --- Main Processing Function ---
|
| 205 |
def process_face_swap(source_pil_img: Image.Image, target_pil_img: Image.Image,
|
| 206 |
-
target_face_index: int,
|
| 207 |
apply_enhancement: bool, apply_color_correction: bool,
|
| 208 |
apply_naturalness: bool,
|
| 209 |
progress=gr.Progress(track_tqdm=True)):
|
| 210 |
|
| 211 |
progress(0, desc="Initializing process...")
|
| 212 |
-
if not core_models_loaded_successfully:
|
| 213 |
gr.Error("CRITICAL: Core models (Face Analyzer or ReSwapper) not loaded. Cannot proceed.")
|
| 214 |
return Image.new('RGB', (100,100), color='lightgrey'), None
|
| 215 |
|
|
@@ -217,350 +190,265 @@ def process_face_swap(source_pil_img: Image.Image, target_pil_img: Image.Image,
|
|
| 217 |
if target_pil_img is None: raise gr.Error("Target image not provided.")
|
| 218 |
|
| 219 |
progress(0.05, desc="Converting images...")
|
| 220 |
-
source_np = convert_pil_to_cv2(source_pil_img)
|
| 221 |
-
target_np = convert_pil_to_cv2(target_pil_img)
|
| 222 |
if source_np is None or target_np is None:
|
| 223 |
-
raise gr.Error("Image conversion failed.
|
| 224 |
|
| 225 |
-
|
| 226 |
|
| 227 |
-
progress(0.15, desc="Detecting
|
| 228 |
source_faces = get_faces_from_image(source_np)
|
| 229 |
-
if not source_faces: raise gr.Error("No face
|
| 230 |
-
source_face = source_faces[0]
|
| 231 |
|
| 232 |
-
progress(0.25, desc="Detecting
|
| 233 |
target_faces = get_faces_from_image(target_np)
|
| 234 |
-
if not target_faces: raise gr.Error("No faces
|
| 235 |
if not (0 <= target_face_index < len(target_faces)):
|
| 236 |
-
raise gr.Error(f"
|
| 237 |
-
|
| 238 |
|
| 239 |
-
|
| 240 |
|
| 241 |
try:
|
| 242 |
-
progress(0.
|
| 243 |
|
| 244 |
-
#
|
| 245 |
-
|
| 246 |
-
source_embedding = source_face.embedding.astype(np.float32) # (512,)
|
| 247 |
-
# Most ONNX models expect a batch dimension.
|
| 248 |
source_embedding_tensor = np.expand_dims(source_embedding, axis=0) # (1, 512)
|
| 249 |
|
| 250 |
# 2. Target Image Preparation
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
#
|
| 254 |
-
#
|
| 255 |
-
#
|
| 256 |
-
|
| 257 |
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
-
#
|
| 261 |
-
#
|
| 262 |
-
#
|
| 263 |
-
|
| 264 |
-
#
|
| 265 |
-
#
|
| 266 |
-
|
| 267 |
-
target_img_normalized = (target_img_for_onnx / 255.0).astype(np.float32) # Normalize [0, 1]
|
| 268 |
-
# Or normalize to [-1, 1]: (target_img_for_onnx / 127.5 - 1.0).astype(np.float32)
|
| 269 |
|
| 270 |
target_img_chw = np.transpose(target_img_normalized, (2, 0, 1)) # HWC to CHW
|
| 271 |
target_image_tensor = np.expand_dims(target_img_chw, axis=0) # Add batch dimension
|
| 272 |
|
| 273 |
-
# 3. Target Face Bounding Box (
|
| 274 |
-
#
|
| 275 |
-
#
|
| 276 |
-
|
| 277 |
-
#
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
#
|
| 286 |
-
#
|
| 287 |
-
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
-
|
| 291 |
-
# like what an insightface swapper's .get() method might abstract.
|
| 292 |
-
# A common pattern for standalone ONNX swappers is:
|
| 293 |
-
# Input 1: Target image tensor (e.g., BCHW, float32, normalized)
|
| 294 |
-
# Input 2: Source embedding tensor (e.g., B x EmbDim, float32)
|
| 295 |
-
# Some might also take target face landmarks or a mask.
|
| 296 |
-
|
| 297 |
-
# Let's assume these input names based on typical conventions.
|
| 298 |
-
# YOU MUST VERIFY THESE using Netron or the model's documentation.
|
| 299 |
-
# input_feed = {
|
| 300 |
-
# "target_image": target_image_tensor, # Placeholder name
|
| 301 |
-
# "source_embedding": source_embedding_tensor # Placeholder name
|
| 302 |
-
# }
|
| 303 |
-
# If target_bbox is also an input:
|
| 304 |
-
# input_feed["target_bbox"] = target_bbox_tensor # Placeholder name
|
| 305 |
-
|
| 306 |
-
# A more robust way is to check input names from the model:
|
| 307 |
-
# input_name_target_img = reswapper_session.get_inputs()[0].name
|
| 308 |
-
# input_name_source_emb = reswapper_session.get_inputs()[1].name
|
| 309 |
-
# input_feed = {
|
| 310 |
-
# input_name_target_img: target_image_tensor,
|
| 311 |
-
# input_name_source_emb: source_embedding_tensor
|
| 312 |
-
# }
|
| 313 |
-
# if len(reswapper_session.get_inputs()) > 2:
|
| 314 |
-
# input_name_target_bbox = reswapper_session.get_inputs()[2].name
|
| 315 |
-
# input_feed[input_name_target_bbox] = target_bbox_tensor
|
| 316 |
-
|
| 317 |
-
# For this illustrative code, I will make up plausible input names.
|
| 318 |
-
# ***** REPLACE "actual_target_input_name", "actual_source_embed_input_name" etc. *****
|
| 319 |
-
# ***** WITH THE REAL TENSOR NAMES FROM reswapper_256.onnx. *****
|
| 320 |
-
# ***** Common names are often like 'input', 'target', 'source', 'embedding'.*****
|
| 321 |
-
# ***** Use Netron app to inspect your reswapper_256.onnx file! *****
|
| 322 |
-
input_feed = {
|
| 323 |
-
# Example, assuming first input is target image, second is source embedding
|
| 324 |
-
reswapper_session.get_inputs()[0].name: target_image_tensor,
|
| 325 |
-
reswapper_session.get_inputs()[1].name: source_embedding_tensor
|
| 326 |
-
}
|
| 327 |
-
# If your model takes a third input like target_bbox (very model specific):
|
| 328 |
-
if len(reswapper_session.get_inputs()) > 2: # Check if there's a third input
|
| 329 |
-
input_feed[reswapper_session.get_inputs()[2].name] = target_bbox_tensor
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
# Placeholder for output tensor name. YOU MUST FIND THE REAL NAME.
|
| 333 |
-
# output_name_placeholder = "output_image"
|
| 334 |
-
# More robust:
|
| 335 |
-
output_name = reswapper_session.get_outputs()[0].name
|
| 336 |
-
########## END OF USER VERIFICATION NEEDED ##########
|
| 337 |
|
| 338 |
progress(0.5, desc="Running ReSwapper model...")
|
| 339 |
onnx_outputs = reswapper_session.run([output_name], input_feed)
|
| 340 |
-
|
| 341 |
|
| 342 |
progress(0.6, desc="Post-processing ReSwapper output...")
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
# 3. Denormalize (inverse of the normalization applied earlier)
|
| 350 |
# If normalized to [0, 1]:
|
| 351 |
swapped_image_denormalized = swapped_image_hwc * 255.0
|
| 352 |
# If normalized to [-1, 1]:
|
| 353 |
# swapped_image_denormalized = (swapped_image_hwc + 1.0) * 127.5
|
| 354 |
|
| 355 |
-
|
| 356 |
|
| 357 |
-
#
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
|
|
|
| 365 |
|
| 366 |
# --- Post-processing pipeline (Restoration, Color Correction, Naturalness) ---
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
roi_x1, roi_y1, roi_x2, roi_y2 = bbox_coords[0], bbox_coords[1], bbox_coords[2], bbox_coords[3]
|
| 371 |
|
| 372 |
face_roi_for_postprocessing = current_processed_image[roi_y1:roi_y2, roi_x1:roi_x2].copy()
|
| 373 |
|
| 374 |
-
if face_roi_for_postprocessing.size
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
if restoration_model_loaded_successfully:
|
| 379 |
-
progress(0.7, desc="Applying face restoration...")
|
| 380 |
-
face_roi_for_postprocessing = enhance_cropped_face(face_roi_for_postprocessing)
|
| 381 |
-
else:
|
| 382 |
-
gr.Info("Face restoration model N/A, enhancement skipped.")
|
| 383 |
|
| 384 |
if apply_color_correction:
|
| 385 |
progress(0.8, desc="Applying color correction...")
|
| 386 |
-
|
| 387 |
-
if
|
| 388 |
-
face_roi_for_postprocessing = histogram_match_color(face_roi_for_postprocessing,
|
| 389 |
|
| 390 |
if apply_naturalness:
|
| 391 |
progress(0.85, desc="Applying naturalness filters...")
|
| 392 |
face_roi_for_postprocessing = apply_naturalness_filters(face_roi_for_postprocessing)
|
| 393 |
|
| 394 |
# Paste the processed ROI back
|
| 395 |
-
if face_roi_for_postprocessing.shape[0] == (roi_y2 - roi_y1) and
|
|
|
|
| 396 |
current_processed_image[roi_y1:roi_y2, roi_x1:roi_x2] = face_roi_for_postprocessing
|
| 397 |
-
else:
|
| 398 |
-
logging.warning("
|
| 399 |
try:
|
| 400 |
-
|
| 401 |
-
current_processed_image[roi_y1:roi_y2, roi_x1:roi_x2] =
|
| 402 |
-
except Exception as
|
| 403 |
-
logging.error(f"Failed to resize
|
| 404 |
-
|
|
|
|
| 405 |
|
| 406 |
-
|
| 407 |
|
| 408 |
except Exception as e:
|
| 409 |
logging.error(f"Error during main processing pipeline: {e}", exc_info=True)
|
| 410 |
-
|
| 411 |
-
img_to_show_on_error = swapped_bgr_img # Use current state which might be partially processed
|
| 412 |
-
swapped_pil_img_on_error = convert_cv2_to_pil(img_to_show_on_error)
|
| 413 |
if swapped_pil_img_on_error is None:
|
| 414 |
-
swapped_pil_img_on_error = Image.new('RGB', (
|
| 415 |
-
raise gr.Error(f"An error occurred: {str(e)}")
|
| 416 |
|
| 417 |
progress(0.9, desc="Finalizing image...")
|
| 418 |
-
swapped_pil_img = convert_cv2_to_pil(
|
| 419 |
if swapped_pil_img is None:
|
| 420 |
-
gr.Error("Failed to convert final image
|
| 421 |
-
swapped_pil_img = Image.new('RGB', (
|
| 422 |
|
| 423 |
temp_file_path = None
|
| 424 |
try:
|
| 425 |
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False, mode="wb") as tmp_file:
|
| 426 |
swapped_pil_img.save(tmp_file, format="JPEG")
|
| 427 |
temp_file_path = tmp_file.name
|
| 428 |
-
logging.info(f"Swapped image saved to temporary file: {temp_file_path}")
|
| 429 |
except Exception as e:
|
| 430 |
-
logging.error(f"Error saving
|
| 431 |
-
gr.Warning("Could not save
|
| 432 |
|
| 433 |
progress(1.0, desc="Processing complete!")
|
| 434 |
return swapped_pil_img, temp_file_path
|
| 435 |
|
| 436 |
-
# --- Gradio Preview Function
|
| 437 |
-
def preview_target_faces(target_pil_img: Image.Image):
|
| 438 |
if target_pil_img is None:
|
| 439 |
-
|
| 440 |
-
return
|
| 441 |
target_np = convert_pil_to_cv2(target_pil_img)
|
| 442 |
if target_np is None:
|
| 443 |
-
|
| 444 |
-
return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
|
| 445 |
faces = get_faces_from_image(target_np)
|
| 446 |
if not faces:
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
num_faces = len(faces)
|
| 456 |
-
slider_update = gr.Slider(minimum=0, maximum=max(0, num_faces - 1), value=0, step=1, interactive=(num_faces > 0))
|
| 457 |
-
return preview_pil_img, slider_update
|
| 458 |
-
|
| 459 |
-
# --- Gradio UI Definition (Largely same, ensure inputs to process_face_swap are correct) ---
|
| 460 |
-
with gr.Blocks(title="ReSwapper AI Face Swap π", theme=gr.themes.Soft()) as demo:
|
| 461 |
gr.Markdown(
|
| 462 |
"""
|
| 463 |
<div style="text-align: center;">
|
| 464 |
-
<h1>π ReSwapper AI Face Swap π </h1>
|
| 465 |
-
<p>
|
| 466 |
</div>
|
| 467 |
"""
|
| 468 |
)
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
gr.Error("CRITICAL ERROR: Core models (Face Analyzer or ReSwapper) failed to load. Application will not function correctly. Please check console logs and restart.")
|
| 472 |
|
| 473 |
with gr.Row():
|
| 474 |
with gr.Column(scale=1):
|
| 475 |
-
source_image_input = gr.Image(label="π€ Source Face Image", type="pil", sources=["upload",
|
| 476 |
with gr.Column(scale=1):
|
| 477 |
-
target_image_input = gr.Image(label="πΌοΈ Target Scene Image", type="pil", sources=["upload",
|
| 478 |
-
|
| 479 |
with gr.Row(equal_height=True):
|
| 480 |
preview_button = gr.Button("π Preview & Select Target Face", variant="secondary")
|
| 481 |
-
face_index_slider = gr.Slider(label="π― Select Target Face
|
| 482 |
-
|
| 483 |
target_faces_preview_output = gr.Image(label="π Detected Faces in Target", interactive=False, height=350)
|
| 484 |
-
gr.HTML("<hr style='margin
|
| 485 |
-
|
| 486 |
with gr.Row():
|
| 487 |
with gr.Column(scale=1):
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
if not restoration_model_loaded_successfully: gr.Markdown("<p style='color: orange; font-size:0.8em;'>β οΈ Restoration model N/A.</p>")
|
| 492 |
-
|
| 493 |
with gr.Column(scale=1):
|
| 494 |
-
color_correction_checkbox = gr.Checkbox(label="π¨
|
| 495 |
-
|
| 496 |
with gr.Column(scale=1):
|
| 497 |
-
naturalness_checkbox = gr.Checkbox(label="πΏ
|
| 498 |
-
|
| 499 |
with gr.Row():
|
| 500 |
swap_button = gr.Button("π GENERATE SWAP!", variant="primary", scale=3, interactive=core_models_loaded_successfully)
|
| 501 |
clear_button = gr.Button("π§Ή Clear All", variant="stop", scale=1)
|
| 502 |
-
|
| 503 |
with gr.Row():
|
| 504 |
swapped_image_output = gr.Image(label="β¨ Swapped Result", interactive=False, height=450)
|
| 505 |
download_output_file = gr.File(label="β¬οΈ Download Swapped Image")
|
| 506 |
|
| 507 |
# Event Handlers
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
target_image_input.change(fn=on_target_image_change_or_clear, inputs=[target_image_input], outputs=[target_faces_preview_output, face_index_slider], queue=False)
|
| 515 |
preview_button.click(fn=preview_target_faces, inputs=[target_image_input], outputs=[target_faces_preview_output, face_index_slider])
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
def clear_all_inputs_outputs():
|
| 528 |
-
blank_preview = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color = 'lightgray')
|
| 529 |
-
return (None, None, blank_preview, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False), None, None)
|
| 530 |
-
|
| 531 |
-
clear_button.click(fn=clear_all_inputs_outputs, inputs=None, outputs=[source_image_input, target_image_input, target_faces_preview_output, face_index_slider, swapped_image_output, download_output_file], queue=False)
|
| 532 |
-
|
| 533 |
gr.Examples(
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
],
|
| 537 |
-
|
| 538 |
-
outputs=[swapped_image_output, download_output_file],
|
| 539 |
-
fn=process_face_swap, cache_examples=False,
|
| 540 |
-
label="Example Face Swaps (Click to run)"
|
| 541 |
)
|
| 542 |
|
| 543 |
# --- Main Execution Block ---
|
| 544 |
if __name__ == "__main__":
|
| 545 |
os.makedirs("models", exist_ok=True)
|
| 546 |
os.makedirs("examples", exist_ok=True)
|
| 547 |
-
|
| 548 |
-
print("\n" + "="*60)
|
| 549 |
-
print("π ReSwapper AI FACE SWAP - STARTUP STATUS π")
|
| 550 |
-
print("="*60)
|
| 551 |
if not core_models_loaded_successfully:
|
| 552 |
-
print(f"π΄ CRITICAL
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
print("π’ Core models (Face Analyzer & ReSwapper) loaded successfully.")
|
| 558 |
-
|
| 559 |
-
if not restoration_model_loaded_successfully:
|
| 560 |
-
print(f"π‘ INFO: Face Restoration model ('{RESTORATION_MODEL_PATH}') not loaded. Enhancement feature disabled.")
|
| 561 |
-
else:
|
| 562 |
-
print("π’ Face Restoration model loaded successfully.")
|
| 563 |
-
print("="*60 + "\n")
|
| 564 |
-
|
| 565 |
-
print("Launching Gradio Interface...")
|
| 566 |
demo.launch()
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import gradio as gr
|
| 5 |
from insightface.app import FaceAnalysis # Still needed for detection and embeddings
|
|
|
|
| 6 |
from PIL import Image
|
| 7 |
import tempfile
|
| 8 |
import logging
|
|
|
|
| 11 |
# --- Configuration & Setup ---
|
| 12 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 13 |
|
| 14 |
+
SWAPPER_MODEL_PATH = "models/reswapper_256.onnx" # Using ReSwapper
|
|
|
|
| 15 |
RESTORATION_MODEL_PATH = "models/gfpgan_1.4.onnx" # Or your chosen restorer, e.g., restoreformer++.onnx
|
|
|
|
| 16 |
|
| 17 |
+
FACE_ANALYZER_NAME = 'buffalo_l'
|
| 18 |
DETECTION_SIZE = (640, 640)
|
| 19 |
+
EXECUTION_PROVIDERS = ['CPUExecutionProvider']
|
| 20 |
|
| 21 |
+
# --- Global Variables ---
|
| 22 |
face_analyzer = None
|
| 23 |
+
reswapper_session = None # For reswapper_256.onnx
|
| 24 |
face_restorer = None
|
|
|
|
| 25 |
|
| 26 |
# --- Initialization Functions ---
|
| 27 |
def initialize_models():
|
| 28 |
+
global face_analyzer, reswapper_session, face_restorer
|
| 29 |
try:
|
|
|
|
| 30 |
if face_analyzer is None:
|
| 31 |
logging.info(f"Initializing FaceAnalysis model: {FACE_ANALYZER_NAME}")
|
| 32 |
face_analyzer = FaceAnalysis(name=FACE_ANALYZER_NAME, providers=EXECUTION_PROVIDERS)
|
| 33 |
face_analyzer.prepare(ctx_id=0, det_size=DETECTION_SIZE)
|
| 34 |
logging.info("FaceAnalysis model initialized successfully.")
|
| 35 |
|
|
|
|
| 36 |
if reswapper_session is None:
|
| 37 |
if not os.path.exists(SWAPPER_MODEL_PATH):
|
| 38 |
logging.error(f"ReSwapper model FILE NOT FOUND at {SWAPPER_MODEL_PATH}. Swapping will fail.")
|
|
|
|
| 43 |
logging.info(f"ReSwapper model ({SWAPPER_MODEL_PATH}) loaded successfully with onnxruntime.")
|
| 44 |
except Exception as e:
|
| 45 |
logging.error(f"Error loading ReSwapper model {SWAPPER_MODEL_PATH} with onnxruntime: {e}", exc_info=True)
|
| 46 |
+
reswapper_session = None
|
| 47 |
+
|
|
|
|
| 48 |
if face_restorer is None:
|
| 49 |
if not os.path.exists(RESTORATION_MODEL_PATH):
|
| 50 |
+
logging.error(f"Face restoration model FILE NOT FOUND at: {RESTORATION_MODEL_PATH}. Enhancement feature will be disabled.")
|
| 51 |
else:
|
|
|
|
| 52 |
logging.info(f"Attempting to load face restoration model from: {RESTORATION_MODEL_PATH}")
|
| 53 |
try:
|
| 54 |
face_restorer = onnxruntime.InferenceSession(RESTORATION_MODEL_PATH, providers=EXECUTION_PROVIDERS)
|
|
|
|
| 56 |
except Exception as e:
|
| 57 |
logging.error(f"Error loading face restoration model from {RESTORATION_MODEL_PATH}: {e}. Enhancement feature will be disabled.", exc_info=True)
|
| 58 |
face_restorer = None
|
|
|
|
|
|
|
|
|
|
| 59 |
except Exception as e:
|
| 60 |
logging.error(f"A critical error occurred during model initialization: {e}", exc_info=True)
|
| 61 |
|
|
|
|
| 62 |
initialize_models()
|
| 63 |
+
core_models_loaded_successfully = face_analyzer is not None and reswapper_session is not None
|
| 64 |
restoration_model_loaded_successfully = face_restorer is not None
|
|
|
|
| 65 |
|
| 66 |
+
# --- Image Utility Functions ---
|
| 67 |
def convert_pil_to_cv2(pil_image: Image.Image) -> np.ndarray | None:
|
| 68 |
if pil_image is None: return None
|
| 69 |
try:
|
|
|
|
| 80 |
logging.error(f"Error converting CV2 to PIL: {e}")
|
| 81 |
return None
|
| 82 |
|
| 83 |
+
# --- Core AI & Image Processing Functions ---
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
def get_faces_from_image(img_np: np.ndarray) -> list:
|
| 85 |
+
if face_analyzer is None:
|
| 86 |
logging.error("Face analyzer not available for get_faces_from_image.")
|
| 87 |
return []
|
| 88 |
if img_np is None: return []
|
| 89 |
try:
|
| 90 |
+
return face_analyzer.get(img_np)
|
| 91 |
except Exception as e:
|
| 92 |
logging.error(f"Error during face detection: {e}", exc_info=True)
|
| 93 |
return []
|
| 94 |
|
| 95 |
+
def draw_detected_faces(img_np: np.ndarray, faces: list) -> np.ndarray:
|
|
|
|
| 96 |
img_with_boxes = img_np.copy()
|
| 97 |
for i, face in enumerate(faces):
|
| 98 |
box = face.bbox.astype(int)
|
| 99 |
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
|
| 100 |
cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 101 |
label_position = (x1, max(0, y1 - 10))
|
| 102 |
+
cv2.putText(img_with_boxes, f"Face {i}", label_position, cv2.FONT_HERSHEY_SIMPLEX, 0.6, (36, 255, 12), 2)
|
|
|
|
| 103 |
return img_with_boxes
|
| 104 |
|
| 105 |
+
def enhance_cropped_face(face_crop_bgr: np.ndarray) -> np.ndarray:
|
| 106 |
if not restoration_model_loaded_successfully or face_restorer is None:
|
|
|
|
| 107 |
return face_crop_bgr
|
| 108 |
+
if face_crop_bgr is None or face_crop_bgr.size == 0: return face_crop_bgr if face_crop_bgr is not None else np.array([])
|
|
|
|
|
|
|
| 109 |
logging.info(f"Applying face restoration to crop of size {face_crop_bgr.shape[:2]}...")
|
| 110 |
+
crop_h, crop_w = face_crop_bgr.shape[:2]
|
| 111 |
try:
|
| 112 |
img_rgb = cv2.cvtColor(face_crop_bgr, cv2.COLOR_BGR2RGB)
|
| 113 |
+
input_size = (512, 512) # Common for GFPGAN-like models
|
| 114 |
+
img_resized = cv2.resize(img_rgb, input_size, interpolation=cv2.INTER_AREA)
|
| 115 |
+
img_norm = (img_resized / 255.0).astype(np.float32)
|
| 116 |
+
img_chw = np.transpose(img_norm, (2,0,1))
|
| 117 |
input_tensor = np.expand_dims(img_chw, axis=0)
|
| 118 |
input_name = face_restorer.get_inputs()[0].name
|
| 119 |
output_name = face_restorer.get_outputs()[0].name
|
| 120 |
+
output = face_restorer.run([output_name], {input_name: input_tensor})[0]
|
| 121 |
+
out_chw = np.squeeze(output, axis=0)
|
| 122 |
+
out_hwc = np.transpose(out_chw, (1,2,0))
|
| 123 |
+
out_denorm = np.clip(out_hwc * 255.0, 0, 255).astype(np.uint8)
|
| 124 |
+
out_resized_rgb = cv2.resize(out_denorm, (crop_w, crop_h), interpolation=cv2.INTER_LANCZOS4)
|
| 125 |
+
return cv2.cvtColor(out_resized_rgb, cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
|
| 126 |
except Exception as e:
|
| 127 |
logging.error(f"Error during face restoration for crop: {e}", exc_info=True)
|
| 128 |
return face_crop_bgr
|
| 129 |
|
| 130 |
+
def histogram_match_channel(source_ch: np.ndarray, target_ch: np.ndarray) -> np.ndarray:
|
| 131 |
+
src_shape = source_ch.shape
|
| 132 |
+
src_flat = source_ch.flatten()
|
| 133 |
+
tgt_flat = target_ch.flatten()
|
| 134 |
+
src_hist, _ = np.histogram(src_flat, 256, [0,255])
|
| 135 |
+
tgt_hist, _ = np.histogram(tgt_flat, 256, [0,255])
|
| 136 |
+
src_cdf = src_hist.cumsum()
|
| 137 |
+
tgt_cdf = tgt_hist.cumsum()
|
| 138 |
+
if src_cdf[-1] == 0: return source_ch # Avoid division by zero for blank source
|
| 139 |
+
if tgt_cdf[-1] == 0: return source_ch # Avoid issues with blank target
|
| 140 |
+
src_cdf_norm = src_cdf * float(tgt_cdf[-1]) / src_cdf[-1] # Normalize src CDF to target range
|
| 141 |
+
lut = np.zeros(256, dtype='uint8')
|
| 142 |
+
j = 0
|
| 143 |
for i in range(256):
|
| 144 |
+
while j < 255 and tgt_cdf[j] < src_cdf_norm[i]:
|
| 145 |
+
j += 1
|
| 146 |
+
lut[i] = j
|
| 147 |
+
return cv2.LUT(source_ch, lut).reshape(src_shape)
|
|
|
|
| 148 |
|
| 149 |
+
def histogram_match_color(source_img: np.ndarray, target_img: np.ndarray) -> np.ndarray:
|
| 150 |
if source_img is None or target_img is None or source_img.size == 0 or target_img.size == 0:
|
|
|
|
| 151 |
return source_img if source_img is not None else np.array([])
|
| 152 |
+
matched = np.zeros_like(source_img)
|
| 153 |
try:
|
| 154 |
for i in range(source_img.shape[2]):
|
| 155 |
+
matched[:,:,i] = histogram_match_channel(source_img[:,:,i], target_img[:,:,i])
|
| 156 |
+
return matched
|
| 157 |
except Exception as e:
|
| 158 |
logging.error(f"Error during histogram matching: {e}", exc_info=True)
|
| 159 |
return source_img
|
| 160 |
|
| 161 |
+
def apply_naturalness_filters(face_region_bgr: np.ndarray, noise_level: int = 3) -> np.ndarray:
|
| 162 |
if face_region_bgr is None or face_region_bgr.size == 0:
|
|
|
|
| 163 |
return face_region_bgr if face_region_bgr is not None else np.array([])
|
| 164 |
+
proc = face_region_bgr.copy()
|
| 165 |
try:
|
| 166 |
+
proc = cv2.medianBlur(proc, 3)
|
| 167 |
if noise_level > 0:
|
| 168 |
+
noise = np.random.normal(0, noise_level, proc.shape).astype(np.int16)
|
| 169 |
+
proc_int16 = proc.astype(np.int16) + noise
|
| 170 |
+
proc = np.clip(proc_int16, 0, 255).astype(np.uint8)
|
| 171 |
+
logging.info("Applied naturalness filters.")
|
| 172 |
except Exception as e:
|
| 173 |
logging.error(f"Error applying naturalness filters: {e}", exc_info=True)
|
| 174 |
return face_region_bgr
|
| 175 |
+
return proc
|
| 176 |
|
| 177 |
# --- Main Processing Function ---
|
| 178 |
def process_face_swap(source_pil_img: Image.Image, target_pil_img: Image.Image,
|
| 179 |
+
target_face_index: int,
|
| 180 |
apply_enhancement: bool, apply_color_correction: bool,
|
| 181 |
apply_naturalness: bool,
|
| 182 |
progress=gr.Progress(track_tqdm=True)):
|
| 183 |
|
| 184 |
progress(0, desc="Initializing process...")
|
| 185 |
+
if not core_models_loaded_successfully:
|
| 186 |
gr.Error("CRITICAL: Core models (Face Analyzer or ReSwapper) not loaded. Cannot proceed.")
|
| 187 |
return Image.new('RGB', (100,100), color='lightgrey'), None
|
| 188 |
|
|
|
|
| 190 |
if target_pil_img is None: raise gr.Error("Target image not provided.")
|
| 191 |
|
| 192 |
progress(0.05, desc="Converting images...")
|
| 193 |
+
source_np = convert_pil_to_cv2(source_pil_img)
|
| 194 |
+
target_np = convert_pil_to_cv2(target_pil_img)
|
| 195 |
if source_np is None or target_np is None:
|
| 196 |
+
raise gr.Error("Image conversion failed.")
|
| 197 |
|
| 198 |
+
original_target_h, original_target_w = target_np.shape[:2]
|
| 199 |
|
| 200 |
+
progress(0.15, desc="Detecting source face...")
|
| 201 |
source_faces = get_faces_from_image(source_np)
|
| 202 |
+
if not source_faces: raise gr.Error("No face in source.")
|
| 203 |
+
source_face = source_faces[0]
|
| 204 |
|
| 205 |
+
progress(0.25, desc="Detecting target faces...")
|
| 206 |
target_faces = get_faces_from_image(target_np)
|
| 207 |
+
if not target_faces: raise gr.Error("No faces in target.")
|
| 208 |
if not (0 <= target_face_index < len(target_faces)):
|
| 209 |
+
raise gr.Error(f"Target face index out of range.")
|
| 210 |
+
target_face_info = target_faces[int(target_face_index)]
|
| 211 |
|
| 212 |
+
swapped_bgr_img_final = target_np.copy() # Initialize for fallback
|
| 213 |
|
| 214 |
try:
|
| 215 |
+
progress(0.35, desc="Preparing inputs for ReSwapper...")
|
| 216 |
|
| 217 |
+
# 1. Source Face Embedding
|
| 218 |
+
source_embedding = source_face.embedding.astype(np.float32)
|
|
|
|
|
|
|
| 219 |
source_embedding_tensor = np.expand_dims(source_embedding, axis=0) # (1, 512)
|
| 220 |
|
| 221 |
# 2. Target Image Preparation
|
| 222 |
+
target_img_rgb = cv2.cvtColor(target_np, cv2.COLOR_BGR2RGB)
|
| 223 |
+
|
| 224 |
+
# ***** FIX APPLIED HERE: Resize target image for ReSwapper *****
|
| 225 |
+
# The "256" in reswapper_256.onnx strongly suggests this expected input size.
|
| 226 |
+
# Your error indicated: Expected 256, Got 225 (or similar).
|
| 227 |
+
reswapper_input_size = (256, 256) # Expected H, W by the model
|
| 228 |
|
| 229 |
+
# Store original size of the target image if it's not already stored for final resize
|
| 230 |
+
# original_target_h, original_target_w are already available
|
| 231 |
+
|
| 232 |
+
target_img_resized_rgb = cv2.resize(target_img_rgb, (reswapper_input_size[1], reswapper_input_size[0]), interpolation=cv2.INTER_AREA)
|
| 233 |
+
# cv2.resize takes (width, height)
|
| 234 |
|
| 235 |
+
# --- VERIFY NORMALIZATION for reswapper_256.onnx ---
|
| 236 |
+
# Common options:
|
| 237 |
+
# Option A: [0, 1]
|
| 238 |
+
target_img_normalized = (target_img_resized_rgb / 255.0).astype(np.float32)
|
| 239 |
+
# Option B: [-1, 1]
|
| 240 |
+
# target_img_normalized = (target_img_resized_rgb / 127.5 - 1.0).astype(np.float32)
|
| 241 |
+
# Choose the one your reswapper_256.onnx model expects!
|
|
|
|
|
|
|
| 242 |
|
| 243 |
target_img_chw = np.transpose(target_img_normalized, (2, 0, 1)) # HWC to CHW
|
| 244 |
target_image_tensor = np.expand_dims(target_img_chw, axis=0) # Add batch dimension
|
| 245 |
|
| 246 |
+
# 3. Target Face Bounding Box (Highly Model Specific!)
|
| 247 |
+
# --- VERIFY IF AND HOW reswapper_256.onnx USES THIS ---
|
| 248 |
+
# If it does, it needs to be scaled to the resized_target_image (256x256)
|
| 249 |
+
# Original bbox is target_face_info.bbox relative to original_target_w, original_target_h
|
| 250 |
+
# For now, we'll pass the original bbox, but this might be incorrect if the model
|
| 251 |
+
# expects a bbox relative to the (potentially resized) target_image_tensor.
|
| 252 |
+
# Some models might not need it if they perform their own detection on the target_image_tensor.
|
| 253 |
+
target_bbox_original = target_face_info.bbox.astype(np.float32)
|
| 254 |
+
target_bbox_tensor_input = np.expand_dims(target_bbox_original, axis=0)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# --- VERIFY TENSOR NAMES AND INPUT ORDER for reswapper_256.onnx ---
|
| 258 |
+
# Use Netron to inspect your .onnx file.
|
| 259 |
+
# The following are GUESSES and placeholders.
|
| 260 |
+
onnx_input_list = reswapper_session.get_inputs()
|
| 261 |
+
onnx_output_list = reswapper_session.get_outputs()
|
| 262 |
+
|
| 263 |
+
input_feed = {}
|
| 264 |
+
# Example GUESS:
|
| 265 |
+
# Assumed input order: target_image, source_embedding, (optional) target_bbox
|
| 266 |
+
if len(onnx_input_list) > 0: input_feed[onnx_input_list[0].name] = target_image_tensor
|
| 267 |
+
if len(onnx_input_list) > 1: input_feed[onnx_input_list[1].name] = source_embedding_tensor
|
| 268 |
+
if len(onnx_input_list) > 2: # If model expects a third input (e.g., bbox)
|
| 269 |
+
input_feed[onnx_input_list[2].name] = target_bbox_tensor_input # This bbox might need scaling/normalization!
|
| 270 |
|
| 271 |
+
output_name = onnx_output_list[0].name # Assuming first output is the image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
progress(0.5, desc="Running ReSwapper model...")
|
| 274 |
onnx_outputs = reswapper_session.run([output_name], input_feed)
|
| 275 |
+
swapped_image_tensor_output = onnx_outputs[0]
|
| 276 |
|
| 277 |
progress(0.6, desc="Post-processing ReSwapper output...")
|
| 278 |
+
if swapped_image_tensor_output.ndim == 4:
|
| 279 |
+
swapped_image_tensor_output = np.squeeze(swapped_image_tensor_output, axis=0)
|
| 280 |
+
|
| 281 |
+
swapped_image_hwc = np.transpose(swapped_image_tensor_output, (1, 2, 0))
|
| 282 |
+
|
| 283 |
+
# --- VERIFY DENORMALIZATION (inverse of normalization used above) ---
|
|
|
|
| 284 |
# If normalized to [0, 1]:
|
| 285 |
swapped_image_denormalized = swapped_image_hwc * 255.0
|
| 286 |
# If normalized to [-1, 1]:
|
| 287 |
# swapped_image_denormalized = (swapped_image_hwc + 1.0) * 127.5
|
| 288 |
|
| 289 |
+
swapped_image_uint8_rgb = np.clip(swapped_image_denormalized, 0, 255).astype(np.uint8)
|
| 290 |
|
| 291 |
+
# The output of reswapper will be 256x256 (if input was 256x256 and model maintains size)
|
| 292 |
+
# Resize this output back to the original target image's dimensions
|
| 293 |
+
swapped_bgr_resized_to_original = cv2.resize(
|
| 294 |
+
swapped_image_uint8_rgb,
|
| 295 |
+
(original_target_w, original_target_h),
|
| 296 |
+
interpolation=cv2.INTER_LANCZOS4
|
| 297 |
+
)
|
| 298 |
+
# Convert to BGR as the rest of the pipeline expects BGR
|
| 299 |
+
swapped_bgr_img_final = cv2.cvtColor(swapped_bgr_resized_to_original, cv2.COLOR_RGB2BGR)
|
| 300 |
|
| 301 |
# --- Post-processing pipeline (Restoration, Color Correction, Naturalness) ---
|
| 302 |
+
current_processed_image = swapped_bgr_img_final.copy()
|
| 303 |
+
# ROI for post-processing is based on original bbox coordinates
|
| 304 |
+
roi_x1, roi_y1, roi_x2, roi_y2 = target_face_info.bbox.astype(int)
|
|
|
|
| 305 |
|
| 306 |
face_roi_for_postprocessing = current_processed_image[roi_y1:roi_y2, roi_x1:roi_x2].copy()
|
| 307 |
|
| 308 |
+
if face_roi_for_postprocessing.size > 0:
|
| 309 |
+
if apply_enhancement and restoration_model_loaded_successfully:
|
| 310 |
+
progress(0.7, desc="Applying face restoration...")
|
| 311 |
+
face_roi_for_postprocessing = enhance_cropped_face(face_roi_for_postprocessing)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
if apply_color_correction:
|
| 314 |
progress(0.8, desc="Applying color correction...")
|
| 315 |
+
original_target_roi_for_color = target_np[roi_y1:roi_y2, roi_x1:roi_x2] # Original target pixels
|
| 316 |
+
if original_target_roi_for_color.size > 0:
|
| 317 |
+
face_roi_for_postprocessing = histogram_match_color(face_roi_for_postprocessing, original_target_roi_for_color.copy())
|
| 318 |
|
| 319 |
if apply_naturalness:
|
| 320 |
progress(0.85, desc="Applying naturalness filters...")
|
| 321 |
face_roi_for_postprocessing = apply_naturalness_filters(face_roi_for_postprocessing)
|
| 322 |
|
| 323 |
# Paste the processed ROI back
|
| 324 |
+
if face_roi_for_postprocessing.shape[0] == (roi_y2 - roi_y1) and \
|
| 325 |
+
face_roi_for_postprocessing.shape[1] == (roi_x2 - roi_x1):
|
| 326 |
current_processed_image[roi_y1:roi_y2, roi_x1:roi_x2] = face_roi_for_postprocessing
|
| 327 |
+
else: # If size changed during post-processing (should not happen with current funcs)
|
| 328 |
+
logging.warning("ROI size mismatch after post-processing. Attempting resize for paste.")
|
| 329 |
try:
|
| 330 |
+
resized_roi = cv2.resize(face_roi_for_postprocessing, (roi_x2-roi_x1, roi_y2-roi_y1), interpolation=cv2.INTER_LANCZOS4)
|
| 331 |
+
current_processed_image[roi_y1:roi_y2, roi_x1:roi_x2] = resized_roi
|
| 332 |
+
except Exception as e_resize_paste:
|
| 333 |
+
logging.error(f"Failed to resize/paste processed ROI: {e_resize_paste}")
|
| 334 |
+
else:
|
| 335 |
+
logging.warning("ROI for post-processing is empty. Skipping these steps.")
|
| 336 |
|
| 337 |
+
swapped_bgr_img_final = current_processed_image
|
| 338 |
|
| 339 |
except Exception as e:
|
| 340 |
logging.error(f"Error during main processing pipeline: {e}", exc_info=True)
|
| 341 |
+
swapped_pil_img_on_error = convert_cv2_to_pil(swapped_bgr_img_final) # Use last good state
|
|
|
|
|
|
|
| 342 |
if swapped_pil_img_on_error is None:
|
| 343 |
+
swapped_pil_img_on_error = Image.new('RGB', (original_target_w, original_target_h), color='lightgrey')
|
| 344 |
+
raise gr.Error(f"An error occurred: {str(e)}. Check console for details, especially ONNX input/output errors.")
|
| 345 |
|
| 346 |
progress(0.9, desc="Finalizing image...")
|
| 347 |
+
swapped_pil_img = convert_cv2_to_pil(swapped_bgr_img_final)
|
| 348 |
if swapped_pil_img is None:
|
| 349 |
+
gr.Error("Failed to convert final image.")
|
| 350 |
+
swapped_pil_img = Image.new('RGB', (original_target_w, original_target_h), color='lightgrey')
|
| 351 |
|
| 352 |
temp_file_path = None
|
| 353 |
try:
|
| 354 |
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False, mode="wb") as tmp_file:
|
| 355 |
swapped_pil_img.save(tmp_file, format="JPEG")
|
| 356 |
temp_file_path = tmp_file.name
|
|
|
|
| 357 |
except Exception as e:
|
| 358 |
+
logging.error(f"Error saving temp file: {e}", exc_info=True)
|
| 359 |
+
gr.Warning("Could not save swapped image for download.")
|
| 360 |
|
| 361 |
progress(1.0, desc="Processing complete!")
|
| 362 |
return swapped_pil_img, temp_file_path
|
| 363 |
|
| 364 |
+
# --- Gradio Preview Function ---
|
| 365 |
+
def preview_target_faces(target_pil_img: Image.Image):
|
| 366 |
if target_pil_img is None:
|
| 367 |
+
blank_img = Image.new('RGB', DETECTION_SIZE, color='lightgray')
|
| 368 |
+
return blank_img, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
|
| 369 |
target_np = convert_pil_to_cv2(target_pil_img)
|
| 370 |
if target_np is None:
|
| 371 |
+
return Image.new('RGB', DETECTION_SIZE, color='lightgray'), gr.Slider(interactive=False)
|
|
|
|
| 372 |
faces = get_faces_from_image(target_np)
|
| 373 |
if not faces:
|
| 374 |
+
return convert_cv2_to_pil(target_np) or Image.new('RGB', DETECTION_SIZE, color='lightgray'), gr.Slider(interactive=False)
|
| 375 |
+
|
| 376 |
+
preview_np = draw_detected_faces(target_np, faces)
|
| 377 |
+
preview_pil = convert_cv2_to_pil(preview_np) or Image.new('RGB', DETECTION_SIZE, color='lightgray')
|
| 378 |
+
return preview_pil, gr.Slider(minimum=0, maximum=max(0, len(faces)-1), value=0, step=1, interactive=len(faces)>0)
|
| 379 |
+
|
| 380 |
+
# --- Gradio UI Definition ---
|
| 381 |
+
with gr.Blocks(title="ReSwapper AI Face Swap π v2", theme=gr.themes.Soft()) as demo: # Changed title slightly
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
gr.Markdown(
|
| 383 |
"""
|
| 384 |
<div style="text-align: center;">
|
| 385 |
+
<h1>π ReSwapper AI Face Swap π v2</h1>
|
| 386 |
+
<p>Utilizing <code>reswapper_256.onnx</code>. Please ensure model inputs/outputs are correctly configured in the code if issues arise.</p>
|
| 387 |
</div>
|
| 388 |
"""
|
| 389 |
)
|
| 390 |
+
if not core_models_loaded_successfully:
|
| 391 |
+
gr.Error("CRITICAL ERROR: Core models (Face Analyzer or ReSwapper) failed to load. Check console.")
|
|
|
|
| 392 |
|
| 393 |
with gr.Row():
|
| 394 |
with gr.Column(scale=1):
|
| 395 |
+
source_image_input = gr.Image(label="π€ Source Face Image", type="pil", sources=["upload","clipboard"], height=350)
|
| 396 |
with gr.Column(scale=1):
|
| 397 |
+
target_image_input = gr.Image(label="πΌοΈ Target Scene Image", type="pil", sources=["upload","clipboard"], height=350)
|
|
|
|
| 398 |
with gr.Row(equal_height=True):
|
| 399 |
preview_button = gr.Button("π Preview & Select Target Face", variant="secondary")
|
| 400 |
+
face_index_slider = gr.Slider(label="π― Select Target Face", minimum=0, maximum=0, step=1, value=0, interactive=False)
|
|
|
|
| 401 |
target_faces_preview_output = gr.Image(label="π Detected Faces in Target", interactive=False, height=350)
|
| 402 |
+
gr.HTML("<hr style='margin:15px 0;'>")
|
|
|
|
| 403 |
with gr.Row():
|
| 404 |
with gr.Column(scale=1):
|
| 405 |
+
enh_label = "β¨ Face Restoration" + (" (Model N/A)" if not restoration_model_loaded_successfully else "")
|
| 406 |
+
enhance_checkbox = gr.Checkbox(label=enh_label, value=restoration_model_loaded_successfully, interactive=restoration_model_loaded_successfully)
|
| 407 |
+
if not restoration_model_loaded_successfully: gr.Markdown("<p style='color:orange;font-size:0.8em;'>β οΈ Restoration N/A</p>")
|
|
|
|
|
|
|
| 408 |
with gr.Column(scale=1):
|
| 409 |
+
color_correction_checkbox = gr.Checkbox(label="π¨ Color Correction", value=True)
|
|
|
|
| 410 |
with gr.Column(scale=1):
|
| 411 |
+
naturalness_checkbox = gr.Checkbox(label="πΏ Naturalness Filters", value=False)
|
|
|
|
| 412 |
with gr.Row():
|
| 413 |
swap_button = gr.Button("π GENERATE SWAP!", variant="primary", scale=3, interactive=core_models_loaded_successfully)
|
| 414 |
clear_button = gr.Button("π§Ή Clear All", variant="stop", scale=1)
|
|
|
|
| 415 |
with gr.Row():
|
| 416 |
swapped_image_output = gr.Image(label="β¨ Swapped Result", interactive=False, height=450)
|
| 417 |
download_output_file = gr.File(label="β¬οΈ Download Swapped Image")
|
| 418 |
|
| 419 |
# Event Handlers
|
| 420 |
+
target_image_input.change(fn=lambda x: (Image.new('RGB', DETECTION_SIZE, color='lightgray') if x is None else target_faces_preview_output.value,
|
| 421 |
+
gr.Slider(interactive=False) if x is None else face_index_slider.value),
|
| 422 |
+
inputs=[target_image_input],
|
| 423 |
+
outputs=[target_faces_preview_output, face_index_slider],
|
| 424 |
+
queue=False)
|
|
|
|
|
|
|
| 425 |
preview_button.click(fn=preview_target_faces, inputs=[target_image_input], outputs=[target_faces_preview_output, face_index_slider])
|
| 426 |
+
swap_button.click(fn=process_face_swap,
|
| 427 |
+
inputs=[source_image_input, target_image_input, face_index_slider,
|
| 428 |
+
enhance_checkbox, color_correction_checkbox, naturalness_checkbox],
|
| 429 |
+
outputs=[swapped_image_output, download_output_file])
|
| 430 |
+
def clear_all():
|
| 431 |
+
return None, None, Image.new('RGB', DETECTION_SIZE, color='lightgray'), gr.Slider(interactive=False), None, None
|
| 432 |
+
clear_button.click(fn=clear_all, inputs=None,
|
| 433 |
+
outputs=[source_image_input, target_image_input, target_faces_preview_output,
|
| 434 |
+
face_index_slider, swapped_image_output, download_output_file],
|
| 435 |
+
queue=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
gr.Examples(
|
| 437 |
+
[["examples/source_face.jpg", "examples/target_group.jpg", 0, True, True, True]],
|
| 438 |
+
[source_image_input, target_image_input, face_index_slider, enhance_checkbox, color_correction_checkbox, naturalness_checkbox],
|
| 439 |
+
[swapped_image_output, download_output_file],
|
| 440 |
+
process_face_swap, cache_examples=False, label="Example"
|
|
|
|
|
|
|
|
|
|
| 441 |
)
|
| 442 |
|
| 443 |
# --- Main Execution Block ---
|
| 444 |
if __name__ == "__main__":
|
| 445 |
os.makedirs("models", exist_ok=True)
|
| 446 |
os.makedirs("examples", exist_ok=True)
|
| 447 |
+
print("\n" + "="*60 + "\nπ ReSwapper AI FACE SWAP - STARTUP STATUS π\n" + "="*60)
|
|
|
|
|
|
|
|
|
|
| 448 |
if not core_models_loaded_successfully:
|
| 449 |
+
print(f"π΄ CRITICAL: Core models FAILED. Analyzer: {'OK' if face_analyzer else 'FAIL'}, ReSwapper: {'OK' if reswapper_session else 'FAIL'}")
|
| 450 |
+
else: print("π’ Core models (Analyzer & ReSwapper) loaded.")
|
| 451 |
+
if not restoration_model_loaded_successfully: print(f"π‘ INFO: Restoration model ('{RESTORATION_MODEL_PATH}') N/A.")
|
| 452 |
+
else: print("π’ Restoration model loaded.")
|
| 453 |
+
print("="*60 + "\nLaunching Gradio Interface...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
demo.launch()
|