C4G-HKUST commited on
Commit
fd55666
·
1 Parent(s): 490e55a

feat: time out check

Browse files
Files changed (3) hide show
  1. app.py +127 -6
  2. wan/audio2video_multiID.py +2 -0
  3. wan/utils/infer_utils.py +53 -7
app.py CHANGED
@@ -11,6 +11,7 @@ import spaces
11
  warnings.filterwarnings('ignore')
12
 
13
  import random
 
14
  import torch
15
  import torch.distributed as dist
16
  from PIL import Image
@@ -435,7 +436,7 @@ def run_graio_demo(args):
435
  logging.info("Model and face processor loaded successfully.")
436
 
437
  def generate_video(img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3,
438
- sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector, fixed_steps=None):
439
  # 参考 LivePortrait: 在 worker 进程中直接使用 cuda 设备
440
  # 参考: https://huggingface.co/spaces/KlingTeam/LivePortrait/blob/main/src/gradio_pipeline.py
441
  # @spaces.GPU 装饰器已经初始化了 GPU,这里直接使用即可
@@ -480,7 +481,18 @@ def run_graio_demo(args):
480
  if audio_paths and len(audio_paths) > 0:
481
  # 使用 cfg 中的 fps,如果不可用则使用默认值 24
482
  fps = getattr(cfg, 'fps', 24)
483
- current_frame_num = calculate_frame_num_from_audio(audio_paths, fps, mode=audio_mode_selector)
 
 
 
 
 
 
 
 
 
 
 
484
  logging.info(f"Dynamically determined frame number: {current_frame_num} (mode: {audio_mode_selector})")
485
  else:
486
  # 没有音频时使用默认帧数
@@ -519,6 +531,7 @@ def run_graio_demo(args):
519
  audio_paths=audio_paths,
520
  task_key="gradio_output",
521
  mode=audio_mode_selector,
 
522
  )
523
 
524
  if isinstance(video, dict):
@@ -610,6 +623,63 @@ def run_graio_demo(args):
610
  def gpu_wrapped_generate_video_fast(*args, **kwargs):
611
  # 固定使用10步去噪,通过关键字参数传递
612
  kwargs['fixed_steps'] = 10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
  return gpu_wrapped_generate_video_worker(*args, **kwargs)
614
 
615
  # 高质量生成模式:780秒,用户选择去噪步数
@@ -666,7 +736,8 @@ def run_graio_demo(args):
666
  except Exception as e:
667
  logging.warning(f"Failed to move models to GPU: {e}")
668
 
669
- return generate_video(*args, **kwargs)
 
670
 
671
 
672
 
@@ -815,22 +886,72 @@ def run_graio_demo(args):
815
  )
816
 
817
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
818
  # 快速生成按钮:210秒,固定10步
819
  run_i2v_button_fast.click(
820
- fn=gpu_wrapped_generate_video_fast,
821
  inputs=[img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3, sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector],
822
  outputs=[result_gallery],
823
  )
824
 
825
  # 高质量生成按钮:780秒,用户选择步数
826
  run_i2v_button_quality.click(
827
- fn=gpu_wrapped_generate_video_quality,
828
  inputs=[img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3, sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector],
829
  outputs=[result_gallery],
830
  )
831
  # 参考 Meigen-MultiTalk 的成功配置
832
  # 在 Hugging Face Spaces 上,Gradio 会自动处理端口和服务器配置
833
- demo.queue(max_size=4).launch(show_error=True)
834
 
835
 
836
 
 
11
  warnings.filterwarnings('ignore')
12
 
13
  import random
14
+ import math
15
  import torch
16
  import torch.distributed as dist
17
  from PIL import Image
 
436
  logging.info("Model and face processor loaded successfully.")
437
 
438
  def generate_video(img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3,
439
+ sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector, fixed_steps=None, trim_to_6s=False):
440
  # 参考 LivePortrait: 在 worker 进程中直接使用 cuda 设备
441
  # 参考: https://huggingface.co/spaces/KlingTeam/LivePortrait/blob/main/src/gradio_pipeline.py
442
  # @spaces.GPU 装饰器已经初始化了 GPU,这里直接使用即可
 
481
  if audio_paths and len(audio_paths) > 0:
482
  # 使用 cfg 中的 fps,如果不可用则使用默认值 24
483
  fps = getattr(cfg, 'fps', 24)
484
+ calculated_frame_num = calculate_frame_num_from_audio(audio_paths, fps, mode=audio_mode_selector)
485
+
486
+ # Fast模式:如果trim_to_6s为True,强制限制为6秒对应的帧数
487
+ if trim_to_6s:
488
+ # 计算6秒对应的帧数(4n+1格式)
489
+ max_frames_6s = int(math.ceil(6.0 * fps))
490
+ max_frames_6s = ((max_frames_6s - 1) // 4) * 4 + 1
491
+ current_frame_num = min(calculated_frame_num, max_frames_6s)
492
+ logging.warning(f"Fast mode: Audio duration exceeds 6 seconds. Trimming to 6 seconds ({max_frames_6s} frames). Original: {calculated_frame_num} frames")
493
+ else:
494
+ current_frame_num = calculated_frame_num
495
+
496
  logging.info(f"Dynamically determined frame number: {current_frame_num} (mode: {audio_mode_selector})")
497
  else:
498
  # 没有音频时使用默认帧数
 
531
  audio_paths=audio_paths,
532
  task_key="gradio_output",
533
  mode=audio_mode_selector,
534
+ trim_to_6s=trim_to_6s,
535
  )
536
 
537
  if isinstance(video, dict):
 
623
  def gpu_wrapped_generate_video_fast(*args, **kwargs):
624
  # 固定使用10步去噪,通过关键字参数传递
625
  kwargs['fixed_steps'] = 10
626
+
627
+ # Fast模式音频长度检测:检查是否超过6秒
628
+ # 参数顺序: img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3,
629
+ # sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector
630
+ if len(args) >= 11:
631
+ img2vid_image = args[0]
632
+ img2vid_prompt = args[1]
633
+ n_prompt = args[2]
634
+ img2vid_audio_1 = args[3]
635
+ img2vid_audio_2 = args[4]
636
+ img2vid_audio_3 = args[5]
637
+ sd_steps = args[6]
638
+ seed = args[7]
639
+ guide_scale = args[8]
640
+ person_num_selector = args[9]
641
+ audio_mode_selector = args[10]
642
+
643
+ # 根据人数收集音频路径
644
+ audio_paths = []
645
+ if person_num_selector == "1 Person":
646
+ if img2vid_audio_1:
647
+ audio_paths.append(img2vid_audio_1)
648
+ elif person_num_selector == "2 Persons":
649
+ if img2vid_audio_1:
650
+ audio_paths.append(img2vid_audio_1)
651
+ if img2vid_audio_2:
652
+ audio_paths.append(img2vid_audio_2)
653
+ elif person_num_selector == "3 Persons":
654
+ if img2vid_audio_1:
655
+ audio_paths.append(img2vid_audio_1)
656
+ if img2vid_audio_2:
657
+ audio_paths.append(img2vid_audio_2)
658
+ if img2vid_audio_3:
659
+ audio_paths.append(img2vid_audio_3)
660
+
661
+ # 检测音频长度是否超过6秒
662
+ if audio_paths and len(audio_paths) > 0:
663
+ fps = getattr(cfg, 'fps', 24)
664
+ try:
665
+ calculated_frame_num = calculate_frame_num_from_audio(audio_paths, fps, mode=audio_mode_selector)
666
+ # 计算6秒对应的帧数
667
+ max_frames_6s = int(math.ceil(6.0 * fps))
668
+ max_frames_6s = ((max_frames_6s - 1) // 4) * 4 + 1
669
+
670
+ if calculated_frame_num > max_frames_6s:
671
+ # 超过6秒,设置trim_to_6s标记
672
+ kwargs['trim_to_6s'] = True
673
+ calculated_duration = calculated_frame_num / fps
674
+ logging.warning(f"Fast mode: Audio duration ({calculated_duration:.2f}s) exceeds 6 seconds limit. Will trim to 6 seconds.")
675
+ else:
676
+ kwargs['trim_to_6s'] = False
677
+ except Exception as e:
678
+ logging.warning(f"Failed to check audio duration: {e}")
679
+ kwargs['trim_to_6s'] = False
680
+ else:
681
+ kwargs['trim_to_6s'] = False
682
+
683
  return gpu_wrapped_generate_video_worker(*args, **kwargs)
684
 
685
  # 高质量生成模式:780秒,用户选择去噪步数
 
736
  except Exception as e:
737
  logging.warning(f"Failed to move models to GPU: {e}")
738
 
739
+ result = generate_video(*args, **kwargs)
740
+ return result
741
 
742
 
743
 
 
886
  )
887
 
888
 
889
+ # 包装函数:处理警告信息显示
890
+ def handle_fast_generation(img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3,
891
+ sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector):
892
+ # 在开始生成前先检测音频长度,如果超过6秒立即显示警告
893
+ # 根据人数收集音频路径
894
+ audio_paths = []
895
+ if person_num_selector == "1 Person":
896
+ if img2vid_audio_1:
897
+ audio_paths.append(img2vid_audio_1)
898
+ elif person_num_selector == "2 Persons":
899
+ if img2vid_audio_1:
900
+ audio_paths.append(img2vid_audio_1)
901
+ if img2vid_audio_2:
902
+ audio_paths.append(img2vid_audio_2)
903
+ elif person_num_selector == "3 Persons":
904
+ if img2vid_audio_1:
905
+ audio_paths.append(img2vid_audio_1)
906
+ if img2vid_audio_2:
907
+ audio_paths.append(img2vid_audio_2)
908
+ if img2vid_audio_3:
909
+ audio_paths.append(img2vid_audio_3)
910
+
911
+ # 检测音频长度是否超过6秒
912
+ if audio_paths and len(audio_paths) > 0:
913
+ fps = getattr(cfg, 'fps', 24)
914
+ try:
915
+ calculated_frame_num = calculate_frame_num_from_audio(audio_paths, fps, mode=audio_mode_selector)
916
+ # 计算6秒对应的帧数
917
+ max_frames_6s = int(math.ceil(6.0 * fps))
918
+ max_frames_6s = ((max_frames_6s - 1) // 4) * 4 + 1
919
+
920
+ if calculated_frame_num > max_frames_6s:
921
+ # 超过6秒,立即显示警告
922
+ calculated_duration = calculated_frame_num / fps
923
+ warning_msg = f"⚠️ Warning: Your audio duration ({calculated_duration:.2f}s) exceeds the 6-second limit for Fast Mode. The audio will be automatically trimmed to 6 seconds to prevent timeout."
924
+ gr.Warning(warning_msg, duration=5)
925
+ except Exception as e:
926
+ logging.warning(f"Failed to check audio duration: {e}")
927
+
928
+ # 继续执行视频生成
929
+ result = gpu_wrapped_generate_video_fast(
930
+ img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3,
931
+ sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector
932
+ )
933
+ return result
934
+
935
+ def handle_quality_generation(*args):
936
+ result = gpu_wrapped_generate_video_quality(*args)
937
+ return result
938
+
939
  # 快速生成按钮:210秒,固定10步
940
  run_i2v_button_fast.click(
941
+ fn=handle_fast_generation,
942
  inputs=[img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3, sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector],
943
  outputs=[result_gallery],
944
  )
945
 
946
  # 高质量生成按钮:780秒,用户选择步数
947
  run_i2v_button_quality.click(
948
+ fn=handle_quality_generation,
949
  inputs=[img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3, sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector],
950
  outputs=[result_gallery],
951
  )
952
  # 参考 Meigen-MultiTalk 的成功配置
953
  # 在 Hugging Face Spaces 上,Gradio 会自动处理端口和服务器配置
954
+ demo.queue(max_size=10).launch(show_error=True)
955
 
956
 
957
 
wan/audio2video_multiID.py CHANGED
@@ -199,6 +199,7 @@ class WanAF2V:
199
  audio_paths=None, # New: audio path list, supports multiple audio files
200
  task_key=None,
201
  mode="pad", # Audio processing mode: "pad" or "concat"
 
202
  ):
203
  r"""
204
  Generates video frames from input image and text prompt using diffusion process.
@@ -514,6 +515,7 @@ class WanAF2V:
514
  half_dtype=self.half_dtype,
515
  preprocess_audio=preprocess_audio,
516
  resample_audio=resample_audio,
 
517
  )
518
 
519
  # Prepare audio_ref_features - new list mode
 
199
  audio_paths=None, # New: audio path list, supports multiple audio files
200
  task_key=None,
201
  mode="pad", # Audio processing mode: "pad" or "concat"
202
+ trim_to_6s=False, # Fast mode: trim audio to 6 seconds
203
  ):
204
  r"""
205
  Generates video frames from input image and text prompt using diffusion process.
 
515
  half_dtype=self.half_dtype,
516
  preprocess_audio=preprocess_audio,
517
  resample_audio=resample_audio,
518
+ trim_to_6s=trim_to_6s,
519
  )
520
 
521
  # Prepare audio_ref_features - new list mode
wan/utils/infer_utils.py CHANGED
@@ -118,6 +118,7 @@ def process_audio_features(
118
  half_dtype=None,
119
  preprocess_audio=None,
120
  resample_audio=None,
 
121
  ):
122
  """
123
  Process audio files and extract audio features.
@@ -202,6 +203,31 @@ def process_audio_features(
202
  total_length = sum(audio_lengths)
203
  print(f"Total audio length in concat mode (from processed frames): {total_length} frames")
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  # Ensure total length is in 4n+1 format (model requirement)
206
  total_length = ((total_length - 1) // 4) * 4 + 1
207
  print(f"Adjusted total length to 4n+1 format: {total_length} frames")
@@ -257,7 +283,7 @@ def process_audio_features(
257
  audio_feat_list.append(zero_audio_feat)
258
  print(f"Audio {i} is missing, created zero features with shape: {zero_audio_feat.shape}")
259
  else:
260
- # Pad mode: keep existing logic, no changes needed
261
  for i, audio_path in enumerate(audio_paths):
262
  if audio_path and os.path.exists(audio_path):
263
  print(f"Processing audio {i}: {audio_path}")
@@ -270,10 +296,19 @@ def process_audio_features(
270
  with torch.no_grad():
271
  print(f"wav2vec_model: {wav2vec_model}")
272
  print(f"cache_dir:{cache_dir}")
273
- # Use dynamically determined frame number F
 
 
 
 
 
 
 
 
 
274
  audio_emb, audio_length = preprocess_audio(
275
  wav_path=target_resampled_audio_path,
276
- num_generated_frames_per_clip=F, # Use dynamically determined frame number
277
  fps=fps,
278
  wav2vec_model=wav2vec_model,
279
  vocal_separator_model=vocal_separator_model,
@@ -284,7 +319,8 @@ def process_audio_features(
284
  audio_dtype = half_dtype if use_half else torch.bfloat16
285
  audio_emb = audio_emb.to(device, dtype=audio_dtype)
286
 
287
- audio_feat = audio_emb[:F] # Use dynamically determined frame number
 
288
  audio_feat_list.append(audio_feat)
289
  print(f"Audio {i} processed, shape: {audio_feat.shape}")
290
  else:
@@ -310,10 +346,19 @@ def process_audio_features(
310
  target_resampled_audio_path,
311
  )
312
  with torch.no_grad():
313
- # Use dynamically determined frame number F
 
 
 
 
 
 
 
 
 
314
  audio_emb, audio_length = preprocess_audio(
315
  wav_path=audio,
316
- num_generated_frames_per_clip=F, # Use dynamically determined frame number
317
  fps=fps,
318
  wav2vec_model=wav2vec_model,
319
  vocal_separator_model=vocal_separator_model,
@@ -324,7 +369,8 @@ def process_audio_features(
324
  audio_dtype = half_dtype if use_half else torch.bfloat16
325
  audio_emb = audio_emb.to(device, dtype=audio_dtype)
326
 
327
- audio_feat = audio_emb[:F] # Use dynamically determined frame number
 
328
  audio_feat_list.append(audio_feat)
329
  print(f"Single audio processed, shape: {audio_feat.shape}")
330
  else:
 
118
  half_dtype=None,
119
  preprocess_audio=None,
120
  resample_audio=None,
121
+ trim_to_6s=False, # Fast mode: trim audio to 6 seconds
122
  ):
123
  """
124
  Process audio files and extract audio features.
 
203
  total_length = sum(audio_lengths)
204
  print(f"Total audio length in concat mode (from processed frames): {total_length} frames")
205
 
206
+ # Fast mode: trim to 6 seconds if trim_to_6s is True
207
+ if trim_to_6s:
208
+ import math
209
+ # Calculate 6 seconds in frames
210
+ max_frames_6s = int(math.ceil(6.0 * fps))
211
+ max_frames_6s = ((max_frames_6s - 1) // 4) * 4 + 1
212
+ if total_length > max_frames_6s:
213
+ print(f"Fast mode: Trimming audio from {total_length} frames to {max_frames_6s} frames (6 seconds)")
214
+ # Truncate each audio proportionally
215
+ scale_factor = max_frames_6s / total_length
216
+ cumulative_length = 0
217
+ for i, audio_len in enumerate(audio_lengths):
218
+ if audio_len > 0:
219
+ new_audio_len = int(audio_len * scale_factor)
220
+ # Ensure it fits within remaining space
221
+ remaining_space = max_frames_6s - cumulative_length
222
+ new_audio_len = min(new_audio_len, remaining_space)
223
+ audio_lengths[i] = new_audio_len
224
+ # Truncate the corresponding raw audio feature
225
+ if raw_audio_feat_list[i] is not None:
226
+ raw_audio_feat_list[i] = raw_audio_feat_list[i][:new_audio_len]
227
+ cumulative_length += new_audio_len
228
+ total_length = sum(audio_lengths)
229
+ print(f"After trimming: total_length = {total_length} frames")
230
+
231
  # Ensure total length is in 4n+1 format (model requirement)
232
  total_length = ((total_length - 1) // 4) * 4 + 1
233
  print(f"Adjusted total length to 4n+1 format: {total_length} frames")
 
283
  audio_feat_list.append(zero_audio_feat)
284
  print(f"Audio {i} is missing, created zero features with shape: {zero_audio_feat.shape}")
285
  else:
286
+ # Pad mode: keep existing logic, but apply trim_to_6s if needed
287
  for i, audio_path in enumerate(audio_paths):
288
  if audio_path and os.path.exists(audio_path):
289
  print(f"Processing audio {i}: {audio_path}")
 
296
  with torch.no_grad():
297
  print(f"wav2vec_model: {wav2vec_model}")
298
  print(f"cache_dir:{cache_dir}")
299
+ # Fast mode: if trim_to_6s, limit to 6 seconds
300
+ target_frames = F
301
+ if trim_to_6s:
302
+ import math
303
+ max_frames_6s = int(math.ceil(6.0 * fps))
304
+ max_frames_6s = ((max_frames_6s - 1) // 4) * 4 + 1
305
+ target_frames = min(F, max_frames_6s)
306
+ if F > max_frames_6s:
307
+ print(f"Fast mode: Trimming audio {i} from {F} frames to {max_frames_6s} frames (6 seconds)")
308
+ # Use dynamically determined frame number
309
  audio_emb, audio_length = preprocess_audio(
310
  wav_path=target_resampled_audio_path,
311
+ num_generated_frames_per_clip=target_frames, # Use target frames (may be trimmed)
312
  fps=fps,
313
  wav2vec_model=wav2vec_model,
314
  vocal_separator_model=vocal_separator_model,
 
319
  audio_dtype = half_dtype if use_half else torch.bfloat16
320
  audio_emb = audio_emb.to(device, dtype=audio_dtype)
321
 
322
+ # Ensure we don't exceed F frames (for consistency with other tensors)
323
+ audio_feat = audio_emb[:F] # Use F to maintain consistency
324
  audio_feat_list.append(audio_feat)
325
  print(f"Audio {i} processed, shape: {audio_feat.shape}")
326
  else:
 
346
  target_resampled_audio_path,
347
  )
348
  with torch.no_grad():
349
+ # Fast mode: if trim_to_6s, limit to 6 seconds
350
+ target_frames = F
351
+ if trim_to_6s:
352
+ import math
353
+ max_frames_6s = int(math.ceil(6.0 * fps))
354
+ max_frames_6s = ((max_frames_6s - 1) // 4) * 4 + 1
355
+ target_frames = min(F, max_frames_6s)
356
+ if F > max_frames_6s:
357
+ print(f"Fast mode: Trimming single audio from {F} frames to {max_frames_6s} frames (6 seconds)")
358
+ # Use dynamically determined frame number
359
  audio_emb, audio_length = preprocess_audio(
360
  wav_path=audio,
361
+ num_generated_frames_per_clip=target_frames, # Use target frames (may be trimmed)
362
  fps=fps,
363
  wav2vec_model=wav2vec_model,
364
  vocal_separator_model=vocal_separator_model,
 
369
  audio_dtype = half_dtype if use_half else torch.bfloat16
370
  audio_emb = audio_emb.to(device, dtype=audio_dtype)
371
 
372
+ # Ensure we don't exceed F frames (for consistency with other tensors)
373
+ audio_feat = audio_emb[:F] # Use F to maintain consistency
374
  audio_feat_list.append(audio_feat)
375
  print(f"Single audio processed, shape: {audio_feat.shape}")
376
  else: