C4G-HKUST commited on
Commit
f2436d3
·
1 Parent(s): 7b63b72

Move models from CPU to GPU in worker process (like LivePortrait)

Browse files
Files changed (1) hide show
  1. app.py +33 -0
app.py CHANGED
@@ -600,6 +600,39 @@ def run_graio_demo(args):
600
  # @spaces.GPU 装饰器会自动处理 GPU 初始化,不需要手动初始化
601
  @spaces.GPU(duration=360)
602
  def gpu_wrapped_generate_video(*args, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
  return generate_video(*args, **kwargs)
604
 
605
  def toggle_audio_inputs(person_num):
 
600
  # @spaces.GPU 装饰器会自动处理 GPU 初始化,不需要手动初始化
601
  @spaces.GPU(duration=360)
602
  def gpu_wrapped_generate_video(*args, **kwargs):
603
+ # 在 worker 进程中将模型移动到 GPU(如果模型在 CPU 上)
604
+ # 参考 LivePortrait: 在 worker 进程中直接使用 .to("cuda")
605
+ if torch.cuda.is_available() and device == -1:
606
+ try:
607
+ logging.info("Moving models from CPU to GPU in worker process...")
608
+ cuda_device = torch.device("cuda")
609
+
610
+ # 移动主模型到 GPU
611
+ if hasattr(wan_a2v, 'model') and wan_a2v.model is not None:
612
+ wan_a2v.model = wan_a2v.model.to(cuda_device)
613
+
614
+ # 移动 VAE 模型到 GPU
615
+ if hasattr(wan_a2v, 'vae') and wan_a2v.vae is not None:
616
+ if hasattr(wan_a2v.vae, 'model'):
617
+ wan_a2v.vae.model = wan_a2v.vae.model.to(cuda_device)
618
+ # 移动 VAE 的 mean 和 std 张量
619
+ if hasattr(wan_a2v.vae, 'mean'):
620
+ wan_a2v.vae.mean = wan_a2v.vae.mean.to(cuda_device)
621
+ if hasattr(wan_a2v.vae, 'std'):
622
+ wan_a2v.vae.std = wan_a2v.vae.std.to(cuda_device)
623
+
624
+ # 移动 CLIP 模型到 GPU
625
+ if hasattr(wan_a2v, 'clip') and wan_a2v.clip is not None:
626
+ if hasattr(wan_a2v.clip, 'model'):
627
+ wan_a2v.clip.model = wan_a2v.clip.model.to(cuda_device)
628
+
629
+ # 更新设备信息
630
+ wan_a2v.device = cuda_device
631
+
632
+ logging.info("Models moved to GPU successfully")
633
+ except Exception as e:
634
+ logging.warning(f"Failed to move models to GPU: {e}")
635
+
636
  return generate_video(*args, **kwargs)
637
 
638
  def toggle_audio_inputs(person_num):