Update app.py
Browse files
app.py
CHANGED
|
@@ -44,6 +44,59 @@ logger = logging.getLogger(__name__)
|
|
| 44 |
CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
|
| 45 |
REPLICATE_API_TOKEN = os.getenv("API_KEY")
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
def upload_to_catbox(file_path):
|
| 48 |
"""catbox.moe API를 사용하여 파일 업로드"""
|
| 49 |
try:
|
|
@@ -287,7 +340,7 @@ footer {display: none}
|
|
| 287 |
"""
|
| 288 |
|
| 289 |
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
|
| 290 |
-
|
| 291 |
|
| 292 |
with gr.Row():
|
| 293 |
with gr.Column(scale=3):
|
|
|
|
| 44 |
CATBOX_USER_HASH = "30f52c895fd9d9cb387eee489"
|
| 45 |
REPLICATE_API_TOKEN = os.getenv("API_KEY")
|
| 46 |
|
| 47 |
+
def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
|
| 48 |
+
seq_cfg = model.seq_cfg
|
| 49 |
+
|
| 50 |
+
net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
|
| 51 |
+
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
|
| 52 |
+
logger.info(f'Loaded weights from {model.model_path}')
|
| 53 |
+
|
| 54 |
+
feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
|
| 55 |
+
synchformer_ckpt=model.synchformer_ckpt,
|
| 56 |
+
enable_conditions=True,
|
| 57 |
+
mode=model.mode,
|
| 58 |
+
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
|
| 59 |
+
need_vae_encoder=False)
|
| 60 |
+
feature_utils = feature_utils.to(device, dtype).eval()
|
| 61 |
+
|
| 62 |
+
return net, feature_utils, seq_cfg
|
| 63 |
+
|
| 64 |
+
@spaces.GPU(duration=120)
|
| 65 |
+
@torch.inference_mode()
|
| 66 |
+
def video_to_audio(video_path: str, prompt: str, negative_prompt: str = "music",
|
| 67 |
+
seed: int = -1, num_steps: int = 25,
|
| 68 |
+
cfg_strength: float = 4.5, duration: float = 8):
|
| 69 |
+
rng = torch.Generator(device=device)
|
| 70 |
+
if seed >= 0:
|
| 71 |
+
rng.manual_seed(seed)
|
| 72 |
+
else:
|
| 73 |
+
rng.seed()
|
| 74 |
+
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
| 75 |
+
|
| 76 |
+
video_info = load_video(video_path, duration)
|
| 77 |
+
clip_frames = video_info.clip_frames
|
| 78 |
+
sync_frames = video_info.sync_frames
|
| 79 |
+
duration = video_info.duration_sec
|
| 80 |
+
clip_frames = clip_frames.unsqueeze(0)
|
| 81 |
+
sync_frames = sync_frames.unsqueeze(0)
|
| 82 |
+
seq_cfg.duration = duration
|
| 83 |
+
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
| 84 |
+
|
| 85 |
+
audios = generate(clip_frames,
|
| 86 |
+
sync_frames, [prompt],
|
| 87 |
+
negative_text=[negative_prompt],
|
| 88 |
+
feature_utils=feature_utils,
|
| 89 |
+
net=net,
|
| 90 |
+
fm=fm,
|
| 91 |
+
rng=rng,
|
| 92 |
+
cfg_strength=cfg_strength)
|
| 93 |
+
audio = audios.float().cpu()[0]
|
| 94 |
+
|
| 95 |
+
video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
|
| 96 |
+
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
| 97 |
+
logger.info(f'Saved video with audio to {video_save_path}')
|
| 98 |
+
return video_save_path
|
| 99 |
+
|
| 100 |
def upload_to_catbox(file_path):
|
| 101 |
"""catbox.moe API를 사용하여 파일 업로드"""
|
| 102 |
try:
|
|
|
|
| 340 |
"""
|
| 341 |
|
| 342 |
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
|
| 343 |
+
|
| 344 |
|
| 345 |
with gr.Row():
|
| 346 |
with gr.Column(scale=3):
|