Update app.py
Browse files
app.py
CHANGED
|
@@ -13,16 +13,16 @@ import onnxruntime
|
|
| 13 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 14 |
|
| 15 |
SWAPPER_MODEL_PATH = "models/inswapper_128.onnx"
|
| 16 |
-
|
| 17 |
|
| 18 |
-
FACE_ANALYZER_NAME = 'buffalo_l'
|
| 19 |
-
DETECTION_SIZE = (640, 640)
|
| 20 |
-
EXECUTION_PROVIDERS = ['CPUExecutionProvider']
|
| 21 |
|
| 22 |
-
# --- Global Variables (Lazy Loaded) ---
|
| 23 |
face_analyzer = None
|
| 24 |
swapper = None
|
| 25 |
-
face_restorer = None
|
| 26 |
|
| 27 |
# --- Initialization Functions ---
|
| 28 |
def initialize_models():
|
|
@@ -32,197 +32,298 @@ def initialize_models():
|
|
| 32 |
if face_analyzer is None:
|
| 33 |
logging.info(f"Initializing FaceAnalysis model: {FACE_ANALYZER_NAME}")
|
| 34 |
face_analyzer = FaceAnalysis(name=FACE_ANALYZER_NAME, providers=EXECUTION_PROVIDERS)
|
| 35 |
-
face_analyzer.prepare(ctx_id=0, det_size=DETECTION_SIZE)
|
| 36 |
logging.info("FaceAnalysis model initialized successfully.")
|
| 37 |
|
| 38 |
# Initialize Swapper model
|
| 39 |
if swapper is None:
|
| 40 |
if not os.path.exists(SWAPPER_MODEL_PATH):
|
| 41 |
logging.error(f"Swapper model FILE NOT FOUND at {SWAPPER_MODEL_PATH}. Swapping will fail.")
|
| 42 |
-
# Keep swapper as None, subsequent checks should handle this
|
| 43 |
else:
|
| 44 |
logging.info(f"Loading swapper model from: {SWAPPER_MODEL_PATH}")
|
| 45 |
try:
|
|
|
|
| 46 |
swapper = get_model(SWAPPER_MODEL_PATH, download=False, providers=EXECUTION_PROVIDERS)
|
| 47 |
except TypeError:
|
| 48 |
logging.warning(f"Failed to pass 'providers' to swapper model {SWAPPER_MODEL_PATH}. Retrying without 'providers' argument.")
|
| 49 |
-
swapper = get_model(SWAPPER_MODEL_PATH, download=False)
|
| 50 |
logging.info("Swapper model loaded successfully.")
|
| 51 |
|
| 52 |
# Initialize Face Restoration Model
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
# This log is critical for the user to see
|
| 57 |
-
logging.error(f"Face restoration model FILE NOT FOUND at: {GFPGAN_MODEL_PATH}. Enhancement feature will be disabled.")
|
| 58 |
-
face_restorer = None # Explicitly ensure it's None
|
| 59 |
else:
|
| 60 |
-
logging.info(f"Attempting to load face restoration model from: {
|
| 61 |
try:
|
| 62 |
-
face_restorer = onnxruntime.InferenceSession(
|
| 63 |
logging.info("Face restoration model loaded successfully.")
|
| 64 |
except Exception as e:
|
| 65 |
-
|
| 66 |
-
logging.error(f"Error loading face restoration model from {GFPGAN_MODEL_PATH}: {e}. Enhancement feature will be disabled.", exc_info=True)
|
| 67 |
face_restorer = None # Ensure it's None if loading failed
|
| 68 |
-
|
| 69 |
-
except Exception as e: # Catch other unexpected errors during initialization
|
| 70 |
logging.error(f"A critical error occurred during model initialization: {e}", exc_info=True)
|
| 71 |
-
#
|
| 72 |
-
# The application might be in a partially usable or unusable state.
|
| 73 |
|
| 74 |
# --- Call Initialization Early ---
|
| 75 |
-
# This ensures models are loaded (or attempted to be loaded) before the UI is built,
|
| 76 |
-
# allowing the UI to reflect the status of optional models like the face_restorer.
|
| 77 |
initialize_models()
|
| 78 |
-
#
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
# Optionally raise an error here to stop the script if you deem these absolutely essential for any UI interaction.
|
| 82 |
-
# raise RuntimeError("Core models failed to load. Cannot start application.")
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
# ... (ensure these functions are present here from your previous code) ...
|
| 87 |
-
def convert_pil_to_cv2(pil_image: Image.Image) -> np.ndarray:
|
| 88 |
if pil_image is None: return None
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
def convert_cv2_to_pil(cv2_image: np.ndarray) -> Image.Image:
|
| 92 |
if cv2_image is None: return None
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
return []
|
| 99 |
if img_np is None: return []
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
-
def draw_detected_faces(img_np: np.ndarray, faces: list):
|
| 103 |
img_with_boxes = img_np.copy()
|
| 104 |
for i, face in enumerate(faces):
|
| 105 |
-
box = face.bbox.astype(int)
|
| 106 |
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
|
| 107 |
cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 108 |
-
label_position = (x1, max(0, y1 - 10))
|
| 109 |
cv2.putText(img_with_boxes, f"Face {i}", label_position,
|
| 110 |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (36, 255, 12), 2)
|
| 111 |
return img_with_boxes
|
| 112 |
|
| 113 |
def enhance_cropped_face(face_crop_bgr: np.ndarray) -> np.ndarray:
|
| 114 |
-
if face_restorer is None:
|
| 115 |
logging.warning("Face restorer model not available. Skipping enhancement for crop.")
|
| 116 |
return face_crop_bgr
|
| 117 |
-
if face_crop_bgr.shape[0] == 0 or face_crop_bgr.shape[1] == 0:
|
| 118 |
-
logging.warning("Received empty face crop for enhancement.")
|
| 119 |
-
return face_crop_bgr
|
|
|
|
| 120 |
|
| 121 |
logging.info(f"Applying face restoration to crop of size {face_crop_bgr.shape[:2]}...")
|
| 122 |
crop_height, crop_width = face_crop_bgr.shape[:2]
|
| 123 |
-
|
| 124 |
-
restorer_input_size = (512, 512)
|
| 125 |
-
img_resized_for_model = cv2.resize(img_rgb, restorer_input_size, interpolation=cv2.INTER_AREA)
|
| 126 |
-
img_normalized = (img_resized_for_model / 255.0).astype(np.float32)
|
| 127 |
-
img_chw = np.transpose(img_normalized, (2, 0, 1))
|
| 128 |
-
input_tensor = np.expand_dims(img_chw, axis=0)
|
| 129 |
-
input_name = face_restorer.get_inputs()[0].name
|
| 130 |
-
output_name = face_restorer.get_outputs()[0].name
|
| 131 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
restored_output_model_size = face_restorer.run([output_name], {input_name: input_tensor})[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
except Exception as e:
|
| 134 |
-
logging.error(f"Error during face restoration
|
| 135 |
-
return face_crop_bgr
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
# --- Gradio Interface Functions ---
|
| 145 |
def process_face_swap(source_pil_img: Image.Image, target_pil_img: Image.Image,
|
| 146 |
-
target_face_index: int, apply_enhancement: bool,
|
| 147 |
progress=gr.Progress(track_tqdm=True)):
|
| 148 |
-
progress(0, desc="Initializing...")
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
gr.Error("Core models (
|
| 152 |
-
|
| 153 |
-
# For safety, let's return empty results for Gradio components.
|
| 154 |
-
blank_pil = Image.new('RGB', (100,100), color='grey')
|
| 155 |
-
return blank_pil, None # For output_image and download_output
|
| 156 |
|
| 157 |
if source_pil_img is None: raise gr.Error("Source image not provided.")
|
| 158 |
if target_pil_img is None: raise gr.Error("Target image not provided.")
|
| 159 |
|
| 160 |
-
progress(0.
|
| 161 |
source_np = convert_pil_to_cv2(source_pil_img)
|
| 162 |
target_np = convert_pil_to_cv2(target_pil_img)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
target_h, target_w = target_np.shape[:2]
|
| 164 |
|
| 165 |
-
progress(0.
|
| 166 |
source_faces = get_faces_from_image(source_np)
|
| 167 |
-
if not source_faces: raise gr.Error("No face found in the source image.")
|
| 168 |
-
source_face = source_faces[0]
|
| 169 |
|
| 170 |
-
progress(0.
|
| 171 |
target_faces = get_faces_from_image(target_np)
|
| 172 |
if not target_faces: raise gr.Error("No faces found in the target image.")
|
| 173 |
if not (0 <= target_face_index < len(target_faces)):
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
-
swapped_bgr_img = target_np #
|
| 178 |
try:
|
| 179 |
-
progress(0.
|
|
|
|
| 180 |
swapped_bgr_img = swapper.get(target_np, target_face_to_swap_info, source_face, paste_back=True)
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
if apply_enhancement:
|
| 183 |
-
if
|
| 184 |
-
progress(0.
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
if
|
| 193 |
-
face_crop_to_enhance = swapped_bgr_img[
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
swapped_bgr_img[
|
| 197 |
else:
|
| 198 |
-
logging.warning("
|
| 199 |
else:
|
| 200 |
logging.warning("Skipping enhancement, invalid crop dimensions after padding.")
|
| 201 |
else:
|
| 202 |
-
|
| 203 |
-
gr.Info("Face restoration model not available, enhancement skipped by process_face_swap.")
|
| 204 |
logging.warning("Enhancement requested at runtime but face restorer is not available.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
except Exception as e:
|
| 207 |
-
logging.error(f"Error during face swapping or
|
| 208 |
-
# Return the
|
| 209 |
-
|
| 210 |
-
swapped_pil_img_on_error
|
|
|
|
| 211 |
raise gr.Error(f"An error occurred: {str(e)}")
|
| 212 |
-
# return swapped_pil_img_on_error, None # Alternative
|
| 213 |
|
| 214 |
progress(0.9, desc="Finalizing image...")
|
| 215 |
swapped_pil_img = convert_cv2_to_pil(swapped_bgr_img)
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
temp_file_path = None
|
| 218 |
try:
|
| 219 |
# Using 'with' ensures the file is closed before Gradio tries to use it
|
| 220 |
-
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
|
| 221 |
-
swapped_pil_img.save(tmp_file
|
| 222 |
temp_file_path = tmp_file.name
|
| 223 |
logging.info(f"Swapped image saved to temporary file: {temp_file_path}")
|
| 224 |
except Exception as e:
|
| 225 |
-
logging.error(f"Error saving to temporary file: {e}")
|
| 226 |
gr.Warning("Could not save the swapped image for download.")
|
| 227 |
|
| 228 |
progress(1.0, desc="Processing complete!")
|
|
@@ -234,87 +335,104 @@ def preview_target_faces(target_pil_img: Image.Image):
|
|
| 234 |
return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
|
| 235 |
|
| 236 |
target_np = convert_pil_to_cv2(target_pil_img)
|
| 237 |
-
if target_np is None:
|
| 238 |
blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
|
| 239 |
return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
|
| 240 |
|
| 241 |
faces = get_faces_from_image(target_np)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
preview_np_img = draw_detected_faces(target_np, faces)
|
| 243 |
preview_pil_img = convert_cv2_to_pil(preview_np_img)
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
num_faces = len(faces)
|
| 246 |
-
slider_update = gr.Slider(minimum=0, maximum=max(0, num_faces - 1), value=0, step=1, interactive=num_faces > 0)
|
| 247 |
return preview_pil_img, slider_update
|
| 248 |
|
| 249 |
|
| 250 |
# --- Gradio UI Definition ---
|
| 251 |
-
with gr.Blocks(title="Ultimate Face Swap AI π
|
| 252 |
gr.Markdown(
|
| 253 |
"""
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
|
|
|
|
|
|
| 257 |
"""
|
| 258 |
)
|
| 259 |
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
core_models_loaded = face_analyzer is not None and swapper is not None
|
| 263 |
-
|
| 264 |
-
if not core_models_loaded:
|
| 265 |
-
gr.Error("CRITICAL ERROR: Core models (Face Analyzer or Swapper) failed to load. Application will not function correctly. Please check the console logs for details (e.g., model file paths).")
|
| 266 |
|
| 267 |
with gr.Row():
|
| 268 |
with gr.Column(scale=1):
|
| 269 |
-
source_image_input = gr.Image(label="π€ Source Face Image", type="pil", sources=["upload", "clipboard"], height=
|
| 270 |
with gr.Column(scale=1):
|
| 271 |
-
target_image_input = gr.Image(label="πΌοΈ Target Scene Image", type="pil", sources=["upload", "clipboard"], height=
|
| 272 |
|
| 273 |
with gr.Row(equal_height=True):
|
| 274 |
-
preview_button = gr.Button("π Preview
|
| 275 |
face_index_slider = gr.Slider(
|
| 276 |
label="π― Select Target Face (0-indexed)",
|
| 277 |
minimum=0, maximum=0, step=1, value=0, interactive=False
|
| 278 |
)
|
| 279 |
|
| 280 |
-
target_faces_preview_output = gr.Image(label="π Detected Faces in Target", interactive=False, height=
|
| 281 |
|
| 282 |
-
gr.HTML("<hr style='margin-top:
|
| 283 |
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
|
| 288 |
-
enhance_checkbox = gr.Checkbox(
|
| 289 |
-
label=enhance_checkbox_label,
|
| 290 |
-
value=restoration_model_loaded, # Default to True only if available and loaded
|
| 291 |
-
interactive=restoration_model_loaded # Disable if not available/loaded
|
| 292 |
-
)
|
| 293 |
-
if not restoration_model_loaded:
|
| 294 |
-
gr.Markdown("<p style='color: orange; font-size:0.9em;'>β οΈ Face restoration model could not be loaded. Enhancement feature is disabled. Please check console logs for model path errors (e.g., `models/gfpgan_1.4.onnx`).</p>")
|
| 295 |
|
| 296 |
with gr.Row():
|
| 297 |
-
swap_button = gr.Button("π
|
| 298 |
clear_button = gr.Button("π§Ή Clear All", variant="stop", scale=1)
|
| 299 |
|
|
|
|
| 300 |
with gr.Row():
|
| 301 |
-
swapped_image_output = gr.Image(label="β¨ Swapped Result", interactive=False, height=
|
| 302 |
download_output_file = gr.File(label="β¬οΈ Download Swapped Image")
|
| 303 |
|
| 304 |
-
# Event Handlers
|
| 305 |
-
def on_target_image_change_or_clear(
|
| 306 |
-
if
|
| 307 |
blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color = 'lightgray')
|
| 308 |
return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
|
| 309 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
target_image_input.change(
|
| 312 |
fn=on_target_image_change_or_clear,
|
| 313 |
inputs=[target_image_input],
|
| 314 |
outputs=[target_faces_preview_output, face_index_slider],
|
|
|
|
| 315 |
)
|
| 316 |
-
|
| 317 |
-
|
| 318 |
preview_button.click(
|
| 319 |
fn=preview_target_faces,
|
| 320 |
inputs=[target_image_input],
|
|
@@ -323,59 +441,72 @@ with gr.Blocks(title="Ultimate Face Swap AI π v2", theme=gr.themes.Glass()) a
|
|
| 323 |
|
| 324 |
swap_button.click(
|
| 325 |
fn=process_face_swap,
|
| 326 |
-
inputs=[source_image_input, target_image_input, face_index_slider, enhance_checkbox],
|
| 327 |
outputs=[swapped_image_output, download_output_file]
|
| 328 |
)
|
| 329 |
|
| 330 |
def clear_all_inputs_outputs():
|
| 331 |
blank_preview = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color = 'lightgray')
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
clear_button.click(
|
| 338 |
fn=clear_all_inputs_outputs,
|
| 339 |
-
inputs=None,
|
| 340 |
-
outputs=[
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
|
|
|
|
|
|
| 344 |
)
|
| 345 |
|
| 346 |
gr.Examples(
|
| 347 |
examples=[
|
| 348 |
-
["examples/source_face.jpg", "examples/target_group.jpg", 0, True],
|
| 349 |
-
["examples/source_actor.png", "examples/target_scene.png", 1, True],
|
| 350 |
-
["examples/source_face.jpg", "examples/target_group.jpg", 0, False],
|
|
|
|
| 351 |
],
|
| 352 |
-
inputs=[source_image_input, target_image_input, face_index_slider, enhance_checkbox],
|
| 353 |
outputs=[swapped_image_output, download_output_file],
|
| 354 |
-
fn=process_face_swap,
|
| 355 |
-
cache_examples="lazy"
|
| 356 |
-
label="Example Face Swaps (Click to run
|
| 357 |
)
|
| 358 |
|
| 359 |
-
|
| 360 |
if __name__ == "__main__":
|
| 361 |
os.makedirs("models", exist_ok=True)
|
| 362 |
-
os.makedirs("examples", exist_ok=True)
|
| 363 |
-
|
| 364 |
-
#
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
print("
|
| 370 |
-
print(f"
|
| 371 |
-
print("
|
| 372 |
-
print("
|
| 373 |
-
|
| 374 |
-
print("\n" + "="*50)
|
| 375 |
-
print("INFO: Face Restoration model not loaded.")
|
| 376 |
-
print(f"The enhancement feature will be disabled. Check path: '{GFPGAN_MODEL_PATH}'")
|
| 377 |
-
print("="*50 + "\n")
|
| 378 |
else:
|
| 379 |
-
print("
|
| 380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
demo.launch()
|
|
|
|
| 13 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 14 |
|
| 15 |
SWAPPER_MODEL_PATH = "models/inswapper_128.onnx"
|
| 16 |
+
RESTORATION_MODEL_PATH = "models/gfpgan_1.4.onnx" # Path to your GFPGAN or chosen restoration ONNX model
|
| 17 |
|
| 18 |
+
FACE_ANALYZER_NAME = 'buffalo_l' # InsightFace detection model
|
| 19 |
+
DETECTION_SIZE = (640, 640) # Input size for face detection
|
| 20 |
+
EXECUTION_PROVIDERS = ['CPUExecutionProvider'] # Or ['CUDAExecutionProvider', 'CPUExecutionProvider'] for GPU
|
| 21 |
|
| 22 |
+
# --- Global Variables (Lazy Loaded by initialize_models) ---
|
| 23 |
face_analyzer = None
|
| 24 |
swapper = None
|
| 25 |
+
face_restorer = None
|
| 26 |
|
| 27 |
# --- Initialization Functions ---
|
| 28 |
def initialize_models():
|
|
|
|
| 32 |
if face_analyzer is None:
|
| 33 |
logging.info(f"Initializing FaceAnalysis model: {FACE_ANALYZER_NAME}")
|
| 34 |
face_analyzer = FaceAnalysis(name=FACE_ANALYZER_NAME, providers=EXECUTION_PROVIDERS)
|
| 35 |
+
face_analyzer.prepare(ctx_id=0, det_size=DETECTION_SIZE) # ctx_id=0 for CPU, -1 for auto
|
| 36 |
logging.info("FaceAnalysis model initialized successfully.")
|
| 37 |
|
| 38 |
# Initialize Swapper model
|
| 39 |
if swapper is None:
|
| 40 |
if not os.path.exists(SWAPPER_MODEL_PATH):
|
| 41 |
logging.error(f"Swapper model FILE NOT FOUND at {SWAPPER_MODEL_PATH}. Swapping will fail.")
|
|
|
|
| 42 |
else:
|
| 43 |
logging.info(f"Loading swapper model from: {SWAPPER_MODEL_PATH}")
|
| 44 |
try:
|
| 45 |
+
# Pass providers to get_model if supported
|
| 46 |
swapper = get_model(SWAPPER_MODEL_PATH, download=False, providers=EXECUTION_PROVIDERS)
|
| 47 |
except TypeError:
|
| 48 |
logging.warning(f"Failed to pass 'providers' to swapper model {SWAPPER_MODEL_PATH}. Retrying without 'providers' argument.")
|
| 49 |
+
swapper = get_model(SWAPPER_MODEL_PATH, download=False) # Fallback
|
| 50 |
logging.info("Swapper model loaded successfully.")
|
| 51 |
|
| 52 |
# Initialize Face Restoration Model
|
| 53 |
+
if face_restorer is None: # Only attempt to load if not already loaded/failed
|
| 54 |
+
if not os.path.exists(RESTORATION_MODEL_PATH):
|
| 55 |
+
logging.error(f"Face restoration model FILE NOT FOUND at: {RESTORATION_MODEL_PATH}. Enhancement feature will be disabled.")
|
|
|
|
|
|
|
|
|
|
| 56 |
else:
|
| 57 |
+
logging.info(f"Attempting to load face restoration model from: {RESTORATION_MODEL_PATH}")
|
| 58 |
try:
|
| 59 |
+
face_restorer = onnxruntime.InferenceSession(RESTORATION_MODEL_PATH, providers=EXECUTION_PROVIDERS)
|
| 60 |
logging.info("Face restoration model loaded successfully.")
|
| 61 |
except Exception as e:
|
| 62 |
+
logging.error(f"Error loading face restoration model from {RESTORATION_MODEL_PATH}: {e}. Enhancement feature will be disabled.", exc_info=True)
|
|
|
|
| 63 |
face_restorer = None # Ensure it's None if loading failed
|
| 64 |
+
except Exception as e:
|
|
|
|
| 65 |
logging.error(f"A critical error occurred during model initialization: {e}", exc_info=True)
|
| 66 |
+
# face_analyzer, swapper, or face_restorer might be None. Subsequent checks will handle this.
|
|
|
|
| 67 |
|
| 68 |
# --- Call Initialization Early ---
|
|
|
|
|
|
|
| 69 |
initialize_models()
|
| 70 |
+
# Global flags based on model loading status, used for UI conditional rendering
|
| 71 |
+
core_models_loaded_successfully = face_analyzer is not None and swapper is not None
|
| 72 |
+
restoration_model_loaded_successfully = face_restorer is not None
|
|
|
|
|
|
|
| 73 |
|
| 74 |
+
# --- Image Utility Functions ---
|
| 75 |
+
def convert_pil_to_cv2(pil_image: Image.Image) -> np.ndarray | None:
|
|
|
|
|
|
|
| 76 |
if pil_image is None: return None
|
| 77 |
+
try:
|
| 78 |
+
return cv2.cvtColor(np.array(pil_image.convert('RGB')), cv2.COLOR_RGB2BGR)
|
| 79 |
+
except Exception as e:
|
| 80 |
+
logging.error(f"Error converting PIL to CV2: {e}")
|
| 81 |
+
return None
|
| 82 |
|
| 83 |
+
def convert_cv2_to_pil(cv2_image: np.ndarray) -> Image.Image | None:
|
| 84 |
if cv2_image is None: return None
|
| 85 |
+
try:
|
| 86 |
+
return Image.fromarray(cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB))
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logging.error(f"Error converting CV2 to PIL: {e}")
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
# --- Core AI & Image Processing Functions ---
|
| 92 |
+
def get_faces_from_image(img_np: np.ndarray) -> list:
|
| 93 |
+
if not core_models_loaded_successfully or face_analyzer is None:
|
| 94 |
+
# This condition should ideally be caught before calling this function if core models failed.
|
| 95 |
+
logging.error("Face analyzer not available for get_faces_from_image.")
|
| 96 |
return []
|
| 97 |
if img_np is None: return []
|
| 98 |
+
try:
|
| 99 |
+
return face_analyzer.get(img_np)
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logging.error(f"Error during face detection: {e}", exc_info=True)
|
| 102 |
+
return []
|
| 103 |
|
| 104 |
+
def draw_detected_faces(img_np: np.ndarray, faces: list) -> np.ndarray:
|
| 105 |
img_with_boxes = img_np.copy()
|
| 106 |
for i, face in enumerate(faces):
|
| 107 |
+
box = face.bbox.astype(int) # InsightFace bbox is [x1, y1, x2, y2]
|
| 108 |
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
|
| 109 |
cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 110 |
+
label_position = (x1, max(0, y1 - 10)) # Ensure label is not drawn outside top
|
| 111 |
cv2.putText(img_with_boxes, f"Face {i}", label_position,
|
| 112 |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (36, 255, 12), 2)
|
| 113 |
return img_with_boxes
|
| 114 |
|
| 115 |
def enhance_cropped_face(face_crop_bgr: np.ndarray) -> np.ndarray:
|
| 116 |
+
if not restoration_model_loaded_successfully or face_restorer is None:
|
| 117 |
logging.warning("Face restorer model not available. Skipping enhancement for crop.")
|
| 118 |
return face_crop_bgr
|
| 119 |
+
if face_crop_bgr is None or face_crop_bgr.shape[0] == 0 or face_crop_bgr.shape[1] == 0:
|
| 120 |
+
logging.warning("Received empty or invalid face crop for enhancement.")
|
| 121 |
+
return face_crop_bgr if face_crop_bgr is not None else np.array([])
|
| 122 |
+
|
| 123 |
|
| 124 |
logging.info(f"Applying face restoration to crop of size {face_crop_bgr.shape[:2]}...")
|
| 125 |
crop_height, crop_width = face_crop_bgr.shape[:2]
|
| 126 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
try:
|
| 128 |
+
img_rgb = cv2.cvtColor(face_crop_bgr, cv2.COLOR_BGR2RGB)
|
| 129 |
+
|
| 130 |
+
# GFPGAN typically expects 512x512 input.
|
| 131 |
+
restorer_input_size = (512, 512) # Check your specific model's expected input size
|
| 132 |
+
img_resized_for_model = cv2.resize(img_rgb, restorer_input_size, interpolation=cv2.INTER_AREA)
|
| 133 |
+
|
| 134 |
+
img_normalized = (img_resized_for_model / 255.0).astype(np.float32) # Normalize to [0, 1]
|
| 135 |
+
img_chw = np.transpose(img_normalized, (2, 0, 1)) # HWC to CHW
|
| 136 |
+
input_tensor = np.expand_dims(img_chw, axis=0) # Add batch dimension
|
| 137 |
+
|
| 138 |
+
input_name = face_restorer.get_inputs()[0].name
|
| 139 |
+
output_name = face_restorer.get_outputs()[0].name
|
| 140 |
+
|
| 141 |
restored_output_model_size = face_restorer.run([output_name], {input_name: input_tensor})[0]
|
| 142 |
+
|
| 143 |
+
restored_img_chw = np.squeeze(restored_output_model_size, axis=0)
|
| 144 |
+
restored_img_hwc_model_size = np.transpose(restored_img_chw, (1, 2, 0)) # CHW to HWC
|
| 145 |
+
restored_img_uint8_model_size = np.clip(restored_img_hwc_model_size * 255.0, 0, 255).astype(np.uint8)
|
| 146 |
+
|
| 147 |
+
# Resize back to the original crop's dimensions
|
| 148 |
+
restored_crop_rgb = cv2.resize(restored_img_uint8_model_size, (crop_width, crop_height), interpolation=cv2.INTER_LANCZOS4)
|
| 149 |
+
restored_crop_bgr = cv2.cvtColor(restored_crop_rgb, cv2.COLOR_RGB2BGR)
|
| 150 |
+
|
| 151 |
+
logging.info("Cropped face restoration complete.")
|
| 152 |
+
return restored_crop_bgr
|
| 153 |
except Exception as e:
|
| 154 |
+
logging.error(f"Error during face restoration for crop: {e}", exc_info=True)
|
| 155 |
+
return face_crop_bgr # Return original crop on error
|
| 156 |
+
|
| 157 |
+
def histogram_match_channel(source_channel: np.ndarray, target_channel: np.ndarray) -> np.ndarray:
|
| 158 |
+
"""Matches histogram of a single source channel to a target channel."""
|
| 159 |
+
source_shape = source_channel.shape
|
| 160 |
+
source_channel_flat = source_channel.flatten()
|
| 161 |
+
target_channel_flat = target_channel.flatten()
|
| 162 |
+
|
| 163 |
+
# Get the histogram and CDF of the source image
|
| 164 |
+
source_hist, bins = np.histogram(source_channel_flat, 256, [0,256])
|
| 165 |
+
source_cdf = source_hist.cumsum()
|
| 166 |
+
|
| 167 |
+
# Get the histogram and CDF of the target image
|
| 168 |
+
target_hist, bins = np.histogram(target_channel_flat, 256, [0,256])
|
| 169 |
+
target_cdf = target_hist.cumsum()
|
| 170 |
+
|
| 171 |
+
# Normalize CDFs
|
| 172 |
+
source_cdf_norm = source_cdf * float(target_hist.max()) / source_hist.max() # Avoid divide by zero if source_hist.max() is 0
|
| 173 |
+
target_cdf_norm = target_cdf # No change needed for target if using its own max
|
| 174 |
+
|
| 175 |
+
# Create a lookup table
|
| 176 |
+
lookup_table = np.zeros(256, dtype='uint8')
|
| 177 |
+
gj = 0
|
| 178 |
+
for gi in range(256):
|
| 179 |
+
while gj < 256 and target_cdf_norm[gj] < source_cdf_norm[gi]:
|
| 180 |
+
gj += 1
|
| 181 |
+
if gj == 256: # Safety for out of bounds
|
| 182 |
+
gj = 255
|
| 183 |
+
lookup_table[gi] = gj
|
| 184 |
+
|
| 185 |
+
matched_channel_flat = cv2.LUT(source_channel, lookup_table) # Apply lookup table
|
| 186 |
+
return matched_channel_flat.reshape(source_shape)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def histogram_match_color(source_img: np.ndarray, target_img: np.ndarray) -> np.ndarray:
|
| 190 |
+
"""
|
| 191 |
+
Performs histogram matching on color images (BGR).
|
| 192 |
+
"""
|
| 193 |
+
if source_img is None or target_img is None or source_img.size == 0 or target_img.size == 0:
|
| 194 |
+
logging.warning("Cannot perform histogram matching on empty or None images.")
|
| 195 |
+
return source_img if source_img is not None else np.array([])
|
| 196 |
+
|
| 197 |
+
matched_img = np.zeros_like(source_img)
|
| 198 |
+
try:
|
| 199 |
+
for i in range(source_img.shape[2]): # Iterate over B, G, R channels
|
| 200 |
+
matched_img[:,:,i] = histogram_match_channel(source_img[:,:,i], target_img[:,:,i])
|
| 201 |
+
return matched_img
|
| 202 |
+
except Exception as e:
|
| 203 |
+
logging.error(f"Error during histogram matching: {e}", exc_info=True)
|
| 204 |
+
return source_img # Return original on error
|
| 205 |
+
|
| 206 |
|
| 207 |
# --- Gradio Interface Functions ---
|
| 208 |
def process_face_swap(source_pil_img: Image.Image, target_pil_img: Image.Image,
|
| 209 |
+
target_face_index: int, apply_enhancement: bool, apply_color_correction: bool,
|
| 210 |
progress=gr.Progress(track_tqdm=True)):
|
| 211 |
+
progress(0, desc="Initializing process...")
|
| 212 |
+
|
| 213 |
+
if not core_models_loaded_successfully:
|
| 214 |
+
gr.Error("CRITICAL: Core models (Face Analyzer or Swapper) not loaded. Cannot proceed.")
|
| 215 |
+
return Image.new('RGB', (100,100), color='lightgrey'), None # Placeholder for output
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
if source_pil_img is None: raise gr.Error("Source image not provided.")
|
| 218 |
if target_pil_img is None: raise gr.Error("Target image not provided.")
|
| 219 |
|
| 220 |
+
progress(0.05, desc="Converting images...")
|
| 221 |
source_np = convert_pil_to_cv2(source_pil_img)
|
| 222 |
target_np = convert_pil_to_cv2(target_pil_img)
|
| 223 |
+
|
| 224 |
+
if source_np is None or target_np is None:
|
| 225 |
+
raise gr.Error("Image conversion failed. Please try different images.")
|
| 226 |
+
|
| 227 |
target_h, target_w = target_np.shape[:2]
|
| 228 |
|
| 229 |
+
progress(0.15, desc="Detecting face in source image...")
|
| 230 |
source_faces = get_faces_from_image(source_np)
|
| 231 |
+
if not source_faces: raise gr.Error("No face found in the source image. Please use a clear image of a single face.")
|
| 232 |
+
source_face = source_faces[0] # Assuming the first detected face is the one to use
|
| 233 |
|
| 234 |
+
progress(0.25, desc="Detecting faces in target image...")
|
| 235 |
target_faces = get_faces_from_image(target_np)
|
| 236 |
if not target_faces: raise gr.Error("No faces found in the target image.")
|
| 237 |
if not (0 <= target_face_index < len(target_faces)):
|
| 238 |
+
# This case should ideally be prevented by the slider's dynamic range
|
| 239 |
+
raise gr.Error(f"Selected target face index ({target_face_index}) is out of range. "
|
| 240 |
+
f"Detected {len(target_faces)} faces (indices 0 to {len(target_faces)-1}).")
|
| 241 |
+
|
| 242 |
+
target_face_to_swap_info = target_faces[int(target_face_index)] # Face object from InsightFace
|
| 243 |
|
| 244 |
+
swapped_bgr_img = target_np.copy() # Start with a copy of target
|
| 245 |
try:
|
| 246 |
+
progress(0.4, desc="Performing face swap...")
|
| 247 |
+
# swapper.get returns the modified target_np with the face pasted back
|
| 248 |
swapped_bgr_img = swapper.get(target_np, target_face_to_swap_info, source_face, paste_back=True)
|
| 249 |
|
| 250 |
+
# Define bounding box for post-processing (enhancement and color correction)
|
| 251 |
+
# Use the bbox of the target face where the swap occurred.
|
| 252 |
+
# InsightFace bbox is [x1, y1, x2, y2]
|
| 253 |
+
bbox_coords = target_face_to_swap_info.bbox.astype(int)
|
| 254 |
+
|
| 255 |
if apply_enhancement:
|
| 256 |
+
if restoration_model_loaded_successfully:
|
| 257 |
+
progress(0.6, desc="Applying selective face enhancement...")
|
| 258 |
+
padding_enh = 20 # Pixels of padding around the bbox for enhancement context
|
| 259 |
+
|
| 260 |
+
enh_x1 = max(0, bbox_coords[0] - padding_enh)
|
| 261 |
+
enh_y1 = max(0, bbox_coords[1] - padding_enh)
|
| 262 |
+
enh_x2 = min(target_w, bbox_coords[2] + padding_enh)
|
| 263 |
+
enh_y2 = min(target_h, bbox_coords[3] + padding_enh)
|
| 264 |
+
|
| 265 |
+
if enh_x1 < enh_x2 and enh_y1 < enh_y2: # Check if the crop is valid
|
| 266 |
+
face_crop_to_enhance = swapped_bgr_img[enh_y1:enh_y2, enh_x1:enh_x2]
|
| 267 |
+
enhanced_crop = enhance_cropped_face(face_crop_to_enhance.copy()) # Use .copy()
|
| 268 |
+
if enhanced_crop.shape == face_crop_to_enhance.shape: # Ensure enhanced crop is valid and same size
|
| 269 |
+
swapped_bgr_img[enh_y1:enh_y2, enh_x1:enh_x2] = enhanced_crop
|
| 270 |
else:
|
| 271 |
+
logging.warning("Enhanced crop size mismatch. Skipping paste-back for enhancement.")
|
| 272 |
else:
|
| 273 |
logging.warning("Skipping enhancement, invalid crop dimensions after padding.")
|
| 274 |
else:
|
| 275 |
+
gr.Info("Face restoration model not available, enhancement skipped by process.")
|
|
|
|
| 276 |
logging.warning("Enhancement requested at runtime but face restorer is not available.")
|
| 277 |
+
|
| 278 |
+
if apply_color_correction:
|
| 279 |
+
progress(0.75, desc="Applying color correction...")
|
| 280 |
+
# For color correction, we use the original target face region from `target_np` as the reference.
|
| 281 |
+
# The region to correct is the same bounding box in `swapped_bgr_img`.
|
| 282 |
+
cc_x1 = bbox_coords[0]
|
| 283 |
+
cc_y1 = bbox_coords[1]
|
| 284 |
+
cc_x2 = bbox_coords[2]
|
| 285 |
+
cc_y2 = bbox_coords[3]
|
| 286 |
+
|
| 287 |
+
if cc_x1 < cc_x2 and cc_y1 < cc_y2: # Valid bbox
|
| 288 |
+
target_face_region_for_color = target_np[cc_y1:cc_y2, cc_x1:cc_x2]
|
| 289 |
+
swapped_face_region_to_correct = swapped_bgr_img[cc_y1:cc_y2, cc_x1:cc_x2]
|
| 290 |
+
|
| 291 |
+
if target_face_region_for_color.size > 0 and swapped_face_region_to_correct.size > 0:
|
| 292 |
+
corrected_swapped_region = histogram_match_color(swapped_face_region_to_correct.copy(), target_face_region_for_color.copy())
|
| 293 |
+
if corrected_swapped_region.shape == swapped_face_region_to_correct.shape:
|
| 294 |
+
swapped_bgr_img[cc_y1:cc_y2, cc_x1:cc_x2] = corrected_swapped_region
|
| 295 |
+
else:
|
| 296 |
+
logging.warning("Color corrected region size mismatch. Skipping paste-back for color correction.")
|
| 297 |
+
else:
|
| 298 |
+
logging.warning("Skipping color correction, empty face region(s) for matching.")
|
| 299 |
+
else:
|
| 300 |
+
logging.warning("Skipping color correction, invalid bounding box for region extraction.")
|
| 301 |
+
|
| 302 |
|
| 303 |
except Exception as e:
|
| 304 |
+
logging.error(f"Error during face swapping or post-processing: {e}", exc_info=True)
|
| 305 |
+
# Return the current state of swapped_bgr_img to avoid losing partial work if possible
|
| 306 |
+
swapped_pil_img_on_error = convert_cv2_to_pil(swapped_bgr_img)
|
| 307 |
+
if swapped_pil_img_on_error is None: # Fallback if conversion itself fails
|
| 308 |
+
swapped_pil_img_on_error = Image.new('RGB', (target_w, target_h), color='lightgrey')
|
| 309 |
raise gr.Error(f"An error occurred: {str(e)}")
|
| 310 |
+
# return swapped_pil_img_on_error, None # Alternative for UI stability
|
| 311 |
|
| 312 |
progress(0.9, desc="Finalizing image...")
|
| 313 |
swapped_pil_img = convert_cv2_to_pil(swapped_bgr_img)
|
| 314 |
+
if swapped_pil_img is None: # Handle case where final conversion fails
|
| 315 |
+
gr.Error("Failed to convert final image to display format.")
|
| 316 |
+
swapped_pil_img = Image.new('RGB', (target_w, target_h), color='lightgrey') # Placeholder
|
| 317 |
|
| 318 |
temp_file_path = None
|
| 319 |
try:
|
| 320 |
# Using 'with' ensures the file is closed before Gradio tries to use it
|
| 321 |
+
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False, mode="wb") as tmp_file:
|
| 322 |
+
swapped_pil_img.save(tmp_file, format="JPEG") # Specify format for PIL save
|
| 323 |
temp_file_path = tmp_file.name
|
| 324 |
logging.info(f"Swapped image saved to temporary file: {temp_file_path}")
|
| 325 |
except Exception as e:
|
| 326 |
+
logging.error(f"Error saving to temporary file: {e}", exc_info=True)
|
| 327 |
gr.Warning("Could not save the swapped image for download.")
|
| 328 |
|
| 329 |
progress(1.0, desc="Processing complete!")
|
|
|
|
| 335 |
return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
|
| 336 |
|
| 337 |
target_np = convert_pil_to_cv2(target_pil_img)
|
| 338 |
+
if target_np is None: # If conversion failed
|
| 339 |
blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
|
| 340 |
return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
|
| 341 |
|
| 342 |
faces = get_faces_from_image(target_np)
|
| 343 |
+
if not faces: # If no faces are detected
|
| 344 |
+
preview_pil_img = convert_cv2_to_pil(target_np) # Show original image if no faces
|
| 345 |
+
if preview_pil_img is None: # Fallback
|
| 346 |
+
preview_pil_img = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
|
| 347 |
+
return preview_pil_img, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
|
| 348 |
+
|
| 349 |
preview_np_img = draw_detected_faces(target_np, faces)
|
| 350 |
preview_pil_img = convert_cv2_to_pil(preview_np_img)
|
| 351 |
+
if preview_pil_img is None: # Fallback
|
| 352 |
+
preview_pil_img = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color='lightgray')
|
| 353 |
+
|
| 354 |
|
| 355 |
num_faces = len(faces)
|
| 356 |
+
slider_update = gr.Slider(minimum=0, maximum=max(0, num_faces - 1), value=0, step=1, interactive=(num_faces > 0))
|
| 357 |
return preview_pil_img, slider_update
|
| 358 |
|
| 359 |
|
| 360 |
# --- Gradio UI Definition ---
|
| 361 |
+
with gr.Blocks(title="Ultimate Face Swap AI π v3", theme=gr.themes.Soft()) as demo: # Changed theme for variety
|
| 362 |
gr.Markdown(
|
| 363 |
"""
|
| 364 |
+
<div style="text-align: center;">
|
| 365 |
+
<h1>π Ultimate Face Swap AI π v3</h1>
|
| 366 |
+
<p>Upload a source face, a target image, and let the AI work its magic!</p>
|
| 367 |
+
<p>Optionally enhance with face restoration and color correction for stunning realism.</p>
|
| 368 |
+
</div>
|
| 369 |
"""
|
| 370 |
)
|
| 371 |
|
| 372 |
+
if not core_models_loaded_successfully:
|
| 373 |
+
gr.Error("CRITICAL ERROR: Core models (Face Analyzer or Swapper) failed to load. Application will not function correctly. Please check the console logs for details (e.g., model file paths) and restart the application after fixing.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
with gr.Row():
|
| 376 |
with gr.Column(scale=1):
|
| 377 |
+
source_image_input = gr.Image(label="π€ Source Face Image (Clear, single face)", type="pil", sources=["upload", "clipboard"], height=350)
|
| 378 |
with gr.Column(scale=1):
|
| 379 |
+
target_image_input = gr.Image(label="πΌοΈ Target Scene Image", type="pil", sources=["upload", "clipboard"], height=350)
|
| 380 |
|
| 381 |
with gr.Row(equal_height=True):
|
| 382 |
+
preview_button = gr.Button("π Preview & Select Target Face", variant="secondary")
|
| 383 |
face_index_slider = gr.Slider(
|
| 384 |
label="π― Select Target Face (0-indexed)",
|
| 385 |
minimum=0, maximum=0, step=1, value=0, interactive=False
|
| 386 |
)
|
| 387 |
|
| 388 |
+
target_faces_preview_output = gr.Image(label="π Detected Faces in Target", interactive=False, height=350)
|
| 389 |
|
| 390 |
+
gr.HTML("<hr style='margin-top: 15px; margin-bottom: 15px;'>") # Styled HR
|
| 391 |
|
| 392 |
+
with gr.Row():
|
| 393 |
+
with gr.Column(scale=1):
|
| 394 |
+
enhance_checkbox_label = "β¨ Apply Selective Face Restoration"
|
| 395 |
+
if not restoration_model_loaded_successfully:
|
| 396 |
+
enhance_checkbox_label += " (Model N/A)"
|
| 397 |
+
enhance_checkbox = gr.Checkbox(
|
| 398 |
+
label=enhance_checkbox_label,
|
| 399 |
+
value=restoration_model_loaded_successfully, # Default to True only if available
|
| 400 |
+
interactive=restoration_model_loaded_successfully # Disable if not available
|
| 401 |
+
)
|
| 402 |
+
if not restoration_model_loaded_successfully:
|
| 403 |
+
gr.Markdown("<p style='color: orange; font-size:0.8em;'>β οΈ Face restoration model not loaded. Feature disabled. Check logs.</p>")
|
| 404 |
+
|
| 405 |
+
with gr.Column(scale=1):
|
| 406 |
+
color_correction_checkbox = gr.Checkbox(label="π¨ Apply Color Correction (Histogram Matching)", value=True) # Default to True
|
| 407 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
|
| 409 |
with gr.Row():
|
| 410 |
+
swap_button = gr.Button("π GENERATE SWAP!", variant="primary", scale=3, interactive=core_models_loaded_successfully) # Disable if core models failed
|
| 411 |
clear_button = gr.Button("π§Ή Clear All", variant="stop", scale=1)
|
| 412 |
|
| 413 |
+
|
| 414 |
with gr.Row():
|
| 415 |
+
swapped_image_output = gr.Image(label="β¨ Swapped Result", interactive=False, height=450)
|
| 416 |
download_output_file = gr.File(label="β¬οΈ Download Swapped Image")
|
| 417 |
|
| 418 |
+
# --- Event Handlers ---
|
| 419 |
+
def on_target_image_change_or_clear(target_img_pil):
|
| 420 |
+
if target_img_pil is None: # Cleared
|
| 421 |
blank_image_pil = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color = 'lightgray')
|
| 422 |
return blank_image_pil, gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False)
|
| 423 |
+
# If an image is uploaded, preview_target_faces will be called by the button click.
|
| 424 |
+
# This function is mainly for handling the clear/reset case or an initial state.
|
| 425 |
+
# For safety, return current values if not cleared, or trigger preview.
|
| 426 |
+
# Let's just return blank on clear, and let preview button handle new image.
|
| 427 |
+
return target_faces_preview_output.value, face_index_slider.value # Keep current state on new image upload until preview
|
| 428 |
|
| 429 |
target_image_input.change(
|
| 430 |
fn=on_target_image_change_or_clear,
|
| 431 |
inputs=[target_image_input],
|
| 432 |
outputs=[target_faces_preview_output, face_index_slider],
|
| 433 |
+
queue=False
|
| 434 |
)
|
| 435 |
+
|
|
|
|
| 436 |
preview_button.click(
|
| 437 |
fn=preview_target_faces,
|
| 438 |
inputs=[target_image_input],
|
|
|
|
| 441 |
|
| 442 |
swap_button.click(
|
| 443 |
fn=process_face_swap,
|
| 444 |
+
inputs=[source_image_input, target_image_input, face_index_slider, enhance_checkbox, color_correction_checkbox],
|
| 445 |
outputs=[swapped_image_output, download_output_file]
|
| 446 |
)
|
| 447 |
|
| 448 |
def clear_all_inputs_outputs():
|
| 449 |
blank_preview = Image.new('RGB', (DETECTION_SIZE[0], DETECTION_SIZE[1]), color = 'lightgray')
|
| 450 |
+
# Also reset the slider and preview output properly
|
| 451 |
+
return (
|
| 452 |
+
None, # source_image_input
|
| 453 |
+
None, # target_image_input
|
| 454 |
+
blank_preview, # target_faces_preview_output
|
| 455 |
+
gr.Slider(minimum=0, maximum=0, value=0, step=1, interactive=False), # face_index_slider
|
| 456 |
+
None, # swapped_image_output
|
| 457 |
+
None # download_output_file
|
| 458 |
+
)
|
| 459 |
|
| 460 |
clear_button.click(
|
| 461 |
fn=clear_all_inputs_outputs,
|
| 462 |
+
inputs=None, # No inputs for this function
|
| 463 |
+
outputs=[
|
| 464 |
+
source_image_input, target_image_input,
|
| 465 |
+
target_faces_preview_output, face_index_slider,
|
| 466 |
+
swapped_image_output, download_output_file
|
| 467 |
+
],
|
| 468 |
+
queue=False # No need to queue a simple clear
|
| 469 |
)
|
| 470 |
|
| 471 |
gr.Examples(
|
| 472 |
examples=[
|
| 473 |
+
["examples/source_face.jpg", "examples/target_group.jpg", 0, True, True],
|
| 474 |
+
["examples/source_actor.png", "examples/target_scene.png", 1, True, True],
|
| 475 |
+
["examples/source_face.jpg", "examples/target_group.jpg", 0, False, True], # No enhancement, with color correction
|
| 476 |
+
["examples/source_actor.png", "examples/target_scene.png", 0, True, False], # With enhancement, no color correction
|
| 477 |
],
|
| 478 |
+
inputs=[source_image_input, target_image_input, face_index_slider, enhance_checkbox, color_correction_checkbox],
|
| 479 |
outputs=[swapped_image_output, download_output_file],
|
| 480 |
+
fn=process_face_swap, # This will also show progress for examples
|
| 481 |
+
cache_examples=False, # Set to "lazy" or True if processing is slow and examples are static
|
| 482 |
+
label="Example Face Swaps (Click to run)"
|
| 483 |
)
|
| 484 |
|
| 485 |
+
# --- Main Execution Block ---
|
| 486 |
if __name__ == "__main__":
|
| 487 |
os.makedirs("models", exist_ok=True)
|
| 488 |
+
os.makedirs("examples", exist_ok=True) # Create if you intend to use example images
|
| 489 |
+
|
| 490 |
+
# Console messages about model loading status
|
| 491 |
+
print("\n" + "="*60)
|
| 492 |
+
print("π ULTIMATE FACE SWAP AI - STARTUP STATUS π")
|
| 493 |
+
print("="*60)
|
| 494 |
+
if not core_models_loaded_successfully:
|
| 495 |
+
print("π΄ CRITICAL ERROR: Core models (Face Analyzer or Swapper) FAILED to load.")
|
| 496 |
+
print(f" - Face Analyzer: '{FACE_ANALYZER_NAME}' (Status: {'Loaded' if face_analyzer else 'Failed'})")
|
| 497 |
+
print(f" - Swapper Model: '{SWAPPER_MODEL_PATH}' (Status: {'Loaded' if swapper else 'Failed'})")
|
| 498 |
+
print(" The application UI will load but will be NON-FUNCTIONAL.")
|
| 499 |
+
print(" Please check model paths and ensure files exist and are not corrupted.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
else:
|
| 501 |
+
print("π’ Core models (Face Analyzer & Swapper) loaded successfully.")
|
| 502 |
|
| 503 |
+
if not restoration_model_loaded_successfully:
|
| 504 |
+
print(f"π‘ INFO: Face Restoration model ('{RESTORATION_MODEL_PATH}') not loaded.")
|
| 505 |
+
print(" The 'Face Restoration' feature will be disabled.")
|
| 506 |
+
print(" To enable, ensure the model file exists at the specified path and is valid.")
|
| 507 |
+
else:
|
| 508 |
+
print("π’ Face Restoration model loaded successfully.")
|
| 509 |
+
print("="*60 + "\n")
|
| 510 |
+
|
| 511 |
+
print("Launching Gradio Interface...")
|
| 512 |
demo.launch()
|