Update app.py
Browse files
app.py
CHANGED
|
@@ -65,11 +65,11 @@ output_dir = Path('./output/gradio')
|
|
| 65 |
setup_eval_logging()
|
| 66 |
net, feature_utils, seq_cfg = get_model()
|
| 67 |
|
| 68 |
-
@spaces.GPU(duration=
|
| 69 |
@torch.inference_mode()
|
| 70 |
def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
| 71 |
-
seed: int = -1, num_steps: int =
|
| 72 |
-
cfg_strength: float = 4.
|
| 73 |
try:
|
| 74 |
logger.info("Starting audio generation process")
|
| 75 |
torch.cuda.empty_cache()
|
|
@@ -83,16 +83,12 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
|
| 83 |
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
| 84 |
|
| 85 |
# load_video ํจ์ ํธ์ถ ์์
|
| 86 |
-
video_info = load_video(video_path) #
|
| 87 |
|
| 88 |
if video_info is None:
|
| 89 |
logger.error("Failed to load video")
|
| 90 |
return video_path
|
| 91 |
|
| 92 |
-
# ๋น๋์ค ๊ธธ์ด ์กฐ์ ์ด ํ์ํ ๊ฒฝ์ฐ ์ฌ๊ธฐ์ ์ฒ๋ฆฌ
|
| 93 |
-
if hasattr(video_info, 'set_duration'):
|
| 94 |
-
video_info.set_duration(target_duration)
|
| 95 |
-
|
| 96 |
clip_frames = video_info.clip_frames
|
| 97 |
sync_frames = video_info.sync_frames
|
| 98 |
actual_duration = video_info.duration_sec
|
|
@@ -101,6 +97,10 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
|
| 101 |
logger.error("Failed to extract frames from video")
|
| 102 |
return video_path
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16)
|
| 105 |
sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16)
|
| 106 |
|
|
@@ -108,15 +108,16 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
|
| 108 |
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
| 109 |
|
| 110 |
logger.info("Generating audio...")
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
| 120 |
|
| 121 |
if audios is None:
|
| 122 |
logger.error("Failed to generate audio")
|
|
|
|
| 65 |
setup_eval_logging()
|
| 66 |
net, feature_utils, seq_cfg = get_model()
|
| 67 |
|
| 68 |
+
@spaces.GPU(duration=30) # 30์ด๋ก ์ ํ
|
| 69 |
@torch.inference_mode()
|
| 70 |
def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
| 71 |
+
seed: int = -1, num_steps: int = 15,
|
| 72 |
+
cfg_strength: float = 4.0, target_duration: float = 4.0):
|
| 73 |
try:
|
| 74 |
logger.info("Starting audio generation process")
|
| 75 |
torch.cuda.empty_cache()
|
|
|
|
| 83 |
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
| 84 |
|
| 85 |
# load_video ํจ์ ํธ์ถ ์์
|
| 86 |
+
video_info = load_video(video_path, duration_sec=target_duration) # duration_sec ํ๋ผ๋ฏธํฐ๋ก ๋ณ๊ฒฝ
|
| 87 |
|
| 88 |
if video_info is None:
|
| 89 |
logger.error("Failed to load video")
|
| 90 |
return video_path
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
clip_frames = video_info.clip_frames
|
| 93 |
sync_frames = video_info.sync_frames
|
| 94 |
actual_duration = video_info.duration_sec
|
|
|
|
| 97 |
logger.error("Failed to extract frames from video")
|
| 98 |
return video_path
|
| 99 |
|
| 100 |
+
# ๋ฉ๋ชจ๋ฆฌ ์ต์ ํ
|
| 101 |
+
clip_frames = clip_frames[:int(actual_duration * video_info.fps)]
|
| 102 |
+
sync_frames = sync_frames[:int(actual_duration * video_info.fps)]
|
| 103 |
+
|
| 104 |
clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16)
|
| 105 |
sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16)
|
| 106 |
|
|
|
|
| 108 |
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
| 109 |
|
| 110 |
logger.info("Generating audio...")
|
| 111 |
+
with torch.cuda.amp.autocast():
|
| 112 |
+
audios = generate(clip_frames,
|
| 113 |
+
sync_frames,
|
| 114 |
+
[prompt],
|
| 115 |
+
negative_text=[negative_prompt],
|
| 116 |
+
feature_utils=feature_utils,
|
| 117 |
+
net=net,
|
| 118 |
+
fm=fm,
|
| 119 |
+
rng=rng,
|
| 120 |
+
cfg_strength=cfg_strength)
|
| 121 |
|
| 122 |
if audios is None:
|
| 123 |
logger.error("Failed to generate audio")
|