BoxOfColors Claude Opus 4.6 commited on
Commit
fe18eeb
Β·
1 Parent(s): dbba693

Unify all models to 48kHz; remove all FlashSR traces

Browse files

- Rename _apply_flashsr → _upsample_taro (TARO 16kHz→48kHz sinc)
- Add _resample_to_target helper (any SR β†’ TARGET_SR=48000, sinc, CPU)
- Add TARGET_SR=48000 constant as single source of truth
- MMAudio (44100Hz): resample in generate_mmaudio and
regen_mmaudio_segment post-processing β†’ all outputs now 48kHz
- HunyuanFoley already native 48kHz β€” no change needed
- Update MODEL_CONFIGS mmaudio sr: 44100 β†’ 48000
- Console logs confirm upsample ratios and output durations
- Remove FLASHSR_SR_IN/FLASHSR_SR_OUT, zero FlashSR references remain

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +56 -29
app.py CHANGED
@@ -367,7 +367,7 @@ MODEL_CONFIGS = {
367
  },
368
  "mmaudio": {
369
  "window_s": MMAUDIO_WINDOW, # 8.0 s
370
- "sr": 44100,
371
  "secs_per_step": MMAUDIO_SECS_PER_STEP, # 0.25
372
  "load_overhead": MMAUDIO_LOAD_OVERHEAD, # 15
373
  "tab_prefix": "mma",
@@ -499,30 +499,46 @@ def _taro_infer_segment(
499
 
500
 
501
  # ================================================================== #
502
- # FlashSR (16 β†’ 48 kHz) #
503
  # ================================================================== #
504
- # FlashSR is used as a post-processing step on TARO outputs only.
505
- # TARO generates at 16 kHz; FlashSR upsamples to 48 kHz so all three
506
- # models produce output at the same sample rate.
507
- # Model weights are downloaded once from HF Hub and cached on disk.
508
 
509
- FLASHSR_SR_IN = 16000
510
- FLASHSR_SR_OUT = 48000
511
 
512
 
513
- def _apply_flashsr(wav_16k: np.ndarray) -> np.ndarray:
514
- """Upsample a mono 16 kHz numpy array to 48 kHz using sinc resampling (CPU).
515
 
516
- FlashSR was attempted but its dependencies trigger torch.cuda.is_available()
517
- on import, which violates ZeroGPU's stateless-GPU rule and aborts subsequent
518
- GPU tasks. High-quality sinc resampling via torchaudio is ZeroGPU-safe and
519
- produces clean 16β†’48 kHz output for foley/ambient audio.
520
  """
521
- print(f"[upsample] {len(wav_16k)/FLASHSR_SR_IN:.2f}s @ 16kHz β†’ 48kHz (sinc, CPU) …")
522
- t = torch.from_numpy(wav_16k.astype(np.float32)).unsqueeze(0)
523
- out = torchaudio.functional.resample(t, FLASHSR_SR_IN, FLASHSR_SR_OUT)
524
- result = out.squeeze().numpy()
525
- print(f"[upsample] Done β€” {len(result)/FLASHSR_SR_OUT:.2f}s @ {FLASHSR_SR_OUT}Hz")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  return result
527
 
528
 
@@ -699,11 +715,11 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
699
  first_cavp_saved = False
700
  outputs = []
701
  for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
702
- # FlashSR: upsample each segment 16kHz β†’ 48kHz (CPU-only, no GPU needed)
703
- wavs = [_apply_flashsr(w) for w in wavs]
704
- final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, FLASHSR_SR_OUT)
705
  audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
706
- _save_wav(audio_path, final_wav, FLASHSR_SR_OUT)
707
  video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
708
  mux_video_audio(silent_video, audio_path, video_path)
709
  wav_paths = _save_seg_wavs(wavs, tmp_dir, f"taro_{sample_idx}")
@@ -715,7 +731,7 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
715
  first_cavp_saved = True
716
  seg_meta = _build_seg_meta(
717
  segments=segments, wav_paths=wav_paths, audio_path=audio_path,
718
- video_path=video_path, silent_video=silent_video, sr=FLASHSR_SR_OUT,
719
  model="taro", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
720
  total_dur_s=total_dur_s, cavp_path=cavp_path, onset_path=onset_path,
721
  )
@@ -854,6 +870,12 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
854
  # ── CPU post-processing ──
855
  outputs = []
856
  for sample_idx, (seg_audios, sr) in enumerate(results):
 
 
 
 
 
 
857
  full_wav = _stitch_wavs(seg_audios, crossfade_s, crossfade_db, total_dur_s, sr)
858
 
859
  audio_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.wav")
@@ -1178,8 +1200,8 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
1178
  seed_val, cfg_scale, num_steps, mode,
1179
  crossfade_s, crossfade_db, slot_id)
1180
 
1181
- # FlashSR: upsample 16kHz β†’ 48kHz on CPU (no GPU needed)
1182
- new_wav = _apply_flashsr(new_wav)
1183
  # CPU: splice, stitch, mux, save
1184
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1185
  new_wav, seg_idx, meta, slot_id
@@ -1269,6 +1291,11 @@ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
1269
  prompt, negative_prompt, seed_val,
1270
  cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id)
1271
 
 
 
 
 
 
1272
  meta["sr"] = sr
1273
 
1274
  # CPU: splice, stitch, mux, save
@@ -1450,9 +1477,9 @@ def xregen_taro(seg_idx, state_json, slot_id,
1450
  new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
1451
  seed_val, cfg_scale, num_steps, mode,
1452
  crossfade_s, crossfade_db, slot_id)
1453
- # FlashSR: upsample 16kHz β†’ 48kHz on CPU (no GPU needed)
1454
- new_wav_raw = _apply_flashsr(new_wav_raw)
1455
- video_path, waveform_html = _xregen_splice(new_wav_raw, FLASHSR_SR_OUT, meta, seg_idx, slot_id)
1456
  yield gr.update(value=video_path), gr.update(value=waveform_html)
1457
 
1458
 
 
367
  },
368
  "mmaudio": {
369
  "window_s": MMAUDIO_WINDOW, # 8.0 s
370
+ "sr": 48000, # resampled to 48kHz in post-processing
371
  "secs_per_step": MMAUDIO_SECS_PER_STEP, # 0.25
372
  "load_overhead": MMAUDIO_LOAD_OVERHEAD, # 15
373
  "tab_prefix": "mma",
 
499
 
500
 
501
  # ================================================================== #
502
+ # TARO 16 kHz β†’ 48 kHz upsample #
503
  # ================================================================== #
504
+ # TARO generates at 16 kHz; all other models output at 44.1/48 kHz.
505
+ # We upsample via sinc resampling (torchaudio, CPU-only) so the final
506
+ # stitched audio is uniformly at 48 kHz across all three models.
 
507
 
508
+ TARGET_SR = 48000 # unified output sample rate for all three models
509
+ TARO_SR_OUT = TARGET_SR
510
 
511
 
512
+ def _resample_to_target(wav: np.ndarray, src_sr: int) -> np.ndarray:
513
+ """Resample *wav* (mono or stereo numpy float32) from src_sr to TARGET_SR (48kHz).
514
 
515
+ No-op if src_sr already equals TARGET_SR. Uses torchaudio Kaiser-windowed
516
+ sinc resampling β€” CPU-only, ZeroGPU-safe.
 
 
517
  """
518
+ if src_sr == TARGET_SR:
519
+ return wav
520
+ stereo = wav.ndim == 2
521
+ t = torch.from_numpy(np.ascontiguousarray(wav.astype(np.float32)))
522
+ if not stereo:
523
+ t = t.unsqueeze(0) # [1, T]
524
+ t = torchaudio.functional.resample(t, src_sr, TARGET_SR)
525
+ if not stereo:
526
+ t = t.squeeze(0) # [T]
527
+ return t.numpy()
528
+
529
+
530
+ def _upsample_taro(wav_16k: np.ndarray) -> np.ndarray:
531
+ """Upsample a mono 16 kHz numpy array to 48 kHz via sinc resampling (CPU).
532
+
533
+ torchaudio.functional.resample uses a Kaiser-windowed sinc filter β€”
534
+ mathematically optimal for bandlimited signals, zero CUDA risk.
535
+ Returns a mono float32 numpy array at 48 kHz.
536
+ """
537
+ dur_in = len(wav_16k) / TARO_SR
538
+ print(f"[TARO upsample] {dur_in:.2f}s @ {TARO_SR}Hz β†’ {TARGET_SR}Hz (sinc, CPU) …")
539
+ result = _resample_to_target(wav_16k, TARO_SR)
540
+ print(f"[TARO upsample] done β€” {len(result)/TARGET_SR:.2f}s @ {TARGET_SR}Hz "
541
+ f"(expected {dur_in * 3:.2f}s, ratio 3Γ—)")
542
  return result
543
 
544
 
 
715
  first_cavp_saved = False
716
  outputs = []
717
  for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
718
+ # Upsample each segment 16kHz β†’ 48kHz (sinc, CPU)
719
+ wavs = [_upsample_taro(w) for w in wavs]
720
+ final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR_OUT)
721
  audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
722
+ _save_wav(audio_path, final_wav, TARO_SR_OUT)
723
  video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
724
  mux_video_audio(silent_video, audio_path, video_path)
725
  wav_paths = _save_seg_wavs(wavs, tmp_dir, f"taro_{sample_idx}")
 
731
  first_cavp_saved = True
732
  seg_meta = _build_seg_meta(
733
  segments=segments, wav_paths=wav_paths, audio_path=audio_path,
734
+ video_path=video_path, silent_video=silent_video, sr=TARO_SR_OUT,
735
  model="taro", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
736
  total_dur_s=total_dur_s, cavp_path=cavp_path, onset_path=onset_path,
737
  )
 
870
  # ── CPU post-processing ──
871
  outputs = []
872
  for sample_idx, (seg_audios, sr) in enumerate(results):
873
+ # Resample 44100 β†’ 48000 Hz so all three models share the same output SR
874
+ if sr != TARGET_SR:
875
+ print(f"[MMAudio upsample] resampling {sr}Hz β†’ {TARGET_SR}Hz (sinc, CPU) …")
876
+ seg_audios = [_resample_to_target(w, sr) for w in seg_audios]
877
+ print(f"[MMAudio upsample] done β€” {len(seg_audios)} seg(s) @ {TARGET_SR}Hz")
878
+ sr = TARGET_SR
879
  full_wav = _stitch_wavs(seg_audios, crossfade_s, crossfade_db, total_dur_s, sr)
880
 
881
  audio_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.wav")
 
1200
  seed_val, cfg_scale, num_steps, mode,
1201
  crossfade_s, crossfade_db, slot_id)
1202
 
1203
+ # Upsample 16kHz β†’ 48kHz (sinc, CPU)
1204
+ new_wav = _upsample_taro(new_wav)
1205
  # CPU: splice, stitch, mux, save
1206
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1207
  new_wav, seg_idx, meta, slot_id
 
1291
  prompt, negative_prompt, seed_val,
1292
  cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id)
1293
 
1294
+ # Resample to 48kHz if needed (MMAudio outputs at 44100 Hz)
1295
+ if sr != TARGET_SR:
1296
+ print(f"[MMAudio regen upsample] {sr}Hz β†’ {TARGET_SR}Hz (sinc, CPU) …")
1297
+ new_wav = _resample_to_target(new_wav, sr)
1298
+ sr = TARGET_SR
1299
  meta["sr"] = sr
1300
 
1301
  # CPU: splice, stitch, mux, save
 
1477
  new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
1478
  seed_val, cfg_scale, num_steps, mode,
1479
  crossfade_s, crossfade_db, slot_id)
1480
+ # Upsample 16kHz β†’ 48kHz (sinc, CPU)
1481
+ new_wav_raw = _upsample_taro(new_wav_raw)
1482
+ video_path, waveform_html = _xregen_splice(new_wav_raw, TARO_SR_OUT, meta, seg_idx, slot_id)
1483
  yield gr.update(value=video_path), gr.update(value=waveform_html)
1484
 
1485