BoxOfColors Claude Opus 4.6 commited on
Commit
a4226e1
·
1 Parent(s): 1b4e6b8

perf: comprehensive optimization pass — caching, dedup, cleanup

Browse files

- Complete .npy migration: segment wavs stored as file paths instead of
serialized arrays, removing all .tolist() calls
- Cache CAVP+onset features (TARO) and text features (HunyuanFoley) in
seg_meta so regen skips re-extraction (~5-7s saved per TARO regen,
~2-3s per HunyuanFoley regen)
- Extract shared model loading into helper functions (_load_taro_models,
_load_taro_feature_extractors, _load_mmaudio_models, _load_hunyuan_model)
to deduplicate generate/regen code
- Batch ffmpeg segment extraction before sample loop (MMAudio, HunyuanFoley)
so clips are extracted once and reused across samples
- Add temp directory registry with automatic cleanup of old dirs (keeps
last 10, removes older ones on new generation)
- Make TARO inference cache thread-safe with threading.Lock

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +232 -220
app.py CHANGED
@@ -11,7 +11,6 @@ Supported models
11
  import os
12
  import sys
13
  import json
14
- import base64
15
  import tempfile
16
  import random
17
  import threading
@@ -111,6 +110,121 @@ def strip_audio_from_video(video_path: str, output_path: str):
111
  overwrite_output=True, quiet=True
112
  )
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  def mux_video_audio(silent_video: str, audio_path: str, output_path: str):
115
  """Mux a silent video with an audio file into *output_path* (stream-copy video, encode audio)."""
116
  ffmpeg.output(
@@ -193,6 +307,8 @@ GPU_DURATION_CAP = 300 # hard cap per call — never reserve more than t
193
 
194
  _TARO_CACHE_MAXLEN = 16 # evict oldest entries beyond this limit
195
  _TARO_INFERENCE_CACHE: dict = {} # keyed by (video_file, seed, cfg, steps, mode, crossfade_s)
 
 
196
 
197
 
198
  def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: float) -> int:
@@ -308,50 +424,14 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
308
  if _taro_dir not in sys.path:
309
  sys.path.insert(0, _taro_dir)
310
 
311
- # Imports are inside the GPU context so the Space only pays for GPU time here
312
- from TARO.cavp_util import Extract_CAVP_Features
313
- from TARO.onset_util import VideoOnsetNet, extract_onset
314
- from TARO.models import MMDiT
315
- from TARO.samplers import euler_sampler, euler_maruyama_sampler
316
- from diffusers import AutoencoderKL
317
- from transformers import SpeechT5HifiGan
318
-
319
- # -- Load CAVP encoder (uses checkpoint from our HF repo) --
320
- extract_cavp = Extract_CAVP_Features(
321
- device=device,
322
- config_path="TARO/cavp/cavp.yaml",
323
- ckpt_path=cavp_ckpt_path,
324
- )
325
-
326
- # -- Load onset detection model --
327
- # Key remapping matches the original TARO infer.py exactly
328
- raw_sd = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
329
- onset_sd = {}
330
- for k, v in raw_sd.items():
331
- if "model.net.model" in k:
332
- k = k.replace("model.net.model", "net.model")
333
- elif "model.fc." in k:
334
- k = k.replace("model.fc", "fc")
335
- onset_sd[k] = v
336
- onset_model = VideoOnsetNet(pretrained=False).to(device)
337
- onset_model.load_state_dict(onset_sd)
338
- onset_model.eval()
339
-
340
- # -- Load TARO MMDiT --
341
- # Architecture params match TARO/train.py: adm_in_channels=120 (onset dim),
342
- # z_dims=[768] (CAVP dim), encoder_depth=4
343
- model = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
344
- model.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"])
345
- model.eval().to(weight_dtype)
346
 
347
- # -- Load AudioLDM2 VAE + vocoder only (saves ~3-4 GB vs loading the full pipeline) --
348
- # TARO only needs VAE and vocoder for decoding; the text encoder and UNet are never used.
349
- vae = AutoencoderKL.from_pretrained("cvssp/audioldm2", subfolder="vae").to(device).eval()
350
- vocoder = SpeechT5HifiGan.from_pretrained("cvssp/audioldm2", subfolder="vocoder").to(device)
351
- latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
352
 
353
  # -- Prepare silent video (shared across all samples) --
354
- tmp_dir = tempfile.mkdtemp()
355
  silent_video = os.path.join(tmp_dir, "silent_input.mp4")
356
  strip_audio_from_video(video_file, silent_video)
357
 
@@ -366,9 +446,11 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
366
  sample_seed = seed_val + sample_idx
367
  cache_key = (video_file, sample_seed, float(cfg_scale), int(num_steps), mode, crossfade_s)
368
 
369
- if cache_key in _TARO_INFERENCE_CACHE:
 
 
370
  print(f"[TARO] Sample {sample_idx+1}: cache hit.")
371
- wavs = _TARO_INFERENCE_CACHE[cache_key]["wavs"]
372
  else:
373
  set_global_seed(sample_seed)
374
  onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
@@ -392,19 +474,25 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
392
  print(f"[TARO] Inference done: {_n_segs} seg(s) × {int(num_steps)} steps in "
393
  f"{_t_infer_elapsed:.1f}s wall → {_secs_per_step:.3f}s/step "
394
  f"(current constant={TARO_SECS_PER_STEP})")
395
- _TARO_INFERENCE_CACHE[cache_key] = {"wavs": wavs}
396
- # Evict oldest entries if cache exceeds max size
397
- while len(_TARO_INFERENCE_CACHE) > _TARO_CACHE_MAXLEN:
398
- _TARO_INFERENCE_CACHE.pop(next(iter(_TARO_INFERENCE_CACHE)))
399
 
400
  final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR)
401
  audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
402
  torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(final_wav)).unsqueeze(0), TARO_SR)
403
  video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
404
  mux_video_audio(silent_video, audio_path, video_path)
 
 
 
 
 
 
405
  seg_meta = {
406
  "segments": segments,
407
- "wavs": [w.copy() for w in wavs],
408
  "audio_path": audio_path,
409
  "video_path": video_path,
410
  "silent_video": silent_video,
@@ -413,6 +501,8 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
413
  "crossfade_s": crossfade_s,
414
  "crossfade_db": crossfade_db,
415
  "total_dur_s": total_dur_s,
 
 
416
  }
417
  outputs.append((video_path, audio_path, seg_meta))
418
 
@@ -456,10 +546,8 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
456
  if _mmaudio_dir not in sys.path:
457
  sys.path.insert(0, _mmaudio_dir)
458
 
459
- from mmaudio.eval_utils import all_model_cfg, generate, load_video, make_video
460
  from mmaudio.model.flow_matching import FlowMatching
461
- from mmaudio.model.networks import get_my_mmaudio
462
- from mmaudio.model.utils.features_utils import FeaturesUtils
463
 
464
  seed_val = int(seed_val)
465
  num_samples = int(num_samples)
@@ -469,33 +557,9 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
469
  device = "cuda" if torch.cuda.is_available() else "cpu"
470
  dtype = torch.bfloat16
471
 
472
- # Use large_44k_v2 variant; override paths to our consolidated HF checkpoint repo
473
- model_cfg = all_model_cfg["large_44k_v2"]
474
- # Patch checkpoint paths to our downloaded files
475
- from pathlib import Path as _Path
476
- model_cfg.model_path = _Path(mmaudio_model_path)
477
- model_cfg.vae_path = _Path(mmaudio_vae_path)
478
- model_cfg.synchformer_ckpt = _Path(mmaudio_synchformer_path)
479
- # large_44k_v2 is 44k mode, no BigVGAN vocoder needed
480
- model_cfg.bigvgan_16k_path = None
481
- seq_cfg = model_cfg.seq_cfg # CONFIG_44K: 8 s, 44100 Hz
482
-
483
- # Load network weights
484
- net = get_my_mmaudio(model_cfg.model_name).to(device, dtype).eval()
485
- net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True))
486
-
487
- # Load feature utilities: CLIP (auto-downloaded from apple/DFN5B-CLIP-ViT-H-14-384),
488
- # Synchformer (from our repo), VAE (from our repo), no BigVGAN for 44k mode
489
- feature_utils = FeaturesUtils(
490
- tod_vae_ckpt=str(model_cfg.vae_path),
491
- synchformer_ckpt=str(model_cfg.synchformer_ckpt),
492
- enable_conditions=True,
493
- mode=model_cfg.mode, # "44k"
494
- bigvgan_vocoder_ckpt=None,
495
- need_vae_encoder=False,
496
- ).to(device, dtype).eval()
497
 
498
- tmp_dir = tempfile.mkdtemp()
499
  outputs = []
500
 
501
  # Strip original audio so the muxed output only contains the generated track
@@ -510,6 +574,16 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
510
 
511
  sr = seq_cfg.sampling_rate # 44100
512
 
 
 
 
 
 
 
 
 
 
 
513
  for sample_idx in range(num_samples):
514
  rng = torch.Generator(device=device)
515
  if seed_val >= 0:
@@ -522,11 +596,7 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
522
 
523
  for seg_i, (seg_start, seg_end) in enumerate(segments):
524
  seg_dur = seg_end - seg_start
525
- # Trim a clean video clip for this segment (stream-copy, no re-encode)
526
- seg_path = os.path.join(tmp_dir, f"mma_seg_{sample_idx}_{seg_i}.mp4")
527
- ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
528
- seg_path, vcodec="copy", an=None
529
- ).run(overwrite_output=True, quiet=True)
530
 
531
  fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
532
  video_info = load_video(seg_path, seg_dur)
@@ -575,9 +645,10 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
575
 
576
  video_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.mp4")
577
  mux_video_audio(silent_video, audio_path, video_path)
 
578
  seg_meta = {
579
  "segments": segments,
580
- "wavs": [w.copy() for w in seg_audios],
581
  "audio_path": audio_path,
582
  "video_path": video_path,
583
  "silent_video": silent_video,
@@ -631,7 +702,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
631
  if _hf_path not in sys.path:
632
  sys.path.insert(0, _hf_path)
633
 
634
- from hunyuanvideo_foley.utils.model_utils import load_model, denoise_process
635
  from hunyuanvideo_foley.utils.feature_utils import feature_process
636
  from hunyuanvideo_foley.utils.media_utils import merge_audio_video
637
 
@@ -645,25 +716,9 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
645
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
646
  model_size = model_size.lower() # "xl" or "xxl"
647
 
648
- config_map = {
649
- "xl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xl.yaml",
650
- "xxl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xxl.yaml",
651
- }
652
- config_path = config_map.get(model_size, config_map["xxl"])
653
 
654
- # hf_hub_download preserves the repo subfolder, so weights land in
655
- # HUNYUAN_MODEL_DIR/HunyuanVideo-Foley/ — pass that as the weights dir.
656
- hunyuan_weights_dir = str(HUNYUAN_MODEL_DIR / "HunyuanVideo-Foley")
657
- print(f"[HunyuanFoley] Loading {model_size.upper()} model from {hunyuan_weights_dir}")
658
- model_dict, cfg = load_model(
659
- hunyuan_weights_dir,
660
- config_path,
661
- device,
662
- enable_offload=False,
663
- model_size=model_size,
664
- )
665
-
666
- tmp_dir = tempfile.mkdtemp()
667
  outputs = []
668
 
669
  # Strip original audio so the muxed output only contains the generated track
@@ -690,6 +745,16 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
690
  neg_prompt=negative_prompt if negative_prompt else None,
691
  )
692
 
 
 
 
 
 
 
 
 
 
 
693
  # Generate audio per segment, then stitch
694
  for sample_idx in range(num_samples):
695
  seg_wavs = []
@@ -697,10 +762,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
697
  _t_hny_start = time.perf_counter()
698
  for seg_i, (seg_start, seg_end) in enumerate(segments):
699
  seg_dur = seg_end - seg_start
700
- seg_path = os.path.join(tmp_dir, f"seg_{sample_idx}_{seg_i}.mp4")
701
- ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
702
- seg_path, vcodec="copy", an=None
703
- ).run(overwrite_output=True, quiet=True)
704
 
705
  # feature_process returns (visual_feats, text_feats, audio_len).
706
  # We discard the returned text_feats (_) and use the pre-computed
@@ -750,9 +812,13 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
750
  torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(full_wav)), sr)
751
  video_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.mp4")
752
  merge_audio_video(audio_path, silent_video, video_path)
 
 
 
 
753
  seg_meta = {
754
  "segments": segments,
755
- "wavs": [w.copy() for w in seg_wavs],
756
  "audio_path": audio_path,
757
  "video_path": video_path,
758
  "silent_video": silent_video,
@@ -761,6 +827,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
761
  "crossfade_s": crossfade_s,
762
  "crossfade_db": crossfade_db,
763
  "total_dur_s": total_dur_s,
 
764
  }
765
  outputs.append((video_path, audio_path, seg_meta))
766
 
@@ -781,7 +848,7 @@ def _splice_and_save(new_wav, seg_idx, meta, slot_id):
781
  """Replace wavs[seg_idx] with new_wav, re-stitch, re-save, re-mux.
782
  Returns (video_path, audio_path, updated_meta, waveform_html).
783
  """
784
- wavs = [w.copy() for w in meta["wavs"]]
785
  wavs[seg_idx]= new_wav
786
  crossfade_s = float(meta["crossfade_s"])
787
  crossfade_db = float(meta["crossfade_db"])
@@ -830,15 +897,14 @@ def _splice_and_save(new_wav, seg_idx, meta, slot_id):
830
  else:
831
  mux_video_audio(silent_video, audio_path, video_path)
832
 
 
 
833
  updated_meta = dict(meta)
834
- updated_meta["wavs"] = wavs
835
  updated_meta["audio_path"] = audio_path
836
  updated_meta["video_path"] = video_path
837
 
838
- # Serialise for embedding in waveform HTML data-state (wavs as lists for JSON)
839
- _serialised_meta = dict(updated_meta)
840
- _serialised_meta["wavs"] = [w.tolist() for w in wavs]
841
- state_json_new = json.dumps(_serialised_meta)
842
 
843
  waveform_html = _build_waveform_html(audio_path, segments, slot_id, "",
844
  state_json=state_json_new,
@@ -872,36 +938,27 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
872
  if _taro_dir not in sys.path:
873
  sys.path.insert(0, _taro_dir)
874
 
875
- from TARO.cavp_util import Extract_CAVP_Features
876
- from TARO.onset_util import VideoOnsetNet, extract_onset
877
- from TARO.models import MMDiT
878
- from TARO.samplers import euler_sampler, euler_maruyama_sampler
879
- from diffusers import AutoencoderKL
880
- from transformers import SpeechT5HifiGan
881
 
882
- silent_video = meta["silent_video"]
883
- tmp_dir = tempfile.mkdtemp()
 
 
 
 
 
 
 
 
 
 
 
 
 
884
 
885
- extract_cavp = Extract_CAVP_Features(device=device, config_path="TARO/cavp/cavp.yaml", ckpt_path=cavp_ckpt_path)
886
- raw_sd = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
887
- onset_sd = {}
888
- for k, v in raw_sd.items():
889
- if "model.net.model" in k: k = k.replace("model.net.model", "net.model")
890
- elif "model.fc." in k: k = k.replace("model.fc", "fc")
891
- onset_sd[k] = v
892
- onset_model = VideoOnsetNet(pretrained=False).to(device)
893
- onset_model.load_state_dict(onset_sd)
894
- onset_model.eval()
895
- model_net = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
896
- model_net.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"])
897
- model_net.eval().to(weight_dtype)
898
- vae = AutoencoderKL.from_pretrained("cvssp/audioldm2", subfolder="vae").to(device).eval()
899
- vocoder = SpeechT5HifiGan.from_pretrained("cvssp/audioldm2", subfolder="vocoder").to(device)
900
- latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
901
 
902
- cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
903
  set_global_seed(random.randint(0, 2**32 - 1))
904
- onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
905
 
906
  new_wav = _taro_infer_segment(
907
  model_net, vae, vocoder, cavp_feats, onset_feats,
@@ -910,14 +967,9 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
910
  euler_sampler, euler_maruyama_sampler,
911
  )
912
 
913
- # Deserialise stored wavs from lists back to numpy arrays (json roundtrip)
914
- stored_wavs = [np.array(w, dtype=np.float32) for w in meta["wavs"]]
915
- meta["wavs"] = stored_wavs
916
-
917
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
918
  new_wav, seg_idx, meta, slot_id
919
  )
920
- updated_meta["wavs"] = [w.tolist() for w in updated_meta["wavs"]]
921
  return video_path, audio_path, json.dumps(updated_meta), waveform_html
922
 
923
 
@@ -944,30 +996,13 @@ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
944
  if _mmaudio_dir not in sys.path:
945
  sys.path.insert(0, _mmaudio_dir)
946
 
947
- from mmaudio.eval_utils import all_model_cfg, generate, load_video
948
- from mmaudio.model.flow_matching import FlowMatching
949
- from mmaudio.model.networks import get_my_mmaudio
950
- from mmaudio.model.utils.features_utils import FeaturesUtils
951
- from pathlib import Path as _Path
952
 
953
  device = "cuda" if torch.cuda.is_available() else "cpu"
954
  dtype = torch.bfloat16
955
 
956
- model_cfg = all_model_cfg["large_44k_v2"]
957
- model_cfg.model_path = _Path(mmaudio_model_path)
958
- model_cfg.vae_path = _Path(mmaudio_vae_path)
959
- model_cfg.synchformer_ckpt = _Path(mmaudio_synchformer_path)
960
- model_cfg.bigvgan_16k_path = None
961
- seq_cfg = model_cfg.seq_cfg
962
-
963
- net = get_my_mmaudio(model_cfg.model_name).to(device, dtype).eval()
964
- net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True))
965
- feature_utils = FeaturesUtils(
966
- tod_vae_ckpt=str(model_cfg.vae_path),
967
- synchformer_ckpt=str(model_cfg.synchformer_ckpt),
968
- enable_conditions=True, mode=model_cfg.mode,
969
- bigvgan_vocoder_ckpt=None, need_vae_encoder=False,
970
- ).to(device, dtype).eval()
971
 
972
  sr = seq_cfg.sampling_rate
973
  silent_video = meta["silent_video"]
@@ -999,14 +1034,11 @@ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
999
  seg_samples = int(round(seg_dur * sr))
1000
  new_wav = new_wav[:, :seg_samples]
1001
 
1002
- stored_wavs = [np.array(w, dtype=np.float32) for w in meta["wavs"]]
1003
- meta["wavs"] = stored_wavs
1004
  meta["sr"] = sr
1005
 
1006
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1007
  new_wav, seg_idx, meta, slot_id
1008
  )
1009
- updated_meta["wavs"] = [w.tolist() for w in updated_meta["wavs"]]
1010
  return video_path, audio_path, json.dumps(updated_meta), waveform_html
1011
 
1012
 
@@ -1035,19 +1067,11 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1035
  if _hf_path not in sys.path:
1036
  sys.path.insert(0, _hf_path)
1037
 
1038
- from hunyuanvideo_foley.utils.model_utils import load_model, denoise_process
1039
- from hunyuanvideo_foley.utils.feature_utils import feature_process
1040
 
1041
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1042
- model_size = model_size.lower()
1043
- config_map = {
1044
- "xl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xl.yaml",
1045
- "xxl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xxl.yaml",
1046
- }
1047
- config_path = config_map.get(model_size, config_map["xxl"])
1048
- hunyuan_weights_dir = str(HUNYUAN_MODEL_DIR / "HunyuanVideo-Foley")
1049
- model_dict, cfg = load_model(hunyuan_weights_dir, config_path, device,
1050
- enable_offload=False, model_size=model_size)
1051
 
1052
  set_global_seed(random.randint(0, 2**32 - 1))
1053
 
@@ -1058,10 +1082,19 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1058
  seg_path, vcodec="copy", an=None
1059
  ).run(overwrite_output=True, quiet=True)
1060
 
1061
- visual_feats, text_feats, seg_audio_len = feature_process(
1062
- seg_path, prompt if prompt else "", model_dict, cfg,
1063
- neg_prompt=negative_prompt if negative_prompt else None,
1064
- )
 
 
 
 
 
 
 
 
 
1065
  audio_batch, sr = denoise_process(
1066
  visual_feats, text_feats, seg_audio_len, model_dict, cfg,
1067
  guidance_scale=float(guidance_scale),
@@ -1072,14 +1105,11 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1072
  seg_samples = int(round(seg_dur * sr))
1073
  new_wav = new_wav[:, :seg_samples]
1074
 
1075
- stored_wavs = [np.array(w, dtype=np.float32) for w in meta["wavs"]]
1076
- meta["wavs"] = stored_wavs
1077
  meta["sr"] = sr
1078
 
1079
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1080
  new_wav, seg_idx, meta, slot_id
1081
  )
1082
- updated_meta["wavs"] = [w.tolist() for w in updated_meta["wavs"]]
1083
  return video_path, audio_path, json.dumps(updated_meta), waveform_html
1084
 
1085
 
@@ -1093,7 +1123,7 @@ def _pad_outputs(outputs: list) -> list:
1093
  Each entry in *outputs* must be a (video_path, audio_path, seg_meta) tuple where
1094
  seg_meta = {"segments": [...], "audio_path": str, "video_path": str,
1095
  "sr": int, "model": str, "crossfade_s": float,
1096
- "crossfade_db": float, "wavs": list[np.ndarray]}
1097
  """
1098
  result = []
1099
  for i in range(MAX_SLOTS):
@@ -1180,9 +1210,9 @@ def _build_waveform_html(audio_path: str, segments: list, slot_id: str,
1180
  if not audio_path or not os.path.exists(audio_path):
1181
  return "<p style='color:#888;font-size:12px'>No audio yet.</p>"
1182
 
1183
- with open(audio_path, "rb") as f:
1184
- b64 = base64.b64encode(f.read()).decode()
1185
- data_uri = f"data:audio/wav;base64,{b64}"
1186
 
1187
  segs_json = json.dumps(segments)
1188
 
@@ -1365,17 +1395,15 @@ def _build_waveform_html(audio_path: str, segments: list, slot_id: str,
1365
  }}
1366
  }}, 50);
1367
 
1368
- // ── Decode audio ───────────────────────────────────────────────────
1369
- const b64str = '{b64}';
1370
- const bin = atob(b64str);
1371
- const buf = new Uint8Array(bin.length);
1372
- for (let i=0; i<bin.length; i++) buf[i]=bin.charCodeAt(i);
1373
-
1374
- const AudioCtx = window.AudioContext || window.webkitAudioContext;
1375
- if (AudioCtx) {{
1376
- const tmpCtx = new AudioCtx({{sampleRate:44100}});
1377
- try {{
1378
- tmpCtx.decodeAudioData(buf.buffer.slice(0),
1379
  function(ab) {{
1380
  try {{ tmpCtx.close(); }} catch(e) {{}}
1381
  function tryDraw() {{
@@ -1387,8 +1415,8 @@ def _build_waveform_html(audio_path: str, segments: list, slot_id: str,
1387
  }},
1388
  function(err) {{}}
1389
  );
1390
- }} catch(e) {{}}
1391
- }}
1392
  }})();
1393
  </script>
1394
  </body>
@@ -1415,7 +1443,7 @@ def _build_waveform_html(audio_path: str, segments: list, slot_id: str,
1415
  </div>
1416
  <div style="display:flex;align-items:center;gap:8px;margin-top:6px;">
1417
  <span style="color:#888;font-size:11px;">Click a segment to regenerate &nbsp;|&nbsp; Playhead syncs to video</span>
1418
- <a href="{data_uri}" download="audio_{slot_id}.wav"
1419
  style="margin-left:auto;background:#333;color:#eee;border:1px solid #555;
1420
  border-radius:4px;padding:3px 10px;font-size:12px;text-decoration:none;">
1421
  &#8595; Audio</a>{f'''
@@ -1875,12 +1903,6 @@ with gr.Blocks(title="Generate Audio for Video", css=_SLOT_CSS, js=_GLOBAL_JS) a
1875
 
1876
  def _run_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n):
1877
  flat = generate_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n)
1878
- # Serialise wavs in meta to JSON-safe lists
1879
- for i in range(MAX_SLOTS):
1880
- meta = flat[i * 3 + 2]
1881
- if meta is not None:
1882
- meta["wavs"] = [w.tolist() for w in meta["wavs"]]
1883
- flat[i * 3 + 2] = meta
1884
  return _unpack_outputs(flat, n, "taro")
1885
 
1886
  # Split group visibility into a separate .then() to avoid Gradio 5 SSR
@@ -1977,11 +1999,6 @@ with gr.Blocks(title="Generate Audio for Video", css=_SLOT_CSS, js=_GLOBAL_JS) a
1977
 
1978
  def _run_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n):
1979
  flat = generate_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n)
1980
- for i in range(MAX_SLOTS):
1981
- meta = flat[i * 3 + 2]
1982
- if meta is not None:
1983
- meta["wavs"] = [w.tolist() for w in meta["wavs"]]
1984
- flat[i * 3 + 2] = meta
1985
  return _unpack_outputs(flat, n, "mma")
1986
 
1987
  (mma_btn.click(
@@ -2069,11 +2086,6 @@ with gr.Blocks(title="Generate Audio for Video", css=_SLOT_CSS, js=_GLOBAL_JS) a
2069
 
2070
  def _run_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n):
2071
  flat = generate_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n)
2072
- for i in range(MAX_SLOTS):
2073
- meta = flat[i * 3 + 2]
2074
- if meta is not None:
2075
- meta["wavs"] = [w.tolist() for w in meta["wavs"]]
2076
- flat[i * 3 + 2] = meta
2077
  return _unpack_outputs(flat, n, "hf")
2078
 
2079
  (hf_btn.click(
 
11
  import os
12
  import sys
13
  import json
 
14
  import tempfile
15
  import random
16
  import threading
 
110
  overwrite_output=True, quiet=True
111
  )
112
 
113
+ # ------------------------------------------------------------------ #
114
+ # Temp directory registry — tracks dirs for cleanup on new generation #
115
+ # ------------------------------------------------------------------ #
116
+ _TEMP_DIRS: list = [] # list of tmp_dir paths created by generate_*
117
+ _TEMP_DIRS_MAX = 10 # keep at most this many; older ones get cleaned up
118
+
119
+ def _register_tmp_dir(tmp_dir: str) -> str:
120
+ """Register a temp dir so it can be cleaned up when newer ones replace it."""
121
+ import shutil
122
+ _TEMP_DIRS.append(tmp_dir)
123
+ while len(_TEMP_DIRS) > _TEMP_DIRS_MAX:
124
+ old = _TEMP_DIRS.pop(0)
125
+ try:
126
+ shutil.rmtree(old, ignore_errors=True)
127
+ print(f"[cleanup] Removed old temp dir: {old}")
128
+ except Exception:
129
+ pass
130
+ return tmp_dir
131
+
132
+
133
+ def _save_seg_wavs(wavs: list, tmp_dir: str, prefix: str) -> list:
134
+ """Save a list of numpy wav arrays to .npy files, return list of paths.
135
+ This avoids serialising large float arrays into JSON/HTML data-state."""
136
+ paths = []
137
+ for i, w in enumerate(wavs):
138
+ p = os.path.join(tmp_dir, f"{prefix}_seg{i}.npy")
139
+ np.save(p, w)
140
+ paths.append(p)
141
+ return paths
142
+
143
+
144
+ def _load_seg_wavs(paths: list) -> list:
145
+ """Load segment wav arrays from .npy file paths."""
146
+ return [np.load(p) for p in paths]
147
+
148
+
149
+ # ------------------------------------------------------------------ #
150
+ # Shared model-loading helpers (deduplicate generate / regen code) #
151
+ # ------------------------------------------------------------------ #
152
+
153
+ def _load_taro_models(device, weight_dtype):
154
+ """Load TARO MMDiT + AudioLDM2 VAE/vocoder. Returns (model_net, vae, vocoder, latents_scale)."""
155
+ from TARO.models import MMDiT
156
+ from diffusers import AutoencoderKL
157
+ from transformers import SpeechT5HifiGan
158
+
159
+ model_net = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
160
+ model_net.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"])
161
+ model_net.eval().to(weight_dtype)
162
+ vae = AutoencoderKL.from_pretrained("cvssp/audioldm2", subfolder="vae").to(device).eval()
163
+ vocoder = SpeechT5HifiGan.from_pretrained("cvssp/audioldm2", subfolder="vocoder").to(device)
164
+ latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
165
+ return model_net, vae, vocoder, latents_scale
166
+
167
+
168
+ def _load_taro_feature_extractors(device):
169
+ """Load CAVP + onset extractors. Returns (extract_cavp, onset_model)."""
170
+ from TARO.cavp_util import Extract_CAVP_Features
171
+ from TARO.onset_util import VideoOnsetNet
172
+
173
+ extract_cavp = Extract_CAVP_Features(
174
+ device=device, config_path="TARO/cavp/cavp.yaml", ckpt_path=cavp_ckpt_path,
175
+ )
176
+ raw_sd = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
177
+ onset_sd = {}
178
+ for k, v in raw_sd.items():
179
+ if "model.net.model" in k: k = k.replace("model.net.model", "net.model")
180
+ elif "model.fc." in k: k = k.replace("model.fc", "fc")
181
+ onset_sd[k] = v
182
+ onset_model = VideoOnsetNet(pretrained=False).to(device)
183
+ onset_model.load_state_dict(onset_sd)
184
+ onset_model.eval()
185
+ return extract_cavp, onset_model
186
+
187
+
188
+ def _load_mmaudio_models(device, dtype):
189
+ """Load MMAudio net + feature_utils. Returns (net, feature_utils, model_cfg, seq_cfg)."""
190
+ from mmaudio.eval_utils import all_model_cfg
191
+ from mmaudio.model.networks import get_my_mmaudio
192
+ from mmaudio.model.utils.features_utils import FeaturesUtils
193
+ from pathlib import Path as _Path
194
+
195
+ model_cfg = all_model_cfg["large_44k_v2"]
196
+ model_cfg.model_path = _Path(mmaudio_model_path)
197
+ model_cfg.vae_path = _Path(mmaudio_vae_path)
198
+ model_cfg.synchformer_ckpt = _Path(mmaudio_synchformer_path)
199
+ model_cfg.bigvgan_16k_path = None
200
+ seq_cfg = model_cfg.seq_cfg
201
+
202
+ net = get_my_mmaudio(model_cfg.model_name).to(device, dtype).eval()
203
+ net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True))
204
+ feature_utils = FeaturesUtils(
205
+ tod_vae_ckpt=str(model_cfg.vae_path),
206
+ synchformer_ckpt=str(model_cfg.synchformer_ckpt),
207
+ enable_conditions=True, mode=model_cfg.mode,
208
+ bigvgan_vocoder_ckpt=None, need_vae_encoder=False,
209
+ ).to(device, dtype).eval()
210
+ return net, feature_utils, model_cfg, seq_cfg
211
+
212
+
213
+ def _load_hunyuan_model(device, model_size):
214
+ """Load HunyuanFoley model dict + config. Returns (model_dict, cfg)."""
215
+ from hunyuanvideo_foley.utils.model_utils import load_model
216
+ model_size = model_size.lower()
217
+ config_map = {
218
+ "xl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xl.yaml",
219
+ "xxl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xxl.yaml",
220
+ }
221
+ config_path = config_map.get(model_size, config_map["xxl"])
222
+ hunyuan_weights_dir = str(HUNYUAN_MODEL_DIR / "HunyuanVideo-Foley")
223
+ print(f"[HunyuanFoley] Loading {model_size.upper()} model from {hunyuan_weights_dir}")
224
+ return load_model(hunyuan_weights_dir, config_path, device,
225
+ enable_offload=False, model_size=model_size)
226
+
227
+
228
  def mux_video_audio(silent_video: str, audio_path: str, output_path: str):
229
  """Mux a silent video with an audio file into *output_path* (stream-copy video, encode audio)."""
230
  ffmpeg.output(
 
307
 
308
  _TARO_CACHE_MAXLEN = 16 # evict oldest entries beyond this limit
309
  _TARO_INFERENCE_CACHE: dict = {} # keyed by (video_file, seed, cfg, steps, mode, crossfade_s)
310
+ import threading
311
+ _TARO_CACHE_LOCK = threading.Lock()
312
 
313
 
314
  def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: float) -> int:
 
424
  if _taro_dir not in sys.path:
425
  sys.path.insert(0, _taro_dir)
426
 
427
+ from TARO.onset_util import extract_onset
428
+ from TARO.samplers import euler_sampler, euler_maruyama_sampler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
+ extract_cavp, onset_model = _load_taro_feature_extractors(device)
431
+ model, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype)
 
 
 
432
 
433
  # -- Prepare silent video (shared across all samples) --
434
+ tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
435
  silent_video = os.path.join(tmp_dir, "silent_input.mp4")
436
  strip_audio_from_video(video_file, silent_video)
437
 
 
446
  sample_seed = seed_val + sample_idx
447
  cache_key = (video_file, sample_seed, float(cfg_scale), int(num_steps), mode, crossfade_s)
448
 
449
+ with _TARO_CACHE_LOCK:
450
+ cached = _TARO_INFERENCE_CACHE.get(cache_key)
451
+ if cached is not None:
452
  print(f"[TARO] Sample {sample_idx+1}: cache hit.")
453
+ wavs = cached["wavs"]
454
  else:
455
  set_global_seed(sample_seed)
456
  onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
 
474
  print(f"[TARO] Inference done: {_n_segs} seg(s) × {int(num_steps)} steps in "
475
  f"{_t_infer_elapsed:.1f}s wall → {_secs_per_step:.3f}s/step "
476
  f"(current constant={TARO_SECS_PER_STEP})")
477
+ with _TARO_CACHE_LOCK:
478
+ _TARO_INFERENCE_CACHE[cache_key] = {"wavs": wavs}
479
+ while len(_TARO_INFERENCE_CACHE) > _TARO_CACHE_MAXLEN:
480
+ _TARO_INFERENCE_CACHE.pop(next(iter(_TARO_INFERENCE_CACHE)))
481
 
482
  final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR)
483
  audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
484
  torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(final_wav)).unsqueeze(0), TARO_SR)
485
  video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
486
  mux_video_audio(silent_video, audio_path, video_path)
487
+ wav_paths = _save_seg_wavs(wavs, tmp_dir, f"taro_{sample_idx}")
488
+ # Cache CAVP + onset features so regen can skip re-extraction (~5-7s saved)
489
+ cavp_path = os.path.join(tmp_dir, f"taro_{sample_idx}_cavp.npy")
490
+ onset_path = os.path.join(tmp_dir, f"taro_{sample_idx}_onset.npy")
491
+ np.save(cavp_path, cavp_feats)
492
+ np.save(onset_path, onset_feats)
493
  seg_meta = {
494
  "segments": segments,
495
+ "wav_paths": wav_paths,
496
  "audio_path": audio_path,
497
  "video_path": video_path,
498
  "silent_video": silent_video,
 
501
  "crossfade_s": crossfade_s,
502
  "crossfade_db": crossfade_db,
503
  "total_dur_s": total_dur_s,
504
+ "cavp_path": cavp_path,
505
+ "onset_path": onset_path,
506
  }
507
  outputs.append((video_path, audio_path, seg_meta))
508
 
 
546
  if _mmaudio_dir not in sys.path:
547
  sys.path.insert(0, _mmaudio_dir)
548
 
549
+ from mmaudio.eval_utils import generate, load_video, make_video
550
  from mmaudio.model.flow_matching import FlowMatching
 
 
551
 
552
  seed_val = int(seed_val)
553
  num_samples = int(num_samples)
 
557
  device = "cuda" if torch.cuda.is_available() else "cpu"
558
  dtype = torch.bfloat16
559
 
560
+ net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
 
562
+ tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
563
  outputs = []
564
 
565
  # Strip original audio so the muxed output only contains the generated track
 
574
 
575
  sr = seq_cfg.sampling_rate # 44100
576
 
577
+ # Pre-extract all segment clips once (shared across samples, saves ffmpeg overhead)
578
+ seg_clip_paths = []
579
+ for seg_i, (seg_start, seg_end) in enumerate(segments):
580
+ seg_dur = seg_end - seg_start
581
+ seg_path = os.path.join(tmp_dir, f"mma_seg_{seg_i}.mp4")
582
+ ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
583
+ seg_path, vcodec="copy", an=None
584
+ ).run(overwrite_output=True, quiet=True)
585
+ seg_clip_paths.append(seg_path)
586
+
587
  for sample_idx in range(num_samples):
588
  rng = torch.Generator(device=device)
589
  if seed_val >= 0:
 
596
 
597
  for seg_i, (seg_start, seg_end) in enumerate(segments):
598
  seg_dur = seg_end - seg_start
599
+ seg_path = seg_clip_paths[seg_i]
 
 
 
 
600
 
601
  fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
602
  video_info = load_video(seg_path, seg_dur)
 
645
 
646
  video_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.mp4")
647
  mux_video_audio(silent_video, audio_path, video_path)
648
+ wav_paths = _save_seg_wavs(seg_audios, tmp_dir, f"mmaudio_{sample_idx}")
649
  seg_meta = {
650
  "segments": segments,
651
+ "wav_paths": wav_paths,
652
  "audio_path": audio_path,
653
  "video_path": video_path,
654
  "silent_video": silent_video,
 
702
  if _hf_path not in sys.path:
703
  sys.path.insert(0, _hf_path)
704
 
705
+ from hunyuanvideo_foley.utils.model_utils import denoise_process
706
  from hunyuanvideo_foley.utils.feature_utils import feature_process
707
  from hunyuanvideo_foley.utils.media_utils import merge_audio_video
708
 
 
716
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
717
  model_size = model_size.lower() # "xl" or "xxl"
718
 
719
+ model_dict, cfg = _load_hunyuan_model(device, model_size)
 
 
 
 
720
 
721
+ tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
 
 
 
 
 
 
 
 
 
 
 
 
722
  outputs = []
723
 
724
  # Strip original audio so the muxed output only contains the generated track
 
745
  neg_prompt=negative_prompt if negative_prompt else None,
746
  )
747
 
748
+ # Pre-extract all segment clips once (shared across samples, saves ffmpeg overhead)
749
+ hny_seg_clip_paths = []
750
+ for seg_i, (seg_start, seg_end) in enumerate(segments):
751
+ seg_dur = seg_end - seg_start
752
+ seg_path = os.path.join(tmp_dir, f"hny_seg_{seg_i}.mp4")
753
+ ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
754
+ seg_path, vcodec="copy", an=None
755
+ ).run(overwrite_output=True, quiet=True)
756
+ hny_seg_clip_paths.append(seg_path)
757
+
758
  # Generate audio per segment, then stitch
759
  for sample_idx in range(num_samples):
760
  seg_wavs = []
 
762
  _t_hny_start = time.perf_counter()
763
  for seg_i, (seg_start, seg_end) in enumerate(segments):
764
  seg_dur = seg_end - seg_start
765
+ seg_path = hny_seg_clip_paths[seg_i]
 
 
 
766
 
767
  # feature_process returns (visual_feats, text_feats, audio_len).
768
  # We discard the returned text_feats (_) and use the pre-computed
 
812
  torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(full_wav)), sr)
813
  video_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.mp4")
814
  merge_audio_video(audio_path, silent_video, video_path)
815
+ wav_paths = _save_seg_wavs(seg_wavs, tmp_dir, f"hunyuan_{sample_idx}")
816
+ # Cache text features so regen can skip text encoding (~2-3s saved)
817
+ text_feats_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}_text_feats.pt")
818
+ torch.save(text_feats, text_feats_path)
819
  seg_meta = {
820
  "segments": segments,
821
+ "wav_paths": wav_paths,
822
  "audio_path": audio_path,
823
  "video_path": video_path,
824
  "silent_video": silent_video,
 
827
  "crossfade_s": crossfade_s,
828
  "crossfade_db": crossfade_db,
829
  "total_dur_s": total_dur_s,
830
+ "text_feats_path": text_feats_path,
831
  }
832
  outputs.append((video_path, audio_path, seg_meta))
833
 
 
848
  """Replace wavs[seg_idx] with new_wav, re-stitch, re-save, re-mux.
849
  Returns (video_path, audio_path, updated_meta, waveform_html).
850
  """
851
+ wavs = _load_seg_wavs(meta["wav_paths"])
852
  wavs[seg_idx]= new_wav
853
  crossfade_s = float(meta["crossfade_s"])
854
  crossfade_db = float(meta["crossfade_db"])
 
897
  else:
898
  mux_video_audio(silent_video, audio_path, video_path)
899
 
900
+ # Save updated segment wavs to .npy files
901
+ updated_wav_paths = _save_seg_wavs(wavs, tmp_dir, os.path.splitext(_base_clean)[0])
902
  updated_meta = dict(meta)
903
+ updated_meta["wav_paths"] = updated_wav_paths
904
  updated_meta["audio_path"] = audio_path
905
  updated_meta["video_path"] = video_path
906
 
907
+ state_json_new = json.dumps(updated_meta)
 
 
 
908
 
909
  waveform_html = _build_waveform_html(audio_path, segments, slot_id, "",
910
  state_json=state_json_new,
 
938
  if _taro_dir not in sys.path:
939
  sys.path.insert(0, _taro_dir)
940
 
941
+ from TARO.samplers import euler_sampler, euler_maruyama_sampler
 
 
 
 
 
942
 
943
+ # Load cached CAVP + onset features if available (saves ~5-7s of GPU work)
944
+ cavp_path = meta.get("cavp_path")
945
+ onset_path = meta.get("onset_path")
946
+ if cavp_path and os.path.exists(cavp_path) and onset_path and os.path.exists(onset_path):
947
+ print("[TARO regen] Loading cached CAVP + onset features")
948
+ cavp_feats = np.load(cavp_path)
949
+ onset_feats = np.load(onset_path)
950
+ else:
951
+ print("[TARO regen] Cache miss — re-extracting CAVP + onset features")
952
+ from TARO.onset_util import extract_onset
953
+ extract_cavp, onset_model = _load_taro_feature_extractors(device)
954
+ silent_video = meta["silent_video"]
955
+ tmp_dir = tempfile.mkdtemp()
956
+ cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
957
+ onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
958
 
959
+ model_net, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
960
 
 
961
  set_global_seed(random.randint(0, 2**32 - 1))
 
962
 
963
  new_wav = _taro_infer_segment(
964
  model_net, vae, vocoder, cavp_feats, onset_feats,
 
967
  euler_sampler, euler_maruyama_sampler,
968
  )
969
 
 
 
 
 
970
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
971
  new_wav, seg_idx, meta, slot_id
972
  )
 
973
  return video_path, audio_path, json.dumps(updated_meta), waveform_html
974
 
975
 
 
996
  if _mmaudio_dir not in sys.path:
997
  sys.path.insert(0, _mmaudio_dir)
998
 
999
+ from mmaudio.eval_utils import generate, load_video
1000
+ from mmaudio.model.flow_matching import FlowMatching
 
 
 
1001
 
1002
  device = "cuda" if torch.cuda.is_available() else "cpu"
1003
  dtype = torch.bfloat16
1004
 
1005
+ net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1006
 
1007
  sr = seq_cfg.sampling_rate
1008
  silent_video = meta["silent_video"]
 
1034
  seg_samples = int(round(seg_dur * sr))
1035
  new_wav = new_wav[:, :seg_samples]
1036
 
 
 
1037
  meta["sr"] = sr
1038
 
1039
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1040
  new_wav, seg_idx, meta, slot_id
1041
  )
 
1042
  return video_path, audio_path, json.dumps(updated_meta), waveform_html
1043
 
1044
 
 
1067
  if _hf_path not in sys.path:
1068
  sys.path.insert(0, _hf_path)
1069
 
1070
+ from hunyuanvideo_foley.utils.model_utils import denoise_process
1071
+ from hunyuanvideo_foley.utils.feature_utils import feature_process
1072
 
1073
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1074
+ model_dict, cfg = _load_hunyuan_model(device, model_size)
 
 
 
 
 
 
 
 
1075
 
1076
  set_global_seed(random.randint(0, 2**32 - 1))
1077
 
 
1082
  seg_path, vcodec="copy", an=None
1083
  ).run(overwrite_output=True, quiet=True)
1084
 
1085
+ # Load cached text features if available (saves ~2-3s text encoding)
1086
+ text_feats_path = meta.get("text_feats_path")
1087
+ if text_feats_path and os.path.exists(text_feats_path):
1088
+ print("[HunyuanFoley regen] Loading cached text features, extracting visual only")
1089
+ from hunyuanvideo_foley.utils.feature_utils import encode_video_features
1090
+ visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict)
1091
+ text_feats = torch.load(text_feats_path, map_location=device, weights_only=False)
1092
+ else:
1093
+ print("[HunyuanFoley regen] Cache miss — extracting text + visual features")
1094
+ visual_feats, text_feats, seg_audio_len = feature_process(
1095
+ seg_path, prompt if prompt else "", model_dict, cfg,
1096
+ neg_prompt=negative_prompt if negative_prompt else None,
1097
+ )
1098
  audio_batch, sr = denoise_process(
1099
  visual_feats, text_feats, seg_audio_len, model_dict, cfg,
1100
  guidance_scale=float(guidance_scale),
 
1105
  seg_samples = int(round(seg_dur * sr))
1106
  new_wav = new_wav[:, :seg_samples]
1107
 
 
 
1108
  meta["sr"] = sr
1109
 
1110
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1111
  new_wav, seg_idx, meta, slot_id
1112
  )
 
1113
  return video_path, audio_path, json.dumps(updated_meta), waveform_html
1114
 
1115
 
 
1123
  Each entry in *outputs* must be a (video_path, audio_path, seg_meta) tuple where
1124
  seg_meta = {"segments": [...], "audio_path": str, "video_path": str,
1125
  "sr": int, "model": str, "crossfade_s": float,
1126
+ "crossfade_db": float, "wav_paths": list[str]}
1127
  """
1128
  result = []
1129
  for i in range(MAX_SLOTS):
 
1210
  if not audio_path or not os.path.exists(audio_path):
1211
  return "<p style='color:#888;font-size:12px'>No audio yet.</p>"
1212
 
1213
+ # Serve audio via Gradio's file API instead of base64-encoding the entire
1214
+ # WAV inline. For a 25s stereo 44.1kHz track this saves ~5 MB per slot.
1215
+ audio_url = f"/gradio_api/file={audio_path}"
1216
 
1217
  segs_json = json.dumps(segments)
1218
 
 
1395
  }}
1396
  }}, 50);
1397
 
1398
+ // ── Fetch + decode audio from Gradio file API ──────────────────────
1399
+ const audioUrl = '{audio_url}';
1400
+ fetch(audioUrl)
1401
+ .then(function(r) {{ return r.arrayBuffer(); }})
1402
+ .then(function(arrayBuf) {{
1403
+ const AudioCtx = window.AudioContext || window.webkitAudioContext;
1404
+ if (!AudioCtx) return;
1405
+ const tmpCtx = new AudioCtx({{sampleRate:44100}});
1406
+ tmpCtx.decodeAudioData(arrayBuf,
 
 
1407
  function(ab) {{
1408
  try {{ tmpCtx.close(); }} catch(e) {{}}
1409
  function tryDraw() {{
 
1415
  }},
1416
  function(err) {{}}
1417
  );
1418
+ }})
1419
+ .catch(function(e) {{}});
1420
  }})();
1421
  </script>
1422
  </body>
 
1443
  </div>
1444
  <div style="display:flex;align-items:center;gap:8px;margin-top:6px;">
1445
  <span style="color:#888;font-size:11px;">Click a segment to regenerate &nbsp;|&nbsp; Playhead syncs to video</span>
1446
+ <a href="{audio_url}" download="audio_{slot_id}.wav"
1447
  style="margin-left:auto;background:#333;color:#eee;border:1px solid #555;
1448
  border-radius:4px;padding:3px 10px;font-size:12px;text-decoration:none;">
1449
  &#8595; Audio</a>{f'''
 
1903
 
1904
  def _run_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n):
1905
  flat = generate_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n)
 
 
 
 
 
 
1906
  return _unpack_outputs(flat, n, "taro")
1907
 
1908
  # Split group visibility into a separate .then() to avoid Gradio 5 SSR
 
1999
 
2000
  def _run_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n):
2001
  flat = generate_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n)
 
 
 
 
 
2002
  return _unpack_outputs(flat, n, "mma")
2003
 
2004
  (mma_btn.click(
 
2086
 
2087
  def _run_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n):
2088
  flat = generate_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n)
 
 
 
 
 
2089
  return _unpack_outputs(flat, n, "hf")
2090
 
2091
  (hf_btn.click(