BoxOfColors commited on
Commit
72afd74
·
1 Parent(s): 95c5c55

Refactor: consolidate segmentation and crossfade into shared helpers

Browse files

- Add _build_segments(total, window_s, crossfade_s) as universal segmentation
helper used by all three models (TARO via thin wrapper, MMAudio and
HunyuanFoley directly with their respective window constants)
- Add _cf_join_stereo(a, b, cf_s, db, sr) as shared equal-power crossfade for
stereo (C, T) arrays; MMAudio and HunyuanFoley both call this instead of
duplicating the same inline closure
- Remove duplicate _mma_build_segments and inline _cf_join closures from
generate_mmaudio and generate_hunyuan
- _taro_build_segments now delegates to _build_segments; TARO keeps its own
mono _crossfade_join/_stitch_wavs since it outputs 1D not (C,T)

Files changed (1) hide show
  1. app.py +46 -58
app.py CHANGED
@@ -105,6 +105,43 @@ def mux_video_audio(silent_video: str, audio_path: str, output_path: str):
105
  ).run(overwrite_output=True, quiet=True)
106
 
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # ================================================================== #
109
  # TARO #
110
  # ================================================================== #
@@ -128,19 +165,8 @@ _TARO_INFERENCE_CACHE: dict = {}
128
 
129
 
130
  def _taro_build_segments(total_dur_s: float, crossfade_s: float) -> list:
131
- """Sliding-window segmentation for videos longer than one TARO window."""
132
- if total_dur_s <= TARO_MODEL_DUR:
133
- return [(0.0, total_dur_s)]
134
- step_s = TARO_MODEL_DUR - crossfade_s
135
- segments, seg_start = [], 0.0
136
- while True:
137
- if seg_start + TARO_MODEL_DUR >= total_dur_s:
138
- seg_start = max(0.0, total_dur_s - TARO_MODEL_DUR)
139
- segments.append((seg_start, total_dur_s))
140
- break
141
- segments.append((seg_start, seg_start + TARO_MODEL_DUR))
142
- seg_start += step_s
143
- return segments
144
 
145
 
146
  def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: float) -> int:
@@ -411,25 +437,11 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
411
  outputs = []
412
 
413
  # MMAudio's fixed window is 8 s. For longer videos we slide over 8 s segments
414
- # with a 1 s crossfade overlap and stitch the results into a full-length track.
415
  total_dur_s = get_video_duration(video_file)
416
  MMA_CF_S = float(crossfade_s)
417
  MMA_CF_DB = float(crossfade_db)
418
-
419
- def _mma_build_segments(total_s, cf_s):
420
- if total_s <= MMAUDIO_WINDOW:
421
- return [(0.0, total_s)]
422
- step_s = MMAUDIO_WINDOW - cf_s
423
- segs, t = [], 0.0
424
- while True:
425
- if t + MMAUDIO_WINDOW >= total_s:
426
- segs.append((max(0.0, total_s - MMAUDIO_WINDOW), total_s))
427
- break
428
- segs.append((t, t + MMAUDIO_WINDOW))
429
- t += step_s
430
- return segs
431
-
432
- segments = _mma_build_segments(total_dur_s, MMA_CF_S)
433
  print(f"[MMAudio] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤8 s")
434
 
435
  sr = seq_cfg.sampling_rate # 44100
@@ -480,22 +492,10 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
480
  wav = wav[:, :seg_samples]
481
  seg_audios.append(wav)
482
 
483
- # Crossfade-stitch all segments (equal-power fade)
484
- def _cf_join(a, b, cf_s):
485
- cf = int(round(cf_s * sr))
486
- cf = min(cf, a.shape[1], b.shape[1])
487
- if cf <= 0:
488
- return np.concatenate([a, b], axis=1)
489
- gain = 10 ** (MMA_CF_DB / 20.0)
490
- t = np.linspace(0.0, 1.0, cf, dtype=np.float32)
491
- fade_out = np.cos(t * np.pi / 2)
492
- fade_in = np.sin(t * np.pi / 2)
493
- overlap = a[:, -cf:] * fade_out * gain + b[:, :cf] * fade_in * gain
494
- return np.concatenate([a[:, :-cf], overlap, b[:, cf:]], axis=1)
495
-
496
  full_wav = seg_audios[0]
497
  for nw in seg_audios[1:]:
498
- full_wav = _cf_join(full_wav, nw, MMA_CF_S)
499
  full_wav = full_wav[:, : int(round(total_dur_s * sr))]
500
 
501
  audio_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.flac")
@@ -574,7 +574,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
574
  total_dur_s = get_video_duration(video_file)
575
  CF_S = float(crossfade_s)
576
  CF_DB = float(crossfade_db)
577
- segments = _taro_build_segments(total_dur_s, CF_S) # reuse TARO helper
578
  print(f"[HunyuanFoley] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤15 s")
579
 
580
  # Pre-encode text features once (same for every segment)
@@ -628,22 +628,10 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
628
  wav = wav[:, :seg_samples]
629
  seg_wavs.append(wav)
630
 
631
- # Stitch segments with equal-power crossfade (operates on (channels, samples) arrays)
632
- def _cf_join_stereo(a, b, cf_s, db):
633
- cf = int(round(cf_s * sr))
634
- cf = min(cf, a.shape[1], b.shape[1])
635
- if cf <= 0:
636
- return np.concatenate([a, b], axis=1)
637
- gain = 10 ** (db / 20.0)
638
- t = np.linspace(0.0, 1.0, cf, dtype=np.float32)
639
- fade_out = np.cos(t * np.pi / 2)
640
- fade_in = np.sin(t * np.pi / 2)
641
- overlap = a[:, -cf:] * fade_out * gain + b[:, :cf] * fade_in * gain
642
- return np.concatenate([a[:, :-cf], overlap, b[:, cf:]], axis=1)
643
-
644
  full_wav = seg_wavs[0]
645
  for nw in seg_wavs[1:]:
646
- full_wav = _cf_join_stereo(full_wav, nw, CF_S, CF_DB)
647
  # Trim to exact video duration
648
  full_wav = full_wav[:, : int(round(total_dur_s * sr))]
649
 
 
105
  ).run(overwrite_output=True, quiet=True)
106
 
107
 
108
+ # ------------------------------------------------------------------ #
109
+ # Shared sliding-window segmentation and crossfade helpers #
110
+ # Used by all three models (TARO, MMAudio, HunyuanFoley). #
111
+ # ------------------------------------------------------------------ #
112
+
113
+ def _build_segments(total_dur_s: float, window_s: float, crossfade_s: float) -> list:
114
+ """Return list of (start, end) pairs covering *total_dur_s* with a sliding
115
+ window of *window_s* and *crossfade_s* overlap between consecutive segments."""
116
+ if total_dur_s <= window_s:
117
+ return [(0.0, total_dur_s)]
118
+ step_s = window_s - crossfade_s
119
+ segments, seg_start = [], 0.0
120
+ while True:
121
+ if seg_start + window_s >= total_dur_s:
122
+ seg_start = max(0.0, total_dur_s - window_s)
123
+ segments.append((seg_start, total_dur_s))
124
+ break
125
+ segments.append((seg_start, seg_start + window_s))
126
+ seg_start += step_s
127
+ return segments
128
+
129
+
130
+ def _cf_join_stereo(a: np.ndarray, b: np.ndarray,
131
+ crossfade_s: float, db_boost: float, sr: int) -> np.ndarray:
132
+ """Equal-power crossfade join for stereo (C, T) numpy arrays."""
133
+ cf = int(round(crossfade_s * sr))
134
+ cf = min(cf, a.shape[1], b.shape[1])
135
+ if cf <= 0:
136
+ return np.concatenate([a, b], axis=1)
137
+ gain = 10 ** (db_boost / 20.0)
138
+ t = np.linspace(0.0, 1.0, cf, dtype=np.float32)
139
+ fade_out = np.cos(t * np.pi / 2) # 1 → 0
140
+ fade_in = np.sin(t * np.pi / 2) # 0 → 1
141
+ overlap = a[:, -cf:] * fade_out * gain + b[:, :cf] * fade_in * gain
142
+ return np.concatenate([a[:, :-cf], overlap, b[:, cf:]], axis=1)
143
+
144
+
145
  # ================================================================== #
146
  # TARO #
147
  # ================================================================== #
 
165
 
166
 
167
  def _taro_build_segments(total_dur_s: float, crossfade_s: float) -> list:
168
+ """Sliding-window segmentation using TARO's 8.192 s window."""
169
+ return _build_segments(total_dur_s, TARO_MODEL_DUR, crossfade_s)
 
 
 
 
 
 
 
 
 
 
 
170
 
171
 
172
  def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: float) -> int:
 
437
  outputs = []
438
 
439
  # MMAudio's fixed window is 8 s. For longer videos we slide over 8 s segments
440
+ # with a crossfade overlap and stitch the results into a full-length track.
441
  total_dur_s = get_video_duration(video_file)
442
  MMA_CF_S = float(crossfade_s)
443
  MMA_CF_DB = float(crossfade_db)
444
+ segments = _build_segments(total_dur_s, MMAUDIO_WINDOW, MMA_CF_S)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  print(f"[MMAudio] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤8 s")
446
 
447
  sr = seq_cfg.sampling_rate # 44100
 
492
  wav = wav[:, :seg_samples]
493
  seg_audios.append(wav)
494
 
495
+ # Crossfade-stitch all segments using shared equal-power helper
 
 
 
 
 
 
 
 
 
 
 
 
496
  full_wav = seg_audios[0]
497
  for nw in seg_audios[1:]:
498
+ full_wav = _cf_join_stereo(full_wav, nw, MMA_CF_S, MMA_CF_DB, sr)
499
  full_wav = full_wav[:, : int(round(total_dur_s * sr))]
500
 
501
  audio_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.flac")
 
574
  total_dur_s = get_video_duration(video_file)
575
  CF_S = float(crossfade_s)
576
  CF_DB = float(crossfade_db)
577
+ segments = _build_segments(total_dur_s, HUNYUAN_MAX_DUR, CF_S)
578
  print(f"[HunyuanFoley] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤15 s")
579
 
580
  # Pre-encode text features once (same for every segment)
 
628
  wav = wav[:, :seg_samples]
629
  seg_wavs.append(wav)
630
 
631
+ # Crossfade-stitch all segments using shared equal-power helper
 
 
 
 
 
 
 
 
 
 
 
 
632
  full_wav = seg_wavs[0]
633
  for nw in seg_wavs[1:]:
634
+ full_wav = _cf_join_stereo(full_wav, nw, CF_S, CF_DB, sr)
635
  # Trim to exact video duration
636
  full_wav = full_wav[:, : int(round(total_dur_s * sr))]
637