Spaces:
Running on Zero
Move FlashSR inside GPU window, fix xregen routing, refactor helpers
Browse files- FlashSR upsampling (16kHzβ48kHz) now runs inside _taro_gpu_infer
under @spaces.GPU to comply with ZeroGPU CUDA-init rules
- Remove stale _apply_flashsr calls from xregen_taro and regen_taro_segment
(FlashSR already applied per-segment inside _taro_infer_segment)
- Add api_name to queue/join fetch body to fix xregen "Too many arguments"
routing issue in Gradio 5
- Extract _build_seg_meta, _cpu_preprocess, _save_wav, _log_inference_timing
helpers; replace inline stitch with _stitch_wavs
- Cap crossfade slider max at 4s + safety clamp in _build_segments
- Add diagonal hatch waveform indicator for crossfade overlap zones
- Remove HF Token settings accordion (ZeroGPU attribution via JWT)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
@@ -686,6 +686,12 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 686 |
latents_scale,
|
| 687 |
euler_sampler, euler_maruyama_sampler,
|
| 688 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 689 |
wavs.append(wav)
|
| 690 |
_log_inference_timing("TARO", time.perf_counter() - _t_infer_start,
|
| 691 |
len(segments), int(num_steps), TARO_SECS_PER_STEP)
|
|
@@ -735,16 +741,10 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 735 |
first_cavp_saved = False
|
| 736 |
outputs = []
|
| 737 |
for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
# ββ FlashSR: upsample 16 kHz β 48 kHz ββ
|
| 741 |
-
print(f"[TARO] Sample {sample_idx+1}: running FlashSR upsampler (16kHz β 48kHz) β¦")
|
| 742 |
-
final_wav = _apply_flashsr(final_wav)
|
| 743 |
-
out_sr = FLASHSR_SR_OUT
|
| 744 |
-
print(f"[TARO] Sample {sample_idx+1}: FlashSR complete β {len(final_wav)/out_sr:.2f}s @ {out_sr}Hz")
|
| 745 |
-
|
| 746 |
audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
|
| 747 |
-
_save_wav(audio_path, final_wav,
|
| 748 |
video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
|
| 749 |
mux_video_audio(silent_video, audio_path, video_path)
|
| 750 |
wav_paths = _save_seg_wavs(wavs, tmp_dir, f"taro_{sample_idx}")
|
|
@@ -756,7 +756,7 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 756 |
first_cavp_saved = True
|
| 757 |
seg_meta = _build_seg_meta(
|
| 758 |
segments=segments, wav_paths=wav_paths, audio_path=audio_path,
|
| 759 |
-
video_path=video_path, silent_video=silent_video, sr=
|
| 760 |
model="taro", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
|
| 761 |
total_dur_s=total_dur_s, cavp_path=cavp_path, onset_path=onset_path,
|
| 762 |
)
|
|
@@ -1206,16 +1206,10 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
|
|
| 1206 |
seed_val, cfg_scale, num_steps, mode,
|
| 1207 |
crossfade_s, crossfade_db, slot_id)
|
| 1208 |
|
| 1209 |
-
#
|
| 1210 |
-
|
| 1211 |
-
new_wav = _apply_flashsr(new_wav)
|
| 1212 |
-
print(f"[TARO regen] FlashSR complete β {len(new_wav)/FLASHSR_SR_OUT:.2f}s @ {FLASHSR_SR_OUT}Hz")
|
| 1213 |
-
|
| 1214 |
-
# CPU: splice, stitch, mux, save β meta["sr"] must reflect the upsampled rate
|
| 1215 |
-
meta_48k = dict(meta)
|
| 1216 |
-
meta_48k["sr"] = FLASHSR_SR_OUT
|
| 1217 |
video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
|
| 1218 |
-
new_wav, seg_idx,
|
| 1219 |
)
|
| 1220 |
return video_path, audio_path, json.dumps(updated_meta), waveform_html
|
| 1221 |
|
|
@@ -1483,10 +1477,7 @@ def xregen_taro(seg_idx, state_json, slot_id,
|
|
| 1483 |
new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
|
| 1484 |
seed_val, cfg_scale, num_steps, mode,
|
| 1485 |
crossfade_s, crossfade_db, slot_id)
|
| 1486 |
-
#
|
| 1487 |
-
print(f"[xregen TARO] Running FlashSR upsampler (16kHz β 48kHz) on seg {seg_idx} β¦")
|
| 1488 |
-
new_wav_raw = _apply_flashsr(new_wav_raw)
|
| 1489 |
-
print(f"[xregen TARO] FlashSR complete β {len(new_wav_raw)/FLASHSR_SR_OUT:.2f}s @ {FLASHSR_SR_OUT}Hz")
|
| 1490 |
video_path, waveform_html = _xregen_splice(new_wav_raw, FLASHSR_SR_OUT, meta, seg_idx, slot_id)
|
| 1491 |
yield gr.update(value=video_path), gr.update(value=waveform_html)
|
| 1492 |
|
|
@@ -2278,6 +2269,7 @@ _GLOBAL_JS = """
|
|
| 2278 |
body: JSON.stringify({
|
| 2279 |
data: data,
|
| 2280 |
fn_index: fnIndex,
|
|
|
|
| 2281 |
session_hash: window.__gradio_session_hash__,
|
| 2282 |
event_data: null,
|
| 2283 |
trigger_id: null
|
|
|
|
| 686 |
latents_scale,
|
| 687 |
euler_sampler, euler_maruyama_sampler,
|
| 688 |
)
|
| 689 |
+
# FlashSR: upsample 16kHz β 48kHz inside GPU window to avoid
|
| 690 |
+
# ZeroGPU CUDA-init-in-main-process violation
|
| 691 |
+
print(f"[FlashSR] Upsampling seg {len(wavs)+1} "
|
| 692 |
+
f"{seg_end_s-seg_start_s:.2f}s @ 16kHz β 48kHz β¦")
|
| 693 |
+
wav = _apply_flashsr(wav)
|
| 694 |
+
print(f"[FlashSR] Done β {len(wav)/FLASHSR_SR_OUT:.2f}s @ {FLASHSR_SR_OUT}Hz")
|
| 695 |
wavs.append(wav)
|
| 696 |
_log_inference_timing("TARO", time.perf_counter() - _t_infer_start,
|
| 697 |
len(segments), int(num_steps), TARO_SECS_PER_STEP)
|
|
|
|
| 741 |
first_cavp_saved = False
|
| 742 |
outputs = []
|
| 743 |
for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
|
| 744 |
+
# wavs are already at 48kHz β FlashSR ran inside _taro_gpu_infer
|
| 745 |
+
final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, FLASHSR_SR_OUT)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 746 |
audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
|
| 747 |
+
_save_wav(audio_path, final_wav, FLASHSR_SR_OUT)
|
| 748 |
video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
|
| 749 |
mux_video_audio(silent_video, audio_path, video_path)
|
| 750 |
wav_paths = _save_seg_wavs(wavs, tmp_dir, f"taro_{sample_idx}")
|
|
|
|
| 756 |
first_cavp_saved = True
|
| 757 |
seg_meta = _build_seg_meta(
|
| 758 |
segments=segments, wav_paths=wav_paths, audio_path=audio_path,
|
| 759 |
+
video_path=video_path, silent_video=silent_video, sr=FLASHSR_SR_OUT,
|
| 760 |
model="taro", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
|
| 761 |
total_dur_s=total_dur_s, cavp_path=cavp_path, onset_path=onset_path,
|
| 762 |
)
|
|
|
|
| 1206 |
seed_val, cfg_scale, num_steps, mode,
|
| 1207 |
crossfade_s, crossfade_db, slot_id)
|
| 1208 |
|
| 1209 |
+
# new_wav is already at 48kHz β FlashSR ran inside _regen_taro_gpu β _taro_infer_segment
|
| 1210 |
+
# CPU: splice, stitch, mux, save
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1211 |
video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
|
| 1212 |
+
new_wav, seg_idx, meta, slot_id
|
| 1213 |
)
|
| 1214 |
return video_path, audio_path, json.dumps(updated_meta), waveform_html
|
| 1215 |
|
|
|
|
| 1477 |
new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
|
| 1478 |
seed_val, cfg_scale, num_steps, mode,
|
| 1479 |
crossfade_s, crossfade_db, slot_id)
|
| 1480 |
+
# new_wav_raw already at 48kHz β FlashSR ran inside _regen_taro_gpu β _taro_infer_segment
|
|
|
|
|
|
|
|
|
|
| 1481 |
video_path, waveform_html = _xregen_splice(new_wav_raw, FLASHSR_SR_OUT, meta, seg_idx, slot_id)
|
| 1482 |
yield gr.update(value=video_path), gr.update(value=waveform_html)
|
| 1483 |
|
|
|
|
| 2269 |
body: JSON.stringify({
|
| 2270 |
data: data,
|
| 2271 |
fn_index: fnIndex,
|
| 2272 |
+
api_name: '/' + apiName,
|
| 2273 |
session_hash: window.__gradio_session_hash__,
|
| 2274 |
event_data: null,
|
| 2275 |
trigger_id: null
|