fix: pass CPU context as ctx_json argument to @spaces.GPU functions
Browse filesZeroGPU runs GPU functions on its own worker thread pool β thread-local
storage and per-thread dicts both fail because the writer and reader are
on different threads with different thread IDs.
The only reliable approach: pass context as a JSON string argument.
ZeroGPU forwards all arguments to the GPU worker unchanged.
Changes:
- Add ctx_json='{}' parameter to all 6 @spaces.GPU functions
(_taro_gpu_infer, _mmaudio_gpu_infer, _hunyuan_gpu_infer,
_regen_taro_gpu, _regen_mmaudio_gpu, _regen_hunyuan_gpu)
- Each wrapper serialises its pre-computed data to json.dumps({...})
and passes it as the last positional argument
- GPU functions parse with json.loads(ctx_json)
- TARO/HunyuanFoley regen: numpy/tensor features loaded directly from
disk paths already stored in seg_meta_json β no pre-serialisation needed
- Remove dead _preload_taro_regen_ctx, _preload_hunyuan_regen_ctx helpers
- Remove _CTX/_ctx_set/_ctx_get infrastructure (replaced by arg passing)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
@@ -124,32 +124,10 @@ print(f"[startup] All downloads done in {time.perf_counter() - _t_dl_start:.1f}s
|
|
| 124 |
# SHARED CONSTANTS / HELPERS #
|
| 125 |
# ================================================================== #
|
| 126 |
|
| 127 |
-
#
|
| 128 |
-
#
|
| 129 |
-
#
|
| 130 |
-
#
|
| 131 |
-
# ZeroGPU dispatches @spaces.GPU functions on its OWN worker thread, not
|
| 132 |
-
# the Gradio handler thread. threading.local() values are invisible across
|
| 133 |
-
# threads, so the GPU worker would always see an empty namespace.
|
| 134 |
-
#
|
| 135 |
-
# SOLUTION β a plain dict keyed by (caller_thread_id, context_name):
|
| 136 |
-
# The wrapper writes _CTX[(tid, key)] = value before calling the GPU fn.
|
| 137 |
-
# The GPU fn reads _CTX.get((tid, key)) β same tid because ZeroGPU runs
|
| 138 |
-
# the function synchronously on behalf of the calling thread (the caller
|
| 139 |
-
# blocks until the GPU task completes, so there is no concurrent write).
|
| 140 |
-
# Entries are deleted after the GPU fn reads them to avoid memory leaks.
|
| 141 |
-
_CTX: dict = {}
|
| 142 |
-
_CTX_LOCK = threading.Lock()
|
| 143 |
-
|
| 144 |
-
def _ctx_set(key: str, value) -> None:
|
| 145 |
-
tid = threading.get_ident()
|
| 146 |
-
with _CTX_LOCK:
|
| 147 |
-
_CTX[(tid, key)] = value
|
| 148 |
-
|
| 149 |
-
def _ctx_get(key: str, default=None):
|
| 150 |
-
tid = threading.get_ident()
|
| 151 |
-
with _CTX_LOCK:
|
| 152 |
-
return _CTX.pop((tid, key), default)
|
| 153 |
|
| 154 |
MAX_SLOTS = 8 # max parallel generation slots shown in UI
|
| 155 |
MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video β€ ~64 s at 8 s/seg)
|
|
@@ -791,7 +769,7 @@ def _cpu_preprocess(video_file: str, model_dur: float,
|
|
| 791 |
|
| 792 |
@spaces.GPU(duration=_taro_duration)
|
| 793 |
def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
|
| 794 |
-
crossfade_s, crossfade_db, num_samples):
|
| 795 |
"""GPU-only TARO inference β model loading + feature extraction + diffusion.
|
| 796 |
Returns list of (wavs_list, onset_feats) per sample."""
|
| 797 |
seed_val = int(seed_val)
|
|
@@ -807,8 +785,7 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 807 |
from TARO.onset_util import extract_onset
|
| 808 |
from TARO.samplers import euler_sampler, euler_maruyama_sampler
|
| 809 |
|
| 810 |
-
|
| 811 |
-
ctx = _ctx_get("taro_gen_ctx")
|
| 812 |
tmp_dir = ctx["tmp_dir"]
|
| 813 |
silent_video = ctx["silent_video"]
|
| 814 |
segments = ctx["segments"]
|
|
@@ -880,15 +857,14 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 880 |
tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
|
| 881 |
video_file, TARO_MODEL_DUR, crossfade_s)
|
| 882 |
|
| 883 |
-
|
| 884 |
-
_ctx_set("taro_gen_ctx", {
|
| 885 |
"tmp_dir": tmp_dir, "silent_video": silent_video,
|
| 886 |
"segments": segments, "total_dur_s": total_dur_s,
|
| 887 |
})
|
| 888 |
|
| 889 |
# ββ GPU inference only ββ
|
| 890 |
results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
|
| 891 |
-
crossfade_s, crossfade_db, num_samples)
|
| 892 |
|
| 893 |
# ββ CPU post-processing (no GPU needed) ββ
|
| 894 |
# Upsample 16kHz β 48kHz and normalise result tuples to (seg_wavs, ...)
|
|
@@ -945,7 +921,8 @@ def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
|
|
| 945 |
|
| 946 |
@spaces.GPU(duration=_mmaudio_duration)
|
| 947 |
def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
| 948 |
-
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples
|
|
|
|
| 949 |
"""GPU-only MMAudio inference β model loading + flow-matching generation.
|
| 950 |
Returns list of (seg_audios, sr) per sample."""
|
| 951 |
_ensure_syspath("MMAudio")
|
|
@@ -960,7 +937,7 @@ def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
|
| 960 |
|
| 961 |
net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
|
| 962 |
|
| 963 |
-
ctx =
|
| 964 |
segments = ctx["segments"]
|
| 965 |
seg_clip_paths = ctx["seg_clip_paths"]
|
| 966 |
|
|
@@ -1039,13 +1016,12 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
|
|
| 1039 |
for i, (s, e) in enumerate(segments)
|
| 1040 |
]
|
| 1041 |
|
| 1042 |
-
|
| 1043 |
-
"segments": segments, "seg_clip_paths": seg_clip_paths,
|
| 1044 |
-
})
|
| 1045 |
|
| 1046 |
# ββ GPU inference only ββ
|
| 1047 |
results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
| 1048 |
-
cfg_strength, num_steps, crossfade_s, crossfade_db,
|
|
|
|
| 1049 |
|
| 1050 |
# ββ CPU post-processing ββ
|
| 1051 |
# Resample 44100 β 48000 and normalise tuples to (seg_wavs, ...)
|
|
@@ -1090,7 +1066,8 @@ def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
|
|
| 1090 |
|
| 1091 |
@spaces.GPU(duration=_hunyuan_duration)
|
| 1092 |
def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
| 1093 |
-
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db,
|
|
|
|
| 1094 |
"""GPU-only HunyuanFoley inference β model loading + feature extraction + denoising.
|
| 1095 |
Returns list of (seg_wavs, sr, text_feats) per sample."""
|
| 1096 |
_ensure_syspath("HunyuanVideo-Foley")
|
|
@@ -1109,7 +1086,7 @@ def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
|
| 1109 |
|
| 1110 |
model_dict, cfg = _load_hunyuan_model(device, model_size)
|
| 1111 |
|
| 1112 |
-
ctx =
|
| 1113 |
segments = ctx["segments"]
|
| 1114 |
total_dur_s = ctx["total_dur_s"]
|
| 1115 |
dummy_seg_path = ctx["dummy_seg_path"]
|
|
@@ -1193,7 +1170,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
|
|
| 1193 |
for i, (s, e) in enumerate(segments)
|
| 1194 |
]
|
| 1195 |
|
| 1196 |
-
|
| 1197 |
"segments": segments, "total_dur_s": total_dur_s,
|
| 1198 |
"dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
|
| 1199 |
})
|
|
@@ -1201,7 +1178,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
|
|
| 1201 |
# ββ GPU inference only ββ
|
| 1202 |
results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
| 1203 |
guidance_scale, num_steps, model_size,
|
| 1204 |
-
crossfade_s, crossfade_db, num_samples)
|
| 1205 |
|
| 1206 |
# ββ CPU post-processing (no GPU needed) ββ
|
| 1207 |
def _hunyuan_extras(sample_idx, result, td):
|
|
@@ -1230,27 +1207,6 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
|
|
| 1230 |
# 4. Returns (new_video_path, new_audio_path, updated_seg_meta, new_waveform_html)
|
| 1231 |
# ================================================================== #
|
| 1232 |
|
| 1233 |
-
def _preload_taro_regen_ctx(meta: dict) -> dict:
|
| 1234 |
-
"""Pre-load TARO CAVP/onset features on CPU for regen.
|
| 1235 |
-
Returns a dict for _ctx_set("taro_regen_ctx", ...)."""
|
| 1236 |
-
cavp_path = meta.get("cavp_path", "")
|
| 1237 |
-
onset_path = meta.get("onset_path", "")
|
| 1238 |
-
ctx = {}
|
| 1239 |
-
if cavp_path and os.path.exists(cavp_path) and onset_path and os.path.exists(onset_path):
|
| 1240 |
-
ctx["cavp"] = np.load(cavp_path)
|
| 1241 |
-
ctx["onset"] = np.load(onset_path)
|
| 1242 |
-
return ctx
|
| 1243 |
-
|
| 1244 |
-
|
| 1245 |
-
def _preload_hunyuan_regen_ctx(meta: dict, seg_path: str) -> dict:
|
| 1246 |
-
"""Pre-load HunyuanFoley text features + segment path on CPU for regen.
|
| 1247 |
-
Returns a dict for _ctx_set("hunyuan_regen_ctx", ...)."""
|
| 1248 |
-
ctx = {"seg_path": seg_path}
|
| 1249 |
-
text_feats_path = meta.get("text_feats_path", "")
|
| 1250 |
-
if text_feats_path and os.path.exists(text_feats_path):
|
| 1251 |
-
ctx["text_feats"] = torch.load(text_feats_path, map_location="cpu", weights_only=False)
|
| 1252 |
-
return ctx
|
| 1253 |
-
|
| 1254 |
|
| 1255 |
def _splice_and_save(new_wav, seg_idx, meta, slot_id):
|
| 1256 |
"""Replace wavs[seg_idx] with new_wav, re-stitch, re-save, re-mux.
|
|
@@ -1334,12 +1290,13 @@ def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
|
|
| 1334 |
_ensure_syspath("TARO")
|
| 1335 |
from TARO.samplers import euler_sampler, euler_maruyama_sampler
|
| 1336 |
|
| 1337 |
-
#
|
| 1338 |
-
|
| 1339 |
-
|
| 1340 |
-
|
| 1341 |
-
|
| 1342 |
-
|
|
|
|
| 1343 |
else:
|
| 1344 |
print("[TARO regen] Cache miss β re-extracting CAVP + onset features")
|
| 1345 |
from TARO.onset_util import extract_onset
|
|
@@ -1348,7 +1305,6 @@ def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
|
|
| 1348 |
tmp_dir = tempfile.mkdtemp()
|
| 1349 |
cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
|
| 1350 |
onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
|
| 1351 |
-
# Free feature extractors before loading inference models
|
| 1352 |
del extract_cavp, onset_model
|
| 1353 |
if torch.cuda.is_available():
|
| 1354 |
torch.cuda.empty_cache()
|
|
@@ -1372,10 +1328,7 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
|
|
| 1372 |
meta = json.loads(seg_meta_json)
|
| 1373 |
seg_idx = int(seg_idx)
|
| 1374 |
|
| 1375 |
-
#
|
| 1376 |
-
_ctx_set("taro_regen_ctx", _preload_taro_regen_ctx(meta))
|
| 1377 |
-
|
| 1378 |
-
# GPU: inference only
|
| 1379 |
new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
|
| 1380 |
seed_val, cfg_scale, num_steps, mode,
|
| 1381 |
crossfade_s, crossfade_db, slot_id)
|
|
@@ -1398,7 +1351,8 @@ def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
|
|
| 1398 |
@spaces.GPU(duration=_mmaudio_regen_duration)
|
| 1399 |
def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
|
| 1400 |
prompt, negative_prompt, seed_val,
|
| 1401 |
-
cfg_strength, num_steps, crossfade_s, crossfade_db,
|
|
|
|
| 1402 |
"""GPU-only MMAudio regen β returns (new_wav, sr) for a single segment."""
|
| 1403 |
meta = json.loads(seg_meta_json)
|
| 1404 |
seg_idx = int(seg_idx)
|
|
@@ -1414,8 +1368,7 @@ def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
|
|
| 1414 |
net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
|
| 1415 |
sr = seq_cfg.sampling_rate
|
| 1416 |
|
| 1417 |
-
|
| 1418 |
-
seg_path = _ctx_get("mmaudio_regen_ctx", {}).get("seg_path")
|
| 1419 |
assert seg_path, "[MMAudio regen] seg_path not set β wrapper must pre-extract segment clip"
|
| 1420 |
|
| 1421 |
rng = torch.Generator(device=device)
|
|
@@ -1457,12 +1410,13 @@ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
|
|
| 1457 |
meta["silent_video"], seg_start, seg_dur,
|
| 1458 |
os.path.join(tmp_dir, "regen_seg.mp4"),
|
| 1459 |
)
|
| 1460 |
-
|
| 1461 |
|
| 1462 |
# GPU: inference only
|
| 1463 |
new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
|
| 1464 |
prompt, negative_prompt, seed_val,
|
| 1465 |
-
cfg_strength, num_steps, crossfade_s, crossfade_db,
|
|
|
|
| 1466 |
|
| 1467 |
# Resample to 48kHz if needed (MMAudio outputs at 44100 Hz)
|
| 1468 |
if sr != TARGET_SR:
|
|
@@ -1489,7 +1443,7 @@ def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json,
|
|
| 1489 |
def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
|
| 1490 |
prompt, negative_prompt, seed_val,
|
| 1491 |
guidance_scale, num_steps, model_size,
|
| 1492 |
-
crossfade_s, crossfade_db, slot_id=None):
|
| 1493 |
"""GPU-only HunyuanFoley regen β returns (new_wav, sr) for a single segment."""
|
| 1494 |
meta = json.loads(seg_meta_json)
|
| 1495 |
seg_idx = int(seg_idx)
|
|
@@ -1506,16 +1460,16 @@ def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
|
|
| 1506 |
|
| 1507 |
set_global_seed(random.randint(0, 2**32 - 1))
|
| 1508 |
|
| 1509 |
-
|
| 1510 |
-
ctx = _ctx_get("hunyuan_regen_ctx", {})
|
| 1511 |
seg_path = ctx.get("seg_path")
|
| 1512 |
assert seg_path, "[HunyuanFoley regen] seg_path not set β wrapper must pre-extract segment clip"
|
| 1513 |
|
| 1514 |
-
|
| 1515 |
-
|
|
|
|
| 1516 |
from hunyuanvideo_foley.utils.feature_utils import encode_video_features
|
| 1517 |
visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict)
|
| 1518 |
-
text_feats =
|
| 1519 |
else:
|
| 1520 |
print("[HunyuanFoley regen] Cache miss β extracting text + visual features")
|
| 1521 |
visual_feats, text_feats, seg_audio_len = feature_process(
|
|
@@ -1550,13 +1504,16 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
|
|
| 1550 |
meta["silent_video"], seg_start, seg_dur,
|
| 1551 |
os.path.join(tmp_dir, "regen_seg.mp4"),
|
| 1552 |
)
|
| 1553 |
-
|
|
|
|
|
|
|
|
|
|
| 1554 |
|
| 1555 |
# GPU: inference only
|
| 1556 |
new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
|
| 1557 |
prompt, negative_prompt, seed_val,
|
| 1558 |
guidance_scale, num_steps, model_size,
|
| 1559 |
-
crossfade_s, crossfade_db, slot_id)
|
| 1560 |
|
| 1561 |
meta["sr"] = sr
|
| 1562 |
|
|
@@ -1650,7 +1607,7 @@ def xregen_taro(seg_idx, state_json, slot_id,
|
|
| 1650 |
meta = json.loads(state_json)
|
| 1651 |
|
| 1652 |
def _run():
|
| 1653 |
-
|
| 1654 |
wav = _regen_taro_gpu(None, seg_idx, state_json,
|
| 1655 |
seed_val, cfg_scale, num_steps, mode,
|
| 1656 |
crossfade_s, crossfade_db, slot_id)
|
|
@@ -1673,11 +1630,11 @@ def xregen_mmaudio(seg_idx, state_json, slot_id,
|
|
| 1673 |
meta["silent_video"], seg_start, seg_end - seg_start,
|
| 1674 |
os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
|
| 1675 |
)
|
| 1676 |
-
|
| 1677 |
wav, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
|
| 1678 |
prompt, negative_prompt, seed_val,
|
| 1679 |
cfg_strength, num_steps,
|
| 1680 |
-
crossfade_s, crossfade_db, slot_id)
|
| 1681 |
return wav, src_sr
|
| 1682 |
|
| 1683 |
yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
|
|
@@ -1698,11 +1655,14 @@ def xregen_hunyuan(seg_idx, state_json, slot_id,
|
|
| 1698 |
meta["silent_video"], seg_start, seg_end - seg_start,
|
| 1699 |
os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
|
| 1700 |
)
|
| 1701 |
-
|
|
|
|
|
|
|
|
|
|
| 1702 |
wav, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
|
| 1703 |
prompt, negative_prompt, seed_val,
|
| 1704 |
guidance_scale, num_steps, model_size,
|
| 1705 |
-
crossfade_s, crossfade_db, slot_id)
|
| 1706 |
return wav, src_sr
|
| 1707 |
|
| 1708 |
yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
|
|
|
|
| 124 |
# SHARED CONSTANTS / HELPERS #
|
| 125 |
# ================================================================== #
|
| 126 |
|
| 127 |
+
# CPU β GPU context passing: each wrapper serialises pre-computed CPU data
|
| 128 |
+
# into a JSON string and passes it as the last argument (ctx_json) to the
|
| 129 |
+
# @spaces.GPU function. ZeroGPU forwards all arguments to the GPU worker
|
| 130 |
+
# unchanged, so no shared state or thread-local tricks are needed.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
MAX_SLOTS = 8 # max parallel generation slots shown in UI
|
| 133 |
MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video β€ ~64 s at 8 s/seg)
|
|
|
|
| 769 |
|
| 770 |
@spaces.GPU(duration=_taro_duration)
|
| 771 |
def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
|
| 772 |
+
crossfade_s, crossfade_db, num_samples, ctx_json="{}"):
|
| 773 |
"""GPU-only TARO inference β model loading + feature extraction + diffusion.
|
| 774 |
Returns list of (wavs_list, onset_feats) per sample."""
|
| 775 |
seed_val = int(seed_val)
|
|
|
|
| 785 |
from TARO.onset_util import extract_onset
|
| 786 |
from TARO.samplers import euler_sampler, euler_maruyama_sampler
|
| 787 |
|
| 788 |
+
ctx = json.loads(ctx_json)
|
|
|
|
| 789 |
tmp_dir = ctx["tmp_dir"]
|
| 790 |
silent_video = ctx["silent_video"]
|
| 791 |
segments = ctx["segments"]
|
|
|
|
| 857 |
tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
|
| 858 |
video_file, TARO_MODEL_DUR, crossfade_s)
|
| 859 |
|
| 860 |
+
ctx_json = json.dumps({
|
|
|
|
| 861 |
"tmp_dir": tmp_dir, "silent_video": silent_video,
|
| 862 |
"segments": segments, "total_dur_s": total_dur_s,
|
| 863 |
})
|
| 864 |
|
| 865 |
# ββ GPU inference only ββ
|
| 866 |
results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
|
| 867 |
+
crossfade_s, crossfade_db, num_samples, ctx_json)
|
| 868 |
|
| 869 |
# ββ CPU post-processing (no GPU needed) ββ
|
| 870 |
# Upsample 16kHz β 48kHz and normalise result tuples to (seg_wavs, ...)
|
|
|
|
| 921 |
|
| 922 |
@spaces.GPU(duration=_mmaudio_duration)
|
| 923 |
def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
| 924 |
+
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples,
|
| 925 |
+
ctx_json="{}"):
|
| 926 |
"""GPU-only MMAudio inference β model loading + flow-matching generation.
|
| 927 |
Returns list of (seg_audios, sr) per sample."""
|
| 928 |
_ensure_syspath("MMAudio")
|
|
|
|
| 937 |
|
| 938 |
net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
|
| 939 |
|
| 940 |
+
ctx = json.loads(ctx_json)
|
| 941 |
segments = ctx["segments"]
|
| 942 |
seg_clip_paths = ctx["seg_clip_paths"]
|
| 943 |
|
|
|
|
| 1016 |
for i, (s, e) in enumerate(segments)
|
| 1017 |
]
|
| 1018 |
|
| 1019 |
+
ctx_json = json.dumps({"segments": segments, "seg_clip_paths": seg_clip_paths})
|
|
|
|
|
|
|
| 1020 |
|
| 1021 |
# ββ GPU inference only ββ
|
| 1022 |
results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
| 1023 |
+
cfg_strength, num_steps, crossfade_s, crossfade_db,
|
| 1024 |
+
num_samples, ctx_json)
|
| 1025 |
|
| 1026 |
# ββ CPU post-processing ββ
|
| 1027 |
# Resample 44100 β 48000 and normalise tuples to (seg_wavs, ...)
|
|
|
|
| 1066 |
|
| 1067 |
@spaces.GPU(duration=_hunyuan_duration)
|
| 1068 |
def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
| 1069 |
+
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db,
|
| 1070 |
+
num_samples, ctx_json="{}"):
|
| 1071 |
"""GPU-only HunyuanFoley inference β model loading + feature extraction + denoising.
|
| 1072 |
Returns list of (seg_wavs, sr, text_feats) per sample."""
|
| 1073 |
_ensure_syspath("HunyuanVideo-Foley")
|
|
|
|
| 1086 |
|
| 1087 |
model_dict, cfg = _load_hunyuan_model(device, model_size)
|
| 1088 |
|
| 1089 |
+
ctx = json.loads(ctx_json)
|
| 1090 |
segments = ctx["segments"]
|
| 1091 |
total_dur_s = ctx["total_dur_s"]
|
| 1092 |
dummy_seg_path = ctx["dummy_seg_path"]
|
|
|
|
| 1170 |
for i, (s, e) in enumerate(segments)
|
| 1171 |
]
|
| 1172 |
|
| 1173 |
+
ctx_json = json.dumps({
|
| 1174 |
"segments": segments, "total_dur_s": total_dur_s,
|
| 1175 |
"dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
|
| 1176 |
})
|
|
|
|
| 1178 |
# ββ GPU inference only ββ
|
| 1179 |
results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
| 1180 |
guidance_scale, num_steps, model_size,
|
| 1181 |
+
crossfade_s, crossfade_db, num_samples, ctx_json)
|
| 1182 |
|
| 1183 |
# ββ CPU post-processing (no GPU needed) ββ
|
| 1184 |
def _hunyuan_extras(sample_idx, result, td):
|
|
|
|
| 1207 |
# 4. Returns (new_video_path, new_audio_path, updated_seg_meta, new_waveform_html)
|
| 1208 |
# ================================================================== #
|
| 1209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1210 |
|
| 1211 |
def _splice_and_save(new_wav, seg_idx, meta, slot_id):
|
| 1212 |
"""Replace wavs[seg_idx] with new_wav, re-stitch, re-save, re-mux.
|
|
|
|
| 1290 |
_ensure_syspath("TARO")
|
| 1291 |
from TARO.samplers import euler_sampler, euler_maruyama_sampler
|
| 1292 |
|
| 1293 |
+
# Load cached CAVP/onset features from .npy files (CPU I/O, fast, outside GPU budget)
|
| 1294 |
+
cavp_path = meta.get("cavp_path", "")
|
| 1295 |
+
onset_path = meta.get("onset_path", "")
|
| 1296 |
+
if cavp_path and os.path.exists(cavp_path) and onset_path and os.path.exists(onset_path):
|
| 1297 |
+
print("[TARO regen] Loading cached CAVP + onset features from disk")
|
| 1298 |
+
cavp_feats = np.load(cavp_path)
|
| 1299 |
+
onset_feats = np.load(onset_path)
|
| 1300 |
else:
|
| 1301 |
print("[TARO regen] Cache miss β re-extracting CAVP + onset features")
|
| 1302 |
from TARO.onset_util import extract_onset
|
|
|
|
| 1305 |
tmp_dir = tempfile.mkdtemp()
|
| 1306 |
cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
|
| 1307 |
onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
|
|
|
|
| 1308 |
del extract_cavp, onset_model
|
| 1309 |
if torch.cuda.is_available():
|
| 1310 |
torch.cuda.empty_cache()
|
|
|
|
| 1328 |
meta = json.loads(seg_meta_json)
|
| 1329 |
seg_idx = int(seg_idx)
|
| 1330 |
|
| 1331 |
+
# GPU: inference β CAVP/onset features loaded from disk paths in seg_meta_json
|
|
|
|
|
|
|
|
|
|
| 1332 |
new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
|
| 1333 |
seed_val, cfg_scale, num_steps, mode,
|
| 1334 |
crossfade_s, crossfade_db, slot_id)
|
|
|
|
| 1351 |
@spaces.GPU(duration=_mmaudio_regen_duration)
|
| 1352 |
def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
|
| 1353 |
prompt, negative_prompt, seed_val,
|
| 1354 |
+
cfg_strength, num_steps, crossfade_s, crossfade_db,
|
| 1355 |
+
slot_id=None, ctx_json="{}"):
|
| 1356 |
"""GPU-only MMAudio regen β returns (new_wav, sr) for a single segment."""
|
| 1357 |
meta = json.loads(seg_meta_json)
|
| 1358 |
seg_idx = int(seg_idx)
|
|
|
|
| 1368 |
net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
|
| 1369 |
sr = seq_cfg.sampling_rate
|
| 1370 |
|
| 1371 |
+
seg_path = json.loads(ctx_json).get("seg_path")
|
|
|
|
| 1372 |
assert seg_path, "[MMAudio regen] seg_path not set β wrapper must pre-extract segment clip"
|
| 1373 |
|
| 1374 |
rng = torch.Generator(device=device)
|
|
|
|
| 1410 |
meta["silent_video"], seg_start, seg_dur,
|
| 1411 |
os.path.join(tmp_dir, "regen_seg.mp4"),
|
| 1412 |
)
|
| 1413 |
+
ctx_json = json.dumps({"seg_path": seg_path})
|
| 1414 |
|
| 1415 |
# GPU: inference only
|
| 1416 |
new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
|
| 1417 |
prompt, negative_prompt, seed_val,
|
| 1418 |
+
cfg_strength, num_steps, crossfade_s, crossfade_db,
|
| 1419 |
+
slot_id, ctx_json)
|
| 1420 |
|
| 1421 |
# Resample to 48kHz if needed (MMAudio outputs at 44100 Hz)
|
| 1422 |
if sr != TARGET_SR:
|
|
|
|
| 1443 |
def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
|
| 1444 |
prompt, negative_prompt, seed_val,
|
| 1445 |
guidance_scale, num_steps, model_size,
|
| 1446 |
+
crossfade_s, crossfade_db, slot_id=None, ctx_json="{}"):
|
| 1447 |
"""GPU-only HunyuanFoley regen β returns (new_wav, sr) for a single segment."""
|
| 1448 |
meta = json.loads(seg_meta_json)
|
| 1449 |
seg_idx = int(seg_idx)
|
|
|
|
| 1460 |
|
| 1461 |
set_global_seed(random.randint(0, 2**32 - 1))
|
| 1462 |
|
| 1463 |
+
ctx = json.loads(ctx_json)
|
|
|
|
| 1464 |
seg_path = ctx.get("seg_path")
|
| 1465 |
assert seg_path, "[HunyuanFoley regen] seg_path not set β wrapper must pre-extract segment clip"
|
| 1466 |
|
| 1467 |
+
text_feats_path = ctx.get("text_feats_path", "")
|
| 1468 |
+
if text_feats_path and os.path.exists(text_feats_path):
|
| 1469 |
+
print("[HunyuanFoley regen] Loading cached text features from disk")
|
| 1470 |
from hunyuanvideo_foley.utils.feature_utils import encode_video_features
|
| 1471 |
visual_feats, seg_audio_len = encode_video_features(seg_path, model_dict)
|
| 1472 |
+
text_feats = torch.load(text_feats_path, map_location=device, weights_only=False)
|
| 1473 |
else:
|
| 1474 |
print("[HunyuanFoley regen] Cache miss β extracting text + visual features")
|
| 1475 |
visual_feats, text_feats, seg_audio_len = feature_process(
|
|
|
|
| 1504 |
meta["silent_video"], seg_start, seg_dur,
|
| 1505 |
os.path.join(tmp_dir, "regen_seg.mp4"),
|
| 1506 |
)
|
| 1507 |
+
ctx_json = json.dumps({
|
| 1508 |
+
"seg_path": seg_path,
|
| 1509 |
+
"text_feats_path": meta.get("text_feats_path", ""),
|
| 1510 |
+
})
|
| 1511 |
|
| 1512 |
# GPU: inference only
|
| 1513 |
new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
|
| 1514 |
prompt, negative_prompt, seed_val,
|
| 1515 |
guidance_scale, num_steps, model_size,
|
| 1516 |
+
crossfade_s, crossfade_db, slot_id, ctx_json)
|
| 1517 |
|
| 1518 |
meta["sr"] = sr
|
| 1519 |
|
|
|
|
| 1607 |
meta = json.loads(state_json)
|
| 1608 |
|
| 1609 |
def _run():
|
| 1610 |
+
# CAVP/onset features are loaded from disk paths inside the GPU fn
|
| 1611 |
wav = _regen_taro_gpu(None, seg_idx, state_json,
|
| 1612 |
seed_val, cfg_scale, num_steps, mode,
|
| 1613 |
crossfade_s, crossfade_db, slot_id)
|
|
|
|
| 1630 |
meta["silent_video"], seg_start, seg_end - seg_start,
|
| 1631 |
os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
|
| 1632 |
)
|
| 1633 |
+
ctx_json = json.dumps({"seg_path": seg_path})
|
| 1634 |
wav, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
|
| 1635 |
prompt, negative_prompt, seed_val,
|
| 1636 |
cfg_strength, num_steps,
|
| 1637 |
+
crossfade_s, crossfade_db, slot_id, ctx_json)
|
| 1638 |
return wav, src_sr
|
| 1639 |
|
| 1640 |
yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
|
|
|
|
| 1655 |
meta["silent_video"], seg_start, seg_end - seg_start,
|
| 1656 |
os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
|
| 1657 |
)
|
| 1658 |
+
ctx_json = json.dumps({
|
| 1659 |
+
"seg_path": seg_path,
|
| 1660 |
+
"text_feats_path": meta.get("text_feats_path", ""),
|
| 1661 |
+
})
|
| 1662 |
wav, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
|
| 1663 |
prompt, negative_prompt, seed_val,
|
| 1664 |
guidance_scale, num_steps, model_size,
|
| 1665 |
+
crossfade_s, crossfade_db, slot_id, ctx_json)
|
| 1666 |
return wav, src_sr
|
| 1667 |
|
| 1668 |
yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
|