""" Face Re-Aging with ONNX (CPU) Based on Disney's FRAN (Face Re-Aging Network) architecture. Model: face_reaging.onnx from VisoMaster-Fusion. """ import os import time import cv2 import numpy as np import onnxruntime as ort import gradio as gr from PIL import Image from huggingface_hub import hf_hub_download # --------------------------------------------------------------------------- # Model loading # --------------------------------------------------------------------------- MODEL_PATH = "face_reaging.onnx" REPO_ID = "Luminia/Face-ReAging-CPU" def get_model_path(): if os.path.exists(MODEL_PATH): return MODEL_PATH return hf_hub_download(repo_id=REPO_ID, filename=MODEL_PATH) print("Loading ONNX model...") _so = ort.SessionOptions() _so.intra_op_num_threads = os.cpu_count() _so.inter_op_num_threads = os.cpu_count() sess = ort.InferenceSession( get_model_path(), providers=["CPUExecutionProvider"], sess_options=_so, ) print("Model loaded.") # --------------------------------------------------------------------------- # OpenCV DNN face detection (no extra dependencies) # --------------------------------------------------------------------------- # Use OpenCV's built-in Haar cascade as primary, with DNN SSD as fallback _face_cascade = cv2.CascadeClassifier( cv2.data.haarcascades + "haarcascade_frontalface_default.xml" ) # Try to use the more accurate DNN face detector if available _dnn_net = None _dnn_model_path = os.path.join(os.path.dirname(__file__), "face_detection_yunet_2023mar.onnx") YUNET_URL = "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx" def _ensure_yunet(): """Download YuNet face detector if not present.""" global _dnn_model_path if not os.path.exists(_dnn_model_path): print("Downloading YuNet face detector...") try: path = hf_hub_download( repo_id="opencv/opencv_zoo", filename="models/face_detection_yunet/face_detection_yunet_2023mar.onnx", ) _dnn_model_path = path except Exception: import urllib.request urllib.request.urlretrieve(YUNET_URL, _dnn_model_path) print("YuNet downloaded.") return _dnn_model_path def detect_face_box(image_rgb: np.ndarray): """ Detect the largest face bounding box. Returns (x1, y1, x2, y2) in pixel coords or None. """ h, w = image_rgb.shape[:2] # Try YuNet first (more accurate) try: yunet_path = _ensure_yunet() detector = cv2.FaceDetectorYN.create(yunet_path, "", (w, h), 0.5, 0.3, 5000) _, faces = detector.detect(image_rgb) if faces is not None and len(faces) > 0: # Pick largest face by area best_idx = 0 best_area = 0 for i, face in enumerate(faces): fw, fh = face[2], face[3] area = fw * fh if area > best_area: best_area = area best_idx = i f = faces[best_idx] x1, y1 = int(f[0]), int(f[1]) x2, y2 = int(f[0] + f[2]), int(f[1] + f[3]) return (max(x1, 0), max(y1, 0), min(x2, w), min(y2, h)) except Exception as e: print(f"YuNet failed, falling back to Haar: {e}") # Fallback: Haar cascade gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY) faces = _face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60)) if len(faces) == 0: return None # Pick largest best_idx = np.argmax([fw * fh for (_, _, fw, fh) in faces]) x, y, fw, fh = faces[best_idx] return (x, y, x + fw, y + fh) # --------------------------------------------------------------------------- # Face cropping with margin # --------------------------------------------------------------------------- def crop_face_region(image_rgb: np.ndarray, box): """ Crop a square region around the detected face with generous margins (similar to FRAN's approach: forehead gets more margin). Returns: cropped image, (l_x, l_y, r_x, r_y) paste-back coords. """ h, w = image_rgb.shape[:2] x1, y1, x2, y2 = box face_w = x2 - x1 face_h = y2 - y1 # Margins: top is larger (forehead), bottom smaller margin_top = int(face_h * 0.63 * 0.85) margin_bot = int(face_h * 0.37 * 0.85) margin_x = int(face_w * 0.85 / 2) # Adjust top margin to keep square margin_top += 2 * margin_x - margin_top - margin_bot l_y = max(y1 - margin_top, 0) r_y = min(y2 + margin_bot, h) l_x = max(x1 - margin_x, 0) r_x = min(x2 + margin_x, w) cropped = image_rgb[l_y:r_y, l_x:r_x, :] return cropped, (l_x, l_y, r_x, r_y) # --------------------------------------------------------------------------- # Blending mask (soft feathered edges) # --------------------------------------------------------------------------- def create_blend_mask(crop_h, crop_w, feather=0.15): """ Create a soft feathered blending mask to avoid hard edges when pasting the re-aged face back. """ mask = np.ones((crop_h, crop_w), dtype=np.float32) border_y = max(int(crop_h * feather), 1) border_x = max(int(crop_w * feather), 1) for i in range(border_y): alpha = i / border_y mask[i, :] *= alpha mask[crop_h - 1 - i, :] *= alpha for j in range(border_x): alpha = j / border_x mask[:, j] *= alpha mask[:, crop_w - 1 - j] *= alpha return mask[:, :, np.newaxis] # (H, W, 1) # --------------------------------------------------------------------------- # Core inference # --------------------------------------------------------------------------- def reage_face( image_pil: Image.Image, source_age: int, target_age: int, ): """ Re-age the face in the given PIL image. """ t0 = time.time() image_rgb = np.array(image_pil.convert("RGB")) h_orig, w_orig = image_rgb.shape[:2] # Detect face box = detect_face_box(image_rgb) if box is None: raise gr.Error("No face detected in the image. Please upload a clear photo with a visible face.") # Crop face region cropped, (l_x, l_y, r_x, r_y) = crop_face_region(image_rgb, box) crop_h, crop_w = cropped.shape[:2] # Resize to 512x512 for the model cropped_resized = cv2.resize(cropped, (512, 512), interpolation=cv2.INTER_LINEAR) # Normalize to [0, 1] float32, CHW img_tensor = cropped_resized.astype(np.float32) / 255.0 img_tensor = np.transpose(img_tensor, (2, 0, 1)) # (3, 512, 512) # Create age channels src_age_ch = np.full((1, 512, 512), source_age / 100.0, dtype=np.float32) tgt_age_ch = np.full((1, 512, 512), target_age / 100.0, dtype=np.float32) # Stack: (5, 512, 512) -> (1, 5, 512, 512) input_tensor = np.concatenate([img_tensor, src_age_ch, tgt_age_ch], axis=0) input_tensor = input_tensor[np.newaxis, ...] # Run inference delta = sess.run(None, {"input": input_tensor})[0] # (1, 3, 512, 512) # Apply delta to the cropped image aged = img_tensor + delta[0] # (3, 512, 512) aged = np.clip(aged, 0.0, 1.0) # Convert back to HWC uint8 aged_hwc = np.transpose(aged, (1, 2, 0)) # (512, 512, 3) aged_hwc = (aged_hwc * 255).astype(np.uint8) # Resize back to original crop size aged_resized = cv2.resize(aged_hwc, (crop_w, crop_h), interpolation=cv2.INTER_LINEAR) # Blend back into original image result = image_rgb.copy() blend_mask = create_blend_mask(crop_h, crop_w, feather=0.12) region = result[l_y:r_y, l_x:r_x].astype(np.float32) aged_f = aged_resized.astype(np.float32) blended = region * (1 - blend_mask) + aged_f * blend_mask result[l_y:r_y, l_x:r_x] = blended.astype(np.uint8) elapsed = time.time() - t0 info = f"Done in {elapsed:.2f}s | Source age: {source_age} | Target age: {target_age}" return Image.fromarray(result), info # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- def process(image, source_age, target_age): if image is None: raise gr.Error("Please upload an image.") return reage_face(image, int(source_age), int(target_age)) with gr.Blocks(title="Face Re-Aging (CPU)") as demo: gr.Markdown("# Face Re-Aging (CPU)\nAge or de-age faces using Disney FRAN-style model. Upload a photo, set source & target age.") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input Image") source_age = gr.Slider( minimum=5, maximum=95, value=25, step=1, label="Source Age (current age of the person)", ) target_age = gr.Slider( minimum=5, maximum=95, value=65, step=1, label="Target Age (desired age)", ) run_btn = gr.Button("Re-Age Face", variant="primary") with gr.Column(): output_image = gr.Image(type="pil", label="Re-Aged Result") info_text = gr.Textbox(label="Info", interactive=False) run_btn.click( fn=process, inputs=[input_image, source_age, target_age], outputs=[output_image, info_text], ) gr.Markdown( "**Model:** `face_reaging.onnx` (118 MB) from " "[VisoMaster-Fusion](https://github.com/VisoMasterFusion/VisoMaster-Fusion) | " "Based on [Disney FRAN](https://studios.disneyresearch.com/2022/11/30/production-ready-face-re-aging-for-visual-effects/)" ) if __name__ == "__main__": demo.launch(show_error=True, ssr_mode=False, theme="NoCrypt/miku")