Nekochu commited on
Commit
1f4128d
·
0 Parent(s):

Add Face Re-Aging CPU Space

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. .gitignore +5 -0
  3. README.md +12 -0
  4. app.py +239 -0
  5. face_reaging.onnx +3 -0
  6. requirements.txt +6 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.onnx filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .venv/
2
+ __pycache__/
3
+ *.pyc
4
+ test_face.jpg
5
+ test_result_*.jpg
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Face Re-Aging CPU
3
+ emoji: 👴
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: "6.9.0"
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ private: true
12
+ ---
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Face Re-Aging with ONNX (CPU)
3
+ Based on Disney's FRAN (Face Re-Aging Network) architecture.
4
+ Model: face_reaging.onnx from VisoMaster-Fusion.
5
+ """
6
+
7
+ import os
8
+ import time
9
+ import cv2
10
+ import numpy as np
11
+ import onnxruntime as ort
12
+ import mediapipe as mp
13
+ import gradio as gr
14
+ from PIL import Image
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Model loading
19
+ # ---------------------------------------------------------------------------
20
+ MODEL_PATH = "face_reaging.onnx"
21
+ REPO_ID = "Luminia/Face-ReAging-CPU"
22
+
23
+ def get_model_path():
24
+ if os.path.exists(MODEL_PATH):
25
+ return MODEL_PATH
26
+ return hf_hub_download(repo_id=REPO_ID, filename=MODEL_PATH)
27
+
28
+ print("Loading ONNX model...")
29
+ _so = ort.SessionOptions()
30
+ _so.intra_op_num_threads = os.cpu_count()
31
+ _so.inter_op_num_threads = os.cpu_count()
32
+ sess = ort.InferenceSession(
33
+ get_model_path(),
34
+ providers=["CPUExecutionProvider"],
35
+ sess_options=_so,
36
+ )
37
+ print("Model loaded.")
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # MediaPipe face detection
41
+ # ---------------------------------------------------------------------------
42
+ mp_face_detection = mp.solutions.face_detection
43
+
44
+ def detect_face_box(image_rgb: np.ndarray):
45
+ """
46
+ Detect the largest face bounding box using MediaPipe.
47
+ Returns (x1, y1, x2, y2) in pixel coords or None.
48
+ """
49
+ h, w = image_rgb.shape[:2]
50
+ with mp_face_detection.FaceDetection(
51
+ model_selection=1, min_detection_confidence=0.5
52
+ ) as face_det:
53
+ results = face_det.process(image_rgb)
54
+ if not results.detections:
55
+ return None
56
+
57
+ # Pick the largest detection by area
58
+ best = None
59
+ best_area = 0
60
+ for det in results.detections:
61
+ bb = det.location_data.relative_bounding_box
62
+ area = bb.width * bb.height
63
+ if area > best_area:
64
+ best_area = area
65
+ best = bb
66
+
67
+ # Convert relative to absolute
68
+ x1 = int(best.xmin * w)
69
+ y1 = int(best.ymin * h)
70
+ bw = int(best.width * w)
71
+ bh = int(best.height * h)
72
+ x2 = x1 + bw
73
+ y2 = y1 + bh
74
+ return (x1, y1, x2, y2)
75
+
76
+ # ---------------------------------------------------------------------------
77
+ # Face cropping with margin
78
+ # ---------------------------------------------------------------------------
79
+ def crop_face_region(image_rgb: np.ndarray, box):
80
+ """
81
+ Crop a square region around the detected face with generous margins
82
+ (similar to FRAN's approach: forehead gets more margin).
83
+ Returns: cropped image, (l_x, l_y, r_x, r_y) paste-back coords.
84
+ """
85
+ h, w = image_rgb.shape[:2]
86
+ x1, y1, x2, y2 = box
87
+
88
+ face_w = x2 - x1
89
+ face_h = y2 - y1
90
+
91
+ # Margins: top is larger (forehead), bottom smaller
92
+ margin_top = int(face_h * 0.63 * 0.85)
93
+ margin_bot = int(face_h * 0.37 * 0.85)
94
+ margin_x = int(face_w * 0.85 / 2)
95
+
96
+ # Adjust top margin to keep square
97
+ margin_top += 2 * margin_x - margin_top - margin_bot
98
+
99
+ l_y = max(y1 - margin_top, 0)
100
+ r_y = min(y2 + margin_bot, h)
101
+ l_x = max(x1 - margin_x, 0)
102
+ r_x = min(x2 + margin_x, w)
103
+
104
+ cropped = image_rgb[l_y:r_y, l_x:r_x, :]
105
+ return cropped, (l_x, l_y, r_x, r_y)
106
+
107
+ # ---------------------------------------------------------------------------
108
+ # Blending mask (soft feathered edges)
109
+ # ---------------------------------------------------------------------------
110
+ def create_blend_mask(crop_h, crop_w, feather=0.15):
111
+ """
112
+ Create a soft feathered blending mask to avoid hard edges
113
+ when pasting the re-aged face back.
114
+ """
115
+ mask = np.ones((crop_h, crop_w), dtype=np.float32)
116
+ border_y = max(int(crop_h * feather), 1)
117
+ border_x = max(int(crop_w * feather), 1)
118
+
119
+ for i in range(border_y):
120
+ alpha = i / border_y
121
+ mask[i, :] *= alpha
122
+ mask[crop_h - 1 - i, :] *= alpha
123
+
124
+ for j in range(border_x):
125
+ alpha = j / border_x
126
+ mask[:, j] *= alpha
127
+ mask[:, crop_w - 1 - j] *= alpha
128
+
129
+ return mask[:, :, np.newaxis] # (H, W, 1)
130
+
131
+ # ---------------------------------------------------------------------------
132
+ # Core inference
133
+ # ---------------------------------------------------------------------------
134
+ def reage_face(
135
+ image_pil: Image.Image,
136
+ source_age: int,
137
+ target_age: int,
138
+ ):
139
+ """
140
+ Re-age the face in the given PIL image.
141
+ """
142
+ t0 = time.time()
143
+
144
+ image_rgb = np.array(image_pil.convert("RGB"))
145
+ h_orig, w_orig = image_rgb.shape[:2]
146
+
147
+ # Detect face
148
+ box = detect_face_box(image_rgb)
149
+ if box is None:
150
+ raise gr.Error("No face detected in the image. Please upload a clear photo with a visible face.")
151
+
152
+ # Crop face region
153
+ cropped, (l_x, l_y, r_x, r_y) = crop_face_region(image_rgb, box)
154
+ crop_h, crop_w = cropped.shape[:2]
155
+
156
+ # Resize to 512x512 for the model
157
+ cropped_resized = cv2.resize(cropped, (512, 512), interpolation=cv2.INTER_LINEAR)
158
+
159
+ # Normalize to [0, 1] float32, CHW
160
+ img_tensor = cropped_resized.astype(np.float32) / 255.0
161
+ img_tensor = np.transpose(img_tensor, (2, 0, 1)) # (3, 512, 512)
162
+
163
+ # Create age channels
164
+ src_age_ch = np.full((1, 512, 512), source_age / 100.0, dtype=np.float32)
165
+ tgt_age_ch = np.full((1, 512, 512), target_age / 100.0, dtype=np.float32)
166
+
167
+ # Stack: (5, 512, 512) -> (1, 5, 512, 512)
168
+ input_tensor = np.concatenate([img_tensor, src_age_ch, tgt_age_ch], axis=0)
169
+ input_tensor = input_tensor[np.newaxis, ...]
170
+
171
+ # Run inference
172
+ delta = sess.run(None, {"input": input_tensor})[0] # (1, 3, 512, 512)
173
+
174
+ # Apply delta to the cropped image
175
+ aged = img_tensor + delta[0] # (3, 512, 512)
176
+ aged = np.clip(aged, 0.0, 1.0)
177
+
178
+ # Convert back to HWC uint8
179
+ aged_hwc = np.transpose(aged, (1, 2, 0)) # (512, 512, 3)
180
+ aged_hwc = (aged_hwc * 255).astype(np.uint8)
181
+
182
+ # Resize back to original crop size
183
+ aged_resized = cv2.resize(aged_hwc, (crop_w, crop_h), interpolation=cv2.INTER_LINEAR)
184
+
185
+ # Blend back into original image
186
+ result = image_rgb.copy()
187
+ blend_mask = create_blend_mask(crop_h, crop_w, feather=0.12)
188
+ region = result[l_y:r_y, l_x:r_x].astype(np.float32)
189
+ aged_f = aged_resized.astype(np.float32)
190
+ blended = region * (1 - blend_mask) + aged_f * blend_mask
191
+ result[l_y:r_y, l_x:r_x] = blended.astype(np.uint8)
192
+
193
+ elapsed = time.time() - t0
194
+ info = f"Done in {elapsed:.2f}s | Source age: {source_age} | Target age: {target_age}"
195
+
196
+ return Image.fromarray(result), info
197
+
198
+ # ---------------------------------------------------------------------------
199
+ # Gradio UI
200
+ # ---------------------------------------------------------------------------
201
+ def process(image, source_age, target_age):
202
+ if image is None:
203
+ raise gr.Error("Please upload an image.")
204
+ return reage_face(image, int(source_age), int(target_age))
205
+
206
+ with gr.Blocks(title="Face Re-Aging (CPU)") as demo:
207
+ gr.Markdown("# Face Re-Aging (CPU)\nAge or de-age faces using Disney FRAN-style model. Upload a photo, set source & target age.")
208
+
209
+ with gr.Row():
210
+ with gr.Column():
211
+ input_image = gr.Image(type="pil", label="Input Image")
212
+ source_age = gr.Slider(
213
+ minimum=5, maximum=95, value=25, step=1,
214
+ label="Source Age (current age of the person)",
215
+ )
216
+ target_age = gr.Slider(
217
+ minimum=5, maximum=95, value=65, step=1,
218
+ label="Target Age (desired age)",
219
+ )
220
+ run_btn = gr.Button("Re-Age Face", variant="primary")
221
+
222
+ with gr.Column():
223
+ output_image = gr.Image(type="pil", label="Re-Aged Result")
224
+ info_text = gr.Textbox(label="Info", interactive=False)
225
+
226
+ run_btn.click(
227
+ fn=process,
228
+ inputs=[input_image, source_age, target_age],
229
+ outputs=[output_image, info_text],
230
+ )
231
+
232
+ gr.Markdown(
233
+ "**Model:** `face_reaging.onnx` (118 MB) from "
234
+ "[VisoMaster-Fusion](https://github.com/VisoMasterFusion/VisoMaster-Fusion) | "
235
+ "Based on [Disney FRAN](https://studios.disneyresearch.com/2022/11/30/production-ready-face-re-aging-for-visual-effects/)"
236
+ )
237
+
238
+ if __name__ == "__main__":
239
+ demo.launch(show_error=True, ssr_mode=False, theme="NoCrypt/miku")
face_reaging.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62c62598a71067cf12680c8421230556d08069d172f1dc645be2a5ebe815fb1f
3
+ size 124230760
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.0
2
+ onnxruntime
3
+ opencv-python-headless
4
+ numpy
5
+ Pillow
6
+ mediapipe==0.10.14