Wei Liu Claude Sonnet 4.6 commited on
Commit
0cdce4a
·
1 Parent(s): a1190a9

CPU-first startup: load all models/scenes to CPU at module level, GPU transfer at generation time

Browse files

- startup() now uses device="cpu" for StreamingVideoGenerator and InteractiveSimulator
- Added move_pipeline_to_device() to StreamingVideoGenerator
- Added move_to_device() to _MinimalSVR and InteractiveSimulator
- do_generate() transfers everything to GPU at start, back to CPU in finally
- Warmup deferred to first generation call (CUDA kernel compile on GPU)
- Avoids ZeroGPU time limit: only fast tensor moves happen inside GPU slot

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (3) hide show
  1. app.py +19 -25
  2. simulation_engine.py +49 -0
  3. video_generator.py +22 -8
app.py CHANGED
@@ -286,6 +286,7 @@ def startup():
286
  use_ema=USE_EMA,
287
  seed=SEED,
288
  enable_taehv=ENABLE_TAEHV,
 
289
  )
290
  video_generator.setup()
291
  log_gpu("after video generator setup")
@@ -303,7 +304,7 @@ def startup():
303
  if case_name == "santa_cloth":
304
  config_overrides["skip_force_fields"] = True
305
 
306
- simulator = InteractiveSimulator(str(case_dir), config_overrides=config_overrides)
307
  simulator.config["debug"] = False
308
  log_gpu(f"after simulator init ({case_name})")
309
 
@@ -350,17 +351,11 @@ def startup():
350
  log_gpu("after finish_precompute")
351
 
352
  # ---- Step 5: Warmup ----
353
- first_case = list(cases.keys())[0]
354
- _warmup_pipeline(first_case)
355
-
356
- # Release per-case precomputed tensors (i2v_conditional, full_y,
357
- # default_text_features) back to CPU so ZeroGPU can reclaim VRAM
358
- # between the startup @spaces.GPU slot and the first generation call.
359
- # The main model weights (transformer, VAE, text_encoder) stay on GPU;
360
- # ZeroGPU preserves the process's GPU slot across calls for efficiency.
361
- video_generator.move_case_data_to_device("cpu")
362
  torch.cuda.empty_cache()
363
- print("[6/6] Startup complete — Gradio server starting.")
364
 
365
 
366
  # ---------------------------------------------------------------------------
@@ -488,20 +483,17 @@ def do_generate(case_name, prompt, d0, s0, d1, s1, d2, s2):
488
  _is_generating = True
489
  _stop_event.clear()
490
 
491
- # Lazy full initialization: load models + build physics scenes + precompute.
492
- # Runs only on the first generation; subsequent calls skip this branch.
493
  if video_generator is None:
494
- yield None, "First run: loading models and initializing physics (this takes a minute)..."
495
- try:
496
- startup()
497
- except Exception as e:
498
- import traceback; traceback.print_exc()
499
- _is_generating = False
500
- yield None, f"Initialization error: {e}"
501
- return
502
 
503
- # Move precomputed case tensors back to CUDA for this generation session.
 
504
  video_generator.move_case_data_to_device("cuda")
 
 
 
505
 
506
  bundle = cases[case_name]
507
 
@@ -691,10 +683,12 @@ def do_generate(case_name, prompt, d0, s0, d1, s1, d2, s2):
691
  render_thread.join(timeout=10)
692
  if warp_thread is not None:
693
  warp_thread.join(timeout=10)
694
- # Release precomputed case tensors to CPU so ZeroGPU can reclaim
695
- # VRAM for other users once this generation session ends.
696
  if video_generator is not None:
 
697
  video_generator.move_case_data_to_device("cpu")
 
 
 
698
  torch.cuda.empty_cache()
699
  _is_generating = False
700
 
@@ -921,7 +915,7 @@ def build_demo():
921
  # files are already on disk so snapshot_download() is a fast no-op. By doing
922
  # this here we avoid holding a ZeroGPU allocation while waiting on downloads.
923
  _ensure_models_downloaded()
924
-
925
  demo = build_demo()
926
 
927
  if __name__ == "__main__":
 
286
  use_ema=USE_EMA,
287
  seed=SEED,
288
  enable_taehv=ENABLE_TAEHV,
289
+ device="cpu",
290
  )
291
  video_generator.setup()
292
  log_gpu("after video generator setup")
 
304
  if case_name == "santa_cloth":
305
  config_overrides["skip_force_fields"] = True
306
 
307
+ simulator = InteractiveSimulator(str(case_dir), device="cpu", config_overrides=config_overrides)
308
  simulator.config["debug"] = False
309
  log_gpu(f"after simulator init ({case_name})")
310
 
 
351
  log_gpu("after finish_precompute")
352
 
353
  # ---- Step 5: Warmup ----
354
+ # Warmup (CUDA kernel compilation) is deferred to first generation call.
355
+ print("[5/6] Skipping warmup at CPU-only startup — CUDA kernels compile on first generation.")
356
+
 
 
 
 
 
 
357
  torch.cuda.empty_cache()
358
+ print("[6/6] CPU-only startup complete — models and scenes ready. GPU transfer at generation time.")
359
 
360
 
361
  # ---------------------------------------------------------------------------
 
483
  _is_generating = True
484
  _stop_event.clear()
485
 
 
 
486
  if video_generator is None:
487
+ _is_generating = False
488
+ yield None, "Error: models not initialized. Please reload the Space."
489
+ return
 
 
 
 
 
490
 
491
+ # Transfer all CPU-resident state to GPU for this generation session.
492
+ video_generator.move_pipeline_to_device("cuda")
493
  video_generator.move_case_data_to_device("cuda")
494
+ for _b in cases.values():
495
+ if _b.simulator is not None:
496
+ _b.simulator.move_to_device("cuda")
497
 
498
  bundle = cases[case_name]
499
 
 
683
  render_thread.join(timeout=10)
684
  if warp_thread is not None:
685
  warp_thread.join(timeout=10)
 
 
686
  if video_generator is not None:
687
+ video_generator.move_pipeline_to_device("cpu")
688
  video_generator.move_case_data_to_device("cpu")
689
+ for _b in cases.values():
690
+ if _b is not None and _b.simulator is not None:
691
+ _b.simulator.move_to_device("cpu")
692
  torch.cuda.empty_cache()
693
  _is_generating = False
694
 
 
915
  # files are already on disk so snapshot_download() is a fast no-op. By doing
916
  # this here we avoid holding a ZeroGPU allocation while waiting on downloads.
917
  _ensure_models_downloaded()
918
+ startup() # Load all models and scenes to CPU at module level
919
  demo = build_demo()
920
 
921
  if __name__ == "__main__":
simulation_engine.py CHANGED
@@ -306,6 +306,34 @@ class InteractiveSimulator:
306
  def set_demo_case_handler(self, handler):
307
  self.demo_case_handler = handler
308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  def _load_object_masks(self):
310
  masks_dir = self.demo_data_path / "fg_masks"
311
  if not masks_dir.exists():
@@ -615,6 +643,27 @@ class _MinimalSVR:
615
  compositor=AlphaCompositor(),
616
  )
617
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618
  def update_fg_obj_info(self, all_obj_points):
619
  for idx, pts in enumerate(all_obj_points):
620
  self.fg_pcs[idx]["points"] = pts.clone()
 
306
  def set_demo_case_handler(self, handler):
307
  self.demo_case_handler = handler
308
 
309
+ def move_to_device(self, device):
310
+ """Move all renderer/simulation tensors to target device (CPU↔GPU)."""
311
+ dev = torch.device(device)
312
+ self.device = dev
313
+ # Move SVR (PyTorch3D renderer + camera + point clouds)
314
+ self.svr.move_to_device(dev)
315
+ # Move mesh data
316
+ for mesh in self.fg_meshes:
317
+ for k, v in list(mesh.items()):
318
+ if isinstance(v, torch.Tensor):
319
+ mesh[k] = v.to(dev)
320
+ # Move foreground point clouds
321
+ for pc_list in (self.fg_pcs_pt3d, self.fg_pcs_gs):
322
+ for pc in pc_list:
323
+ for k, v in list(pc.items()):
324
+ if isinstance(v, torch.Tensor):
325
+ pc[k] = v.to(dev)
326
+ # Move per-object transform matrices and initial particles
327
+ for k in list(self.initial_transform_matrix.keys()):
328
+ self.initial_transform_matrix[k] = self.initial_transform_matrix[k].to(dev)
329
+ for k in list(self._init_particles_gpu.keys()):
330
+ self._init_particles_gpu[k] = self._init_particles_gpu[k].to(dev)
331
+ # Move obj_info tensors (shared with case_handler by reference)
332
+ for obj_info in self.all_obj_info:
333
+ for k, v in list(obj_info.items()):
334
+ if isinstance(v, torch.Tensor):
335
+ obj_info[k] = v.to(dev)
336
+
337
  def _load_object_masks(self):
338
  masks_dir = self.demo_data_path / "fg_masks"
339
  if not masks_dir.exists():
 
643
  compositor=AlphaCompositor(),
644
  )
645
 
646
+ def move_to_device(self, device):
647
+ """Move all tensors to target device and rebuild renderers."""
648
+ from pytorch3d.renderer import PerspectiveCameras
649
+ cam = self.current_camera
650
+ self.current_camera = PerspectiveCameras(
651
+ K=cam.K.to(device),
652
+ R=cam.R.to(device),
653
+ T=cam.T.to(device),
654
+ in_ndc=False,
655
+ image_size=((512, 512),),
656
+ device=device,
657
+ )
658
+ self.bg_points = self.bg_points.to(device)
659
+ self.bg_points_colors = self.bg_points_colors.to(device)
660
+ for pc in self.fg_pcs:
661
+ pc['points'] = pc['points'].to(device)
662
+ pc['colors'] = pc['colors'].to(device)
663
+ self.device = device
664
+ self.cache_bg = None # stale after device change; recomputed on next render
665
+ self._build_cached_renderers()
666
+
667
  def update_fg_obj_info(self, all_obj_points):
668
  for idx, pts in enumerate(all_obj_points):
669
  self.fg_pcs[idx]["points"] = pts.clone()
video_generator.py CHANGED
@@ -109,13 +109,13 @@ class StreamingVideoGenerator:
109
  log_gpu("after checkpoint load (bf16, CPU)")
110
 
111
  if low_memory:
112
- DynamicSwapInstaller.install_model(self.pipeline.text_encoder, device=gpu)
113
  else:
114
- self.pipeline.text_encoder.to(device=gpu)
115
 
116
- self.pipeline.generator.to(device=gpu)
117
- self.pipeline.vae.to(device=gpu)
118
- self.pipeline.encode_vae.to(device=gpu, dtype=torch.bfloat16)
119
 
120
  if self.enable_taehv:
121
  import os
@@ -138,9 +138,9 @@ class StreamingVideoGenerator:
138
  self.taehv_decoder.requires_grad_(False)
139
 
140
  self.pipeline.processor_dtype = torch.float32
141
- self.pipeline.processor_device = gpu
142
- self.pipeline.processor_vae = WanVideoVAE().to(device=gpu, dtype=torch.float32)
143
- self.pipeline.processor_ienc = WanImageEncoder().to(device=gpu, dtype=torch.float32)
144
 
145
  self.pipeline.processor_vae.requires_grad_(False)
146
  self.pipeline.processor_ienc.requires_grad_(False)
@@ -477,3 +477,17 @@ class StreamingVideoGenerator:
477
  self.current_start_frame = 0
478
  self.conditional_dict = None
479
  self.taehv_cache = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  log_gpu("after checkpoint load (bf16, CPU)")
110
 
111
  if low_memory:
112
+ DynamicSwapInstaller.install_model(self.pipeline.text_encoder, device=self.device)
113
  else:
114
+ self.pipeline.text_encoder.to(device=self.device)
115
 
116
+ self.pipeline.generator.to(device=self.device)
117
+ self.pipeline.vae.to(device=self.device)
118
+ self.pipeline.encode_vae.to(device=self.device, dtype=torch.bfloat16)
119
 
120
  if self.enable_taehv:
121
  import os
 
138
  self.taehv_decoder.requires_grad_(False)
139
 
140
  self.pipeline.processor_dtype = torch.float32
141
+ self.pipeline.processor_device = self.device
142
+ self.pipeline.processor_vae = WanVideoVAE().to(device=self.device, dtype=torch.float32)
143
+ self.pipeline.processor_ienc = WanImageEncoder().to(device=self.device, dtype=torch.float32)
144
 
145
  self.pipeline.processor_vae.requires_grad_(False)
146
  self.pipeline.processor_ienc.requires_grad_(False)
 
477
  self.current_start_frame = 0
478
  self.conditional_dict = None
479
  self.taehv_cache = None
480
+
481
+ def move_pipeline_to_device(self, device: str):
482
+ """Move all pipeline models to target device (CPU→GPU at generation start, GPU→CPU at end)."""
483
+ dev = torch.device(device)
484
+ self.device = dev
485
+ pipeline = self.pipeline
486
+ if hasattr(pipeline, 'generator') and pipeline.generator is not None:
487
+ pipeline.generator.to(device=dev)
488
+ if hasattr(pipeline, 'vae') and pipeline.vae is not None:
489
+ pipeline.vae.to(device=dev)
490
+ if hasattr(pipeline, 'encode_vae') and pipeline.encode_vae is not None:
491
+ pipeline.encode_vae.to(device=dev)
492
+ if hasattr(pipeline, 'text_encoder') and pipeline.text_encoder is not None:
493
+ pipeline.text_encoder.to(device=dev)