""" Face Re-Aging with ONNX (CPU) Based on Disney's FRAN (Face Re-Aging Network) architecture. Model: face_reaging.onnx from VisoMaster-Fusion. Supports image and video re-aging in a single unified view. """ import os import shutil import subprocess import tempfile import time import glob as glob_mod import cv2 import numpy as np import onnxruntime as ort import gradio as gr from PIL import Image from huggingface_hub import hf_hub_download # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- MAX_VIDEO_SECONDS = 30 MAX_FRAMES = 900 MODEL_PATH = "face_reaging.onnx" REPO_ID = "Luminia/Face-ReAging-CPU" # --------------------------------------------------------------------------- # Model loading # --------------------------------------------------------------------------- def get_model_path(): if os.path.exists(MODEL_PATH): return MODEL_PATH return hf_hub_download(repo_id=REPO_ID, filename=MODEL_PATH) print("Loading ONNX model...") _so = ort.SessionOptions() _so.intra_op_num_threads = os.cpu_count() _so.inter_op_num_threads = os.cpu_count() sess = ort.InferenceSession( get_model_path(), providers=["CPUExecutionProvider"], sess_options=_so, ) print("Model loaded.") # --------------------------------------------------------------------------- # Face detection # --------------------------------------------------------------------------- _face_cascade = cv2.CascadeClassifier( cv2.data.haarcascades + "haarcascade_frontalface_default.xml" ) _dnn_model_path = os.path.join(os.path.dirname(__file__), "face_detection_yunet_2023mar.onnx") YUNET_URL = "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx" def _ensure_yunet(): global _dnn_model_path if not os.path.exists(_dnn_model_path): print("Downloading YuNet face detector...") try: path = hf_hub_download( repo_id="opencv/opencv_zoo", filename="models/face_detection_yunet/face_detection_yunet_2023mar.onnx", ) _dnn_model_path = path except Exception: import urllib.request urllib.request.urlretrieve(YUNET_URL, _dnn_model_path) print("YuNet downloaded.") return _dnn_model_path def detect_face_box(image_rgb: np.ndarray): h, w = image_rgb.shape[:2] try: yunet_path = _ensure_yunet() detector = cv2.FaceDetectorYN.create(yunet_path, "", (w, h), 0.5, 0.3, 5000) _, faces = detector.detect(image_rgb) if faces is not None and len(faces) > 0: best_idx = int(np.argmax([f[2] * f[3] for f in faces])) f = faces[best_idx] x1, y1 = int(f[0]), int(f[1]) x2, y2 = int(f[0] + f[2]), int(f[1] + f[3]) return (max(x1, 0), max(y1, 0), min(x2, w), min(y2, h)) except Exception as e: print(f"YuNet failed, falling back to Haar: {e}") gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY) faces = _face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60)) if len(faces) == 0: return None best_idx = np.argmax([fw * fh for (_, _, fw, fh) in faces]) x, y, fw, fh = faces[best_idx] return (x, y, x + fw, y + fh) # --------------------------------------------------------------------------- # Core inference # --------------------------------------------------------------------------- def crop_face_region(image_rgb, box): h, w = image_rgb.shape[:2] x1, y1, x2, y2 = box face_w, face_h = x2 - x1, y2 - y1 margin_top = int(face_h * 0.63 * 0.85) margin_bot = int(face_h * 0.37 * 0.85) margin_x = int(face_w * 0.85 / 2) margin_top += 2 * margin_x - margin_top - margin_bot l_y, r_y = max(y1 - margin_top, 0), min(y2 + margin_bot, h) l_x, r_x = max(x1 - margin_x, 0), min(x2 + margin_x, w) return image_rgb[l_y:r_y, l_x:r_x, :], (l_x, l_y, r_x, r_y) def create_blend_mask(crop_h, crop_w, feather=0.15): mask = np.ones((crop_h, crop_w), dtype=np.float32) by, bx = max(int(crop_h * feather), 1), max(int(crop_w * feather), 1) for i in range(by): a = i / by mask[i, :] *= a mask[crop_h - 1 - i, :] *= a for j in range(bx): a = j / bx mask[:, j] *= a mask[:, crop_w - 1 - j] *= a return mask[:, :, np.newaxis] def reage_frame(image_rgb, source_age, target_age): box = detect_face_box(image_rgb) if box is None: return image_rgb cropped, (l_x, l_y, r_x, r_y) = crop_face_region(image_rgb, box) crop_h, crop_w = cropped.shape[:2] cropped_resized = cv2.resize(cropped, (512, 512), interpolation=cv2.INTER_LINEAR) img_t = cropped_resized.astype(np.float32) / 255.0 img_t = np.transpose(img_t, (2, 0, 1)) src_ch = np.full((1, 512, 512), source_age / 100.0, dtype=np.float32) tgt_ch = np.full((1, 512, 512), target_age / 100.0, dtype=np.float32) inp = np.concatenate([img_t, src_ch, tgt_ch], axis=0)[np.newaxis, ...] delta = sess.run(None, {"input": inp})[0] aged = np.clip(img_t + delta[0], 0.0, 1.0) aged_hwc = (np.transpose(aged, (1, 2, 0)) * 255).astype(np.uint8) aged_resized = cv2.resize(aged_hwc, (crop_w, crop_h), interpolation=cv2.INTER_LINEAR) result = image_rgb.copy() mask = create_blend_mask(crop_h, crop_w, feather=0.12) region = result[l_y:r_y, l_x:r_x].astype(np.float32) blended = region * (1 - mask) + aged_resized.astype(np.float32) * mask result[l_y:r_y, l_x:r_x] = blended.astype(np.uint8) return result # --------------------------------------------------------------------------- # ffmpeg helpers # --------------------------------------------------------------------------- def _find_ffmpeg(): path = shutil.which("ffmpeg") if path: return path for p in ["/usr/bin/ffmpeg", "/usr/local/bin/ffmpeg"]: if os.path.isfile(p): return p raise gr.Error("ffmpeg not found.") def _get_video_info(video_path): ffprobe = shutil.which("ffprobe") or shutil.which("ffprobe", path="/usr/bin:/usr/local/bin") if not ffprobe: cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() return fps, count try: import json r = subprocess.run( [ffprobe, "-v", "quiet", "-print_format", "json", "-show_streams", "-select_streams", "v:0", video_path], capture_output=True, text=True, timeout=30, ) stream = json.loads(r.stdout)["streams"][0] num, den = stream.get("r_frame_rate", "25/1").split("/") fps = float(num) / float(den) nb = stream.get("nb_frames") count = int(nb) if nb and nb != "N/A" else int(float(stream.get("duration", 0)) * fps) return fps, count except Exception: cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() return fps, count def _extract_frames(video_path, out_dir): ffmpeg = _find_ffmpeg() cmd = [ffmpeg, "-i", video_path, "-vsync", "0", os.path.join(out_dir, "frame_%06d.png"), "-y"] r = subprocess.run(cmd, capture_output=True, text=True, timeout=300) if r.returncode != 0: raise gr.Error(f"Frame extraction failed: {r.stderr[-500:]}") def _assemble_video(frames_dir, output_path, fps, audio_source=None): ffmpeg = _find_ffmpeg() cmd = [ffmpeg, "-y", "-framerate", str(fps), "-i", os.path.join(frames_dir, "frame_%06d.png")] if audio_source: cmd += ["-i", audio_source, "-map", "0:v", "-map", "1:a?", "-shortest"] cmd += ["-c:v", "libx264", "-pix_fmt", "yuv420p", "-preset", "fast", "-crf", "20", "-movflags", "+faststart", output_path] r = subprocess.run(cmd, capture_output=True, text=True, timeout=600) if r.returncode != 0: raise gr.Error(f"Video assembly failed: {r.stderr[-500:]}") # --------------------------------------------------------------------------- # Unified process function # --------------------------------------------------------------------------- VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv", ".wmv", ".m4v"} def process(input_file, source_age, target_age, progress=gr.Progress()): if input_file is None: raise gr.Error("Please upload an image or video.") t0 = time.time() source_age, target_age = int(source_age), int(target_age) # Determine if image or video if isinstance(input_file, Image.Image): # Direct PIL image from gr.Image image_rgb = np.array(input_file.convert("RGB")) box = detect_face_box(image_rgb) if box is None: raise gr.Error("No face detected. Please upload a clear photo with a visible face.") result = reage_frame(image_rgb, source_age, target_age) elapsed = time.time() - t0 info = f"Done in {elapsed:.2f}s | {source_age} -> {target_age} years" return Image.fromarray(result), None, info # File path (could be image or video) file_path = input_file if isinstance(input_file, str) else str(input_file) ext = os.path.splitext(file_path)[1].lower() if ext in VIDEO_EXTS: # --- Video processing --- fps, total_frames = _get_video_info(file_path) duration = total_frames / max(fps, 1) if duration > MAX_VIDEO_SECONDS: raise gr.Error(f"Video is {duration:.1f}s (max {MAX_VIDEO_SECONDS}s). Please trim it.") if total_frames > MAX_FRAMES: raise gr.Error(f"Video has {total_frames} frames (max {MAX_FRAMES}).") tmp_root = tempfile.mkdtemp(prefix="reage_") frames_in = os.path.join(tmp_root, "in") frames_out = os.path.join(tmp_root, "out") os.makedirs(frames_in, exist_ok=True) os.makedirs(frames_out, exist_ok=True) try: progress(0, desc="Extracting frames...") _extract_frames(file_path, frames_in) frame_files = sorted(glob_mod.glob(os.path.join(frames_in, "frame_*.png"))) n_frames = len(frame_files) if n_frames == 0: raise gr.Error("No frames extracted. Is this a valid video?") if n_frames > MAX_FRAMES: raise gr.Error(f"{n_frames} frames (max {MAX_FRAMES}).") faces_found, faces_missed = 0, 0 for idx, fpath in enumerate(frame_files): progress((idx + 1) / n_frames, desc=f"Re-aging frame {idx + 1}/{n_frames}...") frame_bgr = cv2.imread(fpath) if frame_bgr is None: continue frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) box = detect_face_box(frame_rgb) if box is not None: result_rgb = reage_frame(frame_rgb, source_age, target_age) faces_found += 1 else: result_rgb = frame_rgb faces_missed += 1 out_path = os.path.join(frames_out, os.path.basename(fpath)) cv2.imwrite(out_path, cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)) progress(1.0, desc="Assembling video...") output_path = os.path.join(tmp_root, "output.mp4") _assemble_video(frames_out, output_path, fps, audio_source=file_path) elapsed = time.time() - t0 speed = n_frames / max(elapsed, 0.01) info = (f"Done in {elapsed:.1f}s | {n_frames} frames at {speed:.1f} fps | " f"Faces: {faces_found} found, {faces_missed} skipped | " f"{source_age} -> {target_age} years") return None, output_path, info except gr.Error: raise except Exception as e: raise gr.Error(f"Video processing failed: {e}") else: # --- Image processing --- image_rgb = cv2.imread(file_path) if image_rgb is None: raise gr.Error("Could not read the file. Please upload a valid image or video.") image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2RGB) box = detect_face_box(image_rgb) if box is None: raise gr.Error("No face detected.") result = reage_frame(image_rgb, source_age, target_age) elapsed = time.time() - t0 info = f"Done in {elapsed:.2f}s | {source_age} -> {target_age} years" return Image.fromarray(result), None, info # --------------------------------------------------------------------------- # Gradio UI - Single unified view # --------------------------------------------------------------------------- with gr.Blocks(title="Face Re-Aging (CPU)") as demo: gr.Markdown( "# Face Re-Aging (CPU)\n" "Upload an **image or video** to age or de-age faces. " f"Videos: max {MAX_VIDEO_SECONDS}s, ~0.5-2 fps on CPU." ) with gr.Row(): with gr.Column(): file_input = gr.File( label="Drop Image or Video Here", file_types=["image", "video"], ) # Also accept pasted/webcam images img_input = gr.Image( type="pil", label="Or paste/capture an image", visible=True, ) src_age = gr.Slider(minimum=5, maximum=95, value=25, step=1, label="Source Age (current)") tgt_age = gr.Slider(minimum=5, maximum=95, value=65, step=1, label="Target Age (desired)") btn = gr.Button("Re-Age", variant="primary", size="lg") with gr.Column(): img_output = gr.Image(type="pil", label="Result (Image)") vid_output = gr.Video(label="Result (Video)") info_box = gr.Textbox(label="Info", interactive=False) def on_submit_file(file_obj, source_age, target_age, progress=gr.Progress()): if file_obj is None: raise gr.Error("Please upload a file.") return process(file_obj, source_age, target_age, progress) def on_submit_image(image, source_age, target_age, progress=gr.Progress()): if image is None: raise gr.Error("Please provide an image.") return process(image, source_age, target_age, progress) btn.click( fn=on_submit_file, inputs=[file_input, src_age, tgt_age], outputs=[img_output, vid_output, info_box], ) # Also trigger on image input (for paste/webcam) img_input.change( fn=on_submit_image, inputs=[img_input, src_age, tgt_age], outputs=[img_output, vid_output, info_box], ) gr.Markdown( "**Model:** `face_reaging.onnx` (118 MB) from " "[VisoMaster-Fusion](https://github.com/VisoMasterFusion/VisoMaster-Fusion) | " "Based on [Disney FRAN](https://studios.disneyresearch.com/2022/11/30/production-ready-face-re-aging-for-visual-effects/)" ) if __name__ == "__main__": demo.launch(show_error=True, ssr_mode=False, theme="NoCrypt/miku")