Spaces:
Running on Zero
Running on Zero
| import spaces | |
| import gc | |
| import logging | |
| import sys | |
| import os | |
| from datetime import datetime | |
| from fractions import Fraction | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| from huggingface_hub import hf_hub_download | |
| # --- CONFIGURATION & PATHS --- | |
| MY_VAULT_REPO = "ibyteohdear/mmaudio-weights-vault" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.insert(0, os.path.join(current_dir, "MMAudio")) | |
| from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image, | |
| load_video, make_video, setup_eval_logging) | |
| from mmaudio.model.flow_matching import FlowMatching | |
| from mmaudio.model.networks import MMAudio, get_my_mmaudio | |
| from mmaudio.model.sequence_config import SequenceConfig | |
| from mmaudio.model.utils.features_utils import FeaturesUtils | |
| # Optimization flags | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| log = logging.getLogger() | |
| setup_eval_logging() | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| dtype = torch.bfloat16 if device == "cuda" else torch.float32 | |
| output_dir = Path('./output/gradio') | |
| # --- WEIGHT SYNCHRONIZATION --- | |
| def get_weights(): | |
| log.info(f"Syncing weights from {MY_VAULT_REPO}...") | |
| return { | |
| "model": hf_hub_download(repo_id=MY_VAULT_REPO, filename="weights/mmaudio_large_44k_v2.pth", token=HF_TOKEN), | |
| "vae": hf_hub_download(repo_id=MY_VAULT_REPO, filename="ext_weights/v1-44.pth", token=HF_TOKEN), | |
| "sync": hf_hub_download(repo_id=MY_VAULT_REPO, filename="ext_weights/synchformer_state_dict.pth", token=HF_TOKEN), | |
| "vocoder": hf_hub_download(repo_id=MY_VAULT_REPO, filename="ext_weights/v1-44.pth", token=HF_TOKEN) | |
| } | |
| weight_paths = get_weights() | |
| # --- MODEL INITIALIZATION --- | |
| model: ModelConfig = all_model_cfg['large_44k_v2'] | |
| model.model_path = weight_paths["model"] | |
| model.vae_path = weight_paths["vae"] | |
| model.synchformer_ckpt = weight_paths["sync"] | |
| model.bigvgan_16k_path = weight_paths["vocoder"] | |
| def load_all_models() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]: | |
| seq_cfg = model.seq_cfg | |
| net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval() | |
| net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True)) | |
| feature_utils = FeaturesUtils( | |
| tod_vae_ckpt=model.vae_path, | |
| synchformer_ckpt=model.synchformer_ckpt, | |
| enable_conditions=True, | |
| mode=model.mode, | |
| bigvgan_vocoder_ckpt=model.bigvgan_16k_path, | |
| need_vae_encoder=False | |
| ).to(device, dtype).eval() | |
| return net, feature_utils, seq_cfg | |
| net, feature_utils, seq_cfg = load_all_models() | |
| # --- INFERENCE FUNCTIONS --- | |
| def get_video_duration(video_path): | |
| if video_path is None: | |
| return 8 | |
| try: | |
| import torchaudio | |
| info = torchaudio.info(video_path) | |
| duration = info.num_frames / info.sample_rate | |
| return round(duration, 2) | |
| except Exception: | |
| return 8 # | |
| def video_to_audio(video, prompt, negative_prompt, seed, num_steps, cfg_strength, duration=None): | |
| if duration is None: | |
| duration = get_video_duration(video) | |
| rng = torch.Generator(device=device) | |
| rng.manual_seed(int(seed)) if seed >= 0 else rng.seed() | |
| fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=int(num_steps)) | |
| video_info = load_video(video, duration) | |
| clip_frames = video_info.clip_frames.unsqueeze(0) | |
| sync_frames = video_info.sync_frames.unsqueeze(0) | |
| seq_cfg.duration = video_info.duration_sec | |
| net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) | |
| audio = generate( | |
| clip_frames, sync_frames, [prompt], | |
| negative_text=[negative_prompt], | |
| feature_utils=feature_utils, net=net, fm=fm, rng=rng, | |
| cfg_strength=cfg_strength | |
| ).float().cpu()[0] | |
| output_dir.mkdir(exist_ok=True, parents=True) | |
| path = output_dir / f"v2a_{datetime.now().strftime('%H%M%S')}.mp4" | |
| make_video(video_info, path, audio, sampling_rate=seq_cfg.sampling_rate) | |
| gc.collect() | |
| return path | |
| def image_to_audio(image, prompt, negative_prompt, seed, num_steps, cfg_strength, duration): | |
| rng = torch.Generator(device=device) | |
| rng.manual_seed(int(seed)) if seed >= 0 else rng.seed() | |
| fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=int(num_steps)) | |
| image_info = load_image(image) | |
| clip_frames = image_info.clip_frames.unsqueeze(0) | |
| sync_frames = image_info.sync_frames.unsqueeze(0) | |
| seq_cfg.duration = duration | |
| net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) | |
| audio = generate( | |
| clip_frames, sync_frames, [prompt], | |
| negative_text=[negative_prompt], | |
| feature_utils=feature_utils, net=net, fm=fm, rng=rng, | |
| cfg_strength=cfg_strength, image_input=True | |
| ).float().cpu()[0] | |
| output_dir.mkdir(exist_ok=True, parents=True) | |
| path = output_dir / f"i2a_{datetime.now().strftime('%H%M%S')}.mp4" | |
| video_info = VideoInfo.from_image_info(image_info, duration, fps=Fraction(1)) | |
| make_video(video_info, path, audio, sampling_rate=seq_cfg.sampling_rate) | |
| gc.collect() | |
| return path | |
| def text_to_audio(prompt, negative_prompt, seed, num_steps, cfg_strength, duration): | |
| rng = torch.Generator(device=device) | |
| rng.manual_seed(int(seed)) if seed >= 0 else rng.seed() | |
| fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=int(num_steps)) | |
| seq_cfg.duration = duration | |
| net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) | |
| audio = generate( | |
| None, None, [prompt], | |
| negative_text=[negative_prompt], | |
| feature_utils=feature_utils, net=net, fm=fm, rng=rng, | |
| cfg_strength=cfg_strength | |
| ).float().cpu()[0] | |
| output_dir.mkdir(exist_ok=True, parents=True) | |
| path = output_dir / f"t2a_{datetime.now().strftime('%H%M%S')}.flac" | |
| torchaudio.save(path, audio, seq_cfg.sampling_rate) | |
| gc.collect() | |
| return path | |
| # --- GRADIO UI --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# MMAudio: Video/Image/Text-to-Audio Generation") | |
| with gr.Tabs(): | |
| # --- VIDEO TAB --- | |
| with gr.Tab("Video-to-Audio"): | |
| with gr.Column(): | |
| video_in = gr.Video(label="Input Video") | |
| v_prompt = gr.Textbox(label="Prompt", value="high-quality sound, detailed acoustic environment") | |
| v_neg = gr.Textbox(label="Negative Prompt", value="music, distorted, low quality, static, noise, talking, laughing") | |
| v_seed = gr.Number(label="Seed (-1 = random)", value=-1) | |
| v_steps = gr.Slider(label="Num Steps", minimum=1, maximum=100, value=25, step=1) | |
| v_cfg = gr.Slider(label="Guidance Strength", minimum=1, maximum=15, value=4.5, step=0.1) | |
| v_btn = gr.Button("Generate Video Audio") | |
| v_out = gr.Video(label="Generated Video with Audio") | |
| v_btn.click( | |
| fn=video_to_audio, | |
| inputs=[video_in, v_prompt, v_neg, v_seed, v_steps, v_cfg], | |
| outputs=v_out | |
| ) | |
| # --- IMAGE TAB --- | |
| with gr.Tab("Image-to-Audio"): | |
| with gr.Column(): | |
| img_in = gr.Image(label="Input Image", type="filepath") | |
| i_prompt = gr.Textbox(label="Prompt", value="high-quality sound, detailed acoustic environment") | |
| i_neg = gr.Textbox(label="Negative Prompt", value="music, distorted, low quality, static, noise, talking, laughing") | |
| i_seed = gr.Number(label="Seed (-1 = random)", value=-1) | |
| i_steps = gr.Slider(label="Num Steps", minimum=1, maximum=100, value=25, step=1) | |
| i_cfg = gr.Slider(label="Guidance Strength", minimum=1, maximum=15, value=4.5, step=0.1) | |
| i_dur = gr.Number(label="Duration (sec)", value=8) | |
| i_btn = gr.Button("Generate Image Audio") | |
| i_out = gr.Video(label="Static Video with Audio") | |
| i_btn.click( | |
| fn=image_to_audio, | |
| inputs=[img_in, i_prompt, i_neg, i_seed, i_steps, i_cfg, i_dur], | |
| outputs=i_out | |
| ) | |
| # --- TEXT TAB --- | |
| with gr.Tab("Text-to-Audio"): | |
| with gr.Column(): | |
| t_prompt = gr.Textbox(label="Prompt", value="high-quality sound, detailed acoustic environment") | |
| t_neg = gr.Textbox(label="Negative Prompt", value="music, distorted, low quality, static, noise, talking, laughing") | |
| t_seed = gr.Number(label="Seed (-1 = random)", value=-1) | |
| t_steps = gr.Slider(label="Num Steps", minimum=1, maximum=100, value=25, step=1) | |
| t_cfg = gr.Slider(label="Guidance Strength", minimum=1, maximum=15, value=4.5, step=0.1) | |
| t_dur = gr.Number(label="Duration (sec)", value=8) | |
| t_btn = gr.Button("Generate Pure Audio") | |
| t_out = gr.Audio(label="Generated Audio") | |
| t_btn.click( | |
| fn=text_to_audio, | |
| inputs=[t_prompt, t_neg, t_seed, t_steps, t_cfg, t_dur], | |
| outputs=t_out | |
| ) | |
| if __name__ == "__main__": | |
| output_dir.mkdir(exist_ok=True, parents=True) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |