BoxOfColors Claude Sonnet 4.6 commited on
Commit
4c173d1
Β·
1 Parent(s): e152b28

fix: pass CPU context as ctx_json argument to @spaces.GPU functions

Browse files

ZeroGPU runs GPU functions on its own worker thread pool β€” thread-local
storage and per-thread dicts both fail because the writer and reader are
on different threads with different thread IDs.

The only reliable approach: pass context as a JSON string argument.
ZeroGPU forwards all arguments to the GPU worker unchanged.

Changes:
- Add ctx_json='{}' parameter to all 6 @spaces.GPU functions
(_taro_gpu_infer, _mmaudio_gpu_infer, _hunyuan_gpu_infer,
_regen_taro_gpu, _regen_mmaudio_gpu, _regen_hunyuan_gpu)
- Each wrapper serialises its pre-computed data to json.dumps({...})
and passes it as the last positional argument
- GPU functions parse with json.loads(ctx_json)
- TARO/HunyuanFoley regen: numpy/tensor features loaded directly from
disk paths already stored in seg_meta_json β€” no pre-serialisation needed
- Remove dead _preload_taro_regen_ctx, _preload_hunyuan_regen_ctx helpers
- Remove _CTX/_ctx_set/_ctx_get infrastructure (replaced by arg passing)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +52 -92
app.py CHANGED
@@ -124,32 +124,10 @@ print(f"[startup] All downloads done in {time.perf_counter() - _t_dl_start:.1f}s
124
  # SHARED CONSTANTS / HELPERS #
125
  # ================================================================== #
126
 
127
- # Per-caller-thread context store for CPU β†’ GPU context passing.
128
- # Replaces the fragile function-attribute pattern (_fn._cpu_ctx = {...}).
129
- #
130
- # WHY NOT threading.local():
131
- # ZeroGPU dispatches @spaces.GPU functions on its OWN worker thread, not
132
- # the Gradio handler thread. threading.local() values are invisible across
133
- # threads, so the GPU worker would always see an empty namespace.
134
- #
135
- # SOLUTION β€” a plain dict keyed by (caller_thread_id, context_name):
136
- # The wrapper writes _CTX[(tid, key)] = value before calling the GPU fn.
137
- # The GPU fn reads _CTX.get((tid, key)) β€” same tid because ZeroGPU runs
138
- # the function synchronously on behalf of the calling thread (the caller
139
- # blocks until the GPU task completes, so there is no concurrent write).
140
- # Entries are deleted after the GPU fn reads them to avoid memory leaks.
141
- _CTX: dict = {}
142
- _CTX_LOCK = threading.Lock()
143
-
144
- def _ctx_set(key: str, value) -> None:
145
- tid = threading.get_ident()
146
- with _CTX_LOCK:
147
- _CTX[(tid, key)] = value
148
-
149
- def _ctx_get(key: str, default=None):
150
- tid = threading.get_ident()
151
- with _CTX_LOCK:
152
- return _CTX.pop((tid, key), default)
153
 
154
  MAX_SLOTS = 8 # max parallel generation slots shown in UI
155
  MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≀ ~64 s at 8 s/seg)
@@ -791,7 +769,7 @@ def _cpu_preprocess(video_file: str, model_dur: float,
791
 
792
  @spaces.GPU(duration=_taro_duration)
793
  def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
794
- crossfade_s, crossfade_db, num_samples):
795
  """GPU-only TARO inference β€” model loading + feature extraction + diffusion.
796
  Returns list of (wavs_list, onset_feats) per sample."""
797
  seed_val = int(seed_val)
@@ -807,8 +785,7 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
807
  from TARO.onset_util import extract_onset
808
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
809
 
810
- # Use pre-computed CPU results passed via cross-thread context store
811
- ctx = _ctx_get("taro_gen_ctx")
812
  tmp_dir = ctx["tmp_dir"]
813
  silent_video = ctx["silent_video"]
814
  segments = ctx["segments"]
@@ -880,15 +857,14 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
880
  tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
881
  video_file, TARO_MODEL_DUR, crossfade_s)
882
 
883
- # Pass pre-computed CPU results to the GPU function via cross-thread context store
884
- _ctx_set("taro_gen_ctx", {
885
  "tmp_dir": tmp_dir, "silent_video": silent_video,
886
  "segments": segments, "total_dur_s": total_dur_s,
887
  })
888
 
889
  # ── GPU inference only ──
890
  results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
891
- crossfade_s, crossfade_db, num_samples)
892
 
893
  # ── CPU post-processing (no GPU needed) ──
894
  # Upsample 16kHz β†’ 48kHz and normalise result tuples to (seg_wavs, ...)
@@ -945,7 +921,8 @@ def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
945
 
946
  @spaces.GPU(duration=_mmaudio_duration)
947
  def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
948
- cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
 
949
  """GPU-only MMAudio inference β€” model loading + flow-matching generation.
950
  Returns list of (seg_audios, sr) per sample."""
951
  _ensure_syspath("MMAudio")
@@ -960,7 +937,7 @@ def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
960
 
961
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
962
 
963
- ctx = _ctx_get("mmaudio_gen_ctx")
964
  segments = ctx["segments"]
965
  seg_clip_paths = ctx["seg_clip_paths"]
966
 
@@ -1039,13 +1016,12 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
1039
  for i, (s, e) in enumerate(segments)
1040
  ]
1041
 
1042
- _ctx_set("mmaudio_gen_ctx", {
1043
- "segments": segments, "seg_clip_paths": seg_clip_paths,
1044
- })
1045
 
1046
  # ── GPU inference only ──
1047
  results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1048
- cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples)
 
1049
 
1050
  # ── CPU post-processing ──
1051
  # Resample 44100 β†’ 48000 and normalise tuples to (seg_wavs, ...)
@@ -1090,7 +1066,8 @@ def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
1090
 
1091
  @spaces.GPU(duration=_hunyuan_duration)
1092
  def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1093
- guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
 
1094
  """GPU-only HunyuanFoley inference β€” model loading + feature extraction + denoising.
1095
  Returns list of (seg_wavs, sr, text_feats) per sample."""
1096
  _ensure_syspath("HunyuanVideo-Foley")
@@ -1109,7 +1086,7 @@ def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1109
 
1110
  model_dict, cfg = _load_hunyuan_model(device, model_size)
1111
 
1112
- ctx = _ctx_get("hunyuan_gen_ctx")
1113
  segments = ctx["segments"]
1114
  total_dur_s = ctx["total_dur_s"]
1115
  dummy_seg_path = ctx["dummy_seg_path"]
@@ -1193,7 +1170,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
1193
  for i, (s, e) in enumerate(segments)
1194
  ]
1195
 
1196
- _ctx_set("hunyuan_gen_ctx", {
1197
  "segments": segments, "total_dur_s": total_dur_s,
1198
  "dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
1199
  })
@@ -1201,7 +1178,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
1201
  # ── GPU inference only ──
1202
  results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1203
  guidance_scale, num_steps, model_size,
1204
- crossfade_s, crossfade_db, num_samples)
1205
 
1206
  # ── CPU post-processing (no GPU needed) ──
1207
  def _hunyuan_extras(sample_idx, result, td):
@@ -1230,27 +1207,6 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
1230
  # 4. Returns (new_video_path, new_audio_path, updated_seg_meta, new_waveform_html)
1231
  # ================================================================== #
1232
 
1233
- def _preload_taro_regen_ctx(meta: dict) -> dict:
1234
- """Pre-load TARO CAVP/onset features on CPU for regen.
1235
- Returns a dict for _ctx_set("taro_regen_ctx", ...)."""
1236
- cavp_path = meta.get("cavp_path", "")
1237
- onset_path = meta.get("onset_path", "")
1238
- ctx = {}
1239
- if cavp_path and os.path.exists(cavp_path) and onset_path and os.path.exists(onset_path):
1240
- ctx["cavp"] = np.load(cavp_path)
1241
- ctx["onset"] = np.load(onset_path)
1242
- return ctx
1243
-
1244
-
1245
- def _preload_hunyuan_regen_ctx(meta: dict, seg_path: str) -> dict:
1246
- """Pre-load HunyuanFoley text features + segment path on CPU for regen.
1247
- Returns a dict for _ctx_set("hunyuan_regen_ctx", ...)."""
1248
- ctx = {"seg_path": seg_path}
1249
- text_feats_path = meta.get("text_feats_path", "")
1250
- if text_feats_path and os.path.exists(text_feats_path):
1251
- ctx["text_feats"] = torch.load(text_feats_path, map_location="cpu", weights_only=False)
1252
- return ctx
1253
-
1254
 
1255
  def _splice_and_save(new_wav, seg_idx, meta, slot_id):
1256
  """Replace wavs[seg_idx] with new_wav, re-stitch, re-save, re-mux.
@@ -1334,12 +1290,13 @@ def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
1334
  _ensure_syspath("TARO")
1335
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
1336
 
1337
- # Use pre-loaded features from CPU wrapper (avoids np.load inside GPU window)
1338
- ctx = _ctx_get("taro_regen_ctx", {})
1339
- if "cavp" in ctx and "onset" in ctx:
1340
- print("[TARO regen] Using pre-loaded CAVP + onset features (CPU cache hit)")
1341
- cavp_feats = ctx["cavp"]
1342
- onset_feats = ctx["onset"]
 
1343
  else:
1344
  print("[TARO regen] Cache miss β€” re-extracting CAVP + onset features")
1345
  from TARO.onset_util import extract_onset
@@ -1348,7 +1305,6 @@ def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
1348
  tmp_dir = tempfile.mkdtemp()
1349
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
1350
  onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
1351
- # Free feature extractors before loading inference models
1352
  del extract_cavp, onset_model
1353
  if torch.cuda.is_available():
1354
  torch.cuda.empty_cache()
@@ -1372,10 +1328,7 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
1372
  meta = json.loads(seg_meta_json)
1373
  seg_idx = int(seg_idx)
1374
 
1375
- # CPU: pre-load cached features so np.load doesn't happen inside GPU window
1376
- _ctx_set("taro_regen_ctx", _preload_taro_regen_ctx(meta))
1377
-
1378
- # GPU: inference only
1379
  new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
1380
  seed_val, cfg_scale, num_steps, mode,
1381
  crossfade_s, crossfade_db, slot_id)
@@ -1398,7 +1351,8 @@ def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
1398
  @spaces.GPU(duration=_mmaudio_regen_duration)
1399
  def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1400
  prompt, negative_prompt, seed_val,
1401
- cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id=None):
 
1402
  """GPU-only MMAudio regen β€” returns (new_wav, sr) for a single segment."""
1403
  meta = json.loads(seg_meta_json)
1404
  seg_idx = int(seg_idx)
@@ -1414,8 +1368,7 @@ def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1414
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
1415
  sr = seq_cfg.sampling_rate
1416
 
1417
- # Use pre-extracted segment clip from the CPU wrapper
1418
- seg_path = _ctx_get("mmaudio_regen_ctx", {}).get("seg_path")
1419
  assert seg_path, "[MMAudio regen] seg_path not set β€” wrapper must pre-extract segment clip"
1420
 
1421
  rng = torch.Generator(device=device)
@@ -1457,12 +1410,13 @@ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
1457
  meta["silent_video"], seg_start, seg_dur,
1458
  os.path.join(tmp_dir, "regen_seg.mp4"),
1459
  )
1460
- _ctx_set("mmaudio_regen_ctx", {"seg_path": seg_path})
1461
 
1462
  # GPU: inference only
1463
  new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1464
  prompt, negative_prompt, seed_val,
1465
- cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id)
 
1466
 
1467
  # Resample to 48kHz if needed (MMAudio outputs at 44100 Hz)
1468
  if sr != TARGET_SR:
@@ -1489,7 +1443,7 @@ def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json,
1489
  def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1490
  prompt, negative_prompt, seed_val,
1491
  guidance_scale, num_steps, model_size,
1492
- crossfade_s, crossfade_db, slot_id=None):
1493
  """GPU-only HunyuanFoley regen β€” returns (new_wav, sr) for a single segment."""
1494
  meta = json.loads(seg_meta_json)
1495
  seg_idx = int(seg_idx)
@@ -1506,16 +1460,16 @@ def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1506
 
1507
  set_global_seed(random.randint(0, 2**32 - 1))
1508
 
1509
- # Use pre-extracted segment clip + text_feats from CPU wrapper
1510
- ctx = _ctx_get("hunyuan_regen_ctx", {})
1511
  seg_path = ctx.get("seg_path")
1512
  assert seg_path, "[HunyuanFoley regen] seg_path not set β€” wrapper must pre-extract segment clip"
1513
 
1514
- if "text_feats" in ctx:
1515
- print("[HunyuanFoley regen] Using pre-loaded text features (CPU cache hit)")
 
1516
  from hunyuanvideo_foley.utils.feature_utils import encode_video_features
1517
  visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict)
1518
- text_feats = ctx["text_feats"].to(device)
1519
  else:
1520
  print("[HunyuanFoley regen] Cache miss β€” extracting text + visual features")
1521
  visual_feats, text_feats, seg_audio_len = feature_process(
@@ -1550,13 +1504,16 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1550
  meta["silent_video"], seg_start, seg_dur,
1551
  os.path.join(tmp_dir, "regen_seg.mp4"),
1552
  )
1553
- _ctx_set("hunyuan_regen_ctx", _preload_hunyuan_regen_ctx(meta, seg_path))
 
 
 
1554
 
1555
  # GPU: inference only
1556
  new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1557
  prompt, negative_prompt, seed_val,
1558
  guidance_scale, num_steps, model_size,
1559
- crossfade_s, crossfade_db, slot_id)
1560
 
1561
  meta["sr"] = sr
1562
 
@@ -1650,7 +1607,7 @@ def xregen_taro(seg_idx, state_json, slot_id,
1650
  meta = json.loads(state_json)
1651
 
1652
  def _run():
1653
- _ctx_set("taro_regen_ctx", _preload_taro_regen_ctx(meta))
1654
  wav = _regen_taro_gpu(None, seg_idx, state_json,
1655
  seed_val, cfg_scale, num_steps, mode,
1656
  crossfade_s, crossfade_db, slot_id)
@@ -1673,11 +1630,11 @@ def xregen_mmaudio(seg_idx, state_json, slot_id,
1673
  meta["silent_video"], seg_start, seg_end - seg_start,
1674
  os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1675
  )
1676
- _ctx_set("mmaudio_regen_ctx", {"seg_path": seg_path})
1677
  wav, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
1678
  prompt, negative_prompt, seed_val,
1679
  cfg_strength, num_steps,
1680
- crossfade_s, crossfade_db, slot_id)
1681
  return wav, src_sr
1682
 
1683
  yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
@@ -1698,11 +1655,14 @@ def xregen_hunyuan(seg_idx, state_json, slot_id,
1698
  meta["silent_video"], seg_start, seg_end - seg_start,
1699
  os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1700
  )
1701
- _ctx_set("hunyuan_regen_ctx", _preload_hunyuan_regen_ctx(meta, seg_path))
 
 
 
1702
  wav, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
1703
  prompt, negative_prompt, seed_val,
1704
  guidance_scale, num_steps, model_size,
1705
- crossfade_s, crossfade_db, slot_id)
1706
  return wav, src_sr
1707
 
1708
  yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
 
124
  # SHARED CONSTANTS / HELPERS #
125
  # ================================================================== #
126
 
127
+ # CPU β†’ GPU context passing: each wrapper serialises pre-computed CPU data
128
+ # into a JSON string and passes it as the last argument (ctx_json) to the
129
+ # @spaces.GPU function. ZeroGPU forwards all arguments to the GPU worker
130
+ # unchanged, so no shared state or thread-local tricks are needed.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  MAX_SLOTS = 8 # max parallel generation slots shown in UI
133
  MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≀ ~64 s at 8 s/seg)
 
769
 
770
  @spaces.GPU(duration=_taro_duration)
771
  def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
772
+ crossfade_s, crossfade_db, num_samples, ctx_json="{}"):
773
  """GPU-only TARO inference β€” model loading + feature extraction + diffusion.
774
  Returns list of (wavs_list, onset_feats) per sample."""
775
  seed_val = int(seed_val)
 
785
  from TARO.onset_util import extract_onset
786
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
787
 
788
+ ctx = json.loads(ctx_json)
 
789
  tmp_dir = ctx["tmp_dir"]
790
  silent_video = ctx["silent_video"]
791
  segments = ctx["segments"]
 
857
  tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
858
  video_file, TARO_MODEL_DUR, crossfade_s)
859
 
860
+ ctx_json = json.dumps({
 
861
  "tmp_dir": tmp_dir, "silent_video": silent_video,
862
  "segments": segments, "total_dur_s": total_dur_s,
863
  })
864
 
865
  # ── GPU inference only ──
866
  results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
867
+ crossfade_s, crossfade_db, num_samples, ctx_json)
868
 
869
  # ── CPU post-processing (no GPU needed) ──
870
  # Upsample 16kHz β†’ 48kHz and normalise result tuples to (seg_wavs, ...)
 
921
 
922
  @spaces.GPU(duration=_mmaudio_duration)
923
  def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
924
+ cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples,
925
+ ctx_json="{}"):
926
  """GPU-only MMAudio inference β€” model loading + flow-matching generation.
927
  Returns list of (seg_audios, sr) per sample."""
928
  _ensure_syspath("MMAudio")
 
937
 
938
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
939
 
940
+ ctx = json.loads(ctx_json)
941
  segments = ctx["segments"]
942
  seg_clip_paths = ctx["seg_clip_paths"]
943
 
 
1016
  for i, (s, e) in enumerate(segments)
1017
  ]
1018
 
1019
+ ctx_json = json.dumps({"segments": segments, "seg_clip_paths": seg_clip_paths})
 
 
1020
 
1021
  # ── GPU inference only ──
1022
  results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1023
+ cfg_strength, num_steps, crossfade_s, crossfade_db,
1024
+ num_samples, ctx_json)
1025
 
1026
  # ── CPU post-processing ──
1027
  # Resample 44100 β†’ 48000 and normalise tuples to (seg_wavs, ...)
 
1066
 
1067
  @spaces.GPU(duration=_hunyuan_duration)
1068
  def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1069
+ guidance_scale, num_steps, model_size, crossfade_s, crossfade_db,
1070
+ num_samples, ctx_json="{}"):
1071
  """GPU-only HunyuanFoley inference β€” model loading + feature extraction + denoising.
1072
  Returns list of (seg_wavs, sr, text_feats) per sample."""
1073
  _ensure_syspath("HunyuanVideo-Foley")
 
1086
 
1087
  model_dict, cfg = _load_hunyuan_model(device, model_size)
1088
 
1089
+ ctx = json.loads(ctx_json)
1090
  segments = ctx["segments"]
1091
  total_dur_s = ctx["total_dur_s"]
1092
  dummy_seg_path = ctx["dummy_seg_path"]
 
1170
  for i, (s, e) in enumerate(segments)
1171
  ]
1172
 
1173
+ ctx_json = json.dumps({
1174
  "segments": segments, "total_dur_s": total_dur_s,
1175
  "dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
1176
  })
 
1178
  # ── GPU inference only ──
1179
  results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1180
  guidance_scale, num_steps, model_size,
1181
+ crossfade_s, crossfade_db, num_samples, ctx_json)
1182
 
1183
  # ── CPU post-processing (no GPU needed) ──
1184
  def _hunyuan_extras(sample_idx, result, td):
 
1207
  # 4. Returns (new_video_path, new_audio_path, updated_seg_meta, new_waveform_html)
1208
  # ================================================================== #
1209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1210
 
1211
  def _splice_and_save(new_wav, seg_idx, meta, slot_id):
1212
  """Replace wavs[seg_idx] with new_wav, re-stitch, re-save, re-mux.
 
1290
  _ensure_syspath("TARO")
1291
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
1292
 
1293
+ # Load cached CAVP/onset features from .npy files (CPU I/O, fast, outside GPU budget)
1294
+ cavp_path = meta.get("cavp_path", "")
1295
+ onset_path = meta.get("onset_path", "")
1296
+ if cavp_path and os.path.exists(cavp_path) and onset_path and os.path.exists(onset_path):
1297
+ print("[TARO regen] Loading cached CAVP + onset features from disk")
1298
+ cavp_feats = np.load(cavp_path)
1299
+ onset_feats = np.load(onset_path)
1300
  else:
1301
  print("[TARO regen] Cache miss β€” re-extracting CAVP + onset features")
1302
  from TARO.onset_util import extract_onset
 
1305
  tmp_dir = tempfile.mkdtemp()
1306
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
1307
  onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
 
1308
  del extract_cavp, onset_model
1309
  if torch.cuda.is_available():
1310
  torch.cuda.empty_cache()
 
1328
  meta = json.loads(seg_meta_json)
1329
  seg_idx = int(seg_idx)
1330
 
1331
+ # GPU: inference β€” CAVP/onset features loaded from disk paths in seg_meta_json
 
 
 
1332
  new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
1333
  seed_val, cfg_scale, num_steps, mode,
1334
  crossfade_s, crossfade_db, slot_id)
 
1351
  @spaces.GPU(duration=_mmaudio_regen_duration)
1352
  def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1353
  prompt, negative_prompt, seed_val,
1354
+ cfg_strength, num_steps, crossfade_s, crossfade_db,
1355
+ slot_id=None, ctx_json="{}"):
1356
  """GPU-only MMAudio regen β€” returns (new_wav, sr) for a single segment."""
1357
  meta = json.loads(seg_meta_json)
1358
  seg_idx = int(seg_idx)
 
1368
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
1369
  sr = seq_cfg.sampling_rate
1370
 
1371
+ seg_path = json.loads(ctx_json).get("seg_path")
 
1372
  assert seg_path, "[MMAudio regen] seg_path not set β€” wrapper must pre-extract segment clip"
1373
 
1374
  rng = torch.Generator(device=device)
 
1410
  meta["silent_video"], seg_start, seg_dur,
1411
  os.path.join(tmp_dir, "regen_seg.mp4"),
1412
  )
1413
+ ctx_json = json.dumps({"seg_path": seg_path})
1414
 
1415
  # GPU: inference only
1416
  new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1417
  prompt, negative_prompt, seed_val,
1418
+ cfg_strength, num_steps, crossfade_s, crossfade_db,
1419
+ slot_id, ctx_json)
1420
 
1421
  # Resample to 48kHz if needed (MMAudio outputs at 44100 Hz)
1422
  if sr != TARGET_SR:
 
1443
  def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1444
  prompt, negative_prompt, seed_val,
1445
  guidance_scale, num_steps, model_size,
1446
+ crossfade_s, crossfade_db, slot_id=None, ctx_json="{}"):
1447
  """GPU-only HunyuanFoley regen β€” returns (new_wav, sr) for a single segment."""
1448
  meta = json.loads(seg_meta_json)
1449
  seg_idx = int(seg_idx)
 
1460
 
1461
  set_global_seed(random.randint(0, 2**32 - 1))
1462
 
1463
+ ctx = json.loads(ctx_json)
 
1464
  seg_path = ctx.get("seg_path")
1465
  assert seg_path, "[HunyuanFoley regen] seg_path not set β€” wrapper must pre-extract segment clip"
1466
 
1467
+ text_feats_path = ctx.get("text_feats_path", "")
1468
+ if text_feats_path and os.path.exists(text_feats_path):
1469
+ print("[HunyuanFoley regen] Loading cached text features from disk")
1470
  from hunyuanvideo_foley.utils.feature_utils import encode_video_features
1471
  visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict)
1472
+ text_feats = torch.load(text_feats_path, map_location=device, weights_only=False)
1473
  else:
1474
  print("[HunyuanFoley regen] Cache miss β€” extracting text + visual features")
1475
  visual_feats, text_feats, seg_audio_len = feature_process(
 
1504
  meta["silent_video"], seg_start, seg_dur,
1505
  os.path.join(tmp_dir, "regen_seg.mp4"),
1506
  )
1507
+ ctx_json = json.dumps({
1508
+ "seg_path": seg_path,
1509
+ "text_feats_path": meta.get("text_feats_path", ""),
1510
+ })
1511
 
1512
  # GPU: inference only
1513
  new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1514
  prompt, negative_prompt, seed_val,
1515
  guidance_scale, num_steps, model_size,
1516
+ crossfade_s, crossfade_db, slot_id, ctx_json)
1517
 
1518
  meta["sr"] = sr
1519
 
 
1607
  meta = json.loads(state_json)
1608
 
1609
  def _run():
1610
+ # CAVP/onset features are loaded from disk paths inside the GPU fn
1611
  wav = _regen_taro_gpu(None, seg_idx, state_json,
1612
  seed_val, cfg_scale, num_steps, mode,
1613
  crossfade_s, crossfade_db, slot_id)
 
1630
  meta["silent_video"], seg_start, seg_end - seg_start,
1631
  os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1632
  )
1633
+ ctx_json = json.dumps({"seg_path": seg_path})
1634
  wav, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
1635
  prompt, negative_prompt, seed_val,
1636
  cfg_strength, num_steps,
1637
+ crossfade_s, crossfade_db, slot_id, ctx_json)
1638
  return wav, src_sr
1639
 
1640
  yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
 
1655
  meta["silent_video"], seg_start, seg_end - seg_start,
1656
  os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1657
  )
1658
+ ctx_json = json.dumps({
1659
+ "seg_path": seg_path,
1660
+ "text_feats_path": meta.get("text_feats_path", ""),
1661
+ })
1662
  wav, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
1663
  prompt, negative_prompt, seed_val,
1664
  guidance_scale, num_steps, model_size,
1665
+ crossfade_s, crossfade_db, slot_id, ctx_json)
1666
  return wav, src_sr
1667
 
1668
  yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)