C4G-HKUST commited on
Commit
3c322ce
·
1 Parent(s): 694fec8

Track GPU device during initialization and use same GPU in worker process and inference

Browse files
Files changed (1) hide show
  1. app.py +44 -14
app.py CHANGED
@@ -385,14 +385,25 @@ def run_graio_demo(args):
385
  os.makedirs(args.audio_save_dir, exist_ok=True)
386
 
387
  # 运行时动态检测 GPU 可用性(参考 Meigen-MultiTalk)
 
 
 
 
 
388
  if torch.cuda.is_available():
389
  try:
390
  num_gpus = torch.cuda.device_count()
391
  if num_gpus > 0:
392
- gpu_name = torch.cuda.get_device_name(0)
393
- logging.info(f"GPU AVAILABLE: {num_gpus} GPU(s), Name: {gpu_name}")
394
- # 使用 GPU
395
- device = local_rank if world_size > 1 else 0
 
 
 
 
 
 
396
  else:
397
  logging.warning("CUDA is available but no GPU devices found. Using CPU.")
398
  device = -1 # 使用 CPU
@@ -426,6 +437,16 @@ def run_graio_demo(args):
426
 
427
  def generate_video(img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3,
428
  sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector):
 
 
 
 
 
 
 
 
 
 
429
  input_data = {}
430
  input_data["prompt"] = img2vid_prompt
431
  input_data["cond_image"] = img2vid_image
@@ -577,16 +598,25 @@ def run_graio_demo(args):
577
  # 参考: https://huggingface.co/spaces/KlingTeam/LivePortrait/blob/main/app.py
578
  @spaces.GPU(duration=360)
579
  def gpu_wrapped_generate_video(*args, **kwargs):
580
- # 在 worker 进程中检查 GPU 可用性
581
- try:
582
- if torch.cuda.is_available():
583
- # 尝试访问 GPU 以确保它已准备好
584
- _ = torch.cuda.current_device()
585
- logging.info(f"GPU ready in worker process: {torch.cuda.get_device_name(0)}")
586
- else:
587
- logging.warning("GPU not available in worker process, but continuing...")
588
- except RuntimeError as e:
589
- logging.warning(f"GPU initialization error in worker process: {e}. Continuing anyway...")
 
 
 
 
 
 
 
 
 
590
 
591
  return generate_video(*args, **kwargs)
592
 
 
385
  os.makedirs(args.audio_save_dir, exist_ok=True)
386
 
387
  # 运行时动态检测 GPU 可用性(参考 Meigen-MultiTalk)
388
+ # 记录 GPU 信息,以便在 worker 进程中使用相同的 GPU
389
+ gpu_device_id = None
390
+ gpu_name = None
391
+ gpu_uuid = None
392
+
393
  if torch.cuda.is_available():
394
  try:
395
  num_gpus = torch.cuda.device_count()
396
  if num_gpus > 0:
397
+ gpu_device_id = local_rank if world_size > 1 else 0
398
+ torch.cuda.set_device(gpu_device_id)
399
+ gpu_name = torch.cuda.get_device_name(gpu_device_id)
400
+ # 尝试获取 GPU UUID(如果可用)
401
+ try:
402
+ gpu_uuid = torch.cuda.get_device_properties(gpu_device_id).uuid
403
+ except:
404
+ pass
405
+ logging.info(f"GPU AVAILABLE: {num_gpus} GPU(s), Device ID: {gpu_device_id}, Name: {gpu_name}, UUID: {gpu_uuid}")
406
+ device = gpu_device_id
407
  else:
408
  logging.warning("CUDA is available but no GPU devices found. Using CPU.")
409
  device = -1 # 使用 CPU
 
437
 
438
  def generate_video(img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3,
439
  sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector):
440
+ # 确保使用初始化时记录的 GPU 设备
441
+ if gpu_device_id is not None and torch.cuda.is_available():
442
+ try:
443
+ torch.cuda.set_device(gpu_device_id)
444
+ current_device = torch.cuda.current_device()
445
+ current_gpu_name = torch.cuda.get_device_name(current_device)
446
+ logging.info(f"Using GPU device {current_device} ({current_gpu_name}) for inference")
447
+ except Exception as e:
448
+ logging.warning(f"Failed to set GPU device {gpu_device_id}: {e}")
449
+
450
  input_data = {}
451
  input_data["prompt"] = img2vid_prompt
452
  input_data["cond_image"] = img2vid_image
 
598
  # 参考: https://huggingface.co/spaces/KlingTeam/LivePortrait/blob/main/app.py
599
  @spaces.GPU(duration=360)
600
  def gpu_wrapped_generate_video(*args, **kwargs):
601
+ # 在 worker 进程中确保使用初始化时记录的 GPU
602
+ if gpu_device_id is not None:
603
+ try:
604
+ if torch.cuda.is_available():
605
+ # 设置到初始化时记录的 GPU 设备
606
+ torch.cuda.set_device(gpu_device_id)
607
+ current_device = torch.cuda.current_device()
608
+ current_gpu_name = torch.cuda.get_device_name(current_device)
609
+ logging.info(f"Worker process using GPU device {current_device} ({current_gpu_name}) - matching initialization")
610
+
611
+ # 验证 GPU 名称是否匹配(如果记录了名称)
612
+ if gpu_name and current_gpu_name != gpu_name:
613
+ logging.warning(f"GPU name mismatch: init={gpu_name}, worker={current_gpu_name}")
614
+ else:
615
+ logging.warning("GPU not available in worker process, but continuing...")
616
+ except RuntimeError as e:
617
+ logging.warning(f"GPU initialization error in worker process: {e}. Continuing anyway...")
618
+ else:
619
+ logging.info("No GPU device ID recorded, using default device")
620
 
621
  return generate_video(*args, **kwargs)
622