BoxOfColors's picture
Fix regen GPU fns: move seg clip extraction inside GPU scope
d5399ac
"""
Generate Audio for Video β€” multi-model Gradio app.
Supported models
----------------
TARO – video-conditioned diffusion via CAVP + onset features (16 kHz, 8.192 s window)
MMAudio – multimodal flow-matching with CLIP/Synchformer + text prompt (44 kHz, 8 s window)
HunyuanFoley – text-guided foley via SigLIP2 + Synchformer + CLAP (48 kHz, up to 15 s)
"""
import html as _html
import os
import sys
import json
import shutil
import tempfile
import random
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import torch
import numpy as np
import torchaudio
import ffmpeg
import spaces
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download
# ================================================================== #
# CHECKPOINT CONFIGURATION #
# ================================================================== #
CKPT_REPO_ID = "JackIsNotInTheBox/Generate_Audio_for_Video_Checkpoints"
CACHE_DIR = "/tmp/model_ckpts"
os.makedirs(CACHE_DIR, exist_ok=True)
# ---- Local directories that must exist before parallel downloads start ----
MMAUDIO_WEIGHTS_DIR = Path(CACHE_DIR) / "MMAudio" / "weights"
MMAUDIO_EXT_DIR = Path(CACHE_DIR) / "MMAudio" / "ext_weights"
HUNYUAN_MODEL_DIR = Path(CACHE_DIR) / "HunyuanFoley"
MMAUDIO_WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
MMAUDIO_EXT_DIR.mkdir(parents=True, exist_ok=True)
HUNYUAN_MODEL_DIR.mkdir(parents=True, exist_ok=True)
# ------------------------------------------------------------------ #
# Parallel checkpoint + model downloads #
# All downloads are I/O-bound (network), so running them in threads #
# cuts Space cold-start time roughly proportional to the number of #
# independent groups (previously sequential, now concurrent). #
# hf_hub_download / snapshot_download are thread-safe. #
# ------------------------------------------------------------------ #
def _dl_taro():
"""Download TARO .ckpt/.pt files and return their local paths."""
c = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/cavp_epoch66.ckpt", cache_dir=CACHE_DIR)
o = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/onset_model.ckpt", cache_dir=CACHE_DIR)
t = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/taro_ckpt.pt", cache_dir=CACHE_DIR)
print("TARO checkpoints downloaded.")
return c, o, t
def _dl_mmaudio():
"""Download MMAudio .pth files and return their local paths."""
m = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/mmaudio_large_44k_v2.pth",
cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_WEIGHTS_DIR), local_dir_use_symlinks=False)
v = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/v1-44.pth",
cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False)
s = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/synchformer_state_dict.pth",
cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False)
print("MMAudio checkpoints downloaded.")
return m, v, s
def _dl_hunyuan():
"""Download HunyuanVideoFoley .pth files."""
hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/hunyuanvideo_foley.pth",
cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/vae_128d_48k.pth",
cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/synchformer_state_dict.pth",
cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
print("HunyuanVideoFoley checkpoints downloaded.")
def _dl_clap():
"""Pre-download CLAP so from_pretrained() hits local cache inside the ZeroGPU worker."""
snapshot_download(repo_id="laion/larger_clap_general")
print("CLAP model pre-downloaded.")
def _dl_clip():
"""Pre-download MMAudio's CLIP model (~3.95 GB) to avoid GPU-window budget drain."""
snapshot_download(repo_id="apple/DFN5B-CLIP-ViT-H-14-384")
print("MMAudio CLIP model pre-downloaded.")
def _dl_audioldm2():
"""Pre-download AudioLDM2 VAE/vocoder used by TARO's from_pretrained() calls."""
snapshot_download(repo_id="cvssp/audioldm2")
print("AudioLDM2 pre-downloaded.")
def _dl_bigvgan():
"""Pre-download BigVGAN vocoder (~489 MB) used by MMAudio."""
snapshot_download(repo_id="nvidia/bigvgan_v2_44khz_128band_512x")
print("BigVGAN vocoder pre-downloaded.")
print("[startup] Starting parallel checkpoint + model downloads…")
_t_dl_start = time.perf_counter()
with ThreadPoolExecutor(max_workers=7) as _pool:
_fut_taro = _pool.submit(_dl_taro)
_fut_mmaudio = _pool.submit(_dl_mmaudio)
_fut_hunyuan = _pool.submit(_dl_hunyuan)
_fut_clap = _pool.submit(_dl_clap)
_fut_clip = _pool.submit(_dl_clip)
_fut_aldm2 = _pool.submit(_dl_audioldm2)
_fut_bigvgan = _pool.submit(_dl_bigvgan)
# Raise any download exceptions immediately
for _fut in as_completed([_fut_taro, _fut_mmaudio, _fut_hunyuan,
_fut_clap, _fut_clip, _fut_aldm2, _fut_bigvgan]):
_fut.result()
cavp_ckpt_path, onset_ckpt_path, taro_ckpt_path = _fut_taro.result()
mmaudio_model_path, mmaudio_vae_path, mmaudio_synchformer_path = _fut_mmaudio.result()
print(f"[startup] All downloads done in {time.perf_counter() - _t_dl_start:.1f}s")
# ================================================================== #
# SHARED CONSTANTS / HELPERS #
# ================================================================== #
# CPU β†’ GPU context passing via function-name-keyed global store.
#
# Problem: ZeroGPU runs @spaces.GPU functions on its own worker thread, so
# threading.local() is invisible to the GPU worker. Passing ctx as a
# function argument exposes it to Gradio's API endpoint, causing
# "Too many arguments" errors.
#
# Solution: store context in a plain global dict keyed by function name.
# A per-key Lock serialises concurrent callers for the same function
# (ZeroGPU is already synchronous β€” the wrapper blocks until the GPU fn
# returns β€” so in practice only one call per GPU fn is in-flight at a time).
# The global dict is readable from any thread.
_GPU_CTX: dict = {}
_GPU_CTX_LOCK = threading.Lock()
def _ctx_store(fn_name: str, data: dict) -> None:
"""Store *data* under *fn_name* key (overwrites previous)."""
with _GPU_CTX_LOCK:
_GPU_CTX[fn_name] = data
def _ctx_load(fn_name: str) -> dict:
"""Pop and return the context dict stored under *fn_name*."""
with _GPU_CTX_LOCK:
return _GPU_CTX.pop(fn_name, {})
MAX_SLOTS = 8 # max parallel generation slots shown in UI
MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≀ ~64 s at 8 s/seg)
# Segment overlay palette β€” shared between _build_waveform_html and _build_regen_pending_html
SEG_COLORS = [
"rgba(100,180,255,{a})", "rgba(255,160,100,{a})",
"rgba(120,220,140,{a})", "rgba(220,120,220,{a})",
"rgba(255,220,80,{a})", "rgba(80,220,220,{a})",
"rgba(255,100,100,{a})", "rgba(180,255,180,{a})",
]
# ------------------------------------------------------------------ #
# Micro-helpers that eliminate repeated boilerplate across the file #
# ------------------------------------------------------------------ #
def _ensure_syspath(subdir: str) -> str:
"""Add *subdir* (relative to app.py) to sys.path if not already present.
Returns the absolute path for convenience."""
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), subdir)
if p not in sys.path:
sys.path.insert(0, p)
return p
def _get_device_and_dtype() -> tuple:
"""Return (device, weight_dtype) pair used by all GPU functions."""
device = "cuda" if torch.cuda.is_available() else "cpu"
return device, torch.bfloat16
def _extract_segment_clip(silent_video: str, seg_start: float, seg_dur: float,
output_path: str) -> str:
"""Stream-copy a segment from *silent_video* to *output_path*. Returns *output_path*."""
ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
output_path, vcodec="copy", an=None
).run(overwrite_output=True, quiet=True)
return output_path
# Per-slot reentrant locks β€” prevent concurrent regens on the same slot from
# producing a race condition where the second regen reads stale state
# (the shared seg_state textbox hasn't been updated yet by the first regen).
# Locks are keyed by slot_id string (e.g. "taro_0", "mma_2").
_SLOT_LOCKS: dict = {}
_SLOT_LOCKS_MUTEX = threading.Lock()
def _get_slot_lock(slot_id: str) -> threading.Lock:
with _SLOT_LOCKS_MUTEX:
if slot_id not in _SLOT_LOCKS:
_SLOT_LOCKS[slot_id] = threading.Lock()
return _SLOT_LOCKS[slot_id]
def set_global_seed(seed: int) -> None:
np.random.seed(seed % (2**32))
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def get_random_seed() -> int:
return random.randint(0, 2**32 - 1)
def get_video_duration(video_path: str) -> float:
"""Return video duration in seconds (CPU only)."""
probe = ffmpeg.probe(video_path)
return float(probe["format"]["duration"])
def strip_audio_from_video(video_path: str, output_path: str) -> None:
"""Write a silent copy of *video_path* to *output_path* (stream-copy, no re-encode)."""
ffmpeg.input(video_path).output(output_path, vcodec="copy", an=None).run(
overwrite_output=True, quiet=True
)
def _transcode_for_browser(video_path: str) -> str:
"""Re-encode uploaded video to H.264/AAC MP4 so the browser preview widget can play it.
Returns a NEW path in a fresh /tmp/gradio/ subdirectory. Gradio probes the
returned path fresh, sees H.264, and serves it directly without its own
slow fallback converter. The in-place overwrite approach loses the race
because Gradio probes the original path at upload time before this callback runs.
Only called on upload β€” not during generation.
"""
if video_path is None:
return video_path
try:
probe = ffmpeg.probe(video_path)
has_audio = any(s["codec_type"] == "audio" for s in probe.get("streams", []))
# Check if already H.264 β€” skip transcode if so
video_streams = [s for s in probe.get("streams", []) if s["codec_type"] == "video"]
if video_streams and video_streams[0].get("codec_name") == "h264":
print(f"[transcode_for_browser] already H.264, skipping")
return video_path
# Write the H.264 output into the SAME directory as the original upload.
# Gradio's file server only allows paths under dirs it registered β€” the
# upload dir is already allowed, so a sibling file there will serve fine.
import os as _os
upload_dir = _os.path.dirname(video_path)
stem = _os.path.splitext(_os.path.basename(video_path))[0]
out_path = _os.path.join(upload_dir, stem + "_h264.mp4")
kwargs = dict(
vcodec="libx264", preset="fast", crf=18,
pix_fmt="yuv420p", movflags="+faststart",
)
if has_audio:
kwargs["acodec"] = "aac"
kwargs["audio_bitrate"] = "128k"
else:
kwargs["an"] = None
# map 0:v:0 explicitly to skip non-video streams (e.g. data/timecode tracks)
ffmpeg.input(video_path).output(out_path, map="0:v:0", **kwargs).run(
overwrite_output=True, quiet=True
)
print(f"[transcode_for_browser] transcoded to H.264: {out_path}")
return out_path
except Exception as e:
print(f"[transcode_for_browser] failed, using original: {e}")
return video_path
# ------------------------------------------------------------------ #
# Temp directory registry β€” tracks dirs for cleanup on new generation #
# ------------------------------------------------------------------ #
_TEMP_DIRS: list = [] # list of tmp_dir paths created by generate_*
_TEMP_DIRS_MAX = 10 # keep at most this many; older ones get cleaned up
def _register_tmp_dir(tmp_dir: str) -> str:
"""Register a temp dir so it can be cleaned up when newer ones replace it."""
_TEMP_DIRS.append(tmp_dir)
while len(_TEMP_DIRS) > _TEMP_DIRS_MAX:
old = _TEMP_DIRS.pop(0)
try:
shutil.rmtree(old, ignore_errors=True)
print(f"[cleanup] Removed old temp dir: {old}")
except Exception:
pass
return tmp_dir
def _save_seg_wavs(wavs: list[np.ndarray], tmp_dir: str, prefix: str) -> list[str]:
"""Save a list of numpy wav arrays to .npy files, return list of paths.
This avoids serialising large float arrays into JSON/HTML data-state."""
paths = []
for i, w in enumerate(wavs):
p = os.path.join(tmp_dir, f"{prefix}_seg{i}.npy")
np.save(p, w)
paths.append(p)
return paths
def _load_seg_wavs(paths: list[str]) -> list[np.ndarray]:
"""Load segment wav arrays from .npy file paths."""
return [np.load(p) for p in paths]
# ------------------------------------------------------------------ #
# Shared model-loading helpers (deduplicate generate / regen code) #
# ------------------------------------------------------------------ #
def _load_taro_models(device, weight_dtype):
"""Load TARO MMDiT + AudioLDM2 VAE/vocoder. Returns (model_net, vae, vocoder, latents_scale)."""
from TARO.models import MMDiT
from diffusers import AutoencoderKL
from transformers import SpeechT5HifiGan
model_net = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
model_net.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"])
model_net.eval().to(weight_dtype)
vae = AutoencoderKL.from_pretrained("cvssp/audioldm2", subfolder="vae").to(device).eval()
vocoder = SpeechT5HifiGan.from_pretrained("cvssp/audioldm2", subfolder="vocoder").to(device)
latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
return model_net, vae, vocoder, latents_scale
def _load_taro_feature_extractors(device):
"""Load CAVP + onset extractors. Returns (extract_cavp, onset_model)."""
from TARO.cavp_util import Extract_CAVP_Features
from TARO.onset_util import VideoOnsetNet
extract_cavp = Extract_CAVP_Features(
device=device, config_path="TARO/cavp/cavp.yaml", ckpt_path=cavp_ckpt_path,
)
raw_sd = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
onset_sd = {}
for k, v in raw_sd.items():
if "model.net.model" in k: k = k.replace("model.net.model", "net.model")
elif "model.fc." in k: k = k.replace("model.fc", "fc")
onset_sd[k] = v
onset_model = VideoOnsetNet(pretrained=False).to(device)
onset_model.load_state_dict(onset_sd)
onset_model.eval()
return extract_cavp, onset_model
def _load_mmaudio_models(device, dtype):
"""Load MMAudio net + feature_utils. Returns (net, feature_utils, model_cfg, seq_cfg)."""
from mmaudio.eval_utils import all_model_cfg
from mmaudio.model.networks import get_my_mmaudio
from mmaudio.model.utils.features_utils import FeaturesUtils
model_cfg = all_model_cfg["large_44k_v2"]
model_cfg.model_path = Path(mmaudio_model_path)
model_cfg.vae_path = Path(mmaudio_vae_path)
model_cfg.synchformer_ckpt = Path(mmaudio_synchformer_path)
model_cfg.bigvgan_16k_path = None
seq_cfg = model_cfg.seq_cfg
net = get_my_mmaudio(model_cfg.model_name).to(device, dtype).eval()
net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True))
feature_utils = FeaturesUtils(
tod_vae_ckpt=str(model_cfg.vae_path),
synchformer_ckpt=str(model_cfg.synchformer_ckpt),
enable_conditions=True, mode=model_cfg.mode,
bigvgan_vocoder_ckpt=None, need_vae_encoder=False,
).to(device, dtype).eval()
return net, feature_utils, model_cfg, seq_cfg
def _load_hunyuan_model(device, model_size):
"""Load HunyuanFoley model dict + config. Returns (model_dict, cfg)."""
from hunyuanvideo_foley.utils.model_utils import load_model
model_size = model_size.lower()
config_map = {
"xl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xl.yaml",
"xxl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xxl.yaml",
}
config_path = config_map.get(model_size, config_map["xxl"])
hunyuan_weights_dir = str(HUNYUAN_MODEL_DIR / "HunyuanVideo-Foley")
print(f"[HunyuanFoley] Loading {model_size.upper()} model from {hunyuan_weights_dir}")
return load_model(hunyuan_weights_dir, config_path, device,
enable_offload=False, model_size=model_size)
def mux_video_audio(silent_video: str, audio_path: str, output_path: str,
model: str = None) -> None:
"""Mux a silent video with an audio file into *output_path*.
For HunyuanFoley (*model*="hunyuan") we use its own merge_audio_video which
handles its specific ffmpeg quirks; all other models use stream-copy muxing.
"""
if model == "hunyuan":
_ensure_syspath("HunyuanVideo-Foley")
from hunyuanvideo_foley.utils.media_utils import merge_audio_video
merge_audio_video(audio_path, silent_video, output_path)
else:
v_in = ffmpeg.input(silent_video)
a_in = ffmpeg.input(audio_path)
ffmpeg.output(
v_in["v:0"],
a_in["a:0"],
output_path,
vcodec="libx264", preset="fast", crf=18,
pix_fmt="yuv420p",
acodec="aac", audio_bitrate="128k",
movflags="+faststart",
).run(overwrite_output=True, quiet=True)
# ------------------------------------------------------------------ #
# Shared sliding-window segmentation and crossfade helpers #
# Used by all three models (TARO, MMAudio, HunyuanFoley). #
# ------------------------------------------------------------------ #
def _build_segments(total_dur_s: float, window_s: float, crossfade_s: float) -> list[tuple[float, float]]:
"""Return list of (start, end) pairs covering *total_dur_s* with a sliding
window of *window_s* and *crossfade_s* overlap between consecutive segments."""
# Safety: clamp crossfade to < half the window so step_s stays positive
crossfade_s = min(crossfade_s, window_s * 0.5)
if total_dur_s <= window_s:
return [(0.0, total_dur_s)]
step_s = window_s - crossfade_s
segments, seg_start = [], 0.0
while True:
if seg_start + window_s >= total_dur_s:
seg_start = max(0.0, total_dur_s - window_s)
segments.append((seg_start, total_dur_s))
break
segments.append((seg_start, seg_start + window_s))
seg_start += step_s
return segments
def _cf_join(a: np.ndarray, b: np.ndarray,
crossfade_s: float, db_boost: float, sr: int) -> np.ndarray:
"""Equal-power crossfade join. Works for both mono (T,) and stereo (C, T) arrays.
Stereo arrays are expected in (channels, samples) layout.
db_boost is applied to the overlap region as a whole (after blending), so
it compensates for the -3 dB equal-power dip without doubling amplitude.
Applying gain to each side independently (the common mistake) causes a
+3 dB loudness bump at the seam β€” this version avoids that."""
stereo = a.ndim == 2
n_a = a.shape[1] if stereo else len(a)
n_b = b.shape[1] if stereo else len(b)
cf = min(int(round(crossfade_s * sr)), n_a, n_b)
if cf <= 0:
return np.concatenate([a, b], axis=1 if stereo else 0)
gain = 10 ** (db_boost / 20.0)
t = np.linspace(0.0, 1.0, cf, dtype=np.float32)
fade_out = np.cos(t * np.pi / 2) # 1 β†’ 0
fade_in = np.sin(t * np.pi / 2) # 0 β†’ 1
if stereo:
# Blend first, then apply boost to the overlap region as a unit
overlap = (a[:, -cf:] * fade_out + b[:, :cf] * fade_in) * gain
return np.concatenate([a[:, :-cf], overlap, b[:, cf:]], axis=1)
else:
overlap = (a[-cf:] * fade_out + b[:cf] * fade_in) * gain
return np.concatenate([a[:-cf], overlap, b[cf:]])
# ================================================================== #
# TARO #
# ================================================================== #
# Constants sourced from TARO/infer.py and TARO/models.py:
# SR=16000, TRUNCATE=131072 β†’ 8.192 s window
# TRUNCATE_FRAME = 4 fps Γ— 131072/16000 = 32 CAVP frames per window
# TRUNCATE_ONSET = 120 onset frames per window
# latent shape: (1, 8, 204, 16) β€” fixed by MMDiT architecture
# latents_scale: [0.18215]*8 β€” AudioLDM2 VAE scale factor
# ================================================================== #
# ================================================================== #
# MODEL CONSTANTS & CONFIGURATION REGISTRY #
# ================================================================== #
# All per-model numeric constants live here β€” MODEL_CONFIGS is the #
# single source of truth consumed by duration estimation, segmentation,#
# and the UI. Standalone names kept only where other code references #
# them by name (TARO geometry, TARGET_SR, GPU_DURATION_CAP). #
# ================================================================== #
# TARO geometry β€” referenced directly in _taro_infer_segment
TARO_SR = 16000
TARO_TRUNCATE = 131072
TARO_FPS = 4
TARO_TRUNCATE_FRAME = int(TARO_FPS * TARO_TRUNCATE / TARO_SR) # 32
TARO_TRUNCATE_ONSET = 120
TARO_MODEL_DUR = TARO_TRUNCATE / TARO_SR # 8.192 s
GPU_DURATION_CAP = 300 # hard cap per @spaces.GPU call β€” never reserve more than this
MODEL_CONFIGS = {
"taro": {
"window_s": TARO_MODEL_DUR, # 8.192 s
"sr": TARO_SR, # 16000 (output resampled to TARGET_SR)
"secs_per_step": 0.025, # measured 0.023 s/step on H200
"load_overhead": 15, # model load + CAVP feature extraction
"tab_prefix": "taro",
"label": "TARO",
"regen_fn": None, # set after function definitions (avoids forward-ref)
},
"mmaudio": {
"window_s": 8.0, # MMAudio's fixed generation window
"sr": 48000, # resampled from 44100 in post-processing
"secs_per_step": 0.25, # measured 0.230 s/step on H200
"load_overhead": 30, # 15s warm + 15s model init
"tab_prefix": "mma",
"label": "MMAudio",
"regen_fn": None,
},
"hunyuan": {
"window_s": 15.0, # HunyuanFoley max video duration
"sr": 48000,
"secs_per_step": 0.35, # measured 0.328 s/step on H200
"load_overhead": 55, # ~55s to load the 10 GB XXL weights
"tab_prefix": "hf",
"label": "HunyuanFoley",
"regen_fn": None,
},
}
# Convenience aliases used only in the TARO inference path
TARO_SECS_PER_STEP = MODEL_CONFIGS["taro"]["secs_per_step"]
MMAUDIO_WINDOW = MODEL_CONFIGS["mmaudio"]["window_s"]
MMAUDIO_SECS_PER_STEP = MODEL_CONFIGS["mmaudio"]["secs_per_step"]
HUNYUAN_MAX_DUR = MODEL_CONFIGS["hunyuan"]["window_s"]
HUNYUAN_SECS_PER_STEP = MODEL_CONFIGS["hunyuan"]["secs_per_step"]
def _clamp_duration(secs: float, label: str) -> int:
"""Clamp a raw GPU-seconds estimate to [60, GPU_DURATION_CAP] and log it."""
result = min(GPU_DURATION_CAP, max(60, int(secs)))
print(f"[duration] {label}: {secs:.0f}s raw β†’ {result}s reserved")
return result
def _estimate_gpu_duration(model_key: str, num_samples: int, num_steps: int,
total_dur_s: float = None, crossfade_s: float = 0,
video_file: str = None) -> int:
"""Estimate GPU seconds for a full generation call.
Formula: num_samples Γ— n_segs Γ— num_steps Γ— secs_per_step + load_overhead
"""
cfg = MODEL_CONFIGS[model_key]
try:
if total_dur_s is None:
total_dur_s = get_video_duration(video_file)
n_segs = len(_build_segments(total_dur_s, cfg["window_s"], float(crossfade_s)))
except Exception:
n_segs = 1
secs = int(num_samples) * n_segs * int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
print(f"[duration] {cfg['label']}: {int(num_samples)}samp Γ— {n_segs}seg Γ— "
f"{int(num_steps)}steps β†’ {secs:.0f}s β†’ capped ", end="")
return _clamp_duration(secs, cfg["label"])
def _estimate_regen_duration(model_key: str, num_steps: int) -> int:
"""Estimate GPU seconds for a single-segment regen call."""
cfg = MODEL_CONFIGS[model_key]
secs = int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
print(f"[duration] {cfg['label']} regen: 1 seg Γ— {int(num_steps)} steps β†’ ", end="")
return _clamp_duration(secs, f"{cfg['label']} regen")
_TARO_CACHE_MAXLEN = 16 # evict oldest entries beyond this limit
_TARO_INFERENCE_CACHE: dict = {} # keyed by (video_file, seed, cfg, steps, mode, crossfade_s)
_TARO_CACHE_LOCK = threading.Lock()
def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: float) -> int:
n_segs = len(_build_segments(total_dur_s, TARO_MODEL_DUR, crossfade_s))
time_per_seg = num_steps * TARO_SECS_PER_STEP
max_s = int(600.0 / (n_segs * time_per_seg))
return max(1, min(max_s, MAX_SLOTS))
def _taro_duration(video_file, seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, num_samples):
"""Pre-GPU callable β€” must match _taro_gpu_infer's input order exactly."""
return _estimate_gpu_duration("taro", int(num_samples), int(num_steps),
video_file=video_file, crossfade_s=crossfade_s)
def _taro_infer_segment(
model, vae, vocoder,
cavp_feats_full, onset_feats_full,
seg_start_s: float, seg_end_s: float,
device, weight_dtype,
cfg_scale: float, num_steps: int, mode: str,
latents_scale,
euler_sampler, euler_maruyama_sampler,
) -> np.ndarray:
"""Single-segment TARO inference. Returns wav array trimmed to segment length."""
# CAVP features (4 fps)
cavp_start = int(round(seg_start_s * TARO_FPS))
cavp_slice = cavp_feats_full[cavp_start : cavp_start + TARO_TRUNCATE_FRAME]
if cavp_slice.shape[0] < TARO_TRUNCATE_FRAME:
pad = np.zeros(
(TARO_TRUNCATE_FRAME - cavp_slice.shape[0],) + cavp_slice.shape[1:],
dtype=cavp_slice.dtype,
)
cavp_slice = np.concatenate([cavp_slice, pad], axis=0)
video_feats = torch.from_numpy(cavp_slice).unsqueeze(0).to(device, weight_dtype)
# Onset features (onset_fps = TRUNCATE_ONSET / MODEL_DUR β‰ˆ 14.65 fps)
onset_fps = TARO_TRUNCATE_ONSET / TARO_MODEL_DUR
onset_start = int(round(seg_start_s * onset_fps))
onset_slice = onset_feats_full[onset_start : onset_start + TARO_TRUNCATE_ONSET]
if onset_slice.shape[0] < TARO_TRUNCATE_ONSET:
onset_slice = np.pad(
onset_slice,
((0, TARO_TRUNCATE_ONSET - onset_slice.shape[0]),),
mode="constant",
)
onset_feats_t = torch.from_numpy(onset_slice).unsqueeze(0).to(device, weight_dtype)
# Latent noise β€” shape matches MMDiT architecture (in_channels=8, 204Γ—16 spatial)
z = torch.randn(1, model.in_channels, 204, 16, device=device, dtype=weight_dtype)
sampling_kwargs = dict(
model=model,
latents=z,
y=onset_feats_t,
context=video_feats,
num_steps=int(num_steps),
heun=False,
cfg_scale=float(cfg_scale),
guidance_low=0.0,
guidance_high=0.7,
path_type="linear",
)
with torch.no_grad():
samples = (euler_maruyama_sampler if mode == "sde" else euler_sampler)(**sampling_kwargs)
# samplers return (output_tensor, zs) β€” index [0] for the audio latent
if isinstance(samples, tuple):
samples = samples[0]
# Decode: AudioLDM2 VAE β†’ mel β†’ vocoder β†’ waveform
samples = vae.decode(samples / latents_scale).sample
wav = vocoder(samples.squeeze().float()).detach().cpu().numpy()
seg_samples = int(round((seg_end_s - seg_start_s) * TARO_SR))
return wav[:seg_samples]
# ================================================================== #
# TARO 16 kHz β†’ 48 kHz upsample #
# ================================================================== #
# TARO generates at 16 kHz; all other models output at 44.1/48 kHz.
# We upsample via sinc resampling (torchaudio, CPU-only) so the final
# stitched audio is uniformly at 48 kHz across all three models.
TARGET_SR = 48000 # unified output sample rate for all three models
TARO_SR_OUT = TARGET_SR
def _resample_to_target(wav: np.ndarray, src_sr: int,
dst_sr: int = None) -> np.ndarray:
"""Resample *wav* (mono or stereo numpy float32) from *src_sr* to *dst_sr*.
*dst_sr* defaults to TARGET_SR (48 kHz). No-op if src_sr == dst_sr.
Uses torchaudio Kaiser-windowed sinc resampling β€” CPU-only, ZeroGPU-safe.
"""
if dst_sr is None:
dst_sr = TARGET_SR
if src_sr == dst_sr:
return wav
stereo = wav.ndim == 2
t = torch.from_numpy(np.ascontiguousarray(wav.astype(np.float32)))
if not stereo:
t = t.unsqueeze(0) # [1, T]
t = torchaudio.functional.resample(t, src_sr, dst_sr)
if not stereo:
t = t.squeeze(0) # [T]
return t.numpy()
def _upsample_taro(wav_16k: np.ndarray) -> np.ndarray:
"""Upsample a mono 16 kHz numpy array to 48 kHz via sinc resampling (CPU).
torchaudio.functional.resample uses a Kaiser-windowed sinc filter β€”
mathematically optimal for bandlimited signals, zero CUDA risk.
Returns a mono float32 numpy array at 48 kHz.
"""
dur_in = len(wav_16k) / TARO_SR
print(f"[TARO upsample] {dur_in:.2f}s @ {TARO_SR}Hz β†’ {TARGET_SR}Hz (sinc, CPU) …")
result = _resample_to_target(wav_16k, TARO_SR)
print(f"[TARO upsample] done β€” {len(result)/TARGET_SR:.2f}s @ {TARGET_SR}Hz "
f"(expected {dur_in * 3:.2f}s, ratio 3Γ—)")
return result
def _stitch_wavs(wavs: list[np.ndarray], crossfade_s: float, db_boost: float,
total_dur_s: float, sr: int) -> np.ndarray:
"""Crossfade-join a list of wav arrays and trim to *total_dur_s*.
Works for both mono (T,) and stereo (C, T) arrays."""
out = wavs[0]
for nw in wavs[1:]:
out = _cf_join(out, nw, crossfade_s, db_boost, sr)
n = int(round(total_dur_s * sr))
return out[:, :n] if out.ndim == 2 else out[:n]
def _save_wav(path: str, wav: np.ndarray, sr: int) -> None:
"""Save a numpy wav array (mono or stereo) to *path* via torchaudio."""
t = torch.from_numpy(np.ascontiguousarray(wav))
if t.ndim == 1:
t = t.unsqueeze(0)
torchaudio.save(path, t, sr)
def _log_inference_timing(label: str, elapsed: float, n_segs: int,
num_steps: int, constant: float) -> None:
"""Print a standardised inference-timing summary line."""
total_steps = n_segs * num_steps
secs_per_step = elapsed / total_steps if total_steps > 0 else 0
print(f"[{label}] Inference done: {n_segs} seg(s) Γ— {num_steps} steps in "
f"{elapsed:.1f}s wall β†’ {secs_per_step:.3f}s/step "
f"(current constant={constant})")
def _build_seg_meta(*, segments, wav_paths, audio_path, video_path,
silent_video, sr, model, crossfade_s, crossfade_db,
total_dur_s, **extras) -> dict:
"""Build the seg_meta dict shared by all three generate_* functions.
Model-specific keys are passed via **extras."""
meta = {
"segments": segments,
"wav_paths": wav_paths,
"audio_path": audio_path,
"video_path": video_path,
"silent_video": silent_video,
"sr": sr,
"model": model,
"crossfade_s": crossfade_s,
"crossfade_db": crossfade_db,
"total_dur_s": total_dur_s,
}
meta.update(extras)
return meta
def _post_process_samples(results: list, *, model: str, tmp_dir: str,
silent_video: str, segments: list,
crossfade_s: float, crossfade_db: float,
total_dur_s: float, sr: int,
extra_meta_fn=None) -> list:
"""Shared CPU post-processing for all three generate_* wrappers.
Each entry in *results* is a tuple whose first element is a list of
per-segment wav arrays. The remaining elements are model-specific
(e.g. TARO returns features, HunyuanFoley returns text_feats).
*extra_meta_fn(sample_idx, result_tuple, tmp_dir) -> dict* is an optional
callback that returns model-specific extra keys to merge into seg_meta
(e.g. cavp_path, onset_path, text_feats_path).
Returns a list of (video_path, audio_path, seg_meta) tuples.
"""
outputs = []
for sample_idx, result in enumerate(results):
seg_wavs = result[0]
full_wav = _stitch_wavs(seg_wavs, crossfade_s, crossfade_db, total_dur_s, sr)
audio_path = os.path.join(tmp_dir, f"{model}_{sample_idx}.wav")
_save_wav(audio_path, full_wav, sr)
video_path = os.path.join(tmp_dir, f"{model}_{sample_idx}.mp4")
mux_video_audio(silent_video, audio_path, video_path, model=model)
wav_paths = _save_seg_wavs(seg_wavs, tmp_dir, f"{model}_{sample_idx}")
extras = extra_meta_fn(sample_idx, result, tmp_dir) if extra_meta_fn else {}
seg_meta = _build_seg_meta(
segments=segments, wav_paths=wav_paths, audio_path=audio_path,
video_path=video_path, silent_video=silent_video, sr=sr,
model=model, crossfade_s=crossfade_s, crossfade_db=crossfade_db,
total_dur_s=total_dur_s, **extras,
)
outputs.append((video_path, audio_path, seg_meta))
return outputs
def _cpu_preprocess(video_file: str, model_dur: float,
crossfade_s: float) -> tuple:
"""Shared CPU pre-processing for all generate_* wrappers.
Returns (tmp_dir, silent_video, total_dur_s, segments)."""
tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
silent_video = os.path.join(tmp_dir, "silent_input.mp4")
strip_audio_from_video(video_file, silent_video)
total_dur_s = get_video_duration(video_file)
segments = _build_segments(total_dur_s, model_dur, crossfade_s)
return tmp_dir, silent_video, total_dur_s, segments
@spaces.GPU(duration=_taro_duration)
def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, num_samples):
"""GPU-only TARO inference β€” model loading + feature extraction + diffusion.
Returns list of (wavs_list, onset_feats) per sample."""
seed_val = int(seed_val)
crossfade_s = float(crossfade_s)
num_samples = int(num_samples)
if seed_val < 0:
seed_val = random.randint(0, 2**32 - 1)
torch.set_grad_enabled(False)
device, weight_dtype = _get_device_and_dtype()
_ensure_syspath("TARO")
from TARO.onset_util import extract_onset
from TARO.samplers import euler_sampler, euler_maruyama_sampler
ctx = _ctx_load("taro_gpu_infer")
tmp_dir = ctx["tmp_dir"]
silent_video = ctx["silent_video"]
segments = ctx["segments"]
total_dur_s = ctx["total_dur_s"]
extract_cavp, onset_model = _load_taro_feature_extractors(device)
cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
# Onset features depend only on the video β€” extract once for all samples
onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
# Free feature extractors before loading the heavier inference models
del extract_cavp, onset_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
model, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype)
results = [] # list of (wavs, onset_feats) per sample
for sample_idx in range(num_samples):
sample_seed = seed_val + sample_idx
cache_key = (video_file, sample_seed, float(cfg_scale), int(num_steps), mode, crossfade_s)
with _TARO_CACHE_LOCK:
cached = _TARO_INFERENCE_CACHE.get(cache_key)
if cached is not None:
print(f"[TARO] Sample {sample_idx+1}: cache hit.")
results.append((cached["wavs"], cavp_feats, None))
else:
set_global_seed(sample_seed)
wavs = []
_t_infer_start = time.perf_counter()
for seg_start_s, seg_end_s in segments:
print(f"[TARO] Sample {sample_idx+1} | {seg_start_s:.2f}s – {seg_end_s:.2f}s")
wav = _taro_infer_segment(
model, vae, vocoder,
cavp_feats, onset_feats,
seg_start_s, seg_end_s,
device, weight_dtype,
cfg_scale, num_steps, mode,
latents_scale,
euler_sampler, euler_maruyama_sampler,
)
wavs.append(wav)
_log_inference_timing("TARO", time.perf_counter() - _t_infer_start,
len(segments), int(num_steps), TARO_SECS_PER_STEP)
with _TARO_CACHE_LOCK:
_TARO_INFERENCE_CACHE[cache_key] = {"wavs": wavs}
while len(_TARO_INFERENCE_CACHE) > _TARO_CACHE_MAXLEN:
_TARO_INFERENCE_CACHE.pop(next(iter(_TARO_INFERENCE_CACHE)))
results.append((wavs, cavp_feats, onset_feats))
# Free GPU memory between samples so VRAM fragmentation doesn't
# degrade diffusion quality on samples 2, 3, 4, etc.
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, num_samples):
"""TARO: video-conditioned diffusion, 16 kHz, 8.192 s sliding window.
CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost."""
crossfade_s = float(crossfade_s)
crossfade_db = float(crossfade_db)
num_samples = int(num_samples)
# ── CPU pre-processing (no GPU needed) ──
tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
video_file, TARO_MODEL_DUR, crossfade_s)
_ctx_store("taro_gpu_infer", {
"tmp_dir": tmp_dir, "silent_video": silent_video,
"segments": segments, "total_dur_s": total_dur_s,
})
# ── GPU inference only ──
results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, num_samples)
# ── CPU post-processing (no GPU needed) ──
# Upsample 16kHz β†’ 48kHz and normalise result tuples to (seg_wavs, ...)
cavp_path = os.path.join(tmp_dir, "taro_cavp.npy")
onset_path = os.path.join(tmp_dir, "taro_onset.npy")
_feats_saved = False
def _upsample_and_save_feats(result):
nonlocal _feats_saved
wavs, cavp_feats, onset_feats = result
wavs = [_upsample_taro(w) for w in wavs]
if not _feats_saved:
np.save(cavp_path, cavp_feats)
if onset_feats is not None:
np.save(onset_path, onset_feats)
_feats_saved = True
return (wavs, cavp_feats, onset_feats)
results = [_upsample_and_save_feats(r) for r in results]
def _taro_extras(sample_idx, result, td):
return {"cavp_path": cavp_path, "onset_path": onset_path}
outputs = _post_process_samples(
results, model="taro", tmp_dir=tmp_dir,
silent_video=silent_video, segments=segments,
crossfade_s=crossfade_s, crossfade_db=crossfade_db,
total_dur_s=total_dur_s, sr=TARO_SR_OUT,
extra_meta_fn=_taro_extras,
)
return _pad_outputs(outputs)
# ================================================================== #
# MMAudio #
# ================================================================== #
# Constants sourced from MMAudio/mmaudio/model/sequence_config.py:
# CONFIG_44K: duration=8.0 s, sampling_rate=44100
# CLIP encoder: 8 fps, 384Γ—384 px
# Synchformer: 25 fps, 224Γ—224 px
# Default variant: large_44k_v2
# MMAudio uses flow-matching (FlowMatching with euler inference).
# generate() handles all feature extraction + decoding internally.
# ================================================================== #
def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
"""Pre-GPU callable β€” must match _mmaudio_gpu_infer's input order exactly."""
return _estimate_gpu_duration("mmaudio", int(num_samples), int(num_steps),
video_file=video_file, crossfade_s=crossfade_s)
@spaces.GPU(duration=_mmaudio_duration)
def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
"""GPU-only MMAudio inference β€” model loading + flow-matching generation.
Returns list of (seg_audios, sr) per sample."""
_ensure_syspath("MMAudio")
from mmaudio.eval_utils import generate, load_video
from mmaudio.model.flow_matching import FlowMatching
seed_val = int(seed_val)
num_samples = int(num_samples)
crossfade_s = float(crossfade_s)
device, dtype = _get_device_and_dtype()
net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
ctx = _ctx_load("mmaudio_gpu_infer")
segments = ctx["segments"]
seg_clip_paths = ctx["seg_clip_paths"]
sr = seq_cfg.sampling_rate # 44100
results = []
for sample_idx in range(num_samples):
rng = torch.Generator(device=device)
if seed_val >= 0:
rng.manual_seed(seed_val + sample_idx)
else:
rng.seed()
seg_audios = []
_t_mma_start = time.perf_counter()
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
for seg_i, (seg_start, seg_end) in enumerate(segments):
seg_dur = seg_end - seg_start
seg_path = seg_clip_paths[seg_i]
video_info = load_video(seg_path, seg_dur)
clip_frames = video_info.clip_frames.unsqueeze(0)
sync_frames = video_info.sync_frames.unsqueeze(0)
actual_dur = video_info.duration_sec
seq_cfg.duration = actual_dur
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
print(f"[MMAudio] Sample {sample_idx+1} | seg {seg_i+1}/{len(segments)} "
f"{seg_start:.1f}–{seg_end:.1f}s | dur={actual_dur:.2f}s | prompt='{prompt}'")
with torch.no_grad():
audios = generate(
clip_frames,
sync_frames,
[prompt],
negative_text=[negative_prompt] if negative_prompt else None,
feature_utils=feature_utils,
net=net,
fm=fm,
rng=rng,
cfg_strength=float(cfg_strength),
)
wav = audios.float().cpu()[0].numpy() # (C, T)
seg_samples = int(round(seg_dur * sr))
wav = wav[:, :seg_samples]
seg_audios.append(wav)
_log_inference_timing("MMAudio", time.perf_counter() - _t_mma_start,
len(segments), int(num_steps), MMAUDIO_SECS_PER_STEP)
results.append((seg_audios, sr))
# Free GPU memory between samples to prevent VRAM fragmentation
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
"""MMAudio: flow-matching video-to-audio, 44.1 kHz, 8 s sliding window.
CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost."""
num_samples = int(num_samples)
crossfade_s = float(crossfade_s)
crossfade_db = float(crossfade_db)
# ── CPU pre-processing ──
tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
video_file, MMAUDIO_WINDOW, crossfade_s)
print(f"[MMAudio] Video={total_dur_s:.2f}s | {len(segments)} segment(s) Γ— ≀8 s")
seg_clip_paths = [
_extract_segment_clip(silent_video, s, e - s, os.path.join(tmp_dir, f"mma_seg_{i}.mp4"))
for i, (s, e) in enumerate(segments)
]
_ctx_store("mmaudio_gpu_infer", {"segments": segments, "seg_clip_paths": seg_clip_paths})
# ── GPU inference only ──
results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db,
num_samples)
# ── CPU post-processing ──
# Resample 44100 β†’ 48000 and normalise tuples to (seg_wavs, ...)
resampled = []
for seg_audios, sr in results:
if sr != TARGET_SR:
print(f"[MMAudio upsample] resampling {sr}Hz β†’ {TARGET_SR}Hz (sinc, CPU) …")
seg_audios = [_resample_to_target(w, sr) for w in seg_audios]
print(f"[MMAudio upsample] done β€” {len(seg_audios)} seg(s) @ {TARGET_SR}Hz")
resampled.append((seg_audios,))
outputs = _post_process_samples(
resampled, model="mmaudio", tmp_dir=tmp_dir,
silent_video=silent_video, segments=segments,
crossfade_s=crossfade_s, crossfade_db=crossfade_db,
total_dur_s=total_dur_s, sr=TARGET_SR,
)
return _pad_outputs(outputs)
# ================================================================== #
# HunyuanVideoFoley #
# ================================================================== #
# Constants sourced from HunyuanVideo-Foley/hunyuanvideo_foley/constants.py
# and configs/hunyuanvideo-foley-xxl.yaml:
# sample_rate = 48000 Hz (from DAC VAE)
# audio_frame_rate = 50 (latent fps, xxl config)
# max video duration = 15 s
# SigLIP2 fps = 8, Synchformer fps = 25
# CLAP text encoder: laion/larger_clap_general (auto-downloaded from HF Hub)
# Default guidance_scale=4.5, num_inference_steps=50
# ================================================================== #
def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db,
num_samples):
"""Pre-GPU callable β€” must match _hunyuan_gpu_infer's input order exactly."""
return _estimate_gpu_duration("hunyuan", int(num_samples), int(num_steps),
video_file=video_file, crossfade_s=crossfade_s)
@spaces.GPU(duration=_hunyuan_duration)
def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db,
num_samples):
"""GPU-only HunyuanFoley inference β€” model loading + feature extraction + denoising.
Returns list of (seg_wavs, sr, text_feats) per sample."""
_ensure_syspath("HunyuanVideo-Foley")
from hunyuanvideo_foley.utils.model_utils import denoise_process
from hunyuanvideo_foley.utils.feature_utils import feature_process
seed_val = int(seed_val)
num_samples = int(num_samples)
crossfade_s = float(crossfade_s)
if seed_val >= 0:
set_global_seed(seed_val)
device, _ = _get_device_and_dtype()
device = torch.device(device)
model_size = model_size.lower()
model_dict, cfg = _load_hunyuan_model(device, model_size)
ctx = _ctx_load("hunyuan_gpu_infer")
segments = ctx["segments"]
total_dur_s = ctx["total_dur_s"]
dummy_seg_path = ctx["dummy_seg_path"]
seg_clip_paths = ctx["seg_clip_paths"]
# Text feature extraction (GPU β€” runs once for all segments)
_, text_feats, _ = feature_process(
dummy_seg_path,
prompt if prompt else "",
model_dict,
cfg,
neg_prompt=negative_prompt if negative_prompt else None,
)
# Import visual-only feature extractor to avoid redundant text extraction
# per segment (text_feats already computed once above for the whole batch).
from hunyuanvideo_foley.utils.feature_utils import encode_video_features
results = []
for sample_idx in range(num_samples):
seg_wavs = []
sr = 48000
_t_hny_start = time.perf_counter()
for seg_i, (seg_start, seg_end) in enumerate(segments):
seg_dur = seg_end - seg_start
seg_path = seg_clip_paths[seg_i]
# Extract only visual features β€” reuse text_feats from above
visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict)
print(f"[HunyuanFoley] Sample {sample_idx+1} | seg {seg_i+1}/{len(segments)} "
f"{seg_start:.1f}–{seg_end:.1f}s β†’ {seg_audio_len:.2f}s audio")
audio_batch, sr = denoise_process(
visual_feats,
text_feats,
seg_audio_len,
model_dict,
cfg,
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_steps),
batch_size=1,
)
wav = audio_batch[0].float().cpu().numpy()
seg_samples = int(round(seg_dur * sr))
wav = wav[:, :seg_samples]
seg_wavs.append(wav)
_log_inference_timing("HunyuanFoley", time.perf_counter() - _t_hny_start,
len(segments), int(num_steps), HUNYUAN_SECS_PER_STEP)
results.append((seg_wavs, sr, text_feats))
# Free GPU memory between samples to prevent VRAM fragmentation
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
"""HunyuanVideoFoley: text-guided foley, 48 kHz, up to 15 s.
CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost."""
num_samples = int(num_samples)
crossfade_s = float(crossfade_s)
crossfade_db = float(crossfade_db)
# ── CPU pre-processing (no GPU needed) ──
tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
video_file, HUNYUAN_MAX_DUR, crossfade_s)
print(f"[HunyuanFoley] Video={total_dur_s:.2f}s | {len(segments)} segment(s) Γ— ≀15 s")
# Pre-extract dummy segment for text feature extraction (ffmpeg, CPU)
dummy_seg_path = _extract_segment_clip(
silent_video, 0, min(total_dur_s, HUNYUAN_MAX_DUR),
os.path.join(tmp_dir, "_seg_dummy.mp4"),
)
# Pre-extract all segment clips (ffmpeg, CPU)
seg_clip_paths = [
_extract_segment_clip(silent_video, s, e - s, os.path.join(tmp_dir, f"hny_seg_{i}.mp4"))
for i, (s, e) in enumerate(segments)
]
_ctx_store("hunyuan_gpu_infer", {
"segments": segments, "total_dur_s": total_dur_s,
"dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
})
# ── GPU inference only ──
results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db, num_samples)
# ── CPU post-processing (no GPU needed) ──
def _hunyuan_extras(sample_idx, result, td):
_, _sr, text_feats = result
path = os.path.join(td, f"hunyuan_{sample_idx}_text_feats.pt")
torch.save(text_feats, path)
return {"text_feats_path": path}
outputs = _post_process_samples(
results, model="hunyuan", tmp_dir=tmp_dir,
silent_video=silent_video, segments=segments,
crossfade_s=crossfade_s, crossfade_db=crossfade_db,
total_dur_s=total_dur_s, sr=48000,
extra_meta_fn=_hunyuan_extras,
)
return _pad_outputs(outputs)
# ================================================================== #
# SEGMENT REGENERATION HELPERS #
# ================================================================== #
# Each regen function:
# 1. Runs inference for ONE segment (random seed, current settings)
# 2. Splices the new wav into the stored wavs list
# 3. Re-stitches the full track, re-saves .wav and re-muxes .mp4
# 4. Returns (new_video_path, new_audio_path, updated_seg_meta, new_waveform_html)
# ================================================================== #
def _splice_and_save(new_wav, seg_idx, meta, slot_id):
"""Replace wavs[seg_idx] with new_wav, re-stitch, re-save, re-mux.
Returns (video_path, audio_path, updated_meta, waveform_html).
"""
wavs = _load_seg_wavs(meta["wav_paths"])
wavs[seg_idx]= new_wav
crossfade_s = float(meta["crossfade_s"])
crossfade_db = float(meta["crossfade_db"])
sr = int(meta["sr"])
total_dur_s = float(meta["total_dur_s"])
silent_video = meta["silent_video"]
segments = meta["segments"]
model = meta["model"]
full_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, sr)
# Save new audio β€” use a new timestamped filename so Gradio / the browser
# treats it as a genuinely different file and reloads the video player.
_ts = int(time.time() * 1000)
tmp_dir = os.path.dirname(meta["audio_path"])
_base = os.path.splitext(os.path.basename(meta["audio_path"]))[0]
# Strip any previous timestamp suffix before adding a new one
_base_clean = _base.rsplit("_regen_", 1)[0]
audio_path = os.path.join(tmp_dir, f"{_base_clean}_regen_{_ts}.wav")
_save_wav(audio_path, full_wav, sr)
# Re-mux into a new video file so the browser is forced to reload it
_vid_base = os.path.splitext(os.path.basename(meta["video_path"]))[0]
_vid_base_clean = _vid_base.rsplit("_regen_", 1)[0]
video_path = os.path.join(tmp_dir, f"{_vid_base_clean}_regen_{_ts}.mp4")
mux_video_audio(silent_video, audio_path, video_path, model=model)
# Save updated segment wavs to .npy files
updated_wav_paths = _save_seg_wavs(wavs, tmp_dir, os.path.splitext(_base_clean)[0])
updated_meta = dict(meta)
updated_meta["wav_paths"] = updated_wav_paths
updated_meta["audio_path"] = audio_path
updated_meta["video_path"] = video_path
state_json_new = json.dumps(updated_meta)
waveform_html = _build_waveform_html(audio_path, segments, slot_id, "",
state_json=state_json_new,
video_path=video_path,
crossfade_s=crossfade_s)
return video_path, audio_path, updated_meta, waveform_html
def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, slot_id=None):
# If cached CAVP/onset features exist, skip ~10s feature-extractor overhead
try:
meta = json.loads(seg_meta_json)
cavp_ok = os.path.exists(meta.get("cavp_path", ""))
onset_ok = os.path.exists(meta.get("onset_path", ""))
if cavp_ok and onset_ok:
cfg = MODEL_CONFIGS["taro"]
secs = int(num_steps) * cfg["secs_per_step"] + 5 # 5s model-load only
result = min(GPU_DURATION_CAP, max(30, int(secs)))
print(f"[duration] TARO regen (cache hit): 1 seg Γ— {int(num_steps)} steps β†’ {secs:.0f}s β†’ capped {result}s")
return result
except Exception:
pass
return _estimate_regen_duration("taro", int(num_steps))
@spaces.GPU(duration=_taro_regen_duration)
def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, slot_id=None):
"""GPU-only TARO regen β€” returns new_wav for a single segment."""
meta = json.loads(seg_meta_json)
seg_idx = int(seg_idx)
seg_start_s, seg_end_s = meta["segments"][seg_idx]
torch.set_grad_enabled(False)
device, weight_dtype = _get_device_and_dtype()
_ensure_syspath("TARO")
from TARO.samplers import euler_sampler, euler_maruyama_sampler
# Load cached CAVP/onset features from .npy files (CPU I/O, fast, outside GPU budget)
cavp_path = meta.get("cavp_path", "")
onset_path = meta.get("onset_path", "")
if cavp_path and os.path.exists(cavp_path) and onset_path and os.path.exists(onset_path):
print("[TARO regen] Loading cached CAVP + onset features from disk")
cavp_feats = np.load(cavp_path)
onset_feats = np.load(onset_path)
else:
print("[TARO regen] Cache miss β€” re-extracting CAVP + onset features")
from TARO.onset_util import extract_onset
extract_cavp, onset_model = _load_taro_feature_extractors(device)
silent_video = meta["silent_video"]
tmp_dir = tempfile.mkdtemp()
cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
del extract_cavp, onset_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
model_net, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype)
set_global_seed(random.randint(0, 2**32 - 1))
return _taro_infer_segment(
model_net, vae, vocoder, cavp_feats, onset_feats,
seg_start_s, seg_end_s, device, weight_dtype,
float(cfg_scale), int(num_steps), mode, latents_scale,
euler_sampler, euler_maruyama_sampler,
)
def regen_taro_segment(video_file, seg_idx, seg_meta_json,
seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, slot_id):
"""Regenerate one TARO segment. GPU inference + CPU splice/save."""
meta = json.loads(seg_meta_json)
seg_idx = int(seg_idx)
# GPU: inference β€” CAVP/onset features loaded from disk paths in seg_meta_json
new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, slot_id)
# Upsample 16kHz β†’ 48kHz (sinc, CPU)
new_wav = _upsample_taro(new_wav)
# CPU: splice, stitch, mux, save
video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
new_wav, seg_idx, meta, slot_id
)
return video_path, audio_path, json.dumps(updated_meta), waveform_html
def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db,
slot_id=None):
return _estimate_regen_duration("mmaudio", int(num_steps))
@spaces.GPU(duration=_mmaudio_regen_duration)
def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db,
slot_id=None):
"""GPU-only MMAudio regen β€” returns (new_wav, sr) for a single segment."""
meta = json.loads(seg_meta_json)
seg_idx = int(seg_idx)
seg_start, seg_end = meta["segments"][seg_idx]
seg_dur = seg_end - seg_start
_ensure_syspath("MMAudio")
from mmaudio.eval_utils import generate, load_video
from mmaudio.model.flow_matching import FlowMatching
device, dtype = _get_device_and_dtype()
net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
sr = seq_cfg.sampling_rate
# Extract segment clip inside the GPU function β€” ffmpeg is CPU-only and safe here.
# This avoids any cross-process context passing that fails under ZeroGPU isolation.
seg_path = _extract_segment_clip(
meta["silent_video"], seg_start, seg_dur,
os.path.join(tempfile.mkdtemp(), "regen_seg.mp4"),
)
rng = torch.Generator(device=device)
rng.manual_seed(random.randint(0, 2**32 - 1))
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=int(num_steps))
video_info = load_video(seg_path, seg_dur)
clip_frames = video_info.clip_frames.unsqueeze(0)
sync_frames = video_info.sync_frames.unsqueeze(0)
actual_dur = video_info.duration_sec
seq_cfg.duration = actual_dur
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
with torch.no_grad():
audios = generate(
clip_frames, sync_frames, [prompt],
negative_text=[negative_prompt] if negative_prompt else None,
feature_utils=feature_utils, net=net, fm=fm, rng=rng,
cfg_strength=float(cfg_strength),
)
new_wav = audios.float().cpu()[0].numpy()
seg_samples = int(round(seg_dur * sr))
new_wav = new_wav[:, :seg_samples]
return new_wav, sr
def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id):
"""Regenerate one MMAudio segment. GPU inference + CPU splice/save."""
meta = json.loads(seg_meta_json)
seg_idx = int(seg_idx)
# GPU: inference (segment clip extraction happens inside the GPU function)
new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db,
slot_id)
# Resample to 48kHz if needed (MMAudio outputs at 44100 Hz)
if sr != TARGET_SR:
print(f"[MMAudio regen upsample] {sr}Hz β†’ {TARGET_SR}Hz (sinc, CPU) …")
new_wav = _resample_to_target(new_wav, sr)
sr = TARGET_SR
meta["sr"] = sr
# CPU: splice, stitch, mux, save
video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
new_wav, seg_idx, meta, slot_id
)
return video_path, audio_path, json.dumps(updated_meta), waveform_html
def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db, slot_id=None):
return _estimate_regen_duration("hunyuan", int(num_steps))
@spaces.GPU(duration=_hunyuan_regen_duration)
def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db, slot_id=None):
"""GPU-only HunyuanFoley regen β€” returns (new_wav, sr) for a single segment."""
meta = json.loads(seg_meta_json)
seg_idx = int(seg_idx)
seg_start, seg_end = meta["segments"][seg_idx]
seg_dur = seg_end - seg_start
_ensure_syspath("HunyuanVideo-Foley")
from hunyuanvideo_foley.utils.model_utils import denoise_process
from hunyuanvideo_foley.utils.feature_utils import feature_process
device, _ = _get_device_and_dtype()
device = torch.device(device)
model_dict, cfg = _load_hunyuan_model(device, model_size)
set_global_seed(random.randint(0, 2**32 - 1))
# Extract segment clip inside the GPU function β€” ffmpeg is CPU-only and safe here.
seg_path = _extract_segment_clip(
meta["silent_video"], seg_start, seg_dur,
os.path.join(tempfile.mkdtemp(), "regen_seg.mp4"),
)
text_feats_path = meta.get("text_feats_path", "")
if text_feats_path and os.path.exists(text_feats_path):
print("[HunyuanFoley regen] Loading cached text features from disk")
from hunyuanvideo_foley.utils.feature_utils import encode_video_features
visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict)
text_feats = torch.load(text_feats_path, map_location=device, weights_only=False)
else:
print("[HunyuanFoley regen] Cache miss β€” extracting text + visual features")
visual_feats, text_feats, seg_audio_len = feature_process(
seg_path, prompt if prompt else "", model_dict, cfg,
neg_prompt=negative_prompt if negative_prompt else None,
)
audio_batch, sr = denoise_process(
visual_feats, text_feats, seg_audio_len, model_dict, cfg,
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_steps),
batch_size=1,
)
new_wav = audio_batch[0].float().cpu().numpy()
seg_samples = int(round(seg_dur * sr))
new_wav = new_wav[:, :seg_samples]
return new_wav, sr
def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db, slot_id):
"""Regenerate one HunyuanFoley segment. GPU inference + CPU splice/save."""
meta = json.loads(seg_meta_json)
seg_idx = int(seg_idx)
# GPU: inference (segment clip extraction happens inside the GPU function)
new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db, slot_id)
meta["sr"] = sr
# CPU: splice, stitch, mux, save
video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
new_wav, seg_idx, meta, slot_id
)
return video_path, audio_path, json.dumps(updated_meta), waveform_html
# Wire up regen_fn references now that the functions are defined
MODEL_CONFIGS["taro"]["regen_fn"] = regen_taro_segment
MODEL_CONFIGS["mmaudio"]["regen_fn"] = regen_mmaudio_segment
MODEL_CONFIGS["hunyuan"]["regen_fn"] = regen_hunyuan_segment
# ================================================================== #
# CROSS-MODEL REGEN WRAPPERS #
# ================================================================== #
# Three shared endpoints β€” one per model β€” that can be called from #
# *any* slot tab. slot_id is passed as plain string data so the #
# result is applied back to the correct slot by the JS listener. #
# The new segment is resampled to match the slot's existing SR before #
# being handed to _splice_and_save, so TARO (16 kHz) / MMAudio #
# (44.1 kHz) / Hunyuan (48 kHz) outputs can all be mixed freely. #
# ================================================================== #
def _resample_to_slot_sr(wav: np.ndarray, src_sr: int, dst_sr: int,
slot_wav_ref: np.ndarray = None) -> np.ndarray:
"""Resample *wav* from src_sr to dst_sr, then match channel layout to
*slot_wav_ref* (the first existing segment in the slot).
TARO is mono (T,), MMAudio/Hunyuan are stereo (C, T). Mixing them
without normalisation causes a shape mismatch in _cf_join. Rules:
- stereo β†’ mono : average channels
- mono β†’ stereo: duplicate the single channel
"""
wav = _resample_to_target(wav, src_sr, dst_sr)
# Match channel layout to the slot's existing segments
if slot_wav_ref is not None:
slot_stereo = slot_wav_ref.ndim == 2
wav_stereo = wav.ndim == 2
if slot_stereo and not wav_stereo:
wav = np.stack([wav, wav], axis=0) # mono β†’ stereo (C, T)
elif not slot_stereo and wav_stereo:
wav = wav.mean(axis=0) # stereo β†’ mono (T,)
return wav
def _xregen_splice(new_wav_raw: np.ndarray, src_sr: int,
meta: dict, seg_idx: int, slot_id: str) -> tuple:
"""Shared epilogue for all xregen_* functions: resample β†’ splice β†’ save.
Returns (video_path, waveform_html)."""
slot_sr = int(meta["sr"])
slot_wavs = _load_seg_wavs(meta["wav_paths"])
new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr, slot_wavs[0])
video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
new_wav, seg_idx, meta, slot_id
)
return video_path, waveform_html
def _xregen_dispatch(state_json: str, seg_idx: int, slot_id: str, infer_fn):
"""Shared generator skeleton for all xregen_* wrappers.
Yields pending HTML immediately, then calls *infer_fn()* β€” a zero-argument
callable that runs model-specific CPU prep + GPU inference and returns
(wav_array, src_sr). For TARO, *infer_fn* should return the wav already
upsampled to 48 kHz; pass TARO_SR_OUT as src_sr.
Yields:
First: (gr.update(), gr.update(value=pending_html)) β€” shown while GPU runs
Second: (gr.update(value=video_path), gr.update(value=waveform_html))
"""
meta = json.loads(state_json)
pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
yield gr.update(), gr.update(value=pending_html)
new_wav_raw, src_sr = infer_fn()
video_path, waveform_html = _xregen_splice(new_wav_raw, src_sr, meta, seg_idx, slot_id)
yield gr.update(value=video_path), gr.update(value=waveform_html)
def xregen_taro(seg_idx, state_json, slot_id,
seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db,
request: gr.Request = None):
"""Cross-model regen: run TARO inference and splice into *slot_id*."""
seg_idx = int(seg_idx)
meta = json.loads(state_json)
def _run():
# CAVP/onset features are loaded from disk paths inside the GPU fn
wav = _regen_taro_gpu(None, seg_idx, state_json,
seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, slot_id)
return _upsample_taro(wav), TARO_SR_OUT # 16 kHz β†’ 48 kHz (CPU)
yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
def xregen_mmaudio(seg_idx, state_json, slot_id,
prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db,
request: gr.Request = None):
"""Cross-model regen: run MMAudio inference and splice into *slot_id*."""
seg_idx = int(seg_idx)
def _run():
# Segment clip extraction happens inside _regen_mmaudio_gpu
wav, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
prompt, negative_prompt, seed_val,
cfg_strength, num_steps,
crossfade_s, crossfade_db, slot_id)
return wav, src_sr
yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
def xregen_hunyuan(seg_idx, state_json, slot_id,
prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db,
request: gr.Request = None):
"""Cross-model regen: run HunyuanFoley inference and splice into *slot_id*."""
seg_idx = int(seg_idx)
def _run():
# Segment clip extraction happens inside _regen_hunyuan_gpu
wav, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db, slot_id)
return wav, src_sr
yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
# ================================================================== #
# SHARED UI HELPERS #
# ================================================================== #
def _register_regen_handlers(tab_prefix, model_key, regen_seg_tb, regen_state_tb,
input_components, slot_vids, slot_waves):
"""Register per-slot regen button handlers for a model tab.
This replaces the three nearly-identical for-loops that previously existed
for TARO, MMAudio, and HunyuanFoley tabs.
Args:
tab_prefix: e.g. "taro", "mma", "hf"
model_key: e.g. "taro", "mmaudio", "hunyuan"
regen_seg_tb: gr.Textbox for seg_idx (render=False)
regen_state_tb: gr.Textbox for state_json (render=False)
input_components: list of Gradio input components (video, seed, etc.)
β€” order must match regen_fn signature after (seg_idx, state_json, video)
slot_vids: list of gr.Video components per slot
slot_waves: list of gr.HTML components per slot
Returns:
list of hidden gr.Buttons (one per slot)
"""
cfg = MODEL_CONFIGS[model_key]
regen_fn = cfg["regen_fn"]
label = cfg["label"]
btns = []
for _i in range(MAX_SLOTS):
_slot_id = f"{tab_prefix}_{_i}"
_btn = gr.Button(render=False, elem_id=f"regen_btn_{_slot_id}")
btns.append(_btn)
print(f"[startup] registering regen handler for slot {_slot_id}")
def _make_regen(_si, _sid, _model_key, _label, _regen_fn):
def _do(seg_idx, state_json, *args):
print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} "
f"state_json_len={len(state_json) if state_json else 0}")
if not state_json:
print(f"[regen {_label}] early-exit: state_json empty")
yield gr.update(), gr.update()
return
lock = _get_slot_lock(_sid)
with lock:
state = json.loads(state_json)
pending_html = _build_regen_pending_html(
state["segments"], int(seg_idx), _sid, ""
)
yield gr.update(), gr.update(value=pending_html)
print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} β€” calling regen")
try:
# args[0] = video, args[1:] = model-specific params
vid, aud, new_meta_json, html = _regen_fn(
args[0], int(seg_idx), state_json, *args[1:], _sid,
)
print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} β€” done, vid={vid!r}")
except Exception as _e:
print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} β€” ERROR: {_e}")
raise
yield gr.update(value=vid), gr.update(value=html)
return _do
_btn.click(
fn=_make_regen(_i, _slot_id, model_key, label, regen_fn),
inputs=[regen_seg_tb, regen_state_tb] + input_components,
outputs=[slot_vids[_i], slot_waves[_i]],
api_name=f"regen_{tab_prefix}_{_i}",
)
return btns
def _pad_outputs(outputs: list) -> list:
"""Flatten (video, audio, seg_meta) triples and pad to MAX_SLOTS * 3 with None.
Each entry in *outputs* must be a (video_path, audio_path, seg_meta) tuple where
seg_meta = {"segments": [...], "audio_path": str, "video_path": str,
"sr": int, "model": str, "crossfade_s": float,
"crossfade_db": float, "wav_paths": list[str]}
"""
result = []
for i in range(MAX_SLOTS):
if i < len(outputs):
result.extend(outputs[i]) # 3 items: video, audio, meta
else:
result.extend([None, None, None])
return result
# ------------------------------------------------------------------ #
# WaveSurfer waveform + segment marker HTML builder #
# ------------------------------------------------------------------ #
def _build_regen_pending_html(segments: list, regen_seg_idx: int, slot_id: str,
hidden_input_id: str) -> str:
"""Return a waveform placeholder shown while a segment is being regenerated.
Renders a dark bar with the active segment highlighted in amber + a spinner.
"""
segs_json = json.dumps(segments)
seg_colors = [c.format(a="0.25") for c in SEG_COLORS]
active_color = "rgba(255,180,0,0.55)"
duration = segments[-1][1] if segments else 1.0
seg_divs = ""
for i, seg in enumerate(segments):
# Draw only the non-overlapping (unique) portion of each segment so that
# overlapping windows don't visually bleed into adjacent segments.
# Each segment owns the region from its own start up to the next segment's
# start (or its own end for the final segment).
seg_start = seg[0]
seg_end = segments[i + 1][0] if i + 1 < len(segments) else seg[1]
left_pct = seg_start / duration * 100
width_pct = (seg_end - seg_start) / duration * 100
color = active_color if i == regen_seg_idx else seg_colors[i % len(seg_colors)]
extra = "border:2px solid #ffb300;animation:wf_pulse 0.8s ease-in-out infinite alternate;" if i == regen_seg_idx else ""
seg_divs += (
f'<div style="position:absolute;top:0;left:{left_pct:.2f}%;'
f'width:{width_pct:.2f}%;height:100%;background:{color};{extra}">'
f'<span style="color:rgba(255,255,255,0.7);font-size:10px;padding:2px 3px;">Seg {i+1}</span>'
f'</div>'
)
spinner = (
'<div style="position:absolute;top:50%;left:50%;transform:translate(-50%,-50%);'
'display:flex;align-items:center;gap:6px;">'
'<div style="width:14px;height:14px;border:2px solid #ffb300;'
'border-top-color:transparent;border-radius:50%;'
'animation:wf_spin 0.7s linear infinite;"></div>'
f'<span style="color:#ffb300;font-size:12px;white-space:nowrap;">'
f'Regenerating Seg {regen_seg_idx+1}…</span>'
'</div>'
)
return f"""
<style>
@keyframes wf_pulse {{from{{opacity:0.5}}to{{opacity:1}}}}
@keyframes wf_spin {{to{{transform:rotate(360deg)}}}}
</style>
<div style="background:#1a1a1a;border-radius:8px;padding:10px;margin-top:6px;">
<div style="position:relative;width:100%;height:80px;background:#1e1e2e;border-radius:4px;overflow:hidden;">
{seg_divs}
{spinner}
</div>
<div style="color:#888;font-size:11px;margin-top:6px;">Regenerating β€” please wait…</div>
</div>
"""
def _build_waveform_html(audio_path: str, segments: list, slot_id: str,
hidden_input_id: str, state_json: str = "",
fn_index: int = -1, video_path: str = "",
crossfade_s: float = 0.0) -> str:
"""Return a self-contained HTML block with a Canvas waveform (display only),
segment boundary markers, and a download link.
Uses Web Audio API + Canvas β€” no external libraries.
The waveform is SILENT. The playhead tracks the Gradio <video> element
in the same slot via its timeupdate event.
"""
if not audio_path or not os.path.exists(audio_path):
return "<p style='color:#888;font-size:12px'>No audio yet.</p>"
# Serve audio via Gradio's file API instead of base64-encoding the entire
# WAV inline. For a 25s stereo 44.1kHz track this saves ~5 MB per slot.
audio_url = f"/gradio_api/file={audio_path}"
segs_json = json.dumps(segments)
seg_colors = [c.format(a="0.35") for c in SEG_COLORS]
# NOTE: Gradio updates gr.HTML via innerHTML which does NOT execute <script> tags.
# Solution: put the entire waveform (canvas + JS) inside an <iframe srcdoc="...">.
# iframes always execute their scripts. The iframe posts messages to the parent for
# segment-click events; the parent listens and fires the Gradio regen trigger.
# For playhead sync, the iframe polls window.parent for a <video> element.
iframe_inner = f"""<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<style>
* {{ margin:0; padding:0; box-sizing:border-box; }}
body {{ background:#1a1a1a; overflow:hidden; }}
#wrap {{ position:relative; width:100%; height:80px; }}
canvas {{ display:block; }}
#cv {{ position:absolute; top:0; left:0; width:100%; height:100%; }}
#cvp {{ position:absolute; top:0; left:0; width:100%; height:100%; pointer-events:none; }}
</style>
</head>
<body>
<div id="wrap">
<canvas id="cv"></canvas>
<canvas id="cvp"></canvas>
</div>
<script>
(function() {{
const SLOT_ID = '{slot_id}';
const segments = {segs_json};
const segColors = {json.dumps(seg_colors)};
const crossfadeSec = {crossfade_s};
let audioDuration = 0;
// ── Popup via postMessage to parent global listener ─────────────────
// The parent page (Gradio) has a global window.addEventListener('message',...)
// set up via gr.Blocks(js=...) that handles popup show/hide and regen trigger.
function showPopup(idx, mx, my) {{
console.log('[wf showPopup] slot='+SLOT_ID+' idx='+idx+' posting to parent');
// Convert iframe-local coords to parent page coords
try {{
const fr = window.frameElement ? window.frameElement.getBoundingClientRect() : {{left:0,top:0}};
window.parent.postMessage({{
type:'wf_popup', action:'show',
slot_id: SLOT_ID, seg_idx: idx,
t0: segments[idx][0], t1: segments[idx][1],
x: mx + fr.left, y: my + fr.top
}}, '*');
console.log('[wf showPopup] postMessage sent OK');
}} catch(e) {{
console.log('[wf showPopup] postMessage fallback, err='+e.message);
window.parent.postMessage({{
type:'wf_popup', action:'show',
slot_id: SLOT_ID, seg_idx: idx,
t0: segments[idx][0], t1: segments[idx][1],
x: mx, y: my
}}, '*');
}}
}}
function hidePopup() {{
window.parent.postMessage({{type:'wf_popup', action:'hide'}}, '*');
}}
// ── Canvas waveform ────────────────────────────────────────────────
const cv = document.getElementById('cv');
const cvp = document.getElementById('cvp');
const wrap= document.getElementById('wrap');
function drawWaveform(channelData, duration) {{
audioDuration = duration;
const dpr = window.devicePixelRatio || 1;
const W = wrap.getBoundingClientRect().width || window.innerWidth || 600;
const H = 80;
cv.width = W * dpr; cv.height = H * dpr;
const ctx = cv.getContext('2d');
ctx.scale(dpr, dpr);
ctx.fillStyle = '#1e1e2e';
ctx.fillRect(0, 0, W, H);
segments.forEach(function(seg, idx) {{
// Color boundary = midpoint of the crossfade zone = where the blend is
// 50/50. This is also where the cut would land if crossfade were 0, and
// where the listener perceptually hears the transition to the next segment.
const x1 = (seg[0] / duration) * W;
const xEnd = idx + 1 < segments.length
? ((segments[idx + 1][0] + crossfadeSec / 2) / duration) * W
: (seg[1] / duration) * W;
ctx.fillStyle = segColors[idx % segColors.length];
ctx.fillRect(x1, 0, xEnd - x1, H);
ctx.fillStyle = 'rgba(255,255,255,0.6)';
ctx.font = '10px sans-serif';
ctx.fillText('Seg '+(idx+1), x1+3, 12);
}});
const samples = channelData.length;
const barW=2, gap=1, step=barW+gap;
const numBars = Math.floor(W / step);
const blockSz = Math.floor(samples / numBars);
ctx.fillStyle = '#4a9eff';
for (let i=0; i<numBars; i++) {{
let max=0;
const s=i*blockSz;
for (let j=0; j<blockSz; j++) {{
const v=Math.abs(channelData[s+j]||0);
if (v>max) max=v;
}}
const barH=Math.max(1, max*H);
ctx.fillRect(i*step, (H-barH)/2, barW, barH);
}}
segments.forEach(function(seg) {{
[seg[0],seg[1]].forEach(function(t) {{
const x=(t/duration)*W;
ctx.strokeStyle='rgba(255,255,255,0.4)';
ctx.lineWidth=1;
ctx.beginPath(); ctx.moveTo(x,0); ctx.lineTo(x,H); ctx.stroke();
}});
}});
// ── Crossfade overlap indicators ──
// The color boundary is at segments[i+1][0] (= seg_i.end - crossfadeSec).
// We centre the hatch on that edge: half the crossfade on each color side.
if (crossfadeSec > 0 && segments.length > 1) {{
for (let i = 0; i < segments.length - 1; i++) {{
// Color edge = segments[i+1][0], hatch spans half on each side
const edgeT = segments[i+1][0];
const overlapStart = edgeT - crossfadeSec / 2;
const overlapEnd = edgeT + crossfadeSec / 2;
const xL = (overlapStart / duration) * W;
const xR = (overlapEnd / duration) * W;
// Diagonal hatch pattern over the overlap zone
ctx.save();
ctx.beginPath();
ctx.rect(xL, 0, xR - xL, H);
ctx.clip();
ctx.strokeStyle = 'rgba(255,255,255,0.35)';
ctx.lineWidth = 1;
const spacing = 6;
for (let lx = xL - H; lx < xR + H; lx += spacing) {{
ctx.beginPath();
ctx.moveTo(lx, H);
ctx.lineTo(lx + H, 0);
ctx.stroke();
}}
ctx.restore();
}}
}}
cv.onclick = function(e) {{
const r=cv.getBoundingClientRect();
const xRel=(e.clientX-r.left)/r.width;
const tClick=xRel*duration;
// Pick the segment whose unique (non-overlapping) region contains the click.
// Each segment owns [seg[0], nextSeg[0]) visually; last segment owns [seg[0], seg[1]].
let hit=-1;
segments.forEach(function(seg,idx){{
const uniqueEnd = idx + 1 < segments.length ? segments[idx+1][0] : seg[1];
if (tClick >= seg[0] && tClick < uniqueEnd) hit = idx;
}});
console.log('[wf click] tClick='+tClick.toFixed(2)+' hit='+hit+' audioDuration='+audioDuration+' segments='+JSON.stringify(segments));
if (hit>=0) showPopup(hit, e.clientX, e.clientY);
else hidePopup();
}};
}}
function drawPlayhead(progress) {{
const dpr = window.devicePixelRatio || 1;
const W = wrap.getBoundingClientRect().width || window.innerWidth || 600;
const H = 80;
if (cvp.width !== W*dpr) {{ cvp.width=W*dpr; cvp.height=H*dpr; }}
const ctx = cvp.getContext('2d');
ctx.clearRect(0,0,W*dpr,H*dpr);
ctx.save();
ctx.scale(dpr,dpr);
const x=progress*W;
ctx.strokeStyle='#fff';
ctx.lineWidth=2;
ctx.beginPath(); ctx.moveTo(x,0); ctx.lineTo(x,H); ctx.stroke();
ctx.restore();
}}
// Poll parent for video time β€” find the video in the same wf_container slot
function findSlotVideo() {{
try {{
const par = window.parent.document;
// Walk up from our iframe to find wf_container_{slot_id}, then find its sibling video
const container = par.getElementById('wf_container_{slot_id}');
if (!container) return par.querySelector('video');
// The video is inside the same gr.Group β€” walk up to find it
let node = container.parentElement;
while (node && node !== par.body) {{
const v = node.querySelector('video');
if (v) return v;
node = node.parentElement;
}}
return null;
}} catch(e) {{ return null; }}
}}
setInterval(function() {{
const vid = findSlotVideo();
if (vid && vid.duration && isFinite(vid.duration) && audioDuration > 0) {{
drawPlayhead(vid.currentTime / vid.duration);
}}
}}, 50);
// ── Fetch + decode audio from Gradio file API ──────────────────────
const audioUrl = '{audio_url}';
fetch(audioUrl)
.then(function(r) {{ return r.arrayBuffer(); }})
.then(function(arrayBuf) {{
const AudioCtx = window.AudioContext || window.webkitAudioContext;
if (!AudioCtx) return;
const tmpCtx = new AudioCtx({{sampleRate:44100}});
tmpCtx.decodeAudioData(arrayBuf,
function(ab) {{
try {{ tmpCtx.close(); }} catch(e) {{}}
function tryDraw() {{
const W = wrap.getBoundingClientRect().width || window.innerWidth;
if (W > 0) {{ drawWaveform(ab.getChannelData(0), ab.duration); }}
else {{ setTimeout(tryDraw, 100); }}
}}
tryDraw();
}},
function(err) {{}}
);
}})
.catch(function(e) {{}});
}})();
</script>
</body>
</html>"""
# Escape for HTML attribute (srcdoc uses HTML entities)
srcdoc = _html.escape(iframe_inner, quote=True)
state_escaped = _html.escape(state_json or "", quote=True)
return f"""
<div id="wf_container_{slot_id}"
data-fn-index="{fn_index}"
data-state="{state_escaped}"
style="background:#1a1a1a;border-radius:8px;padding:10px;margin-top:6px;position:relative;">
<div style="position:relative;width:100%;height:80px;">
<iframe id="wf_iframe_{slot_id}"
srcdoc="{srcdoc}"
sandbox="allow-scripts allow-same-origin"
style="width:100%;height:80px;border:none;border-radius:4px;display:block;"
scrolling="no"></iframe>
</div>
<div style="display:flex;align-items:center;gap:8px;margin-top:6px;">
<span id="wf_statusbar_{slot_id}" style="color:#888;font-size:11px;">Click a segment to regenerate &nbsp;|&nbsp; Playhead syncs to video</span>
<a href="{audio_url}" download="audio_{slot_id}.wav"
style="margin-left:auto;background:#333;color:#eee;border:1px solid #555;
border-radius:4px;padding:3px 10px;font-size:12px;text-decoration:none;">
&#8595; Audio</a>{f'''
<a href="/gradio_api/file={video_path}" download="video_{slot_id}.mp4"
style="background:#333;color:#eee;border:1px solid #555;
border-radius:4px;padding:3px 10px;font-size:12px;text-decoration:none;">
&#8595; Video</a>''' if video_path else ''}
</div>
<div id="wf_seglabel_{slot_id}"
style="color:#aaa;font-size:11px;margin-top:4px;min-height:16px;"></div>
</div>
"""
def _make_output_slots(tab_prefix: str) -> tuple:
"""Build MAX_SLOTS output groups for one tab.
Each slot has: video and waveform HTML.
Regen is triggered via direct Gradio queue API calls from JS (no hidden
trigger textboxes needed β€” DOM event dispatch is unreliable in Gradio 5
Svelte components). State JSON is embedded in the waveform HTML's
data-state attribute and passed directly in the queue API payload.
Returns (grps, vids, waveforms).
"""
grps, vids, waveforms = [], [], []
for i in range(MAX_SLOTS):
slot_id = f"{tab_prefix}_{i}"
with gr.Group(visible=(i == 0)) as g:
vids.append(gr.Video(label=f"Generation {i+1} β€” Video",
elem_id=f"slot_vid_{slot_id}",
show_download_button=False))
waveforms.append(gr.HTML(
value="<p style='color:#888;font-size:12px'>Generate audio to see waveform.</p>",
elem_id=f"slot_wave_{slot_id}",
))
grps.append(g)
return grps, vids, waveforms
def _unpack_outputs(flat: list, n: int, tab_prefix: str) -> list:
"""Turn a flat _pad_outputs list into Gradio update lists.
flat has MAX_SLOTS * 3 items: [vid0, aud0, meta0, vid1, aud1, meta1, ...]
Returns updates for vids + waveforms only (NOT grps).
Group visibility is handled separately via .then() to avoid Gradio 5 SSR
'Too many arguments' caused by mixing gr.Group updates with other outputs.
State JSON is embedded in the waveform HTML data-state attribute so JS
can read it when calling the Gradio queue API for regen.
"""
n = int(n)
vid_updates = []
wave_updates = []
for i in range(MAX_SLOTS):
vid_path = flat[i * 3]
aud_path = flat[i * 3 + 1]
meta = flat[i * 3 + 2]
vid_updates.append(gr.update(value=vid_path))
if aud_path and meta:
slot_id = f"{tab_prefix}_{i}"
state_json = json.dumps(meta)
html = _build_waveform_html(aud_path, meta["segments"], slot_id,
"", state_json=state_json,
video_path=meta.get("video_path", ""),
crossfade_s=float(meta.get("crossfade_s", 0)))
wave_updates.append(gr.update(value=html))
else:
wave_updates.append(gr.update(
value="<p style='color:#888;font-size:12px'>Generate audio to see waveform.</p>"
))
return vid_updates + wave_updates
def _on_video_upload_taro(video_file, num_steps, crossfade_s):
if video_file is None:
return gr.update(maximum=MAX_SLOTS, value=1)
try:
D = get_video_duration(video_file)
max_s = _taro_calc_max_samples(D, int(num_steps), float(crossfade_s))
except Exception:
max_s = MAX_SLOTS
return gr.update(maximum=max_s, value=min(1, max_s))
def _update_slot_visibility(n):
n = int(n)
return [gr.update(visible=(i < n)) for i in range(MAX_SLOTS)]
# ================================================================== #
# GRADIO UI #
# ================================================================== #
_SLOT_CSS = """
/* Responsive video: fills column width, height auto from aspect ratio */
.gradio-video video {
width: 100%;
height: auto;
max-height: 60vh;
object-fit: contain;
}
/* Force two-column layout to stay equal-width */
.gradio-container .gradio-row > .gradio-column {
flex: 1 1 0 !important;
min-width: 0 !important;
max-width: 50% !important;
}
/* Hide the built-in download button on output video slots β€” downloads are
handled by the waveform panel links which always reflect the latest regen. */
[id^="slot_vid_"] .download-icon,
[id^="slot_vid_"] button[aria-label="Download"],
[id^="slot_vid_"] a[download] {
display: none !important;
}
"""
_GLOBAL_JS = """
() => {
// Global postMessage handler for waveform iframe events.
// Runs once on page load (Gradio js= parameter).
// Handles: popup open/close relay, regen trigger via Gradio queue API.
if (window._wf_global_listener) return; // already registered
window._wf_global_listener = true;
// ── ZeroGPU quota attribution ──
// HF Spaces run inside an iframe on huggingface.co. Gradio's own JS client
// gets ZeroGPU auth headers (x-zerogpu-token, x-zerogpu-uuid) by sending a
// postMessage("zerogpu-headers") to the parent frame. The parent responds
// with a Map of headers that must be included on queue/join calls.
// We replicate this exact mechanism so our raw regen fetch() calls are
// attributed to the logged-in user's Pro quota.
function _fetchZerogpuHeaders() {
return new Promise(function(resolve) {
// Check if we're in an HF iframe with zerogpu support
if (typeof window === 'undefined' || window.parent === window || !window.supports_zerogpu_headers) {
console.log('[zerogpu] not in HF iframe or no zerogpu support');
resolve({});
return;
}
// Determine origin β€” same logic as Gradio's client
var hostname = window.location.hostname;
var hfhubdev = 'dev.spaces.huggingface.tech';
var origin = hostname.includes('.dev.')
? 'https://moon-' + hostname.split('.')[1] + '.' + hfhubdev
: 'https://huggingface.co';
// Use MessageChannel just like Gradio's post_message helper
var channel = new MessageChannel();
var done = false;
channel.port1.onmessage = function(ev) {
channel.port1.close();
done = true;
var headers = ev.data;
if (headers && typeof headers === 'object') {
// Convert Map to plain object if needed
var obj = {};
if (typeof headers.forEach === 'function') {
headers.forEach(function(v, k) { obj[k] = v; });
} else {
obj = headers;
}
console.log('[zerogpu] got headers from parent:', Object.keys(obj).join(', '));
resolve(obj);
} else {
resolve({});
}
};
window.parent.postMessage('zerogpu-headers', origin, [channel.port2]);
// Timeout: don't block regen if parent doesn't respond
setTimeout(function() { if (!done) { done = true; channel.port1.close(); resolve({}); } }, 3000);
});
}
// Cache: api_name -> fn_index, built once from gradio_config.dependencies
let _fnIndexCache = null;
function getFnIndex(apiName) {
if (!_fnIndexCache) {
_fnIndexCache = {};
const deps = window.gradio_config && window.gradio_config.dependencies;
if (deps) deps.forEach(function(d, i) {
if (d.api_name) _fnIndexCache[d.api_name] = i;
});
}
return _fnIndexCache[apiName];
}
// Read a component's current DOM value by elem_id.
// For Number/Slider: reads the <input type="number"> or <input type="range">.
// For Textbox/Radio: reads the <textarea> or checked <input type="radio">.
// Returns null if not found.
function readComponentValue(elemId) {
const el = document.getElementById(elemId);
if (!el) return null;
const numInput = el.querySelector('input[type="number"]');
if (numInput) return parseFloat(numInput.value);
const rangeInput = el.querySelector('input[type="range"]');
if (rangeInput) return parseFloat(rangeInput.value);
const radio = el.querySelector('input[type="radio"]:checked');
if (radio) return radio.value;
const ta = el.querySelector('textarea');
if (ta) return ta.value;
const txt = el.querySelector('input[type="text"], input:not([type])');
if (txt) return txt.value;
return null;
}
// Fire regen for a given slot and segment by posting directly to the
// Gradio queue API β€” bypasses Svelte binding entirely.
// targetModel: 'taro' | 'mma' | 'hf' (which model to use for inference)
// If targetModel matches the slot's own prefix, uses the per-slot regen_* endpoint.
// Otherwise uses the shared xregen_* cross-model endpoint.
function fireRegen(slot_id, seg_idx, targetModel) {
// Block if a regen is already in-flight for this slot
if (_regenInFlight[slot_id]) {
console.log('[fireRegen] blocked β€” regen already in-flight for', slot_id);
return;
}
_regenInFlight[slot_id] = true;
const prefix = slot_id.split('_')[0]; // owning tab: 'taro'|'mma'|'hf'
const slotNum = parseInt(slot_id.split('_')[1], 10);
// Decide which endpoint to call
const crossModel = (targetModel !== prefix);
let apiName, data;
// Read state_json from the waveform container data-state attribute
const container = document.getElementById('wf_container_' + slot_id);
const stateJson = container ? (container.getAttribute('data-state') || '') : '';
if (!stateJson) {
console.warn('[fireRegen] no state_json for slot', slot_id);
return;
}
if (!crossModel) {
// ── Same-model regen: per-slot endpoint, video passed as null ──
apiName = 'regen_' + prefix + '_' + slotNum;
if (prefix === 'taro') {
data = [seg_idx, stateJson, null,
readComponentValue('taro_seed'), readComponentValue('taro_cfg'),
readComponentValue('taro_steps'), readComponentValue('taro_mode'),
readComponentValue('taro_cf_dur'), readComponentValue('taro_cf_db')];
} else if (prefix === 'mma') {
data = [seg_idx, stateJson, null,
readComponentValue('mma_prompt'), readComponentValue('mma_neg'),
readComponentValue('mma_seed'), readComponentValue('mma_cfg'),
readComponentValue('mma_steps'),
readComponentValue('mma_cf_dur'), readComponentValue('mma_cf_db')];
} else {
data = [seg_idx, stateJson, null,
readComponentValue('hf_prompt'), readComponentValue('hf_neg'),
readComponentValue('hf_seed'), readComponentValue('hf_guidance'),
readComponentValue('hf_steps'), readComponentValue('hf_size'),
readComponentValue('hf_cf_dur'), readComponentValue('hf_cf_db')];
}
} else {
// ── Cross-model regen: shared xregen_* endpoint ──
// slot_id is passed so the server knows which slot's state to splice into.
// UI params are read from the target model's tab inputs.
if (targetModel === 'taro') {
apiName = 'xregen_taro';
data = [seg_idx, stateJson, slot_id,
readComponentValue('taro_seed'), readComponentValue('taro_cfg'),
readComponentValue('taro_steps'), readComponentValue('taro_mode'),
readComponentValue('taro_cf_dur'), readComponentValue('taro_cf_db')];
} else if (targetModel === 'mma') {
apiName = 'xregen_mmaudio';
data = [seg_idx, stateJson, slot_id,
readComponentValue('mma_prompt'), readComponentValue('mma_neg'),
readComponentValue('mma_seed'), readComponentValue('mma_cfg'),
readComponentValue('mma_steps'),
readComponentValue('mma_cf_dur'), readComponentValue('mma_cf_db')];
} else {
apiName = 'xregen_hunyuan';
data = [seg_idx, stateJson, slot_id,
readComponentValue('hf_prompt'), readComponentValue('hf_neg'),
readComponentValue('hf_seed'), readComponentValue('hf_guidance'),
readComponentValue('hf_steps'), readComponentValue('hf_size'),
readComponentValue('hf_cf_dur'), readComponentValue('hf_cf_db')];
}
}
console.log('[fireRegen] calling api', apiName, 'seg', seg_idx);
// Snapshot current waveform HTML + video src before mutating anything,
// so we can restore on error (e.g. quota exceeded).
var _preRegenWaveHtml = null;
var _preRegenVideoSrc = null;
var waveElSnap = document.getElementById('slot_wave_' + slot_id);
if (waveElSnap) _preRegenWaveHtml = waveElSnap.innerHTML;
var vidElSnap = document.getElementById('slot_vid_' + slot_id);
if (vidElSnap) { var vSnap = vidElSnap.querySelector('video'); if (vSnap) _preRegenVideoSrc = vSnap.getAttribute('src'); }
// Show spinner immediately
const lbl = document.getElementById('wf_seglabel_' + slot_id);
if (lbl) lbl.textContent = 'Regenerating Seg ' + (seg_idx + 1) + '...';
const fnIndex = getFnIndex(apiName);
if (fnIndex === undefined) {
console.warn('[fireRegen] fn_index not found for api_name:', apiName);
return;
}
// Get ZeroGPU auth headers from the HF parent frame (same mechanism
// Gradio's own JS client uses), then fire the regen queue/join call.
// Falls back to user-supplied HF token if zerogpu headers aren't available.
_fetchZerogpuHeaders().then(function(zerogpuHeaders) {
var regenHeaders = {'Content-Type': 'application/json'};
var hasZerogpu = zerogpuHeaders && Object.keys(zerogpuHeaders).length > 0;
if (hasZerogpu) {
// Merge zerogpu headers (x-zerogpu-token, x-zerogpu-uuid)
for (var k in zerogpuHeaders) { regenHeaders[k] = zerogpuHeaders[k]; }
console.log('[fireRegen] using zerogpu headers from parent frame');
} else {
console.warn('[fireRegen] no zerogpu headers available β€” may use anonymous quota');
}
fetch('/gradio_api/queue/join', {
method: 'POST',
credentials: 'include',
headers: regenHeaders,
body: JSON.stringify({
data: data,
fn_index: fnIndex,
api_name: '/' + apiName,
session_hash: window.__gradio_session_hash__,
event_data: null,
trigger_id: null
})
}).then(function(r) { return r.json(); }).then(function(j) {
if (!j.event_id) { console.error('[fireRegen] no event_id:', j); return; }
console.log('[fireRegen] queued, event_id:', j.event_id);
_listenAndApply(j.event_id, slot_id, seg_idx, _preRegenWaveHtml, _preRegenVideoSrc);
}).catch(function(e) {
console.error('[fireRegen] fetch error:', e);
if (lbl) lbl.textContent = 'Error β€” see console';
var sb = document.getElementById('wf_statusbar_' + slot_id);
if (sb) { sb.style.color = '#e05252'; sb.textContent = '\u26a0 Request failed: ' + e.message; }
});
});
}
// Subscribe to Gradio SSE stream for an event and apply outputs to DOM.
// For regen handlers, output[0] = video update, output[1] = waveform HTML update.
function _applyVideoSrc(slot_id, newSrc) {
var vidEl = document.getElementById('slot_vid_' + slot_id);
if (!vidEl) return false;
var video = vidEl.querySelector('video');
if (!video) return false;
if (video.getAttribute('src') === newSrc) return true; // already correct
video.setAttribute('src', newSrc);
video.src = newSrc;
video.load();
console.log('[_applyVideoSrc] applied src to', 'slot_vid_' + slot_id, 'src:', newSrc.slice(-40));
return true;
}
// Toast notification β€” styled like ZeroGPU quota warnings.
function _showRegenToast(message, isError) {
var t = document.createElement('div');
t.style.cssText = 'position:fixed;bottom:24px;left:50%;transform:translateX(-50%);' +
'z-index:2147483647;padding:12px 20px;border-radius:8px;font-family:sans-serif;' +
'font-size:13px;max-width:520px;text-align:center;box-shadow:0 4px 20px rgba(0,0,0,.6);' +
'background:' + (isError ? '#7a1c1c' : '#1c4a1c') + ';color:#fff;' +
'border:1px solid ' + (isError ? '#c0392b' : '#27ae60') + ';' +
'pointer-events:none;';
t.textContent = message;
document.body.appendChild(t);
setTimeout(function() {
t.style.transition = 'opacity 0.5s';
t.style.opacity = '0';
setTimeout(function() { t.parentNode && t.parentNode.removeChild(t); }, 600);
}, isError ? 8000 : 3000);
}
function _listenAndApply(eventId, slot_id, seg_idx, preRegenWaveHtml, preRegenVideoSrc) {
var _pendingVideoSrc = null;
const es = new EventSource('/gradio_api/queue/data?session_hash=' + window.__gradio_session_hash__);
es.onmessage = function(e) {
var msg;
try { msg = JSON.parse(e.data); } catch(_) { return; }
if (msg.event_id !== eventId) return;
if (msg.msg === 'process_generating' || msg.msg === 'process_completed') {
var out = msg.output;
if (out && out.data) {
var vidUpdate = out.data[0];
var waveUpdate = out.data[1];
var newSrc = null;
if (vidUpdate) {
if (vidUpdate.value && vidUpdate.value.video && vidUpdate.value.video.url) newSrc = vidUpdate.value.video.url;
else if (vidUpdate.video && vidUpdate.video.url) newSrc = vidUpdate.video.url;
else if (vidUpdate.value && vidUpdate.value.url) newSrc = vidUpdate.value.url;
else if (typeof vidUpdate.value === 'string') newSrc = vidUpdate.value;
else if (vidUpdate.url) newSrc = vidUpdate.url;
}
if (newSrc) _pendingVideoSrc = newSrc;
var waveHtml = null;
if (waveUpdate) {
if (typeof waveUpdate === 'string') waveHtml = waveUpdate;
else if (waveUpdate.value && typeof waveUpdate.value === 'string') waveHtml = waveUpdate.value;
}
if (waveHtml) {
var waveEl = document.getElementById('slot_wave_' + slot_id);
if (waveEl) {
var inner = waveEl.querySelector('.prose') || waveEl.querySelector('div');
if (inner) inner.innerHTML = waveHtml;
else waveEl.innerHTML = waveHtml;
}
}
}
if (msg.msg === 'process_completed') {
es.close();
_regenInFlight[slot_id] = false;
var errMsg = msg.output && msg.output.error;
var hadError = !!errMsg;
console.log('[fireRegen] completed for', slot_id, 'error:', hadError, errMsg || '');
var lbl = document.getElementById('wf_seglabel_' + slot_id);
if (hadError) {
var toastMsg = typeof errMsg === 'string' ? errMsg : JSON.stringify(errMsg);
// Restore previous waveform HTML and video src
if (preRegenWaveHtml !== null) {
var waveEl2 = document.getElementById('slot_wave_' + slot_id);
if (waveEl2) waveEl2.innerHTML = preRegenWaveHtml;
}
if (preRegenVideoSrc !== null) {
var vidElR = document.getElementById('slot_vid_' + slot_id);
if (vidElR) { var vR = vidElR.querySelector('video'); if (vR) { vR.setAttribute('src', preRegenVideoSrc); vR.src = preRegenVideoSrc; vR.load(); } }
}
// Update the statusbar (query after restore so we get the freshly-restored element)
var isAbort = toastMsg.toLowerCase().indexOf('aborted') !== -1;
var isTimeout = toastMsg.toLowerCase().indexOf('timeout') !== -1;
var failMsg = isAbort || isTimeout
? '\u26a0 GPU cold-start β€” segment unchanged, try again'
: '\u26a0 Regen failed β€” segment unchanged';
var statusBar = document.getElementById('wf_statusbar_' + slot_id);
if (statusBar) {
statusBar.style.color = '#e05252';
statusBar.textContent = failMsg;
setTimeout(function() { statusBar.style.color = '#888'; statusBar.textContent = 'Click a segment to regenerate \u00a0|\u00a0 Playhead syncs to video'; }, 8000);
}
} else {
if (lbl) lbl.textContent = 'Done';
var src = _pendingVideoSrc;
if (src) {
_applyVideoSrc(slot_id, src);
setTimeout(function() { _applyVideoSrc(slot_id, src); }, 50);
setTimeout(function() { _applyVideoSrc(slot_id, src); }, 300);
setTimeout(function() { _applyVideoSrc(slot_id, src); }, 800);
var vidEl = document.getElementById('slot_vid_' + slot_id);
if (vidEl) {
var obs = new MutationObserver(function() { _applyVideoSrc(slot_id, src); });
obs.observe(vidEl, {subtree: true, attributes: true, attributeFilter: ['src'], childList: true});
setTimeout(function() { obs.disconnect(); }, 2000);
}
}
}
}
}
if (msg.msg === 'close_stream') { es.close(); }
};
es.onerror = function() { es.close(); _regenInFlight[slot_id] = false; };
}
// Track in-flight regen per slot β€” prevents queuing multiple jobs from rapid clicks
var _regenInFlight = {};
// Shared popup element created once and reused across all slots
let _popup = null;
let _pendingSlot = null, _pendingIdx = null;
function ensurePopup() {
if (_popup) return _popup;
_popup = document.createElement('div');
_popup.style.cssText = 'display:none;position:fixed;z-index:99999;' +
'background:#2a2a2a;border:1px solid #555;border-radius:6px;' +
'padding:8px 12px;box-shadow:0 4px 16px rgba(0,0,0,.5);font-family:sans-serif;';
var btnStyle = 'color:#fff;border:none;border-radius:4px;padding:5px 10px;' +
'font-size:11px;cursor:pointer;flex:1;';
_popup.innerHTML =
'<div id="_wf_popup_lbl" style="color:#ccc;font-size:11px;margin-bottom:6px;white-space:nowrap;"></div>' +
'<div style="display:flex;gap:5px;">' +
'<button id="_wf_popup_taro" style="background:#1d6fa5;' + btnStyle + '">&#10227; TARO</button>' +
'<button id="_wf_popup_mma" style="background:#2d7a4a;' + btnStyle + '">&#10227; MMAudio</button>' +
'<button id="_wf_popup_hf" style="background:#7a3d8c;' + btnStyle + '">&#10227; Hunyuan</button>' +
'</div>';
document.body.appendChild(_popup);
['taro','mma','hf'].forEach(function(model) {
document.getElementById('_wf_popup_' + model).onclick = function(e) {
e.stopPropagation();
var slot = _pendingSlot, idx = _pendingIdx;
hidePopup();
if (slot !== null && idx !== null) fireRegen(slot, idx, model);
};
});
// Use bubble phase (false) so stopPropagation() on the button click prevents this from firing
document.addEventListener('click', function() { hidePopup(); }, false);
return _popup;
}
function hidePopup() {
if (_popup) _popup.style.display = 'none';
_pendingSlot = null; _pendingIdx = null;
}
window.addEventListener('message', function(e) {
const d = e.data;
console.log('[global msg] received type=' + (d && d.type) + ' action=' + (d && d.action));
if (!d || d.type !== 'wf_popup') return;
const p = ensurePopup();
if (d.action === 'hide') { hidePopup(); return; }
// action === 'show'
_pendingSlot = d.slot_id;
_pendingIdx = d.seg_idx;
const lbl = document.getElementById('_wf_popup_lbl');
if (lbl) lbl.textContent = 'Seg ' + (d.seg_idx + 1) +
' (' + d.t0.toFixed(2) + 's \u2013 ' + d.t1.toFixed(2) + 's)';
p.style.display = 'block';
p.style.left = (d.x + 10) + 'px';
p.style.top = (d.y + 10) + 'px';
requestAnimationFrame(function() {
const r = p.getBoundingClientRect();
if (r.right > window.innerWidth - 8) p.style.left = (window.innerWidth - r.width - 8) + 'px';
if (r.bottom > window.innerHeight - 8) p.style.top = (window.innerHeight - r.height - 8) + 'px';
});
});
}
"""
with gr.Blocks(title="Generate Audio for Video", css=_SLOT_CSS, js=_GLOBAL_JS) as demo:
gr.Markdown(
"# Generate Audio for Video\n"
"Choose a model and upload a video to generate synchronized audio.\n\n"
"| Model | Best for | Avoid for |\n"
"|-------|----------|-----------|\n"
"| **TARO** | Natural, physics-driven impacts β€” footsteps, collisions, water, wind, crackling fire. Excels when the sound is tightly coupled to visible motion without needing a text description. | Dialogue, music, or complex layered soundscapes where semantic context matters. |\n"
"| **MMAudio** | Mixed scenes where you want both visual grounding *and* semantic control via a text prompt β€” e.g. a busy street scene where you want to emphasize the rain rather than the traffic. Great for ambient textures and nuanced sound design. | Pure impact/foley shots where TARO's motion-coupling would be sharper, or cinematic music beds. |\n"
"| **HunyuanFoley** | Cinematic foley requiring high fidelity and explicit creative direction β€” dramatic SFX, layered environmental design, or any scene where you have a clear written description of the desired sound palette. | Quick one-shot clips where you don't want to write a prompt, or raw impact sounds where timing precision matters more than richness. |"
)
with gr.Tabs():
# ---------------------------------------------------------- #
# Tab 1 β€” TARO #
# ---------------------------------------------------------- #
with gr.Tab("TARO"):
with gr.Row():
with gr.Column(scale=1):
taro_video = gr.Video(label="Input Video")
taro_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0, elem_id="taro_seed")
taro_cfg = gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8.0, step=0.5, elem_id="taro_cfg")
taro_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1, elem_id="taro_steps")
taro_mode = gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde", elem_id="taro_mode")
taro_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=4, value=2, step=0.1, elem_id="taro_cf_dur")
taro_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3", elem_id="taro_cf_db")
taro_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1)
taro_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
(taro_slot_grps, taro_slot_vids,
taro_slot_waves) = _make_output_slots("taro")
# Hidden regen plumbing β€” render=False so no DOM element is created,
# avoiding Gradio's "Too many arguments" Svelte validation error.
# JS passes values directly via queue/join data array at the correct
# positional index (these show up as inputs to the fn but have no DOM).
taro_regen_seg = gr.Textbox(value="0", render=False)
taro_regen_state = gr.Textbox(value="", render=False)
for trigger in [taro_video, taro_steps, taro_cf_dur]:
trigger.change(
fn=_on_video_upload_taro,
inputs=[taro_video, taro_steps, taro_cf_dur],
outputs=[taro_samples],
)
taro_samples.change(
fn=_update_slot_visibility,
inputs=[taro_samples],
outputs=taro_slot_grps,
)
def _run_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n):
flat = generate_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n)
return _unpack_outputs(flat, n, "taro")
# Split group visibility into a separate .then() to avoid Gradio 5 SSR
# "Too many arguments" caused by including gr.Group in mixed output lists.
(taro_btn.click(
fn=_run_taro,
inputs=[taro_video, taro_seed, taro_cfg, taro_steps, taro_mode,
taro_cf_dur, taro_cf_db, taro_samples],
outputs=taro_slot_vids + taro_slot_waves,
).then(
fn=_update_slot_visibility,
inputs=[taro_samples],
outputs=taro_slot_grps,
))
# Per-slot regen handlers β€” JS calls /gradio_api/queue/join with
# fn_index (by api_name) + data=[seg_idx, state_json, video, ...params].
taro_regen_btns = _register_regen_handlers(
"taro", "taro", taro_regen_seg, taro_regen_state,
[taro_video, taro_seed, taro_cfg, taro_steps,
taro_mode, taro_cf_dur, taro_cf_db],
taro_slot_vids, taro_slot_waves,
)
# ---------------------------------------------------------- #
# Tab 2 β€” MMAudio #
# ---------------------------------------------------------- #
with gr.Tab("MMAudio"):
with gr.Row():
with gr.Column(scale=1):
mma_video = gr.Video(label="Input Video")
mma_prompt = gr.Textbox(label="Prompt", placeholder="e.g. footsteps on gravel", elem_id="mma_prompt")
mma_neg = gr.Textbox(label="Negative Prompt", value="music", placeholder="music, speech", elem_id="mma_neg")
mma_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0, elem_id="mma_seed")
mma_cfg = gr.Slider(label="CFG Strength", minimum=1, maximum=10, value=4.5, step=0.5, elem_id="mma_cfg")
mma_steps = gr.Slider(label="Steps", minimum=10, maximum=50, value=25, step=1, elem_id="mma_steps")
mma_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=4, value=2, step=0.1, elem_id="mma_cf_dur")
mma_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3", elem_id="mma_cf_db")
mma_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1)
mma_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
(mma_slot_grps, mma_slot_vids,
mma_slot_waves) = _make_output_slots("mma")
# Hidden regen plumbing β€” render=False so no DOM element is created,
# avoiding Gradio's "Too many arguments" Svelte validation error.
mma_regen_seg = gr.Textbox(value="0", render=False)
mma_regen_state = gr.Textbox(value="", render=False)
mma_samples.change(
fn=_update_slot_visibility,
inputs=[mma_samples],
outputs=mma_slot_grps,
)
def _run_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n):
flat = generate_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n)
return _unpack_outputs(flat, n, "mma")
(mma_btn.click(
fn=_run_mmaudio,
inputs=[mma_video, mma_prompt, mma_neg, mma_seed,
mma_cfg, mma_steps, mma_cf_dur, mma_cf_db, mma_samples],
outputs=mma_slot_vids + mma_slot_waves,
).then(
fn=_update_slot_visibility,
inputs=[mma_samples],
outputs=mma_slot_grps,
))
mma_regen_btns = _register_regen_handlers(
"mma", "mmaudio", mma_regen_seg, mma_regen_state,
[mma_video, mma_prompt, mma_neg, mma_seed,
mma_cfg, mma_steps, mma_cf_dur, mma_cf_db],
mma_slot_vids, mma_slot_waves,
)
# ---------------------------------------------------------- #
# Tab 3 β€” HunyuanVideoFoley #
# ---------------------------------------------------------- #
with gr.Tab("HunyuanFoley"):
with gr.Row():
with gr.Column(scale=1):
hf_video = gr.Video(label="Input Video")
hf_prompt = gr.Textbox(label="Prompt", placeholder="e.g. rain hitting a metal roof", elem_id="hf_prompt")
hf_neg = gr.Textbox(label="Negative Prompt", value="noisy, harsh", elem_id="hf_neg")
hf_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0, elem_id="hf_seed")
hf_guidance = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, value=4.5, step=0.5, elem_id="hf_guidance")
hf_steps = gr.Slider(label="Steps", minimum=10, maximum=100, value=50, step=5, elem_id="hf_steps")
hf_size = gr.Radio(label="Model Size", choices=["xl", "xxl"], value="xxl", elem_id="hf_size")
hf_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=4, value=2, step=0.1, elem_id="hf_cf_dur")
hf_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3", elem_id="hf_cf_db")
hf_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1)
hf_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
(hf_slot_grps, hf_slot_vids,
hf_slot_waves) = _make_output_slots("hf")
# Hidden regen plumbing β€” render=False so no DOM element is created,
# avoiding Gradio's "Too many arguments" Svelte validation error.
hf_regen_seg = gr.Textbox(value="0", render=False)
hf_regen_state = gr.Textbox(value="", render=False)
hf_samples.change(
fn=_update_slot_visibility,
inputs=[hf_samples],
outputs=hf_slot_grps,
)
def _run_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n):
flat = generate_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n)
return _unpack_outputs(flat, n, "hf")
(hf_btn.click(
fn=_run_hunyuan,
inputs=[hf_video, hf_prompt, hf_neg, hf_seed,
hf_guidance, hf_steps, hf_size, hf_cf_dur, hf_cf_db, hf_samples],
outputs=hf_slot_vids + hf_slot_waves,
).then(
fn=_update_slot_visibility,
inputs=[hf_samples],
outputs=hf_slot_grps,
))
hf_regen_btns = _register_regen_handlers(
"hf", "hunyuan", hf_regen_seg, hf_regen_state,
[hf_video, hf_prompt, hf_neg, hf_seed,
hf_guidance, hf_steps, hf_size, hf_cf_dur, hf_cf_db],
hf_slot_vids, hf_slot_waves,
)
# ---- Browser-safe transcode on upload ----
# Gradio serves the original uploaded file to the browser preview widget,
# so H.265 sources show as blank. We re-encode to H.264 on upload and feed
# the result back so the preview plays. mux_video_audio already re-encodes
# to H.264 during generation, so no double-conversion conflict.
taro_video.upload(fn=_transcode_for_browser, inputs=[taro_video], outputs=[taro_video])
mma_video.upload(fn=_transcode_for_browser, inputs=[mma_video], outputs=[mma_video])
hf_video.upload(fn=_transcode_for_browser, inputs=[hf_video], outputs=[hf_video])
# ---- Cross-tab video sync ----
_sync = lambda v: (gr.update(value=v), gr.update(value=v))
taro_video.change(fn=_sync, inputs=[taro_video], outputs=[mma_video, hf_video])
mma_video.change(fn=_sync, inputs=[mma_video], outputs=[taro_video, hf_video])
hf_video.change(fn=_sync, inputs=[hf_video], outputs=[taro_video, mma_video])
# ---- Cross-model regen endpoints ----
# render=False inputs/outputs: no DOM elements created, no SSR validation impact.
# JS calls these via /gradio_api/queue/join using the api_name and applies
# the returned video+waveform directly to the target slot's DOM elements.
_xr_seg = gr.Textbox(value="0", render=False)
_xr_state = gr.Textbox(value="", render=False)
_xr_slot_id = gr.Textbox(value="", render=False)
# Dummy outputs for xregen events: must be real rendered components so Gradio
# can look them up in session state during postprocess_data. The JS listener
# (_listenAndApply) applies the returned video/HTML directly to the correct
# slot's DOM elements and ignores Gradio's own output routing, so these
# slot-0 components simply act as sinks β€” their displayed value is overwritten
# by the real JS update immediately after.
_xr_dummy_vid = taro_slot_vids[0]
_xr_dummy_wave = taro_slot_waves[0]
# TARO cross-model regen inputs: seg_idx, state_json, slot_id, seed, cfg, steps, mode, cf_dur, cf_db
_xr_taro_seed = gr.Textbox(value="-1", render=False)
_xr_taro_cfg = gr.Textbox(value="7.5", render=False)
_xr_taro_steps = gr.Textbox(value="25", render=False)
_xr_taro_mode = gr.Textbox(value="sde", render=False)
_xr_taro_cfd = gr.Textbox(value="2", render=False)
_xr_taro_cfdb = gr.Textbox(value="3", render=False)
gr.Button(render=False).click(
fn=xregen_taro,
inputs=[_xr_seg, _xr_state, _xr_slot_id,
_xr_taro_seed, _xr_taro_cfg, _xr_taro_steps,
_xr_taro_mode, _xr_taro_cfd, _xr_taro_cfdb],
outputs=[_xr_dummy_vid, _xr_dummy_wave],
api_name="xregen_taro",
)
# MMAudio cross-model regen inputs: seg_idx, state_json, slot_id, prompt, neg, seed, cfg, steps, cf_dur, cf_db
_xr_mma_prompt = gr.Textbox(value="", render=False)
_xr_mma_neg = gr.Textbox(value="", render=False)
_xr_mma_seed = gr.Textbox(value="-1", render=False)
_xr_mma_cfg = gr.Textbox(value="4.5", render=False)
_xr_mma_steps = gr.Textbox(value="25", render=False)
_xr_mma_cfd = gr.Textbox(value="2", render=False)
_xr_mma_cfdb = gr.Textbox(value="3", render=False)
gr.Button(render=False).click(
fn=xregen_mmaudio,
inputs=[_xr_seg, _xr_state, _xr_slot_id,
_xr_mma_prompt, _xr_mma_neg, _xr_mma_seed,
_xr_mma_cfg, _xr_mma_steps, _xr_mma_cfd, _xr_mma_cfdb],
outputs=[_xr_dummy_vid, _xr_dummy_wave],
api_name="xregen_mmaudio",
)
# HunyuanFoley cross-model regen inputs: seg_idx, state_json, slot_id, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db
_xr_hf_prompt = gr.Textbox(value="", render=False)
_xr_hf_neg = gr.Textbox(value="", render=False)
_xr_hf_seed = gr.Textbox(value="-1", render=False)
_xr_hf_guide = gr.Textbox(value="4.5", render=False)
_xr_hf_steps = gr.Textbox(value="50", render=False)
_xr_hf_size = gr.Textbox(value="xxl", render=False)
_xr_hf_cfd = gr.Textbox(value="2", render=False)
_xr_hf_cfdb = gr.Textbox(value="3", render=False)
gr.Button(render=False).click(
fn=xregen_hunyuan,
inputs=[_xr_seg, _xr_state, _xr_slot_id,
_xr_hf_prompt, _xr_hf_neg, _xr_hf_seed,
_xr_hf_guide, _xr_hf_steps, _xr_hf_size,
_xr_hf_cfd, _xr_hf_cfdb],
outputs=[_xr_dummy_vid, _xr_dummy_wave],
api_name="xregen_hunyuan",
)
# NOTE: ZeroGPU quota attribution is handled via postMessage("zerogpu-headers")
# to the HF parent frame β€” the same mechanism Gradio's own JS client uses.
# This replaced the old x-ip-token relay approach which was unreliable.
print("[startup] app.py fully loaded β€” regen handlers registered, SSR disabled")
demo.queue(max_size=10).launch(ssr_mode=False, height=900, allowed_paths=["/tmp"])