Remy commited on
Commit
33e4e1d
·
verified ·
1 Parent(s): 93f2115

Update ActionMesh space

Browse files
Files changed (2) hide show
  1. app.py +17 -6
  2. gradio_pipeline.py +115 -0
app.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  """
2
  ActionMesh Gradio Demo
3
 
@@ -205,11 +211,11 @@ from actionmesh.io.mesh_io import save_deformation
205
 
206
  # --- Import ActionMesh modules after setup ---
207
  from actionmesh.io.video_input import load_frames
208
- from actionmesh.pipeline import ActionMeshPipeline
209
  from actionmesh.render.utils import save_rgba_video
 
210
 
211
  # Global pipeline instance (loaded on CPU at startup)
212
- pipeline: ActionMeshPipeline | None = None
213
 
214
 
215
  def get_available_examples() -> list[tuple[str, str]]:
@@ -275,7 +281,7 @@ def load_example_images(evt: gr.SelectData) -> list[str]:
275
  return []
276
 
277
 
278
- def load_pipeline_cpu() -> ActionMeshPipeline:
279
  """Load the ActionMesh pipeline on CPU (called once at module load)."""
280
  global pipeline
281
  if pipeline is None:
@@ -283,7 +289,7 @@ def load_pipeline_cpu() -> ActionMeshPipeline:
283
  # Get config path from actionmesh cache directory
284
  cache_dir = Path.home() / ".cache" / "actionmesh"
285
  config_dir = str(cache_dir / "actionmesh" / "configs")
286
- pipeline = ActionMeshPipeline(
287
  config_name="actionmesh.yaml",
288
  config_dir=config_dir,
289
  )
@@ -355,7 +361,7 @@ def _run_actionmesh_impl(
355
  torch.cuda.empty_cache()
356
 
357
  # Run inference
358
- progress(0.3, desc="Running ActionMesh inference...")
359
 
360
  # Set steps based on quality mode
361
  if quality_mode == "⚡ Fast":
@@ -365,12 +371,17 @@ def _run_actionmesh_impl(
365
  stage_0_steps = 100
366
  stage_1_steps = 30
367
 
 
 
 
 
368
  meshes = pipe(
369
  input=input_data,
370
  anchor_idx=reference_frame - 1, # Convert from 1-indexed UI to 0-indexed
371
  stage_0_steps=stage_0_steps,
372
  stage_1_steps=stage_1_steps,
373
  seed=seed,
 
374
  )
375
 
376
  # Save input video
@@ -381,7 +392,7 @@ def _run_actionmesh_impl(
381
  return None, None, None, "Error: No meshes generated."
382
 
383
  # Save deformations and create animated GLB
384
- progress(0.9, desc="Creating animated GLB...")
385
 
386
  vertices_path, faces_path = save_deformation(
387
  meshes, path=f"{output_dir}/deformations"
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
  """
8
  ActionMesh Gradio Demo
9
 
 
211
 
212
  # --- Import ActionMesh modules after setup ---
213
  from actionmesh.io.video_input import load_frames
 
214
  from actionmesh.render.utils import save_rgba_video
215
+ from gradio_pipeline import GradioPipeline
216
 
217
  # Global pipeline instance (loaded on CPU at startup)
218
+ pipeline: GradioPipeline | None = None
219
 
220
 
221
  def get_available_examples() -> list[tuple[str, str]]:
 
281
  return []
282
 
283
 
284
+ def load_pipeline_cpu() -> GradioPipeline:
285
  """Load the ActionMesh pipeline on CPU (called once at module load)."""
286
  global pipeline
287
  if pipeline is None:
 
289
  # Get config path from actionmesh cache directory
290
  cache_dir = Path.home() / ".cache" / "actionmesh"
291
  config_dir = str(cache_dir / "actionmesh" / "configs")
292
+ pipeline = GradioPipeline(
293
  config_name="actionmesh.yaml",
294
  config_dir=config_dir,
295
  )
 
361
  torch.cuda.empty_cache()
362
 
363
  # Run inference
364
+ progress(None, desc="Starting pipeline...")
365
 
366
  # Set steps based on quality mode
367
  if quality_mode == "⚡ Fast":
 
371
  stage_0_steps = 100
372
  stage_1_steps = 30
373
 
374
+ # Create progress callback for the pipeline
375
+ def pipeline_progress_callback(value: float, desc: str) -> None:
376
+ progress(value, desc=desc)
377
+
378
  meshes = pipe(
379
  input=input_data,
380
  anchor_idx=reference_frame - 1, # Convert from 1-indexed UI to 0-indexed
381
  stage_0_steps=stage_0_steps,
382
  stage_1_steps=stage_1_steps,
383
  seed=seed,
384
+ progress_callback=pipeline_progress_callback,
385
  )
386
 
387
  # Save input video
 
392
  return None, None, None, "Error: No meshes generated."
393
 
394
  # Save deformations and create animated GLB
395
+ progress(1.0, desc="Creating animated GLB...")
396
 
397
  vertices_path, faces_path = save_deformation(
398
  meshes, path=f"{output_dir}/deformations"
gradio_pipeline.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ GradioPipeline: ActionMesh pipeline with Gradio progress tracking.
9
+
10
+ This module provides a subclass of ActionMeshPipeline that adds progress
11
+ callbacks for integration with Gradio's progress bar.
12
+ """
13
+
14
+ from typing import Callable, Optional
15
+
16
+ import torch
17
+ import trimesh
18
+ from actionmesh.io.video_input import ActionMeshInput
19
+ from actionmesh.pipeline import ActionMeshPipeline
20
+
21
+ ProgressCallback = Callable[[float, str], None]
22
+
23
+
24
+ class GradioPipeline(ActionMeshPipeline):
25
+ """
26
+ ActionMesh pipeline with Gradio progress tracking support.
27
+
28
+ Progress breakdown:
29
+ - 0% -> 10%: Anchor 3D generation (image_to_3d)
30
+ - 10% -> 90%: Stage 1 - Flow matching denoising (step-by-step)
31
+ - 90% -> 100%: Stage 2 - Mesh decoding (step-by-step)
32
+ """
33
+
34
+ def __call__(
35
+ self,
36
+ input: ActionMeshInput,
37
+ seed: int = 44,
38
+ stage_0_steps: int | None = None,
39
+ face_decimation: int | None = None,
40
+ floaters_threshold: float | None = None,
41
+ stage_1_steps: int | None = None,
42
+ guidance_scales: list[float] | None = None,
43
+ anchor_idx: int | None = None,
44
+ progress_callback: Optional[ProgressCallback] = None,
45
+ ) -> list[trimesh.Trimesh]:
46
+ """Generate an animated mesh sequence with progress tracking."""
47
+ # Apply parameter overrides
48
+ if stage_0_steps is not None:
49
+ self.cfg.model.image_to_3D_denoiser.num_inference_steps = stage_0_steps
50
+ if stage_1_steps is not None:
51
+ self.scheduler.num_inference_steps = stage_1_steps
52
+ if guidance_scales is not None:
53
+ self.cf_guidance.guidance_scales = guidance_scales
54
+ if face_decimation is not None:
55
+ self.mesh_process.face_decimation = face_decimation
56
+ if floaters_threshold is not None:
57
+ self.mesh_process.floaters_threshold = floaters_threshold
58
+ if anchor_idx is not None:
59
+ self.cfg.anchor_idx = anchor_idx
60
+
61
+ # -- Preprocessing: remove background
62
+ input.frames = self.background_removal.process_images(input.frames)
63
+
64
+ # -- Preprocessing: grouped cropping & padding
65
+ input.frames = self.image_process.process_images(input.frames)
66
+
67
+ with torch.inference_mode():
68
+ # -- Stage 0: generate anchor 3D mesh & latent from single frame
69
+ latent_bank, mesh_bank = self.init_banks_from_anchor(input, seed)
70
+
71
+ if progress_callback is not None:
72
+ progress_callback(0.10, "Anchor 3D generated, starting Stage 1...")
73
+
74
+ # Stage 1 callback: 10% -> 90%
75
+ def stage1_callback(
76
+ step: int, total_steps: int, window_idx: int, total_windows: int
77
+ ) -> None:
78
+ if progress_callback is not None:
79
+ window_progress = (window_idx + step / total_steps) / total_windows
80
+ progress_callback(
81
+ 0.10 + 0.80 * window_progress,
82
+ f"Stage 1: step {step}/{total_steps} ",
83
+ )
84
+
85
+ # Stage 2 callback: 90% -> 100%
86
+ def stage2_callback(
87
+ step: int, total_steps: int, window_idx: int, total_windows: int
88
+ ) -> None:
89
+ if progress_callback is not None:
90
+ window_progress = (window_idx + step / total_steps) / total_windows
91
+ progress_callback(
92
+ 0.90 + 0.10 * window_progress,
93
+ f"Stage 2: step {step}/{total_steps} ",
94
+ )
95
+
96
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
97
+ # -- Stage I: denoise synchronized 3D latents
98
+ latent_bank = self.generate_3d_latents(
99
+ input,
100
+ latent_bank=latent_bank,
101
+ seed=seed,
102
+ step_callback=stage1_callback,
103
+ )
104
+
105
+ # -- Stage II: decode latents into mesh displacements
106
+ mesh_bank = self.generate_mesh_animation(
107
+ latent_bank=latent_bank,
108
+ mesh_bank=mesh_bank,
109
+ step_callback=stage2_callback,
110
+ )
111
+
112
+ if progress_callback is not None:
113
+ progress_callback(1.0, "Pipeline complete!")
114
+
115
+ return mesh_bank.get_ordered(device="cpu")[0]