Spaces:
Running
on
Zero
Running
on
Zero
Move models from CPU to GPU in worker process (like LivePortrait)
Browse files
app.py
CHANGED
|
@@ -600,6 +600,39 @@ def run_graio_demo(args):
|
|
| 600 |
# @spaces.GPU 装饰器会自动处理 GPU 初始化,不需要手动初始化
|
| 601 |
@spaces.GPU(duration=360)
|
| 602 |
def gpu_wrapped_generate_video(*args, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
return generate_video(*args, **kwargs)
|
| 604 |
|
| 605 |
def toggle_audio_inputs(person_num):
|
|
|
|
| 600 |
# @spaces.GPU 装饰器会自动处理 GPU 初始化,不需要手动初始化
|
| 601 |
@spaces.GPU(duration=360)
|
| 602 |
def gpu_wrapped_generate_video(*args, **kwargs):
|
| 603 |
+
# 在 worker 进程中将模型移动到 GPU(如果模型在 CPU 上)
|
| 604 |
+
# 参考 LivePortrait: 在 worker 进程中直接使用 .to("cuda")
|
| 605 |
+
if torch.cuda.is_available() and device == -1:
|
| 606 |
+
try:
|
| 607 |
+
logging.info("Moving models from CPU to GPU in worker process...")
|
| 608 |
+
cuda_device = torch.device("cuda")
|
| 609 |
+
|
| 610 |
+
# 移动主模型到 GPU
|
| 611 |
+
if hasattr(wan_a2v, 'model') and wan_a2v.model is not None:
|
| 612 |
+
wan_a2v.model = wan_a2v.model.to(cuda_device)
|
| 613 |
+
|
| 614 |
+
# 移动 VAE 模型到 GPU
|
| 615 |
+
if hasattr(wan_a2v, 'vae') and wan_a2v.vae is not None:
|
| 616 |
+
if hasattr(wan_a2v.vae, 'model'):
|
| 617 |
+
wan_a2v.vae.model = wan_a2v.vae.model.to(cuda_device)
|
| 618 |
+
# 移动 VAE 的 mean 和 std 张量
|
| 619 |
+
if hasattr(wan_a2v.vae, 'mean'):
|
| 620 |
+
wan_a2v.vae.mean = wan_a2v.vae.mean.to(cuda_device)
|
| 621 |
+
if hasattr(wan_a2v.vae, 'std'):
|
| 622 |
+
wan_a2v.vae.std = wan_a2v.vae.std.to(cuda_device)
|
| 623 |
+
|
| 624 |
+
# 移动 CLIP 模型到 GPU
|
| 625 |
+
if hasattr(wan_a2v, 'clip') and wan_a2v.clip is not None:
|
| 626 |
+
if hasattr(wan_a2v.clip, 'model'):
|
| 627 |
+
wan_a2v.clip.model = wan_a2v.clip.model.to(cuda_device)
|
| 628 |
+
|
| 629 |
+
# 更新设备信息
|
| 630 |
+
wan_a2v.device = cuda_device
|
| 631 |
+
|
| 632 |
+
logging.info("Models moved to GPU successfully")
|
| 633 |
+
except Exception as e:
|
| 634 |
+
logging.warning(f"Failed to move models to GPU: {e}")
|
| 635 |
+
|
| 636 |
return generate_video(*args, **kwargs)
|
| 637 |
|
| 638 |
def toggle_audio_inputs(person_num):
|