| """ |
| Generate Audio for Video β multi-model Gradio app. |
| |
| Supported models |
| ---------------- |
| TARO β video-conditioned diffusion via CAVP + onset features (16 kHz, 8.192 s window) |
| MMAudio β multimodal flow-matching with CLIP/Synchformer + text prompt (44 kHz, 8 s window) |
| HunyuanFoley β text-guided foley via SigLIP2 + Synchformer + CLAP (48 kHz, up to 15 s) |
| """ |
|
|
| import html as _html |
| import os |
| import sys |
| import json |
| import shutil |
| import tempfile |
| import random |
| import threading |
| import time |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from pathlib import Path |
|
|
| import torch |
| import numpy as np |
| import torchaudio |
| import ffmpeg |
| import spaces |
| import gradio as gr |
| from huggingface_hub import hf_hub_download, snapshot_download |
|
|
| |
| |
| |
|
|
| CKPT_REPO_ID = "JackIsNotInTheBox/Generate_Audio_for_Video_Checkpoints" |
| CACHE_DIR = "/tmp/model_ckpts" |
| os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
| |
| MMAUDIO_WEIGHTS_DIR = Path(CACHE_DIR) / "MMAudio" / "weights" |
| MMAUDIO_EXT_DIR = Path(CACHE_DIR) / "MMAudio" / "ext_weights" |
| HUNYUAN_MODEL_DIR = Path(CACHE_DIR) / "HunyuanFoley" |
| MMAUDIO_WEIGHTS_DIR.mkdir(parents=True, exist_ok=True) |
| MMAUDIO_EXT_DIR.mkdir(parents=True, exist_ok=True) |
| HUNYUAN_MODEL_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| def _dl_taro(): |
| """Download TARO .ckpt/.pt files and return their local paths.""" |
| c = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/cavp_epoch66.ckpt", cache_dir=CACHE_DIR) |
| o = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/onset_model.ckpt", cache_dir=CACHE_DIR) |
| t = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/taro_ckpt.pt", cache_dir=CACHE_DIR) |
| print("TARO checkpoints downloaded.") |
| return c, o, t |
|
|
| def _dl_mmaudio(): |
| """Download MMAudio .pth files and return their local paths.""" |
| m = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/mmaudio_large_44k_v2.pth", |
| cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_WEIGHTS_DIR), local_dir_use_symlinks=False) |
| v = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/v1-44.pth", |
| cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False) |
| s = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/synchformer_state_dict.pth", |
| cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False) |
| print("MMAudio checkpoints downloaded.") |
| return m, v, s |
|
|
| def _dl_hunyuan(): |
| """Download HunyuanVideoFoley .pth files.""" |
| hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/hunyuanvideo_foley.pth", |
| cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False) |
| hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/vae_128d_48k.pth", |
| cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False) |
| hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/synchformer_state_dict.pth", |
| cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False) |
| print("HunyuanVideoFoley checkpoints downloaded.") |
|
|
| def _dl_clap(): |
| """Pre-download CLAP so from_pretrained() hits local cache inside the ZeroGPU worker.""" |
| snapshot_download(repo_id="laion/larger_clap_general") |
| print("CLAP model pre-downloaded.") |
|
|
| def _dl_clip(): |
| """Pre-download MMAudio's CLIP model (~3.95 GB) to avoid GPU-window budget drain.""" |
| snapshot_download(repo_id="apple/DFN5B-CLIP-ViT-H-14-384") |
| print("MMAudio CLIP model pre-downloaded.") |
|
|
| def _dl_audioldm2(): |
| """Pre-download AudioLDM2 VAE/vocoder used by TARO's from_pretrained() calls.""" |
| snapshot_download(repo_id="cvssp/audioldm2") |
| print("AudioLDM2 pre-downloaded.") |
|
|
| def _dl_bigvgan(): |
| """Pre-download BigVGAN vocoder (~489 MB) used by MMAudio.""" |
| snapshot_download(repo_id="nvidia/bigvgan_v2_44khz_128band_512x") |
| print("BigVGAN vocoder pre-downloaded.") |
|
|
| print("[startup] Starting parallel checkpoint + model downloadsβ¦") |
| _t_dl_start = time.perf_counter() |
| with ThreadPoolExecutor(max_workers=7) as _pool: |
| _fut_taro = _pool.submit(_dl_taro) |
| _fut_mmaudio = _pool.submit(_dl_mmaudio) |
| _fut_hunyuan = _pool.submit(_dl_hunyuan) |
| _fut_clap = _pool.submit(_dl_clap) |
| _fut_clip = _pool.submit(_dl_clip) |
| _fut_aldm2 = _pool.submit(_dl_audioldm2) |
| _fut_bigvgan = _pool.submit(_dl_bigvgan) |
| |
| for _fut in as_completed([_fut_taro, _fut_mmaudio, _fut_hunyuan, |
| _fut_clap, _fut_clip, _fut_aldm2, _fut_bigvgan]): |
| _fut.result() |
|
|
| cavp_ckpt_path, onset_ckpt_path, taro_ckpt_path = _fut_taro.result() |
| mmaudio_model_path, mmaudio_vae_path, mmaudio_synchformer_path = _fut_mmaudio.result() |
| print(f"[startup] All downloads done in {time.perf_counter() - _t_dl_start:.1f}s") |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| _GPU_CTX: dict = {} |
| _GPU_CTX_LOCK = threading.Lock() |
|
|
| def _ctx_store(fn_name: str, data: dict) -> None: |
| """Store *data* under *fn_name* key (overwrites previous).""" |
| with _GPU_CTX_LOCK: |
| _GPU_CTX[fn_name] = data |
|
|
| def _ctx_load(fn_name: str) -> dict: |
| """Pop and return the context dict stored under *fn_name*.""" |
| with _GPU_CTX_LOCK: |
| return _GPU_CTX.pop(fn_name, {}) |
|
|
| MAX_SLOTS = 8 |
| MAX_SEGS = 8 |
|
|
| |
| SEG_COLORS = [ |
| "rgba(100,180,255,{a})", "rgba(255,160,100,{a})", |
| "rgba(120,220,140,{a})", "rgba(220,120,220,{a})", |
| "rgba(255,220,80,{a})", "rgba(80,220,220,{a})", |
| "rgba(255,100,100,{a})", "rgba(180,255,180,{a})", |
| ] |
|
|
| |
| |
| |
|
|
| def _ensure_syspath(subdir: str) -> str: |
| """Add *subdir* (relative to app.py) to sys.path if not already present. |
| Returns the absolute path for convenience.""" |
| p = os.path.join(os.path.dirname(os.path.abspath(__file__)), subdir) |
| if p not in sys.path: |
| sys.path.insert(0, p) |
| return p |
|
|
|
|
| def _get_device_and_dtype() -> tuple: |
| """Return (device, weight_dtype) pair used by all GPU functions.""" |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| return device, torch.bfloat16 |
|
|
|
|
| def _extract_segment_clip(silent_video: str, seg_start: float, seg_dur: float, |
| output_path: str) -> str: |
| """Stream-copy a segment from *silent_video* to *output_path*. Returns *output_path*.""" |
| ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output( |
| output_path, vcodec="copy", an=None |
| ).run(overwrite_output=True, quiet=True) |
| return output_path |
|
|
| |
| |
| |
| |
| _SLOT_LOCKS: dict = {} |
| _SLOT_LOCKS_MUTEX = threading.Lock() |
|
|
| def _get_slot_lock(slot_id: str) -> threading.Lock: |
| with _SLOT_LOCKS_MUTEX: |
| if slot_id not in _SLOT_LOCKS: |
| _SLOT_LOCKS[slot_id] = threading.Lock() |
| return _SLOT_LOCKS[slot_id] |
|
|
| def set_global_seed(seed: int) -> None: |
| np.random.seed(seed % (2**32)) |
| random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(seed) |
|
|
| def get_random_seed() -> int: |
| return random.randint(0, 2**32 - 1) |
|
|
| def get_video_duration(video_path: str) -> float: |
| """Return video duration in seconds (CPU only).""" |
| probe = ffmpeg.probe(video_path) |
| return float(probe["format"]["duration"]) |
|
|
| def strip_audio_from_video(video_path: str, output_path: str) -> None: |
| """Write a silent copy of *video_path* to *output_path* (stream-copy, no re-encode).""" |
| ffmpeg.input(video_path).output(output_path, vcodec="copy", an=None).run( |
| overwrite_output=True, quiet=True |
| ) |
|
|
| def _transcode_for_browser(video_path: str) -> str: |
| """Re-encode uploaded video to H.264/AAC MP4 so the browser preview widget can play it. |
| |
| Returns a NEW path in a fresh /tmp/gradio/ subdirectory. Gradio probes the |
| returned path fresh, sees H.264, and serves it directly without its own |
| slow fallback converter. The in-place overwrite approach loses the race |
| because Gradio probes the original path at upload time before this callback runs. |
| Only called on upload β not during generation. |
| """ |
| if video_path is None: |
| return video_path |
| try: |
| probe = ffmpeg.probe(video_path) |
| has_audio = any(s["codec_type"] == "audio" for s in probe.get("streams", [])) |
| |
| video_streams = [s for s in probe.get("streams", []) if s["codec_type"] == "video"] |
| if video_streams and video_streams[0].get("codec_name") == "h264": |
| print(f"[transcode_for_browser] already H.264, skipping") |
| return video_path |
| |
| |
| |
| import os as _os |
| upload_dir = _os.path.dirname(video_path) |
| stem = _os.path.splitext(_os.path.basename(video_path))[0] |
| out_path = _os.path.join(upload_dir, stem + "_h264.mp4") |
| kwargs = dict( |
| vcodec="libx264", preset="fast", crf=18, |
| pix_fmt="yuv420p", movflags="+faststart", |
| ) |
| if has_audio: |
| kwargs["acodec"] = "aac" |
| kwargs["audio_bitrate"] = "128k" |
| else: |
| kwargs["an"] = None |
| |
| ffmpeg.input(video_path).output(out_path, map="0:v:0", **kwargs).run( |
| overwrite_output=True, quiet=True |
| ) |
| print(f"[transcode_for_browser] transcoded to H.264: {out_path}") |
| return out_path |
| except Exception as e: |
| print(f"[transcode_for_browser] failed, using original: {e}") |
| return video_path |
|
|
|
|
| |
| |
| |
| _TEMP_DIRS: list = [] |
| _TEMP_DIRS_MAX = 10 |
|
|
| def _register_tmp_dir(tmp_dir: str) -> str: |
| """Register a temp dir so it can be cleaned up when newer ones replace it.""" |
| _TEMP_DIRS.append(tmp_dir) |
| while len(_TEMP_DIRS) > _TEMP_DIRS_MAX: |
| old = _TEMP_DIRS.pop(0) |
| try: |
| shutil.rmtree(old, ignore_errors=True) |
| print(f"[cleanup] Removed old temp dir: {old}") |
| except Exception: |
| pass |
| return tmp_dir |
|
|
|
|
| def _save_seg_wavs(wavs: list[np.ndarray], tmp_dir: str, prefix: str) -> list[str]: |
| """Save a list of numpy wav arrays to .npy files, return list of paths. |
| This avoids serialising large float arrays into JSON/HTML data-state.""" |
| paths = [] |
| for i, w in enumerate(wavs): |
| p = os.path.join(tmp_dir, f"{prefix}_seg{i}.npy") |
| np.save(p, w) |
| paths.append(p) |
| return paths |
|
|
|
|
| def _load_seg_wavs(paths: list[str]) -> list[np.ndarray]: |
| """Load segment wav arrays from .npy file paths.""" |
| return [np.load(p) for p in paths] |
|
|
|
|
| |
| |
| |
|
|
| def _load_taro_models(device, weight_dtype): |
| """Load TARO MMDiT + AudioLDM2 VAE/vocoder. Returns (model_net, vae, vocoder, latents_scale).""" |
| from TARO.models import MMDiT |
| from diffusers import AutoencoderKL |
| from transformers import SpeechT5HifiGan |
|
|
| model_net = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device) |
| model_net.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"]) |
| model_net.eval().to(weight_dtype) |
| vae = AutoencoderKL.from_pretrained("cvssp/audioldm2", subfolder="vae").to(device).eval() |
| vocoder = SpeechT5HifiGan.from_pretrained("cvssp/audioldm2", subfolder="vocoder").to(device) |
| latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device) |
| return model_net, vae, vocoder, latents_scale |
|
|
|
|
| def _load_taro_feature_extractors(device): |
| """Load CAVP + onset extractors. Returns (extract_cavp, onset_model).""" |
| from TARO.cavp_util import Extract_CAVP_Features |
| from TARO.onset_util import VideoOnsetNet |
|
|
| extract_cavp = Extract_CAVP_Features( |
| device=device, config_path="TARO/cavp/cavp.yaml", ckpt_path=cavp_ckpt_path, |
| ) |
| raw_sd = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"] |
| onset_sd = {} |
| for k, v in raw_sd.items(): |
| if "model.net.model" in k: k = k.replace("model.net.model", "net.model") |
| elif "model.fc." in k: k = k.replace("model.fc", "fc") |
| onset_sd[k] = v |
| onset_model = VideoOnsetNet(pretrained=False).to(device) |
| onset_model.load_state_dict(onset_sd) |
| onset_model.eval() |
| return extract_cavp, onset_model |
|
|
|
|
| def _load_mmaudio_models(device, dtype): |
| """Load MMAudio net + feature_utils. Returns (net, feature_utils, model_cfg, seq_cfg).""" |
| from mmaudio.eval_utils import all_model_cfg |
| from mmaudio.model.networks import get_my_mmaudio |
| from mmaudio.model.utils.features_utils import FeaturesUtils |
|
|
| model_cfg = all_model_cfg["large_44k_v2"] |
| model_cfg.model_path = Path(mmaudio_model_path) |
| model_cfg.vae_path = Path(mmaudio_vae_path) |
| model_cfg.synchformer_ckpt = Path(mmaudio_synchformer_path) |
| model_cfg.bigvgan_16k_path = None |
| seq_cfg = model_cfg.seq_cfg |
|
|
| net = get_my_mmaudio(model_cfg.model_name).to(device, dtype).eval() |
| net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True)) |
| feature_utils = FeaturesUtils( |
| tod_vae_ckpt=str(model_cfg.vae_path), |
| synchformer_ckpt=str(model_cfg.synchformer_ckpt), |
| enable_conditions=True, mode=model_cfg.mode, |
| bigvgan_vocoder_ckpt=None, need_vae_encoder=False, |
| ).to(device, dtype).eval() |
| return net, feature_utils, model_cfg, seq_cfg |
|
|
|
|
| def _load_hunyuan_model(device, model_size): |
| """Load HunyuanFoley model dict + config. Returns (model_dict, cfg).""" |
| from hunyuanvideo_foley.utils.model_utils import load_model |
| model_size = model_size.lower() |
| config_map = { |
| "xl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xl.yaml", |
| "xxl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xxl.yaml", |
| } |
| config_path = config_map.get(model_size, config_map["xxl"]) |
| hunyuan_weights_dir = str(HUNYUAN_MODEL_DIR / "HunyuanVideo-Foley") |
| print(f"[HunyuanFoley] Loading {model_size.upper()} model from {hunyuan_weights_dir}") |
| return load_model(hunyuan_weights_dir, config_path, device, |
| enable_offload=False, model_size=model_size) |
|
|
|
|
| def mux_video_audio(silent_video: str, audio_path: str, output_path: str, |
| model: str = None) -> None: |
| """Mux a silent video with an audio file into *output_path*. |
| |
| For HunyuanFoley (*model*="hunyuan") we use its own merge_audio_video which |
| handles its specific ffmpeg quirks; all other models use stream-copy muxing. |
| """ |
| if model == "hunyuan": |
| _ensure_syspath("HunyuanVideo-Foley") |
| from hunyuanvideo_foley.utils.media_utils import merge_audio_video |
| merge_audio_video(audio_path, silent_video, output_path) |
| else: |
| v_in = ffmpeg.input(silent_video) |
| a_in = ffmpeg.input(audio_path) |
| ffmpeg.output( |
| v_in["v:0"], |
| a_in["a:0"], |
| output_path, |
| vcodec="libx264", preset="fast", crf=18, |
| pix_fmt="yuv420p", |
| acodec="aac", audio_bitrate="128k", |
| movflags="+faststart", |
| ).run(overwrite_output=True, quiet=True) |
|
|
|
|
| |
| |
| |
| |
|
|
| def _build_segments(total_dur_s: float, window_s: float, crossfade_s: float) -> list[tuple[float, float]]: |
| """Return list of (start, end) pairs covering *total_dur_s*. |
| |
| Every segment uses the full *window_s* inference window. Segments are |
| equally spaced so every overlap is identical, guaranteeing the crossfade |
| setting is honoured at every boundary with no raw bleed. |
| |
| Algorithm |
| --------- |
| 1. Clamp crossfade_s so the step stays positive. |
| 2. Find the minimum n such that n segments of *window_s* cover |
| *total_dur_s* with overlap β₯ crossfade_s at every boundary: |
| n = ceil((total_dur_s - crossfade_s) / (window_s - crossfade_s)) |
| 3. Compute equal spacing: step = (total_dur_s - window_s) / (n - 1) |
| so that every gap is identical and the last segment ends exactly at |
| total_dur_s. |
| 4. Every segment is exactly *window_s* wide. The trailing audio of each |
| segment beyond its contact edge is discarded in _stitch_wavs. |
| """ |
| crossfade_s = min(crossfade_s, window_s * 0.5) |
| if total_dur_s <= window_s: |
| return [(0.0, total_dur_s)] |
| import math |
| step_min = window_s - crossfade_s |
| n = math.ceil((total_dur_s - crossfade_s) / step_min) |
| n = max(n, 2) |
| |
| step_s = (total_dur_s - window_s) / (n - 1) |
| return [(i * step_s, i * step_s + window_s) for i in range(n)] |
|
|
|
|
| def _cf_join(a: np.ndarray, b: np.ndarray, |
| crossfade_s: float, db_boost: float, sr: int) -> np.ndarray: |
| """Equal-power crossfade join. Works for both mono (T,) and stereo (C, T) arrays. |
| Stereo arrays are expected in (channels, samples) layout. |
| |
| db_boost is applied to the overlap region as a whole (after blending), so |
| it compensates for the -3 dB equal-power dip without doubling amplitude. |
| Applying gain to each side independently (the common mistake) causes a |
| +3 dB loudness bump at the seam β this version avoids that.""" |
| stereo = a.ndim == 2 |
| n_a = a.shape[1] if stereo else len(a) |
| n_b = b.shape[1] if stereo else len(b) |
| cf = min(int(round(crossfade_s * sr)), n_a, n_b) |
| if cf <= 0: |
| return np.concatenate([a, b], axis=1 if stereo else 0) |
| gain = 10 ** (db_boost / 20.0) |
| t = np.linspace(0.0, 1.0, cf, dtype=np.float32) |
| fade_out = np.cos(t * np.pi / 2) |
| fade_in = np.sin(t * np.pi / 2) |
| if stereo: |
| |
| overlap = (a[:, -cf:] * fade_out + b[:, :cf] * fade_in) * gain |
| return np.concatenate([a[:, :-cf], overlap, b[:, cf:]], axis=1) |
| else: |
| overlap = (a[-cf:] * fade_out + b[:cf] * fade_in) * gain |
| return np.concatenate([a[:-cf], overlap, b[cf:]]) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| TARO_SR = 16000 |
| TARO_TRUNCATE = 131072 |
| TARO_FPS = 4 |
| TARO_TRUNCATE_FRAME = int(TARO_FPS * TARO_TRUNCATE / TARO_SR) |
| TARO_TRUNCATE_ONSET = 120 |
| TARO_MODEL_DUR = TARO_TRUNCATE / TARO_SR |
|
|
| GPU_DURATION_CAP = 300 |
|
|
| MODEL_CONFIGS = { |
| "taro": { |
| "window_s": TARO_MODEL_DUR, |
| "sr": TARO_SR, |
| "secs_per_step": 0.025, |
| "load_overhead": 15, |
| "tab_prefix": "taro", |
| "label": "TARO", |
| "regen_fn": None, |
| }, |
| "mmaudio": { |
| "window_s": 8.0, |
| "sr": 48000, |
| "secs_per_step": 0.25, |
| "load_overhead": 30, |
| "tab_prefix": "mma", |
| "label": "MMAudio", |
| "regen_fn": None, |
| }, |
| "hunyuan": { |
| "window_s": 15.0, |
| "sr": 48000, |
| "secs_per_step": 0.35, |
| "load_overhead": 55, |
| "tab_prefix": "hf", |
| "label": "HunyuanFoley", |
| "regen_fn": None, |
| }, |
| } |
|
|
| |
| TARO_SECS_PER_STEP = MODEL_CONFIGS["taro"]["secs_per_step"] |
| MMAUDIO_WINDOW = MODEL_CONFIGS["mmaudio"]["window_s"] |
| MMAUDIO_SECS_PER_STEP = MODEL_CONFIGS["mmaudio"]["secs_per_step"] |
| HUNYUAN_MAX_DUR = MODEL_CONFIGS["hunyuan"]["window_s"] |
| HUNYUAN_SECS_PER_STEP = MODEL_CONFIGS["hunyuan"]["secs_per_step"] |
|
|
|
|
| def _clamp_duration(secs: float, label: str) -> int: |
| """Clamp a raw GPU-seconds estimate to [60, GPU_DURATION_CAP] and log it.""" |
| result = min(GPU_DURATION_CAP, max(60, int(secs))) |
| print(f"[duration] {label}: {secs:.0f}s raw β {result}s reserved") |
| return result |
|
|
|
|
| def _estimate_gpu_duration(model_key: str, num_samples: int, num_steps: int, |
| total_dur_s: float = None, crossfade_s: float = 0, |
| video_file: str = None) -> int: |
| """Estimate GPU seconds for a full generation call. |
| |
| Formula: num_samples Γ n_segs Γ num_steps Γ secs_per_step + load_overhead |
| """ |
| cfg = MODEL_CONFIGS[model_key] |
| try: |
| if total_dur_s is None: |
| total_dur_s = get_video_duration(video_file) |
| n_segs = len(_build_segments(total_dur_s, cfg["window_s"], float(crossfade_s))) |
| except Exception: |
| n_segs = 1 |
| secs = int(num_samples) * n_segs * int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"] |
| print(f"[duration] {cfg['label']}: {int(num_samples)}samp Γ {n_segs}seg Γ " |
| f"{int(num_steps)}steps β {secs:.0f}s β capped ", end="") |
| return _clamp_duration(secs, cfg["label"]) |
|
|
|
|
| def _estimate_regen_duration(model_key: str, num_steps: int) -> int: |
| """Estimate GPU seconds for a single-segment regen call.""" |
| cfg = MODEL_CONFIGS[model_key] |
| secs = int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"] |
| print(f"[duration] {cfg['label']} regen: 1 seg Γ {int(num_steps)} steps β ", end="") |
| return _clamp_duration(secs, f"{cfg['label']} regen") |
|
|
| _TARO_CACHE_MAXLEN = 16 |
| _TARO_INFERENCE_CACHE: dict = {} |
| _TARO_CACHE_LOCK = threading.Lock() |
|
|
|
|
| def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: float) -> int: |
| n_segs = len(_build_segments(total_dur_s, TARO_MODEL_DUR, crossfade_s)) |
| time_per_seg = num_steps * TARO_SECS_PER_STEP |
| max_s = int(600.0 / (n_segs * time_per_seg)) |
| return max(1, min(max_s, MAX_SLOTS)) |
|
|
|
|
| def _taro_duration(video_file, seed_val, cfg_scale, num_steps, mode, |
| crossfade_s, crossfade_db, num_samples): |
| """Pre-GPU callable β must match _taro_gpu_infer's input order exactly.""" |
| return _estimate_gpu_duration("taro", int(num_samples), int(num_steps), |
| video_file=video_file, crossfade_s=crossfade_s) |
|
|
|
|
| def _taro_infer_segment( |
| model, vae, vocoder, |
| cavp_feats_full, onset_feats_full, |
| seg_start_s: float, seg_end_s: float, |
| device, weight_dtype, |
| cfg_scale: float, num_steps: int, mode: str, |
| latents_scale, |
| euler_sampler, euler_maruyama_sampler, |
| ) -> np.ndarray: |
| """Single-segment TARO inference. Returns wav array trimmed to segment length.""" |
| |
| cavp_start = int(round(seg_start_s * TARO_FPS)) |
| cavp_slice = cavp_feats_full[cavp_start : cavp_start + TARO_TRUNCATE_FRAME] |
| if cavp_slice.shape[0] < TARO_TRUNCATE_FRAME: |
| pad = np.zeros( |
| (TARO_TRUNCATE_FRAME - cavp_slice.shape[0],) + cavp_slice.shape[1:], |
| dtype=cavp_slice.dtype, |
| ) |
| cavp_slice = np.concatenate([cavp_slice, pad], axis=0) |
| video_feats = torch.from_numpy(cavp_slice).unsqueeze(0).to(device, weight_dtype) |
|
|
| |
| onset_fps = TARO_TRUNCATE_ONSET / TARO_MODEL_DUR |
| onset_start = int(round(seg_start_s * onset_fps)) |
| onset_slice = onset_feats_full[onset_start : onset_start + TARO_TRUNCATE_ONSET] |
| if onset_slice.shape[0] < TARO_TRUNCATE_ONSET: |
| onset_slice = np.pad( |
| onset_slice, |
| ((0, TARO_TRUNCATE_ONSET - onset_slice.shape[0]),), |
| mode="constant", |
| ) |
| onset_feats_t = torch.from_numpy(onset_slice).unsqueeze(0).to(device, weight_dtype) |
|
|
| |
| z = torch.randn(1, model.in_channels, 204, 16, device=device, dtype=weight_dtype) |
|
|
| sampling_kwargs = dict( |
| model=model, |
| latents=z, |
| y=onset_feats_t, |
| context=video_feats, |
| num_steps=int(num_steps), |
| heun=False, |
| cfg_scale=float(cfg_scale), |
| guidance_low=0.0, |
| guidance_high=0.7, |
| path_type="linear", |
| ) |
| with torch.no_grad(): |
| samples = (euler_maruyama_sampler if mode == "sde" else euler_sampler)(**sampling_kwargs) |
| |
| if isinstance(samples, tuple): |
| samples = samples[0] |
|
|
| |
| samples = vae.decode(samples / latents_scale).sample |
| wav = vocoder(samples.squeeze().float()).detach().cpu().numpy() |
| return wav |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| TARGET_SR = 48000 |
| TARO_SR_OUT = TARGET_SR |
|
|
|
|
| def _resample_to_target(wav: np.ndarray, src_sr: int, |
| dst_sr: int = None) -> np.ndarray: |
| """Resample *wav* (mono or stereo numpy float32) from *src_sr* to *dst_sr*. |
| |
| *dst_sr* defaults to TARGET_SR (48 kHz). No-op if src_sr == dst_sr. |
| Uses torchaudio Kaiser-windowed sinc resampling β CPU-only, ZeroGPU-safe. |
| """ |
| if dst_sr is None: |
| dst_sr = TARGET_SR |
| if src_sr == dst_sr: |
| return wav |
| stereo = wav.ndim == 2 |
| t = torch.from_numpy(np.ascontiguousarray(wav.astype(np.float32))) |
| if not stereo: |
| t = t.unsqueeze(0) |
| t = torchaudio.functional.resample(t, src_sr, dst_sr) |
| if not stereo: |
| t = t.squeeze(0) |
| return t.numpy() |
|
|
|
|
| def _upsample_taro(wav_16k: np.ndarray) -> np.ndarray: |
| """Upsample a mono 16 kHz numpy array to 48 kHz via sinc resampling (CPU). |
| |
| torchaudio.functional.resample uses a Kaiser-windowed sinc filter β |
| mathematically optimal for bandlimited signals, zero CUDA risk. |
| Returns a mono float32 numpy array at 48 kHz. |
| """ |
| dur_in = len(wav_16k) / TARO_SR |
| print(f"[TARO upsample] {dur_in:.2f}s @ {TARO_SR}Hz β {TARGET_SR}Hz (sinc, CPU) β¦") |
| result = _resample_to_target(wav_16k, TARO_SR) |
| print(f"[TARO upsample] done β {len(result)/TARGET_SR:.2f}s @ {TARGET_SR}Hz " |
| f"(expected {dur_in * 3:.2f}s, ratio 3Γ)") |
| return result |
|
|
|
|
| def _stitch_wavs(wavs: list[np.ndarray], crossfade_s: float, db_boost: float, |
| total_dur_s: float, sr: int, |
| segments: list[tuple[float, float]] = None) -> np.ndarray: |
| """Crossfade-join a list of wav arrays and trim to *total_dur_s*. |
| Works for both mono (T,) and stereo (C, T) arrays. |
| |
| When *segments* is provided (list of (start, end) video-time pairs), |
| each wav is trimmed to its contact-edge window before joining: |
| |
| contact_edge[iβi+1] = midpoint of overlap = (seg[i].end + seg[i+1].start) / 2 |
| half_cf = crossfade_s / 2 |
| |
| seg i keep: [contact_edge[i-1βi] - half_cf, contact_edge[iβi+1] + half_cf] |
| expressed as sample offsets into the generated audio for that segment. |
| |
| This guarantees every crossfade zone is exactly crossfade_s wide with no |
| raw bleed regardless of how much the inference windows overlap. |
| """ |
| def _trim(wav, start_s, end_s, seg_start_s): |
| """Trim wav to [start_s, end_s] expressed in absolute video time, |
| where the wav starts at seg_start_s in video time.""" |
| s = max(0, int(round((start_s - seg_start_s) * sr))) |
| e = int(round((end_s - seg_start_s) * sr)) |
| e = min(e, wav.shape[1] if wav.ndim == 2 else len(wav)) |
| return wav[:, s:e] if wav.ndim == 2 else wav[s:e] |
|
|
| if segments is None or len(segments) == 1: |
| out = wavs[0] |
| for nw in wavs[1:]: |
| out = _cf_join(out, nw, crossfade_s, db_boost, sr) |
| n = int(round(total_dur_s * sr)) |
| return out[:, :n] if out.ndim == 2 else out[:n] |
|
|
| half_cf = crossfade_s / 2.0 |
|
|
| |
| contact_edges = [ |
| (segments[i][1] + segments[i + 1][0]) / 2.0 |
| for i in range(len(segments) - 1) |
| ] |
|
|
| |
| trimmed = [] |
| for i, (wav, (seg_start, seg_end)) in enumerate(zip(wavs, segments)): |
| keep_start = (contact_edges[i - 1] - half_cf) if i > 0 else seg_start |
| keep_end = (contact_edges[i] + half_cf) if i < len(segments) - 1 else total_dur_s |
| trimmed.append(_trim(wav, keep_start, keep_end, seg_start)) |
|
|
| |
| out = trimmed[0] |
| for nw in trimmed[1:]: |
| out = _cf_join(out, nw, crossfade_s, db_boost, sr) |
|
|
| n = int(round(total_dur_s * sr)) |
| return out[:, :n] if out.ndim == 2 else out[:n] |
|
|
|
|
| def _save_wav(path: str, wav: np.ndarray, sr: int) -> None: |
| """Save a numpy wav array (mono or stereo) to *path* via torchaudio.""" |
| t = torch.from_numpy(np.ascontiguousarray(wav)) |
| if t.ndim == 1: |
| t = t.unsqueeze(0) |
| torchaudio.save(path, t, sr) |
|
|
|
|
| def _log_inference_timing(label: str, elapsed: float, n_segs: int, |
| num_steps: int, constant: float) -> None: |
| """Print a standardised inference-timing summary line.""" |
| total_steps = n_segs * num_steps |
| secs_per_step = elapsed / total_steps if total_steps > 0 else 0 |
| print(f"[{label}] Inference done: {n_segs} seg(s) Γ {num_steps} steps in " |
| f"{elapsed:.1f}s wall β {secs_per_step:.3f}s/step " |
| f"(current constant={constant})") |
|
|
|
|
| def _build_seg_meta(*, segments, wav_paths, audio_path, video_path, |
| silent_video, sr, model, crossfade_s, crossfade_db, |
| total_dur_s, **extras) -> dict: |
| """Build the seg_meta dict shared by all three generate_* functions. |
| Model-specific keys are passed via **extras.""" |
| meta = { |
| "segments": segments, |
| "wav_paths": wav_paths, |
| "audio_path": audio_path, |
| "video_path": video_path, |
| "silent_video": silent_video, |
| "sr": sr, |
| "model": model, |
| "crossfade_s": crossfade_s, |
| "crossfade_db": crossfade_db, |
| "total_dur_s": total_dur_s, |
| } |
| meta.update(extras) |
| return meta |
|
|
|
|
| def _post_process_samples(results: list, *, model: str, tmp_dir: str, |
| silent_video: str, segments: list, |
| crossfade_s: float, crossfade_db: float, |
| total_dur_s: float, sr: int, |
| extra_meta_fn=None) -> list: |
| """Shared CPU post-processing for all three generate_* wrappers. |
| |
| Each entry in *results* is a tuple whose first element is a list of |
| per-segment wav arrays. The remaining elements are model-specific |
| (e.g. TARO returns features, HunyuanFoley returns text_feats). |
| |
| *extra_meta_fn(sample_idx, result_tuple, tmp_dir) -> dict* is an optional |
| callback that returns model-specific extra keys to merge into seg_meta |
| (e.g. cavp_path, onset_path, text_feats_path). |
| |
| Returns a list of (video_path, audio_path, seg_meta) tuples. |
| """ |
| outputs = [] |
| for sample_idx, result in enumerate(results): |
| seg_wavs = result[0] |
|
|
| full_wav = _stitch_wavs(seg_wavs, crossfade_s, crossfade_db, total_dur_s, sr, segments) |
| audio_path = os.path.join(tmp_dir, f"{model}_{sample_idx}.wav") |
| _save_wav(audio_path, full_wav, sr) |
| video_path = os.path.join(tmp_dir, f"{model}_{sample_idx}.mp4") |
| mux_video_audio(silent_video, audio_path, video_path, model=model) |
| wav_paths = _save_seg_wavs(seg_wavs, tmp_dir, f"{model}_{sample_idx}") |
|
|
| extras = extra_meta_fn(sample_idx, result, tmp_dir) if extra_meta_fn else {} |
| seg_meta = _build_seg_meta( |
| segments=segments, wav_paths=wav_paths, audio_path=audio_path, |
| video_path=video_path, silent_video=silent_video, sr=sr, |
| model=model, crossfade_s=crossfade_s, crossfade_db=crossfade_db, |
| total_dur_s=total_dur_s, **extras, |
| ) |
| outputs.append((video_path, audio_path, seg_meta)) |
| return outputs |
|
|
|
|
| def _cpu_preprocess(video_file: str, model_dur: float, |
| crossfade_s: float) -> tuple: |
| """Shared CPU pre-processing for all generate_* wrappers. |
| Returns (tmp_dir, silent_video, total_dur_s, segments).""" |
| tmp_dir = _register_tmp_dir(tempfile.mkdtemp()) |
| silent_video = os.path.join(tmp_dir, "silent_input.mp4") |
| strip_audio_from_video(video_file, silent_video) |
| total_dur_s = get_video_duration(video_file) |
| segments = _build_segments(total_dur_s, model_dur, crossfade_s) |
| return tmp_dir, silent_video, total_dur_s, segments |
|
|
|
|
| @spaces.GPU(duration=_taro_duration) |
| def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode, |
| crossfade_s, crossfade_db, num_samples): |
| """GPU-only TARO inference β model loading + feature extraction + diffusion. |
| Returns list of (wavs_list, onset_feats) per sample.""" |
| seed_val = int(seed_val) |
| crossfade_s = float(crossfade_s) |
| num_samples = int(num_samples) |
| if seed_val < 0: |
| seed_val = random.randint(0, 2**32 - 1) |
|
|
| torch.set_grad_enabled(False) |
| device, weight_dtype = _get_device_and_dtype() |
|
|
| _ensure_syspath("TARO") |
| from TARO.onset_util import extract_onset |
| from TARO.samplers import euler_sampler, euler_maruyama_sampler |
|
|
| ctx = _ctx_load("taro_gpu_infer") |
| tmp_dir = ctx["tmp_dir"] |
| silent_video = ctx["silent_video"] |
| segments = ctx["segments"] |
| total_dur_s = ctx["total_dur_s"] |
|
|
| extract_cavp, onset_model = _load_taro_feature_extractors(device) |
| cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir) |
| |
| onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device) |
|
|
| |
| del extract_cavp, onset_model |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| model, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype) |
|
|
| results = [] |
| for sample_idx in range(num_samples): |
| sample_seed = seed_val + sample_idx |
| cache_key = (video_file, sample_seed, float(cfg_scale), int(num_steps), mode, crossfade_s) |
|
|
| with _TARO_CACHE_LOCK: |
| cached = _TARO_INFERENCE_CACHE.get(cache_key) |
| if cached is not None: |
| print(f"[TARO] Sample {sample_idx+1}: cache hit.") |
| results.append((cached["wavs"], cavp_feats, None)) |
| else: |
| set_global_seed(sample_seed) |
| wavs = [] |
| _t_infer_start = time.perf_counter() |
| for seg_start_s, seg_end_s in segments: |
| print(f"[TARO] Sample {sample_idx+1} | {seg_start_s:.2f}s β {seg_end_s:.2f}s") |
| wav = _taro_infer_segment( |
| model, vae, vocoder, |
| cavp_feats, onset_feats, |
| seg_start_s, seg_end_s, |
| device, weight_dtype, |
| cfg_scale, num_steps, mode, |
| latents_scale, |
| euler_sampler, euler_maruyama_sampler, |
| ) |
| wavs.append(wav) |
| _log_inference_timing("TARO", time.perf_counter() - _t_infer_start, |
| len(segments), int(num_steps), TARO_SECS_PER_STEP) |
| with _TARO_CACHE_LOCK: |
| _TARO_INFERENCE_CACHE[cache_key] = {"wavs": wavs} |
| while len(_TARO_INFERENCE_CACHE) > _TARO_CACHE_MAXLEN: |
| _TARO_INFERENCE_CACHE.pop(next(iter(_TARO_INFERENCE_CACHE))) |
| results.append((wavs, cavp_feats, onset_feats)) |
|
|
| |
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return results |
|
|
|
|
| def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode, |
| crossfade_s, crossfade_db, num_samples): |
| """TARO: video-conditioned diffusion, 16 kHz, 8.192 s sliding window. |
| CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost.""" |
| crossfade_s = float(crossfade_s) |
| crossfade_db = float(crossfade_db) |
| num_samples = int(num_samples) |
|
|
| |
| tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess( |
| video_file, TARO_MODEL_DUR, crossfade_s) |
|
|
| _ctx_store("taro_gpu_infer", { |
| "tmp_dir": tmp_dir, "silent_video": silent_video, |
| "segments": segments, "total_dur_s": total_dur_s, |
| }) |
|
|
| |
| results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode, |
| crossfade_s, crossfade_db, num_samples) |
|
|
| |
| |
| cavp_path = os.path.join(tmp_dir, "taro_cavp.npy") |
| onset_path = os.path.join(tmp_dir, "taro_onset.npy") |
| _feats_saved = False |
|
|
| def _upsample_and_save_feats(result): |
| nonlocal _feats_saved |
| wavs, cavp_feats, onset_feats = result |
| wavs = [_upsample_taro(w) for w in wavs] |
| if not _feats_saved: |
| np.save(cavp_path, cavp_feats) |
| if onset_feats is not None: |
| np.save(onset_path, onset_feats) |
| _feats_saved = True |
| return (wavs, cavp_feats, onset_feats) |
|
|
| results = [_upsample_and_save_feats(r) for r in results] |
|
|
| def _taro_extras(sample_idx, result, td): |
| return {"cavp_path": cavp_path, "onset_path": onset_path} |
|
|
| outputs = _post_process_samples( |
| results, model="taro", tmp_dir=tmp_dir, |
| silent_video=silent_video, segments=segments, |
| crossfade_s=crossfade_s, crossfade_db=crossfade_db, |
| total_dur_s=total_dur_s, sr=TARO_SR_OUT, |
| extra_meta_fn=_taro_extras, |
| ) |
| return _pad_outputs(outputs) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
|
|
| def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val, |
| cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples): |
| """Pre-GPU callable β must match _mmaudio_gpu_infer's input order exactly.""" |
| return _estimate_gpu_duration("mmaudio", int(num_samples), int(num_steps), |
| video_file=video_file, crossfade_s=crossfade_s) |
|
|
|
|
| @spaces.GPU(duration=_mmaudio_duration) |
| def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val, |
| cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples): |
| """GPU-only MMAudio inference β model loading + flow-matching generation. |
| Returns list of (seg_audios, sr) per sample.""" |
| _ensure_syspath("MMAudio") |
| from mmaudio.eval_utils import generate, load_video |
| from mmaudio.model.flow_matching import FlowMatching |
|
|
| seed_val = int(seed_val) |
| num_samples = int(num_samples) |
| crossfade_s = float(crossfade_s) |
|
|
| device, dtype = _get_device_and_dtype() |
|
|
| net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype) |
|
|
| ctx = _ctx_load("mmaudio_gpu_infer") |
| segments = ctx["segments"] |
| seg_clip_paths = ctx["seg_clip_paths"] |
|
|
| sr = seq_cfg.sampling_rate |
|
|
| results = [] |
| for sample_idx in range(num_samples): |
| rng = torch.Generator(device=device) |
| if seed_val >= 0: |
| rng.manual_seed(seed_val + sample_idx) |
| else: |
| rng.seed() |
|
|
| seg_audios = [] |
| _t_mma_start = time.perf_counter() |
| fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps) |
|
|
| for seg_i, (seg_start, seg_end) in enumerate(segments): |
| seg_dur = seg_end - seg_start |
| seg_path = seg_clip_paths[seg_i] |
|
|
| video_info = load_video(seg_path, seg_dur) |
| clip_frames = video_info.clip_frames.unsqueeze(0) |
| sync_frames = video_info.sync_frames.unsqueeze(0) |
| actual_dur = video_info.duration_sec |
|
|
| seq_cfg.duration = actual_dur |
| net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) |
|
|
| print(f"[MMAudio] Sample {sample_idx+1} | seg {seg_i+1}/{len(segments)} " |
| f"{seg_start:.1f}β{seg_end:.1f}s | dur={actual_dur:.2f}s | prompt='{prompt}'") |
|
|
| with torch.no_grad(): |
| audios = generate( |
| clip_frames, |
| sync_frames, |
| [prompt], |
| negative_text=[negative_prompt] if negative_prompt else None, |
| feature_utils=feature_utils, |
| net=net, |
| fm=fm, |
| rng=rng, |
| cfg_strength=float(cfg_strength), |
| ) |
| wav = audios.float().cpu()[0].numpy() |
| seg_audios.append(wav) |
|
|
| _log_inference_timing("MMAudio", time.perf_counter() - _t_mma_start, |
| len(segments), int(num_steps), MMAUDIO_SECS_PER_STEP) |
| results.append((seg_audios, sr)) |
|
|
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return results |
|
|
|
|
| def generate_mmaudio(video_file, prompt, negative_prompt, seed_val, |
| cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples): |
| """MMAudio: flow-matching video-to-audio, 44.1 kHz, 8 s sliding window. |
| CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost.""" |
| num_samples = int(num_samples) |
| crossfade_s = float(crossfade_s) |
| crossfade_db = float(crossfade_db) |
|
|
| |
| tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess( |
| video_file, MMAUDIO_WINDOW, crossfade_s) |
| print(f"[MMAudio] Video={total_dur_s:.2f}s | {len(segments)} segment(s) Γ β€8 s") |
|
|
| seg_clip_paths = [ |
| _extract_segment_clip(silent_video, s, e - s, os.path.join(tmp_dir, f"mma_seg_{i}.mp4")) |
| for i, (s, e) in enumerate(segments) |
| ] |
|
|
| _ctx_store("mmaudio_gpu_infer", {"segments": segments, "seg_clip_paths": seg_clip_paths}) |
|
|
| |
| results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val, |
| cfg_strength, num_steps, crossfade_s, crossfade_db, |
| num_samples) |
|
|
| |
| |
| resampled = [] |
| for seg_audios, sr in results: |
| if sr != TARGET_SR: |
| print(f"[MMAudio upsample] resampling {sr}Hz β {TARGET_SR}Hz (sinc, CPU) β¦") |
| seg_audios = [_resample_to_target(w, sr) for w in seg_audios] |
| print(f"[MMAudio upsample] done β {len(seg_audios)} seg(s) @ {TARGET_SR}Hz") |
| resampled.append((seg_audios,)) |
|
|
| outputs = _post_process_samples( |
| resampled, model="mmaudio", tmp_dir=tmp_dir, |
| silent_video=silent_video, segments=segments, |
| crossfade_s=crossfade_s, crossfade_db=crossfade_db, |
| total_dur_s=total_dur_s, sr=TARGET_SR, |
| ) |
| return _pad_outputs(outputs) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
|
|
| def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val, |
| guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, |
| num_samples): |
| """Pre-GPU callable β must match _hunyuan_gpu_infer's input order exactly.""" |
| return _estimate_gpu_duration("hunyuan", int(num_samples), int(num_steps), |
| video_file=video_file, crossfade_s=crossfade_s) |
|
|
|
|
| @spaces.GPU(duration=_hunyuan_duration) |
| def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val, |
| guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, |
| num_samples): |
| """GPU-only HunyuanFoley inference β model loading + feature extraction + denoising. |
| Returns list of (seg_wavs, sr, text_feats) per sample.""" |
| _ensure_syspath("HunyuanVideo-Foley") |
| from hunyuanvideo_foley.utils.model_utils import denoise_process |
| from hunyuanvideo_foley.utils.feature_utils import feature_process |
|
|
| seed_val = int(seed_val) |
| num_samples = int(num_samples) |
| crossfade_s = float(crossfade_s) |
| if seed_val >= 0: |
| set_global_seed(seed_val) |
|
|
| device, _ = _get_device_and_dtype() |
| device = torch.device(device) |
| model_size = model_size.lower() |
|
|
| model_dict, cfg = _load_hunyuan_model(device, model_size) |
|
|
| ctx = _ctx_load("hunyuan_gpu_infer") |
| segments = ctx["segments"] |
| total_dur_s = ctx["total_dur_s"] |
| dummy_seg_path = ctx["dummy_seg_path"] |
| seg_clip_paths = ctx["seg_clip_paths"] |
|
|
| |
| _, text_feats, _ = feature_process( |
| dummy_seg_path, |
| prompt if prompt else "", |
| model_dict, |
| cfg, |
| neg_prompt=negative_prompt if negative_prompt else None, |
| ) |
|
|
| |
| |
| from hunyuanvideo_foley.utils.feature_utils import encode_video_features |
|
|
| results = [] |
| for sample_idx in range(num_samples): |
| seg_wavs = [] |
| sr = 48000 |
| _t_hny_start = time.perf_counter() |
| for seg_i, (seg_start, seg_end) in enumerate(segments): |
| seg_dur = seg_end - seg_start |
| seg_path = seg_clip_paths[seg_i] |
|
|
| |
| visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict) |
| print(f"[HunyuanFoley] Sample {sample_idx+1} | seg {seg_i+1}/{len(segments)} " |
| f"{seg_start:.1f}β{seg_end:.1f}s β {seg_audio_len:.2f}s audio") |
|
|
| audio_batch, sr = denoise_process( |
| visual_feats, |
| text_feats, |
| seg_audio_len, |
| model_dict, |
| cfg, |
| guidance_scale=float(guidance_scale), |
| num_inference_steps=int(num_steps), |
| batch_size=1, |
| ) |
| wav = audio_batch[0].float().cpu().numpy() |
| seg_wavs.append(wav) |
|
|
| _log_inference_timing("HunyuanFoley", time.perf_counter() - _t_hny_start, |
| len(segments), int(num_steps), HUNYUAN_SECS_PER_STEP) |
| results.append((seg_wavs, sr, text_feats)) |
|
|
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return results |
|
|
|
|
| def generate_hunyuan(video_file, prompt, negative_prompt, seed_val, |
| guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples): |
| """HunyuanVideoFoley: text-guided foley, 48 kHz, up to 15 s. |
| CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost.""" |
| num_samples = int(num_samples) |
| crossfade_s = float(crossfade_s) |
| crossfade_db = float(crossfade_db) |
|
|
| |
| tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess( |
| video_file, HUNYUAN_MAX_DUR, crossfade_s) |
| print(f"[HunyuanFoley] Video={total_dur_s:.2f}s | {len(segments)} segment(s) Γ β€15 s") |
|
|
| |
| dummy_seg_path = _extract_segment_clip( |
| silent_video, 0, min(total_dur_s, HUNYUAN_MAX_DUR), |
| os.path.join(tmp_dir, "_seg_dummy.mp4"), |
| ) |
|
|
| |
| seg_clip_paths = [ |
| _extract_segment_clip(silent_video, s, e - s, os.path.join(tmp_dir, f"hny_seg_{i}.mp4")) |
| for i, (s, e) in enumerate(segments) |
| ] |
|
|
| _ctx_store("hunyuan_gpu_infer", { |
| "segments": segments, "total_dur_s": total_dur_s, |
| "dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths, |
| }) |
|
|
| |
| results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val, |
| guidance_scale, num_steps, model_size, |
| crossfade_s, crossfade_db, num_samples) |
|
|
| |
| def _hunyuan_extras(sample_idx, result, td): |
| _, _sr, text_feats = result |
| path = os.path.join(td, f"hunyuan_{sample_idx}_text_feats.pt") |
| torch.save(text_feats, path) |
| return {"text_feats_path": path} |
|
|
| outputs = _post_process_samples( |
| results, model="hunyuan", tmp_dir=tmp_dir, |
| silent_video=silent_video, segments=segments, |
| crossfade_s=crossfade_s, crossfade_db=crossfade_db, |
| total_dur_s=total_dur_s, sr=48000, |
| extra_meta_fn=_hunyuan_extras, |
| ) |
| return _pad_outputs(outputs) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def _splice_and_save(new_wav, seg_idx, meta, slot_id): |
| """Replace wavs[seg_idx] with new_wav, re-stitch, re-save, re-mux. |
| Returns (video_path, audio_path, updated_meta, waveform_html). |
| """ |
| wavs = _load_seg_wavs(meta["wav_paths"]) |
| wavs[seg_idx]= new_wav |
| crossfade_s = float(meta["crossfade_s"]) |
| crossfade_db = float(meta["crossfade_db"]) |
| sr = int(meta["sr"]) |
| total_dur_s = float(meta["total_dur_s"]) |
| silent_video = meta["silent_video"] |
| segments = meta["segments"] |
| model = meta["model"] |
|
|
| full_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, sr, segments) |
|
|
| |
| |
| _ts = int(time.time() * 1000) |
| tmp_dir = os.path.dirname(meta["audio_path"]) |
| _base = os.path.splitext(os.path.basename(meta["audio_path"]))[0] |
| |
| _base_clean = _base.rsplit("_regen_", 1)[0] |
| audio_path = os.path.join(tmp_dir, f"{_base_clean}_regen_{_ts}.wav") |
| _save_wav(audio_path, full_wav, sr) |
|
|
| |
| _vid_base = os.path.splitext(os.path.basename(meta["video_path"]))[0] |
| _vid_base_clean = _vid_base.rsplit("_regen_", 1)[0] |
| video_path = os.path.join(tmp_dir, f"{_vid_base_clean}_regen_{_ts}.mp4") |
| mux_video_audio(silent_video, audio_path, video_path, model=model) |
|
|
| |
| updated_wav_paths = _save_seg_wavs(wavs, tmp_dir, os.path.splitext(_base_clean)[0]) |
| updated_meta = dict(meta) |
| updated_meta["wav_paths"] = updated_wav_paths |
| updated_meta["audio_path"] = audio_path |
| updated_meta["video_path"] = video_path |
|
|
| state_json_new = json.dumps(updated_meta) |
|
|
| waveform_html = _build_waveform_html(audio_path, segments, slot_id, "", |
| state_json=state_json_new, |
| video_path=video_path, |
| crossfade_s=crossfade_s) |
| return video_path, audio_path, updated_meta, waveform_html |
|
|
|
|
| def _taro_regen_duration(video_file, seg_idx, seg_meta_json, |
| seed_val, cfg_scale, num_steps, mode, |
| crossfade_s, crossfade_db, slot_id=None): |
| |
| try: |
| meta = json.loads(seg_meta_json) |
| cavp_ok = os.path.exists(meta.get("cavp_path", "")) |
| onset_ok = os.path.exists(meta.get("onset_path", "")) |
| if cavp_ok and onset_ok: |
| cfg = MODEL_CONFIGS["taro"] |
| secs = int(num_steps) * cfg["secs_per_step"] + 5 |
| result = min(GPU_DURATION_CAP, max(30, int(secs))) |
| print(f"[duration] TARO regen (cache hit): 1 seg Γ {int(num_steps)} steps β {secs:.0f}s β capped {result}s") |
| return result |
| except Exception: |
| pass |
| return _estimate_regen_duration("taro", int(num_steps)) |
|
|
|
|
| @spaces.GPU(duration=_taro_regen_duration) |
| def _regen_taro_gpu(video_file, seg_idx, seg_meta_json, |
| seed_val, cfg_scale, num_steps, mode, |
| crossfade_s, crossfade_db, slot_id=None): |
| """GPU-only TARO regen β returns new_wav for a single segment.""" |
| meta = json.loads(seg_meta_json) |
| seg_idx = int(seg_idx) |
| seg_start_s, seg_end_s = meta["segments"][seg_idx] |
|
|
| torch.set_grad_enabled(False) |
| device, weight_dtype = _get_device_and_dtype() |
|
|
| _ensure_syspath("TARO") |
| from TARO.samplers import euler_sampler, euler_maruyama_sampler |
|
|
| |
| cavp_path = meta.get("cavp_path", "") |
| onset_path = meta.get("onset_path", "") |
| if cavp_path and os.path.exists(cavp_path) and onset_path and os.path.exists(onset_path): |
| print("[TARO regen] Loading cached CAVP + onset features from disk") |
| cavp_feats = np.load(cavp_path) |
| onset_feats = np.load(onset_path) |
| else: |
| print("[TARO regen] Cache miss β re-extracting CAVP + onset features") |
| from TARO.onset_util import extract_onset |
| extract_cavp, onset_model = _load_taro_feature_extractors(device) |
| silent_video = meta["silent_video"] |
| tmp_dir = tempfile.mkdtemp() |
| cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir) |
| onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device) |
| del extract_cavp, onset_model |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| model_net, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype) |
|
|
| set_global_seed(random.randint(0, 2**32 - 1)) |
|
|
| return _taro_infer_segment( |
| model_net, vae, vocoder, cavp_feats, onset_feats, |
| seg_start_s, seg_end_s, device, weight_dtype, |
| float(cfg_scale), int(num_steps), mode, latents_scale, |
| euler_sampler, euler_maruyama_sampler, |
| ) |
|
|
|
|
| def regen_taro_segment(video_file, seg_idx, seg_meta_json, |
| seed_val, cfg_scale, num_steps, mode, |
| crossfade_s, crossfade_db, slot_id): |
| """Regenerate one TARO segment. GPU inference + CPU splice/save.""" |
| meta = json.loads(seg_meta_json) |
| seg_idx = int(seg_idx) |
|
|
| |
| new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json, |
| seed_val, cfg_scale, num_steps, mode, |
| crossfade_s, crossfade_db, slot_id) |
|
|
| |
| new_wav = _upsample_taro(new_wav) |
| |
| video_path, audio_path, updated_meta, waveform_html = _splice_and_save( |
| new_wav, seg_idx, meta, slot_id |
| ) |
| return video_path, audio_path, json.dumps(updated_meta), waveform_html |
|
|
|
|
| def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json, |
| prompt, negative_prompt, seed_val, |
| cfg_strength, num_steps, crossfade_s, crossfade_db, |
| slot_id=None): |
| return _estimate_regen_duration("mmaudio", int(num_steps)) |
|
|
|
|
| @spaces.GPU(duration=_mmaudio_regen_duration) |
| def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json, |
| prompt, negative_prompt, seed_val, |
| cfg_strength, num_steps, crossfade_s, crossfade_db, |
| slot_id=None): |
| """GPU-only MMAudio regen β returns (new_wav, sr) for a single segment.""" |
| meta = json.loads(seg_meta_json) |
| seg_idx = int(seg_idx) |
| seg_start, seg_end = meta["segments"][seg_idx] |
| seg_dur = seg_end - seg_start |
|
|
| _ensure_syspath("MMAudio") |
| from mmaudio.eval_utils import generate, load_video |
| from mmaudio.model.flow_matching import FlowMatching |
|
|
| device, dtype = _get_device_and_dtype() |
|
|
| net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype) |
| sr = seq_cfg.sampling_rate |
|
|
| |
| |
| seg_path = _extract_segment_clip( |
| meta["silent_video"], seg_start, seg_dur, |
| os.path.join(tempfile.mkdtemp(), "regen_seg.mp4"), |
| ) |
|
|
| rng = torch.Generator(device=device) |
| rng.manual_seed(random.randint(0, 2**32 - 1)) |
|
|
| fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=int(num_steps)) |
| video_info = load_video(seg_path, seg_dur) |
| clip_frames = video_info.clip_frames.unsqueeze(0) |
| sync_frames = video_info.sync_frames.unsqueeze(0) |
| actual_dur = video_info.duration_sec |
| seq_cfg.duration = actual_dur |
| net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) |
|
|
| with torch.no_grad(): |
| audios = generate( |
| clip_frames, sync_frames, [prompt], |
| negative_text=[negative_prompt] if negative_prompt else None, |
| feature_utils=feature_utils, net=net, fm=fm, rng=rng, |
| cfg_strength=float(cfg_strength), |
| ) |
| new_wav = audios.float().cpu()[0].numpy() |
| return new_wav, sr |
|
|
|
|
| def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json, |
| prompt, negative_prompt, seed_val, |
| cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id): |
| """Regenerate one MMAudio segment. GPU inference + CPU splice/save.""" |
| meta = json.loads(seg_meta_json) |
| seg_idx = int(seg_idx) |
|
|
| |
| new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json, |
| prompt, negative_prompt, seed_val, |
| cfg_strength, num_steps, crossfade_s, crossfade_db, |
| slot_id) |
|
|
| |
| if sr != TARGET_SR: |
| print(f"[MMAudio regen upsample] {sr}Hz β {TARGET_SR}Hz (sinc, CPU) β¦") |
| new_wav = _resample_to_target(new_wav, sr) |
| sr = TARGET_SR |
| meta["sr"] = sr |
|
|
| |
| video_path, audio_path, updated_meta, waveform_html = _splice_and_save( |
| new_wav, seg_idx, meta, slot_id |
| ) |
| return video_path, audio_path, json.dumps(updated_meta), waveform_html |
|
|
|
|
| def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json, |
| prompt, negative_prompt, seed_val, |
| guidance_scale, num_steps, model_size, |
| crossfade_s, crossfade_db, slot_id=None): |
| return _estimate_regen_duration("hunyuan", int(num_steps)) |
|
|
|
|
| @spaces.GPU(duration=_hunyuan_regen_duration) |
| def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json, |
| prompt, negative_prompt, seed_val, |
| guidance_scale, num_steps, model_size, |
| crossfade_s, crossfade_db, slot_id=None): |
| """GPU-only HunyuanFoley regen β returns (new_wav, sr) for a single segment.""" |
| meta = json.loads(seg_meta_json) |
| seg_idx = int(seg_idx) |
| seg_start, seg_end = meta["segments"][seg_idx] |
| seg_dur = seg_end - seg_start |
|
|
| _ensure_syspath("HunyuanVideo-Foley") |
| from hunyuanvideo_foley.utils.model_utils import denoise_process |
| from hunyuanvideo_foley.utils.feature_utils import feature_process |
|
|
| device, _ = _get_device_and_dtype() |
| device = torch.device(device) |
| model_dict, cfg = _load_hunyuan_model(device, model_size) |
|
|
| set_global_seed(random.randint(0, 2**32 - 1)) |
|
|
| |
| seg_path = _extract_segment_clip( |
| meta["silent_video"], seg_start, seg_dur, |
| os.path.join(tempfile.mkdtemp(), "regen_seg.mp4"), |
| ) |
|
|
| text_feats_path = meta.get("text_feats_path", "") |
| if text_feats_path and os.path.exists(text_feats_path): |
| print("[HunyuanFoley regen] Loading cached text features from disk") |
| from hunyuanvideo_foley.utils.feature_utils import encode_video_features |
| visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict) |
| text_feats = torch.load(text_feats_path, map_location=device, weights_only=False) |
| else: |
| print("[HunyuanFoley regen] Cache miss β extracting text + visual features") |
| visual_feats, text_feats, seg_audio_len = feature_process( |
| seg_path, prompt if prompt else "", model_dict, cfg, |
| neg_prompt=negative_prompt if negative_prompt else None, |
| ) |
| audio_batch, sr = denoise_process( |
| visual_feats, text_feats, seg_audio_len, model_dict, cfg, |
| guidance_scale=float(guidance_scale), |
| num_inference_steps=int(num_steps), |
| batch_size=1, |
| ) |
| new_wav = audio_batch[0].float().cpu().numpy() |
| return new_wav, sr |
|
|
|
|
| def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json, |
| prompt, negative_prompt, seed_val, |
| guidance_scale, num_steps, model_size, |
| crossfade_s, crossfade_db, slot_id): |
| """Regenerate one HunyuanFoley segment. GPU inference + CPU splice/save.""" |
| meta = json.loads(seg_meta_json) |
| seg_idx = int(seg_idx) |
|
|
| |
| new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json, |
| prompt, negative_prompt, seed_val, |
| guidance_scale, num_steps, model_size, |
| crossfade_s, crossfade_db, slot_id) |
|
|
| meta["sr"] = sr |
|
|
| |
| video_path, audio_path, updated_meta, waveform_html = _splice_and_save( |
| new_wav, seg_idx, meta, slot_id |
| ) |
| return video_path, audio_path, json.dumps(updated_meta), waveform_html |
|
|
|
|
| |
| MODEL_CONFIGS["taro"]["regen_fn"] = regen_taro_segment |
| MODEL_CONFIGS["mmaudio"]["regen_fn"] = regen_mmaudio_segment |
| MODEL_CONFIGS["hunyuan"]["regen_fn"] = regen_hunyuan_segment |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def _resample_to_slot_sr(wav: np.ndarray, src_sr: int, dst_sr: int, |
| slot_wav_ref: np.ndarray = None) -> np.ndarray: |
| """Resample *wav* from src_sr to dst_sr, then match channel layout to |
| *slot_wav_ref* (the first existing segment in the slot). |
| |
| TARO is mono (T,), MMAudio/Hunyuan are stereo (C, T). Mixing them |
| without normalisation causes a shape mismatch in _cf_join. Rules: |
| - stereo β mono : average channels |
| - mono β stereo: duplicate the single channel |
| """ |
| wav = _resample_to_target(wav, src_sr, dst_sr) |
|
|
| |
| if slot_wav_ref is not None: |
| slot_stereo = slot_wav_ref.ndim == 2 |
| wav_stereo = wav.ndim == 2 |
| if slot_stereo and not wav_stereo: |
| wav = np.stack([wav, wav], axis=0) |
| elif not slot_stereo and wav_stereo: |
| wav = wav.mean(axis=0) |
| return wav |
|
|
|
|
| def _xregen_clip_window(meta: dict, seg_idx: int, target_window_s: float) -> tuple: |
| """Compute the video clip window for a cross-model regen. |
| |
| Centers *target_window_s* on the original segment's midpoint, clamped to |
| [0, total_dur_s]. Returns (clip_start, clip_end, clip_dur). |
| |
| If the video is shorter than *target_window_s*, the full video is used |
| (suboptimal but never breaks). If the segment span exceeds |
| *target_window_s*, the caller should run _build_segments on the span and |
| generate multiple sub-segments β but the clip window is still returned as |
| the full segment span so the caller can decide. |
| """ |
| total_dur_s = float(meta["total_dur_s"]) |
| seg_start, seg_end = meta["segments"][seg_idx] |
| seg_mid = (seg_start + seg_end) / 2.0 |
| half_win = target_window_s / 2.0 |
|
|
| clip_start = max(0.0, seg_mid - half_win) |
| clip_end = min(total_dur_s, seg_mid + half_win) |
| |
| if clip_start == 0.0: |
| clip_end = min(total_dur_s, target_window_s) |
| elif clip_end == total_dur_s: |
| clip_start = max(0.0, total_dur_s - target_window_s) |
| clip_dur = clip_end - clip_start |
| return clip_start, clip_end, clip_dur |
|
|
|
|
| def _xregen_splice(new_wav_raw: np.ndarray, src_sr: int, |
| meta: dict, seg_idx: int, slot_id: str, |
| clip_start_s: float = None) -> tuple: |
| """Shared epilogue for all xregen_* functions: resample β splice β save. |
| Returns (video_path, waveform_html). |
| |
| *clip_start_s* is the absolute video time where new_wav_raw starts. |
| When the clip was centered on the segment midpoint (not at seg_start), |
| we need to shift the wav so _stitch_wavs can trim it correctly relative |
| to the original segment's start. We do this by prepending silence so |
| the wav's time origin aligns with the original segment's start. |
| """ |
| slot_sr = int(meta["sr"]) |
| slot_wavs = _load_seg_wavs(meta["wav_paths"]) |
| new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr, slot_wavs[0]) |
|
|
| |
| |
| if clip_start_s is not None: |
| seg_start = meta["segments"][seg_idx][0] |
| offset_s = seg_start - clip_start_s |
| if offset_s < 0: |
| |
| pad_samples = int(round(abs(offset_s) * slot_sr)) |
| silence = np.zeros( |
| (new_wav.shape[0], pad_samples) if new_wav.ndim == 2 else pad_samples, |
| dtype=new_wav.dtype, |
| ) |
| new_wav = np.concatenate([silence, new_wav], axis=1 if new_wav.ndim == 2 else 0) |
|
|
| video_path, audio_path, updated_meta, waveform_html = _splice_and_save( |
| new_wav, seg_idx, meta, slot_id |
| ) |
| return video_path, waveform_html |
|
|
|
|
| def _xregen_dispatch(state_json: str, seg_idx: int, slot_id: str, infer_fn): |
| """Shared generator skeleton for all xregen_* wrappers. |
| |
| Yields pending HTML immediately, then calls *infer_fn()* β a zero-argument |
| callable that runs model-specific CPU prep + GPU inference and returns |
| (wav_array, src_sr, clip_start_s). For TARO, *infer_fn* should return |
| the wav already upsampled to 48 kHz; pass TARO_SR_OUT as src_sr. |
| |
| Yields: |
| First: (gr.update(), gr.update(value=pending_html)) β shown while GPU runs |
| Second: (gr.update(value=video_path), gr.update(value=waveform_html)) |
| """ |
| meta = json.loads(state_json) |
| pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "") |
| yield gr.update(), gr.update(value=pending_html) |
|
|
| new_wav_raw, src_sr, clip_start_s = infer_fn() |
| video_path, waveform_html = _xregen_splice(new_wav_raw, src_sr, meta, seg_idx, slot_id, clip_start_s) |
| yield gr.update(value=video_path), gr.update(value=waveform_html) |
|
|
|
|
| def xregen_taro(seg_idx, state_json, slot_id, |
| seed_val, cfg_scale, num_steps, mode, |
| crossfade_s, crossfade_db, |
| request: gr.Request = None): |
| """Cross-model regen: run TARO on its optimal window, splice into *slot_id*.""" |
| seg_idx = int(seg_idx) |
| meta = json.loads(state_json) |
|
|
| def _run(): |
| clip_start, clip_end, clip_dur = _xregen_clip_window(meta, seg_idx, TARO_MODEL_DUR) |
| tmp_dir = _register_tmp_dir(tempfile.mkdtemp()) |
| clip_path = _extract_segment_clip( |
| meta["silent_video"], clip_start, clip_dur, |
| os.path.join(tmp_dir, "xregen_taro_clip.mp4"), |
| ) |
| |
| sub_segs = _build_segments(clip_dur, TARO_MODEL_DUR, float(crossfade_s)) |
| sub_meta_json = json.dumps({ |
| "segments": sub_segs, "silent_video": clip_path, |
| "total_dur_s": clip_dur, |
| }) |
| |
| _ctx_store("taro_gpu_infer", { |
| "tmp_dir": tmp_dir, "silent_video": clip_path, |
| "segments": sub_segs, "total_dur_s": clip_dur, |
| }) |
| results = _taro_gpu_infer(clip_path, seed_val, cfg_scale, num_steps, mode, |
| crossfade_s, crossfade_db, 1) |
| wavs, _, _ = results[0] |
| wavs = [_upsample_taro(w) for w in wavs] |
| wav = _stitch_wavs(wavs, float(crossfade_s), float(crossfade_db), |
| clip_dur, TARO_SR_OUT, sub_segs) |
| return wav, TARO_SR_OUT, clip_start |
|
|
| yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run) |
|
|
|
|
| def xregen_mmaudio(seg_idx, state_json, slot_id, |
| prompt, negative_prompt, seed_val, |
| cfg_strength, num_steps, crossfade_s, crossfade_db, |
| request: gr.Request = None): |
| """Cross-model regen: run MMAudio on its optimal window, splice into *slot_id*.""" |
| seg_idx = int(seg_idx) |
| meta = json.loads(state_json) |
|
|
| def _run(): |
| clip_start, clip_end, clip_dur = _xregen_clip_window(meta, seg_idx, MMAUDIO_WINDOW) |
| tmp_dir = _register_tmp_dir(tempfile.mkdtemp()) |
| clip_path = _extract_segment_clip( |
| meta["silent_video"], clip_start, clip_dur, |
| os.path.join(tmp_dir, "xregen_mmaudio_clip.mp4"), |
| ) |
| sub_segs = _build_segments(clip_dur, MMAUDIO_WINDOW, float(crossfade_s)) |
| seg_clip_paths = [ |
| _extract_segment_clip( |
| clip_path, s, e - s, |
| os.path.join(tmp_dir, f"xregen_mma_sub_{i}.mp4"), |
| ) |
| for i, (s, e) in enumerate(sub_segs) |
| ] |
| _ctx_store("mmaudio_gpu_infer", { |
| "segments": sub_segs, "seg_clip_paths": seg_clip_paths, |
| }) |
| results = _mmaudio_gpu_infer(clip_path, prompt, negative_prompt, seed_val, |
| cfg_strength, num_steps, crossfade_s, crossfade_db, 1) |
| seg_wavs, sr = results[0] |
| wav = _stitch_wavs(seg_wavs, float(crossfade_s), float(crossfade_db), |
| clip_dur, sr, sub_segs) |
| if sr != TARGET_SR: |
| wav = _resample_to_target(wav, sr) |
| sr = TARGET_SR |
| return wav, sr, clip_start |
|
|
| yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run) |
|
|
|
|
| def xregen_hunyuan(seg_idx, state_json, slot_id, |
| prompt, negative_prompt, seed_val, |
| guidance_scale, num_steps, model_size, |
| crossfade_s, crossfade_db, |
| request: gr.Request = None): |
| """Cross-model regen: run HunyuanFoley on its optimal window, splice into *slot_id*.""" |
| seg_idx = int(seg_idx) |
| meta = json.loads(state_json) |
|
|
| def _run(): |
| clip_start, clip_end, clip_dur = _xregen_clip_window(meta, seg_idx, HUNYUAN_MAX_DUR) |
| tmp_dir = _register_tmp_dir(tempfile.mkdtemp()) |
| clip_path = _extract_segment_clip( |
| meta["silent_video"], clip_start, clip_dur, |
| os.path.join(tmp_dir, "xregen_hunyuan_clip.mp4"), |
| ) |
| sub_segs = _build_segments(clip_dur, HUNYUAN_MAX_DUR, float(crossfade_s)) |
| seg_clip_paths = [ |
| _extract_segment_clip( |
| clip_path, s, e - s, |
| os.path.join(tmp_dir, f"xregen_hny_sub_{i}.mp4"), |
| ) |
| for i, (s, e) in enumerate(sub_segs) |
| ] |
| dummy_seg_path = _extract_segment_clip( |
| clip_path, 0, min(clip_dur, HUNYUAN_MAX_DUR), |
| os.path.join(tmp_dir, "xregen_hny_dummy.mp4"), |
| ) |
| _ctx_store("hunyuan_gpu_infer", { |
| "segments": sub_segs, "total_dur_s": clip_dur, |
| "dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths, |
| }) |
| results = _hunyuan_gpu_infer(clip_path, prompt, negative_prompt, seed_val, |
| guidance_scale, num_steps, model_size, |
| crossfade_s, crossfade_db, 1) |
| seg_wavs, sr, _ = results[0] |
| wav = _stitch_wavs(seg_wavs, float(crossfade_s), float(crossfade_db), |
| clip_dur, sr, sub_segs) |
| return wav, sr, clip_start |
|
|
| yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run) |
|
|
|
|
| |
| |
| |
|
|
| def _register_regen_handlers(tab_prefix, model_key, regen_seg_tb, regen_state_tb, |
| input_components, slot_vids, slot_waves): |
| """Register per-slot regen button handlers for a model tab. |
| |
| This replaces the three nearly-identical for-loops that previously existed |
| for TARO, MMAudio, and HunyuanFoley tabs. |
| |
| Args: |
| tab_prefix: e.g. "taro", "mma", "hf" |
| model_key: e.g. "taro", "mmaudio", "hunyuan" |
| regen_seg_tb: gr.Textbox for seg_idx (render=False) |
| regen_state_tb: gr.Textbox for state_json (render=False) |
| input_components: list of Gradio input components (video, seed, etc.) |
| β order must match regen_fn signature after (seg_idx, state_json, video) |
| slot_vids: list of gr.Video components per slot |
| slot_waves: list of gr.HTML components per slot |
| Returns: |
| list of hidden gr.Buttons (one per slot) |
| """ |
| cfg = MODEL_CONFIGS[model_key] |
| regen_fn = cfg["regen_fn"] |
| label = cfg["label"] |
| btns = [] |
| for _i in range(MAX_SLOTS): |
| _slot_id = f"{tab_prefix}_{_i}" |
| _btn = gr.Button(render=False, elem_id=f"regen_btn_{_slot_id}") |
| btns.append(_btn) |
| print(f"[startup] registering regen handler for slot {_slot_id}") |
|
|
| def _make_regen(_si, _sid, _model_key, _label, _regen_fn): |
| def _do(seg_idx, state_json, *args): |
| print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} " |
| f"state_json_len={len(state_json) if state_json else 0}") |
| if not state_json: |
| print(f"[regen {_label}] early-exit: state_json empty") |
| yield gr.update(), gr.update() |
| return |
| lock = _get_slot_lock(_sid) |
| with lock: |
| state = json.loads(state_json) |
| pending_html = _build_regen_pending_html( |
| state["segments"], int(seg_idx), _sid, "" |
| ) |
| yield gr.update(), gr.update(value=pending_html) |
| print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} β calling regen") |
| try: |
| |
| vid, aud, new_meta_json, html = _regen_fn( |
| args[0], int(seg_idx), state_json, *args[1:], _sid, |
| ) |
| print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} β done, vid={vid!r}") |
| except Exception as _e: |
| print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} β ERROR: {_e}") |
| raise |
| yield gr.update(value=vid), gr.update(value=html) |
| return _do |
|
|
| _btn.click( |
| fn=_make_regen(_i, _slot_id, model_key, label, regen_fn), |
| inputs=[regen_seg_tb, regen_state_tb] + input_components, |
| outputs=[slot_vids[_i], slot_waves[_i]], |
| api_name=f"regen_{tab_prefix}_{_i}", |
| ) |
| return btns |
|
|
|
|
| def _pad_outputs(outputs: list) -> list: |
| """Flatten (video, audio, seg_meta) triples and pad to MAX_SLOTS * 3 with None. |
| |
| Each entry in *outputs* must be a (video_path, audio_path, seg_meta) tuple where |
| seg_meta = {"segments": [...], "audio_path": str, "video_path": str, |
| "sr": int, "model": str, "crossfade_s": float, |
| "crossfade_db": float, "wav_paths": list[str]} |
| """ |
| result = [] |
| for i in range(MAX_SLOTS): |
| if i < len(outputs): |
| result.extend(outputs[i]) |
| else: |
| result.extend([None, None, None]) |
| return result |
|
|
|
|
| |
| |
| |
|
|
| def _build_regen_pending_html(segments: list, regen_seg_idx: int, slot_id: str, |
| hidden_input_id: str) -> str: |
| """Return a waveform placeholder shown while a segment is being regenerated. |
| |
| Renders a dark bar with the active segment highlighted in amber + a spinner. |
| """ |
| segs_json = json.dumps(segments) |
| seg_colors = [c.format(a="0.25") for c in SEG_COLORS] |
| active_color = "rgba(255,180,0,0.55)" |
| duration = segments[-1][1] if segments else 1.0 |
|
|
| seg_divs = "" |
| for i, seg in enumerate(segments): |
| |
| |
| |
| |
| seg_start = seg[0] |
| seg_end = segments[i + 1][0] if i + 1 < len(segments) else seg[1] |
| left_pct = seg_start / duration * 100 |
| width_pct = (seg_end - seg_start) / duration * 100 |
| color = active_color if i == regen_seg_idx else seg_colors[i % len(seg_colors)] |
| extra = "border:2px solid #ffb300;animation:wf_pulse 0.8s ease-in-out infinite alternate;" if i == regen_seg_idx else "" |
| seg_divs += ( |
| f'<div style="position:absolute;top:0;left:{left_pct:.2f}%;' |
| f'width:{width_pct:.2f}%;height:100%;background:{color};{extra}">' |
| f'<span style="color:rgba(255,255,255,0.7);font-size:10px;padding:2px 3px;">Seg {i+1}</span>' |
| f'</div>' |
| ) |
|
|
| spinner = ( |
| '<div style="position:absolute;top:50%;left:50%;transform:translate(-50%,-50%);' |
| 'display:flex;align-items:center;gap:6px;">' |
| '<div style="width:14px;height:14px;border:2px solid #ffb300;' |
| 'border-top-color:transparent;border-radius:50%;' |
| 'animation:wf_spin 0.7s linear infinite;"></div>' |
| f'<span style="color:#ffb300;font-size:12px;white-space:nowrap;">' |
| f'Regenerating Seg {regen_seg_idx+1}β¦</span>' |
| '</div>' |
| ) |
|
|
| return f""" |
| <style> |
| @keyframes wf_pulse {{from{{opacity:0.5}}to{{opacity:1}}}} |
| @keyframes wf_spin {{to{{transform:rotate(360deg)}}}} |
| </style> |
| <div style="background:#1a1a1a;border-radius:8px;padding:10px;margin-top:6px;"> |
| <div style="position:relative;width:100%;height:80px;background:#1e1e2e;border-radius:4px;overflow:hidden;"> |
| {seg_divs} |
| {spinner} |
| </div> |
| <div style="color:#888;font-size:11px;margin-top:6px;">Regenerating β please waitβ¦</div> |
| </div> |
| """ |
|
|
|
|
| def _build_waveform_html(audio_path: str, segments: list, slot_id: str, |
| hidden_input_id: str, state_json: str = "", |
| fn_index: int = -1, video_path: str = "", |
| crossfade_s: float = 0.0) -> str: |
| """Return a self-contained HTML block with a Canvas waveform (display only), |
| segment boundary markers, and a download link. |
| |
| Uses Web Audio API + Canvas β no external libraries. |
| The waveform is SILENT. The playhead tracks the Gradio <video> element |
| in the same slot via its timeupdate event. |
| """ |
| if not audio_path or not os.path.exists(audio_path): |
| return "<p style='color:#888;font-size:12px'>No audio yet.</p>" |
|
|
| |
| |
| audio_url = f"/gradio_api/file={audio_path}" |
|
|
| segs_json = json.dumps(segments) |
| seg_colors = [c.format(a="0.35") for c in SEG_COLORS] |
|
|
| |
| |
| |
| |
| |
|
|
| iframe_inner = f"""<!DOCTYPE html> |
| <html> |
| <head> |
| <meta charset="utf-8"> |
| <style> |
| * {{ margin:0; padding:0; box-sizing:border-box; }} |
| body {{ background:#1a1a1a; overflow:hidden; }} |
| #wrap {{ position:relative; width:100%; height:80px; }} |
| canvas {{ display:block; }} |
| #cv {{ position:absolute; top:0; left:0; width:100%; height:100%; }} |
| #cvp {{ position:absolute; top:0; left:0; width:100%; height:100%; pointer-events:none; }} |
| </style> |
| </head> |
| <body> |
| <div id="wrap"> |
| <canvas id="cv"></canvas> |
| <canvas id="cvp"></canvas> |
| </div> |
| <script> |
| (function() {{ |
| const SLOT_ID = '{slot_id}'; |
| const segments = {segs_json}; |
| const segColors = {json.dumps(seg_colors)}; |
| const crossfadeSec = {crossfade_s}; |
| let audioDuration = 0; |
| |
| // ββ Popup via postMessage to parent global listener βββββββββββββββββ |
| // The parent page (Gradio) has a global window.addEventListener('message',...) |
| // set up via gr.Blocks(js=...) that handles popup show/hide and regen trigger. |
| function showPopup(idx, mx, my) {{ |
| console.log('[wf showPopup] slot='+SLOT_ID+' idx='+idx+' posting to parent'); |
| // Convert iframe-local coords to parent page coords |
| try {{ |
| const fr = window.frameElement ? window.frameElement.getBoundingClientRect() : {{left:0,top:0}}; |
| window.parent.postMessage({{ |
| type:'wf_popup', action:'show', |
| slot_id: SLOT_ID, seg_idx: idx, |
| t0: segments[idx][0], t1: segments[idx][1], |
| x: mx + fr.left, y: my + fr.top |
| }}, '*'); |
| console.log('[wf showPopup] postMessage sent OK'); |
| }} catch(e) {{ |
| console.log('[wf showPopup] postMessage fallback, err='+e.message); |
| window.parent.postMessage({{ |
| type:'wf_popup', action:'show', |
| slot_id: SLOT_ID, seg_idx: idx, |
| t0: segments[idx][0], t1: segments[idx][1], |
| x: mx, y: my |
| }}, '*'); |
| }} |
| }} |
| function hidePopup() {{ |
| window.parent.postMessage({{type:'wf_popup', action:'hide'}}, '*'); |
| }} |
| |
| // ββ Canvas waveform ββββββββββββββββββββββββββββββββββββββββββββββββ |
| const cv = document.getElementById('cv'); |
| const cvp = document.getElementById('cvp'); |
| const wrap= document.getElementById('wrap'); |
| |
| function drawWaveform(channelData, duration) {{ |
| audioDuration = duration; |
| const dpr = window.devicePixelRatio || 1; |
| const W = wrap.getBoundingClientRect().width || window.innerWidth || 600; |
| const H = 80; |
| cv.width = W * dpr; cv.height = H * dpr; |
| const ctx = cv.getContext('2d'); |
| ctx.scale(dpr, dpr); |
| |
| ctx.fillStyle = '#1e1e2e'; |
| ctx.fillRect(0, 0, W, H); |
| |
| segments.forEach(function(seg, idx) {{ |
| // Color boundary = midpoint of the crossfade zone = where the blend is |
| // 50/50. This is also where the cut would land if crossfade were 0, and |
| // where the listener perceptually hears the transition to the next segment. |
| const x1 = (seg[0] / duration) * W; |
| const xEnd = idx + 1 < segments.length |
| ? ((segments[idx + 1][0] + crossfadeSec / 2) / duration) * W |
| : (seg[1] / duration) * W; |
| ctx.fillStyle = segColors[idx % segColors.length]; |
| ctx.fillRect(x1, 0, xEnd - x1, H); |
| ctx.fillStyle = 'rgba(255,255,255,0.6)'; |
| ctx.font = '10px sans-serif'; |
| ctx.fillText('Seg '+(idx+1), x1+3, 12); |
| }}); |
| |
| const samples = channelData.length; |
| const barW=2, gap=1, step=barW+gap; |
| const numBars = Math.floor(W / step); |
| const blockSz = Math.floor(samples / numBars); |
| ctx.fillStyle = '#4a9eff'; |
| for (let i=0; i<numBars; i++) {{ |
| let max=0; |
| const s=i*blockSz; |
| for (let j=0; j<blockSz; j++) {{ |
| const v=Math.abs(channelData[s+j]||0); |
| if (v>max) max=v; |
| }} |
| const barH=Math.max(1, max*H); |
| ctx.fillRect(i*step, (H-barH)/2, barW, barH); |
| }} |
| |
| segments.forEach(function(seg) {{ |
| [seg[0],seg[1]].forEach(function(t) {{ |
| const x=(t/duration)*W; |
| ctx.strokeStyle='rgba(255,255,255,0.4)'; |
| ctx.lineWidth=1; |
| ctx.beginPath(); ctx.moveTo(x,0); ctx.lineTo(x,H); ctx.stroke(); |
| }}); |
| }}); |
| |
| // ββ Crossfade overlap indicators ββ |
| // The color boundary is at segments[i+1][0] (= seg_i.end - crossfadeSec). |
| // We centre the hatch on that edge: half the crossfade on each color side. |
| if (crossfadeSec > 0 && segments.length > 1) {{ |
| for (let i = 0; i < segments.length - 1; i++) {{ |
| // Color edge = segments[i+1][0], hatch spans half on each side |
| const edgeT = segments[i+1][0]; |
| const overlapStart = edgeT - crossfadeSec / 2; |
| const overlapEnd = edgeT + crossfadeSec / 2; |
| const xL = (overlapStart / duration) * W; |
| const xR = (overlapEnd / duration) * W; |
| // Diagonal hatch pattern over the overlap zone |
| ctx.save(); |
| ctx.beginPath(); |
| ctx.rect(xL, 0, xR - xL, H); |
| ctx.clip(); |
| ctx.strokeStyle = 'rgba(255,255,255,0.35)'; |
| ctx.lineWidth = 1; |
| const spacing = 6; |
| for (let lx = xL - H; lx < xR + H; lx += spacing) {{ |
| ctx.beginPath(); |
| ctx.moveTo(lx, H); |
| ctx.lineTo(lx + H, 0); |
| ctx.stroke(); |
| }} |
| ctx.restore(); |
| }} |
| }} |
| |
| cv.onclick = function(e) {{ |
| const r=cv.getBoundingClientRect(); |
| const xRel=(e.clientX-r.left)/r.width; |
| const tClick=xRel*duration; |
| // Pick the segment whose unique (non-overlapping) region contains the click. |
| // Each segment owns [seg[0], nextSeg[0]) visually; last segment owns [seg[0], seg[1]]. |
| let hit=-1; |
| segments.forEach(function(seg,idx){{ |
| const uniqueEnd = idx + 1 < segments.length ? segments[idx+1][0] : seg[1]; |
| if (tClick >= seg[0] && tClick < uniqueEnd) hit = idx; |
| }}); |
| console.log('[wf click] tClick='+tClick.toFixed(2)+' hit='+hit+' audioDuration='+audioDuration+' segments='+JSON.stringify(segments)); |
| if (hit>=0) showPopup(hit, e.clientX, e.clientY); |
| else hidePopup(); |
| }}; |
| }} |
| |
| function drawPlayhead(progress) {{ |
| const dpr = window.devicePixelRatio || 1; |
| const W = wrap.getBoundingClientRect().width || window.innerWidth || 600; |
| const H = 80; |
| if (cvp.width !== W*dpr) {{ cvp.width=W*dpr; cvp.height=H*dpr; }} |
| const ctx = cvp.getContext('2d'); |
| ctx.clearRect(0,0,W*dpr,H*dpr); |
| ctx.save(); |
| ctx.scale(dpr,dpr); |
| const x=progress*W; |
| ctx.strokeStyle='#fff'; |
| ctx.lineWidth=2; |
| ctx.beginPath(); ctx.moveTo(x,0); ctx.lineTo(x,H); ctx.stroke(); |
| ctx.restore(); |
| }} |
| |
| // Poll parent for video time β find the video in the same wf_container slot |
| function findSlotVideo() {{ |
| try {{ |
| const par = window.parent.document; |
| // Walk up from our iframe to find wf_container_{slot_id}, then find its sibling video |
| const container = par.getElementById('wf_container_{slot_id}'); |
| if (!container) return par.querySelector('video'); |
| // The video is inside the same gr.Group β walk up to find it |
| let node = container.parentElement; |
| while (node && node !== par.body) {{ |
| const v = node.querySelector('video'); |
| if (v) return v; |
| node = node.parentElement; |
| }} |
| return null; |
| }} catch(e) {{ return null; }} |
| }} |
| |
| setInterval(function() {{ |
| const vid = findSlotVideo(); |
| if (vid && vid.duration && isFinite(vid.duration) && audioDuration > 0) {{ |
| drawPlayhead(vid.currentTime / vid.duration); |
| }} |
| }}, 50); |
| |
| // ββ Fetch + decode audio from Gradio file API ββββββββββββββββββββββ |
| const audioUrl = '{audio_url}'; |
| fetch(audioUrl) |
| .then(function(r) {{ return r.arrayBuffer(); }}) |
| .then(function(arrayBuf) {{ |
| const AudioCtx = window.AudioContext || window.webkitAudioContext; |
| if (!AudioCtx) return; |
| const tmpCtx = new AudioCtx({{sampleRate:44100}}); |
| tmpCtx.decodeAudioData(arrayBuf, |
| function(ab) {{ |
| try {{ tmpCtx.close(); }} catch(e) {{}} |
| function tryDraw() {{ |
| const W = wrap.getBoundingClientRect().width || window.innerWidth; |
| if (W > 0) {{ drawWaveform(ab.getChannelData(0), ab.duration); }} |
| else {{ setTimeout(tryDraw, 100); }} |
| }} |
| tryDraw(); |
| }}, |
| function(err) {{}} |
| ); |
| }}) |
| .catch(function(e) {{}}); |
| }})(); |
| </script> |
| </body> |
| </html>""" |
|
|
| |
| srcdoc = _html.escape(iframe_inner, quote=True) |
| state_escaped = _html.escape(state_json or "", quote=True) |
|
|
| return f""" |
| <div id="wf_container_{slot_id}" |
| data-fn-index="{fn_index}" |
| data-state="{state_escaped}" |
| style="background:#1a1a1a;border-radius:8px;padding:10px;margin-top:6px;position:relative;"> |
| <div style="position:relative;width:100%;height:80px;"> |
| <iframe id="wf_iframe_{slot_id}" |
| srcdoc="{srcdoc}" |
| sandbox="allow-scripts allow-same-origin" |
| style="width:100%;height:80px;border:none;border-radius:4px;display:block;" |
| scrolling="no"></iframe> |
| </div> |
| <div style="display:flex;align-items:center;gap:8px;margin-top:6px;"> |
| <span id="wf_statusbar_{slot_id}" style="color:#888;font-size:11px;">Click a segment to regenerate | Playhead syncs to video</span> |
| <a href="{audio_url}" download="audio_{slot_id}.wav" |
| style="margin-left:auto;background:#333;color:#eee;border:1px solid #555; |
| border-radius:4px;padding:3px 10px;font-size:12px;text-decoration:none;"> |
| ↓ Audio</a>{f''' |
| <a href="/gradio_api/file={video_path}" download="video_{slot_id}.mp4" |
| style="background:#333;color:#eee;border:1px solid #555; |
| border-radius:4px;padding:3px 10px;font-size:12px;text-decoration:none;"> |
| ↓ Video</a>''' if video_path else ''} |
| </div> |
| <div id="wf_seglabel_{slot_id}" |
| style="color:#aaa;font-size:11px;margin-top:4px;min-height:16px;"></div> |
| </div> |
| """ |
|
|
|
|
| def _make_output_slots(tab_prefix: str) -> tuple: |
| """Build MAX_SLOTS output groups for one tab. |
| |
| Each slot has: video and waveform HTML. |
| Regen is triggered via direct Gradio queue API calls from JS (no hidden |
| trigger textboxes needed β DOM event dispatch is unreliable in Gradio 5 |
| Svelte components). State JSON is embedded in the waveform HTML's |
| data-state attribute and passed directly in the queue API payload. |
| Returns (grps, vids, waveforms). |
| """ |
| grps, vids, waveforms = [], [], [] |
| for i in range(MAX_SLOTS): |
| slot_id = f"{tab_prefix}_{i}" |
| with gr.Group(visible=(i == 0)) as g: |
| vids.append(gr.Video(label=f"Generation {i+1} β Video", |
| elem_id=f"slot_vid_{slot_id}", |
| show_download_button=False)) |
| waveforms.append(gr.HTML( |
| value="<p style='color:#888;font-size:12px'>Generate audio to see waveform.</p>", |
| elem_id=f"slot_wave_{slot_id}", |
| )) |
| grps.append(g) |
| return grps, vids, waveforms |
|
|
|
|
| def _unpack_outputs(flat: list, n: int, tab_prefix: str) -> list: |
| """Turn a flat _pad_outputs list into Gradio update lists. |
| |
| flat has MAX_SLOTS * 3 items: [vid0, aud0, meta0, vid1, aud1, meta1, ...] |
| Returns updates for vids + waveforms only (NOT grps). |
| Group visibility is handled separately via .then() to avoid Gradio 5 SSR |
| 'Too many arguments' caused by mixing gr.Group updates with other outputs. |
| State JSON is embedded in the waveform HTML data-state attribute so JS |
| can read it when calling the Gradio queue API for regen. |
| """ |
| n = int(n) |
| vid_updates = [] |
| wave_updates = [] |
| for i in range(MAX_SLOTS): |
| vid_path = flat[i * 3] |
| aud_path = flat[i * 3 + 1] |
| meta = flat[i * 3 + 2] |
| vid_updates.append(gr.update(value=vid_path)) |
| if aud_path and meta: |
| slot_id = f"{tab_prefix}_{i}" |
| state_json = json.dumps(meta) |
| html = _build_waveform_html(aud_path, meta["segments"], slot_id, |
| "", state_json=state_json, |
| video_path=meta.get("video_path", ""), |
| crossfade_s=float(meta.get("crossfade_s", 0))) |
| wave_updates.append(gr.update(value=html)) |
| else: |
| wave_updates.append(gr.update( |
| value="<p style='color:#888;font-size:12px'>Generate audio to see waveform.</p>" |
| )) |
| return vid_updates + wave_updates |
|
|
|
|
| def _on_video_upload_taro(video_file, num_steps, crossfade_s): |
| if video_file is None: |
| return gr.update(maximum=MAX_SLOTS, value=1) |
| try: |
| D = get_video_duration(video_file) |
| max_s = _taro_calc_max_samples(D, int(num_steps), float(crossfade_s)) |
| except Exception: |
| max_s = MAX_SLOTS |
| return gr.update(maximum=max_s, value=min(1, max_s)) |
|
|
|
|
| def _update_slot_visibility(n): |
| n = int(n) |
| return [gr.update(visible=(i < n)) for i in range(MAX_SLOTS)] |
|
|
|
|
| |
| |
| |
|
|
| _SLOT_CSS = """ |
| /* Responsive video: fills column width, height auto from aspect ratio */ |
| .gradio-video video { |
| width: 100%; |
| height: auto; |
| max-height: 60vh; |
| object-fit: contain; |
| } |
| /* Force two-column layout to stay equal-width */ |
| .gradio-container .gradio-row > .gradio-column { |
| flex: 1 1 0 !important; |
| min-width: 0 !important; |
| max-width: 50% !important; |
| } |
| /* Hide the built-in download button on output video slots β downloads are |
| handled by the waveform panel links which always reflect the latest regen. */ |
| [id^="slot_vid_"] .download-icon, |
| [id^="slot_vid_"] button[aria-label="Download"], |
| [id^="slot_vid_"] a[download] { |
| display: none !important; |
| } |
| """ |
|
|
| _GLOBAL_JS = """ |
| () => { |
| // Global postMessage handler for waveform iframe events. |
| // Runs once on page load (Gradio js= parameter). |
| // Handles: popup open/close relay, regen trigger via Gradio queue API. |
| if (window._wf_global_listener) return; // already registered |
| window._wf_global_listener = true; |
| |
| // ββ ZeroGPU quota attribution ββ |
| // HF Spaces run inside an iframe on huggingface.co. Gradio's own JS client |
| // gets ZeroGPU auth headers (x-zerogpu-token, x-zerogpu-uuid) by sending a |
| // postMessage("zerogpu-headers") to the parent frame. The parent responds |
| // with a Map of headers that must be included on queue/join calls. |
| // We replicate this exact mechanism so our raw regen fetch() calls are |
| // attributed to the logged-in user's Pro quota. |
| function _fetchZerogpuHeaders() { |
| return new Promise(function(resolve) { |
| // Check if we're in an HF iframe with zerogpu support |
| if (typeof window === 'undefined' || window.parent === window || !window.supports_zerogpu_headers) { |
| console.log('[zerogpu] not in HF iframe or no zerogpu support'); |
| resolve({}); |
| return; |
| } |
| // Determine origin β same logic as Gradio's client |
| var hostname = window.location.hostname; |
| var hfhubdev = 'dev.spaces.huggingface.tech'; |
| var origin = hostname.includes('.dev.') |
| ? 'https://moon-' + hostname.split('.')[1] + '.' + hfhubdev |
| : 'https://huggingface.co'; |
| // Use MessageChannel just like Gradio's post_message helper |
| var channel = new MessageChannel(); |
| var done = false; |
| channel.port1.onmessage = function(ev) { |
| channel.port1.close(); |
| done = true; |
| var headers = ev.data; |
| if (headers && typeof headers === 'object') { |
| // Convert Map to plain object if needed |
| var obj = {}; |
| if (typeof headers.forEach === 'function') { |
| headers.forEach(function(v, k) { obj[k] = v; }); |
| } else { |
| obj = headers; |
| } |
| console.log('[zerogpu] got headers from parent:', Object.keys(obj).join(', ')); |
| resolve(obj); |
| } else { |
| resolve({}); |
| } |
| }; |
| window.parent.postMessage('zerogpu-headers', origin, [channel.port2]); |
| // Timeout: don't block regen if parent doesn't respond |
| setTimeout(function() { if (!done) { done = true; channel.port1.close(); resolve({}); } }, 3000); |
| }); |
| } |
| |
| // Cache: api_name -> fn_index, built once from gradio_config.dependencies |
| let _fnIndexCache = null; |
| function getFnIndex(apiName) { |
| if (!_fnIndexCache) { |
| _fnIndexCache = {}; |
| const deps = window.gradio_config && window.gradio_config.dependencies; |
| if (deps) deps.forEach(function(d, i) { |
| if (d.api_name) _fnIndexCache[d.api_name] = i; |
| }); |
| } |
| return _fnIndexCache[apiName]; |
| } |
| |
| // Read a component's current DOM value by elem_id. |
| // For Number/Slider: reads the <input type="number"> or <input type="range">. |
| // For Textbox/Radio: reads the <textarea> or checked <input type="radio">. |
| // Returns null if not found. |
| function readComponentValue(elemId) { |
| const el = document.getElementById(elemId); |
| if (!el) return null; |
| const numInput = el.querySelector('input[type="number"]'); |
| if (numInput) return parseFloat(numInput.value); |
| const rangeInput = el.querySelector('input[type="range"]'); |
| if (rangeInput) return parseFloat(rangeInput.value); |
| const radio = el.querySelector('input[type="radio"]:checked'); |
| if (radio) return radio.value; |
| const ta = el.querySelector('textarea'); |
| if (ta) return ta.value; |
| const txt = el.querySelector('input[type="text"], input:not([type])'); |
| if (txt) return txt.value; |
| return null; |
| } |
| |
| // Fire regen for a given slot and segment by posting directly to the |
| // Gradio queue API β bypasses Svelte binding entirely. |
| // targetModel: 'taro' | 'mma' | 'hf' (which model to use for inference) |
| // If targetModel matches the slot's own prefix, uses the per-slot regen_* endpoint. |
| // Otherwise uses the shared xregen_* cross-model endpoint. |
| function fireRegen(slot_id, seg_idx, targetModel) { |
| // Block if a regen is already in-flight for this slot |
| if (_regenInFlight[slot_id]) { |
| console.log('[fireRegen] blocked β regen already in-flight for', slot_id); |
| return; |
| } |
| _regenInFlight[slot_id] = true; |
| |
| const prefix = slot_id.split('_')[0]; // owning tab: 'taro'|'mma'|'hf' |
| const slotNum = parseInt(slot_id.split('_')[1], 10); |
| |
| // Decide which endpoint to call |
| const crossModel = (targetModel !== prefix); |
| let apiName, data; |
| |
| // Read state_json from the waveform container data-state attribute |
| const container = document.getElementById('wf_container_' + slot_id); |
| const stateJson = container ? (container.getAttribute('data-state') || '') : ''; |
| if (!stateJson) { |
| console.warn('[fireRegen] no state_json for slot', slot_id); |
| return; |
| } |
| |
| if (!crossModel) { |
| // ββ Same-model regen: per-slot endpoint, video passed as null ββ |
| apiName = 'regen_' + prefix + '_' + slotNum; |
| if (prefix === 'taro') { |
| data = [seg_idx, stateJson, null, |
| readComponentValue('taro_seed'), readComponentValue('taro_cfg'), |
| readComponentValue('taro_steps'), readComponentValue('taro_mode'), |
| readComponentValue('taro_cf_dur'), readComponentValue('taro_cf_db')]; |
| } else if (prefix === 'mma') { |
| data = [seg_idx, stateJson, null, |
| readComponentValue('mma_prompt'), readComponentValue('mma_neg'), |
| readComponentValue('mma_seed'), readComponentValue('mma_cfg'), |
| readComponentValue('mma_steps'), |
| readComponentValue('mma_cf_dur'), readComponentValue('mma_cf_db')]; |
| } else { |
| data = [seg_idx, stateJson, null, |
| readComponentValue('hf_prompt'), readComponentValue('hf_neg'), |
| readComponentValue('hf_seed'), readComponentValue('hf_guidance'), |
| readComponentValue('hf_steps'), readComponentValue('hf_size'), |
| readComponentValue('hf_cf_dur'), readComponentValue('hf_cf_db')]; |
| } |
| } else { |
| // ββ Cross-model regen: shared xregen_* endpoint ββ |
| // slot_id is passed so the server knows which slot's state to splice into. |
| // UI params are read from the target model's tab inputs. |
| if (targetModel === 'taro') { |
| apiName = 'xregen_taro'; |
| data = [seg_idx, stateJson, slot_id, |
| readComponentValue('taro_seed'), readComponentValue('taro_cfg'), |
| readComponentValue('taro_steps'), readComponentValue('taro_mode'), |
| readComponentValue('taro_cf_dur'), readComponentValue('taro_cf_db')]; |
| } else if (targetModel === 'mma') { |
| apiName = 'xregen_mmaudio'; |
| data = [seg_idx, stateJson, slot_id, |
| readComponentValue('mma_prompt'), readComponentValue('mma_neg'), |
| readComponentValue('mma_seed'), readComponentValue('mma_cfg'), |
| readComponentValue('mma_steps'), |
| readComponentValue('mma_cf_dur'), readComponentValue('mma_cf_db')]; |
| } else { |
| apiName = 'xregen_hunyuan'; |
| data = [seg_idx, stateJson, slot_id, |
| readComponentValue('hf_prompt'), readComponentValue('hf_neg'), |
| readComponentValue('hf_seed'), readComponentValue('hf_guidance'), |
| readComponentValue('hf_steps'), readComponentValue('hf_size'), |
| readComponentValue('hf_cf_dur'), readComponentValue('hf_cf_db')]; |
| } |
| } |
| |
| console.log('[fireRegen] calling api', apiName, 'seg', seg_idx); |
| |
| // Snapshot current waveform HTML + video src before mutating anything, |
| // so we can restore on error (e.g. quota exceeded). |
| var _preRegenWaveHtml = null; |
| var _preRegenVideoSrc = null; |
| var waveElSnap = document.getElementById('slot_wave_' + slot_id); |
| if (waveElSnap) _preRegenWaveHtml = waveElSnap.innerHTML; |
| var vidElSnap = document.getElementById('slot_vid_' + slot_id); |
| if (vidElSnap) { var vSnap = vidElSnap.querySelector('video'); if (vSnap) _preRegenVideoSrc = vSnap.getAttribute('src'); } |
| |
| // Show spinner immediately |
| const lbl = document.getElementById('wf_seglabel_' + slot_id); |
| if (lbl) lbl.textContent = 'Regenerating Seg ' + (seg_idx + 1) + '...'; |
| |
| const fnIndex = getFnIndex(apiName); |
| if (fnIndex === undefined) { |
| console.warn('[fireRegen] fn_index not found for api_name:', apiName); |
| return; |
| } |
| // Get ZeroGPU auth headers from the HF parent frame (same mechanism |
| // Gradio's own JS client uses), then fire the regen queue/join call. |
| // Falls back to user-supplied HF token if zerogpu headers aren't available. |
| _fetchZerogpuHeaders().then(function(zerogpuHeaders) { |
| var regenHeaders = {'Content-Type': 'application/json'}; |
| var hasZerogpu = zerogpuHeaders && Object.keys(zerogpuHeaders).length > 0; |
| if (hasZerogpu) { |
| // Merge zerogpu headers (x-zerogpu-token, x-zerogpu-uuid) |
| for (var k in zerogpuHeaders) { regenHeaders[k] = zerogpuHeaders[k]; } |
| console.log('[fireRegen] using zerogpu headers from parent frame'); |
| } else { |
| console.warn('[fireRegen] no zerogpu headers available β may use anonymous quota'); |
| } |
| fetch('/gradio_api/queue/join', { |
| method: 'POST', |
| credentials: 'include', |
| headers: regenHeaders, |
| body: JSON.stringify({ |
| data: data, |
| fn_index: fnIndex, |
| api_name: '/' + apiName, |
| session_hash: window.__gradio_session_hash__, |
| event_data: null, |
| trigger_id: null |
| }) |
| }).then(function(r) { return r.json(); }).then(function(j) { |
| if (!j.event_id) { console.error('[fireRegen] no event_id:', j); return; } |
| console.log('[fireRegen] queued, event_id:', j.event_id); |
| _listenAndApply(j.event_id, slot_id, seg_idx, _preRegenWaveHtml, _preRegenVideoSrc); |
| }).catch(function(e) { |
| console.error('[fireRegen] fetch error:', e); |
| if (lbl) lbl.textContent = 'Error β see console'; |
| var sb = document.getElementById('wf_statusbar_' + slot_id); |
| if (sb) { sb.style.color = '#e05252'; sb.textContent = '\u26a0 Request failed: ' + e.message; } |
| }); |
| }); |
| } |
| |
| // Subscribe to Gradio SSE stream for an event and apply outputs to DOM. |
| // For regen handlers, output[0] = video update, output[1] = waveform HTML update. |
| function _applyVideoSrc(slot_id, newSrc) { |
| var vidEl = document.getElementById('slot_vid_' + slot_id); |
| if (!vidEl) return false; |
| var video = vidEl.querySelector('video'); |
| if (!video) return false; |
| if (video.getAttribute('src') === newSrc) return true; // already correct |
| video.setAttribute('src', newSrc); |
| video.src = newSrc; |
| video.load(); |
| console.log('[_applyVideoSrc] applied src to', 'slot_vid_' + slot_id, 'src:', newSrc.slice(-40)); |
| return true; |
| } |
| |
| // Toast notification β styled like ZeroGPU quota warnings. |
| function _showRegenToast(message, isError) { |
| var t = document.createElement('div'); |
| t.style.cssText = 'position:fixed;bottom:24px;left:50%;transform:translateX(-50%);' + |
| 'z-index:2147483647;padding:12px 20px;border-radius:8px;font-family:sans-serif;' + |
| 'font-size:13px;max-width:520px;text-align:center;box-shadow:0 4px 20px rgba(0,0,0,.6);' + |
| 'background:' + (isError ? '#7a1c1c' : '#1c4a1c') + ';color:#fff;' + |
| 'border:1px solid ' + (isError ? '#c0392b' : '#27ae60') + ';' + |
| 'pointer-events:none;'; |
| t.textContent = message; |
| document.body.appendChild(t); |
| setTimeout(function() { |
| t.style.transition = 'opacity 0.5s'; |
| t.style.opacity = '0'; |
| setTimeout(function() { t.parentNode && t.parentNode.removeChild(t); }, 600); |
| }, isError ? 8000 : 3000); |
| } |
| |
| function _listenAndApply(eventId, slot_id, seg_idx, preRegenWaveHtml, preRegenVideoSrc) { |
| var _pendingVideoSrc = null; |
| const es = new EventSource('/gradio_api/queue/data?session_hash=' + window.__gradio_session_hash__); |
| es.onmessage = function(e) { |
| var msg; |
| try { msg = JSON.parse(e.data); } catch(_) { return; } |
| if (msg.event_id !== eventId) return; |
| if (msg.msg === 'process_generating' || msg.msg === 'process_completed') { |
| var out = msg.output; |
| if (out && out.data) { |
| var vidUpdate = out.data[0]; |
| var waveUpdate = out.data[1]; |
| var newSrc = null; |
| if (vidUpdate) { |
| if (vidUpdate.value && vidUpdate.value.video && vidUpdate.value.video.url) newSrc = vidUpdate.value.video.url; |
| else if (vidUpdate.video && vidUpdate.video.url) newSrc = vidUpdate.video.url; |
| else if (vidUpdate.value && vidUpdate.value.url) newSrc = vidUpdate.value.url; |
| else if (typeof vidUpdate.value === 'string') newSrc = vidUpdate.value; |
| else if (vidUpdate.url) newSrc = vidUpdate.url; |
| } |
| if (newSrc) _pendingVideoSrc = newSrc; |
| var waveHtml = null; |
| if (waveUpdate) { |
| if (typeof waveUpdate === 'string') waveHtml = waveUpdate; |
| else if (waveUpdate.value && typeof waveUpdate.value === 'string') waveHtml = waveUpdate.value; |
| } |
| if (waveHtml) { |
| var waveEl = document.getElementById('slot_wave_' + slot_id); |
| if (waveEl) { |
| var inner = waveEl.querySelector('.prose') || waveEl.querySelector('div'); |
| if (inner) inner.innerHTML = waveHtml; |
| else waveEl.innerHTML = waveHtml; |
| } |
| } |
| } |
| if (msg.msg === 'process_completed') { |
| es.close(); |
| _regenInFlight[slot_id] = false; |
| var errMsg = msg.output && msg.output.error; |
| var hadError = !!errMsg; |
| console.log('[fireRegen] completed for', slot_id, 'error:', hadError, errMsg || ''); |
| var lbl = document.getElementById('wf_seglabel_' + slot_id); |
| if (hadError) { |
| var toastMsg = typeof errMsg === 'string' ? errMsg : JSON.stringify(errMsg); |
| // Restore previous waveform HTML and video src |
| if (preRegenWaveHtml !== null) { |
| var waveEl2 = document.getElementById('slot_wave_' + slot_id); |
| if (waveEl2) waveEl2.innerHTML = preRegenWaveHtml; |
| } |
| if (preRegenVideoSrc !== null) { |
| var vidElR = document.getElementById('slot_vid_' + slot_id); |
| if (vidElR) { var vR = vidElR.querySelector('video'); if (vR) { vR.setAttribute('src', preRegenVideoSrc); vR.src = preRegenVideoSrc; vR.load(); } } |
| } |
| // Update the statusbar (query after restore so we get the freshly-restored element) |
| var isAbort = toastMsg.toLowerCase().indexOf('aborted') !== -1; |
| var isTimeout = toastMsg.toLowerCase().indexOf('timeout') !== -1; |
| var failMsg = isAbort || isTimeout |
| ? '\u26a0 GPU cold-start β segment unchanged, try again' |
| : '\u26a0 Regen failed β segment unchanged'; |
| var statusBar = document.getElementById('wf_statusbar_' + slot_id); |
| if (statusBar) { |
| statusBar.style.color = '#e05252'; |
| statusBar.textContent = failMsg; |
| setTimeout(function() { statusBar.style.color = '#888'; statusBar.textContent = 'Click a segment to regenerate \u00a0|\u00a0 Playhead syncs to video'; }, 8000); |
| } |
| } else { |
| if (lbl) lbl.textContent = 'Done'; |
| var src = _pendingVideoSrc; |
| if (src) { |
| _applyVideoSrc(slot_id, src); |
| setTimeout(function() { _applyVideoSrc(slot_id, src); }, 50); |
| setTimeout(function() { _applyVideoSrc(slot_id, src); }, 300); |
| setTimeout(function() { _applyVideoSrc(slot_id, src); }, 800); |
| var vidEl = document.getElementById('slot_vid_' + slot_id); |
| if (vidEl) { |
| var obs = new MutationObserver(function() { _applyVideoSrc(slot_id, src); }); |
| obs.observe(vidEl, {subtree: true, attributes: true, attributeFilter: ['src'], childList: true}); |
| setTimeout(function() { obs.disconnect(); }, 2000); |
| } |
| } |
| } |
| } |
| } |
| if (msg.msg === 'close_stream') { es.close(); } |
| }; |
| es.onerror = function() { es.close(); _regenInFlight[slot_id] = false; }; |
| } |
| |
| // Track in-flight regen per slot β prevents queuing multiple jobs from rapid clicks |
| var _regenInFlight = {}; |
| |
| // Shared popup element created once and reused across all slots |
| let _popup = null; |
| let _pendingSlot = null, _pendingIdx = null; |
| |
| function ensurePopup() { |
| if (_popup) return _popup; |
| _popup = document.createElement('div'); |
| _popup.style.cssText = 'display:none;position:fixed;z-index:99999;' + |
| 'background:#2a2a2a;border:1px solid #555;border-radius:6px;' + |
| 'padding:8px 12px;box-shadow:0 4px 16px rgba(0,0,0,.5);font-family:sans-serif;'; |
| var btnStyle = 'color:#fff;border:none;border-radius:4px;padding:5px 10px;' + |
| 'font-size:11px;cursor:pointer;flex:1;'; |
| _popup.innerHTML = |
| '<div id="_wf_popup_lbl" style="color:#ccc;font-size:11px;margin-bottom:6px;white-space:nowrap;"></div>' + |
| '<div style="display:flex;gap:5px;">' + |
| '<button id="_wf_popup_taro" style="background:#1d6fa5;' + btnStyle + '">⟳ TARO</button>' + |
| '<button id="_wf_popup_mma" style="background:#2d7a4a;' + btnStyle + '">⟳ MMAudio</button>' + |
| '<button id="_wf_popup_hf" style="background:#7a3d8c;' + btnStyle + '">⟳ Hunyuan</button>' + |
| '</div>'; |
| document.body.appendChild(_popup); |
| ['taro','mma','hf'].forEach(function(model) { |
| document.getElementById('_wf_popup_' + model).onclick = function(e) { |
| e.stopPropagation(); |
| var slot = _pendingSlot, idx = _pendingIdx; |
| hidePopup(); |
| if (slot !== null && idx !== null) fireRegen(slot, idx, model); |
| }; |
| }); |
| // Use bubble phase (false) so stopPropagation() on the button click prevents this from firing |
| document.addEventListener('click', function() { hidePopup(); }, false); |
| return _popup; |
| } |
| |
| function hidePopup() { |
| if (_popup) _popup.style.display = 'none'; |
| _pendingSlot = null; _pendingIdx = null; |
| } |
| |
| window.addEventListener('message', function(e) { |
| const d = e.data; |
| console.log('[global msg] received type=' + (d && d.type) + ' action=' + (d && d.action)); |
| if (!d || d.type !== 'wf_popup') return; |
| const p = ensurePopup(); |
| if (d.action === 'hide') { hidePopup(); return; } |
| // action === 'show' |
| _pendingSlot = d.slot_id; |
| _pendingIdx = d.seg_idx; |
| const lbl = document.getElementById('_wf_popup_lbl'); |
| if (lbl) lbl.textContent = 'Seg ' + (d.seg_idx + 1) + |
| ' (' + d.t0.toFixed(2) + 's \u2013 ' + d.t1.toFixed(2) + 's)'; |
| p.style.display = 'block'; |
| p.style.left = (d.x + 10) + 'px'; |
| p.style.top = (d.y + 10) + 'px'; |
| requestAnimationFrame(function() { |
| const r = p.getBoundingClientRect(); |
| if (r.right > window.innerWidth - 8) p.style.left = (window.innerWidth - r.width - 8) + 'px'; |
| if (r.bottom > window.innerHeight - 8) p.style.top = (window.innerHeight - r.height - 8) + 'px'; |
| }); |
| }); |
| } |
| """ |
|
|
| with gr.Blocks(title="Generate Audio for Video", css=_SLOT_CSS, js=_GLOBAL_JS) as demo: |
| gr.Markdown( |
| "# Generate Audio for Video\n" |
| "Choose a model and upload a video to generate synchronized audio.\n\n" |
| "| Model | Best for | Avoid for |\n" |
| "|-------|----------|-----------|\n" |
| "| **TARO** | Natural, physics-driven impacts β footsteps, collisions, water, wind, crackling fire. Excels when the sound is tightly coupled to visible motion without needing a text description. | Dialogue, music, or complex layered soundscapes where semantic context matters. |\n" |
| "| **MMAudio** | Mixed scenes where you want both visual grounding *and* semantic control via a text prompt β e.g. a busy street scene where you want to emphasize the rain rather than the traffic. Great for ambient textures and nuanced sound design. | Pure impact/foley shots where TARO's motion-coupling would be sharper, or cinematic music beds. |\n" |
| "| **HunyuanFoley** | Cinematic foley requiring high fidelity and explicit creative direction β dramatic SFX, layered environmental design, or any scene where you have a clear written description of the desired sound palette. | Quick one-shot clips where you don't want to write a prompt, or raw impact sounds where timing precision matters more than richness. |" |
| ) |
|
|
| with gr.Tabs(): |
|
|
| |
| |
| |
| with gr.Tab("TARO"): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| taro_video = gr.Video(label="Input Video") |
| taro_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0, elem_id="taro_seed") |
| taro_cfg = gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8.0, step=0.5, elem_id="taro_cfg") |
| taro_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1, elem_id="taro_steps") |
| taro_mode = gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde", elem_id="taro_mode") |
| taro_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=4, value=2, step=0.1, elem_id="taro_cf_dur") |
| taro_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3", elem_id="taro_cf_db") |
| taro_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1) |
| taro_btn = gr.Button("Generate", variant="primary") |
|
|
| with gr.Column(scale=1): |
| (taro_slot_grps, taro_slot_vids, |
| taro_slot_waves) = _make_output_slots("taro") |
|
|
| |
| |
| |
| |
| taro_regen_seg = gr.Textbox(value="0", render=False) |
| taro_regen_state = gr.Textbox(value="", render=False) |
|
|
| for trigger in [taro_video, taro_steps, taro_cf_dur]: |
| trigger.change( |
| fn=_on_video_upload_taro, |
| inputs=[taro_video, taro_steps, taro_cf_dur], |
| outputs=[taro_samples], |
| ) |
| taro_samples.change( |
| fn=_update_slot_visibility, |
| inputs=[taro_samples], |
| outputs=taro_slot_grps, |
| ) |
|
|
| def _run_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n): |
| flat = generate_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n) |
| return _unpack_outputs(flat, n, "taro") |
|
|
| |
| |
| (taro_btn.click( |
| fn=_run_taro, |
| inputs=[taro_video, taro_seed, taro_cfg, taro_steps, taro_mode, |
| taro_cf_dur, taro_cf_db, taro_samples], |
| outputs=taro_slot_vids + taro_slot_waves, |
| ).then( |
| fn=_update_slot_visibility, |
| inputs=[taro_samples], |
| outputs=taro_slot_grps, |
| )) |
|
|
| |
| |
| taro_regen_btns = _register_regen_handlers( |
| "taro", "taro", taro_regen_seg, taro_regen_state, |
| [taro_video, taro_seed, taro_cfg, taro_steps, |
| taro_mode, taro_cf_dur, taro_cf_db], |
| taro_slot_vids, taro_slot_waves, |
| ) |
|
|
| |
| |
| |
| with gr.Tab("MMAudio"): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| mma_video = gr.Video(label="Input Video") |
| mma_prompt = gr.Textbox(label="Prompt", placeholder="e.g. footsteps on gravel", elem_id="mma_prompt") |
| mma_neg = gr.Textbox(label="Negative Prompt", value="music", placeholder="music, speech", elem_id="mma_neg") |
| mma_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0, elem_id="mma_seed") |
| mma_cfg = gr.Slider(label="CFG Strength", minimum=1, maximum=10, value=4.5, step=0.5, elem_id="mma_cfg") |
| mma_steps = gr.Slider(label="Steps", minimum=10, maximum=50, value=25, step=1, elem_id="mma_steps") |
| mma_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=4, value=2, step=0.1, elem_id="mma_cf_dur") |
| mma_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3", elem_id="mma_cf_db") |
| mma_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1) |
| mma_btn = gr.Button("Generate", variant="primary") |
|
|
| with gr.Column(scale=1): |
| (mma_slot_grps, mma_slot_vids, |
| mma_slot_waves) = _make_output_slots("mma") |
|
|
| |
| |
| mma_regen_seg = gr.Textbox(value="0", render=False) |
| mma_regen_state = gr.Textbox(value="", render=False) |
|
|
| mma_samples.change( |
| fn=_update_slot_visibility, |
| inputs=[mma_samples], |
| outputs=mma_slot_grps, |
| ) |
|
|
| def _run_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n): |
| flat = generate_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n) |
| return _unpack_outputs(flat, n, "mma") |
|
|
| (mma_btn.click( |
| fn=_run_mmaudio, |
| inputs=[mma_video, mma_prompt, mma_neg, mma_seed, |
| mma_cfg, mma_steps, mma_cf_dur, mma_cf_db, mma_samples], |
| outputs=mma_slot_vids + mma_slot_waves, |
| ).then( |
| fn=_update_slot_visibility, |
| inputs=[mma_samples], |
| outputs=mma_slot_grps, |
| )) |
|
|
| mma_regen_btns = _register_regen_handlers( |
| "mma", "mmaudio", mma_regen_seg, mma_regen_state, |
| [mma_video, mma_prompt, mma_neg, mma_seed, |
| mma_cfg, mma_steps, mma_cf_dur, mma_cf_db], |
| mma_slot_vids, mma_slot_waves, |
| ) |
|
|
| |
| |
| |
| with gr.Tab("HunyuanFoley"): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| hf_video = gr.Video(label="Input Video") |
| hf_prompt = gr.Textbox(label="Prompt", placeholder="e.g. rain hitting a metal roof", elem_id="hf_prompt") |
| hf_neg = gr.Textbox(label="Negative Prompt", value="noisy, harsh", elem_id="hf_neg") |
| hf_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0, elem_id="hf_seed") |
| hf_guidance = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, value=4.5, step=0.5, elem_id="hf_guidance") |
| hf_steps = gr.Slider(label="Steps", minimum=10, maximum=100, value=50, step=5, elem_id="hf_steps") |
| hf_size = gr.Radio(label="Model Size", choices=["xl", "xxl"], value="xxl", elem_id="hf_size") |
| hf_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=4, value=2, step=0.1, elem_id="hf_cf_dur") |
| hf_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3", elem_id="hf_cf_db") |
| hf_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1) |
| hf_btn = gr.Button("Generate", variant="primary") |
|
|
| with gr.Column(scale=1): |
| (hf_slot_grps, hf_slot_vids, |
| hf_slot_waves) = _make_output_slots("hf") |
|
|
| |
| |
| hf_regen_seg = gr.Textbox(value="0", render=False) |
| hf_regen_state = gr.Textbox(value="", render=False) |
|
|
| hf_samples.change( |
| fn=_update_slot_visibility, |
| inputs=[hf_samples], |
| outputs=hf_slot_grps, |
| ) |
|
|
| def _run_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n): |
| flat = generate_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n) |
| return _unpack_outputs(flat, n, "hf") |
|
|
| (hf_btn.click( |
| fn=_run_hunyuan, |
| inputs=[hf_video, hf_prompt, hf_neg, hf_seed, |
| hf_guidance, hf_steps, hf_size, hf_cf_dur, hf_cf_db, hf_samples], |
| outputs=hf_slot_vids + hf_slot_waves, |
| ).then( |
| fn=_update_slot_visibility, |
| inputs=[hf_samples], |
| outputs=hf_slot_grps, |
| )) |
|
|
| hf_regen_btns = _register_regen_handlers( |
| "hf", "hunyuan", hf_regen_seg, hf_regen_state, |
| [hf_video, hf_prompt, hf_neg, hf_seed, |
| hf_guidance, hf_steps, hf_size, hf_cf_dur, hf_cf_db], |
| hf_slot_vids, hf_slot_waves, |
| ) |
|
|
| |
| |
| |
| |
| |
| taro_video.upload(fn=_transcode_for_browser, inputs=[taro_video], outputs=[taro_video]) |
| mma_video.upload(fn=_transcode_for_browser, inputs=[mma_video], outputs=[mma_video]) |
| hf_video.upload(fn=_transcode_for_browser, inputs=[hf_video], outputs=[hf_video]) |
|
|
| |
| _sync = lambda v: (gr.update(value=v), gr.update(value=v)) |
| taro_video.change(fn=_sync, inputs=[taro_video], outputs=[mma_video, hf_video]) |
| mma_video.change(fn=_sync, inputs=[mma_video], outputs=[taro_video, hf_video]) |
| hf_video.change(fn=_sync, inputs=[hf_video], outputs=[taro_video, mma_video]) |
|
|
| |
| |
| |
| |
| _xr_seg = gr.Textbox(value="0", render=False) |
| _xr_state = gr.Textbox(value="", render=False) |
| _xr_slot_id = gr.Textbox(value="", render=False) |
| |
| |
| |
| |
| |
| |
| _xr_dummy_vid = taro_slot_vids[0] |
| _xr_dummy_wave = taro_slot_waves[0] |
|
|
| |
| _xr_taro_seed = gr.Textbox(value="-1", render=False) |
| _xr_taro_cfg = gr.Textbox(value="7.5", render=False) |
| _xr_taro_steps = gr.Textbox(value="25", render=False) |
| _xr_taro_mode = gr.Textbox(value="sde", render=False) |
| _xr_taro_cfd = gr.Textbox(value="2", render=False) |
| _xr_taro_cfdb = gr.Textbox(value="3", render=False) |
| gr.Button(render=False).click( |
| fn=xregen_taro, |
| inputs=[_xr_seg, _xr_state, _xr_slot_id, |
| _xr_taro_seed, _xr_taro_cfg, _xr_taro_steps, |
| _xr_taro_mode, _xr_taro_cfd, _xr_taro_cfdb], |
| outputs=[_xr_dummy_vid, _xr_dummy_wave], |
| api_name="xregen_taro", |
| ) |
|
|
| |
| _xr_mma_prompt = gr.Textbox(value="", render=False) |
| _xr_mma_neg = gr.Textbox(value="", render=False) |
| _xr_mma_seed = gr.Textbox(value="-1", render=False) |
| _xr_mma_cfg = gr.Textbox(value="4.5", render=False) |
| _xr_mma_steps = gr.Textbox(value="25", render=False) |
| _xr_mma_cfd = gr.Textbox(value="2", render=False) |
| _xr_mma_cfdb = gr.Textbox(value="3", render=False) |
| gr.Button(render=False).click( |
| fn=xregen_mmaudio, |
| inputs=[_xr_seg, _xr_state, _xr_slot_id, |
| _xr_mma_prompt, _xr_mma_neg, _xr_mma_seed, |
| _xr_mma_cfg, _xr_mma_steps, _xr_mma_cfd, _xr_mma_cfdb], |
| outputs=[_xr_dummy_vid, _xr_dummy_wave], |
| api_name="xregen_mmaudio", |
| ) |
|
|
| |
| _xr_hf_prompt = gr.Textbox(value="", render=False) |
| _xr_hf_neg = gr.Textbox(value="", render=False) |
| _xr_hf_seed = gr.Textbox(value="-1", render=False) |
| _xr_hf_guide = gr.Textbox(value="4.5", render=False) |
| _xr_hf_steps = gr.Textbox(value="50", render=False) |
| _xr_hf_size = gr.Textbox(value="xxl", render=False) |
| _xr_hf_cfd = gr.Textbox(value="2", render=False) |
| _xr_hf_cfdb = gr.Textbox(value="3", render=False) |
| gr.Button(render=False).click( |
| fn=xregen_hunyuan, |
| inputs=[_xr_seg, _xr_state, _xr_slot_id, |
| _xr_hf_prompt, _xr_hf_neg, _xr_hf_seed, |
| _xr_hf_guide, _xr_hf_steps, _xr_hf_size, |
| _xr_hf_cfd, _xr_hf_cfdb], |
| outputs=[_xr_dummy_vid, _xr_dummy_wave], |
| api_name="xregen_hunyuan", |
| ) |
|
|
| |
| |
| |
|
|
| print("[startup] app.py fully loaded β regen handlers registered, SSR disabled") |
| demo.queue(max_size=10).launch(ssr_mode=False, height=900, allowed_paths=["/tmp"]) |
|
|