Spaces:
Runtime error
Runtime error
Wei Liu Claude Sonnet 4.6 commited on
Commit ·
0cdce4a
1
Parent(s): a1190a9
CPU-first startup: load all models/scenes to CPU at module level, GPU transfer at generation time
Browse files- startup() now uses device="cpu" for StreamingVideoGenerator and InteractiveSimulator
- Added move_pipeline_to_device() to StreamingVideoGenerator
- Added move_to_device() to _MinimalSVR and InteractiveSimulator
- do_generate() transfers everything to GPU at start, back to CPU in finally
- Warmup deferred to first generation call (CUDA kernel compile on GPU)
- Avoids ZeroGPU time limit: only fast tensor moves happen inside GPU slot
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- app.py +19 -25
- simulation_engine.py +49 -0
- video_generator.py +22 -8
app.py
CHANGED
|
@@ -286,6 +286,7 @@ def startup():
|
|
| 286 |
use_ema=USE_EMA,
|
| 287 |
seed=SEED,
|
| 288 |
enable_taehv=ENABLE_TAEHV,
|
|
|
|
| 289 |
)
|
| 290 |
video_generator.setup()
|
| 291 |
log_gpu("after video generator setup")
|
|
@@ -303,7 +304,7 @@ def startup():
|
|
| 303 |
if case_name == "santa_cloth":
|
| 304 |
config_overrides["skip_force_fields"] = True
|
| 305 |
|
| 306 |
-
simulator = InteractiveSimulator(str(case_dir), config_overrides=config_overrides)
|
| 307 |
simulator.config["debug"] = False
|
| 308 |
log_gpu(f"after simulator init ({case_name})")
|
| 309 |
|
|
@@ -350,17 +351,11 @@ def startup():
|
|
| 350 |
log_gpu("after finish_precompute")
|
| 351 |
|
| 352 |
# ---- Step 5: Warmup ----
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
# Release per-case precomputed tensors (i2v_conditional, full_y,
|
| 357 |
-
# default_text_features) back to CPU so ZeroGPU can reclaim VRAM
|
| 358 |
-
# between the startup @spaces.GPU slot and the first generation call.
|
| 359 |
-
# The main model weights (transformer, VAE, text_encoder) stay on GPU;
|
| 360 |
-
# ZeroGPU preserves the process's GPU slot across calls for efficiency.
|
| 361 |
-
video_generator.move_case_data_to_device("cpu")
|
| 362 |
torch.cuda.empty_cache()
|
| 363 |
-
print("[6/6]
|
| 364 |
|
| 365 |
|
| 366 |
# ---------------------------------------------------------------------------
|
|
@@ -488,20 +483,17 @@ def do_generate(case_name, prompt, d0, s0, d1, s1, d2, s2):
|
|
| 488 |
_is_generating = True
|
| 489 |
_stop_event.clear()
|
| 490 |
|
| 491 |
-
# Lazy full initialization: load models + build physics scenes + precompute.
|
| 492 |
-
# Runs only on the first generation; subsequent calls skip this branch.
|
| 493 |
if video_generator is None:
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
except Exception as e:
|
| 498 |
-
import traceback; traceback.print_exc()
|
| 499 |
-
_is_generating = False
|
| 500 |
-
yield None, f"Initialization error: {e}"
|
| 501 |
-
return
|
| 502 |
|
| 503 |
-
#
|
|
|
|
| 504 |
video_generator.move_case_data_to_device("cuda")
|
|
|
|
|
|
|
|
|
|
| 505 |
|
| 506 |
bundle = cases[case_name]
|
| 507 |
|
|
@@ -691,10 +683,12 @@ def do_generate(case_name, prompt, d0, s0, d1, s1, d2, s2):
|
|
| 691 |
render_thread.join(timeout=10)
|
| 692 |
if warp_thread is not None:
|
| 693 |
warp_thread.join(timeout=10)
|
| 694 |
-
# Release precomputed case tensors to CPU so ZeroGPU can reclaim
|
| 695 |
-
# VRAM for other users once this generation session ends.
|
| 696 |
if video_generator is not None:
|
|
|
|
| 697 |
video_generator.move_case_data_to_device("cpu")
|
|
|
|
|
|
|
|
|
|
| 698 |
torch.cuda.empty_cache()
|
| 699 |
_is_generating = False
|
| 700 |
|
|
@@ -921,7 +915,7 @@ def build_demo():
|
|
| 921 |
# files are already on disk so snapshot_download() is a fast no-op. By doing
|
| 922 |
# this here we avoid holding a ZeroGPU allocation while waiting on downloads.
|
| 923 |
_ensure_models_downloaded()
|
| 924 |
-
|
| 925 |
demo = build_demo()
|
| 926 |
|
| 927 |
if __name__ == "__main__":
|
|
|
|
| 286 |
use_ema=USE_EMA,
|
| 287 |
seed=SEED,
|
| 288 |
enable_taehv=ENABLE_TAEHV,
|
| 289 |
+
device="cpu",
|
| 290 |
)
|
| 291 |
video_generator.setup()
|
| 292 |
log_gpu("after video generator setup")
|
|
|
|
| 304 |
if case_name == "santa_cloth":
|
| 305 |
config_overrides["skip_force_fields"] = True
|
| 306 |
|
| 307 |
+
simulator = InteractiveSimulator(str(case_dir), device="cpu", config_overrides=config_overrides)
|
| 308 |
simulator.config["debug"] = False
|
| 309 |
log_gpu(f"after simulator init ({case_name})")
|
| 310 |
|
|
|
|
| 351 |
log_gpu("after finish_precompute")
|
| 352 |
|
| 353 |
# ---- Step 5: Warmup ----
|
| 354 |
+
# Warmup (CUDA kernel compilation) is deferred to first generation call.
|
| 355 |
+
print("[5/6] Skipping warmup at CPU-only startup — CUDA kernels compile on first generation.")
|
| 356 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
torch.cuda.empty_cache()
|
| 358 |
+
print("[6/6] CPU-only startup complete — models and scenes ready. GPU transfer at generation time.")
|
| 359 |
|
| 360 |
|
| 361 |
# ---------------------------------------------------------------------------
|
|
|
|
| 483 |
_is_generating = True
|
| 484 |
_stop_event.clear()
|
| 485 |
|
|
|
|
|
|
|
| 486 |
if video_generator is None:
|
| 487 |
+
_is_generating = False
|
| 488 |
+
yield None, "Error: models not initialized. Please reload the Space."
|
| 489 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
|
| 491 |
+
# Transfer all CPU-resident state to GPU for this generation session.
|
| 492 |
+
video_generator.move_pipeline_to_device("cuda")
|
| 493 |
video_generator.move_case_data_to_device("cuda")
|
| 494 |
+
for _b in cases.values():
|
| 495 |
+
if _b.simulator is not None:
|
| 496 |
+
_b.simulator.move_to_device("cuda")
|
| 497 |
|
| 498 |
bundle = cases[case_name]
|
| 499 |
|
|
|
|
| 683 |
render_thread.join(timeout=10)
|
| 684 |
if warp_thread is not None:
|
| 685 |
warp_thread.join(timeout=10)
|
|
|
|
|
|
|
| 686 |
if video_generator is not None:
|
| 687 |
+
video_generator.move_pipeline_to_device("cpu")
|
| 688 |
video_generator.move_case_data_to_device("cpu")
|
| 689 |
+
for _b in cases.values():
|
| 690 |
+
if _b is not None and _b.simulator is not None:
|
| 691 |
+
_b.simulator.move_to_device("cpu")
|
| 692 |
torch.cuda.empty_cache()
|
| 693 |
_is_generating = False
|
| 694 |
|
|
|
|
| 915 |
# files are already on disk so snapshot_download() is a fast no-op. By doing
|
| 916 |
# this here we avoid holding a ZeroGPU allocation while waiting on downloads.
|
| 917 |
_ensure_models_downloaded()
|
| 918 |
+
startup() # Load all models and scenes to CPU at module level
|
| 919 |
demo = build_demo()
|
| 920 |
|
| 921 |
if __name__ == "__main__":
|
simulation_engine.py
CHANGED
|
@@ -306,6 +306,34 @@ class InteractiveSimulator:
|
|
| 306 |
def set_demo_case_handler(self, handler):
|
| 307 |
self.demo_case_handler = handler
|
| 308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
def _load_object_masks(self):
|
| 310 |
masks_dir = self.demo_data_path / "fg_masks"
|
| 311 |
if not masks_dir.exists():
|
|
@@ -615,6 +643,27 @@ class _MinimalSVR:
|
|
| 615 |
compositor=AlphaCompositor(),
|
| 616 |
)
|
| 617 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 618 |
def update_fg_obj_info(self, all_obj_points):
|
| 619 |
for idx, pts in enumerate(all_obj_points):
|
| 620 |
self.fg_pcs[idx]["points"] = pts.clone()
|
|
|
|
| 306 |
def set_demo_case_handler(self, handler):
|
| 307 |
self.demo_case_handler = handler
|
| 308 |
|
| 309 |
+
def move_to_device(self, device):
|
| 310 |
+
"""Move all renderer/simulation tensors to target device (CPU↔GPU)."""
|
| 311 |
+
dev = torch.device(device)
|
| 312 |
+
self.device = dev
|
| 313 |
+
# Move SVR (PyTorch3D renderer + camera + point clouds)
|
| 314 |
+
self.svr.move_to_device(dev)
|
| 315 |
+
# Move mesh data
|
| 316 |
+
for mesh in self.fg_meshes:
|
| 317 |
+
for k, v in list(mesh.items()):
|
| 318 |
+
if isinstance(v, torch.Tensor):
|
| 319 |
+
mesh[k] = v.to(dev)
|
| 320 |
+
# Move foreground point clouds
|
| 321 |
+
for pc_list in (self.fg_pcs_pt3d, self.fg_pcs_gs):
|
| 322 |
+
for pc in pc_list:
|
| 323 |
+
for k, v in list(pc.items()):
|
| 324 |
+
if isinstance(v, torch.Tensor):
|
| 325 |
+
pc[k] = v.to(dev)
|
| 326 |
+
# Move per-object transform matrices and initial particles
|
| 327 |
+
for k in list(self.initial_transform_matrix.keys()):
|
| 328 |
+
self.initial_transform_matrix[k] = self.initial_transform_matrix[k].to(dev)
|
| 329 |
+
for k in list(self._init_particles_gpu.keys()):
|
| 330 |
+
self._init_particles_gpu[k] = self._init_particles_gpu[k].to(dev)
|
| 331 |
+
# Move obj_info tensors (shared with case_handler by reference)
|
| 332 |
+
for obj_info in self.all_obj_info:
|
| 333 |
+
for k, v in list(obj_info.items()):
|
| 334 |
+
if isinstance(v, torch.Tensor):
|
| 335 |
+
obj_info[k] = v.to(dev)
|
| 336 |
+
|
| 337 |
def _load_object_masks(self):
|
| 338 |
masks_dir = self.demo_data_path / "fg_masks"
|
| 339 |
if not masks_dir.exists():
|
|
|
|
| 643 |
compositor=AlphaCompositor(),
|
| 644 |
)
|
| 645 |
|
| 646 |
+
def move_to_device(self, device):
|
| 647 |
+
"""Move all tensors to target device and rebuild renderers."""
|
| 648 |
+
from pytorch3d.renderer import PerspectiveCameras
|
| 649 |
+
cam = self.current_camera
|
| 650 |
+
self.current_camera = PerspectiveCameras(
|
| 651 |
+
K=cam.K.to(device),
|
| 652 |
+
R=cam.R.to(device),
|
| 653 |
+
T=cam.T.to(device),
|
| 654 |
+
in_ndc=False,
|
| 655 |
+
image_size=((512, 512),),
|
| 656 |
+
device=device,
|
| 657 |
+
)
|
| 658 |
+
self.bg_points = self.bg_points.to(device)
|
| 659 |
+
self.bg_points_colors = self.bg_points_colors.to(device)
|
| 660 |
+
for pc in self.fg_pcs:
|
| 661 |
+
pc['points'] = pc['points'].to(device)
|
| 662 |
+
pc['colors'] = pc['colors'].to(device)
|
| 663 |
+
self.device = device
|
| 664 |
+
self.cache_bg = None # stale after device change; recomputed on next render
|
| 665 |
+
self._build_cached_renderers()
|
| 666 |
+
|
| 667 |
def update_fg_obj_info(self, all_obj_points):
|
| 668 |
for idx, pts in enumerate(all_obj_points):
|
| 669 |
self.fg_pcs[idx]["points"] = pts.clone()
|
video_generator.py
CHANGED
|
@@ -109,13 +109,13 @@ class StreamingVideoGenerator:
|
|
| 109 |
log_gpu("after checkpoint load (bf16, CPU)")
|
| 110 |
|
| 111 |
if low_memory:
|
| 112 |
-
DynamicSwapInstaller.install_model(self.pipeline.text_encoder, device=
|
| 113 |
else:
|
| 114 |
-
self.pipeline.text_encoder.to(device=
|
| 115 |
|
| 116 |
-
self.pipeline.generator.to(device=
|
| 117 |
-
self.pipeline.vae.to(device=
|
| 118 |
-
self.pipeline.encode_vae.to(device=
|
| 119 |
|
| 120 |
if self.enable_taehv:
|
| 121 |
import os
|
|
@@ -138,9 +138,9 @@ class StreamingVideoGenerator:
|
|
| 138 |
self.taehv_decoder.requires_grad_(False)
|
| 139 |
|
| 140 |
self.pipeline.processor_dtype = torch.float32
|
| 141 |
-
self.pipeline.processor_device =
|
| 142 |
-
self.pipeline.processor_vae = WanVideoVAE().to(device=
|
| 143 |
-
self.pipeline.processor_ienc = WanImageEncoder().to(device=
|
| 144 |
|
| 145 |
self.pipeline.processor_vae.requires_grad_(False)
|
| 146 |
self.pipeline.processor_ienc.requires_grad_(False)
|
|
@@ -477,3 +477,17 @@ class StreamingVideoGenerator:
|
|
| 477 |
self.current_start_frame = 0
|
| 478 |
self.conditional_dict = None
|
| 479 |
self.taehv_cache = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
log_gpu("after checkpoint load (bf16, CPU)")
|
| 110 |
|
| 111 |
if low_memory:
|
| 112 |
+
DynamicSwapInstaller.install_model(self.pipeline.text_encoder, device=self.device)
|
| 113 |
else:
|
| 114 |
+
self.pipeline.text_encoder.to(device=self.device)
|
| 115 |
|
| 116 |
+
self.pipeline.generator.to(device=self.device)
|
| 117 |
+
self.pipeline.vae.to(device=self.device)
|
| 118 |
+
self.pipeline.encode_vae.to(device=self.device, dtype=torch.bfloat16)
|
| 119 |
|
| 120 |
if self.enable_taehv:
|
| 121 |
import os
|
|
|
|
| 138 |
self.taehv_decoder.requires_grad_(False)
|
| 139 |
|
| 140 |
self.pipeline.processor_dtype = torch.float32
|
| 141 |
+
self.pipeline.processor_device = self.device
|
| 142 |
+
self.pipeline.processor_vae = WanVideoVAE().to(device=self.device, dtype=torch.float32)
|
| 143 |
+
self.pipeline.processor_ienc = WanImageEncoder().to(device=self.device, dtype=torch.float32)
|
| 144 |
|
| 145 |
self.pipeline.processor_vae.requires_grad_(False)
|
| 146 |
self.pipeline.processor_ienc.requires_grad_(False)
|
|
|
|
| 477 |
self.current_start_frame = 0
|
| 478 |
self.conditional_dict = None
|
| 479 |
self.taehv_cache = None
|
| 480 |
+
|
| 481 |
+
def move_pipeline_to_device(self, device: str):
|
| 482 |
+
"""Move all pipeline models to target device (CPU→GPU at generation start, GPU→CPU at end)."""
|
| 483 |
+
dev = torch.device(device)
|
| 484 |
+
self.device = dev
|
| 485 |
+
pipeline = self.pipeline
|
| 486 |
+
if hasattr(pipeline, 'generator') and pipeline.generator is not None:
|
| 487 |
+
pipeline.generator.to(device=dev)
|
| 488 |
+
if hasattr(pipeline, 'vae') and pipeline.vae is not None:
|
| 489 |
+
pipeline.vae.to(device=dev)
|
| 490 |
+
if hasattr(pipeline, 'encode_vae') and pipeline.encode_vae is not None:
|
| 491 |
+
pipeline.encode_vae.to(device=dev)
|
| 492 |
+
if hasattr(pipeline, 'text_encoder') and pipeline.text_encoder is not None:
|
| 493 |
+
pipeline.text_encoder.to(device=dev)
|