Spaces:
Running on Zero
Running on Zero
| """ | |
| Generate Audio for Video β multi-model Gradio app. | |
| Supported models | |
| ---------------- | |
| TARO β video-conditioned diffusion via CAVP + onset features (16 kHz, 8.192 s window) | |
| MMAudio β multimodal flow-matching with CLIP/Synchformer + text prompt (44 kHz, 8 s window) | |
| HunyuanFoley β text-guided foley via SigLIP2 + Synchformer + CLAP (48 kHz, up to 15 s) | |
| """ | |
| import html as _html | |
| import os | |
| import sys | |
| import json | |
| import shutil | |
| import tempfile | |
| import random | |
| import threading | |
| import time | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from pathlib import Path | |
| import torch | |
| import numpy as np | |
| import torchaudio | |
| import ffmpeg | |
| import spaces | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| # ================================================================== # | |
| # CHECKPOINT CONFIGURATION # | |
| # ================================================================== # | |
| CKPT_REPO_ID = "JackIsNotInTheBox/Generate_Audio_for_Video_Checkpoints" | |
| CACHE_DIR = "/tmp/model_ckpts" | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| # ---- Local directories that must exist before parallel downloads start ---- | |
| MMAUDIO_WEIGHTS_DIR = Path(CACHE_DIR) / "MMAudio" / "weights" | |
| MMAUDIO_EXT_DIR = Path(CACHE_DIR) / "MMAudio" / "ext_weights" | |
| HUNYUAN_MODEL_DIR = Path(CACHE_DIR) / "HunyuanFoley" | |
| MMAUDIO_WEIGHTS_DIR.mkdir(parents=True, exist_ok=True) | |
| MMAUDIO_EXT_DIR.mkdir(parents=True, exist_ok=True) | |
| HUNYUAN_MODEL_DIR.mkdir(parents=True, exist_ok=True) | |
| # ------------------------------------------------------------------ # | |
| # Parallel checkpoint + model downloads # | |
| # All downloads are I/O-bound (network), so running them in threads # | |
| # cuts Space cold-start time roughly proportional to the number of # | |
| # independent groups (previously sequential, now concurrent). # | |
| # hf_hub_download / snapshot_download are thread-safe. # | |
| # ------------------------------------------------------------------ # | |
| def _dl_taro(): | |
| """Download TARO .ckpt/.pt files and return their local paths.""" | |
| c = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/cavp_epoch66.ckpt", cache_dir=CACHE_DIR) | |
| o = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/onset_model.ckpt", cache_dir=CACHE_DIR) | |
| t = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/taro_ckpt.pt", cache_dir=CACHE_DIR) | |
| print("TARO checkpoints downloaded.") | |
| return c, o, t | |
| def _dl_mmaudio(): | |
| """Download MMAudio .pth files and return their local paths.""" | |
| m = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/mmaudio_large_44k_v2.pth", | |
| cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_WEIGHTS_DIR), local_dir_use_symlinks=False) | |
| v = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/v1-44.pth", | |
| cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False) | |
| s = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/synchformer_state_dict.pth", | |
| cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False) | |
| print("MMAudio checkpoints downloaded.") | |
| return m, v, s | |
| def _dl_hunyuan(): | |
| """Download HunyuanVideoFoley .pth files.""" | |
| hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/hunyuanvideo_foley.pth", | |
| cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False) | |
| hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/vae_128d_48k.pth", | |
| cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False) | |
| hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/synchformer_state_dict.pth", | |
| cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False) | |
| print("HunyuanVideoFoley checkpoints downloaded.") | |
| def _dl_clap(): | |
| """Pre-download CLAP so from_pretrained() hits local cache inside the ZeroGPU worker.""" | |
| snapshot_download(repo_id="laion/larger_clap_general") | |
| print("CLAP model pre-downloaded.") | |
| def _dl_clip(): | |
| """Pre-download MMAudio's CLIP model (~3.95 GB) to avoid GPU-window budget drain.""" | |
| snapshot_download(repo_id="apple/DFN5B-CLIP-ViT-H-14-384") | |
| print("MMAudio CLIP model pre-downloaded.") | |
| def _dl_audioldm2(): | |
| """Pre-download AudioLDM2 VAE/vocoder used by TARO's from_pretrained() calls.""" | |
| snapshot_download(repo_id="cvssp/audioldm2") | |
| print("AudioLDM2 pre-downloaded.") | |
| def _dl_bigvgan(): | |
| """Pre-download BigVGAN vocoder (~489 MB) used by MMAudio.""" | |
| snapshot_download(repo_id="nvidia/bigvgan_v2_44khz_128band_512x") | |
| print("BigVGAN vocoder pre-downloaded.") | |
| print("[startup] Starting parallel checkpoint + model downloadsβ¦") | |
| _t_dl_start = time.perf_counter() | |
| with ThreadPoolExecutor(max_workers=7) as _pool: | |
| _fut_taro = _pool.submit(_dl_taro) | |
| _fut_mmaudio = _pool.submit(_dl_mmaudio) | |
| _fut_hunyuan = _pool.submit(_dl_hunyuan) | |
| _fut_clap = _pool.submit(_dl_clap) | |
| _fut_clip = _pool.submit(_dl_clip) | |
| _fut_aldm2 = _pool.submit(_dl_audioldm2) | |
| _fut_bigvgan = _pool.submit(_dl_bigvgan) | |
| # Raise any download exceptions immediately | |
| for _fut in as_completed([_fut_taro, _fut_mmaudio, _fut_hunyuan, | |
| _fut_clap, _fut_clip, _fut_aldm2, _fut_bigvgan]): | |
| _fut.result() | |
| cavp_ckpt_path, onset_ckpt_path, taro_ckpt_path = _fut_taro.result() | |
| mmaudio_model_path, mmaudio_vae_path, mmaudio_synchformer_path = _fut_mmaudio.result() | |
| print(f"[startup] All downloads done in {time.perf_counter() - _t_dl_start:.1f}s") | |
| # ================================================================== # | |
| # SHARED CONSTANTS / HELPERS # | |
| # ================================================================== # | |
| # CPU β GPU context passing via function-name-keyed global store. | |
| # | |
| # Problem: ZeroGPU runs @spaces.GPU functions on its own worker thread, so | |
| # threading.local() is invisible to the GPU worker. Passing ctx as a | |
| # function argument exposes it to Gradio's API endpoint, causing | |
| # "Too many arguments" errors. | |
| # | |
| # Solution: store context in a plain global dict keyed by function name. | |
| # A per-key Lock serialises concurrent callers for the same function | |
| # (ZeroGPU is already synchronous β the wrapper blocks until the GPU fn | |
| # returns β so in practice only one call per GPU fn is in-flight at a time). | |
| # The global dict is readable from any thread. | |
| _GPU_CTX: dict = {} | |
| _GPU_CTX_LOCK = threading.Lock() | |
| def _ctx_store(fn_name: str, data: dict) -> None: | |
| """Store *data* under *fn_name* key (overwrites previous).""" | |
| with _GPU_CTX_LOCK: | |
| _GPU_CTX[fn_name] = data | |
| def _ctx_load(fn_name: str) -> dict: | |
| """Pop and return the context dict stored under *fn_name*.""" | |
| with _GPU_CTX_LOCK: | |
| return _GPU_CTX.pop(fn_name, {}) | |
| MAX_SLOTS = 8 # max parallel generation slots shown in UI | |
| MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video β€ ~64 s at 8 s/seg) | |
| # Segment overlay palette β shared between _build_waveform_html and _build_regen_pending_html | |
| SEG_COLORS = [ | |
| "rgba(100,180,255,{a})", "rgba(255,160,100,{a})", | |
| "rgba(120,220,140,{a})", "rgba(220,120,220,{a})", | |
| "rgba(255,220,80,{a})", "rgba(80,220,220,{a})", | |
| "rgba(255,100,100,{a})", "rgba(180,255,180,{a})", | |
| ] | |
| # ------------------------------------------------------------------ # | |
| # Micro-helpers that eliminate repeated boilerplate across the file # | |
| # ------------------------------------------------------------------ # | |
| def _ensure_syspath(subdir: str) -> str: | |
| """Add *subdir* (relative to app.py) to sys.path if not already present. | |
| Returns the absolute path for convenience.""" | |
| p = os.path.join(os.path.dirname(os.path.abspath(__file__)), subdir) | |
| if p not in sys.path: | |
| sys.path.insert(0, p) | |
| return p | |
| def _get_device_and_dtype() -> tuple: | |
| """Return (device, weight_dtype) pair used by all GPU functions.""" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| return device, torch.bfloat16 | |
| def _extract_segment_clip(silent_video: str, seg_start: float, seg_dur: float, | |
| output_path: str) -> str: | |
| """Stream-copy a segment from *silent_video* to *output_path*. Returns *output_path*.""" | |
| ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output( | |
| output_path, vcodec="copy", an=None | |
| ).run(overwrite_output=True, quiet=True) | |
| return output_path | |
| # Per-slot reentrant locks β prevent concurrent regens on the same slot from | |
| # producing a race condition where the second regen reads stale state | |
| # (the shared seg_state textbox hasn't been updated yet by the first regen). | |
| # Locks are keyed by slot_id string (e.g. "taro_0", "mma_2"). | |
| _SLOT_LOCKS: dict = {} | |
| _SLOT_LOCKS_MUTEX = threading.Lock() | |
| def _get_slot_lock(slot_id: str) -> threading.Lock: | |
| with _SLOT_LOCKS_MUTEX: | |
| if slot_id not in _SLOT_LOCKS: | |
| _SLOT_LOCKS[slot_id] = threading.Lock() | |
| return _SLOT_LOCKS[slot_id] | |
| def set_global_seed(seed: int) -> None: | |
| np.random.seed(seed % (2**32)) | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| def get_random_seed() -> int: | |
| return random.randint(0, 2**32 - 1) | |
| def get_video_duration(video_path: str) -> float: | |
| """Return video duration in seconds (CPU only).""" | |
| probe = ffmpeg.probe(video_path) | |
| return float(probe["format"]["duration"]) | |
| def strip_audio_from_video(video_path: str, output_path: str) -> None: | |
| """Write a silent copy of *video_path* to *output_path* (stream-copy, no re-encode).""" | |
| ffmpeg.input(video_path).output(output_path, vcodec="copy", an=None).run( | |
| overwrite_output=True, quiet=True | |
| ) | |
| def _transcode_for_browser(video_path: str) -> str: | |
| """Re-encode uploaded video to H.264/AAC MP4 so the browser preview widget can play it. | |
| Returns a NEW path in a fresh /tmp/gradio/ subdirectory. Gradio probes the | |
| returned path fresh, sees H.264, and serves it directly without its own | |
| slow fallback converter. The in-place overwrite approach loses the race | |
| because Gradio probes the original path at upload time before this callback runs. | |
| Only called on upload β not during generation. | |
| """ | |
| if video_path is None: | |
| return video_path | |
| try: | |
| probe = ffmpeg.probe(video_path) | |
| has_audio = any(s["codec_type"] == "audio" for s in probe.get("streams", [])) | |
| # Check if already H.264 β skip transcode if so | |
| video_streams = [s for s in probe.get("streams", []) if s["codec_type"] == "video"] | |
| if video_streams and video_streams[0].get("codec_name") == "h264": | |
| print(f"[transcode_for_browser] already H.264, skipping") | |
| return video_path | |
| # Write the H.264 output into the SAME directory as the original upload. | |
| # Gradio's file server only allows paths under dirs it registered β the | |
| # upload dir is already allowed, so a sibling file there will serve fine. | |
| import os as _os | |
| upload_dir = _os.path.dirname(video_path) | |
| stem = _os.path.splitext(_os.path.basename(video_path))[0] | |
| out_path = _os.path.join(upload_dir, stem + "_h264.mp4") | |
| kwargs = dict( | |
| vcodec="libx264", preset="fast", crf=18, | |
| pix_fmt="yuv420p", movflags="+faststart", | |
| ) | |
| if has_audio: | |
| kwargs["acodec"] = "aac" | |
| kwargs["audio_bitrate"] = "128k" | |
| else: | |
| kwargs["an"] = None | |
| # map 0:v:0 explicitly to skip non-video streams (e.g. data/timecode tracks) | |
| ffmpeg.input(video_path).output(out_path, map="0:v:0", **kwargs).run( | |
| overwrite_output=True, quiet=True | |
| ) | |
| print(f"[transcode_for_browser] transcoded to H.264: {out_path}") | |
| return out_path | |
| except Exception as e: | |
| print(f"[transcode_for_browser] failed, using original: {e}") | |
| return video_path | |
| # ------------------------------------------------------------------ # | |
| # Temp directory registry β tracks dirs for cleanup on new generation # | |
| # ------------------------------------------------------------------ # | |
| _TEMP_DIRS: list = [] # list of tmp_dir paths created by generate_* | |
| _TEMP_DIRS_MAX = 10 # keep at most this many; older ones get cleaned up | |
| def _register_tmp_dir(tmp_dir: str) -> str: | |
| """Register a temp dir so it can be cleaned up when newer ones replace it.""" | |
| _TEMP_DIRS.append(tmp_dir) | |
| while len(_TEMP_DIRS) > _TEMP_DIRS_MAX: | |
| old = _TEMP_DIRS.pop(0) | |
| try: | |
| shutil.rmtree(old, ignore_errors=True) | |
| print(f"[cleanup] Removed old temp dir: {old}") | |
| except Exception: | |
| pass | |
| return tmp_dir | |
| def _save_seg_wavs(wavs: list[np.ndarray], tmp_dir: str, prefix: str) -> list[str]: | |
| """Save a list of numpy wav arrays to .npy files, return list of paths. | |
| This avoids serialising large float arrays into JSON/HTML data-state.""" | |
| paths = [] | |
| for i, w in enumerate(wavs): | |
| p = os.path.join(tmp_dir, f"{prefix}_seg{i}.npy") | |
| np.save(p, w) | |
| paths.append(p) | |
| return paths | |
| def _load_seg_wavs(paths: list[str]) -> list[np.ndarray]: | |
| """Load segment wav arrays from .npy file paths.""" | |
| return [np.load(p) for p in paths] | |
| # ------------------------------------------------------------------ # | |
| # Shared model-loading helpers (deduplicate generate / regen code) # | |
| # ------------------------------------------------------------------ # | |
| def _load_taro_models(device, weight_dtype): | |
| """Load TARO MMDiT + AudioLDM2 VAE/vocoder. Returns (model_net, vae, vocoder, latents_scale).""" | |
| from TARO.models import MMDiT | |
| from diffusers import AutoencoderKL | |
| from transformers import SpeechT5HifiGan | |
| model_net = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device) | |
| model_net.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"]) | |
| model_net.eval().to(weight_dtype) | |
| vae = AutoencoderKL.from_pretrained("cvssp/audioldm2", subfolder="vae").to(device).eval() | |
| vocoder = SpeechT5HifiGan.from_pretrained("cvssp/audioldm2", subfolder="vocoder").to(device) | |
| latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device) | |
| return model_net, vae, vocoder, latents_scale | |
| def _load_taro_feature_extractors(device): | |
| """Load CAVP + onset extractors. Returns (extract_cavp, onset_model).""" | |
| from TARO.cavp_util import Extract_CAVP_Features | |
| from TARO.onset_util import VideoOnsetNet | |
| extract_cavp = Extract_CAVP_Features( | |
| device=device, config_path="TARO/cavp/cavp.yaml", ckpt_path=cavp_ckpt_path, | |
| ) | |
| raw_sd = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"] | |
| onset_sd = {} | |
| for k, v in raw_sd.items(): | |
| if "model.net.model" in k: k = k.replace("model.net.model", "net.model") | |
| elif "model.fc." in k: k = k.replace("model.fc", "fc") | |
| onset_sd[k] = v | |
| onset_model = VideoOnsetNet(pretrained=False).to(device) | |
| onset_model.load_state_dict(onset_sd) | |
| onset_model.eval() | |
| return extract_cavp, onset_model | |
| def _load_mmaudio_models(device, dtype): | |
| """Load MMAudio net + feature_utils. Returns (net, feature_utils, model_cfg, seq_cfg).""" | |
| from mmaudio.eval_utils import all_model_cfg | |
| from mmaudio.model.networks import get_my_mmaudio | |
| from mmaudio.model.utils.features_utils import FeaturesUtils | |
| model_cfg = all_model_cfg["large_44k_v2"] | |
| model_cfg.model_path = Path(mmaudio_model_path) | |
| model_cfg.vae_path = Path(mmaudio_vae_path) | |
| model_cfg.synchformer_ckpt = Path(mmaudio_synchformer_path) | |
| model_cfg.bigvgan_16k_path = None | |
| seq_cfg = model_cfg.seq_cfg | |
| net = get_my_mmaudio(model_cfg.model_name).to(device, dtype).eval() | |
| net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True)) | |
| feature_utils = FeaturesUtils( | |
| tod_vae_ckpt=str(model_cfg.vae_path), | |
| synchformer_ckpt=str(model_cfg.synchformer_ckpt), | |
| enable_conditions=True, mode=model_cfg.mode, | |
| bigvgan_vocoder_ckpt=None, need_vae_encoder=False, | |
| ).to(device, dtype).eval() | |
| return net, feature_utils, model_cfg, seq_cfg | |
| def _load_hunyuan_model(device, model_size): | |
| """Load HunyuanFoley model dict + config. Returns (model_dict, cfg).""" | |
| from hunyuanvideo_foley.utils.model_utils import load_model | |
| model_size = model_size.lower() | |
| config_map = { | |
| "xl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xl.yaml", | |
| "xxl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xxl.yaml", | |
| } | |
| config_path = config_map.get(model_size, config_map["xxl"]) | |
| hunyuan_weights_dir = str(HUNYUAN_MODEL_DIR / "HunyuanVideo-Foley") | |
| print(f"[HunyuanFoley] Loading {model_size.upper()} model from {hunyuan_weights_dir}") | |
| return load_model(hunyuan_weights_dir, config_path, device, | |
| enable_offload=False, model_size=model_size) | |
| def mux_video_audio(silent_video: str, audio_path: str, output_path: str, | |
| model: str = None) -> None: | |
| """Mux a silent video with an audio file into *output_path*. | |
| For HunyuanFoley (*model*="hunyuan") we use its own merge_audio_video which | |
| handles its specific ffmpeg quirks; all other models use stream-copy muxing. | |
| """ | |
| if model == "hunyuan": | |
| _ensure_syspath("HunyuanVideo-Foley") | |
| from hunyuanvideo_foley.utils.media_utils import merge_audio_video | |
| merge_audio_video(audio_path, silent_video, output_path) | |
| else: | |
| v_in = ffmpeg.input(silent_video) | |
| a_in = ffmpeg.input(audio_path) | |
| ffmpeg.output( | |
| v_in["v:0"], | |
| a_in["a:0"], | |
| output_path, | |
| vcodec="libx264", preset="fast", crf=18, | |
| pix_fmt="yuv420p", | |
| acodec="aac", audio_bitrate="128k", | |
| movflags="+faststart", | |
| ).run(overwrite_output=True, quiet=True) | |
| # ------------------------------------------------------------------ # | |
| # Shared sliding-window segmentation and crossfade helpers # | |
| # Used by all three models (TARO, MMAudio, HunyuanFoley). # | |
| # ------------------------------------------------------------------ # | |
| def _build_segments(total_dur_s: float, window_s: float, crossfade_s: float) -> list[tuple[float, float]]: | |
| """Return list of (start, end) pairs covering *total_dur_s* with a sliding | |
| window of *window_s* and *crossfade_s* overlap between consecutive segments.""" | |
| # Safety: clamp crossfade to < half the window so step_s stays positive | |
| crossfade_s = min(crossfade_s, window_s * 0.5) | |
| if total_dur_s <= window_s: | |
| return [(0.0, total_dur_s)] | |
| step_s = window_s - crossfade_s | |
| segments, seg_start = [], 0.0 | |
| while True: | |
| if seg_start + window_s >= total_dur_s: | |
| seg_start = max(0.0, total_dur_s - window_s) | |
| segments.append((seg_start, total_dur_s)) | |
| break | |
| segments.append((seg_start, seg_start + window_s)) | |
| seg_start += step_s | |
| return segments | |
| def _cf_join(a: np.ndarray, b: np.ndarray, | |
| crossfade_s: float, db_boost: float, sr: int) -> np.ndarray: | |
| """Equal-power crossfade join. Works for both mono (T,) and stereo (C, T) arrays. | |
| Stereo arrays are expected in (channels, samples) layout. | |
| db_boost is applied to the overlap region as a whole (after blending), so | |
| it compensates for the -3 dB equal-power dip without doubling amplitude. | |
| Applying gain to each side independently (the common mistake) causes a | |
| +3 dB loudness bump at the seam β this version avoids that.""" | |
| stereo = a.ndim == 2 | |
| n_a = a.shape[1] if stereo else len(a) | |
| n_b = b.shape[1] if stereo else len(b) | |
| cf = min(int(round(crossfade_s * sr)), n_a, n_b) | |
| if cf <= 0: | |
| return np.concatenate([a, b], axis=1 if stereo else 0) | |
| gain = 10 ** (db_boost / 20.0) | |
| t = np.linspace(0.0, 1.0, cf, dtype=np.float32) | |
| fade_out = np.cos(t * np.pi / 2) # 1 β 0 | |
| fade_in = np.sin(t * np.pi / 2) # 0 β 1 | |
| if stereo: | |
| # Blend first, then apply boost to the overlap region as a unit | |
| overlap = (a[:, -cf:] * fade_out + b[:, :cf] * fade_in) * gain | |
| return np.concatenate([a[:, :-cf], overlap, b[:, cf:]], axis=1) | |
| else: | |
| overlap = (a[-cf:] * fade_out + b[:cf] * fade_in) * gain | |
| return np.concatenate([a[:-cf], overlap, b[cf:]]) | |
| # ================================================================== # | |
| # TARO # | |
| # ================================================================== # | |
| # Constants sourced from TARO/infer.py and TARO/models.py: | |
| # SR=16000, TRUNCATE=131072 β 8.192 s window | |
| # TRUNCATE_FRAME = 4 fps Γ 131072/16000 = 32 CAVP frames per window | |
| # TRUNCATE_ONSET = 120 onset frames per window | |
| # latent shape: (1, 8, 204, 16) β fixed by MMDiT architecture | |
| # latents_scale: [0.18215]*8 β AudioLDM2 VAE scale factor | |
| # ================================================================== # | |
| # ================================================================== # | |
| # MODEL CONSTANTS & CONFIGURATION REGISTRY # | |
| # ================================================================== # | |
| # All per-model numeric constants live here β MODEL_CONFIGS is the # | |
| # single source of truth consumed by duration estimation, segmentation,# | |
| # and the UI. Standalone names kept only where other code references # | |
| # them by name (TARO geometry, TARGET_SR, GPU_DURATION_CAP). # | |
| # ================================================================== # | |
| # TARO geometry β referenced directly in _taro_infer_segment | |
| TARO_SR = 16000 | |
| TARO_TRUNCATE = 131072 | |
| TARO_FPS = 4 | |
| TARO_TRUNCATE_FRAME = int(TARO_FPS * TARO_TRUNCATE / TARO_SR) # 32 | |
| TARO_TRUNCATE_ONSET = 120 | |
| TARO_MODEL_DUR = TARO_TRUNCATE / TARO_SR # 8.192 s | |
| GPU_DURATION_CAP = 300 # hard cap per @spaces.GPU call β never reserve more than this | |
| MODEL_CONFIGS = { | |
| "taro": { | |
| "window_s": TARO_MODEL_DUR, # 8.192 s | |
| "sr": TARO_SR, # 16000 (output resampled to TARGET_SR) | |
| "secs_per_step": 0.025, # measured 0.023 s/step on H200 | |
| "load_overhead": 15, # model load + CAVP feature extraction | |
| "tab_prefix": "taro", | |
| "label": "TARO", | |
| "regen_fn": None, # set after function definitions (avoids forward-ref) | |
| }, | |
| "mmaudio": { | |
| "window_s": 8.0, # MMAudio's fixed generation window | |
| "sr": 48000, # resampled from 44100 in post-processing | |
| "secs_per_step": 0.25, # measured 0.230 s/step on H200 | |
| "load_overhead": 30, # 15s warm + 15s model init | |
| "tab_prefix": "mma", | |
| "label": "MMAudio", | |
| "regen_fn": None, | |
| }, | |
| "hunyuan": { | |
| "window_s": 15.0, # HunyuanFoley max video duration | |
| "sr": 48000, | |
| "secs_per_step": 0.35, # measured 0.328 s/step on H200 | |
| "load_overhead": 55, # ~55s to load the 10 GB XXL weights | |
| "tab_prefix": "hf", | |
| "label": "HunyuanFoley", | |
| "regen_fn": None, | |
| }, | |
| } | |
| # Convenience aliases used only in the TARO inference path | |
| TARO_SECS_PER_STEP = MODEL_CONFIGS["taro"]["secs_per_step"] | |
| MMAUDIO_WINDOW = MODEL_CONFIGS["mmaudio"]["window_s"] | |
| MMAUDIO_SECS_PER_STEP = MODEL_CONFIGS["mmaudio"]["secs_per_step"] | |
| HUNYUAN_MAX_DUR = MODEL_CONFIGS["hunyuan"]["window_s"] | |
| HUNYUAN_SECS_PER_STEP = MODEL_CONFIGS["hunyuan"]["secs_per_step"] | |
| def _clamp_duration(secs: float, label: str) -> int: | |
| """Clamp a raw GPU-seconds estimate to [60, GPU_DURATION_CAP] and log it.""" | |
| result = min(GPU_DURATION_CAP, max(60, int(secs))) | |
| print(f"[duration] {label}: {secs:.0f}s raw β {result}s reserved") | |
| return result | |
| def _estimate_gpu_duration(model_key: str, num_samples: int, num_steps: int, | |
| total_dur_s: float = None, crossfade_s: float = 0, | |
| video_file: str = None) -> int: | |
| """Estimate GPU seconds for a full generation call. | |
| Formula: num_samples Γ n_segs Γ num_steps Γ secs_per_step + load_overhead | |
| """ | |
| cfg = MODEL_CONFIGS[model_key] | |
| try: | |
| if total_dur_s is None: | |
| total_dur_s = get_video_duration(video_file) | |
| n_segs = len(_build_segments(total_dur_s, cfg["window_s"], float(crossfade_s))) | |
| except Exception: | |
| n_segs = 1 | |
| secs = int(num_samples) * n_segs * int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"] | |
| print(f"[duration] {cfg['label']}: {int(num_samples)}samp Γ {n_segs}seg Γ " | |
| f"{int(num_steps)}steps β {secs:.0f}s β capped ", end="") | |
| return _clamp_duration(secs, cfg["label"]) | |
| def _estimate_regen_duration(model_key: str, num_steps: int) -> int: | |
| """Estimate GPU seconds for a single-segment regen call.""" | |
| cfg = MODEL_CONFIGS[model_key] | |
| secs = int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"] | |
| print(f"[duration] {cfg['label']} regen: 1 seg Γ {int(num_steps)} steps β ", end="") | |
| return _clamp_duration(secs, f"{cfg['label']} regen") | |
| _TARO_CACHE_MAXLEN = 16 # evict oldest entries beyond this limit | |
| _TARO_INFERENCE_CACHE: dict = {} # keyed by (video_file, seed, cfg, steps, mode, crossfade_s) | |
| _TARO_CACHE_LOCK = threading.Lock() | |
| def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: float) -> int: | |
| n_segs = len(_build_segments(total_dur_s, TARO_MODEL_DUR, crossfade_s)) | |
| time_per_seg = num_steps * TARO_SECS_PER_STEP | |
| max_s = int(600.0 / (n_segs * time_per_seg)) | |
| return max(1, min(max_s, MAX_SLOTS)) | |
| def _taro_duration(video_file, seed_val, cfg_scale, num_steps, mode, | |
| crossfade_s, crossfade_db, num_samples): | |
| """Pre-GPU callable β must match _taro_gpu_infer's input order exactly.""" | |
| return _estimate_gpu_duration("taro", int(num_samples), int(num_steps), | |
| video_file=video_file, crossfade_s=crossfade_s) | |
| def _taro_infer_segment( | |
| model, vae, vocoder, | |
| cavp_feats_full, onset_feats_full, | |
| seg_start_s: float, seg_end_s: float, | |
| device, weight_dtype, | |
| cfg_scale: float, num_steps: int, mode: str, | |
| latents_scale, | |
| euler_sampler, euler_maruyama_sampler, | |
| ) -> np.ndarray: | |
| """Single-segment TARO inference. Returns wav array trimmed to segment length.""" | |
| # CAVP features (4 fps) | |
| cavp_start = int(round(seg_start_s * TARO_FPS)) | |
| cavp_slice = cavp_feats_full[cavp_start : cavp_start + TARO_TRUNCATE_FRAME] | |
| if cavp_slice.shape[0] < TARO_TRUNCATE_FRAME: | |
| pad = np.zeros( | |
| (TARO_TRUNCATE_FRAME - cavp_slice.shape[0],) + cavp_slice.shape[1:], | |
| dtype=cavp_slice.dtype, | |
| ) | |
| cavp_slice = np.concatenate([cavp_slice, pad], axis=0) | |
| video_feats = torch.from_numpy(cavp_slice).unsqueeze(0).to(device, weight_dtype) | |
| # Onset features (onset_fps = TRUNCATE_ONSET / MODEL_DUR β 14.65 fps) | |
| onset_fps = TARO_TRUNCATE_ONSET / TARO_MODEL_DUR | |
| onset_start = int(round(seg_start_s * onset_fps)) | |
| onset_slice = onset_feats_full[onset_start : onset_start + TARO_TRUNCATE_ONSET] | |
| if onset_slice.shape[0] < TARO_TRUNCATE_ONSET: | |
| onset_slice = np.pad( | |
| onset_slice, | |
| ((0, TARO_TRUNCATE_ONSET - onset_slice.shape[0]),), | |
| mode="constant", | |
| ) | |
| onset_feats_t = torch.from_numpy(onset_slice).unsqueeze(0).to(device, weight_dtype) | |
| # Latent noise β shape matches MMDiT architecture (in_channels=8, 204Γ16 spatial) | |
| z = torch.randn(1, model.in_channels, 204, 16, device=device, dtype=weight_dtype) | |
| sampling_kwargs = dict( | |
| model=model, | |
| latents=z, | |
| y=onset_feats_t, | |
| context=video_feats, | |
| num_steps=int(num_steps), | |
| heun=False, | |
| cfg_scale=float(cfg_scale), | |
| guidance_low=0.0, | |
| guidance_high=0.7, | |
| path_type="linear", | |
| ) | |
| with torch.no_grad(): | |
| samples = (euler_maruyama_sampler if mode == "sde" else euler_sampler)(**sampling_kwargs) | |
| # samplers return (output_tensor, zs) β index [0] for the audio latent | |
| if isinstance(samples, tuple): | |
| samples = samples[0] | |
| # Decode: AudioLDM2 VAE β mel β vocoder β waveform | |
| samples = vae.decode(samples / latents_scale).sample | |
| wav = vocoder(samples.squeeze().float()).detach().cpu().numpy() | |
| seg_samples = int(round((seg_end_s - seg_start_s) * TARO_SR)) | |
| return wav[:seg_samples] | |
| # ================================================================== # | |
| # TARO 16 kHz β 48 kHz upsample # | |
| # ================================================================== # | |
| # TARO generates at 16 kHz; all other models output at 44.1/48 kHz. | |
| # We upsample via sinc resampling (torchaudio, CPU-only) so the final | |
| # stitched audio is uniformly at 48 kHz across all three models. | |
| TARGET_SR = 48000 # unified output sample rate for all three models | |
| TARO_SR_OUT = TARGET_SR | |
| def _resample_to_target(wav: np.ndarray, src_sr: int, | |
| dst_sr: int = None) -> np.ndarray: | |
| """Resample *wav* (mono or stereo numpy float32) from *src_sr* to *dst_sr*. | |
| *dst_sr* defaults to TARGET_SR (48 kHz). No-op if src_sr == dst_sr. | |
| Uses torchaudio Kaiser-windowed sinc resampling β CPU-only, ZeroGPU-safe. | |
| """ | |
| if dst_sr is None: | |
| dst_sr = TARGET_SR | |
| if src_sr == dst_sr: | |
| return wav | |
| stereo = wav.ndim == 2 | |
| t = torch.from_numpy(np.ascontiguousarray(wav.astype(np.float32))) | |
| if not stereo: | |
| t = t.unsqueeze(0) # [1, T] | |
| t = torchaudio.functional.resample(t, src_sr, dst_sr) | |
| if not stereo: | |
| t = t.squeeze(0) # [T] | |
| return t.numpy() | |
| def _upsample_taro(wav_16k: np.ndarray) -> np.ndarray: | |
| """Upsample a mono 16 kHz numpy array to 48 kHz via sinc resampling (CPU). | |
| torchaudio.functional.resample uses a Kaiser-windowed sinc filter β | |
| mathematically optimal for bandlimited signals, zero CUDA risk. | |
| Returns a mono float32 numpy array at 48 kHz. | |
| """ | |
| dur_in = len(wav_16k) / TARO_SR | |
| print(f"[TARO upsample] {dur_in:.2f}s @ {TARO_SR}Hz β {TARGET_SR}Hz (sinc, CPU) β¦") | |
| result = _resample_to_target(wav_16k, TARO_SR) | |
| print(f"[TARO upsample] done β {len(result)/TARGET_SR:.2f}s @ {TARGET_SR}Hz " | |
| f"(expected {dur_in * 3:.2f}s, ratio 3Γ)") | |
| return result | |
| def _stitch_wavs(wavs: list[np.ndarray], crossfade_s: float, db_boost: float, | |
| total_dur_s: float, sr: int) -> np.ndarray: | |
| """Crossfade-join a list of wav arrays and trim to *total_dur_s*. | |
| Works for both mono (T,) and stereo (C, T) arrays.""" | |
| out = wavs[0] | |
| for nw in wavs[1:]: | |
| out = _cf_join(out, nw, crossfade_s, db_boost, sr) | |
| n = int(round(total_dur_s * sr)) | |
| return out[:, :n] if out.ndim == 2 else out[:n] | |
| def _save_wav(path: str, wav: np.ndarray, sr: int) -> None: | |
| """Save a numpy wav array (mono or stereo) to *path* via torchaudio.""" | |
| t = torch.from_numpy(np.ascontiguousarray(wav)) | |
| if t.ndim == 1: | |
| t = t.unsqueeze(0) | |
| torchaudio.save(path, t, sr) | |
| def _log_inference_timing(label: str, elapsed: float, n_segs: int, | |
| num_steps: int, constant: float) -> None: | |
| """Print a standardised inference-timing summary line.""" | |
| total_steps = n_segs * num_steps | |
| secs_per_step = elapsed / total_steps if total_steps > 0 else 0 | |
| print(f"[{label}] Inference done: {n_segs} seg(s) Γ {num_steps} steps in " | |
| f"{elapsed:.1f}s wall β {secs_per_step:.3f}s/step " | |
| f"(current constant={constant})") | |
| def _build_seg_meta(*, segments, wav_paths, audio_path, video_path, | |
| silent_video, sr, model, crossfade_s, crossfade_db, | |
| total_dur_s, **extras) -> dict: | |
| """Build the seg_meta dict shared by all three generate_* functions. | |
| Model-specific keys are passed via **extras.""" | |
| meta = { | |
| "segments": segments, | |
| "wav_paths": wav_paths, | |
| "audio_path": audio_path, | |
| "video_path": video_path, | |
| "silent_video": silent_video, | |
| "sr": sr, | |
| "model": model, | |
| "crossfade_s": crossfade_s, | |
| "crossfade_db": crossfade_db, | |
| "total_dur_s": total_dur_s, | |
| } | |
| meta.update(extras) | |
| return meta | |
| def _post_process_samples(results: list, *, model: str, tmp_dir: str, | |
| silent_video: str, segments: list, | |
| crossfade_s: float, crossfade_db: float, | |
| total_dur_s: float, sr: int, | |
| extra_meta_fn=None) -> list: | |
| """Shared CPU post-processing for all three generate_* wrappers. | |
| Each entry in *results* is a tuple whose first element is a list of | |
| per-segment wav arrays. The remaining elements are model-specific | |
| (e.g. TARO returns features, HunyuanFoley returns text_feats). | |
| *extra_meta_fn(sample_idx, result_tuple, tmp_dir) -> dict* is an optional | |
| callback that returns model-specific extra keys to merge into seg_meta | |
| (e.g. cavp_path, onset_path, text_feats_path). | |
| Returns a list of (video_path, audio_path, seg_meta) tuples. | |
| """ | |
| outputs = [] | |
| for sample_idx, result in enumerate(results): | |
| seg_wavs = result[0] | |
| full_wav = _stitch_wavs(seg_wavs, crossfade_s, crossfade_db, total_dur_s, sr) | |
| audio_path = os.path.join(tmp_dir, f"{model}_{sample_idx}.wav") | |
| _save_wav(audio_path, full_wav, sr) | |
| video_path = os.path.join(tmp_dir, f"{model}_{sample_idx}.mp4") | |
| mux_video_audio(silent_video, audio_path, video_path, model=model) | |
| wav_paths = _save_seg_wavs(seg_wavs, tmp_dir, f"{model}_{sample_idx}") | |
| extras = extra_meta_fn(sample_idx, result, tmp_dir) if extra_meta_fn else {} | |
| seg_meta = _build_seg_meta( | |
| segments=segments, wav_paths=wav_paths, audio_path=audio_path, | |
| video_path=video_path, silent_video=silent_video, sr=sr, | |
| model=model, crossfade_s=crossfade_s, crossfade_db=crossfade_db, | |
| total_dur_s=total_dur_s, **extras, | |
| ) | |
| outputs.append((video_path, audio_path, seg_meta)) | |
| return outputs | |
| def _cpu_preprocess(video_file: str, model_dur: float, | |
| crossfade_s: float) -> tuple: | |
| """Shared CPU pre-processing for all generate_* wrappers. | |
| Returns (tmp_dir, silent_video, total_dur_s, segments).""" | |
| tmp_dir = _register_tmp_dir(tempfile.mkdtemp()) | |
| silent_video = os.path.join(tmp_dir, "silent_input.mp4") | |
| strip_audio_from_video(video_file, silent_video) | |
| total_dur_s = get_video_duration(video_file) | |
| segments = _build_segments(total_dur_s, model_dur, crossfade_s) | |
| return tmp_dir, silent_video, total_dur_s, segments | |
| def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode, | |
| crossfade_s, crossfade_db, num_samples): | |
| """GPU-only TARO inference β model loading + feature extraction + diffusion. | |
| Returns list of (wavs_list, onset_feats) per sample.""" | |
| seed_val = int(seed_val) | |
| crossfade_s = float(crossfade_s) | |
| num_samples = int(num_samples) | |
| if seed_val < 0: | |
| seed_val = random.randint(0, 2**32 - 1) | |
| torch.set_grad_enabled(False) | |
| device, weight_dtype = _get_device_and_dtype() | |
| _ensure_syspath("TARO") | |
| from TARO.onset_util import extract_onset | |
| from TARO.samplers import euler_sampler, euler_maruyama_sampler | |
| ctx = _ctx_load("taro_gpu_infer") | |
| tmp_dir = ctx["tmp_dir"] | |
| silent_video = ctx["silent_video"] | |
| segments = ctx["segments"] | |
| total_dur_s = ctx["total_dur_s"] | |
| extract_cavp, onset_model = _load_taro_feature_extractors(device) | |
| cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir) | |
| # Onset features depend only on the video β extract once for all samples | |
| onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device) | |
| # Free feature extractors before loading the heavier inference models | |
| del extract_cavp, onset_model | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| model, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype) | |
| results = [] # list of (wavs, onset_feats) per sample | |
| for sample_idx in range(num_samples): | |
| sample_seed = seed_val + sample_idx | |
| cache_key = (video_file, sample_seed, float(cfg_scale), int(num_steps), mode, crossfade_s) | |
| with _TARO_CACHE_LOCK: | |
| cached = _TARO_INFERENCE_CACHE.get(cache_key) | |
| if cached is not None: | |
| print(f"[TARO] Sample {sample_idx+1}: cache hit.") | |
| results.append((cached["wavs"], cavp_feats, None)) | |
| else: | |
| set_global_seed(sample_seed) | |
| wavs = [] | |
| _t_infer_start = time.perf_counter() | |
| for seg_start_s, seg_end_s in segments: | |
| print(f"[TARO] Sample {sample_idx+1} | {seg_start_s:.2f}s β {seg_end_s:.2f}s") | |
| wav = _taro_infer_segment( | |
| model, vae, vocoder, | |
| cavp_feats, onset_feats, | |
| seg_start_s, seg_end_s, | |
| device, weight_dtype, | |
| cfg_scale, num_steps, mode, | |
| latents_scale, | |
| euler_sampler, euler_maruyama_sampler, | |
| ) | |
| wavs.append(wav) | |
| _log_inference_timing("TARO", time.perf_counter() - _t_infer_start, | |
| len(segments), int(num_steps), TARO_SECS_PER_STEP) | |
| with _TARO_CACHE_LOCK: | |
| _TARO_INFERENCE_CACHE[cache_key] = {"wavs": wavs} | |
| while len(_TARO_INFERENCE_CACHE) > _TARO_CACHE_MAXLEN: | |
| _TARO_INFERENCE_CACHE.pop(next(iter(_TARO_INFERENCE_CACHE))) | |
| results.append((wavs, cavp_feats, onset_feats)) | |
| # Free GPU memory between samples so VRAM fragmentation doesn't | |
| # degrade diffusion quality on samples 2, 3, 4, etc. | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return results | |
| def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode, | |
| crossfade_s, crossfade_db, num_samples): | |
| """TARO: video-conditioned diffusion, 16 kHz, 8.192 s sliding window. | |
| CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost.""" | |
| crossfade_s = float(crossfade_s) | |
| crossfade_db = float(crossfade_db) | |
| num_samples = int(num_samples) | |
| # ββ CPU pre-processing (no GPU needed) ββ | |
| tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess( | |
| video_file, TARO_MODEL_DUR, crossfade_s) | |
| _ctx_store("taro_gpu_infer", { | |
| "tmp_dir": tmp_dir, "silent_video": silent_video, | |
| "segments": segments, "total_dur_s": total_dur_s, | |
| }) | |
| # ββ GPU inference only ββ | |
| results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode, | |
| crossfade_s, crossfade_db, num_samples) | |
| # ββ CPU post-processing (no GPU needed) ββ | |
| # Upsample 16kHz β 48kHz and normalise result tuples to (seg_wavs, ...) | |
| cavp_path = os.path.join(tmp_dir, "taro_cavp.npy") | |
| onset_path = os.path.join(tmp_dir, "taro_onset.npy") | |
| _feats_saved = False | |
| def _upsample_and_save_feats(result): | |
| nonlocal _feats_saved | |
| wavs, cavp_feats, onset_feats = result | |
| wavs = [_upsample_taro(w) for w in wavs] | |
| if not _feats_saved: | |
| np.save(cavp_path, cavp_feats) | |
| if onset_feats is not None: | |
| np.save(onset_path, onset_feats) | |
| _feats_saved = True | |
| return (wavs, cavp_feats, onset_feats) | |
| results = [_upsample_and_save_feats(r) for r in results] | |
| def _taro_extras(sample_idx, result, td): | |
| return {"cavp_path": cavp_path, "onset_path": onset_path} | |
| outputs = _post_process_samples( | |
| results, model="taro", tmp_dir=tmp_dir, | |
| silent_video=silent_video, segments=segments, | |
| crossfade_s=crossfade_s, crossfade_db=crossfade_db, | |
| total_dur_s=total_dur_s, sr=TARO_SR_OUT, | |
| extra_meta_fn=_taro_extras, | |
| ) | |
| return _pad_outputs(outputs) | |
| # ================================================================== # | |
| # MMAudio # | |
| # ================================================================== # | |
| # Constants sourced from MMAudio/mmaudio/model/sequence_config.py: | |
| # CONFIG_44K: duration=8.0 s, sampling_rate=44100 | |
| # CLIP encoder: 8 fps, 384Γ384 px | |
| # Synchformer: 25 fps, 224Γ224 px | |
| # Default variant: large_44k_v2 | |
| # MMAudio uses flow-matching (FlowMatching with euler inference). | |
| # generate() handles all feature extraction + decoding internally. | |
| # ================================================================== # | |
| def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val, | |
| cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples): | |
| """Pre-GPU callable β must match _mmaudio_gpu_infer's input order exactly.""" | |
| return _estimate_gpu_duration("mmaudio", int(num_samples), int(num_steps), | |
| video_file=video_file, crossfade_s=crossfade_s) | |
| def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val, | |
| cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples): | |
| """GPU-only MMAudio inference β model loading + flow-matching generation. | |
| Returns list of (seg_audios, sr) per sample.""" | |
| _ensure_syspath("MMAudio") | |
| from mmaudio.eval_utils import generate, load_video | |
| from mmaudio.model.flow_matching import FlowMatching | |
| seed_val = int(seed_val) | |
| num_samples = int(num_samples) | |
| crossfade_s = float(crossfade_s) | |
| device, dtype = _get_device_and_dtype() | |
| net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype) | |
| ctx = _ctx_load("mmaudio_gpu_infer") | |
| segments = ctx["segments"] | |
| seg_clip_paths = ctx["seg_clip_paths"] | |
| sr = seq_cfg.sampling_rate # 44100 | |
| results = [] | |
| for sample_idx in range(num_samples): | |
| rng = torch.Generator(device=device) | |
| if seed_val >= 0: | |
| rng.manual_seed(seed_val + sample_idx) | |
| else: | |
| rng.seed() | |
| seg_audios = [] | |
| _t_mma_start = time.perf_counter() | |
| fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps) | |
| for seg_i, (seg_start, seg_end) in enumerate(segments): | |
| seg_dur = seg_end - seg_start | |
| seg_path = seg_clip_paths[seg_i] | |
| video_info = load_video(seg_path, seg_dur) | |
| clip_frames = video_info.clip_frames.unsqueeze(0) | |
| sync_frames = video_info.sync_frames.unsqueeze(0) | |
| actual_dur = video_info.duration_sec | |
| seq_cfg.duration = actual_dur | |
| net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) | |
| print(f"[MMAudio] Sample {sample_idx+1} | seg {seg_i+1}/{len(segments)} " | |
| f"{seg_start:.1f}β{seg_end:.1f}s | dur={actual_dur:.2f}s | prompt='{prompt}'") | |
| with torch.no_grad(): | |
| audios = generate( | |
| clip_frames, | |
| sync_frames, | |
| [prompt], | |
| negative_text=[negative_prompt] if negative_prompt else None, | |
| feature_utils=feature_utils, | |
| net=net, | |
| fm=fm, | |
| rng=rng, | |
| cfg_strength=float(cfg_strength), | |
| ) | |
| wav = audios.float().cpu()[0].numpy() # (C, T) | |
| seg_samples = int(round(seg_dur * sr)) | |
| wav = wav[:, :seg_samples] | |
| seg_audios.append(wav) | |
| _log_inference_timing("MMAudio", time.perf_counter() - _t_mma_start, | |
| len(segments), int(num_steps), MMAUDIO_SECS_PER_STEP) | |
| results.append((seg_audios, sr)) | |
| # Free GPU memory between samples to prevent VRAM fragmentation | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return results | |
| def generate_mmaudio(video_file, prompt, negative_prompt, seed_val, | |
| cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples): | |
| """MMAudio: flow-matching video-to-audio, 44.1 kHz, 8 s sliding window. | |
| CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost.""" | |
| num_samples = int(num_samples) | |
| crossfade_s = float(crossfade_s) | |
| crossfade_db = float(crossfade_db) | |
| # ββ CPU pre-processing ββ | |
| tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess( | |
| video_file, MMAUDIO_WINDOW, crossfade_s) | |
| print(f"[MMAudio] Video={total_dur_s:.2f}s | {len(segments)} segment(s) Γ β€8 s") | |
| seg_clip_paths = [ | |
| _extract_segment_clip(silent_video, s, e - s, os.path.join(tmp_dir, f"mma_seg_{i}.mp4")) | |
| for i, (s, e) in enumerate(segments) | |
| ] | |
| _ctx_store("mmaudio_gpu_infer", {"segments": segments, "seg_clip_paths": seg_clip_paths}) | |
| # ββ GPU inference only ββ | |
| results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val, | |
| cfg_strength, num_steps, crossfade_s, crossfade_db, | |
| num_samples) | |
| # ββ CPU post-processing ββ | |
| # Resample 44100 β 48000 and normalise tuples to (seg_wavs, ...) | |
| resampled = [] | |
| for seg_audios, sr in results: | |
| if sr != TARGET_SR: | |
| print(f"[MMAudio upsample] resampling {sr}Hz β {TARGET_SR}Hz (sinc, CPU) β¦") | |
| seg_audios = [_resample_to_target(w, sr) for w in seg_audios] | |
| print(f"[MMAudio upsample] done β {len(seg_audios)} seg(s) @ {TARGET_SR}Hz") | |
| resampled.append((seg_audios,)) | |
| outputs = _post_process_samples( | |
| resampled, model="mmaudio", tmp_dir=tmp_dir, | |
| silent_video=silent_video, segments=segments, | |
| crossfade_s=crossfade_s, crossfade_db=crossfade_db, | |
| total_dur_s=total_dur_s, sr=TARGET_SR, | |
| ) | |
| return _pad_outputs(outputs) | |
| # ================================================================== # | |
| # HunyuanVideoFoley # | |
| # ================================================================== # | |
| # Constants sourced from HunyuanVideo-Foley/hunyuanvideo_foley/constants.py | |
| # and configs/hunyuanvideo-foley-xxl.yaml: | |
| # sample_rate = 48000 Hz (from DAC VAE) | |
| # audio_frame_rate = 50 (latent fps, xxl config) | |
| # max video duration = 15 s | |
| # SigLIP2 fps = 8, Synchformer fps = 25 | |
| # CLAP text encoder: laion/larger_clap_general (auto-downloaded from HF Hub) | |
| # Default guidance_scale=4.5, num_inference_steps=50 | |
| # ================================================================== # | |
| def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val, | |
| guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, | |
| num_samples): | |
| """Pre-GPU callable β must match _hunyuan_gpu_infer's input order exactly.""" | |
| return _estimate_gpu_duration("hunyuan", int(num_samples), int(num_steps), | |
| video_file=video_file, crossfade_s=crossfade_s) | |
| def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val, | |
| guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, | |
| num_samples): | |
| """GPU-only HunyuanFoley inference β model loading + feature extraction + denoising. | |
| Returns list of (seg_wavs, sr, text_feats) per sample.""" | |
| _ensure_syspath("HunyuanVideo-Foley") | |
| from hunyuanvideo_foley.utils.model_utils import denoise_process | |
| from hunyuanvideo_foley.utils.feature_utils import feature_process | |
| seed_val = int(seed_val) | |
| num_samples = int(num_samples) | |
| crossfade_s = float(crossfade_s) | |
| if seed_val >= 0: | |
| set_global_seed(seed_val) | |
| device, _ = _get_device_and_dtype() | |
| device = torch.device(device) | |
| model_size = model_size.lower() | |
| model_dict, cfg = _load_hunyuan_model(device, model_size) | |
| ctx = _ctx_load("hunyuan_gpu_infer") | |
| segments = ctx["segments"] | |
| total_dur_s = ctx["total_dur_s"] | |
| dummy_seg_path = ctx["dummy_seg_path"] | |
| seg_clip_paths = ctx["seg_clip_paths"] | |
| # Text feature extraction (GPU β runs once for all segments) | |
| _, text_feats, _ = feature_process( | |
| dummy_seg_path, | |
| prompt if prompt else "", | |
| model_dict, | |
| cfg, | |
| neg_prompt=negative_prompt if negative_prompt else None, | |
| ) | |
| # Import visual-only feature extractor to avoid redundant text extraction | |
| # per segment (text_feats already computed once above for the whole batch). | |
| from hunyuanvideo_foley.utils.feature_utils import encode_video_features | |
| results = [] | |
| for sample_idx in range(num_samples): | |
| seg_wavs = [] | |
| sr = 48000 | |
| _t_hny_start = time.perf_counter() | |
| for seg_i, (seg_start, seg_end) in enumerate(segments): | |
| seg_dur = seg_end - seg_start | |
| seg_path = seg_clip_paths[seg_i] | |
| # Extract only visual features β reuse text_feats from above | |
| visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict) | |
| print(f"[HunyuanFoley] Sample {sample_idx+1} | seg {seg_i+1}/{len(segments)} " | |
| f"{seg_start:.1f}β{seg_end:.1f}s β {seg_audio_len:.2f}s audio") | |
| audio_batch, sr = denoise_process( | |
| visual_feats, | |
| text_feats, | |
| seg_audio_len, | |
| model_dict, | |
| cfg, | |
| guidance_scale=float(guidance_scale), | |
| num_inference_steps=int(num_steps), | |
| batch_size=1, | |
| ) | |
| wav = audio_batch[0].float().cpu().numpy() | |
| seg_samples = int(round(seg_dur * sr)) | |
| wav = wav[:, :seg_samples] | |
| seg_wavs.append(wav) | |
| _log_inference_timing("HunyuanFoley", time.perf_counter() - _t_hny_start, | |
| len(segments), int(num_steps), HUNYUAN_SECS_PER_STEP) | |
| results.append((seg_wavs, sr, text_feats)) | |
| # Free GPU memory between samples to prevent VRAM fragmentation | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return results | |
| def generate_hunyuan(video_file, prompt, negative_prompt, seed_val, | |
| guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples): | |
| """HunyuanVideoFoley: text-guided foley, 48 kHz, up to 15 s. | |
| CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost.""" | |
| num_samples = int(num_samples) | |
| crossfade_s = float(crossfade_s) | |
| crossfade_db = float(crossfade_db) | |
| # ββ CPU pre-processing (no GPU needed) ββ | |
| tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess( | |
| video_file, HUNYUAN_MAX_DUR, crossfade_s) | |
| print(f"[HunyuanFoley] Video={total_dur_s:.2f}s | {len(segments)} segment(s) Γ β€15 s") | |
| # Pre-extract dummy segment for text feature extraction (ffmpeg, CPU) | |
| dummy_seg_path = _extract_segment_clip( | |
| silent_video, 0, min(total_dur_s, HUNYUAN_MAX_DUR), | |
| os.path.join(tmp_dir, "_seg_dummy.mp4"), | |
| ) | |
| # Pre-extract all segment clips (ffmpeg, CPU) | |
| seg_clip_paths = [ | |
| _extract_segment_clip(silent_video, s, e - s, os.path.join(tmp_dir, f"hny_seg_{i}.mp4")) | |
| for i, (s, e) in enumerate(segments) | |
| ] | |
| _ctx_store("hunyuan_gpu_infer", { | |
| "segments": segments, "total_dur_s": total_dur_s, | |
| "dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths, | |
| }) | |
| # ββ GPU inference only ββ | |
| results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val, | |
| guidance_scale, num_steps, model_size, | |
| crossfade_s, crossfade_db, num_samples) | |
| # ββ CPU post-processing (no GPU needed) ββ | |
| def _hunyuan_extras(sample_idx, result, td): | |
| _, _sr, text_feats = result | |
| path = os.path.join(td, f"hunyuan_{sample_idx}_text_feats.pt") | |
| torch.save(text_feats, path) | |
| return {"text_feats_path": path} | |
| outputs = _post_process_samples( | |
| results, model="hunyuan", tmp_dir=tmp_dir, | |
| silent_video=silent_video, segments=segments, | |
| crossfade_s=crossfade_s, crossfade_db=crossfade_db, | |
| total_dur_s=total_dur_s, sr=48000, | |
| extra_meta_fn=_hunyuan_extras, | |
| ) | |
| return _pad_outputs(outputs) | |
| # ================================================================== # | |
| # SEGMENT REGENERATION HELPERS # | |
| # ================================================================== # | |
| # Each regen function: | |
| # 1. Runs inference for ONE segment (random seed, current settings) | |
| # 2. Splices the new wav into the stored wavs list | |
| # 3. Re-stitches the full track, re-saves .wav and re-muxes .mp4 | |
| # 4. Returns (new_video_path, new_audio_path, updated_seg_meta, new_waveform_html) | |
| # ================================================================== # | |
| def _splice_and_save(new_wav, seg_idx, meta, slot_id): | |
| """Replace wavs[seg_idx] with new_wav, re-stitch, re-save, re-mux. | |
| Returns (video_path, audio_path, updated_meta, waveform_html). | |
| """ | |
| wavs = _load_seg_wavs(meta["wav_paths"]) | |
| wavs[seg_idx]= new_wav | |
| crossfade_s = float(meta["crossfade_s"]) | |
| crossfade_db = float(meta["crossfade_db"]) | |
| sr = int(meta["sr"]) | |
| total_dur_s = float(meta["total_dur_s"]) | |
| silent_video = meta["silent_video"] | |
| segments = meta["segments"] | |
| model = meta["model"] | |
| full_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, sr) | |
| # Save new audio β use a new timestamped filename so Gradio / the browser | |
| # treats it as a genuinely different file and reloads the video player. | |
| _ts = int(time.time() * 1000) | |
| tmp_dir = os.path.dirname(meta["audio_path"]) | |
| _base = os.path.splitext(os.path.basename(meta["audio_path"]))[0] | |
| # Strip any previous timestamp suffix before adding a new one | |
| _base_clean = _base.rsplit("_regen_", 1)[0] | |
| audio_path = os.path.join(tmp_dir, f"{_base_clean}_regen_{_ts}.wav") | |
| _save_wav(audio_path, full_wav, sr) | |
| # Re-mux into a new video file so the browser is forced to reload it | |
| _vid_base = os.path.splitext(os.path.basename(meta["video_path"]))[0] | |
| _vid_base_clean = _vid_base.rsplit("_regen_", 1)[0] | |
| video_path = os.path.join(tmp_dir, f"{_vid_base_clean}_regen_{_ts}.mp4") | |
| mux_video_audio(silent_video, audio_path, video_path, model=model) | |
| # Save updated segment wavs to .npy files | |
| updated_wav_paths = _save_seg_wavs(wavs, tmp_dir, os.path.splitext(_base_clean)[0]) | |
| updated_meta = dict(meta) | |
| updated_meta["wav_paths"] = updated_wav_paths | |
| updated_meta["audio_path"] = audio_path | |
| updated_meta["video_path"] = video_path | |
| state_json_new = json.dumps(updated_meta) | |
| waveform_html = _build_waveform_html(audio_path, segments, slot_id, "", | |
| state_json=state_json_new, | |
| video_path=video_path, | |
| crossfade_s=crossfade_s) | |
| return video_path, audio_path, updated_meta, waveform_html | |
| def _taro_regen_duration(video_file, seg_idx, seg_meta_json, | |
| seed_val, cfg_scale, num_steps, mode, | |
| crossfade_s, crossfade_db, slot_id=None): | |
| # If cached CAVP/onset features exist, skip ~10s feature-extractor overhead | |
| try: | |
| meta = json.loads(seg_meta_json) | |
| cavp_ok = os.path.exists(meta.get("cavp_path", "")) | |
| onset_ok = os.path.exists(meta.get("onset_path", "")) | |
| if cavp_ok and onset_ok: | |
| cfg = MODEL_CONFIGS["taro"] | |
| secs = int(num_steps) * cfg["secs_per_step"] + 5 # 5s model-load only | |
| result = min(GPU_DURATION_CAP, max(30, int(secs))) | |
| print(f"[duration] TARO regen (cache hit): 1 seg Γ {int(num_steps)} steps β {secs:.0f}s β capped {result}s") | |
| return result | |
| except Exception: | |
| pass | |
| return _estimate_regen_duration("taro", int(num_steps)) | |
| def _regen_taro_gpu(video_file, seg_idx, seg_meta_json, | |
| seed_val, cfg_scale, num_steps, mode, | |
| crossfade_s, crossfade_db, slot_id=None): | |
| """GPU-only TARO regen β returns new_wav for a single segment.""" | |
| meta = json.loads(seg_meta_json) | |
| seg_idx = int(seg_idx) | |
| seg_start_s, seg_end_s = meta["segments"][seg_idx] | |
| torch.set_grad_enabled(False) | |
| device, weight_dtype = _get_device_and_dtype() | |
| _ensure_syspath("TARO") | |
| from TARO.samplers import euler_sampler, euler_maruyama_sampler | |
| # Load cached CAVP/onset features from .npy files (CPU I/O, fast, outside GPU budget) | |
| cavp_path = meta.get("cavp_path", "") | |
| onset_path = meta.get("onset_path", "") | |
| if cavp_path and os.path.exists(cavp_path) and onset_path and os.path.exists(onset_path): | |
| print("[TARO regen] Loading cached CAVP + onset features from disk") | |
| cavp_feats = np.load(cavp_path) | |
| onset_feats = np.load(onset_path) | |
| else: | |
| print("[TARO regen] Cache miss β re-extracting CAVP + onset features") | |
| from TARO.onset_util import extract_onset | |
| extract_cavp, onset_model = _load_taro_feature_extractors(device) | |
| silent_video = meta["silent_video"] | |
| tmp_dir = tempfile.mkdtemp() | |
| cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir) | |
| onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device) | |
| del extract_cavp, onset_model | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| model_net, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype) | |
| set_global_seed(random.randint(0, 2**32 - 1)) | |
| return _taro_infer_segment( | |
| model_net, vae, vocoder, cavp_feats, onset_feats, | |
| seg_start_s, seg_end_s, device, weight_dtype, | |
| float(cfg_scale), int(num_steps), mode, latents_scale, | |
| euler_sampler, euler_maruyama_sampler, | |
| ) | |
| def regen_taro_segment(video_file, seg_idx, seg_meta_json, | |
| seed_val, cfg_scale, num_steps, mode, | |
| crossfade_s, crossfade_db, slot_id): | |
| """Regenerate one TARO segment. GPU inference + CPU splice/save.""" | |
| meta = json.loads(seg_meta_json) | |
| seg_idx = int(seg_idx) | |
| # GPU: inference β CAVP/onset features loaded from disk paths in seg_meta_json | |
| new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json, | |
| seed_val, cfg_scale, num_steps, mode, | |
| crossfade_s, crossfade_db, slot_id) | |
| # Upsample 16kHz β 48kHz (sinc, CPU) | |
| new_wav = _upsample_taro(new_wav) | |
| # CPU: splice, stitch, mux, save | |
| video_path, audio_path, updated_meta, waveform_html = _splice_and_save( | |
| new_wav, seg_idx, meta, slot_id | |
| ) | |
| return video_path, audio_path, json.dumps(updated_meta), waveform_html | |
| def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json, | |
| prompt, negative_prompt, seed_val, | |
| cfg_strength, num_steps, crossfade_s, crossfade_db, | |
| slot_id=None): | |
| return _estimate_regen_duration("mmaudio", int(num_steps)) | |
| def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json, | |
| prompt, negative_prompt, seed_val, | |
| cfg_strength, num_steps, crossfade_s, crossfade_db, | |
| slot_id=None): | |
| """GPU-only MMAudio regen β returns (new_wav, sr) for a single segment.""" | |
| meta = json.loads(seg_meta_json) | |
| seg_idx = int(seg_idx) | |
| seg_start, seg_end = meta["segments"][seg_idx] | |
| seg_dur = seg_end - seg_start | |
| _ensure_syspath("MMAudio") | |
| from mmaudio.eval_utils import generate, load_video | |
| from mmaudio.model.flow_matching import FlowMatching | |
| device, dtype = _get_device_and_dtype() | |
| net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype) | |
| sr = seq_cfg.sampling_rate | |
| # Extract segment clip inside the GPU function β ffmpeg is CPU-only and safe here. | |
| # This avoids any cross-process context passing that fails under ZeroGPU isolation. | |
| seg_path = _extract_segment_clip( | |
| meta["silent_video"], seg_start, seg_dur, | |
| os.path.join(tempfile.mkdtemp(), "regen_seg.mp4"), | |
| ) | |
| rng = torch.Generator(device=device) | |
| rng.manual_seed(random.randint(0, 2**32 - 1)) | |
| fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=int(num_steps)) | |
| video_info = load_video(seg_path, seg_dur) | |
| clip_frames = video_info.clip_frames.unsqueeze(0) | |
| sync_frames = video_info.sync_frames.unsqueeze(0) | |
| actual_dur = video_info.duration_sec | |
| seq_cfg.duration = actual_dur | |
| net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) | |
| with torch.no_grad(): | |
| audios = generate( | |
| clip_frames, sync_frames, [prompt], | |
| negative_text=[negative_prompt] if negative_prompt else None, | |
| feature_utils=feature_utils, net=net, fm=fm, rng=rng, | |
| cfg_strength=float(cfg_strength), | |
| ) | |
| new_wav = audios.float().cpu()[0].numpy() | |
| seg_samples = int(round(seg_dur * sr)) | |
| new_wav = new_wav[:, :seg_samples] | |
| return new_wav, sr | |
| def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json, | |
| prompt, negative_prompt, seed_val, | |
| cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id): | |
| """Regenerate one MMAudio segment. GPU inference + CPU splice/save.""" | |
| meta = json.loads(seg_meta_json) | |
| seg_idx = int(seg_idx) | |
| # GPU: inference (segment clip extraction happens inside the GPU function) | |
| new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json, | |
| prompt, negative_prompt, seed_val, | |
| cfg_strength, num_steps, crossfade_s, crossfade_db, | |
| slot_id) | |
| # Resample to 48kHz if needed (MMAudio outputs at 44100 Hz) | |
| if sr != TARGET_SR: | |
| print(f"[MMAudio regen upsample] {sr}Hz β {TARGET_SR}Hz (sinc, CPU) β¦") | |
| new_wav = _resample_to_target(new_wav, sr) | |
| sr = TARGET_SR | |
| meta["sr"] = sr | |
| # CPU: splice, stitch, mux, save | |
| video_path, audio_path, updated_meta, waveform_html = _splice_and_save( | |
| new_wav, seg_idx, meta, slot_id | |
| ) | |
| return video_path, audio_path, json.dumps(updated_meta), waveform_html | |
| def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json, | |
| prompt, negative_prompt, seed_val, | |
| guidance_scale, num_steps, model_size, | |
| crossfade_s, crossfade_db, slot_id=None): | |
| return _estimate_regen_duration("hunyuan", int(num_steps)) | |
| def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json, | |
| prompt, negative_prompt, seed_val, | |
| guidance_scale, num_steps, model_size, | |
| crossfade_s, crossfade_db, slot_id=None): | |
| """GPU-only HunyuanFoley regen β returns (new_wav, sr) for a single segment.""" | |
| meta = json.loads(seg_meta_json) | |
| seg_idx = int(seg_idx) | |
| seg_start, seg_end = meta["segments"][seg_idx] | |
| seg_dur = seg_end - seg_start | |
| _ensure_syspath("HunyuanVideo-Foley") | |
| from hunyuanvideo_foley.utils.model_utils import denoise_process | |
| from hunyuanvideo_foley.utils.feature_utils import feature_process | |
| device, _ = _get_device_and_dtype() | |
| device = torch.device(device) | |
| model_dict, cfg = _load_hunyuan_model(device, model_size) | |
| set_global_seed(random.randint(0, 2**32 - 1)) | |
| # Extract segment clip inside the GPU function β ffmpeg is CPU-only and safe here. | |
| seg_path = _extract_segment_clip( | |
| meta["silent_video"], seg_start, seg_dur, | |
| os.path.join(tempfile.mkdtemp(), "regen_seg.mp4"), | |
| ) | |
| text_feats_path = meta.get("text_feats_path", "") | |
| if text_feats_path and os.path.exists(text_feats_path): | |
| print("[HunyuanFoley regen] Loading cached text features from disk") | |
| from hunyuanvideo_foley.utils.feature_utils import encode_video_features | |
| visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict) | |
| text_feats = torch.load(text_feats_path, map_location=device, weights_only=False) | |
| else: | |
| print("[HunyuanFoley regen] Cache miss β extracting text + visual features") | |
| visual_feats, text_feats, seg_audio_len = feature_process( | |
| seg_path, prompt if prompt else "", model_dict, cfg, | |
| neg_prompt=negative_prompt if negative_prompt else None, | |
| ) | |
| audio_batch, sr = denoise_process( | |
| visual_feats, text_feats, seg_audio_len, model_dict, cfg, | |
| guidance_scale=float(guidance_scale), | |
| num_inference_steps=int(num_steps), | |
| batch_size=1, | |
| ) | |
| new_wav = audio_batch[0].float().cpu().numpy() | |
| seg_samples = int(round(seg_dur * sr)) | |
| new_wav = new_wav[:, :seg_samples] | |
| return new_wav, sr | |
| def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json, | |
| prompt, negative_prompt, seed_val, | |
| guidance_scale, num_steps, model_size, | |
| crossfade_s, crossfade_db, slot_id): | |
| """Regenerate one HunyuanFoley segment. GPU inference + CPU splice/save.""" | |
| meta = json.loads(seg_meta_json) | |
| seg_idx = int(seg_idx) | |
| # GPU: inference (segment clip extraction happens inside the GPU function) | |
| new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json, | |
| prompt, negative_prompt, seed_val, | |
| guidance_scale, num_steps, model_size, | |
| crossfade_s, crossfade_db, slot_id) | |
| meta["sr"] = sr | |
| # CPU: splice, stitch, mux, save | |
| video_path, audio_path, updated_meta, waveform_html = _splice_and_save( | |
| new_wav, seg_idx, meta, slot_id | |
| ) | |
| return video_path, audio_path, json.dumps(updated_meta), waveform_html | |
| # Wire up regen_fn references now that the functions are defined | |
| MODEL_CONFIGS["taro"]["regen_fn"] = regen_taro_segment | |
| MODEL_CONFIGS["mmaudio"]["regen_fn"] = regen_mmaudio_segment | |
| MODEL_CONFIGS["hunyuan"]["regen_fn"] = regen_hunyuan_segment | |
| # ================================================================== # | |
| # CROSS-MODEL REGEN WRAPPERS # | |
| # ================================================================== # | |
| # Three shared endpoints β one per model β that can be called from # | |
| # *any* slot tab. slot_id is passed as plain string data so the # | |
| # result is applied back to the correct slot by the JS listener. # | |
| # The new segment is resampled to match the slot's existing SR before # | |
| # being handed to _splice_and_save, so TARO (16 kHz) / MMAudio # | |
| # (44.1 kHz) / Hunyuan (48 kHz) outputs can all be mixed freely. # | |
| # ================================================================== # | |
| def _resample_to_slot_sr(wav: np.ndarray, src_sr: int, dst_sr: int, | |
| slot_wav_ref: np.ndarray = None) -> np.ndarray: | |
| """Resample *wav* from src_sr to dst_sr, then match channel layout to | |
| *slot_wav_ref* (the first existing segment in the slot). | |
| TARO is mono (T,), MMAudio/Hunyuan are stereo (C, T). Mixing them | |
| without normalisation causes a shape mismatch in _cf_join. Rules: | |
| - stereo β mono : average channels | |
| - mono β stereo: duplicate the single channel | |
| """ | |
| wav = _resample_to_target(wav, src_sr, dst_sr) | |
| # Match channel layout to the slot's existing segments | |
| if slot_wav_ref is not None: | |
| slot_stereo = slot_wav_ref.ndim == 2 | |
| wav_stereo = wav.ndim == 2 | |
| if slot_stereo and not wav_stereo: | |
| wav = np.stack([wav, wav], axis=0) # mono β stereo (C, T) | |
| elif not slot_stereo and wav_stereo: | |
| wav = wav.mean(axis=0) # stereo β mono (T,) | |
| return wav | |
| def _xregen_splice(new_wav_raw: np.ndarray, src_sr: int, | |
| meta: dict, seg_idx: int, slot_id: str) -> tuple: | |
| """Shared epilogue for all xregen_* functions: resample β splice β save. | |
| Returns (video_path, waveform_html).""" | |
| slot_sr = int(meta["sr"]) | |
| slot_wavs = _load_seg_wavs(meta["wav_paths"]) | |
| new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr, slot_wavs[0]) | |
| video_path, audio_path, updated_meta, waveform_html = _splice_and_save( | |
| new_wav, seg_idx, meta, slot_id | |
| ) | |
| return video_path, waveform_html | |
| def _xregen_dispatch(state_json: str, seg_idx: int, slot_id: str, infer_fn): | |
| """Shared generator skeleton for all xregen_* wrappers. | |
| Yields pending HTML immediately, then calls *infer_fn()* β a zero-argument | |
| callable that runs model-specific CPU prep + GPU inference and returns | |
| (wav_array, src_sr). For TARO, *infer_fn* should return the wav already | |
| upsampled to 48 kHz; pass TARO_SR_OUT as src_sr. | |
| Yields: | |
| First: (gr.update(), gr.update(value=pending_html)) β shown while GPU runs | |
| Second: (gr.update(value=video_path), gr.update(value=waveform_html)) | |
| """ | |
| meta = json.loads(state_json) | |
| pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "") | |
| yield gr.update(), gr.update(value=pending_html) | |
| new_wav_raw, src_sr = infer_fn() | |
| video_path, waveform_html = _xregen_splice(new_wav_raw, src_sr, meta, seg_idx, slot_id) | |
| yield gr.update(value=video_path), gr.update(value=waveform_html) | |
| def xregen_taro(seg_idx, state_json, slot_id, | |
| seed_val, cfg_scale, num_steps, mode, | |
| crossfade_s, crossfade_db, | |
| request: gr.Request = None): | |
| """Cross-model regen: run TARO inference and splice into *slot_id*.""" | |
| seg_idx = int(seg_idx) | |
| meta = json.loads(state_json) | |
| def _run(): | |
| # CAVP/onset features are loaded from disk paths inside the GPU fn | |
| wav = _regen_taro_gpu(None, seg_idx, state_json, | |
| seed_val, cfg_scale, num_steps, mode, | |
| crossfade_s, crossfade_db, slot_id) | |
| return _upsample_taro(wav), TARO_SR_OUT # 16 kHz β 48 kHz (CPU) | |
| yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run) | |
| def xregen_mmaudio(seg_idx, state_json, slot_id, | |
| prompt, negative_prompt, seed_val, | |
| cfg_strength, num_steps, crossfade_s, crossfade_db, | |
| request: gr.Request = None): | |
| """Cross-model regen: run MMAudio inference and splice into *slot_id*.""" | |
| seg_idx = int(seg_idx) | |
| def _run(): | |
| # Segment clip extraction happens inside _regen_mmaudio_gpu | |
| wav, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json, | |
| prompt, negative_prompt, seed_val, | |
| cfg_strength, num_steps, | |
| crossfade_s, crossfade_db, slot_id) | |
| return wav, src_sr | |
| yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run) | |
| def xregen_hunyuan(seg_idx, state_json, slot_id, | |
| prompt, negative_prompt, seed_val, | |
| guidance_scale, num_steps, model_size, | |
| crossfade_s, crossfade_db, | |
| request: gr.Request = None): | |
| """Cross-model regen: run HunyuanFoley inference and splice into *slot_id*.""" | |
| seg_idx = int(seg_idx) | |
| def _run(): | |
| # Segment clip extraction happens inside _regen_hunyuan_gpu | |
| wav, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json, | |
| prompt, negative_prompt, seed_val, | |
| guidance_scale, num_steps, model_size, | |
| crossfade_s, crossfade_db, slot_id) | |
| return wav, src_sr | |
| yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run) | |
| # ================================================================== # | |
| # SHARED UI HELPERS # | |
| # ================================================================== # | |
| def _register_regen_handlers(tab_prefix, model_key, regen_seg_tb, regen_state_tb, | |
| input_components, slot_vids, slot_waves): | |
| """Register per-slot regen button handlers for a model tab. | |
| This replaces the three nearly-identical for-loops that previously existed | |
| for TARO, MMAudio, and HunyuanFoley tabs. | |
| Args: | |
| tab_prefix: e.g. "taro", "mma", "hf" | |
| model_key: e.g. "taro", "mmaudio", "hunyuan" | |
| regen_seg_tb: gr.Textbox for seg_idx (render=False) | |
| regen_state_tb: gr.Textbox for state_json (render=False) | |
| input_components: list of Gradio input components (video, seed, etc.) | |
| β order must match regen_fn signature after (seg_idx, state_json, video) | |
| slot_vids: list of gr.Video components per slot | |
| slot_waves: list of gr.HTML components per slot | |
| Returns: | |
| list of hidden gr.Buttons (one per slot) | |
| """ | |
| cfg = MODEL_CONFIGS[model_key] | |
| regen_fn = cfg["regen_fn"] | |
| label = cfg["label"] | |
| btns = [] | |
| for _i in range(MAX_SLOTS): | |
| _slot_id = f"{tab_prefix}_{_i}" | |
| _btn = gr.Button(render=False, elem_id=f"regen_btn_{_slot_id}") | |
| btns.append(_btn) | |
| print(f"[startup] registering regen handler for slot {_slot_id}") | |
| def _make_regen(_si, _sid, _model_key, _label, _regen_fn): | |
| def _do(seg_idx, state_json, *args): | |
| print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} " | |
| f"state_json_len={len(state_json) if state_json else 0}") | |
| if not state_json: | |
| print(f"[regen {_label}] early-exit: state_json empty") | |
| yield gr.update(), gr.update() | |
| return | |
| lock = _get_slot_lock(_sid) | |
| with lock: | |
| state = json.loads(state_json) | |
| pending_html = _build_regen_pending_html( | |
| state["segments"], int(seg_idx), _sid, "" | |
| ) | |
| yield gr.update(), gr.update(value=pending_html) | |
| print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} β calling regen") | |
| try: | |
| # args[0] = video, args[1:] = model-specific params | |
| vid, aud, new_meta_json, html = _regen_fn( | |
| args[0], int(seg_idx), state_json, *args[1:], _sid, | |
| ) | |
| print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} β done, vid={vid!r}") | |
| except Exception as _e: | |
| print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} β ERROR: {_e}") | |
| raise | |
| yield gr.update(value=vid), gr.update(value=html) | |
| return _do | |
| _btn.click( | |
| fn=_make_regen(_i, _slot_id, model_key, label, regen_fn), | |
| inputs=[regen_seg_tb, regen_state_tb] + input_components, | |
| outputs=[slot_vids[_i], slot_waves[_i]], | |
| api_name=f"regen_{tab_prefix}_{_i}", | |
| ) | |
| return btns | |
| def _pad_outputs(outputs: list) -> list: | |
| """Flatten (video, audio, seg_meta) triples and pad to MAX_SLOTS * 3 with None. | |
| Each entry in *outputs* must be a (video_path, audio_path, seg_meta) tuple where | |
| seg_meta = {"segments": [...], "audio_path": str, "video_path": str, | |
| "sr": int, "model": str, "crossfade_s": float, | |
| "crossfade_db": float, "wav_paths": list[str]} | |
| """ | |
| result = [] | |
| for i in range(MAX_SLOTS): | |
| if i < len(outputs): | |
| result.extend(outputs[i]) # 3 items: video, audio, meta | |
| else: | |
| result.extend([None, None, None]) | |
| return result | |
| # ------------------------------------------------------------------ # | |
| # WaveSurfer waveform + segment marker HTML builder # | |
| # ------------------------------------------------------------------ # | |
| def _build_regen_pending_html(segments: list, regen_seg_idx: int, slot_id: str, | |
| hidden_input_id: str) -> str: | |
| """Return a waveform placeholder shown while a segment is being regenerated. | |
| Renders a dark bar with the active segment highlighted in amber + a spinner. | |
| """ | |
| segs_json = json.dumps(segments) | |
| seg_colors = [c.format(a="0.25") for c in SEG_COLORS] | |
| active_color = "rgba(255,180,0,0.55)" | |
| duration = segments[-1][1] if segments else 1.0 | |
| seg_divs = "" | |
| for i, seg in enumerate(segments): | |
| # Draw only the non-overlapping (unique) portion of each segment so that | |
| # overlapping windows don't visually bleed into adjacent segments. | |
| # Each segment owns the region from its own start up to the next segment's | |
| # start (or its own end for the final segment). | |
| seg_start = seg[0] | |
| seg_end = segments[i + 1][0] if i + 1 < len(segments) else seg[1] | |
| left_pct = seg_start / duration * 100 | |
| width_pct = (seg_end - seg_start) / duration * 100 | |
| color = active_color if i == regen_seg_idx else seg_colors[i % len(seg_colors)] | |
| extra = "border:2px solid #ffb300;animation:wf_pulse 0.8s ease-in-out infinite alternate;" if i == regen_seg_idx else "" | |
| seg_divs += ( | |
| f'<div style="position:absolute;top:0;left:{left_pct:.2f}%;' | |
| f'width:{width_pct:.2f}%;height:100%;background:{color};{extra}">' | |
| f'<span style="color:rgba(255,255,255,0.7);font-size:10px;padding:2px 3px;">Seg {i+1}</span>' | |
| f'</div>' | |
| ) | |
| spinner = ( | |
| '<div style="position:absolute;top:50%;left:50%;transform:translate(-50%,-50%);' | |
| 'display:flex;align-items:center;gap:6px;">' | |
| '<div style="width:14px;height:14px;border:2px solid #ffb300;' | |
| 'border-top-color:transparent;border-radius:50%;' | |
| 'animation:wf_spin 0.7s linear infinite;"></div>' | |
| f'<span style="color:#ffb300;font-size:12px;white-space:nowrap;">' | |
| f'Regenerating Seg {regen_seg_idx+1}β¦</span>' | |
| '</div>' | |
| ) | |
| return f""" | |
| <style> | |
| @keyframes wf_pulse {{from{{opacity:0.5}}to{{opacity:1}}}} | |
| @keyframes wf_spin {{to{{transform:rotate(360deg)}}}} | |
| </style> | |
| <div style="background:#1a1a1a;border-radius:8px;padding:10px;margin-top:6px;"> | |
| <div style="position:relative;width:100%;height:80px;background:#1e1e2e;border-radius:4px;overflow:hidden;"> | |
| {seg_divs} | |
| {spinner} | |
| </div> | |
| <div style="color:#888;font-size:11px;margin-top:6px;">Regenerating β please waitβ¦</div> | |
| </div> | |
| """ | |
| def _build_waveform_html(audio_path: str, segments: list, slot_id: str, | |
| hidden_input_id: str, state_json: str = "", | |
| fn_index: int = -1, video_path: str = "", | |
| crossfade_s: float = 0.0) -> str: | |
| """Return a self-contained HTML block with a Canvas waveform (display only), | |
| segment boundary markers, and a download link. | |
| Uses Web Audio API + Canvas β no external libraries. | |
| The waveform is SILENT. The playhead tracks the Gradio <video> element | |
| in the same slot via its timeupdate event. | |
| """ | |
| if not audio_path or not os.path.exists(audio_path): | |
| return "<p style='color:#888;font-size:12px'>No audio yet.</p>" | |
| # Serve audio via Gradio's file API instead of base64-encoding the entire | |
| # WAV inline. For a 25s stereo 44.1kHz track this saves ~5 MB per slot. | |
| audio_url = f"/gradio_api/file={audio_path}" | |
| segs_json = json.dumps(segments) | |
| seg_colors = [c.format(a="0.35") for c in SEG_COLORS] | |
| # NOTE: Gradio updates gr.HTML via innerHTML which does NOT execute <script> tags. | |
| # Solution: put the entire waveform (canvas + JS) inside an <iframe srcdoc="...">. | |
| # iframes always execute their scripts. The iframe posts messages to the parent for | |
| # segment-click events; the parent listens and fires the Gradio regen trigger. | |
| # For playhead sync, the iframe polls window.parent for a <video> element. | |
| iframe_inner = f"""<!DOCTYPE html> | |
| <html> | |
| <head> | |
| <meta charset="utf-8"> | |
| <style> | |
| * {{ margin:0; padding:0; box-sizing:border-box; }} | |
| body {{ background:#1a1a1a; overflow:hidden; }} | |
| #wrap {{ position:relative; width:100%; height:80px; }} | |
| canvas {{ display:block; }} | |
| #cv {{ position:absolute; top:0; left:0; width:100%; height:100%; }} | |
| #cvp {{ position:absolute; top:0; left:0; width:100%; height:100%; pointer-events:none; }} | |
| </style> | |
| </head> | |
| <body> | |
| <div id="wrap"> | |
| <canvas id="cv"></canvas> | |
| <canvas id="cvp"></canvas> | |
| </div> | |
| <script> | |
| (function() {{ | |
| const SLOT_ID = '{slot_id}'; | |
| const segments = {segs_json}; | |
| const segColors = {json.dumps(seg_colors)}; | |
| const crossfadeSec = {crossfade_s}; | |
| let audioDuration = 0; | |
| // ββ Popup via postMessage to parent global listener βββββββββββββββββ | |
| // The parent page (Gradio) has a global window.addEventListener('message',...) | |
| // set up via gr.Blocks(js=...) that handles popup show/hide and regen trigger. | |
| function showPopup(idx, mx, my) {{ | |
| console.log('[wf showPopup] slot='+SLOT_ID+' idx='+idx+' posting to parent'); | |
| // Convert iframe-local coords to parent page coords | |
| try {{ | |
| const fr = window.frameElement ? window.frameElement.getBoundingClientRect() : {{left:0,top:0}}; | |
| window.parent.postMessage({{ | |
| type:'wf_popup', action:'show', | |
| slot_id: SLOT_ID, seg_idx: idx, | |
| t0: segments[idx][0], t1: segments[idx][1], | |
| x: mx + fr.left, y: my + fr.top | |
| }}, '*'); | |
| console.log('[wf showPopup] postMessage sent OK'); | |
| }} catch(e) {{ | |
| console.log('[wf showPopup] postMessage fallback, err='+e.message); | |
| window.parent.postMessage({{ | |
| type:'wf_popup', action:'show', | |
| slot_id: SLOT_ID, seg_idx: idx, | |
| t0: segments[idx][0], t1: segments[idx][1], | |
| x: mx, y: my | |
| }}, '*'); | |
| }} | |
| }} | |
| function hidePopup() {{ | |
| window.parent.postMessage({{type:'wf_popup', action:'hide'}}, '*'); | |
| }} | |
| // ββ Canvas waveform ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| const cv = document.getElementById('cv'); | |
| const cvp = document.getElementById('cvp'); | |
| const wrap= document.getElementById('wrap'); | |
| function drawWaveform(channelData, duration) {{ | |
| audioDuration = duration; | |
| const dpr = window.devicePixelRatio || 1; | |
| const W = wrap.getBoundingClientRect().width || window.innerWidth || 600; | |
| const H = 80; | |
| cv.width = W * dpr; cv.height = H * dpr; | |
| const ctx = cv.getContext('2d'); | |
| ctx.scale(dpr, dpr); | |
| ctx.fillStyle = '#1e1e2e'; | |
| ctx.fillRect(0, 0, W, H); | |
| segments.forEach(function(seg, idx) {{ | |
| // Color boundary = midpoint of the crossfade zone = where the blend is | |
| // 50/50. This is also where the cut would land if crossfade were 0, and | |
| // where the listener perceptually hears the transition to the next segment. | |
| const x1 = (seg[0] / duration) * W; | |
| const xEnd = idx + 1 < segments.length | |
| ? ((segments[idx + 1][0] + crossfadeSec / 2) / duration) * W | |
| : (seg[1] / duration) * W; | |
| ctx.fillStyle = segColors[idx % segColors.length]; | |
| ctx.fillRect(x1, 0, xEnd - x1, H); | |
| ctx.fillStyle = 'rgba(255,255,255,0.6)'; | |
| ctx.font = '10px sans-serif'; | |
| ctx.fillText('Seg '+(idx+1), x1+3, 12); | |
| }}); | |
| const samples = channelData.length; | |
| const barW=2, gap=1, step=barW+gap; | |
| const numBars = Math.floor(W / step); | |
| const blockSz = Math.floor(samples / numBars); | |
| ctx.fillStyle = '#4a9eff'; | |
| for (let i=0; i<numBars; i++) {{ | |
| let max=0; | |
| const s=i*blockSz; | |
| for (let j=0; j<blockSz; j++) {{ | |
| const v=Math.abs(channelData[s+j]||0); | |
| if (v>max) max=v; | |
| }} | |
| const barH=Math.max(1, max*H); | |
| ctx.fillRect(i*step, (H-barH)/2, barW, barH); | |
| }} | |
| segments.forEach(function(seg) {{ | |
| [seg[0],seg[1]].forEach(function(t) {{ | |
| const x=(t/duration)*W; | |
| ctx.strokeStyle='rgba(255,255,255,0.4)'; | |
| ctx.lineWidth=1; | |
| ctx.beginPath(); ctx.moveTo(x,0); ctx.lineTo(x,H); ctx.stroke(); | |
| }}); | |
| }}); | |
| // ββ Crossfade overlap indicators ββ | |
| // The color boundary is at segments[i+1][0] (= seg_i.end - crossfadeSec). | |
| // We centre the hatch on that edge: half the crossfade on each color side. | |
| if (crossfadeSec > 0 && segments.length > 1) {{ | |
| for (let i = 0; i < segments.length - 1; i++) {{ | |
| // Color edge = segments[i+1][0], hatch spans half on each side | |
| const edgeT = segments[i+1][0]; | |
| const overlapStart = edgeT - crossfadeSec / 2; | |
| const overlapEnd = edgeT + crossfadeSec / 2; | |
| const xL = (overlapStart / duration) * W; | |
| const xR = (overlapEnd / duration) * W; | |
| // Diagonal hatch pattern over the overlap zone | |
| ctx.save(); | |
| ctx.beginPath(); | |
| ctx.rect(xL, 0, xR - xL, H); | |
| ctx.clip(); | |
| ctx.strokeStyle = 'rgba(255,255,255,0.35)'; | |
| ctx.lineWidth = 1; | |
| const spacing = 6; | |
| for (let lx = xL - H; lx < xR + H; lx += spacing) {{ | |
| ctx.beginPath(); | |
| ctx.moveTo(lx, H); | |
| ctx.lineTo(lx + H, 0); | |
| ctx.stroke(); | |
| }} | |
| ctx.restore(); | |
| }} | |
| }} | |
| cv.onclick = function(e) {{ | |
| const r=cv.getBoundingClientRect(); | |
| const xRel=(e.clientX-r.left)/r.width; | |
| const tClick=xRel*duration; | |
| // Pick the segment whose unique (non-overlapping) region contains the click. | |
| // Each segment owns [seg[0], nextSeg[0]) visually; last segment owns [seg[0], seg[1]]. | |
| let hit=-1; | |
| segments.forEach(function(seg,idx){{ | |
| const uniqueEnd = idx + 1 < segments.length ? segments[idx+1][0] : seg[1]; | |
| if (tClick >= seg[0] && tClick < uniqueEnd) hit = idx; | |
| }}); | |
| console.log('[wf click] tClick='+tClick.toFixed(2)+' hit='+hit+' audioDuration='+audioDuration+' segments='+JSON.stringify(segments)); | |
| if (hit>=0) showPopup(hit, e.clientX, e.clientY); | |
| else hidePopup(); | |
| }}; | |
| }} | |
| function drawPlayhead(progress) {{ | |
| const dpr = window.devicePixelRatio || 1; | |
| const W = wrap.getBoundingClientRect().width || window.innerWidth || 600; | |
| const H = 80; | |
| if (cvp.width !== W*dpr) {{ cvp.width=W*dpr; cvp.height=H*dpr; }} | |
| const ctx = cvp.getContext('2d'); | |
| ctx.clearRect(0,0,W*dpr,H*dpr); | |
| ctx.save(); | |
| ctx.scale(dpr,dpr); | |
| const x=progress*W; | |
| ctx.strokeStyle='#fff'; | |
| ctx.lineWidth=2; | |
| ctx.beginPath(); ctx.moveTo(x,0); ctx.lineTo(x,H); ctx.stroke(); | |
| ctx.restore(); | |
| }} | |
| // Poll parent for video time β find the video in the same wf_container slot | |
| function findSlotVideo() {{ | |
| try {{ | |
| const par = window.parent.document; | |
| // Walk up from our iframe to find wf_container_{slot_id}, then find its sibling video | |
| const container = par.getElementById('wf_container_{slot_id}'); | |
| if (!container) return par.querySelector('video'); | |
| // The video is inside the same gr.Group β walk up to find it | |
| let node = container.parentElement; | |
| while (node && node !== par.body) {{ | |
| const v = node.querySelector('video'); | |
| if (v) return v; | |
| node = node.parentElement; | |
| }} | |
| return null; | |
| }} catch(e) {{ return null; }} | |
| }} | |
| setInterval(function() {{ | |
| const vid = findSlotVideo(); | |
| if (vid && vid.duration && isFinite(vid.duration) && audioDuration > 0) {{ | |
| drawPlayhead(vid.currentTime / vid.duration); | |
| }} | |
| }}, 50); | |
| // ββ Fetch + decode audio from Gradio file API ββββββββββββββββββββββ | |
| const audioUrl = '{audio_url}'; | |
| fetch(audioUrl) | |
| .then(function(r) {{ return r.arrayBuffer(); }}) | |
| .then(function(arrayBuf) {{ | |
| const AudioCtx = window.AudioContext || window.webkitAudioContext; | |
| if (!AudioCtx) return; | |
| const tmpCtx = new AudioCtx({{sampleRate:44100}}); | |
| tmpCtx.decodeAudioData(arrayBuf, | |
| function(ab) {{ | |
| try {{ tmpCtx.close(); }} catch(e) {{}} | |
| function tryDraw() {{ | |
| const W = wrap.getBoundingClientRect().width || window.innerWidth; | |
| if (W > 0) {{ drawWaveform(ab.getChannelData(0), ab.duration); }} | |
| else {{ setTimeout(tryDraw, 100); }} | |
| }} | |
| tryDraw(); | |
| }}, | |
| function(err) {{}} | |
| ); | |
| }}) | |
| .catch(function(e) {{}}); | |
| }})(); | |
| </script> | |
| </body> | |
| </html>""" | |
| # Escape for HTML attribute (srcdoc uses HTML entities) | |
| srcdoc = _html.escape(iframe_inner, quote=True) | |
| state_escaped = _html.escape(state_json or "", quote=True) | |
| return f""" | |
| <div id="wf_container_{slot_id}" | |
| data-fn-index="{fn_index}" | |
| data-state="{state_escaped}" | |
| style="background:#1a1a1a;border-radius:8px;padding:10px;margin-top:6px;position:relative;"> | |
| <div style="position:relative;width:100%;height:80px;"> | |
| <iframe id="wf_iframe_{slot_id}" | |
| srcdoc="{srcdoc}" | |
| sandbox="allow-scripts allow-same-origin" | |
| style="width:100%;height:80px;border:none;border-radius:4px;display:block;" | |
| scrolling="no"></iframe> | |
| </div> | |
| <div style="display:flex;align-items:center;gap:8px;margin-top:6px;"> | |
| <span id="wf_statusbar_{slot_id}" style="color:#888;font-size:11px;">Click a segment to regenerate | Playhead syncs to video</span> | |
| <a href="{audio_url}" download="audio_{slot_id}.wav" | |
| style="margin-left:auto;background:#333;color:#eee;border:1px solid #555; | |
| border-radius:4px;padding:3px 10px;font-size:12px;text-decoration:none;"> | |
| ↓ Audio</a>{f''' | |
| <a href="/gradio_api/file={video_path}" download="video_{slot_id}.mp4" | |
| style="background:#333;color:#eee;border:1px solid #555; | |
| border-radius:4px;padding:3px 10px;font-size:12px;text-decoration:none;"> | |
| ↓ Video</a>''' if video_path else ''} | |
| </div> | |
| <div id="wf_seglabel_{slot_id}" | |
| style="color:#aaa;font-size:11px;margin-top:4px;min-height:16px;"></div> | |
| </div> | |
| """ | |
| def _make_output_slots(tab_prefix: str) -> tuple: | |
| """Build MAX_SLOTS output groups for one tab. | |
| Each slot has: video and waveform HTML. | |
| Regen is triggered via direct Gradio queue API calls from JS (no hidden | |
| trigger textboxes needed β DOM event dispatch is unreliable in Gradio 5 | |
| Svelte components). State JSON is embedded in the waveform HTML's | |
| data-state attribute and passed directly in the queue API payload. | |
| Returns (grps, vids, waveforms). | |
| """ | |
| grps, vids, waveforms = [], [], [] | |
| for i in range(MAX_SLOTS): | |
| slot_id = f"{tab_prefix}_{i}" | |
| with gr.Group(visible=(i == 0)) as g: | |
| vids.append(gr.Video(label=f"Generation {i+1} β Video", | |
| elem_id=f"slot_vid_{slot_id}", | |
| show_download_button=False)) | |
| waveforms.append(gr.HTML( | |
| value="<p style='color:#888;font-size:12px'>Generate audio to see waveform.</p>", | |
| elem_id=f"slot_wave_{slot_id}", | |
| )) | |
| grps.append(g) | |
| return grps, vids, waveforms | |
| def _unpack_outputs(flat: list, n: int, tab_prefix: str) -> list: | |
| """Turn a flat _pad_outputs list into Gradio update lists. | |
| flat has MAX_SLOTS * 3 items: [vid0, aud0, meta0, vid1, aud1, meta1, ...] | |
| Returns updates for vids + waveforms only (NOT grps). | |
| Group visibility is handled separately via .then() to avoid Gradio 5 SSR | |
| 'Too many arguments' caused by mixing gr.Group updates with other outputs. | |
| State JSON is embedded in the waveform HTML data-state attribute so JS | |
| can read it when calling the Gradio queue API for regen. | |
| """ | |
| n = int(n) | |
| vid_updates = [] | |
| wave_updates = [] | |
| for i in range(MAX_SLOTS): | |
| vid_path = flat[i * 3] | |
| aud_path = flat[i * 3 + 1] | |
| meta = flat[i * 3 + 2] | |
| vid_updates.append(gr.update(value=vid_path)) | |
| if aud_path and meta: | |
| slot_id = f"{tab_prefix}_{i}" | |
| state_json = json.dumps(meta) | |
| html = _build_waveform_html(aud_path, meta["segments"], slot_id, | |
| "", state_json=state_json, | |
| video_path=meta.get("video_path", ""), | |
| crossfade_s=float(meta.get("crossfade_s", 0))) | |
| wave_updates.append(gr.update(value=html)) | |
| else: | |
| wave_updates.append(gr.update( | |
| value="<p style='color:#888;font-size:12px'>Generate audio to see waveform.</p>" | |
| )) | |
| return vid_updates + wave_updates | |
| def _on_video_upload_taro(video_file, num_steps, crossfade_s): | |
| if video_file is None: | |
| return gr.update(maximum=MAX_SLOTS, value=1) | |
| try: | |
| D = get_video_duration(video_file) | |
| max_s = _taro_calc_max_samples(D, int(num_steps), float(crossfade_s)) | |
| except Exception: | |
| max_s = MAX_SLOTS | |
| return gr.update(maximum=max_s, value=min(1, max_s)) | |
| def _update_slot_visibility(n): | |
| n = int(n) | |
| return [gr.update(visible=(i < n)) for i in range(MAX_SLOTS)] | |
| # ================================================================== # | |
| # GRADIO UI # | |
| # ================================================================== # | |
| _SLOT_CSS = """ | |
| /* Responsive video: fills column width, height auto from aspect ratio */ | |
| .gradio-video video { | |
| width: 100%; | |
| height: auto; | |
| max-height: 60vh; | |
| object-fit: contain; | |
| } | |
| /* Force two-column layout to stay equal-width */ | |
| .gradio-container .gradio-row > .gradio-column { | |
| flex: 1 1 0 !important; | |
| min-width: 0 !important; | |
| max-width: 50% !important; | |
| } | |
| /* Hide the built-in download button on output video slots β downloads are | |
| handled by the waveform panel links which always reflect the latest regen. */ | |
| [id^="slot_vid_"] .download-icon, | |
| [id^="slot_vid_"] button[aria-label="Download"], | |
| [id^="slot_vid_"] a[download] { | |
| display: none !important; | |
| } | |
| """ | |
| _GLOBAL_JS = """ | |
| () => { | |
| // Global postMessage handler for waveform iframe events. | |
| // Runs once on page load (Gradio js= parameter). | |
| // Handles: popup open/close relay, regen trigger via Gradio queue API. | |
| if (window._wf_global_listener) return; // already registered | |
| window._wf_global_listener = true; | |
| // ββ ZeroGPU quota attribution ββ | |
| // HF Spaces run inside an iframe on huggingface.co. Gradio's own JS client | |
| // gets ZeroGPU auth headers (x-zerogpu-token, x-zerogpu-uuid) by sending a | |
| // postMessage("zerogpu-headers") to the parent frame. The parent responds | |
| // with a Map of headers that must be included on queue/join calls. | |
| // We replicate this exact mechanism so our raw regen fetch() calls are | |
| // attributed to the logged-in user's Pro quota. | |
| function _fetchZerogpuHeaders() { | |
| return new Promise(function(resolve) { | |
| // Check if we're in an HF iframe with zerogpu support | |
| if (typeof window === 'undefined' || window.parent === window || !window.supports_zerogpu_headers) { | |
| console.log('[zerogpu] not in HF iframe or no zerogpu support'); | |
| resolve({}); | |
| return; | |
| } | |
| // Determine origin β same logic as Gradio's client | |
| var hostname = window.location.hostname; | |
| var hfhubdev = 'dev.spaces.huggingface.tech'; | |
| var origin = hostname.includes('.dev.') | |
| ? 'https://moon-' + hostname.split('.')[1] + '.' + hfhubdev | |
| : 'https://huggingface.co'; | |
| // Use MessageChannel just like Gradio's post_message helper | |
| var channel = new MessageChannel(); | |
| var done = false; | |
| channel.port1.onmessage = function(ev) { | |
| channel.port1.close(); | |
| done = true; | |
| var headers = ev.data; | |
| if (headers && typeof headers === 'object') { | |
| // Convert Map to plain object if needed | |
| var obj = {}; | |
| if (typeof headers.forEach === 'function') { | |
| headers.forEach(function(v, k) { obj[k] = v; }); | |
| } else { | |
| obj = headers; | |
| } | |
| console.log('[zerogpu] got headers from parent:', Object.keys(obj).join(', ')); | |
| resolve(obj); | |
| } else { | |
| resolve({}); | |
| } | |
| }; | |
| window.parent.postMessage('zerogpu-headers', origin, [channel.port2]); | |
| // Timeout: don't block regen if parent doesn't respond | |
| setTimeout(function() { if (!done) { done = true; channel.port1.close(); resolve({}); } }, 3000); | |
| }); | |
| } | |
| // Cache: api_name -> fn_index, built once from gradio_config.dependencies | |
| let _fnIndexCache = null; | |
| function getFnIndex(apiName) { | |
| if (!_fnIndexCache) { | |
| _fnIndexCache = {}; | |
| const deps = window.gradio_config && window.gradio_config.dependencies; | |
| if (deps) deps.forEach(function(d, i) { | |
| if (d.api_name) _fnIndexCache[d.api_name] = i; | |
| }); | |
| } | |
| return _fnIndexCache[apiName]; | |
| } | |
| // Read a component's current DOM value by elem_id. | |
| // For Number/Slider: reads the <input type="number"> or <input type="range">. | |
| // For Textbox/Radio: reads the <textarea> or checked <input type="radio">. | |
| // Returns null if not found. | |
| function readComponentValue(elemId) { | |
| const el = document.getElementById(elemId); | |
| if (!el) return null; | |
| const numInput = el.querySelector('input[type="number"]'); | |
| if (numInput) return parseFloat(numInput.value); | |
| const rangeInput = el.querySelector('input[type="range"]'); | |
| if (rangeInput) return parseFloat(rangeInput.value); | |
| const radio = el.querySelector('input[type="radio"]:checked'); | |
| if (radio) return radio.value; | |
| const ta = el.querySelector('textarea'); | |
| if (ta) return ta.value; | |
| const txt = el.querySelector('input[type="text"], input:not([type])'); | |
| if (txt) return txt.value; | |
| return null; | |
| } | |
| // Fire regen for a given slot and segment by posting directly to the | |
| // Gradio queue API β bypasses Svelte binding entirely. | |
| // targetModel: 'taro' | 'mma' | 'hf' (which model to use for inference) | |
| // If targetModel matches the slot's own prefix, uses the per-slot regen_* endpoint. | |
| // Otherwise uses the shared xregen_* cross-model endpoint. | |
| function fireRegen(slot_id, seg_idx, targetModel) { | |
| // Block if a regen is already in-flight for this slot | |
| if (_regenInFlight[slot_id]) { | |
| console.log('[fireRegen] blocked β regen already in-flight for', slot_id); | |
| return; | |
| } | |
| _regenInFlight[slot_id] = true; | |
| const prefix = slot_id.split('_')[0]; // owning tab: 'taro'|'mma'|'hf' | |
| const slotNum = parseInt(slot_id.split('_')[1], 10); | |
| // Decide which endpoint to call | |
| const crossModel = (targetModel !== prefix); | |
| let apiName, data; | |
| // Read state_json from the waveform container data-state attribute | |
| const container = document.getElementById('wf_container_' + slot_id); | |
| const stateJson = container ? (container.getAttribute('data-state') || '') : ''; | |
| if (!stateJson) { | |
| console.warn('[fireRegen] no state_json for slot', slot_id); | |
| return; | |
| } | |
| if (!crossModel) { | |
| // ββ Same-model regen: per-slot endpoint, video passed as null ββ | |
| apiName = 'regen_' + prefix + '_' + slotNum; | |
| if (prefix === 'taro') { | |
| data = [seg_idx, stateJson, null, | |
| readComponentValue('taro_seed'), readComponentValue('taro_cfg'), | |
| readComponentValue('taro_steps'), readComponentValue('taro_mode'), | |
| readComponentValue('taro_cf_dur'), readComponentValue('taro_cf_db')]; | |
| } else if (prefix === 'mma') { | |
| data = [seg_idx, stateJson, null, | |
| readComponentValue('mma_prompt'), readComponentValue('mma_neg'), | |
| readComponentValue('mma_seed'), readComponentValue('mma_cfg'), | |
| readComponentValue('mma_steps'), | |
| readComponentValue('mma_cf_dur'), readComponentValue('mma_cf_db')]; | |
| } else { | |
| data = [seg_idx, stateJson, null, | |
| readComponentValue('hf_prompt'), readComponentValue('hf_neg'), | |
| readComponentValue('hf_seed'), readComponentValue('hf_guidance'), | |
| readComponentValue('hf_steps'), readComponentValue('hf_size'), | |
| readComponentValue('hf_cf_dur'), readComponentValue('hf_cf_db')]; | |
| } | |
| } else { | |
| // ββ Cross-model regen: shared xregen_* endpoint ββ | |
| // slot_id is passed so the server knows which slot's state to splice into. | |
| // UI params are read from the target model's tab inputs. | |
| if (targetModel === 'taro') { | |
| apiName = 'xregen_taro'; | |
| data = [seg_idx, stateJson, slot_id, | |
| readComponentValue('taro_seed'), readComponentValue('taro_cfg'), | |
| readComponentValue('taro_steps'), readComponentValue('taro_mode'), | |
| readComponentValue('taro_cf_dur'), readComponentValue('taro_cf_db')]; | |
| } else if (targetModel === 'mma') { | |
| apiName = 'xregen_mmaudio'; | |
| data = [seg_idx, stateJson, slot_id, | |
| readComponentValue('mma_prompt'), readComponentValue('mma_neg'), | |
| readComponentValue('mma_seed'), readComponentValue('mma_cfg'), | |
| readComponentValue('mma_steps'), | |
| readComponentValue('mma_cf_dur'), readComponentValue('mma_cf_db')]; | |
| } else { | |
| apiName = 'xregen_hunyuan'; | |
| data = [seg_idx, stateJson, slot_id, | |
| readComponentValue('hf_prompt'), readComponentValue('hf_neg'), | |
| readComponentValue('hf_seed'), readComponentValue('hf_guidance'), | |
| readComponentValue('hf_steps'), readComponentValue('hf_size'), | |
| readComponentValue('hf_cf_dur'), readComponentValue('hf_cf_db')]; | |
| } | |
| } | |
| console.log('[fireRegen] calling api', apiName, 'seg', seg_idx); | |
| // Snapshot current waveform HTML + video src before mutating anything, | |
| // so we can restore on error (e.g. quota exceeded). | |
| var _preRegenWaveHtml = null; | |
| var _preRegenVideoSrc = null; | |
| var waveElSnap = document.getElementById('slot_wave_' + slot_id); | |
| if (waveElSnap) _preRegenWaveHtml = waveElSnap.innerHTML; | |
| var vidElSnap = document.getElementById('slot_vid_' + slot_id); | |
| if (vidElSnap) { var vSnap = vidElSnap.querySelector('video'); if (vSnap) _preRegenVideoSrc = vSnap.getAttribute('src'); } | |
| // Show spinner immediately | |
| const lbl = document.getElementById('wf_seglabel_' + slot_id); | |
| if (lbl) lbl.textContent = 'Regenerating Seg ' + (seg_idx + 1) + '...'; | |
| const fnIndex = getFnIndex(apiName); | |
| if (fnIndex === undefined) { | |
| console.warn('[fireRegen] fn_index not found for api_name:', apiName); | |
| return; | |
| } | |
| // Get ZeroGPU auth headers from the HF parent frame (same mechanism | |
| // Gradio's own JS client uses), then fire the regen queue/join call. | |
| // Falls back to user-supplied HF token if zerogpu headers aren't available. | |
| _fetchZerogpuHeaders().then(function(zerogpuHeaders) { | |
| var regenHeaders = {'Content-Type': 'application/json'}; | |
| var hasZerogpu = zerogpuHeaders && Object.keys(zerogpuHeaders).length > 0; | |
| if (hasZerogpu) { | |
| // Merge zerogpu headers (x-zerogpu-token, x-zerogpu-uuid) | |
| for (var k in zerogpuHeaders) { regenHeaders[k] = zerogpuHeaders[k]; } | |
| console.log('[fireRegen] using zerogpu headers from parent frame'); | |
| } else { | |
| console.warn('[fireRegen] no zerogpu headers available β may use anonymous quota'); | |
| } | |
| fetch('/gradio_api/queue/join', { | |
| method: 'POST', | |
| credentials: 'include', | |
| headers: regenHeaders, | |
| body: JSON.stringify({ | |
| data: data, | |
| fn_index: fnIndex, | |
| api_name: '/' + apiName, | |
| session_hash: window.__gradio_session_hash__, | |
| event_data: null, | |
| trigger_id: null | |
| }) | |
| }).then(function(r) { return r.json(); }).then(function(j) { | |
| if (!j.event_id) { console.error('[fireRegen] no event_id:', j); return; } | |
| console.log('[fireRegen] queued, event_id:', j.event_id); | |
| _listenAndApply(j.event_id, slot_id, seg_idx, _preRegenWaveHtml, _preRegenVideoSrc); | |
| }).catch(function(e) { | |
| console.error('[fireRegen] fetch error:', e); | |
| if (lbl) lbl.textContent = 'Error β see console'; | |
| var sb = document.getElementById('wf_statusbar_' + slot_id); | |
| if (sb) { sb.style.color = '#e05252'; sb.textContent = '\u26a0 Request failed: ' + e.message; } | |
| }); | |
| }); | |
| } | |
| // Subscribe to Gradio SSE stream for an event and apply outputs to DOM. | |
| // For regen handlers, output[0] = video update, output[1] = waveform HTML update. | |
| function _applyVideoSrc(slot_id, newSrc) { | |
| var vidEl = document.getElementById('slot_vid_' + slot_id); | |
| if (!vidEl) return false; | |
| var video = vidEl.querySelector('video'); | |
| if (!video) return false; | |
| if (video.getAttribute('src') === newSrc) return true; // already correct | |
| video.setAttribute('src', newSrc); | |
| video.src = newSrc; | |
| video.load(); | |
| console.log('[_applyVideoSrc] applied src to', 'slot_vid_' + slot_id, 'src:', newSrc.slice(-40)); | |
| return true; | |
| } | |
| // Toast notification β styled like ZeroGPU quota warnings. | |
| function _showRegenToast(message, isError) { | |
| var t = document.createElement('div'); | |
| t.style.cssText = 'position:fixed;bottom:24px;left:50%;transform:translateX(-50%);' + | |
| 'z-index:2147483647;padding:12px 20px;border-radius:8px;font-family:sans-serif;' + | |
| 'font-size:13px;max-width:520px;text-align:center;box-shadow:0 4px 20px rgba(0,0,0,.6);' + | |
| 'background:' + (isError ? '#7a1c1c' : '#1c4a1c') + ';color:#fff;' + | |
| 'border:1px solid ' + (isError ? '#c0392b' : '#27ae60') + ';' + | |
| 'pointer-events:none;'; | |
| t.textContent = message; | |
| document.body.appendChild(t); | |
| setTimeout(function() { | |
| t.style.transition = 'opacity 0.5s'; | |
| t.style.opacity = '0'; | |
| setTimeout(function() { t.parentNode && t.parentNode.removeChild(t); }, 600); | |
| }, isError ? 8000 : 3000); | |
| } | |
| function _listenAndApply(eventId, slot_id, seg_idx, preRegenWaveHtml, preRegenVideoSrc) { | |
| var _pendingVideoSrc = null; | |
| const es = new EventSource('/gradio_api/queue/data?session_hash=' + window.__gradio_session_hash__); | |
| es.onmessage = function(e) { | |
| var msg; | |
| try { msg = JSON.parse(e.data); } catch(_) { return; } | |
| if (msg.event_id !== eventId) return; | |
| if (msg.msg === 'process_generating' || msg.msg === 'process_completed') { | |
| var out = msg.output; | |
| if (out && out.data) { | |
| var vidUpdate = out.data[0]; | |
| var waveUpdate = out.data[1]; | |
| var newSrc = null; | |
| if (vidUpdate) { | |
| if (vidUpdate.value && vidUpdate.value.video && vidUpdate.value.video.url) newSrc = vidUpdate.value.video.url; | |
| else if (vidUpdate.video && vidUpdate.video.url) newSrc = vidUpdate.video.url; | |
| else if (vidUpdate.value && vidUpdate.value.url) newSrc = vidUpdate.value.url; | |
| else if (typeof vidUpdate.value === 'string') newSrc = vidUpdate.value; | |
| else if (vidUpdate.url) newSrc = vidUpdate.url; | |
| } | |
| if (newSrc) _pendingVideoSrc = newSrc; | |
| var waveHtml = null; | |
| if (waveUpdate) { | |
| if (typeof waveUpdate === 'string') waveHtml = waveUpdate; | |
| else if (waveUpdate.value && typeof waveUpdate.value === 'string') waveHtml = waveUpdate.value; | |
| } | |
| if (waveHtml) { | |
| var waveEl = document.getElementById('slot_wave_' + slot_id); | |
| if (waveEl) { | |
| var inner = waveEl.querySelector('.prose') || waveEl.querySelector('div'); | |
| if (inner) inner.innerHTML = waveHtml; | |
| else waveEl.innerHTML = waveHtml; | |
| } | |
| } | |
| } | |
| if (msg.msg === 'process_completed') { | |
| es.close(); | |
| _regenInFlight[slot_id] = false; | |
| var errMsg = msg.output && msg.output.error; | |
| var hadError = !!errMsg; | |
| console.log('[fireRegen] completed for', slot_id, 'error:', hadError, errMsg || ''); | |
| var lbl = document.getElementById('wf_seglabel_' + slot_id); | |
| if (hadError) { | |
| var toastMsg = typeof errMsg === 'string' ? errMsg : JSON.stringify(errMsg); | |
| // Restore previous waveform HTML and video src | |
| if (preRegenWaveHtml !== null) { | |
| var waveEl2 = document.getElementById('slot_wave_' + slot_id); | |
| if (waveEl2) waveEl2.innerHTML = preRegenWaveHtml; | |
| } | |
| if (preRegenVideoSrc !== null) { | |
| var vidElR = document.getElementById('slot_vid_' + slot_id); | |
| if (vidElR) { var vR = vidElR.querySelector('video'); if (vR) { vR.setAttribute('src', preRegenVideoSrc); vR.src = preRegenVideoSrc; vR.load(); } } | |
| } | |
| // Update the statusbar (query after restore so we get the freshly-restored element) | |
| var isAbort = toastMsg.toLowerCase().indexOf('aborted') !== -1; | |
| var isTimeout = toastMsg.toLowerCase().indexOf('timeout') !== -1; | |
| var failMsg = isAbort || isTimeout | |
| ? '\u26a0 GPU cold-start β segment unchanged, try again' | |
| : '\u26a0 Regen failed β segment unchanged'; | |
| var statusBar = document.getElementById('wf_statusbar_' + slot_id); | |
| if (statusBar) { | |
| statusBar.style.color = '#e05252'; | |
| statusBar.textContent = failMsg; | |
| setTimeout(function() { statusBar.style.color = '#888'; statusBar.textContent = 'Click a segment to regenerate \u00a0|\u00a0 Playhead syncs to video'; }, 8000); | |
| } | |
| } else { | |
| if (lbl) lbl.textContent = 'Done'; | |
| var src = _pendingVideoSrc; | |
| if (src) { | |
| _applyVideoSrc(slot_id, src); | |
| setTimeout(function() { _applyVideoSrc(slot_id, src); }, 50); | |
| setTimeout(function() { _applyVideoSrc(slot_id, src); }, 300); | |
| setTimeout(function() { _applyVideoSrc(slot_id, src); }, 800); | |
| var vidEl = document.getElementById('slot_vid_' + slot_id); | |
| if (vidEl) { | |
| var obs = new MutationObserver(function() { _applyVideoSrc(slot_id, src); }); | |
| obs.observe(vidEl, {subtree: true, attributes: true, attributeFilter: ['src'], childList: true}); | |
| setTimeout(function() { obs.disconnect(); }, 2000); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| if (msg.msg === 'close_stream') { es.close(); } | |
| }; | |
| es.onerror = function() { es.close(); _regenInFlight[slot_id] = false; }; | |
| } | |
| // Track in-flight regen per slot β prevents queuing multiple jobs from rapid clicks | |
| var _regenInFlight = {}; | |
| // Shared popup element created once and reused across all slots | |
| let _popup = null; | |
| let _pendingSlot = null, _pendingIdx = null; | |
| function ensurePopup() { | |
| if (_popup) return _popup; | |
| _popup = document.createElement('div'); | |
| _popup.style.cssText = 'display:none;position:fixed;z-index:99999;' + | |
| 'background:#2a2a2a;border:1px solid #555;border-radius:6px;' + | |
| 'padding:8px 12px;box-shadow:0 4px 16px rgba(0,0,0,.5);font-family:sans-serif;'; | |
| var btnStyle = 'color:#fff;border:none;border-radius:4px;padding:5px 10px;' + | |
| 'font-size:11px;cursor:pointer;flex:1;'; | |
| _popup.innerHTML = | |
| '<div id="_wf_popup_lbl" style="color:#ccc;font-size:11px;margin-bottom:6px;white-space:nowrap;"></div>' + | |
| '<div style="display:flex;gap:5px;">' + | |
| '<button id="_wf_popup_taro" style="background:#1d6fa5;' + btnStyle + '">⟳ TARO</button>' + | |
| '<button id="_wf_popup_mma" style="background:#2d7a4a;' + btnStyle + '">⟳ MMAudio</button>' + | |
| '<button id="_wf_popup_hf" style="background:#7a3d8c;' + btnStyle + '">⟳ Hunyuan</button>' + | |
| '</div>'; | |
| document.body.appendChild(_popup); | |
| ['taro','mma','hf'].forEach(function(model) { | |
| document.getElementById('_wf_popup_' + model).onclick = function(e) { | |
| e.stopPropagation(); | |
| var slot = _pendingSlot, idx = _pendingIdx; | |
| hidePopup(); | |
| if (slot !== null && idx !== null) fireRegen(slot, idx, model); | |
| }; | |
| }); | |
| // Use bubble phase (false) so stopPropagation() on the button click prevents this from firing | |
| document.addEventListener('click', function() { hidePopup(); }, false); | |
| return _popup; | |
| } | |
| function hidePopup() { | |
| if (_popup) _popup.style.display = 'none'; | |
| _pendingSlot = null; _pendingIdx = null; | |
| } | |
| window.addEventListener('message', function(e) { | |
| const d = e.data; | |
| console.log('[global msg] received type=' + (d && d.type) + ' action=' + (d && d.action)); | |
| if (!d || d.type !== 'wf_popup') return; | |
| const p = ensurePopup(); | |
| if (d.action === 'hide') { hidePopup(); return; } | |
| // action === 'show' | |
| _pendingSlot = d.slot_id; | |
| _pendingIdx = d.seg_idx; | |
| const lbl = document.getElementById('_wf_popup_lbl'); | |
| if (lbl) lbl.textContent = 'Seg ' + (d.seg_idx + 1) + | |
| ' (' + d.t0.toFixed(2) + 's \u2013 ' + d.t1.toFixed(2) + 's)'; | |
| p.style.display = 'block'; | |
| p.style.left = (d.x + 10) + 'px'; | |
| p.style.top = (d.y + 10) + 'px'; | |
| requestAnimationFrame(function() { | |
| const r = p.getBoundingClientRect(); | |
| if (r.right > window.innerWidth - 8) p.style.left = (window.innerWidth - r.width - 8) + 'px'; | |
| if (r.bottom > window.innerHeight - 8) p.style.top = (window.innerHeight - r.height - 8) + 'px'; | |
| }); | |
| }); | |
| } | |
| """ | |
| with gr.Blocks(title="Generate Audio for Video", css=_SLOT_CSS, js=_GLOBAL_JS) as demo: | |
| gr.Markdown( | |
| "# Generate Audio for Video\n" | |
| "Choose a model and upload a video to generate synchronized audio.\n\n" | |
| "| Model | Best for | Avoid for |\n" | |
| "|-------|----------|-----------|\n" | |
| "| **TARO** | Natural, physics-driven impacts β footsteps, collisions, water, wind, crackling fire. Excels when the sound is tightly coupled to visible motion without needing a text description. | Dialogue, music, or complex layered soundscapes where semantic context matters. |\n" | |
| "| **MMAudio** | Mixed scenes where you want both visual grounding *and* semantic control via a text prompt β e.g. a busy street scene where you want to emphasize the rain rather than the traffic. Great for ambient textures and nuanced sound design. | Pure impact/foley shots where TARO's motion-coupling would be sharper, or cinematic music beds. |\n" | |
| "| **HunyuanFoley** | Cinematic foley requiring high fidelity and explicit creative direction β dramatic SFX, layered environmental design, or any scene where you have a clear written description of the desired sound palette. | Quick one-shot clips where you don't want to write a prompt, or raw impact sounds where timing precision matters more than richness. |" | |
| ) | |
| with gr.Tabs(): | |
| # ---------------------------------------------------------- # | |
| # Tab 1 β TARO # | |
| # ---------------------------------------------------------- # | |
| with gr.Tab("TARO"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| taro_video = gr.Video(label="Input Video") | |
| taro_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0, elem_id="taro_seed") | |
| taro_cfg = gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8.0, step=0.5, elem_id="taro_cfg") | |
| taro_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1, elem_id="taro_steps") | |
| taro_mode = gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde", elem_id="taro_mode") | |
| taro_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=4, value=2, step=0.1, elem_id="taro_cf_dur") | |
| taro_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3", elem_id="taro_cf_db") | |
| taro_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1) | |
| taro_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| (taro_slot_grps, taro_slot_vids, | |
| taro_slot_waves) = _make_output_slots("taro") | |
| # Hidden regen plumbing β render=False so no DOM element is created, | |
| # avoiding Gradio's "Too many arguments" Svelte validation error. | |
| # JS passes values directly via queue/join data array at the correct | |
| # positional index (these show up as inputs to the fn but have no DOM). | |
| taro_regen_seg = gr.Textbox(value="0", render=False) | |
| taro_regen_state = gr.Textbox(value="", render=False) | |
| for trigger in [taro_video, taro_steps, taro_cf_dur]: | |
| trigger.change( | |
| fn=_on_video_upload_taro, | |
| inputs=[taro_video, taro_steps, taro_cf_dur], | |
| outputs=[taro_samples], | |
| ) | |
| taro_samples.change( | |
| fn=_update_slot_visibility, | |
| inputs=[taro_samples], | |
| outputs=taro_slot_grps, | |
| ) | |
| def _run_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n): | |
| flat = generate_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n) | |
| return _unpack_outputs(flat, n, "taro") | |
| # Split group visibility into a separate .then() to avoid Gradio 5 SSR | |
| # "Too many arguments" caused by including gr.Group in mixed output lists. | |
| (taro_btn.click( | |
| fn=_run_taro, | |
| inputs=[taro_video, taro_seed, taro_cfg, taro_steps, taro_mode, | |
| taro_cf_dur, taro_cf_db, taro_samples], | |
| outputs=taro_slot_vids + taro_slot_waves, | |
| ).then( | |
| fn=_update_slot_visibility, | |
| inputs=[taro_samples], | |
| outputs=taro_slot_grps, | |
| )) | |
| # Per-slot regen handlers β JS calls /gradio_api/queue/join with | |
| # fn_index (by api_name) + data=[seg_idx, state_json, video, ...params]. | |
| taro_regen_btns = _register_regen_handlers( | |
| "taro", "taro", taro_regen_seg, taro_regen_state, | |
| [taro_video, taro_seed, taro_cfg, taro_steps, | |
| taro_mode, taro_cf_dur, taro_cf_db], | |
| taro_slot_vids, taro_slot_waves, | |
| ) | |
| # ---------------------------------------------------------- # | |
| # Tab 2 β MMAudio # | |
| # ---------------------------------------------------------- # | |
| with gr.Tab("MMAudio"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| mma_video = gr.Video(label="Input Video") | |
| mma_prompt = gr.Textbox(label="Prompt", placeholder="e.g. footsteps on gravel", elem_id="mma_prompt") | |
| mma_neg = gr.Textbox(label="Negative Prompt", value="music", placeholder="music, speech", elem_id="mma_neg") | |
| mma_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0, elem_id="mma_seed") | |
| mma_cfg = gr.Slider(label="CFG Strength", minimum=1, maximum=10, value=4.5, step=0.5, elem_id="mma_cfg") | |
| mma_steps = gr.Slider(label="Steps", minimum=10, maximum=50, value=25, step=1, elem_id="mma_steps") | |
| mma_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=4, value=2, step=0.1, elem_id="mma_cf_dur") | |
| mma_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3", elem_id="mma_cf_db") | |
| mma_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1) | |
| mma_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| (mma_slot_grps, mma_slot_vids, | |
| mma_slot_waves) = _make_output_slots("mma") | |
| # Hidden regen plumbing β render=False so no DOM element is created, | |
| # avoiding Gradio's "Too many arguments" Svelte validation error. | |
| mma_regen_seg = gr.Textbox(value="0", render=False) | |
| mma_regen_state = gr.Textbox(value="", render=False) | |
| mma_samples.change( | |
| fn=_update_slot_visibility, | |
| inputs=[mma_samples], | |
| outputs=mma_slot_grps, | |
| ) | |
| def _run_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n): | |
| flat = generate_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n) | |
| return _unpack_outputs(flat, n, "mma") | |
| (mma_btn.click( | |
| fn=_run_mmaudio, | |
| inputs=[mma_video, mma_prompt, mma_neg, mma_seed, | |
| mma_cfg, mma_steps, mma_cf_dur, mma_cf_db, mma_samples], | |
| outputs=mma_slot_vids + mma_slot_waves, | |
| ).then( | |
| fn=_update_slot_visibility, | |
| inputs=[mma_samples], | |
| outputs=mma_slot_grps, | |
| )) | |
| mma_regen_btns = _register_regen_handlers( | |
| "mma", "mmaudio", mma_regen_seg, mma_regen_state, | |
| [mma_video, mma_prompt, mma_neg, mma_seed, | |
| mma_cfg, mma_steps, mma_cf_dur, mma_cf_db], | |
| mma_slot_vids, mma_slot_waves, | |
| ) | |
| # ---------------------------------------------------------- # | |
| # Tab 3 β HunyuanVideoFoley # | |
| # ---------------------------------------------------------- # | |
| with gr.Tab("HunyuanFoley"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| hf_video = gr.Video(label="Input Video") | |
| hf_prompt = gr.Textbox(label="Prompt", placeholder="e.g. rain hitting a metal roof", elem_id="hf_prompt") | |
| hf_neg = gr.Textbox(label="Negative Prompt", value="noisy, harsh", elem_id="hf_neg") | |
| hf_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0, elem_id="hf_seed") | |
| hf_guidance = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, value=4.5, step=0.5, elem_id="hf_guidance") | |
| hf_steps = gr.Slider(label="Steps", minimum=10, maximum=100, value=50, step=5, elem_id="hf_steps") | |
| hf_size = gr.Radio(label="Model Size", choices=["xl", "xxl"], value="xxl", elem_id="hf_size") | |
| hf_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=4, value=2, step=0.1, elem_id="hf_cf_dur") | |
| hf_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3", elem_id="hf_cf_db") | |
| hf_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1) | |
| hf_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| (hf_slot_grps, hf_slot_vids, | |
| hf_slot_waves) = _make_output_slots("hf") | |
| # Hidden regen plumbing β render=False so no DOM element is created, | |
| # avoiding Gradio's "Too many arguments" Svelte validation error. | |
| hf_regen_seg = gr.Textbox(value="0", render=False) | |
| hf_regen_state = gr.Textbox(value="", render=False) | |
| hf_samples.change( | |
| fn=_update_slot_visibility, | |
| inputs=[hf_samples], | |
| outputs=hf_slot_grps, | |
| ) | |
| def _run_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n): | |
| flat = generate_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n) | |
| return _unpack_outputs(flat, n, "hf") | |
| (hf_btn.click( | |
| fn=_run_hunyuan, | |
| inputs=[hf_video, hf_prompt, hf_neg, hf_seed, | |
| hf_guidance, hf_steps, hf_size, hf_cf_dur, hf_cf_db, hf_samples], | |
| outputs=hf_slot_vids + hf_slot_waves, | |
| ).then( | |
| fn=_update_slot_visibility, | |
| inputs=[hf_samples], | |
| outputs=hf_slot_grps, | |
| )) | |
| hf_regen_btns = _register_regen_handlers( | |
| "hf", "hunyuan", hf_regen_seg, hf_regen_state, | |
| [hf_video, hf_prompt, hf_neg, hf_seed, | |
| hf_guidance, hf_steps, hf_size, hf_cf_dur, hf_cf_db], | |
| hf_slot_vids, hf_slot_waves, | |
| ) | |
| # ---- Browser-safe transcode on upload ---- | |
| # Gradio serves the original uploaded file to the browser preview widget, | |
| # so H.265 sources show as blank. We re-encode to H.264 on upload and feed | |
| # the result back so the preview plays. mux_video_audio already re-encodes | |
| # to H.264 during generation, so no double-conversion conflict. | |
| taro_video.upload(fn=_transcode_for_browser, inputs=[taro_video], outputs=[taro_video]) | |
| mma_video.upload(fn=_transcode_for_browser, inputs=[mma_video], outputs=[mma_video]) | |
| hf_video.upload(fn=_transcode_for_browser, inputs=[hf_video], outputs=[hf_video]) | |
| # ---- Cross-tab video sync ---- | |
| _sync = lambda v: (gr.update(value=v), gr.update(value=v)) | |
| taro_video.change(fn=_sync, inputs=[taro_video], outputs=[mma_video, hf_video]) | |
| mma_video.change(fn=_sync, inputs=[mma_video], outputs=[taro_video, hf_video]) | |
| hf_video.change(fn=_sync, inputs=[hf_video], outputs=[taro_video, mma_video]) | |
| # ---- Cross-model regen endpoints ---- | |
| # render=False inputs/outputs: no DOM elements created, no SSR validation impact. | |
| # JS calls these via /gradio_api/queue/join using the api_name and applies | |
| # the returned video+waveform directly to the target slot's DOM elements. | |
| _xr_seg = gr.Textbox(value="0", render=False) | |
| _xr_state = gr.Textbox(value="", render=False) | |
| _xr_slot_id = gr.Textbox(value="", render=False) | |
| # Dummy outputs for xregen events: must be real rendered components so Gradio | |
| # can look them up in session state during postprocess_data. The JS listener | |
| # (_listenAndApply) applies the returned video/HTML directly to the correct | |
| # slot's DOM elements and ignores Gradio's own output routing, so these | |
| # slot-0 components simply act as sinks β their displayed value is overwritten | |
| # by the real JS update immediately after. | |
| _xr_dummy_vid = taro_slot_vids[0] | |
| _xr_dummy_wave = taro_slot_waves[0] | |
| # TARO cross-model regen inputs: seg_idx, state_json, slot_id, seed, cfg, steps, mode, cf_dur, cf_db | |
| _xr_taro_seed = gr.Textbox(value="-1", render=False) | |
| _xr_taro_cfg = gr.Textbox(value="7.5", render=False) | |
| _xr_taro_steps = gr.Textbox(value="25", render=False) | |
| _xr_taro_mode = gr.Textbox(value="sde", render=False) | |
| _xr_taro_cfd = gr.Textbox(value="2", render=False) | |
| _xr_taro_cfdb = gr.Textbox(value="3", render=False) | |
| gr.Button(render=False).click( | |
| fn=xregen_taro, | |
| inputs=[_xr_seg, _xr_state, _xr_slot_id, | |
| _xr_taro_seed, _xr_taro_cfg, _xr_taro_steps, | |
| _xr_taro_mode, _xr_taro_cfd, _xr_taro_cfdb], | |
| outputs=[_xr_dummy_vid, _xr_dummy_wave], | |
| api_name="xregen_taro", | |
| ) | |
| # MMAudio cross-model regen inputs: seg_idx, state_json, slot_id, prompt, neg, seed, cfg, steps, cf_dur, cf_db | |
| _xr_mma_prompt = gr.Textbox(value="", render=False) | |
| _xr_mma_neg = gr.Textbox(value="", render=False) | |
| _xr_mma_seed = gr.Textbox(value="-1", render=False) | |
| _xr_mma_cfg = gr.Textbox(value="4.5", render=False) | |
| _xr_mma_steps = gr.Textbox(value="25", render=False) | |
| _xr_mma_cfd = gr.Textbox(value="2", render=False) | |
| _xr_mma_cfdb = gr.Textbox(value="3", render=False) | |
| gr.Button(render=False).click( | |
| fn=xregen_mmaudio, | |
| inputs=[_xr_seg, _xr_state, _xr_slot_id, | |
| _xr_mma_prompt, _xr_mma_neg, _xr_mma_seed, | |
| _xr_mma_cfg, _xr_mma_steps, _xr_mma_cfd, _xr_mma_cfdb], | |
| outputs=[_xr_dummy_vid, _xr_dummy_wave], | |
| api_name="xregen_mmaudio", | |
| ) | |
| # HunyuanFoley cross-model regen inputs: seg_idx, state_json, slot_id, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db | |
| _xr_hf_prompt = gr.Textbox(value="", render=False) | |
| _xr_hf_neg = gr.Textbox(value="", render=False) | |
| _xr_hf_seed = gr.Textbox(value="-1", render=False) | |
| _xr_hf_guide = gr.Textbox(value="4.5", render=False) | |
| _xr_hf_steps = gr.Textbox(value="50", render=False) | |
| _xr_hf_size = gr.Textbox(value="xxl", render=False) | |
| _xr_hf_cfd = gr.Textbox(value="2", render=False) | |
| _xr_hf_cfdb = gr.Textbox(value="3", render=False) | |
| gr.Button(render=False).click( | |
| fn=xregen_hunyuan, | |
| inputs=[_xr_seg, _xr_state, _xr_slot_id, | |
| _xr_hf_prompt, _xr_hf_neg, _xr_hf_seed, | |
| _xr_hf_guide, _xr_hf_steps, _xr_hf_size, | |
| _xr_hf_cfd, _xr_hf_cfdb], | |
| outputs=[_xr_dummy_vid, _xr_dummy_wave], | |
| api_name="xregen_hunyuan", | |
| ) | |
| # NOTE: ZeroGPU quota attribution is handled via postMessage("zerogpu-headers") | |
| # to the HF parent frame β the same mechanism Gradio's own JS client uses. | |
| # This replaced the old x-ip-token relay approach which was unreliable. | |
| print("[startup] app.py fully loaded β regen handlers registered, SSR disabled") | |
| demo.queue(max_size=10).launch(ssr_mode=False, height=900, allowed_paths=["/tmp"]) | |