Nekochu commited on
Commit
f2a1251
·
1 Parent(s): 844e775

unified view: single input for image or video

Browse files
Files changed (1) hide show
  1. app.py +181 -298
app.py CHANGED
@@ -2,7 +2,7 @@
2
  Face Re-Aging with ONNX (CPU)
3
  Based on Disney's FRAN (Face Re-Aging Network) architecture.
4
  Model: face_reaging.onnx from VisoMaster-Fusion.
5
- Supports single image and video re-aging.
6
  """
7
 
8
  import os
@@ -47,18 +47,16 @@ sess = ort.InferenceSession(
47
  print("Model loaded.")
48
 
49
  # ---------------------------------------------------------------------------
50
- # OpenCV DNN face detection (no extra dependencies)
51
  # ---------------------------------------------------------------------------
52
  _face_cascade = cv2.CascadeClassifier(
53
  cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
54
  )
55
-
56
  _dnn_model_path = os.path.join(os.path.dirname(__file__), "face_detection_yunet_2023mar.onnx")
57
  YUNET_URL = "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx"
58
 
59
 
60
  def _ensure_yunet():
61
- """Download YuNet face detector if not present."""
62
  global _dnn_model_path
63
  if not os.path.exists(_dnn_model_path):
64
  print("Downloading YuNet face detector...")
@@ -76,26 +74,13 @@ def _ensure_yunet():
76
 
77
 
78
  def detect_face_box(image_rgb: np.ndarray):
79
- """
80
- Detect the largest face bounding box.
81
- Returns (x1, y1, x2, y2) in pixel coords or None.
82
- """
83
  h, w = image_rgb.shape[:2]
84
-
85
- # Try YuNet first (more accurate)
86
  try:
87
  yunet_path = _ensure_yunet()
88
  detector = cv2.FaceDetectorYN.create(yunet_path, "", (w, h), 0.5, 0.3, 5000)
89
  _, faces = detector.detect(image_rgb)
90
  if faces is not None and len(faces) > 0:
91
- best_idx = 0
92
- best_area = 0
93
- for i, face in enumerate(faces):
94
- fw, fh = face[2], face[3]
95
- area = fw * fh
96
- if area > best_area:
97
- best_area = area
98
- best_idx = i
99
  f = faces[best_idx]
100
  x1, y1 = int(f[0]), int(f[1])
101
  x2, y2 = int(f[0] + f[2]), int(f[1] + f[3])
@@ -103,172 +88,104 @@ def detect_face_box(image_rgb: np.ndarray):
103
  except Exception as e:
104
  print(f"YuNet failed, falling back to Haar: {e}")
105
 
106
- # Fallback: Haar cascade
107
  gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY)
108
  faces = _face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60))
109
  if len(faces) == 0:
110
  return None
111
-
112
  best_idx = np.argmax([fw * fh for (_, _, fw, fh) in faces])
113
  x, y, fw, fh = faces[best_idx]
114
  return (x, y, x + fw, y + fh)
115
 
116
  # ---------------------------------------------------------------------------
117
- # Face cropping with margin
118
  # ---------------------------------------------------------------------------
119
- def crop_face_region(image_rgb: np.ndarray, box):
120
- """
121
- Crop a square region around the detected face with generous margins.
122
- Returns: cropped image, (l_x, l_y, r_x, r_y) paste-back coords.
123
- """
124
  h, w = image_rgb.shape[:2]
125
  x1, y1, x2, y2 = box
126
-
127
- face_w = x2 - x1
128
- face_h = y2 - y1
129
-
130
  margin_top = int(face_h * 0.63 * 0.85)
131
  margin_bot = int(face_h * 0.37 * 0.85)
132
  margin_x = int(face_w * 0.85 / 2)
133
  margin_top += 2 * margin_x - margin_top - margin_bot
 
 
 
134
 
135
- l_y = max(y1 - margin_top, 0)
136
- r_y = min(y2 + margin_bot, h)
137
- l_x = max(x1 - margin_x, 0)
138
- r_x = min(x2 + margin_x, w)
139
-
140
- cropped = image_rgb[l_y:r_y, l_x:r_x, :]
141
- return cropped, (l_x, l_y, r_x, r_y)
142
 
143
- # ---------------------------------------------------------------------------
144
- # Blending mask (soft feathered edges)
145
- # ---------------------------------------------------------------------------
146
  def create_blend_mask(crop_h, crop_w, feather=0.15):
147
- """Create a soft feathered blending mask."""
148
  mask = np.ones((crop_h, crop_w), dtype=np.float32)
149
- border_y = max(int(crop_h * feather), 1)
150
- border_x = max(int(crop_w * feather), 1)
151
-
152
- for i in range(border_y):
153
- alpha = i / border_y
154
- mask[i, :] *= alpha
155
- mask[crop_h - 1 - i, :] *= alpha
156
-
157
- for j in range(border_x):
158
- alpha = j / border_x
159
- mask[:, j] *= alpha
160
- mask[:, crop_w - 1 - j] *= alpha
161
-
162
  return mask[:, :, np.newaxis]
163
 
164
- # ---------------------------------------------------------------------------
165
- # Core inference on a single frame (numpy RGB in, numpy RGB out)
166
- # ---------------------------------------------------------------------------
167
- def reage_frame(image_rgb: np.ndarray, source_age: int, target_age: int) -> np.ndarray:
168
- """
169
- Re-age the face in a numpy RGB image.
170
- Returns the re-aged image (same size), or original if no face found.
171
- """
172
  box = detect_face_box(image_rgb)
173
  if box is None:
174
- return image_rgb # no face, return unchanged
175
 
176
  cropped, (l_x, l_y, r_x, r_y) = crop_face_region(image_rgb, box)
177
  crop_h, crop_w = cropped.shape[:2]
178
-
179
  cropped_resized = cv2.resize(cropped, (512, 512), interpolation=cv2.INTER_LINEAR)
180
 
181
- img_tensor = cropped_resized.astype(np.float32) / 255.0
182
- img_tensor = np.transpose(img_tensor, (2, 0, 1))
183
-
184
- src_age_ch = np.full((1, 512, 512), source_age / 100.0, dtype=np.float32)
185
- tgt_age_ch = np.full((1, 512, 512), target_age / 100.0, dtype=np.float32)
186
-
187
- input_tensor = np.concatenate([img_tensor, src_age_ch, tgt_age_ch], axis=0)
188
- input_tensor = input_tensor[np.newaxis, ...]
189
-
190
- delta = sess.run(None, {"input": input_tensor})[0]
191
-
192
- aged = img_tensor + delta[0]
193
- aged = np.clip(aged, 0.0, 1.0)
194
-
195
- aged_hwc = np.transpose(aged, (1, 2, 0))
196
- aged_hwc = (aged_hwc * 255).astype(np.uint8)
197
 
 
 
 
198
  aged_resized = cv2.resize(aged_hwc, (crop_w, crop_h), interpolation=cv2.INTER_LINEAR)
199
 
200
  result = image_rgb.copy()
201
- blend_mask = create_blend_mask(crop_h, crop_w, feather=0.12)
202
  region = result[l_y:r_y, l_x:r_x].astype(np.float32)
203
- aged_f = aged_resized.astype(np.float32)
204
- blended = region * (1 - blend_mask) + aged_f * blend_mask
205
  result[l_y:r_y, l_x:r_x] = blended.astype(np.uint8)
206
-
207
  return result
208
 
209
- # ---------------------------------------------------------------------------
210
- # Image re-aging (wraps reage_frame for Gradio)
211
- # ---------------------------------------------------------------------------
212
- def reage_face(image_pil: Image.Image, source_age: int, target_age: int):
213
- """Re-age the face in the given PIL image."""
214
- t0 = time.time()
215
- image_rgb = np.array(image_pil.convert("RGB"))
216
-
217
- box = detect_face_box(image_rgb)
218
- if box is None:
219
- raise gr.Error("No face detected in the image. Please upload a clear photo with a visible face.")
220
-
221
- result = reage_frame(image_rgb, source_age, target_age)
222
- elapsed = time.time() - t0
223
- info = f"Done in {elapsed:.2f}s | Source age: {source_age} | Target age: {target_age}"
224
- return Image.fromarray(result), info
225
-
226
  # ---------------------------------------------------------------------------
227
  # ffmpeg helpers
228
  # ---------------------------------------------------------------------------
229
  def _find_ffmpeg():
230
- """Return ffmpeg path."""
231
  path = shutil.which("ffmpeg")
232
  if path:
233
  return path
234
- # HF Spaces usually have it
235
  for p in ["/usr/bin/ffmpeg", "/usr/local/bin/ffmpeg"]:
236
  if os.path.isfile(p):
237
  return p
238
- raise gr.Error("ffmpeg not found. Video processing requires ffmpeg.")
239
 
240
 
241
- def _get_video_info(video_path: str):
242
- """Get fps and frame count using ffprobe."""
243
  ffprobe = shutil.which("ffprobe") or shutil.which("ffprobe", path="/usr/bin:/usr/local/bin")
244
  if not ffprobe:
245
- # Fallback: use OpenCV just to read metadata
246
  cap = cv2.VideoCapture(video_path)
247
  fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
248
  count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
249
  cap.release()
250
  return fps, count
251
-
252
  try:
 
253
  r = subprocess.run(
254
  [ffprobe, "-v", "quiet", "-print_format", "json",
255
  "-show_streams", "-select_streams", "v:0", video_path],
256
  capture_output=True, text=True, timeout=30,
257
  )
258
- import json
259
- info = json.loads(r.stdout)
260
- stream = info["streams"][0]
261
- # fps
262
- fps_str = stream.get("r_frame_rate", "25/1")
263
- num, den = fps_str.split("/")
264
  fps = float(num) / float(den)
265
- # frame count
266
  nb = stream.get("nb_frames")
267
- if nb and nb != "N/A":
268
- count = int(nb)
269
- else:
270
- dur = float(stream.get("duration", 0))
271
- count = int(dur * fps)
272
  return fps, count
273
  except Exception:
274
  cap = cv2.VideoCapture(video_path)
@@ -278,216 +195,182 @@ def _get_video_info(video_path: str):
278
  return fps, count
279
 
280
 
281
- def _extract_frames(video_path: str, out_dir: str):
282
- """Extract frames from video using ffmpeg."""
283
  ffmpeg = _find_ffmpeg()
284
- out_pattern = os.path.join(out_dir, "frame_%06d.png")
285
- cmd = [ffmpeg, "-i", video_path, "-vsync", "0", out_pattern, "-y"]
286
  r = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
287
  if r.returncode != 0:
288
- raise gr.Error(f"ffmpeg frame extraction failed: {r.stderr[-500:]}")
289
 
290
 
291
- def _assemble_video(frames_dir: str, output_path: str, fps: float, audio_source: str = None):
292
- """Reassemble frames into MP4 using ffmpeg."""
293
  ffmpeg = _find_ffmpeg()
294
- in_pattern = os.path.join(frames_dir, "frame_%06d.png")
295
-
296
- cmd = [
297
- ffmpeg, "-y",
298
- "-framerate", str(fps),
299
- "-i", in_pattern,
300
- ]
301
-
302
- # Try to copy audio from original
303
  if audio_source:
304
  cmd += ["-i", audio_source, "-map", "0:v", "-map", "1:a?", "-shortest"]
305
-
306
- cmd += [
307
- "-c:v", "libx264",
308
- "-pix_fmt", "yuv420p",
309
- "-preset", "fast",
310
- "-crf", "20",
311
- "-movflags", "+faststart",
312
- output_path,
313
- ]
314
-
315
  r = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
316
  if r.returncode != 0:
317
- raise gr.Error(f"ffmpeg assembly failed: {r.stderr[-500:]}")
318
 
319
  # ---------------------------------------------------------------------------
320
- # Video re-aging
321
  # ---------------------------------------------------------------------------
322
- def reage_video(video_path: str, source_age: int, target_age: int, progress=gr.Progress()):
323
- """Re-age faces in every frame of a video."""
324
- if video_path is None:
325
- raise gr.Error("Please upload a video.")
326
 
327
- t0 = time.time()
328
 
329
- # Get video info
330
- fps, total_frames = _get_video_info(video_path)
331
- duration = total_frames / max(fps, 1)
332
 
333
- if duration > MAX_VIDEO_SECONDS:
334
- raise gr.Error(
335
- f"Video is {duration:.1f}s long. Maximum allowed is {MAX_VIDEO_SECONDS}s. "
336
- f"Please trim your video first."
337
- )
 
 
 
 
 
 
 
 
 
338
 
339
- if total_frames > MAX_FRAMES:
340
- raise gr.Error(
341
- f"Video has {total_frames} frames. Maximum allowed is {MAX_FRAMES}. "
342
- f"Please use a shorter video."
343
- )
344
 
345
- # Create temp dirs
346
- tmp_root = tempfile.mkdtemp(prefix="reage_")
347
- frames_in = os.path.join(tmp_root, "in")
348
- frames_out = os.path.join(tmp_root, "out")
349
- os.makedirs(frames_in, exist_ok=True)
350
- os.makedirs(frames_out, exist_ok=True)
351
 
352
- try:
353
- # Extract frames
354
- progress(0, desc="Extracting frames...")
355
- _extract_frames(video_path, frames_in)
356
-
357
- # Get frame list
358
- frame_files = sorted(glob_mod.glob(os.path.join(frames_in, "frame_*.png")))
359
- n_frames = len(frame_files)
360
- if n_frames == 0:
361
- raise gr.Error("No frames extracted from video. Is the file a valid video?")
362
-
363
- # Re-check limit after extraction
364
- if n_frames > MAX_FRAMES:
365
- raise gr.Error(f"Video has {n_frames} frames (max {MAX_FRAMES}). Please use a shorter video.")
366
-
367
- faces_found = 0
368
- faces_missed = 0
369
-
370
- # Process each frame
371
- for idx, fpath in enumerate(frame_files):
372
- progress((idx + 1) / n_frames, desc=f"Re-aging frame {idx + 1}/{n_frames}...")
373
-
374
- # Read frame (BGR -> RGB)
375
- frame_bgr = cv2.imread(fpath)
376
- if frame_bgr is None:
377
- continue
378
- frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
379
-
380
- # Detect and re-age
381
- box = detect_face_box(frame_rgb)
382
- if box is not None:
383
- result_rgb = reage_frame(frame_rgb, source_age, target_age)
384
- faces_found += 1
385
- else:
386
- result_rgb = frame_rgb
387
- faces_missed += 1
388
-
389
- # Save (RGB -> BGR)
390
- fname = os.path.basename(fpath)
391
- out_path = os.path.join(frames_out, fname)
392
- result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
393
- cv2.imwrite(out_path, result_bgr)
394
-
395
- # Assemble video
396
- progress(1.0, desc="Assembling video...")
397
- output_path = os.path.join(tmp_root, "output.mp4")
398
- _assemble_video(frames_out, output_path, fps, audio_source=video_path)
399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  elapsed = time.time() - t0
401
- speed = n_frames / max(elapsed, 0.01)
402
- info = (
403
- f"Done in {elapsed:.1f}s | {n_frames} frames at {speed:.1f} fps | "
404
- f"Faces found: {faces_found}, skipped: {faces_missed} | "
405
- f"Source age: {source_age} -> Target age: {target_age}"
406
- )
407
 
408
- return output_path, info
409
-
410
- except gr.Error:
411
- raise
412
- except Exception as e:
413
- raise gr.Error(f"Video processing failed: {str(e)}")
414
 
415
  # ---------------------------------------------------------------------------
416
- # Gradio UI
417
  # ---------------------------------------------------------------------------
418
- def process_image(image, source_age, target_age):
419
- if image is None:
420
- raise gr.Error("Please upload an image.")
421
- return reage_face(image, int(source_age), int(target_age))
422
-
423
-
424
- def process_video(video, source_age, target_age, progress=gr.Progress()):
425
- if video is None:
426
- raise gr.Error("Please upload a video.")
427
- return reage_video(video, int(source_age), int(target_age), progress)
428
-
429
-
430
  with gr.Blocks(title="Face Re-Aging (CPU)") as demo:
431
  gr.Markdown(
432
  "# Face Re-Aging (CPU)\n"
433
- "Age or de-age faces using Disney FRAN-style model. "
434
- "Works on both **images** and **videos**."
435
  )
436
 
437
- with gr.Tabs():
438
- # ---- Image Tab ----
439
- with gr.TabItem("Image"):
440
- with gr.Row():
441
- with gr.Column():
442
- img_input = gr.Image(type="pil", label="Input Image")
443
- img_src_age = gr.Slider(
444
- minimum=5, maximum=95, value=25, step=1,
445
- label="Source Age (current age)",
446
- )
447
- img_tgt_age = gr.Slider(
448
- minimum=5, maximum=95, value=65, step=1,
449
- label="Target Age (desired age)",
450
- )
451
- img_btn = gr.Button("Re-Age Face", variant="primary")
452
-
453
- with gr.Column():
454
- img_output = gr.Image(type="pil", label="Re-Aged Result")
455
- img_info = gr.Textbox(label="Info", interactive=False)
456
-
457
- img_btn.click(
458
- fn=process_image,
459
- inputs=[img_input, img_src_age, img_tgt_age],
460
- outputs=[img_output, img_info],
461
- )
462
-
463
- # ---- Video Tab ----
464
- with gr.TabItem("Video"):
465
- gr.Markdown(
466
- f"Upload a video (max **{MAX_VIDEO_SECONDS}s** / **{MAX_FRAMES} frames**). "
467
- f"Each frame is processed individually on CPU, so expect ~0.5-2 fps."
468
  )
469
- with gr.Row():
470
- with gr.Column():
471
- vid_input = gr.Video(label="Input Video")
472
- vid_src_age = gr.Slider(
473
- minimum=5, maximum=95, value=25, step=1,
474
- label="Source Age (current age)",
475
- )
476
- vid_tgt_age = gr.Slider(
477
- minimum=5, maximum=95, value=65, step=1,
478
- label="Target Age (desired age)",
479
- )
480
- vid_btn = gr.Button("Re-Age Video", variant="primary")
481
-
482
- with gr.Column():
483
- vid_output = gr.Video(label="Re-Aged Video")
484
- vid_info = gr.Textbox(label="Info", interactive=False)
485
-
486
- vid_btn.click(
487
- fn=process_video,
488
- inputs=[vid_input, vid_src_age, vid_tgt_age],
489
- outputs=[vid_output, vid_info],
490
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
 
492
  gr.Markdown(
493
  "**Model:** `face_reaging.onnx` (118 MB) from "
 
2
  Face Re-Aging with ONNX (CPU)
3
  Based on Disney's FRAN (Face Re-Aging Network) architecture.
4
  Model: face_reaging.onnx from VisoMaster-Fusion.
5
+ Supports image and video re-aging in a single unified view.
6
  """
7
 
8
  import os
 
47
  print("Model loaded.")
48
 
49
  # ---------------------------------------------------------------------------
50
+ # Face detection
51
  # ---------------------------------------------------------------------------
52
  _face_cascade = cv2.CascadeClassifier(
53
  cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
54
  )
 
55
  _dnn_model_path = os.path.join(os.path.dirname(__file__), "face_detection_yunet_2023mar.onnx")
56
  YUNET_URL = "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx"
57
 
58
 
59
  def _ensure_yunet():
 
60
  global _dnn_model_path
61
  if not os.path.exists(_dnn_model_path):
62
  print("Downloading YuNet face detector...")
 
74
 
75
 
76
  def detect_face_box(image_rgb: np.ndarray):
 
 
 
 
77
  h, w = image_rgb.shape[:2]
 
 
78
  try:
79
  yunet_path = _ensure_yunet()
80
  detector = cv2.FaceDetectorYN.create(yunet_path, "", (w, h), 0.5, 0.3, 5000)
81
  _, faces = detector.detect(image_rgb)
82
  if faces is not None and len(faces) > 0:
83
+ best_idx = int(np.argmax([f[2] * f[3] for f in faces]))
 
 
 
 
 
 
 
84
  f = faces[best_idx]
85
  x1, y1 = int(f[0]), int(f[1])
86
  x2, y2 = int(f[0] + f[2]), int(f[1] + f[3])
 
88
  except Exception as e:
89
  print(f"YuNet failed, falling back to Haar: {e}")
90
 
 
91
  gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY)
92
  faces = _face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60))
93
  if len(faces) == 0:
94
  return None
 
95
  best_idx = np.argmax([fw * fh for (_, _, fw, fh) in faces])
96
  x, y, fw, fh = faces[best_idx]
97
  return (x, y, x + fw, y + fh)
98
 
99
  # ---------------------------------------------------------------------------
100
+ # Core inference
101
  # ---------------------------------------------------------------------------
102
+ def crop_face_region(image_rgb, box):
 
 
 
 
103
  h, w = image_rgb.shape[:2]
104
  x1, y1, x2, y2 = box
105
+ face_w, face_h = x2 - x1, y2 - y1
 
 
 
106
  margin_top = int(face_h * 0.63 * 0.85)
107
  margin_bot = int(face_h * 0.37 * 0.85)
108
  margin_x = int(face_w * 0.85 / 2)
109
  margin_top += 2 * margin_x - margin_top - margin_bot
110
+ l_y, r_y = max(y1 - margin_top, 0), min(y2 + margin_bot, h)
111
+ l_x, r_x = max(x1 - margin_x, 0), min(x2 + margin_x, w)
112
+ return image_rgb[l_y:r_y, l_x:r_x, :], (l_x, l_y, r_x, r_y)
113
 
 
 
 
 
 
 
 
114
 
 
 
 
115
  def create_blend_mask(crop_h, crop_w, feather=0.15):
 
116
  mask = np.ones((crop_h, crop_w), dtype=np.float32)
117
+ by, bx = max(int(crop_h * feather), 1), max(int(crop_w * feather), 1)
118
+ for i in range(by):
119
+ a = i / by
120
+ mask[i, :] *= a
121
+ mask[crop_h - 1 - i, :] *= a
122
+ for j in range(bx):
123
+ a = j / bx
124
+ mask[:, j] *= a
125
+ mask[:, crop_w - 1 - j] *= a
 
 
 
 
126
  return mask[:, :, np.newaxis]
127
 
128
+
129
+ def reage_frame(image_rgb, source_age, target_age):
 
 
 
 
 
 
130
  box = detect_face_box(image_rgb)
131
  if box is None:
132
+ return image_rgb
133
 
134
  cropped, (l_x, l_y, r_x, r_y) = crop_face_region(image_rgb, box)
135
  crop_h, crop_w = cropped.shape[:2]
 
136
  cropped_resized = cv2.resize(cropped, (512, 512), interpolation=cv2.INTER_LINEAR)
137
 
138
+ img_t = cropped_resized.astype(np.float32) / 255.0
139
+ img_t = np.transpose(img_t, (2, 0, 1))
140
+ src_ch = np.full((1, 512, 512), source_age / 100.0, dtype=np.float32)
141
+ tgt_ch = np.full((1, 512, 512), target_age / 100.0, dtype=np.float32)
142
+ inp = np.concatenate([img_t, src_ch, tgt_ch], axis=0)[np.newaxis, ...]
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ delta = sess.run(None, {"input": inp})[0]
145
+ aged = np.clip(img_t + delta[0], 0.0, 1.0)
146
+ aged_hwc = (np.transpose(aged, (1, 2, 0)) * 255).astype(np.uint8)
147
  aged_resized = cv2.resize(aged_hwc, (crop_w, crop_h), interpolation=cv2.INTER_LINEAR)
148
 
149
  result = image_rgb.copy()
150
+ mask = create_blend_mask(crop_h, crop_w, feather=0.12)
151
  region = result[l_y:r_y, l_x:r_x].astype(np.float32)
152
+ blended = region * (1 - mask) + aged_resized.astype(np.float32) * mask
 
153
  result[l_y:r_y, l_x:r_x] = blended.astype(np.uint8)
 
154
  return result
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  # ---------------------------------------------------------------------------
157
  # ffmpeg helpers
158
  # ---------------------------------------------------------------------------
159
  def _find_ffmpeg():
 
160
  path = shutil.which("ffmpeg")
161
  if path:
162
  return path
 
163
  for p in ["/usr/bin/ffmpeg", "/usr/local/bin/ffmpeg"]:
164
  if os.path.isfile(p):
165
  return p
166
+ raise gr.Error("ffmpeg not found.")
167
 
168
 
169
+ def _get_video_info(video_path):
 
170
  ffprobe = shutil.which("ffprobe") or shutil.which("ffprobe", path="/usr/bin:/usr/local/bin")
171
  if not ffprobe:
 
172
  cap = cv2.VideoCapture(video_path)
173
  fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
174
  count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
175
  cap.release()
176
  return fps, count
 
177
  try:
178
+ import json
179
  r = subprocess.run(
180
  [ffprobe, "-v", "quiet", "-print_format", "json",
181
  "-show_streams", "-select_streams", "v:0", video_path],
182
  capture_output=True, text=True, timeout=30,
183
  )
184
+ stream = json.loads(r.stdout)["streams"][0]
185
+ num, den = stream.get("r_frame_rate", "25/1").split("/")
 
 
 
 
186
  fps = float(num) / float(den)
 
187
  nb = stream.get("nb_frames")
188
+ count = int(nb) if nb and nb != "N/A" else int(float(stream.get("duration", 0)) * fps)
 
 
 
 
189
  return fps, count
190
  except Exception:
191
  cap = cv2.VideoCapture(video_path)
 
195
  return fps, count
196
 
197
 
198
+ def _extract_frames(video_path, out_dir):
 
199
  ffmpeg = _find_ffmpeg()
200
+ cmd = [ffmpeg, "-i", video_path, "-vsync", "0", os.path.join(out_dir, "frame_%06d.png"), "-y"]
 
201
  r = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
202
  if r.returncode != 0:
203
+ raise gr.Error(f"Frame extraction failed: {r.stderr[-500:]}")
204
 
205
 
206
+ def _assemble_video(frames_dir, output_path, fps, audio_source=None):
 
207
  ffmpeg = _find_ffmpeg()
208
+ cmd = [ffmpeg, "-y", "-framerate", str(fps), "-i", os.path.join(frames_dir, "frame_%06d.png")]
 
 
 
 
 
 
 
 
209
  if audio_source:
210
  cmd += ["-i", audio_source, "-map", "0:v", "-map", "1:a?", "-shortest"]
211
+ cmd += ["-c:v", "libx264", "-pix_fmt", "yuv420p", "-preset", "fast", "-crf", "20",
212
+ "-movflags", "+faststart", output_path]
 
 
 
 
 
 
 
 
213
  r = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
214
  if r.returncode != 0:
215
+ raise gr.Error(f"Video assembly failed: {r.stderr[-500:]}")
216
 
217
  # ---------------------------------------------------------------------------
218
+ # Unified process function
219
  # ---------------------------------------------------------------------------
220
+ VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv", ".wmv", ".m4v"}
 
 
 
221
 
 
222
 
223
+ def process(input_file, source_age, target_age, progress=gr.Progress()):
224
+ if input_file is None:
225
+ raise gr.Error("Please upload an image or video.")
226
 
227
+ t0 = time.time()
228
+ source_age, target_age = int(source_age), int(target_age)
229
+
230
+ # Determine if image or video
231
+ if isinstance(input_file, Image.Image):
232
+ # Direct PIL image from gr.Image
233
+ image_rgb = np.array(input_file.convert("RGB"))
234
+ box = detect_face_box(image_rgb)
235
+ if box is None:
236
+ raise gr.Error("No face detected. Please upload a clear photo with a visible face.")
237
+ result = reage_frame(image_rgb, source_age, target_age)
238
+ elapsed = time.time() - t0
239
+ info = f"Done in {elapsed:.2f}s | {source_age} -> {target_age} years"
240
+ return Image.fromarray(result), None, info
241
 
242
+ # File path (could be image or video)
243
+ file_path = input_file if isinstance(input_file, str) else str(input_file)
244
+ ext = os.path.splitext(file_path)[1].lower()
 
 
245
 
246
+ if ext in VIDEO_EXTS:
247
+ # --- Video processing ---
248
+ fps, total_frames = _get_video_info(file_path)
249
+ duration = total_frames / max(fps, 1)
 
 
250
 
251
+ if duration > MAX_VIDEO_SECONDS:
252
+ raise gr.Error(f"Video is {duration:.1f}s (max {MAX_VIDEO_SECONDS}s). Please trim it.")
253
+ if total_frames > MAX_FRAMES:
254
+ raise gr.Error(f"Video has {total_frames} frames (max {MAX_FRAMES}).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
+ tmp_root = tempfile.mkdtemp(prefix="reage_")
257
+ frames_in = os.path.join(tmp_root, "in")
258
+ frames_out = os.path.join(tmp_root, "out")
259
+ os.makedirs(frames_in, exist_ok=True)
260
+ os.makedirs(frames_out, exist_ok=True)
261
+
262
+ try:
263
+ progress(0, desc="Extracting frames...")
264
+ _extract_frames(file_path, frames_in)
265
+
266
+ frame_files = sorted(glob_mod.glob(os.path.join(frames_in, "frame_*.png")))
267
+ n_frames = len(frame_files)
268
+ if n_frames == 0:
269
+ raise gr.Error("No frames extracted. Is this a valid video?")
270
+ if n_frames > MAX_FRAMES:
271
+ raise gr.Error(f"{n_frames} frames (max {MAX_FRAMES}).")
272
+
273
+ faces_found, faces_missed = 0, 0
274
+ for idx, fpath in enumerate(frame_files):
275
+ progress((idx + 1) / n_frames, desc=f"Re-aging frame {idx + 1}/{n_frames}...")
276
+ frame_bgr = cv2.imread(fpath)
277
+ if frame_bgr is None:
278
+ continue
279
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
280
+ box = detect_face_box(frame_rgb)
281
+ if box is not None:
282
+ result_rgb = reage_frame(frame_rgb, source_age, target_age)
283
+ faces_found += 1
284
+ else:
285
+ result_rgb = frame_rgb
286
+ faces_missed += 1
287
+ out_path = os.path.join(frames_out, os.path.basename(fpath))
288
+ cv2.imwrite(out_path, cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR))
289
+
290
+ progress(1.0, desc="Assembling video...")
291
+ output_path = os.path.join(tmp_root, "output.mp4")
292
+ _assemble_video(frames_out, output_path, fps, audio_source=file_path)
293
+
294
+ elapsed = time.time() - t0
295
+ speed = n_frames / max(elapsed, 0.01)
296
+ info = (f"Done in {elapsed:.1f}s | {n_frames} frames at {speed:.1f} fps | "
297
+ f"Faces: {faces_found} found, {faces_missed} skipped | "
298
+ f"{source_age} -> {target_age} years")
299
+ return None, output_path, info
300
+
301
+ except gr.Error:
302
+ raise
303
+ except Exception as e:
304
+ raise gr.Error(f"Video processing failed: {e}")
305
+ else:
306
+ # --- Image processing ---
307
+ image_rgb = cv2.imread(file_path)
308
+ if image_rgb is None:
309
+ raise gr.Error("Could not read the file. Please upload a valid image or video.")
310
+ image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2RGB)
311
+ box = detect_face_box(image_rgb)
312
+ if box is None:
313
+ raise gr.Error("No face detected.")
314
+ result = reage_frame(image_rgb, source_age, target_age)
315
  elapsed = time.time() - t0
316
+ info = f"Done in {elapsed:.2f}s | {source_age} -> {target_age} years"
317
+ return Image.fromarray(result), None, info
 
 
 
 
318
 
 
 
 
 
 
 
319
 
320
  # ---------------------------------------------------------------------------
321
+ # Gradio UI - Single unified view
322
  # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
323
  with gr.Blocks(title="Face Re-Aging (CPU)") as demo:
324
  gr.Markdown(
325
  "# Face Re-Aging (CPU)\n"
326
+ "Upload an **image or video** to age or de-age faces. "
327
+ f"Videos: max {MAX_VIDEO_SECONDS}s, ~0.5-2 fps on CPU."
328
  )
329
 
330
+ with gr.Row():
331
+ with gr.Column():
332
+ file_input = gr.File(
333
+ label="Drop Image or Video Here",
334
+ file_types=["image", "video"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  )
336
+ # Also accept pasted/webcam images
337
+ img_input = gr.Image(
338
+ type="pil", label="Or paste/capture an image",
339
+ visible=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  )
341
+ src_age = gr.Slider(minimum=5, maximum=95, value=25, step=1,
342
+ label="Source Age (current)")
343
+ tgt_age = gr.Slider(minimum=5, maximum=95, value=65, step=1,
344
+ label="Target Age (desired)")
345
+ btn = gr.Button("Re-Age", variant="primary", size="lg")
346
+
347
+ with gr.Column():
348
+ img_output = gr.Image(type="pil", label="Result (Image)")
349
+ vid_output = gr.Video(label="Result (Video)")
350
+ info_box = gr.Textbox(label="Info", interactive=False)
351
+
352
+ def on_submit_file(file_obj, source_age, target_age, progress=gr.Progress()):
353
+ if file_obj is None:
354
+ raise gr.Error("Please upload a file.")
355
+ return process(file_obj, source_age, target_age, progress)
356
+
357
+ def on_submit_image(image, source_age, target_age, progress=gr.Progress()):
358
+ if image is None:
359
+ raise gr.Error("Please provide an image.")
360
+ return process(image, source_age, target_age, progress)
361
+
362
+ btn.click(
363
+ fn=on_submit_file,
364
+ inputs=[file_input, src_age, tgt_age],
365
+ outputs=[img_output, vid_output, info_box],
366
+ )
367
+
368
+ # Also trigger on image input (for paste/webcam)
369
+ img_input.change(
370
+ fn=on_submit_image,
371
+ inputs=[img_input, src_age, tgt_age],
372
+ outputs=[img_output, vid_output, info_box],
373
+ )
374
 
375
  gr.Markdown(
376
  "**Model:** `face_reaging.onnx` (118 MB) from "