BoxOfColors Claude Opus 4.6 commited on
Commit
ef4f0ff
Β·
1 Parent(s): 0bc4a35

Optimize ZeroGPU: move FlashSR to CPU, fix regen 16kHz bug

Browse files

- Move FlashSR upsampling (16kHz→48kHz) from inside @spaces.GPU
to CPU wrappers β€” saves ~1-2s GPU quota per segment since
FlashSR is CPU-only and doesn not need the GPU allocation
- Fix bug: regen_taro_segment and xregen_taro were returning
raw 16kHz wav without applying FlashSR (initial gen applied
it inside the GPU loop but regen never did)
- Reduce TARO regen duration estimate when CAVP/onset feature
cache exists β€” requests 5s overhead instead of 15s

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +19 -9
app.py CHANGED
@@ -686,12 +686,6 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
686
  latents_scale,
687
  euler_sampler, euler_maruyama_sampler,
688
  )
689
- # FlashSR: upsample 16kHz β†’ 48kHz inside GPU window to avoid
690
- # ZeroGPU CUDA-init-in-main-process violation
691
- print(f"[FlashSR] Upsampling seg {len(wavs)+1} "
692
- f"{seg_end_s-seg_start_s:.2f}s @ 16kHz β†’ 48kHz …")
693
- wav = _apply_flashsr(wav)
694
- print(f"[FlashSR] Done β€” {len(wav)/FLASHSR_SR_OUT:.2f}s @ {FLASHSR_SR_OUT}Hz")
695
  wavs.append(wav)
696
  _log_inference_timing("TARO", time.perf_counter() - _t_infer_start,
697
  len(segments), int(num_steps), TARO_SECS_PER_STEP)
@@ -741,7 +735,8 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
741
  first_cavp_saved = False
742
  outputs = []
743
  for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
744
- # wavs are already at 48kHz β€” FlashSR ran inside _taro_gpu_infer
 
745
  final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, FLASHSR_SR_OUT)
746
  audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
747
  _save_wav(audio_path, final_wav, FLASHSR_SR_OUT)
@@ -1145,6 +1140,19 @@ def _splice_and_save(new_wav, seg_idx, meta, slot_id):
1145
  def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
1146
  seed_val, cfg_scale, num_steps, mode,
1147
  crossfade_s, crossfade_db, slot_id=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
1148
  return _estimate_regen_duration("taro", int(num_steps))
1149
 
1150
 
@@ -1206,7 +1214,8 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
1206
  seed_val, cfg_scale, num_steps, mode,
1207
  crossfade_s, crossfade_db, slot_id)
1208
 
1209
- # new_wav is already at 48kHz β€” FlashSR ran inside _regen_taro_gpu β†’ _taro_infer_segment
 
1210
  # CPU: splice, stitch, mux, save
1211
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1212
  new_wav, seg_idx, meta, slot_id
@@ -1477,7 +1486,8 @@ def xregen_taro(seg_idx, state_json, slot_id,
1477
  new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
1478
  seed_val, cfg_scale, num_steps, mode,
1479
  crossfade_s, crossfade_db, slot_id)
1480
- # new_wav_raw already at 48kHz β€” FlashSR ran inside _regen_taro_gpu β†’ _taro_infer_segment
 
1481
  video_path, waveform_html = _xregen_splice(new_wav_raw, FLASHSR_SR_OUT, meta, seg_idx, slot_id)
1482
  yield gr.update(value=video_path), gr.update(value=waveform_html)
1483
 
 
686
  latents_scale,
687
  euler_sampler, euler_maruyama_sampler,
688
  )
 
 
 
 
 
 
689
  wavs.append(wav)
690
  _log_inference_timing("TARO", time.perf_counter() - _t_infer_start,
691
  len(segments), int(num_steps), TARO_SECS_PER_STEP)
 
735
  first_cavp_saved = False
736
  outputs = []
737
  for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
738
+ # FlashSR: upsample each segment 16kHz β†’ 48kHz (CPU-only, no GPU needed)
739
+ wavs = [_apply_flashsr(w) for w in wavs]
740
  final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, FLASHSR_SR_OUT)
741
  audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
742
  _save_wav(audio_path, final_wav, FLASHSR_SR_OUT)
 
1140
  def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
1141
  seed_val, cfg_scale, num_steps, mode,
1142
  crossfade_s, crossfade_db, slot_id=None):
1143
+ # If cached CAVP/onset features exist, skip ~10s feature-extractor overhead
1144
+ try:
1145
+ meta = json.loads(seg_meta_json)
1146
+ cavp_ok = os.path.exists(meta.get("cavp_path", ""))
1147
+ onset_ok = os.path.exists(meta.get("onset_path", ""))
1148
+ if cavp_ok and onset_ok:
1149
+ cfg = MODEL_CONFIGS["taro"]
1150
+ secs = int(num_steps) * cfg["secs_per_step"] + 5 # 5s model-load only
1151
+ result = min(GPU_DURATION_CAP, max(30, int(secs)))
1152
+ print(f"[duration] TARO regen (cache hit): 1 seg Γ— {int(num_steps)} steps β†’ {secs:.0f}s β†’ capped {result}s")
1153
+ return result
1154
+ except Exception:
1155
+ pass
1156
  return _estimate_regen_duration("taro", int(num_steps))
1157
 
1158
 
 
1214
  seed_val, cfg_scale, num_steps, mode,
1215
  crossfade_s, crossfade_db, slot_id)
1216
 
1217
+ # FlashSR: upsample 16kHz β†’ 48kHz on CPU (no GPU needed)
1218
+ new_wav = _apply_flashsr(new_wav)
1219
  # CPU: splice, stitch, mux, save
1220
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1221
  new_wav, seg_idx, meta, slot_id
 
1486
  new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
1487
  seed_val, cfg_scale, num_steps, mode,
1488
  crossfade_s, crossfade_db, slot_id)
1489
+ # FlashSR: upsample 16kHz β†’ 48kHz on CPU (no GPU needed)
1490
+ new_wav_raw = _apply_flashsr(new_wav_raw)
1491
  video_path, waveform_html = _xregen_splice(new_wav_raw, FLASHSR_SR_OUT, meta, seg_idx, slot_id)
1492
  yield gr.update(value=video_path), gr.update(value=waveform_html)
1493