Spaces:
Sleeping
Sleeping
| """ | |
| Face Re-Aging with ONNX (CPU) | |
| Based on Disney's FRAN (Face Re-Aging Network) architecture. | |
| Model: face_reaging.onnx from VisoMaster-Fusion. | |
| """ | |
| import os | |
| import time | |
| import cv2 | |
| import numpy as np | |
| import onnxruntime as ort | |
| import gradio as gr | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| # --------------------------------------------------------------------------- | |
| # Model loading | |
| # --------------------------------------------------------------------------- | |
| MODEL_PATH = "face_reaging.onnx" | |
| REPO_ID = "Luminia/Face-ReAging-CPU" | |
| def get_model_path(): | |
| if os.path.exists(MODEL_PATH): | |
| return MODEL_PATH | |
| return hf_hub_download(repo_id=REPO_ID, filename=MODEL_PATH) | |
| print("Loading ONNX model...") | |
| _so = ort.SessionOptions() | |
| _so.intra_op_num_threads = os.cpu_count() | |
| _so.inter_op_num_threads = os.cpu_count() | |
| sess = ort.InferenceSession( | |
| get_model_path(), | |
| providers=["CPUExecutionProvider"], | |
| sess_options=_so, | |
| ) | |
| print("Model loaded.") | |
| # --------------------------------------------------------------------------- | |
| # OpenCV DNN face detection (no extra dependencies) | |
| # --------------------------------------------------------------------------- | |
| # Use OpenCV's built-in Haar cascade as primary, with DNN SSD as fallback | |
| _face_cascade = cv2.CascadeClassifier( | |
| cv2.data.haarcascades + "haarcascade_frontalface_default.xml" | |
| ) | |
| # Try to use the more accurate DNN face detector if available | |
| _dnn_net = None | |
| _dnn_model_path = os.path.join(os.path.dirname(__file__), "face_detection_yunet_2023mar.onnx") | |
| YUNET_URL = "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx" | |
| def _ensure_yunet(): | |
| """Download YuNet face detector if not present.""" | |
| global _dnn_model_path | |
| if not os.path.exists(_dnn_model_path): | |
| print("Downloading YuNet face detector...") | |
| try: | |
| path = hf_hub_download( | |
| repo_id="opencv/opencv_zoo", | |
| filename="models/face_detection_yunet/face_detection_yunet_2023mar.onnx", | |
| ) | |
| _dnn_model_path = path | |
| except Exception: | |
| import urllib.request | |
| urllib.request.urlretrieve(YUNET_URL, _dnn_model_path) | |
| print("YuNet downloaded.") | |
| return _dnn_model_path | |
| def detect_face_box(image_rgb: np.ndarray): | |
| """ | |
| Detect the largest face bounding box. | |
| Returns (x1, y1, x2, y2) in pixel coords or None. | |
| """ | |
| h, w = image_rgb.shape[:2] | |
| # Try YuNet first (more accurate) | |
| try: | |
| yunet_path = _ensure_yunet() | |
| detector = cv2.FaceDetectorYN.create(yunet_path, "", (w, h), 0.5, 0.3, 5000) | |
| _, faces = detector.detect(image_rgb) | |
| if faces is not None and len(faces) > 0: | |
| # Pick largest face by area | |
| best_idx = 0 | |
| best_area = 0 | |
| for i, face in enumerate(faces): | |
| fw, fh = face[2], face[3] | |
| area = fw * fh | |
| if area > best_area: | |
| best_area = area | |
| best_idx = i | |
| f = faces[best_idx] | |
| x1, y1 = int(f[0]), int(f[1]) | |
| x2, y2 = int(f[0] + f[2]), int(f[1] + f[3]) | |
| return (max(x1, 0), max(y1, 0), min(x2, w), min(y2, h)) | |
| except Exception as e: | |
| print(f"YuNet failed, falling back to Haar: {e}") | |
| # Fallback: Haar cascade | |
| gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY) | |
| faces = _face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60)) | |
| if len(faces) == 0: | |
| return None | |
| # Pick largest | |
| best_idx = np.argmax([fw * fh for (_, _, fw, fh) in faces]) | |
| x, y, fw, fh = faces[best_idx] | |
| return (x, y, x + fw, y + fh) | |
| # --------------------------------------------------------------------------- | |
| # Face cropping with margin | |
| # --------------------------------------------------------------------------- | |
| def crop_face_region(image_rgb: np.ndarray, box): | |
| """ | |
| Crop a square region around the detected face with generous margins | |
| (similar to FRAN's approach: forehead gets more margin). | |
| Returns: cropped image, (l_x, l_y, r_x, r_y) paste-back coords. | |
| """ | |
| h, w = image_rgb.shape[:2] | |
| x1, y1, x2, y2 = box | |
| face_w = x2 - x1 | |
| face_h = y2 - y1 | |
| # Margins: top is larger (forehead), bottom smaller | |
| margin_top = int(face_h * 0.63 * 0.85) | |
| margin_bot = int(face_h * 0.37 * 0.85) | |
| margin_x = int(face_w * 0.85 / 2) | |
| # Adjust top margin to keep square | |
| margin_top += 2 * margin_x - margin_top - margin_bot | |
| l_y = max(y1 - margin_top, 0) | |
| r_y = min(y2 + margin_bot, h) | |
| l_x = max(x1 - margin_x, 0) | |
| r_x = min(x2 + margin_x, w) | |
| cropped = image_rgb[l_y:r_y, l_x:r_x, :] | |
| return cropped, (l_x, l_y, r_x, r_y) | |
| # --------------------------------------------------------------------------- | |
| # Blending mask (soft feathered edges) | |
| # --------------------------------------------------------------------------- | |
| def create_blend_mask(crop_h, crop_w, feather=0.15): | |
| """ | |
| Create a soft feathered blending mask to avoid hard edges | |
| when pasting the re-aged face back. | |
| """ | |
| mask = np.ones((crop_h, crop_w), dtype=np.float32) | |
| border_y = max(int(crop_h * feather), 1) | |
| border_x = max(int(crop_w * feather), 1) | |
| for i in range(border_y): | |
| alpha = i / border_y | |
| mask[i, :] *= alpha | |
| mask[crop_h - 1 - i, :] *= alpha | |
| for j in range(border_x): | |
| alpha = j / border_x | |
| mask[:, j] *= alpha | |
| mask[:, crop_w - 1 - j] *= alpha | |
| return mask[:, :, np.newaxis] # (H, W, 1) | |
| # --------------------------------------------------------------------------- | |
| # Core inference | |
| # --------------------------------------------------------------------------- | |
| def reage_face( | |
| image_pil: Image.Image, | |
| source_age: int, | |
| target_age: int, | |
| ): | |
| """ | |
| Re-age the face in the given PIL image. | |
| """ | |
| t0 = time.time() | |
| image_rgb = np.array(image_pil.convert("RGB")) | |
| h_orig, w_orig = image_rgb.shape[:2] | |
| # Detect face | |
| box = detect_face_box(image_rgb) | |
| if box is None: | |
| raise gr.Error("No face detected in the image. Please upload a clear photo with a visible face.") | |
| # Crop face region | |
| cropped, (l_x, l_y, r_x, r_y) = crop_face_region(image_rgb, box) | |
| crop_h, crop_w = cropped.shape[:2] | |
| # Resize to 512x512 for the model | |
| cropped_resized = cv2.resize(cropped, (512, 512), interpolation=cv2.INTER_LINEAR) | |
| # Normalize to [0, 1] float32, CHW | |
| img_tensor = cropped_resized.astype(np.float32) / 255.0 | |
| img_tensor = np.transpose(img_tensor, (2, 0, 1)) # (3, 512, 512) | |
| # Create age channels | |
| src_age_ch = np.full((1, 512, 512), source_age / 100.0, dtype=np.float32) | |
| tgt_age_ch = np.full((1, 512, 512), target_age / 100.0, dtype=np.float32) | |
| # Stack: (5, 512, 512) -> (1, 5, 512, 512) | |
| input_tensor = np.concatenate([img_tensor, src_age_ch, tgt_age_ch], axis=0) | |
| input_tensor = input_tensor[np.newaxis, ...] | |
| # Run inference | |
| delta = sess.run(None, {"input": input_tensor})[0] # (1, 3, 512, 512) | |
| # Apply delta to the cropped image | |
| aged = img_tensor + delta[0] # (3, 512, 512) | |
| aged = np.clip(aged, 0.0, 1.0) | |
| # Convert back to HWC uint8 | |
| aged_hwc = np.transpose(aged, (1, 2, 0)) # (512, 512, 3) | |
| aged_hwc = (aged_hwc * 255).astype(np.uint8) | |
| # Resize back to original crop size | |
| aged_resized = cv2.resize(aged_hwc, (crop_w, crop_h), interpolation=cv2.INTER_LINEAR) | |
| # Blend back into original image | |
| result = image_rgb.copy() | |
| blend_mask = create_blend_mask(crop_h, crop_w, feather=0.12) | |
| region = result[l_y:r_y, l_x:r_x].astype(np.float32) | |
| aged_f = aged_resized.astype(np.float32) | |
| blended = region * (1 - blend_mask) + aged_f * blend_mask | |
| result[l_y:r_y, l_x:r_x] = blended.astype(np.uint8) | |
| elapsed = time.time() - t0 | |
| info = f"Done in {elapsed:.2f}s | Source age: {source_age} | Target age: {target_age}" | |
| return Image.fromarray(result), info | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| def process(image, source_age, target_age): | |
| if image is None: | |
| raise gr.Error("Please upload an image.") | |
| return reage_face(image, int(source_age), int(target_age)) | |
| with gr.Blocks(title="Face Re-Aging (CPU)") as demo: | |
| gr.Markdown("# Face Re-Aging (CPU)\nAge or de-age faces using Disney FRAN-style model. Upload a photo, set source & target age.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Input Image") | |
| source_age = gr.Slider( | |
| minimum=5, maximum=95, value=25, step=1, | |
| label="Source Age (current age of the person)", | |
| ) | |
| target_age = gr.Slider( | |
| minimum=5, maximum=95, value=65, step=1, | |
| label="Target Age (desired age)", | |
| ) | |
| run_btn = gr.Button("Re-Age Face", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image(type="pil", label="Re-Aged Result") | |
| info_text = gr.Textbox(label="Info", interactive=False) | |
| run_btn.click( | |
| fn=process, | |
| inputs=[input_image, source_age, target_age], | |
| outputs=[output_image, info_text], | |
| ) | |
| gr.Markdown( | |
| "**Model:** `face_reaging.onnx` (118 MB) from " | |
| "[VisoMaster-Fusion](https://github.com/VisoMasterFusion/VisoMaster-Fusion) | " | |
| "Based on [Disney FRAN](https://studios.disneyresearch.com/2022/11/30/production-ready-face-re-aging-for-visual-effects/)" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True, ssr_mode=False, theme="NoCrypt/miku") | |