| import spaces |
| import gc |
| import logging |
| import sys |
| import os |
| from datetime import datetime |
| from fractions import Fraction |
| from pathlib import Path |
|
|
| import gradio as gr |
| import torch |
| import torchaudio |
| from huggingface_hub import hf_hub_download |
|
|
| |
| MY_VAULT_REPO = "GiorgioV/SFVS" |
| HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
| current_dir = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, os.path.join(current_dir, "MMAudio")) |
|
|
| from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image, |
| load_video, make_video, setup_eval_logging) |
| from mmaudio.model.flow_matching import FlowMatching |
| from mmaudio.model.networks import MMAudio, get_my_mmaudio |
| from mmaudio.model.sequence_config import SequenceConfig |
| from mmaudio.model.utils.features_utils import FeaturesUtils |
|
|
| |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| log = logging.getLogger() |
| setup_eval_logging() |
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| dtype = torch.bfloat16 if device == "cuda" else torch.float32 |
| output_dir = Path('./output/gradio') |
|
|
| |
| def get_weights(): |
| log.info(f"Syncing weights from {MY_VAULT_REPO}...") |
| |
| return { |
| "model": hf_hub_download(repo_id=MY_VAULT_REPO, filename="nsfw_gold_8.5k_final.pth", token=HF_TOKEN), |
| "vae": hf_hub_download(repo_id=MY_VAULT_REPO, filename="ext_weights/v1-44.pth", token=HF_TOKEN), |
| "sync": hf_hub_download(repo_id=MY_VAULT_REPO, filename="ext_weights/synchformer_state_dict.pth", token=HF_TOKEN), |
| "vocoder": hf_hub_download(repo_id=MY_VAULT_REPO, filename="ext_weights/v1-44.pth", token=HF_TOKEN) |
| } |
|
|
| weight_paths = get_weights() |
|
|
| |
| model: ModelConfig = all_model_cfg['large_44k'] |
| model.model_path = weight_paths["model"] |
| model.vae_path = weight_paths["vae"] |
| model.synchformer_ckpt = weight_paths["sync"] |
| model.bigvgan_16k_path = weight_paths["vocoder"] |
|
|
| def load_all_models() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]: |
| seq_cfg = model.seq_cfg |
| net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval() |
| net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True)) |
| |
| feature_utils = FeaturesUtils( |
| tod_vae_ckpt=model.vae_path, |
| synchformer_ckpt=model.synchformer_ckpt, |
| enable_conditions=True, |
| mode=model.mode, |
| bigvgan_vocoder_ckpt=model.bigvgan_16k_path, |
| need_vae_encoder=False |
| ).to(device, dtype).eval() |
|
|
| return net, feature_utils, seq_cfg |
|
|
| net, feature_utils, seq_cfg = load_all_models() |
|
|
| |
|
|
| def get_video_duration(video_path): |
| if video_path is None: |
| return 8 |
| try: |
| import torchaudio |
| info = torchaudio.info(video_path) |
| duration = info.num_frames / info.sample_rate |
| return round(duration, 2) |
| except Exception: |
| return 8 |
|
|
| @spaces.GPU() |
| @torch.inference_mode() |
| def video_to_audio(video, prompt, negative_prompt, seed, num_steps, cfg_strength, duration=None): |
|
|
| if duration is None: |
| duration = get_video_duration(video) |
| |
| rng = torch.Generator(device=device) |
| rng.manual_seed(int(seed)) if seed >= 0 else rng.seed() |
| fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=int(num_steps)) |
|
|
| video_info = load_video(video, duration) |
| clip_frames = video_info.clip_frames.unsqueeze(0) |
| sync_frames = video_info.sync_frames.unsqueeze(0) |
|
|
| seq_cfg.duration = video_info.duration_sec |
| net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) |
|
|
| audio = generate( |
| clip_frames, sync_frames, [prompt], |
| negative_text=[negative_prompt], |
| feature_utils=feature_utils, net=net, fm=fm, rng=rng, |
| cfg_strength=cfg_strength |
| ).float().cpu()[0] |
|
|
| output_dir.mkdir(exist_ok=True, parents=True) |
| path = output_dir / f"v2a_{datetime.now().strftime('%H%M%S')}.mp4" |
| make_video(video_info, path, audio, sampling_rate=seq_cfg.sampling_rate) |
| gc.collect() |
| return path |
| |
| @spaces.GPU() |
| @torch.inference_mode() |
| def image_to_audio(image, prompt, negative_prompt, seed, num_steps, cfg_strength, duration): |
| rng = torch.Generator(device=device) |
| rng.manual_seed(int(seed)) if seed >= 0 else rng.seed() |
| fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=int(num_steps)) |
|
|
| image_info = load_image(image) |
| clip_frames = image_info.clip_frames.unsqueeze(0) |
| sync_frames = image_info.sync_frames.unsqueeze(0) |
|
|
| seq_cfg.duration = duration |
| net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) |
|
|
| audio = generate( |
| clip_frames, sync_frames, [prompt], |
| negative_text=[negative_prompt], |
| feature_utils=feature_utils, net=net, fm=fm, rng=rng, |
| cfg_strength=cfg_strength, image_input=True |
| ).float().cpu()[0] |
|
|
| output_dir.mkdir(exist_ok=True, parents=True) |
| path = output_dir / f"i2a_{datetime.now().strftime('%H%M%S')}.mp4" |
| video_info = VideoInfo.from_image_info(image_info, duration, fps=Fraction(1)) |
| make_video(video_info, path, audio, sampling_rate=seq_cfg.sampling_rate) |
| gc.collect() |
| return path |
|
|
| @spaces.GPU() |
| @torch.inference_mode() |
| def text_to_audio(prompt, negative_prompt, seed, num_steps, cfg_strength, duration): |
| rng = torch.Generator(device=device) |
| rng.manual_seed(int(seed)) if seed >= 0 else rng.seed() |
| fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=int(num_steps)) |
|
|
| seq_cfg.duration = duration |
| net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) |
|
|
| audio = generate( |
| None, None, [prompt], |
| negative_text=[negative_prompt], |
| feature_utils=feature_utils, net=net, fm=fm, rng=rng, |
| cfg_strength=cfg_strength |
| ).float().cpu()[0] |
|
|
| output_dir.mkdir(exist_ok=True, parents=True) |
| path = output_dir / f"t2a_{datetime.now().strftime('%H%M%S')}.flac" |
| torchaudio.save(path, audio, seq_cfg.sampling_rate) |
| gc.collect() |
| return path |
|
|
| |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# Video/Image/Text-to-Audio Generation") |
| |
| with gr.Tabs(): |
| |
| with gr.Tab("Video-to-Audio"): |
| with gr.Column(): |
| video_in = gr.Video(label="Input Video") |
| v_prompt = gr.Textbox(label="Prompt", value="high-quality sound, detailed acoustic environment") |
| v_neg = gr.Textbox(label="Negative Prompt", value="music, distorted, low quality, static, noise, talking, laughing") |
| v_seed = gr.Number(label="Seed (-1 = random)", value=-1) |
| v_steps = gr.Slider(label="Num Steps", minimum=1, maximum=100, value=25, step=1) |
| v_cfg = gr.Slider(label="Guidance Strength", minimum=1, maximum=15, value=4.5, step=0.1) |
| v_btn = gr.Button("Generate Video Audio") |
| v_out = gr.Video(label="Generated Video with Audio") |
| |
| v_btn.click( |
| fn=video_to_audio, |
| inputs=[video_in, v_prompt, v_neg, v_seed, v_steps, v_cfg], |
| outputs=v_out |
| ) |
|
|
| |
| with gr.Tab("Image-to-Audio"): |
| with gr.Column(): |
| img_in = gr.Image(label="Input Image", type="filepath") |
| i_prompt = gr.Textbox(label="Prompt", value="high-quality sound, detailed acoustic environment") |
| i_neg = gr.Textbox(label="Negative Prompt", value="music, distorted, low quality, static, noise, talking, laughing") |
| i_seed = gr.Number(label="Seed (-1 = random)", value=-1) |
| i_steps = gr.Slider(label="Num Steps", minimum=1, maximum=100, value=25, step=1) |
| i_cfg = gr.Slider(label="Guidance Strength", minimum=1, maximum=15, value=4.5, step=0.1) |
| i_dur = gr.Number(label="Duration (sec)", value=8) |
| i_btn = gr.Button("Generate Image Audio") |
| i_out = gr.Video(label="Static Video with Audio") |
| |
| i_btn.click( |
| fn=image_to_audio, |
| inputs=[img_in, i_prompt, i_neg, i_seed, i_steps, i_cfg, i_dur], |
| outputs=i_out |
| ) |
|
|
| |
| with gr.Tab("Text-to-Audio"): |
| with gr.Column(): |
| t_prompt = gr.Textbox(label="Prompt", value="high-quality sound, detailed acoustic environment") |
| t_neg = gr.Textbox(label="Negative Prompt", value="music, distorted, low quality, static, noise, talking, laughing") |
| t_seed = gr.Number(label="Seed (-1 = random)", value=-1) |
| t_steps = gr.Slider(label="Num Steps", minimum=1, maximum=100, value=25, step=1) |
| t_cfg = gr.Slider(label="Guidance Strength", minimum=1, maximum=15, value=4.5, step=0.1) |
| t_dur = gr.Number(label="Duration (sec)", value=8) |
| t_btn = gr.Button("Generate Pure Audio") |
| t_out = gr.Audio(label="Generated Audio") |
| |
| t_btn.click( |
| fn=text_to_audio, |
| inputs=[t_prompt, t_neg, t_seed, t_steps, t_cfg, t_dur], |
| outputs=t_out |
| ) |
|
|
| if __name__ == "__main__": |
| output_dir.mkdir(exist_ok=True, parents=True) |
| demo.launch(server_name="0.0.0.0", server_port=7860) |