BoxOfColors commited on
Commit
51979c2
·
1 Parent(s): d141e30

refactor: extract shared helpers to reduce technical debt

Browse files

- Replace _splice_and_save inline stitch with _stitch_wavs (fixes latent stereo bug)
- Extract _build_seg_meta helper (deduplicates 3 identical dict constructions)
- Extract _cpu_preprocess helper (deduplicates 3 identical pre-processing blocks)
- Extract _save_wav helper (deduplicates mono/stereo torchaudio.save logic)
- Extract _log_inference_timing helper (deduplicates 3 identical timing blocks)
- Remove redundant 'from pathlib import Path as _Path' in _load_mmaudio_models
- Remove unnecessary 'global _TARO_INFERENCE_CACHE' statement

Files changed (1) hide show
  1. app.py +228 -246
app.py CHANGED
@@ -8,6 +8,7 @@ Supported models
8
  HunyuanFoley – text-guided foley via SigLIP2 + Synchformer + CLAP (48 kHz, up to 15 s)
9
  """
10
 
 
11
  import os
12
  import sys
13
  import json
@@ -78,6 +79,41 @@ print("CLAP model pre-downloaded.")
78
  MAX_SLOTS = 8 # max parallel generation slots shown in UI
79
  MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≤ ~64 s at 8 s/seg)
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # Per-slot reentrant locks — prevent concurrent regens on the same slot from
82
  # producing a race condition where the second regen reads stale state
83
  # (the shared seg_state textbox hasn't been updated yet by the first regen).
@@ -91,11 +127,12 @@ def _get_slot_lock(slot_id: str) -> threading.Lock:
91
  _SLOT_LOCKS[slot_id] = threading.Lock()
92
  return _SLOT_LOCKS[slot_id]
93
 
94
- def set_global_seed(seed: int):
95
  np.random.seed(seed % (2**32))
96
  random.seed(seed)
97
  torch.manual_seed(seed)
98
- torch.cuda.manual_seed(seed)
 
99
 
100
  def get_random_seed() -> int:
101
  return random.randint(0, 2**32 - 1)
@@ -105,7 +142,7 @@ def get_video_duration(video_path: str) -> float:
105
  probe = ffmpeg.probe(video_path)
106
  return float(probe["format"]["duration"])
107
 
108
- def strip_audio_from_video(video_path: str, output_path: str):
109
  """Write a silent copy of *video_path* to *output_path* (stream-copy, no re-encode)."""
110
  ffmpeg.input(video_path).output(output_path, vcodec="copy", an=None).run(
111
  overwrite_output=True, quiet=True
@@ -130,7 +167,7 @@ def _register_tmp_dir(tmp_dir: str) -> str:
130
  return tmp_dir
131
 
132
 
133
- def _save_seg_wavs(wavs: list, tmp_dir: str, prefix: str) -> list:
134
  """Save a list of numpy wav arrays to .npy files, return list of paths.
135
  This avoids serialising large float arrays into JSON/HTML data-state."""
136
  paths = []
@@ -141,7 +178,7 @@ def _save_seg_wavs(wavs: list, tmp_dir: str, prefix: str) -> list:
141
  return paths
142
 
143
 
144
- def _load_seg_wavs(paths: list) -> list:
145
  """Load segment wav arrays from .npy file paths."""
146
  return [np.load(p) for p in paths]
147
 
@@ -190,12 +227,11 @@ def _load_mmaudio_models(device, dtype):
190
  from mmaudio.eval_utils import all_model_cfg
191
  from mmaudio.model.networks import get_my_mmaudio
192
  from mmaudio.model.utils.features_utils import FeaturesUtils
193
- from pathlib import Path as _Path
194
 
195
  model_cfg = all_model_cfg["large_44k_v2"]
196
- model_cfg.model_path = _Path(mmaudio_model_path)
197
- model_cfg.vae_path = _Path(mmaudio_vae_path)
198
- model_cfg.synchformer_ckpt = _Path(mmaudio_synchformer_path)
199
  model_cfg.bigvgan_16k_path = None
200
  seq_cfg = model_cfg.seq_cfg
201
 
@@ -225,7 +261,7 @@ def _load_hunyuan_model(device, model_size):
225
  enable_offload=False, model_size=model_size)
226
 
227
 
228
- def mux_video_audio(silent_video: str, audio_path: str, output_path: str):
229
  """Mux a silent video with an audio file into *output_path* (stream-copy video, encode audio)."""
230
  ffmpeg.output(
231
  ffmpeg.input(silent_video),
@@ -240,7 +276,7 @@ def mux_video_audio(silent_video: str, audio_path: str, output_path: str):
240
  # Used by all three models (TARO, MMAudio, HunyuanFoley). #
241
  # ------------------------------------------------------------------ #
242
 
243
- def _build_segments(total_dur_s: float, window_s: float, crossfade_s: float) -> list:
244
  """Return list of (start, end) pairs covering *total_dur_s* with a sliding
245
  window of *window_s* and *crossfade_s* overlap between consecutive segments."""
246
  if total_dur_s <= window_s:
@@ -460,12 +496,66 @@ def _taro_infer_segment(
460
  return wav[:seg_samples]
461
 
462
 
463
- def _stitch_wavs(wavs: list, crossfade_s: float, db_boost: float,
464
  total_dur_s: float, sr: int) -> np.ndarray:
 
 
465
  out = wavs[0]
466
  for nw in wavs[1:]:
467
  out = _cf_join(out, nw, crossfade_s, db_boost, sr)
468
- return out[:int(round(total_dur_s * sr))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
 
471
  @spaces.GPU(duration=_taro_duration)
@@ -473,8 +563,6 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
473
  crossfade_s, crossfade_db, num_samples):
474
  """GPU-only TARO inference — model loading + feature extraction + diffusion.
475
  Returns list of (wavs_list, onset_feats) per sample."""
476
- global _TARO_INFERENCE_CACHE
477
-
478
  seed_val = int(seed_val)
479
  crossfade_s = float(crossfade_s)
480
  num_samples = int(num_samples)
@@ -482,13 +570,9 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
482
  seed_val = random.randint(0, 2**32 - 1)
483
 
484
  torch.set_grad_enabled(False)
485
- device = "cuda" if torch.cuda.is_available() else "cpu"
486
- weight_dtype = torch.bfloat16
487
-
488
- _taro_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "TARO")
489
- if _taro_dir not in sys.path:
490
- sys.path.insert(0, _taro_dir)
491
 
 
492
  from TARO.onset_util import extract_onset
493
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
494
 
@@ -500,9 +584,16 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
500
  total_dur_s = ctx["total_dur_s"]
501
 
502
  extract_cavp, onset_model = _load_taro_feature_extractors(device)
503
- model, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype)
504
-
505
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
 
 
 
 
 
 
 
 
 
506
 
507
  results = [] # list of (wavs, onset_feats) per sample
508
  for sample_idx in range(num_samples):
@@ -516,7 +607,6 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
516
  results.append((cached["wavs"], cavp_feats, None))
517
  else:
518
  set_global_seed(sample_seed)
519
- onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
520
  wavs = []
521
  _t_infer_start = time.perf_counter()
522
  for seg_start_s, seg_end_s in segments:
@@ -531,12 +621,8 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
531
  euler_sampler, euler_maruyama_sampler,
532
  )
533
  wavs.append(wav)
534
- _t_infer_elapsed = time.perf_counter() - _t_infer_start
535
- _n_segs = len(segments)
536
- _secs_per_step = _t_infer_elapsed / (_n_segs * int(num_steps)) if _n_segs * int(num_steps) > 0 else 0
537
- print(f"[TARO] Inference done: {_n_segs} seg(s) × {int(num_steps)} steps in "
538
- f"{_t_infer_elapsed:.1f}s wall → {_secs_per_step:.3f}s/step "
539
- f"(current constant={TARO_SECS_PER_STEP})")
540
  with _TARO_CACHE_LOCK:
541
  _TARO_INFERENCE_CACHE[cache_key] = {"wavs": wavs}
542
  while len(_TARO_INFERENCE_CACHE) > _TARO_CACHE_MAXLEN:
@@ -563,11 +649,8 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
563
  num_samples = int(num_samples)
564
 
565
  # ── CPU pre-processing (no GPU needed) ──
566
- tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
567
- silent_video = os.path.join(tmp_dir, "silent_input.mp4")
568
- strip_audio_from_video(video_file, silent_video)
569
- total_dur_s = get_video_duration(video_file)
570
- segments = _build_segments(total_dur_s, TARO_MODEL_DUR, crossfade_s)
571
 
572
  # Pass pre-computed CPU results to the GPU function via context
573
  _taro_gpu_infer._cpu_ctx = {
@@ -588,7 +671,7 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
588
  for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
589
  final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR)
590
  audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
591
- torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(final_wav)).unsqueeze(0), TARO_SR)
592
  video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
593
  mux_video_audio(silent_video, audio_path, video_path)
594
  wav_paths = _save_seg_wavs(wavs, tmp_dir, f"taro_{sample_idx}")
@@ -598,20 +681,12 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
598
  if onset_feats is not None:
599
  np.save(onset_path, onset_feats)
600
  first_cavp_saved = True
601
- seg_meta = {
602
- "segments": segments,
603
- "wav_paths": wav_paths,
604
- "audio_path": audio_path,
605
- "video_path": video_path,
606
- "silent_video": silent_video,
607
- "sr": TARO_SR,
608
- "model": "taro",
609
- "crossfade_s": crossfade_s,
610
- "crossfade_db": crossfade_db,
611
- "total_dur_s": total_dur_s,
612
- "cavp_path": cavp_path,
613
- "onset_path": onset_path,
614
- }
615
  outputs.append((video_path, audio_path, seg_meta))
616
 
617
  return _pad_outputs(outputs)
@@ -643,10 +718,7 @@ def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
643
  cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
644
  """GPU-only MMAudio inference — model loading + flow-matching generation.
645
  Returns list of (seg_audios, sr) per sample."""
646
- _mmaudio_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "MMAudio")
647
- if _mmaudio_dir not in sys.path:
648
- sys.path.insert(0, _mmaudio_dir)
649
-
650
  from mmaudio.eval_utils import generate, load_video
651
  from mmaudio.model.flow_matching import FlowMatching
652
 
@@ -654,8 +726,7 @@ def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
654
  num_samples = int(num_samples)
655
  crossfade_s = float(crossfade_s)
656
 
657
- device = "cuda" if torch.cuda.is_available() else "cpu"
658
- dtype = torch.bfloat16
659
 
660
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
661
 
@@ -709,12 +780,8 @@ def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
709
  wav = wav[:, :seg_samples]
710
  seg_audios.append(wav)
711
 
712
- _t_mma_elapsed = time.perf_counter() - _t_mma_start
713
- _n_segs_mma = len(segments)
714
- _secs_per_step_mma = _t_mma_elapsed / (_n_segs_mma * int(num_steps)) if _n_segs_mma * int(num_steps) > 0 else 0
715
- print(f"[MMAudio] Inference done: {_n_segs_mma} seg(s) × {int(num_steps)} steps in "
716
- f"{_t_mma_elapsed:.1f}s wall → {_secs_per_step_mma:.3f}s/step "
717
- f"(current constant={MMAUDIO_SECS_PER_STEP})")
718
  results.append((seg_audios, sr))
719
 
720
  # Free GPU memory between samples to prevent VRAM fragmentation
@@ -735,21 +802,14 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
735
  crossfade_db = float(crossfade_db)
736
 
737
  # ── CPU pre-processing ──
738
- tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
739
- silent_video = os.path.join(tmp_dir, "silent_input.mp4")
740
- strip_audio_from_video(video_file, silent_video)
741
- total_dur_s = get_video_duration(video_file)
742
- segments = _build_segments(total_dur_s, MMAUDIO_WINDOW, crossfade_s)
743
  print(f"[MMAudio] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤8 s")
744
 
745
- seg_clip_paths = []
746
- for seg_i, (seg_start, seg_end) in enumerate(segments):
747
- seg_dur = seg_end - seg_start
748
- seg_path = os.path.join(tmp_dir, f"mma_seg_{seg_i}.mp4")
749
- ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
750
- seg_path, vcodec="copy", an=None
751
- ).run(overwrite_output=True, quiet=True)
752
- seg_clip_paths.append(seg_path)
753
 
754
  _mmaudio_gpu_infer._cpu_ctx = {
755
  "segments": segments, "seg_clip_paths": seg_clip_paths,
@@ -762,28 +822,19 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
762
  # ── CPU post-processing ──
763
  outputs = []
764
  for sample_idx, (seg_audios, sr) in enumerate(results):
765
- full_wav = seg_audios[0]
766
- for nw in seg_audios[1:]:
767
- full_wav = _cf_join(full_wav, nw, crossfade_s, crossfade_db, sr)
768
- full_wav = full_wav[:, : int(round(total_dur_s * sr))]
769
 
770
  audio_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.wav")
771
- torchaudio.save(audio_path, torch.from_numpy(full_wav), sr)
772
  video_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.mp4")
773
  mux_video_audio(silent_video, audio_path, video_path)
774
  wav_paths = _save_seg_wavs(seg_audios, tmp_dir, f"mmaudio_{sample_idx}")
775
- seg_meta = {
776
- "segments": segments,
777
- "wav_paths": wav_paths,
778
- "audio_path": audio_path,
779
- "video_path": video_path,
780
- "silent_video": silent_video,
781
- "sr": sr,
782
- "model": "mmaudio",
783
- "crossfade_s": crossfade_s,
784
- "crossfade_db": crossfade_db,
785
- "total_dur_s": total_dur_s,
786
- }
787
  outputs.append((video_path, audio_path, seg_meta))
788
 
789
  return _pad_outputs(outputs)
@@ -816,10 +867,7 @@ def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
816
  guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
817
  """GPU-only HunyuanFoley inference — model loading + feature extraction + denoising.
818
  Returns list of (seg_wavs, sr, text_feats) per sample."""
819
- _hf_path = str(Path("HunyuanVideo-Foley").resolve())
820
- if _hf_path not in sys.path:
821
- sys.path.insert(0, _hf_path)
822
-
823
  from hunyuanvideo_foley.utils.model_utils import denoise_process
824
  from hunyuanvideo_foley.utils.feature_utils import feature_process
825
 
@@ -829,8 +877,9 @@ def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
829
  if seed_val >= 0:
830
  set_global_seed(seed_val)
831
 
832
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
833
- model_size = model_size.lower()
 
834
 
835
  model_dict, cfg = _load_hunyuan_model(device, model_size)
836
 
@@ -882,12 +931,8 @@ def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
882
  wav = wav[:, :seg_samples]
883
  seg_wavs.append(wav)
884
 
885
- _t_hny_elapsed = time.perf_counter() - _t_hny_start
886
- _n_segs_hny = len(segments)
887
- _secs_per_step_hny = _t_hny_elapsed / (_n_segs_hny * int(num_steps)) if _n_segs_hny * int(num_steps) > 0 else 0
888
- print(f"[HunyuanFoley] Inference done: {_n_segs_hny} seg(s) × {int(num_steps)} steps in "
889
- f"{_t_hny_elapsed:.1f}s wall → {_secs_per_step_hny:.3f}s/step "
890
- f"(current constant={HUNYUAN_SECS_PER_STEP})")
891
  results.append((seg_wavs, sr, text_feats))
892
 
893
  # Free GPU memory between samples to prevent VRAM fragmentation
@@ -908,28 +953,21 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
908
  crossfade_db = float(crossfade_db)
909
 
910
  # ── CPU pre-processing (no GPU needed) ──
911
- tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
912
- silent_video = os.path.join(tmp_dir, "silent_input.mp4")
913
- strip_audio_from_video(video_file, silent_video)
914
- total_dur_s = get_video_duration(silent_video)
915
- segments = _build_segments(total_dur_s, HUNYUAN_MAX_DUR, crossfade_s)
916
  print(f"[HunyuanFoley] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤15 s")
917
 
918
  # Pre-extract dummy segment for text feature extraction (ffmpeg, CPU)
919
- dummy_seg_path = os.path.join(tmp_dir, "_seg_dummy.mp4")
920
- ffmpeg.input(silent_video, ss=0, t=min(total_dur_s, HUNYUAN_MAX_DUR)).output(
921
- dummy_seg_path, vcodec="copy", an=None
922
- ).run(overwrite_output=True, quiet=True)
923
 
924
  # Pre-extract all segment clips (ffmpeg, CPU)
925
- seg_clip_paths = []
926
- for seg_i, (seg_start, seg_end) in enumerate(segments):
927
- seg_dur = seg_end - seg_start
928
- seg_path = os.path.join(tmp_dir, f"hny_seg_{seg_i}.mp4")
929
- ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
930
- seg_path, vcodec="copy", an=None
931
- ).run(overwrite_output=True, quiet=True)
932
- seg_clip_paths.append(seg_path)
933
 
934
  _hunyuan_gpu_infer._cpu_ctx = {
935
  "segments": segments, "total_dur_s": total_dur_s,
@@ -942,38 +980,26 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
942
  crossfade_s, crossfade_db, num_samples)
943
 
944
  # ── CPU post-processing (no GPU needed) ──
945
- _hf_path = str(Path("HunyuanVideo-Foley").resolve())
946
- if _hf_path not in sys.path:
947
- sys.path.insert(0, _hf_path)
948
  from hunyuanvideo_foley.utils.media_utils import merge_audio_video
949
 
950
  outputs = []
951
  for sample_idx, (seg_wavs, sr, text_feats) in enumerate(results):
952
- full_wav = seg_wavs[0]
953
- for nw in seg_wavs[1:]:
954
- full_wav = _cf_join(full_wav, nw, crossfade_s, crossfade_db, sr)
955
- full_wav = full_wav[:, : int(round(total_dur_s * sr))]
956
 
957
  audio_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.wav")
958
- torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(full_wav)), sr)
959
  video_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.mp4")
960
  merge_audio_video(audio_path, silent_video, video_path)
961
  wav_paths = _save_seg_wavs(seg_wavs, tmp_dir, f"hunyuan_{sample_idx}")
962
  text_feats_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}_text_feats.pt")
963
  torch.save(text_feats, text_feats_path)
964
- seg_meta = {
965
- "segments": segments,
966
- "wav_paths": wav_paths,
967
- "audio_path": audio_path,
968
- "video_path": video_path,
969
- "silent_video": silent_video,
970
- "sr": sr,
971
- "model": "hunyuan",
972
- "crossfade_s": crossfade_s,
973
- "crossfade_db": crossfade_db,
974
- "total_dur_s": total_dur_s,
975
- "text_feats_path": text_feats_path,
976
- }
977
  outputs.append((video_path, audio_path, seg_meta))
978
 
979
  return _pad_outputs(outputs)
@@ -1003,16 +1029,7 @@ def _splice_and_save(new_wav, seg_idx, meta, slot_id):
1003
  segments = meta["segments"]
1004
  model = meta["model"]
1005
 
1006
- # Stitch (works for both mono and stereo)
1007
- stereo = wavs[0].ndim == 2
1008
- full_wav = wavs[0]
1009
- for nw in wavs[1:]:
1010
- full_wav = _cf_join(full_wav, nw, crossfade_s, crossfade_db, sr)
1011
- n_total = int(round(total_dur_s * sr))
1012
- if stereo:
1013
- full_wav = full_wav[:, :n_total]
1014
- else:
1015
- full_wav = full_wav[:n_total]
1016
 
1017
  # Save new audio — use a new timestamped filename so Gradio / the browser
1018
  # treats it as a genuinely different file and reloads the video player.
@@ -1022,10 +1039,7 @@ def _splice_and_save(new_wav, seg_idx, meta, slot_id):
1022
  # Strip any previous timestamp suffix before adding a new one
1023
  _base_clean = _base.rsplit("_regen_", 1)[0]
1024
  audio_path = os.path.join(tmp_dir, f"{_base_clean}_regen_{_ts}.wav")
1025
- if stereo:
1026
- torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(full_wav)), sr)
1027
- else:
1028
- torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(full_wav)).unsqueeze(0), sr)
1029
 
1030
  # Re-mux into a new video file so the browser is forced to reload it
1031
  _vid_base = os.path.splitext(os.path.basename(meta["video_path"]))[0]
@@ -1033,9 +1047,7 @@ def _splice_and_save(new_wav, seg_idx, meta, slot_id):
1033
  video_path = os.path.join(tmp_dir, f"{_vid_base_clean}_regen_{_ts}.mp4")
1034
  if model == "hunyuan":
1035
  # HunyuanFoley uses its own merge_audio_video
1036
- _hf_path = str(Path("HunyuanVideo-Foley").resolve())
1037
- if _hf_path not in sys.path:
1038
- sys.path.insert(0, _hf_path)
1039
  from hunyuanvideo_foley.utils.media_utils import merge_audio_video
1040
  merge_audio_video(audio_path, silent_video, video_path)
1041
  else:
@@ -1072,13 +1084,9 @@ def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
1072
  seg_start_s, seg_end_s = meta["segments"][seg_idx]
1073
 
1074
  torch.set_grad_enabled(False)
1075
- device = "cuda" if torch.cuda.is_available() else "cpu"
1076
- weight_dtype = torch.bfloat16
1077
-
1078
- _taro_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "TARO")
1079
- if _taro_dir not in sys.path:
1080
- sys.path.insert(0, _taro_dir)
1081
 
 
1082
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
1083
 
1084
  cavp_path = meta.get("cavp_path")
@@ -1095,6 +1103,10 @@ def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
1095
  tmp_dir = tempfile.mkdtemp()
1096
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
1097
  onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
 
 
 
 
1098
 
1099
  model_net, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype)
1100
 
@@ -1143,15 +1155,11 @@ def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1143
  seg_start, seg_end = meta["segments"][seg_idx]
1144
  seg_dur = seg_end - seg_start
1145
 
1146
- _mmaudio_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "MMAudio")
1147
- if _mmaudio_dir not in sys.path:
1148
- sys.path.insert(0, _mmaudio_dir)
1149
-
1150
  from mmaudio.eval_utils import generate, load_video
1151
  from mmaudio.model.flow_matching import FlowMatching
1152
 
1153
- device = "cuda" if torch.cuda.is_available() else "cpu"
1154
- dtype = torch.bfloat16
1155
 
1156
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
1157
  sr = seq_cfg.sampling_rate
@@ -1160,12 +1168,10 @@ def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1160
  seg_path = _regen_mmaudio_gpu._cpu_ctx.get("seg_path")
1161
  if not seg_path:
1162
  # Fallback: extract inside GPU (shouldn't happen)
1163
- silent_video = meta["silent_video"]
1164
- tmp_dir = tempfile.mkdtemp()
1165
- seg_path = os.path.join(tmp_dir, "regen_seg.mp4")
1166
- ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
1167
- seg_path, vcodec="copy", an=None
1168
- ).run(overwrite_output=True, quiet=True)
1169
 
1170
  rng = torch.Generator(device=device)
1171
  rng.manual_seed(random.randint(0, 2**32 - 1))
@@ -1203,12 +1209,11 @@ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
1203
  seg_dur = seg_end - seg_start
1204
 
1205
  # CPU: pre-extract segment clip
1206
- silent_video = meta["silent_video"]
1207
- tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
1208
- seg_path = os.path.join(tmp_dir, "regen_seg.mp4")
1209
- ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
1210
- seg_path, vcodec="copy", an=None
1211
- ).run(overwrite_output=True, quiet=True)
1212
  _regen_mmaudio_gpu._cpu_ctx = {"seg_path": seg_path}
1213
 
1214
  # GPU: inference only
@@ -1243,14 +1248,12 @@ def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1243
  seg_start, seg_end = meta["segments"][seg_idx]
1244
  seg_dur = seg_end - seg_start
1245
 
1246
- _hf_path = str(Path("HunyuanVideo-Foley").resolve())
1247
- if _hf_path not in sys.path:
1248
- sys.path.insert(0, _hf_path)
1249
-
1250
  from hunyuanvideo_foley.utils.model_utils import denoise_process
1251
  from hunyuanvideo_foley.utils.feature_utils import feature_process
1252
 
1253
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
1254
  model_dict, cfg = _load_hunyuan_model(device, model_size)
1255
 
1256
  set_global_seed(random.randint(0, 2**32 - 1))
@@ -1258,12 +1261,10 @@ def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1258
  # Use pre-extracted segment clip from wrapper
1259
  seg_path = _regen_hunyuan_gpu._cpu_ctx.get("seg_path")
1260
  if not seg_path:
1261
- silent_video = meta["silent_video"]
1262
- tmp_dir = tempfile.mkdtemp()
1263
- seg_path = os.path.join(tmp_dir, "regen_seg.mp4")
1264
- ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
1265
- seg_path, vcodec="copy", an=None
1266
- ).run(overwrite_output=True, quiet=True)
1267
 
1268
  text_feats_path = meta.get("text_feats_path")
1269
  if text_feats_path and os.path.exists(text_feats_path):
@@ -1302,12 +1303,11 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1302
  seg_dur = seg_end - seg_start
1303
 
1304
  # CPU: pre-extract segment clip
1305
- silent_video = meta["silent_video"]
1306
- tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
1307
- seg_path = os.path.join(tmp_dir, "regen_seg.mp4")
1308
- ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
1309
- seg_path, vcodec="copy", an=None
1310
- ).run(overwrite_output=True, quiet=True)
1311
  _regen_hunyuan_gpu._cpu_ctx = {"seg_path": seg_path}
1312
 
1313
  # GPU: inference only
@@ -1374,6 +1374,19 @@ def _resample_to_slot_sr(wav: np.ndarray, src_sr: int, dst_sr: int,
1374
  return wav
1375
 
1376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1377
  def xregen_taro(seg_idx, state_json, slot_id,
1378
  seed_val, cfg_scale, num_steps, mode,
1379
  crossfade_s, crossfade_db,
@@ -1381,7 +1394,6 @@ def xregen_taro(seg_idx, state_json, slot_id,
1381
  """Cross-model regen: run TARO inference and splice into *slot_id*."""
1382
  meta = json.loads(state_json)
1383
  seg_idx = int(seg_idx)
1384
- slot_sr = int(meta["sr"])
1385
 
1386
  # Show pending waveform immediately
1387
  pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
@@ -1390,11 +1402,7 @@ def xregen_taro(seg_idx, state_json, slot_id,
1390
  new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
1391
  seed_val, cfg_scale, num_steps, mode,
1392
  crossfade_s, crossfade_db, slot_id)
1393
- slot_wavs = _load_seg_wavs(meta["wav_paths"])
1394
- new_wav = _resample_to_slot_sr(new_wav_raw, TARO_SR, slot_sr, slot_wavs[0])
1395
- video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1396
- new_wav, seg_idx, meta, slot_id
1397
- )
1398
  yield gr.update(value=video_path), gr.update(value=waveform_html)
1399
 
1400
 
@@ -1406,30 +1414,22 @@ def xregen_mmaudio(seg_idx, state_json, slot_id,
1406
  meta = json.loads(state_json)
1407
  seg_idx = int(seg_idx)
1408
  seg_start, seg_end = meta["segments"][seg_idx]
1409
- seg_dur = seg_end - seg_start
1410
- slot_sr = int(meta["sr"])
1411
 
1412
  # Show pending waveform immediately
1413
  pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
1414
  yield gr.update(), gr.update(value=pending_html)
1415
 
1416
- silent_video = meta["silent_video"]
1417
- tmp_dir = tempfile.mkdtemp()
1418
- seg_path = os.path.join(tmp_dir, "xregen_seg.mp4")
1419
- ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
1420
- seg_path, vcodec="copy", an=None
1421
- ).run(overwrite_output=True, quiet=True)
1422
  _regen_mmaudio_gpu._cpu_ctx = {"seg_path": seg_path}
1423
 
1424
  new_wav_raw, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
1425
  prompt, negative_prompt, seed_val,
1426
  cfg_strength, num_steps,
1427
  crossfade_s, crossfade_db, slot_id)
1428
- slot_wavs = _load_seg_wavs(meta["wav_paths"])
1429
- new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr, slot_wavs[0])
1430
- video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1431
- new_wav, seg_idx, meta, slot_id
1432
- )
1433
  yield gr.update(value=video_path), gr.update(value=waveform_html)
1434
 
1435
 
@@ -1442,30 +1442,22 @@ def xregen_hunyuan(seg_idx, state_json, slot_id,
1442
  meta = json.loads(state_json)
1443
  seg_idx = int(seg_idx)
1444
  seg_start, seg_end = meta["segments"][seg_idx]
1445
- seg_dur = seg_end - seg_start
1446
- slot_sr = int(meta["sr"])
1447
 
1448
  # Show pending waveform immediately
1449
  pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
1450
  yield gr.update(), gr.update(value=pending_html)
1451
 
1452
- silent_video = meta["silent_video"]
1453
- tmp_dir = tempfile.mkdtemp()
1454
- seg_path = os.path.join(tmp_dir, "xregen_seg.mp4")
1455
- ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
1456
- seg_path, vcodec="copy", an=None
1457
- ).run(overwrite_output=True, quiet=True)
1458
  _regen_hunyuan_gpu._cpu_ctx = {"seg_path": seg_path}
1459
 
1460
  new_wav_raw, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
1461
  prompt, negative_prompt, seed_val,
1462
  guidance_scale, num_steps, model_size,
1463
  crossfade_s, crossfade_db, slot_id)
1464
- slot_wavs = _load_seg_wavs(meta["wav_paths"])
1465
- new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr, slot_wavs[0])
1466
- video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1467
- new_wav, seg_idx, meta, slot_id
1468
- )
1469
  yield gr.update(value=video_path), gr.update(value=waveform_html)
1470
 
1471
 
@@ -1567,10 +1559,7 @@ def _build_regen_pending_html(segments: list, regen_seg_idx: int, slot_id: str,
1567
  Renders a dark bar with the active segment highlighted in amber + a spinner.
1568
  """
1569
  segs_json = json.dumps(segments)
1570
- seg_colors = ["rgba(100,180,255,0.25)", "rgba(255,160,100,0.25)",
1571
- "rgba(120,220,140,0.25)", "rgba(220,120,220,0.25)",
1572
- "rgba(255,220,80,0.25)", "rgba(80,220,220,0.25)",
1573
- "rgba(255,100,100,0.25)", "rgba(180,255,180,0.25)"]
1574
  active_color = "rgba(255,180,0,0.55)"
1575
  duration = segments[-1][1] if segments else 1.0
1576
 
@@ -1637,11 +1626,7 @@ def _build_waveform_html(audio_path: str, segments: list, slot_id: str,
1637
  audio_url = f"/gradio_api/file={audio_path}"
1638
 
1639
  segs_json = json.dumps(segments)
1640
-
1641
- seg_colors = ["rgba(100,180,255,0.35)", "rgba(255,160,100,0.35)",
1642
- "rgba(120,220,140,0.35)", "rgba(220,120,220,0.35)",
1643
- "rgba(255,220,80,0.35)", "rgba(80,220,220,0.35)",
1644
- "rgba(255,100,100,0.35)", "rgba(180,255,180,0.35)"]
1645
 
1646
  # NOTE: Gradio updates gr.HTML via innerHTML which does NOT execute <script> tags.
1647
  # Solution: put the entire waveform (canvas + JS) inside an <iframe srcdoc="...">.
@@ -1845,11 +1830,8 @@ def _build_waveform_html(audio_path: str, segments: list, slot_id: str,
1845
  </html>"""
1846
 
1847
  # Escape for HTML attribute (srcdoc uses HTML entities)
1848
- import html as _html
1849
- srcdoc = _html.escape(iframe_inner, quote=True)
1850
-
1851
- import html as _html2
1852
- state_escaped = _html2.escape(state_json or "", quote=True)
1853
 
1854
  return f"""
1855
  <div id="wf_container_{slot_id}"
 
8
  HunyuanFoley – text-guided foley via SigLIP2 + Synchformer + CLAP (48 kHz, up to 15 s)
9
  """
10
 
11
+ import html as _html
12
  import os
13
  import sys
14
  import json
 
79
  MAX_SLOTS = 8 # max parallel generation slots shown in UI
80
  MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≤ ~64 s at 8 s/seg)
81
 
82
+ # Segment overlay palette — shared between _build_waveform_html and _build_regen_pending_html
83
+ SEG_COLORS = [
84
+ "rgba(100,180,255,{a})", "rgba(255,160,100,{a})",
85
+ "rgba(120,220,140,{a})", "rgba(220,120,220,{a})",
86
+ "rgba(255,220,80,{a})", "rgba(80,220,220,{a})",
87
+ "rgba(255,100,100,{a})", "rgba(180,255,180,{a})",
88
+ ]
89
+
90
+ # ------------------------------------------------------------------ #
91
+ # Micro-helpers that eliminate repeated boilerplate across the file #
92
+ # ------------------------------------------------------------------ #
93
+
94
+ def _ensure_syspath(subdir: str) -> str:
95
+ """Add *subdir* (relative to app.py) to sys.path if not already present.
96
+ Returns the absolute path for convenience."""
97
+ p = os.path.join(os.path.dirname(os.path.abspath(__file__)), subdir)
98
+ if p not in sys.path:
99
+ sys.path.insert(0, p)
100
+ return p
101
+
102
+
103
+ def _get_device_and_dtype() -> tuple:
104
+ """Return (device, weight_dtype) pair used by all GPU functions."""
105
+ device = "cuda" if torch.cuda.is_available() else "cpu"
106
+ return device, torch.bfloat16
107
+
108
+
109
+ def _extract_segment_clip(silent_video: str, seg_start: float, seg_dur: float,
110
+ output_path: str) -> str:
111
+ """Stream-copy a segment from *silent_video* to *output_path*. Returns *output_path*."""
112
+ ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
113
+ output_path, vcodec="copy", an=None
114
+ ).run(overwrite_output=True, quiet=True)
115
+ return output_path
116
+
117
  # Per-slot reentrant locks — prevent concurrent regens on the same slot from
118
  # producing a race condition where the second regen reads stale state
119
  # (the shared seg_state textbox hasn't been updated yet by the first regen).
 
127
  _SLOT_LOCKS[slot_id] = threading.Lock()
128
  return _SLOT_LOCKS[slot_id]
129
 
130
+ def set_global_seed(seed: int) -> None:
131
  np.random.seed(seed % (2**32))
132
  random.seed(seed)
133
  torch.manual_seed(seed)
134
+ if torch.cuda.is_available():
135
+ torch.cuda.manual_seed(seed)
136
 
137
  def get_random_seed() -> int:
138
  return random.randint(0, 2**32 - 1)
 
142
  probe = ffmpeg.probe(video_path)
143
  return float(probe["format"]["duration"])
144
 
145
+ def strip_audio_from_video(video_path: str, output_path: str) -> None:
146
  """Write a silent copy of *video_path* to *output_path* (stream-copy, no re-encode)."""
147
  ffmpeg.input(video_path).output(output_path, vcodec="copy", an=None).run(
148
  overwrite_output=True, quiet=True
 
167
  return tmp_dir
168
 
169
 
170
+ def _save_seg_wavs(wavs: list[np.ndarray], tmp_dir: str, prefix: str) -> list[str]:
171
  """Save a list of numpy wav arrays to .npy files, return list of paths.
172
  This avoids serialising large float arrays into JSON/HTML data-state."""
173
  paths = []
 
178
  return paths
179
 
180
 
181
+ def _load_seg_wavs(paths: list[str]) -> list[np.ndarray]:
182
  """Load segment wav arrays from .npy file paths."""
183
  return [np.load(p) for p in paths]
184
 
 
227
  from mmaudio.eval_utils import all_model_cfg
228
  from mmaudio.model.networks import get_my_mmaudio
229
  from mmaudio.model.utils.features_utils import FeaturesUtils
 
230
 
231
  model_cfg = all_model_cfg["large_44k_v2"]
232
+ model_cfg.model_path = Path(mmaudio_model_path)
233
+ model_cfg.vae_path = Path(mmaudio_vae_path)
234
+ model_cfg.synchformer_ckpt = Path(mmaudio_synchformer_path)
235
  model_cfg.bigvgan_16k_path = None
236
  seq_cfg = model_cfg.seq_cfg
237
 
 
261
  enable_offload=False, model_size=model_size)
262
 
263
 
264
+ def mux_video_audio(silent_video: str, audio_path: str, output_path: str) -> None:
265
  """Mux a silent video with an audio file into *output_path* (stream-copy video, encode audio)."""
266
  ffmpeg.output(
267
  ffmpeg.input(silent_video),
 
276
  # Used by all three models (TARO, MMAudio, HunyuanFoley). #
277
  # ------------------------------------------------------------------ #
278
 
279
+ def _build_segments(total_dur_s: float, window_s: float, crossfade_s: float) -> list[tuple[float, float]]:
280
  """Return list of (start, end) pairs covering *total_dur_s* with a sliding
281
  window of *window_s* and *crossfade_s* overlap between consecutive segments."""
282
  if total_dur_s <= window_s:
 
496
  return wav[:seg_samples]
497
 
498
 
499
+ def _stitch_wavs(wavs: list[np.ndarray], crossfade_s: float, db_boost: float,
500
  total_dur_s: float, sr: int) -> np.ndarray:
501
+ """Crossfade-join a list of wav arrays and trim to *total_dur_s*.
502
+ Works for both mono (T,) and stereo (C, T) arrays."""
503
  out = wavs[0]
504
  for nw in wavs[1:]:
505
  out = _cf_join(out, nw, crossfade_s, db_boost, sr)
506
+ n = int(round(total_dur_s * sr))
507
+ return out[:, :n] if out.ndim == 2 else out[:n]
508
+
509
+
510
+ def _save_wav(path: str, wav: np.ndarray, sr: int) -> None:
511
+ """Save a numpy wav array (mono or stereo) to *path* via torchaudio."""
512
+ t = torch.from_numpy(np.ascontiguousarray(wav))
513
+ if t.ndim == 1:
514
+ t = t.unsqueeze(0)
515
+ torchaudio.save(path, t, sr)
516
+
517
+
518
+ def _log_inference_timing(label: str, elapsed: float, n_segs: int,
519
+ num_steps: int, constant: float) -> None:
520
+ """Print a standardised inference-timing summary line."""
521
+ total_steps = n_segs * num_steps
522
+ secs_per_step = elapsed / total_steps if total_steps > 0 else 0
523
+ print(f"[{label}] Inference done: {n_segs} seg(s) × {num_steps} steps in "
524
+ f"{elapsed:.1f}s wall → {secs_per_step:.3f}s/step "
525
+ f"(current constant={constant})")
526
+
527
+
528
+ def _build_seg_meta(*, segments, wav_paths, audio_path, video_path,
529
+ silent_video, sr, model, crossfade_s, crossfade_db,
530
+ total_dur_s, **extras) -> dict:
531
+ """Build the seg_meta dict shared by all three generate_* functions.
532
+ Model-specific keys are passed via **extras."""
533
+ meta = {
534
+ "segments": segments,
535
+ "wav_paths": wav_paths,
536
+ "audio_path": audio_path,
537
+ "video_path": video_path,
538
+ "silent_video": silent_video,
539
+ "sr": sr,
540
+ "model": model,
541
+ "crossfade_s": crossfade_s,
542
+ "crossfade_db": crossfade_db,
543
+ "total_dur_s": total_dur_s,
544
+ }
545
+ meta.update(extras)
546
+ return meta
547
+
548
+
549
+ def _cpu_preprocess(video_file: str, model_dur: float,
550
+ crossfade_s: float) -> tuple:
551
+ """Shared CPU pre-processing for all generate_* wrappers.
552
+ Returns (tmp_dir, silent_video, total_dur_s, segments)."""
553
+ tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
554
+ silent_video = os.path.join(tmp_dir, "silent_input.mp4")
555
+ strip_audio_from_video(video_file, silent_video)
556
+ total_dur_s = get_video_duration(video_file)
557
+ segments = _build_segments(total_dur_s, model_dur, crossfade_s)
558
+ return tmp_dir, silent_video, total_dur_s, segments
559
 
560
 
561
  @spaces.GPU(duration=_taro_duration)
 
563
  crossfade_s, crossfade_db, num_samples):
564
  """GPU-only TARO inference — model loading + feature extraction + diffusion.
565
  Returns list of (wavs_list, onset_feats) per sample."""
 
 
566
  seed_val = int(seed_val)
567
  crossfade_s = float(crossfade_s)
568
  num_samples = int(num_samples)
 
570
  seed_val = random.randint(0, 2**32 - 1)
571
 
572
  torch.set_grad_enabled(False)
573
+ device, weight_dtype = _get_device_and_dtype()
 
 
 
 
 
574
 
575
+ _ensure_syspath("TARO")
576
  from TARO.onset_util import extract_onset
577
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
578
 
 
584
  total_dur_s = ctx["total_dur_s"]
585
 
586
  extract_cavp, onset_model = _load_taro_feature_extractors(device)
 
 
587
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
588
+ # Onset features depend only on the video — extract once for all samples
589
+ onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
590
+
591
+ # Free feature extractors before loading the heavier inference models
592
+ del extract_cavp, onset_model
593
+ if torch.cuda.is_available():
594
+ torch.cuda.empty_cache()
595
+
596
+ model, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype)
597
 
598
  results = [] # list of (wavs, onset_feats) per sample
599
  for sample_idx in range(num_samples):
 
607
  results.append((cached["wavs"], cavp_feats, None))
608
  else:
609
  set_global_seed(sample_seed)
 
610
  wavs = []
611
  _t_infer_start = time.perf_counter()
612
  for seg_start_s, seg_end_s in segments:
 
621
  euler_sampler, euler_maruyama_sampler,
622
  )
623
  wavs.append(wav)
624
+ _log_inference_timing("TARO", time.perf_counter() - _t_infer_start,
625
+ len(segments), int(num_steps), TARO_SECS_PER_STEP)
 
 
 
 
626
  with _TARO_CACHE_LOCK:
627
  _TARO_INFERENCE_CACHE[cache_key] = {"wavs": wavs}
628
  while len(_TARO_INFERENCE_CACHE) > _TARO_CACHE_MAXLEN:
 
649
  num_samples = int(num_samples)
650
 
651
  # ── CPU pre-processing (no GPU needed) ──
652
+ tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
653
+ video_file, TARO_MODEL_DUR, crossfade_s)
 
 
 
654
 
655
  # Pass pre-computed CPU results to the GPU function via context
656
  _taro_gpu_infer._cpu_ctx = {
 
671
  for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
672
  final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR)
673
  audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
674
+ _save_wav(audio_path, final_wav, TARO_SR)
675
  video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
676
  mux_video_audio(silent_video, audio_path, video_path)
677
  wav_paths = _save_seg_wavs(wavs, tmp_dir, f"taro_{sample_idx}")
 
681
  if onset_feats is not None:
682
  np.save(onset_path, onset_feats)
683
  first_cavp_saved = True
684
+ seg_meta = _build_seg_meta(
685
+ segments=segments, wav_paths=wav_paths, audio_path=audio_path,
686
+ video_path=video_path, silent_video=silent_video, sr=TARO_SR,
687
+ model="taro", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
688
+ total_dur_s=total_dur_s, cavp_path=cavp_path, onset_path=onset_path,
689
+ )
 
 
 
 
 
 
 
 
690
  outputs.append((video_path, audio_path, seg_meta))
691
 
692
  return _pad_outputs(outputs)
 
718
  cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
719
  """GPU-only MMAudio inference — model loading + flow-matching generation.
720
  Returns list of (seg_audios, sr) per sample."""
721
+ _ensure_syspath("MMAudio")
 
 
 
722
  from mmaudio.eval_utils import generate, load_video
723
  from mmaudio.model.flow_matching import FlowMatching
724
 
 
726
  num_samples = int(num_samples)
727
  crossfade_s = float(crossfade_s)
728
 
729
+ device, dtype = _get_device_and_dtype()
 
730
 
731
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
732
 
 
780
  wav = wav[:, :seg_samples]
781
  seg_audios.append(wav)
782
 
783
+ _log_inference_timing("MMAudio", time.perf_counter() - _t_mma_start,
784
+ len(segments), int(num_steps), MMAUDIO_SECS_PER_STEP)
 
 
 
 
785
  results.append((seg_audios, sr))
786
 
787
  # Free GPU memory between samples to prevent VRAM fragmentation
 
802
  crossfade_db = float(crossfade_db)
803
 
804
  # ── CPU pre-processing ──
805
+ tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
806
+ video_file, MMAUDIO_WINDOW, crossfade_s)
 
 
 
807
  print(f"[MMAudio] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤8 s")
808
 
809
+ seg_clip_paths = [
810
+ _extract_segment_clip(silent_video, s, e - s, os.path.join(tmp_dir, f"mma_seg_{i}.mp4"))
811
+ for i, (s, e) in enumerate(segments)
812
+ ]
 
 
 
 
813
 
814
  _mmaudio_gpu_infer._cpu_ctx = {
815
  "segments": segments, "seg_clip_paths": seg_clip_paths,
 
822
  # ── CPU post-processing ──
823
  outputs = []
824
  for sample_idx, (seg_audios, sr) in enumerate(results):
825
+ full_wav = _stitch_wavs(seg_audios, crossfade_s, crossfade_db, total_dur_s, sr)
 
 
 
826
 
827
  audio_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.wav")
828
+ _save_wav(audio_path, full_wav, sr)
829
  video_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.mp4")
830
  mux_video_audio(silent_video, audio_path, video_path)
831
  wav_paths = _save_seg_wavs(seg_audios, tmp_dir, f"mmaudio_{sample_idx}")
832
+ seg_meta = _build_seg_meta(
833
+ segments=segments, wav_paths=wav_paths, audio_path=audio_path,
834
+ video_path=video_path, silent_video=silent_video, sr=sr,
835
+ model="mmaudio", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
836
+ total_dur_s=total_dur_s,
837
+ )
 
 
 
 
 
 
838
  outputs.append((video_path, audio_path, seg_meta))
839
 
840
  return _pad_outputs(outputs)
 
867
  guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
868
  """GPU-only HunyuanFoley inference — model loading + feature extraction + denoising.
869
  Returns list of (seg_wavs, sr, text_feats) per sample."""
870
+ _ensure_syspath("HunyuanVideo-Foley")
 
 
 
871
  from hunyuanvideo_foley.utils.model_utils import denoise_process
872
  from hunyuanvideo_foley.utils.feature_utils import feature_process
873
 
 
877
  if seed_val >= 0:
878
  set_global_seed(seed_val)
879
 
880
+ device, _ = _get_device_and_dtype()
881
+ device = torch.device(device)
882
+ model_size = model_size.lower()
883
 
884
  model_dict, cfg = _load_hunyuan_model(device, model_size)
885
 
 
931
  wav = wav[:, :seg_samples]
932
  seg_wavs.append(wav)
933
 
934
+ _log_inference_timing("HunyuanFoley", time.perf_counter() - _t_hny_start,
935
+ len(segments), int(num_steps), HUNYUAN_SECS_PER_STEP)
 
 
 
 
936
  results.append((seg_wavs, sr, text_feats))
937
 
938
  # Free GPU memory between samples to prevent VRAM fragmentation
 
953
  crossfade_db = float(crossfade_db)
954
 
955
  # ── CPU pre-processing (no GPU needed) ──
956
+ tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
957
+ video_file, HUNYUAN_MAX_DUR, crossfade_s)
 
 
 
958
  print(f"[HunyuanFoley] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤15 s")
959
 
960
  # Pre-extract dummy segment for text feature extraction (ffmpeg, CPU)
961
+ dummy_seg_path = _extract_segment_clip(
962
+ silent_video, 0, min(total_dur_s, HUNYUAN_MAX_DUR),
963
+ os.path.join(tmp_dir, "_seg_dummy.mp4"),
964
+ )
965
 
966
  # Pre-extract all segment clips (ffmpeg, CPU)
967
+ seg_clip_paths = [
968
+ _extract_segment_clip(silent_video, s, e - s, os.path.join(tmp_dir, f"hny_seg_{i}.mp4"))
969
+ for i, (s, e) in enumerate(segments)
970
+ ]
 
 
 
 
971
 
972
  _hunyuan_gpu_infer._cpu_ctx = {
973
  "segments": segments, "total_dur_s": total_dur_s,
 
980
  crossfade_s, crossfade_db, num_samples)
981
 
982
  # ── CPU post-processing (no GPU needed) ──
983
+ _ensure_syspath("HunyuanVideo-Foley")
 
 
984
  from hunyuanvideo_foley.utils.media_utils import merge_audio_video
985
 
986
  outputs = []
987
  for sample_idx, (seg_wavs, sr, text_feats) in enumerate(results):
988
+ full_wav = _stitch_wavs(seg_wavs, crossfade_s, crossfade_db, total_dur_s, sr)
 
 
 
989
 
990
  audio_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.wav")
991
+ _save_wav(audio_path, full_wav, sr)
992
  video_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.mp4")
993
  merge_audio_video(audio_path, silent_video, video_path)
994
  wav_paths = _save_seg_wavs(seg_wavs, tmp_dir, f"hunyuan_{sample_idx}")
995
  text_feats_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}_text_feats.pt")
996
  torch.save(text_feats, text_feats_path)
997
+ seg_meta = _build_seg_meta(
998
+ segments=segments, wav_paths=wav_paths, audio_path=audio_path,
999
+ video_path=video_path, silent_video=silent_video, sr=sr,
1000
+ model="hunyuan", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
1001
+ total_dur_s=total_dur_s, text_feats_path=text_feats_path,
1002
+ )
 
 
 
 
 
 
 
1003
  outputs.append((video_path, audio_path, seg_meta))
1004
 
1005
  return _pad_outputs(outputs)
 
1029
  segments = meta["segments"]
1030
  model = meta["model"]
1031
 
1032
+ full_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, sr)
 
 
 
 
 
 
 
 
 
1033
 
1034
  # Save new audio — use a new timestamped filename so Gradio / the browser
1035
  # treats it as a genuinely different file and reloads the video player.
 
1039
  # Strip any previous timestamp suffix before adding a new one
1040
  _base_clean = _base.rsplit("_regen_", 1)[0]
1041
  audio_path = os.path.join(tmp_dir, f"{_base_clean}_regen_{_ts}.wav")
1042
+ _save_wav(audio_path, full_wav, sr)
 
 
 
1043
 
1044
  # Re-mux into a new video file so the browser is forced to reload it
1045
  _vid_base = os.path.splitext(os.path.basename(meta["video_path"]))[0]
 
1047
  video_path = os.path.join(tmp_dir, f"{_vid_base_clean}_regen_{_ts}.mp4")
1048
  if model == "hunyuan":
1049
  # HunyuanFoley uses its own merge_audio_video
1050
+ _ensure_syspath("HunyuanVideo-Foley")
 
 
1051
  from hunyuanvideo_foley.utils.media_utils import merge_audio_video
1052
  merge_audio_video(audio_path, silent_video, video_path)
1053
  else:
 
1084
  seg_start_s, seg_end_s = meta["segments"][seg_idx]
1085
 
1086
  torch.set_grad_enabled(False)
1087
+ device, weight_dtype = _get_device_and_dtype()
 
 
 
 
 
1088
 
1089
+ _ensure_syspath("TARO")
1090
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
1091
 
1092
  cavp_path = meta.get("cavp_path")
 
1103
  tmp_dir = tempfile.mkdtemp()
1104
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
1105
  onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
1106
+ # Free feature extractors before loading inference models
1107
+ del extract_cavp, onset_model
1108
+ if torch.cuda.is_available():
1109
+ torch.cuda.empty_cache()
1110
 
1111
  model_net, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype)
1112
 
 
1155
  seg_start, seg_end = meta["segments"][seg_idx]
1156
  seg_dur = seg_end - seg_start
1157
 
1158
+ _ensure_syspath("MMAudio")
 
 
 
1159
  from mmaudio.eval_utils import generate, load_video
1160
  from mmaudio.model.flow_matching import FlowMatching
1161
 
1162
+ device, dtype = _get_device_and_dtype()
 
1163
 
1164
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
1165
  sr = seq_cfg.sampling_rate
 
1168
  seg_path = _regen_mmaudio_gpu._cpu_ctx.get("seg_path")
1169
  if not seg_path:
1170
  # Fallback: extract inside GPU (shouldn't happen)
1171
+ seg_path = _extract_segment_clip(
1172
+ meta["silent_video"], seg_start, seg_dur,
1173
+ os.path.join(tempfile.mkdtemp(), "regen_seg.mp4"),
1174
+ )
 
 
1175
 
1176
  rng = torch.Generator(device=device)
1177
  rng.manual_seed(random.randint(0, 2**32 - 1))
 
1209
  seg_dur = seg_end - seg_start
1210
 
1211
  # CPU: pre-extract segment clip
1212
+ tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
1213
+ seg_path = _extract_segment_clip(
1214
+ meta["silent_video"], seg_start, seg_dur,
1215
+ os.path.join(tmp_dir, "regen_seg.mp4"),
1216
+ )
 
1217
  _regen_mmaudio_gpu._cpu_ctx = {"seg_path": seg_path}
1218
 
1219
  # GPU: inference only
 
1248
  seg_start, seg_end = meta["segments"][seg_idx]
1249
  seg_dur = seg_end - seg_start
1250
 
1251
+ _ensure_syspath("HunyuanVideo-Foley")
 
 
 
1252
  from hunyuanvideo_foley.utils.model_utils import denoise_process
1253
  from hunyuanvideo_foley.utils.feature_utils import feature_process
1254
 
1255
+ device, _ = _get_device_and_dtype()
1256
+ device = torch.device(device)
1257
  model_dict, cfg = _load_hunyuan_model(device, model_size)
1258
 
1259
  set_global_seed(random.randint(0, 2**32 - 1))
 
1261
  # Use pre-extracted segment clip from wrapper
1262
  seg_path = _regen_hunyuan_gpu._cpu_ctx.get("seg_path")
1263
  if not seg_path:
1264
+ seg_path = _extract_segment_clip(
1265
+ meta["silent_video"], seg_start, seg_dur,
1266
+ os.path.join(tempfile.mkdtemp(), "regen_seg.mp4"),
1267
+ )
 
 
1268
 
1269
  text_feats_path = meta.get("text_feats_path")
1270
  if text_feats_path and os.path.exists(text_feats_path):
 
1303
  seg_dur = seg_end - seg_start
1304
 
1305
  # CPU: pre-extract segment clip
1306
+ tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
1307
+ seg_path = _extract_segment_clip(
1308
+ meta["silent_video"], seg_start, seg_dur,
1309
+ os.path.join(tmp_dir, "regen_seg.mp4"),
1310
+ )
 
1311
  _regen_hunyuan_gpu._cpu_ctx = {"seg_path": seg_path}
1312
 
1313
  # GPU: inference only
 
1374
  return wav
1375
 
1376
 
1377
+ def _xregen_splice(new_wav_raw: np.ndarray, src_sr: int,
1378
+ meta: dict, seg_idx: int, slot_id: str) -> tuple:
1379
+ """Shared epilogue for all xregen_* functions: resample → splice → save.
1380
+ Returns (video_path, waveform_html)."""
1381
+ slot_sr = int(meta["sr"])
1382
+ slot_wavs = _load_seg_wavs(meta["wav_paths"])
1383
+ new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr, slot_wavs[0])
1384
+ video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1385
+ new_wav, seg_idx, meta, slot_id
1386
+ )
1387
+ return video_path, waveform_html
1388
+
1389
+
1390
  def xregen_taro(seg_idx, state_json, slot_id,
1391
  seed_val, cfg_scale, num_steps, mode,
1392
  crossfade_s, crossfade_db,
 
1394
  """Cross-model regen: run TARO inference and splice into *slot_id*."""
1395
  meta = json.loads(state_json)
1396
  seg_idx = int(seg_idx)
 
1397
 
1398
  # Show pending waveform immediately
1399
  pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
 
1402
  new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
1403
  seed_val, cfg_scale, num_steps, mode,
1404
  crossfade_s, crossfade_db, slot_id)
1405
+ video_path, waveform_html = _xregen_splice(new_wav_raw, TARO_SR, meta, seg_idx, slot_id)
 
 
 
 
1406
  yield gr.update(value=video_path), gr.update(value=waveform_html)
1407
 
1408
 
 
1414
  meta = json.loads(state_json)
1415
  seg_idx = int(seg_idx)
1416
  seg_start, seg_end = meta["segments"][seg_idx]
 
 
1417
 
1418
  # Show pending waveform immediately
1419
  pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
1420
  yield gr.update(), gr.update(value=pending_html)
1421
 
1422
+ seg_path = _extract_segment_clip(
1423
+ meta["silent_video"], seg_start, seg_end - seg_start,
1424
+ os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1425
+ )
 
 
1426
  _regen_mmaudio_gpu._cpu_ctx = {"seg_path": seg_path}
1427
 
1428
  new_wav_raw, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
1429
  prompt, negative_prompt, seed_val,
1430
  cfg_strength, num_steps,
1431
  crossfade_s, crossfade_db, slot_id)
1432
+ video_path, waveform_html = _xregen_splice(new_wav_raw, src_sr, meta, seg_idx, slot_id)
 
 
 
 
1433
  yield gr.update(value=video_path), gr.update(value=waveform_html)
1434
 
1435
 
 
1442
  meta = json.loads(state_json)
1443
  seg_idx = int(seg_idx)
1444
  seg_start, seg_end = meta["segments"][seg_idx]
 
 
1445
 
1446
  # Show pending waveform immediately
1447
  pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
1448
  yield gr.update(), gr.update(value=pending_html)
1449
 
1450
+ seg_path = _extract_segment_clip(
1451
+ meta["silent_video"], seg_start, seg_end - seg_start,
1452
+ os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1453
+ )
 
 
1454
  _regen_hunyuan_gpu._cpu_ctx = {"seg_path": seg_path}
1455
 
1456
  new_wav_raw, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
1457
  prompt, negative_prompt, seed_val,
1458
  guidance_scale, num_steps, model_size,
1459
  crossfade_s, crossfade_db, slot_id)
1460
+ video_path, waveform_html = _xregen_splice(new_wav_raw, src_sr, meta, seg_idx, slot_id)
 
 
 
 
1461
  yield gr.update(value=video_path), gr.update(value=waveform_html)
1462
 
1463
 
 
1559
  Renders a dark bar with the active segment highlighted in amber + a spinner.
1560
  """
1561
  segs_json = json.dumps(segments)
1562
+ seg_colors = [c.format(a="0.25") for c in SEG_COLORS]
 
 
 
1563
  active_color = "rgba(255,180,0,0.55)"
1564
  duration = segments[-1][1] if segments else 1.0
1565
 
 
1626
  audio_url = f"/gradio_api/file={audio_path}"
1627
 
1628
  segs_json = json.dumps(segments)
1629
+ seg_colors = [c.format(a="0.35") for c in SEG_COLORS]
 
 
 
 
1630
 
1631
  # NOTE: Gradio updates gr.HTML via innerHTML which does NOT execute <script> tags.
1632
  # Solution: put the entire waveform (canvas + JS) inside an <iframe srcdoc="...">.
 
1830
  </html>"""
1831
 
1832
  # Escape for HTML attribute (srcdoc uses HTML entities)
1833
+ srcdoc = _html.escape(iframe_inner, quote=True)
1834
+ state_escaped = _html.escape(state_json or "", quote=True)
 
 
 
1835
 
1836
  return f"""
1837
  <div id="wf_container_{slot_id}"