BoxOfColors Claude Sonnet 4.6 commited on
Commit
3272260
Β·
1 Parent(s): 813c771

ZeroGPU optimizations, refactoring, and regen UX improvements

Browse files

ZeroGPU optimizations:
- Pre-load TARO CAVP/onset features on CPU via _preload_taro_regen_ctx()
- Pre-load HunyuanFoley text_feats on CPU via _preload_hunyuan_regen_ctx()
- Move FlowMatching creation outside per-segment loop in MMAudio
- Replace fallback video extraction in regen GPU fns with asserts
- Tighten TARO_SECS_PER_STEP 0.05β†’0.025 (measured 0.023s/step)
- Drop regen duration floor 60s→20s (cold-start spin-up not in timer)

Refactoring:
- Extract _post_process_samples() shared post-processing for all 3 models
- Unify mux_video_audio() to handle HunyuanFoley internally via model= param
- Extract _preload_taro_regen_ctx / _preload_hunyuan_regen_ctx helpers
- Make _resample_to_target() accept dst_sr; _resample_to_slot_sr delegates to it

Regen UX:
- Flash red border on waveform for 3s when regen aborts
- Show 'GPU cold-start β€” segment unchanged, try again' instead of misleading
'Quota exceeded' for GPU task aborted errors
- Status bar and seg label both auto-clear after 8s

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +198 -132
app.py CHANGED
@@ -261,14 +261,24 @@ def _load_hunyuan_model(device, model_size):
261
  enable_offload=False, model_size=model_size)
262
 
263
 
264
- def mux_video_audio(silent_video: str, audio_path: str, output_path: str) -> None:
265
- """Mux a silent video with an audio file into *output_path* (stream-copy video, encode audio)."""
266
- ffmpeg.output(
267
- ffmpeg.input(silent_video),
268
- ffmpeg.input(audio_path),
269
- output_path,
270
- vcodec="copy", acodec="aac", strict="experimental",
271
- ).run(overwrite_output=True, quiet=True)
 
 
 
 
 
 
 
 
 
 
272
 
273
 
274
  # ------------------------------------------------------------------ #
@@ -340,7 +350,7 @@ TARO_FPS = 4
340
  TARO_TRUNCATE_FRAME = int(TARO_FPS * TARO_TRUNCATE / TARO_SR) # 32
341
  TARO_TRUNCATE_ONSET = 120
342
  TARO_MODEL_DUR = TARO_TRUNCATE / TARO_SR # 8.192 s
343
- TARO_SECS_PER_STEP = 0.05 # measured 0.043s/step on H200 (8.2s video, 2 segs Γ— 25 steps = 2.2s wall)
344
 
345
  TARO_LOAD_OVERHEAD = 15 # seconds: model load + CAVP feature extraction
346
  MMAUDIO_WINDOW = 8.0 # seconds β€” MMAudio's fixed generation window
@@ -359,7 +369,7 @@ MODEL_CONFIGS = {
359
  "taro": {
360
  "window_s": TARO_MODEL_DUR, # 8.192 s
361
  "sr": TARO_SR, # 16000
362
- "secs_per_step": TARO_SECS_PER_STEP, # 0.05
363
  "load_overhead": TARO_LOAD_OVERHEAD, # 15
364
  "tab_prefix": "taro",
365
  "regen_fn": None, # set after function definitions (avoids forward-ref)
@@ -410,11 +420,14 @@ def _estimate_gpu_duration(model_key: str, num_samples: int, num_steps: int,
410
 
411
  def _estimate_regen_duration(model_key: str, num_steps: int) -> int:
412
  """Generic GPU duration estimator for single-segment regen.
413
- Uses a lower floor (30s) than initial generation since regen only runs
414
- one segment β€” saves 30s of wasted ZeroGPU quota per regen call."""
 
 
 
415
  cfg = MODEL_CONFIGS[model_key]
416
  secs = int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
417
- result = min(GPU_DURATION_CAP, max(60, int(secs)))
418
  print(f"[duration] {cfg['label']} regen: 1 seg Γ— {int(num_steps)} steps β†’ {secs:.0f}s β†’ capped {result}s")
419
  return result
420
 
@@ -509,19 +522,22 @@ TARGET_SR = 48000 # unified output sample rate for all three models
509
  TARO_SR_OUT = TARGET_SR
510
 
511
 
512
- def _resample_to_target(wav: np.ndarray, src_sr: int) -> np.ndarray:
513
- """Resample *wav* (mono or stereo numpy float32) from src_sr to TARGET_SR (48kHz).
 
514
 
515
- No-op if src_sr already equals TARGET_SR. Uses torchaudio Kaiser-windowed
516
- sinc resampling β€” CPU-only, ZeroGPU-safe.
517
  """
518
- if src_sr == TARGET_SR:
 
 
519
  return wav
520
  stereo = wav.ndim == 2
521
  t = torch.from_numpy(np.ascontiguousarray(wav.astype(np.float32)))
522
  if not stereo:
523
  t = t.unsqueeze(0) # [1, T]
524
- t = torchaudio.functional.resample(t, src_sr, TARGET_SR)
525
  if not stereo:
526
  t = t.squeeze(0) # [T]
527
  return t.numpy()
@@ -592,6 +608,45 @@ def _build_seg_meta(*, segments, wav_paths, audio_path, video_path,
592
  return meta
593
 
594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
  def _cpu_preprocess(video_file: str, model_dur: float,
596
  crossfade_s: float) -> tuple:
597
  """Shared CPU pre-processing for all generate_* wrappers.
@@ -709,34 +764,34 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
709
  crossfade_s, crossfade_db, num_samples)
710
 
711
  # ── CPU post-processing (no GPU needed) ──
712
- # Cache CAVP + onset features once (same for all samples β€” they depend only on the video)
713
  cavp_path = os.path.join(tmp_dir, "taro_cavp.npy")
714
  onset_path = os.path.join(tmp_dir, "taro_onset.npy")
715
- first_cavp_saved = False
716
- outputs = []
717
- for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
718
- # Upsample each segment 16kHz β†’ 48kHz (sinc, CPU)
 
719
  wavs = [_upsample_taro(w) for w in wavs]
720
- final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR_OUT)
721
- audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
722
- _save_wav(audio_path, final_wav, TARO_SR_OUT)
723
- video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
724
- mux_video_audio(silent_video, audio_path, video_path)
725
- wav_paths = _save_seg_wavs(wavs, tmp_dir, f"taro_{sample_idx}")
726
- # Save shared features once (not per-sample β€” they're identical)
727
- if not first_cavp_saved:
728
  np.save(cavp_path, cavp_feats)
729
  if onset_feats is not None:
730
  np.save(onset_path, onset_feats)
731
- first_cavp_saved = True
732
- seg_meta = _build_seg_meta(
733
- segments=segments, wav_paths=wav_paths, audio_path=audio_path,
734
- video_path=video_path, silent_video=silent_video, sr=TARO_SR_OUT,
735
- model="taro", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
736
- total_dur_s=total_dur_s, cavp_path=cavp_path, onset_path=onset_path,
737
- )
738
- outputs.append((video_path, audio_path, seg_meta))
739
 
 
 
 
 
 
 
 
 
 
 
740
  return _pad_outputs(outputs)
741
 
742
 
@@ -794,12 +849,12 @@ def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
794
 
795
  seg_audios = []
796
  _t_mma_start = time.perf_counter()
 
797
 
798
  for seg_i, (seg_start, seg_end) in enumerate(segments):
799
  seg_dur = seg_end - seg_start
800
  seg_path = seg_clip_paths[seg_i]
801
 
802
- fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
803
  video_info = load_video(seg_path, seg_dur)
804
  clip_frames = video_info.clip_frames.unsqueeze(0)
805
  sync_frames = video_info.sync_frames.unsqueeze(0)
@@ -868,29 +923,21 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
868
  cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples)
869
 
870
  # ── CPU post-processing ──
871
- outputs = []
872
- for sample_idx, (seg_audios, sr) in enumerate(results):
873
- # Resample 44100 β†’ 48000 Hz so all three models share the same output SR
874
  if sr != TARGET_SR:
875
  print(f"[MMAudio upsample] resampling {sr}Hz β†’ {TARGET_SR}Hz (sinc, CPU) …")
876
  seg_audios = [_resample_to_target(w, sr) for w in seg_audios]
877
  print(f"[MMAudio upsample] done β€” {len(seg_audios)} seg(s) @ {TARGET_SR}Hz")
878
- sr = TARGET_SR
879
- full_wav = _stitch_wavs(seg_audios, crossfade_s, crossfade_db, total_dur_s, sr)
880
-
881
- audio_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.wav")
882
- _save_wav(audio_path, full_wav, sr)
883
- video_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.mp4")
884
- mux_video_audio(silent_video, audio_path, video_path)
885
- wav_paths = _save_seg_wavs(seg_audios, tmp_dir, f"mmaudio_{sample_idx}")
886
- seg_meta = _build_seg_meta(
887
- segments=segments, wav_paths=wav_paths, audio_path=audio_path,
888
- video_path=video_path, silent_video=silent_video, sr=sr,
889
- model="mmaudio", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
890
- total_dur_s=total_dur_s,
891
- )
892
- outputs.append((video_path, audio_path, seg_meta))
893
 
 
 
 
 
 
 
894
  return _pad_outputs(outputs)
895
 
896
 
@@ -1034,28 +1081,19 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
1034
  crossfade_s, crossfade_db, num_samples)
1035
 
1036
  # ── CPU post-processing (no GPU needed) ──
1037
- _ensure_syspath("HunyuanVideo-Foley")
1038
- from hunyuanvideo_foley.utils.media_utils import merge_audio_video
1039
-
1040
- outputs = []
1041
- for sample_idx, (seg_wavs, sr, text_feats) in enumerate(results):
1042
- full_wav = _stitch_wavs(seg_wavs, crossfade_s, crossfade_db, total_dur_s, sr)
1043
-
1044
- audio_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.wav")
1045
- _save_wav(audio_path, full_wav, sr)
1046
- video_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.mp4")
1047
- merge_audio_video(audio_path, silent_video, video_path)
1048
- wav_paths = _save_seg_wavs(seg_wavs, tmp_dir, f"hunyuan_{sample_idx}")
1049
- text_feats_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}_text_feats.pt")
1050
- torch.save(text_feats, text_feats_path)
1051
- seg_meta = _build_seg_meta(
1052
- segments=segments, wav_paths=wav_paths, audio_path=audio_path,
1053
- video_path=video_path, silent_video=silent_video, sr=sr,
1054
- model="hunyuan", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
1055
- total_dur_s=total_dur_s, text_feats_path=text_feats_path,
1056
- )
1057
- outputs.append((video_path, audio_path, seg_meta))
1058
-
1059
  return _pad_outputs(outputs)
1060
 
1061
 
@@ -1069,6 +1107,28 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
1069
  # 4. Returns (new_video_path, new_audio_path, updated_seg_meta, new_waveform_html)
1070
  # ================================================================== #
1071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1072
  def _splice_and_save(new_wav, seg_idx, meta, slot_id):
1073
  """Replace wavs[seg_idx] with new_wav, re-stitch, re-save, re-mux.
1074
  Returns (video_path, audio_path, updated_meta, waveform_html).
@@ -1099,13 +1159,7 @@ def _splice_and_save(new_wav, seg_idx, meta, slot_id):
1099
  _vid_base = os.path.splitext(os.path.basename(meta["video_path"]))[0]
1100
  _vid_base_clean = _vid_base.rsplit("_regen_", 1)[0]
1101
  video_path = os.path.join(tmp_dir, f"{_vid_base_clean}_regen_{_ts}.mp4")
1102
- if model == "hunyuan":
1103
- # HunyuanFoley uses its own merge_audio_video
1104
- _ensure_syspath("HunyuanVideo-Foley")
1105
- from hunyuanvideo_foley.utils.media_utils import merge_audio_video
1106
- merge_audio_video(audio_path, silent_video, video_path)
1107
- else:
1108
- mux_video_audio(silent_video, audio_path, video_path)
1109
 
1110
  # Save updated segment wavs to .npy files
1111
  updated_wav_paths = _save_seg_wavs(wavs, tmp_dir, os.path.splitext(_base_clean)[0])
@@ -1157,12 +1211,12 @@ def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
1157
  _ensure_syspath("TARO")
1158
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
1159
 
1160
- cavp_path = meta.get("cavp_path")
1161
- onset_path = meta.get("onset_path")
1162
- if cavp_path and os.path.exists(cavp_path) and onset_path and os.path.exists(onset_path):
1163
- print("[TARO regen] Loading cached CAVP + onset features")
1164
- cavp_feats = np.load(cavp_path)
1165
- onset_feats = np.load(onset_path)
1166
  else:
1167
  print("[TARO regen] Cache miss β€” re-extracting CAVP + onset features")
1168
  from TARO.onset_util import extract_onset
@@ -1195,6 +1249,9 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
1195
  meta = json.loads(seg_meta_json)
1196
  seg_idx = int(seg_idx)
1197
 
 
 
 
1198
  # GPU: inference only
1199
  new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
1200
  seed_val, cfg_scale, num_steps, mode,
@@ -1234,14 +1291,9 @@ def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1234
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
1235
  sr = seq_cfg.sampling_rate
1236
 
1237
- # Use pre-extracted segment clip from the wrapper
1238
  seg_path = _regen_mmaudio_gpu._cpu_ctx.get("seg_path")
1239
- if not seg_path:
1240
- # Fallback: extract inside GPU (shouldn't happen)
1241
- seg_path = _extract_segment_clip(
1242
- meta["silent_video"], seg_start, seg_dur,
1243
- os.path.join(tempfile.mkdtemp(), "regen_seg.mp4"),
1244
- )
1245
 
1246
  rng = torch.Generator(device=device)
1247
  rng.manual_seed(random.randint(0, 2**32 - 1))
@@ -1335,18 +1387,15 @@ def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1335
 
1336
  # Use pre-extracted segment clip from wrapper
1337
  seg_path = _regen_hunyuan_gpu._cpu_ctx.get("seg_path")
1338
- if not seg_path:
1339
- seg_path = _extract_segment_clip(
1340
- meta["silent_video"], seg_start, seg_dur,
1341
- os.path.join(tempfile.mkdtemp(), "regen_seg.mp4"),
1342
- )
1343
 
1344
- text_feats_path = meta.get("text_feats_path")
1345
- if text_feats_path and os.path.exists(text_feats_path):
1346
- print("[HunyuanFoley regen] Loading cached text features, extracting visual only")
 
1347
  from hunyuanvideo_foley.utils.feature_utils import encode_video_features
1348
  visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict)
1349
- text_feats = torch.load(text_feats_path, map_location=device, weights_only=False)
1350
  else:
1351
  print("[HunyuanFoley regen] Cache miss β€” extracting text + visual features")
1352
  visual_feats, text_feats, seg_audio_len = feature_process(
@@ -1377,13 +1426,13 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1377
  seg_start, seg_end = meta["segments"][seg_idx]
1378
  seg_dur = seg_end - seg_start
1379
 
1380
- # CPU: pre-extract segment clip
1381
  tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
1382
  seg_path = _extract_segment_clip(
1383
  meta["silent_video"], seg_start, seg_dur,
1384
  os.path.join(tmp_dir, "regen_seg.mp4"),
1385
  )
1386
- _regen_hunyuan_gpu._cpu_ctx = {"seg_path": seg_path}
1387
 
1388
  # GPU: inference only
1389
  new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
@@ -1419,26 +1468,17 @@ MODEL_CONFIGS["hunyuan"]["regen_fn"] = regen_hunyuan_segment
1419
 
1420
  def _resample_to_slot_sr(wav: np.ndarray, src_sr: int, dst_sr: int,
1421
  slot_wav_ref: np.ndarray = None) -> np.ndarray:
1422
- """Resample *wav* from src_sr to dst_sr using torchaudio, then match
1423
- channel layout to *slot_wav_ref* (the first existing segment in the slot).
1424
 
1425
  TARO is mono (T,), MMAudio/Hunyuan are stereo (C, T). Mixing them
1426
  without normalisation causes a shape mismatch in _cf_join. Rules:
1427
- β€’ stereo β†’ mono : average channels
1428
- β€’ mono β†’ stereo: duplicate the single channel
1429
  """
1430
- # 1. Resample
1431
- if src_sr != dst_sr:
1432
- stereo_in = wav.ndim == 2
1433
- t = torch.from_numpy(np.ascontiguousarray(wav))
1434
- if not stereo_in:
1435
- t = t.unsqueeze(0)
1436
- t = torchaudio.functional.resample(t.float(), src_sr, dst_sr)
1437
- if not stereo_in:
1438
- t = t.squeeze(0)
1439
- wav = t.numpy()
1440
-
1441
- # 2. Match channel layout to the slot's existing segments
1442
  if slot_wav_ref is not None:
1443
  slot_stereo = slot_wav_ref.ndim == 2
1444
  wav_stereo = wav.ndim == 2
@@ -1474,6 +1514,9 @@ def xregen_taro(seg_idx, state_json, slot_id,
1474
  pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
1475
  yield gr.update(), gr.update(value=pending_html)
1476
 
 
 
 
1477
  new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
1478
  seed_val, cfg_scale, num_steps, mode,
1479
  crossfade_s, crossfade_db, slot_id)
@@ -1528,7 +1571,7 @@ def xregen_hunyuan(seg_idx, state_json, slot_id,
1528
  meta["silent_video"], seg_start, seg_end - seg_start,
1529
  os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1530
  )
1531
- _regen_hunyuan_gpu._cpu_ctx = {"seg_path": seg_path}
1532
 
1533
  new_wav_raw, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
1534
  prompt, negative_prompt, seed_val,
@@ -2364,6 +2407,7 @@ _GLOBAL_JS = """
2364
  var lbl = document.getElementById('wf_seglabel_' + slot_id);
2365
  if (hadError) {
2366
  var toastMsg = typeof errMsg === 'string' ? errMsg : JSON.stringify(errMsg);
 
2367
  if (preRegenWaveHtml !== null) {
2368
  var waveEl2 = document.getElementById('slot_wave_' + slot_id);
2369
  if (waveEl2) waveEl2.innerHTML = preRegenWaveHtml;
@@ -2372,13 +2416,35 @@ _GLOBAL_JS = """
2372
  var vidElR = document.getElementById('slot_vid_' + slot_id);
2373
  if (vidElR) { var vR = vidElR.querySelector('video'); if (vR) { vR.setAttribute('src', preRegenVideoSrc); vR.src = preRegenVideoSrc; vR.load(); } }
2374
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2375
  var statusBar = document.getElementById('wf_statusbar_' + slot_id);
2376
  if (statusBar) {
2377
  statusBar.style.color = '#e05252';
2378
- statusBar.textContent = '\u26a0 ' + toastMsg;
2379
- setTimeout(function() { statusBar.style.color = '#888'; statusBar.textContent = 'Click a segment to regenerate \u00a0|\u00a0 Playhead syncs to video'; }, 15000);
 
 
 
 
 
2380
  }
2381
- if (lbl) lbl.textContent = 'Quota exceeded β€” try again later';
2382
  } else {
2383
  if (lbl) lbl.textContent = 'Done';
2384
  var src = _pendingVideoSrc;
 
261
  enable_offload=False, model_size=model_size)
262
 
263
 
264
+ def mux_video_audio(silent_video: str, audio_path: str, output_path: str,
265
+ model: str = None) -> None:
266
+ """Mux a silent video with an audio file into *output_path*.
267
+
268
+ For HunyuanFoley (*model*="hunyuan") we use its own merge_audio_video which
269
+ handles its specific ffmpeg quirks; all other models use stream-copy muxing.
270
+ """
271
+ if model == "hunyuan":
272
+ _ensure_syspath("HunyuanVideo-Foley")
273
+ from hunyuanvideo_foley.utils.media_utils import merge_audio_video
274
+ merge_audio_video(audio_path, silent_video, output_path)
275
+ else:
276
+ ffmpeg.output(
277
+ ffmpeg.input(silent_video),
278
+ ffmpeg.input(audio_path),
279
+ output_path,
280
+ vcodec="copy", acodec="aac", strict="experimental",
281
+ ).run(overwrite_output=True, quiet=True)
282
 
283
 
284
  # ------------------------------------------------------------------ #
 
350
  TARO_TRUNCATE_FRAME = int(TARO_FPS * TARO_TRUNCATE / TARO_SR) # 32
351
  TARO_TRUNCATE_ONSET = 120
352
  TARO_MODEL_DUR = TARO_TRUNCATE / TARO_SR # 8.192 s
353
+ TARO_SECS_PER_STEP = 0.025 # measured 0.023s/step on H200; was 0.05, tightened to halve GPU allocation
354
 
355
  TARO_LOAD_OVERHEAD = 15 # seconds: model load + CAVP feature extraction
356
  MMAUDIO_WINDOW = 8.0 # seconds β€” MMAudio's fixed generation window
 
369
  "taro": {
370
  "window_s": TARO_MODEL_DUR, # 8.192 s
371
  "sr": TARO_SR, # 16000
372
+ "secs_per_step": TARO_SECS_PER_STEP, # 0.025
373
  "load_overhead": TARO_LOAD_OVERHEAD, # 15
374
  "tab_prefix": "taro",
375
  "regen_fn": None, # set after function definitions (avoids forward-ref)
 
420
 
421
  def _estimate_regen_duration(model_key: str, num_steps: int) -> int:
422
  """Generic GPU duration estimator for single-segment regen.
423
+
424
+ Floor is 20s β€” enough headroom above the 10s ZeroGPU abort threshold
425
+ for any model on a warm worker. Cold-start spin-up happens *before*
426
+ the timer starts so raising the floor does not help with cold-start aborts.
427
+ """
428
  cfg = MODEL_CONFIGS[model_key]
429
  secs = int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
430
+ result = min(GPU_DURATION_CAP, max(20, int(secs)))
431
  print(f"[duration] {cfg['label']} regen: 1 seg Γ— {int(num_steps)} steps β†’ {secs:.0f}s β†’ capped {result}s")
432
  return result
433
 
 
522
  TARO_SR_OUT = TARGET_SR
523
 
524
 
525
+ def _resample_to_target(wav: np.ndarray, src_sr: int,
526
+ dst_sr: int = None) -> np.ndarray:
527
+ """Resample *wav* (mono or stereo numpy float32) from *src_sr* to *dst_sr*.
528
 
529
+ *dst_sr* defaults to TARGET_SR (48 kHz). No-op if src_sr == dst_sr.
530
+ Uses torchaudio Kaiser-windowed sinc resampling β€” CPU-only, ZeroGPU-safe.
531
  """
532
+ if dst_sr is None:
533
+ dst_sr = TARGET_SR
534
+ if src_sr == dst_sr:
535
  return wav
536
  stereo = wav.ndim == 2
537
  t = torch.from_numpy(np.ascontiguousarray(wav.astype(np.float32)))
538
  if not stereo:
539
  t = t.unsqueeze(0) # [1, T]
540
+ t = torchaudio.functional.resample(t, src_sr, dst_sr)
541
  if not stereo:
542
  t = t.squeeze(0) # [T]
543
  return t.numpy()
 
608
  return meta
609
 
610
 
611
+ def _post_process_samples(results: list, *, model: str, tmp_dir: str,
612
+ silent_video: str, segments: list,
613
+ crossfade_s: float, crossfade_db: float,
614
+ total_dur_s: float, sr: int,
615
+ extra_meta_fn=None) -> list:
616
+ """Shared CPU post-processing for all three generate_* wrappers.
617
+
618
+ Each entry in *results* is a tuple whose first element is a list of
619
+ per-segment wav arrays. The remaining elements are model-specific
620
+ (e.g. TARO returns features, HunyuanFoley returns text_feats).
621
+
622
+ *extra_meta_fn(sample_idx, result_tuple, tmp_dir) -> dict* is an optional
623
+ callback that returns model-specific extra keys to merge into seg_meta
624
+ (e.g. cavp_path, onset_path, text_feats_path).
625
+
626
+ Returns a list of (video_path, audio_path, seg_meta) tuples.
627
+ """
628
+ outputs = []
629
+ for sample_idx, result in enumerate(results):
630
+ seg_wavs = result[0]
631
+
632
+ full_wav = _stitch_wavs(seg_wavs, crossfade_s, crossfade_db, total_dur_s, sr)
633
+ audio_path = os.path.join(tmp_dir, f"{model}_{sample_idx}.wav")
634
+ _save_wav(audio_path, full_wav, sr)
635
+ video_path = os.path.join(tmp_dir, f"{model}_{sample_idx}.mp4")
636
+ mux_video_audio(silent_video, audio_path, video_path, model=model)
637
+ wav_paths = _save_seg_wavs(seg_wavs, tmp_dir, f"{model}_{sample_idx}")
638
+
639
+ extras = extra_meta_fn(sample_idx, result, tmp_dir) if extra_meta_fn else {}
640
+ seg_meta = _build_seg_meta(
641
+ segments=segments, wav_paths=wav_paths, audio_path=audio_path,
642
+ video_path=video_path, silent_video=silent_video, sr=sr,
643
+ model=model, crossfade_s=crossfade_s, crossfade_db=crossfade_db,
644
+ total_dur_s=total_dur_s, **extras,
645
+ )
646
+ outputs.append((video_path, audio_path, seg_meta))
647
+ return outputs
648
+
649
+
650
  def _cpu_preprocess(video_file: str, model_dur: float,
651
  crossfade_s: float) -> tuple:
652
  """Shared CPU pre-processing for all generate_* wrappers.
 
764
  crossfade_s, crossfade_db, num_samples)
765
 
766
  # ── CPU post-processing (no GPU needed) ──
767
+ # Upsample 16kHz β†’ 48kHz and normalise result tuples to (seg_wavs, ...)
768
  cavp_path = os.path.join(tmp_dir, "taro_cavp.npy")
769
  onset_path = os.path.join(tmp_dir, "taro_onset.npy")
770
+ _feats_saved = False
771
+
772
+ def _upsample_and_save_feats(result):
773
+ nonlocal _feats_saved
774
+ wavs, cavp_feats, onset_feats = result
775
  wavs = [_upsample_taro(w) for w in wavs]
776
+ if not _feats_saved:
 
 
 
 
 
 
 
777
  np.save(cavp_path, cavp_feats)
778
  if onset_feats is not None:
779
  np.save(onset_path, onset_feats)
780
+ _feats_saved = True
781
+ return (wavs, cavp_feats, onset_feats)
782
+
783
+ results = [_upsample_and_save_feats(r) for r in results]
 
 
 
 
784
 
785
+ def _taro_extras(sample_idx, result, td):
786
+ return {"cavp_path": cavp_path, "onset_path": onset_path}
787
+
788
+ outputs = _post_process_samples(
789
+ results, model="taro", tmp_dir=tmp_dir,
790
+ silent_video=silent_video, segments=segments,
791
+ crossfade_s=crossfade_s, crossfade_db=crossfade_db,
792
+ total_dur_s=total_dur_s, sr=TARO_SR_OUT,
793
+ extra_meta_fn=_taro_extras,
794
+ )
795
  return _pad_outputs(outputs)
796
 
797
 
 
849
 
850
  seg_audios = []
851
  _t_mma_start = time.perf_counter()
852
+ fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
853
 
854
  for seg_i, (seg_start, seg_end) in enumerate(segments):
855
  seg_dur = seg_end - seg_start
856
  seg_path = seg_clip_paths[seg_i]
857
 
 
858
  video_info = load_video(seg_path, seg_dur)
859
  clip_frames = video_info.clip_frames.unsqueeze(0)
860
  sync_frames = video_info.sync_frames.unsqueeze(0)
 
923
  cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples)
924
 
925
  # ── CPU post-processing ──
926
+ # Resample 44100 β†’ 48000 and normalise tuples to (seg_wavs, ...)
927
+ resampled = []
928
+ for seg_audios, sr in results:
929
  if sr != TARGET_SR:
930
  print(f"[MMAudio upsample] resampling {sr}Hz β†’ {TARGET_SR}Hz (sinc, CPU) …")
931
  seg_audios = [_resample_to_target(w, sr) for w in seg_audios]
932
  print(f"[MMAudio upsample] done β€” {len(seg_audios)} seg(s) @ {TARGET_SR}Hz")
933
+ resampled.append((seg_audios,))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
934
 
935
+ outputs = _post_process_samples(
936
+ resampled, model="mmaudio", tmp_dir=tmp_dir,
937
+ silent_video=silent_video, segments=segments,
938
+ crossfade_s=crossfade_s, crossfade_db=crossfade_db,
939
+ total_dur_s=total_dur_s, sr=TARGET_SR,
940
+ )
941
  return _pad_outputs(outputs)
942
 
943
 
 
1081
  crossfade_s, crossfade_db, num_samples)
1082
 
1083
  # ── CPU post-processing (no GPU needed) ──
1084
+ def _hunyuan_extras(sample_idx, result, td):
1085
+ _, _sr, text_feats = result
1086
+ path = os.path.join(td, f"hunyuan_{sample_idx}_text_feats.pt")
1087
+ torch.save(text_feats, path)
1088
+ return {"text_feats_path": path}
1089
+
1090
+ outputs = _post_process_samples(
1091
+ results, model="hunyuan", tmp_dir=tmp_dir,
1092
+ silent_video=silent_video, segments=segments,
1093
+ crossfade_s=crossfade_s, crossfade_db=crossfade_db,
1094
+ total_dur_s=total_dur_s, sr=48000,
1095
+ extra_meta_fn=_hunyuan_extras,
1096
+ )
 
 
 
 
 
 
 
 
 
1097
  return _pad_outputs(outputs)
1098
 
1099
 
 
1107
  # 4. Returns (new_video_path, new_audio_path, updated_seg_meta, new_waveform_html)
1108
  # ================================================================== #
1109
 
1110
+ def _preload_taro_regen_ctx(meta: dict) -> dict:
1111
+ """Pre-load TARO CAVP/onset features on CPU for regen.
1112
+ Returns a dict suitable for _regen_taro_gpu._cpu_ctx."""
1113
+ cavp_path = meta.get("cavp_path", "")
1114
+ onset_path = meta.get("onset_path", "")
1115
+ ctx = {}
1116
+ if cavp_path and os.path.exists(cavp_path) and onset_path and os.path.exists(onset_path):
1117
+ ctx["cavp"] = np.load(cavp_path)
1118
+ ctx["onset"] = np.load(onset_path)
1119
+ return ctx
1120
+
1121
+
1122
+ def _preload_hunyuan_regen_ctx(meta: dict, seg_path: str) -> dict:
1123
+ """Pre-load HunyuanFoley text features + segment path on CPU for regen.
1124
+ Returns a dict suitable for _regen_hunyuan_gpu._cpu_ctx."""
1125
+ ctx = {"seg_path": seg_path}
1126
+ text_feats_path = meta.get("text_feats_path", "")
1127
+ if text_feats_path and os.path.exists(text_feats_path):
1128
+ ctx["text_feats"] = torch.load(text_feats_path, map_location="cpu", weights_only=False)
1129
+ return ctx
1130
+
1131
+
1132
  def _splice_and_save(new_wav, seg_idx, meta, slot_id):
1133
  """Replace wavs[seg_idx] with new_wav, re-stitch, re-save, re-mux.
1134
  Returns (video_path, audio_path, updated_meta, waveform_html).
 
1159
  _vid_base = os.path.splitext(os.path.basename(meta["video_path"]))[0]
1160
  _vid_base_clean = _vid_base.rsplit("_regen_", 1)[0]
1161
  video_path = os.path.join(tmp_dir, f"{_vid_base_clean}_regen_{_ts}.mp4")
1162
+ mux_video_audio(silent_video, audio_path, video_path, model=model)
 
 
 
 
 
 
1163
 
1164
  # Save updated segment wavs to .npy files
1165
  updated_wav_paths = _save_seg_wavs(wavs, tmp_dir, os.path.splitext(_base_clean)[0])
 
1211
  _ensure_syspath("TARO")
1212
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
1213
 
1214
+ # Use pre-loaded features from CPU wrapper (avoids np.load inside GPU window)
1215
+ ctx = _regen_taro_gpu._cpu_ctx
1216
+ if "cavp" in ctx and "onset" in ctx:
1217
+ print("[TARO regen] Using pre-loaded CAVP + onset features (CPU cache hit)")
1218
+ cavp_feats = ctx["cavp"]
1219
+ onset_feats = ctx["onset"]
1220
  else:
1221
  print("[TARO regen] Cache miss β€” re-extracting CAVP + onset features")
1222
  from TARO.onset_util import extract_onset
 
1249
  meta = json.loads(seg_meta_json)
1250
  seg_idx = int(seg_idx)
1251
 
1252
+ # CPU: pre-load cached features so np.load doesn't happen inside GPU window
1253
+ _regen_taro_gpu._cpu_ctx = _preload_taro_regen_ctx(meta)
1254
+
1255
  # GPU: inference only
1256
  new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
1257
  seed_val, cfg_scale, num_steps, mode,
 
1291
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
1292
  sr = seq_cfg.sampling_rate
1293
 
1294
+ # Use pre-extracted segment clip from the CPU wrapper
1295
  seg_path = _regen_mmaudio_gpu._cpu_ctx.get("seg_path")
1296
+ assert seg_path, "[MMAudio regen] seg_path not set β€” wrapper must pre-extract segment clip"
 
 
 
 
 
1297
 
1298
  rng = torch.Generator(device=device)
1299
  rng.manual_seed(random.randint(0, 2**32 - 1))
 
1387
 
1388
  # Use pre-extracted segment clip from wrapper
1389
  seg_path = _regen_hunyuan_gpu._cpu_ctx.get("seg_path")
1390
+ assert seg_path, "[HunyuanFoley regen] seg_path not set β€” wrapper must pre-extract segment clip"
 
 
 
 
1391
 
1392
+ # Use pre-loaded text_feats from CPU wrapper (avoids torch.load inside GPU window)
1393
+ ctx = _regen_hunyuan_gpu._cpu_ctx
1394
+ if "text_feats" in ctx:
1395
+ print("[HunyuanFoley regen] Using pre-loaded text features (CPU cache hit)")
1396
  from hunyuanvideo_foley.utils.feature_utils import encode_video_features
1397
  visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict)
1398
+ text_feats = ctx["text_feats"].to(device)
1399
  else:
1400
  print("[HunyuanFoley regen] Cache miss β€” extracting text + visual features")
1401
  visual_feats, text_feats, seg_audio_len = feature_process(
 
1426
  seg_start, seg_end = meta["segments"][seg_idx]
1427
  seg_dur = seg_end - seg_start
1428
 
1429
+ # CPU: pre-extract segment clip + pre-load cached text features
1430
  tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
1431
  seg_path = _extract_segment_clip(
1432
  meta["silent_video"], seg_start, seg_dur,
1433
  os.path.join(tmp_dir, "regen_seg.mp4"),
1434
  )
1435
+ _regen_hunyuan_gpu._cpu_ctx = _preload_hunyuan_regen_ctx(meta, seg_path)
1436
 
1437
  # GPU: inference only
1438
  new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
 
1468
 
1469
  def _resample_to_slot_sr(wav: np.ndarray, src_sr: int, dst_sr: int,
1470
  slot_wav_ref: np.ndarray = None) -> np.ndarray:
1471
+ """Resample *wav* from src_sr to dst_sr, then match channel layout to
1472
+ *slot_wav_ref* (the first existing segment in the slot).
1473
 
1474
  TARO is mono (T,), MMAudio/Hunyuan are stereo (C, T). Mixing them
1475
  without normalisation causes a shape mismatch in _cf_join. Rules:
1476
+ - stereo β†’ mono : average channels
1477
+ - mono β†’ stereo: duplicate the single channel
1478
  """
1479
+ wav = _resample_to_target(wav, src_sr, dst_sr)
1480
+
1481
+ # Match channel layout to the slot's existing segments
 
 
 
 
 
 
 
 
 
1482
  if slot_wav_ref is not None:
1483
  slot_stereo = slot_wav_ref.ndim == 2
1484
  wav_stereo = wav.ndim == 2
 
1514
  pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
1515
  yield gr.update(), gr.update(value=pending_html)
1516
 
1517
+ # CPU: pre-load cached features so np.load doesn't happen inside GPU window
1518
+ _regen_taro_gpu._cpu_ctx = _preload_taro_regen_ctx(meta)
1519
+
1520
  new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
1521
  seed_val, cfg_scale, num_steps, mode,
1522
  crossfade_s, crossfade_db, slot_id)
 
1571
  meta["silent_video"], seg_start, seg_end - seg_start,
1572
  os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1573
  )
1574
+ _regen_hunyuan_gpu._cpu_ctx = _preload_hunyuan_regen_ctx(meta, seg_path)
1575
 
1576
  new_wav_raw, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
1577
  prompt, negative_prompt, seed_val,
 
2407
  var lbl = document.getElementById('wf_seglabel_' + slot_id);
2408
  if (hadError) {
2409
  var toastMsg = typeof errMsg === 'string' ? errMsg : JSON.stringify(errMsg);
2410
+ // Restore previous waveform HTML and video src
2411
  if (preRegenWaveHtml !== null) {
2412
  var waveEl2 = document.getElementById('slot_wave_' + slot_id);
2413
  if (waveEl2) waveEl2.innerHTML = preRegenWaveHtml;
 
2416
  var vidElR = document.getElementById('slot_vid_' + slot_id);
2417
  if (vidElR) { var vR = vidElR.querySelector('video'); if (vR) { vR.setAttribute('src', preRegenVideoSrc); vR.src = preRegenVideoSrc; vR.load(); } }
2418
  }
2419
+ // Flash the waveform iframe border red so it's obvious the segment didn't change
2420
+ var iframeEl = document.getElementById('wf_iframe_' + slot_id);
2421
+ if (!iframeEl) {
2422
+ // waveform may have been restored into preRegenWaveHtml β€” find via slot_wave wrapper
2423
+ var waveWrap = document.getElementById('slot_wave_' + slot_id);
2424
+ if (waveWrap) iframeEl = waveWrap.querySelector('iframe[id^="wf_iframe_"]');
2425
+ }
2426
+ if (iframeEl) {
2427
+ iframeEl.style.transition = 'box-shadow 0.15s';
2428
+ iframeEl.style.boxShadow = '0 0 0 2px #e05252';
2429
+ setTimeout(function() { iframeEl.style.boxShadow = 'none'; }, 3000);
2430
+ }
2431
+ // Pick a human-readable message based on the error text
2432
+ var isAbort = toastMsg.toLowerCase().indexOf('aborted') !== -1;
2433
+ var isTimeout = toastMsg.toLowerCase().indexOf('timeout') !== -1;
2434
+ var userMsg = isAbort || isTimeout
2435
+ ? '\u26a0\ufe0f GPU cold-start β€” segment unchanged, try again'
2436
+ : '\u26a0\ufe0f Regen failed β€” segment unchanged';
2437
  var statusBar = document.getElementById('wf_statusbar_' + slot_id);
2438
  if (statusBar) {
2439
  statusBar.style.color = '#e05252';
2440
+ statusBar.textContent = userMsg;
2441
+ setTimeout(function() { statusBar.style.color = '#888'; statusBar.textContent = 'Click a segment to regenerate \u00a0|\u00a0 Playhead syncs to video'; }, 8000);
2442
+ }
2443
+ if (lbl) {
2444
+ lbl.style.color = '#e05252';
2445
+ lbl.textContent = isAbort || isTimeout ? 'Cold-start abort β€” segment unchanged, try again' : 'Regen failed β€” segment unchanged';
2446
+ setTimeout(function() { lbl.style.color = '#aaa'; lbl.textContent = ''; }, 8000);
2447
  }
 
2448
  } else {
2449
  if (lbl) lbl.textContent = 'Done';
2450
  var src = _pendingVideoSrc;