Spaces:
Running
on
Zero
Running
on
Zero
feat: time out check
Browse files- app.py +127 -6
- wan/audio2video_multiID.py +2 -0
- wan/utils/infer_utils.py +53 -7
app.py
CHANGED
|
@@ -11,6 +11,7 @@ import spaces
|
|
| 11 |
warnings.filterwarnings('ignore')
|
| 12 |
|
| 13 |
import random
|
|
|
|
| 14 |
import torch
|
| 15 |
import torch.distributed as dist
|
| 16 |
from PIL import Image
|
|
@@ -435,7 +436,7 @@ def run_graio_demo(args):
|
|
| 435 |
logging.info("Model and face processor loaded successfully.")
|
| 436 |
|
| 437 |
def generate_video(img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3,
|
| 438 |
-
sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector, fixed_steps=None):
|
| 439 |
# 参考 LivePortrait: 在 worker 进程中直接使用 cuda 设备
|
| 440 |
# 参考: https://huggingface.co/spaces/KlingTeam/LivePortrait/blob/main/src/gradio_pipeline.py
|
| 441 |
# @spaces.GPU 装饰器已经初始化了 GPU,这里直接使用即可
|
|
@@ -480,7 +481,18 @@ def run_graio_demo(args):
|
|
| 480 |
if audio_paths and len(audio_paths) > 0:
|
| 481 |
# 使用 cfg 中的 fps,如果不可用则使用默认值 24
|
| 482 |
fps = getattr(cfg, 'fps', 24)
|
| 483 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
logging.info(f"Dynamically determined frame number: {current_frame_num} (mode: {audio_mode_selector})")
|
| 485 |
else:
|
| 486 |
# 没有音频时使用默认帧数
|
|
@@ -519,6 +531,7 @@ def run_graio_demo(args):
|
|
| 519 |
audio_paths=audio_paths,
|
| 520 |
task_key="gradio_output",
|
| 521 |
mode=audio_mode_selector,
|
|
|
|
| 522 |
)
|
| 523 |
|
| 524 |
if isinstance(video, dict):
|
|
@@ -610,6 +623,63 @@ def run_graio_demo(args):
|
|
| 610 |
def gpu_wrapped_generate_video_fast(*args, **kwargs):
|
| 611 |
# 固定使用10步去噪,通过关键字参数传递
|
| 612 |
kwargs['fixed_steps'] = 10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
return gpu_wrapped_generate_video_worker(*args, **kwargs)
|
| 614 |
|
| 615 |
# 高质量生成模式:780秒,用户选择去噪步数
|
|
@@ -666,7 +736,8 @@ def run_graio_demo(args):
|
|
| 666 |
except Exception as e:
|
| 667 |
logging.warning(f"Failed to move models to GPU: {e}")
|
| 668 |
|
| 669 |
-
|
|
|
|
| 670 |
|
| 671 |
|
| 672 |
|
|
@@ -815,22 +886,72 @@ def run_graio_demo(args):
|
|
| 815 |
)
|
| 816 |
|
| 817 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 818 |
# 快速生成按钮:210秒,固定10步
|
| 819 |
run_i2v_button_fast.click(
|
| 820 |
-
fn=
|
| 821 |
inputs=[img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3, sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector],
|
| 822 |
outputs=[result_gallery],
|
| 823 |
)
|
| 824 |
|
| 825 |
# 高质量生成按钮:780秒,用户选择步数
|
| 826 |
run_i2v_button_quality.click(
|
| 827 |
-
fn=
|
| 828 |
inputs=[img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3, sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector],
|
| 829 |
outputs=[result_gallery],
|
| 830 |
)
|
| 831 |
# 参考 Meigen-MultiTalk 的成功配置
|
| 832 |
# 在 Hugging Face Spaces 上,Gradio 会自动处理端口和服务器配置
|
| 833 |
-
demo.queue(max_size=
|
| 834 |
|
| 835 |
|
| 836 |
|
|
|
|
| 11 |
warnings.filterwarnings('ignore')
|
| 12 |
|
| 13 |
import random
|
| 14 |
+
import math
|
| 15 |
import torch
|
| 16 |
import torch.distributed as dist
|
| 17 |
from PIL import Image
|
|
|
|
| 436 |
logging.info("Model and face processor loaded successfully.")
|
| 437 |
|
| 438 |
def generate_video(img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3,
|
| 439 |
+
sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector, fixed_steps=None, trim_to_6s=False):
|
| 440 |
# 参考 LivePortrait: 在 worker 进程中直接使用 cuda 设备
|
| 441 |
# 参考: https://huggingface.co/spaces/KlingTeam/LivePortrait/blob/main/src/gradio_pipeline.py
|
| 442 |
# @spaces.GPU 装饰器已经初始化了 GPU,这里直接使用即可
|
|
|
|
| 481 |
if audio_paths and len(audio_paths) > 0:
|
| 482 |
# 使用 cfg 中的 fps,如果不可用则使用默认值 24
|
| 483 |
fps = getattr(cfg, 'fps', 24)
|
| 484 |
+
calculated_frame_num = calculate_frame_num_from_audio(audio_paths, fps, mode=audio_mode_selector)
|
| 485 |
+
|
| 486 |
+
# Fast模式:如果trim_to_6s为True,强制限制为6秒对应的帧数
|
| 487 |
+
if trim_to_6s:
|
| 488 |
+
# 计算6秒对应的帧数(4n+1格式)
|
| 489 |
+
max_frames_6s = int(math.ceil(6.0 * fps))
|
| 490 |
+
max_frames_6s = ((max_frames_6s - 1) // 4) * 4 + 1
|
| 491 |
+
current_frame_num = min(calculated_frame_num, max_frames_6s)
|
| 492 |
+
logging.warning(f"Fast mode: Audio duration exceeds 6 seconds. Trimming to 6 seconds ({max_frames_6s} frames). Original: {calculated_frame_num} frames")
|
| 493 |
+
else:
|
| 494 |
+
current_frame_num = calculated_frame_num
|
| 495 |
+
|
| 496 |
logging.info(f"Dynamically determined frame number: {current_frame_num} (mode: {audio_mode_selector})")
|
| 497 |
else:
|
| 498 |
# 没有音频时使用默认帧数
|
|
|
|
| 531 |
audio_paths=audio_paths,
|
| 532 |
task_key="gradio_output",
|
| 533 |
mode=audio_mode_selector,
|
| 534 |
+
trim_to_6s=trim_to_6s,
|
| 535 |
)
|
| 536 |
|
| 537 |
if isinstance(video, dict):
|
|
|
|
| 623 |
def gpu_wrapped_generate_video_fast(*args, **kwargs):
|
| 624 |
# 固定使用10步去噪,通过关键字参数传递
|
| 625 |
kwargs['fixed_steps'] = 10
|
| 626 |
+
|
| 627 |
+
# Fast模式音频长度检测:检查是否超过6秒
|
| 628 |
+
# 参数顺序: img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3,
|
| 629 |
+
# sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector
|
| 630 |
+
if len(args) >= 11:
|
| 631 |
+
img2vid_image = args[0]
|
| 632 |
+
img2vid_prompt = args[1]
|
| 633 |
+
n_prompt = args[2]
|
| 634 |
+
img2vid_audio_1 = args[3]
|
| 635 |
+
img2vid_audio_2 = args[4]
|
| 636 |
+
img2vid_audio_3 = args[5]
|
| 637 |
+
sd_steps = args[6]
|
| 638 |
+
seed = args[7]
|
| 639 |
+
guide_scale = args[8]
|
| 640 |
+
person_num_selector = args[9]
|
| 641 |
+
audio_mode_selector = args[10]
|
| 642 |
+
|
| 643 |
+
# 根据人数收集音频路径
|
| 644 |
+
audio_paths = []
|
| 645 |
+
if person_num_selector == "1 Person":
|
| 646 |
+
if img2vid_audio_1:
|
| 647 |
+
audio_paths.append(img2vid_audio_1)
|
| 648 |
+
elif person_num_selector == "2 Persons":
|
| 649 |
+
if img2vid_audio_1:
|
| 650 |
+
audio_paths.append(img2vid_audio_1)
|
| 651 |
+
if img2vid_audio_2:
|
| 652 |
+
audio_paths.append(img2vid_audio_2)
|
| 653 |
+
elif person_num_selector == "3 Persons":
|
| 654 |
+
if img2vid_audio_1:
|
| 655 |
+
audio_paths.append(img2vid_audio_1)
|
| 656 |
+
if img2vid_audio_2:
|
| 657 |
+
audio_paths.append(img2vid_audio_2)
|
| 658 |
+
if img2vid_audio_3:
|
| 659 |
+
audio_paths.append(img2vid_audio_3)
|
| 660 |
+
|
| 661 |
+
# 检测音频长度是否超过6秒
|
| 662 |
+
if audio_paths and len(audio_paths) > 0:
|
| 663 |
+
fps = getattr(cfg, 'fps', 24)
|
| 664 |
+
try:
|
| 665 |
+
calculated_frame_num = calculate_frame_num_from_audio(audio_paths, fps, mode=audio_mode_selector)
|
| 666 |
+
# 计算6秒对应的帧数
|
| 667 |
+
max_frames_6s = int(math.ceil(6.0 * fps))
|
| 668 |
+
max_frames_6s = ((max_frames_6s - 1) // 4) * 4 + 1
|
| 669 |
+
|
| 670 |
+
if calculated_frame_num > max_frames_6s:
|
| 671 |
+
# 超过6秒,设置trim_to_6s标记
|
| 672 |
+
kwargs['trim_to_6s'] = True
|
| 673 |
+
calculated_duration = calculated_frame_num / fps
|
| 674 |
+
logging.warning(f"Fast mode: Audio duration ({calculated_duration:.2f}s) exceeds 6 seconds limit. Will trim to 6 seconds.")
|
| 675 |
+
else:
|
| 676 |
+
kwargs['trim_to_6s'] = False
|
| 677 |
+
except Exception as e:
|
| 678 |
+
logging.warning(f"Failed to check audio duration: {e}")
|
| 679 |
+
kwargs['trim_to_6s'] = False
|
| 680 |
+
else:
|
| 681 |
+
kwargs['trim_to_6s'] = False
|
| 682 |
+
|
| 683 |
return gpu_wrapped_generate_video_worker(*args, **kwargs)
|
| 684 |
|
| 685 |
# 高质量生成模式:780秒,用户选择去噪步数
|
|
|
|
| 736 |
except Exception as e:
|
| 737 |
logging.warning(f"Failed to move models to GPU: {e}")
|
| 738 |
|
| 739 |
+
result = generate_video(*args, **kwargs)
|
| 740 |
+
return result
|
| 741 |
|
| 742 |
|
| 743 |
|
|
|
|
| 886 |
)
|
| 887 |
|
| 888 |
|
| 889 |
+
# 包装函数:处理警告信息显示
|
| 890 |
+
def handle_fast_generation(img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3,
|
| 891 |
+
sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector):
|
| 892 |
+
# 在开始生成前先检测音频长度,如果超过6秒立即显示警告
|
| 893 |
+
# 根据人数收集音频路径
|
| 894 |
+
audio_paths = []
|
| 895 |
+
if person_num_selector == "1 Person":
|
| 896 |
+
if img2vid_audio_1:
|
| 897 |
+
audio_paths.append(img2vid_audio_1)
|
| 898 |
+
elif person_num_selector == "2 Persons":
|
| 899 |
+
if img2vid_audio_1:
|
| 900 |
+
audio_paths.append(img2vid_audio_1)
|
| 901 |
+
if img2vid_audio_2:
|
| 902 |
+
audio_paths.append(img2vid_audio_2)
|
| 903 |
+
elif person_num_selector == "3 Persons":
|
| 904 |
+
if img2vid_audio_1:
|
| 905 |
+
audio_paths.append(img2vid_audio_1)
|
| 906 |
+
if img2vid_audio_2:
|
| 907 |
+
audio_paths.append(img2vid_audio_2)
|
| 908 |
+
if img2vid_audio_3:
|
| 909 |
+
audio_paths.append(img2vid_audio_3)
|
| 910 |
+
|
| 911 |
+
# 检测音频长度是否超过6秒
|
| 912 |
+
if audio_paths and len(audio_paths) > 0:
|
| 913 |
+
fps = getattr(cfg, 'fps', 24)
|
| 914 |
+
try:
|
| 915 |
+
calculated_frame_num = calculate_frame_num_from_audio(audio_paths, fps, mode=audio_mode_selector)
|
| 916 |
+
# 计算6秒对应的帧数
|
| 917 |
+
max_frames_6s = int(math.ceil(6.0 * fps))
|
| 918 |
+
max_frames_6s = ((max_frames_6s - 1) // 4) * 4 + 1
|
| 919 |
+
|
| 920 |
+
if calculated_frame_num > max_frames_6s:
|
| 921 |
+
# 超过6秒,立即显示警告
|
| 922 |
+
calculated_duration = calculated_frame_num / fps
|
| 923 |
+
warning_msg = f"⚠️ Warning: Your audio duration ({calculated_duration:.2f}s) exceeds the 6-second limit for Fast Mode. The audio will be automatically trimmed to 6 seconds to prevent timeout."
|
| 924 |
+
gr.Warning(warning_msg, duration=5)
|
| 925 |
+
except Exception as e:
|
| 926 |
+
logging.warning(f"Failed to check audio duration: {e}")
|
| 927 |
+
|
| 928 |
+
# 继续执行视频生成
|
| 929 |
+
result = gpu_wrapped_generate_video_fast(
|
| 930 |
+
img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3,
|
| 931 |
+
sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector
|
| 932 |
+
)
|
| 933 |
+
return result
|
| 934 |
+
|
| 935 |
+
def handle_quality_generation(*args):
|
| 936 |
+
result = gpu_wrapped_generate_video_quality(*args)
|
| 937 |
+
return result
|
| 938 |
+
|
| 939 |
# 快速生成按钮:210秒,固定10步
|
| 940 |
run_i2v_button_fast.click(
|
| 941 |
+
fn=handle_fast_generation,
|
| 942 |
inputs=[img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3, sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector],
|
| 943 |
outputs=[result_gallery],
|
| 944 |
)
|
| 945 |
|
| 946 |
# 高质量生成按钮:780秒,用户选择步数
|
| 947 |
run_i2v_button_quality.click(
|
| 948 |
+
fn=handle_quality_generation,
|
| 949 |
inputs=[img2vid_image, img2vid_prompt, n_prompt, img2vid_audio_1, img2vid_audio_2, img2vid_audio_3, sd_steps, seed, guide_scale, person_num_selector, audio_mode_selector],
|
| 950 |
outputs=[result_gallery],
|
| 951 |
)
|
| 952 |
# 参考 Meigen-MultiTalk 的成功配置
|
| 953 |
# 在 Hugging Face Spaces 上,Gradio 会自动处理端口和服务器配置
|
| 954 |
+
demo.queue(max_size=10).launch(show_error=True)
|
| 955 |
|
| 956 |
|
| 957 |
|
wan/audio2video_multiID.py
CHANGED
|
@@ -199,6 +199,7 @@ class WanAF2V:
|
|
| 199 |
audio_paths=None, # New: audio path list, supports multiple audio files
|
| 200 |
task_key=None,
|
| 201 |
mode="pad", # Audio processing mode: "pad" or "concat"
|
|
|
|
| 202 |
):
|
| 203 |
r"""
|
| 204 |
Generates video frames from input image and text prompt using diffusion process.
|
|
@@ -514,6 +515,7 @@ class WanAF2V:
|
|
| 514 |
half_dtype=self.half_dtype,
|
| 515 |
preprocess_audio=preprocess_audio,
|
| 516 |
resample_audio=resample_audio,
|
|
|
|
| 517 |
)
|
| 518 |
|
| 519 |
# Prepare audio_ref_features - new list mode
|
|
|
|
| 199 |
audio_paths=None, # New: audio path list, supports multiple audio files
|
| 200 |
task_key=None,
|
| 201 |
mode="pad", # Audio processing mode: "pad" or "concat"
|
| 202 |
+
trim_to_6s=False, # Fast mode: trim audio to 6 seconds
|
| 203 |
):
|
| 204 |
r"""
|
| 205 |
Generates video frames from input image and text prompt using diffusion process.
|
|
|
|
| 515 |
half_dtype=self.half_dtype,
|
| 516 |
preprocess_audio=preprocess_audio,
|
| 517 |
resample_audio=resample_audio,
|
| 518 |
+
trim_to_6s=trim_to_6s,
|
| 519 |
)
|
| 520 |
|
| 521 |
# Prepare audio_ref_features - new list mode
|
wan/utils/infer_utils.py
CHANGED
|
@@ -118,6 +118,7 @@ def process_audio_features(
|
|
| 118 |
half_dtype=None,
|
| 119 |
preprocess_audio=None,
|
| 120 |
resample_audio=None,
|
|
|
|
| 121 |
):
|
| 122 |
"""
|
| 123 |
Process audio files and extract audio features.
|
|
@@ -202,6 +203,31 @@ def process_audio_features(
|
|
| 202 |
total_length = sum(audio_lengths)
|
| 203 |
print(f"Total audio length in concat mode (from processed frames): {total_length} frames")
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
# Ensure total length is in 4n+1 format (model requirement)
|
| 206 |
total_length = ((total_length - 1) // 4) * 4 + 1
|
| 207 |
print(f"Adjusted total length to 4n+1 format: {total_length} frames")
|
|
@@ -257,7 +283,7 @@ def process_audio_features(
|
|
| 257 |
audio_feat_list.append(zero_audio_feat)
|
| 258 |
print(f"Audio {i} is missing, created zero features with shape: {zero_audio_feat.shape}")
|
| 259 |
else:
|
| 260 |
-
# Pad mode: keep existing logic,
|
| 261 |
for i, audio_path in enumerate(audio_paths):
|
| 262 |
if audio_path and os.path.exists(audio_path):
|
| 263 |
print(f"Processing audio {i}: {audio_path}")
|
|
@@ -270,10 +296,19 @@ def process_audio_features(
|
|
| 270 |
with torch.no_grad():
|
| 271 |
print(f"wav2vec_model: {wav2vec_model}")
|
| 272 |
print(f"cache_dir:{cache_dir}")
|
| 273 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
audio_emb, audio_length = preprocess_audio(
|
| 275 |
wav_path=target_resampled_audio_path,
|
| 276 |
-
num_generated_frames_per_clip=
|
| 277 |
fps=fps,
|
| 278 |
wav2vec_model=wav2vec_model,
|
| 279 |
vocal_separator_model=vocal_separator_model,
|
|
@@ -284,7 +319,8 @@ def process_audio_features(
|
|
| 284 |
audio_dtype = half_dtype if use_half else torch.bfloat16
|
| 285 |
audio_emb = audio_emb.to(device, dtype=audio_dtype)
|
| 286 |
|
| 287 |
-
|
|
|
|
| 288 |
audio_feat_list.append(audio_feat)
|
| 289 |
print(f"Audio {i} processed, shape: {audio_feat.shape}")
|
| 290 |
else:
|
|
@@ -310,10 +346,19 @@ def process_audio_features(
|
|
| 310 |
target_resampled_audio_path,
|
| 311 |
)
|
| 312 |
with torch.no_grad():
|
| 313 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
audio_emb, audio_length = preprocess_audio(
|
| 315 |
wav_path=audio,
|
| 316 |
-
num_generated_frames_per_clip=
|
| 317 |
fps=fps,
|
| 318 |
wav2vec_model=wav2vec_model,
|
| 319 |
vocal_separator_model=vocal_separator_model,
|
|
@@ -324,7 +369,8 @@ def process_audio_features(
|
|
| 324 |
audio_dtype = half_dtype if use_half else torch.bfloat16
|
| 325 |
audio_emb = audio_emb.to(device, dtype=audio_dtype)
|
| 326 |
|
| 327 |
-
|
|
|
|
| 328 |
audio_feat_list.append(audio_feat)
|
| 329 |
print(f"Single audio processed, shape: {audio_feat.shape}")
|
| 330 |
else:
|
|
|
|
| 118 |
half_dtype=None,
|
| 119 |
preprocess_audio=None,
|
| 120 |
resample_audio=None,
|
| 121 |
+
trim_to_6s=False, # Fast mode: trim audio to 6 seconds
|
| 122 |
):
|
| 123 |
"""
|
| 124 |
Process audio files and extract audio features.
|
|
|
|
| 203 |
total_length = sum(audio_lengths)
|
| 204 |
print(f"Total audio length in concat mode (from processed frames): {total_length} frames")
|
| 205 |
|
| 206 |
+
# Fast mode: trim to 6 seconds if trim_to_6s is True
|
| 207 |
+
if trim_to_6s:
|
| 208 |
+
import math
|
| 209 |
+
# Calculate 6 seconds in frames
|
| 210 |
+
max_frames_6s = int(math.ceil(6.0 * fps))
|
| 211 |
+
max_frames_6s = ((max_frames_6s - 1) // 4) * 4 + 1
|
| 212 |
+
if total_length > max_frames_6s:
|
| 213 |
+
print(f"Fast mode: Trimming audio from {total_length} frames to {max_frames_6s} frames (6 seconds)")
|
| 214 |
+
# Truncate each audio proportionally
|
| 215 |
+
scale_factor = max_frames_6s / total_length
|
| 216 |
+
cumulative_length = 0
|
| 217 |
+
for i, audio_len in enumerate(audio_lengths):
|
| 218 |
+
if audio_len > 0:
|
| 219 |
+
new_audio_len = int(audio_len * scale_factor)
|
| 220 |
+
# Ensure it fits within remaining space
|
| 221 |
+
remaining_space = max_frames_6s - cumulative_length
|
| 222 |
+
new_audio_len = min(new_audio_len, remaining_space)
|
| 223 |
+
audio_lengths[i] = new_audio_len
|
| 224 |
+
# Truncate the corresponding raw audio feature
|
| 225 |
+
if raw_audio_feat_list[i] is not None:
|
| 226 |
+
raw_audio_feat_list[i] = raw_audio_feat_list[i][:new_audio_len]
|
| 227 |
+
cumulative_length += new_audio_len
|
| 228 |
+
total_length = sum(audio_lengths)
|
| 229 |
+
print(f"After trimming: total_length = {total_length} frames")
|
| 230 |
+
|
| 231 |
# Ensure total length is in 4n+1 format (model requirement)
|
| 232 |
total_length = ((total_length - 1) // 4) * 4 + 1
|
| 233 |
print(f"Adjusted total length to 4n+1 format: {total_length} frames")
|
|
|
|
| 283 |
audio_feat_list.append(zero_audio_feat)
|
| 284 |
print(f"Audio {i} is missing, created zero features with shape: {zero_audio_feat.shape}")
|
| 285 |
else:
|
| 286 |
+
# Pad mode: keep existing logic, but apply trim_to_6s if needed
|
| 287 |
for i, audio_path in enumerate(audio_paths):
|
| 288 |
if audio_path and os.path.exists(audio_path):
|
| 289 |
print(f"Processing audio {i}: {audio_path}")
|
|
|
|
| 296 |
with torch.no_grad():
|
| 297 |
print(f"wav2vec_model: {wav2vec_model}")
|
| 298 |
print(f"cache_dir:{cache_dir}")
|
| 299 |
+
# Fast mode: if trim_to_6s, limit to 6 seconds
|
| 300 |
+
target_frames = F
|
| 301 |
+
if trim_to_6s:
|
| 302 |
+
import math
|
| 303 |
+
max_frames_6s = int(math.ceil(6.0 * fps))
|
| 304 |
+
max_frames_6s = ((max_frames_6s - 1) // 4) * 4 + 1
|
| 305 |
+
target_frames = min(F, max_frames_6s)
|
| 306 |
+
if F > max_frames_6s:
|
| 307 |
+
print(f"Fast mode: Trimming audio {i} from {F} frames to {max_frames_6s} frames (6 seconds)")
|
| 308 |
+
# Use dynamically determined frame number
|
| 309 |
audio_emb, audio_length = preprocess_audio(
|
| 310 |
wav_path=target_resampled_audio_path,
|
| 311 |
+
num_generated_frames_per_clip=target_frames, # Use target frames (may be trimmed)
|
| 312 |
fps=fps,
|
| 313 |
wav2vec_model=wav2vec_model,
|
| 314 |
vocal_separator_model=vocal_separator_model,
|
|
|
|
| 319 |
audio_dtype = half_dtype if use_half else torch.bfloat16
|
| 320 |
audio_emb = audio_emb.to(device, dtype=audio_dtype)
|
| 321 |
|
| 322 |
+
# Ensure we don't exceed F frames (for consistency with other tensors)
|
| 323 |
+
audio_feat = audio_emb[:F] # Use F to maintain consistency
|
| 324 |
audio_feat_list.append(audio_feat)
|
| 325 |
print(f"Audio {i} processed, shape: {audio_feat.shape}")
|
| 326 |
else:
|
|
|
|
| 346 |
target_resampled_audio_path,
|
| 347 |
)
|
| 348 |
with torch.no_grad():
|
| 349 |
+
# Fast mode: if trim_to_6s, limit to 6 seconds
|
| 350 |
+
target_frames = F
|
| 351 |
+
if trim_to_6s:
|
| 352 |
+
import math
|
| 353 |
+
max_frames_6s = int(math.ceil(6.0 * fps))
|
| 354 |
+
max_frames_6s = ((max_frames_6s - 1) // 4) * 4 + 1
|
| 355 |
+
target_frames = min(F, max_frames_6s)
|
| 356 |
+
if F > max_frames_6s:
|
| 357 |
+
print(f"Fast mode: Trimming single audio from {F} frames to {max_frames_6s} frames (6 seconds)")
|
| 358 |
+
# Use dynamically determined frame number
|
| 359 |
audio_emb, audio_length = preprocess_audio(
|
| 360 |
wav_path=audio,
|
| 361 |
+
num_generated_frames_per_clip=target_frames, # Use target frames (may be trimmed)
|
| 362 |
fps=fps,
|
| 363 |
wav2vec_model=wav2vec_model,
|
| 364 |
vocal_separator_model=vocal_separator_model,
|
|
|
|
| 369 |
audio_dtype = half_dtype if use_half else torch.bfloat16
|
| 370 |
audio_emb = audio_emb.to(device, dtype=audio_dtype)
|
| 371 |
|
| 372 |
+
# Ensure we don't exceed F frames (for consistency with other tensors)
|
| 373 |
+
audio_feat = audio_emb[:F] # Use F to maintain consistency
|
| 374 |
audio_feat_list.append(audio_feat)
|
| 375 |
print(f"Single audio processed, shape: {audio_feat.shape}")
|
| 376 |
else:
|