Spaces:
Running on Zero
Running on Zero
Commit ·
efe424b
1
Parent(s): 4d46101
feat: add FlashSR post-processing to upsample TARO 16kHz → 48kHz
Browse filesAll three models now output at 48kHz (TARO via FlashSR, MMAudio at
44.1kHz natively resampled, HunyuanFoley at 48kHz natively).
FlashSR is applied after generation and after each regen/xregen on
TARO outputs. Console logs confirm each upsampling step with duration
and sample rate. Falls back to sinc resampling if FlashSR errors.
- app.py +90 -5
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -498,6 +498,73 @@ def _taro_infer_segment(
|
|
| 498 |
return wav[:seg_samples]
|
| 499 |
|
| 500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
def _stitch_wavs(wavs: list[np.ndarray], crossfade_s: float, db_boost: float,
|
| 502 |
total_dur_s: float, sr: int) -> np.ndarray:
|
| 503 |
"""Crossfade-join a list of wav arrays and trim to *total_dur_s*.
|
|
@@ -672,8 +739,15 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 672 |
outputs = []
|
| 673 |
for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
|
| 674 |
final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
|
| 676 |
-
_save_wav(audio_path, final_wav,
|
| 677 |
video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
|
| 678 |
mux_video_audio(silent_video, audio_path, video_path)
|
| 679 |
wav_paths = _save_seg_wavs(wavs, tmp_dir, f"taro_{sample_idx}")
|
|
@@ -685,7 +759,7 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 685 |
first_cavp_saved = True
|
| 686 |
seg_meta = _build_seg_meta(
|
| 687 |
segments=segments, wav_paths=wav_paths, audio_path=audio_path,
|
| 688 |
-
video_path=video_path, silent_video=silent_video, sr=
|
| 689 |
model="taro", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
|
| 690 |
total_dur_s=total_dur_s, cavp_path=cavp_path, onset_path=onset_path,
|
| 691 |
)
|
|
@@ -1135,9 +1209,16 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
|
|
| 1135 |
seed_val, cfg_scale, num_steps, mode,
|
| 1136 |
crossfade_s, crossfade_db, slot_id)
|
| 1137 |
|
| 1138 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1139 |
video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
|
| 1140 |
-
new_wav, seg_idx,
|
| 1141 |
)
|
| 1142 |
return video_path, audio_path, json.dumps(updated_meta), waveform_html
|
| 1143 |
|
|
@@ -1405,7 +1486,11 @@ def xregen_taro(seg_idx, state_json, slot_id,
|
|
| 1405 |
new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
|
| 1406 |
seed_val, cfg_scale, num_steps, mode,
|
| 1407 |
crossfade_s, crossfade_db, slot_id)
|
| 1408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1409 |
yield gr.update(value=video_path), gr.update(value=waveform_html)
|
| 1410 |
|
| 1411 |
|
|
|
|
| 498 |
return wav[:seg_samples]
|
| 499 |
|
| 500 |
|
| 501 |
+
# ================================================================== #
|
| 502 |
+
# FlashSR (16 → 48 kHz) #
|
| 503 |
+
# ================================================================== #
|
| 504 |
+
# FlashSR is used as a post-processing step on TARO outputs only.
|
| 505 |
+
# TARO generates at 16 kHz; FlashSR upsamples to 48 kHz so all three
|
| 506 |
+
# models produce output at the same sample rate.
|
| 507 |
+
# Model weights are downloaded once from HF Hub and cached on disk.
|
| 508 |
+
|
| 509 |
+
_FLASHSR_MODEL = None # module-level cache — loaded once per process
|
| 510 |
+
_FLASHSR_LOCK = threading.Lock()
|
| 511 |
+
|
| 512 |
+
FLASHSR_SR_IN = 16000
|
| 513 |
+
FLASHSR_SR_OUT = 48000
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def _load_flashsr():
|
| 517 |
+
"""Load FlashSR model (cached after first call). Returns FASR instance."""
|
| 518 |
+
global _FLASHSR_MODEL
|
| 519 |
+
with _FLASHSR_LOCK:
|
| 520 |
+
if _FLASHSR_MODEL is not None:
|
| 521 |
+
return _FLASHSR_MODEL
|
| 522 |
+
print("[FlashSR] Loading model weights from HF Hub …")
|
| 523 |
+
from huggingface_hub import hf_hub_download
|
| 524 |
+
from FastAudioSR import FASR
|
| 525 |
+
ckpt_path = hf_hub_download(
|
| 526 |
+
repo_id="YatharthS/FlashSR",
|
| 527 |
+
filename="upsampler.pth",
|
| 528 |
+
local_dir=os.path.join(os.path.dirname(os.path.abspath(__file__)), ".flashsr_cache"),
|
| 529 |
+
)
|
| 530 |
+
model = FASR(ckpt_path)
|
| 531 |
+
if torch.cuda.is_available():
|
| 532 |
+
model.model.half().cuda()
|
| 533 |
+
print("[FlashSR] Model loaded on GPU (fp16)")
|
| 534 |
+
else:
|
| 535 |
+
print("[FlashSR] Model loaded on CPU (fp32)")
|
| 536 |
+
_FLASHSR_MODEL = model
|
| 537 |
+
return model
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def _apply_flashsr(wav_16k: np.ndarray) -> np.ndarray:
|
| 541 |
+
"""Upsample a mono 16 kHz numpy array to 48 kHz using FlashSR.
|
| 542 |
+
|
| 543 |
+
Returns a mono float32 numpy array at 48 kHz.
|
| 544 |
+
Falls back to torchaudio sinc resampling if FlashSR fails.
|
| 545 |
+
"""
|
| 546 |
+
try:
|
| 547 |
+
model = _load_flashsr()
|
| 548 |
+
t = torch.from_numpy(wav_16k.astype(np.float32)).unsqueeze(0)
|
| 549 |
+
if torch.cuda.is_available():
|
| 550 |
+
t = t.half().cuda()
|
| 551 |
+
print(f"[FlashSR] Upsampling {len(wav_16k)/FLASHSR_SR_IN:.2f}s @ 16kHz → 48kHz …")
|
| 552 |
+
with torch.no_grad():
|
| 553 |
+
out = model.run(t)
|
| 554 |
+
# out is a tensor or numpy array — normalise to numpy float32 cpu
|
| 555 |
+
if isinstance(out, torch.Tensor):
|
| 556 |
+
out = out.float().cpu().squeeze().numpy()
|
| 557 |
+
else:
|
| 558 |
+
out = np.array(out, dtype=np.float32).squeeze()
|
| 559 |
+
print(f"[FlashSR] Done — output shape {out.shape}, sr={FLASHSR_SR_OUT}")
|
| 560 |
+
return out
|
| 561 |
+
except Exception as e:
|
| 562 |
+
print(f"[FlashSR] ERROR: {e} — falling back to sinc resampling")
|
| 563 |
+
t = torch.from_numpy(wav_16k.astype(np.float32)).unsqueeze(0)
|
| 564 |
+
out = torchaudio.functional.resample(t, FLASHSR_SR_IN, FLASHSR_SR_OUT)
|
| 565 |
+
return out.squeeze().numpy()
|
| 566 |
+
|
| 567 |
+
|
| 568 |
def _stitch_wavs(wavs: list[np.ndarray], crossfade_s: float, db_boost: float,
|
| 569 |
total_dur_s: float, sr: int) -> np.ndarray:
|
| 570 |
"""Crossfade-join a list of wav arrays and trim to *total_dur_s*.
|
|
|
|
| 739 |
outputs = []
|
| 740 |
for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
|
| 741 |
final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR)
|
| 742 |
+
|
| 743 |
+
# ── FlashSR: upsample 16 kHz → 48 kHz ──
|
| 744 |
+
print(f"[TARO] Sample {sample_idx+1}: running FlashSR upsampler (16kHz → 48kHz) …")
|
| 745 |
+
final_wav = _apply_flashsr(final_wav)
|
| 746 |
+
out_sr = FLASHSR_SR_OUT
|
| 747 |
+
print(f"[TARO] Sample {sample_idx+1}: FlashSR complete — {len(final_wav)/out_sr:.2f}s @ {out_sr}Hz")
|
| 748 |
+
|
| 749 |
audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
|
| 750 |
+
_save_wav(audio_path, final_wav, out_sr)
|
| 751 |
video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
|
| 752 |
mux_video_audio(silent_video, audio_path, video_path)
|
| 753 |
wav_paths = _save_seg_wavs(wavs, tmp_dir, f"taro_{sample_idx}")
|
|
|
|
| 759 |
first_cavp_saved = True
|
| 760 |
seg_meta = _build_seg_meta(
|
| 761 |
segments=segments, wav_paths=wav_paths, audio_path=audio_path,
|
| 762 |
+
video_path=video_path, silent_video=silent_video, sr=out_sr,
|
| 763 |
model="taro", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
|
| 764 |
total_dur_s=total_dur_s, cavp_path=cavp_path, onset_path=onset_path,
|
| 765 |
)
|
|
|
|
| 1209 |
seed_val, cfg_scale, num_steps, mode,
|
| 1210 |
crossfade_s, crossfade_db, slot_id)
|
| 1211 |
|
| 1212 |
+
# FlashSR: upsample 16 kHz → 48 kHz before splicing
|
| 1213 |
+
print(f"[TARO regen] Running FlashSR upsampler (16kHz → 48kHz) on seg {seg_idx} …")
|
| 1214 |
+
new_wav = _apply_flashsr(new_wav)
|
| 1215 |
+
print(f"[TARO regen] FlashSR complete — {len(new_wav)/FLASHSR_SR_OUT:.2f}s @ {FLASHSR_SR_OUT}Hz")
|
| 1216 |
+
|
| 1217 |
+
# CPU: splice, stitch, mux, save — meta["sr"] must reflect the upsampled rate
|
| 1218 |
+
meta_48k = dict(meta)
|
| 1219 |
+
meta_48k["sr"] = FLASHSR_SR_OUT
|
| 1220 |
video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
|
| 1221 |
+
new_wav, seg_idx, meta_48k, slot_id
|
| 1222 |
)
|
| 1223 |
return video_path, audio_path, json.dumps(updated_meta), waveform_html
|
| 1224 |
|
|
|
|
| 1486 |
new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
|
| 1487 |
seed_val, cfg_scale, num_steps, mode,
|
| 1488 |
crossfade_s, crossfade_db, slot_id)
|
| 1489 |
+
# FlashSR: upsample 16 kHz → 48 kHz before splicing into slot
|
| 1490 |
+
print(f"[xregen TARO] Running FlashSR upsampler (16kHz → 48kHz) on seg {seg_idx} …")
|
| 1491 |
+
new_wav_raw = _apply_flashsr(new_wav_raw)
|
| 1492 |
+
print(f"[xregen TARO] FlashSR complete — {len(new_wav_raw)/FLASHSR_SR_OUT:.2f}s @ {FLASHSR_SR_OUT}Hz")
|
| 1493 |
+
video_path, waveform_html = _xregen_splice(new_wav_raw, FLASHSR_SR_OUT, meta, seg_idx, slot_id)
|
| 1494 |
yield gr.update(value=video_path), gr.update(value=waveform_html)
|
| 1495 |
|
| 1496 |
|
requirements.txt
CHANGED
|
@@ -21,6 +21,7 @@ loguru
|
|
| 21 |
torchdiffeq
|
| 22 |
open_clip_torch
|
| 23 |
git+https://github.com/descriptinc/audiotools
|
|
|
|
| 24 |
--extra-index-url https://download.pytorch.org/whl/cu124
|
| 25 |
torchaudio==2.5.1+cu124
|
| 26 |
--find-links https://download.openmmlab.com/mmcv/dist/cu121/torch2.4.0/index.html
|
|
|
|
| 21 |
torchdiffeq
|
| 22 |
open_clip_torch
|
| 23 |
git+https://github.com/descriptinc/audiotools
|
| 24 |
+
git+https://github.com/ysharma3501/FlashSR.git
|
| 25 |
--extra-index-url https://download.pytorch.org/whl/cu124
|
| 26 |
torchaudio==2.5.1+cu124
|
| 27 |
--find-links https://download.openmmlab.com/mmcv/dist/cu121/torch2.4.0/index.html
|