import spaces import gc import logging import sys import os from datetime import datetime from fractions import Fraction from pathlib import Path import gradio as gr import torch import torchaudio from huggingface_hub import hf_hub_download # --- CONFIGURATION & PATHS --- MY_VAULT_REPO = "GiorgioV/SFVS" HF_TOKEN = os.getenv("HF_TOKEN") current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.join(current_dir, "MMAudio")) from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image, load_video, make_video, setup_eval_logging) from mmaudio.model.flow_matching import FlowMatching from mmaudio.model.networks import MMAudio, get_my_mmaudio from mmaudio.model.sequence_config import SequenceConfig from mmaudio.model.utils.features_utils import FeaturesUtils # Optimization flags torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True log = logging.getLogger() setup_eval_logging() device = 'cuda' if torch.cuda.is_available() else 'cpu' dtype = torch.bfloat16 if device == "cuda" else torch.float32 output_dir = Path('./output/gradio') # --- WEIGHT SYNCHRONIZATION --- def get_weights(): log.info(f"Syncing weights from {MY_VAULT_REPO}...") return { "model": hf_hub_download(repo_id=MY_VAULT_REPO, filename="nsfw_gold_8.5k_final.pth", token=HF_TOKEN), "vae": hf_hub_download(repo_id=MY_VAULT_REPO, filename="ext_weights/v1-44.pth", token=HF_TOKEN), "sync": hf_hub_download(repo_id=MY_VAULT_REPO, filename="ext_weights/synchformer_state_dict.pth", token=HF_TOKEN), "vocoder": hf_hub_download(repo_id=MY_VAULT_REPO, filename="ext_weights/v1-44.pth", token=HF_TOKEN) } weight_paths = get_weights() # --- MODEL INITIALIZATION --- model: ModelConfig = all_model_cfg['large_44k'] model.model_path = weight_paths["model"] model.vae_path = weight_paths["vae"] model.synchformer_ckpt = weight_paths["sync"] model.bigvgan_16k_path = weight_paths["vocoder"] def load_all_models() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]: seq_cfg = model.seq_cfg net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval() net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True)) feature_utils = FeaturesUtils( tod_vae_ckpt=model.vae_path, synchformer_ckpt=model.synchformer_ckpt, enable_conditions=True, mode=model.mode, bigvgan_vocoder_ckpt=model.bigvgan_16k_path, need_vae_encoder=False ).to(device, dtype).eval() return net, feature_utils, seq_cfg net, feature_utils, seq_cfg = load_all_models() # --- INFERENCE FUNCTIONS --- def get_video_duration(video_path): if video_path is None: return 8 try: import torchaudio info = torchaudio.info(video_path) duration = info.num_frames / info.sample_rate return round(duration, 2) except Exception: return 8 # @spaces.GPU() @torch.inference_mode() def video_to_audio(video, prompt, negative_prompt, seed, num_steps, cfg_strength, duration=None): if duration is None: duration = get_video_duration(video) rng = torch.Generator(device=device) rng.manual_seed(int(seed)) if seed >= 0 else rng.seed() fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=int(num_steps)) video_info = load_video(video, duration) clip_frames = video_info.clip_frames.unsqueeze(0) sync_frames = video_info.sync_frames.unsqueeze(0) seq_cfg.duration = video_info.duration_sec net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) audio = generate( clip_frames, sync_frames, [prompt], negative_text=[negative_prompt], feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=cfg_strength ).float().cpu()[0] output_dir.mkdir(exist_ok=True, parents=True) path = output_dir / f"v2a_{datetime.now().strftime('%H%M%S')}.mp4" make_video(video_info, path, audio, sampling_rate=seq_cfg.sampling_rate) gc.collect() return path @spaces.GPU() @torch.inference_mode() def image_to_audio(image, prompt, negative_prompt, seed, num_steps, cfg_strength, duration): rng = torch.Generator(device=device) rng.manual_seed(int(seed)) if seed >= 0 else rng.seed() fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=int(num_steps)) image_info = load_image(image) clip_frames = image_info.clip_frames.unsqueeze(0) sync_frames = image_info.sync_frames.unsqueeze(0) seq_cfg.duration = duration net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) audio = generate( clip_frames, sync_frames, [prompt], negative_text=[negative_prompt], feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=cfg_strength, image_input=True ).float().cpu()[0] output_dir.mkdir(exist_ok=True, parents=True) path = output_dir / f"i2a_{datetime.now().strftime('%H%M%S')}.mp4" video_info = VideoInfo.from_image_info(image_info, duration, fps=Fraction(1)) make_video(video_info, path, audio, sampling_rate=seq_cfg.sampling_rate) gc.collect() return path @spaces.GPU() @torch.inference_mode() def text_to_audio(prompt, negative_prompt, seed, num_steps, cfg_strength, duration): rng = torch.Generator(device=device) rng.manual_seed(int(seed)) if seed >= 0 else rng.seed() fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=int(num_steps)) seq_cfg.duration = duration net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) audio = generate( None, None, [prompt], negative_text=[negative_prompt], feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=cfg_strength ).float().cpu()[0] output_dir.mkdir(exist_ok=True, parents=True) path = output_dir / f"t2a_{datetime.now().strftime('%H%M%S')}.flac" torchaudio.save(path, audio, seq_cfg.sampling_rate) gc.collect() return path # --- GRADIO UI --- with gr.Blocks() as demo: gr.Markdown("# Video/Image/Text-to-Audio Generation") with gr.Tabs(): # --- VIDEO TAB --- with gr.Tab("Video-to-Audio"): with gr.Column(): video_in = gr.Video(label="Input Video") v_prompt = gr.Textbox(label="Prompt", value="high-quality sound, detailed acoustic environment") v_neg = gr.Textbox(label="Negative Prompt", value="music, distorted, low quality, static, noise, talking, laughing") v_seed = gr.Number(label="Seed (-1 = random)", value=-1) v_steps = gr.Slider(label="Num Steps", minimum=1, maximum=100, value=25, step=1) v_cfg = gr.Slider(label="Guidance Strength", minimum=1, maximum=15, value=4.5, step=0.1) v_btn = gr.Button("Generate Video Audio") v_out = gr.Video(label="Generated Video with Audio") v_btn.click( fn=video_to_audio, inputs=[video_in, v_prompt, v_neg, v_seed, v_steps, v_cfg], outputs=v_out ) # --- IMAGE TAB --- with gr.Tab("Image-to-Audio"): with gr.Column(): img_in = gr.Image(label="Input Image", type="filepath") i_prompt = gr.Textbox(label="Prompt", value="high-quality sound, detailed acoustic environment") i_neg = gr.Textbox(label="Negative Prompt", value="music, distorted, low quality, static, noise, talking, laughing") i_seed = gr.Number(label="Seed (-1 = random)", value=-1) i_steps = gr.Slider(label="Num Steps", minimum=1, maximum=100, value=25, step=1) i_cfg = gr.Slider(label="Guidance Strength", minimum=1, maximum=15, value=4.5, step=0.1) i_dur = gr.Number(label="Duration (sec)", value=8) i_btn = gr.Button("Generate Image Audio") i_out = gr.Video(label="Static Video with Audio") i_btn.click( fn=image_to_audio, inputs=[img_in, i_prompt, i_neg, i_seed, i_steps, i_cfg, i_dur], outputs=i_out ) # --- TEXT TAB --- with gr.Tab("Text-to-Audio"): with gr.Column(): t_prompt = gr.Textbox(label="Prompt", value="high-quality sound, detailed acoustic environment") t_neg = gr.Textbox(label="Negative Prompt", value="music, distorted, low quality, static, noise, talking, laughing") t_seed = gr.Number(label="Seed (-1 = random)", value=-1) t_steps = gr.Slider(label="Num Steps", minimum=1, maximum=100, value=25, step=1) t_cfg = gr.Slider(label="Guidance Strength", minimum=1, maximum=15, value=4.5, step=0.1) t_dur = gr.Number(label="Duration (sec)", value=8) t_btn = gr.Button("Generate Pure Audio") t_out = gr.Audio(label="Generated Audio") t_btn.click( fn=text_to_audio, inputs=[t_prompt, t_neg, t_seed, t_steps, t_cfg, t_dur], outputs=t_out ) if __name__ == "__main__": output_dir.mkdir(exist_ok=True, parents=True) demo.launch(server_name="0.0.0.0", server_port=7860)