BoxOfColors's picture
Replace FlashSR with sinc resampling for ZeroGPU compatibility
dbba693
raw
history blame
124 kB
"""
Generate Audio for Video — multi-model Gradio app.
Supported models
----------------
TARO – video-conditioned diffusion via CAVP + onset features (16 kHz, 8.192 s window)
MMAudio – multimodal flow-matching with CLIP/Synchformer + text prompt (44 kHz, 8 s window)
HunyuanFoley – text-guided foley via SigLIP2 + Synchformer + CLAP (48 kHz, up to 15 s)
"""
import html as _html
import os
import sys
import json
import shutil
import tempfile
import random
import threading
import time
from pathlib import Path
import torch
import numpy as np
import torchaudio
import ffmpeg
import spaces
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download
# ================================================================== #
# CHECKPOINT CONFIGURATION #
# ================================================================== #
CKPT_REPO_ID = "JackIsNotInTheBox/Generate_Audio_for_Video_Checkpoints"
CACHE_DIR = "/tmp/model_ckpts"
os.makedirs(CACHE_DIR, exist_ok=True)
# ---- TARO checkpoints (in TARO/ subfolder of the HF repo) ----
print("Downloading TARO checkpoints…")
cavp_ckpt_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/cavp_epoch66.ckpt", cache_dir=CACHE_DIR)
onset_ckpt_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/onset_model.ckpt", cache_dir=CACHE_DIR)
taro_ckpt_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/taro_ckpt.pt", cache_dir=CACHE_DIR)
print("TARO checkpoints downloaded.")
# ---- MMAudio checkpoints (in MMAudio/ subfolder) ----
# MMAudio normally auto-downloads from its own HF repo, but we
# override the paths so it pulls from our consolidated repo instead.
MMAUDIO_WEIGHTS_DIR = Path(CACHE_DIR) / "MMAudio" / "weights"
MMAUDIO_EXT_DIR = Path(CACHE_DIR) / "MMAudio" / "ext_weights"
MMAUDIO_WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
MMAUDIO_EXT_DIR.mkdir(parents=True, exist_ok=True)
print("Downloading MMAudio checkpoints…")
mmaudio_model_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/mmaudio_large_44k_v2.pth", cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_WEIGHTS_DIR), local_dir_use_symlinks=False)
mmaudio_vae_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/v1-44.pth", cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False)
mmaudio_synchformer_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/synchformer_state_dict.pth", cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False)
print("MMAudio checkpoints downloaded.")
# ---- HunyuanVideoFoley checkpoints (in HunyuanFoley/ subfolder) ----
HUNYUAN_MODEL_DIR = Path(CACHE_DIR) / "HunyuanFoley"
HUNYUAN_MODEL_DIR.mkdir(parents=True, exist_ok=True)
print("Downloading HunyuanVideoFoley checkpoints…")
hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/hunyuanvideo_foley.pth", cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/vae_128d_48k.pth", cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/synchformer_state_dict.pth", cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
print("HunyuanVideoFoley checkpoints downloaded.")
# Pre-download CLAP model so from_pretrained() reads local cache inside the
# ZeroGPU daemonic worker (spawning child processes there is not allowed).
print("Pre-downloading CLAP model (laion/larger_clap_general)…")
snapshot_download(repo_id="laion/larger_clap_general")
print("CLAP model pre-downloaded.")
# ================================================================== #
# SHARED CONSTANTS / HELPERS #
# ================================================================== #
MAX_SLOTS = 8 # max parallel generation slots shown in UI
MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≤ ~64 s at 8 s/seg)
# Segment overlay palette — shared between _build_waveform_html and _build_regen_pending_html
SEG_COLORS = [
"rgba(100,180,255,{a})", "rgba(255,160,100,{a})",
"rgba(120,220,140,{a})", "rgba(220,120,220,{a})",
"rgba(255,220,80,{a})", "rgba(80,220,220,{a})",
"rgba(255,100,100,{a})", "rgba(180,255,180,{a})",
]
# ------------------------------------------------------------------ #
# Micro-helpers that eliminate repeated boilerplate across the file #
# ------------------------------------------------------------------ #
def _ensure_syspath(subdir: str) -> str:
"""Add *subdir* (relative to app.py) to sys.path if not already present.
Returns the absolute path for convenience."""
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), subdir)
if p not in sys.path:
sys.path.insert(0, p)
return p
def _get_device_and_dtype() -> tuple:
"""Return (device, weight_dtype) pair used by all GPU functions."""
device = "cuda" if torch.cuda.is_available() else "cpu"
return device, torch.bfloat16
def _extract_segment_clip(silent_video: str, seg_start: float, seg_dur: float,
output_path: str) -> str:
"""Stream-copy a segment from *silent_video* to *output_path*. Returns *output_path*."""
ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
output_path, vcodec="copy", an=None
).run(overwrite_output=True, quiet=True)
return output_path
# Per-slot reentrant locks — prevent concurrent regens on the same slot from
# producing a race condition where the second regen reads stale state
# (the shared seg_state textbox hasn't been updated yet by the first regen).
# Locks are keyed by slot_id string (e.g. "taro_0", "mma_2").
_SLOT_LOCKS: dict = {}
_SLOT_LOCKS_MUTEX = threading.Lock()
def _get_slot_lock(slot_id: str) -> threading.Lock:
with _SLOT_LOCKS_MUTEX:
if slot_id not in _SLOT_LOCKS:
_SLOT_LOCKS[slot_id] = threading.Lock()
return _SLOT_LOCKS[slot_id]
def set_global_seed(seed: int) -> None:
np.random.seed(seed % (2**32))
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def get_random_seed() -> int:
return random.randint(0, 2**32 - 1)
def get_video_duration(video_path: str) -> float:
"""Return video duration in seconds (CPU only)."""
probe = ffmpeg.probe(video_path)
return float(probe["format"]["duration"])
def strip_audio_from_video(video_path: str, output_path: str) -> None:
"""Write a silent copy of *video_path* to *output_path* (stream-copy, no re-encode)."""
ffmpeg.input(video_path).output(output_path, vcodec="copy", an=None).run(
overwrite_output=True, quiet=True
)
# ------------------------------------------------------------------ #
# Temp directory registry — tracks dirs for cleanup on new generation #
# ------------------------------------------------------------------ #
_TEMP_DIRS: list = [] # list of tmp_dir paths created by generate_*
_TEMP_DIRS_MAX = 10 # keep at most this many; older ones get cleaned up
def _register_tmp_dir(tmp_dir: str) -> str:
"""Register a temp dir so it can be cleaned up when newer ones replace it."""
_TEMP_DIRS.append(tmp_dir)
while len(_TEMP_DIRS) > _TEMP_DIRS_MAX:
old = _TEMP_DIRS.pop(0)
try:
shutil.rmtree(old, ignore_errors=True)
print(f"[cleanup] Removed old temp dir: {old}")
except Exception:
pass
return tmp_dir
def _save_seg_wavs(wavs: list[np.ndarray], tmp_dir: str, prefix: str) -> list[str]:
"""Save a list of numpy wav arrays to .npy files, return list of paths.
This avoids serialising large float arrays into JSON/HTML data-state."""
paths = []
for i, w in enumerate(wavs):
p = os.path.join(tmp_dir, f"{prefix}_seg{i}.npy")
np.save(p, w)
paths.append(p)
return paths
def _load_seg_wavs(paths: list[str]) -> list[np.ndarray]:
"""Load segment wav arrays from .npy file paths."""
return [np.load(p) for p in paths]
# ------------------------------------------------------------------ #
# Shared model-loading helpers (deduplicate generate / regen code) #
# ------------------------------------------------------------------ #
def _load_taro_models(device, weight_dtype):
"""Load TARO MMDiT + AudioLDM2 VAE/vocoder. Returns (model_net, vae, vocoder, latents_scale)."""
from TARO.models import MMDiT
from diffusers import AutoencoderKL
from transformers import SpeechT5HifiGan
model_net = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
model_net.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"])
model_net.eval().to(weight_dtype)
vae = AutoencoderKL.from_pretrained("cvssp/audioldm2", subfolder="vae").to(device).eval()
vocoder = SpeechT5HifiGan.from_pretrained("cvssp/audioldm2", subfolder="vocoder").to(device)
latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
return model_net, vae, vocoder, latents_scale
def _load_taro_feature_extractors(device):
"""Load CAVP + onset extractors. Returns (extract_cavp, onset_model)."""
from TARO.cavp_util import Extract_CAVP_Features
from TARO.onset_util import VideoOnsetNet
extract_cavp = Extract_CAVP_Features(
device=device, config_path="TARO/cavp/cavp.yaml", ckpt_path=cavp_ckpt_path,
)
raw_sd = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
onset_sd = {}
for k, v in raw_sd.items():
if "model.net.model" in k: k = k.replace("model.net.model", "net.model")
elif "model.fc." in k: k = k.replace("model.fc", "fc")
onset_sd[k] = v
onset_model = VideoOnsetNet(pretrained=False).to(device)
onset_model.load_state_dict(onset_sd)
onset_model.eval()
return extract_cavp, onset_model
def _load_mmaudio_models(device, dtype):
"""Load MMAudio net + feature_utils. Returns (net, feature_utils, model_cfg, seq_cfg)."""
from mmaudio.eval_utils import all_model_cfg
from mmaudio.model.networks import get_my_mmaudio
from mmaudio.model.utils.features_utils import FeaturesUtils
model_cfg = all_model_cfg["large_44k_v2"]
model_cfg.model_path = Path(mmaudio_model_path)
model_cfg.vae_path = Path(mmaudio_vae_path)
model_cfg.synchformer_ckpt = Path(mmaudio_synchformer_path)
model_cfg.bigvgan_16k_path = None
seq_cfg = model_cfg.seq_cfg
net = get_my_mmaudio(model_cfg.model_name).to(device, dtype).eval()
net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True))
feature_utils = FeaturesUtils(
tod_vae_ckpt=str(model_cfg.vae_path),
synchformer_ckpt=str(model_cfg.synchformer_ckpt),
enable_conditions=True, mode=model_cfg.mode,
bigvgan_vocoder_ckpt=None, need_vae_encoder=False,
).to(device, dtype).eval()
return net, feature_utils, model_cfg, seq_cfg
def _load_hunyuan_model(device, model_size):
"""Load HunyuanFoley model dict + config. Returns (model_dict, cfg)."""
from hunyuanvideo_foley.utils.model_utils import load_model
model_size = model_size.lower()
config_map = {
"xl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xl.yaml",
"xxl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xxl.yaml",
}
config_path = config_map.get(model_size, config_map["xxl"])
hunyuan_weights_dir = str(HUNYUAN_MODEL_DIR / "HunyuanVideo-Foley")
print(f"[HunyuanFoley] Loading {model_size.upper()} model from {hunyuan_weights_dir}")
return load_model(hunyuan_weights_dir, config_path, device,
enable_offload=False, model_size=model_size)
def mux_video_audio(silent_video: str, audio_path: str, output_path: str) -> None:
"""Mux a silent video with an audio file into *output_path* (stream-copy video, encode audio)."""
ffmpeg.output(
ffmpeg.input(silent_video),
ffmpeg.input(audio_path),
output_path,
vcodec="copy", acodec="aac", strict="experimental",
).run(overwrite_output=True, quiet=True)
# ------------------------------------------------------------------ #
# Shared sliding-window segmentation and crossfade helpers #
# Used by all three models (TARO, MMAudio, HunyuanFoley). #
# ------------------------------------------------------------------ #
def _build_segments(total_dur_s: float, window_s: float, crossfade_s: float) -> list[tuple[float, float]]:
"""Return list of (start, end) pairs covering *total_dur_s* with a sliding
window of *window_s* and *crossfade_s* overlap between consecutive segments."""
# Safety: clamp crossfade to < half the window so step_s stays positive
crossfade_s = min(crossfade_s, window_s * 0.5)
if total_dur_s <= window_s:
return [(0.0, total_dur_s)]
step_s = window_s - crossfade_s
segments, seg_start = [], 0.0
while True:
if seg_start + window_s >= total_dur_s:
seg_start = max(0.0, total_dur_s - window_s)
segments.append((seg_start, total_dur_s))
break
segments.append((seg_start, seg_start + window_s))
seg_start += step_s
return segments
def _cf_join(a: np.ndarray, b: np.ndarray,
crossfade_s: float, db_boost: float, sr: int) -> np.ndarray:
"""Equal-power crossfade join. Works for both mono (T,) and stereo (C, T) arrays.
Stereo arrays are expected in (channels, samples) layout.
db_boost is applied to the overlap region as a whole (after blending), so
it compensates for the -3 dB equal-power dip without doubling amplitude.
Applying gain to each side independently (the common mistake) causes a
+3 dB loudness bump at the seam — this version avoids that."""
stereo = a.ndim == 2
n_a = a.shape[1] if stereo else len(a)
n_b = b.shape[1] if stereo else len(b)
cf = min(int(round(crossfade_s * sr)), n_a, n_b)
if cf <= 0:
return np.concatenate([a, b], axis=1 if stereo else 0)
gain = 10 ** (db_boost / 20.0)
t = np.linspace(0.0, 1.0, cf, dtype=np.float32)
fade_out = np.cos(t * np.pi / 2) # 1 → 0
fade_in = np.sin(t * np.pi / 2) # 0 → 1
if stereo:
# Blend first, then apply boost to the overlap region as a unit
overlap = (a[:, -cf:] * fade_out + b[:, :cf] * fade_in) * gain
return np.concatenate([a[:, :-cf], overlap, b[:, cf:]], axis=1)
else:
overlap = (a[-cf:] * fade_out + b[:cf] * fade_in) * gain
return np.concatenate([a[:-cf], overlap, b[cf:]])
# ================================================================== #
# TARO #
# ================================================================== #
# Constants sourced from TARO/infer.py and TARO/models.py:
# SR=16000, TRUNCATE=131072 → 8.192 s window
# TRUNCATE_FRAME = 4 fps × 131072/16000 = 32 CAVP frames per window
# TRUNCATE_ONSET = 120 onset frames per window
# latent shape: (1, 8, 204, 16) — fixed by MMDiT architecture
# latents_scale: [0.18215]*8 — AudioLDM2 VAE scale factor
# ================================================================== #
TARO_SR = 16000
TARO_TRUNCATE = 131072
TARO_FPS = 4
TARO_TRUNCATE_FRAME = int(TARO_FPS * TARO_TRUNCATE / TARO_SR) # 32
TARO_TRUNCATE_ONSET = 120
TARO_MODEL_DUR = TARO_TRUNCATE / TARO_SR # 8.192 s
TARO_SECS_PER_STEP = 0.05 # measured 0.043s/step on H200 (8.2s video, 2 segs × 25 steps = 2.2s wall)
TARO_LOAD_OVERHEAD = 15 # seconds: model load + CAVP feature extraction
MMAUDIO_WINDOW = 8.0 # seconds — MMAudio's fixed generation window
MMAUDIO_SECS_PER_STEP = 0.25 # measured 0.230s/step on H200 (8.3s video, 2 segs × 25 steps = 11.5s wall)
MMAUDIO_LOAD_OVERHEAD = 15
HUNYUAN_MAX_DUR = 15.0 # seconds — HunyuanFoley max video duration
HUNYUAN_SECS_PER_STEP = 0.35 # measured 0.328s/step on H200 (8.3s video, 1 seg × 50 steps = 16.4s wall)
HUNYUAN_LOAD_OVERHEAD = 55 # ~55s to load the 10GB XXL model weights into GPU
GPU_DURATION_CAP = 300 # hard cap per call — never reserve more than this
# ------------------------------------------------------------------ #
# Model configuration registry — single source of truth for per-model #
# constants used by duration estimation, segmentation, and UI. #
# ------------------------------------------------------------------ #
MODEL_CONFIGS = {
"taro": {
"window_s": TARO_MODEL_DUR, # 8.192 s
"sr": TARO_SR, # 16000
"secs_per_step": TARO_SECS_PER_STEP, # 0.05
"load_overhead": TARO_LOAD_OVERHEAD, # 15
"tab_prefix": "taro",
"regen_fn": None, # set after function definitions (avoids forward-ref)
"label": "TARO",
},
"mmaudio": {
"window_s": MMAUDIO_WINDOW, # 8.0 s
"sr": 44100,
"secs_per_step": MMAUDIO_SECS_PER_STEP, # 0.25
"load_overhead": MMAUDIO_LOAD_OVERHEAD, # 15
"tab_prefix": "mma",
"regen_fn": None,
"label": "MMAudio",
},
"hunyuan": {
"window_s": HUNYUAN_MAX_DUR, # 15.0 s
"sr": 48000,
"secs_per_step": HUNYUAN_SECS_PER_STEP, # 0.35
"load_overhead": HUNYUAN_LOAD_OVERHEAD, # 55
"tab_prefix": "hf",
"regen_fn": None,
"label": "HunyuanFoley",
},
}
def _estimate_gpu_duration(model_key: str, num_samples: int, num_steps: int,
total_dur_s: float = None, crossfade_s: float = 0,
video_file: str = None) -> int:
"""Generic GPU duration estimator used by all models.
Computes: num_samples × n_segs × num_steps × secs_per_step + load_overhead
Clamped to [60, GPU_DURATION_CAP].
"""
cfg = MODEL_CONFIGS[model_key]
try:
if total_dur_s is None:
total_dur_s = get_video_duration(video_file)
n_segs = len(_build_segments(total_dur_s, cfg["window_s"], float(crossfade_s)))
except Exception:
n_segs = 1
secs = int(num_samples) * n_segs * int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
result = min(GPU_DURATION_CAP, max(60, int(secs)))
print(f"[duration] {cfg['label']}: {int(num_samples)}samp × {n_segs}seg × "
f"{int(num_steps)}steps → {secs:.0f}s → capped {result}s")
return result
def _estimate_regen_duration(model_key: str, num_steps: int) -> int:
"""Generic GPU duration estimator for single-segment regen.
Uses a lower floor (30s) than initial generation since regen only runs
one segment — saves 30s of wasted ZeroGPU quota per regen call."""
cfg = MODEL_CONFIGS[model_key]
secs = int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
result = min(GPU_DURATION_CAP, max(30, int(secs)))
print(f"[duration] {cfg['label']} regen: 1 seg × {int(num_steps)} steps → {secs:.0f}s → capped {result}s")
return result
_TARO_CACHE_MAXLEN = 16 # evict oldest entries beyond this limit
_TARO_INFERENCE_CACHE: dict = {} # keyed by (video_file, seed, cfg, steps, mode, crossfade_s)
_TARO_CACHE_LOCK = threading.Lock()
def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: float) -> int:
n_segs = len(_build_segments(total_dur_s, TARO_MODEL_DUR, crossfade_s))
time_per_seg = num_steps * TARO_SECS_PER_STEP
max_s = int(600.0 / (n_segs * time_per_seg))
return max(1, min(max_s, MAX_SLOTS))
def _taro_duration(video_file, seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, num_samples):
"""Pre-GPU callable — must match _taro_gpu_infer's input order exactly."""
return _estimate_gpu_duration("taro", int(num_samples), int(num_steps),
video_file=video_file, crossfade_s=crossfade_s)
def _taro_infer_segment(
model, vae, vocoder,
cavp_feats_full, onset_feats_full,
seg_start_s: float, seg_end_s: float,
device, weight_dtype,
cfg_scale: float, num_steps: int, mode: str,
latents_scale,
euler_sampler, euler_maruyama_sampler,
) -> np.ndarray:
"""Single-segment TARO inference. Returns wav array trimmed to segment length."""
# CAVP features (4 fps)
cavp_start = int(round(seg_start_s * TARO_FPS))
cavp_slice = cavp_feats_full[cavp_start : cavp_start + TARO_TRUNCATE_FRAME]
if cavp_slice.shape[0] < TARO_TRUNCATE_FRAME:
pad = np.zeros(
(TARO_TRUNCATE_FRAME - cavp_slice.shape[0],) + cavp_slice.shape[1:],
dtype=cavp_slice.dtype,
)
cavp_slice = np.concatenate([cavp_slice, pad], axis=0)
video_feats = torch.from_numpy(cavp_slice).unsqueeze(0).to(device, weight_dtype)
# Onset features (onset_fps = TRUNCATE_ONSET / MODEL_DUR ≈ 14.65 fps)
onset_fps = TARO_TRUNCATE_ONSET / TARO_MODEL_DUR
onset_start = int(round(seg_start_s * onset_fps))
onset_slice = onset_feats_full[onset_start : onset_start + TARO_TRUNCATE_ONSET]
if onset_slice.shape[0] < TARO_TRUNCATE_ONSET:
onset_slice = np.pad(
onset_slice,
((0, TARO_TRUNCATE_ONSET - onset_slice.shape[0]),),
mode="constant",
)
onset_feats_t = torch.from_numpy(onset_slice).unsqueeze(0).to(device, weight_dtype)
# Latent noise — shape matches MMDiT architecture (in_channels=8, 204×16 spatial)
z = torch.randn(1, model.in_channels, 204, 16, device=device, dtype=weight_dtype)
sampling_kwargs = dict(
model=model,
latents=z,
y=onset_feats_t,
context=video_feats,
num_steps=int(num_steps),
heun=False,
cfg_scale=float(cfg_scale),
guidance_low=0.0,
guidance_high=0.7,
path_type="linear",
)
with torch.no_grad():
samples = (euler_maruyama_sampler if mode == "sde" else euler_sampler)(**sampling_kwargs)
# samplers return (output_tensor, zs) — index [0] for the audio latent
if isinstance(samples, tuple):
samples = samples[0]
# Decode: AudioLDM2 VAE → mel → vocoder → waveform
samples = vae.decode(samples / latents_scale).sample
wav = vocoder(samples.squeeze().float()).detach().cpu().numpy()
seg_samples = int(round((seg_end_s - seg_start_s) * TARO_SR))
return wav[:seg_samples]
# ================================================================== #
# FlashSR (16 → 48 kHz) #
# ================================================================== #
# FlashSR is used as a post-processing step on TARO outputs only.
# TARO generates at 16 kHz; FlashSR upsamples to 48 kHz so all three
# models produce output at the same sample rate.
# Model weights are downloaded once from HF Hub and cached on disk.
FLASHSR_SR_IN = 16000
FLASHSR_SR_OUT = 48000
def _apply_flashsr(wav_16k: np.ndarray) -> np.ndarray:
"""Upsample a mono 16 kHz numpy array to 48 kHz using sinc resampling (CPU).
FlashSR was attempted but its dependencies trigger torch.cuda.is_available()
on import, which violates ZeroGPU's stateless-GPU rule and aborts subsequent
GPU tasks. High-quality sinc resampling via torchaudio is ZeroGPU-safe and
produces clean 16→48 kHz output for foley/ambient audio.
"""
print(f"[upsample] {len(wav_16k)/FLASHSR_SR_IN:.2f}s @ 16kHz → 48kHz (sinc, CPU) …")
t = torch.from_numpy(wav_16k.astype(np.float32)).unsqueeze(0)
out = torchaudio.functional.resample(t, FLASHSR_SR_IN, FLASHSR_SR_OUT)
result = out.squeeze().numpy()
print(f"[upsample] Done — {len(result)/FLASHSR_SR_OUT:.2f}s @ {FLASHSR_SR_OUT}Hz")
return result
def _stitch_wavs(wavs: list[np.ndarray], crossfade_s: float, db_boost: float,
total_dur_s: float, sr: int) -> np.ndarray:
"""Crossfade-join a list of wav arrays and trim to *total_dur_s*.
Works for both mono (T,) and stereo (C, T) arrays."""
out = wavs[0]
for nw in wavs[1:]:
out = _cf_join(out, nw, crossfade_s, db_boost, sr)
n = int(round(total_dur_s * sr))
return out[:, :n] if out.ndim == 2 else out[:n]
def _save_wav(path: str, wav: np.ndarray, sr: int) -> None:
"""Save a numpy wav array (mono or stereo) to *path* via torchaudio."""
t = torch.from_numpy(np.ascontiguousarray(wav))
if t.ndim == 1:
t = t.unsqueeze(0)
torchaudio.save(path, t, sr)
def _log_inference_timing(label: str, elapsed: float, n_segs: int,
num_steps: int, constant: float) -> None:
"""Print a standardised inference-timing summary line."""
total_steps = n_segs * num_steps
secs_per_step = elapsed / total_steps if total_steps > 0 else 0
print(f"[{label}] Inference done: {n_segs} seg(s) × {num_steps} steps in "
f"{elapsed:.1f}s wall → {secs_per_step:.3f}s/step "
f"(current constant={constant})")
def _build_seg_meta(*, segments, wav_paths, audio_path, video_path,
silent_video, sr, model, crossfade_s, crossfade_db,
total_dur_s, **extras) -> dict:
"""Build the seg_meta dict shared by all three generate_* functions.
Model-specific keys are passed via **extras."""
meta = {
"segments": segments,
"wav_paths": wav_paths,
"audio_path": audio_path,
"video_path": video_path,
"silent_video": silent_video,
"sr": sr,
"model": model,
"crossfade_s": crossfade_s,
"crossfade_db": crossfade_db,
"total_dur_s": total_dur_s,
}
meta.update(extras)
return meta
def _cpu_preprocess(video_file: str, model_dur: float,
crossfade_s: float) -> tuple:
"""Shared CPU pre-processing for all generate_* wrappers.
Returns (tmp_dir, silent_video, total_dur_s, segments)."""
tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
silent_video = os.path.join(tmp_dir, "silent_input.mp4")
strip_audio_from_video(video_file, silent_video)
total_dur_s = get_video_duration(video_file)
segments = _build_segments(total_dur_s, model_dur, crossfade_s)
return tmp_dir, silent_video, total_dur_s, segments
@spaces.GPU(duration=_taro_duration)
def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, num_samples):
"""GPU-only TARO inference — model loading + feature extraction + diffusion.
Returns list of (wavs_list, onset_feats) per sample."""
seed_val = int(seed_val)
crossfade_s = float(crossfade_s)
num_samples = int(num_samples)
if seed_val < 0:
seed_val = random.randint(0, 2**32 - 1)
torch.set_grad_enabled(False)
device, weight_dtype = _get_device_and_dtype()
_ensure_syspath("TARO")
from TARO.onset_util import extract_onset
from TARO.samplers import euler_sampler, euler_maruyama_sampler
# Use pre-computed CPU results from the wrapper
ctx = _taro_gpu_infer._cpu_ctx
tmp_dir = ctx["tmp_dir"]
silent_video = ctx["silent_video"]
segments = ctx["segments"]
total_dur_s = ctx["total_dur_s"]
extract_cavp, onset_model = _load_taro_feature_extractors(device)
cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
# Onset features depend only on the video — extract once for all samples
onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
# Free feature extractors before loading the heavier inference models
del extract_cavp, onset_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
model, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype)
results = [] # list of (wavs, onset_feats) per sample
for sample_idx in range(num_samples):
sample_seed = seed_val + sample_idx
cache_key = (video_file, sample_seed, float(cfg_scale), int(num_steps), mode, crossfade_s)
with _TARO_CACHE_LOCK:
cached = _TARO_INFERENCE_CACHE.get(cache_key)
if cached is not None:
print(f"[TARO] Sample {sample_idx+1}: cache hit.")
results.append((cached["wavs"], cavp_feats, None))
else:
set_global_seed(sample_seed)
wavs = []
_t_infer_start = time.perf_counter()
for seg_start_s, seg_end_s in segments:
print(f"[TARO] Sample {sample_idx+1} | {seg_start_s:.2f}s – {seg_end_s:.2f}s")
wav = _taro_infer_segment(
model, vae, vocoder,
cavp_feats, onset_feats,
seg_start_s, seg_end_s,
device, weight_dtype,
cfg_scale, num_steps, mode,
latents_scale,
euler_sampler, euler_maruyama_sampler,
)
wavs.append(wav)
_log_inference_timing("TARO", time.perf_counter() - _t_infer_start,
len(segments), int(num_steps), TARO_SECS_PER_STEP)
with _TARO_CACHE_LOCK:
_TARO_INFERENCE_CACHE[cache_key] = {"wavs": wavs}
while len(_TARO_INFERENCE_CACHE) > _TARO_CACHE_MAXLEN:
_TARO_INFERENCE_CACHE.pop(next(iter(_TARO_INFERENCE_CACHE)))
results.append((wavs, cavp_feats, onset_feats))
# Free GPU memory between samples so VRAM fragmentation doesn't
# degrade diffusion quality on samples 2, 3, 4, etc.
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
# Attach a context slot for the CPU wrapper to pass pre-computed data
_taro_gpu_infer._cpu_ctx = {}
def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, num_samples):
"""TARO: video-conditioned diffusion, 16 kHz, 8.192 s sliding window.
CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost."""
crossfade_s = float(crossfade_s)
crossfade_db = float(crossfade_db)
num_samples = int(num_samples)
# ── CPU pre-processing (no GPU needed) ──
tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
video_file, TARO_MODEL_DUR, crossfade_s)
# Pass pre-computed CPU results to the GPU function via context
_taro_gpu_infer._cpu_ctx = {
"tmp_dir": tmp_dir, "silent_video": silent_video,
"segments": segments, "total_dur_s": total_dur_s,
}
# ── GPU inference only ──
results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, num_samples)
# ── CPU post-processing (no GPU needed) ──
# Cache CAVP + onset features once (same for all samples — they depend only on the video)
cavp_path = os.path.join(tmp_dir, "taro_cavp.npy")
onset_path = os.path.join(tmp_dir, "taro_onset.npy")
first_cavp_saved = False
outputs = []
for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
# FlashSR: upsample each segment 16kHz → 48kHz (CPU-only, no GPU needed)
wavs = [_apply_flashsr(w) for w in wavs]
final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, FLASHSR_SR_OUT)
audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
_save_wav(audio_path, final_wav, FLASHSR_SR_OUT)
video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
mux_video_audio(silent_video, audio_path, video_path)
wav_paths = _save_seg_wavs(wavs, tmp_dir, f"taro_{sample_idx}")
# Save shared features once (not per-sample — they're identical)
if not first_cavp_saved:
np.save(cavp_path, cavp_feats)
if onset_feats is not None:
np.save(onset_path, onset_feats)
first_cavp_saved = True
seg_meta = _build_seg_meta(
segments=segments, wav_paths=wav_paths, audio_path=audio_path,
video_path=video_path, silent_video=silent_video, sr=FLASHSR_SR_OUT,
model="taro", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
total_dur_s=total_dur_s, cavp_path=cavp_path, onset_path=onset_path,
)
outputs.append((video_path, audio_path, seg_meta))
return _pad_outputs(outputs)
# ================================================================== #
# MMAudio #
# ================================================================== #
# Constants sourced from MMAudio/mmaudio/model/sequence_config.py:
# CONFIG_44K: duration=8.0 s, sampling_rate=44100
# CLIP encoder: 8 fps, 384×384 px
# Synchformer: 25 fps, 224×224 px
# Default variant: large_44k_v2
# MMAudio uses flow-matching (FlowMatching with euler inference).
# generate() handles all feature extraction + decoding internally.
# ================================================================== #
def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
"""Pre-GPU callable — must match _mmaudio_gpu_infer's input order exactly."""
return _estimate_gpu_duration("mmaudio", int(num_samples), int(num_steps),
video_file=video_file, crossfade_s=crossfade_s)
@spaces.GPU(duration=_mmaudio_duration)
def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
"""GPU-only MMAudio inference — model loading + flow-matching generation.
Returns list of (seg_audios, sr) per sample."""
_ensure_syspath("MMAudio")
from mmaudio.eval_utils import generate, load_video
from mmaudio.model.flow_matching import FlowMatching
seed_val = int(seed_val)
num_samples = int(num_samples)
crossfade_s = float(crossfade_s)
device, dtype = _get_device_and_dtype()
net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
ctx = _mmaudio_gpu_infer._cpu_ctx
segments = ctx["segments"]
seg_clip_paths = ctx["seg_clip_paths"]
sr = seq_cfg.sampling_rate # 44100
results = []
for sample_idx in range(num_samples):
rng = torch.Generator(device=device)
if seed_val >= 0:
rng.manual_seed(seed_val + sample_idx)
else:
rng.seed()
seg_audios = []
_t_mma_start = time.perf_counter()
for seg_i, (seg_start, seg_end) in enumerate(segments):
seg_dur = seg_end - seg_start
seg_path = seg_clip_paths[seg_i]
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
video_info = load_video(seg_path, seg_dur)
clip_frames = video_info.clip_frames.unsqueeze(0)
sync_frames = video_info.sync_frames.unsqueeze(0)
actual_dur = video_info.duration_sec
seq_cfg.duration = actual_dur
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
print(f"[MMAudio] Sample {sample_idx+1} | seg {seg_i+1}/{len(segments)} "
f"{seg_start:.1f}{seg_end:.1f}s | dur={actual_dur:.2f}s | prompt='{prompt}'")
with torch.no_grad():
audios = generate(
clip_frames,
sync_frames,
[prompt],
negative_text=[negative_prompt] if negative_prompt else None,
feature_utils=feature_utils,
net=net,
fm=fm,
rng=rng,
cfg_strength=float(cfg_strength),
)
wav = audios.float().cpu()[0].numpy() # (C, T)
seg_samples = int(round(seg_dur * sr))
wav = wav[:, :seg_samples]
seg_audios.append(wav)
_log_inference_timing("MMAudio", time.perf_counter() - _t_mma_start,
len(segments), int(num_steps), MMAUDIO_SECS_PER_STEP)
results.append((seg_audios, sr))
# Free GPU memory between samples to prevent VRAM fragmentation
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
_mmaudio_gpu_infer._cpu_ctx = {}
def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
"""MMAudio: flow-matching video-to-audio, 44.1 kHz, 8 s sliding window.
CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost."""
num_samples = int(num_samples)
crossfade_s = float(crossfade_s)
crossfade_db = float(crossfade_db)
# ── CPU pre-processing ──
tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
video_file, MMAUDIO_WINDOW, crossfade_s)
print(f"[MMAudio] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤8 s")
seg_clip_paths = [
_extract_segment_clip(silent_video, s, e - s, os.path.join(tmp_dir, f"mma_seg_{i}.mp4"))
for i, (s, e) in enumerate(segments)
]
_mmaudio_gpu_infer._cpu_ctx = {
"segments": segments, "seg_clip_paths": seg_clip_paths,
}
# ── GPU inference only ──
results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples)
# ── CPU post-processing ──
outputs = []
for sample_idx, (seg_audios, sr) in enumerate(results):
full_wav = _stitch_wavs(seg_audios, crossfade_s, crossfade_db, total_dur_s, sr)
audio_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.wav")
_save_wav(audio_path, full_wav, sr)
video_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.mp4")
mux_video_audio(silent_video, audio_path, video_path)
wav_paths = _save_seg_wavs(seg_audios, tmp_dir, f"mmaudio_{sample_idx}")
seg_meta = _build_seg_meta(
segments=segments, wav_paths=wav_paths, audio_path=audio_path,
video_path=video_path, silent_video=silent_video, sr=sr,
model="mmaudio", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
total_dur_s=total_dur_s,
)
outputs.append((video_path, audio_path, seg_meta))
return _pad_outputs(outputs)
# ================================================================== #
# HunyuanVideoFoley #
# ================================================================== #
# Constants sourced from HunyuanVideo-Foley/hunyuanvideo_foley/constants.py
# and configs/hunyuanvideo-foley-xxl.yaml:
# sample_rate = 48000 Hz (from DAC VAE)
# audio_frame_rate = 50 (latent fps, xxl config)
# max video duration = 15 s
# SigLIP2 fps = 8, Synchformer fps = 25
# CLAP text encoder: laion/larger_clap_general (auto-downloaded from HF Hub)
# Default guidance_scale=4.5, num_inference_steps=50
# ================================================================== #
def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
"""Pre-GPU callable — must match _hunyuan_gpu_infer's input order exactly."""
return _estimate_gpu_duration("hunyuan", int(num_samples), int(num_steps),
video_file=video_file, crossfade_s=crossfade_s)
@spaces.GPU(duration=_hunyuan_duration)
def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
"""GPU-only HunyuanFoley inference — model loading + feature extraction + denoising.
Returns list of (seg_wavs, sr, text_feats) per sample."""
_ensure_syspath("HunyuanVideo-Foley")
from hunyuanvideo_foley.utils.model_utils import denoise_process
from hunyuanvideo_foley.utils.feature_utils import feature_process
seed_val = int(seed_val)
num_samples = int(num_samples)
crossfade_s = float(crossfade_s)
if seed_val >= 0:
set_global_seed(seed_val)
device, _ = _get_device_and_dtype()
device = torch.device(device)
model_size = model_size.lower()
model_dict, cfg = _load_hunyuan_model(device, model_size)
ctx = _hunyuan_gpu_infer._cpu_ctx
segments = ctx["segments"]
total_dur_s = ctx["total_dur_s"]
dummy_seg_path = ctx["dummy_seg_path"]
seg_clip_paths = ctx["seg_clip_paths"]
# Text feature extraction (GPU — runs once for all segments)
_, text_feats, _ = feature_process(
dummy_seg_path,
prompt if prompt else "",
model_dict,
cfg,
neg_prompt=negative_prompt if negative_prompt else None,
)
# Import visual-only feature extractor to avoid redundant text extraction
# per segment (text_feats already computed once above for the whole batch).
from hunyuanvideo_foley.utils.feature_utils import encode_video_features
results = []
for sample_idx in range(num_samples):
seg_wavs = []
sr = 48000
_t_hny_start = time.perf_counter()
for seg_i, (seg_start, seg_end) in enumerate(segments):
seg_dur = seg_end - seg_start
seg_path = seg_clip_paths[seg_i]
# Extract only visual features — reuse text_feats from above
visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict)
print(f"[HunyuanFoley] Sample {sample_idx+1} | seg {seg_i+1}/{len(segments)} "
f"{seg_start:.1f}{seg_end:.1f}s → {seg_audio_len:.2f}s audio")
audio_batch, sr = denoise_process(
visual_feats,
text_feats,
seg_audio_len,
model_dict,
cfg,
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_steps),
batch_size=1,
)
wav = audio_batch[0].float().cpu().numpy()
seg_samples = int(round(seg_dur * sr))
wav = wav[:, :seg_samples]
seg_wavs.append(wav)
_log_inference_timing("HunyuanFoley", time.perf_counter() - _t_hny_start,
len(segments), int(num_steps), HUNYUAN_SECS_PER_STEP)
results.append((seg_wavs, sr, text_feats))
# Free GPU memory between samples to prevent VRAM fragmentation
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
_hunyuan_gpu_infer._cpu_ctx = {}
def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
"""HunyuanVideoFoley: text-guided foley, 48 kHz, up to 15 s.
CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost."""
num_samples = int(num_samples)
crossfade_s = float(crossfade_s)
crossfade_db = float(crossfade_db)
# ── CPU pre-processing (no GPU needed) ──
tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
video_file, HUNYUAN_MAX_DUR, crossfade_s)
print(f"[HunyuanFoley] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤15 s")
# Pre-extract dummy segment for text feature extraction (ffmpeg, CPU)
dummy_seg_path = _extract_segment_clip(
silent_video, 0, min(total_dur_s, HUNYUAN_MAX_DUR),
os.path.join(tmp_dir, "_seg_dummy.mp4"),
)
# Pre-extract all segment clips (ffmpeg, CPU)
seg_clip_paths = [
_extract_segment_clip(silent_video, s, e - s, os.path.join(tmp_dir, f"hny_seg_{i}.mp4"))
for i, (s, e) in enumerate(segments)
]
_hunyuan_gpu_infer._cpu_ctx = {
"segments": segments, "total_dur_s": total_dur_s,
"dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
}
# ── GPU inference only ──
results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db, num_samples)
# ── CPU post-processing (no GPU needed) ──
_ensure_syspath("HunyuanVideo-Foley")
from hunyuanvideo_foley.utils.media_utils import merge_audio_video
outputs = []
for sample_idx, (seg_wavs, sr, text_feats) in enumerate(results):
full_wav = _stitch_wavs(seg_wavs, crossfade_s, crossfade_db, total_dur_s, sr)
audio_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.wav")
_save_wav(audio_path, full_wav, sr)
video_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.mp4")
merge_audio_video(audio_path, silent_video, video_path)
wav_paths = _save_seg_wavs(seg_wavs, tmp_dir, f"hunyuan_{sample_idx}")
text_feats_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}_text_feats.pt")
torch.save(text_feats, text_feats_path)
seg_meta = _build_seg_meta(
segments=segments, wav_paths=wav_paths, audio_path=audio_path,
video_path=video_path, silent_video=silent_video, sr=sr,
model="hunyuan", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
total_dur_s=total_dur_s, text_feats_path=text_feats_path,
)
outputs.append((video_path, audio_path, seg_meta))
return _pad_outputs(outputs)
# ================================================================== #
# SEGMENT REGENERATION HELPERS #
# ================================================================== #
# Each regen function:
# 1. Runs inference for ONE segment (random seed, current settings)
# 2. Splices the new wav into the stored wavs list
# 3. Re-stitches the full track, re-saves .wav and re-muxes .mp4
# 4. Returns (new_video_path, new_audio_path, updated_seg_meta, new_waveform_html)
# ================================================================== #
def _splice_and_save(new_wav, seg_idx, meta, slot_id):
"""Replace wavs[seg_idx] with new_wav, re-stitch, re-save, re-mux.
Returns (video_path, audio_path, updated_meta, waveform_html).
"""
wavs = _load_seg_wavs(meta["wav_paths"])
wavs[seg_idx]= new_wav
crossfade_s = float(meta["crossfade_s"])
crossfade_db = float(meta["crossfade_db"])
sr = int(meta["sr"])
total_dur_s = float(meta["total_dur_s"])
silent_video = meta["silent_video"]
segments = meta["segments"]
model = meta["model"]
full_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, sr)
# Save new audio — use a new timestamped filename so Gradio / the browser
# treats it as a genuinely different file and reloads the video player.
_ts = int(time.time() * 1000)
tmp_dir = os.path.dirname(meta["audio_path"])
_base = os.path.splitext(os.path.basename(meta["audio_path"]))[0]
# Strip any previous timestamp suffix before adding a new one
_base_clean = _base.rsplit("_regen_", 1)[0]
audio_path = os.path.join(tmp_dir, f"{_base_clean}_regen_{_ts}.wav")
_save_wav(audio_path, full_wav, sr)
# Re-mux into a new video file so the browser is forced to reload it
_vid_base = os.path.splitext(os.path.basename(meta["video_path"]))[0]
_vid_base_clean = _vid_base.rsplit("_regen_", 1)[0]
video_path = os.path.join(tmp_dir, f"{_vid_base_clean}_regen_{_ts}.mp4")
if model == "hunyuan":
# HunyuanFoley uses its own merge_audio_video
_ensure_syspath("HunyuanVideo-Foley")
from hunyuanvideo_foley.utils.media_utils import merge_audio_video
merge_audio_video(audio_path, silent_video, video_path)
else:
mux_video_audio(silent_video, audio_path, video_path)
# Save updated segment wavs to .npy files
updated_wav_paths = _save_seg_wavs(wavs, tmp_dir, os.path.splitext(_base_clean)[0])
updated_meta = dict(meta)
updated_meta["wav_paths"] = updated_wav_paths
updated_meta["audio_path"] = audio_path
updated_meta["video_path"] = video_path
state_json_new = json.dumps(updated_meta)
waveform_html = _build_waveform_html(audio_path, segments, slot_id, "",
state_json=state_json_new,
video_path=video_path,
crossfade_s=crossfade_s)
return video_path, audio_path, updated_meta, waveform_html
def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, slot_id=None):
# If cached CAVP/onset features exist, skip ~10s feature-extractor overhead
try:
meta = json.loads(seg_meta_json)
cavp_ok = os.path.exists(meta.get("cavp_path", ""))
onset_ok = os.path.exists(meta.get("onset_path", ""))
if cavp_ok and onset_ok:
cfg = MODEL_CONFIGS["taro"]
secs = int(num_steps) * cfg["secs_per_step"] + 5 # 5s model-load only
result = min(GPU_DURATION_CAP, max(30, int(secs)))
print(f"[duration] TARO regen (cache hit): 1 seg × {int(num_steps)} steps → {secs:.0f}s → capped {result}s")
return result
except Exception:
pass
return _estimate_regen_duration("taro", int(num_steps))
@spaces.GPU(duration=_taro_regen_duration)
def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, slot_id=None):
"""GPU-only TARO regen — returns new_wav for a single segment."""
meta = json.loads(seg_meta_json)
seg_idx = int(seg_idx)
seg_start_s, seg_end_s = meta["segments"][seg_idx]
torch.set_grad_enabled(False)
device, weight_dtype = _get_device_and_dtype()
_ensure_syspath("TARO")
from TARO.samplers import euler_sampler, euler_maruyama_sampler
cavp_path = meta.get("cavp_path")
onset_path = meta.get("onset_path")
if cavp_path and os.path.exists(cavp_path) and onset_path and os.path.exists(onset_path):
print("[TARO regen] Loading cached CAVP + onset features")
cavp_feats = np.load(cavp_path)
onset_feats = np.load(onset_path)
else:
print("[TARO regen] Cache miss — re-extracting CAVP + onset features")
from TARO.onset_util import extract_onset
extract_cavp, onset_model = _load_taro_feature_extractors(device)
silent_video = meta["silent_video"]
tmp_dir = tempfile.mkdtemp()
cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
# Free feature extractors before loading inference models
del extract_cavp, onset_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
model_net, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype)
set_global_seed(random.randint(0, 2**32 - 1))
return _taro_infer_segment(
model_net, vae, vocoder, cavp_feats, onset_feats,
seg_start_s, seg_end_s, device, weight_dtype,
float(cfg_scale), int(num_steps), mode, latents_scale,
euler_sampler, euler_maruyama_sampler,
)
def regen_taro_segment(video_file, seg_idx, seg_meta_json,
seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, slot_id):
"""Regenerate one TARO segment. GPU inference + CPU splice/save."""
meta = json.loads(seg_meta_json)
seg_idx = int(seg_idx)
# GPU: inference only
new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, slot_id)
# FlashSR: upsample 16kHz → 48kHz on CPU (no GPU needed)
new_wav = _apply_flashsr(new_wav)
# CPU: splice, stitch, mux, save
video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
new_wav, seg_idx, meta, slot_id
)
return video_path, audio_path, json.dumps(updated_meta), waveform_html
def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id=None):
return _estimate_regen_duration("mmaudio", int(num_steps))
@spaces.GPU(duration=_mmaudio_regen_duration)
def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id=None):
"""GPU-only MMAudio regen — returns (new_wav, sr) for a single segment."""
meta = json.loads(seg_meta_json)
seg_idx = int(seg_idx)
seg_start, seg_end = meta["segments"][seg_idx]
seg_dur = seg_end - seg_start
_ensure_syspath("MMAudio")
from mmaudio.eval_utils import generate, load_video
from mmaudio.model.flow_matching import FlowMatching
device, dtype = _get_device_and_dtype()
net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
sr = seq_cfg.sampling_rate
# Use pre-extracted segment clip from the wrapper
seg_path = _regen_mmaudio_gpu._cpu_ctx.get("seg_path")
if not seg_path:
# Fallback: extract inside GPU (shouldn't happen)
seg_path = _extract_segment_clip(
meta["silent_video"], seg_start, seg_dur,
os.path.join(tempfile.mkdtemp(), "regen_seg.mp4"),
)
rng = torch.Generator(device=device)
rng.manual_seed(random.randint(0, 2**32 - 1))
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=int(num_steps))
video_info = load_video(seg_path, seg_dur)
clip_frames = video_info.clip_frames.unsqueeze(0)
sync_frames = video_info.sync_frames.unsqueeze(0)
actual_dur = video_info.duration_sec
seq_cfg.duration = actual_dur
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
with torch.no_grad():
audios = generate(
clip_frames, sync_frames, [prompt],
negative_text=[negative_prompt] if negative_prompt else None,
feature_utils=feature_utils, net=net, fm=fm, rng=rng,
cfg_strength=float(cfg_strength),
)
new_wav = audios.float().cpu()[0].numpy()
seg_samples = int(round(seg_dur * sr))
new_wav = new_wav[:, :seg_samples]
return new_wav, sr
_regen_mmaudio_gpu._cpu_ctx = {}
def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id):
"""Regenerate one MMAudio segment. GPU inference + CPU splice/save."""
meta = json.loads(seg_meta_json)
seg_idx = int(seg_idx)
seg_start, seg_end = meta["segments"][seg_idx]
seg_dur = seg_end - seg_start
# CPU: pre-extract segment clip
tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
seg_path = _extract_segment_clip(
meta["silent_video"], seg_start, seg_dur,
os.path.join(tmp_dir, "regen_seg.mp4"),
)
_regen_mmaudio_gpu._cpu_ctx = {"seg_path": seg_path}
# GPU: inference only
new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id)
meta["sr"] = sr
# CPU: splice, stitch, mux, save
video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
new_wav, seg_idx, meta, slot_id
)
return video_path, audio_path, json.dumps(updated_meta), waveform_html
def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db, slot_id=None):
return _estimate_regen_duration("hunyuan", int(num_steps))
@spaces.GPU(duration=_hunyuan_regen_duration)
def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db, slot_id=None):
"""GPU-only HunyuanFoley regen — returns (new_wav, sr) for a single segment."""
meta = json.loads(seg_meta_json)
seg_idx = int(seg_idx)
seg_start, seg_end = meta["segments"][seg_idx]
seg_dur = seg_end - seg_start
_ensure_syspath("HunyuanVideo-Foley")
from hunyuanvideo_foley.utils.model_utils import denoise_process
from hunyuanvideo_foley.utils.feature_utils import feature_process
device, _ = _get_device_and_dtype()
device = torch.device(device)
model_dict, cfg = _load_hunyuan_model(device, model_size)
set_global_seed(random.randint(0, 2**32 - 1))
# Use pre-extracted segment clip from wrapper
seg_path = _regen_hunyuan_gpu._cpu_ctx.get("seg_path")
if not seg_path:
seg_path = _extract_segment_clip(
meta["silent_video"], seg_start, seg_dur,
os.path.join(tempfile.mkdtemp(), "regen_seg.mp4"),
)
text_feats_path = meta.get("text_feats_path")
if text_feats_path and os.path.exists(text_feats_path):
print("[HunyuanFoley regen] Loading cached text features, extracting visual only")
from hunyuanvideo_foley.utils.feature_utils import encode_video_features
visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict)
text_feats = torch.load(text_feats_path, map_location=device, weights_only=False)
else:
print("[HunyuanFoley regen] Cache miss — extracting text + visual features")
visual_feats, text_feats, seg_audio_len = feature_process(
seg_path, prompt if prompt else "", model_dict, cfg,
neg_prompt=negative_prompt if negative_prompt else None,
)
audio_batch, sr = denoise_process(
visual_feats, text_feats, seg_audio_len, model_dict, cfg,
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_steps),
batch_size=1,
)
new_wav = audio_batch[0].float().cpu().numpy()
seg_samples = int(round(seg_dur * sr))
new_wav = new_wav[:, :seg_samples]
return new_wav, sr
_regen_hunyuan_gpu._cpu_ctx = {}
def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db, slot_id):
"""Regenerate one HunyuanFoley segment. GPU inference + CPU splice/save."""
meta = json.loads(seg_meta_json)
seg_idx = int(seg_idx)
seg_start, seg_end = meta["segments"][seg_idx]
seg_dur = seg_end - seg_start
# CPU: pre-extract segment clip
tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
seg_path = _extract_segment_clip(
meta["silent_video"], seg_start, seg_dur,
os.path.join(tmp_dir, "regen_seg.mp4"),
)
_regen_hunyuan_gpu._cpu_ctx = {"seg_path": seg_path}
# GPU: inference only
new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db, slot_id)
meta["sr"] = sr
# CPU: splice, stitch, mux, save
video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
new_wav, seg_idx, meta, slot_id
)
return video_path, audio_path, json.dumps(updated_meta), waveform_html
# Wire up regen_fn references now that the functions are defined
MODEL_CONFIGS["taro"]["regen_fn"] = regen_taro_segment
MODEL_CONFIGS["mmaudio"]["regen_fn"] = regen_mmaudio_segment
MODEL_CONFIGS["hunyuan"]["regen_fn"] = regen_hunyuan_segment
# ================================================================== #
# CROSS-MODEL REGEN WRAPPERS #
# ================================================================== #
# Three shared endpoints — one per model — that can be called from #
# *any* slot tab. slot_id is passed as plain string data so the #
# result is applied back to the correct slot by the JS listener. #
# The new segment is resampled to match the slot's existing SR before #
# being handed to _splice_and_save, so TARO (16 kHz) / MMAudio #
# (44.1 kHz) / Hunyuan (48 kHz) outputs can all be mixed freely. #
# ================================================================== #
def _resample_to_slot_sr(wav: np.ndarray, src_sr: int, dst_sr: int,
slot_wav_ref: np.ndarray = None) -> np.ndarray:
"""Resample *wav* from src_sr to dst_sr using torchaudio, then match
channel layout to *slot_wav_ref* (the first existing segment in the slot).
TARO is mono (T,), MMAudio/Hunyuan are stereo (C, T). Mixing them
without normalisation causes a shape mismatch in _cf_join. Rules:
• stereo → mono : average channels
• mono → stereo: duplicate the single channel
"""
# 1. Resample
if src_sr != dst_sr:
stereo_in = wav.ndim == 2
t = torch.from_numpy(np.ascontiguousarray(wav))
if not stereo_in:
t = t.unsqueeze(0)
t = torchaudio.functional.resample(t.float(), src_sr, dst_sr)
if not stereo_in:
t = t.squeeze(0)
wav = t.numpy()
# 2. Match channel layout to the slot's existing segments
if slot_wav_ref is not None:
slot_stereo = slot_wav_ref.ndim == 2
wav_stereo = wav.ndim == 2
if slot_stereo and not wav_stereo:
wav = np.stack([wav, wav], axis=0) # mono → stereo (C, T)
elif not slot_stereo and wav_stereo:
wav = wav.mean(axis=0) # stereo → mono (T,)
return wav
def _xregen_splice(new_wav_raw: np.ndarray, src_sr: int,
meta: dict, seg_idx: int, slot_id: str) -> tuple:
"""Shared epilogue for all xregen_* functions: resample → splice → save.
Returns (video_path, waveform_html)."""
slot_sr = int(meta["sr"])
slot_wavs = _load_seg_wavs(meta["wav_paths"])
new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr, slot_wavs[0])
video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
new_wav, seg_idx, meta, slot_id
)
return video_path, waveform_html
def xregen_taro(seg_idx, state_json, slot_id,
seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db,
request: gr.Request = None):
"""Cross-model regen: run TARO inference and splice into *slot_id*."""
meta = json.loads(state_json)
seg_idx = int(seg_idx)
# Show pending waveform immediately
pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
yield gr.update(), gr.update(value=pending_html)
new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
seed_val, cfg_scale, num_steps, mode,
crossfade_s, crossfade_db, slot_id)
# FlashSR: upsample 16kHz → 48kHz on CPU (no GPU needed)
new_wav_raw = _apply_flashsr(new_wav_raw)
video_path, waveform_html = _xregen_splice(new_wav_raw, FLASHSR_SR_OUT, meta, seg_idx, slot_id)
yield gr.update(value=video_path), gr.update(value=waveform_html)
def xregen_mmaudio(seg_idx, state_json, slot_id,
prompt, negative_prompt, seed_val,
cfg_strength, num_steps, crossfade_s, crossfade_db,
request: gr.Request = None):
"""Cross-model regen: run MMAudio inference and splice into *slot_id*."""
meta = json.loads(state_json)
seg_idx = int(seg_idx)
seg_start, seg_end = meta["segments"][seg_idx]
# Show pending waveform immediately
pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
yield gr.update(), gr.update(value=pending_html)
seg_path = _extract_segment_clip(
meta["silent_video"], seg_start, seg_end - seg_start,
os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
)
_regen_mmaudio_gpu._cpu_ctx = {"seg_path": seg_path}
new_wav_raw, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
prompt, negative_prompt, seed_val,
cfg_strength, num_steps,
crossfade_s, crossfade_db, slot_id)
video_path, waveform_html = _xregen_splice(new_wav_raw, src_sr, meta, seg_idx, slot_id)
yield gr.update(value=video_path), gr.update(value=waveform_html)
def xregen_hunyuan(seg_idx, state_json, slot_id,
prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db,
request: gr.Request = None):
"""Cross-model regen: run HunyuanFoley inference and splice into *slot_id*."""
meta = json.loads(state_json)
seg_idx = int(seg_idx)
seg_start, seg_end = meta["segments"][seg_idx]
# Show pending waveform immediately
pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
yield gr.update(), gr.update(value=pending_html)
seg_path = _extract_segment_clip(
meta["silent_video"], seg_start, seg_end - seg_start,
os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
)
_regen_hunyuan_gpu._cpu_ctx = {"seg_path": seg_path}
new_wav_raw, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
prompt, negative_prompt, seed_val,
guidance_scale, num_steps, model_size,
crossfade_s, crossfade_db, slot_id)
video_path, waveform_html = _xregen_splice(new_wav_raw, src_sr, meta, seg_idx, slot_id)
yield gr.update(value=video_path), gr.update(value=waveform_html)
# ================================================================== #
# SHARED UI HELPERS #
# ================================================================== #
def _register_regen_handlers(tab_prefix, model_key, regen_seg_tb, regen_state_tb,
input_components, slot_vids, slot_waves):
"""Register per-slot regen button handlers for a model tab.
This replaces the three nearly-identical for-loops that previously existed
for TARO, MMAudio, and HunyuanFoley tabs.
Args:
tab_prefix: e.g. "taro", "mma", "hf"
model_key: e.g. "taro", "mmaudio", "hunyuan"
regen_seg_tb: gr.Textbox for seg_idx (render=False)
regen_state_tb: gr.Textbox for state_json (render=False)
input_components: list of Gradio input components (video, seed, etc.)
— order must match regen_fn signature after (seg_idx, state_json, video)
slot_vids: list of gr.Video components per slot
slot_waves: list of gr.HTML components per slot
Returns:
list of hidden gr.Buttons (one per slot)
"""
cfg = MODEL_CONFIGS[model_key]
regen_fn = cfg["regen_fn"]
label = cfg["label"]
btns = []
for _i in range(MAX_SLOTS):
_slot_id = f"{tab_prefix}_{_i}"
_btn = gr.Button(render=False, elem_id=f"regen_btn_{_slot_id}")
btns.append(_btn)
print(f"[startup] registering regen handler for slot {_slot_id}")
def _make_regen(_si, _sid, _model_key, _label, _regen_fn):
def _do(seg_idx, state_json, *args):
print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} "
f"state_json_len={len(state_json) if state_json else 0}")
if not state_json:
print(f"[regen {_label}] early-exit: state_json empty")
yield gr.update(), gr.update()
return
lock = _get_slot_lock(_sid)
with lock:
state = json.loads(state_json)
pending_html = _build_regen_pending_html(
state["segments"], int(seg_idx), _sid, ""
)
yield gr.update(), gr.update(value=pending_html)
print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} — calling regen")
try:
# args[0] = video, args[1:] = model-specific params
vid, aud, new_meta_json, html = _regen_fn(
args[0], int(seg_idx), state_json, *args[1:], _sid,
)
print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} — done, vid={vid!r}")
except Exception as _e:
print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} — ERROR: {_e}")
raise
yield gr.update(value=vid), gr.update(value=html)
return _do
_btn.click(
fn=_make_regen(_i, _slot_id, model_key, label, regen_fn),
inputs=[regen_seg_tb, regen_state_tb] + input_components,
outputs=[slot_vids[_i], slot_waves[_i]],
api_name=f"regen_{tab_prefix}_{_i}",
)
return btns
def _pad_outputs(outputs: list) -> list:
"""Flatten (video, audio, seg_meta) triples and pad to MAX_SLOTS * 3 with None.
Each entry in *outputs* must be a (video_path, audio_path, seg_meta) tuple where
seg_meta = {"segments": [...], "audio_path": str, "video_path": str,
"sr": int, "model": str, "crossfade_s": float,
"crossfade_db": float, "wav_paths": list[str]}
"""
result = []
for i in range(MAX_SLOTS):
if i < len(outputs):
result.extend(outputs[i]) # 3 items: video, audio, meta
else:
result.extend([None, None, None])
return result
# ------------------------------------------------------------------ #
# WaveSurfer waveform + segment marker HTML builder #
# ------------------------------------------------------------------ #
def _build_regen_pending_html(segments: list, regen_seg_idx: int, slot_id: str,
hidden_input_id: str) -> str:
"""Return a waveform placeholder shown while a segment is being regenerated.
Renders a dark bar with the active segment highlighted in amber + a spinner.
"""
segs_json = json.dumps(segments)
seg_colors = [c.format(a="0.25") for c in SEG_COLORS]
active_color = "rgba(255,180,0,0.55)"
duration = segments[-1][1] if segments else 1.0
seg_divs = ""
for i, seg in enumerate(segments):
# Draw only the non-overlapping (unique) portion of each segment so that
# overlapping windows don't visually bleed into adjacent segments.
# Each segment owns the region from its own start up to the next segment's
# start (or its own end for the final segment).
seg_start = seg[0]
seg_end = segments[i + 1][0] if i + 1 < len(segments) else seg[1]
left_pct = seg_start / duration * 100
width_pct = (seg_end - seg_start) / duration * 100
color = active_color if i == regen_seg_idx else seg_colors[i % len(seg_colors)]
extra = "border:2px solid #ffb300;animation:wf_pulse 0.8s ease-in-out infinite alternate;" if i == regen_seg_idx else ""
seg_divs += (
f'<div style="position:absolute;top:0;left:{left_pct:.2f}%;'
f'width:{width_pct:.2f}%;height:100%;background:{color};{extra}">'
f'<span style="color:rgba(255,255,255,0.7);font-size:10px;padding:2px 3px;">Seg {i+1}</span>'
f'</div>'
)
spinner = (
'<div style="position:absolute;top:50%;left:50%;transform:translate(-50%,-50%);'
'display:flex;align-items:center;gap:6px;">'
'<div style="width:14px;height:14px;border:2px solid #ffb300;'
'border-top-color:transparent;border-radius:50%;'
'animation:wf_spin 0.7s linear infinite;"></div>'
f'<span style="color:#ffb300;font-size:12px;white-space:nowrap;">'
f'Regenerating Seg {regen_seg_idx+1}…</span>'
'</div>'
)
return f"""
<style>
@keyframes wf_pulse {{from{{opacity:0.5}}to{{opacity:1}}}}
@keyframes wf_spin {{to{{transform:rotate(360deg)}}}}
</style>
<div style="background:#1a1a1a;border-radius:8px;padding:10px;margin-top:6px;">
<div style="position:relative;width:100%;height:80px;background:#1e1e2e;border-radius:4px;overflow:hidden;">
{seg_divs}
{spinner}
</div>
<div style="color:#888;font-size:11px;margin-top:6px;">Regenerating — please wait…</div>
</div>
"""
def _build_waveform_html(audio_path: str, segments: list, slot_id: str,
hidden_input_id: str, state_json: str = "",
fn_index: int = -1, video_path: str = "",
crossfade_s: float = 0.0) -> str:
"""Return a self-contained HTML block with a Canvas waveform (display only),
segment boundary markers, and a download link.
Uses Web Audio API + Canvas — no external libraries.
The waveform is SILENT. The playhead tracks the Gradio <video> element
in the same slot via its timeupdate event.
"""
if not audio_path or not os.path.exists(audio_path):
return "<p style='color:#888;font-size:12px'>No audio yet.</p>"
# Serve audio via Gradio's file API instead of base64-encoding the entire
# WAV inline. For a 25s stereo 44.1kHz track this saves ~5 MB per slot.
audio_url = f"/gradio_api/file={audio_path}"
segs_json = json.dumps(segments)
seg_colors = [c.format(a="0.35") for c in SEG_COLORS]
# NOTE: Gradio updates gr.HTML via innerHTML which does NOT execute <script> tags.
# Solution: put the entire waveform (canvas + JS) inside an <iframe srcdoc="...">.
# iframes always execute their scripts. The iframe posts messages to the parent for
# segment-click events; the parent listens and fires the Gradio regen trigger.
# For playhead sync, the iframe polls window.parent for a <video> element.
iframe_inner = f"""<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<style>
* {{ margin:0; padding:0; box-sizing:border-box; }}
body {{ background:#1a1a1a; overflow:hidden; }}
#wrap {{ position:relative; width:100%; height:80px; }}
canvas {{ display:block; }}
#cv {{ position:absolute; top:0; left:0; width:100%; height:100%; }}
#cvp {{ position:absolute; top:0; left:0; width:100%; height:100%; pointer-events:none; }}
</style>
</head>
<body>
<div id="wrap">
<canvas id="cv"></canvas>
<canvas id="cvp"></canvas>
</div>
<script>
(function() {{
const SLOT_ID = '{slot_id}';
const segments = {segs_json};
const segColors = {json.dumps(seg_colors)};
const crossfadeSec = {crossfade_s};
let audioDuration = 0;
// ── Popup via postMessage to parent global listener ─────────────────
// The parent page (Gradio) has a global window.addEventListener('message',...)
// set up via gr.Blocks(js=...) that handles popup show/hide and regen trigger.
function showPopup(idx, mx, my) {{
console.log('[wf showPopup] slot='+SLOT_ID+' idx='+idx+' posting to parent');
// Convert iframe-local coords to parent page coords
try {{
const fr = window.frameElement ? window.frameElement.getBoundingClientRect() : {{left:0,top:0}};
window.parent.postMessage({{
type:'wf_popup', action:'show',
slot_id: SLOT_ID, seg_idx: idx,
t0: segments[idx][0], t1: segments[idx][1],
x: mx + fr.left, y: my + fr.top
}}, '*');
console.log('[wf showPopup] postMessage sent OK');
}} catch(e) {{
console.log('[wf showPopup] postMessage fallback, err='+e.message);
window.parent.postMessage({{
type:'wf_popup', action:'show',
slot_id: SLOT_ID, seg_idx: idx,
t0: segments[idx][0], t1: segments[idx][1],
x: mx, y: my
}}, '*');
}}
}}
function hidePopup() {{
window.parent.postMessage({{type:'wf_popup', action:'hide'}}, '*');
}}
// ── Canvas waveform ────────────────────────────────────────────────
const cv = document.getElementById('cv');
const cvp = document.getElementById('cvp');
const wrap= document.getElementById('wrap');
function drawWaveform(channelData, duration) {{
audioDuration = duration;
const dpr = window.devicePixelRatio || 1;
const W = wrap.getBoundingClientRect().width || window.innerWidth || 600;
const H = 80;
cv.width = W * dpr; cv.height = H * dpr;
const ctx = cv.getContext('2d');
ctx.scale(dpr, dpr);
ctx.fillStyle = '#1e1e2e';
ctx.fillRect(0, 0, W, H);
segments.forEach(function(seg, idx) {{
// Color boundary = midpoint of the crossfade zone = where the blend is
// 50/50. This is also where the cut would land if crossfade were 0, and
// where the listener perceptually hears the transition to the next segment.
const x1 = (seg[0] / duration) * W;
const xEnd = idx + 1 < segments.length
? ((segments[idx + 1][0] + crossfadeSec / 2) / duration) * W
: (seg[1] / duration) * W;
ctx.fillStyle = segColors[idx % segColors.length];
ctx.fillRect(x1, 0, xEnd - x1, H);
ctx.fillStyle = 'rgba(255,255,255,0.6)';
ctx.font = '10px sans-serif';
ctx.fillText('Seg '+(idx+1), x1+3, 12);
}});
const samples = channelData.length;
const barW=2, gap=1, step=barW+gap;
const numBars = Math.floor(W / step);
const blockSz = Math.floor(samples / numBars);
ctx.fillStyle = '#4a9eff';
for (let i=0; i<numBars; i++) {{
let max=0;
const s=i*blockSz;
for (let j=0; j<blockSz; j++) {{
const v=Math.abs(channelData[s+j]||0);
if (v>max) max=v;
}}
const barH=Math.max(1, max*H);
ctx.fillRect(i*step, (H-barH)/2, barW, barH);
}}
segments.forEach(function(seg) {{
[seg[0],seg[1]].forEach(function(t) {{
const x=(t/duration)*W;
ctx.strokeStyle='rgba(255,255,255,0.4)';
ctx.lineWidth=1;
ctx.beginPath(); ctx.moveTo(x,0); ctx.lineTo(x,H); ctx.stroke();
}});
}});
// ── Crossfade overlap indicators ──
// The color boundary is at segments[i+1][0] (= seg_i.end - crossfadeSec).
// We centre the hatch on that edge: half the crossfade on each color side.
if (crossfadeSec > 0 && segments.length > 1) {{
for (let i = 0; i < segments.length - 1; i++) {{
// Color edge = segments[i+1][0], hatch spans half on each side
const edgeT = segments[i+1][0];
const overlapStart = edgeT - crossfadeSec / 2;
const overlapEnd = edgeT + crossfadeSec / 2;
const xL = (overlapStart / duration) * W;
const xR = (overlapEnd / duration) * W;
// Diagonal hatch pattern over the overlap zone
ctx.save();
ctx.beginPath();
ctx.rect(xL, 0, xR - xL, H);
ctx.clip();
ctx.strokeStyle = 'rgba(255,255,255,0.35)';
ctx.lineWidth = 1;
const spacing = 6;
for (let lx = xL - H; lx < xR + H; lx += spacing) {{
ctx.beginPath();
ctx.moveTo(lx, H);
ctx.lineTo(lx + H, 0);
ctx.stroke();
}}
ctx.restore();
}}
}}
cv.onclick = function(e) {{
const r=cv.getBoundingClientRect();
const xRel=(e.clientX-r.left)/r.width;
const tClick=xRel*duration;
// Pick the segment whose unique (non-overlapping) region contains the click.
// Each segment owns [seg[0], nextSeg[0]) visually; last segment owns [seg[0], seg[1]].
let hit=-1;
segments.forEach(function(seg,idx){{
const uniqueEnd = idx + 1 < segments.length ? segments[idx+1][0] : seg[1];
if (tClick >= seg[0] && tClick < uniqueEnd) hit = idx;
}});
console.log('[wf click] tClick='+tClick.toFixed(2)+' hit='+hit+' audioDuration='+audioDuration+' segments='+JSON.stringify(segments));
if (hit>=0) showPopup(hit, e.clientX, e.clientY);
else hidePopup();
}};
}}
function drawPlayhead(progress) {{
const dpr = window.devicePixelRatio || 1;
const W = wrap.getBoundingClientRect().width || window.innerWidth || 600;
const H = 80;
if (cvp.width !== W*dpr) {{ cvp.width=W*dpr; cvp.height=H*dpr; }}
const ctx = cvp.getContext('2d');
ctx.clearRect(0,0,W*dpr,H*dpr);
ctx.save();
ctx.scale(dpr,dpr);
const x=progress*W;
ctx.strokeStyle='#fff';
ctx.lineWidth=2;
ctx.beginPath(); ctx.moveTo(x,0); ctx.lineTo(x,H); ctx.stroke();
ctx.restore();
}}
// Poll parent for video time — find the video in the same wf_container slot
function findSlotVideo() {{
try {{
const par = window.parent.document;
// Walk up from our iframe to find wf_container_{slot_id}, then find its sibling video
const container = par.getElementById('wf_container_{slot_id}');
if (!container) return par.querySelector('video');
// The video is inside the same gr.Group — walk up to find it
let node = container.parentElement;
while (node && node !== par.body) {{
const v = node.querySelector('video');
if (v) return v;
node = node.parentElement;
}}
return null;
}} catch(e) {{ return null; }}
}}
setInterval(function() {{
const vid = findSlotVideo();
if (vid && vid.duration && isFinite(vid.duration) && audioDuration > 0) {{
drawPlayhead(vid.currentTime / vid.duration);
}}
}}, 50);
// ── Fetch + decode audio from Gradio file API ──────────────────────
const audioUrl = '{audio_url}';
fetch(audioUrl)
.then(function(r) {{ return r.arrayBuffer(); }})
.then(function(arrayBuf) {{
const AudioCtx = window.AudioContext || window.webkitAudioContext;
if (!AudioCtx) return;
const tmpCtx = new AudioCtx({{sampleRate:44100}});
tmpCtx.decodeAudioData(arrayBuf,
function(ab) {{
try {{ tmpCtx.close(); }} catch(e) {{}}
function tryDraw() {{
const W = wrap.getBoundingClientRect().width || window.innerWidth;
if (W > 0) {{ drawWaveform(ab.getChannelData(0), ab.duration); }}
else {{ setTimeout(tryDraw, 100); }}
}}
tryDraw();
}},
function(err) {{}}
);
}})
.catch(function(e) {{}});
}})();
</script>
</body>
</html>"""
# Escape for HTML attribute (srcdoc uses HTML entities)
srcdoc = _html.escape(iframe_inner, quote=True)
state_escaped = _html.escape(state_json or "", quote=True)
return f"""
<div id="wf_container_{slot_id}"
data-fn-index="{fn_index}"
data-state="{state_escaped}"
style="background:#1a1a1a;border-radius:8px;padding:10px;margin-top:6px;position:relative;">
<div style="position:relative;width:100%;height:80px;">
<iframe id="wf_iframe_{slot_id}"
srcdoc="{srcdoc}"
sandbox="allow-scripts allow-same-origin"
style="width:100%;height:80px;border:none;border-radius:4px;display:block;"
scrolling="no"></iframe>
</div>
<div style="display:flex;align-items:center;gap:8px;margin-top:6px;">
<span id="wf_statusbar_{slot_id}" style="color:#888;font-size:11px;">Click a segment to regenerate &nbsp;|&nbsp; Playhead syncs to video</span>
<a href="{audio_url}" download="audio_{slot_id}.wav"
style="margin-left:auto;background:#333;color:#eee;border:1px solid #555;
border-radius:4px;padding:3px 10px;font-size:12px;text-decoration:none;">
&#8595; Audio</a>{f'''
<a href="/gradio_api/file={video_path}" download="video_{slot_id}.mp4"
style="background:#333;color:#eee;border:1px solid #555;
border-radius:4px;padding:3px 10px;font-size:12px;text-decoration:none;">
&#8595; Video</a>''' if video_path else ''}
</div>
<div id="wf_seglabel_{slot_id}"
style="color:#aaa;font-size:11px;margin-top:4px;min-height:16px;"></div>
</div>
"""
def _make_output_slots(tab_prefix: str) -> tuple:
"""Build MAX_SLOTS output groups for one tab.
Each slot has: video and waveform HTML.
Regen is triggered via direct Gradio queue API calls from JS (no hidden
trigger textboxes needed — DOM event dispatch is unreliable in Gradio 5
Svelte components). State JSON is embedded in the waveform HTML's
data-state attribute and passed directly in the queue API payload.
Returns (grps, vids, waveforms).
"""
grps, vids, waveforms = [], [], []
for i in range(MAX_SLOTS):
slot_id = f"{tab_prefix}_{i}"
with gr.Group(visible=(i == 0)) as g:
vids.append(gr.Video(label=f"Generation {i+1} — Video",
elem_id=f"slot_vid_{slot_id}",
show_download_button=False))
waveforms.append(gr.HTML(
value="<p style='color:#888;font-size:12px'>Generate audio to see waveform.</p>",
elem_id=f"slot_wave_{slot_id}",
))
grps.append(g)
return grps, vids, waveforms
def _unpack_outputs(flat: list, n: int, tab_prefix: str) -> list:
"""Turn a flat _pad_outputs list into Gradio update lists.
flat has MAX_SLOTS * 3 items: [vid0, aud0, meta0, vid1, aud1, meta1, ...]
Returns updates for vids + waveforms only (NOT grps).
Group visibility is handled separately via .then() to avoid Gradio 5 SSR
'Too many arguments' caused by mixing gr.Group updates with other outputs.
State JSON is embedded in the waveform HTML data-state attribute so JS
can read it when calling the Gradio queue API for regen.
"""
n = int(n)
vid_updates = []
wave_updates = []
for i in range(MAX_SLOTS):
vid_path = flat[i * 3]
aud_path = flat[i * 3 + 1]
meta = flat[i * 3 + 2]
vid_updates.append(gr.update(value=vid_path))
if aud_path and meta:
slot_id = f"{tab_prefix}_{i}"
state_json = json.dumps(meta)
html = _build_waveform_html(aud_path, meta["segments"], slot_id,
"", state_json=state_json,
video_path=meta.get("video_path", ""),
crossfade_s=float(meta.get("crossfade_s", 0)))
wave_updates.append(gr.update(value=html))
else:
wave_updates.append(gr.update(
value="<p style='color:#888;font-size:12px'>Generate audio to see waveform.</p>"
))
return vid_updates + wave_updates
def _on_video_upload_taro(video_file, num_steps, crossfade_s):
if video_file is None:
return gr.update(maximum=MAX_SLOTS, value=1)
try:
D = get_video_duration(video_file)
max_s = _taro_calc_max_samples(D, int(num_steps), float(crossfade_s))
except Exception:
max_s = MAX_SLOTS
return gr.update(maximum=max_s, value=min(1, max_s))
def _update_slot_visibility(n):
n = int(n)
return [gr.update(visible=(i < n)) for i in range(MAX_SLOTS)]
# ================================================================== #
# GRADIO UI #
# ================================================================== #
_SLOT_CSS = """
/* Responsive video: fills column width, height auto from aspect ratio */
.gradio-video video {
width: 100%;
height: auto;
max-height: 60vh;
object-fit: contain;
}
/* Force two-column layout to stay equal-width */
.gradio-container .gradio-row > .gradio-column {
flex: 1 1 0 !important;
min-width: 0 !important;
max-width: 50% !important;
}
/* Hide the built-in download button on output video slots — downloads are
handled by the waveform panel links which always reflect the latest regen. */
[id^="slot_vid_"] .download-icon,
[id^="slot_vid_"] button[aria-label="Download"],
[id^="slot_vid_"] a[download] {
display: none !important;
}
"""
_GLOBAL_JS = """
() => {
// Global postMessage handler for waveform iframe events.
// Runs once on page load (Gradio js= parameter).
// Handles: popup open/close relay, regen trigger via Gradio queue API.
if (window._wf_global_listener) return; // already registered
window._wf_global_listener = true;
// ── ZeroGPU quota attribution ──
// HF Spaces run inside an iframe on huggingface.co. Gradio's own JS client
// gets ZeroGPU auth headers (x-zerogpu-token, x-zerogpu-uuid) by sending a
// postMessage("zerogpu-headers") to the parent frame. The parent responds
// with a Map of headers that must be included on queue/join calls.
// We replicate this exact mechanism so our raw regen fetch() calls are
// attributed to the logged-in user's Pro quota.
function _fetchZerogpuHeaders() {
return new Promise(function(resolve) {
// Check if we're in an HF iframe with zerogpu support
if (typeof window === 'undefined' || window.parent === window || !window.supports_zerogpu_headers) {
console.log('[zerogpu] not in HF iframe or no zerogpu support');
resolve({});
return;
}
// Determine origin — same logic as Gradio's client
var hostname = window.location.hostname;
var hfhubdev = 'dev.spaces.huggingface.tech';
var origin = hostname.includes('.dev.')
? 'https://moon-' + hostname.split('.')[1] + '.' + hfhubdev
: 'https://huggingface.co';
// Use MessageChannel just like Gradio's post_message helper
var channel = new MessageChannel();
var done = false;
channel.port1.onmessage = function(ev) {
channel.port1.close();
done = true;
var headers = ev.data;
if (headers && typeof headers === 'object') {
// Convert Map to plain object if needed
var obj = {};
if (typeof headers.forEach === 'function') {
headers.forEach(function(v, k) { obj[k] = v; });
} else {
obj = headers;
}
console.log('[zerogpu] got headers from parent:', Object.keys(obj).join(', '));
resolve(obj);
} else {
resolve({});
}
};
window.parent.postMessage('zerogpu-headers', origin, [channel.port2]);
// Timeout: don't block regen if parent doesn't respond
setTimeout(function() { if (!done) { done = true; channel.port1.close(); resolve({}); } }, 3000);
});
}
// Cache: api_name -> fn_index, built once from gradio_config.dependencies
let _fnIndexCache = null;
function getFnIndex(apiName) {
if (!_fnIndexCache) {
_fnIndexCache = {};
const deps = window.gradio_config && window.gradio_config.dependencies;
if (deps) deps.forEach(function(d, i) {
if (d.api_name) _fnIndexCache[d.api_name] = i;
});
}
return _fnIndexCache[apiName];
}
// Read a component's current DOM value by elem_id.
// For Number/Slider: reads the <input type="number"> or <input type="range">.
// For Textbox/Radio: reads the <textarea> or checked <input type="radio">.
// Returns null if not found.
function readComponentValue(elemId) {
const el = document.getElementById(elemId);
if (!el) return null;
const numInput = el.querySelector('input[type="number"]');
if (numInput) return parseFloat(numInput.value);
const rangeInput = el.querySelector('input[type="range"]');
if (rangeInput) return parseFloat(rangeInput.value);
const radio = el.querySelector('input[type="radio"]:checked');
if (radio) return radio.value;
const ta = el.querySelector('textarea');
if (ta) return ta.value;
const txt = el.querySelector('input[type="text"], input:not([type])');
if (txt) return txt.value;
return null;
}
// Fire regen for a given slot and segment by posting directly to the
// Gradio queue API — bypasses Svelte binding entirely.
// targetModel: 'taro' | 'mma' | 'hf' (which model to use for inference)
// If targetModel matches the slot's own prefix, uses the per-slot regen_* endpoint.
// Otherwise uses the shared xregen_* cross-model endpoint.
function fireRegen(slot_id, seg_idx, targetModel) {
const prefix = slot_id.split('_')[0]; // owning tab: 'taro'|'mma'|'hf'
const slotNum = parseInt(slot_id.split('_')[1], 10);
// Decide which endpoint to call
const crossModel = (targetModel !== prefix);
let apiName, data;
// Read state_json from the waveform container data-state attribute
const container = document.getElementById('wf_container_' + slot_id);
const stateJson = container ? (container.getAttribute('data-state') || '') : '';
if (!stateJson) {
console.warn('[fireRegen] no state_json for slot', slot_id);
return;
}
if (!crossModel) {
// ── Same-model regen: per-slot endpoint, video passed as null ──
apiName = 'regen_' + prefix + '_' + slotNum;
if (prefix === 'taro') {
data = [seg_idx, stateJson, null,
readComponentValue('taro_seed'), readComponentValue('taro_cfg'),
readComponentValue('taro_steps'), readComponentValue('taro_mode'),
readComponentValue('taro_cf_dur'), readComponentValue('taro_cf_db')];
} else if (prefix === 'mma') {
data = [seg_idx, stateJson, null,
readComponentValue('mma_prompt'), readComponentValue('mma_neg'),
readComponentValue('mma_seed'), readComponentValue('mma_cfg'),
readComponentValue('mma_steps'),
readComponentValue('mma_cf_dur'), readComponentValue('mma_cf_db')];
} else {
data = [seg_idx, stateJson, null,
readComponentValue('hf_prompt'), readComponentValue('hf_neg'),
readComponentValue('hf_seed'), readComponentValue('hf_guidance'),
readComponentValue('hf_steps'), readComponentValue('hf_size'),
readComponentValue('hf_cf_dur'), readComponentValue('hf_cf_db')];
}
} else {
// ── Cross-model regen: shared xregen_* endpoint ──
// slot_id is passed so the server knows which slot's state to splice into.
// UI params are read from the target model's tab inputs.
if (targetModel === 'taro') {
apiName = 'xregen_taro';
data = [seg_idx, stateJson, slot_id,
readComponentValue('taro_seed'), readComponentValue('taro_cfg'),
readComponentValue('taro_steps'), readComponentValue('taro_mode'),
readComponentValue('taro_cf_dur'), readComponentValue('taro_cf_db')];
} else if (targetModel === 'mma') {
apiName = 'xregen_mmaudio';
data = [seg_idx, stateJson, slot_id,
readComponentValue('mma_prompt'), readComponentValue('mma_neg'),
readComponentValue('mma_seed'), readComponentValue('mma_cfg'),
readComponentValue('mma_steps'),
readComponentValue('mma_cf_dur'), readComponentValue('mma_cf_db')];
} else {
apiName = 'xregen_hunyuan';
data = [seg_idx, stateJson, slot_id,
readComponentValue('hf_prompt'), readComponentValue('hf_neg'),
readComponentValue('hf_seed'), readComponentValue('hf_guidance'),
readComponentValue('hf_steps'), readComponentValue('hf_size'),
readComponentValue('hf_cf_dur'), readComponentValue('hf_cf_db')];
}
}
console.log('[fireRegen] calling api', apiName, 'seg', seg_idx);
// Snapshot current waveform HTML + video src before mutating anything,
// so we can restore on error (e.g. quota exceeded).
var _preRegenWaveHtml = null;
var _preRegenVideoSrc = null;
var waveElSnap = document.getElementById('slot_wave_' + slot_id);
if (waveElSnap) _preRegenWaveHtml = waveElSnap.innerHTML;
var vidElSnap = document.getElementById('slot_vid_' + slot_id);
if (vidElSnap) { var vSnap = vidElSnap.querySelector('video'); if (vSnap) _preRegenVideoSrc = vSnap.getAttribute('src'); }
// Show spinner immediately
const lbl = document.getElementById('wf_seglabel_' + slot_id);
if (lbl) lbl.textContent = 'Regenerating Seg ' + (seg_idx + 1) + '...';
const fnIndex = getFnIndex(apiName);
if (fnIndex === undefined) {
console.warn('[fireRegen] fn_index not found for api_name:', apiName);
return;
}
// Get ZeroGPU auth headers from the HF parent frame (same mechanism
// Gradio's own JS client uses), then fire the regen queue/join call.
// Falls back to user-supplied HF token if zerogpu headers aren't available.
_fetchZerogpuHeaders().then(function(zerogpuHeaders) {
var regenHeaders = {'Content-Type': 'application/json'};
var hasZerogpu = zerogpuHeaders && Object.keys(zerogpuHeaders).length > 0;
if (hasZerogpu) {
// Merge zerogpu headers (x-zerogpu-token, x-zerogpu-uuid)
for (var k in zerogpuHeaders) { regenHeaders[k] = zerogpuHeaders[k]; }
console.log('[fireRegen] using zerogpu headers from parent frame');
} else {
console.warn('[fireRegen] no zerogpu headers available — may use anonymous quota');
}
fetch('/gradio_api/queue/join', {
method: 'POST',
credentials: 'include',
headers: regenHeaders,
body: JSON.stringify({
data: data,
fn_index: fnIndex,
api_name: '/' + apiName,
session_hash: window.__gradio_session_hash__,
event_data: null,
trigger_id: null
})
}).then(function(r) { return r.json(); }).then(function(j) {
if (!j.event_id) { console.error('[fireRegen] no event_id:', j); return; }
console.log('[fireRegen] queued, event_id:', j.event_id);
_listenAndApply(j.event_id, slot_id, seg_idx, _preRegenWaveHtml, _preRegenVideoSrc);
}).catch(function(e) {
console.error('[fireRegen] fetch error:', e);
if (lbl) lbl.textContent = 'Error — see console';
var sb = document.getElementById('wf_statusbar_' + slot_id);
if (sb) { sb.style.color = '#e05252'; sb.textContent = '\u26a0 Request failed: ' + e.message; }
});
});
}
// Subscribe to Gradio SSE stream for an event and apply outputs to DOM.
// For regen handlers, output[0] = video update, output[1] = waveform HTML update.
function _applyVideoSrc(slot_id, newSrc) {
var vidEl = document.getElementById('slot_vid_' + slot_id);
if (!vidEl) return false;
var video = vidEl.querySelector('video');
if (!video) return false;
if (video.getAttribute('src') === newSrc) return true; // already correct
video.setAttribute('src', newSrc);
video.src = newSrc;
video.load();
console.log('[_applyVideoSrc] applied src to', 'slot_vid_' + slot_id, 'src:', newSrc.slice(-40));
return true;
}
// Toast notification — styled like ZeroGPU quota warnings.
function _showRegenToast(message, isError) {
var t = document.createElement('div');
t.style.cssText = 'position:fixed;bottom:24px;left:50%;transform:translateX(-50%);' +
'z-index:2147483647;padding:12px 20px;border-radius:8px;font-family:sans-serif;' +
'font-size:13px;max-width:520px;text-align:center;box-shadow:0 4px 20px rgba(0,0,0,.6);' +
'background:' + (isError ? '#7a1c1c' : '#1c4a1c') + ';color:#fff;' +
'border:1px solid ' + (isError ? '#c0392b' : '#27ae60') + ';' +
'pointer-events:none;';
t.textContent = message;
document.body.appendChild(t);
setTimeout(function() {
t.style.transition = 'opacity 0.5s';
t.style.opacity = '0';
setTimeout(function() { t.parentNode && t.parentNode.removeChild(t); }, 600);
}, isError ? 8000 : 3000);
}
function _listenAndApply(eventId, slot_id, seg_idx, preRegenWaveHtml, preRegenVideoSrc) {
var _pendingVideoSrc = null;
const es = new EventSource('/gradio_api/queue/data?session_hash=' + window.__gradio_session_hash__);
es.onmessage = function(e) {
var msg;
try { msg = JSON.parse(e.data); } catch(_) { return; }
if (msg.event_id !== eventId) return;
if (msg.msg === 'process_generating' || msg.msg === 'process_completed') {
var out = msg.output;
if (out && out.data) {
var vidUpdate = out.data[0];
var waveUpdate = out.data[1];
var newSrc = null;
if (vidUpdate) {
if (vidUpdate.value && vidUpdate.value.video && vidUpdate.value.video.url) newSrc = vidUpdate.value.video.url;
else if (vidUpdate.video && vidUpdate.video.url) newSrc = vidUpdate.video.url;
else if (vidUpdate.value && vidUpdate.value.url) newSrc = vidUpdate.value.url;
else if (typeof vidUpdate.value === 'string') newSrc = vidUpdate.value;
else if (vidUpdate.url) newSrc = vidUpdate.url;
}
if (newSrc) _pendingVideoSrc = newSrc;
var waveHtml = null;
if (waveUpdate) {
if (typeof waveUpdate === 'string') waveHtml = waveUpdate;
else if (waveUpdate.value && typeof waveUpdate.value === 'string') waveHtml = waveUpdate.value;
}
if (waveHtml) {
var waveEl = document.getElementById('slot_wave_' + slot_id);
if (waveEl) {
var inner = waveEl.querySelector('.prose') || waveEl.querySelector('div');
if (inner) inner.innerHTML = waveHtml;
else waveEl.innerHTML = waveHtml;
}
}
}
if (msg.msg === 'process_completed') {
es.close();
var errMsg = msg.output && msg.output.error;
var hadError = !!errMsg;
console.log('[fireRegen] completed for', slot_id, 'error:', hadError, errMsg || '');
var lbl = document.getElementById('wf_seglabel_' + slot_id);
if (hadError) {
var toastMsg = typeof errMsg === 'string' ? errMsg : JSON.stringify(errMsg);
if (preRegenWaveHtml !== null) {
var waveEl2 = document.getElementById('slot_wave_' + slot_id);
if (waveEl2) waveEl2.innerHTML = preRegenWaveHtml;
}
if (preRegenVideoSrc !== null) {
var vidElR = document.getElementById('slot_vid_' + slot_id);
if (vidElR) { var vR = vidElR.querySelector('video'); if (vR) { vR.setAttribute('src', preRegenVideoSrc); vR.src = preRegenVideoSrc; vR.load(); } }
}
var statusBar = document.getElementById('wf_statusbar_' + slot_id);
if (statusBar) {
statusBar.style.color = '#e05252';
statusBar.textContent = '\u26a0 ' + toastMsg;
setTimeout(function() { statusBar.style.color = '#888'; statusBar.textContent = 'Click a segment to regenerate \u00a0|\u00a0 Playhead syncs to video'; }, 15000);
}
if (lbl) lbl.textContent = 'Quota exceeded — try again later';
} else {
if (lbl) lbl.textContent = 'Done';
var src = _pendingVideoSrc;
if (src) {
_applyVideoSrc(slot_id, src);
setTimeout(function() { _applyVideoSrc(slot_id, src); }, 50);
setTimeout(function() { _applyVideoSrc(slot_id, src); }, 300);
setTimeout(function() { _applyVideoSrc(slot_id, src); }, 800);
var vidEl = document.getElementById('slot_vid_' + slot_id);
if (vidEl) {
var obs = new MutationObserver(function() { _applyVideoSrc(slot_id, src); });
obs.observe(vidEl, {subtree: true, attributes: true, attributeFilter: ['src'], childList: true});
setTimeout(function() { obs.disconnect(); }, 2000);
}
}
}
}
}
if (msg.msg === 'close_stream') { es.close(); }
};
es.onerror = function() { es.close(); };
}
// Shared popup element created once and reused across all slots
let _popup = null;
let _pendingSlot = null, _pendingIdx = null;
function ensurePopup() {
if (_popup) return _popup;
_popup = document.createElement('div');
_popup.style.cssText = 'display:none;position:fixed;z-index:99999;' +
'background:#2a2a2a;border:1px solid #555;border-radius:6px;' +
'padding:8px 12px;box-shadow:0 4px 16px rgba(0,0,0,.5);font-family:sans-serif;';
var btnStyle = 'color:#fff;border:none;border-radius:4px;padding:5px 10px;' +
'font-size:11px;cursor:pointer;flex:1;';
_popup.innerHTML =
'<div id="_wf_popup_lbl" style="color:#ccc;font-size:11px;margin-bottom:6px;white-space:nowrap;"></div>' +
'<div style="display:flex;gap:5px;">' +
'<button id="_wf_popup_taro" style="background:#1d6fa5;' + btnStyle + '">&#10227; TARO</button>' +
'<button id="_wf_popup_mma" style="background:#2d7a4a;' + btnStyle + '">&#10227; MMAudio</button>' +
'<button id="_wf_popup_hf" style="background:#7a3d8c;' + btnStyle + '">&#10227; Hunyuan</button>' +
'</div>';
document.body.appendChild(_popup);
['taro','mma','hf'].forEach(function(model) {
document.getElementById('_wf_popup_' + model).onclick = function(e) {
e.stopPropagation();
var slot = _pendingSlot, idx = _pendingIdx;
hidePopup();
if (slot !== null && idx !== null) fireRegen(slot, idx, model);
};
});
// Use bubble phase (false) so stopPropagation() on the button click prevents this from firing
document.addEventListener('click', function() { hidePopup(); }, false);
return _popup;
}
function hidePopup() {
if (_popup) _popup.style.display = 'none';
_pendingSlot = null; _pendingIdx = null;
}
window.addEventListener('message', function(e) {
const d = e.data;
console.log('[global msg] received type=' + (d && d.type) + ' action=' + (d && d.action));
if (!d || d.type !== 'wf_popup') return;
const p = ensurePopup();
if (d.action === 'hide') { hidePopup(); return; }
// action === 'show'
_pendingSlot = d.slot_id;
_pendingIdx = d.seg_idx;
const lbl = document.getElementById('_wf_popup_lbl');
if (lbl) lbl.textContent = 'Seg ' + (d.seg_idx + 1) +
' (' + d.t0.toFixed(2) + 's \u2013 ' + d.t1.toFixed(2) + 's)';
p.style.display = 'block';
p.style.left = (d.x + 10) + 'px';
p.style.top = (d.y + 10) + 'px';
requestAnimationFrame(function() {
const r = p.getBoundingClientRect();
if (r.right > window.innerWidth - 8) p.style.left = (window.innerWidth - r.width - 8) + 'px';
if (r.bottom > window.innerHeight - 8) p.style.top = (window.innerHeight - r.height - 8) + 'px';
});
});
}
"""
with gr.Blocks(title="Generate Audio for Video", css=_SLOT_CSS, js=_GLOBAL_JS) as demo:
gr.Markdown(
"# Generate Audio for Video\n"
"Choose a model and upload a video to generate synchronized audio.\n\n"
"| Model | Best for | Avoid for |\n"
"|-------|----------|-----------|\n"
"| **TARO** | Natural, physics-driven impacts — footsteps, collisions, water, wind, crackling fire. Excels when the sound is tightly coupled to visible motion without needing a text description. | Dialogue, music, or complex layered soundscapes where semantic context matters. |\n"
"| **MMAudio** | Mixed scenes where you want both visual grounding *and* semantic control via a text prompt — e.g. a busy street scene where you want to emphasize the rain rather than the traffic. Great for ambient textures and nuanced sound design. | Pure impact/foley shots where TARO's motion-coupling would be sharper, or cinematic music beds. |\n"
"| **HunyuanFoley** | Cinematic foley requiring high fidelity and explicit creative direction — dramatic SFX, layered environmental design, or any scene where you have a clear written description of the desired sound palette. | Quick one-shot clips where you don't want to write a prompt, or raw impact sounds where timing precision matters more than richness. |"
)
with gr.Tabs():
# ---------------------------------------------------------- #
# Tab 1 — TARO #
# ---------------------------------------------------------- #
with gr.Tab("TARO"):
with gr.Row():
with gr.Column(scale=1):
taro_video = gr.Video(label="Input Video")
taro_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0, elem_id="taro_seed")
taro_cfg = gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8.0, step=0.5, elem_id="taro_cfg")
taro_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1, elem_id="taro_steps")
taro_mode = gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde", elem_id="taro_mode")
taro_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=4, value=2, step=0.1, elem_id="taro_cf_dur")
taro_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3", elem_id="taro_cf_db")
taro_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1)
taro_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
(taro_slot_grps, taro_slot_vids,
taro_slot_waves) = _make_output_slots("taro")
# Hidden regen plumbing — render=False so no DOM element is created,
# avoiding Gradio's "Too many arguments" Svelte validation error.
# JS passes values directly via queue/join data array at the correct
# positional index (these show up as inputs to the fn but have no DOM).
taro_regen_seg = gr.Textbox(value="0", render=False)
taro_regen_state = gr.Textbox(value="", render=False)
for trigger in [taro_video, taro_steps, taro_cf_dur]:
trigger.change(
fn=_on_video_upload_taro,
inputs=[taro_video, taro_steps, taro_cf_dur],
outputs=[taro_samples],
)
taro_samples.change(
fn=_update_slot_visibility,
inputs=[taro_samples],
outputs=taro_slot_grps,
)
def _run_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n):
flat = generate_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n)
return _unpack_outputs(flat, n, "taro")
# Split group visibility into a separate .then() to avoid Gradio 5 SSR
# "Too many arguments" caused by including gr.Group in mixed output lists.
(taro_btn.click(
fn=_run_taro,
inputs=[taro_video, taro_seed, taro_cfg, taro_steps, taro_mode,
taro_cf_dur, taro_cf_db, taro_samples],
outputs=taro_slot_vids + taro_slot_waves,
).then(
fn=_update_slot_visibility,
inputs=[taro_samples],
outputs=taro_slot_grps,
))
# Per-slot regen handlers — JS calls /gradio_api/queue/join with
# fn_index (by api_name) + data=[seg_idx, state_json, video, ...params].
taro_regen_btns = _register_regen_handlers(
"taro", "taro", taro_regen_seg, taro_regen_state,
[taro_video, taro_seed, taro_cfg, taro_steps,
taro_mode, taro_cf_dur, taro_cf_db],
taro_slot_vids, taro_slot_waves,
)
# ---------------------------------------------------------- #
# Tab 2 — MMAudio #
# ---------------------------------------------------------- #
with gr.Tab("MMAudio"):
with gr.Row():
with gr.Column(scale=1):
mma_video = gr.Video(label="Input Video")
mma_prompt = gr.Textbox(label="Prompt", placeholder="e.g. footsteps on gravel", elem_id="mma_prompt")
mma_neg = gr.Textbox(label="Negative Prompt", value="music", placeholder="music, speech", elem_id="mma_neg")
mma_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0, elem_id="mma_seed")
mma_cfg = gr.Slider(label="CFG Strength", minimum=1, maximum=10, value=4.5, step=0.5, elem_id="mma_cfg")
mma_steps = gr.Slider(label="Steps", minimum=10, maximum=50, value=25, step=1, elem_id="mma_steps")
mma_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=4, value=2, step=0.1, elem_id="mma_cf_dur")
mma_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3", elem_id="mma_cf_db")
mma_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1)
mma_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
(mma_slot_grps, mma_slot_vids,
mma_slot_waves) = _make_output_slots("mma")
# Hidden regen plumbing — render=False so no DOM element is created,
# avoiding Gradio's "Too many arguments" Svelte validation error.
mma_regen_seg = gr.Textbox(value="0", render=False)
mma_regen_state = gr.Textbox(value="", render=False)
mma_samples.change(
fn=_update_slot_visibility,
inputs=[mma_samples],
outputs=mma_slot_grps,
)
def _run_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n):
flat = generate_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n)
return _unpack_outputs(flat, n, "mma")
(mma_btn.click(
fn=_run_mmaudio,
inputs=[mma_video, mma_prompt, mma_neg, mma_seed,
mma_cfg, mma_steps, mma_cf_dur, mma_cf_db, mma_samples],
outputs=mma_slot_vids + mma_slot_waves,
).then(
fn=_update_slot_visibility,
inputs=[mma_samples],
outputs=mma_slot_grps,
))
mma_regen_btns = _register_regen_handlers(
"mma", "mmaudio", mma_regen_seg, mma_regen_state,
[mma_video, mma_prompt, mma_neg, mma_seed,
mma_cfg, mma_steps, mma_cf_dur, mma_cf_db],
mma_slot_vids, mma_slot_waves,
)
# ---------------------------------------------------------- #
# Tab 3 — HunyuanVideoFoley #
# ---------------------------------------------------------- #
with gr.Tab("HunyuanFoley"):
with gr.Row():
with gr.Column(scale=1):
hf_video = gr.Video(label="Input Video")
hf_prompt = gr.Textbox(label="Prompt", placeholder="e.g. rain hitting a metal roof", elem_id="hf_prompt")
hf_neg = gr.Textbox(label="Negative Prompt", value="noisy, harsh", elem_id="hf_neg")
hf_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0, elem_id="hf_seed")
hf_guidance = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, value=4.5, step=0.5, elem_id="hf_guidance")
hf_steps = gr.Slider(label="Steps", minimum=10, maximum=100, value=50, step=5, elem_id="hf_steps")
hf_size = gr.Radio(label="Model Size", choices=["xl", "xxl"], value="xxl", elem_id="hf_size")
hf_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=4, value=2, step=0.1, elem_id="hf_cf_dur")
hf_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3", elem_id="hf_cf_db")
hf_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1)
hf_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
(hf_slot_grps, hf_slot_vids,
hf_slot_waves) = _make_output_slots("hf")
# Hidden regen plumbing — render=False so no DOM element is created,
# avoiding Gradio's "Too many arguments" Svelte validation error.
hf_regen_seg = gr.Textbox(value="0", render=False)
hf_regen_state = gr.Textbox(value="", render=False)
hf_samples.change(
fn=_update_slot_visibility,
inputs=[hf_samples],
outputs=hf_slot_grps,
)
def _run_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n):
flat = generate_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n)
return _unpack_outputs(flat, n, "hf")
(hf_btn.click(
fn=_run_hunyuan,
inputs=[hf_video, hf_prompt, hf_neg, hf_seed,
hf_guidance, hf_steps, hf_size, hf_cf_dur, hf_cf_db, hf_samples],
outputs=hf_slot_vids + hf_slot_waves,
).then(
fn=_update_slot_visibility,
inputs=[hf_samples],
outputs=hf_slot_grps,
))
hf_regen_btns = _register_regen_handlers(
"hf", "hunyuan", hf_regen_seg, hf_regen_state,
[hf_video, hf_prompt, hf_neg, hf_seed,
hf_guidance, hf_steps, hf_size, hf_cf_dur, hf_cf_db],
hf_slot_vids, hf_slot_waves,
)
# ---- Cross-tab video sync ----
_sync = lambda v: (gr.update(value=v), gr.update(value=v))
taro_video.change(fn=_sync, inputs=[taro_video], outputs=[mma_video, hf_video])
mma_video.change(fn=_sync, inputs=[mma_video], outputs=[taro_video, hf_video])
hf_video.change(fn=_sync, inputs=[hf_video], outputs=[taro_video, mma_video])
# ---- Cross-model regen endpoints ----
# render=False inputs/outputs: no DOM elements created, no SSR validation impact.
# JS calls these via /gradio_api/queue/join using the api_name and applies
# the returned video+waveform directly to the target slot's DOM elements.
_xr_seg = gr.Textbox(value="0", render=False)
_xr_state = gr.Textbox(value="", render=False)
_xr_slot_id = gr.Textbox(value="", render=False)
# Dummy outputs for xregen events: must be real rendered components so Gradio
# can look them up in session state during postprocess_data. The JS listener
# (_listenAndApply) applies the returned video/HTML directly to the correct
# slot's DOM elements and ignores Gradio's own output routing, so these
# slot-0 components simply act as sinks — their displayed value is overwritten
# by the real JS update immediately after.
_xr_dummy_vid = taro_slot_vids[0]
_xr_dummy_wave = taro_slot_waves[0]
# TARO cross-model regen inputs: seg_idx, state_json, slot_id, seed, cfg, steps, mode, cf_dur, cf_db
_xr_taro_seed = gr.Textbox(value="-1", render=False)
_xr_taro_cfg = gr.Textbox(value="7.5", render=False)
_xr_taro_steps = gr.Textbox(value="25", render=False)
_xr_taro_mode = gr.Textbox(value="sde", render=False)
_xr_taro_cfd = gr.Textbox(value="2", render=False)
_xr_taro_cfdb = gr.Textbox(value="3", render=False)
gr.Button(render=False).click(
fn=xregen_taro,
inputs=[_xr_seg, _xr_state, _xr_slot_id,
_xr_taro_seed, _xr_taro_cfg, _xr_taro_steps,
_xr_taro_mode, _xr_taro_cfd, _xr_taro_cfdb],
outputs=[_xr_dummy_vid, _xr_dummy_wave],
api_name="xregen_taro",
)
# MMAudio cross-model regen inputs: seg_idx, state_json, slot_id, prompt, neg, seed, cfg, steps, cf_dur, cf_db
_xr_mma_prompt = gr.Textbox(value="", render=False)
_xr_mma_neg = gr.Textbox(value="", render=False)
_xr_mma_seed = gr.Textbox(value="-1", render=False)
_xr_mma_cfg = gr.Textbox(value="4.5", render=False)
_xr_mma_steps = gr.Textbox(value="25", render=False)
_xr_mma_cfd = gr.Textbox(value="2", render=False)
_xr_mma_cfdb = gr.Textbox(value="3", render=False)
gr.Button(render=False).click(
fn=xregen_mmaudio,
inputs=[_xr_seg, _xr_state, _xr_slot_id,
_xr_mma_prompt, _xr_mma_neg, _xr_mma_seed,
_xr_mma_cfg, _xr_mma_steps, _xr_mma_cfd, _xr_mma_cfdb],
outputs=[_xr_dummy_vid, _xr_dummy_wave],
api_name="xregen_mmaudio",
)
# HunyuanFoley cross-model regen inputs: seg_idx, state_json, slot_id, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db
_xr_hf_prompt = gr.Textbox(value="", render=False)
_xr_hf_neg = gr.Textbox(value="", render=False)
_xr_hf_seed = gr.Textbox(value="-1", render=False)
_xr_hf_guide = gr.Textbox(value="4.5", render=False)
_xr_hf_steps = gr.Textbox(value="50", render=False)
_xr_hf_size = gr.Textbox(value="xxl", render=False)
_xr_hf_cfd = gr.Textbox(value="2", render=False)
_xr_hf_cfdb = gr.Textbox(value="3", render=False)
gr.Button(render=False).click(
fn=xregen_hunyuan,
inputs=[_xr_seg, _xr_state, _xr_slot_id,
_xr_hf_prompt, _xr_hf_neg, _xr_hf_seed,
_xr_hf_guide, _xr_hf_steps, _xr_hf_size,
_xr_hf_cfd, _xr_hf_cfdb],
outputs=[_xr_dummy_vid, _xr_dummy_wave],
api_name="xregen_hunyuan",
)
# NOTE: ZeroGPU quota attribution is handled via postMessage("zerogpu-headers")
# to the HF parent frame — the same mechanism Gradio's own JS client uses.
# This replaced the old x-ip-token relay approach which was unreliable.
print("[startup] app.py fully loaded — regen handlers registered, SSR disabled")
demo.queue(max_size=10).launch(ssr_mode=False, height=900, allowed_paths=["/tmp"])