BoxOfColors commited on
Commit
02a1f95
·
1 Parent(s): 09846c9
Files changed (1) hide show
  1. app.py +745 -29
app.py CHANGED
@@ -10,6 +10,8 @@ Supported models
10
 
11
  import os
12
  import sys
 
 
13
  import tempfile
14
  import random
15
  from pathlib import Path
@@ -383,7 +385,19 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
383
  torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(final_wav)).unsqueeze(0), TARO_SR)
384
  video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
385
  mux_video_audio(silent_video, audio_path, video_path)
386
- outputs.append((video_path, audio_path))
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
  return _pad_outputs(outputs)
389
 
@@ -544,7 +558,19 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
544
 
545
  video_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.mp4")
546
  mux_video_audio(silent_video, audio_path, video_path)
547
- outputs.append((video_path, audio_path))
 
 
 
 
 
 
 
 
 
 
 
 
548
 
549
  return _pad_outputs(outputs)
550
 
@@ -707,45 +733,633 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
707
  torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(full_wav)), sr)
708
  video_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.mp4")
709
  merge_audio_video(audio_path, silent_video, video_path)
710
- outputs.append((video_path, audio_path))
 
 
 
 
 
 
 
 
 
 
 
 
711
 
712
  return _pad_outputs(outputs)
713
 
714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
715
  # ================================================================== #
716
  # SHARED UI HELPERS #
717
  # ================================================================== #
718
 
719
  def _pad_outputs(outputs: list) -> list:
720
- """Flatten (video, audio) pairs and pad to MAX_SLOTS * 2 with None."""
 
 
 
 
 
 
721
  result = []
722
  for i in range(MAX_SLOTS):
723
  if i < len(outputs):
724
- result.extend(outputs[i])
725
  else:
726
- result.extend([None, None])
727
  return result
728
 
729
 
730
- def _make_output_slots() -> tuple:
731
- """Build MAX_SLOTS video+audio output groups. Returns (grps, vids, auds)."""
732
- grps, vids, auds = [], [], []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
733
  for i in range(MAX_SLOTS):
734
  with gr.Group(visible=(i == 0)) as g:
 
735
  vids.append(gr.Video(label=f"Generation {i+1} — Video"))
736
- auds.append(gr.Audio(label=f"Generation {i+1} — Audio"))
 
 
 
 
 
 
 
 
 
 
 
737
  grps.append(g)
738
- return grps, vids, auds
739
 
740
 
741
- def _unpack_outputs(flat: list, n: int) -> list:
742
- """Turn a flat _pad_outputs list into Gradio update lists for grps+vids+auds."""
 
 
 
 
743
  n = int(n)
744
- return (
745
- [gr.update(visible=(i < n)) for i in range(MAX_SLOTS)] +
746
- [gr.update(value=flat[i * 2]) for i in range(MAX_SLOTS)] +
747
- [gr.update(value=flat[i * 2 + 1]) for i in range(MAX_SLOTS)]
748
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
749
 
750
 
751
  def _on_video_upload_taro(video_file, num_steps, crossfade_s):
@@ -798,7 +1412,9 @@ with gr.Blocks(title="Generate Audio for Video") as demo:
798
  taro_btn = gr.Button("Generate", variant="primary")
799
 
800
  with gr.Column():
801
- taro_slot_grps, taro_slot_vids, taro_slot_auds = _make_output_slots()
 
 
802
 
803
  for trigger in [taro_video, taro_steps, taro_cf_dur]:
804
  trigger.change(
@@ -813,15 +1429,49 @@ with gr.Blocks(title="Generate Audio for Video") as demo:
813
  )
814
 
815
  def _run_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n):
816
- return _unpack_outputs(generate_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n), n)
 
 
 
 
 
 
 
817
 
818
  taro_btn.click(
819
  fn=_run_taro,
820
  inputs=[taro_video, taro_seed, taro_cfg, taro_steps, taro_mode,
821
  taro_cf_dur, taro_cf_db, taro_samples],
822
- outputs=taro_slot_grps + taro_slot_vids + taro_slot_auds,
823
  )
824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
825
  # ---------------------------------------------------------- #
826
  # Tab 2 — MMAudio #
827
  # ---------------------------------------------------------- #
@@ -840,7 +1490,9 @@ with gr.Blocks(title="Generate Audio for Video") as demo:
840
  mma_btn = gr.Button("Generate", variant="primary")
841
 
842
  with gr.Column():
843
- mma_slot_grps, mma_slot_vids, mma_slot_auds = _make_output_slots()
 
 
844
 
845
  mma_samples.change(
846
  fn=_update_slot_visibility,
@@ -849,15 +1501,47 @@ with gr.Blocks(title="Generate Audio for Video") as demo:
849
  )
850
 
851
  def _run_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n):
852
- return _unpack_outputs(generate_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n), n)
 
 
 
 
 
 
853
 
854
  mma_btn.click(
855
  fn=_run_mmaudio,
856
  inputs=[mma_video, mma_prompt, mma_neg, mma_seed,
857
  mma_cfg, mma_steps, mma_cf_dur, mma_cf_db, mma_samples],
858
- outputs=mma_slot_grps + mma_slot_vids + mma_slot_auds,
859
  )
860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
861
  # ---------------------------------------------------------- #
862
  # Tab 3 — HunyuanVideoFoley #
863
  # ---------------------------------------------------------- #
@@ -877,7 +1561,9 @@ with gr.Blocks(title="Generate Audio for Video") as demo:
877
  hf_btn = gr.Button("Generate", variant="primary")
878
 
879
  with gr.Column():
880
- hf_slot_grps, hf_slot_vids, hf_slot_auds = _make_output_slots()
 
 
881
 
882
  hf_samples.change(
883
  fn=_update_slot_visibility,
@@ -886,18 +1572,48 @@ with gr.Blocks(title="Generate Audio for Video") as demo:
886
  )
887
 
888
  def _run_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n):
889
- return _unpack_outputs(generate_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n), n)
 
 
 
 
 
 
890
 
891
  hf_btn.click(
892
  fn=_run_hunyuan,
893
  inputs=[hf_video, hf_prompt, hf_neg, hf_seed,
894
  hf_guidance, hf_steps, hf_size, hf_cf_dur, hf_cf_db, hf_samples],
895
- outputs=hf_slot_grps + hf_slot_vids + hf_slot_auds,
896
  )
897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
898
  # ---- Cross-tab video sync ----
899
- # When any tab's video changes, push the value to the other two tabs.
900
- # Clearing (value=None) also propagates so the X button clears all.
901
  _sync = lambda v: (gr.update(value=v), gr.update(value=v))
902
  taro_video.change(fn=_sync, inputs=[taro_video], outputs=[mma_video, hf_video])
903
  mma_video.change(fn=_sync, inputs=[mma_video], outputs=[taro_video, hf_video])
 
10
 
11
  import os
12
  import sys
13
+ import json
14
+ import base64
15
  import tempfile
16
  import random
17
  from pathlib import Path
 
385
  torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(final_wav)).unsqueeze(0), TARO_SR)
386
  video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
387
  mux_video_audio(silent_video, audio_path, video_path)
388
+ seg_meta = {
389
+ "segments": segments,
390
+ "wavs": [w.copy() for w in wavs],
391
+ "audio_path": audio_path,
392
+ "video_path": video_path,
393
+ "silent_video": silent_video,
394
+ "sr": TARO_SR,
395
+ "model": "taro",
396
+ "crossfade_s": crossfade_s,
397
+ "crossfade_db": crossfade_db,
398
+ "total_dur_s": total_dur_s,
399
+ }
400
+ outputs.append((video_path, audio_path, seg_meta))
401
 
402
  return _pad_outputs(outputs)
403
 
 
558
 
559
  video_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.mp4")
560
  mux_video_audio(silent_video, audio_path, video_path)
561
+ seg_meta = {
562
+ "segments": segments,
563
+ "wavs": [w.copy() for w in seg_audios],
564
+ "audio_path": audio_path,
565
+ "video_path": video_path,
566
+ "silent_video": silent_video,
567
+ "sr": sr,
568
+ "model": "mmaudio",
569
+ "crossfade_s": crossfade_s,
570
+ "crossfade_db": crossfade_db,
571
+ "total_dur_s": total_dur_s,
572
+ }
573
+ outputs.append((video_path, audio_path, seg_meta))
574
 
575
  return _pad_outputs(outputs)
576
 
 
733
  torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(full_wav)), sr)
734
  video_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.mp4")
735
  merge_audio_video(audio_path, silent_video, video_path)
736
+ seg_meta = {
737
+ "segments": segments,
738
+ "wavs": [w.copy() for w in seg_wavs],
739
+ "audio_path": audio_path,
740
+ "video_path": video_path,
741
+ "silent_video": silent_video,
742
+ "sr": sr,
743
+ "model": "hunyuan",
744
+ "crossfade_s": crossfade_s,
745
+ "crossfade_db": crossfade_db,
746
+ "total_dur_s": total_dur_s,
747
+ }
748
+ outputs.append((video_path, audio_path, seg_meta))
749
 
750
  return _pad_outputs(outputs)
751
 
752
 
753
+ # ================================================================== #
754
+ # SEGMENT REGENERATION HELPERS #
755
+ # ================================================================== #
756
+ # Each regen function:
757
+ # 1. Runs inference for ONE segment (random seed, current settings)
758
+ # 2. Splices the new wav into the stored wavs list
759
+ # 3. Re-stitches the full track, re-saves .wav and re-muxes .mp4
760
+ # 4. Returns (new_video_path, new_audio_path, updated_seg_meta, new_waveform_html)
761
+ # ================================================================== #
762
+
763
+ def _splice_and_save(new_wav, seg_idx, meta, slot_id):
764
+ """Replace wavs[seg_idx] with new_wav, re-stitch, re-save, re-mux.
765
+ Returns (video_path, audio_path, updated_meta, waveform_html).
766
+ """
767
+ wavs = [w.copy() for w in meta["wavs"]]
768
+ wavs[seg_idx]= new_wav
769
+ crossfade_s = float(meta["crossfade_s"])
770
+ crossfade_db = float(meta["crossfade_db"])
771
+ sr = int(meta["sr"])
772
+ total_dur_s = float(meta["total_dur_s"])
773
+ silent_video = meta["silent_video"]
774
+ segments = meta["segments"]
775
+ model = meta["model"]
776
+
777
+ # Stitch (works for both mono and stereo)
778
+ stereo = wavs[0].ndim == 2
779
+ full_wav = wavs[0]
780
+ for nw in wavs[1:]:
781
+ full_wav = _cf_join(full_wav, nw, crossfade_s, crossfade_db, sr)
782
+ n_total = int(round(total_dur_s * sr))
783
+ if stereo:
784
+ full_wav = full_wav[:, :n_total]
785
+ else:
786
+ full_wav = full_wav[:n_total]
787
+
788
+ # Save new audio
789
+ tmp_dir = os.path.dirname(meta["audio_path"])
790
+ audio_path = meta["audio_path"] # overwrite in-place
791
+ if stereo:
792
+ torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(full_wav)), sr)
793
+ else:
794
+ torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(full_wav)).unsqueeze(0), sr)
795
+
796
+ # Re-mux video
797
+ video_path = meta["video_path"] # overwrite in-place
798
+ if model == "hunyuan":
799
+ # HunyuanFoley uses its own merge_audio_video
800
+ _hf_path = str(Path("HunyuanVideo-Foley").resolve())
801
+ if _hf_path not in sys.path:
802
+ sys.path.insert(0, _hf_path)
803
+ from hunyuanvideo_foley.utils.media_utils import merge_audio_video
804
+ merge_audio_video(audio_path, silent_video, video_path)
805
+ else:
806
+ mux_video_audio(silent_video, audio_path, video_path)
807
+
808
+ updated_meta = dict(meta)
809
+ updated_meta["wavs"] = wavs
810
+ updated_meta["audio_path"] = audio_path
811
+ updated_meta["video_path"] = video_path
812
+
813
+ hidden_el_id = f"regen_trigger_{slot_id}"
814
+ waveform_html = _build_waveform_html(audio_path, segments, slot_id, hidden_el_id)
815
+ return video_path, audio_path, updated_meta, waveform_html
816
+
817
+
818
+ def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
819
+ seed_val, cfg_scale, num_steps, mode,
820
+ crossfade_s, crossfade_db):
821
+ secs = int(num_steps) * TARO_SECS_PER_STEP + TARO_LOAD_OVERHEAD
822
+ result = min(GPU_DURATION_CAP, max(60, int(secs)))
823
+ print(f"[duration] TARO regen: 1 seg × {int(num_steps)} steps → {secs:.0f}s → capped {result}s")
824
+ return result
825
+
826
+
827
+ @spaces.GPU(duration=_taro_regen_duration)
828
+ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
829
+ seed_val, cfg_scale, num_steps, mode,
830
+ crossfade_s, crossfade_db, slot_id):
831
+ """Regenerate one TARO segment with a fresh random seed."""
832
+ meta = json.loads(seg_meta_json)
833
+ seg_idx = int(seg_idx)
834
+ seg_start_s, seg_end_s = meta["segments"][seg_idx]
835
+
836
+ torch.set_grad_enabled(False)
837
+ device = "cuda" if torch.cuda.is_available() else "cpu"
838
+ weight_dtype = torch.bfloat16
839
+
840
+ _taro_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "TARO")
841
+ if _taro_dir not in sys.path:
842
+ sys.path.insert(0, _taro_dir)
843
+
844
+ from TARO.cavp_util import Extract_CAVP_Features
845
+ from TARO.onset_util import VideoOnsetNet, extract_onset
846
+ from TARO.models import MMDiT
847
+ from TARO.samplers import euler_sampler, euler_maruyama_sampler
848
+ from diffusers import AudioLDM2Pipeline
849
+
850
+ silent_video = meta["silent_video"]
851
+ tmp_dir = tempfile.mkdtemp()
852
+
853
+ extract_cavp = Extract_CAVP_Features(device=device, config_path="TARO/cavp/cavp.yaml", ckpt_path=cavp_ckpt_path)
854
+ raw_sd = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
855
+ onset_sd = {}
856
+ for k, v in raw_sd.items():
857
+ if "model.net.model" in k: k = k.replace("model.net.model", "net.model")
858
+ elif "model.fc." in k: k = k.replace("model.fc", "fc")
859
+ onset_sd[k] = v
860
+ onset_model = VideoOnsetNet(pretrained=False).to(device)
861
+ onset_model.load_state_dict(onset_sd)
862
+ onset_model.eval()
863
+ model_net = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
864
+ model_net.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"])
865
+ model_net.eval().to(weight_dtype)
866
+ audioldm2 = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
867
+ vae = audioldm2.vae.to(device).eval()
868
+ vocoder = audioldm2.vocoder.to(device)
869
+ latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
870
+
871
+ cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
872
+ set_global_seed(random.randint(0, 2**32 - 1))
873
+ onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
874
+
875
+ new_wav = _taro_infer_segment(
876
+ model_net, vae, vocoder, cavp_feats, onset_feats,
877
+ seg_start_s, seg_end_s, device, weight_dtype,
878
+ float(cfg_scale), int(num_steps), mode, latents_scale,
879
+ euler_sampler, euler_maruyama_sampler,
880
+ )
881
+
882
+ # Deserialise stored wavs from lists back to numpy arrays (json roundtrip)
883
+ stored_wavs = [np.array(w, dtype=np.float32) for w in meta["wavs"]]
884
+ meta["wavs"] = stored_wavs
885
+
886
+ video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
887
+ new_wav, seg_idx, meta, slot_id
888
+ )
889
+ updated_meta["wavs"] = [w.tolist() for w in updated_meta["wavs"]]
890
+ return video_path, audio_path, json.dumps(updated_meta), waveform_html
891
+
892
+
893
+ def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
894
+ prompt, negative_prompt, seed_val,
895
+ cfg_strength, num_steps, crossfade_s, crossfade_db):
896
+ secs = int(num_steps) * MMAUDIO_SECS_PER_STEP + MMAUDIO_LOAD_OVERHEAD
897
+ result = min(GPU_DURATION_CAP, max(60, int(secs)))
898
+ print(f"[duration] MMAudio regen: 1 seg × {int(num_steps)} steps → {secs:.0f}s → capped {result}s")
899
+ return result
900
+
901
+
902
+ @spaces.GPU(duration=_mmaudio_regen_duration)
903
+ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
904
+ prompt, negative_prompt, seed_val,
905
+ cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id):
906
+ """Regenerate one MMAudio segment with a fresh random seed."""
907
+ meta = json.loads(seg_meta_json)
908
+ seg_idx = int(seg_idx)
909
+ seg_start, seg_end = meta["segments"][seg_idx]
910
+ seg_dur = seg_end - seg_start
911
+
912
+ _mmaudio_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "MMAudio")
913
+ if _mmaudio_dir not in sys.path:
914
+ sys.path.insert(0, _mmaudio_dir)
915
+
916
+ from mmaudio.eval_utils import all_model_cfg, generate, load_video
917
+ from mmaudio.model.flow_matching import FlowMatching
918
+ from mmaudio.model.networks import get_my_mmaudio
919
+ from mmaudio.model.utils.features_utils import FeaturesUtils
920
+ from pathlib import Path as _Path
921
+
922
+ device = "cuda" if torch.cuda.is_available() else "cpu"
923
+ dtype = torch.bfloat16
924
+
925
+ model_cfg = all_model_cfg["large_44k_v2"]
926
+ model_cfg.model_path = _Path(mmaudio_model_path)
927
+ model_cfg.vae_path = _Path(mmaudio_vae_path)
928
+ model_cfg.synchformer_ckpt = _Path(mmaudio_synchformer_path)
929
+ model_cfg.bigvgan_16k_path = None
930
+ seq_cfg = model_cfg.seq_cfg
931
+
932
+ net = get_my_mmaudio(model_cfg.model_name).to(device, dtype).eval()
933
+ net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True))
934
+ feature_utils = FeaturesUtils(
935
+ tod_vae_ckpt=str(model_cfg.vae_path),
936
+ synchformer_ckpt=str(model_cfg.synchformer_ckpt),
937
+ enable_conditions=True, mode=model_cfg.mode,
938
+ bigvgan_vocoder_ckpt=None, need_vae_encoder=False,
939
+ ).to(device, dtype).eval()
940
+
941
+ sr = seq_cfg.sampling_rate
942
+ silent_video = meta["silent_video"]
943
+ tmp_dir = tempfile.mkdtemp()
944
+ seg_path = os.path.join(tmp_dir, "regen_seg.mp4")
945
+ ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
946
+ seg_path, vcodec="copy", an=None
947
+ ).run(overwrite_output=True, quiet=True)
948
+
949
+ rng = torch.Generator(device=device)
950
+ rng.manual_seed(random.randint(0, 2**32 - 1))
951
+
952
+ fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=int(num_steps))
953
+ video_info = load_video(seg_path, seg_dur)
954
+ clip_frames = video_info.clip_frames.unsqueeze(0)
955
+ sync_frames = video_info.sync_frames.unsqueeze(0)
956
+ actual_dur = video_info.duration_sec
957
+ seq_cfg.duration = actual_dur
958
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
959
+
960
+ with torch.no_grad():
961
+ audios = generate(
962
+ clip_frames, sync_frames, [prompt],
963
+ negative_text=[negative_prompt] if negative_prompt else None,
964
+ feature_utils=feature_utils, net=net, fm=fm, rng=rng,
965
+ cfg_strength=float(cfg_strength),
966
+ )
967
+ new_wav = audios.float().cpu()[0].numpy()
968
+ seg_samples = int(round(seg_dur * sr))
969
+ new_wav = new_wav[:, :seg_samples]
970
+
971
+ stored_wavs = [np.array(w, dtype=np.float32) for w in meta["wavs"]]
972
+ meta["wavs"] = stored_wavs
973
+ meta["sr"] = sr
974
+
975
+ video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
976
+ new_wav, seg_idx, meta, slot_id
977
+ )
978
+ updated_meta["wavs"] = [w.tolist() for w in updated_meta["wavs"]]
979
+ return video_path, audio_path, json.dumps(updated_meta), waveform_html
980
+
981
+
982
+ def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json,
983
+ prompt, negative_prompt, seed_val,
984
+ guidance_scale, num_steps, model_size,
985
+ crossfade_s, crossfade_db):
986
+ secs = int(num_steps) * HUNYUAN_SECS_PER_STEP + HUNYUAN_LOAD_OVERHEAD
987
+ result = min(GPU_DURATION_CAP, max(60, int(secs)))
988
+ print(f"[duration] HunyuanFoley regen: 1 seg × {int(num_steps)} steps → {secs:.0f}s → capped {result}s")
989
+ return result
990
+
991
+
992
+ @spaces.GPU(duration=_hunyuan_regen_duration)
993
+ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
994
+ prompt, negative_prompt, seed_val,
995
+ guidance_scale, num_steps, model_size,
996
+ crossfade_s, crossfade_db, slot_id):
997
+ """Regenerate one HunyuanFoley segment with a fresh random seed."""
998
+ meta = json.loads(seg_meta_json)
999
+ seg_idx = int(seg_idx)
1000
+ seg_start, seg_end = meta["segments"][seg_idx]
1001
+ seg_dur = seg_end - seg_start
1002
+
1003
+ _hf_path = str(Path("HunyuanVideo-Foley").resolve())
1004
+ if _hf_path not in sys.path:
1005
+ sys.path.insert(0, _hf_path)
1006
+
1007
+ from hunyuanvideo_foley.utils.model_utils import load_model, denoise_process
1008
+ from hunyuanvideo_foley.utils.feature_utils import feature_process
1009
+
1010
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1011
+ model_size = model_size.lower()
1012
+ config_map = {
1013
+ "xl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xl.yaml",
1014
+ "xxl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xxl.yaml",
1015
+ }
1016
+ config_path = config_map.get(model_size, config_map["xxl"])
1017
+ hunyuan_weights_dir = str(HUNYUAN_MODEL_DIR / "HunyuanVideo-Foley")
1018
+ model_dict, cfg = load_model(hunyuan_weights_dir, config_path, device,
1019
+ enable_offload=False, model_size=model_size)
1020
+
1021
+ set_global_seed(random.randint(0, 2**32 - 1))
1022
+
1023
+ silent_video = meta["silent_video"]
1024
+ tmp_dir = tempfile.mkdtemp()
1025
+ seg_path = os.path.join(tmp_dir, "regen_seg.mp4")
1026
+ ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
1027
+ seg_path, vcodec="copy", an=None
1028
+ ).run(overwrite_output=True, quiet=True)
1029
+
1030
+ visual_feats, text_feats, seg_audio_len = feature_process(
1031
+ seg_path, prompt if prompt else "", model_dict, cfg,
1032
+ neg_prompt=negative_prompt if negative_prompt else None,
1033
+ )
1034
+ audio_batch, sr = denoise_process(
1035
+ visual_feats, text_feats, seg_audio_len, model_dict, cfg,
1036
+ guidance_scale=float(guidance_scale),
1037
+ num_inference_steps=int(num_steps),
1038
+ batch_size=1,
1039
+ )
1040
+ new_wav = audio_batch[0].float().cpu().numpy()
1041
+ seg_samples = int(round(seg_dur * sr))
1042
+ new_wav = new_wav[:, :seg_samples]
1043
+
1044
+ stored_wavs = [np.array(w, dtype=np.float32) for w in meta["wavs"]]
1045
+ meta["wavs"] = stored_wavs
1046
+ meta["sr"] = sr
1047
+
1048
+ video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1049
+ new_wav, seg_idx, meta, slot_id
1050
+ )
1051
+ updated_meta["wavs"] = [w.tolist() for w in updated_meta["wavs"]]
1052
+ return video_path, audio_path, json.dumps(updated_meta), waveform_html
1053
+
1054
+
1055
  # ================================================================== #
1056
  # SHARED UI HELPERS #
1057
  # ================================================================== #
1058
 
1059
  def _pad_outputs(outputs: list) -> list:
1060
+ """Flatten (video, audio, seg_meta) triples and pad to MAX_SLOTS * 3 with None.
1061
+
1062
+ Each entry in *outputs* must be a (video_path, audio_path, seg_meta) tuple where
1063
+ seg_meta = {"segments": [...], "audio_path": str, "video_path": str,
1064
+ "sr": int, "model": str, "crossfade_s": float,
1065
+ "crossfade_db": float, "wavs": list[np.ndarray]}
1066
+ """
1067
  result = []
1068
  for i in range(MAX_SLOTS):
1069
  if i < len(outputs):
1070
+ result.extend(outputs[i]) # 3 items: video, audio, meta
1071
  else:
1072
+ result.extend([None, None, None])
1073
  return result
1074
 
1075
 
1076
+ # ------------------------------------------------------------------ #
1077
+ # WaveSurfer waveform + segment marker HTML builder #
1078
+ # ------------------------------------------------------------------ #
1079
+
1080
+ _WAVESURFER_CDN = "https://cdnjs.cloudflare.com/ajax/libs/wavesurfer.js/7.8.7/wavesurfer.min.js"
1081
+ _REGIONS_CDN = "https://cdnjs.cloudflare.com/ajax/libs/wavesurfer.js/7.8.7/plugins/regions.min.js"
1082
+
1083
+ def _build_waveform_html(audio_path: str, segments: list, slot_id: str,
1084
+ hidden_input_id: str) -> str:
1085
+ """Return a self-contained HTML block with a WaveSurfer waveform,
1086
+ segment boundary markers, a play/pause button, and a download link.
1087
+
1088
+ Clicking a region shows a small popup near the cursor with a
1089
+ "Regenerate" button. Clicking elsewhere dismisses the popup.
1090
+ Clicking "Regenerate" fires the hidden Gradio textbox to trigger Python.
1091
+
1092
+ Args:
1093
+ audio_path: absolute path to the .wav file
1094
+ segments: list of (start_s, end_s) tuples
1095
+ slot_id: unique string id for this slot (e.g. "taro_0")
1096
+ hidden_input_id: elem_id of the hidden gr.Textbox to fire
1097
+ """
1098
+ if not audio_path or not os.path.exists(audio_path):
1099
+ return "<p style='color:#888;font-size:12px'>No audio yet.</p>"
1100
+
1101
+ with open(audio_path, "rb") as f:
1102
+ b64 = base64.b64encode(f.read()).decode()
1103
+ data_uri = f"data:audio/wav;base64,{b64}"
1104
+
1105
+ segs_json = json.dumps(segments)
1106
+
1107
+ colors = ["rgba(100,180,255,0.22)", "rgba(255,160,100,0.22)",
1108
+ "rgba(120,220,140,0.22)", "rgba(220,120,220,0.22)",
1109
+ "rgba(255,220,80,0.22)", "rgba(80,220,220,0.22)",
1110
+ "rgba(255,100,100,0.22)", "rgba(180,255,180,0.22)"]
1111
+
1112
+ return f"""
1113
+ <div id="wf_container_{slot_id}"
1114
+ style="background:#1a1a1a;border-radius:8px;padding:10px;margin-top:6px;position:relative;">
1115
+ <div id="wf_{slot_id}" style="width:100%;min-height:80px;"></div>
1116
+ <div style="display:flex;align-items:center;gap:8px;margin-top:6px;">
1117
+ <button id="wf_playbtn_{slot_id}" onclick="wf_toggle_{slot_id}()"
1118
+ style="background:#333;color:#eee;border:1px solid #555;border-radius:4px;
1119
+ padding:3px 10px;font-size:12px;cursor:pointer;">&#9654; Play</button>
1120
+ <span style="color:#888;font-size:11px;">Click a segment to regenerate</span>
1121
+ <a href="{data_uri}" download="audio_{slot_id}.wav"
1122
+ style="margin-left:auto;background:#333;color:#eee;border:1px solid #555;
1123
+ border-radius:4px;padding:3px 10px;font-size:12px;text-decoration:none;">
1124
+ &#8595; Download</a>
1125
+ </div>
1126
+ <div id="wf_seglabel_{slot_id}"
1127
+ style="color:#aaa;font-size:11px;margin-top:4px;min-height:16px;"></div>
1128
+
1129
+ <!-- Popup that appears on segment click -->
1130
+ <div id="wf_popup_{slot_id}"
1131
+ style="display:none;position:fixed;z-index:9999;
1132
+ background:#2a2a2a;border:1px solid #555;border-radius:6px;
1133
+ padding:8px 12px;box-shadow:0 4px 16px rgba(0,0,0,0.5);">
1134
+ <div id="wf_popup_label_{slot_id}"
1135
+ style="color:#ccc;font-size:11px;margin-bottom:6px;white-space:nowrap;"></div>
1136
+ <button id="wf_regen_btn_{slot_id}"
1137
+ style="background:#1d6fa5;color:#fff;border:none;border-radius:4px;
1138
+ padding:5px 14px;font-size:12px;cursor:pointer;width:100%;">
1139
+ &#10227; Regenerate
1140
+ </button>
1141
+ </div>
1142
+ </div>
1143
+ <script>
1144
+ (function() {{
1145
+ // Guard against double-init on Gradio re-renders
1146
+ if (window["_wf_init_{slot_id}"]) return;
1147
+ window["_wf_init_{slot_id}"] = true;
1148
+
1149
+ let _pendingSegIdx_{slot_id} = null;
1150
+
1151
+ function fireRegen(idx) {{
1152
+ const popup = document.getElementById('wf_popup_{slot_id}');
1153
+ if (popup) popup.style.display = 'none';
1154
+ const lbl = document.getElementById('wf_seglabel_{slot_id}');
1155
+ const segs = {segs_json};
1156
+ if (lbl) lbl.textContent = 'Regenerating Seg ' + (idx+1) +
1157
+ ' (' + segs[idx][0].toFixed(2) + 's \u2013 ' + segs[idx][1].toFixed(2) + 's)\u2026';
1158
+ // Trigger Gradio via the hidden textbox
1159
+ const el = document.getElementById('{hidden_input_id}');
1160
+ if (el) {{
1161
+ const input = el.querySelector('input, textarea');
1162
+ if (input) {{
1163
+ const setter =
1164
+ Object.getOwnPropertyDescriptor(window.HTMLInputElement.prototype, 'value').set ||
1165
+ Object.getOwnPropertyDescriptor(window.HTMLTextAreaElement.prototype, 'value').set;
1166
+ setter.call(input, '{slot_id}|' + idx);
1167
+ input.dispatchEvent(new Event('input', {{ bubbles: true }}));
1168
+ }}
1169
+ }}
1170
+ }}
1171
+
1172
+ function showPopup(idx, mouseX, mouseY) {{
1173
+ _pendingSegIdx_{slot_id} = idx;
1174
+ const segs = {segs_json};
1175
+ const popup = document.getElementById('wf_popup_{slot_id}');
1176
+ const plbl = document.getElementById('wf_popup_label_{slot_id}');
1177
+ if (plbl) plbl.textContent =
1178
+ 'Seg ' + (idx+1) + ' (' + segs[idx][0].toFixed(2) + 's \u2013 ' + segs[idx][1].toFixed(2) + 's)';
1179
+ if (popup) {{
1180
+ popup.style.display = 'block';
1181
+ // Position near cursor, keep inside viewport
1182
+ const vw = window.innerWidth, vh = window.innerHeight;
1183
+ let x = mouseX + 10, y = mouseY + 10;
1184
+ popup.style.left = x + 'px';
1185
+ popup.style.top = y + 'px';
1186
+ // nudge back if off screen
1187
+ requestAnimationFrame(function() {{
1188
+ const r = popup.getBoundingClientRect();
1189
+ if (r.right > vw - 8) popup.style.left = (vw - r.width - 8) + 'px';
1190
+ if (r.bottom > vh - 8) popup.style.top = (vh - r.height - 8) + 'px';
1191
+ }});
1192
+ }}
1193
+ }}
1194
+
1195
+ function hidePopup() {{
1196
+ const popup = document.getElementById('wf_popup_{slot_id}');
1197
+ if (popup) popup.style.display = 'none';
1198
+ _pendingSegIdx_{slot_id} = null;
1199
+ }}
1200
+
1201
+ // Wire the Regenerate button
1202
+ document.addEventListener('DOMContentLoaded', function() {{
1203
+ const btn = document.getElementById('wf_regen_btn_{slot_id}');
1204
+ if (btn) btn.addEventListener('click', function(e) {{
1205
+ e.stopPropagation();
1206
+ if (_pendingSegIdx_{slot_id} !== null) fireRegen(_pendingSegIdx_{slot_id});
1207
+ }});
1208
+ }});
1209
+ // Also wire immediately in case DOM already loaded
1210
+ (function tryWireBtn() {{
1211
+ const btn = document.getElementById('wf_regen_btn_{slot_id}');
1212
+ if (btn) {{
1213
+ btn.onclick = function(e) {{
1214
+ e.stopPropagation();
1215
+ if (_pendingSegIdx_{slot_id} !== null) fireRegen(_pendingSegIdx_{slot_id});
1216
+ }};
1217
+ }} else {{
1218
+ setTimeout(tryWireBtn, 100);
1219
+ }}
1220
+ }})();
1221
+
1222
+ // Dismiss popup on click outside
1223
+ document.addEventListener('click', function(e) {{
1224
+ const popup = document.getElementById('wf_popup_{slot_id}');
1225
+ if (popup && popup.style.display !== 'none') {{
1226
+ if (!popup.contains(e.target)) hidePopup();
1227
+ }}
1228
+ }}, true);
1229
+
1230
+ function loadWS() {{
1231
+ if (!window.WaveSurfer || !window.WaveSurfer.Regions) {{
1232
+ setTimeout(loadWS, 200);
1233
+ return;
1234
+ }}
1235
+ const RegionsPlugin = window.WaveSurfer.Regions.create();
1236
+ const ws = WaveSurfer.create({{
1237
+ container: '#wf_{slot_id}',
1238
+ waveColor: '#4a9eff',
1239
+ progressColor:'#1a5fa8',
1240
+ height: 80,
1241
+ barWidth: 2,
1242
+ barGap: 1,
1243
+ barRadius: 2,
1244
+ backend: 'WebAudio',
1245
+ url: '{data_uri}',
1246
+ plugins: [RegionsPlugin],
1247
+ }});
1248
+ window["_wf_ws_{slot_id}"] = ws;
1249
+ window["wf_toggle_{slot_id}"] = function() {{ ws.playPause(); }};
1250
+
1251
+ const segments = {segs_json};
1252
+ const colors = {json.dumps(colors)};
1253
+
1254
+ ws.on('ready', function() {{
1255
+ segments.forEach(function(seg, idx) {{
1256
+ RegionsPlugin.addRegion({{
1257
+ id: 'seg_' + idx,
1258
+ start: seg[0],
1259
+ end: seg[1],
1260
+ color: colors[idx % colors.length],
1261
+ drag: false,
1262
+ resize: false,
1263
+ content: 'Seg ' + (idx + 1),
1264
+ }});
1265
+ }});
1266
+ }});
1267
+
1268
+ RegionsPlugin.on('region-clicked', function(region, e) {{
1269
+ e.stopPropagation();
1270
+ const idx = parseInt(region.id.replace('seg_', ''));
1271
+ showPopup(idx, e.clientX, e.clientY);
1272
+ }});
1273
+
1274
+ ws.on('play', function() {{
1275
+ const b = document.getElementById('wf_playbtn_{slot_id}');
1276
+ if (b) b.textContent = '\u23f8 Pause';
1277
+ }});
1278
+ ws.on('pause', function() {{
1279
+ const b = document.getElementById('wf_playbtn_{slot_id}');
1280
+ if (b) b.textContent = '\u25b6 Play';
1281
+ }});
1282
+ ws.on('finish', function() {{
1283
+ const b = document.getElementById('wf_playbtn_{slot_id}');
1284
+ if (b) b.textContent = '\u25b6 Play';
1285
+ }});
1286
+ }}
1287
+
1288
+ if (!document.getElementById('wavesurfer_script')) {{
1289
+ const s = document.createElement('script');
1290
+ s.id = 'wavesurfer_script';
1291
+ s.src = '{_WAVESURFER_CDN}';
1292
+ s.onload = function() {{
1293
+ const r = document.createElement('script');
1294
+ r.id = 'wavesurfer_regions_script';
1295
+ r.src = '{_REGIONS_CDN}';
1296
+ r.onload = loadWS;
1297
+ document.head.appendChild(r);
1298
+ }};
1299
+ document.head.appendChild(s);
1300
+ }} else {{
1301
+ loadWS();
1302
+ }}
1303
+ }})();
1304
+ </script>
1305
+ """
1306
+
1307
+
1308
+ def _make_output_slots(tab_prefix: str) -> tuple:
1309
+ """Build MAX_SLOTS output groups for one tab.
1310
+
1311
+ Each slot has: video, waveform HTML, hidden regen trigger textbox, seg state.
1312
+ Returns (grps, vids, waveforms, regen_triggers, seg_states).
1313
+ """
1314
+ grps, vids, waveforms, regen_triggers, seg_states = [], [], [], [], []
1315
  for i in range(MAX_SLOTS):
1316
  with gr.Group(visible=(i == 0)) as g:
1317
+ slot_id = f"{tab_prefix}_{i}"
1318
  vids.append(gr.Video(label=f"Generation {i+1} — Video"))
1319
+ waveforms.append(gr.HTML(
1320
+ value="<p style='color:#888;font-size:12px'>Generate audio to see waveform.</p>",
1321
+ label=f"Generation {i+1} — Waveform",
1322
+ ))
1323
+ # Hidden textbox: JS writes "<slot_id>|<seg_idx>" here to trigger regen
1324
+ regen_triggers.append(gr.Textbox(
1325
+ value="",
1326
+ visible=False,
1327
+ elem_id=f"regen_trigger_{slot_id}",
1328
+ label=f"regen_trigger_{slot_id}",
1329
+ ))
1330
+ seg_states.append(gr.State(value=None))
1331
  grps.append(g)
1332
+ return grps, vids, waveforms, regen_triggers, seg_states
1333
 
1334
 
1335
+ def _unpack_outputs(flat: list, n: int, tab_prefix: str) -> list:
1336
+ """Turn a flat _pad_outputs list into Gradio update lists.
1337
+
1338
+ flat has MAX_SLOTS * 3 items: [vid0, aud0, meta0, vid1, aud1, meta1, ...]
1339
+ Returns updates for: grps + vids + waveforms + seg_states
1340
+ """
1341
  n = int(n)
1342
+ grp_updates = [gr.update(visible=(i < n)) for i in range(MAX_SLOTS)]
1343
+ vid_updates = []
1344
+ wave_updates = []
1345
+ state_updates= []
1346
+ for i in range(MAX_SLOTS):
1347
+ vid_path = flat[i * 3]
1348
+ aud_path = flat[i * 3 + 1]
1349
+ meta = flat[i * 3 + 2]
1350
+ vid_updates.append(gr.update(value=vid_path))
1351
+ if aud_path and meta:
1352
+ slot_id = f"{tab_prefix}_{i}"
1353
+ hidden_el_id = f"regen_trigger_{slot_id}"
1354
+ html = _build_waveform_html(aud_path, meta["segments"], slot_id, hidden_el_id)
1355
+ wave_updates.append(gr.update(value=html))
1356
+ state_updates.append(meta)
1357
+ else:
1358
+ wave_updates.append(gr.update(
1359
+ value="<p style='color:#888;font-size:12px'>Generate audio to see waveform.</p>"
1360
+ ))
1361
+ state_updates.append(None)
1362
+ return grp_updates + vid_updates + wave_updates + state_updates
1363
 
1364
 
1365
  def _on_video_upload_taro(video_file, num_steps, crossfade_s):
 
1412
  taro_btn = gr.Button("Generate", variant="primary")
1413
 
1414
  with gr.Column():
1415
+ (taro_slot_grps, taro_slot_vids,
1416
+ taro_slot_waves, taro_slot_rtrigs,
1417
+ taro_slot_states) = _make_output_slots("taro")
1418
 
1419
  for trigger in [taro_video, taro_steps, taro_cf_dur]:
1420
  trigger.change(
 
1429
  )
1430
 
1431
  def _run_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n):
1432
+ flat = generate_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n)
1433
+ # Serialise wavs in meta to JSON-safe lists
1434
+ for i in range(MAX_SLOTS):
1435
+ meta = flat[i * 3 + 2]
1436
+ if meta is not None:
1437
+ meta["wavs"] = [w.tolist() for w in meta["wavs"]]
1438
+ flat[i * 3 + 2] = meta
1439
+ return _unpack_outputs(flat, n, "taro")
1440
 
1441
  taro_btn.click(
1442
  fn=_run_taro,
1443
  inputs=[taro_video, taro_seed, taro_cfg, taro_steps, taro_mode,
1444
  taro_cf_dur, taro_cf_db, taro_samples],
1445
+ outputs=taro_slot_grps + taro_slot_vids + taro_slot_waves + taro_slot_states,
1446
  )
1447
 
1448
+ # Per-slot regen trigger wiring for TARO
1449
+ for _i, _rtrig in enumerate(taro_slot_rtrigs):
1450
+ _slot_id = f"taro_{_i}"
1451
+ def _make_taro_regen(_si, _sid):
1452
+ def _do(trigger_val, video, seed, cfg, steps, mode, cf_dur, cf_db, state):
1453
+ if not trigger_val or not state:
1454
+ return gr.update(), gr.update(), state, gr.update()
1455
+ parts = trigger_val.split("|")
1456
+ if len(parts) != 2 or parts[0] != _sid:
1457
+ return gr.update(), gr.update(), state, gr.update()
1458
+ seg_idx = int(parts[1])
1459
+ meta_json = json.dumps(state)
1460
+ vid, aud, new_meta_json, html = regen_taro_segment(
1461
+ video, seg_idx, meta_json,
1462
+ seed, cfg, steps, mode, cf_dur, cf_db, _sid,
1463
+ )
1464
+ new_meta = json.loads(new_meta_json)
1465
+ return gr.update(value=vid), gr.update(value=html), new_meta, gr.update(value="")
1466
+ return _do
1467
+ _rtrig.change(
1468
+ fn=_make_taro_regen(_i, _slot_id),
1469
+ inputs=[_rtrig, taro_video, taro_seed, taro_cfg, taro_steps,
1470
+ taro_mode, taro_cf_dur, taro_cf_db, taro_slot_states[_i]],
1471
+ outputs=[taro_slot_vids[_i], taro_slot_waves[_i],
1472
+ taro_slot_states[_i], _rtrig],
1473
+ )
1474
+
1475
  # ---------------------------------------------------------- #
1476
  # Tab 2 — MMAudio #
1477
  # ---------------------------------------------------------- #
 
1490
  mma_btn = gr.Button("Generate", variant="primary")
1491
 
1492
  with gr.Column():
1493
+ (mma_slot_grps, mma_slot_vids,
1494
+ mma_slot_waves, mma_slot_rtrigs,
1495
+ mma_slot_states) = _make_output_slots("mma")
1496
 
1497
  mma_samples.change(
1498
  fn=_update_slot_visibility,
 
1501
  )
1502
 
1503
  def _run_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n):
1504
+ flat = generate_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n)
1505
+ for i in range(MAX_SLOTS):
1506
+ meta = flat[i * 3 + 2]
1507
+ if meta is not None:
1508
+ meta["wavs"] = [w.tolist() for w in meta["wavs"]]
1509
+ flat[i * 3 + 2] = meta
1510
+ return _unpack_outputs(flat, n, "mma")
1511
 
1512
  mma_btn.click(
1513
  fn=_run_mmaudio,
1514
  inputs=[mma_video, mma_prompt, mma_neg, mma_seed,
1515
  mma_cfg, mma_steps, mma_cf_dur, mma_cf_db, mma_samples],
1516
+ outputs=mma_slot_grps + mma_slot_vids + mma_slot_waves + mma_slot_states,
1517
  )
1518
 
1519
+ for _i, _rtrig in enumerate(mma_slot_rtrigs):
1520
+ _slot_id = f"mma_{_i}"
1521
+ def _make_mma_regen(_si, _sid):
1522
+ def _do(trigger_val, video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, state):
1523
+ if not trigger_val or not state:
1524
+ return gr.update(), gr.update(), state, gr.update()
1525
+ parts = trigger_val.split("|")
1526
+ if len(parts) != 2 or parts[0] != _sid:
1527
+ return gr.update(), gr.update(), state, gr.update()
1528
+ seg_idx = int(parts[1])
1529
+ meta_json = json.dumps(state)
1530
+ vid, aud, new_meta_json, html = regen_mmaudio_segment(
1531
+ video, seg_idx, meta_json,
1532
+ prompt, neg, seed, cfg, steps, cf_dur, cf_db, _sid,
1533
+ )
1534
+ new_meta = json.loads(new_meta_json)
1535
+ return gr.update(value=vid), gr.update(value=html), new_meta, gr.update(value="")
1536
+ return _do
1537
+ _rtrig.change(
1538
+ fn=_make_mma_regen(_i, _slot_id),
1539
+ inputs=[_rtrig, mma_video, mma_prompt, mma_neg, mma_seed,
1540
+ mma_cfg, mma_steps, mma_cf_dur, mma_cf_db, mma_slot_states[_i]],
1541
+ outputs=[mma_slot_vids[_i], mma_slot_waves[_i],
1542
+ mma_slot_states[_i], _rtrig],
1543
+ )
1544
+
1545
  # ---------------------------------------------------------- #
1546
  # Tab 3 — HunyuanVideoFoley #
1547
  # ---------------------------------------------------------- #
 
1561
  hf_btn = gr.Button("Generate", variant="primary")
1562
 
1563
  with gr.Column():
1564
+ (hf_slot_grps, hf_slot_vids,
1565
+ hf_slot_waves, hf_slot_rtrigs,
1566
+ hf_slot_states) = _make_output_slots("hf")
1567
 
1568
  hf_samples.change(
1569
  fn=_update_slot_visibility,
 
1572
  )
1573
 
1574
  def _run_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n):
1575
+ flat = generate_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n)
1576
+ for i in range(MAX_SLOTS):
1577
+ meta = flat[i * 3 + 2]
1578
+ if meta is not None:
1579
+ meta["wavs"] = [w.tolist() for w in meta["wavs"]]
1580
+ flat[i * 3 + 2] = meta
1581
+ return _unpack_outputs(flat, n, "hf")
1582
 
1583
  hf_btn.click(
1584
  fn=_run_hunyuan,
1585
  inputs=[hf_video, hf_prompt, hf_neg, hf_seed,
1586
  hf_guidance, hf_steps, hf_size, hf_cf_dur, hf_cf_db, hf_samples],
1587
+ outputs=hf_slot_grps + hf_slot_vids + hf_slot_waves + hf_slot_states,
1588
  )
1589
 
1590
+ for _i, _rtrig in enumerate(hf_slot_rtrigs):
1591
+ _slot_id = f"hf_{_i}"
1592
+ def _make_hf_regen(_si, _sid):
1593
+ def _do(trigger_val, video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, state):
1594
+ if not trigger_val or not state:
1595
+ return gr.update(), gr.update(), state, gr.update()
1596
+ parts = trigger_val.split("|")
1597
+ if len(parts) != 2 or parts[0] != _sid:
1598
+ return gr.update(), gr.update(), state, gr.update()
1599
+ seg_idx = int(parts[1])
1600
+ meta_json = json.dumps(state)
1601
+ vid, aud, new_meta_json, html = regen_hunyuan_segment(
1602
+ video, seg_idx, meta_json,
1603
+ prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, _sid,
1604
+ )
1605
+ new_meta = json.loads(new_meta_json)
1606
+ return gr.update(value=vid), gr.update(value=html), new_meta, gr.update(value="")
1607
+ return _do
1608
+ _rtrig.change(
1609
+ fn=_make_hf_regen(_i, _slot_id),
1610
+ inputs=[_rtrig, hf_video, hf_prompt, hf_neg, hf_seed,
1611
+ hf_guidance, hf_steps, hf_size, hf_cf_dur, hf_cf_db, hf_slot_states[_i]],
1612
+ outputs=[hf_slot_vids[_i], hf_slot_waves[_i],
1613
+ hf_slot_states[_i], _rtrig],
1614
+ )
1615
+
1616
  # ---- Cross-tab video sync ----
 
 
1617
  _sync = lambda v: (gr.update(value=v), gr.update(value=v))
1618
  taro_video.change(fn=_sync, inputs=[taro_video], outputs=[mma_video, hf_video])
1619
  mma_video.change(fn=_sync, inputs=[mma_video], outputs=[taro_video, hf_video])