BoxOfColors commited on
Commit
9f07d3f
·
1 Parent(s): d5399ac

Fix segment bleed: equal-spaced windows + contact-edge trimming

Browse files

_build_segments now places n equally-spaced full-window segments so every
overlap is identical and >= crossfade_s, eliminating the unequal last-
boundary overlap that caused raw bleed.

_stitch_wavs now accepts segments list and trims each generated wav to its
contact-edge window (midpoint of overlap +/- half_cf) before crossfade-join,
so the crossfade zone is always exactly crossfade_s wide at every boundary.

Files changed (1) hide show
  1. app.py +77 -19
app.py CHANGED
@@ -411,22 +411,34 @@ def mux_video_audio(silent_video: str, audio_path: str, output_path: str,
411
  # ------------------------------------------------------------------ #
412
 
413
  def _build_segments(total_dur_s: float, window_s: float, crossfade_s: float) -> list[tuple[float, float]]:
414
- """Return list of (start, end) pairs covering *total_dur_s* with a sliding
415
- window of *window_s* and *crossfade_s* overlap between consecutive segments."""
416
- # Safety: clamp crossfade to < half the window so step_s stays positive
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  crossfade_s = min(crossfade_s, window_s * 0.5)
418
  if total_dur_s <= window_s:
419
  return [(0.0, total_dur_s)]
420
- step_s = window_s - crossfade_s
421
- segments, seg_start = [], 0.0
422
- while True:
423
- if seg_start + window_s >= total_dur_s:
424
- seg_start = max(0.0, total_dur_s - window_s)
425
- segments.append((seg_start, total_dur_s))
426
- break
427
- segments.append((seg_start, seg_start + window_s))
428
- seg_start += step_s
429
- return segments
430
 
431
 
432
  def _cf_join(a: np.ndarray, b: np.ndarray,
@@ -687,12 +699,58 @@ def _upsample_taro(wav_16k: np.ndarray) -> np.ndarray:
687
 
688
 
689
  def _stitch_wavs(wavs: list[np.ndarray], crossfade_s: float, db_boost: float,
690
- total_dur_s: float, sr: int) -> np.ndarray:
 
691
  """Crossfade-join a list of wav arrays and trim to *total_dur_s*.
692
- Works for both mono (T,) and stereo (C, T) arrays."""
693
- out = wavs[0]
694
- for nw in wavs[1:]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
  out = _cf_join(out, nw, crossfade_s, db_boost, sr)
 
696
  n = int(round(total_dur_s * sr))
697
  return out[:, :n] if out.ndim == 2 else out[:n]
698
 
@@ -757,7 +815,7 @@ def _post_process_samples(results: list, *, model: str, tmp_dir: str,
757
  for sample_idx, result in enumerate(results):
758
  seg_wavs = result[0]
759
 
760
- full_wav = _stitch_wavs(seg_wavs, crossfade_s, crossfade_db, total_dur_s, sr)
761
  audio_path = os.path.join(tmp_dir, f"{model}_{sample_idx}.wav")
762
  _save_wav(audio_path, full_wav, sr)
763
  video_path = os.path.join(tmp_dir, f"{model}_{sample_idx}.mp4")
@@ -1242,7 +1300,7 @@ def _splice_and_save(new_wav, seg_idx, meta, slot_id):
1242
  segments = meta["segments"]
1243
  model = meta["model"]
1244
 
1245
- full_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, sr)
1246
 
1247
  # Save new audio — use a new timestamped filename so Gradio / the browser
1248
  # treats it as a genuinely different file and reloads the video player.
 
411
  # ------------------------------------------------------------------ #
412
 
413
  def _build_segments(total_dur_s: float, window_s: float, crossfade_s: float) -> list[tuple[float, float]]:
414
+ """Return list of (start, end) pairs covering *total_dur_s*.
415
+
416
+ Every segment uses the full *window_s* inference window. Segments are
417
+ equally spaced so every overlap is identical, guaranteeing the crossfade
418
+ setting is honoured at every boundary with no raw bleed.
419
+
420
+ Algorithm
421
+ ---------
422
+ 1. Clamp crossfade_s so the step stays positive.
423
+ 2. Find the minimum n such that n segments of *window_s* cover
424
+ *total_dur_s* with overlap ≥ crossfade_s at every boundary:
425
+ n = ceil((total_dur_s - crossfade_s) / (window_s - crossfade_s))
426
+ 3. Compute equal spacing: step = (total_dur_s - window_s) / (n - 1)
427
+ so that every gap is identical and the last segment ends exactly at
428
+ total_dur_s.
429
+ 4. Every segment is exactly *window_s* wide. The trailing audio of each
430
+ segment beyond its contact edge is discarded in _stitch_wavs.
431
+ """
432
  crossfade_s = min(crossfade_s, window_s * 0.5)
433
  if total_dur_s <= window_s:
434
  return [(0.0, total_dur_s)]
435
+ import math
436
+ step_min = window_s - crossfade_s # minimum step to honour crossfade
437
+ n = math.ceil((total_dur_s - crossfade_s) / step_min)
438
+ n = max(n, 2)
439
+ # Equal step so first seg starts at 0 and last seg ends at total_dur_s
440
+ step_s = (total_dur_s - window_s) / (n - 1)
441
+ return [(i * step_s, i * step_s + window_s) for i in range(n)]
 
 
 
442
 
443
 
444
  def _cf_join(a: np.ndarray, b: np.ndarray,
 
699
 
700
 
701
  def _stitch_wavs(wavs: list[np.ndarray], crossfade_s: float, db_boost: float,
702
+ total_dur_s: float, sr: int,
703
+ segments: list[tuple[float, float]] = None) -> np.ndarray:
704
  """Crossfade-join a list of wav arrays and trim to *total_dur_s*.
705
+ Works for both mono (T,) and stereo (C, T) arrays.
706
+
707
+ When *segments* is provided (list of (start, end) video-time pairs),
708
+ each wav is trimmed to its contact-edge window before joining:
709
+
710
+ contact_edge[i→i+1] = midpoint of overlap = (seg[i].end + seg[i+1].start) / 2
711
+ half_cf = crossfade_s / 2
712
+
713
+ seg i keep: [contact_edge[i-1→i] - half_cf, contact_edge[i→i+1] + half_cf]
714
+ expressed as sample offsets into the generated audio for that segment.
715
+
716
+ This guarantees every crossfade zone is exactly crossfade_s wide with no
717
+ raw bleed regardless of how much the inference windows overlap.
718
+ """
719
+ def _trim(wav, start_s, end_s, seg_start_s):
720
+ """Trim wav to [start_s, end_s] expressed in absolute video time,
721
+ where the wav starts at seg_start_s in video time."""
722
+ s = max(0, int(round((start_s - seg_start_s) * sr)))
723
+ e = int(round((end_s - seg_start_s) * sr))
724
+ e = min(e, wav.shape[1] if wav.ndim == 2 else len(wav))
725
+ return wav[:, s:e] if wav.ndim == 2 else wav[s:e]
726
+
727
+ if segments is None or len(segments) == 1:
728
+ out = wavs[0]
729
+ for nw in wavs[1:]:
730
+ out = _cf_join(out, nw, crossfade_s, db_boost, sr)
731
+ n = int(round(total_dur_s * sr))
732
+ return out[:, :n] if out.ndim == 2 else out[:n]
733
+
734
+ half_cf = crossfade_s / 2.0
735
+
736
+ # Compute contact edges between consecutive segments
737
+ contact_edges = [
738
+ (segments[i][1] + segments[i + 1][0]) / 2.0
739
+ for i in range(len(segments) - 1)
740
+ ]
741
+
742
+ # Trim each segment to its keep window
743
+ trimmed = []
744
+ for i, (wav, (seg_start, seg_end)) in enumerate(zip(wavs, segments)):
745
+ keep_start = (contact_edges[i - 1] - half_cf) if i > 0 else seg_start
746
+ keep_end = (contact_edges[i] + half_cf) if i < len(segments) - 1 else total_dur_s
747
+ trimmed.append(_trim(wav, keep_start, keep_end, seg_start))
748
+
749
+ # Crossfade-join the trimmed segments
750
+ out = trimmed[0]
751
+ for nw in trimmed[1:]:
752
  out = _cf_join(out, nw, crossfade_s, db_boost, sr)
753
+
754
  n = int(round(total_dur_s * sr))
755
  return out[:, :n] if out.ndim == 2 else out[:n]
756
 
 
815
  for sample_idx, result in enumerate(results):
816
  seg_wavs = result[0]
817
 
818
+ full_wav = _stitch_wavs(seg_wavs, crossfade_s, crossfade_db, total_dur_s, sr, segments)
819
  audio_path = os.path.join(tmp_dir, f"{model}_{sample_idx}.wav")
820
  _save_wav(audio_path, full_wav, sr)
821
  video_path = os.path.join(tmp_dir, f"{model}_{sample_idx}.mp4")
 
1300
  segments = meta["segments"]
1301
  model = meta["model"]
1302
 
1303
+ full_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, sr, segments)
1304
 
1305
  # Save new audio — use a new timestamped filename so Gradio / the browser
1306
  # treats it as a genuinely different file and reloads the video player.