SFVS / app.py
GiorgioV's picture
Update app.py
2644331 verified
import spaces
import gc
import logging
import sys
import os
from datetime import datetime
from fractions import Fraction
from pathlib import Path
import gradio as gr
import torch
import torchaudio
from huggingface_hub import hf_hub_download
# --- CONFIGURATION & PATHS ---
MY_VAULT_REPO = "GiorgioV/SFVS"
HF_TOKEN = os.getenv("HF_TOKEN")
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(current_dir, "MMAudio"))
from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image,
load_video, make_video, setup_eval_logging)
from mmaudio.model.flow_matching import FlowMatching
from mmaudio.model.networks import MMAudio, get_my_mmaudio
from mmaudio.model.sequence_config import SequenceConfig
from mmaudio.model.utils.features_utils import FeaturesUtils
# Optimization flags
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
log = logging.getLogger()
setup_eval_logging()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16 if device == "cuda" else torch.float32
output_dir = Path('./output/gradio')
# --- WEIGHT SYNCHRONIZATION ---
def get_weights():
log.info(f"Syncing weights from {MY_VAULT_REPO}...")
return {
"model": hf_hub_download(repo_id=MY_VAULT_REPO, filename="nsfw_gold_8.5k_final.pth", token=HF_TOKEN),
"vae": hf_hub_download(repo_id=MY_VAULT_REPO, filename="ext_weights/v1-44.pth", token=HF_TOKEN),
"sync": hf_hub_download(repo_id=MY_VAULT_REPO, filename="ext_weights/synchformer_state_dict.pth", token=HF_TOKEN),
"vocoder": hf_hub_download(repo_id=MY_VAULT_REPO, filename="ext_weights/v1-44.pth", token=HF_TOKEN)
}
weight_paths = get_weights()
# --- MODEL INITIALIZATION ---
model: ModelConfig = all_model_cfg['large_44k']
model.model_path = weight_paths["model"]
model.vae_path = weight_paths["vae"]
model.synchformer_ckpt = weight_paths["sync"]
model.bigvgan_16k_path = weight_paths["vocoder"]
def load_all_models() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
seq_cfg = model.seq_cfg
net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
feature_utils = FeaturesUtils(
tod_vae_ckpt=model.vae_path,
synchformer_ckpt=model.synchformer_ckpt,
enable_conditions=True,
mode=model.mode,
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
need_vae_encoder=False
).to(device, dtype).eval()
return net, feature_utils, seq_cfg
net, feature_utils, seq_cfg = load_all_models()
# --- INFERENCE FUNCTIONS ---
def get_video_duration(video_path):
if video_path is None:
return 8
try:
import torchaudio
info = torchaudio.info(video_path)
duration = info.num_frames / info.sample_rate
return round(duration, 2)
except Exception:
return 8 #
@spaces.GPU()
@torch.inference_mode()
def video_to_audio(video, prompt, negative_prompt, seed, num_steps, cfg_strength, duration=None):
if duration is None:
duration = get_video_duration(video)
rng = torch.Generator(device=device)
rng.manual_seed(int(seed)) if seed >= 0 else rng.seed()
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=int(num_steps))
video_info = load_video(video, duration)
clip_frames = video_info.clip_frames.unsqueeze(0)
sync_frames = video_info.sync_frames.unsqueeze(0)
seq_cfg.duration = video_info.duration_sec
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
audio = generate(
clip_frames, sync_frames, [prompt],
negative_text=[negative_prompt],
feature_utils=feature_utils, net=net, fm=fm, rng=rng,
cfg_strength=cfg_strength
).float().cpu()[0]
output_dir.mkdir(exist_ok=True, parents=True)
path = output_dir / f"v2a_{datetime.now().strftime('%H%M%S')}.mp4"
make_video(video_info, path, audio, sampling_rate=seq_cfg.sampling_rate)
gc.collect()
return path
@spaces.GPU()
@torch.inference_mode()
def image_to_audio(image, prompt, negative_prompt, seed, num_steps, cfg_strength, duration):
rng = torch.Generator(device=device)
rng.manual_seed(int(seed)) if seed >= 0 else rng.seed()
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=int(num_steps))
image_info = load_image(image)
clip_frames = image_info.clip_frames.unsqueeze(0)
sync_frames = image_info.sync_frames.unsqueeze(0)
seq_cfg.duration = duration
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
audio = generate(
clip_frames, sync_frames, [prompt],
negative_text=[negative_prompt],
feature_utils=feature_utils, net=net, fm=fm, rng=rng,
cfg_strength=cfg_strength, image_input=True
).float().cpu()[0]
output_dir.mkdir(exist_ok=True, parents=True)
path = output_dir / f"i2a_{datetime.now().strftime('%H%M%S')}.mp4"
video_info = VideoInfo.from_image_info(image_info, duration, fps=Fraction(1))
make_video(video_info, path, audio, sampling_rate=seq_cfg.sampling_rate)
gc.collect()
return path
@spaces.GPU()
@torch.inference_mode()
def text_to_audio(prompt, negative_prompt, seed, num_steps, cfg_strength, duration):
rng = torch.Generator(device=device)
rng.manual_seed(int(seed)) if seed >= 0 else rng.seed()
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=int(num_steps))
seq_cfg.duration = duration
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
audio = generate(
None, None, [prompt],
negative_text=[negative_prompt],
feature_utils=feature_utils, net=net, fm=fm, rng=rng,
cfg_strength=cfg_strength
).float().cpu()[0]
output_dir.mkdir(exist_ok=True, parents=True)
path = output_dir / f"t2a_{datetime.now().strftime('%H%M%S')}.flac"
torchaudio.save(path, audio, seq_cfg.sampling_rate)
gc.collect()
return path
# --- GRADIO UI ---
with gr.Blocks() as demo:
gr.Markdown("# Video/Image/Text-to-Audio Generation")
with gr.Tabs():
# --- VIDEO TAB ---
with gr.Tab("Video-to-Audio"):
with gr.Column():
video_in = gr.Video(label="Input Video")
v_prompt = gr.Textbox(label="Prompt", value="high-quality sound, detailed acoustic environment")
v_neg = gr.Textbox(label="Negative Prompt", value="music, distorted, low quality, static, noise, talking, laughing")
v_seed = gr.Number(label="Seed (-1 = random)", value=-1)
v_steps = gr.Slider(label="Num Steps", minimum=1, maximum=100, value=25, step=1)
v_cfg = gr.Slider(label="Guidance Strength", minimum=1, maximum=15, value=4.5, step=0.1)
v_btn = gr.Button("Generate Video Audio")
v_out = gr.Video(label="Generated Video with Audio")
v_btn.click(
fn=video_to_audio,
inputs=[video_in, v_prompt, v_neg, v_seed, v_steps, v_cfg],
outputs=v_out
)
# --- IMAGE TAB ---
with gr.Tab("Image-to-Audio"):
with gr.Column():
img_in = gr.Image(label="Input Image", type="filepath")
i_prompt = gr.Textbox(label="Prompt", value="high-quality sound, detailed acoustic environment")
i_neg = gr.Textbox(label="Negative Prompt", value="music, distorted, low quality, static, noise, talking, laughing")
i_seed = gr.Number(label="Seed (-1 = random)", value=-1)
i_steps = gr.Slider(label="Num Steps", minimum=1, maximum=100, value=25, step=1)
i_cfg = gr.Slider(label="Guidance Strength", minimum=1, maximum=15, value=4.5, step=0.1)
i_dur = gr.Number(label="Duration (sec)", value=8)
i_btn = gr.Button("Generate Image Audio")
i_out = gr.Video(label="Static Video with Audio")
i_btn.click(
fn=image_to_audio,
inputs=[img_in, i_prompt, i_neg, i_seed, i_steps, i_cfg, i_dur],
outputs=i_out
)
# --- TEXT TAB ---
with gr.Tab("Text-to-Audio"):
with gr.Column():
t_prompt = gr.Textbox(label="Prompt", value="high-quality sound, detailed acoustic environment")
t_neg = gr.Textbox(label="Negative Prompt", value="music, distorted, low quality, static, noise, talking, laughing")
t_seed = gr.Number(label="Seed (-1 = random)", value=-1)
t_steps = gr.Slider(label="Num Steps", minimum=1, maximum=100, value=25, step=1)
t_cfg = gr.Slider(label="Guidance Strength", minimum=1, maximum=15, value=4.5, step=0.1)
t_dur = gr.Number(label="Duration (sec)", value=8)
t_btn = gr.Button("Generate Pure Audio")
t_out = gr.Audio(label="Generated Audio")
t_btn.click(
fn=text_to_audio,
inputs=[t_prompt, t_neg, t_seed, t_steps, t_cfg, t_dur],
outputs=t_out
)
if __name__ == "__main__":
output_dir.mkdir(exist_ok=True, parents=True)
demo.launch(server_name="0.0.0.0", server_port=7860)