wlyu-adobe commited on
Commit
090676f
·
1 Parent(s): f563fa2

Initial commit

Browse files
Files changed (1) hide show
  1. app.py +2 -131
app.py CHANGED
@@ -47,9 +47,6 @@ from gslrm.model.gaussians_renderer import render_turntable, imageseq2video
47
  from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
48
  from utils_folder.face_utils import preprocess_image, preprocess_image_without_cropping
49
 
50
- # Import convert function for PLY to SPLAT conversion
51
- import convert as ply_to_splat
52
-
53
  # HuggingFace repository configuration
54
  HF_REPO_ID = "wlyu/OpenFaceLift"
55
 
@@ -148,120 +145,6 @@ class FaceLiftPipeline:
148
  self._models_on_gpu = True
149
  print("Models on GPU, xformers enabled!")
150
 
151
- def _create_viewer_html(self, splat_path):
152
- """Create standalone HTML viewer for the gaussian splat."""
153
- import base64
154
-
155
- # Read the splat file and encode as base64
156
- with open(splat_path, 'rb') as f:
157
- splat_data = f.read()
158
- splat_b64 = base64.b64encode(splat_data).decode('utf-8')
159
-
160
- # Read the main.js content and modify it to use embedded data
161
- with open(Path(__file__).parent / "main.js", 'r') as f:
162
- js_content = f.read()
163
-
164
- # Replace the URL fetching part with blob URL
165
- js_content = js_content.replace(
166
- 'params.get("url") || "train.splat"',
167
- 'window.EMBEDDED_SPLAT_URL || "train.splat"'
168
- )
169
-
170
- html = f"""<!DOCTYPE html>
171
- <html lang="en">
172
- <head>
173
- <meta charset="UTF-8">
174
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
175
- <title>3D Gaussian Splat Viewer - FaceLift</title>
176
- <style>
177
- body {{
178
- margin: 0;
179
- padding: 0;
180
- overflow: hidden;
181
- font-family: Arial, sans-serif;
182
- background: #000;
183
- }}
184
- #canvas {{
185
- width: 100vw;
186
- height: 100vh;
187
- display: block;
188
- }}
189
- #info {{
190
- position: absolute;
191
- top: 10px;
192
- left: 10px;
193
- background: rgba(0, 0, 0, 0.7);
194
- color: white;
195
- padding: 10px;
196
- border-radius: 5px;
197
- font-size: 12px;
198
- max-width: 300px;
199
- z-index: 1000;
200
- }}
201
- #info h3 {{
202
- margin: 0 0 10px 0;
203
- font-size: 14px;
204
- }}
205
- #loading {{
206
- position: absolute;
207
- top: 50%;
208
- left: 50%;
209
- transform: translate(-50%, -50%);
210
- color: white;
211
- font-size: 18px;
212
- text-align: center;
213
- z-index: 999;
214
- }}
215
- </style>
216
- </head>
217
- <body>
218
- <canvas id="canvas"></canvas>
219
- <div id="loading">Loading 3D model...</div>
220
- <div id="info">
221
- <h3>Controls</h3>
222
- <p><b>Mouse:</b> Click and drag to orbit</p>
223
- <p><b>Right click/Ctrl+drag:</b> Move forward/back (up/down), strafe (left/right)</p>
224
- <p><b>Arrow keys:</b> Move forward/back, strafe left/right</p>
225
- <p><b>WASD:</b> Rotate camera</p>
226
- <p><b>Space:</b> Jump</p>
227
- <p><b>Q/E:</b> Roll camera</p>
228
- </div>
229
- <script>
230
- // Convert base64 to blob and create URL
231
- const SPLAT_DATA_B64 = '{splat_b64}';
232
-
233
- function b64toBlob(b64Data, contentType='application/octet-stream', sliceSize=512) {{
234
- const byteCharacters = atob(b64Data);
235
- const byteArrays = [];
236
- for (let offset = 0; offset < byteCharacters.length; offset += sliceSize) {{
237
- const slice = byteCharacters.slice(offset, offset + sliceSize);
238
- const byteNumbers = new Array(slice.length);
239
- for (let i = 0; i < slice.length; i++) {{
240
- byteNumbers[i] = slice.charCodeAt(i);
241
- }}
242
- const byteArray = new Uint8Array(byteNumbers);
243
- byteArrays.push(byteArray);
244
- }}
245
- return new Blob(byteArrays, {{type: contentType}});
246
- }}
247
-
248
- // Create blob URL for embedded splat data
249
- const blob = b64toBlob(SPLAT_DATA_B64);
250
- window.EMBEDDED_SPLAT_URL = URL.createObjectURL(blob);
251
-
252
- // Remove loading message once render starts
253
- setTimeout(() => {{
254
- const loading = document.getElementById('loading');
255
- if (loading) loading.style.display = 'none';
256
- }}, 2000);
257
- </script>
258
- <script>
259
- {js_content}
260
- </script>
261
- </body>
262
- </html>"""
263
- return html
264
-
265
  @spaces.GPU
266
  def generate_3d_head(self, image_path, auto_crop=True, guidance_scale=3.0,
267
  random_seed=4, num_steps=50):
@@ -358,16 +241,6 @@ class FaceLiftPipeline:
358
  ply_path = output_dir / "gaussians.ply"
359
  filtered_gaussians.save_ply(str(ply_path))
360
 
361
- # Convert PLY to SPLAT format for web viewer
362
- splat_path = output_dir / "gaussians.splat"
363
- ply_to_splat.convert(str(ply_path), str(splat_path))
364
-
365
- # Create HTML viewer
366
- viewer_html = self._create_viewer_html(str(splat_path))
367
- viewer_path = output_dir / "viewer.html"
368
- with open(viewer_path, 'w') as f:
369
- f.write(viewer_html)
370
-
371
  # Save output image
372
  comp_image = rearrange(comp_image, "x v c h w -> (x h) (v w) c")
373
  comp_image = (comp_image.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
@@ -384,7 +257,7 @@ class FaceLiftPipeline:
384
  imageseq2video(turntable_frames, str(turntable_path), fps=30)
385
 
386
  return str(input_path), str(multiview_path), str(output_path), \
387
- str(turntable_path), str(viewer_path), str(ply_path)
388
 
389
  except Exception as e:
390
  raise gr.Error(f"Generation failed: {str(e)}")
@@ -405,12 +278,11 @@ def main():
405
  fn=pipeline.generate_3d_head,
406
  title="FaceLift: Single Image 3D Face Reconstruction",
407
  description="""
408
- Transform a single portrait image into a complete 3D head model with an interactive WebGL viewer.
409
 
410
  **Tips:**
411
  - Use high-quality portrait images with clear facial features
412
  - If face detection fails, try disabling auto-cropping and manually crop to square
413
- - Download the Interactive 3D Viewer HTML file to explore your model in full screen
414
  """,
415
  inputs=[
416
  gr.Image(type="filepath", label="Input Portrait Image"),
@@ -424,7 +296,6 @@ def main():
424
  gr.Image(label="Multi-view Generation"),
425
  gr.Image(label="3D Reconstruction"),
426
  gr.PlayableVideo(label="Turntable Animation"),
427
- gr.File(label="Interactive 3D Viewer (.html) - Download & Open"),
428
  gr.File(label="3D Model (.ply)"),
429
  ],
430
  examples=examples,
 
47
  from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
48
  from utils_folder.face_utils import preprocess_image, preprocess_image_without_cropping
49
 
 
 
 
50
  # HuggingFace repository configuration
51
  HF_REPO_ID = "wlyu/OpenFaceLift"
52
 
 
145
  self._models_on_gpu = True
146
  print("Models on GPU, xformers enabled!")
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  @spaces.GPU
149
  def generate_3d_head(self, image_path, auto_crop=True, guidance_scale=3.0,
150
  random_seed=4, num_steps=50):
 
241
  ply_path = output_dir / "gaussians.ply"
242
  filtered_gaussians.save_ply(str(ply_path))
243
 
 
 
 
 
 
 
 
 
 
 
244
  # Save output image
245
  comp_image = rearrange(comp_image, "x v c h w -> (x h) (v w) c")
246
  comp_image = (comp_image.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
 
257
  imageseq2video(turntable_frames, str(turntable_path), fps=30)
258
 
259
  return str(input_path), str(multiview_path), str(output_path), \
260
+ str(turntable_path), str(ply_path)
261
 
262
  except Exception as e:
263
  raise gr.Error(f"Generation failed: {str(e)}")
 
278
  fn=pipeline.generate_3d_head,
279
  title="FaceLift: Single Image 3D Face Reconstruction",
280
  description="""
281
+ Transform a single portrait image into a complete 3D head model.
282
 
283
  **Tips:**
284
  - Use high-quality portrait images with clear facial features
285
  - If face detection fails, try disabling auto-cropping and manually crop to square
 
286
  """,
287
  inputs=[
288
  gr.Image(type="filepath", label="Input Portrait Image"),
 
296
  gr.Image(label="Multi-view Generation"),
297
  gr.Image(label="3D Reconstruction"),
298
  gr.PlayableVideo(label="Turntable Animation"),
 
299
  gr.File(label="3D Model (.ply)"),
300
  ],
301
  examples=examples,