unified view: single input for image or video
Browse files
app.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
Face Re-Aging with ONNX (CPU)
|
| 3 |
Based on Disney's FRAN (Face Re-Aging Network) architecture.
|
| 4 |
Model: face_reaging.onnx from VisoMaster-Fusion.
|
| 5 |
-
Supports
|
| 6 |
"""
|
| 7 |
|
| 8 |
import os
|
|
@@ -47,18 +47,16 @@ sess = ort.InferenceSession(
|
|
| 47 |
print("Model loaded.")
|
| 48 |
|
| 49 |
# ---------------------------------------------------------------------------
|
| 50 |
-
#
|
| 51 |
# ---------------------------------------------------------------------------
|
| 52 |
_face_cascade = cv2.CascadeClassifier(
|
| 53 |
cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
|
| 54 |
)
|
| 55 |
-
|
| 56 |
_dnn_model_path = os.path.join(os.path.dirname(__file__), "face_detection_yunet_2023mar.onnx")
|
| 57 |
YUNET_URL = "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx"
|
| 58 |
|
| 59 |
|
| 60 |
def _ensure_yunet():
|
| 61 |
-
"""Download YuNet face detector if not present."""
|
| 62 |
global _dnn_model_path
|
| 63 |
if not os.path.exists(_dnn_model_path):
|
| 64 |
print("Downloading YuNet face detector...")
|
|
@@ -76,26 +74,13 @@ def _ensure_yunet():
|
|
| 76 |
|
| 77 |
|
| 78 |
def detect_face_box(image_rgb: np.ndarray):
|
| 79 |
-
"""
|
| 80 |
-
Detect the largest face bounding box.
|
| 81 |
-
Returns (x1, y1, x2, y2) in pixel coords or None.
|
| 82 |
-
"""
|
| 83 |
h, w = image_rgb.shape[:2]
|
| 84 |
-
|
| 85 |
-
# Try YuNet first (more accurate)
|
| 86 |
try:
|
| 87 |
yunet_path = _ensure_yunet()
|
| 88 |
detector = cv2.FaceDetectorYN.create(yunet_path, "", (w, h), 0.5, 0.3, 5000)
|
| 89 |
_, faces = detector.detect(image_rgb)
|
| 90 |
if faces is not None and len(faces) > 0:
|
| 91 |
-
best_idx =
|
| 92 |
-
best_area = 0
|
| 93 |
-
for i, face in enumerate(faces):
|
| 94 |
-
fw, fh = face[2], face[3]
|
| 95 |
-
area = fw * fh
|
| 96 |
-
if area > best_area:
|
| 97 |
-
best_area = area
|
| 98 |
-
best_idx = i
|
| 99 |
f = faces[best_idx]
|
| 100 |
x1, y1 = int(f[0]), int(f[1])
|
| 101 |
x2, y2 = int(f[0] + f[2]), int(f[1] + f[3])
|
|
@@ -103,172 +88,104 @@ def detect_face_box(image_rgb: np.ndarray):
|
|
| 103 |
except Exception as e:
|
| 104 |
print(f"YuNet failed, falling back to Haar: {e}")
|
| 105 |
|
| 106 |
-
# Fallback: Haar cascade
|
| 107 |
gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY)
|
| 108 |
faces = _face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60))
|
| 109 |
if len(faces) == 0:
|
| 110 |
return None
|
| 111 |
-
|
| 112 |
best_idx = np.argmax([fw * fh for (_, _, fw, fh) in faces])
|
| 113 |
x, y, fw, fh = faces[best_idx]
|
| 114 |
return (x, y, x + fw, y + fh)
|
| 115 |
|
| 116 |
# ---------------------------------------------------------------------------
|
| 117 |
-
#
|
| 118 |
# ---------------------------------------------------------------------------
|
| 119 |
-
def crop_face_region(image_rgb
|
| 120 |
-
"""
|
| 121 |
-
Crop a square region around the detected face with generous margins.
|
| 122 |
-
Returns: cropped image, (l_x, l_y, r_x, r_y) paste-back coords.
|
| 123 |
-
"""
|
| 124 |
h, w = image_rgb.shape[:2]
|
| 125 |
x1, y1, x2, y2 = box
|
| 126 |
-
|
| 127 |
-
face_w = x2 - x1
|
| 128 |
-
face_h = y2 - y1
|
| 129 |
-
|
| 130 |
margin_top = int(face_h * 0.63 * 0.85)
|
| 131 |
margin_bot = int(face_h * 0.37 * 0.85)
|
| 132 |
margin_x = int(face_w * 0.85 / 2)
|
| 133 |
margin_top += 2 * margin_x - margin_top - margin_bot
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
l_y = max(y1 - margin_top, 0)
|
| 136 |
-
r_y = min(y2 + margin_bot, h)
|
| 137 |
-
l_x = max(x1 - margin_x, 0)
|
| 138 |
-
r_x = min(x2 + margin_x, w)
|
| 139 |
-
|
| 140 |
-
cropped = image_rgb[l_y:r_y, l_x:r_x, :]
|
| 141 |
-
return cropped, (l_x, l_y, r_x, r_y)
|
| 142 |
|
| 143 |
-
# ---------------------------------------------------------------------------
|
| 144 |
-
# Blending mask (soft feathered edges)
|
| 145 |
-
# ---------------------------------------------------------------------------
|
| 146 |
def create_blend_mask(crop_h, crop_w, feather=0.15):
|
| 147 |
-
"""Create a soft feathered blending mask."""
|
| 148 |
mask = np.ones((crop_h, crop_w), dtype=np.float32)
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
alpha = j / border_x
|
| 159 |
-
mask[:, j] *= alpha
|
| 160 |
-
mask[:, crop_w - 1 - j] *= alpha
|
| 161 |
-
|
| 162 |
return mask[:, :, np.newaxis]
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
# ---------------------------------------------------------------------------
|
| 167 |
-
def reage_frame(image_rgb: np.ndarray, source_age: int, target_age: int) -> np.ndarray:
|
| 168 |
-
"""
|
| 169 |
-
Re-age the face in a numpy RGB image.
|
| 170 |
-
Returns the re-aged image (same size), or original if no face found.
|
| 171 |
-
"""
|
| 172 |
box = detect_face_box(image_rgb)
|
| 173 |
if box is None:
|
| 174 |
-
return image_rgb
|
| 175 |
|
| 176 |
cropped, (l_x, l_y, r_x, r_y) = crop_face_region(image_rgb, box)
|
| 177 |
crop_h, crop_w = cropped.shape[:2]
|
| 178 |
-
|
| 179 |
cropped_resized = cv2.resize(cropped, (512, 512), interpolation=cv2.INTER_LINEAR)
|
| 180 |
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
input_tensor = np.concatenate([img_tensor, src_age_ch, tgt_age_ch], axis=0)
|
| 188 |
-
input_tensor = input_tensor[np.newaxis, ...]
|
| 189 |
-
|
| 190 |
-
delta = sess.run(None, {"input": input_tensor})[0]
|
| 191 |
-
|
| 192 |
-
aged = img_tensor + delta[0]
|
| 193 |
-
aged = np.clip(aged, 0.0, 1.0)
|
| 194 |
-
|
| 195 |
-
aged_hwc = np.transpose(aged, (1, 2, 0))
|
| 196 |
-
aged_hwc = (aged_hwc * 255).astype(np.uint8)
|
| 197 |
|
|
|
|
|
|
|
|
|
|
| 198 |
aged_resized = cv2.resize(aged_hwc, (crop_w, crop_h), interpolation=cv2.INTER_LINEAR)
|
| 199 |
|
| 200 |
result = image_rgb.copy()
|
| 201 |
-
|
| 202 |
region = result[l_y:r_y, l_x:r_x].astype(np.float32)
|
| 203 |
-
|
| 204 |
-
blended = region * (1 - blend_mask) + aged_f * blend_mask
|
| 205 |
result[l_y:r_y, l_x:r_x] = blended.astype(np.uint8)
|
| 206 |
-
|
| 207 |
return result
|
| 208 |
|
| 209 |
-
# ---------------------------------------------------------------------------
|
| 210 |
-
# Image re-aging (wraps reage_frame for Gradio)
|
| 211 |
-
# ---------------------------------------------------------------------------
|
| 212 |
-
def reage_face(image_pil: Image.Image, source_age: int, target_age: int):
|
| 213 |
-
"""Re-age the face in the given PIL image."""
|
| 214 |
-
t0 = time.time()
|
| 215 |
-
image_rgb = np.array(image_pil.convert("RGB"))
|
| 216 |
-
|
| 217 |
-
box = detect_face_box(image_rgb)
|
| 218 |
-
if box is None:
|
| 219 |
-
raise gr.Error("No face detected in the image. Please upload a clear photo with a visible face.")
|
| 220 |
-
|
| 221 |
-
result = reage_frame(image_rgb, source_age, target_age)
|
| 222 |
-
elapsed = time.time() - t0
|
| 223 |
-
info = f"Done in {elapsed:.2f}s | Source age: {source_age} | Target age: {target_age}"
|
| 224 |
-
return Image.fromarray(result), info
|
| 225 |
-
|
| 226 |
# ---------------------------------------------------------------------------
|
| 227 |
# ffmpeg helpers
|
| 228 |
# ---------------------------------------------------------------------------
|
| 229 |
def _find_ffmpeg():
|
| 230 |
-
"""Return ffmpeg path."""
|
| 231 |
path = shutil.which("ffmpeg")
|
| 232 |
if path:
|
| 233 |
return path
|
| 234 |
-
# HF Spaces usually have it
|
| 235 |
for p in ["/usr/bin/ffmpeg", "/usr/local/bin/ffmpeg"]:
|
| 236 |
if os.path.isfile(p):
|
| 237 |
return p
|
| 238 |
-
raise gr.Error("ffmpeg not found.
|
| 239 |
|
| 240 |
|
| 241 |
-
def _get_video_info(video_path
|
| 242 |
-
"""Get fps and frame count using ffprobe."""
|
| 243 |
ffprobe = shutil.which("ffprobe") or shutil.which("ffprobe", path="/usr/bin:/usr/local/bin")
|
| 244 |
if not ffprobe:
|
| 245 |
-
# Fallback: use OpenCV just to read metadata
|
| 246 |
cap = cv2.VideoCapture(video_path)
|
| 247 |
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
| 248 |
count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 249 |
cap.release()
|
| 250 |
return fps, count
|
| 251 |
-
|
| 252 |
try:
|
|
|
|
| 253 |
r = subprocess.run(
|
| 254 |
[ffprobe, "-v", "quiet", "-print_format", "json",
|
| 255 |
"-show_streams", "-select_streams", "v:0", video_path],
|
| 256 |
capture_output=True, text=True, timeout=30,
|
| 257 |
)
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
stream = info["streams"][0]
|
| 261 |
-
# fps
|
| 262 |
-
fps_str = stream.get("r_frame_rate", "25/1")
|
| 263 |
-
num, den = fps_str.split("/")
|
| 264 |
fps = float(num) / float(den)
|
| 265 |
-
# frame count
|
| 266 |
nb = stream.get("nb_frames")
|
| 267 |
-
if nb and nb != "N/A"
|
| 268 |
-
count = int(nb)
|
| 269 |
-
else:
|
| 270 |
-
dur = float(stream.get("duration", 0))
|
| 271 |
-
count = int(dur * fps)
|
| 272 |
return fps, count
|
| 273 |
except Exception:
|
| 274 |
cap = cv2.VideoCapture(video_path)
|
|
@@ -278,216 +195,182 @@ def _get_video_info(video_path: str):
|
|
| 278 |
return fps, count
|
| 279 |
|
| 280 |
|
| 281 |
-
def _extract_frames(video_path
|
| 282 |
-
"""Extract frames from video using ffmpeg."""
|
| 283 |
ffmpeg = _find_ffmpeg()
|
| 284 |
-
|
| 285 |
-
cmd = [ffmpeg, "-i", video_path, "-vsync", "0", out_pattern, "-y"]
|
| 286 |
r = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
| 287 |
if r.returncode != 0:
|
| 288 |
-
raise gr.Error(f"
|
| 289 |
|
| 290 |
|
| 291 |
-
def _assemble_video(frames_dir
|
| 292 |
-
"""Reassemble frames into MP4 using ffmpeg."""
|
| 293 |
ffmpeg = _find_ffmpeg()
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
cmd = [
|
| 297 |
-
ffmpeg, "-y",
|
| 298 |
-
"-framerate", str(fps),
|
| 299 |
-
"-i", in_pattern,
|
| 300 |
-
]
|
| 301 |
-
|
| 302 |
-
# Try to copy audio from original
|
| 303 |
if audio_source:
|
| 304 |
cmd += ["-i", audio_source, "-map", "0:v", "-map", "1:a?", "-shortest"]
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
"-c:v", "libx264",
|
| 308 |
-
"-pix_fmt", "yuv420p",
|
| 309 |
-
"-preset", "fast",
|
| 310 |
-
"-crf", "20",
|
| 311 |
-
"-movflags", "+faststart",
|
| 312 |
-
output_path,
|
| 313 |
-
]
|
| 314 |
-
|
| 315 |
r = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
|
| 316 |
if r.returncode != 0:
|
| 317 |
-
raise gr.Error(f"
|
| 318 |
|
| 319 |
# ---------------------------------------------------------------------------
|
| 320 |
-
#
|
| 321 |
# ---------------------------------------------------------------------------
|
| 322 |
-
|
| 323 |
-
"""Re-age faces in every frame of a video."""
|
| 324 |
-
if video_path is None:
|
| 325 |
-
raise gr.Error("Please upload a video.")
|
| 326 |
|
| 327 |
-
t0 = time.time()
|
| 328 |
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
f"Please use a shorter video."
|
| 343 |
-
)
|
| 344 |
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
os.makedirs(frames_in, exist_ok=True)
|
| 350 |
-
os.makedirs(frames_out, exist_ok=True)
|
| 351 |
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
# Get frame list
|
| 358 |
-
frame_files = sorted(glob_mod.glob(os.path.join(frames_in, "frame_*.png")))
|
| 359 |
-
n_frames = len(frame_files)
|
| 360 |
-
if n_frames == 0:
|
| 361 |
-
raise gr.Error("No frames extracted from video. Is the file a valid video?")
|
| 362 |
-
|
| 363 |
-
# Re-check limit after extraction
|
| 364 |
-
if n_frames > MAX_FRAMES:
|
| 365 |
-
raise gr.Error(f"Video has {n_frames} frames (max {MAX_FRAMES}). Please use a shorter video.")
|
| 366 |
-
|
| 367 |
-
faces_found = 0
|
| 368 |
-
faces_missed = 0
|
| 369 |
-
|
| 370 |
-
# Process each frame
|
| 371 |
-
for idx, fpath in enumerate(frame_files):
|
| 372 |
-
progress((idx + 1) / n_frames, desc=f"Re-aging frame {idx + 1}/{n_frames}...")
|
| 373 |
-
|
| 374 |
-
# Read frame (BGR -> RGB)
|
| 375 |
-
frame_bgr = cv2.imread(fpath)
|
| 376 |
-
if frame_bgr is None:
|
| 377 |
-
continue
|
| 378 |
-
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
|
| 379 |
-
|
| 380 |
-
# Detect and re-age
|
| 381 |
-
box = detect_face_box(frame_rgb)
|
| 382 |
-
if box is not None:
|
| 383 |
-
result_rgb = reage_frame(frame_rgb, source_age, target_age)
|
| 384 |
-
faces_found += 1
|
| 385 |
-
else:
|
| 386 |
-
result_rgb = frame_rgb
|
| 387 |
-
faces_missed += 1
|
| 388 |
-
|
| 389 |
-
# Save (RGB -> BGR)
|
| 390 |
-
fname = os.path.basename(fpath)
|
| 391 |
-
out_path = os.path.join(frames_out, fname)
|
| 392 |
-
result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
|
| 393 |
-
cv2.imwrite(out_path, result_bgr)
|
| 394 |
-
|
| 395 |
-
# Assemble video
|
| 396 |
-
progress(1.0, desc="Assembling video...")
|
| 397 |
-
output_path = os.path.join(tmp_root, "output.mp4")
|
| 398 |
-
_assemble_video(frames_out, output_path, fps, audio_source=video_path)
|
| 399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
elapsed = time.time() - t0
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
f"Done in {elapsed:.1f}s | {n_frames} frames at {speed:.1f} fps | "
|
| 404 |
-
f"Faces found: {faces_found}, skipped: {faces_missed} | "
|
| 405 |
-
f"Source age: {source_age} -> Target age: {target_age}"
|
| 406 |
-
)
|
| 407 |
|
| 408 |
-
return output_path, info
|
| 409 |
-
|
| 410 |
-
except gr.Error:
|
| 411 |
-
raise
|
| 412 |
-
except Exception as e:
|
| 413 |
-
raise gr.Error(f"Video processing failed: {str(e)}")
|
| 414 |
|
| 415 |
# ---------------------------------------------------------------------------
|
| 416 |
-
# Gradio UI
|
| 417 |
# ---------------------------------------------------------------------------
|
| 418 |
-
def process_image(image, source_age, target_age):
|
| 419 |
-
if image is None:
|
| 420 |
-
raise gr.Error("Please upload an image.")
|
| 421 |
-
return reage_face(image, int(source_age), int(target_age))
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
def process_video(video, source_age, target_age, progress=gr.Progress()):
|
| 425 |
-
if video is None:
|
| 426 |
-
raise gr.Error("Please upload a video.")
|
| 427 |
-
return reage_video(video, int(source_age), int(target_age), progress)
|
| 428 |
-
|
| 429 |
-
|
| 430 |
with gr.Blocks(title="Face Re-Aging (CPU)") as demo:
|
| 431 |
gr.Markdown(
|
| 432 |
"# Face Re-Aging (CPU)\n"
|
| 433 |
-
"
|
| 434 |
-
"
|
| 435 |
)
|
| 436 |
|
| 437 |
-
with gr.
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
img_input = gr.Image(type="pil", label="Input Image")
|
| 443 |
-
img_src_age = gr.Slider(
|
| 444 |
-
minimum=5, maximum=95, value=25, step=1,
|
| 445 |
-
label="Source Age (current age)",
|
| 446 |
-
)
|
| 447 |
-
img_tgt_age = gr.Slider(
|
| 448 |
-
minimum=5, maximum=95, value=65, step=1,
|
| 449 |
-
label="Target Age (desired age)",
|
| 450 |
-
)
|
| 451 |
-
img_btn = gr.Button("Re-Age Face", variant="primary")
|
| 452 |
-
|
| 453 |
-
with gr.Column():
|
| 454 |
-
img_output = gr.Image(type="pil", label="Re-Aged Result")
|
| 455 |
-
img_info = gr.Textbox(label="Info", interactive=False)
|
| 456 |
-
|
| 457 |
-
img_btn.click(
|
| 458 |
-
fn=process_image,
|
| 459 |
-
inputs=[img_input, img_src_age, img_tgt_age],
|
| 460 |
-
outputs=[img_output, img_info],
|
| 461 |
-
)
|
| 462 |
-
|
| 463 |
-
# ---- Video Tab ----
|
| 464 |
-
with gr.TabItem("Video"):
|
| 465 |
-
gr.Markdown(
|
| 466 |
-
f"Upload a video (max **{MAX_VIDEO_SECONDS}s** / **{MAX_FRAMES} frames**). "
|
| 467 |
-
f"Each frame is processed individually on CPU, so expect ~0.5-2 fps."
|
| 468 |
)
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
minimum=5, maximum=95, value=25, step=1,
|
| 474 |
-
label="Source Age (current age)",
|
| 475 |
-
)
|
| 476 |
-
vid_tgt_age = gr.Slider(
|
| 477 |
-
minimum=5, maximum=95, value=65, step=1,
|
| 478 |
-
label="Target Age (desired age)",
|
| 479 |
-
)
|
| 480 |
-
vid_btn = gr.Button("Re-Age Video", variant="primary")
|
| 481 |
-
|
| 482 |
-
with gr.Column():
|
| 483 |
-
vid_output = gr.Video(label="Re-Aged Video")
|
| 484 |
-
vid_info = gr.Textbox(label="Info", interactive=False)
|
| 485 |
-
|
| 486 |
-
vid_btn.click(
|
| 487 |
-
fn=process_video,
|
| 488 |
-
inputs=[vid_input, vid_src_age, vid_tgt_age],
|
| 489 |
-
outputs=[vid_output, vid_info],
|
| 490 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
|
| 492 |
gr.Markdown(
|
| 493 |
"**Model:** `face_reaging.onnx` (118 MB) from "
|
|
|
|
| 2 |
Face Re-Aging with ONNX (CPU)
|
| 3 |
Based on Disney's FRAN (Face Re-Aging Network) architecture.
|
| 4 |
Model: face_reaging.onnx from VisoMaster-Fusion.
|
| 5 |
+
Supports image and video re-aging in a single unified view.
|
| 6 |
"""
|
| 7 |
|
| 8 |
import os
|
|
|
|
| 47 |
print("Model loaded.")
|
| 48 |
|
| 49 |
# ---------------------------------------------------------------------------
|
| 50 |
+
# Face detection
|
| 51 |
# ---------------------------------------------------------------------------
|
| 52 |
_face_cascade = cv2.CascadeClassifier(
|
| 53 |
cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
|
| 54 |
)
|
|
|
|
| 55 |
_dnn_model_path = os.path.join(os.path.dirname(__file__), "face_detection_yunet_2023mar.onnx")
|
| 56 |
YUNET_URL = "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx"
|
| 57 |
|
| 58 |
|
| 59 |
def _ensure_yunet():
|
|
|
|
| 60 |
global _dnn_model_path
|
| 61 |
if not os.path.exists(_dnn_model_path):
|
| 62 |
print("Downloading YuNet face detector...")
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
def detect_face_box(image_rgb: np.ndarray):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
h, w = image_rgb.shape[:2]
|
|
|
|
|
|
|
| 78 |
try:
|
| 79 |
yunet_path = _ensure_yunet()
|
| 80 |
detector = cv2.FaceDetectorYN.create(yunet_path, "", (w, h), 0.5, 0.3, 5000)
|
| 81 |
_, faces = detector.detect(image_rgb)
|
| 82 |
if faces is not None and len(faces) > 0:
|
| 83 |
+
best_idx = int(np.argmax([f[2] * f[3] for f in faces]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
f = faces[best_idx]
|
| 85 |
x1, y1 = int(f[0]), int(f[1])
|
| 86 |
x2, y2 = int(f[0] + f[2]), int(f[1] + f[3])
|
|
|
|
| 88 |
except Exception as e:
|
| 89 |
print(f"YuNet failed, falling back to Haar: {e}")
|
| 90 |
|
|
|
|
| 91 |
gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY)
|
| 92 |
faces = _face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60))
|
| 93 |
if len(faces) == 0:
|
| 94 |
return None
|
|
|
|
| 95 |
best_idx = np.argmax([fw * fh for (_, _, fw, fh) in faces])
|
| 96 |
x, y, fw, fh = faces[best_idx]
|
| 97 |
return (x, y, x + fw, y + fh)
|
| 98 |
|
| 99 |
# ---------------------------------------------------------------------------
|
| 100 |
+
# Core inference
|
| 101 |
# ---------------------------------------------------------------------------
|
| 102 |
+
def crop_face_region(image_rgb, box):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
h, w = image_rgb.shape[:2]
|
| 104 |
x1, y1, x2, y2 = box
|
| 105 |
+
face_w, face_h = x2 - x1, y2 - y1
|
|
|
|
|
|
|
|
|
|
| 106 |
margin_top = int(face_h * 0.63 * 0.85)
|
| 107 |
margin_bot = int(face_h * 0.37 * 0.85)
|
| 108 |
margin_x = int(face_w * 0.85 / 2)
|
| 109 |
margin_top += 2 * margin_x - margin_top - margin_bot
|
| 110 |
+
l_y, r_y = max(y1 - margin_top, 0), min(y2 + margin_bot, h)
|
| 111 |
+
l_x, r_x = max(x1 - margin_x, 0), min(x2 + margin_x, w)
|
| 112 |
+
return image_rgb[l_y:r_y, l_x:r_x, :], (l_x, l_y, r_x, r_y)
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
|
|
|
|
|
|
|
|
|
| 115 |
def create_blend_mask(crop_h, crop_w, feather=0.15):
|
|
|
|
| 116 |
mask = np.ones((crop_h, crop_w), dtype=np.float32)
|
| 117 |
+
by, bx = max(int(crop_h * feather), 1), max(int(crop_w * feather), 1)
|
| 118 |
+
for i in range(by):
|
| 119 |
+
a = i / by
|
| 120 |
+
mask[i, :] *= a
|
| 121 |
+
mask[crop_h - 1 - i, :] *= a
|
| 122 |
+
for j in range(bx):
|
| 123 |
+
a = j / bx
|
| 124 |
+
mask[:, j] *= a
|
| 125 |
+
mask[:, crop_w - 1 - j] *= a
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
return mask[:, :, np.newaxis]
|
| 127 |
|
| 128 |
+
|
| 129 |
+
def reage_frame(image_rgb, source_age, target_age):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
box = detect_face_box(image_rgb)
|
| 131 |
if box is None:
|
| 132 |
+
return image_rgb
|
| 133 |
|
| 134 |
cropped, (l_x, l_y, r_x, r_y) = crop_face_region(image_rgb, box)
|
| 135 |
crop_h, crop_w = cropped.shape[:2]
|
|
|
|
| 136 |
cropped_resized = cv2.resize(cropped, (512, 512), interpolation=cv2.INTER_LINEAR)
|
| 137 |
|
| 138 |
+
img_t = cropped_resized.astype(np.float32) / 255.0
|
| 139 |
+
img_t = np.transpose(img_t, (2, 0, 1))
|
| 140 |
+
src_ch = np.full((1, 512, 512), source_age / 100.0, dtype=np.float32)
|
| 141 |
+
tgt_ch = np.full((1, 512, 512), target_age / 100.0, dtype=np.float32)
|
| 142 |
+
inp = np.concatenate([img_t, src_ch, tgt_ch], axis=0)[np.newaxis, ...]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
+
delta = sess.run(None, {"input": inp})[0]
|
| 145 |
+
aged = np.clip(img_t + delta[0], 0.0, 1.0)
|
| 146 |
+
aged_hwc = (np.transpose(aged, (1, 2, 0)) * 255).astype(np.uint8)
|
| 147 |
aged_resized = cv2.resize(aged_hwc, (crop_w, crop_h), interpolation=cv2.INTER_LINEAR)
|
| 148 |
|
| 149 |
result = image_rgb.copy()
|
| 150 |
+
mask = create_blend_mask(crop_h, crop_w, feather=0.12)
|
| 151 |
region = result[l_y:r_y, l_x:r_x].astype(np.float32)
|
| 152 |
+
blended = region * (1 - mask) + aged_resized.astype(np.float32) * mask
|
|
|
|
| 153 |
result[l_y:r_y, l_x:r_x] = blended.astype(np.uint8)
|
|
|
|
| 154 |
return result
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
# ---------------------------------------------------------------------------
|
| 157 |
# ffmpeg helpers
|
| 158 |
# ---------------------------------------------------------------------------
|
| 159 |
def _find_ffmpeg():
|
|
|
|
| 160 |
path = shutil.which("ffmpeg")
|
| 161 |
if path:
|
| 162 |
return path
|
|
|
|
| 163 |
for p in ["/usr/bin/ffmpeg", "/usr/local/bin/ffmpeg"]:
|
| 164 |
if os.path.isfile(p):
|
| 165 |
return p
|
| 166 |
+
raise gr.Error("ffmpeg not found.")
|
| 167 |
|
| 168 |
|
| 169 |
+
def _get_video_info(video_path):
|
|
|
|
| 170 |
ffprobe = shutil.which("ffprobe") or shutil.which("ffprobe", path="/usr/bin:/usr/local/bin")
|
| 171 |
if not ffprobe:
|
|
|
|
| 172 |
cap = cv2.VideoCapture(video_path)
|
| 173 |
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
| 174 |
count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 175 |
cap.release()
|
| 176 |
return fps, count
|
|
|
|
| 177 |
try:
|
| 178 |
+
import json
|
| 179 |
r = subprocess.run(
|
| 180 |
[ffprobe, "-v", "quiet", "-print_format", "json",
|
| 181 |
"-show_streams", "-select_streams", "v:0", video_path],
|
| 182 |
capture_output=True, text=True, timeout=30,
|
| 183 |
)
|
| 184 |
+
stream = json.loads(r.stdout)["streams"][0]
|
| 185 |
+
num, den = stream.get("r_frame_rate", "25/1").split("/")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
fps = float(num) / float(den)
|
|
|
|
| 187 |
nb = stream.get("nb_frames")
|
| 188 |
+
count = int(nb) if nb and nb != "N/A" else int(float(stream.get("duration", 0)) * fps)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
return fps, count
|
| 190 |
except Exception:
|
| 191 |
cap = cv2.VideoCapture(video_path)
|
|
|
|
| 195 |
return fps, count
|
| 196 |
|
| 197 |
|
| 198 |
+
def _extract_frames(video_path, out_dir):
|
|
|
|
| 199 |
ffmpeg = _find_ffmpeg()
|
| 200 |
+
cmd = [ffmpeg, "-i", video_path, "-vsync", "0", os.path.join(out_dir, "frame_%06d.png"), "-y"]
|
|
|
|
| 201 |
r = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
| 202 |
if r.returncode != 0:
|
| 203 |
+
raise gr.Error(f"Frame extraction failed: {r.stderr[-500:]}")
|
| 204 |
|
| 205 |
|
| 206 |
+
def _assemble_video(frames_dir, output_path, fps, audio_source=None):
|
|
|
|
| 207 |
ffmpeg = _find_ffmpeg()
|
| 208 |
+
cmd = [ffmpeg, "-y", "-framerate", str(fps), "-i", os.path.join(frames_dir, "frame_%06d.png")]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
if audio_source:
|
| 210 |
cmd += ["-i", audio_source, "-map", "0:v", "-map", "1:a?", "-shortest"]
|
| 211 |
+
cmd += ["-c:v", "libx264", "-pix_fmt", "yuv420p", "-preset", "fast", "-crf", "20",
|
| 212 |
+
"-movflags", "+faststart", output_path]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
r = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
|
| 214 |
if r.returncode != 0:
|
| 215 |
+
raise gr.Error(f"Video assembly failed: {r.stderr[-500:]}")
|
| 216 |
|
| 217 |
# ---------------------------------------------------------------------------
|
| 218 |
+
# Unified process function
|
| 219 |
# ---------------------------------------------------------------------------
|
| 220 |
+
VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv", ".wmv", ".m4v"}
|
|
|
|
|
|
|
|
|
|
| 221 |
|
|
|
|
| 222 |
|
| 223 |
+
def process(input_file, source_age, target_age, progress=gr.Progress()):
|
| 224 |
+
if input_file is None:
|
| 225 |
+
raise gr.Error("Please upload an image or video.")
|
| 226 |
|
| 227 |
+
t0 = time.time()
|
| 228 |
+
source_age, target_age = int(source_age), int(target_age)
|
| 229 |
+
|
| 230 |
+
# Determine if image or video
|
| 231 |
+
if isinstance(input_file, Image.Image):
|
| 232 |
+
# Direct PIL image from gr.Image
|
| 233 |
+
image_rgb = np.array(input_file.convert("RGB"))
|
| 234 |
+
box = detect_face_box(image_rgb)
|
| 235 |
+
if box is None:
|
| 236 |
+
raise gr.Error("No face detected. Please upload a clear photo with a visible face.")
|
| 237 |
+
result = reage_frame(image_rgb, source_age, target_age)
|
| 238 |
+
elapsed = time.time() - t0
|
| 239 |
+
info = f"Done in {elapsed:.2f}s | {source_age} -> {target_age} years"
|
| 240 |
+
return Image.fromarray(result), None, info
|
| 241 |
|
| 242 |
+
# File path (could be image or video)
|
| 243 |
+
file_path = input_file if isinstance(input_file, str) else str(input_file)
|
| 244 |
+
ext = os.path.splitext(file_path)[1].lower()
|
|
|
|
|
|
|
| 245 |
|
| 246 |
+
if ext in VIDEO_EXTS:
|
| 247 |
+
# --- Video processing ---
|
| 248 |
+
fps, total_frames = _get_video_info(file_path)
|
| 249 |
+
duration = total_frames / max(fps, 1)
|
|
|
|
|
|
|
| 250 |
|
| 251 |
+
if duration > MAX_VIDEO_SECONDS:
|
| 252 |
+
raise gr.Error(f"Video is {duration:.1f}s (max {MAX_VIDEO_SECONDS}s). Please trim it.")
|
| 253 |
+
if total_frames > MAX_FRAMES:
|
| 254 |
+
raise gr.Error(f"Video has {total_frames} frames (max {MAX_FRAMES}).")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
+
tmp_root = tempfile.mkdtemp(prefix="reage_")
|
| 257 |
+
frames_in = os.path.join(tmp_root, "in")
|
| 258 |
+
frames_out = os.path.join(tmp_root, "out")
|
| 259 |
+
os.makedirs(frames_in, exist_ok=True)
|
| 260 |
+
os.makedirs(frames_out, exist_ok=True)
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
progress(0, desc="Extracting frames...")
|
| 264 |
+
_extract_frames(file_path, frames_in)
|
| 265 |
+
|
| 266 |
+
frame_files = sorted(glob_mod.glob(os.path.join(frames_in, "frame_*.png")))
|
| 267 |
+
n_frames = len(frame_files)
|
| 268 |
+
if n_frames == 0:
|
| 269 |
+
raise gr.Error("No frames extracted. Is this a valid video?")
|
| 270 |
+
if n_frames > MAX_FRAMES:
|
| 271 |
+
raise gr.Error(f"{n_frames} frames (max {MAX_FRAMES}).")
|
| 272 |
+
|
| 273 |
+
faces_found, faces_missed = 0, 0
|
| 274 |
+
for idx, fpath in enumerate(frame_files):
|
| 275 |
+
progress((idx + 1) / n_frames, desc=f"Re-aging frame {idx + 1}/{n_frames}...")
|
| 276 |
+
frame_bgr = cv2.imread(fpath)
|
| 277 |
+
if frame_bgr is None:
|
| 278 |
+
continue
|
| 279 |
+
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
|
| 280 |
+
box = detect_face_box(frame_rgb)
|
| 281 |
+
if box is not None:
|
| 282 |
+
result_rgb = reage_frame(frame_rgb, source_age, target_age)
|
| 283 |
+
faces_found += 1
|
| 284 |
+
else:
|
| 285 |
+
result_rgb = frame_rgb
|
| 286 |
+
faces_missed += 1
|
| 287 |
+
out_path = os.path.join(frames_out, os.path.basename(fpath))
|
| 288 |
+
cv2.imwrite(out_path, cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR))
|
| 289 |
+
|
| 290 |
+
progress(1.0, desc="Assembling video...")
|
| 291 |
+
output_path = os.path.join(tmp_root, "output.mp4")
|
| 292 |
+
_assemble_video(frames_out, output_path, fps, audio_source=file_path)
|
| 293 |
+
|
| 294 |
+
elapsed = time.time() - t0
|
| 295 |
+
speed = n_frames / max(elapsed, 0.01)
|
| 296 |
+
info = (f"Done in {elapsed:.1f}s | {n_frames} frames at {speed:.1f} fps | "
|
| 297 |
+
f"Faces: {faces_found} found, {faces_missed} skipped | "
|
| 298 |
+
f"{source_age} -> {target_age} years")
|
| 299 |
+
return None, output_path, info
|
| 300 |
+
|
| 301 |
+
except gr.Error:
|
| 302 |
+
raise
|
| 303 |
+
except Exception as e:
|
| 304 |
+
raise gr.Error(f"Video processing failed: {e}")
|
| 305 |
+
else:
|
| 306 |
+
# --- Image processing ---
|
| 307 |
+
image_rgb = cv2.imread(file_path)
|
| 308 |
+
if image_rgb is None:
|
| 309 |
+
raise gr.Error("Could not read the file. Please upload a valid image or video.")
|
| 310 |
+
image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2RGB)
|
| 311 |
+
box = detect_face_box(image_rgb)
|
| 312 |
+
if box is None:
|
| 313 |
+
raise gr.Error("No face detected.")
|
| 314 |
+
result = reage_frame(image_rgb, source_age, target_age)
|
| 315 |
elapsed = time.time() - t0
|
| 316 |
+
info = f"Done in {elapsed:.2f}s | {source_age} -> {target_age} years"
|
| 317 |
+
return Image.fromarray(result), None, info
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
# ---------------------------------------------------------------------------
|
| 321 |
+
# Gradio UI - Single unified view
|
| 322 |
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
with gr.Blocks(title="Face Re-Aging (CPU)") as demo:
|
| 324 |
gr.Markdown(
|
| 325 |
"# Face Re-Aging (CPU)\n"
|
| 326 |
+
"Upload an **image or video** to age or de-age faces. "
|
| 327 |
+
f"Videos: max {MAX_VIDEO_SECONDS}s, ~0.5-2 fps on CPU."
|
| 328 |
)
|
| 329 |
|
| 330 |
+
with gr.Row():
|
| 331 |
+
with gr.Column():
|
| 332 |
+
file_input = gr.File(
|
| 333 |
+
label="Drop Image or Video Here",
|
| 334 |
+
file_types=["image", "video"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
)
|
| 336 |
+
# Also accept pasted/webcam images
|
| 337 |
+
img_input = gr.Image(
|
| 338 |
+
type="pil", label="Or paste/capture an image",
|
| 339 |
+
visible=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
)
|
| 341 |
+
src_age = gr.Slider(minimum=5, maximum=95, value=25, step=1,
|
| 342 |
+
label="Source Age (current)")
|
| 343 |
+
tgt_age = gr.Slider(minimum=5, maximum=95, value=65, step=1,
|
| 344 |
+
label="Target Age (desired)")
|
| 345 |
+
btn = gr.Button("Re-Age", variant="primary", size="lg")
|
| 346 |
+
|
| 347 |
+
with gr.Column():
|
| 348 |
+
img_output = gr.Image(type="pil", label="Result (Image)")
|
| 349 |
+
vid_output = gr.Video(label="Result (Video)")
|
| 350 |
+
info_box = gr.Textbox(label="Info", interactive=False)
|
| 351 |
+
|
| 352 |
+
def on_submit_file(file_obj, source_age, target_age, progress=gr.Progress()):
|
| 353 |
+
if file_obj is None:
|
| 354 |
+
raise gr.Error("Please upload a file.")
|
| 355 |
+
return process(file_obj, source_age, target_age, progress)
|
| 356 |
+
|
| 357 |
+
def on_submit_image(image, source_age, target_age, progress=gr.Progress()):
|
| 358 |
+
if image is None:
|
| 359 |
+
raise gr.Error("Please provide an image.")
|
| 360 |
+
return process(image, source_age, target_age, progress)
|
| 361 |
+
|
| 362 |
+
btn.click(
|
| 363 |
+
fn=on_submit_file,
|
| 364 |
+
inputs=[file_input, src_age, tgt_age],
|
| 365 |
+
outputs=[img_output, vid_output, info_box],
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# Also trigger on image input (for paste/webcam)
|
| 369 |
+
img_input.change(
|
| 370 |
+
fn=on_submit_image,
|
| 371 |
+
inputs=[img_input, src_age, tgt_age],
|
| 372 |
+
outputs=[img_output, vid_output, info_box],
|
| 373 |
+
)
|
| 374 |
|
| 375 |
gr.Markdown(
|
| 376 |
"**Model:** `face_reaging.onnx` (118 MB) from "
|