BoxOfColors's picture
Lower GPU duration estimates and add 300s hard cap
39d7b17
raw
history blame
40.6 kB
"""
Generate Audio for Video — multi-model Gradio app.
Supported models
----------------
TARO – video-conditioned diffusion via CAVP + onset features (16 kHz, 8.192 s window)
MMAudio – multimodal flow-matching with CLIP/Synchformer + text prompt (44 kHz, 8 s window)
HunyuanFoley – text-guided foley via SigLIP2 + Synchformer + CLAP (48 kHz, up to 15 s)
"""
import os
import tempfile
import random
from math import floor
from pathlib import Path
import torch
import numpy as np
import torchaudio
import ffmpeg
import spaces
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download
# ================================================================== #
# CHECKPOINT CONFIGURATION #
# ================================================================== #
CKPT_REPO_ID = "JackIsNotInTheBox/Generate_Audio_for_Video_Checkpoints"
CACHE_DIR = "/tmp/model_ckpts"
os.makedirs(CACHE_DIR, exist_ok=True)
# ---- TARO checkpoints (in TARO/ subfolder of the HF repo) ----
print("Downloading TARO checkpoints…")
cavp_ckpt_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/cavp_epoch66.ckpt", cache_dir=CACHE_DIR)
onset_ckpt_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/onset_model.ckpt", cache_dir=CACHE_DIR)
taro_ckpt_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/taro_ckpt.pt", cache_dir=CACHE_DIR)
print("TARO checkpoints downloaded.")
# ---- MMAudio checkpoints (in MMAudio/ subfolder) ----
# MMAudio normally auto-downloads from its own HF repo, but we
# override the paths so it pulls from our consolidated repo instead.
MMAUDIO_WEIGHTS_DIR = Path(CACHE_DIR) / "MMAudio" / "weights"
MMAUDIO_EXT_DIR = Path(CACHE_DIR) / "MMAudio" / "ext_weights"
MMAUDIO_WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
MMAUDIO_EXT_DIR.mkdir(parents=True, exist_ok=True)
print("Downloading MMAudio checkpoints…")
mmaudio_model_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/mmaudio_large_44k_v2.pth", cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_WEIGHTS_DIR), local_dir_use_symlinks=False)
mmaudio_vae_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/v1-44.pth", cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False)
mmaudio_synchformer_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/synchformer_state_dict.pth", cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False)
print("MMAudio checkpoints downloaded.")
# ---- HunyuanVideoFoley checkpoints (in HunyuanFoley/ subfolder) ----
HUNYUAN_MODEL_DIR = Path(CACHE_DIR) / "HunyuanFoley"
HUNYUAN_MODEL_DIR.mkdir(parents=True, exist_ok=True)
print("Downloading HunyuanVideoFoley checkpoints…")
hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/hunyuanvideo_foley.pth", cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/vae_128d_48k.pth", cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/synchformer_state_dict.pth", cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
print("HunyuanVideoFoley checkpoints downloaded.")
# Pre-download CLAP model so from_pretrained() reads local cache inside the
# ZeroGPU daemonic worker (spawning child processes there is not allowed).
print("Pre-downloading CLAP model (laion/larger_clap_general)…")
snapshot_download(repo_id="laion/larger_clap_general")
print("CLAP model pre-downloaded.")
# ================================================================== #
# SHARED CONSTANTS / HELPERS #
# ================================================================== #
MAX_SLOTS = 8 # max parallel generation slots shown in UI
def set_global_seed(seed: int):
np.random.seed(seed % (2**32))
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def get_random_seed() -> int:
return random.randint(0, 2**32 - 1)
def get_video_duration(video_path: str) -> float:
"""Return video duration in seconds (CPU only)."""
probe = ffmpeg.probe(video_path)
return float(probe["format"]["duration"])
def strip_audio_from_video(video_path: str, output_path: str):
"""Write a silent copy of *video_path* to *output_path*."""
ffmpeg.input(video_path).output(output_path, vcodec="libx264", an=None).run(
overwrite_output=True, quiet=True
)
def mux_video_audio(silent_video: str, audio_path: str, output_path: str):
"""Mux a silent video with an audio file into *output_path*."""
ffmpeg.output(
ffmpeg.input(silent_video),
ffmpeg.input(audio_path),
output_path,
vcodec="libx264", acodec="aac", strict="experimental",
).run(overwrite_output=True, quiet=True)
# ------------------------------------------------------------------ #
# Shared sliding-window segmentation and crossfade helpers #
# Used by all three models (TARO, MMAudio, HunyuanFoley). #
# ------------------------------------------------------------------ #
def _build_segments(total_dur_s: float, window_s: float, crossfade_s: float) -> list:
"""Return list of (start, end) pairs covering *total_dur_s* with a sliding
window of *window_s* and *crossfade_s* overlap between consecutive segments."""
if total_dur_s <= window_s:
return [(0.0, total_dur_s)]
step_s = window_s - crossfade_s
segments, seg_start = [], 0.0
while True:
if seg_start + window_s >= total_dur_s:
seg_start = max(0.0, total_dur_s - window_s)
segments.append((seg_start, total_dur_s))
break
segments.append((seg_start, seg_start + window_s))
seg_start += step_s
return segments
def _cf_join(a: np.ndarray, b: np.ndarray,
crossfade_s: float, db_boost: float, sr: int) -> np.ndarray:
"""Equal-power crossfade join. Works for both mono (T,) and stereo (C, T) arrays.
Stereo arrays are expected in (channels, samples) layout."""
stereo = a.ndim == 2
n_a = a.shape[1] if stereo else len(a)
n_b = b.shape[1] if stereo else len(b)
cf = min(int(round(crossfade_s * sr)), n_a, n_b)
if cf <= 0:
return np.concatenate([a, b], axis=1 if stereo else 0)
gain = 10 ** (db_boost / 20.0)
t = np.linspace(0.0, 1.0, cf, dtype=np.float32)
fade_out = np.cos(t * np.pi / 2) # 1 → 0
fade_in = np.sin(t * np.pi / 2) # 0 → 1
if stereo:
overlap = a[:, -cf:] * fade_out * gain + b[:, :cf] * fade_in * gain
return np.concatenate([a[:, :-cf], overlap, b[:, cf:]], axis=1)
else:
overlap = a[-cf:] * fade_out * gain + b[:cf] * fade_in * gain
return np.concatenate([a[:-cf], overlap, b[cf:]])
# ================================================================== #
# TARO #
# ================================================================== #
# Constants sourced from TARO/infer.py and TARO/models.py:
# SR=16000, TRUNCATE=131072 → 8.192 s window
# TRUNCATE_FRAME = 4 fps × 131072/16000 = 32 CAVP frames per window
# TRUNCATE_ONSET = 120 onset frames per window
# latent shape: (1, 8, 204, 16) — fixed by MMDiT architecture
# latents_scale: [0.18215]*8 — AudioLDM2 VAE scale factor
# ================================================================== #
TARO_SR = 16000
TARO_TRUNCATE = 131072
TARO_FPS = 4
TARO_TRUNCATE_FRAME = int(TARO_FPS * TARO_TRUNCATE / TARO_SR) # 32
TARO_TRUNCATE_ONSET = 120
TARO_MODEL_DUR = TARO_TRUNCATE / TARO_SR # 8.192 s
TARO_SECS_PER_STEP = 0.8 # estimated GPU-seconds per diffusion step on H200
TARO_LOAD_OVERHEAD = 20 # seconds: model load + CAVP feature extraction
MMAUDIO_SECS_PER_STEP = 0.8 # estimated GPU-seconds per flow-matching step on H200
MMAUDIO_LOAD_OVERHEAD = 15
HUNYUAN_SECS_PER_STEP = 2.0 # estimated GPU-seconds per denoising step on H200 (heavier model)
HUNYUAN_LOAD_OVERHEAD = 20
GPU_DURATION_CAP = 300 # hard cap per call — never reserve more than this
_TARO_INFERENCE_CACHE: dict = {}
def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: float) -> int:
n_segs = len(_build_segments(total_dur_s, TARO_MODEL_DUR, crossfade_s))
time_per_seg = num_steps * TARO_SECS_PER_STEP
max_s = floor(600.0 / (n_segs * time_per_seg))
return max(1, min(max_s, MAX_SLOTS))
def _taro_duration(video_file, seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, num_samples):
"""Pre-GPU callable — must match _run_taro's input order exactly."""
try:
total_s = get_video_duration(video_file)
n_segs = len(_build_segments(total_s, TARO_MODEL_DUR, float(crossfade_s)))
except Exception:
n_segs = 1
secs = int(num_samples) * n_segs * int(num_steps) * TARO_SECS_PER_STEP + TARO_LOAD_OVERHEAD
result = min(GPU_DURATION_CAP, max(60, int(secs)))
print(f"[duration] TARO: {int(num_samples)}samp × {n_segs}seg × {int(num_steps)}steps → {secs:.0f}s → capped {result}s")
return result
def _taro_infer_segment(
model, vae, vocoder,
cavp_feats_full, onset_feats_full,
seg_start_s: float, seg_end_s: float,
device, weight_dtype,
cfg_scale: float, num_steps: int, mode: str,
latents_scale,
euler_sampler, euler_maruyama_sampler,
) -> np.ndarray:
"""Single-segment TARO inference. Returns wav array trimmed to segment length."""
# CAVP features (4 fps)
cavp_start = int(round(seg_start_s * TARO_FPS))
cavp_slice = cavp_feats_full[cavp_start : cavp_start + TARO_TRUNCATE_FRAME]
if cavp_slice.shape[0] < TARO_TRUNCATE_FRAME:
pad = np.zeros(
(TARO_TRUNCATE_FRAME - cavp_slice.shape[0],) + cavp_slice.shape[1:],
dtype=cavp_slice.dtype,
)
cavp_slice = np.concatenate([cavp_slice, pad], axis=0)
video_feats = torch.from_numpy(cavp_slice).unsqueeze(0).to(device, weight_dtype)
# Onset features (onset_fps = TRUNCATE_ONSET / MODEL_DUR ≈ 14.65 fps)
onset_fps = TARO_TRUNCATE_ONSET / TARO_MODEL_DUR
onset_start = int(round(seg_start_s * onset_fps))
onset_slice = onset_feats_full[onset_start : onset_start + TARO_TRUNCATE_ONSET]
if onset_slice.shape[0] < TARO_TRUNCATE_ONSET:
onset_slice = np.pad(
onset_slice,
((0, TARO_TRUNCATE_ONSET - onset_slice.shape[0]),),
mode="constant",
)
onset_feats_t = torch.from_numpy(onset_slice).unsqueeze(0).to(device, weight_dtype)
# Latent noise — shape matches MMDiT architecture (in_channels=8, 204×16 spatial)
z = torch.randn(1, model.in_channels, 204, 16, device=device, dtype=weight_dtype)
sampling_kwargs = dict(
model=model,
latents=z,
y=onset_feats_t,
context=video_feats,
num_steps=int(num_steps),
heun=False,
cfg_scale=float(cfg_scale),
guidance_low=0.0,
guidance_high=0.7,
path_type="linear",
)
with torch.no_grad():
samples = (euler_maruyama_sampler if mode == "sde" else euler_sampler)(**sampling_kwargs)
# samplers return (output_tensor, zs) — index [0] for the audio latent
if isinstance(samples, tuple):
samples = samples[0]
# Decode: AudioLDM2 VAE → mel → vocoder → waveform
samples = vae.decode(samples / latents_scale).sample
wav = vocoder(samples.squeeze().float()).detach().cpu().numpy()
seg_samples = int(round((seg_end_s - seg_start_s) * TARO_SR))
return wav[:seg_samples]
def _stitch_wavs(wavs: list, crossfade_s: float, db_boost: float,
total_dur_s: float, sr: int) -> np.ndarray:
out = wavs[0]
for nw in wavs[1:]:
out = _cf_join(out, nw, crossfade_s, db_boost, sr)
return out[:int(round(total_dur_s * sr))]
@spaces.GPU(duration=_taro_duration)
def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, num_samples):
"""TARO: video-conditioned diffusion, 16 kHz, 8.192 s sliding window."""
global _TARO_INFERENCE_CACHE
seed_val = int(seed_val)
crossfade_s = float(crossfade_s)
crossfade_db = float(crossfade_db)
num_samples = int(num_samples)
if seed_val < 0:
seed_val = random.randint(0, 2**32 - 1)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = torch.bfloat16
# TARO modules use bare imports (e.g. `from cavp_util import ...`) that
# assume the TARO directory is on sys.path. Add it before importing.
import sys, os as _os
_taro_dir = _os.path.join(_os.path.dirname(_os.path.abspath(__file__)), "TARO")
if _taro_dir not in sys.path:
sys.path.insert(0, _taro_dir)
# Imports are inside the GPU context so the Space only pays for GPU time here
from TARO.cavp_util import Extract_CAVP_Features
from TARO.onset_util import VideoOnsetNet, extract_onset
from TARO.models import MMDiT
from TARO.samplers import euler_sampler, euler_maruyama_sampler
from diffusers import AudioLDM2Pipeline
# -- Load CAVP encoder (uses checkpoint from our HF repo) --
extract_cavp = Extract_CAVP_Features(
device=device,
config_path="TARO/cavp/cavp.yaml",
ckpt_path=cavp_ckpt_path,
)
# -- Load onset detection model --
# Key remapping matches the original TARO infer.py exactly
raw_sd = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
onset_sd = {}
for k, v in raw_sd.items():
if "model.net.model" in k:
k = k.replace("model.net.model", "net.model")
elif "model.fc." in k:
k = k.replace("model.fc", "fc")
onset_sd[k] = v
onset_model = VideoOnsetNet(pretrained=False).to(device)
onset_model.load_state_dict(onset_sd)
onset_model.eval()
# -- Load TARO MMDiT --
# Architecture params match TARO/train.py: adm_in_channels=120 (onset dim),
# z_dims=[768] (CAVP dim), encoder_depth=4
model = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
model.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"])
model.eval().to(weight_dtype)
# -- Load AudioLDM2 VAE + vocoder (decoder pipeline only) --
# TARO uses AudioLDM2's VAE and vocoder for decoding; no encoder needed at inference
audioldm2 = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
vae = audioldm2.vae.to(device).eval()
vocoder = audioldm2.vocoder.to(device)
latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
# -- Prepare silent video (shared across all samples) --
tmp_dir = tempfile.mkdtemp()
silent_video = os.path.join(tmp_dir, "silent_input.mp4")
strip_audio_from_video(video_file, silent_video)
cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
total_dur_s = cavp_feats.shape[0] / TARO_FPS
segments = _build_segments(total_dur_s, TARO_MODEL_DUR, crossfade_s)
outputs = []
for sample_idx in range(num_samples):
sample_seed = seed_val + sample_idx
cache_key = (video_file, sample_seed, float(cfg_scale), int(num_steps), mode, crossfade_s)
if cache_key in _TARO_INFERENCE_CACHE:
print(f"[TARO] Sample {sample_idx+1}: cache hit.")
wavs = _TARO_INFERENCE_CACHE[cache_key]["wavs"]
else:
set_global_seed(sample_seed)
onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
wavs = []
for seg_start_s, seg_end_s in segments:
print(f"[TARO] Sample {sample_idx+1} | {seg_start_s:.2f}s – {seg_end_s:.2f}s")
wav = _taro_infer_segment(
model, vae, vocoder,
cavp_feats, onset_feats,
seg_start_s, seg_end_s,
device, weight_dtype,
cfg_scale, num_steps, mode,
latents_scale,
euler_sampler, euler_maruyama_sampler,
)
wavs.append(wav)
_TARO_INFERENCE_CACHE[cache_key] = {"wavs": wavs}
final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR)
audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(final_wav)).unsqueeze(0), TARO_SR)
video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
mux_video_audio(silent_video, audio_path, video_path)
outputs.append((video_path, audio_path))
return _pad_outputs(outputs)
# ================================================================== #
# MMAudio #
# ================================================================== #
# Constants sourced from MMAudio/mmaudio/model/sequence_config.py:
# CONFIG_44K: duration=8.0 s, sampling_rate=44100
# CLIP encoder: 8 fps, 384×384 px
# Synchformer: 25 fps, 224×224 px
# Default variant: large_44k_v2
# MMAudio uses flow-matching (FlowMatching with euler inference).
# generate() handles all feature extraction + decoding internally.
# ================================================================== #
MMAUDIO_WINDOW = 8.0 # seconds — MMAudio's fixed generation window
def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
"""Pre-GPU callable — must match _run_mmaudio's input order exactly."""
try:
total_s = get_video_duration(video_file)
n_segs = len(_build_segments(total_s, MMAUDIO_WINDOW, float(crossfade_s)))
except Exception:
n_segs = 1
secs = int(num_samples) * n_segs * int(num_steps) * MMAUDIO_SECS_PER_STEP + MMAUDIO_LOAD_OVERHEAD
result = min(GPU_DURATION_CAP, max(60, int(secs)))
print(f"[duration] MMAudio: {int(num_samples)}samp × {n_segs}seg × {int(num_steps)}steps → {secs:.0f}s → capped {result}s")
return result
@spaces.GPU(duration=_mmaudio_duration)
def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
"""MMAudio: flow-matching video-to-audio, 44.1 kHz, 8 s sliding window."""
import sys as _sys, os as _os
_mmaudio_dir = _os.path.join(_os.path.dirname(_os.path.abspath(__file__)), "MMAudio")
if _mmaudio_dir not in _sys.path:
_sys.path.insert(0, _mmaudio_dir)
from mmaudio.eval_utils import all_model_cfg, generate, load_video, make_video
from mmaudio.model.flow_matching import FlowMatching
from mmaudio.model.networks import get_my_mmaudio
from mmaudio.model.utils.features_utils import FeaturesUtils
seed_val = int(seed_val)
num_samples = int(num_samples)
crossfade_s = float(crossfade_s)
crossfade_db = float(crossfade_db)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
# Use large_44k_v2 variant; override paths to our consolidated HF checkpoint repo
model_cfg = all_model_cfg["large_44k_v2"]
# Patch checkpoint paths to our downloaded files
from pathlib import Path as _Path
model_cfg.model_path = _Path(mmaudio_model_path)
model_cfg.vae_path = _Path(mmaudio_vae_path)
model_cfg.synchformer_ckpt = _Path(mmaudio_synchformer_path)
# large_44k_v2 is 44k mode, no BigVGAN vocoder needed
model_cfg.bigvgan_16k_path = None
seq_cfg = model_cfg.seq_cfg # CONFIG_44K: 8 s, 44100 Hz
# Load network weights
net = get_my_mmaudio(model_cfg.model_name).to(device, dtype).eval()
net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True))
# Load feature utilities: CLIP (auto-downloaded from apple/DFN5B-CLIP-ViT-H-14-384),
# Synchformer (from our repo), VAE (from our repo), no BigVGAN for 44k mode
feature_utils = FeaturesUtils(
tod_vae_ckpt=str(model_cfg.vae_path),
synchformer_ckpt=str(model_cfg.synchformer_ckpt),
enable_conditions=True,
mode=model_cfg.mode, # "44k"
bigvgan_vocoder_ckpt=None,
need_vae_encoder=False,
).to(device, dtype).eval()
tmp_dir = tempfile.mkdtemp()
outputs = []
# MMAudio's fixed window is 8 s. For longer videos we slide over 8 s segments
# with a crossfade overlap and stitch the results into a full-length track.
total_dur_s = get_video_duration(video_file)
segments = _build_segments(total_dur_s, MMAUDIO_WINDOW, crossfade_s)
print(f"[MMAudio] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤8 s")
sr = seq_cfg.sampling_rate # 44100
for sample_idx in range(num_samples):
rng = torch.Generator(device=device)
if seed_val >= 0:
rng.manual_seed(seed_val + sample_idx)
else:
rng.seed()
seg_audios = [] # list of (channels, samples) numpy arrays
for seg_i, (seg_start, seg_end) in enumerate(segments):
seg_dur = seg_end - seg_start
# Trim a clean video clip for this segment
seg_path = os.path.join(tmp_dir, f"mma_seg_{sample_idx}_{seg_i}.mp4")
ffmpeg.input(video_file, ss=seg_start, t=seg_dur).output(
seg_path, vcodec="libx264", acodec="aac", strict="experimental"
).run(overwrite_output=True, quiet=True)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
video_info = load_video(seg_path, seg_dur)
clip_frames = video_info.clip_frames.unsqueeze(0)
sync_frames = video_info.sync_frames.unsqueeze(0)
actual_dur = video_info.duration_sec
seq_cfg.duration = actual_dur
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
print(f"[MMAudio] Sample {sample_idx+1} | seg {seg_i+1}/{len(segments)} "
f"{seg_start:.1f}{seg_end:.1f}s | dur={actual_dur:.2f}s | prompt='{prompt}'")
with torch.no_grad():
audios = generate(
clip_frames,
sync_frames,
[prompt],
negative_text=[negative_prompt] if negative_prompt else None,
feature_utils=feature_utils,
net=net,
fm=fm,
rng=rng,
cfg_strength=float(cfg_strength),
)
wav = audios.float().cpu()[0].numpy() # (C, T)
seg_samples = int(round(seg_dur * sr))
wav = wav[:, :seg_samples]
seg_audios.append(wav)
# Crossfade-stitch all segments using shared equal-power helper
full_wav = seg_audios[0]
for nw in seg_audios[1:]:
full_wav = _cf_join(full_wav, nw, crossfade_s, crossfade_db, sr)
full_wav = full_wav[:, : int(round(total_dur_s * sr))]
audio_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.flac")
torchaudio.save(audio_path, torch.from_numpy(full_wav), sr)
video_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.mp4")
mux_video_audio(video_file, audio_path, video_path)
outputs.append((video_path, audio_path))
return _pad_outputs(outputs)
# ================================================================== #
# HunyuanVideoFoley #
# ================================================================== #
# Constants sourced from HunyuanVideo-Foley/hunyuanvideo_foley/constants.py
# and configs/hunyuanvideo-foley-xxl.yaml:
# sample_rate = 48000 Hz (from DAC VAE)
# audio_frame_rate = 50 (latent fps, xxl config)
# max video duration = 15 s
# SigLIP2 fps = 8, Synchformer fps = 25
# CLAP text encoder: laion/larger_clap_general (auto-downloaded from HF Hub)
# Default guidance_scale=4.5, num_inference_steps=50
# ================================================================== #
HUNYUAN_MAX_DUR = 15.0 # seconds
def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
"""Pre-GPU callable — must match _run_hunyuan's input order exactly."""
try:
total_s = get_video_duration(video_file)
n_segs = len(_build_segments(total_s, HUNYUAN_MAX_DUR, float(crossfade_s)))
except Exception:
n_segs = 1
secs = int(num_samples) * n_segs * int(num_steps) * HUNYUAN_SECS_PER_STEP + HUNYUAN_LOAD_OVERHEAD
result = min(GPU_DURATION_CAP, max(60, int(secs)))
print(f"[duration] HunyuanFoley: {int(num_samples)}samp × {n_segs}seg × {int(num_steps)}steps → {secs:.0f}s → capped {result}s")
return result
@spaces.GPU(duration=_hunyuan_duration)
def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
"""HunyuanVideoFoley: text-guided foley, 48 kHz, up to 15 s."""
import sys as _sys
# Ensure HunyuanVideo-Foley package is importable
_hf_path = str(Path("HunyuanVideo-Foley").resolve())
if _hf_path not in _sys.path:
_sys.path.insert(0, _hf_path)
from hunyuanvideo_foley.utils.model_utils import load_model, denoise_process
from hunyuanvideo_foley.utils.feature_utils import feature_process
from hunyuanvideo_foley.utils.media_utils import merge_audio_video
seed_val = int(seed_val)
num_samples = int(num_samples)
crossfade_s = float(crossfade_s)
crossfade_db = float(crossfade_db)
if seed_val >= 0:
set_global_seed(seed_val)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_size = model_size.lower() # "xl" or "xxl"
config_map = {
"xl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xl.yaml",
"xxl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xxl.yaml",
}
config_path = config_map.get(model_size, config_map["xxl"])
# hf_hub_download preserves the repo subfolder, so weights land in
# HUNYUAN_MODEL_DIR/HunyuanVideo-Foley/ — pass that as the weights dir.
hunyuan_weights_dir = str(HUNYUAN_MODEL_DIR / "HunyuanVideo-Foley")
print(f"[HunyuanFoley] Loading {model_size.upper()} model from {hunyuan_weights_dir}")
model_dict, cfg = load_model(
hunyuan_weights_dir,
config_path,
device,
enable_offload=False,
model_size=model_size,
)
tmp_dir = tempfile.mkdtemp()
outputs = []
# HunyuanFoley is limited to 15 s per pass. For longer videos we slice the
# input into overlapping segments, generate audio for each, then crossfade-
# stitch the results into a single full-length audio track.
total_dur_s = get_video_duration(video_file)
segments = _build_segments(total_dur_s, HUNYUAN_MAX_DUR, crossfade_s)
print(f"[HunyuanFoley] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤15 s")
# Pre-encode text features once (same for every segment)
_dummy_seg_path = os.path.join(tmp_dir, "_seg_dummy.mp4")
ffmpeg.input(video_file, ss=0, t=min(total_dur_s, HUNYUAN_MAX_DUR)).output(
_dummy_seg_path, vcodec="libx264", acodec="aac", strict="experimental"
).run(overwrite_output=True, quiet=True)
_, text_feats, _ = feature_process(
_dummy_seg_path,
prompt if prompt else "",
model_dict,
cfg,
neg_prompt=negative_prompt if negative_prompt else None,
)
# Generate audio per segment, then stitch
for sample_idx in range(num_samples):
seg_wavs = []
sr = 48000 # HunyuanFoley always outputs 48 kHz
for seg_i, (seg_start, seg_end) in enumerate(segments):
seg_dur = seg_end - seg_start
seg_path = os.path.join(tmp_dir, f"seg_{sample_idx}_{seg_i}.mp4")
ffmpeg.input(video_file, ss=seg_start, t=seg_dur).output(
seg_path, vcodec="libx264", acodec="aac", strict="experimental"
).run(overwrite_output=True, quiet=True)
visual_feats, _, seg_audio_len = feature_process(
seg_path,
prompt if prompt else "",
model_dict,
cfg,
neg_prompt=negative_prompt if negative_prompt else None,
)
print(f"[HunyuanFoley] Sample {sample_idx+1} | seg {seg_i+1}/{len(segments)} "
f"{seg_start:.1f}{seg_end:.1f}s → {seg_audio_len:.2f}s audio")
audio_batch, sr = denoise_process(
visual_feats,
text_feats,
seg_audio_len,
model_dict,
cfg,
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_steps),
batch_size=1,
)
# audio_batch shape: (1, channels, samples) — take first (and only) sample
wav = audio_batch[0].float().cpu().numpy() # (channels, samples)
# Trim to exact segment length in samples
seg_samples = int(round(seg_dur * sr))
wav = wav[:, :seg_samples]
seg_wavs.append(wav)
# Crossfade-stitch all segments using shared equal-power helper
full_wav = seg_wavs[0]
for nw in seg_wavs[1:]:
full_wav = _cf_join(full_wav, nw, crossfade_s, crossfade_db, sr)
# Trim to exact video duration
full_wav = full_wav[:, : int(round(total_dur_s * sr))]
audio_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.wav")
torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(full_wav)), sr)
video_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.mp4")
merge_audio_video(audio_path, video_file, video_path)
outputs.append((video_path, audio_path))
return _pad_outputs(outputs)
# ================================================================== #
# SHARED UI HELPERS #
# ================================================================== #
def _pad_outputs(outputs: list) -> list:
"""Flatten (video, audio) pairs and pad to MAX_SLOTS * 2 with None."""
result = []
for i in range(MAX_SLOTS):
if i < len(outputs):
result.extend(outputs[i])
else:
result.extend([None, None])
return result
def _make_output_slots() -> tuple:
"""Build MAX_SLOTS video+audio output groups. Returns (grps, vids, auds)."""
grps, vids, auds = [], [], []
for i in range(MAX_SLOTS):
with gr.Group(visible=(i == 0)) as g:
vids.append(gr.Video(label=f"Generation {i+1} — Video"))
auds.append(gr.Audio(label=f"Generation {i+1} — Audio"))
grps.append(g)
return grps, vids, auds
def _unpack_outputs(flat: list, n: int) -> list:
"""Turn a flat _pad_outputs list into Gradio update lists for grps+vids+auds."""
n = int(n)
return (
[gr.update(visible=(i < n)) for i in range(MAX_SLOTS)] +
[gr.update(value=flat[i * 2]) for i in range(MAX_SLOTS)] +
[gr.update(value=flat[i * 2 + 1]) for i in range(MAX_SLOTS)]
)
def _on_video_upload_taro(video_file, num_steps, crossfade_s):
if video_file is None:
return gr.update(maximum=MAX_SLOTS, value=1)
try:
D = get_video_duration(video_file)
max_s = _taro_calc_max_samples(D, int(num_steps), float(crossfade_s))
except Exception:
max_s = MAX_SLOTS
return gr.update(maximum=max_s, value=min(1, max_s))
def _update_slot_visibility(n):
n = int(n)
return [gr.update(visible=(i < n)) for i in range(MAX_SLOTS)]
# ================================================================== #
# GRADIO UI #
# ================================================================== #
with gr.Blocks(title="Generate Audio for Video") as demo:
gr.Markdown(
"# Generate Audio for Video\n"
"Choose a model and upload a video to generate synchronized audio.\n\n"
"| Model | Best for | Avoid for |\n"
"|-------|----------|-----------|\n"
"| **TARO** | Natural, physics-driven impacts — footsteps, collisions, water, wind, crackling fire. Excels when the sound is tightly coupled to visible motion without needing a text description. | Dialogue, music, or complex layered soundscapes where semantic context matters. |\n"
"| **MMAudio** | Mixed scenes where you want both visual grounding *and* semantic control via a text prompt — e.g. a busy street scene where you want to emphasize the rain rather than the traffic. Great for ambient textures and nuanced sound design. | Pure impact/foley shots where TARO's motion-coupling would be sharper, or cinematic music beds. |\n"
"| **HunyuanFoley** | Cinematic foley requiring high fidelity and explicit creative direction — dramatic SFX, layered environmental design, or any scene where you have a clear written description of the desired sound palette. | Quick one-shot clips where you don't want to write a prompt, or raw impact sounds where timing precision matters more than richness. |"
)
with gr.Tabs():
# ---------------------------------------------------------- #
# Tab 1 — TARO #
# ---------------------------------------------------------- #
with gr.Tab("TARO"):
with gr.Row():
with gr.Column():
taro_video = gr.Video(label="Input Video")
taro_seed = gr.Number(label="Seed (-1 = random)", value=get_random_seed(), precision=0)
taro_cfg = gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5)
taro_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1)
taro_mode = gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde")
taro_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=8, value=2, step=0.1)
taro_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3")
taro_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1)
taro_btn = gr.Button("Generate", variant="primary")
with gr.Column():
taro_slot_grps, taro_slot_vids, taro_slot_auds = _make_output_slots()
for trigger in [taro_video, taro_steps, taro_cf_dur]:
trigger.change(
fn=_on_video_upload_taro,
inputs=[taro_video, taro_steps, taro_cf_dur],
outputs=[taro_samples],
)
taro_samples.change(
fn=_update_slot_visibility,
inputs=[taro_samples],
outputs=taro_slot_grps,
)
def _run_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n):
return _unpack_outputs(generate_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n), n)
taro_btn.click(
fn=_run_taro,
inputs=[taro_video, taro_seed, taro_cfg, taro_steps, taro_mode,
taro_cf_dur, taro_cf_db, taro_samples],
outputs=taro_slot_grps + taro_slot_vids + taro_slot_auds,
)
# ---------------------------------------------------------- #
# Tab 2 — MMAudio #
# ---------------------------------------------------------- #
with gr.Tab("MMAudio"):
with gr.Row():
with gr.Column():
mma_video = gr.Video(label="Input Video")
mma_prompt = gr.Textbox(label="Prompt", placeholder="e.g. footsteps on gravel")
mma_neg = gr.Textbox(label="Negative Prompt", placeholder="music, speech")
mma_seed = gr.Number(label="Seed (-1 = random)", value=get_random_seed(), precision=0)
mma_cfg = gr.Slider(label="CFG Strength", minimum=1, maximum=10, value=4.5, step=0.5)
mma_steps = gr.Slider(label="Steps", minimum=10, maximum=50, value=25, step=1)
mma_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=8, value=2, step=0.1)
mma_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3")
mma_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1)
mma_btn = gr.Button("Generate", variant="primary")
with gr.Column():
mma_slot_grps, mma_slot_vids, mma_slot_auds = _make_output_slots()
mma_samples.change(
fn=_update_slot_visibility,
inputs=[mma_samples],
outputs=mma_slot_grps,
)
def _run_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n):
return _unpack_outputs(generate_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n), n)
mma_btn.click(
fn=_run_mmaudio,
inputs=[mma_video, mma_prompt, mma_neg, mma_seed,
mma_cfg, mma_steps, mma_cf_dur, mma_cf_db, mma_samples],
outputs=mma_slot_grps + mma_slot_vids + mma_slot_auds,
)
# ---------------------------------------------------------- #
# Tab 3 — HunyuanVideoFoley #
# ---------------------------------------------------------- #
with gr.Tab("HunyuanFoley"):
with gr.Row():
with gr.Column():
hf_video = gr.Video(label="Input Video")
hf_prompt = gr.Textbox(label="Prompt", placeholder="e.g. rain hitting a metal roof")
hf_neg = gr.Textbox(label="Negative Prompt", value="noisy, harsh")
hf_seed = gr.Number(label="Seed (-1 = random)", value=get_random_seed(), precision=0)
hf_guidance = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, value=4.5, step=0.5)
hf_steps = gr.Slider(label="Steps", minimum=10, maximum=100, value=50, step=5)
hf_size = gr.Radio(label="Model Size", choices=["xl", "xxl"], value="xxl")
hf_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=8, value=2, step=0.1)
hf_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3")
hf_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1)
hf_btn = gr.Button("Generate", variant="primary")
with gr.Column():
hf_slot_grps, hf_slot_vids, hf_slot_auds = _make_output_slots()
hf_samples.change(
fn=_update_slot_visibility,
inputs=[hf_samples],
outputs=hf_slot_grps,
)
def _run_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n):
return _unpack_outputs(generate_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n), n)
hf_btn.click(
fn=_run_hunyuan,
inputs=[hf_video, hf_prompt, hf_neg, hf_seed,
hf_guidance, hf_steps, hf_size, hf_cf_dur, hf_cf_db, hf_samples],
outputs=hf_slot_grps + hf_slot_vids + hf_slot_auds,
)
demo.queue(max_size=10).launch()