BoxOfColors's picture
Refactor: reduce technical debt across app.py
d9fa683
raw
history blame
139 kB
"""
Generate Audio for Video β€” multi-model Gradio app.
Supported models
----------------
TARO – video-conditioned diffusion via CAVP + onset features (16 kHz, 8.192 s window)
MMAudio – multimodal flow-matching with CLIP/Synchformer + text prompt (44 kHz, 8 s window)
HunyuanFoley – text-guided foley via SigLIP2 + Synchformer + CLAP (48 kHz, up to 15 s)
"""
import math
import os
import sys
import json
import shutil
import tempfile
import random
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import torch
import numpy as np
import torchaudio
import ffmpeg
import spaces
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download
# ================================================================== #
# CHECKPOINT CONFIGURATION #
# ================================================================== #
CKPT_REPO_ID = "JackIsNotInTheBox/Generate_Audio_for_Video_Checkpoints"
CACHE_DIR = "/tmp/model_ckpts"
os.makedirs(CACHE_DIR, exist_ok=True)
# ---- Local directories that must exist before parallel downloads start ----
MMAUDIO_WEIGHTS_DIR = Path(CACHE_DIR) / "MMAudio" / "weights"
MMAUDIO_EXT_DIR = Path(CACHE_DIR) / "MMAudio" / "ext_weights"
HUNYUAN_MODEL_DIR = Path(CACHE_DIR) / "HunyuanFoley"
MMAUDIO_WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
MMAUDIO_EXT_DIR.mkdir(parents=True, exist_ok=True)
HUNYUAN_MODEL_DIR.mkdir(parents=True, exist_ok=True)
# ------------------------------------------------------------------ #
# Parallel checkpoint + model downloads #
# All downloads are I/O-bound (network), so running them in threads #
# cuts Space cold-start time roughly proportional to the number of #
# independent groups (previously sequential, now concurrent). #
# hf_hub_download / snapshot_download are thread-safe. #
# ------------------------------------------------------------------ #
def _dl_taro():
"""Download TARO .ckpt/.pt files and return their local paths."""
c = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/cavp_epoch66.ckpt", cache_dir=CACHE_DIR)
o = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/onset_model.ckpt", cache_dir=CACHE_DIR)
t = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/taro_ckpt.pt", cache_dir=CACHE_DIR)
print("TARO checkpoints downloaded.")
return c, o, t
def _dl_mmaudio():
"""Download MMAudio .pth files and return their local paths."""
m = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/mmaudio_large_44k_v2.pth",
cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_WEIGHTS_DIR), local_dir_use_symlinks=False)
v = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/v1-44.pth",
cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False)
s = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/synchformer_state_dict.pth",
cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False)
print("MMAudio checkpoints downloaded.")
return m, v, s
def _dl_hunyuan():
"""Download HunyuanVideoFoley .pth files."""
hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/hunyuanvideo_foley.pth",
cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/vae_128d_48k.pth",
cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/synchformer_state_dict.pth",
cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
print("HunyuanVideoFoley checkpoints downloaded.")
def _dl_clap():
"""Pre-download CLAP so from_pretrained() hits local cache inside the ZeroGPU worker."""
snapshot_download(repo_id="laion/larger_clap_general")
print("CLAP model pre-downloaded.")
def _dl_clip():
"""Pre-download MMAudio's CLIP model (~3.95 GB) to avoid GPU-window budget drain."""
snapshot_download(repo_id="apple/DFN5B-CLIP-ViT-H-14-384")
print("MMAudio CLIP model pre-downloaded.")
def _dl_audioldm2():
"""Pre-download AudioLDM2 VAE/vocoder used by TARO's from_pretrained() calls."""
snapshot_download(repo_id="cvssp/audioldm2")
print("AudioLDM2 pre-downloaded.")
def _dl_bigvgan():
"""Pre-download BigVGAN vocoder (~489 MB) used by MMAudio."""
snapshot_download(repo_id="nvidia/bigvgan_v2_44khz_128band_512x")
print("BigVGAN vocoder pre-downloaded.")
print("[startup] Starting parallel checkpoint + model downloads…")
_t_dl_start = time.perf_counter()
with ThreadPoolExecutor(max_workers=7) as _pool:
_fut_taro = _pool.submit(_dl_taro)
_fut_mmaudio = _pool.submit(_dl_mmaudio)
_fut_hunyuan = _pool.submit(_dl_hunyuan)
_fut_clap = _pool.submit(_dl_clap)
_fut_clip = _pool.submit(_dl_clip)
_fut_aldm2 = _pool.submit(_dl_audioldm2)
_fut_bigvgan = _pool.submit(_dl_bigvgan)
# Raise any download exceptions immediately
for _fut in as_completed([_fut_taro, _fut_mmaudio, _fut_hunyuan,
_fut_clap, _fut_clip, _fut_aldm2, _fut_bigvgan]):
_fut.result()
cavp_ckpt_path, onset_ckpt_path, taro_ckpt_path = _fut_taro.result()
mmaudio_model_path, mmaudio_vae_path, mmaudio_synchformer_path = _fut_mmaudio.result()
print(f"[startup] All downloads done in {time.perf_counter() - _t_dl_start:.1f}s")
# ================================================================== #
# SHARED CONSTANTS / HELPERS #
# ================================================================== #
# CPU β†’ GPU context passing via function-name-keyed global store.
#
# Problem: ZeroGPU runs @spaces.GPU functions on its own worker thread, so
# threading.local() is invisible to the GPU worker. Passing ctx as a
# function argument exposes it to Gradio's API endpoint, causing
# "Too many arguments" errors.
#
# Solution: store context in a plain global dict keyed by function name.
# A per-key Lock serialises concurrent callers for the same function
# (ZeroGPU is already synchronous β€” the wrapper blocks until the GPU fn
# returns β€” so in practice only one call per GPU fn is in-flight at a time).
# The global dict is readable from any thread.
_GPU_CTX: dict = {}
_GPU_CTX_LOCK = threading.Lock()
def _ctx_store(fn_name: str, data: dict) -> None:
"""Store *data* under *fn_name* key (overwrites previous)."""
with _GPU_CTX_LOCK:
_GPU_CTX[fn_name] = data
def _ctx_load(fn_name: str) -> dict:
"""Pop and return the context dict stored under *fn_name*."""
with _GPU_CTX_LOCK:
return _GPU_CTX.pop(fn_name, {})
MAX_SLOTS = 8 # max parallel generation slots shown in UI
MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≀ ~64 s at 8 s/seg)
# Segment overlay palette β€” shared between _build_waveform_html and _build_regen_pending_html
SEG_COLORS = [
"rgba(100,180,255,{a})", "rgba(255,160,100,{a})",
"rgba(120,220,140,{a})", "rgba(220,120,220,{a})",
"rgba(255,220,80,{a})", "rgba(80,220,220,{a})",
"rgba(255,100,100,{a})", "rgba(180,255,180,{a})",
]
# ------------------------------------------------------------------ #
# Micro-helpers that eliminate repeated boilerplate across the file #
# ------------------------------------------------------------------ #
def _ensure_syspath(subdir: str) -> str:
"""Add *subdir* (relative to app.py) to sys.path if not already present.
Returns the absolute path for convenience."""
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), subdir)
if p not in sys.path:
sys.path.insert(0, p)
return p
def _get_device_and_dtype() -> tuple:
"""Return (device, weight_dtype) pair used by all GPU functions."""
device = "cuda" if torch.cuda.is_available() else "cpu"
return device, torch.bfloat16
def _extract_segment_clip(silent_video: str, seg_start: float, seg_dur: float,
output_path: str) -> str:
"""Stream-copy a segment from *silent_video* to *output_path*. Returns *output_path*."""
ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
output_path, vcodec="copy", an=None
).run(overwrite_output=True, quiet=True)
return output_path
# Per-slot reentrant locks β€” prevent concurrent regens on the same slot from
# producing a race condition where the second regen reads stale state
# (the shared seg_state textbox hasn't been updated yet by the first regen).
# Locks are keyed by slot_id string (e.g. "taro_0", "mma_2").
_SLOT_LOCKS: dict = {}
_SLOT_LOCKS_MUTEX = threading.Lock()
def _get_slot_lock(slot_id: str) -> threading.Lock:
with _SLOT_LOCKS_MUTEX:
if slot_id not in _SLOT_LOCKS:
_SLOT_LOCKS[slot_id] = threading.Lock()
return _SLOT_LOCKS[slot_id]
def set_global_seed(seed: int) -> None:
np.random.seed(seed % (2**32))
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def get_random_seed() -> int:
return random.randint(0, 2**32 - 1)
def _resolve_seed(seed_val) -> int:
"""Normalise seed_val to a non-negative int.
Negative values (UI default 'random') produce a fresh random seed."""
seed_val = int(seed_val)
return seed_val if seed_val >= 0 else get_random_seed()
def get_video_duration(video_path: str) -> float:
"""Return video duration in seconds (CPU only)."""
probe = ffmpeg.probe(video_path)
return float(probe["format"]["duration"])
def strip_audio_from_video(video_path: str, output_path: str) -> None:
"""Write a silent copy of *video_path* to *output_path* (stream-copy, no re-encode)."""
ffmpeg.input(video_path).output(output_path, vcodec="copy", an=None).run(
overwrite_output=True, quiet=True
)
def _transcode_for_browser(video_path: str) -> str:
"""Re-encode uploaded video to H.264/AAC MP4 so the browser preview widget can play it.
Returns a NEW path in a fresh /tmp/gradio/ subdirectory. Gradio probes the
returned path fresh, sees H.264, and serves it directly without its own
slow fallback converter. The in-place overwrite approach loses the race
because Gradio probes the original path at upload time before this callback runs.
Only called on upload β€” not during generation.
"""
if video_path is None:
return video_path
try:
probe = ffmpeg.probe(video_path)
has_audio = any(s["codec_type"] == "audio" for s in probe.get("streams", []))
# Check if already H.264 β€” skip transcode if so
video_streams = [s for s in probe.get("streams", []) if s["codec_type"] == "video"]
if video_streams and video_streams[0].get("codec_name") == "h264":
print(f"[transcode_for_browser] already H.264, skipping")
return video_path
# Write the H.264 output into the SAME directory as the original upload.
# Gradio's file server only allows paths under dirs it registered β€” the
# upload dir is already allowed, so a sibling file there will serve fine.
import os as _os
upload_dir = _os.path.dirname(video_path)
stem = _os.path.splitext(_os.path.basename(video_path))[0]
out_path = _os.path.join(upload_dir, stem + "_h264.mp4")
kwargs = dict(
vcodec="libx264", preset="fast", crf=18,
pix_fmt="yuv420p", movflags="+faststart",
)
if has_audio:
kwargs["acodec"] = "aac"
kwargs["audio_bitrate"] = "128k"
else:
kwargs["an"] = None
# map 0:v:0 explicitly to skip non-video streams (e.g. data/timecode tracks)
ffmpeg.input(video_path).output(out_path, map="0:v:0", **kwargs).run(
overwrite_output=True, quiet=True
)
print(f"[transcode_for_browser] transcoded to H.264: {out_path}")
return out_path
except Exception as e:
print(f"[transcode_for_browser] failed, using original: {e}")
return video_path
# ------------------------------------------------------------------ #
# Temp directory registry β€” tracks dirs for cleanup on new generation #
# ------------------------------------------------------------------ #
_TEMP_DIRS: list = [] # list of tmp_dir paths created by generate_*
_TEMP_DIRS_MAX = 10 # keep at most this many; older ones get cleaned up
def _register_tmp_dir(tmp_dir: str) -> str:
"""Register a temp dir so it can be cleaned up when newer ones replace it."""
_TEMP_DIRS.append(tmp_dir)
while len(_TEMP_DIRS) > _TEMP_DIRS_MAX:
old = _TEMP_DIRS.pop(0)
try:
shutil.rmtree(old, ignore_errors=True)
print(f"[cleanup] Removed old temp dir: {old}")
except Exception:
pass
return tmp_dir
def _save_seg_wavs(wavs: list[np.ndarray], tmp_dir: str, prefix: str) -> list[str]:
"""Save a list of numpy wav arrays to .npy files, return list of paths.
This avoids serialising large float arrays into JSON/HTML data-state."""
paths = []
for i, w in enumerate(wavs):
p = os.path.join(tmp_dir, f"{prefix}_seg{i}.npy")
np.save(p, w)
paths.append(p)
return paths
def _load_seg_wavs(paths: list[str]) -> list[np.ndarray]:
"""Load segment wav arrays from .npy file paths."""
return [np.load(p) for p in paths]
# ------------------------------------------------------------------ #
# Shared model-loading helpers (deduplicate generate / regen code) #
# ------------------------------------------------------------------ #
def _load_taro_models(device, weight_dtype):
"""Load TARO MMDiT + AudioLDM2 VAE/vocoder. Returns (model_net, vae, vocoder, latents_scale)."""
from TARO.models import MMDiT
from diffusers import AutoencoderKL
from transformers import SpeechT5HifiGan
model_net = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
model_net.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"])
model_net.eval().to(weight_dtype)
vae = AutoencoderKL.from_pretrained("cvssp/audioldm2", subfolder="vae").to(device).eval()
vocoder = SpeechT5HifiGan.from_pretrained("cvssp/audioldm2", subfolder="vocoder").to(device)
latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
return model_net, vae, vocoder, latents_scale
def _load_taro_feature_extractors(device):
"""Load CAVP + onset extractors. Returns (extract_cavp, onset_model)."""
from TARO.cavp_util import Extract_CAVP_Features
from TARO.onset_util import VideoOnsetNet
extract_cavp = Extract_CAVP_Features(
device=device, config_path="TARO/cavp/cavp.yaml", ckpt_path=cavp_ckpt_path,
)
raw_sd = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
onset_sd = {}
for k, v in raw_sd.items():
if "model.net.model" in k: k = k.replace("model.net.model", "net.model")
elif "model.fc." in k: k = k.replace("model.fc", "fc")
onset_sd[k] = v
onset_model = VideoOnsetNet(pretrained=False).to(device)
onset_model.load_state_dict(onset_sd)
onset_model.eval()
return extract_cavp, onset_model
def _load_mmaudio_models(device, dtype):
"""Load MMAudio net + feature_utils. Returns (net, feature_utils, model_cfg, seq_cfg)."""
from mmaudio.eval_utils import all_model_cfg
from mmaudio.model.networks import get_my_mmaudio
from mmaudio.model.utils.features_utils import FeaturesUtils
model_cfg = all_model_cfg["large_44k_v2"]
model_cfg.model_path = Path(mmaudio_model_path)
model_cfg.vae_path = Path(mmaudio_vae_path)
model_cfg.synchformer_ckpt = Path(mmaudio_synchformer_path)
model_cfg.bigvgan_16k_path = None
seq_cfg = model_cfg.seq_cfg
net = get_my_mmaudio(model_cfg.model_name).to(device, dtype).eval()
net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True))
feature_utils = FeaturesUtils(
tod_vae_ckpt=str(model_cfg.vae_path),
synchformer_ckpt=str(model_cfg.synchformer_ckpt),
enable_conditions=True, mode=model_cfg.mode,
bigvgan_vocoder_ckpt=None, need_vae_encoder=False,
).to(device, dtype).eval()
return net, feature_utils, model_cfg, seq_cfg
def _load_hunyuan_model(device, model_size):
"""Load HunyuanFoley model dict + config. Returns (model_dict, cfg)."""
from hunyuanvideo_foley.utils.model_utils import load_model
model_size = model_size.lower()
config_map = {
"xl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xl.yaml",
"xxl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xxl.yaml",
}
config_path = config_map.get(model_size, config_map["xxl"])
hunyuan_weights_dir = str(HUNYUAN_MODEL_DIR / "HunyuanVideo-Foley")
print(f"[HunyuanFoley] Loading {model_size.upper()} model from {hunyuan_weights_dir}")
return load_model(hunyuan_weights_dir, config_path, device,
enable_offload=False, model_size=model_size)
def mux_video_audio(silent_video: str, audio_path: str, output_path: str,
model: str = None) -> None:
"""Mux a silent video with an audio file into *output_path*.
For HunyuanFoley (*model*="hunyuan") we use its own merge_audio_video which
handles its specific ffmpeg quirks; all other models use stream-copy muxing.
"""
if model == "hunyuan":
_ensure_syspath("HunyuanVideo-Foley")
from hunyuanvideo_foley.utils.media_utils import merge_audio_video
merge_audio_video(audio_path, silent_video, output_path)
else:
v_in = ffmpeg.input(silent_video)
a_in = ffmpeg.input(audio_path)
ffmpeg.output(
v_in["v:0"],
a_in["a:0"],
output_path,
vcodec="libx264", preset="fast", crf=18,
pix_fmt="yuv420p",
acodec="aac", audio_bitrate="128k",
movflags="+faststart",
).run(overwrite_output=True, quiet=True)
# ------------------------------------------------------------------ #
# Shared sliding-window segmentation and crossfade helpers #
# Used by all three models (TARO, MMAudio, HunyuanFoley). #
# ------------------------------------------------------------------ #
def _build_segments(total_dur_s: float, window_s: float, crossfade_s: float) -> list[tuple[float, float]]:
"""Return list of (start, end) pairs covering *total_dur_s*.
Every segment uses the full *window_s* inference window. Segments are
equally spaced so every overlap is identical, guaranteeing the crossfade
setting is honoured at every boundary with no raw bleed.
Algorithm
---------
1. Clamp crossfade_s so the step stays positive.
2. Find the minimum n such that n segments of *window_s* cover
*total_dur_s* with overlap β‰₯ crossfade_s at every boundary:
n = ceil((total_dur_s - crossfade_s) / (window_s - crossfade_s))
3. Compute equal spacing: step = (total_dur_s - window_s) / (n - 1)
so that every gap is identical and the last segment ends exactly at
total_dur_s.
4. Every segment is exactly *window_s* wide. The trailing audio of each
segment beyond its contact edge is discarded in _stitch_wavs.
"""
crossfade_s = min(crossfade_s, window_s * 0.5)
if total_dur_s <= window_s:
return [(0.0, total_dur_s)]
step_min = window_s - crossfade_s # minimum step to honour crossfade
n = math.ceil((total_dur_s - crossfade_s) / step_min)
n = max(n, 2)
# Equal step so first seg starts at 0 and last seg ends at total_dur_s
step_s = (total_dur_s - window_s) / (n - 1)
return [(i * step_s, i * step_s + window_s) for i in range(n)]
def _cf_join(a: np.ndarray, b: np.ndarray,
crossfade_s: float, db_boost: float, sr: int) -> np.ndarray:
"""Equal-power crossfade join. Works for both mono (T,) and stereo (C, T) arrays.
Stereo arrays are expected in (channels, samples) layout.
db_boost is applied to the overlap region as a whole (after blending), so
it compensates for the -3 dB equal-power dip without doubling amplitude.
Applying gain to each side independently (the common mistake) causes a
+3 dB loudness bump at the seam β€” this version avoids that."""
stereo = a.ndim == 2
n_a = a.shape[1] if stereo else len(a)
n_b = b.shape[1] if stereo else len(b)
cf = min(int(round(crossfade_s * sr)), n_a, n_b)
if cf <= 0:
return np.concatenate([a, b], axis=1 if stereo else 0)
gain = 10 ** (db_boost / 20.0)
t = np.linspace(0.0, 1.0, cf, dtype=np.float32)
fade_out = np.cos(t * np.pi / 2) # 1 β†’ 0
fade_in = np.sin(t * np.pi / 2) # 0 β†’ 1
if stereo:
# Blend first, then apply boost to the overlap region as a unit
overlap = (a[:, -cf:] * fade_out + b[:, :cf] * fade_in) * gain
return np.concatenate([a[:, :-cf], overlap, b[:, cf:]], axis=1)
else:
overlap = (a[-cf:] * fade_out + b[:cf] * fade_in) * gain
return np.concatenate([a[:-cf], overlap, b[cf:]])
# ================================================================== #
# TARO #
# ================================================================== #
# Constants sourced from TARO/infer.py and TARO/models.py:
# SR=16000, TRUNCATE=131072 β†’ 8.192 s window
# TRUNCATE_FRAME = 4 fps Γ— 131072/16000 = 32 CAVP frames per window
# TRUNCATE_ONSET = 120 onset frames per window
# latent shape: (1, 8, 204, 16) β€” fixed by MMDiT architecture
# latents_scale: [0.18215]*8 β€” AudioLDM2 VAE scale factor
# ================================================================== #
# ================================================================== #
# MODEL CONSTANTS & CONFIGURATION REGISTRY #
# ================================================================== #
# All per-model numeric constants live here β€” MODEL_CONFIGS is the #
# single source of truth consumed by duration estimation, segmentation,#
# and the UI. Standalone names kept only where other code references #
# them by name (TARO geometry, TARGET_SR, GPU_DURATION_CAP). #
# ================================================================== #
# TARO geometry β€” referenced directly in _taro_infer_segment
TARO_SR = 16000
TARO_TRUNCATE = 131072
TARO_FPS = 4
TARO_TRUNCATE_FRAME = int(TARO_FPS * TARO_TRUNCATE / TARO_SR) # 32
TARO_TRUNCATE_ONSET = 120
TARO_MODEL_DUR = TARO_TRUNCATE / TARO_SR # 8.192 s
GPU_DURATION_CAP = 300 # hard cap per @spaces.GPU call β€” never reserve more than this
MODEL_CONFIGS = {
"taro": {
"window_s": TARO_MODEL_DUR, # 8.192 s
"sr": TARO_SR, # 16000 (output resampled to TARGET_SR)
"secs_per_step": 0.025, # measured 0.023 s/step on H200
"load_overhead": 15, # model load + CAVP feature extraction
"tab_prefix": "taro",
"label": "TARO",
"regen_fn": None, # set after function definitions (avoids forward-ref)
},
"mmaudio": {
"window_s": 8.0, # MMAudio's fixed generation window
"sr": 48000, # resampled from 44100 in post-processing
"secs_per_step": 0.25, # measured 0.230 s/step on H200
"load_overhead": 30, # 15s warm + 15s model init
"tab_prefix": "mma",
"label": "MMAudio",
"regen_fn": None,
},
"hunyuan": {
"window_s": 15.0, # HunyuanFoley max video duration
"sr": 48000,
"secs_per_step": 0.35, # measured 0.328 s/step on H200
"load_overhead": 55, # ~55s to load the 10 GB XXL weights
"tab_prefix": "hf",
"label": "HunyuanFoley",
"regen_fn": None,
},
}
# Convenience aliases used only in the TARO inference path
TARO_SECS_PER_STEP = MODEL_CONFIGS["taro"]["secs_per_step"]
MMAUDIO_WINDOW = MODEL_CONFIGS["mmaudio"]["window_s"]
MMAUDIO_SECS_PER_STEP = MODEL_CONFIGS["mmaudio"]["secs_per_step"]
HUNYUAN_MAX_DUR = MODEL_CONFIGS["hunyuan"]["window_s"]
HUNYUAN_SECS_PER_STEP = MODEL_CONFIGS["hunyuan"]["secs_per_step"]
def _clamp_duration(secs: float, label: str) -> int:
"""Clamp a raw GPU-seconds estimate to [60, GPU_DURATION_CAP] and log it."""
result = min(GPU_DURATION_CAP, max(60, int(secs)))
print(f"[duration] {label}: {secs:.0f}s raw β†’ {result}s reserved")
return result
def _estimate_gpu_duration(model_key: str, num_samples: int, num_steps: int,
total_dur_s: float = None, crossfade_s: float = 0,
video_file: str = None) -> int:
"""Estimate GPU seconds for a full generation call.
Formula: num_samples Γ— n_segs Γ— num_steps Γ— secs_per_step + load_overhead
"""
cfg = MODEL_CONFIGS[model_key]
try:
if total_dur_s is None:
total_dur_s = get_video_duration(video_file)
n_segs = len(_build_segments(total_dur_s, cfg["window_s"], float(crossfade_s)))
except Exception:
n_segs = 1
secs = int(num_samples) * n_segs * int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
print(f"[duration] {cfg['label']}: {int(num_samples)}samp Γ— {n_segs}seg Γ— "
f"{int(num_steps)}steps β†’ {secs:.0f}s β†’ capped ", end="")
return _clamp_duration(secs, cfg["label"])
def _estimate_regen_duration(model_key: str, num_steps: int) -> int:
"""Estimate GPU seconds for a single-segment regen call."""
cfg = MODEL_CONFIGS[model_key]
secs = int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
print(f"[duration] {cfg['label']} regen: 1 seg Γ— {int(num_steps)} steps β†’ ", end="")
return _clamp_duration(secs, f"{cfg['label']} regen")
_TARO_CACHE_MAXLEN = 16 # evict oldest entries beyond this limit
_TARO_INFERENCE_CACHE: dict = {} # keyed by (video_file, seed, cfg, steps, mode, crossfade_s)
_TARO_CACHE_LOCK = threading.Lock()
def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: float) -> int:
n_segs = len(_build_segments(total_dur_s, TARO_MODEL_DUR, crossfade_s))
time_per_seg = num_steps * TARO_SECS_PER_STEP
max_s = int(600.0 / (n_segs * time_per_seg))
return max(1, min(max_s, MAX_SLOTS))
def _taro_duration(video_file, seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, num_samples):
"""Pre-GPU callable β€” must match _taro_gpu_infer's input order exactly."""
return _estimate_gpu_duration("taro", int(num_samples), int(num_steps),
video_file=video_file, crossfade_s=crossfade_s)
def _taro_infer_segment(
model, vae, vocoder,
cavp_feats_full, onset_feats_full,
seg_start_s: float, seg_end_s: float,
device, weight_dtype,
cfg_scale: float, num_steps: int, mode: str,
latents_scale,
euler_sampler, euler_maruyama_sampler,
) -> np.ndarray:
"""Single-segment TARO inference. Returns wav array trimmed to segment length."""
# CAVP features (4 fps)
cavp_start = int(round(seg_start_s * TARO_FPS))
cavp_slice = cavp_feats_full[cavp_start : cavp_start + TARO_TRUNCATE_FRAME]
if cavp_slice.shape[0] < TARO_TRUNCATE_FRAME:
pad = np.zeros(
(TARO_TRUNCATE_FRAME - cavp_slice.shape[0],) + cavp_slice.shape[1:],
dtype=cavp_slice.dtype,
)
cavp_slice = np.concatenate([cavp_slice, pad], axis=0)
video_feats = torch.from_numpy(cavp_slice).unsqueeze(0).to(device, weight_dtype)
# Onset features (onset_fps = TRUNCATE_ONSET / MODEL_DUR β‰ˆ 14.65 fps)
onset_fps = TARO_TRUNCATE_ONSET / TARO_MODEL_DUR
onset_start = int(round(seg_start_s * onset_fps))
onset_slice = onset_feats_full[onset_start : onset_start + TARO_TRUNCATE_ONSET]
if onset_slice.shape[0] < TARO_TRUNCATE_ONSET:
onset_slice = np.pad(
onset_slice,
((0, TARO_TRUNCATE_ONSET - onset_slice.shape[0]),),
mode="constant",
)
onset_feats_t = torch.from_numpy(onset_slice).unsqueeze(0).to(device, weight_dtype)
# Latent noise β€” shape matches MMDiT architecture (in_channels=8, 204Γ—16 spatial)
z = torch.randn(1, model.in_channels, 204, 16, device=device, dtype=weight_dtype)
sampling_kwargs = dict(
model=model,
latents=z,
y=onset_feats_t,
context=video_feats,
num_steps=int(num_steps),
heun=False,
cfg_scale=float(cfg_scale),
guidance_low=0.0,
guidance_high=0.7,
path_type="linear",
)
with torch.no_grad():
samples = (euler_maruyama_sampler if mode == "sde" else euler_sampler)(**sampling_kwargs)
# samplers return (output_tensor, zs) β€” index [0] for the audio latent
if isinstance(samples, tuple):
samples = samples[0]
# Decode: AudioLDM2 VAE β†’ mel β†’ vocoder β†’ waveform
samples = vae.decode(samples / latents_scale).sample
wav = vocoder(samples.squeeze().float()).detach().cpu().numpy()
return wav # full window β€” _stitch_wavs handles contact-edge trimming
# ================================================================== #
# TARO 16 kHz β†’ 48 kHz upsample #
# ================================================================== #
# TARO generates at 16 kHz; all other models output at 44.1/48 kHz.
# We upsample via sinc resampling (torchaudio, CPU-only) so the final
# stitched audio is uniformly at 48 kHz across all three models.
TARGET_SR = 48000 # unified output sample rate for all three models
TARO_SR_OUT = TARGET_SR
def _resample_to_target(wav: np.ndarray, src_sr: int,
dst_sr: int = None) -> np.ndarray:
"""Resample *wav* (mono or stereo numpy float32) from *src_sr* to *dst_sr*.
*dst_sr* defaults to TARGET_SR (48 kHz). No-op if src_sr == dst_sr.
Uses torchaudio Kaiser-windowed sinc resampling β€” CPU-only, ZeroGPU-safe.
"""
if dst_sr is None:
dst_sr = TARGET_SR
if src_sr == dst_sr:
return wav
stereo = wav.ndim == 2
t = torch.from_numpy(np.ascontiguousarray(wav.astype(np.float32)))
if not stereo:
t = t.unsqueeze(0) # [1, T]
t = torchaudio.functional.resample(t, src_sr, dst_sr)
if not stereo:
t = t.squeeze(0) # [T]
return t.numpy()
def _upsample_taro(wav_16k: np.ndarray) -> np.ndarray:
"""Upsample a mono 16 kHz numpy array to 48 kHz via sinc resampling (CPU).
torchaudio.functional.resample uses a Kaiser-windowed sinc filter β€”
mathematically optimal for bandlimited signals, zero CUDA risk.
Returns a mono float32 numpy array at 48 kHz.
"""
dur_in = len(wav_16k) / TARO_SR
print(f"[TARO upsample] {dur_in:.2f}s @ {TARO_SR}Hz β†’ {TARGET_SR}Hz (sinc, CPU) …")
result = _resample_to_target(wav_16k, TARO_SR)
print(f"[TARO upsample] done β€” {len(result)/TARGET_SR:.2f}s @ {TARGET_SR}Hz "
f"(expected {dur_in * 3:.2f}s, ratio 3Γ—)")
return result
def _stitch_wavs(wavs: list[np.ndarray], crossfade_s: float, db_boost: float,
total_dur_s: float, sr: int,
segments: list[tuple[float, float]] = None) -> np.ndarray:
"""Crossfade-join a list of wav arrays and trim to *total_dur_s*.
Works for both mono (T,) and stereo (C, T) arrays.
When *segments* is provided (list of (start, end) video-time pairs),
each wav is trimmed to its contact-edge window before joining:
contact_edge[i→i+1] = midpoint of overlap = (seg[i].end + seg[i+1].start) / 2
half_cf = crossfade_s / 2
seg i keep: [contact_edge[i-1→i] - half_cf, contact_edge[i→i+1] + half_cf]
expressed as sample offsets into the generated audio for that segment.
This guarantees every crossfade zone is exactly crossfade_s wide with no
raw bleed regardless of how much the inference windows overlap.
"""
def _trim(wav, start_s, end_s, seg_start_s):
"""Trim wav to [start_s, end_s] expressed in absolute video time,
where the wav starts at seg_start_s in video time."""
s = max(0, int(round((start_s - seg_start_s) * sr)))
e = int(round((end_s - seg_start_s) * sr))
e = min(e, wav.shape[1] if wav.ndim == 2 else len(wav))
return wav[:, s:e] if wav.ndim == 2 else wav[s:e]
if segments is None or len(segments) == 1:
out = wavs[0]
for nw in wavs[1:]:
out = _cf_join(out, nw, crossfade_s, db_boost, sr)
n = int(round(total_dur_s * sr))
return out[:, :n] if out.ndim == 2 else out[:n]
half_cf = crossfade_s / 2.0
# Compute contact edges between consecutive segments
contact_edges = [
(segments[i][1] + segments[i + 1][0]) / 2.0
for i in range(len(segments) - 1)
]
# Trim each segment to its keep window
trimmed = []
for i, (wav, (seg_start, seg_end)) in enumerate(zip(wavs, segments)):
keep_start = (contact_edges[i - 1] - half_cf) if i > 0 else seg_start
keep_end = (contact_edges[i] + half_cf) if i < len(segments) - 1 else total_dur_s
trimmed.append(_trim(wav, keep_start, keep_end, seg_start))
# Crossfade-join the trimmed segments
out = trimmed[0]
for nw in trimmed[1:]:
out = _cf_join(out, nw, crossfade_s, db_boost, sr)
n = int(round(total_dur_s * sr))
return out[:, :n] if out.ndim == 2 else out[:n]
def _save_wav(path: str, wav: np.ndarray, sr: int) -> None:
"""Save a numpy wav array (mono or stereo) to *path* via torchaudio."""
t = torch.from_numpy(np.ascontiguousarray(wav))
if t.ndim == 1:
t = t.unsqueeze(0)
torchaudio.save(path, t, sr)
def _log_inference_timing(label: str, elapsed: float, n_segs: int,
num_steps: int, constant: float) -> None:
"""Print a standardised inference-timing summary line."""
total_steps = n_segs * num_steps
secs_per_step = elapsed / total_steps if total_steps > 0 else 0
print(f"[{label}] Inference done: {n_segs} seg(s) Γ— {num_steps} steps in "
f"{elapsed:.1f}s wall β†’ {secs_per_step:.3f}s/step "
f"(current constant={constant})")
def _build_seg_meta(*, segments, wav_paths, audio_path, video_path,
silent_video, sr, model, crossfade_s, crossfade_db,
total_dur_s, **extras) -> dict:
"""Build the seg_meta dict shared by all three generate_* functions.
Model-specific keys are passed via **extras."""
meta = {
"segments": segments,
"wav_paths": wav_paths,
"audio_path": audio_path,
"video_path": video_path,
"silent_video": silent_video,
"sr": sr,
"model": model,
"crossfade_s": crossfade_s,
"crossfade_db": crossfade_db,
"total_dur_s": total_dur_s,
}
meta.update(extras)
return meta
def _post_process_samples(results: list, *, model: str, tmp_dir: str,
silent_video: str, segments: list,
crossfade_s: float, crossfade_db: float,
total_dur_s: float, sr: int,
extra_meta_fn=None) -> list:
"""Shared CPU post-processing for all three generate_* wrappers.
Each entry in *results* is a tuple whose first element is a list of
per-segment wav arrays. The remaining elements are model-specific
(e.g. TARO returns features, HunyuanFoley returns text_feats).
*extra_meta_fn(sample_idx, result_tuple, tmp_dir) -> dict* is an optional
callback that returns model-specific extra keys to merge into seg_meta
(e.g. cavp_path, onset_path, text_feats_path).
Returns a list of (video_path, audio_path, seg_meta) tuples.
"""
outputs = []
for sample_idx, result in enumerate(results):
seg_wavs = result[0]
full_wav = _stitch_wavs(seg_wavs, crossfade_s, crossfade_db, total_dur_s, sr, segments)
audio_path = os.path.join(tmp_dir, f"{model}_{sample_idx}.wav")
_save_wav(audio_path, full_wav, sr)
video_path = os.path.join(tmp_dir, f"{model}_{sample_idx}.mp4")
mux_video_audio(silent_video, audio_path, video_path, model=model)
wav_paths = _save_seg_wavs(seg_wavs, tmp_dir, f"{model}_{sample_idx}")
extras = extra_meta_fn(sample_idx, result, tmp_dir) if extra_meta_fn else {}
seg_meta = _build_seg_meta(
segments=segments, wav_paths=wav_paths, audio_path=audio_path,
video_path=video_path, silent_video=silent_video, sr=sr,
model=model, crossfade_s=crossfade_s, crossfade_db=crossfade_db,
total_dur_s=total_dur_s, **extras,
)
outputs.append((video_path, audio_path, seg_meta))
return outputs
def _cpu_preprocess(video_file: str, model_dur: float,
crossfade_s: float) -> tuple:
"""Shared CPU pre-processing for all generate_* wrappers.
Returns (tmp_dir, silent_video, total_dur_s, segments)."""
tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
silent_video = os.path.join(tmp_dir, "silent_input.mp4")
strip_audio_from_video(video_file, silent_video)
total_dur_s = get_video_duration(video_file)
segments = _build_segments(total_dur_s, model_dur, crossfade_s)
return tmp_dir, silent_video, total_dur_s, segments
@spaces.GPU(duration=_taro_duration)
def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, num_samples):
"""GPU-only TARO inference β€” model loading + feature extraction + diffusion.
Returns list of (wavs_list, onset_feats) per sample."""
seed_val = _resolve_seed(seed_val)
crossfade_s = float(crossfade_s)
num_samples = int(num_samples)
torch.set_grad_enabled(False)
device, weight_dtype = _get_device_and_dtype()
_ensure_syspath("TARO")
from TARO.onset_util import extract_onset
from TARO.samplers import euler_sampler, euler_maruyama_sampler
ctx = _ctx_load("taro_gpu_infer")
tmp_dir = ctx["tmp_dir"]
silent_video = ctx["silent_video"]
segments = ctx["segments"]
total_dur_s = ctx["total_dur_s"]
extract_cavp, onset_model = _load_taro_feature_extractors(device)
cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
# Onset features depend only on the video β€” extract once for all samples
onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
# Free feature extractors before loading the heavier inference models
del extract_cavp, onset_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
model, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype)
results = [] # list of (wavs, onset_feats) per sample
for sample_idx in range(num_samples):
sample_seed = seed_val + sample_idx
cache_key = (video_file, sample_seed, float(cfg_scale), int(num_steps), mode, crossfade_s)
with _TARO_CACHE_LOCK:
cached = _TARO_INFERENCE_CACHE.get(cache_key)
if cached is not None:
print(f"[TARO] Sample {sample_idx+1}: cache hit.")
results.append((cached["wavs"], cavp_feats, None))
else:
set_global_seed(sample_seed)
wavs = []
_t_infer_start = time.perf_counter()
for seg_start_s, seg_end_s in segments:
print(f"[TARO] Sample {sample_idx+1} | {seg_start_s:.2f}s – {seg_end_s:.2f}s")
wav = _taro_infer_segment(
model, vae, vocoder,
cavp_feats, onset_feats,
seg_start_s, seg_end_s,
device, weight_dtype,
cfg_scale, num_steps, mode,
latents_scale,
euler_sampler, euler_maruyama_sampler,
)
wavs.append(wav)
_log_inference_timing("TARO", time.perf_counter() - _t_infer_start,
len(segments), int(num_steps), TARO_SECS_PER_STEP)
with _TARO_CACHE_LOCK:
_TARO_INFERENCE_CACHE[cache_key] = {"wavs": wavs}
while len(_TARO_INFERENCE_CACHE) > _TARO_CACHE_MAXLEN:
_TARO_INFERENCE_CACHE.pop(next(iter(_TARO_INFERENCE_CACHE)))
results.append((wavs, cavp_feats, onset_feats))
# Free GPU memory between samples so VRAM fragmentation doesn't
# degrade diffusion quality on samples 2, 3, 4, etc.
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, num_samples):
"""TARO: video-conditioned diffusion, 16 kHz, 8.192 s sliding window.
CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost."""
crossfade_s = float(crossfade_s)
crossfade_db = float(crossfade_db)
num_samples = int(num_samples)
# ── CPU pre-processing (no GPU needed) ──
tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
video_file, TARO_MODEL_DUR, crossfade_s)
_ctx_store("taro_gpu_infer", {
"tmp_dir": tmp_dir, "silent_video": silent_video,
"segments": segments, "total_dur_s": total_dur_s,
})
# ── GPU inference only ──
results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, num_samples)
# ── CPU post-processing (no GPU needed) ──
# Upsample 16kHz β†’ 48kHz and normalise result tuples to (seg_wavs, ...)
cavp_path = os.path.join(tmp_dir, "taro_cavp.npy")
onset_path = os.path.join(tmp_dir, "taro_onset.npy")
_feats_saved = False
def _upsample_and_save_feats(result):
nonlocal _feats_saved
wavs, cavp_feats, onset_feats = result
wavs = [_upsample_taro(w) for w in wavs]
if not _feats_saved:
np.save(cavp_path, cavp_feats)
if onset_feats is not None:
np.save(onset_path, onset_feats)
_feats_saved = True
return (wavs, cavp_feats, onset_feats)
results = [_upsample_and_save_feats(r) for r in results]
def _taro_extras(sample_idx, result, td):
return {"cavp_path": cavp_path, "onset_path": onset_path}
outputs = _post_process_samples(
results, model="taro", tmp_dir=tmp_dir,
silent_video=silent_video, segments=segments,
crossfade_s=crossfade_s, crossfade_db=crossfade_db,
total_dur_s=total_dur_s, sr=TARO_SR_OUT,
extra_meta_fn=_taro_extras,
)
return _pad_outputs(outputs)
# ================================================================== #
# MMAudio #
# ================================================================== #
# Constants sourced from MMAudio/mmaudio/model/sequence_config.py:
# CONFIG_44K: duration=8.0 s, sampling_rate=44100
# CLIP encoder: 8 fps, 384Γ—384 px
# Synchformer: 25 fps, 224Γ—224 px
# Default variant: large_44k_v2
# MMAudio uses flow-matching (FlowMatching with euler inference).
# generate() handles all feature extraction + decoding internally.
# ================================================================== #
def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
"""Pre-GPU callable β€” must match _mmaudio_gpu_infer's input order exactly."""
return _estimate_gpu_duration("mmaudio", int(num_samples), int(num_steps),
video_file=video_file, crossfade_s=crossfade_s)
@spaces.GPU(duration=_mmaudio_duration)
def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
"""GPU-only MMAudio inference β€” model loading + flow-matching generation.
Returns list of (seg_audios, sr) per sample."""
_ensure_syspath("MMAudio")
from mmaudio.eval_utils import generate, load_video
from mmaudio.model.flow_matching import FlowMatching
seed_val = _resolve_seed(seed_val)
num_samples = int(num_samples)
crossfade_s = float(crossfade_s)
device, dtype = _get_device_and_dtype()
net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
ctx = _ctx_load("mmaudio_gpu_infer")
segments = ctx["segments"]
seg_clip_paths = ctx["seg_clip_paths"]
sr = seq_cfg.sampling_rate # 44100
results = []
for sample_idx in range(num_samples):
rng = torch.Generator(device=device)
rng.manual_seed(seed_val + sample_idx)
seg_audios = []
_t_mma_start = time.perf_counter()
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
for seg_i, (seg_start, seg_end) in enumerate(segments):
seg_dur = seg_end - seg_start
seg_path = seg_clip_paths[seg_i]
video_info = load_video(seg_path, seg_dur)
clip_frames = video_info.clip_frames.unsqueeze(0)
sync_frames = video_info.sync_frames.unsqueeze(0)
actual_dur = video_info.duration_sec
seq_cfg.duration = actual_dur
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
print(f"[MMAudio] Sample {sample_idx+1} | seg {seg_i+1}/{len(segments)} "
f"{seg_start:.1f}–{seg_end:.1f}s | dur={actual_dur:.2f}s | prompt='{prompt}'")
with torch.no_grad():
audios = generate(
clip_frames,
sync_frames,
[prompt],
negative_text=[negative_prompt] if negative_prompt else None,
feature_utils=feature_utils,
net=net,
fm=fm,
rng=rng,
cfg_strength=float(cfg_strength),
)
wav = audios.float().cpu()[0].numpy() # (C, T) β€” full window
seg_audios.append(wav)
_log_inference_timing("MMAudio", time.perf_counter() - _t_mma_start,
len(segments), int(num_steps), MMAUDIO_SECS_PER_STEP)
results.append((seg_audios, sr))
# Free GPU memory between samples to prevent VRAM fragmentation
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
"""MMAudio: flow-matching video-to-audio, 44.1 kHz, 8 s sliding window.
CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost."""
num_samples = int(num_samples)
crossfade_s = float(crossfade_s)
crossfade_db = float(crossfade_db)
# ── CPU pre-processing ──
tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
video_file, MMAUDIO_WINDOW, crossfade_s)
print(f"[MMAudio] Video={total_dur_s:.2f}s | {len(segments)} segment(s) Γ— ≀8 s")
seg_clip_paths = [
_extract_segment_clip(silent_video, s, e - s, os.path.join(tmp_dir, f"mma_seg_{i}.mp4"))
for i, (s, e) in enumerate(segments)
]
_ctx_store("mmaudio_gpu_infer", {"segments": segments, "seg_clip_paths": seg_clip_paths})
# ── GPU inference only ──
results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db,
num_samples)
# ── CPU post-processing ──
# Resample 44100 β†’ 48000 and normalise tuples to (seg_wavs, ...)
resampled = []
for seg_audios, sr in results:
if sr != TARGET_SR:
print(f"[MMAudio upsample] resampling {sr}Hz β†’ {TARGET_SR}Hz (sinc, CPU) …")
seg_audios = [_resample_to_target(w, sr) for w in seg_audios]
print(f"[MMAudio upsample] done β€” {len(seg_audios)} seg(s) @ {TARGET_SR}Hz")
resampled.append((seg_audios,))
outputs = _post_process_samples(
resampled, model="mmaudio", tmp_dir=tmp_dir,
silent_video=silent_video, segments=segments,
crossfade_s=crossfade_s, crossfade_db=crossfade_db,
total_dur_s=total_dur_s, sr=TARGET_SR,
)
return _pad_outputs(outputs)
# ================================================================== #
# HunyuanVideoFoley #
# ================================================================== #
# Constants sourced from HunyuanVideo-Foley/hunyuanvideo_foley/constants.py
# and configs/hunyuanvideo-foley-xxl.yaml:
# sample_rate = 48000 Hz (from DAC VAE)
# audio_frame_rate = 50 (latent fps, xxl config)
# max video duration = 15 s
# SigLIP2 fps = 8, Synchformer fps = 25
# CLAP text encoder: laion/larger_clap_general (auto-downloaded from HF Hub)
# Default guidance_scale=4.5, num_inference_steps=50
# ================================================================== #
def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db,
num_samples):
"""Pre-GPU callable β€” must match _hunyuan_gpu_infer's input order exactly."""
return _estimate_gpu_duration("hunyuan", int(num_samples), int(num_steps),
video_file=video_file, crossfade_s=crossfade_s)
@spaces.GPU(duration=_hunyuan_duration)
def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db,
num_samples):
"""GPU-only HunyuanFoley inference β€” model loading + feature extraction + denoising.
Returns list of (seg_wavs, sr, text_feats) per sample."""
_ensure_syspath("HunyuanVideo-Foley")
from hunyuanvideo_foley.utils.model_utils import denoise_process
from hunyuanvideo_foley.utils.feature_utils import feature_process
seed_val = _resolve_seed(seed_val)
num_samples = int(num_samples)
crossfade_s = float(crossfade_s)
set_global_seed(seed_val)
device, _ = _get_device_and_dtype()
model_size = model_size.lower()
model_dict, cfg = _load_hunyuan_model(device, model_size)
ctx = _ctx_load("hunyuan_gpu_infer")
segments = ctx["segments"]
total_dur_s = ctx["total_dur_s"]
dummy_seg_path = ctx["dummy_seg_path"]
seg_clip_paths = ctx["seg_clip_paths"]
# Text feature extraction (GPU β€” runs once for all segments)
_, text_feats, _ = feature_process(
dummy_seg_path,
prompt if prompt else "",
model_dict,
cfg,
neg_prompt=negative_prompt if negative_prompt else None,
)
# Import visual-only feature extractor to avoid redundant text extraction
# per segment (text_feats already computed once above for the whole batch).
from hunyuanvideo_foley.utils.feature_utils import encode_video_features
results = []
for sample_idx in range(num_samples):
seg_wavs = []
sr = 48000
_t_hny_start = time.perf_counter()
for seg_i, (seg_start, seg_end) in enumerate(segments):
seg_dur = seg_end - seg_start
seg_path = seg_clip_paths[seg_i]
# Extract only visual features β€” reuse text_feats from above
visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict)
print(f"[HunyuanFoley] Sample {sample_idx+1} | seg {seg_i+1}/{len(segments)} "
f"{seg_start:.1f}–{seg_end:.1f}s β†’ {seg_audio_len:.2f}s audio")
audio_batch, sr = denoise_process(
visual_feats,
text_feats,
seg_audio_len,
model_dict,
cfg,
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_steps),
batch_size=1,
)
wav = audio_batch[0].float().cpu().numpy() # full window
seg_wavs.append(wav)
_log_inference_timing("HunyuanFoley", time.perf_counter() - _t_hny_start,
len(segments), int(num_steps), HUNYUAN_SECS_PER_STEP)
results.append((seg_wavs, sr, text_feats))
# Free GPU memory between samples to prevent VRAM fragmentation
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
"""HunyuanVideoFoley: text-guided foley, 48 kHz, up to 15 s.
CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost."""
num_samples = int(num_samples)
crossfade_s = float(crossfade_s)
crossfade_db = float(crossfade_db)
# ── CPU pre-processing (no GPU needed) ──
tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
video_file, HUNYUAN_MAX_DUR, crossfade_s)
print(f"[HunyuanFoley] Video={total_dur_s:.2f}s | {len(segments)} segment(s) Γ— ≀15 s")
# Pre-extract dummy segment for text feature extraction (ffmpeg, CPU)
dummy_seg_path = _extract_segment_clip(
silent_video, 0, min(total_dur_s, HUNYUAN_MAX_DUR),
os.path.join(tmp_dir, "_seg_dummy.mp4"),
)
# Pre-extract all segment clips (ffmpeg, CPU)
seg_clip_paths = [
_extract_segment_clip(silent_video, s, e - s, os.path.join(tmp_dir, f"hny_seg_{i}.mp4"))
for i, (s, e) in enumerate(segments)
]
_ctx_store("hunyuan_gpu_infer", {
"segments": segments, "total_dur_s": total_dur_s,
"dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
})
# ── GPU inference only ──
results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db, num_samples)
# ── CPU post-processing (no GPU needed) ──
def _hunyuan_extras(sample_idx, result, td):
_, _sr, text_feats = result
path = os.path.join(td, f"hunyuan_{sample_idx}_text_feats.pt")
torch.save(text_feats, path)
return {"text_feats_path": path}
outputs = _post_process_samples(
results, model="hunyuan", tmp_dir=tmp_dir,
silent_video=silent_video, segments=segments,
crossfade_s=crossfade_s, crossfade_db=crossfade_db,
total_dur_s=total_dur_s, sr=48000,
extra_meta_fn=_hunyuan_extras,
)
return _pad_outputs(outputs)
# ================================================================== #
# SEGMENT REGENERATION HELPERS #
# ================================================================== #
# Each regen function:
# 1. Runs inference for ONE segment (random seed, current settings)
# 2. Splices the new wav into the stored wavs list
# 3. Re-stitches the full track, re-saves .wav and re-muxes .mp4
# 4. Returns (new_video_path, new_audio_path, updated_seg_meta, new_waveform_html)
# ================================================================== #
def _splice_and_save(new_wav, seg_idx, meta, slot_id):
"""Replace wavs[seg_idx] with new_wav, re-stitch, re-save, re-mux.
Returns (video_path, audio_path, updated_meta, waveform_html).
"""
wavs = _load_seg_wavs(meta["wav_paths"])
wavs[seg_idx]= new_wav
crossfade_s = float(meta["crossfade_s"])
crossfade_db = float(meta["crossfade_db"])
sr = int(meta["sr"])
total_dur_s = float(meta["total_dur_s"])
silent_video = meta["silent_video"]
segments = meta["segments"]
model = meta["model"]
full_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, sr, segments)
# Save new audio β€” use a new timestamped filename so Gradio / the browser
# treats it as a genuinely different file and reloads the video player.
_ts = int(time.time() * 1000)
tmp_dir = os.path.dirname(meta["audio_path"])
_base = os.path.splitext(os.path.basename(meta["audio_path"]))[0]
# Strip any previous timestamp suffix before adding a new one
_base_clean = _base.rsplit("_regen_", 1)[0]
audio_path = os.path.join(tmp_dir, f"{_base_clean}_regen_{_ts}.wav")
_save_wav(audio_path, full_wav, sr)
# Re-mux into a new video file so the browser is forced to reload it
_vid_base = os.path.splitext(os.path.basename(meta["video_path"]))[0]
_vid_base_clean = _vid_base.rsplit("_regen_", 1)[0]
video_path = os.path.join(tmp_dir, f"{_vid_base_clean}_regen_{_ts}.mp4")
mux_video_audio(silent_video, audio_path, video_path, model=model)
# Save updated segment wavs to .npy files
updated_wav_paths = _save_seg_wavs(wavs, tmp_dir, os.path.splitext(_base_clean)[0])
updated_meta = dict(meta)
updated_meta["wav_paths"] = updated_wav_paths
updated_meta["audio_path"] = audio_path
updated_meta["video_path"] = video_path
state_json_new = json.dumps(updated_meta)
waveform_html = _build_waveform_html(audio_path, segments, slot_id, "",
state_json=state_json_new,
video_path=video_path,
crossfade_s=crossfade_s)
return video_path, audio_path, updated_meta, waveform_html
def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, slot_id=None):
# If cached CAVP/onset features exist, skip ~10s feature-extractor overhead
try:
meta = json.loads(seg_meta_json)
cavp_ok = os.path.exists(meta.get("cavp_path", ""))
onset_ok = os.path.exists(meta.get("onset_path", ""))
if cavp_ok and onset_ok:
cfg = MODEL_CONFIGS["taro"]
secs = int(num_steps) * cfg["secs_per_step"] + 5 # 5s model-load only
result = min(GPU_DURATION_CAP, max(30, int(secs)))
print(f"[duration] TARO regen (cache hit): 1 seg Γ— {int(num_steps)} steps β†’ {secs:.0f}s β†’ capped {result}s")
return result
except Exception:
pass
return _estimate_regen_duration("taro", int(num_steps))
@spaces.GPU(duration=_taro_regen_duration)
def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, slot_id=None):
"""GPU-only TARO regen β€” returns new_wav for a single segment."""
meta = json.loads(seg_meta_json)
seg_idx = int(seg_idx)
seg_start_s, seg_end_s = meta["segments"][seg_idx]
torch.set_grad_enabled(False)
device, weight_dtype = _get_device_and_dtype()
_ensure_syspath("TARO")
from TARO.samplers import euler_sampler, euler_maruyama_sampler
# Load cached CAVP/onset features from .npy files (CPU I/O, fast, outside GPU budget)
cavp_path = meta.get("cavp_path", "")
onset_path = meta.get("onset_path", "")
if cavp_path and os.path.exists(cavp_path) and onset_path and os.path.exists(onset_path):
print("[TARO regen] Loading cached CAVP + onset features from disk")
cavp_feats = np.load(cavp_path)
onset_feats = np.load(onset_path)
else:
print("[TARO regen] Cache miss β€” re-extracting CAVP + onset features")
from TARO.onset_util import extract_onset
extract_cavp, onset_model = _load_taro_feature_extractors(device)
silent_video = meta["silent_video"]
tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
del extract_cavp, onset_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
model_net, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype)
set_global_seed(random.randint(0, 2**32 - 1))
return _taro_infer_segment(
model_net, vae, vocoder, cavp_feats, onset_feats,
seg_start_s, seg_end_s, device, weight_dtype,
float(cfg_scale), int(num_steps), mode, latents_scale,
euler_sampler, euler_maruyama_sampler,
)
def regen_taro_segment(video_file, seg_idx, seg_meta_json,
seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, slot_id):
"""Regenerate one TARO segment. GPU inference + CPU splice/save."""
meta = json.loads(seg_meta_json)
seg_idx = int(seg_idx)
# GPU: inference β€” CAVP/onset features loaded from disk paths in seg_meta_json
new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, slot_id)
# Upsample 16kHz β†’ 48kHz (sinc, CPU)
new_wav = _upsample_taro(new_wav)
# CPU: splice, stitch, mux, save
video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
new_wav, seg_idx, meta, slot_id
)
return video_path, audio_path, json.dumps(updated_meta), waveform_html
def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db,
slot_id=None):
return _estimate_regen_duration("mmaudio", int(num_steps))
@spaces.GPU(duration=_mmaudio_regen_duration)
def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db,
slot_id=None):
"""GPU-only MMAudio regen β€” returns (new_wav, sr) for a single segment."""
meta = json.loads(seg_meta_json)
seg_idx = int(seg_idx)
seg_start, seg_end = meta["segments"][seg_idx]
seg_dur = seg_end - seg_start
_ensure_syspath("MMAudio")
from mmaudio.eval_utils import generate, load_video
from mmaudio.model.flow_matching import FlowMatching
device, dtype = _get_device_and_dtype()
net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
sr = seq_cfg.sampling_rate
# Extract segment clip inside the GPU function β€” ffmpeg is CPU-only and safe here.
# This avoids any cross-process context passing that fails under ZeroGPU isolation.
seg_path = _extract_segment_clip(
meta["silent_video"], seg_start, seg_dur,
os.path.join(_register_tmp_dir(tempfile.mkdtemp()), "regen_seg.mp4"),
)
rng = torch.Generator(device=device)
rng.manual_seed(random.randint(0, 2**32 - 1))
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=int(num_steps))
video_info = load_video(seg_path, seg_dur)
clip_frames = video_info.clip_frames.unsqueeze(0)
sync_frames = video_info.sync_frames.unsqueeze(0)
actual_dur = video_info.duration_sec
seq_cfg.duration = actual_dur
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
with torch.no_grad():
audios = generate(
clip_frames, sync_frames, [prompt],
negative_text=[negative_prompt] if negative_prompt else None,
feature_utils=feature_utils, net=net, fm=fm, rng=rng,
cfg_strength=float(cfg_strength),
)
new_wav = audios.float().cpu()[0].numpy() # full window β€” _stitch_wavs trims
return new_wav, sr
def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id):
"""Regenerate one MMAudio segment. GPU inference + CPU splice/save."""
meta = json.loads(seg_meta_json)
seg_idx = int(seg_idx)
# GPU: inference (segment clip extraction happens inside the GPU function)
new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db,
slot_id)
# Resample to 48kHz if needed (MMAudio outputs at 44100 Hz)
if sr != TARGET_SR:
print(f"[MMAudio regen upsample] {sr}Hz β†’ {TARGET_SR}Hz (sinc, CPU) …")
new_wav = _resample_to_target(new_wav, sr)
sr = TARGET_SR
meta["sr"] = sr
# CPU: splice, stitch, mux, save
video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
new_wav, seg_idx, meta, slot_id
)
return video_path, audio_path, json.dumps(updated_meta), waveform_html
def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db, slot_id=None):
return _estimate_regen_duration("hunyuan", int(num_steps))
@spaces.GPU(duration=_hunyuan_regen_duration)
def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db, slot_id=None):
"""GPU-only HunyuanFoley regen β€” returns (new_wav, sr) for a single segment."""
meta = json.loads(seg_meta_json)
seg_idx = int(seg_idx)
seg_start, seg_end = meta["segments"][seg_idx]
seg_dur = seg_end - seg_start
_ensure_syspath("HunyuanVideo-Foley")
from hunyuanvideo_foley.utils.model_utils import denoise_process
from hunyuanvideo_foley.utils.feature_utils import feature_process
device, _ = _get_device_and_dtype()
model_dict, cfg = _load_hunyuan_model(device, model_size)
set_global_seed(random.randint(0, 2**32 - 1))
# Extract segment clip inside the GPU function β€” ffmpeg is CPU-only and safe here.
seg_path = _extract_segment_clip(
meta["silent_video"], seg_start, seg_dur,
os.path.join(_register_tmp_dir(tempfile.mkdtemp()), "regen_seg.mp4"),
)
text_feats_path = meta.get("text_feats_path", "")
if text_feats_path and os.path.exists(text_feats_path):
print("[HunyuanFoley regen] Loading cached text features from disk")
from hunyuanvideo_foley.utils.feature_utils import encode_video_features
visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict)
text_feats = torch.load(text_feats_path, map_location=device, weights_only=False)
else:
print("[HunyuanFoley regen] Cache miss β€” extracting text + visual features")
visual_feats, text_feats, seg_audio_len = feature_process(
seg_path, prompt if prompt else "", model_dict, cfg,
neg_prompt=negative_prompt if negative_prompt else None,
)
audio_batch, sr = denoise_process(
visual_feats, text_feats, seg_audio_len, model_dict, cfg,
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_steps),
batch_size=1,
)
new_wav = audio_batch[0].float().cpu().numpy() # full window β€” _stitch_wavs trims
return new_wav, sr
def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db, slot_id):
"""Regenerate one HunyuanFoley segment. GPU inference + CPU splice/save."""
meta = json.loads(seg_meta_json)
seg_idx = int(seg_idx)
# GPU: inference (segment clip extraction happens inside the GPU function)
new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db, slot_id)
meta["sr"] = sr
# CPU: splice, stitch, mux, save
video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
new_wav, seg_idx, meta, slot_id
)
return video_path, audio_path, json.dumps(updated_meta), waveform_html
# Wire up regen_fn references now that the functions are defined
MODEL_CONFIGS["taro"]["regen_fn"] = regen_taro_segment
MODEL_CONFIGS["mmaudio"]["regen_fn"] = regen_mmaudio_segment
MODEL_CONFIGS["hunyuan"]["regen_fn"] = regen_hunyuan_segment
# ================================================================== #
# CROSS-MODEL REGEN WRAPPERS #
# ================================================================== #
# Three shared endpoints β€” one per model β€” that can be called from #
# *any* slot tab. slot_id is passed as plain string data so the #
# result is applied back to the correct slot by the JS listener. #
# The new segment is resampled to match the slot's existing SR before #
# being handed to _splice_and_save, so TARO (16 kHz) / MMAudio #
# (44.1 kHz) / Hunyuan (48 kHz) outputs can all be mixed freely. #
# ================================================================== #
def _resample_to_slot_sr(wav: np.ndarray, src_sr: int, dst_sr: int,
slot_wav_ref: np.ndarray = None) -> np.ndarray:
"""Resample *wav* from src_sr to dst_sr, then match channel layout to
*slot_wav_ref* (the first existing segment in the slot).
TARO is mono (T,), MMAudio/Hunyuan are stereo (C, T). Mixing them
without normalisation causes a shape mismatch in _cf_join. Rules:
- stereo β†’ mono : average channels
- mono β†’ stereo: duplicate the single channel
"""
wav = _resample_to_target(wav, src_sr, dst_sr)
# Match channel layout to the slot's existing segments
if slot_wav_ref is not None:
slot_stereo = slot_wav_ref.ndim == 2
wav_stereo = wav.ndim == 2
if slot_stereo and not wav_stereo:
wav = np.stack([wav, wav], axis=0) # mono β†’ stereo (C, T)
elif not slot_stereo and wav_stereo:
wav = wav.mean(axis=0) # stereo β†’ mono (T,)
return wav
def _xregen_clip_window(meta: dict, seg_idx: int, target_window_s: float) -> tuple:
"""Compute the video clip window for a cross-model regen.
Centers *target_window_s* on the original segment's midpoint, clamped to
[0, total_dur_s]. Returns (clip_start, clip_end, clip_dur).
If the video is shorter than *target_window_s*, the full video is used
(suboptimal but never breaks). If the segment span exceeds
*target_window_s*, the caller should run _build_segments on the span and
generate multiple sub-segments β€” but the clip window is still returned as
the full segment span so the caller can decide.
"""
total_dur_s = float(meta["total_dur_s"])
seg_start, seg_end = meta["segments"][seg_idx]
seg_mid = (seg_start + seg_end) / 2.0
half_win = target_window_s / 2.0
clip_start = max(0.0, seg_mid - half_win)
clip_end = min(total_dur_s, seg_mid + half_win)
# If clamped at one end, extend the other to preserve full window if possible
if clip_start == 0.0:
clip_end = min(total_dur_s, target_window_s)
elif clip_end == total_dur_s:
clip_start = max(0.0, total_dur_s - target_window_s)
clip_dur = clip_end - clip_start
return clip_start, clip_end, clip_dur
def _xregen_splice(new_wav_raw: np.ndarray, src_sr: int,
meta: dict, seg_idx: int, slot_id: str,
clip_start_s: float = None) -> tuple:
"""Shared epilogue for all xregen_* functions: resample β†’ splice β†’ save.
Returns (video_path, waveform_html).
*clip_start_s* is the absolute video time where new_wav_raw starts.
When the clip was centered on the segment midpoint (not at seg_start),
we need to shift the wav so _stitch_wavs can trim it correctly relative
to the original segment's start. We do this by prepending silence so
the wav's time origin aligns with the original segment's start.
"""
slot_sr = int(meta["sr"])
slot_wavs = _load_seg_wavs(meta["wav_paths"])
new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr, slot_wavs[0])
# Align new_wav so sample index 0 corresponds to seg_start in video time.
# _stitch_wavs trims using seg_start as the time origin, so if the clip
# started AFTER seg_start (clip_start_s > seg_start), we prepend silence
# equal to (clip_start_s - seg_start) to shift the audio back to seg_start.
if clip_start_s is not None:
seg_start = meta["segments"][seg_idx][0]
offset_s = seg_start - clip_start_s # negative when clip starts after seg_start
if offset_s < 0:
pad_samples = int(round(abs(offset_s) * slot_sr))
silence = np.zeros(
(new_wav.shape[0], pad_samples) if new_wav.ndim == 2 else pad_samples,
dtype=new_wav.dtype,
)
new_wav = np.concatenate([silence, new_wav], axis=1 if new_wav.ndim == 2 else 0)
video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
new_wav, seg_idx, meta, slot_id
)
return video_path, waveform_html
def _xregen_dispatch(state_json: str, seg_idx: int, slot_id: str, infer_fn):
"""Shared generator skeleton for all xregen_* wrappers.
Yields pending HTML immediately, then calls *infer_fn()* β€” a zero-argument
callable that runs model-specific CPU prep + GPU inference and returns
(wav_array, src_sr, clip_start_s). For TARO, *infer_fn* should return
the wav already upsampled to 48 kHz; pass TARO_SR_OUT as src_sr.
Yields:
First: (gr.update(), gr.update(value=pending_html)) β€” shown while GPU runs
Second: (gr.update(value=video_path), gr.update(value=waveform_html))
"""
meta = json.loads(state_json)
pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
yield gr.update(), gr.update(value=pending_html)
new_wav_raw, src_sr, clip_start_s = infer_fn()
video_path, waveform_html = _xregen_splice(new_wav_raw, src_sr, meta, seg_idx, slot_id, clip_start_s)
yield gr.update(value=video_path), gr.update(value=waveform_html)
def xregen_taro(seg_idx, state_json, slot_id,
seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db,
request: gr.Request = None):
"""Cross-model regen: run TARO on its optimal window, splice into *slot_id*."""
seg_idx = int(seg_idx)
meta = json.loads(state_json)
def _run():
clip_start, clip_end, clip_dur = _xregen_clip_window(meta, seg_idx, TARO_MODEL_DUR)
tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
clip_path = _extract_segment_clip(
meta["silent_video"], clip_start, clip_dur,
os.path.join(tmp_dir, "xregen_taro_clip.mp4"),
)
# Build a minimal fake-video meta so generate_taro can run on clip_path
sub_segs = _build_segments(clip_dur, TARO_MODEL_DUR, float(crossfade_s))
sub_meta_json = json.dumps({
"segments": sub_segs, "silent_video": clip_path,
"total_dur_s": clip_dur,
})
# Run full TARO generation pipeline on the clip
_ctx_store("taro_gpu_infer", {
"tmp_dir": tmp_dir, "silent_video": clip_path,
"segments": sub_segs, "total_dur_s": clip_dur,
})
results = _taro_gpu_infer(clip_path, seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, 1)
wavs, _, _ = results[0]
wavs = [_upsample_taro(w) for w in wavs]
wav = _stitch_wavs(wavs, float(crossfade_s), float(crossfade_db),
clip_dur, TARO_SR_OUT, sub_segs)
return wav, TARO_SR_OUT, clip_start
yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
def xregen_mmaudio(seg_idx, state_json, slot_id,
prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db,
request: gr.Request = None):
"""Cross-model regen: run MMAudio on its optimal window, splice into *slot_id*."""
seg_idx = int(seg_idx)
meta = json.loads(state_json)
def _run():
clip_start, clip_end, clip_dur = _xregen_clip_window(meta, seg_idx, MMAUDIO_WINDOW)
tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
clip_path = _extract_segment_clip(
meta["silent_video"], clip_start, clip_dur,
os.path.join(tmp_dir, "xregen_mmaudio_clip.mp4"),
)
sub_segs = _build_segments(clip_dur, MMAUDIO_WINDOW, float(crossfade_s))
seg_clip_paths = [
_extract_segment_clip(
clip_path, s, e - s,
os.path.join(tmp_dir, f"xregen_mma_sub_{i}.mp4"),
)
for i, (s, e) in enumerate(sub_segs)
]
_ctx_store("mmaudio_gpu_infer", {
"segments": sub_segs, "seg_clip_paths": seg_clip_paths,
})
results = _mmaudio_gpu_infer(clip_path, prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, 1)
seg_wavs, sr = results[0]
wav = _stitch_wavs(seg_wavs, float(crossfade_s), float(crossfade_db),
clip_dur, sr, sub_segs)
if sr != TARGET_SR:
wav = _resample_to_target(wav, sr)
sr = TARGET_SR
return wav, sr, clip_start
yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
def xregen_hunyuan(seg_idx, state_json, slot_id,
prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db,
request: gr.Request = None):
"""Cross-model regen: run HunyuanFoley on its optimal window, splice into *slot_id*."""
seg_idx = int(seg_idx)
meta = json.loads(state_json)
def _run():
clip_start, clip_end, clip_dur = _xregen_clip_window(meta, seg_idx, HUNYUAN_MAX_DUR)
tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
clip_path = _extract_segment_clip(
meta["silent_video"], clip_start, clip_dur,
os.path.join(tmp_dir, "xregen_hunyuan_clip.mp4"),
)
sub_segs = _build_segments(clip_dur, HUNYUAN_MAX_DUR, float(crossfade_s))
seg_clip_paths = [
_extract_segment_clip(
clip_path, s, e - s,
os.path.join(tmp_dir, f"xregen_hny_sub_{i}.mp4"),
)
for i, (s, e) in enumerate(sub_segs)
]
dummy_seg_path = _extract_segment_clip(
clip_path, 0, min(clip_dur, HUNYUAN_MAX_DUR),
os.path.join(tmp_dir, "xregen_hny_dummy.mp4"),
)
_ctx_store("hunyuan_gpu_infer", {
"segments": sub_segs, "total_dur_s": clip_dur,
"dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
})
results = _hunyuan_gpu_infer(clip_path, prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db, 1)
seg_wavs, sr, _ = results[0]
wav = _stitch_wavs(seg_wavs, float(crossfade_s), float(crossfade_db),
clip_dur, sr, sub_segs)
return wav, sr, clip_start
yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
# ================================================================== #
# SHARED UI HELPERS #
# ================================================================== #
def _register_regen_handlers(tab_prefix, model_key, regen_seg_tb, regen_state_tb,
input_components, slot_vids, slot_waves):
"""Register per-slot regen button handlers for a model tab.
This replaces the three nearly-identical for-loops that previously existed
for TARO, MMAudio, and HunyuanFoley tabs.
Args:
tab_prefix: e.g. "taro", "mma", "hf"
model_key: e.g. "taro", "mmaudio", "hunyuan"
regen_seg_tb: gr.Textbox for seg_idx (render=False)
regen_state_tb: gr.Textbox for state_json (render=False)
input_components: list of Gradio input components (video, seed, etc.)
β€” order must match regen_fn signature after (seg_idx, state_json, video)
slot_vids: list of gr.Video components per slot
slot_waves: list of gr.HTML components per slot
Returns:
list of hidden gr.Buttons (one per slot)
"""
cfg = MODEL_CONFIGS[model_key]
regen_fn = cfg["regen_fn"]
label = cfg["label"]
btns = []
for _i in range(MAX_SLOTS):
_slot_id = f"{tab_prefix}_{_i}"
_btn = gr.Button(render=False, elem_id=f"regen_btn_{_slot_id}")
btns.append(_btn)
print(f"[startup] registering regen handler for slot {_slot_id}")
def _make_regen(_si, _sid, _model_key, _label, _regen_fn):
def _do(seg_idx, state_json, *args):
print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} "
f"state_json_len={len(state_json) if state_json else 0}")
if not state_json:
print(f"[regen {_label}] early-exit: state_json empty")
yield gr.update(), gr.update()
return
lock = _get_slot_lock(_sid)
with lock:
state = json.loads(state_json)
pending_html = _build_regen_pending_html(
state["segments"], int(seg_idx), _sid, ""
)
yield gr.update(), gr.update(value=pending_html)
print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} β€” calling regen")
try:
# args[0] = video, args[1:] = model-specific params
vid, aud, new_meta_json, html = _regen_fn(
args[0], int(seg_idx), state_json, *args[1:], _sid,
)
print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} β€” done, vid={vid!r}")
except Exception as _e:
print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} β€” ERROR: {_e}")
raise
yield gr.update(value=vid), gr.update(value=html)
return _do
_btn.click(
fn=_make_regen(_i, _slot_id, model_key, label, regen_fn),
inputs=[regen_seg_tb, regen_state_tb] + input_components,
outputs=[slot_vids[_i], slot_waves[_i]],
api_name=f"regen_{tab_prefix}_{_i}",
)
return btns
def _pad_outputs(outputs: list) -> list:
"""Flatten (video, audio, seg_meta) triples and pad to MAX_SLOTS * 3 with None.
Each entry in *outputs* must be a (video_path, audio_path, seg_meta) tuple where
seg_meta = {"segments": [...], "audio_path": str, "video_path": str,
"sr": int, "model": str, "crossfade_s": float,
"crossfade_db": float, "wav_paths": list[str]}
"""
result = []
for i in range(MAX_SLOTS):
if i < len(outputs):
result.extend(outputs[i]) # 3 items: video, audio, meta
else:
result.extend([None, None, None])
return result
# ------------------------------------------------------------------ #
# WaveSurfer waveform + segment marker HTML builder #
# ------------------------------------------------------------------ #
def _build_regen_pending_html(segments: list, regen_seg_idx: int, slot_id: str,
hidden_input_id: str) -> str:
"""Return a waveform placeholder shown while a segment is being regenerated.
Renders a dark bar with the active segment highlighted in amber + a spinner.
"""
segs_json = json.dumps(segments)
seg_colors = [c.format(a="0.25") for c in SEG_COLORS]
active_color = "rgba(255,180,0,0.55)"
duration = segments[-1][1] if segments else 1.0
seg_divs = ""
for i, seg in enumerate(segments):
# Draw only the non-overlapping (unique) portion of each segment so that
# overlapping windows don't visually bleed into adjacent segments.
# Each segment owns the region from its own start up to the next segment's
# start (or its own end for the final segment).
seg_start = seg[0]
seg_end = segments[i + 1][0] if i + 1 < len(segments) else seg[1]
left_pct = seg_start / duration * 100
width_pct = (seg_end - seg_start) / duration * 100
color = active_color if i == regen_seg_idx else seg_colors[i % len(seg_colors)]
extra = "border:2px solid #ffb300;animation:wf_pulse 0.8s ease-in-out infinite alternate;" if i == regen_seg_idx else ""
seg_divs += (
f'<div style="position:absolute;top:0;left:{left_pct:.2f}%;'
f'width:{width_pct:.2f}%;height:100%;background:{color};{extra}">'
f'<span style="color:rgba(255,255,255,0.7);font-size:10px;padding:2px 3px;">Seg {i+1}</span>'
f'</div>'
)
spinner = (
'<div style="position:absolute;top:50%;left:50%;transform:translate(-50%,-50%);'
'display:flex;align-items:center;gap:6px;">'
'<div style="width:14px;height:14px;border:2px solid #ffb300;'
'border-top-color:transparent;border-radius:50%;'
'animation:wf_spin 0.7s linear infinite;"></div>'
f'<span style="color:#ffb300;font-size:12px;white-space:nowrap;">'
f'Regenerating Seg {regen_seg_idx+1}…</span>'
'</div>'
)
return f"""
<style>
@keyframes wf_pulse {{from{{opacity:0.5}}to{{opacity:1}}}}
@keyframes wf_spin {{to{{transform:rotate(360deg)}}}}
</style>
<div style="background:#1a1a1a;border-radius:8px;padding:10px;margin-top:6px;">
<div style="position:relative;width:100%;height:80px;background:#1e1e2e;border-radius:4px;overflow:hidden;">
{seg_divs}
{spinner}
</div>
<div style="color:#888;font-size:11px;margin-top:6px;">Regenerating β€” please wait…</div>
</div>
"""
def _build_waveform_html(audio_path: str, segments: list, slot_id: str,
hidden_input_id: str, state_json: str = "",
fn_index: int = -1, video_path: str = "",
crossfade_s: float = 0.0) -> str:
"""Return a self-contained HTML block with a Canvas waveform (display only),
segment boundary markers, and a download link.
Uses Web Audio API + Canvas β€” no external libraries.
The waveform is SILENT. The playhead tracks the Gradio <video> element
in the same slot via its timeupdate event.
"""
if not audio_path or not os.path.exists(audio_path):
return "<p style='color:#888;font-size:12px'>No audio yet.</p>"
# Serve audio via Gradio's file API instead of base64-encoding the entire
# WAV inline. For a 25s stereo 44.1kHz track this saves ~5 MB per slot.
audio_url = f"/gradio_api/file={audio_path}"
segs_json = json.dumps(segments)
seg_colors = [c.format(a="0.35") for c in SEG_COLORS]
# NOTE: Gradio updates gr.HTML via innerHTML which does NOT execute <script> tags.
# Solution: put the entire waveform (canvas + JS) inside an <iframe srcdoc="...">.
# iframes always execute their scripts. The iframe posts messages to the parent for
# segment-click events; the parent listens and fires the Gradio regen trigger.
# For playhead sync, the iframe polls window.parent for a <video> element.
iframe_inner = f"""<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<style>
* {{ margin:0; padding:0; box-sizing:border-box; }}
body {{ background:#1a1a1a; overflow:hidden; }}
#wrap {{ position:relative; width:100%; height:80px; }}
canvas {{ display:block; }}
#cv {{ position:absolute; top:0; left:0; width:100%; height:100%; }}
#cvp {{ position:absolute; top:0; left:0; width:100%; height:100%; pointer-events:none; }}
</style>
</head>
<body>
<div id="wrap">
<canvas id="cv"></canvas>
<canvas id="cvp"></canvas>
</div>
<script>
(function() {{
const SLOT_ID = '{slot_id}';
const segments = {segs_json};
const segColors = {json.dumps(seg_colors)};
const crossfadeSec = {crossfade_s};
let audioDuration = 0;
// ── Popup via postMessage to parent global listener ─────────────────
// The parent page (Gradio) has a global window.addEventListener('message',...)
// set up via gr.Blocks(js=...) that handles popup show/hide and regen trigger.
function showPopup(idx, mx, my) {{
console.log('[wf showPopup] slot='+SLOT_ID+' idx='+idx+' posting to parent');
// Convert iframe-local coords to parent page coords
try {{
const fr = window.frameElement ? window.frameElement.getBoundingClientRect() : {{left:0,top:0}};
window.parent.postMessage({{
type:'wf_popup', action:'show',
slot_id: SLOT_ID, seg_idx: idx,
t0: segments[idx][0], t1: segments[idx][1],
x: mx + fr.left, y: my + fr.top
}}, '*');
console.log('[wf showPopup] postMessage sent OK');
}} catch(e) {{
console.log('[wf showPopup] postMessage fallback, err='+e.message);
window.parent.postMessage({{
type:'wf_popup', action:'show',
slot_id: SLOT_ID, seg_idx: idx,
t0: segments[idx][0], t1: segments[idx][1],
x: mx, y: my
}}, '*');
}}
}}
function hidePopup() {{
window.parent.postMessage({{type:'wf_popup', action:'hide'}}, '*');
}}
// ── Canvas waveform ────────────────────────────────────────────────
const cv = document.getElementById('cv');
const cvp = document.getElementById('cvp');
const wrap= document.getElementById('wrap');
function drawWaveform(channelData, duration) {{
audioDuration = duration;
const dpr = window.devicePixelRatio || 1;
const W = wrap.getBoundingClientRect().width || window.innerWidth || 600;
const H = 80;
cv.width = W * dpr; cv.height = H * dpr;
const ctx = cv.getContext('2d');
ctx.scale(dpr, dpr);
ctx.fillStyle = '#1e1e2e';
ctx.fillRect(0, 0, W, H);
segments.forEach(function(seg, idx) {{
// Color boundary = midpoint of the crossfade zone = where the blend is
// 50/50. This is also where the cut would land if crossfade were 0, and
// where the listener perceptually hears the transition to the next segment.
const x1 = (seg[0] / duration) * W;
const xEnd = idx + 1 < segments.length
? ((segments[idx + 1][0] + crossfadeSec / 2) / duration) * W
: (seg[1] / duration) * W;
ctx.fillStyle = segColors[idx % segColors.length];
ctx.fillRect(x1, 0, xEnd - x1, H);
ctx.fillStyle = 'rgba(255,255,255,0.6)';
ctx.font = '10px sans-serif';
ctx.fillText('Seg '+(idx+1), x1+3, 12);
}});
const samples = channelData.length;
const barW=2, gap=1, step=barW+gap;
const numBars = Math.floor(W / step);
const blockSz = Math.floor(samples / numBars);
ctx.fillStyle = '#4a9eff';
for (let i=0; i<numBars; i++) {{
let max=0;
const s=i*blockSz;
for (let j=0; j<blockSz; j++) {{
const v=Math.abs(channelData[s+j]||0);
if (v>max) max=v;
}}
const barH=Math.max(1, max*H);
ctx.fillRect(i*step, (H-barH)/2, barW, barH);
}}
segments.forEach(function(seg) {{
[seg[0],seg[1]].forEach(function(t) {{
const x=(t/duration)*W;
ctx.strokeStyle='rgba(255,255,255,0.4)';
ctx.lineWidth=1;
ctx.beginPath(); ctx.moveTo(x,0); ctx.lineTo(x,H); ctx.stroke();
}});
}});
// ── Crossfade overlap indicators ──
// The color boundary is at segments[i+1][0] (= seg_i.end - crossfadeSec).
// We centre the hatch on that edge: half the crossfade on each color side.
if (crossfadeSec > 0 && segments.length > 1) {{
for (let i = 0; i < segments.length - 1; i++) {{
// Color edge = segments[i+1][0], hatch spans half on each side
const edgeT = segments[i+1][0];
const overlapStart = edgeT - crossfadeSec / 2;
const overlapEnd = edgeT + crossfadeSec / 2;
const xL = (overlapStart / duration) * W;
const xR = (overlapEnd / duration) * W;
// Diagonal hatch pattern over the overlap zone
ctx.save();
ctx.beginPath();
ctx.rect(xL, 0, xR - xL, H);
ctx.clip();
ctx.strokeStyle = 'rgba(255,255,255,0.35)';
ctx.lineWidth = 1;
const spacing = 6;
for (let lx = xL - H; lx < xR + H; lx += spacing) {{
ctx.beginPath();
ctx.moveTo(lx, H);
ctx.lineTo(lx + H, 0);
ctx.stroke();
}}
ctx.restore();
}}
}}
cv.onclick = function(e) {{
const r=cv.getBoundingClientRect();
const xRel=(e.clientX-r.left)/r.width;
const tClick=xRel*duration;
// Pick the segment whose unique (non-overlapping) region contains the click.
// Each segment owns [seg[0], nextSeg[0]) visually; last segment owns [seg[0], seg[1]].
let hit=-1;
segments.forEach(function(seg,idx){{
const uniqueEnd = idx + 1 < segments.length ? segments[idx+1][0] : seg[1];
if (tClick >= seg[0] && tClick < uniqueEnd) hit = idx;
}});
console.log('[wf click] tClick='+tClick.toFixed(2)+' hit='+hit+' audioDuration='+audioDuration+' segments='+JSON.stringify(segments));
if (hit>=0) showPopup(hit, e.clientX, e.clientY);
else hidePopup();
}};
}}
function drawPlayhead(progress) {{
const dpr = window.devicePixelRatio || 1;
const W = wrap.getBoundingClientRect().width || window.innerWidth || 600;
const H = 80;
if (cvp.width !== W*dpr) {{ cvp.width=W*dpr; cvp.height=H*dpr; }}
const ctx = cvp.getContext('2d');
ctx.clearRect(0,0,W*dpr,H*dpr);
ctx.save();
ctx.scale(dpr,dpr);
const x=progress*W;
ctx.strokeStyle='#fff';
ctx.lineWidth=2;
ctx.beginPath(); ctx.moveTo(x,0); ctx.lineTo(x,H); ctx.stroke();
ctx.restore();
}}
// Poll parent for video time β€” find the video in the same wf_container slot
function findSlotVideo() {{
try {{
const par = window.parent.document;
// Walk up from our iframe to find wf_container_{slot_id}, then find its sibling video
const container = par.getElementById('wf_container_{slot_id}');
if (!container) return par.querySelector('video');
// The video is inside the same gr.Group β€” walk up to find it
let node = container.parentElement;
while (node && node !== par.body) {{
const v = node.querySelector('video');
if (v) return v;
node = node.parentElement;
}}
return null;
}} catch(e) {{ return null; }}
}}
setInterval(function() {{
const vid = findSlotVideo();
if (vid && vid.duration && isFinite(vid.duration) && audioDuration > 0) {{
drawPlayhead(vid.currentTime / vid.duration);
}}
}}, 50);
// ── Fetch + decode audio from Gradio file API ──────────────────────
const audioUrl = '{audio_url}';
fetch(audioUrl)
.then(function(r) {{ return r.arrayBuffer(); }})
.then(function(arrayBuf) {{
const AudioCtx = window.AudioContext || window.webkitAudioContext;
if (!AudioCtx) return;
const tmpCtx = new AudioCtx({{sampleRate:44100}});
tmpCtx.decodeAudioData(arrayBuf,
function(ab) {{
try {{ tmpCtx.close(); }} catch(e) {{}}
function tryDraw() {{
const W = wrap.getBoundingClientRect().width || window.innerWidth;
if (W > 0) {{ drawWaveform(ab.getChannelData(0), ab.duration); }}
else {{ setTimeout(tryDraw, 100); }}
}}
tryDraw();
}},
function(err) {{}}
);
}})
.catch(function(e) {{}});
}})();
</script>
</body>
</html>"""
# Escape for HTML attribute (srcdoc uses HTML entities)
srcdoc = _html.escape(iframe_inner, quote=True)
state_escaped = _html.escape(state_json or "", quote=True)
return f"""
<div id="wf_container_{slot_id}"
data-fn-index="{fn_index}"
data-state="{state_escaped}"
style="background:#1a1a1a;border-radius:8px;padding:10px;margin-top:6px;position:relative;">
<div style="position:relative;width:100%;height:80px;">
<iframe id="wf_iframe_{slot_id}"
srcdoc="{srcdoc}"
sandbox="allow-scripts allow-same-origin"
style="width:100%;height:80px;border:none;border-radius:4px;display:block;"
scrolling="no"></iframe>
</div>
<div style="display:flex;align-items:center;gap:8px;margin-top:6px;">
<span id="wf_statusbar_{slot_id}" style="color:#888;font-size:11px;">Click a segment to regenerate &nbsp;|&nbsp; Playhead syncs to video</span>
<a href="{audio_url}" download="audio_{slot_id}.wav"
style="margin-left:auto;background:#333;color:#eee;border:1px solid #555;
border-radius:4px;padding:3px 10px;font-size:12px;text-decoration:none;">
&#8595; Audio</a>{f'''
<a href="/gradio_api/file={video_path}" download="video_{slot_id}.mp4"
style="background:#333;color:#eee;border:1px solid #555;
border-radius:4px;padding:3px 10px;font-size:12px;text-decoration:none;">
&#8595; Video</a>''' if video_path else ''}
</div>
<div id="wf_seglabel_{slot_id}"
style="color:#aaa;font-size:11px;margin-top:4px;min-height:16px;"></div>
</div>
"""
def _make_output_slots(tab_prefix: str) -> tuple:
"""Build MAX_SLOTS output groups for one tab.
Each slot has: video and waveform HTML.
Regen is triggered via direct Gradio queue API calls from JS (no hidden
trigger textboxes needed β€” DOM event dispatch is unreliable in Gradio 5
Svelte components). State JSON is embedded in the waveform HTML's
data-state attribute and passed directly in the queue API payload.
Returns (grps, vids, waveforms).
"""
grps, vids, waveforms = [], [], []
for i in range(MAX_SLOTS):
slot_id = f"{tab_prefix}_{i}"
with gr.Group(visible=(i == 0)) as g:
vids.append(gr.Video(label=f"Generation {i+1} β€” Video",
elem_id=f"slot_vid_{slot_id}",
show_download_button=False))
waveforms.append(gr.HTML(
value="<p style='color:#888;font-size:12px'>Generate audio to see waveform.</p>",
elem_id=f"slot_wave_{slot_id}",
))
grps.append(g)
return grps, vids, waveforms
def _unpack_outputs(flat: list, n: int, tab_prefix: str) -> list:
"""Turn a flat _pad_outputs list into Gradio update lists.
flat has MAX_SLOTS * 3 items: [vid0, aud0, meta0, vid1, aud1, meta1, ...]
Returns updates for vids + waveforms only (NOT grps).
Group visibility is handled separately via .then() to avoid Gradio 5 SSR
'Too many arguments' caused by mixing gr.Group updates with other outputs.
State JSON is embedded in the waveform HTML data-state attribute so JS
can read it when calling the Gradio queue API for regen.
"""
n = int(n)
vid_updates = []
wave_updates = []
for i in range(MAX_SLOTS):
vid_path = flat[i * 3]
aud_path = flat[i * 3 + 1]
meta = flat[i * 3 + 2]
vid_updates.append(gr.update(value=vid_path))
if aud_path and meta:
slot_id = f"{tab_prefix}_{i}"
state_json = json.dumps(meta)
html = _build_waveform_html(aud_path, meta["segments"], slot_id,
"", state_json=state_json,
video_path=meta.get("video_path", ""),
crossfade_s=float(meta.get("crossfade_s", 0)))
wave_updates.append(gr.update(value=html))
else:
wave_updates.append(gr.update(
value="<p style='color:#888;font-size:12px'>Generate audio to see waveform.</p>"
))
return vid_updates + wave_updates
def _on_video_upload_taro(video_file, num_steps, crossfade_s):
if video_file is None:
return gr.update(maximum=MAX_SLOTS, value=1)
try:
D = get_video_duration(video_file)
max_s = _taro_calc_max_samples(D, int(num_steps), float(crossfade_s))
except Exception:
max_s = MAX_SLOTS
return gr.update(maximum=max_s, value=min(1, max_s))
def _update_slot_visibility(n):
n = int(n)
return [gr.update(visible=(i < n)) for i in range(MAX_SLOTS)]
# ================================================================== #
# GRADIO UI #
# ================================================================== #
_SLOT_CSS = """
/* Responsive video: fills column width, height auto from aspect ratio */
.gradio-video video {
width: 100%;
height: auto;
max-height: 60vh;
object-fit: contain;
}
/* Force two-column layout to stay equal-width */
.gradio-container .gradio-row > .gradio-column {
flex: 1 1 0 !important;
min-width: 0 !important;
max-width: 50% !important;
}
/* Hide the built-in download button on output video slots β€” downloads are
handled by the waveform panel links which always reflect the latest regen. */
[id^="slot_vid_"] .download-icon,
[id^="slot_vid_"] button[aria-label="Download"],
[id^="slot_vid_"] a[download] {
display: none !important;
}
"""
_GLOBAL_JS = """
() => {
// Global postMessage handler for waveform iframe events.
// Runs once on page load (Gradio js= parameter).
// Handles: popup open/close relay, regen trigger via Gradio queue API.
if (window._wf_global_listener) return; // already registered
window._wf_global_listener = true;
// ── ZeroGPU quota attribution ──
// HF Spaces run inside an iframe on huggingface.co. Gradio's own JS client
// gets ZeroGPU auth headers (x-zerogpu-token, x-zerogpu-uuid) by sending a
// postMessage("zerogpu-headers") to the parent frame. The parent responds
// with a Map of headers that must be included on queue/join calls.
// We replicate this exact mechanism so our raw regen fetch() calls are
// attributed to the logged-in user's Pro quota.
function _fetchZerogpuHeaders() {
return new Promise(function(resolve) {
// Check if we're in an HF iframe with zerogpu support
if (typeof window === 'undefined' || window.parent === window || !window.supports_zerogpu_headers) {
console.log('[zerogpu] not in HF iframe or no zerogpu support');
resolve({});
return;
}
// Determine origin β€” same logic as Gradio's client
var hostname = window.location.hostname;
var hfhubdev = 'dev.spaces.huggingface.tech';
var origin = hostname.includes('.dev.')
? 'https://moon-' + hostname.split('.')[1] + '.' + hfhubdev
: 'https://huggingface.co';
// Use MessageChannel just like Gradio's post_message helper
var channel = new MessageChannel();
var done = false;
channel.port1.onmessage = function(ev) {
channel.port1.close();
done = true;
var headers = ev.data;
if (headers && typeof headers === 'object') {
// Convert Map to plain object if needed
var obj = {};
if (typeof headers.forEach === 'function') {
headers.forEach(function(v, k) { obj[k] = v; });
} else {
obj = headers;
}
console.log('[zerogpu] got headers from parent:', Object.keys(obj).join(', '));
resolve(obj);
} else {
resolve({});
}
};
window.parent.postMessage('zerogpu-headers', origin, [channel.port2]);
// Timeout: don't block regen if parent doesn't respond
setTimeout(function() { if (!done) { done = true; channel.port1.close(); resolve({}); } }, 3000);
});
}
// Cache: api_name -> fn_index, built once from gradio_config.dependencies
let _fnIndexCache = null;
function getFnIndex(apiName) {
if (!_fnIndexCache) {
_fnIndexCache = {};
const deps = window.gradio_config && window.gradio_config.dependencies;
if (deps) deps.forEach(function(d, i) {
if (d.api_name) _fnIndexCache[d.api_name] = i;
});
}
return _fnIndexCache[apiName];
}
// Read a component's current DOM value by elem_id.
// For Number/Slider: reads the <input type="number"> or <input type="range">.
// For Textbox/Radio: reads the <textarea> or checked <input type="radio">.
// Returns null if not found.
function readComponentValue(elemId) {
const el = document.getElementById(elemId);
if (!el) return null;
const numInput = el.querySelector('input[type="number"]');
if (numInput) return parseFloat(numInput.value);
const rangeInput = el.querySelector('input[type="range"]');
if (rangeInput) return parseFloat(rangeInput.value);
const radio = el.querySelector('input[type="radio"]:checked');
if (radio) return radio.value;
const ta = el.querySelector('textarea');
if (ta) return ta.value;
const txt = el.querySelector('input[type="text"], input:not([type])');
if (txt) return txt.value;
return null;
}
// Fire regen for a given slot and segment by posting directly to the
// Gradio queue API β€” bypasses Svelte binding entirely.
// targetModel: 'taro' | 'mma' | 'hf' (which model to use for inference)
// If targetModel matches the slot's own prefix, uses the per-slot regen_* endpoint.
// Otherwise uses the shared xregen_* cross-model endpoint.
function fireRegen(slot_id, seg_idx, targetModel) {
// Block if a regen is already in-flight for this slot
if (_regenInFlight[slot_id]) {
console.log('[fireRegen] blocked β€” regen already in-flight for', slot_id);
return;
}
_regenInFlight[slot_id] = true;
const prefix = slot_id.split('_')[0]; // owning tab: 'taro'|'mma'|'hf'
const slotNum = parseInt(slot_id.split('_')[1], 10);
// Decide which endpoint to call
const crossModel = (targetModel !== prefix);
let apiName, data;
// Read state_json from the waveform container data-state attribute
const container = document.getElementById('wf_container_' + slot_id);
const stateJson = container ? (container.getAttribute('data-state') || '') : '';
if (!stateJson) {
console.warn('[fireRegen] no state_json for slot', slot_id);
return;
}
if (!crossModel) {
// ── Same-model regen: per-slot endpoint, video passed as null ──
apiName = 'regen_' + prefix + '_' + slotNum;
if (prefix === 'taro') {
data = [seg_idx, stateJson, null,
readComponentValue('taro_seed'), readComponentValue('taro_cfg'),
readComponentValue('taro_steps'), readComponentValue('taro_mode'),
readComponentValue('taro_cf_dur'), readComponentValue('taro_cf_db')];
} else if (prefix === 'mma') {
data = [seg_idx, stateJson, null,
readComponentValue('mma_prompt'), readComponentValue('mma_neg'),
readComponentValue('mma_seed'), readComponentValue('mma_cfg'),
readComponentValue('mma_steps'),
readComponentValue('mma_cf_dur'), readComponentValue('mma_cf_db')];
} else {
data = [seg_idx, stateJson, null,
readComponentValue('hf_prompt'), readComponentValue('hf_neg'),
readComponentValue('hf_seed'), readComponentValue('hf_guidance'),
readComponentValue('hf_steps'), readComponentValue('hf_size'),
readComponentValue('hf_cf_dur'), readComponentValue('hf_cf_db')];
}
} else {
// ── Cross-model regen: shared xregen_* endpoint ──
// slot_id is passed so the server knows which slot's state to splice into.
// UI params are read from the target model's tab inputs.
if (targetModel === 'taro') {
apiName = 'xregen_taro';
data = [seg_idx, stateJson, slot_id,
readComponentValue('taro_seed'), readComponentValue('taro_cfg'),
readComponentValue('taro_steps'), readComponentValue('taro_mode'),
readComponentValue('taro_cf_dur'), readComponentValue('taro_cf_db')];
} else if (targetModel === 'mma') {
apiName = 'xregen_mmaudio';
data = [seg_idx, stateJson, slot_id,
readComponentValue('mma_prompt'), readComponentValue('mma_neg'),
readComponentValue('mma_seed'), readComponentValue('mma_cfg'),
readComponentValue('mma_steps'),
readComponentValue('mma_cf_dur'), readComponentValue('mma_cf_db')];
} else {
apiName = 'xregen_hunyuan';
data = [seg_idx, stateJson, slot_id,
readComponentValue('hf_prompt'), readComponentValue('hf_neg'),
readComponentValue('hf_seed'), readComponentValue('hf_guidance'),
readComponentValue('hf_steps'), readComponentValue('hf_size'),
readComponentValue('hf_cf_dur'), readComponentValue('hf_cf_db')];
}
}
console.log('[fireRegen] calling api', apiName, 'seg', seg_idx);
// Snapshot current waveform HTML + video src before mutating anything,
// so we can restore on error (e.g. quota exceeded).
var _preRegenWaveHtml = null;
var _preRegenVideoSrc = null;
var waveElSnap = document.getElementById('slot_wave_' + slot_id);
if (waveElSnap) _preRegenWaveHtml = waveElSnap.innerHTML;
var vidElSnap = document.getElementById('slot_vid_' + slot_id);
if (vidElSnap) { var vSnap = vidElSnap.querySelector('video'); if (vSnap) _preRegenVideoSrc = vSnap.getAttribute('src'); }
// Show spinner immediately
const lbl = document.getElementById('wf_seglabel_' + slot_id);
if (lbl) lbl.textContent = 'Regenerating Seg ' + (seg_idx + 1) + '...';
const fnIndex = getFnIndex(apiName);
if (fnIndex === undefined) {
console.warn('[fireRegen] fn_index not found for api_name:', apiName);
return;
}
// Get ZeroGPU auth headers from the HF parent frame (same mechanism
// Gradio's own JS client uses), then fire the regen queue/join call.
// Falls back to user-supplied HF token if zerogpu headers aren't available.
_fetchZerogpuHeaders().then(function(zerogpuHeaders) {
var regenHeaders = {'Content-Type': 'application/json'};
var hasZerogpu = zerogpuHeaders && Object.keys(zerogpuHeaders).length > 0;
if (hasZerogpu) {
// Merge zerogpu headers (x-zerogpu-token, x-zerogpu-uuid)
for (var k in zerogpuHeaders) { regenHeaders[k] = zerogpuHeaders[k]; }
console.log('[fireRegen] using zerogpu headers from parent frame');
} else {
console.warn('[fireRegen] no zerogpu headers available β€” may use anonymous quota');
}
fetch('/gradio_api/queue/join', {
method: 'POST',
credentials: 'include',
headers: regenHeaders,
body: JSON.stringify({
data: data,
fn_index: fnIndex,
api_name: '/' + apiName,
session_hash: window.__gradio_session_hash__,
event_data: null,
trigger_id: null
})
}).then(function(r) { return r.json(); }).then(function(j) {
if (!j.event_id) { console.error('[fireRegen] no event_id:', j); return; }
console.log('[fireRegen] queued, event_id:', j.event_id);
_listenAndApply(j.event_id, slot_id, seg_idx, _preRegenWaveHtml, _preRegenVideoSrc);
}).catch(function(e) {
console.error('[fireRegen] fetch error:', e);
if (lbl) lbl.textContent = 'Error β€” see console';
var sb = document.getElementById('wf_statusbar_' + slot_id);
if (sb) { sb.style.color = '#e05252'; sb.textContent = '\u26a0 Request failed: ' + e.message; }
});
});
}
// Subscribe to Gradio SSE stream for an event and apply outputs to DOM.
// For regen handlers, output[0] = video update, output[1] = waveform HTML update.
function _applyVideoSrc(slot_id, newSrc) {
var vidEl = document.getElementById('slot_vid_' + slot_id);
if (!vidEl) return false;
var video = vidEl.querySelector('video');
if (!video) return false;
if (video.getAttribute('src') === newSrc) return true; // already correct
video.setAttribute('src', newSrc);
video.src = newSrc;
video.load();
console.log('[_applyVideoSrc] applied src to', 'slot_vid_' + slot_id, 'src:', newSrc.slice(-40));
return true;
}
// Toast notification β€” styled like ZeroGPU quota warnings.
function _showRegenToast(message, isError) {
var t = document.createElement('div');
t.style.cssText = 'position:fixed;bottom:24px;left:50%;transform:translateX(-50%);' +
'z-index:2147483647;padding:12px 20px;border-radius:8px;font-family:sans-serif;' +
'font-size:13px;max-width:520px;text-align:center;box-shadow:0 4px 20px rgba(0,0,0,.6);' +
'background:' + (isError ? '#7a1c1c' : '#1c4a1c') + ';color:#fff;' +
'border:1px solid ' + (isError ? '#c0392b' : '#27ae60') + ';' +
'pointer-events:none;';
t.textContent = message;
document.body.appendChild(t);
setTimeout(function() {
t.style.transition = 'opacity 0.5s';
t.style.opacity = '0';
setTimeout(function() { t.parentNode && t.parentNode.removeChild(t); }, 600);
}, isError ? 8000 : 3000);
}
function _listenAndApply(eventId, slot_id, seg_idx, preRegenWaveHtml, preRegenVideoSrc) {
var _pendingVideoSrc = null;
const es = new EventSource('/gradio_api/queue/data?session_hash=' + window.__gradio_session_hash__);
es.onmessage = function(e) {
var msg;
try { msg = JSON.parse(e.data); } catch(_) { return; }
if (msg.event_id !== eventId) return;
if (msg.msg === 'process_generating' || msg.msg === 'process_completed') {
var out = msg.output;
if (out && out.data) {
var vidUpdate = out.data[0];
var waveUpdate = out.data[1];
var newSrc = null;
if (vidUpdate) {
if (vidUpdate.value && vidUpdate.value.video && vidUpdate.value.video.url) newSrc = vidUpdate.value.video.url;
else if (vidUpdate.video && vidUpdate.video.url) newSrc = vidUpdate.video.url;
else if (vidUpdate.value && vidUpdate.value.url) newSrc = vidUpdate.value.url;
else if (typeof vidUpdate.value === 'string') newSrc = vidUpdate.value;
else if (vidUpdate.url) newSrc = vidUpdate.url;
}
if (newSrc) _pendingVideoSrc = newSrc;
var waveHtml = null;
if (waveUpdate) {
if (typeof waveUpdate === 'string') waveHtml = waveUpdate;
else if (waveUpdate.value && typeof waveUpdate.value === 'string') waveHtml = waveUpdate.value;
}
if (waveHtml) {
var waveEl = document.getElementById('slot_wave_' + slot_id);
if (waveEl) {
var inner = waveEl.querySelector('.prose') || waveEl.querySelector('div');
if (inner) inner.innerHTML = waveHtml;
else waveEl.innerHTML = waveHtml;
}
}
}
if (msg.msg === 'process_completed') {
es.close();
_regenInFlight[slot_id] = false;
var errMsg = msg.output && msg.output.error;
var hadError = !!errMsg;
console.log('[fireRegen] completed for', slot_id, 'error:', hadError, errMsg || '');
var lbl = document.getElementById('wf_seglabel_' + slot_id);
if (hadError) {
var toastMsg = typeof errMsg === 'string' ? errMsg : JSON.stringify(errMsg);
// Restore previous waveform HTML and video src
if (preRegenWaveHtml !== null) {
var waveEl2 = document.getElementById('slot_wave_' + slot_id);
if (waveEl2) waveEl2.innerHTML = preRegenWaveHtml;
}
if (preRegenVideoSrc !== null) {
var vidElR = document.getElementById('slot_vid_' + slot_id);
if (vidElR) { var vR = vidElR.querySelector('video'); if (vR) { vR.setAttribute('src', preRegenVideoSrc); vR.src = preRegenVideoSrc; vR.load(); } }
}
// Update the statusbar (query after restore so we get the freshly-restored element)
var isAbort = toastMsg.toLowerCase().indexOf('aborted') !== -1;
var isTimeout = toastMsg.toLowerCase().indexOf('timeout') !== -1;
var failMsg = isAbort || isTimeout
? '\u26a0 GPU cold-start β€” segment unchanged, try again'
: '\u26a0 Regen failed β€” segment unchanged';
var statusBar = document.getElementById('wf_statusbar_' + slot_id);
if (statusBar) {
statusBar.style.color = '#e05252';
statusBar.textContent = failMsg;
setTimeout(function() { statusBar.style.color = '#888'; statusBar.textContent = 'Click a segment to regenerate \u00a0|\u00a0 Playhead syncs to video'; }, 8000);
}
} else {
if (lbl) lbl.textContent = 'Done';
var src = _pendingVideoSrc;
if (src) {
_applyVideoSrc(slot_id, src);
setTimeout(function() { _applyVideoSrc(slot_id, src); }, 50);
setTimeout(function() { _applyVideoSrc(slot_id, src); }, 300);
setTimeout(function() { _applyVideoSrc(slot_id, src); }, 800);
var vidEl = document.getElementById('slot_vid_' + slot_id);
if (vidEl) {
var obs = new MutationObserver(function() { _applyVideoSrc(slot_id, src); });
obs.observe(vidEl, {subtree: true, attributes: true, attributeFilter: ['src'], childList: true});
setTimeout(function() { obs.disconnect(); }, 2000);
}
}
}
}
}
if (msg.msg === 'close_stream') { es.close(); }
};
es.onerror = function() { es.close(); _regenInFlight[slot_id] = false; };
}
// Track in-flight regen per slot β€” prevents queuing multiple jobs from rapid clicks
var _regenInFlight = {};
// Shared popup element created once and reused across all slots
let _popup = null;
let _pendingSlot = null, _pendingIdx = null;
function ensurePopup() {
if (_popup) return _popup;
_popup = document.createElement('div');
_popup.style.cssText = 'display:none;position:fixed;z-index:99999;' +
'background:#2a2a2a;border:1px solid #555;border-radius:6px;' +
'padding:8px 12px;box-shadow:0 4px 16px rgba(0,0,0,.5);font-family:sans-serif;';
var btnStyle = 'color:#fff;border:none;border-radius:4px;padding:5px 10px;' +
'font-size:11px;cursor:pointer;flex:1;';
_popup.innerHTML =
'<div id="_wf_popup_lbl" style="color:#ccc;font-size:11px;margin-bottom:6px;white-space:nowrap;"></div>' +
'<div style="display:flex;gap:5px;">' +
'<button id="_wf_popup_taro" style="background:#1d6fa5;' + btnStyle + '">&#10227; TARO</button>' +
'<button id="_wf_popup_mma" style="background:#2d7a4a;' + btnStyle + '">&#10227; MMAudio</button>' +
'<button id="_wf_popup_hf" style="background:#7a3d8c;' + btnStyle + '">&#10227; Hunyuan</button>' +
'</div>';
document.body.appendChild(_popup);
['taro','mma','hf'].forEach(function(model) {
document.getElementById('_wf_popup_' + model).onclick = function(e) {
e.stopPropagation();
var slot = _pendingSlot, idx = _pendingIdx;
hidePopup();
if (slot !== null && idx !== null) fireRegen(slot, idx, model);
};
});
// Use bubble phase (false) so stopPropagation() on the button click prevents this from firing
document.addEventListener('click', function() { hidePopup(); }, false);
return _popup;
}
function hidePopup() {
if (_popup) _popup.style.display = 'none';
_pendingSlot = null; _pendingIdx = null;
}
window.addEventListener('message', function(e) {
const d = e.data;
console.log('[global msg] received type=' + (d && d.type) + ' action=' + (d && d.action));
if (!d || d.type !== 'wf_popup') return;
const p = ensurePopup();
if (d.action === 'hide') { hidePopup(); return; }
// action === 'show'
_pendingSlot = d.slot_id;
_pendingIdx = d.seg_idx;
const lbl = document.getElementById('_wf_popup_lbl');
if (lbl) lbl.textContent = 'Seg ' + (d.seg_idx + 1) +
' (' + d.t0.toFixed(2) + 's \u2013 ' + d.t1.toFixed(2) + 's)';
p.style.display = 'block';
p.style.left = (d.x + 10) + 'px';
p.style.top = (d.y + 10) + 'px';
requestAnimationFrame(function() {
const r = p.getBoundingClientRect();
if (r.right > window.innerWidth - 8) p.style.left = (window.innerWidth - r.width - 8) + 'px';
if (r.bottom > window.innerHeight - 8) p.style.top = (window.innerHeight - r.height - 8) + 'px';
});
});
}
"""
with gr.Blocks(title="Generate Audio for Video", css=_SLOT_CSS, js=_GLOBAL_JS) as demo:
gr.Markdown(
"# Generate Audio for Video\n"
"Choose a model and upload a video to generate synchronized audio.\n\n"
"| Model | Best for | Avoid for |\n"
"|-------|----------|-----------|\n"
"| **TARO** | Natural, physics-driven impacts β€” footsteps, collisions, water, wind, crackling fire. Excels when the sound is tightly coupled to visible motion without needing a text description. | Dialogue, music, or complex layered soundscapes where semantic context matters. |\n"
"| **MMAudio** | Mixed scenes where you want both visual grounding *and* semantic control via a text prompt β€” e.g. a busy street scene where you want to emphasize the rain rather than the traffic. Great for ambient textures and nuanced sound design. | Pure impact/foley shots where TARO's motion-coupling would be sharper, or cinematic music beds. |\n"
"| **HunyuanFoley** | Cinematic foley requiring high fidelity and explicit creative direction β€” dramatic SFX, layered environmental design, or any scene where you have a clear written description of the desired sound palette. | Quick one-shot clips where you don't want to write a prompt, or raw impact sounds where timing precision matters more than richness. |"
)
with gr.Tabs():
# ---------------------------------------------------------- #
# Tab 1 β€” TARO #
# ---------------------------------------------------------- #
with gr.Tab("TARO"):
with gr.Row():
with gr.Column(scale=1):
taro_video = gr.Video(label="Input Video")
taro_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0, elem_id="taro_seed")
taro_cfg = gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8.0, step=0.5, elem_id="taro_cfg")
taro_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1, elem_id="taro_steps")
taro_mode = gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde", elem_id="taro_mode")
taro_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=4, value=2, step=0.1, elem_id="taro_cf_dur")
taro_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3", elem_id="taro_cf_db")
taro_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1)
taro_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
(taro_slot_grps, taro_slot_vids,
taro_slot_waves) = _make_output_slots("taro")
# Hidden regen plumbing β€” render=False so no DOM element is created,
# avoiding Gradio's "Too many arguments" Svelte validation error.
# JS passes values directly via queue/join data array at the correct
# positional index (these show up as inputs to the fn but have no DOM).
taro_regen_seg = gr.Textbox(value="0", render=False)
taro_regen_state = gr.Textbox(value="", render=False)
for trigger in [taro_video, taro_steps, taro_cf_dur]:
trigger.change(
fn=_on_video_upload_taro,
inputs=[taro_video, taro_steps, taro_cf_dur],
outputs=[taro_samples],
)
taro_samples.change(
fn=_update_slot_visibility,
inputs=[taro_samples],
outputs=taro_slot_grps,
)
def _run_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n):
flat = generate_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n)
return _unpack_outputs(flat, n, "taro")
# Split group visibility into a separate .then() to avoid Gradio 5 SSR
# "Too many arguments" caused by including gr.Group in mixed output lists.
(taro_btn.click(
fn=_run_taro,
inputs=[taro_video, taro_seed, taro_cfg, taro_steps, taro_mode,
taro_cf_dur, taro_cf_db, taro_samples],
outputs=taro_slot_vids + taro_slot_waves,
).then(
fn=_update_slot_visibility,
inputs=[taro_samples],
outputs=taro_slot_grps,
))
# Per-slot regen handlers β€” JS calls /gradio_api/queue/join with
# fn_index (by api_name) + data=[seg_idx, state_json, video, ...params].
taro_regen_btns = _register_regen_handlers(
"taro", "taro", taro_regen_seg, taro_regen_state,
[taro_video, taro_seed, taro_cfg, taro_steps,
taro_mode, taro_cf_dur, taro_cf_db],
taro_slot_vids, taro_slot_waves,
)
# ---------------------------------------------------------- #
# Tab 2 β€” MMAudio #
# ---------------------------------------------------------- #
with gr.Tab("MMAudio"):
with gr.Row():
with gr.Column(scale=1):
mma_video = gr.Video(label="Input Video")
mma_prompt = gr.Textbox(label="Prompt", placeholder="e.g. footsteps on gravel", elem_id="mma_prompt")
mma_neg = gr.Textbox(label="Negative Prompt", value="music", placeholder="music, speech", elem_id="mma_neg")
mma_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0, elem_id="mma_seed")
mma_cfg = gr.Slider(label="CFG Strength", minimum=1, maximum=10, value=4.5, step=0.5, elem_id="mma_cfg")
mma_steps = gr.Slider(label="Steps", minimum=10, maximum=50, value=25, step=1, elem_id="mma_steps")
mma_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=4, value=2, step=0.1, elem_id="mma_cf_dur")
mma_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3", elem_id="mma_cf_db")
mma_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1)
mma_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
(mma_slot_grps, mma_slot_vids,
mma_slot_waves) = _make_output_slots("mma")
# Hidden regen plumbing β€” render=False so no DOM element is created,
# avoiding Gradio's "Too many arguments" Svelte validation error.
mma_regen_seg = gr.Textbox(value="0", render=False)
mma_regen_state = gr.Textbox(value="", render=False)
mma_samples.change(
fn=_update_slot_visibility,
inputs=[mma_samples],
outputs=mma_slot_grps,
)
def _run_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n):
flat = generate_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n)
return _unpack_outputs(flat, n, "mma")
(mma_btn.click(
fn=_run_mmaudio,
inputs=[mma_video, mma_prompt, mma_neg, mma_seed,
mma_cfg, mma_steps, mma_cf_dur, mma_cf_db, mma_samples],
outputs=mma_slot_vids + mma_slot_waves,
).then(
fn=_update_slot_visibility,
inputs=[mma_samples],
outputs=mma_slot_grps,
))
mma_regen_btns = _register_regen_handlers(
"mma", "mmaudio", mma_regen_seg, mma_regen_state,
[mma_video, mma_prompt, mma_neg, mma_seed,
mma_cfg, mma_steps, mma_cf_dur, mma_cf_db],
mma_slot_vids, mma_slot_waves,
)
# ---------------------------------------------------------- #
# Tab 3 β€” HunyuanVideoFoley #
# ---------------------------------------------------------- #
with gr.Tab("HunyuanFoley"):
with gr.Row():
with gr.Column(scale=1):
hf_video = gr.Video(label="Input Video")
hf_prompt = gr.Textbox(label="Prompt", placeholder="e.g. rain hitting a metal roof", elem_id="hf_prompt")
hf_neg = gr.Textbox(label="Negative Prompt", value="noisy, harsh", elem_id="hf_neg")
hf_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0, elem_id="hf_seed")
hf_guidance = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, value=4.5, step=0.5, elem_id="hf_guidance")
hf_steps = gr.Slider(label="Steps", minimum=10, maximum=100, value=50, step=5, elem_id="hf_steps")
hf_size = gr.Radio(label="Model Size", choices=["xl", "xxl"], value="xxl", elem_id="hf_size")
hf_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=4, value=2, step=0.1, elem_id="hf_cf_dur")
hf_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3", elem_id="hf_cf_db")
hf_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1)
hf_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
(hf_slot_grps, hf_slot_vids,
hf_slot_waves) = _make_output_slots("hf")
# Hidden regen plumbing β€” render=False so no DOM element is created,
# avoiding Gradio's "Too many arguments" Svelte validation error.
hf_regen_seg = gr.Textbox(value="0", render=False)
hf_regen_state = gr.Textbox(value="", render=False)
hf_samples.change(
fn=_update_slot_visibility,
inputs=[hf_samples],
outputs=hf_slot_grps,
)
def _run_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n):
flat = generate_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n)
return _unpack_outputs(flat, n, "hf")
(hf_btn.click(
fn=_run_hunyuan,
inputs=[hf_video, hf_prompt, hf_neg, hf_seed,
hf_guidance, hf_steps, hf_size, hf_cf_dur, hf_cf_db, hf_samples],
outputs=hf_slot_vids + hf_slot_waves,
).then(
fn=_update_slot_visibility,
inputs=[hf_samples],
outputs=hf_slot_grps,
))
hf_regen_btns = _register_regen_handlers(
"hf", "hunyuan", hf_regen_seg, hf_regen_state,
[hf_video, hf_prompt, hf_neg, hf_seed,
hf_guidance, hf_steps, hf_size, hf_cf_dur, hf_cf_db],
hf_slot_vids, hf_slot_waves,
)
# ---- Browser-safe transcode on upload ----
# Gradio serves the original uploaded file to the browser preview widget,
# so H.265 sources show as blank. We re-encode to H.264 on upload and feed
# the result back so the preview plays. mux_video_audio already re-encodes
# to H.264 during generation, so no double-conversion conflict.
taro_video.upload(fn=_transcode_for_browser, inputs=[taro_video], outputs=[taro_video])
mma_video.upload(fn=_transcode_for_browser, inputs=[mma_video], outputs=[mma_video])
hf_video.upload(fn=_transcode_for_browser, inputs=[hf_video], outputs=[hf_video])
# ---- Cross-tab video sync ----
_sync = lambda v: (gr.update(value=v), gr.update(value=v))
taro_video.change(fn=_sync, inputs=[taro_video], outputs=[mma_video, hf_video])
mma_video.change(fn=_sync, inputs=[mma_video], outputs=[taro_video, hf_video])
hf_video.change(fn=_sync, inputs=[hf_video], outputs=[taro_video, mma_video])
# ---- Cross-model regen endpoints ----
# render=False inputs/outputs: no DOM elements created, no SSR validation impact.
# JS calls these via /gradio_api/queue/join using the api_name and applies
# the returned video+waveform directly to the target slot's DOM elements.
_xr_seg = gr.Textbox(value="0", render=False)
_xr_state = gr.Textbox(value="", render=False)
_xr_slot_id = gr.Textbox(value="", render=False)
# Dummy outputs for xregen events: must be real rendered components so Gradio
# can look them up in session state during postprocess_data. The JS listener
# (_listenAndApply) applies the returned video/HTML directly to the correct
# slot's DOM elements and ignores Gradio's own output routing, so these
# slot-0 components simply act as sinks β€” their displayed value is overwritten
# by the real JS update immediately after.
_xr_dummy_vid = taro_slot_vids[0]
_xr_dummy_wave = taro_slot_waves[0]
# TARO cross-model regen inputs: seg_idx, state_json, slot_id, seed, cfg, steps, mode, cf_dur, cf_db
_xr_taro_seed = gr.Textbox(value="-1", render=False)
_xr_taro_cfg = gr.Textbox(value="7.5", render=False)
_xr_taro_steps = gr.Textbox(value="25", render=False)
_xr_taro_mode = gr.Textbox(value="sde", render=False)
_xr_taro_cfd = gr.Textbox(value="2", render=False)
_xr_taro_cfdb = gr.Textbox(value="3", render=False)
gr.Button(render=False).click(
fn=xregen_taro,
inputs=[_xr_seg, _xr_state, _xr_slot_id,
_xr_taro_seed, _xr_taro_cfg, _xr_taro_steps,
_xr_taro_mode, _xr_taro_cfd, _xr_taro_cfdb],
outputs=[_xr_dummy_vid, _xr_dummy_wave],
api_name="xregen_taro",
)
# MMAudio cross-model regen inputs: seg_idx, state_json, slot_id, prompt, neg, seed, cfg, steps, cf_dur, cf_db
_xr_mma_prompt = gr.Textbox(value="", render=False)
_xr_mma_neg = gr.Textbox(value="", render=False)
_xr_mma_seed = gr.Textbox(value="-1", render=False)
_xr_mma_cfg = gr.Textbox(value="4.5", render=False)
_xr_mma_steps = gr.Textbox(value="25", render=False)
_xr_mma_cfd = gr.Textbox(value="2", render=False)
_xr_mma_cfdb = gr.Textbox(value="3", render=False)
gr.Button(render=False).click(
fn=xregen_mmaudio,
inputs=[_xr_seg, _xr_state, _xr_slot_id,
_xr_mma_prompt, _xr_mma_neg, _xr_mma_seed,
_xr_mma_cfg, _xr_mma_steps, _xr_mma_cfd, _xr_mma_cfdb],
outputs=[_xr_dummy_vid, _xr_dummy_wave],
api_name="xregen_mmaudio",
)
# HunyuanFoley cross-model regen inputs: seg_idx, state_json, slot_id, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db
_xr_hf_prompt = gr.Textbox(value="", render=False)
_xr_hf_neg = gr.Textbox(value="", render=False)
_xr_hf_seed = gr.Textbox(value="-1", render=False)
_xr_hf_guide = gr.Textbox(value="4.5", render=False)
_xr_hf_steps = gr.Textbox(value="50", render=False)
_xr_hf_size = gr.Textbox(value="xxl", render=False)
_xr_hf_cfd = gr.Textbox(value="2", render=False)
_xr_hf_cfdb = gr.Textbox(value="3", render=False)
gr.Button(render=False).click(
fn=xregen_hunyuan,
inputs=[_xr_seg, _xr_state, _xr_slot_id,
_xr_hf_prompt, _xr_hf_neg, _xr_hf_seed,
_xr_hf_guide, _xr_hf_steps, _xr_hf_size,
_xr_hf_cfd, _xr_hf_cfdb],
outputs=[_xr_dummy_vid, _xr_dummy_wave],
api_name="xregen_hunyuan",
)
# NOTE: ZeroGPU quota attribution is handled via postMessage("zerogpu-headers")
# to the HF parent frame β€” the same mechanism Gradio's own JS client uses.
# This replaced the old x-ip-token relay approach which was unreliable.
print("[startup] app.py fully loaded β€” regen handlers registered, SSR disabled")
demo.queue(max_size=10).launch(ssr_mode=False, height=900, allowed_paths=["/tmp"])