Nekochu's picture
Switch to OpenCV face detection
2a828f1
raw
history blame
9.77 kB
"""
Face Re-Aging with ONNX (CPU)
Based on Disney's FRAN (Face Re-Aging Network) architecture.
Model: face_reaging.onnx from VisoMaster-Fusion.
"""
import os
import time
import cv2
import numpy as np
import onnxruntime as ort
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
MODEL_PATH = "face_reaging.onnx"
REPO_ID = "Luminia/Face-ReAging-CPU"
def get_model_path():
if os.path.exists(MODEL_PATH):
return MODEL_PATH
return hf_hub_download(repo_id=REPO_ID, filename=MODEL_PATH)
print("Loading ONNX model...")
_so = ort.SessionOptions()
_so.intra_op_num_threads = os.cpu_count()
_so.inter_op_num_threads = os.cpu_count()
sess = ort.InferenceSession(
get_model_path(),
providers=["CPUExecutionProvider"],
sess_options=_so,
)
print("Model loaded.")
# ---------------------------------------------------------------------------
# OpenCV DNN face detection (no extra dependencies)
# ---------------------------------------------------------------------------
# Use OpenCV's built-in Haar cascade as primary, with DNN SSD as fallback
_face_cascade = cv2.CascadeClassifier(
cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
)
# Try to use the more accurate DNN face detector if available
_dnn_net = None
_dnn_model_path = os.path.join(os.path.dirname(__file__), "face_detection_yunet_2023mar.onnx")
YUNET_URL = "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx"
def _ensure_yunet():
"""Download YuNet face detector if not present."""
global _dnn_model_path
if not os.path.exists(_dnn_model_path):
print("Downloading YuNet face detector...")
try:
path = hf_hub_download(
repo_id="opencv/opencv_zoo",
filename="models/face_detection_yunet/face_detection_yunet_2023mar.onnx",
)
_dnn_model_path = path
except Exception:
import urllib.request
urllib.request.urlretrieve(YUNET_URL, _dnn_model_path)
print("YuNet downloaded.")
return _dnn_model_path
def detect_face_box(image_rgb: np.ndarray):
"""
Detect the largest face bounding box.
Returns (x1, y1, x2, y2) in pixel coords or None.
"""
h, w = image_rgb.shape[:2]
# Try YuNet first (more accurate)
try:
yunet_path = _ensure_yunet()
detector = cv2.FaceDetectorYN.create(yunet_path, "", (w, h), 0.5, 0.3, 5000)
_, faces = detector.detect(image_rgb)
if faces is not None and len(faces) > 0:
# Pick largest face by area
best_idx = 0
best_area = 0
for i, face in enumerate(faces):
fw, fh = face[2], face[3]
area = fw * fh
if area > best_area:
best_area = area
best_idx = i
f = faces[best_idx]
x1, y1 = int(f[0]), int(f[1])
x2, y2 = int(f[0] + f[2]), int(f[1] + f[3])
return (max(x1, 0), max(y1, 0), min(x2, w), min(y2, h))
except Exception as e:
print(f"YuNet failed, falling back to Haar: {e}")
# Fallback: Haar cascade
gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY)
faces = _face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60))
if len(faces) == 0:
return None
# Pick largest
best_idx = np.argmax([fw * fh for (_, _, fw, fh) in faces])
x, y, fw, fh = faces[best_idx]
return (x, y, x + fw, y + fh)
# ---------------------------------------------------------------------------
# Face cropping with margin
# ---------------------------------------------------------------------------
def crop_face_region(image_rgb: np.ndarray, box):
"""
Crop a square region around the detected face with generous margins
(similar to FRAN's approach: forehead gets more margin).
Returns: cropped image, (l_x, l_y, r_x, r_y) paste-back coords.
"""
h, w = image_rgb.shape[:2]
x1, y1, x2, y2 = box
face_w = x2 - x1
face_h = y2 - y1
# Margins: top is larger (forehead), bottom smaller
margin_top = int(face_h * 0.63 * 0.85)
margin_bot = int(face_h * 0.37 * 0.85)
margin_x = int(face_w * 0.85 / 2)
# Adjust top margin to keep square
margin_top += 2 * margin_x - margin_top - margin_bot
l_y = max(y1 - margin_top, 0)
r_y = min(y2 + margin_bot, h)
l_x = max(x1 - margin_x, 0)
r_x = min(x2 + margin_x, w)
cropped = image_rgb[l_y:r_y, l_x:r_x, :]
return cropped, (l_x, l_y, r_x, r_y)
# ---------------------------------------------------------------------------
# Blending mask (soft feathered edges)
# ---------------------------------------------------------------------------
def create_blend_mask(crop_h, crop_w, feather=0.15):
"""
Create a soft feathered blending mask to avoid hard edges
when pasting the re-aged face back.
"""
mask = np.ones((crop_h, crop_w), dtype=np.float32)
border_y = max(int(crop_h * feather), 1)
border_x = max(int(crop_w * feather), 1)
for i in range(border_y):
alpha = i / border_y
mask[i, :] *= alpha
mask[crop_h - 1 - i, :] *= alpha
for j in range(border_x):
alpha = j / border_x
mask[:, j] *= alpha
mask[:, crop_w - 1 - j] *= alpha
return mask[:, :, np.newaxis] # (H, W, 1)
# ---------------------------------------------------------------------------
# Core inference
# ---------------------------------------------------------------------------
def reage_face(
image_pil: Image.Image,
source_age: int,
target_age: int,
):
"""
Re-age the face in the given PIL image.
"""
t0 = time.time()
image_rgb = np.array(image_pil.convert("RGB"))
h_orig, w_orig = image_rgb.shape[:2]
# Detect face
box = detect_face_box(image_rgb)
if box is None:
raise gr.Error("No face detected in the image. Please upload a clear photo with a visible face.")
# Crop face region
cropped, (l_x, l_y, r_x, r_y) = crop_face_region(image_rgb, box)
crop_h, crop_w = cropped.shape[:2]
# Resize to 512x512 for the model
cropped_resized = cv2.resize(cropped, (512, 512), interpolation=cv2.INTER_LINEAR)
# Normalize to [0, 1] float32, CHW
img_tensor = cropped_resized.astype(np.float32) / 255.0
img_tensor = np.transpose(img_tensor, (2, 0, 1)) # (3, 512, 512)
# Create age channels
src_age_ch = np.full((1, 512, 512), source_age / 100.0, dtype=np.float32)
tgt_age_ch = np.full((1, 512, 512), target_age / 100.0, dtype=np.float32)
# Stack: (5, 512, 512) -> (1, 5, 512, 512)
input_tensor = np.concatenate([img_tensor, src_age_ch, tgt_age_ch], axis=0)
input_tensor = input_tensor[np.newaxis, ...]
# Run inference
delta = sess.run(None, {"input": input_tensor})[0] # (1, 3, 512, 512)
# Apply delta to the cropped image
aged = img_tensor + delta[0] # (3, 512, 512)
aged = np.clip(aged, 0.0, 1.0)
# Convert back to HWC uint8
aged_hwc = np.transpose(aged, (1, 2, 0)) # (512, 512, 3)
aged_hwc = (aged_hwc * 255).astype(np.uint8)
# Resize back to original crop size
aged_resized = cv2.resize(aged_hwc, (crop_w, crop_h), interpolation=cv2.INTER_LINEAR)
# Blend back into original image
result = image_rgb.copy()
blend_mask = create_blend_mask(crop_h, crop_w, feather=0.12)
region = result[l_y:r_y, l_x:r_x].astype(np.float32)
aged_f = aged_resized.astype(np.float32)
blended = region * (1 - blend_mask) + aged_f * blend_mask
result[l_y:r_y, l_x:r_x] = blended.astype(np.uint8)
elapsed = time.time() - t0
info = f"Done in {elapsed:.2f}s | Source age: {source_age} | Target age: {target_age}"
return Image.fromarray(result), info
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
def process(image, source_age, target_age):
if image is None:
raise gr.Error("Please upload an image.")
return reage_face(image, int(source_age), int(target_age))
with gr.Blocks(title="Face Re-Aging (CPU)") as demo:
gr.Markdown("# Face Re-Aging (CPU)\nAge or de-age faces using Disney FRAN-style model. Upload a photo, set source & target age.")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
source_age = gr.Slider(
minimum=5, maximum=95, value=25, step=1,
label="Source Age (current age of the person)",
)
target_age = gr.Slider(
minimum=5, maximum=95, value=65, step=1,
label="Target Age (desired age)",
)
run_btn = gr.Button("Re-Age Face", variant="primary")
with gr.Column():
output_image = gr.Image(type="pil", label="Re-Aged Result")
info_text = gr.Textbox(label="Info", interactive=False)
run_btn.click(
fn=process,
inputs=[input_image, source_age, target_age],
outputs=[output_image, info_text],
)
gr.Markdown(
"**Model:** `face_reaging.onnx` (118 MB) from "
"[VisoMaster-Fusion](https://github.com/VisoMasterFusion/VisoMaster-Fusion) | "
"Based on [Disney FRAN](https://studios.disneyresearch.com/2022/11/30/production-ready-face-re-aging-for-visual-effects/)"
)
if __name__ == "__main__":
demo.launch(show_error=True, ssr_mode=False, theme="NoCrypt/miku")