Nekochu commited on
Commit
844e775
·
1 Parent(s): 2a828f1

Add video re-aging support with ffmpeg

Browse files
Files changed (2) hide show
  1. app.py +295 -68
  2. packages.txt +1 -0
app.py CHANGED
@@ -2,10 +2,16 @@
2
  Face Re-Aging with ONNX (CPU)
3
  Based on Disney's FRAN (Face Re-Aging Network) architecture.
4
  Model: face_reaging.onnx from VisoMaster-Fusion.
 
5
  """
6
 
7
  import os
 
 
 
8
  import time
 
 
9
  import cv2
10
  import numpy as np
11
  import onnxruntime as ort
@@ -14,11 +20,16 @@ from PIL import Image
14
  from huggingface_hub import hf_hub_download
15
 
16
  # ---------------------------------------------------------------------------
17
- # Model loading
18
  # ---------------------------------------------------------------------------
 
 
19
  MODEL_PATH = "face_reaging.onnx"
20
  REPO_ID = "Luminia/Face-ReAging-CPU"
21
 
 
 
 
22
  def get_model_path():
23
  if os.path.exists(MODEL_PATH):
24
  return MODEL_PATH
@@ -38,16 +49,14 @@ print("Model loaded.")
38
  # ---------------------------------------------------------------------------
39
  # OpenCV DNN face detection (no extra dependencies)
40
  # ---------------------------------------------------------------------------
41
- # Use OpenCV's built-in Haar cascade as primary, with DNN SSD as fallback
42
  _face_cascade = cv2.CascadeClassifier(
43
  cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
44
  )
45
 
46
- # Try to use the more accurate DNN face detector if available
47
- _dnn_net = None
48
  _dnn_model_path = os.path.join(os.path.dirname(__file__), "face_detection_yunet_2023mar.onnx")
49
  YUNET_URL = "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx"
50
 
 
51
  def _ensure_yunet():
52
  """Download YuNet face detector if not present."""
53
  global _dnn_model_path
@@ -79,7 +88,6 @@ def detect_face_box(image_rgb: np.ndarray):
79
  detector = cv2.FaceDetectorYN.create(yunet_path, "", (w, h), 0.5, 0.3, 5000)
80
  _, faces = detector.detect(image_rgb)
81
  if faces is not None and len(faces) > 0:
82
- # Pick largest face by area
83
  best_idx = 0
84
  best_area = 0
85
  for i, face in enumerate(faces):
@@ -101,7 +109,6 @@ def detect_face_box(image_rgb: np.ndarray):
101
  if len(faces) == 0:
102
  return None
103
 
104
- # Pick largest
105
  best_idx = np.argmax([fw * fh for (_, _, fw, fh) in faces])
106
  x, y, fw, fh = faces[best_idx]
107
  return (x, y, x + fw, y + fh)
@@ -111,8 +118,7 @@ def detect_face_box(image_rgb: np.ndarray):
111
  # ---------------------------------------------------------------------------
112
  def crop_face_region(image_rgb: np.ndarray, box):
113
  """
114
- Crop a square region around the detected face with generous margins
115
- (similar to FRAN's approach: forehead gets more margin).
116
  Returns: cropped image, (l_x, l_y, r_x, r_y) paste-back coords.
117
  """
118
  h, w = image_rgb.shape[:2]
@@ -121,12 +127,9 @@ def crop_face_region(image_rgb: np.ndarray, box):
121
  face_w = x2 - x1
122
  face_h = y2 - y1
123
 
124
- # Margins: top is larger (forehead), bottom smaller
125
  margin_top = int(face_h * 0.63 * 0.85)
126
  margin_bot = int(face_h * 0.37 * 0.85)
127
  margin_x = int(face_w * 0.85 / 2)
128
-
129
- # Adjust top margin to keep square
130
  margin_top += 2 * margin_x - margin_top - margin_bot
131
 
132
  l_y = max(y1 - margin_top, 0)
@@ -141,10 +144,7 @@ def crop_face_region(image_rgb: np.ndarray, box):
141
  # Blending mask (soft feathered edges)
142
  # ---------------------------------------------------------------------------
143
  def create_blend_mask(crop_h, crop_w, feather=0.15):
144
- """
145
- Create a soft feathered blending mask to avoid hard edges
146
- when pasting the re-aged face back.
147
- """
148
  mask = np.ones((crop_h, crop_w), dtype=np.float32)
149
  border_y = max(int(crop_h * feather), 1)
150
  border_x = max(int(crop_w * feather), 1)
@@ -159,63 +159,44 @@ def create_blend_mask(crop_h, crop_w, feather=0.15):
159
  mask[:, j] *= alpha
160
  mask[:, crop_w - 1 - j] *= alpha
161
 
162
- return mask[:, :, np.newaxis] # (H, W, 1)
163
 
164
  # ---------------------------------------------------------------------------
165
- # Core inference
166
  # ---------------------------------------------------------------------------
167
- def reage_face(
168
- image_pil: Image.Image,
169
- source_age: int,
170
- target_age: int,
171
- ):
172
  """
173
- Re-age the face in the given PIL image.
 
174
  """
175
- t0 = time.time()
176
-
177
- image_rgb = np.array(image_pil.convert("RGB"))
178
- h_orig, w_orig = image_rgb.shape[:2]
179
-
180
- # Detect face
181
  box = detect_face_box(image_rgb)
182
  if box is None:
183
- raise gr.Error("No face detected in the image. Please upload a clear photo with a visible face.")
184
 
185
- # Crop face region
186
  cropped, (l_x, l_y, r_x, r_y) = crop_face_region(image_rgb, box)
187
  crop_h, crop_w = cropped.shape[:2]
188
 
189
- # Resize to 512x512 for the model
190
  cropped_resized = cv2.resize(cropped, (512, 512), interpolation=cv2.INTER_LINEAR)
191
 
192
- # Normalize to [0, 1] float32, CHW
193
  img_tensor = cropped_resized.astype(np.float32) / 255.0
194
- img_tensor = np.transpose(img_tensor, (2, 0, 1)) # (3, 512, 512)
195
 
196
- # Create age channels
197
  src_age_ch = np.full((1, 512, 512), source_age / 100.0, dtype=np.float32)
198
  tgt_age_ch = np.full((1, 512, 512), target_age / 100.0, dtype=np.float32)
199
 
200
- # Stack: (5, 512, 512) -> (1, 5, 512, 512)
201
  input_tensor = np.concatenate([img_tensor, src_age_ch, tgt_age_ch], axis=0)
202
  input_tensor = input_tensor[np.newaxis, ...]
203
 
204
- # Run inference
205
- delta = sess.run(None, {"input": input_tensor})[0] # (1, 3, 512, 512)
206
 
207
- # Apply delta to the cropped image
208
- aged = img_tensor + delta[0] # (3, 512, 512)
209
  aged = np.clip(aged, 0.0, 1.0)
210
 
211
- # Convert back to HWC uint8
212
- aged_hwc = np.transpose(aged, (1, 2, 0)) # (512, 512, 3)
213
  aged_hwc = (aged_hwc * 255).astype(np.uint8)
214
 
215
- # Resize back to original crop size
216
  aged_resized = cv2.resize(aged_hwc, (crop_w, crop_h), interpolation=cv2.INTER_LINEAR)
217
 
218
- # Blend back into original image
219
  result = image_rgb.copy()
220
  blend_mask = create_blend_mask(crop_h, crop_w, feather=0.12)
221
  region = result[l_y:r_y, l_x:r_x].astype(np.float32)
@@ -223,45 +204,291 @@ def reage_face(
223
  blended = region * (1 - blend_mask) + aged_f * blend_mask
224
  result[l_y:r_y, l_x:r_x] = blended.astype(np.uint8)
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  elapsed = time.time() - t0
227
  info = f"Done in {elapsed:.2f}s | Source age: {source_age} | Target age: {target_age}"
228
-
229
  return Image.fromarray(result), info
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  # ---------------------------------------------------------------------------
232
  # Gradio UI
233
  # ---------------------------------------------------------------------------
234
- def process(image, source_age, target_age):
235
  if image is None:
236
  raise gr.Error("Please upload an image.")
237
  return reage_face(image, int(source_age), int(target_age))
238
 
239
- with gr.Blocks(title="Face Re-Aging (CPU)") as demo:
240
- gr.Markdown("# Face Re-Aging (CPU)\nAge or de-age faces using Disney FRAN-style model. Upload a photo, set source & target age.")
241
-
242
- with gr.Row():
243
- with gr.Column():
244
- input_image = gr.Image(type="pil", label="Input Image")
245
- source_age = gr.Slider(
246
- minimum=5, maximum=95, value=25, step=1,
247
- label="Source Age (current age of the person)",
248
- )
249
- target_age = gr.Slider(
250
- minimum=5, maximum=95, value=65, step=1,
251
- label="Target Age (desired age)",
252
- )
253
- run_btn = gr.Button("Re-Age Face", variant="primary")
254
 
255
- with gr.Column():
256
- output_image = gr.Image(type="pil", label="Re-Aged Result")
257
- info_text = gr.Textbox(label="Info", interactive=False)
 
 
258
 
259
- run_btn.click(
260
- fn=process,
261
- inputs=[input_image, source_age, target_age],
262
- outputs=[output_image, info_text],
 
263
  )
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  gr.Markdown(
266
  "**Model:** `face_reaging.onnx` (118 MB) from "
267
  "[VisoMaster-Fusion](https://github.com/VisoMasterFusion/VisoMaster-Fusion) | "
 
2
  Face Re-Aging with ONNX (CPU)
3
  Based on Disney's FRAN (Face Re-Aging Network) architecture.
4
  Model: face_reaging.onnx from VisoMaster-Fusion.
5
+ Supports single image and video re-aging.
6
  """
7
 
8
  import os
9
+ import shutil
10
+ import subprocess
11
+ import tempfile
12
  import time
13
+ import glob as glob_mod
14
+
15
  import cv2
16
  import numpy as np
17
  import onnxruntime as ort
 
20
  from huggingface_hub import hf_hub_download
21
 
22
  # ---------------------------------------------------------------------------
23
+ # Constants
24
  # ---------------------------------------------------------------------------
25
+ MAX_VIDEO_SECONDS = 30
26
+ MAX_FRAMES = 900
27
  MODEL_PATH = "face_reaging.onnx"
28
  REPO_ID = "Luminia/Face-ReAging-CPU"
29
 
30
+ # ---------------------------------------------------------------------------
31
+ # Model loading
32
+ # ---------------------------------------------------------------------------
33
  def get_model_path():
34
  if os.path.exists(MODEL_PATH):
35
  return MODEL_PATH
 
49
  # ---------------------------------------------------------------------------
50
  # OpenCV DNN face detection (no extra dependencies)
51
  # ---------------------------------------------------------------------------
 
52
  _face_cascade = cv2.CascadeClassifier(
53
  cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
54
  )
55
 
 
 
56
  _dnn_model_path = os.path.join(os.path.dirname(__file__), "face_detection_yunet_2023mar.onnx")
57
  YUNET_URL = "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx"
58
 
59
+
60
  def _ensure_yunet():
61
  """Download YuNet face detector if not present."""
62
  global _dnn_model_path
 
88
  detector = cv2.FaceDetectorYN.create(yunet_path, "", (w, h), 0.5, 0.3, 5000)
89
  _, faces = detector.detect(image_rgb)
90
  if faces is not None and len(faces) > 0:
 
91
  best_idx = 0
92
  best_area = 0
93
  for i, face in enumerate(faces):
 
109
  if len(faces) == 0:
110
  return None
111
 
 
112
  best_idx = np.argmax([fw * fh for (_, _, fw, fh) in faces])
113
  x, y, fw, fh = faces[best_idx]
114
  return (x, y, x + fw, y + fh)
 
118
  # ---------------------------------------------------------------------------
119
  def crop_face_region(image_rgb: np.ndarray, box):
120
  """
121
+ Crop a square region around the detected face with generous margins.
 
122
  Returns: cropped image, (l_x, l_y, r_x, r_y) paste-back coords.
123
  """
124
  h, w = image_rgb.shape[:2]
 
127
  face_w = x2 - x1
128
  face_h = y2 - y1
129
 
 
130
  margin_top = int(face_h * 0.63 * 0.85)
131
  margin_bot = int(face_h * 0.37 * 0.85)
132
  margin_x = int(face_w * 0.85 / 2)
 
 
133
  margin_top += 2 * margin_x - margin_top - margin_bot
134
 
135
  l_y = max(y1 - margin_top, 0)
 
144
  # Blending mask (soft feathered edges)
145
  # ---------------------------------------------------------------------------
146
  def create_blend_mask(crop_h, crop_w, feather=0.15):
147
+ """Create a soft feathered blending mask."""
 
 
 
148
  mask = np.ones((crop_h, crop_w), dtype=np.float32)
149
  border_y = max(int(crop_h * feather), 1)
150
  border_x = max(int(crop_w * feather), 1)
 
159
  mask[:, j] *= alpha
160
  mask[:, crop_w - 1 - j] *= alpha
161
 
162
+ return mask[:, :, np.newaxis]
163
 
164
  # ---------------------------------------------------------------------------
165
+ # Core inference on a single frame (numpy RGB in, numpy RGB out)
166
  # ---------------------------------------------------------------------------
167
+ def reage_frame(image_rgb: np.ndarray, source_age: int, target_age: int) -> np.ndarray:
 
 
 
 
168
  """
169
+ Re-age the face in a numpy RGB image.
170
+ Returns the re-aged image (same size), or original if no face found.
171
  """
 
 
 
 
 
 
172
  box = detect_face_box(image_rgb)
173
  if box is None:
174
+ return image_rgb # no face, return unchanged
175
 
 
176
  cropped, (l_x, l_y, r_x, r_y) = crop_face_region(image_rgb, box)
177
  crop_h, crop_w = cropped.shape[:2]
178
 
 
179
  cropped_resized = cv2.resize(cropped, (512, 512), interpolation=cv2.INTER_LINEAR)
180
 
 
181
  img_tensor = cropped_resized.astype(np.float32) / 255.0
182
+ img_tensor = np.transpose(img_tensor, (2, 0, 1))
183
 
 
184
  src_age_ch = np.full((1, 512, 512), source_age / 100.0, dtype=np.float32)
185
  tgt_age_ch = np.full((1, 512, 512), target_age / 100.0, dtype=np.float32)
186
 
 
187
  input_tensor = np.concatenate([img_tensor, src_age_ch, tgt_age_ch], axis=0)
188
  input_tensor = input_tensor[np.newaxis, ...]
189
 
190
+ delta = sess.run(None, {"input": input_tensor})[0]
 
191
 
192
+ aged = img_tensor + delta[0]
 
193
  aged = np.clip(aged, 0.0, 1.0)
194
 
195
+ aged_hwc = np.transpose(aged, (1, 2, 0))
 
196
  aged_hwc = (aged_hwc * 255).astype(np.uint8)
197
 
 
198
  aged_resized = cv2.resize(aged_hwc, (crop_w, crop_h), interpolation=cv2.INTER_LINEAR)
199
 
 
200
  result = image_rgb.copy()
201
  blend_mask = create_blend_mask(crop_h, crop_w, feather=0.12)
202
  region = result[l_y:r_y, l_x:r_x].astype(np.float32)
 
204
  blended = region * (1 - blend_mask) + aged_f * blend_mask
205
  result[l_y:r_y, l_x:r_x] = blended.astype(np.uint8)
206
 
207
+ return result
208
+
209
+ # ---------------------------------------------------------------------------
210
+ # Image re-aging (wraps reage_frame for Gradio)
211
+ # ---------------------------------------------------------------------------
212
+ def reage_face(image_pil: Image.Image, source_age: int, target_age: int):
213
+ """Re-age the face in the given PIL image."""
214
+ t0 = time.time()
215
+ image_rgb = np.array(image_pil.convert("RGB"))
216
+
217
+ box = detect_face_box(image_rgb)
218
+ if box is None:
219
+ raise gr.Error("No face detected in the image. Please upload a clear photo with a visible face.")
220
+
221
+ result = reage_frame(image_rgb, source_age, target_age)
222
  elapsed = time.time() - t0
223
  info = f"Done in {elapsed:.2f}s | Source age: {source_age} | Target age: {target_age}"
 
224
  return Image.fromarray(result), info
225
 
226
+ # ---------------------------------------------------------------------------
227
+ # ffmpeg helpers
228
+ # ---------------------------------------------------------------------------
229
+ def _find_ffmpeg():
230
+ """Return ffmpeg path."""
231
+ path = shutil.which("ffmpeg")
232
+ if path:
233
+ return path
234
+ # HF Spaces usually have it
235
+ for p in ["/usr/bin/ffmpeg", "/usr/local/bin/ffmpeg"]:
236
+ if os.path.isfile(p):
237
+ return p
238
+ raise gr.Error("ffmpeg not found. Video processing requires ffmpeg.")
239
+
240
+
241
+ def _get_video_info(video_path: str):
242
+ """Get fps and frame count using ffprobe."""
243
+ ffprobe = shutil.which("ffprobe") or shutil.which("ffprobe", path="/usr/bin:/usr/local/bin")
244
+ if not ffprobe:
245
+ # Fallback: use OpenCV just to read metadata
246
+ cap = cv2.VideoCapture(video_path)
247
+ fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
248
+ count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
249
+ cap.release()
250
+ return fps, count
251
+
252
+ try:
253
+ r = subprocess.run(
254
+ [ffprobe, "-v", "quiet", "-print_format", "json",
255
+ "-show_streams", "-select_streams", "v:0", video_path],
256
+ capture_output=True, text=True, timeout=30,
257
+ )
258
+ import json
259
+ info = json.loads(r.stdout)
260
+ stream = info["streams"][0]
261
+ # fps
262
+ fps_str = stream.get("r_frame_rate", "25/1")
263
+ num, den = fps_str.split("/")
264
+ fps = float(num) / float(den)
265
+ # frame count
266
+ nb = stream.get("nb_frames")
267
+ if nb and nb != "N/A":
268
+ count = int(nb)
269
+ else:
270
+ dur = float(stream.get("duration", 0))
271
+ count = int(dur * fps)
272
+ return fps, count
273
+ except Exception:
274
+ cap = cv2.VideoCapture(video_path)
275
+ fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
276
+ count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
277
+ cap.release()
278
+ return fps, count
279
+
280
+
281
+ def _extract_frames(video_path: str, out_dir: str):
282
+ """Extract frames from video using ffmpeg."""
283
+ ffmpeg = _find_ffmpeg()
284
+ out_pattern = os.path.join(out_dir, "frame_%06d.png")
285
+ cmd = [ffmpeg, "-i", video_path, "-vsync", "0", out_pattern, "-y"]
286
+ r = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
287
+ if r.returncode != 0:
288
+ raise gr.Error(f"ffmpeg frame extraction failed: {r.stderr[-500:]}")
289
+
290
+
291
+ def _assemble_video(frames_dir: str, output_path: str, fps: float, audio_source: str = None):
292
+ """Reassemble frames into MP4 using ffmpeg."""
293
+ ffmpeg = _find_ffmpeg()
294
+ in_pattern = os.path.join(frames_dir, "frame_%06d.png")
295
+
296
+ cmd = [
297
+ ffmpeg, "-y",
298
+ "-framerate", str(fps),
299
+ "-i", in_pattern,
300
+ ]
301
+
302
+ # Try to copy audio from original
303
+ if audio_source:
304
+ cmd += ["-i", audio_source, "-map", "0:v", "-map", "1:a?", "-shortest"]
305
+
306
+ cmd += [
307
+ "-c:v", "libx264",
308
+ "-pix_fmt", "yuv420p",
309
+ "-preset", "fast",
310
+ "-crf", "20",
311
+ "-movflags", "+faststart",
312
+ output_path,
313
+ ]
314
+
315
+ r = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
316
+ if r.returncode != 0:
317
+ raise gr.Error(f"ffmpeg assembly failed: {r.stderr[-500:]}")
318
+
319
+ # ---------------------------------------------------------------------------
320
+ # Video re-aging
321
+ # ---------------------------------------------------------------------------
322
+ def reage_video(video_path: str, source_age: int, target_age: int, progress=gr.Progress()):
323
+ """Re-age faces in every frame of a video."""
324
+ if video_path is None:
325
+ raise gr.Error("Please upload a video.")
326
+
327
+ t0 = time.time()
328
+
329
+ # Get video info
330
+ fps, total_frames = _get_video_info(video_path)
331
+ duration = total_frames / max(fps, 1)
332
+
333
+ if duration > MAX_VIDEO_SECONDS:
334
+ raise gr.Error(
335
+ f"Video is {duration:.1f}s long. Maximum allowed is {MAX_VIDEO_SECONDS}s. "
336
+ f"Please trim your video first."
337
+ )
338
+
339
+ if total_frames > MAX_FRAMES:
340
+ raise gr.Error(
341
+ f"Video has {total_frames} frames. Maximum allowed is {MAX_FRAMES}. "
342
+ f"Please use a shorter video."
343
+ )
344
+
345
+ # Create temp dirs
346
+ tmp_root = tempfile.mkdtemp(prefix="reage_")
347
+ frames_in = os.path.join(tmp_root, "in")
348
+ frames_out = os.path.join(tmp_root, "out")
349
+ os.makedirs(frames_in, exist_ok=True)
350
+ os.makedirs(frames_out, exist_ok=True)
351
+
352
+ try:
353
+ # Extract frames
354
+ progress(0, desc="Extracting frames...")
355
+ _extract_frames(video_path, frames_in)
356
+
357
+ # Get frame list
358
+ frame_files = sorted(glob_mod.glob(os.path.join(frames_in, "frame_*.png")))
359
+ n_frames = len(frame_files)
360
+ if n_frames == 0:
361
+ raise gr.Error("No frames extracted from video. Is the file a valid video?")
362
+
363
+ # Re-check limit after extraction
364
+ if n_frames > MAX_FRAMES:
365
+ raise gr.Error(f"Video has {n_frames} frames (max {MAX_FRAMES}). Please use a shorter video.")
366
+
367
+ faces_found = 0
368
+ faces_missed = 0
369
+
370
+ # Process each frame
371
+ for idx, fpath in enumerate(frame_files):
372
+ progress((idx + 1) / n_frames, desc=f"Re-aging frame {idx + 1}/{n_frames}...")
373
+
374
+ # Read frame (BGR -> RGB)
375
+ frame_bgr = cv2.imread(fpath)
376
+ if frame_bgr is None:
377
+ continue
378
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
379
+
380
+ # Detect and re-age
381
+ box = detect_face_box(frame_rgb)
382
+ if box is not None:
383
+ result_rgb = reage_frame(frame_rgb, source_age, target_age)
384
+ faces_found += 1
385
+ else:
386
+ result_rgb = frame_rgb
387
+ faces_missed += 1
388
+
389
+ # Save (RGB -> BGR)
390
+ fname = os.path.basename(fpath)
391
+ out_path = os.path.join(frames_out, fname)
392
+ result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
393
+ cv2.imwrite(out_path, result_bgr)
394
+
395
+ # Assemble video
396
+ progress(1.0, desc="Assembling video...")
397
+ output_path = os.path.join(tmp_root, "output.mp4")
398
+ _assemble_video(frames_out, output_path, fps, audio_source=video_path)
399
+
400
+ elapsed = time.time() - t0
401
+ speed = n_frames / max(elapsed, 0.01)
402
+ info = (
403
+ f"Done in {elapsed:.1f}s | {n_frames} frames at {speed:.1f} fps | "
404
+ f"Faces found: {faces_found}, skipped: {faces_missed} | "
405
+ f"Source age: {source_age} -> Target age: {target_age}"
406
+ )
407
+
408
+ return output_path, info
409
+
410
+ except gr.Error:
411
+ raise
412
+ except Exception as e:
413
+ raise gr.Error(f"Video processing failed: {str(e)}")
414
+
415
  # ---------------------------------------------------------------------------
416
  # Gradio UI
417
  # ---------------------------------------------------------------------------
418
+ def process_image(image, source_age, target_age):
419
  if image is None:
420
  raise gr.Error("Please upload an image.")
421
  return reage_face(image, int(source_age), int(target_age))
422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
+ def process_video(video, source_age, target_age, progress=gr.Progress()):
425
+ if video is None:
426
+ raise gr.Error("Please upload a video.")
427
+ return reage_video(video, int(source_age), int(target_age), progress)
428
+
429
 
430
+ with gr.Blocks(title="Face Re-Aging (CPU)") as demo:
431
+ gr.Markdown(
432
+ "# Face Re-Aging (CPU)\n"
433
+ "Age or de-age faces using Disney FRAN-style model. "
434
+ "Works on both **images** and **videos**."
435
  )
436
 
437
+ with gr.Tabs():
438
+ # ---- Image Tab ----
439
+ with gr.TabItem("Image"):
440
+ with gr.Row():
441
+ with gr.Column():
442
+ img_input = gr.Image(type="pil", label="Input Image")
443
+ img_src_age = gr.Slider(
444
+ minimum=5, maximum=95, value=25, step=1,
445
+ label="Source Age (current age)",
446
+ )
447
+ img_tgt_age = gr.Slider(
448
+ minimum=5, maximum=95, value=65, step=1,
449
+ label="Target Age (desired age)",
450
+ )
451
+ img_btn = gr.Button("Re-Age Face", variant="primary")
452
+
453
+ with gr.Column():
454
+ img_output = gr.Image(type="pil", label="Re-Aged Result")
455
+ img_info = gr.Textbox(label="Info", interactive=False)
456
+
457
+ img_btn.click(
458
+ fn=process_image,
459
+ inputs=[img_input, img_src_age, img_tgt_age],
460
+ outputs=[img_output, img_info],
461
+ )
462
+
463
+ # ---- Video Tab ----
464
+ with gr.TabItem("Video"):
465
+ gr.Markdown(
466
+ f"Upload a video (max **{MAX_VIDEO_SECONDS}s** / **{MAX_FRAMES} frames**). "
467
+ f"Each frame is processed individually on CPU, so expect ~0.5-2 fps."
468
+ )
469
+ with gr.Row():
470
+ with gr.Column():
471
+ vid_input = gr.Video(label="Input Video")
472
+ vid_src_age = gr.Slider(
473
+ minimum=5, maximum=95, value=25, step=1,
474
+ label="Source Age (current age)",
475
+ )
476
+ vid_tgt_age = gr.Slider(
477
+ minimum=5, maximum=95, value=65, step=1,
478
+ label="Target Age (desired age)",
479
+ )
480
+ vid_btn = gr.Button("Re-Age Video", variant="primary")
481
+
482
+ with gr.Column():
483
+ vid_output = gr.Video(label="Re-Aged Video")
484
+ vid_info = gr.Textbox(label="Info", interactive=False)
485
+
486
+ vid_btn.click(
487
+ fn=process_video,
488
+ inputs=[vid_input, vid_src_age, vid_tgt_age],
489
+ outputs=[vid_output, vid_info],
490
+ )
491
+
492
  gr.Markdown(
493
  "**Model:** `face_reaging.onnx` (118 MB) from "
494
  "[VisoMaster-Fusion](https://github.com/VisoMasterFusion/VisoMaster-Fusion) | "
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg