""" Generate Audio for Video — multi-model Gradio app. Supported models ---------------- TARO – video-conditioned diffusion via CAVP + onset features (16 kHz, 8.192 s window) MMAudio – multimodal flow-matching with CLIP/Synchformer + text prompt (44 kHz, 8 s window) HunyuanFoley – text-guided foley via SigLIP2 + Synchformer + CLAP (48 kHz, up to 15 s) """ import html as _html import math import os import sys import json import shutil import tempfile import random import threading import time from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path import torch import numpy as np import torchaudio import ffmpeg import spaces import gradio as gr from huggingface_hub import hf_hub_download, snapshot_download # ================================================================== # # CHECKPOINT CONFIGURATION # # ================================================================== # CKPT_REPO_ID = "JackIsNotInTheBox/Generate_Audio_for_Video_Checkpoints" CACHE_DIR = "/tmp/model_ckpts" os.makedirs(CACHE_DIR, exist_ok=True) # ---- Local directories that must exist before parallel downloads start ---- MMAUDIO_WEIGHTS_DIR = Path(CACHE_DIR) / "MMAudio" / "weights" MMAUDIO_EXT_DIR = Path(CACHE_DIR) / "MMAudio" / "ext_weights" HUNYUAN_MODEL_DIR = Path(CACHE_DIR) / "HunyuanFoley" MMAUDIO_WEIGHTS_DIR.mkdir(parents=True, exist_ok=True) MMAUDIO_EXT_DIR.mkdir(parents=True, exist_ok=True) HUNYUAN_MODEL_DIR.mkdir(parents=True, exist_ok=True) # ------------------------------------------------------------------ # # Parallel checkpoint + model downloads # # All downloads are I/O-bound (network), so running them in threads # # cuts Space cold-start time roughly proportional to the number of # # independent groups (previously sequential, now concurrent). # # hf_hub_download / snapshot_download are thread-safe. # # ------------------------------------------------------------------ # def _dl_taro(): """Download TARO .ckpt/.pt files and return their local paths.""" c = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/cavp_epoch66.ckpt", cache_dir=CACHE_DIR) o = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/onset_model.ckpt", cache_dir=CACHE_DIR) t = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/taro_ckpt.pt", cache_dir=CACHE_DIR) print("TARO checkpoints downloaded.") return c, o, t def _dl_mmaudio(): """Download MMAudio .pth files and return their local paths.""" m = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/mmaudio_large_44k_v2.pth", cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_WEIGHTS_DIR), local_dir_use_symlinks=False) v = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/v1-44.pth", cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False) s = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/synchformer_state_dict.pth", cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False) print("MMAudio checkpoints downloaded.") return m, v, s def _dl_hunyuan(): """Download HunyuanVideoFoley .pth files.""" hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/hunyuanvideo_foley.pth", cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False) hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/vae_128d_48k.pth", cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False) hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/synchformer_state_dict.pth", cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False) print("HunyuanVideoFoley checkpoints downloaded.") def _dl_clap(): """Pre-download CLAP so from_pretrained() hits local cache inside the ZeroGPU worker.""" snapshot_download(repo_id="laion/larger_clap_general") print("CLAP model pre-downloaded.") def _dl_clip(): """Pre-download MMAudio's CLIP model (~3.95 GB) to avoid GPU-window budget drain.""" snapshot_download(repo_id="apple/DFN5B-CLIP-ViT-H-14-384") print("MMAudio CLIP model pre-downloaded.") def _dl_audioldm2(): """Pre-download AudioLDM2 VAE/vocoder used by TARO's from_pretrained() calls.""" snapshot_download(repo_id="cvssp/audioldm2") print("AudioLDM2 pre-downloaded.") def _dl_bigvgan(): """Pre-download BigVGAN vocoder (~489 MB) used by MMAudio.""" snapshot_download(repo_id="nvidia/bigvgan_v2_44khz_128band_512x") print("BigVGAN vocoder pre-downloaded.") print("[startup] Starting parallel checkpoint + model downloads…") _t_dl_start = time.perf_counter() with ThreadPoolExecutor(max_workers=7) as _pool: _fut_taro = _pool.submit(_dl_taro) _fut_mmaudio = _pool.submit(_dl_mmaudio) _fut_hunyuan = _pool.submit(_dl_hunyuan) _fut_clap = _pool.submit(_dl_clap) _fut_clip = _pool.submit(_dl_clip) _fut_aldm2 = _pool.submit(_dl_audioldm2) _fut_bigvgan = _pool.submit(_dl_bigvgan) # Raise any download exceptions immediately for _fut in as_completed([_fut_taro, _fut_mmaudio, _fut_hunyuan, _fut_clap, _fut_clip, _fut_aldm2, _fut_bigvgan]): _fut.result() cavp_ckpt_path, onset_ckpt_path, taro_ckpt_path = _fut_taro.result() mmaudio_model_path, mmaudio_vae_path, mmaudio_synchformer_path = _fut_mmaudio.result() print(f"[startup] All downloads done in {time.perf_counter() - _t_dl_start:.1f}s") # ================================================================== # # SHARED CONSTANTS / HELPERS # # ================================================================== # # CPU → GPU context passing via function-name-keyed global store. # # Problem: ZeroGPU runs @spaces.GPU functions on its own worker thread, so # threading.local() is invisible to the GPU worker. Passing ctx as a # function argument exposes it to Gradio's API endpoint, causing # "Too many arguments" errors. # # Solution: store context in a plain global dict keyed by function name. # A per-key Lock serialises concurrent callers for the same function # (ZeroGPU is already synchronous — the wrapper blocks until the GPU fn # returns — so in practice only one call per GPU fn is in-flight at a time). # The global dict is readable from any thread. _GPU_CTX: dict = {} _GPU_CTX_LOCK = threading.Lock() def _ctx_store(fn_name: str, data: dict) -> None: """Store *data* under *fn_name* key (overwrites previous).""" with _GPU_CTX_LOCK: _GPU_CTX[fn_name] = data def _ctx_load(fn_name: str) -> dict: """Pop and return the context dict stored under *fn_name*.""" with _GPU_CTX_LOCK: return _GPU_CTX.pop(fn_name, {}) MAX_SLOTS = 8 # max parallel generation slots shown in UI MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≤ ~64 s at 8 s/seg) # Segment overlay palette — shared between _build_waveform_html and _build_regen_pending_html SEG_COLORS = [ "rgba(100,180,255,{a})", "rgba(255,160,100,{a})", "rgba(120,220,140,{a})", "rgba(220,120,220,{a})", "rgba(255,220,80,{a})", "rgba(80,220,220,{a})", "rgba(255,100,100,{a})", "rgba(180,255,180,{a})", ] # ------------------------------------------------------------------ # # Micro-helpers that eliminate repeated boilerplate across the file # # ------------------------------------------------------------------ # def _ensure_syspath(subdir: str) -> str: """Add *subdir* (relative to app.py) to sys.path if not already present. Returns the absolute path for convenience.""" p = os.path.join(os.path.dirname(os.path.abspath(__file__)), subdir) if p not in sys.path: sys.path.insert(0, p) return p def _get_device_and_dtype() -> tuple: """Return (device, weight_dtype) pair used by all GPU functions.""" device = "cuda" if torch.cuda.is_available() else "cpu" return device, torch.bfloat16 def _extract_segment_clip(silent_video: str, seg_start: float, seg_dur: float, output_path: str) -> str: """Stream-copy a segment from *silent_video* to *output_path*. Returns *output_path*.""" ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output( output_path, vcodec="copy", an=None ).run(overwrite_output=True, quiet=True) return output_path # Per-slot reentrant locks — prevent concurrent regens on the same slot from # producing a race condition where the second regen reads stale state # (the shared seg_state textbox hasn't been updated yet by the first regen). # Locks are keyed by slot_id string (e.g. "taro_0", "mma_2"). _SLOT_LOCKS: dict = {} _SLOT_LOCKS_MUTEX = threading.Lock() def _get_slot_lock(slot_id: str) -> threading.Lock: with _SLOT_LOCKS_MUTEX: if slot_id not in _SLOT_LOCKS: _SLOT_LOCKS[slot_id] = threading.Lock() return _SLOT_LOCKS[slot_id] def set_global_seed(seed: int) -> None: np.random.seed(seed % (2**32)) random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) def get_random_seed() -> int: return random.randint(0, 2**32 - 1) def _resolve_seed(seed_val) -> int: """Normalise seed_val to a non-negative int. Negative values (UI default 'random') produce a fresh random seed.""" seed_val = int(seed_val) return seed_val if seed_val >= 0 else get_random_seed() def get_video_duration(video_path: str) -> float: """Return video duration in seconds (CPU only).""" probe = ffmpeg.probe(video_path) return float(probe["format"]["duration"]) def strip_audio_from_video(video_path: str, output_path: str) -> None: """Write a silent copy of *video_path* to *output_path* (stream-copy, no re-encode).""" ffmpeg.input(video_path).output(output_path, vcodec="copy", an=None).run( overwrite_output=True, quiet=True ) def _transcode_for_browser(video_path: str) -> str: """Re-encode uploaded video to H.264/AAC MP4 so the browser preview widget can play it. Returns a NEW path in a fresh /tmp/gradio/ subdirectory. Gradio probes the returned path fresh, sees H.264, and serves it directly without its own slow fallback converter. The in-place overwrite approach loses the race because Gradio probes the original path at upload time before this callback runs. Only called on upload — not during generation. """ if video_path is None: return video_path try: probe = ffmpeg.probe(video_path) has_audio = any(s["codec_type"] == "audio" for s in probe.get("streams", [])) # Check if already H.264 — skip transcode if so video_streams = [s for s in probe.get("streams", []) if s["codec_type"] == "video"] if video_streams and video_streams[0].get("codec_name") == "h264": print(f"[transcode_for_browser] already H.264, skipping") return video_path # Write the H.264 output into the SAME directory as the original upload. # Gradio's file server only allows paths under dirs it registered — the # upload dir is already allowed, so a sibling file there will serve fine. import os as _os upload_dir = _os.path.dirname(video_path) stem = _os.path.splitext(_os.path.basename(video_path))[0] out_path = _os.path.join(upload_dir, stem + "_h264.mp4") kwargs = dict( vcodec="libx264", preset="fast", crf=18, pix_fmt="yuv420p", movflags="+faststart", ) if has_audio: kwargs["acodec"] = "aac" kwargs["audio_bitrate"] = "128k" else: kwargs["an"] = None # map 0:v:0 explicitly to skip non-video streams (e.g. data/timecode tracks) ffmpeg.input(video_path).output(out_path, map="0:v:0", **kwargs).run( overwrite_output=True, quiet=True ) print(f"[transcode_for_browser] transcoded to H.264: {out_path}") return out_path except Exception as e: print(f"[transcode_for_browser] failed, using original: {e}") return video_path # ------------------------------------------------------------------ # # Temp directory registry — tracks dirs for cleanup on new generation # # ------------------------------------------------------------------ # _TEMP_DIRS: list = [] # list of tmp_dir paths created by generate_* _TEMP_DIRS_MAX = 10 # keep at most this many; older ones get cleaned up def _register_tmp_dir(tmp_dir: str) -> str: """Register a temp dir so it can be cleaned up when newer ones replace it.""" _TEMP_DIRS.append(tmp_dir) while len(_TEMP_DIRS) > _TEMP_DIRS_MAX: old = _TEMP_DIRS.pop(0) try: shutil.rmtree(old, ignore_errors=True) print(f"[cleanup] Removed old temp dir: {old}") except Exception: pass return tmp_dir def _save_seg_wavs(wavs: list[np.ndarray], tmp_dir: str, prefix: str) -> list[str]: """Save a list of numpy wav arrays to .npy files, return list of paths. This avoids serialising large float arrays into JSON/HTML data-state.""" paths = [] for i, w in enumerate(wavs): p = os.path.join(tmp_dir, f"{prefix}_seg{i}.npy") np.save(p, w) paths.append(p) return paths def _load_seg_wavs(paths: list[str]) -> list[np.ndarray]: """Load segment wav arrays from .npy file paths.""" return [np.load(p) for p in paths] # ------------------------------------------------------------------ # # Shared model-loading helpers (deduplicate generate / regen code) # # ------------------------------------------------------------------ # def _load_taro_models(device, weight_dtype): """Load TARO MMDiT + AudioLDM2 VAE/vocoder. Returns (model_net, vae, vocoder, latents_scale).""" from TARO.models import MMDiT from diffusers import AutoencoderKL from transformers import SpeechT5HifiGan model_net = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device) model_net.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"]) model_net.eval().to(weight_dtype) vae = AutoencoderKL.from_pretrained("cvssp/audioldm2", subfolder="vae").to(device).eval() vocoder = SpeechT5HifiGan.from_pretrained("cvssp/audioldm2", subfolder="vocoder").to(device) latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device) return model_net, vae, vocoder, latents_scale def _load_taro_feature_extractors(device): """Load CAVP + onset extractors. Returns (extract_cavp, onset_model).""" from TARO.cavp_util import Extract_CAVP_Features from TARO.onset_util import VideoOnsetNet extract_cavp = Extract_CAVP_Features( device=device, config_path="TARO/cavp/cavp.yaml", ckpt_path=cavp_ckpt_path, ) raw_sd = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"] onset_sd = {} for k, v in raw_sd.items(): if "model.net.model" in k: k = k.replace("model.net.model", "net.model") elif "model.fc." in k: k = k.replace("model.fc", "fc") onset_sd[k] = v onset_model = VideoOnsetNet(pretrained=False).to(device) onset_model.load_state_dict(onset_sd) onset_model.eval() return extract_cavp, onset_model def _load_mmaudio_models(device, dtype): """Load MMAudio net + feature_utils. Returns (net, feature_utils, model_cfg, seq_cfg).""" from mmaudio.eval_utils import all_model_cfg from mmaudio.model.networks import get_my_mmaudio from mmaudio.model.utils.features_utils import FeaturesUtils model_cfg = all_model_cfg["large_44k_v2"] model_cfg.model_path = Path(mmaudio_model_path) model_cfg.vae_path = Path(mmaudio_vae_path) model_cfg.synchformer_ckpt = Path(mmaudio_synchformer_path) model_cfg.bigvgan_16k_path = None seq_cfg = model_cfg.seq_cfg net = get_my_mmaudio(model_cfg.model_name).to(device, dtype).eval() net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True)) feature_utils = FeaturesUtils( tod_vae_ckpt=str(model_cfg.vae_path), synchformer_ckpt=str(model_cfg.synchformer_ckpt), enable_conditions=True, mode=model_cfg.mode, bigvgan_vocoder_ckpt=None, need_vae_encoder=False, ).to(device, dtype).eval() return net, feature_utils, model_cfg, seq_cfg def _load_hunyuan_model(device, model_size): """Load HunyuanFoley model dict + config. Returns (model_dict, cfg).""" from hunyuanvideo_foley.utils.model_utils import load_model model_size = model_size.lower() config_map = { "xl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xl.yaml", "xxl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xxl.yaml", } config_path = config_map.get(model_size, config_map["xxl"]) hunyuan_weights_dir = str(HUNYUAN_MODEL_DIR / "HunyuanVideo-Foley") print(f"[HunyuanFoley] Loading {model_size.upper()} model from {hunyuan_weights_dir}") return load_model(hunyuan_weights_dir, config_path, device, enable_offload=False, model_size=model_size) def mux_video_audio(silent_video: str, audio_path: str, output_path: str, model: str = None) -> None: """Mux a silent video with an audio file into *output_path*. For HunyuanFoley (*model*="hunyuan") we use its own merge_audio_video which handles its specific ffmpeg quirks; all other models use stream-copy muxing. """ if model == "hunyuan": _ensure_syspath("HunyuanVideo-Foley") from hunyuanvideo_foley.utils.media_utils import merge_audio_video merge_audio_video(audio_path, silent_video, output_path) else: v_in = ffmpeg.input(silent_video) a_in = ffmpeg.input(audio_path) ffmpeg.output( v_in["v:0"], a_in["a:0"], output_path, vcodec="libx264", preset="fast", crf=18, pix_fmt="yuv420p", acodec="aac", audio_bitrate="128k", movflags="+faststart", ).run(overwrite_output=True, quiet=True) # ------------------------------------------------------------------ # # Shared sliding-window segmentation and crossfade helpers # # Used by all three models (TARO, MMAudio, HunyuanFoley). # # ------------------------------------------------------------------ # def _build_segments(total_dur_s: float, window_s: float, crossfade_s: float) -> list[tuple[float, float]]: """Return list of (start, end) pairs covering *total_dur_s*. Every segment uses the full *window_s* inference window. Segments are equally spaced so every overlap is identical, guaranteeing the crossfade setting is honoured at every boundary with no raw bleed. Algorithm --------- 1. Clamp crossfade_s so the step stays positive. 2. Find the minimum n such that n segments of *window_s* cover *total_dur_s* with overlap ≥ crossfade_s at every boundary: n = ceil((total_dur_s - crossfade_s) / (window_s - crossfade_s)) 3. Compute equal spacing: step = (total_dur_s - window_s) / (n - 1) so that every gap is identical and the last segment ends exactly at total_dur_s. 4. Every segment is exactly *window_s* wide. The trailing audio of each segment beyond its contact edge is discarded in _stitch_wavs. """ crossfade_s = min(crossfade_s, window_s * 0.5) if total_dur_s <= window_s: return [(0.0, total_dur_s)] step_min = window_s - crossfade_s # minimum step to honour crossfade n = math.ceil((total_dur_s - crossfade_s) / step_min) n = max(n, 2) # Equal step so first seg starts at 0 and last seg ends at total_dur_s step_s = (total_dur_s - window_s) / (n - 1) return [(i * step_s, i * step_s + window_s) for i in range(n)] def _cf_join(a: np.ndarray, b: np.ndarray, crossfade_s: float, db_boost: float, sr: int) -> np.ndarray: """Equal-power crossfade join. Works for both mono (T,) and stereo (C, T) arrays. Stereo arrays are expected in (channels, samples) layout. db_boost is applied to the overlap region as a whole (after blending), so it compensates for the -3 dB equal-power dip without doubling amplitude. Applying gain to each side independently (the common mistake) causes a +3 dB loudness bump at the seam — this version avoids that.""" stereo = a.ndim == 2 n_a = a.shape[1] if stereo else len(a) n_b = b.shape[1] if stereo else len(b) cf = min(int(round(crossfade_s * sr)), n_a, n_b) if cf <= 0: return np.concatenate([a, b], axis=1 if stereo else 0) gain = 10 ** (db_boost / 20.0) t = np.linspace(0.0, 1.0, cf, dtype=np.float32) fade_out = np.cos(t * np.pi / 2) # 1 → 0 fade_in = np.sin(t * np.pi / 2) # 0 → 1 if stereo: # Blend first, then apply boost to the overlap region as a unit overlap = (a[:, -cf:] * fade_out + b[:, :cf] * fade_in) * gain return np.concatenate([a[:, :-cf], overlap, b[:, cf:]], axis=1) else: overlap = (a[-cf:] * fade_out + b[:cf] * fade_in) * gain return np.concatenate([a[:-cf], overlap, b[cf:]]) # ================================================================== # # TARO # # ================================================================== # # Constants sourced from TARO/infer.py and TARO/models.py: # SR=16000, TRUNCATE=131072 → 8.192 s window # TRUNCATE_FRAME = 4 fps × 131072/16000 = 32 CAVP frames per window # TRUNCATE_ONSET = 120 onset frames per window # latent shape: (1, 8, 204, 16) — fixed by MMDiT architecture # latents_scale: [0.18215]*8 — AudioLDM2 VAE scale factor # ================================================================== # # ================================================================== # # MODEL CONSTANTS & CONFIGURATION REGISTRY # # ================================================================== # # All per-model numeric constants live here — MODEL_CONFIGS is the # # single source of truth consumed by duration estimation, segmentation,# # and the UI. Standalone names kept only where other code references # # them by name (TARO geometry, TARGET_SR, GPU_DURATION_CAP). # # ================================================================== # # TARO geometry — referenced directly in _taro_infer_segment TARO_SR = 16000 TARO_TRUNCATE = 131072 TARO_FPS = 4 TARO_TRUNCATE_FRAME = int(TARO_FPS * TARO_TRUNCATE / TARO_SR) # 32 TARO_TRUNCATE_ONSET = 120 TARO_MODEL_DUR = TARO_TRUNCATE / TARO_SR # 8.192 s GPU_DURATION_CAP = 300 # hard cap per @spaces.GPU call — never reserve more than this MODEL_CONFIGS = { "taro": { "window_s": TARO_MODEL_DUR, # 8.192 s "sr": TARO_SR, # 16000 (output resampled to TARGET_SR) "secs_per_step": 0.025, # measured 0.023 s/step on H200 "load_overhead": 15, # model load + CAVP feature extraction "tab_prefix": "taro", "label": "TARO", "regen_fn": None, # set after function definitions (avoids forward-ref) }, "mmaudio": { "window_s": 8.0, # MMAudio's fixed generation window "sr": 48000, # resampled from 44100 in post-processing "secs_per_step": 0.25, # measured 0.230 s/step on H200 "load_overhead": 30, # 15s warm + 15s model init "tab_prefix": "mma", "label": "MMAudio", "regen_fn": None, }, "hunyuan": { "window_s": 15.0, # HunyuanFoley max video duration "sr": 48000, "secs_per_step": 0.35, # measured 0.328 s/step on H200 "load_overhead": 55, # ~55s to load the 10 GB XXL weights "tab_prefix": "hf", "label": "HunyuanFoley", "regen_fn": None, }, } # Convenience aliases used only in the TARO inference path TARO_SECS_PER_STEP = MODEL_CONFIGS["taro"]["secs_per_step"] MMAUDIO_WINDOW = MODEL_CONFIGS["mmaudio"]["window_s"] MMAUDIO_SECS_PER_STEP = MODEL_CONFIGS["mmaudio"]["secs_per_step"] HUNYUAN_MAX_DUR = MODEL_CONFIGS["hunyuan"]["window_s"] HUNYUAN_SECS_PER_STEP = MODEL_CONFIGS["hunyuan"]["secs_per_step"] def _clamp_duration(secs: float, label: str) -> int: """Clamp a raw GPU-seconds estimate to [60, GPU_DURATION_CAP] and log it.""" result = min(GPU_DURATION_CAP, max(60, int(secs))) print(f"[duration] {label}: {secs:.0f}s raw → {result}s reserved") return result def _estimate_gpu_duration(model_key: str, num_samples: int, num_steps: int, total_dur_s: float = None, crossfade_s: float = 0, video_file: str = None) -> int: """Estimate GPU seconds for a full generation call. Formula: num_samples × n_segs × num_steps × secs_per_step + load_overhead """ cfg = MODEL_CONFIGS[model_key] try: if total_dur_s is None: total_dur_s = get_video_duration(video_file) n_segs = len(_build_segments(total_dur_s, cfg["window_s"], float(crossfade_s))) except Exception: n_segs = 1 secs = int(num_samples) * n_segs * int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"] print(f"[duration] {cfg['label']}: {int(num_samples)}samp × {n_segs}seg × " f"{int(num_steps)}steps → {secs:.0f}s → capped ", end="") return _clamp_duration(secs, cfg["label"]) def _estimate_regen_duration(model_key: str, num_steps: int) -> int: """Estimate GPU seconds for a single-segment regen call.""" cfg = MODEL_CONFIGS[model_key] secs = int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"] print(f"[duration] {cfg['label']} regen: 1 seg × {int(num_steps)} steps → ", end="") return _clamp_duration(secs, f"{cfg['label']} regen") _TARO_CACHE_MAXLEN = 16 # evict oldest entries beyond this limit _TARO_INFERENCE_CACHE: dict = {} # keyed by (video_file, seed, cfg, steps, mode, crossfade_s) _TARO_CACHE_LOCK = threading.Lock() def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: float) -> int: n_segs = len(_build_segments(total_dur_s, TARO_MODEL_DUR, crossfade_s)) time_per_seg = num_steps * TARO_SECS_PER_STEP max_s = int(600.0 / (n_segs * time_per_seg)) return max(1, min(max_s, MAX_SLOTS)) def _taro_duration(video_file, seed_val, cfg_scale, num_steps, mode, crossfade_s, crossfade_db, num_samples): """Pre-GPU callable — must match _taro_gpu_infer's input order exactly.""" return _estimate_gpu_duration("taro", int(num_samples), int(num_steps), video_file=video_file, crossfade_s=crossfade_s) def _taro_infer_segment( model, vae, vocoder, cavp_feats_full, onset_feats_full, seg_start_s: float, seg_end_s: float, device, weight_dtype, cfg_scale: float, num_steps: int, mode: str, latents_scale, euler_sampler, euler_maruyama_sampler, ) -> np.ndarray: """Single-segment TARO inference. Returns wav array trimmed to segment length.""" # CAVP features (4 fps) cavp_start = int(round(seg_start_s * TARO_FPS)) cavp_slice = cavp_feats_full[cavp_start : cavp_start + TARO_TRUNCATE_FRAME] if cavp_slice.shape[0] < TARO_TRUNCATE_FRAME: pad = np.zeros( (TARO_TRUNCATE_FRAME - cavp_slice.shape[0],) + cavp_slice.shape[1:], dtype=cavp_slice.dtype, ) cavp_slice = np.concatenate([cavp_slice, pad], axis=0) video_feats = torch.from_numpy(cavp_slice).unsqueeze(0).to(device, weight_dtype) # Onset features (onset_fps = TRUNCATE_ONSET / MODEL_DUR ≈ 14.65 fps) onset_fps = TARO_TRUNCATE_ONSET / TARO_MODEL_DUR onset_start = int(round(seg_start_s * onset_fps)) onset_slice = onset_feats_full[onset_start : onset_start + TARO_TRUNCATE_ONSET] if onset_slice.shape[0] < TARO_TRUNCATE_ONSET: onset_slice = np.pad( onset_slice, ((0, TARO_TRUNCATE_ONSET - onset_slice.shape[0]),), mode="constant", ) onset_feats_t = torch.from_numpy(onset_slice).unsqueeze(0).to(device, weight_dtype) # Latent noise — shape matches MMDiT architecture (in_channels=8, 204×16 spatial) z = torch.randn(1, model.in_channels, 204, 16, device=device, dtype=weight_dtype) sampling_kwargs = dict( model=model, latents=z, y=onset_feats_t, context=video_feats, num_steps=int(num_steps), heun=False, cfg_scale=float(cfg_scale), guidance_low=0.0, guidance_high=0.7, path_type="linear", ) with torch.no_grad(): samples = (euler_maruyama_sampler if mode == "sde" else euler_sampler)(**sampling_kwargs) # samplers return (output_tensor, zs) — index [0] for the audio latent if isinstance(samples, tuple): samples = samples[0] # Decode: AudioLDM2 VAE → mel → vocoder → waveform samples = vae.decode(samples / latents_scale).sample wav = vocoder(samples.squeeze().float()).detach().cpu().numpy() return wav # full window — _stitch_wavs handles contact-edge trimming # ================================================================== # # TARO 16 kHz → 48 kHz upsample # # ================================================================== # # TARO generates at 16 kHz; all other models output at 44.1/48 kHz. # We upsample via sinc resampling (torchaudio, CPU-only) so the final # stitched audio is uniformly at 48 kHz across all three models. TARGET_SR = 48000 # unified output sample rate for all three models TARO_SR_OUT = TARGET_SR def _resample_to_target(wav: np.ndarray, src_sr: int, dst_sr: int = None) -> np.ndarray: """Resample *wav* (mono or stereo numpy float32) from *src_sr* to *dst_sr*. *dst_sr* defaults to TARGET_SR (48 kHz). No-op if src_sr == dst_sr. Uses torchaudio Kaiser-windowed sinc resampling — CPU-only, ZeroGPU-safe. """ if dst_sr is None: dst_sr = TARGET_SR if src_sr == dst_sr: return wav stereo = wav.ndim == 2 t = torch.from_numpy(np.ascontiguousarray(wav.astype(np.float32))) if not stereo: t = t.unsqueeze(0) # [1, T] t = torchaudio.functional.resample(t, src_sr, dst_sr) if not stereo: t = t.squeeze(0) # [T] return t.numpy() def _upsample_taro(wav_16k: np.ndarray) -> np.ndarray: """Upsample a mono 16 kHz numpy array to 48 kHz via sinc resampling (CPU). torchaudio.functional.resample uses a Kaiser-windowed sinc filter — mathematically optimal for bandlimited signals, zero CUDA risk. Returns a mono float32 numpy array at 48 kHz. """ dur_in = len(wav_16k) / TARO_SR print(f"[TARO upsample] {dur_in:.2f}s @ {TARO_SR}Hz → {TARGET_SR}Hz (sinc, CPU) …") result = _resample_to_target(wav_16k, TARO_SR) print(f"[TARO upsample] done — {len(result)/TARGET_SR:.2f}s @ {TARGET_SR}Hz " f"(expected {dur_in * 3:.2f}s, ratio 3×)") return result def _stitch_wavs(wavs: list[np.ndarray], crossfade_s: float, db_boost: float, total_dur_s: float, sr: int, segments: list[tuple[float, float]] = None) -> np.ndarray: """Crossfade-join a list of wav arrays and trim to *total_dur_s*. Works for both mono (T,) and stereo (C, T) arrays. When *segments* is provided (list of (start, end) video-time pairs), each wav is trimmed to its contact-edge window before joining: contact_edge[i→i+1] = midpoint of overlap = (seg[i].end + seg[i+1].start) / 2 half_cf = crossfade_s / 2 seg i keep: [contact_edge[i-1→i] - half_cf, contact_edge[i→i+1] + half_cf] expressed as sample offsets into the generated audio for that segment. This guarantees every crossfade zone is exactly crossfade_s wide with no raw bleed regardless of how much the inference windows overlap. """ def _trim(wav, start_s, end_s, seg_start_s): """Trim wav to [start_s, end_s] expressed in absolute video time, where the wav starts at seg_start_s in video time.""" s = max(0, int(round((start_s - seg_start_s) * sr))) e = int(round((end_s - seg_start_s) * sr)) e = min(e, wav.shape[1] if wav.ndim == 2 else len(wav)) return wav[:, s:e] if wav.ndim == 2 else wav[s:e] if segments is None or len(segments) == 1: out = wavs[0] for nw in wavs[1:]: out = _cf_join(out, nw, crossfade_s, db_boost, sr) n = int(round(total_dur_s * sr)) return out[:, :n] if out.ndim == 2 else out[:n] half_cf = crossfade_s / 2.0 # Compute contact edges between consecutive segments contact_edges = [ (segments[i][1] + segments[i + 1][0]) / 2.0 for i in range(len(segments) - 1) ] # Trim each segment to its keep window trimmed = [] for i, (wav, (seg_start, seg_end)) in enumerate(zip(wavs, segments)): keep_start = (contact_edges[i - 1] - half_cf) if i > 0 else seg_start keep_end = (contact_edges[i] + half_cf) if i < len(segments) - 1 else total_dur_s trimmed.append(_trim(wav, keep_start, keep_end, seg_start)) # Crossfade-join the trimmed segments out = trimmed[0] for nw in trimmed[1:]: out = _cf_join(out, nw, crossfade_s, db_boost, sr) n = int(round(total_dur_s * sr)) return out[:, :n] if out.ndim == 2 else out[:n] def _save_wav(path: str, wav: np.ndarray, sr: int) -> None: """Save a numpy wav array (mono or stereo) to *path* via torchaudio.""" t = torch.from_numpy(np.ascontiguousarray(wav)) if t.ndim == 1: t = t.unsqueeze(0) torchaudio.save(path, t, sr) def _log_inference_timing(label: str, elapsed: float, n_segs: int, num_steps: int, constant: float) -> None: """Print a standardised inference-timing summary line.""" total_steps = n_segs * num_steps secs_per_step = elapsed / total_steps if total_steps > 0 else 0 print(f"[{label}] Inference done: {n_segs} seg(s) × {num_steps} steps in " f"{elapsed:.1f}s wall → {secs_per_step:.3f}s/step " f"(current constant={constant})") def _build_seg_meta(*, segments, wav_paths, audio_path, video_path, silent_video, sr, model, crossfade_s, crossfade_db, total_dur_s, **extras) -> dict: """Build the seg_meta dict shared by all three generate_* functions. Model-specific keys are passed via **extras.""" meta = { "segments": segments, "wav_paths": wav_paths, "audio_path": audio_path, "video_path": video_path, "silent_video": silent_video, "sr": sr, "model": model, "crossfade_s": crossfade_s, "crossfade_db": crossfade_db, "total_dur_s": total_dur_s, } meta.update(extras) return meta def _post_process_samples(results: list, *, model: str, tmp_dir: str, silent_video: str, segments: list, crossfade_s: float, crossfade_db: float, total_dur_s: float, sr: int, extra_meta_fn=None) -> list: """Shared CPU post-processing for all three generate_* wrappers. Each entry in *results* is a tuple whose first element is a list of per-segment wav arrays. The remaining elements are model-specific (e.g. TARO returns features, HunyuanFoley returns text_feats). *extra_meta_fn(sample_idx, result_tuple, tmp_dir) -> dict* is an optional callback that returns model-specific extra keys to merge into seg_meta (e.g. cavp_path, onset_path, text_feats_path). Returns a list of (video_path, audio_path, seg_meta) tuples. """ outputs = [] for sample_idx, result in enumerate(results): seg_wavs = result[0] full_wav = _stitch_wavs(seg_wavs, crossfade_s, crossfade_db, total_dur_s, sr, segments) audio_path = os.path.join(tmp_dir, f"{model}_{sample_idx}.wav") _save_wav(audio_path, full_wav, sr) video_path = os.path.join(tmp_dir, f"{model}_{sample_idx}.mp4") mux_video_audio(silent_video, audio_path, video_path, model=model) wav_paths = _save_seg_wavs(seg_wavs, tmp_dir, f"{model}_{sample_idx}") extras = extra_meta_fn(sample_idx, result, tmp_dir) if extra_meta_fn else {} seg_meta = _build_seg_meta( segments=segments, wav_paths=wav_paths, audio_path=audio_path, video_path=video_path, silent_video=silent_video, sr=sr, model=model, crossfade_s=crossfade_s, crossfade_db=crossfade_db, total_dur_s=total_dur_s, **extras, ) outputs.append((video_path, audio_path, seg_meta)) return outputs def _cpu_preprocess(video_file: str, model_dur: float, crossfade_s: float) -> tuple: """Shared CPU pre-processing for all generate_* wrappers. Returns (tmp_dir, silent_video, total_dur_s, segments).""" tmp_dir = _register_tmp_dir(tempfile.mkdtemp()) silent_video = os.path.join(tmp_dir, "silent_input.mp4") strip_audio_from_video(video_file, silent_video) total_dur_s = get_video_duration(video_file) segments = _build_segments(total_dur_s, model_dur, crossfade_s) return tmp_dir, silent_video, total_dur_s, segments @spaces.GPU(duration=_taro_duration) def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode, crossfade_s, crossfade_db, num_samples): """GPU-only TARO inference — model loading + feature extraction + diffusion. Returns list of (wavs_list, onset_feats) per sample.""" seed_val = _resolve_seed(seed_val) crossfade_s = float(crossfade_s) num_samples = int(num_samples) torch.set_grad_enabled(False) device, weight_dtype = _get_device_and_dtype() _ensure_syspath("TARO") from TARO.onset_util import extract_onset from TARO.samplers import euler_sampler, euler_maruyama_sampler ctx = _ctx_load("taro_gpu_infer") tmp_dir = ctx["tmp_dir"] silent_video = ctx["silent_video"] segments = ctx["segments"] total_dur_s = ctx["total_dur_s"] extract_cavp, onset_model = _load_taro_feature_extractors(device) cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir) # Onset features depend only on the video — extract once for all samples onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device) # Free feature extractors before loading the heavier inference models del extract_cavp, onset_model if torch.cuda.is_available(): torch.cuda.empty_cache() model, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype) results = [] # list of (wavs, onset_feats) per sample for sample_idx in range(num_samples): sample_seed = seed_val + sample_idx cache_key = (video_file, sample_seed, float(cfg_scale), int(num_steps), mode, crossfade_s) with _TARO_CACHE_LOCK: cached = _TARO_INFERENCE_CACHE.get(cache_key) if cached is not None: print(f"[TARO] Sample {sample_idx+1}: cache hit.") results.append((cached["wavs"], cavp_feats, None)) else: set_global_seed(sample_seed) wavs = [] _t_infer_start = time.perf_counter() for seg_start_s, seg_end_s in segments: print(f"[TARO] Sample {sample_idx+1} | {seg_start_s:.2f}s – {seg_end_s:.2f}s") wav = _taro_infer_segment( model, vae, vocoder, cavp_feats, onset_feats, seg_start_s, seg_end_s, device, weight_dtype, cfg_scale, num_steps, mode, latents_scale, euler_sampler, euler_maruyama_sampler, ) wavs.append(wav) _log_inference_timing("TARO", time.perf_counter() - _t_infer_start, len(segments), int(num_steps), TARO_SECS_PER_STEP) with _TARO_CACHE_LOCK: _TARO_INFERENCE_CACHE[cache_key] = {"wavs": wavs} while len(_TARO_INFERENCE_CACHE) > _TARO_CACHE_MAXLEN: _TARO_INFERENCE_CACHE.pop(next(iter(_TARO_INFERENCE_CACHE))) results.append((wavs, cavp_feats, onset_feats)) # Free GPU memory between samples so VRAM fragmentation doesn't # degrade diffusion quality on samples 2, 3, 4, etc. if torch.cuda.is_available(): torch.cuda.empty_cache() return results def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode, crossfade_s, crossfade_db, num_samples): """TARO: video-conditioned diffusion, 16 kHz, 8.192 s sliding window. CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost.""" crossfade_s = float(crossfade_s) crossfade_db = float(crossfade_db) num_samples = int(num_samples) # ── CPU pre-processing (no GPU needed) ── tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess( video_file, TARO_MODEL_DUR, crossfade_s) _ctx_store("taro_gpu_infer", { "tmp_dir": tmp_dir, "silent_video": silent_video, "segments": segments, "total_dur_s": total_dur_s, }) # ── GPU inference only ── results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode, crossfade_s, crossfade_db, num_samples) # ── CPU post-processing (no GPU needed) ── # Upsample 16kHz → 48kHz and normalise result tuples to (seg_wavs, ...) cavp_path = os.path.join(tmp_dir, "taro_cavp.npy") onset_path = os.path.join(tmp_dir, "taro_onset.npy") _feats_saved = False def _upsample_and_save_feats(result): nonlocal _feats_saved wavs, cavp_feats, onset_feats = result wavs = [_upsample_taro(w) for w in wavs] if not _feats_saved: np.save(cavp_path, cavp_feats) if onset_feats is not None: np.save(onset_path, onset_feats) _feats_saved = True return (wavs, cavp_feats, onset_feats) results = [_upsample_and_save_feats(r) for r in results] def _taro_extras(sample_idx, result, td): return {"cavp_path": cavp_path, "onset_path": onset_path} outputs = _post_process_samples( results, model="taro", tmp_dir=tmp_dir, silent_video=silent_video, segments=segments, crossfade_s=crossfade_s, crossfade_db=crossfade_db, total_dur_s=total_dur_s, sr=TARO_SR_OUT, extra_meta_fn=_taro_extras, ) return _pad_outputs(outputs) # ================================================================== # # MMAudio # # ================================================================== # # Constants sourced from MMAudio/mmaudio/model/sequence_config.py: # CONFIG_44K: duration=8.0 s, sampling_rate=44100 # CLIP encoder: 8 fps, 384×384 px # Synchformer: 25 fps, 224×224 px # Default variant: large_44k_v2 # MMAudio uses flow-matching (FlowMatching with euler inference). # generate() handles all feature extraction + decoding internally. # ================================================================== # def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val, cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples, silent_video=None, segments_json=None): """Pre-GPU callable — must match _mmaudio_gpu_infer's input order exactly.""" return _estimate_gpu_duration("mmaudio", int(num_samples), int(num_steps), video_file=video_file, crossfade_s=crossfade_s) @spaces.GPU(duration=_mmaudio_duration) def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val, cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples, silent_video=None, segments_json=None): """GPU-only MMAudio inference — model loading + flow-matching generation. Returns list of (seg_audios, sr) per sample. *silent_video* and *segments_json* are passed explicitly to avoid cross-process shared-state (ZeroGPU isolation). Segment clips are extracted here via ffmpeg (CPU-safe inside GPU window). """ _ensure_syspath("MMAudio") from mmaudio.eval_utils import generate, load_video from mmaudio.model.flow_matching import FlowMatching seed_val = _resolve_seed(seed_val) num_samples = int(num_samples) crossfade_s = float(crossfade_s) device, dtype = _get_device_and_dtype() net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype) # Extract segment clips inside GPU fn — ffmpeg is CPU-only, safe here. segments = json.loads(segments_json) tmp_dir = _register_tmp_dir(tempfile.mkdtemp()) seg_clip_paths = [ _extract_segment_clip(silent_video, s, e - s, os.path.join(tmp_dir, f"mma_seg_{i}.mp4")) for i, (s, e) in enumerate(segments) ] sr = seq_cfg.sampling_rate # 44100 results = [] for sample_idx in range(num_samples): rng = torch.Generator(device=device) rng.manual_seed(seed_val + sample_idx) seg_audios = [] _t_mma_start = time.perf_counter() fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps) for seg_i, (seg_start, seg_end) in enumerate(segments): seg_dur = seg_end - seg_start seg_path = seg_clip_paths[seg_i] video_info = load_video(seg_path, seg_dur) clip_frames = video_info.clip_frames.unsqueeze(0) sync_frames = video_info.sync_frames.unsqueeze(0) actual_dur = video_info.duration_sec seq_cfg.duration = actual_dur net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) print(f"[MMAudio] Sample {sample_idx+1} | seg {seg_i+1}/{len(segments)} " f"{seg_start:.1f}–{seg_end:.1f}s | dur={actual_dur:.2f}s | prompt='{prompt}'") with torch.no_grad(): audios = generate( clip_frames, sync_frames, [prompt], negative_text=[negative_prompt] if negative_prompt else None, feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=float(cfg_strength), ) wav = audios.float().cpu()[0].numpy() # (C, T) — full window seg_audios.append(wav) _log_inference_timing("MMAudio", time.perf_counter() - _t_mma_start, len(segments), int(num_steps), MMAUDIO_SECS_PER_STEP) results.append((seg_audios, sr)) # Free GPU memory between samples to prevent VRAM fragmentation if torch.cuda.is_available(): torch.cuda.empty_cache() return results def generate_mmaudio(video_file, prompt, negative_prompt, seed_val, cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples): """MMAudio: flow-matching video-to-audio, 44.1 kHz, 8 s sliding window. CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost.""" num_samples = int(num_samples) crossfade_s = float(crossfade_s) crossfade_db = float(crossfade_db) # ── CPU pre-processing ── tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess( video_file, MMAUDIO_WINDOW, crossfade_s) print(f"[MMAudio] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤8 s") # ── GPU inference only ── results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val, cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples, silent_video=silent_video, segments_json=json.dumps(segments)) # ── CPU post-processing ── # Resample 44100 → 48000 and normalise tuples to (seg_wavs, ...) resampled = [] for seg_audios, sr in results: if sr != TARGET_SR: print(f"[MMAudio upsample] resampling {sr}Hz → {TARGET_SR}Hz (sinc, CPU) …") seg_audios = [_resample_to_target(w, sr) for w in seg_audios] print(f"[MMAudio upsample] done — {len(seg_audios)} seg(s) @ {TARGET_SR}Hz") resampled.append((seg_audios,)) outputs = _post_process_samples( resampled, model="mmaudio", tmp_dir=tmp_dir, silent_video=silent_video, segments=segments, crossfade_s=crossfade_s, crossfade_db=crossfade_db, total_dur_s=total_dur_s, sr=TARGET_SR, ) return _pad_outputs(outputs) # ================================================================== # # HunyuanVideoFoley # # ================================================================== # # Constants sourced from HunyuanVideo-Foley/hunyuanvideo_foley/constants.py # and configs/hunyuanvideo-foley-xxl.yaml: # sample_rate = 48000 Hz (from DAC VAE) # audio_frame_rate = 50 (latent fps, xxl config) # max video duration = 15 s # SigLIP2 fps = 8, Synchformer fps = 25 # CLAP text encoder: laion/larger_clap_general (auto-downloaded from HF Hub) # Default guidance_scale=4.5, num_inference_steps=50 # ================================================================== # def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val, guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples, silent_video=None, segments_json=None, total_dur_s=None): """Pre-GPU callable — must match _hunyuan_gpu_infer's input order exactly.""" return _estimate_gpu_duration("hunyuan", int(num_samples), int(num_steps), video_file=video_file, crossfade_s=crossfade_s) @spaces.GPU(duration=_hunyuan_duration) def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val, guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples, silent_video=None, segments_json=None, total_dur_s=None): """GPU-only HunyuanFoley inference — model loading + feature extraction + denoising. Returns list of (seg_wavs, sr, text_feats) per sample. *silent_video*, *segments_json*, and *total_dur_s* are passed explicitly to avoid cross-process shared-state under ZeroGPU isolation. """ _ensure_syspath("HunyuanVideo-Foley") from hunyuanvideo_foley.utils.model_utils import denoise_process from hunyuanvideo_foley.utils.feature_utils import feature_process seed_val = _resolve_seed(seed_val) num_samples = int(num_samples) crossfade_s = float(crossfade_s) total_dur_s = float(total_dur_s) set_global_seed(seed_val) device, _ = _get_device_and_dtype() model_size = model_size.lower() model_dict, cfg = _load_hunyuan_model(device, model_size) # Extract segment clips inside GPU fn — ffmpeg is CPU-only, safe here. segments = json.loads(segments_json) tmp_dir = _register_tmp_dir(tempfile.mkdtemp()) dummy_seg_path = _extract_segment_clip( silent_video, 0, min(total_dur_s, HUNYUAN_MAX_DUR), os.path.join(tmp_dir, "_seg_dummy.mp4"), ) seg_clip_paths = [ _extract_segment_clip(silent_video, s, e - s, os.path.join(tmp_dir, f"hny_seg_{i}.mp4")) for i, (s, e) in enumerate(segments) ] # Text feature extraction (GPU — runs once for all segments) _, text_feats, _ = feature_process( dummy_seg_path, prompt if prompt else "", model_dict, cfg, neg_prompt=negative_prompt if negative_prompt else None, ) # Import visual-only feature extractor to avoid redundant text extraction # per segment (text_feats already computed once above for the whole batch). from hunyuanvideo_foley.utils.feature_utils import encode_video_features results = [] for sample_idx in range(num_samples): seg_wavs = [] sr = 48000 _t_hny_start = time.perf_counter() for seg_i, (seg_start, seg_end) in enumerate(segments): seg_dur = seg_end - seg_start seg_path = seg_clip_paths[seg_i] # Extract only visual features — reuse text_feats from above visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict) print(f"[HunyuanFoley] Sample {sample_idx+1} | seg {seg_i+1}/{len(segments)} " f"{seg_start:.1f}–{seg_end:.1f}s → {seg_audio_len:.2f}s audio") audio_batch, sr = denoise_process( visual_feats, text_feats, seg_audio_len, model_dict, cfg, guidance_scale=float(guidance_scale), num_inference_steps=int(num_steps), batch_size=1, ) wav = audio_batch[0].float().cpu().numpy() # full window seg_wavs.append(wav) _log_inference_timing("HunyuanFoley", time.perf_counter() - _t_hny_start, len(segments), int(num_steps), HUNYUAN_SECS_PER_STEP) results.append((seg_wavs, sr, text_feats)) # Free GPU memory between samples to prevent VRAM fragmentation if torch.cuda.is_available(): torch.cuda.empty_cache() return results def generate_hunyuan(video_file, prompt, negative_prompt, seed_val, guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples): """HunyuanVideoFoley: text-guided foley, 48 kHz, up to 15 s. CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost.""" num_samples = int(num_samples) crossfade_s = float(crossfade_s) crossfade_db = float(crossfade_db) # ── CPU pre-processing (no GPU needed) ── tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess( video_file, HUNYUAN_MAX_DUR, crossfade_s) print(f"[HunyuanFoley] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤15 s") # ── GPU inference only ── results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val, guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples, silent_video=silent_video, segments_json=json.dumps(segments), total_dur_s=total_dur_s) # ── CPU post-processing (no GPU needed) ── def _hunyuan_extras(sample_idx, result, td): _, _sr, text_feats = result path = os.path.join(td, f"hunyuan_{sample_idx}_text_feats.pt") torch.save(text_feats, path) return {"text_feats_path": path} outputs = _post_process_samples( results, model="hunyuan", tmp_dir=tmp_dir, silent_video=silent_video, segments=segments, crossfade_s=crossfade_s, crossfade_db=crossfade_db, total_dur_s=total_dur_s, sr=48000, extra_meta_fn=_hunyuan_extras, ) return _pad_outputs(outputs) # ================================================================== # # SEGMENT REGENERATION HELPERS # # ================================================================== # # Each regen function: # 1. Runs inference for ONE segment (random seed, current settings) # 2. Splices the new wav into the stored wavs list # 3. Re-stitches the full track, re-saves .wav and re-muxes .mp4 # 4. Returns (new_video_path, new_audio_path, updated_seg_meta, new_waveform_html) # ================================================================== # def _splice_and_save(new_wav, seg_idx, meta, slot_id): """Replace wavs[seg_idx] with new_wav, re-stitch, re-save, re-mux. Returns (video_path, audio_path, updated_meta, waveform_html). """ wavs = _load_seg_wavs(meta["wav_paths"]) wavs[seg_idx]= new_wav crossfade_s = float(meta["crossfade_s"]) crossfade_db = float(meta["crossfade_db"]) sr = int(meta["sr"]) total_dur_s = float(meta["total_dur_s"]) silent_video = meta["silent_video"] segments = meta["segments"] model = meta["model"] full_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, sr, segments) # Save new audio — use a new timestamped filename so Gradio / the browser # treats it as a genuinely different file and reloads the video player. _ts = int(time.time() * 1000) tmp_dir = os.path.dirname(meta["audio_path"]) _base = os.path.splitext(os.path.basename(meta["audio_path"]))[0] # Strip any previous timestamp suffix before adding a new one _base_clean = _base.rsplit("_regen_", 1)[0] audio_path = os.path.join(tmp_dir, f"{_base_clean}_regen_{_ts}.wav") _save_wav(audio_path, full_wav, sr) # Re-mux into a new video file so the browser is forced to reload it _vid_base = os.path.splitext(os.path.basename(meta["video_path"]))[0] _vid_base_clean = _vid_base.rsplit("_regen_", 1)[0] video_path = os.path.join(tmp_dir, f"{_vid_base_clean}_regen_{_ts}.mp4") mux_video_audio(silent_video, audio_path, video_path, model=model) # Save updated segment wavs to .npy files updated_wav_paths = _save_seg_wavs(wavs, tmp_dir, os.path.splitext(_base_clean)[0]) updated_meta = dict(meta) updated_meta["wav_paths"] = updated_wav_paths updated_meta["audio_path"] = audio_path updated_meta["video_path"] = video_path state_json_new = json.dumps(updated_meta) waveform_html = _build_waveform_html(audio_path, segments, slot_id, "", state_json=state_json_new, video_path=video_path, crossfade_s=crossfade_s) return video_path, audio_path, updated_meta, waveform_html def _taro_regen_duration(video_file, seg_idx, seg_meta_json, seed_val, cfg_scale, num_steps, mode, crossfade_s, crossfade_db, slot_id=None): # If cached CAVP/onset features exist, skip ~10s feature-extractor overhead try: meta = json.loads(seg_meta_json) cavp_ok = os.path.exists(meta.get("cavp_path", "")) onset_ok = os.path.exists(meta.get("onset_path", "")) if cavp_ok and onset_ok: cfg = MODEL_CONFIGS["taro"] secs = int(num_steps) * cfg["secs_per_step"] + 5 # 5s model-load only result = min(GPU_DURATION_CAP, max(30, int(secs))) print(f"[duration] TARO regen (cache hit): 1 seg × {int(num_steps)} steps → {secs:.0f}s → capped {result}s") return result except Exception: pass return _estimate_regen_duration("taro", int(num_steps)) @spaces.GPU(duration=_taro_regen_duration) def _regen_taro_gpu(video_file, seg_idx, seg_meta_json, seed_val, cfg_scale, num_steps, mode, crossfade_s, crossfade_db, slot_id=None): """GPU-only TARO regen — returns new_wav for a single segment.""" meta = json.loads(seg_meta_json) seg_idx = int(seg_idx) seg_start_s, seg_end_s = meta["segments"][seg_idx] torch.set_grad_enabled(False) device, weight_dtype = _get_device_and_dtype() _ensure_syspath("TARO") from TARO.samplers import euler_sampler, euler_maruyama_sampler # Load cached CAVP/onset features from .npy files (CPU I/O, fast, outside GPU budget) cavp_path = meta.get("cavp_path", "") onset_path = meta.get("onset_path", "") if cavp_path and os.path.exists(cavp_path) and onset_path and os.path.exists(onset_path): print("[TARO regen] Loading cached CAVP + onset features from disk") cavp_feats = np.load(cavp_path) onset_feats = np.load(onset_path) else: print("[TARO regen] Cache miss — re-extracting CAVP + onset features") from TARO.onset_util import extract_onset extract_cavp, onset_model = _load_taro_feature_extractors(device) silent_video = meta["silent_video"] tmp_dir = _register_tmp_dir(tempfile.mkdtemp()) cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir) onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device) del extract_cavp, onset_model if torch.cuda.is_available(): torch.cuda.empty_cache() model_net, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype) set_global_seed(random.randint(0, 2**32 - 1)) return _taro_infer_segment( model_net, vae, vocoder, cavp_feats, onset_feats, seg_start_s, seg_end_s, device, weight_dtype, float(cfg_scale), int(num_steps), mode, latents_scale, euler_sampler, euler_maruyama_sampler, ) def regen_taro_segment(video_file, seg_idx, seg_meta_json, seed_val, cfg_scale, num_steps, mode, crossfade_s, crossfade_db, slot_id): """Regenerate one TARO segment. GPU inference + CPU splice/save.""" meta = json.loads(seg_meta_json) seg_idx = int(seg_idx) # GPU: inference — CAVP/onset features loaded from disk paths in seg_meta_json new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json, seed_val, cfg_scale, num_steps, mode, crossfade_s, crossfade_db, slot_id) # Upsample 16kHz → 48kHz (sinc, CPU) new_wav = _upsample_taro(new_wav) # CPU: splice, stitch, mux, save video_path, audio_path, updated_meta, waveform_html = _splice_and_save( new_wav, seg_idx, meta, slot_id ) return video_path, audio_path, json.dumps(updated_meta), waveform_html def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json, prompt, negative_prompt, seed_val, cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id=None): return _estimate_regen_duration("mmaudio", int(num_steps)) @spaces.GPU(duration=_mmaudio_regen_duration) def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json, prompt, negative_prompt, seed_val, cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id=None): """GPU-only MMAudio regen — returns (new_wav, sr) for a single segment.""" meta = json.loads(seg_meta_json) seg_idx = int(seg_idx) seg_start, seg_end = meta["segments"][seg_idx] seg_dur = seg_end - seg_start _ensure_syspath("MMAudio") from mmaudio.eval_utils import generate, load_video from mmaudio.model.flow_matching import FlowMatching device, dtype = _get_device_and_dtype() net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype) sr = seq_cfg.sampling_rate # Extract segment clip inside the GPU function — ffmpeg is CPU-only and safe here. # This avoids any cross-process context passing that fails under ZeroGPU isolation. seg_path = _extract_segment_clip( meta["silent_video"], seg_start, seg_dur, os.path.join(_register_tmp_dir(tempfile.mkdtemp()), "regen_seg.mp4"), ) rng = torch.Generator(device=device) rng.manual_seed(random.randint(0, 2**32 - 1)) fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=int(num_steps)) video_info = load_video(seg_path, seg_dur) clip_frames = video_info.clip_frames.unsqueeze(0) sync_frames = video_info.sync_frames.unsqueeze(0) actual_dur = video_info.duration_sec seq_cfg.duration = actual_dur net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) with torch.no_grad(): audios = generate( clip_frames, sync_frames, [prompt], negative_text=[negative_prompt] if negative_prompt else None, feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=float(cfg_strength), ) new_wav = audios.float().cpu()[0].numpy() # full window — _stitch_wavs trims return new_wav, sr def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json, prompt, negative_prompt, seed_val, cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id): """Regenerate one MMAudio segment. GPU inference + CPU splice/save.""" meta = json.loads(seg_meta_json) seg_idx = int(seg_idx) # GPU: inference (segment clip extraction happens inside the GPU function) new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json, prompt, negative_prompt, seed_val, cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id) # Resample to 48kHz if needed (MMAudio outputs at 44100 Hz) if sr != TARGET_SR: print(f"[MMAudio regen upsample] {sr}Hz → {TARGET_SR}Hz (sinc, CPU) …") new_wav = _resample_to_target(new_wav, sr) sr = TARGET_SR meta["sr"] = sr # CPU: splice, stitch, mux, save video_path, audio_path, updated_meta, waveform_html = _splice_and_save( new_wav, seg_idx, meta, slot_id ) return video_path, audio_path, json.dumps(updated_meta), waveform_html def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json, prompt, negative_prompt, seed_val, guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, slot_id=None): return _estimate_regen_duration("hunyuan", int(num_steps)) @spaces.GPU(duration=_hunyuan_regen_duration) def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json, prompt, negative_prompt, seed_val, guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, slot_id=None): """GPU-only HunyuanFoley regen — returns (new_wav, sr) for a single segment.""" meta = json.loads(seg_meta_json) seg_idx = int(seg_idx) seg_start, seg_end = meta["segments"][seg_idx] seg_dur = seg_end - seg_start _ensure_syspath("HunyuanVideo-Foley") from hunyuanvideo_foley.utils.model_utils import denoise_process from hunyuanvideo_foley.utils.feature_utils import feature_process device, _ = _get_device_and_dtype() model_dict, cfg = _load_hunyuan_model(device, model_size) set_global_seed(random.randint(0, 2**32 - 1)) # Extract segment clip inside the GPU function — ffmpeg is CPU-only and safe here. seg_path = _extract_segment_clip( meta["silent_video"], seg_start, seg_dur, os.path.join(_register_tmp_dir(tempfile.mkdtemp()), "regen_seg.mp4"), ) text_feats_path = meta.get("text_feats_path", "") if text_feats_path and os.path.exists(text_feats_path): print("[HunyuanFoley regen] Loading cached text features from disk") from hunyuanvideo_foley.utils.feature_utils import encode_video_features visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict) text_feats = torch.load(text_feats_path, map_location=device, weights_only=False) else: print("[HunyuanFoley regen] Cache miss — extracting text + visual features") visual_feats, text_feats, seg_audio_len = feature_process( seg_path, prompt if prompt else "", model_dict, cfg, neg_prompt=negative_prompt if negative_prompt else None, ) audio_batch, sr = denoise_process( visual_feats, text_feats, seg_audio_len, model_dict, cfg, guidance_scale=float(guidance_scale), num_inference_steps=int(num_steps), batch_size=1, ) new_wav = audio_batch[0].float().cpu().numpy() # full window — _stitch_wavs trims return new_wav, sr def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json, prompt, negative_prompt, seed_val, guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, slot_id): """Regenerate one HunyuanFoley segment. GPU inference + CPU splice/save.""" meta = json.loads(seg_meta_json) seg_idx = int(seg_idx) # GPU: inference (segment clip extraction happens inside the GPU function) new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json, prompt, negative_prompt, seed_val, guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, slot_id) meta["sr"] = sr # CPU: splice, stitch, mux, save video_path, audio_path, updated_meta, waveform_html = _splice_and_save( new_wav, seg_idx, meta, slot_id ) return video_path, audio_path, json.dumps(updated_meta), waveform_html # Wire up regen_fn references now that the functions are defined MODEL_CONFIGS["taro"]["regen_fn"] = regen_taro_segment MODEL_CONFIGS["mmaudio"]["regen_fn"] = regen_mmaudio_segment MODEL_CONFIGS["hunyuan"]["regen_fn"] = regen_hunyuan_segment # ================================================================== # # CROSS-MODEL REGEN WRAPPERS # # ================================================================== # # Three shared endpoints — one per model — that can be called from # # *any* slot tab. slot_id is passed as plain string data so the # # result is applied back to the correct slot by the JS listener. # # The new segment is resampled to match the slot's existing SR before # # being handed to _splice_and_save, so TARO (16 kHz) / MMAudio # # (44.1 kHz) / Hunyuan (48 kHz) outputs can all be mixed freely. # # ================================================================== # def _resample_to_slot_sr(wav: np.ndarray, src_sr: int, dst_sr: int, slot_wav_ref: np.ndarray = None) -> np.ndarray: """Resample *wav* from src_sr to dst_sr, then match channel layout to *slot_wav_ref* (the first existing segment in the slot). TARO is mono (T,), MMAudio/Hunyuan are stereo (C, T). Mixing them without normalisation causes a shape mismatch in _cf_join. Rules: - stereo → mono : average channels - mono → stereo: duplicate the single channel """ wav = _resample_to_target(wav, src_sr, dst_sr) # Match channel layout to the slot's existing segments if slot_wav_ref is not None: slot_stereo = slot_wav_ref.ndim == 2 wav_stereo = wav.ndim == 2 if slot_stereo and not wav_stereo: wav = np.stack([wav, wav], axis=0) # mono → stereo (C, T) elif not slot_stereo and wav_stereo: wav = wav.mean(axis=0) # stereo → mono (T,) return wav def _xregen_clip_window(meta: dict, seg_idx: int, target_window_s: float) -> tuple: """Compute the video clip window for a cross-model regen. Centers *target_window_s* on the original segment's midpoint, clamped to [0, total_dur_s]. Returns (clip_start, clip_end, clip_dur). If the video is shorter than *target_window_s*, the full video is used (suboptimal but never breaks). If the segment span exceeds *target_window_s*, the caller should run _build_segments on the span and generate multiple sub-segments — but the clip window is still returned as the full segment span so the caller can decide. """ total_dur_s = float(meta["total_dur_s"]) seg_start, seg_end = meta["segments"][seg_idx] seg_mid = (seg_start + seg_end) / 2.0 half_win = target_window_s / 2.0 clip_start = max(0.0, seg_mid - half_win) clip_end = min(total_dur_s, seg_mid + half_win) # If clamped at one end, extend the other to preserve full window if possible if clip_start == 0.0: clip_end = min(total_dur_s, target_window_s) elif clip_end == total_dur_s: clip_start = max(0.0, total_dur_s - target_window_s) clip_dur = clip_end - clip_start return clip_start, clip_end, clip_dur def _xregen_splice(new_wav_raw: np.ndarray, src_sr: int, meta: dict, seg_idx: int, slot_id: str, clip_start_s: float = None) -> tuple: """Shared epilogue for all xregen_* functions: resample → splice → save. Returns (video_path, waveform_html). *clip_start_s* is the absolute video time where new_wav_raw starts. When the clip was centered on the segment midpoint (not at seg_start), we need to shift the wav so _stitch_wavs can trim it correctly relative to the original segment's start. We do this by prepending silence so the wav's time origin aligns with the original segment's start. """ slot_sr = int(meta["sr"]) slot_wavs = _load_seg_wavs(meta["wav_paths"]) new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr, slot_wavs[0]) # Align new_wav so sample index 0 corresponds to seg_start in video time. # _stitch_wavs trims using seg_start as the time origin, so if the clip # started AFTER seg_start (clip_start_s > seg_start), we prepend silence # equal to (clip_start_s - seg_start) to shift the audio back to seg_start. if clip_start_s is not None: seg_start = meta["segments"][seg_idx][0] offset_s = seg_start - clip_start_s # negative when clip starts after seg_start if offset_s < 0: pad_samples = int(round(abs(offset_s) * slot_sr)) silence = np.zeros( (new_wav.shape[0], pad_samples) if new_wav.ndim == 2 else pad_samples, dtype=new_wav.dtype, ) new_wav = np.concatenate([silence, new_wav], axis=1 if new_wav.ndim == 2 else 0) video_path, audio_path, updated_meta, waveform_html = _splice_and_save( new_wav, seg_idx, meta, slot_id ) return video_path, waveform_html def _xregen_dispatch(state_json: str, seg_idx: int, slot_id: str, infer_fn): """Shared generator skeleton for all xregen_* wrappers. Yields pending HTML immediately, then calls *infer_fn()* — a zero-argument callable that runs model-specific CPU prep + GPU inference and returns (wav_array, src_sr, clip_start_s). For TARO, *infer_fn* should return the wav already upsampled to 48 kHz; pass TARO_SR_OUT as src_sr. Yields: First: (gr.update(), gr.update(value=pending_html)) — shown while GPU runs Second: (gr.update(value=video_path), gr.update(value=waveform_html)) """ meta = json.loads(state_json) pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "") yield gr.update(), gr.update(value=pending_html) new_wav_raw, src_sr, clip_start_s = infer_fn() video_path, waveform_html = _xregen_splice(new_wav_raw, src_sr, meta, seg_idx, slot_id, clip_start_s) yield gr.update(value=video_path), gr.update(value=waveform_html) def xregen_taro(seg_idx, state_json, slot_id, seed_val, cfg_scale, num_steps, mode, crossfade_s, crossfade_db, request: gr.Request = None): """Cross-model regen: run TARO on its optimal window, splice into *slot_id*.""" seg_idx = int(seg_idx) meta = json.loads(state_json) def _run(): clip_start, clip_end, clip_dur = _xregen_clip_window(meta, seg_idx, TARO_MODEL_DUR) tmp_dir = _register_tmp_dir(tempfile.mkdtemp()) clip_path = _extract_segment_clip( meta["silent_video"], clip_start, clip_dur, os.path.join(tmp_dir, "xregen_taro_clip.mp4"), ) # Build a minimal fake-video meta so generate_taro can run on clip_path sub_segs = _build_segments(clip_dur, TARO_MODEL_DUR, float(crossfade_s)) sub_meta_json = json.dumps({ "segments": sub_segs, "silent_video": clip_path, "total_dur_s": clip_dur, }) # Run full TARO generation pipeline on the clip _ctx_store("taro_gpu_infer", { "tmp_dir": tmp_dir, "silent_video": clip_path, "segments": sub_segs, "total_dur_s": clip_dur, }) results = _taro_gpu_infer(clip_path, seed_val, cfg_scale, num_steps, mode, crossfade_s, crossfade_db, 1) wavs, _, _ = results[0] wavs = [_upsample_taro(w) for w in wavs] wav = _stitch_wavs(wavs, float(crossfade_s), float(crossfade_db), clip_dur, TARO_SR_OUT, sub_segs) return wav, TARO_SR_OUT, clip_start yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run) def xregen_mmaudio(seg_idx, state_json, slot_id, prompt, negative_prompt, seed_val, cfg_strength, num_steps, crossfade_s, crossfade_db, request: gr.Request = None): """Cross-model regen: run MMAudio on its optimal window, splice into *slot_id*.""" seg_idx = int(seg_idx) meta = json.loads(state_json) def _run(): clip_start, clip_end, clip_dur = _xregen_clip_window(meta, seg_idx, MMAUDIO_WINDOW) tmp_dir = _register_tmp_dir(tempfile.mkdtemp()) clip_path = _extract_segment_clip( meta["silent_video"], clip_start, clip_dur, os.path.join(tmp_dir, "xregen_mmaudio_clip.mp4"), ) sub_segs = _build_segments(clip_dur, MMAUDIO_WINDOW, float(crossfade_s)) results = _mmaudio_gpu_infer(clip_path, prompt, negative_prompt, seed_val, cfg_strength, num_steps, crossfade_s, crossfade_db, 1, silent_video=clip_path, segments_json=json.dumps(sub_segs)) seg_wavs, sr = results[0] wav = _stitch_wavs(seg_wavs, float(crossfade_s), float(crossfade_db), clip_dur, sr, sub_segs) if sr != TARGET_SR: wav = _resample_to_target(wav, sr) sr = TARGET_SR return wav, sr, clip_start yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run) def xregen_hunyuan(seg_idx, state_json, slot_id, prompt, negative_prompt, seed_val, guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, request: gr.Request = None): """Cross-model regen: run HunyuanFoley on its optimal window, splice into *slot_id*.""" seg_idx = int(seg_idx) meta = json.loads(state_json) def _run(): clip_start, clip_end, clip_dur = _xregen_clip_window(meta, seg_idx, HUNYUAN_MAX_DUR) tmp_dir = _register_tmp_dir(tempfile.mkdtemp()) clip_path = _extract_segment_clip( meta["silent_video"], clip_start, clip_dur, os.path.join(tmp_dir, "xregen_hunyuan_clip.mp4"), ) sub_segs = _build_segments(clip_dur, HUNYUAN_MAX_DUR, float(crossfade_s)) results = _hunyuan_gpu_infer(clip_path, prompt, negative_prompt, seed_val, guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, 1, silent_video=clip_path, segments_json=json.dumps(sub_segs), total_dur_s=clip_dur) seg_wavs, sr, _ = results[0] wav = _stitch_wavs(seg_wavs, float(crossfade_s), float(crossfade_db), clip_dur, sr, sub_segs) return wav, sr, clip_start yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run) # ================================================================== # # SHARED UI HELPERS # # ================================================================== # def _register_regen_handlers(tab_prefix, model_key, regen_seg_tb, regen_state_tb, input_components, slot_vids, slot_waves): """Register per-slot regen button handlers for a model tab. This replaces the three nearly-identical for-loops that previously existed for TARO, MMAudio, and HunyuanFoley tabs. Args: tab_prefix: e.g. "taro", "mma", "hf" model_key: e.g. "taro", "mmaudio", "hunyuan" regen_seg_tb: gr.Textbox for seg_idx (render=False) regen_state_tb: gr.Textbox for state_json (render=False) input_components: list of Gradio input components (video, seed, etc.) — order must match regen_fn signature after (seg_idx, state_json, video) slot_vids: list of gr.Video components per slot slot_waves: list of gr.HTML components per slot Returns: list of hidden gr.Buttons (one per slot) """ cfg = MODEL_CONFIGS[model_key] regen_fn = cfg["regen_fn"] label = cfg["label"] btns = [] for _i in range(MAX_SLOTS): _slot_id = f"{tab_prefix}_{_i}" _btn = gr.Button(render=False, elem_id=f"regen_btn_{_slot_id}") btns.append(_btn) print(f"[startup] registering regen handler for slot {_slot_id}") def _make_regen(_si, _sid, _model_key, _label, _regen_fn): def _do(seg_idx, state_json, *args): print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} " f"state_json_len={len(state_json) if state_json else 0}") if not state_json: print(f"[regen {_label}] early-exit: state_json empty") yield gr.update(), gr.update() return lock = _get_slot_lock(_sid) with lock: state = json.loads(state_json) pending_html = _build_regen_pending_html( state["segments"], int(seg_idx), _sid, "" ) yield gr.update(), gr.update(value=pending_html) print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} — calling regen") try: # args[0] = video, args[1:] = model-specific params vid, aud, new_meta_json, html = _regen_fn( args[0], int(seg_idx), state_json, *args[1:], _sid, ) print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} — done, vid={vid!r}") except Exception as _e: print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} — ERROR: {_e}") raise yield gr.update(value=vid), gr.update(value=html) return _do _btn.click( fn=_make_regen(_i, _slot_id, model_key, label, regen_fn), inputs=[regen_seg_tb, regen_state_tb] + input_components, outputs=[slot_vids[_i], slot_waves[_i]], api_name=f"regen_{tab_prefix}_{_i}", ) return btns def _pad_outputs(outputs: list) -> list: """Flatten (video, audio, seg_meta) triples and pad to MAX_SLOTS * 3 with None. Each entry in *outputs* must be a (video_path, audio_path, seg_meta) tuple where seg_meta = {"segments": [...], "audio_path": str, "video_path": str, "sr": int, "model": str, "crossfade_s": float, "crossfade_db": float, "wav_paths": list[str]} """ result = [] for i in range(MAX_SLOTS): if i < len(outputs): result.extend(outputs[i]) # 3 items: video, audio, meta else: result.extend([None, None, None]) return result # ------------------------------------------------------------------ # # WaveSurfer waveform + segment marker HTML builder # # ------------------------------------------------------------------ # def _build_regen_pending_html(segments: list, regen_seg_idx: int, slot_id: str, hidden_input_id: str) -> str: """Return a waveform placeholder shown while a segment is being regenerated. Renders a dark bar with the active segment highlighted in amber + a spinner. """ segs_json = json.dumps(segments) seg_colors = [c.format(a="0.25") for c in SEG_COLORS] active_color = "rgba(255,180,0,0.55)" duration = segments[-1][1] if segments else 1.0 seg_divs = "" for i, seg in enumerate(segments): # Draw only the non-overlapping (unique) portion of each segment so that # overlapping windows don't visually bleed into adjacent segments. # Each segment owns the region from its own start up to the next segment's # start (or its own end for the final segment). seg_start = seg[0] seg_end = segments[i + 1][0] if i + 1 < len(segments) else seg[1] left_pct = seg_start / duration * 100 width_pct = (seg_end - seg_start) / duration * 100 color = active_color if i == regen_seg_idx else seg_colors[i % len(seg_colors)] extra = "border:2px solid #ffb300;animation:wf_pulse 0.8s ease-in-out infinite alternate;" if i == regen_seg_idx else "" seg_divs += ( f'