Update app.py
Browse files
app.py
CHANGED
|
@@ -67,15 +67,25 @@ output_dir = Path('./output/gradio')
|
|
| 67 |
setup_eval_logging()
|
| 68 |
net, feature_utils, seq_cfg = get_model()
|
| 69 |
|
| 70 |
-
|
| 71 |
-
@torch.inference_mode()
|
| 72 |
def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
| 73 |
seed: int = -1, num_steps: int = 15,
|
| 74 |
-
cfg_strength: float = 4.0, target_duration: float =
|
| 75 |
try:
|
| 76 |
logger.info("Starting audio generation process")
|
| 77 |
torch.cuda.empty_cache()
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
rng = torch.Generator(device=device)
|
| 80 |
if seed >= 0:
|
| 81 |
rng.manual_seed(seed)
|
|
@@ -84,8 +94,8 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
|
| 84 |
|
| 85 |
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
| 86 |
|
| 87 |
-
# load_video
|
| 88 |
-
video_info = load_video(video_path, duration_sec=target_duration)
|
| 89 |
|
| 90 |
if video_info is None:
|
| 91 |
logger.error("Failed to load video")
|
|
@@ -99,16 +109,20 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
|
| 99 |
logger.error("Failed to extract frames from video")
|
| 100 |
return video_path
|
| 101 |
|
| 102 |
-
#
|
| 103 |
clip_frames = clip_frames[:int(actual_duration * video_info.fps)]
|
| 104 |
sync_frames = sync_frames[:int(actual_duration * video_info.fps)]
|
| 105 |
|
| 106 |
clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16)
|
| 107 |
sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16)
|
| 108 |
|
|
|
|
| 109 |
seq_cfg.duration = actual_duration
|
| 110 |
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
| 111 |
|
|
|
|
|
|
|
|
|
|
| 112 |
logger.info("Generating audio...")
|
| 113 |
with torch.cuda.amp.autocast():
|
| 114 |
audios = generate(clip_frames,
|
|
@@ -356,6 +370,15 @@ def generate_video(image, prompt):
|
|
| 356 |
|
| 357 |
final_path = add_watermark(output_path)
|
| 358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
# ์ค๋์ค ์ฒ๋ฆฌ ์ถ๊ฐ
|
| 360 |
try:
|
| 361 |
logger.info("Starting audio generation process")
|
|
@@ -365,8 +388,8 @@ def generate_video(image, prompt):
|
|
| 365 |
negative_prompt="music",
|
| 366 |
seed=-1,
|
| 367 |
num_steps=20,
|
| 368 |
-
cfg_strength=4.5
|
| 369 |
-
target_duration
|
| 370 |
)
|
| 371 |
|
| 372 |
if final_path_with_audio != final_path:
|
|
|
|
| 67 |
setup_eval_logging()
|
| 68 |
net, feature_utils, seq_cfg = get_model()
|
| 69 |
|
| 70 |
+
|
|
|
|
| 71 |
def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
| 72 |
seed: int = -1, num_steps: int = 15,
|
| 73 |
+
cfg_strength: float = 4.0, target_duration: float = None): # target_duration์ ์ ํ์ ์ผ๋ก ๋ณ๊ฒฝ
|
| 74 |
try:
|
| 75 |
logger.info("Starting audio generation process")
|
| 76 |
torch.cuda.empty_cache()
|
| 77 |
|
| 78 |
+
# ๋น๋์ค ๊ธธ์ด ํ์ธ
|
| 79 |
+
cap = cv2.VideoCapture(video_path)
|
| 80 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 81 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 82 |
+
video_duration = total_frames / fps
|
| 83 |
+
cap.release()
|
| 84 |
+
|
| 85 |
+
# ์ค์ ๋น๋์ค ๊ธธ์ด๋ฅผ target_duration์ผ๋ก ์ฌ์ฉ
|
| 86 |
+
target_duration = video_duration
|
| 87 |
+
logger.info(f"Video duration: {target_duration} seconds")
|
| 88 |
+
|
| 89 |
rng = torch.Generator(device=device)
|
| 90 |
if seed >= 0:
|
| 91 |
rng.manual_seed(seed)
|
|
|
|
| 94 |
|
| 95 |
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
| 96 |
|
| 97 |
+
# ๋น๋์ค ๊ธธ์ด์ ๋ง์ถฐ load_video ํธ์ถ
|
| 98 |
+
video_info = load_video(video_path, duration_sec=target_duration)
|
| 99 |
|
| 100 |
if video_info is None:
|
| 101 |
logger.error("Failed to load video")
|
|
|
|
| 109 |
logger.error("Failed to extract frames from video")
|
| 110 |
return video_path
|
| 111 |
|
| 112 |
+
# ์ค์ ๋น๋์ค ํ๋ ์ ์์ ๋ง์ถฐ ์กฐ์
|
| 113 |
clip_frames = clip_frames[:int(actual_duration * video_info.fps)]
|
| 114 |
sync_frames = sync_frames[:int(actual_duration * video_info.fps)]
|
| 115 |
|
| 116 |
clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16)
|
| 117 |
sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16)
|
| 118 |
|
| 119 |
+
# sequence config ์
๋ฐ์ดํธ
|
| 120 |
seq_cfg.duration = actual_duration
|
| 121 |
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
| 122 |
|
| 123 |
+
logger.info(f"Generating audio for {actual_duration} seconds...")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
logger.info("Generating audio...")
|
| 127 |
with torch.cuda.amp.autocast():
|
| 128 |
audios = generate(clip_frames,
|
|
|
|
| 370 |
|
| 371 |
final_path = add_watermark(output_path)
|
| 372 |
|
| 373 |
+
# ๋น๋์ค ๊ธธ์ด ํ์ธ
|
| 374 |
+
cap = cv2.VideoCapture(final_path)
|
| 375 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 376 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 377 |
+
video_duration = total_frames / fps
|
| 378 |
+
cap.release()
|
| 379 |
+
|
| 380 |
+
logger.info(f"Original video duration: {video_duration} seconds")
|
| 381 |
+
|
| 382 |
# ์ค๋์ค ์ฒ๋ฆฌ ์ถ๊ฐ
|
| 383 |
try:
|
| 384 |
logger.info("Starting audio generation process")
|
|
|
|
| 388 |
negative_prompt="music",
|
| 389 |
seed=-1,
|
| 390 |
num_steps=20,
|
| 391 |
+
cfg_strength=4.5
|
| 392 |
+
# target_duration ์ ๊ฑฐ - ์๋์ผ๋ก ๋น๋์ค ๊ธธ์ด ์ฌ์ฉ
|
| 393 |
)
|
| 394 |
|
| 395 |
if final_path_with_audio != final_path:
|