Spaces:
Running on Zero
Running on Zero
Commit Β·
ef4f0ff
1
Parent(s): 0bc4a35
Optimize ZeroGPU: move FlashSR to CPU, fix regen 16kHz bug
Browse files- Move FlashSR upsampling (16kHzβ48kHz) from inside @spaces.GPU
to CPU wrappers β saves ~1-2s GPU quota per segment since
FlashSR is CPU-only and doesn not need the GPU allocation
- Fix bug: regen_taro_segment and xregen_taro were returning
raw 16kHz wav without applying FlashSR (initial gen applied
it inside the GPU loop but regen never did)
- Reduce TARO regen duration estimate when CAVP/onset feature
cache exists β requests 5s overhead instead of 15s
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
app.py
CHANGED
|
@@ -686,12 +686,6 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 686 |
latents_scale,
|
| 687 |
euler_sampler, euler_maruyama_sampler,
|
| 688 |
)
|
| 689 |
-
# FlashSR: upsample 16kHz β 48kHz inside GPU window to avoid
|
| 690 |
-
# ZeroGPU CUDA-init-in-main-process violation
|
| 691 |
-
print(f"[FlashSR] Upsampling seg {len(wavs)+1} "
|
| 692 |
-
f"{seg_end_s-seg_start_s:.2f}s @ 16kHz β 48kHz β¦")
|
| 693 |
-
wav = _apply_flashsr(wav)
|
| 694 |
-
print(f"[FlashSR] Done β {len(wav)/FLASHSR_SR_OUT:.2f}s @ {FLASHSR_SR_OUT}Hz")
|
| 695 |
wavs.append(wav)
|
| 696 |
_log_inference_timing("TARO", time.perf_counter() - _t_infer_start,
|
| 697 |
len(segments), int(num_steps), TARO_SECS_PER_STEP)
|
|
@@ -741,7 +735,8 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 741 |
first_cavp_saved = False
|
| 742 |
outputs = []
|
| 743 |
for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
|
| 744 |
-
#
|
|
|
|
| 745 |
final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, FLASHSR_SR_OUT)
|
| 746 |
audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
|
| 747 |
_save_wav(audio_path, final_wav, FLASHSR_SR_OUT)
|
|
@@ -1145,6 +1140,19 @@ def _splice_and_save(new_wav, seg_idx, meta, slot_id):
|
|
| 1145 |
def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
|
| 1146 |
seed_val, cfg_scale, num_steps, mode,
|
| 1147 |
crossfade_s, crossfade_db, slot_id=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1148 |
return _estimate_regen_duration("taro", int(num_steps))
|
| 1149 |
|
| 1150 |
|
|
@@ -1206,7 +1214,8 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
|
|
| 1206 |
seed_val, cfg_scale, num_steps, mode,
|
| 1207 |
crossfade_s, crossfade_db, slot_id)
|
| 1208 |
|
| 1209 |
-
#
|
|
|
|
| 1210 |
# CPU: splice, stitch, mux, save
|
| 1211 |
video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
|
| 1212 |
new_wav, seg_idx, meta, slot_id
|
|
@@ -1477,7 +1486,8 @@ def xregen_taro(seg_idx, state_json, slot_id,
|
|
| 1477 |
new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
|
| 1478 |
seed_val, cfg_scale, num_steps, mode,
|
| 1479 |
crossfade_s, crossfade_db, slot_id)
|
| 1480 |
-
#
|
|
|
|
| 1481 |
video_path, waveform_html = _xregen_splice(new_wav_raw, FLASHSR_SR_OUT, meta, seg_idx, slot_id)
|
| 1482 |
yield gr.update(value=video_path), gr.update(value=waveform_html)
|
| 1483 |
|
|
|
|
| 686 |
latents_scale,
|
| 687 |
euler_sampler, euler_maruyama_sampler,
|
| 688 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 689 |
wavs.append(wav)
|
| 690 |
_log_inference_timing("TARO", time.perf_counter() - _t_infer_start,
|
| 691 |
len(segments), int(num_steps), TARO_SECS_PER_STEP)
|
|
|
|
| 735 |
first_cavp_saved = False
|
| 736 |
outputs = []
|
| 737 |
for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
|
| 738 |
+
# FlashSR: upsample each segment 16kHz β 48kHz (CPU-only, no GPU needed)
|
| 739 |
+
wavs = [_apply_flashsr(w) for w in wavs]
|
| 740 |
final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, FLASHSR_SR_OUT)
|
| 741 |
audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
|
| 742 |
_save_wav(audio_path, final_wav, FLASHSR_SR_OUT)
|
|
|
|
| 1140 |
def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
|
| 1141 |
seed_val, cfg_scale, num_steps, mode,
|
| 1142 |
crossfade_s, crossfade_db, slot_id=None):
|
| 1143 |
+
# If cached CAVP/onset features exist, skip ~10s feature-extractor overhead
|
| 1144 |
+
try:
|
| 1145 |
+
meta = json.loads(seg_meta_json)
|
| 1146 |
+
cavp_ok = os.path.exists(meta.get("cavp_path", ""))
|
| 1147 |
+
onset_ok = os.path.exists(meta.get("onset_path", ""))
|
| 1148 |
+
if cavp_ok and onset_ok:
|
| 1149 |
+
cfg = MODEL_CONFIGS["taro"]
|
| 1150 |
+
secs = int(num_steps) * cfg["secs_per_step"] + 5 # 5s model-load only
|
| 1151 |
+
result = min(GPU_DURATION_CAP, max(30, int(secs)))
|
| 1152 |
+
print(f"[duration] TARO regen (cache hit): 1 seg Γ {int(num_steps)} steps β {secs:.0f}s β capped {result}s")
|
| 1153 |
+
return result
|
| 1154 |
+
except Exception:
|
| 1155 |
+
pass
|
| 1156 |
return _estimate_regen_duration("taro", int(num_steps))
|
| 1157 |
|
| 1158 |
|
|
|
|
| 1214 |
seed_val, cfg_scale, num_steps, mode,
|
| 1215 |
crossfade_s, crossfade_db, slot_id)
|
| 1216 |
|
| 1217 |
+
# FlashSR: upsample 16kHz β 48kHz on CPU (no GPU needed)
|
| 1218 |
+
new_wav = _apply_flashsr(new_wav)
|
| 1219 |
# CPU: splice, stitch, mux, save
|
| 1220 |
video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
|
| 1221 |
new_wav, seg_idx, meta, slot_id
|
|
|
|
| 1486 |
new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
|
| 1487 |
seed_val, cfg_scale, num_steps, mode,
|
| 1488 |
crossfade_s, crossfade_db, slot_id)
|
| 1489 |
+
# FlashSR: upsample 16kHz β 48kHz on CPU (no GPU needed)
|
| 1490 |
+
new_wav_raw = _apply_flashsr(new_wav_raw)
|
| 1491 |
video_path, waveform_html = _xregen_splice(new_wav_raw, FLASHSR_SR_OUT, meta, seg_idx, slot_id)
|
| 1492 |
yield gr.update(value=video_path), gr.update(value=waveform_html)
|
| 1493 |
|