Update app.py
Browse files
app.py
CHANGED
|
@@ -68,11 +68,14 @@ net, feature_utils, seq_cfg = get_model()
|
|
| 68 |
@spaces.GPU(duration=60)
|
| 69 |
@torch.inference_mode()
|
| 70 |
def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
| 71 |
-
seed: int = -1, num_steps: int =
|
| 72 |
-
cfg_strength: float = 4.5, target_duration: float =
|
| 73 |
try:
|
| 74 |
logger.info("Starting audio generation process")
|
| 75 |
|
|
|
|
|
|
|
|
|
|
| 76 |
rng = torch.Generator(device=device)
|
| 77 |
if seed >= 0:
|
| 78 |
rng.manual_seed(seed)
|
|
@@ -81,9 +84,8 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
|
| 81 |
|
| 82 |
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
| 83 |
|
| 84 |
-
#
|
| 85 |
-
|
| 86 |
-
video_info = load_video(video_path, **kwargs)
|
| 87 |
|
| 88 |
if video_info is None:
|
| 89 |
logger.error("Failed to load video")
|
|
@@ -97,14 +99,13 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
|
| 97 |
logger.error("Failed to extract frames from video")
|
| 98 |
return video_path
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
|
|
|
| 102 |
|
| 103 |
-
# 시퀀스 길이 업데이트
|
| 104 |
seq_cfg.duration = actual_duration
|
| 105 |
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
| 106 |
|
| 107 |
-
# 오디오 생성
|
| 108 |
logger.info("Generating audio...")
|
| 109 |
audios = generate(clip_frames,
|
| 110 |
sync_frames,
|
|
@@ -122,14 +123,16 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
|
| 122 |
|
| 123 |
audio = audios.float().cpu()[0]
|
| 124 |
|
| 125 |
-
# 결과 비디오 생성
|
| 126 |
output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
|
| 127 |
logger.info(f"Creating final video with audio at {output_path}")
|
| 128 |
|
| 129 |
-
make_video(video_info, output_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
if not
|
| 132 |
-
logger.error("Failed to create
|
| 133 |
return video_path
|
| 134 |
|
| 135 |
logger.info(f'Successfully saved video with audio to {output_path}')
|
|
@@ -137,7 +140,8 @@ def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
|
| 137 |
|
| 138 |
except Exception as e:
|
| 139 |
logger.error(f"Error in video_to_audio: {str(e)}")
|
| 140 |
-
|
|
|
|
| 141 |
|
| 142 |
def upload_to_catbox(file_path):
|
| 143 |
"""catbox.moe API를 사용하여 파일 업로드"""
|
|
@@ -357,14 +361,13 @@ def generate_video(image, prompt):
|
|
| 357 |
prompt=prompt,
|
| 358 |
negative_prompt="music",
|
| 359 |
seed=-1,
|
| 360 |
-
num_steps=
|
| 361 |
cfg_strength=4.5,
|
| 362 |
-
target_duration=
|
| 363 |
)
|
| 364 |
|
| 365 |
if final_path_with_audio != final_path:
|
| 366 |
logger.info("Audio generation successful")
|
| 367 |
-
# 임시 파일 정리
|
| 368 |
try:
|
| 369 |
if output_path != final_path:
|
| 370 |
os.remove(output_path)
|
|
|
|
| 68 |
@spaces.GPU(duration=60)
|
| 69 |
@torch.inference_mode()
|
| 70 |
def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
| 71 |
+
seed: int = -1, num_steps: int = 20,
|
| 72 |
+
cfg_strength: float = 4.5, target_duration: float = 6.0):
|
| 73 |
try:
|
| 74 |
logger.info("Starting audio generation process")
|
| 75 |
|
| 76 |
+
# GPU 메모리 최적화
|
| 77 |
+
torch.cuda.empty_cache()
|
| 78 |
+
|
| 79 |
rng = torch.Generator(device=device)
|
| 80 |
if seed >= 0:
|
| 81 |
rng.manual_seed(seed)
|
|
|
|
| 84 |
|
| 85 |
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
| 86 |
|
| 87 |
+
# load_video 함수 호출 수정
|
| 88 |
+
video_info = load_video(video_path, duration=target_duration) # static_duration을 duration으로 변경
|
|
|
|
| 89 |
|
| 90 |
if video_info is None:
|
| 91 |
logger.error("Failed to load video")
|
|
|
|
| 99 |
logger.error("Failed to extract frames from video")
|
| 100 |
return video_path
|
| 101 |
|
| 102 |
+
# 메모리 효율을 위해 배치 크기 조정
|
| 103 |
+
clip_frames = clip_frames.unsqueeze(0).to(device, dtype=torch.float16)
|
| 104 |
+
sync_frames = sync_frames.unsqueeze(0).to(device, dtype=torch.float16)
|
| 105 |
|
|
|
|
| 106 |
seq_cfg.duration = actual_duration
|
| 107 |
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
| 108 |
|
|
|
|
| 109 |
logger.info("Generating audio...")
|
| 110 |
audios = generate(clip_frames,
|
| 111 |
sync_frames,
|
|
|
|
| 123 |
|
| 124 |
audio = audios.float().cpu()[0]
|
| 125 |
|
|
|
|
| 126 |
output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
|
| 127 |
logger.info(f"Creating final video with audio at {output_path}")
|
| 128 |
|
| 129 |
+
success = make_video(video_info, output_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
| 130 |
+
|
| 131 |
+
# GPU 메모리 정리
|
| 132 |
+
torch.cuda.empty_cache()
|
| 133 |
|
| 134 |
+
if not success:
|
| 135 |
+
logger.error("Failed to create video with audio")
|
| 136 |
return video_path
|
| 137 |
|
| 138 |
logger.info(f'Successfully saved video with audio to {output_path}')
|
|
|
|
| 140 |
|
| 141 |
except Exception as e:
|
| 142 |
logger.error(f"Error in video_to_audio: {str(e)}")
|
| 143 |
+
torch.cuda.empty_cache()
|
| 144 |
+
return video_path
|
| 145 |
|
| 146 |
def upload_to_catbox(file_path):
|
| 147 |
"""catbox.moe API를 사용하여 파일 업로드"""
|
|
|
|
| 361 |
prompt=prompt,
|
| 362 |
negative_prompt="music",
|
| 363 |
seed=-1,
|
| 364 |
+
num_steps=20,
|
| 365 |
cfg_strength=4.5,
|
| 366 |
+
target_duration=6.0
|
| 367 |
)
|
| 368 |
|
| 369 |
if final_path_with_audio != final_path:
|
| 370 |
logger.info("Audio generation successful")
|
|
|
|
| 371 |
try:
|
| 372 |
if output_path != final_path:
|
| 373 |
os.remove(output_path)
|