BoxOfColors Claude Sonnet 4.6 commited on
Commit
0bc4a35
Β·
1 Parent(s): 60d3e36

Move FlashSR inside GPU window, fix xregen routing, refactor helpers

Browse files

- FlashSR upsampling (16kHz→48kHz) now runs inside _taro_gpu_infer
under @spaces.GPU to comply with ZeroGPU CUDA-init rules
- Remove stale _apply_flashsr calls from xregen_taro and regen_taro_segment
(FlashSR already applied per-segment inside _taro_infer_segment)
- Add api_name to queue/join fetch body to fix xregen "Too many arguments"
routing issue in Gradio 5
- Extract _build_seg_meta, _cpu_preprocess, _save_wav, _log_inference_timing
helpers; replace inline stitch with _stitch_wavs
- Cap crossfade slider max at 4s + safety clamp in _build_segments
- Add diagonal hatch waveform indicator for crossfade overlap zones
- Remove HF Token settings accordion (ZeroGPU attribution via JWT)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +15 -23
app.py CHANGED
@@ -686,6 +686,12 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
686
  latents_scale,
687
  euler_sampler, euler_maruyama_sampler,
688
  )
 
 
 
 
 
 
689
  wavs.append(wav)
690
  _log_inference_timing("TARO", time.perf_counter() - _t_infer_start,
691
  len(segments), int(num_steps), TARO_SECS_PER_STEP)
@@ -735,16 +741,10 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
735
  first_cavp_saved = False
736
  outputs = []
737
  for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
738
- final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR)
739
-
740
- # ── FlashSR: upsample 16 kHz β†’ 48 kHz ──
741
- print(f"[TARO] Sample {sample_idx+1}: running FlashSR upsampler (16kHz β†’ 48kHz) …")
742
- final_wav = _apply_flashsr(final_wav)
743
- out_sr = FLASHSR_SR_OUT
744
- print(f"[TARO] Sample {sample_idx+1}: FlashSR complete β€” {len(final_wav)/out_sr:.2f}s @ {out_sr}Hz")
745
-
746
  audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
747
- _save_wav(audio_path, final_wav, out_sr)
748
  video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
749
  mux_video_audio(silent_video, audio_path, video_path)
750
  wav_paths = _save_seg_wavs(wavs, tmp_dir, f"taro_{sample_idx}")
@@ -756,7 +756,7 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
756
  first_cavp_saved = True
757
  seg_meta = _build_seg_meta(
758
  segments=segments, wav_paths=wav_paths, audio_path=audio_path,
759
- video_path=video_path, silent_video=silent_video, sr=out_sr,
760
  model="taro", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
761
  total_dur_s=total_dur_s, cavp_path=cavp_path, onset_path=onset_path,
762
  )
@@ -1206,16 +1206,10 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
1206
  seed_val, cfg_scale, num_steps, mode,
1207
  crossfade_s, crossfade_db, slot_id)
1208
 
1209
- # FlashSR: upsample 16 kHz β†’ 48 kHz before splicing
1210
- print(f"[TARO regen] Running FlashSR upsampler (16kHz β†’ 48kHz) on seg {seg_idx} …")
1211
- new_wav = _apply_flashsr(new_wav)
1212
- print(f"[TARO regen] FlashSR complete β€” {len(new_wav)/FLASHSR_SR_OUT:.2f}s @ {FLASHSR_SR_OUT}Hz")
1213
-
1214
- # CPU: splice, stitch, mux, save β€” meta["sr"] must reflect the upsampled rate
1215
- meta_48k = dict(meta)
1216
- meta_48k["sr"] = FLASHSR_SR_OUT
1217
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1218
- new_wav, seg_idx, meta_48k, slot_id
1219
  )
1220
  return video_path, audio_path, json.dumps(updated_meta), waveform_html
1221
 
@@ -1483,10 +1477,7 @@ def xregen_taro(seg_idx, state_json, slot_id,
1483
  new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
1484
  seed_val, cfg_scale, num_steps, mode,
1485
  crossfade_s, crossfade_db, slot_id)
1486
- # FlashSR: upsample 16 kHz β†’ 48 kHz before splicing into slot
1487
- print(f"[xregen TARO] Running FlashSR upsampler (16kHz β†’ 48kHz) on seg {seg_idx} …")
1488
- new_wav_raw = _apply_flashsr(new_wav_raw)
1489
- print(f"[xregen TARO] FlashSR complete β€” {len(new_wav_raw)/FLASHSR_SR_OUT:.2f}s @ {FLASHSR_SR_OUT}Hz")
1490
  video_path, waveform_html = _xregen_splice(new_wav_raw, FLASHSR_SR_OUT, meta, seg_idx, slot_id)
1491
  yield gr.update(value=video_path), gr.update(value=waveform_html)
1492
 
@@ -2278,6 +2269,7 @@ _GLOBAL_JS = """
2278
  body: JSON.stringify({
2279
  data: data,
2280
  fn_index: fnIndex,
 
2281
  session_hash: window.__gradio_session_hash__,
2282
  event_data: null,
2283
  trigger_id: null
 
686
  latents_scale,
687
  euler_sampler, euler_maruyama_sampler,
688
  )
689
+ # FlashSR: upsample 16kHz β†’ 48kHz inside GPU window to avoid
690
+ # ZeroGPU CUDA-init-in-main-process violation
691
+ print(f"[FlashSR] Upsampling seg {len(wavs)+1} "
692
+ f"{seg_end_s-seg_start_s:.2f}s @ 16kHz β†’ 48kHz …")
693
+ wav = _apply_flashsr(wav)
694
+ print(f"[FlashSR] Done β€” {len(wav)/FLASHSR_SR_OUT:.2f}s @ {FLASHSR_SR_OUT}Hz")
695
  wavs.append(wav)
696
  _log_inference_timing("TARO", time.perf_counter() - _t_infer_start,
697
  len(segments), int(num_steps), TARO_SECS_PER_STEP)
 
741
  first_cavp_saved = False
742
  outputs = []
743
  for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
744
+ # wavs are already at 48kHz β€” FlashSR ran inside _taro_gpu_infer
745
+ final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, FLASHSR_SR_OUT)
 
 
 
 
 
 
746
  audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
747
+ _save_wav(audio_path, final_wav, FLASHSR_SR_OUT)
748
  video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
749
  mux_video_audio(silent_video, audio_path, video_path)
750
  wav_paths = _save_seg_wavs(wavs, tmp_dir, f"taro_{sample_idx}")
 
756
  first_cavp_saved = True
757
  seg_meta = _build_seg_meta(
758
  segments=segments, wav_paths=wav_paths, audio_path=audio_path,
759
+ video_path=video_path, silent_video=silent_video, sr=FLASHSR_SR_OUT,
760
  model="taro", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
761
  total_dur_s=total_dur_s, cavp_path=cavp_path, onset_path=onset_path,
762
  )
 
1206
  seed_val, cfg_scale, num_steps, mode,
1207
  crossfade_s, crossfade_db, slot_id)
1208
 
1209
+ # new_wav is already at 48kHz β€” FlashSR ran inside _regen_taro_gpu β†’ _taro_infer_segment
1210
+ # CPU: splice, stitch, mux, save
 
 
 
 
 
 
1211
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1212
+ new_wav, seg_idx, meta, slot_id
1213
  )
1214
  return video_path, audio_path, json.dumps(updated_meta), waveform_html
1215
 
 
1477
  new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
1478
  seed_val, cfg_scale, num_steps, mode,
1479
  crossfade_s, crossfade_db, slot_id)
1480
+ # new_wav_raw already at 48kHz β€” FlashSR ran inside _regen_taro_gpu β†’ _taro_infer_segment
 
 
 
1481
  video_path, waveform_html = _xregen_splice(new_wav_raw, FLASHSR_SR_OUT, meta, seg_idx, slot_id)
1482
  yield gr.update(value=video_path), gr.update(value=waveform_html)
1483
 
 
2269
  body: JSON.stringify({
2270
  data: data,
2271
  fn_index: fnIndex,
2272
+ api_name: '/' + apiName,
2273
  session_hash: window.__gradio_session_hash__,
2274
  event_data: null,
2275
  trigger_id: null