BoxOfColors commited on
Commit
8697d49
·
1 Parent(s): 585a112

Cleanup: remove duplication and dead code across app.py

Browse files

- Remove [DIAG] print left over from debugging
- Move torchaudio to top-level import; drop per-function import torchaudio
inside generate_mmaudio and generate_hunyuan
- Replace sf.write in generate_taro with torchaudio.save (consistent with
other two models); drop soundfile import
- Remove _taro_build_segments wrapper — was a one-liner; call _build_segments
with TARO_MODEL_DUR directly at both call sites
- Drop MMA_CF_S/MMA_CF_DB and CF_S/CF_DB local aliases in generate_mmaudio
and generate_hunyuan — cast crossfade_s/crossfade_db once at function top
- Extract _make_output_slots() — builds the 8-slot video+audio output column;
replaces identical 7-line loop duplicated across all 3 tabs
- Extract _unpack_outputs(flat, n) — turns _pad_outputs list into Gradio
update lists; replaces identical 4-line block in all 3 _run_* functions

Files changed (1) hide show
  1. app.py +43 -66
app.py CHANGED
@@ -15,9 +15,8 @@ from math import floor
15
  from pathlib import Path
16
 
17
  import torch
18
- print(f"[DIAG] torch={torch.__version__} cuda={torch.version.cuda}")
19
  import numpy as np
20
- import soundfile as sf
21
  import ffmpeg
22
  import spaces
23
  import gradio as gr
@@ -171,13 +170,8 @@ TARO_SECS_PER_STEP = 2.5 # estimated GPU-seconds per diffusion step
171
  _TARO_INFERENCE_CACHE: dict = {}
172
 
173
 
174
- def _taro_build_segments(total_dur_s: float, crossfade_s: float) -> list:
175
- """Sliding-window segmentation using TARO's 8.192 s window."""
176
- return _build_segments(total_dur_s, TARO_MODEL_DUR, crossfade_s)
177
-
178
-
179
  def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: float) -> int:
180
- n_segs = len(_taro_build_segments(total_dur_s, crossfade_s))
181
  time_per_seg = num_steps * TARO_SECS_PER_STEP
182
  max_s = floor(600.0 / (n_segs * time_per_seg))
183
  return max(1, min(max_s, MAX_SLOTS))
@@ -325,7 +319,7 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
325
 
326
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
327
  total_dur_s = cavp_feats.shape[0] / TARO_FPS
328
- segments = _taro_build_segments(total_dur_s, crossfade_s)
329
 
330
  outputs = []
331
  for sample_idx in range(num_samples):
@@ -355,7 +349,7 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
355
 
356
  final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR)
357
  audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
358
- sf.write(audio_path, final_wav, TARO_SR)
359
  video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
360
  mux_video_audio(silent_video, audio_path, video_path)
361
  outputs.append((video_path, audio_path))
@@ -382,7 +376,6 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
382
  cfg_strength, num_steps, num_samples,
383
  crossfade_s=1.0, crossfade_db=3.0):
384
  """MMAudio: flow-matching video-to-audio, 44.1 kHz, 8 s sliding window."""
385
- # MMAudio is a local package in ./MMAudio/ — add it to sys.path so imports work.
386
  import sys as _sys, os as _os
387
  _mmaudio_dir = _os.path.join(_os.path.dirname(_os.path.abspath(__file__)), "MMAudio")
388
  if _mmaudio_dir not in _sys.path:
@@ -431,10 +424,10 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
431
 
432
  # MMAudio's fixed window is 8 s. For longer videos we slide over 8 s segments
433
  # with a crossfade overlap and stitch the results into a full-length track.
434
- total_dur_s = get_video_duration(video_file)
435
- MMA_CF_S = float(crossfade_s)
436
- MMA_CF_DB = float(crossfade_db)
437
- segments = _build_segments(total_dur_s, MMAUDIO_WINDOW, MMA_CF_S)
438
  print(f"[MMAudio] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤8 s")
439
 
440
  sr = seq_cfg.sampling_rate # 44100
@@ -488,7 +481,7 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
488
  # Crossfade-stitch all segments using shared equal-power helper
489
  full_wav = seg_audios[0]
490
  for nw in seg_audios[1:]:
491
- full_wav = _cf_join(full_wav, nw, MMA_CF_S, MMA_CF_DB, sr)
492
  full_wav = full_wav[:, : int(round(total_dur_s * sr))]
493
 
494
  audio_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.flac")
@@ -521,7 +514,6 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
521
  guidance_scale, num_steps, model_size, num_samples,
522
  crossfade_s=2.0, crossfade_db=3.0):
523
  """HunyuanVideoFoley: text-guided foley, 48 kHz, up to 15 s."""
524
- import torchaudio
525
  import sys as _sys
526
  # Ensure HunyuanVideo-Foley package is importable
527
  _hf_path = str(Path("HunyuanVideo-Foley").resolve())
@@ -564,10 +556,10 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
564
  # HunyuanFoley is limited to 15 s per pass. For longer videos we slice the
565
  # input into overlapping segments, generate audio for each, then crossfade-
566
  # stitch the results into a single full-length audio track.
567
- total_dur_s = get_video_duration(video_file)
568
- CF_S = float(crossfade_s)
569
- CF_DB = float(crossfade_db)
570
- segments = _build_segments(total_dur_s, HUNYUAN_MAX_DUR, CF_S)
571
  print(f"[HunyuanFoley] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤15 s")
572
 
573
  # Pre-encode text features once (same for every segment)
@@ -624,7 +616,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
624
  # Crossfade-stitch all segments using shared equal-power helper
625
  full_wav = seg_wavs[0]
626
  for nw in seg_wavs[1:]:
627
- full_wav = _cf_join(full_wav, nw, CF_S, CF_DB, sr)
628
  # Trim to exact video duration
629
  full_wav = full_wav[:, : int(round(total_dur_s * sr))]
630
 
@@ -652,6 +644,27 @@ def _pad_outputs(outputs: list) -> list:
652
  return result
653
 
654
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
  def _on_video_upload_taro(video_file, num_steps, crossfade_s):
656
  if video_file is None:
657
  return gr.update(maximum=MAX_SLOTS, value=1)
@@ -702,14 +715,7 @@ with gr.Blocks(title="Generate Audio for Video") as demo:
702
  taro_btn = gr.Button("Generate", variant="primary")
703
 
704
  with gr.Column():
705
- taro_slot_grps, taro_slot_vids, taro_slot_auds = [], [], []
706
- for i in range(MAX_SLOTS):
707
- with gr.Group(visible=(i == 0)) as g:
708
- sv = gr.Video(label=f"Generation {i+1} — Video")
709
- sa = gr.Audio(label=f"Generation {i+1} — Audio")
710
- taro_slot_grps.append(g)
711
- taro_slot_vids.append(sv)
712
- taro_slot_auds.append(sa)
713
 
714
  for trigger in [taro_video, taro_steps, taro_cf_dur]:
715
  trigger.change(
@@ -724,12 +730,7 @@ with gr.Blocks(title="Generate Audio for Video") as demo:
724
  )
725
 
726
  def _run_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n):
727
- flat = generate_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n)
728
- n = int(n)
729
- grp_upd = [gr.update(visible=(i < n)) for i in range(MAX_SLOTS)]
730
- vid_upd = [gr.update(value=flat[i * 2]) for i in range(MAX_SLOTS)]
731
- aud_upd = [gr.update(value=flat[i * 2 + 1]) for i in range(MAX_SLOTS)]
732
- return grp_upd + vid_upd + aud_upd
733
 
734
  taro_btn.click(
735
  fn=_run_taro,
@@ -756,14 +757,7 @@ with gr.Blocks(title="Generate Audio for Video") as demo:
756
  mma_btn = gr.Button("Generate", variant="primary")
757
 
758
  with gr.Column():
759
- mma_slot_grps, mma_slot_vids, mma_slot_auds = [], [], []
760
- for i in range(MAX_SLOTS):
761
- with gr.Group(visible=(i == 0)) as g:
762
- sv = gr.Video(label=f"Generation {i+1} — Video")
763
- sa = gr.Audio(label=f"Generation {i+1} — Audio")
764
- mma_slot_grps.append(g)
765
- mma_slot_vids.append(sv)
766
- mma_slot_auds.append(sa)
767
 
768
  mma_samples.change(
769
  fn=_update_slot_visibility,
@@ -772,13 +766,8 @@ with gr.Blocks(title="Generate Audio for Video") as demo:
772
  )
773
 
774
  def _run_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n):
775
- flat = generate_mmaudio(video, prompt, neg, seed, cfg, steps, n,
776
- crossfade_s=cf_dur, crossfade_db=cf_db)
777
- n = int(n)
778
- grp_upd = [gr.update(visible=(i < n)) for i in range(MAX_SLOTS)]
779
- vid_upd = [gr.update(value=flat[i * 2]) for i in range(MAX_SLOTS)]
780
- aud_upd = [gr.update(value=flat[i * 2 + 1]) for i in range(MAX_SLOTS)]
781
- return grp_upd + vid_upd + aud_upd
782
 
783
  mma_btn.click(
784
  fn=_run_mmaudio,
@@ -806,14 +795,7 @@ with gr.Blocks(title="Generate Audio for Video") as demo:
806
  hf_btn = gr.Button("Generate", variant="primary")
807
 
808
  with gr.Column():
809
- hf_slot_grps, hf_slot_vids, hf_slot_auds = [], [], []
810
- for i in range(MAX_SLOTS):
811
- with gr.Group(visible=(i == 0)) as g:
812
- sv = gr.Video(label=f"Generation {i+1} — Video")
813
- sa = gr.Audio(label=f"Generation {i+1} — Audio")
814
- hf_slot_grps.append(g)
815
- hf_slot_vids.append(sv)
816
- hf_slot_auds.append(sa)
817
 
818
  hf_samples.change(
819
  fn=_update_slot_visibility,
@@ -822,13 +804,8 @@ with gr.Blocks(title="Generate Audio for Video") as demo:
822
  )
823
 
824
  def _run_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n):
825
- flat = generate_hunyuan(video, prompt, neg, seed, guidance, steps, size, n,
826
- crossfade_s=cf_dur, crossfade_db=cf_db)
827
- n = int(n)
828
- grp_upd = [gr.update(visible=(i < n)) for i in range(MAX_SLOTS)]
829
- vid_upd = [gr.update(value=flat[i * 2]) for i in range(MAX_SLOTS)]
830
- aud_upd = [gr.update(value=flat[i * 2 + 1]) for i in range(MAX_SLOTS)]
831
- return grp_upd + vid_upd + aud_upd
832
 
833
  hf_btn.click(
834
  fn=_run_hunyuan,
 
15
  from pathlib import Path
16
 
17
  import torch
 
18
  import numpy as np
19
+ import torchaudio
20
  import ffmpeg
21
  import spaces
22
  import gradio as gr
 
170
  _TARO_INFERENCE_CACHE: dict = {}
171
 
172
 
 
 
 
 
 
173
  def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: float) -> int:
174
+ n_segs = len(_build_segments(total_dur_s, TARO_MODEL_DUR, crossfade_s))
175
  time_per_seg = num_steps * TARO_SECS_PER_STEP
176
  max_s = floor(600.0 / (n_segs * time_per_seg))
177
  return max(1, min(max_s, MAX_SLOTS))
 
319
 
320
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
321
  total_dur_s = cavp_feats.shape[0] / TARO_FPS
322
+ segments = _build_segments(total_dur_s, TARO_MODEL_DUR, crossfade_s)
323
 
324
  outputs = []
325
  for sample_idx in range(num_samples):
 
349
 
350
  final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR)
351
  audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
352
+ torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(final_wav)).unsqueeze(0), TARO_SR)
353
  video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
354
  mux_video_audio(silent_video, audio_path, video_path)
355
  outputs.append((video_path, audio_path))
 
376
  cfg_strength, num_steps, num_samples,
377
  crossfade_s=1.0, crossfade_db=3.0):
378
  """MMAudio: flow-matching video-to-audio, 44.1 kHz, 8 s sliding window."""
 
379
  import sys as _sys, os as _os
380
  _mmaudio_dir = _os.path.join(_os.path.dirname(_os.path.abspath(__file__)), "MMAudio")
381
  if _mmaudio_dir not in _sys.path:
 
424
 
425
  # MMAudio's fixed window is 8 s. For longer videos we slide over 8 s segments
426
  # with a crossfade overlap and stitch the results into a full-length track.
427
+ crossfade_s = float(crossfade_s)
428
+ crossfade_db = float(crossfade_db)
429
+ total_dur_s = get_video_duration(video_file)
430
+ segments = _build_segments(total_dur_s, MMAUDIO_WINDOW, crossfade_s)
431
  print(f"[MMAudio] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤8 s")
432
 
433
  sr = seq_cfg.sampling_rate # 44100
 
481
  # Crossfade-stitch all segments using shared equal-power helper
482
  full_wav = seg_audios[0]
483
  for nw in seg_audios[1:]:
484
+ full_wav = _cf_join(full_wav, nw, crossfade_s, crossfade_db, sr)
485
  full_wav = full_wav[:, : int(round(total_dur_s * sr))]
486
 
487
  audio_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.flac")
 
514
  guidance_scale, num_steps, model_size, num_samples,
515
  crossfade_s=2.0, crossfade_db=3.0):
516
  """HunyuanVideoFoley: text-guided foley, 48 kHz, up to 15 s."""
 
517
  import sys as _sys
518
  # Ensure HunyuanVideo-Foley package is importable
519
  _hf_path = str(Path("HunyuanVideo-Foley").resolve())
 
556
  # HunyuanFoley is limited to 15 s per pass. For longer videos we slice the
557
  # input into overlapping segments, generate audio for each, then crossfade-
558
  # stitch the results into a single full-length audio track.
559
+ crossfade_s = float(crossfade_s)
560
+ crossfade_db = float(crossfade_db)
561
+ total_dur_s = get_video_duration(video_file)
562
+ segments = _build_segments(total_dur_s, HUNYUAN_MAX_DUR, crossfade_s)
563
  print(f"[HunyuanFoley] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤15 s")
564
 
565
  # Pre-encode text features once (same for every segment)
 
616
  # Crossfade-stitch all segments using shared equal-power helper
617
  full_wav = seg_wavs[0]
618
  for nw in seg_wavs[1:]:
619
+ full_wav = _cf_join(full_wav, nw, crossfade_s, crossfade_db, sr)
620
  # Trim to exact video duration
621
  full_wav = full_wav[:, : int(round(total_dur_s * sr))]
622
 
 
644
  return result
645
 
646
 
647
+ def _make_output_slots() -> tuple:
648
+ """Build MAX_SLOTS video+audio output groups. Returns (grps, vids, auds)."""
649
+ grps, vids, auds = [], [], []
650
+ for i in range(MAX_SLOTS):
651
+ with gr.Group(visible=(i == 0)) as g:
652
+ vids.append(gr.Video(label=f"Generation {i+1} — Video"))
653
+ auds.append(gr.Audio(label=f"Generation {i+1} — Audio"))
654
+ grps.append(g)
655
+ return grps, vids, auds
656
+
657
+
658
+ def _unpack_outputs(flat: list, n: int) -> list:
659
+ """Turn a flat _pad_outputs list into Gradio update lists for grps+vids+auds."""
660
+ n = int(n)
661
+ return (
662
+ [gr.update(visible=(i < n)) for i in range(MAX_SLOTS)] +
663
+ [gr.update(value=flat[i * 2]) for i in range(MAX_SLOTS)] +
664
+ [gr.update(value=flat[i * 2 + 1]) for i in range(MAX_SLOTS)]
665
+ )
666
+
667
+
668
  def _on_video_upload_taro(video_file, num_steps, crossfade_s):
669
  if video_file is None:
670
  return gr.update(maximum=MAX_SLOTS, value=1)
 
715
  taro_btn = gr.Button("Generate", variant="primary")
716
 
717
  with gr.Column():
718
+ taro_slot_grps, taro_slot_vids, taro_slot_auds = _make_output_slots()
 
 
 
 
 
 
 
719
 
720
  for trigger in [taro_video, taro_steps, taro_cf_dur]:
721
  trigger.change(
 
730
  )
731
 
732
  def _run_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n):
733
+ return _unpack_outputs(generate_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n), n)
 
 
 
 
 
734
 
735
  taro_btn.click(
736
  fn=_run_taro,
 
757
  mma_btn = gr.Button("Generate", variant="primary")
758
 
759
  with gr.Column():
760
+ mma_slot_grps, mma_slot_vids, mma_slot_auds = _make_output_slots()
 
 
 
 
 
 
 
761
 
762
  mma_samples.change(
763
  fn=_update_slot_visibility,
 
766
  )
767
 
768
  def _run_mmaudio(video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, n):
769
+ return _unpack_outputs(generate_mmaudio(video, prompt, neg, seed, cfg, steps, n,
770
+ crossfade_s=cf_dur, crossfade_db=cf_db), n)
 
 
 
 
 
771
 
772
  mma_btn.click(
773
  fn=_run_mmaudio,
 
795
  hf_btn = gr.Button("Generate", variant="primary")
796
 
797
  with gr.Column():
798
+ hf_slot_grps, hf_slot_vids, hf_slot_auds = _make_output_slots()
 
 
 
 
 
 
 
799
 
800
  hf_samples.change(
801
  fn=_update_slot_visibility,
 
804
  )
805
 
806
  def _run_hunyuan(video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, n):
807
+ return _unpack_outputs(generate_hunyuan(video, prompt, neg, seed, guidance, steps, size, n,
808
+ crossfade_s=cf_dur, crossfade_db=cf_db), n)
 
 
 
 
 
809
 
810
  hf_btn.click(
811
  fn=_run_hunyuan,