forge / run_audio_generator.py
arcacolab's picture
Upload run_audio_generator.py
6503586 verified
# V2A (Video-to-Audio) ๋ฐฑ์—”๋“œ ์Šคํฌ๋ฆฝํŠธ (VRAM ์ตœ์ ํ™”)
import sys
import os
import time
import glob
import gc
import torch
import subprocess
import random
import argparse
import shutil
from typing import Sequence, Mapping, Any, Union
# --- 0. ๊ธฐ๋ณธ ํ—ฌํผ ํ•จ์ˆ˜ (I2V ์Šคํฌ๋ฆฝํŠธ์™€ ๋™์ผ) ---
def to_bool(s: str) -> bool:
return s.lower() in ['true', '1', 't', 'y', 'yes', 'on']
def clear_memory():
"""VRAM ๋ฐ RAM ์บ์‹œ๋ฅผ ์ •๋ฆฌํ•ฉ๋‹ˆ๋‹ค."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
COMFYUI_BASE_PATH = '/content/ComfyUI'
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
""" ComfyUI ๋…ธ๋“œ ์ถœ๋ ฅ์—์„œ ๊ฐ’์„ ์•ˆ์ „ํ•˜๊ฒŒ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค. """
try:
return obj[index]
except (KeyError, TypeError):
if isinstance(obj, dict) and "result" in obj:
return obj["result"][index]
raise
def add_comfyui_directory_to_sys_path() -> None:
""" ComfyUI ๊ฒฝ๋กœ๋ฅผ sys.path์— ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค. """
if os.path.isdir(COMFYUI_BASE_PATH) and COMFYUI_BASE_PATH not in sys.path:
sys.path.append(COMFYUI_BASE_PATH)
print(f"'{COMFYUI_BASE_PATH}' added to sys.path")
def import_custom_nodes() -> None:
"""
ComfyUI ์ปค์Šคํ…€ ๋…ธ๋“œ๋ฅผ ๋กœ๋“œํ•˜๊ธฐ ์œ„ํ•ด ๋น„๋™๊ธฐ ํ™˜๊ฒฝ์„ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
(I2V ์Šคํฌ๋ฆฝํŠธ์˜ import_custom_nodes์™€ ๋™์ผ)
"""
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
print("nest_asyncio not found, installing...")
try:
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "nest_asyncio"], check=True)
import nest_asyncio
nest_asyncio.apply()
print("nest_asyncio installed and applied.")
except Exception as e:
print(f"Failed to install or apply nest_asyncio: {e}")
import asyncio, execution, server
from nodes import init_extra_nodes
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server_instance = server.PromptServer(loop)
execution.PromptQueue(server_instance)
if not loop.is_running():
try:
loop.run_until_complete(init_extra_nodes())
except RuntimeError as e:
print(f"Note: Could not run init_extra_nodes synchronously: {e}")
try: asyncio.ensure_future(init_extra_nodes())
except Exception as fut_e: print(f"Error trying async init_extra_nodes: {fut_e}")
else:
try: asyncio.ensure_future(init_extra_nodes())
except Exception as fut_e: print(f"Error trying async init_extra_nodes on running loop: {fut_e}")
# --- 1. Gradio UI์—์„œ ๋ชจ๋“  ์ธ์ˆ˜๋ฅผ ๋ฐ›๊ธฐ ์œ„ํ•œ ArgParser ---
def parse_args():
parser = argparse.ArgumentParser(description="ComfyUI V2A (Video-to-Audio) Generation Script")
# 1. ์ผ๋ฐ˜ ์„ค์ • (Input/Prompt)
parser.add_argument("--input_video_path", type=str, required=True, help="์˜ค๋””์˜ค๋ฅผ ์ƒ์„ฑํ•  ์ž…๋ ฅ ๋น„๋””์˜ค ํŒŒ์ผ ๊ฒฝ๋กœ")
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--negative_prompt", type=str, default="")
# 2. ๊ณ ๊ธ‰ ์„ค์ • - ์ƒ˜ํ”Œ๋ง (MMAudio Sampler)
parser.add_argument("--steps", type=int, default=25)
parser.add_argument("--cfg", type=float, default=4.5)
parser.add_argument("--seed", type=int, default=-1)
parser.add_argument("--mask_away_clip", type=str, default="off") # bool
parser.add_argument("--force_offload", type=str, default="off") # bool
# 3. ๊ณ ๊ธ‰ ์„ค์ • - ๋ชจ๋ธ (Loaders)
parser.add_argument("--mmaudio_model", type=str, default="mmaudio_large_44k_v2_fp16.safetensors")
parser.add_argument("--base_precision", type=str, default="fp16")
parser.add_argument("--vae_model", type=str, default="mmaudio_vae_44k_fp16.safetensors")
parser.add_argument("--synchformer_model", type=str, default="mmaudio_synchformer_fp16.safetensors")
parser.add_argument("--clip_model", type=str, default="apple_DFN5B-CLIP-ViT-H-14-384_fp16.safetensors")
parser.add_argument("--mode", type=str, default="44k")
parser.add_argument("--precision", type=str, default="fp16", help="Feature Utils Precision")
# 4. ๊ณ ๊ธ‰ ์„ค์ • - ๋น„๋””์˜ค ๋กœ๋”ฉ (VHS LoadVideo)
parser.add_argument("--force_rate", type=int, default=0)
parser.add_argument("--custom_width", type=int, default=0)
parser.add_argument("--custom_height", type=int, default=0)
parser.add_argument("--frame_load_cap", type=int, default=0)
parser.add_argument("--skip_first_frames", type=int, default=0)
parser.add_argument("--select_every_nth", type=int, default=1)
parser.add_argument("--load_format", type=str, default="AnimateDiff")
# 5. ๊ณ ๊ธ‰ ์„ค์ • - ๋น„๋””์˜ค ๊ฒฐํ•ฉ (VHS VideoCombine)
parser.add_argument("--loop_count", type=int, default=0)
parser.add_argument("--filename_prefix", type=str, default="MMaudio")
parser.add_argument("--combine_format", type=str, default="video/h264-mp4")
parser.add_argument("--pix_fmt", type=str, default="yuv420p")
parser.add_argument("--crf", type=int, default=19)
parser.add_argument("--save_metadata", type=str, default="on") # bool
parser.add_argument("--trim_to_audio", type=str, default="off") # bool
parser.add_argument("--pingpong", type=str, default="off") # bool
return parser.parse_args()
# --- 2. VRAM ์ตœ์ ํ™”๋œ ๋ฉ”์ธ ์‹คํ–‰ ํ•จ์ˆ˜ ---
# --- 2. VRAM ์ตœ์ ํ™”๋œ ๋ฉ”์ธ ์‹คํ–‰ ํ•จ์ˆ˜ ---
def main():
args = parse_args()
print("๐Ÿš€ V2A ์˜ค๋””์˜ค ์ƒ์„ฑ์„ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค (VRAM Optimized)...")
# --- ํ™˜๊ฒฝ ์„ค์ • ---
add_comfyui_directory_to_sys_path()
try:
from utils.extra_config import load_extra_path_config
except ImportError:
print("โš ๏ธ ComfyUI์˜ extra_model_paths.yaml ๋กœ๋”ฉ ์‹คํŒจ (๋ฌด์‹œํ•˜๊ณ  ์ง„ํ–‰)")
load_extra_path_config = lambda x: None
extra_model_paths_file = os.path.join(COMFYUI_BASE_PATH, "extra_model_paths.yaml")
if os.path.exists(extra_model_paths_file):
load_extra_path_config(extra_model_paths_file)
print("ComfyUI ์ปค์Šคํ…€ ๋…ธ๋“œ ์ดˆ๊ธฐํ™” ์ค‘...")
import_custom_nodes()
from nodes import NODE_CLASS_MAPPINGS
print("์ปค์Šคํ…€ ๋…ธ๋“œ ์ดˆ๊ธฐํ™” ์™„๋ฃŒ.")
# --- ๋…ธ๋“œ ํด๋ž˜์Šค ์ธ์Šคํ„ด์Šคํ™” ---
mmaudiomodelloader = NODE_CLASS_MAPPINGS["MMAudioModelLoader"]()
vhs_loadvideo = NODE_CLASS_MAPPINGS["VHS_LoadVideo"]()
mmaudiofeatureutilsloader = NODE_CLASS_MAPPINGS["MMAudioFeatureUtilsLoader"]()
vhs_videoinfo = NODE_CLASS_MAPPINGS["VHS_VideoInfo"]()
mmaudiosampler = NODE_CLASS_MAPPINGS["MMAudioSampler"]()
vhs_videocombine = NODE_CLASS_MAPPINGS["VHS_VideoCombine"]()
# --- ์‹œ๋“œ ์„ค์ • ---
if args.seed == -1:
final_seed = random.randint(1, 2**64)
print(f" - ๋žœ๋ค ์‹œ๋“œ ์ƒ์„ฑ: {final_seed}")
else:
final_seed = args.seed
print(f" - ๊ณ ์ • ์‹œ๋“œ ์‚ฌ์šฉ: {final_seed}")
# --- VRAM ์ตœ์ ํ™” ํŒŒ์ดํ”„๋ผ์ธ ---
with torch.inference_mode():
# โœจ [์ˆ˜์ •๋จ] 1๋‹จ๊ณ„: ์˜ค๋””์˜ค ์ƒ์„ฑ์„ ์œ„ํ•œ ๋น„๋””์˜ค ๋กœ๋“œ (25 FPS ๊ฐ•์ œ)
print(f"\n1๋‹จ๊ณ„: ์˜ค๋””์˜ค ์ƒ์„ฑ์„ ์œ„ํ•œ ๋น„๋””์˜ค ๋กœ๋“œ (25 FPS ๊ฐ•์ œ)... ({args.input_video_path})")
vhs_loadvideo_91_audio = vhs_loadvideo.load_video(
video=args.input_video_path,
force_rate=25, # โœจ ์˜ค๋””์˜ค์šฉ 25 FPS ๊ฐ•์ œ
custom_width=args.custom_width,
custom_height=args.custom_height,
frame_load_cap=args.frame_load_cap,
skip_first_frames=args.skip_first_frames,
select_every_nth=args.select_every_nth,
format=args.load_format,
unique_id=random.randint(1, 2**64)
)
images_for_audio = get_value_at_index(vhs_loadvideo_91_audio, 0) # โœจ ์˜ค๋””์˜ค์šฉ ์ด๋ฏธ์ง€
# ์›๋ณธ ๋น„๋””์˜ค ์ •๋ณด ์ถ”์ถœ (์žฌ์ƒ ์‹œ๊ฐ„, ์›๋ณธ FPS ๋“ฑ)
vhs_videoinfo_105 = vhs_videoinfo.get_video_info(
video_info=get_value_at_index(vhs_loadvideo_91_audio, 3)
)
del vhs_loadvideo_91_audio # ํ…์„œ ๋กœ๋”๋Š” ์ฆ‰์‹œ ์‚ญ์ œ
duration = get_value_at_index(vhs_videoinfo_105, 7)
original_frame_rate = get_value_at_index(vhs_videoinfo_105, 0) # โœจ ์ตœ์ข… ๋น„๋””์˜ค์— ์‚ฌ์šฉํ•  ์›๋ณธ FPS
print(f" - ๋น„๋””์˜ค ์ •๋ณด: {duration}์ดˆ, ์›๋ณธ {original_frame_rate} FPS")
clear_memory()
# 2๋‹จ๊ณ„: ์˜ค๋””์˜ค ๋ชจ๋ธ ๋กœ๋“œ
print(f"\n2๋‹จ๊ณ„: ์˜ค๋””์˜ค ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
print(f" - MMAudio ๋ชจ๋ธ: {args.mmaudio_model} ({args.base_precision})")
mmaudiomodelloader_85 = mmaudiomodelloader.loadmodel(
mmaudio_model=args.mmaudio_model,
base_precision=args.base_precision
)
mmaudio_model = get_value_at_index(mmaudiomodelloader_85, 0)
print(f" - ์œ ํ‹ธ๋ฆฌํ‹ฐ ๋ชจ๋ธ: (Mode: {args.mode}, Precision: {args.precision})")
mmaudiofeatureutilsloader_102 = mmaudiofeatureutilsloader.loadmodel(
vae_model=args.vae_model,
synchformer_model=args.synchformer_model,
clip_model=args.clip_model,
mode=args.mode,
precision=args.precision
)
feature_utils = get_value_at_index(mmaudiofeatureutilsloader_102, 0)
# 3๋‹จ๊ณ„: ์˜ค๋””์˜ค ์ƒ์„ฑ (์ƒ˜ํ”Œ๋ง)
print(f"\n3๋‹จ๊ณ„: ์˜ค๋””์˜ค ์ƒ์„ฑ ์ค‘... (Steps: {args.steps}, CFG: {args.cfg})")
mmaudiosampler_92 = mmaudiosampler.sample(
duration=duration,
steps=args.steps,
cfg=args.cfg,
seed=final_seed,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
mask_away_clip=to_bool(args.mask_away_clip),
force_offload=to_bool(args.force_offload),
mmaudio_model=mmaudio_model,
feature_utils=feature_utils,
images=images_for_audio # โœจ ์˜ค๋””์˜ค์šฉ ์ด๋ฏธ์ง€ ํ…์„œ ์‚ฌ์šฉ
)
generated_audio = get_value_at_index(mmaudiosampler_92, 0)
# โœจ [์ˆ˜์ •๋จ] 4๋‹จ๊ณ„: ๋ชจ๋ธ ๋ฐ ์˜ค๋””์˜ค์šฉ ์ด๋ฏธ์ง€ ๋ฉ”๋ชจ๋ฆฌ ํ•ด์ œ
print(f"\n4๋‹จ๊ณ„: ๋ชจ๋ธ ๋ฐ ์˜ค๋””์˜ค์šฉ ์ด๋ฏธ์ง€ ๋ฉ”๋ชจ๋ฆฌ ํ•ด์ œ ์ค‘...")
del mmaudiomodelloader_85, mmaudio_model, mmaudiofeatureutilsloader_102, feature_utils
del images_for_audio # โœจ ์˜ค๋””์˜ค์šฉ ์ด๋ฏธ์ง€ ํ…์„œ ์‚ญ์ œ (VRAM ํ™•๋ณด)
clear_memory()
# โœจ [์ˆ˜์ •๋จ] 5๋‹จ๊ณ„: ๋น„๋””์˜ค ๊ฒฐํ•ฉ์„ ์œ„ํ•œ ์›๋ณธ ๋น„๋””์˜ค ๋กœ๋“œ
print(f"\n5๋‹จ๊ณ„: ๋น„๋””์˜ค ๊ฒฐํ•ฉ์„ ์œ„ํ•œ ์›๋ณธ ๋น„๋””์˜ค ๋กœ๋“œ (์‚ฌ์šฉ์ž ์„ค์ • FPS: {args.force_rate})...")
vhs_loadvideo_91_combine = vhs_loadvideo.load_video(
video=args.input_video_path,
force_rate=args.force_rate, # โœจ ์‚ฌ์šฉ์ž์˜ ์›๋ณธ FPS ์„ค์ • ์‚ฌ์šฉ (๋ณดํ†ต 0)
custom_width=args.custom_width,
custom_height=args.custom_height,
frame_load_cap=args.frame_load_cap,
skip_first_frames=args.skip_first_frames,
select_every_nth=args.select_every_nth,
format=args.load_format,
unique_id=random.randint(1, 2**64) # ๋‹ค๋ฅธ ID ์‚ฌ์šฉ
)
images_for_combine = get_value_at_index(vhs_loadvideo_91_combine, 0) # โœจ ๊ฒฐํ•ฉ์šฉ ์ด๋ฏธ์ง€
del vhs_loadvideo_91_combine
clear_memory()
# โœจ [์ˆ˜์ •๋จ] 6๋‹จ๊ณ„: ๋น„๋””์˜ค์™€ ์˜ค๋””์˜ค ๊ฒฐํ•ฉ ๋ฐ ์ €์žฅ
print(f"\n6๋‹จ๊ณ„: ๋น„๋””์˜ค + ์˜ค๋””์˜ค ๊ฒฐํ•ฉ ๋ฐ ์ €์žฅ ์ค‘...")
timestamp = time.strftime("%Y%m%d-%H%M%S")
final_filename_prefix = f"{args.filename_prefix}_{timestamp}"
vhs_videocombine_97 = vhs_videocombine.combine_video(
frame_rate=original_frame_rate, # โœจ 1๋‹จ๊ณ„์—์„œ ์ถ”์ถœํ•œ "์›๋ณธ" FPS ์‚ฌ์šฉ
loop_count=args.loop_count,
filename_prefix=final_filename_prefix,
format=args.combine_format,
pix_fmt=args.pix_fmt,
crf=args.crf,
save_metadata=to_bool(args.save_metadata),
trim_to_audio=to_bool(args.trim_to_audio),
pingpong=to_bool(args.pingpong),
save_output=True, # <<< ํŒŒ์ผ ์ €์žฅ์„ ์œ„ํ•ด True๋กœ ์„ค์ •
images=images_for_combine, # โœจ ๊ฒฐํ•ฉ์šฉ ์ด๋ฏธ์ง€ ํ…์„œ ์‚ฌ์šฉ
audio=generated_audio,
unique_id=random.randint(1, 2**64)
)
# โœจ [์ถ”๊ฐ€] ๊ฒฐํ•ฉ ํ›„ ์ฆ‰์‹œ ํ…์„œ ์‚ญ์ œ
del images_for_combine, generated_audio
clear_memory()
# 7๋‹จ๊ณ„: Gradio UI๋กœ ๋ฐ˜ํ™˜ํ•  ์ตœ์ข… ํŒŒ์ผ ๊ฒฝ๋กœ ์ถœ๋ ฅ
try:
# ๋””๋ฒ„๊ทธ ๊ฒฐ๊ณผ ๊ธฐ๋ฐ˜: ๋ฐ˜ํ™˜๊ฐ’ {'result': ((True, [workflow.png, video.mp4, video-audio.mp4]),)}
# ์šฐ๋ฆฌ๊ฐ€ ํ•„์š”ํ•œ ๊ฒฝ๋กœ๋Š” ๋ฆฌ์ŠคํŠธ์˜ 3๋ฒˆ์งธ ํ•ญ๋ชฉ(์ธ๋ฑ์Šค 2)์ž…๋‹ˆ๋‹ค.
file_path_list = vhs_videocombine_97['result'][0][1]
final_video_path = file_path_list[2] # ์˜ค๋””์˜ค๊ฐ€ ํฌํ•จ๋œ ์ตœ์ข… ๋น„๋””์˜ค ๊ฒฝ๋กœ
except Exception as e:
print(f"โŒ [์˜ค๋ฅ˜] ์ตœ์ข… ํŒŒ์ผ ๊ฒฝ๋กœ๋ฅผ ์ถ”์ถœํ•˜๋Š” ๋ฐ ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค: {e}")
print(f" - ์ „์ฒด ๋ฐ˜ํ™˜๊ฐ’: {vhs_videocombine_97}")
final_video_path = None
if final_video_path and os.path.exists(final_video_path):
print(f"โœ… ์˜ค๋””์˜ค ์ƒ์„ฑ ๋ฐ ๊ฒฐํ•ฉ ์™„๋ฃŒ!")
print(f"LATEST_VIDEO_PATH:{final_video_path}")
# I2V ์Šคํฌ๋ฆฝํŠธ์™€ ๋™์ผํ•œ ์ถœ๋ ฅ์„ ์œ„ํ•ด ์›๋ณธ ๋ณต์‚ฌ๋ณธ ์ƒ์„ฑ
base, ext = os.path.splitext(final_video_path)
original_copy_path = f"{base}_original{ext}"
try:
shutil.copy2(final_video_path, original_copy_path)
print(f"โœ… ์›๋ณธ ๋ณต์‚ฌ๋ณธ ์ƒ์„ฑ ์™„๋ฃŒ: {original_copy_path}")
print(f"ORIGINAL_COPY_PATH:{original_copy_path}")
except Exception as e:
print(f"โŒ ์›๋ณธ ๋ณต์‚ฌ๋ณธ ์ƒ์„ฑ ์‹คํŒจ: {e}")
else:
print(f"โŒ ์ตœ์ข… ๋น„๋””์˜ค ํŒŒ์ผ ๊ฒฝ๋กœ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
if __name__ == "__main__":
main()