Spaces:
Running on Zero
Running on Zero
Commit ·
e7175d4
1
Parent(s): ac67bf3
fix: remove ctx_key from all function signatures — use fn-name-keyed global dict
Browse filesctx_key as a function argument exposed it to Gradio's API endpoint discovery,
causing 'Too many arguments provided for the endpoint' errors and GPU task aborts.
Fix: remove ctx_key from all @spaces.GPU function signatures and their duration
callables. Store/retrieve context using _ctx_store(fn_name, data) /
_ctx_load(fn_name) — a global dict keyed by function name. This is safe because
ZeroGPU is synchronous (wrapper blocks until GPU fn returns), so only one call
per GPU function is in-flight at a time within a single process.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
app.py
CHANGED
|
@@ -124,35 +124,30 @@ print(f"[startup] All downloads done in {time.perf_counter() - _t_dl_start:.1f}s
|
|
| 124 |
# SHARED CONSTANTS / HELPERS #
|
| 125 |
# ================================================================== #
|
| 126 |
|
| 127 |
-
# CPU → GPU context passing via
|
| 128 |
#
|
| 129 |
-
# ZeroGPU
|
| 130 |
-
# threading.local()
|
| 131 |
-
#
|
| 132 |
-
#
|
| 133 |
-
# gets dropped or set to None before the GPU fn runs.
|
| 134 |
#
|
| 135 |
-
# Solution:
|
| 136 |
-
#
|
| 137 |
-
#
|
| 138 |
-
#
|
| 139 |
-
#
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
def _ctx_store(data: dict) -> str:
|
| 146 |
-
"""Store *data* in the global context dict; return the UUID key."""
|
| 147 |
-
key = _uuid_mod.uuid4().hex
|
| 148 |
with _GPU_CTX_LOCK:
|
| 149 |
-
_GPU_CTX[
|
| 150 |
-
return key
|
| 151 |
|
| 152 |
-
def _ctx_load(
|
| 153 |
-
"""Pop and return the context dict
|
| 154 |
with _GPU_CTX_LOCK:
|
| 155 |
-
return _GPU_CTX.pop(
|
| 156 |
|
| 157 |
MAX_SLOTS = 8 # max parallel generation slots shown in UI
|
| 158 |
MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≤ ~64 s at 8 s/seg)
|
|
@@ -577,7 +572,7 @@ def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: floa
|
|
| 577 |
|
| 578 |
|
| 579 |
def _taro_duration(video_file, seed_val, cfg_scale, num_steps, mode,
|
| 580 |
-
crossfade_s, crossfade_db, num_samples
|
| 581 |
"""Pre-GPU callable — must match _taro_gpu_infer's input order exactly."""
|
| 582 |
return _estimate_gpu_duration("taro", int(num_samples), int(num_steps),
|
| 583 |
video_file=video_file, crossfade_s=crossfade_s)
|
|
@@ -794,7 +789,7 @@ def _cpu_preprocess(video_file: str, model_dur: float,
|
|
| 794 |
|
| 795 |
@spaces.GPU(duration=_taro_duration)
|
| 796 |
def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
|
| 797 |
-
crossfade_s, crossfade_db, num_samples
|
| 798 |
"""GPU-only TARO inference — model loading + feature extraction + diffusion.
|
| 799 |
Returns list of (wavs_list, onset_feats) per sample."""
|
| 800 |
seed_val = int(seed_val)
|
|
@@ -810,7 +805,7 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 810 |
from TARO.onset_util import extract_onset
|
| 811 |
from TARO.samplers import euler_sampler, euler_maruyama_sampler
|
| 812 |
|
| 813 |
-
ctx = _ctx_load(
|
| 814 |
tmp_dir = ctx["tmp_dir"]
|
| 815 |
silent_video = ctx["silent_video"]
|
| 816 |
segments = ctx["segments"]
|
|
@@ -882,14 +877,14 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 882 |
tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
|
| 883 |
video_file, TARO_MODEL_DUR, crossfade_s)
|
| 884 |
|
| 885 |
-
|
| 886 |
"tmp_dir": tmp_dir, "silent_video": silent_video,
|
| 887 |
"segments": segments, "total_dur_s": total_dur_s,
|
| 888 |
})
|
| 889 |
|
| 890 |
# ── GPU inference only ──
|
| 891 |
results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
|
| 892 |
-
crossfade_s, crossfade_db, num_samples
|
| 893 |
|
| 894 |
# ── CPU post-processing (no GPU needed) ──
|
| 895 |
# Upsample 16kHz → 48kHz and normalise result tuples to (seg_wavs, ...)
|
|
@@ -938,8 +933,7 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 938 |
|
| 939 |
|
| 940 |
def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
|
| 941 |
-
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples
|
| 942 |
-
ctx_key=""):
|
| 943 |
"""Pre-GPU callable — must match _mmaudio_gpu_infer's input order exactly."""
|
| 944 |
return _estimate_gpu_duration("mmaudio", int(num_samples), int(num_steps),
|
| 945 |
video_file=video_file, crossfade_s=crossfade_s)
|
|
@@ -947,8 +941,7 @@ def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
|
|
| 947 |
|
| 948 |
@spaces.GPU(duration=_mmaudio_duration)
|
| 949 |
def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
| 950 |
-
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples
|
| 951 |
-
ctx_key=""):
|
| 952 |
"""GPU-only MMAudio inference — model loading + flow-matching generation.
|
| 953 |
Returns list of (seg_audios, sr) per sample."""
|
| 954 |
_ensure_syspath("MMAudio")
|
|
@@ -963,7 +956,7 @@ def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
|
| 963 |
|
| 964 |
net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
|
| 965 |
|
| 966 |
-
ctx = _ctx_load(
|
| 967 |
segments = ctx["segments"]
|
| 968 |
seg_clip_paths = ctx["seg_clip_paths"]
|
| 969 |
|
|
@@ -1042,12 +1035,12 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
|
|
| 1042 |
for i, (s, e) in enumerate(segments)
|
| 1043 |
]
|
| 1044 |
|
| 1045 |
-
|
| 1046 |
|
| 1047 |
# ── GPU inference only ──
|
| 1048 |
results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
| 1049 |
cfg_strength, num_steps, crossfade_s, crossfade_db,
|
| 1050 |
-
num_samples
|
| 1051 |
|
| 1052 |
# ── CPU post-processing ──
|
| 1053 |
# Resample 44100 → 48000 and normalise tuples to (seg_wavs, ...)
|
|
@@ -1085,7 +1078,7 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
|
|
| 1085 |
|
| 1086 |
def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
|
| 1087 |
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db,
|
| 1088 |
-
num_samples
|
| 1089 |
"""Pre-GPU callable — must match _hunyuan_gpu_infer's input order exactly."""
|
| 1090 |
return _estimate_gpu_duration("hunyuan", int(num_samples), int(num_steps),
|
| 1091 |
video_file=video_file, crossfade_s=crossfade_s)
|
|
@@ -1094,7 +1087,7 @@ def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
|
|
| 1094 |
@spaces.GPU(duration=_hunyuan_duration)
|
| 1095 |
def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
| 1096 |
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db,
|
| 1097 |
-
num_samples
|
| 1098 |
"""GPU-only HunyuanFoley inference — model loading + feature extraction + denoising.
|
| 1099 |
Returns list of (seg_wavs, sr, text_feats) per sample."""
|
| 1100 |
_ensure_syspath("HunyuanVideo-Foley")
|
|
@@ -1113,7 +1106,7 @@ def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
|
| 1113 |
|
| 1114 |
model_dict, cfg = _load_hunyuan_model(device, model_size)
|
| 1115 |
|
| 1116 |
-
ctx = _ctx_load(
|
| 1117 |
segments = ctx["segments"]
|
| 1118 |
total_dur_s = ctx["total_dur_s"]
|
| 1119 |
dummy_seg_path = ctx["dummy_seg_path"]
|
|
@@ -1197,7 +1190,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
|
|
| 1197 |
for i, (s, e) in enumerate(segments)
|
| 1198 |
]
|
| 1199 |
|
| 1200 |
-
|
| 1201 |
"segments": segments, "total_dur_s": total_dur_s,
|
| 1202 |
"dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
|
| 1203 |
})
|
|
@@ -1205,7 +1198,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
|
|
| 1205 |
# ── GPU inference only ──
|
| 1206 |
results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
| 1207 |
guidance_scale, num_steps, model_size,
|
| 1208 |
-
crossfade_s, crossfade_db, num_samples
|
| 1209 |
|
| 1210 |
# ── CPU post-processing (no GPU needed) ──
|
| 1211 |
def _hunyuan_extras(sample_idx, result, td):
|
|
@@ -1285,7 +1278,7 @@ def _splice_and_save(new_wav, seg_idx, meta, slot_id):
|
|
| 1285 |
|
| 1286 |
def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
|
| 1287 |
seed_val, cfg_scale, num_steps, mode,
|
| 1288 |
-
crossfade_s, crossfade_db, slot_id=None
|
| 1289 |
# If cached CAVP/onset features exist, skip ~10s feature-extractor overhead
|
| 1290 |
try:
|
| 1291 |
meta = json.loads(seg_meta_json)
|
|
@@ -1305,7 +1298,7 @@ def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
|
|
| 1305 |
@spaces.GPU(duration=_taro_regen_duration)
|
| 1306 |
def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
|
| 1307 |
seed_val, cfg_scale, num_steps, mode,
|
| 1308 |
-
crossfade_s, crossfade_db, slot_id=None
|
| 1309 |
"""GPU-only TARO regen — returns new_wav for a single segment."""
|
| 1310 |
meta = json.loads(seg_meta_json)
|
| 1311 |
seg_idx = int(seg_idx)
|
|
@@ -1372,7 +1365,7 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
|
|
| 1372 |
def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
|
| 1373 |
prompt, negative_prompt, seed_val,
|
| 1374 |
cfg_strength, num_steps, crossfade_s, crossfade_db,
|
| 1375 |
-
slot_id=None
|
| 1376 |
return _estimate_regen_duration("mmaudio", int(num_steps))
|
| 1377 |
|
| 1378 |
|
|
@@ -1380,7 +1373,7 @@ def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
|
|
| 1380 |
def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
|
| 1381 |
prompt, negative_prompt, seed_val,
|
| 1382 |
cfg_strength, num_steps, crossfade_s, crossfade_db,
|
| 1383 |
-
slot_id=None
|
| 1384 |
"""GPU-only MMAudio regen — returns (new_wav, sr) for a single segment."""
|
| 1385 |
meta = json.loads(seg_meta_json)
|
| 1386 |
seg_idx = int(seg_idx)
|
|
@@ -1396,7 +1389,7 @@ def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
|
|
| 1396 |
net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
|
| 1397 |
sr = seq_cfg.sampling_rate
|
| 1398 |
|
| 1399 |
-
seg_path = _ctx_load(
|
| 1400 |
assert seg_path, "[MMAudio regen] seg_path not set — wrapper must pre-extract segment clip"
|
| 1401 |
|
| 1402 |
rng = torch.Generator(device=device)
|
|
@@ -1438,13 +1431,13 @@ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
|
|
| 1438 |
meta["silent_video"], seg_start, seg_dur,
|
| 1439 |
os.path.join(tmp_dir, "regen_seg.mp4"),
|
| 1440 |
)
|
| 1441 |
-
|
| 1442 |
|
| 1443 |
# GPU: inference only
|
| 1444 |
new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
|
| 1445 |
prompt, negative_prompt, seed_val,
|
| 1446 |
cfg_strength, num_steps, crossfade_s, crossfade_db,
|
| 1447 |
-
slot_id
|
| 1448 |
|
| 1449 |
# Resample to 48kHz if needed (MMAudio outputs at 44100 Hz)
|
| 1450 |
if sr != TARGET_SR:
|
|
@@ -1463,7 +1456,7 @@ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
|
|
| 1463 |
def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json,
|
| 1464 |
prompt, negative_prompt, seed_val,
|
| 1465 |
guidance_scale, num_steps, model_size,
|
| 1466 |
-
crossfade_s, crossfade_db, slot_id=None
|
| 1467 |
return _estimate_regen_duration("hunyuan", int(num_steps))
|
| 1468 |
|
| 1469 |
|
|
@@ -1471,7 +1464,7 @@ def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json,
|
|
| 1471 |
def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
|
| 1472 |
prompt, negative_prompt, seed_val,
|
| 1473 |
guidance_scale, num_steps, model_size,
|
| 1474 |
-
crossfade_s, crossfade_db, slot_id=None
|
| 1475 |
"""GPU-only HunyuanFoley regen — returns (new_wav, sr) for a single segment."""
|
| 1476 |
meta = json.loads(seg_meta_json)
|
| 1477 |
seg_idx = int(seg_idx)
|
|
@@ -1488,7 +1481,7 @@ def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
|
|
| 1488 |
|
| 1489 |
set_global_seed(random.randint(0, 2**32 - 1))
|
| 1490 |
|
| 1491 |
-
ctx = _ctx_load(
|
| 1492 |
seg_path = ctx.get("seg_path")
|
| 1493 |
assert seg_path, "[HunyuanFoley regen] seg_path not set — wrapper must pre-extract segment clip"
|
| 1494 |
|
|
@@ -1532,7 +1525,7 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
|
|
| 1532 |
meta["silent_video"], seg_start, seg_dur,
|
| 1533 |
os.path.join(tmp_dir, "regen_seg.mp4"),
|
| 1534 |
)
|
| 1535 |
-
|
| 1536 |
"seg_path": seg_path,
|
| 1537 |
"text_feats_path": meta.get("text_feats_path", ""),
|
| 1538 |
})
|
|
@@ -1541,7 +1534,7 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
|
|
| 1541 |
new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
|
| 1542 |
prompt, negative_prompt, seed_val,
|
| 1543 |
guidance_scale, num_steps, model_size,
|
| 1544 |
-
crossfade_s, crossfade_db, slot_id
|
| 1545 |
|
| 1546 |
meta["sr"] = sr
|
| 1547 |
|
|
@@ -1658,11 +1651,11 @@ def xregen_mmaudio(seg_idx, state_json, slot_id,
|
|
| 1658 |
meta["silent_video"], seg_start, seg_end - seg_start,
|
| 1659 |
os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
|
| 1660 |
)
|
| 1661 |
-
|
| 1662 |
wav, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
|
| 1663 |
prompt, negative_prompt, seed_val,
|
| 1664 |
cfg_strength, num_steps,
|
| 1665 |
-
crossfade_s, crossfade_db, slot_id
|
| 1666 |
return wav, src_sr
|
| 1667 |
|
| 1668 |
yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
|
|
@@ -1683,14 +1676,14 @@ def xregen_hunyuan(seg_idx, state_json, slot_id,
|
|
| 1683 |
meta["silent_video"], seg_start, seg_end - seg_start,
|
| 1684 |
os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
|
| 1685 |
)
|
| 1686 |
-
|
| 1687 |
"seg_path": seg_path,
|
| 1688 |
"text_feats_path": meta.get("text_feats_path", ""),
|
| 1689 |
})
|
| 1690 |
wav, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
|
| 1691 |
prompt, negative_prompt, seed_val,
|
| 1692 |
guidance_scale, num_steps, model_size,
|
| 1693 |
-
crossfade_s, crossfade_db, slot_id
|
| 1694 |
return wav, src_sr
|
| 1695 |
|
| 1696 |
yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
|
|
|
|
| 124 |
# SHARED CONSTANTS / HELPERS #
|
| 125 |
# ================================================================== #
|
| 126 |
|
| 127 |
+
# CPU → GPU context passing via function-name-keyed global store.
|
| 128 |
#
|
| 129 |
+
# Problem: ZeroGPU runs @spaces.GPU functions on its own worker thread, so
|
| 130 |
+
# threading.local() is invisible to the GPU worker. Passing ctx as a
|
| 131 |
+
# function argument exposes it to Gradio's API endpoint, causing
|
| 132 |
+
# "Too many arguments" errors.
|
|
|
|
| 133 |
#
|
| 134 |
+
# Solution: store context in a plain global dict keyed by function name.
|
| 135 |
+
# A per-key Lock serialises concurrent callers for the same function
|
| 136 |
+
# (ZeroGPU is already synchronous — the wrapper blocks until the GPU fn
|
| 137 |
+
# returns — so in practice only one call per GPU fn is in-flight at a time).
|
| 138 |
+
# The global dict is readable from any thread.
|
| 139 |
+
_GPU_CTX: dict = {}
|
| 140 |
+
_GPU_CTX_LOCK = threading.Lock()
|
| 141 |
+
|
| 142 |
+
def _ctx_store(fn_name: str, data: dict) -> None:
|
| 143 |
+
"""Store *data* under *fn_name* key (overwrites previous)."""
|
|
|
|
|
|
|
|
|
|
| 144 |
with _GPU_CTX_LOCK:
|
| 145 |
+
_GPU_CTX[fn_name] = data
|
|
|
|
| 146 |
|
| 147 |
+
def _ctx_load(fn_name: str) -> dict:
|
| 148 |
+
"""Pop and return the context dict stored under *fn_name*."""
|
| 149 |
with _GPU_CTX_LOCK:
|
| 150 |
+
return _GPU_CTX.pop(fn_name, {})
|
| 151 |
|
| 152 |
MAX_SLOTS = 8 # max parallel generation slots shown in UI
|
| 153 |
MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≤ ~64 s at 8 s/seg)
|
|
|
|
| 572 |
|
| 573 |
|
| 574 |
def _taro_duration(video_file, seed_val, cfg_scale, num_steps, mode,
|
| 575 |
+
crossfade_s, crossfade_db, num_samples):
|
| 576 |
"""Pre-GPU callable — must match _taro_gpu_infer's input order exactly."""
|
| 577 |
return _estimate_gpu_duration("taro", int(num_samples), int(num_steps),
|
| 578 |
video_file=video_file, crossfade_s=crossfade_s)
|
|
|
|
| 789 |
|
| 790 |
@spaces.GPU(duration=_taro_duration)
|
| 791 |
def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
|
| 792 |
+
crossfade_s, crossfade_db, num_samples):
|
| 793 |
"""GPU-only TARO inference — model loading + feature extraction + diffusion.
|
| 794 |
Returns list of (wavs_list, onset_feats) per sample."""
|
| 795 |
seed_val = int(seed_val)
|
|
|
|
| 805 |
from TARO.onset_util import extract_onset
|
| 806 |
from TARO.samplers import euler_sampler, euler_maruyama_sampler
|
| 807 |
|
| 808 |
+
ctx = _ctx_load("taro_gpu_infer")
|
| 809 |
tmp_dir = ctx["tmp_dir"]
|
| 810 |
silent_video = ctx["silent_video"]
|
| 811 |
segments = ctx["segments"]
|
|
|
|
| 877 |
tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
|
| 878 |
video_file, TARO_MODEL_DUR, crossfade_s)
|
| 879 |
|
| 880 |
+
_ctx_store("taro_gpu_infer", {
|
| 881 |
"tmp_dir": tmp_dir, "silent_video": silent_video,
|
| 882 |
"segments": segments, "total_dur_s": total_dur_s,
|
| 883 |
})
|
| 884 |
|
| 885 |
# ── GPU inference only ──
|
| 886 |
results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
|
| 887 |
+
crossfade_s, crossfade_db, num_samples)
|
| 888 |
|
| 889 |
# ── CPU post-processing (no GPU needed) ──
|
| 890 |
# Upsample 16kHz → 48kHz and normalise result tuples to (seg_wavs, ...)
|
|
|
|
| 933 |
|
| 934 |
|
| 935 |
def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
|
| 936 |
+
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
|
|
|
|
| 937 |
"""Pre-GPU callable — must match _mmaudio_gpu_infer's input order exactly."""
|
| 938 |
return _estimate_gpu_duration("mmaudio", int(num_samples), int(num_steps),
|
| 939 |
video_file=video_file, crossfade_s=crossfade_s)
|
|
|
|
| 941 |
|
| 942 |
@spaces.GPU(duration=_mmaudio_duration)
|
| 943 |
def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
| 944 |
+
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
|
|
|
|
| 945 |
"""GPU-only MMAudio inference — model loading + flow-matching generation.
|
| 946 |
Returns list of (seg_audios, sr) per sample."""
|
| 947 |
_ensure_syspath("MMAudio")
|
|
|
|
| 956 |
|
| 957 |
net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
|
| 958 |
|
| 959 |
+
ctx = _ctx_load("mmaudio_gpu_infer")
|
| 960 |
segments = ctx["segments"]
|
| 961 |
seg_clip_paths = ctx["seg_clip_paths"]
|
| 962 |
|
|
|
|
| 1035 |
for i, (s, e) in enumerate(segments)
|
| 1036 |
]
|
| 1037 |
|
| 1038 |
+
_ctx_store("mmaudio_gpu_infer", {"segments": segments, "seg_clip_paths": seg_clip_paths})
|
| 1039 |
|
| 1040 |
# ── GPU inference only ──
|
| 1041 |
results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
| 1042 |
cfg_strength, num_steps, crossfade_s, crossfade_db,
|
| 1043 |
+
num_samples)
|
| 1044 |
|
| 1045 |
# ── CPU post-processing ──
|
| 1046 |
# Resample 44100 → 48000 and normalise tuples to (seg_wavs, ...)
|
|
|
|
| 1078 |
|
| 1079 |
def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
|
| 1080 |
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db,
|
| 1081 |
+
num_samples):
|
| 1082 |
"""Pre-GPU callable — must match _hunyuan_gpu_infer's input order exactly."""
|
| 1083 |
return _estimate_gpu_duration("hunyuan", int(num_samples), int(num_steps),
|
| 1084 |
video_file=video_file, crossfade_s=crossfade_s)
|
|
|
|
| 1087 |
@spaces.GPU(duration=_hunyuan_duration)
|
| 1088 |
def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
| 1089 |
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db,
|
| 1090 |
+
num_samples):
|
| 1091 |
"""GPU-only HunyuanFoley inference — model loading + feature extraction + denoising.
|
| 1092 |
Returns list of (seg_wavs, sr, text_feats) per sample."""
|
| 1093 |
_ensure_syspath("HunyuanVideo-Foley")
|
|
|
|
| 1106 |
|
| 1107 |
model_dict, cfg = _load_hunyuan_model(device, model_size)
|
| 1108 |
|
| 1109 |
+
ctx = _ctx_load("hunyuan_gpu_infer")
|
| 1110 |
segments = ctx["segments"]
|
| 1111 |
total_dur_s = ctx["total_dur_s"]
|
| 1112 |
dummy_seg_path = ctx["dummy_seg_path"]
|
|
|
|
| 1190 |
for i, (s, e) in enumerate(segments)
|
| 1191 |
]
|
| 1192 |
|
| 1193 |
+
_ctx_store("hunyuan_gpu_infer", {
|
| 1194 |
"segments": segments, "total_dur_s": total_dur_s,
|
| 1195 |
"dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
|
| 1196 |
})
|
|
|
|
| 1198 |
# ── GPU inference only ──
|
| 1199 |
results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
| 1200 |
guidance_scale, num_steps, model_size,
|
| 1201 |
+
crossfade_s, crossfade_db, num_samples)
|
| 1202 |
|
| 1203 |
# ── CPU post-processing (no GPU needed) ──
|
| 1204 |
def _hunyuan_extras(sample_idx, result, td):
|
|
|
|
| 1278 |
|
| 1279 |
def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
|
| 1280 |
seed_val, cfg_scale, num_steps, mode,
|
| 1281 |
+
crossfade_s, crossfade_db, slot_id=None):
|
| 1282 |
# If cached CAVP/onset features exist, skip ~10s feature-extractor overhead
|
| 1283 |
try:
|
| 1284 |
meta = json.loads(seg_meta_json)
|
|
|
|
| 1298 |
@spaces.GPU(duration=_taro_regen_duration)
|
| 1299 |
def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
|
| 1300 |
seed_val, cfg_scale, num_steps, mode,
|
| 1301 |
+
crossfade_s, crossfade_db, slot_id=None):
|
| 1302 |
"""GPU-only TARO regen — returns new_wav for a single segment."""
|
| 1303 |
meta = json.loads(seg_meta_json)
|
| 1304 |
seg_idx = int(seg_idx)
|
|
|
|
| 1365 |
def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
|
| 1366 |
prompt, negative_prompt, seed_val,
|
| 1367 |
cfg_strength, num_steps, crossfade_s, crossfade_db,
|
| 1368 |
+
slot_id=None):
|
| 1369 |
return _estimate_regen_duration("mmaudio", int(num_steps))
|
| 1370 |
|
| 1371 |
|
|
|
|
| 1373 |
def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
|
| 1374 |
prompt, negative_prompt, seed_val,
|
| 1375 |
cfg_strength, num_steps, crossfade_s, crossfade_db,
|
| 1376 |
+
slot_id=None):
|
| 1377 |
"""GPU-only MMAudio regen — returns (new_wav, sr) for a single segment."""
|
| 1378 |
meta = json.loads(seg_meta_json)
|
| 1379 |
seg_idx = int(seg_idx)
|
|
|
|
| 1389 |
net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
|
| 1390 |
sr = seq_cfg.sampling_rate
|
| 1391 |
|
| 1392 |
+
seg_path = _ctx_load("regen_mmaudio_gpu").get("seg_path")
|
| 1393 |
assert seg_path, "[MMAudio regen] seg_path not set — wrapper must pre-extract segment clip"
|
| 1394 |
|
| 1395 |
rng = torch.Generator(device=device)
|
|
|
|
| 1431 |
meta["silent_video"], seg_start, seg_dur,
|
| 1432 |
os.path.join(tmp_dir, "regen_seg.mp4"),
|
| 1433 |
)
|
| 1434 |
+
_ctx_store("regen_mmaudio_gpu", {"seg_path": seg_path})
|
| 1435 |
|
| 1436 |
# GPU: inference only
|
| 1437 |
new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
|
| 1438 |
prompt, negative_prompt, seed_val,
|
| 1439 |
cfg_strength, num_steps, crossfade_s, crossfade_db,
|
| 1440 |
+
slot_id)
|
| 1441 |
|
| 1442 |
# Resample to 48kHz if needed (MMAudio outputs at 44100 Hz)
|
| 1443 |
if sr != TARGET_SR:
|
|
|
|
| 1456 |
def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json,
|
| 1457 |
prompt, negative_prompt, seed_val,
|
| 1458 |
guidance_scale, num_steps, model_size,
|
| 1459 |
+
crossfade_s, crossfade_db, slot_id=None):
|
| 1460 |
return _estimate_regen_duration("hunyuan", int(num_steps))
|
| 1461 |
|
| 1462 |
|
|
|
|
| 1464 |
def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
|
| 1465 |
prompt, negative_prompt, seed_val,
|
| 1466 |
guidance_scale, num_steps, model_size,
|
| 1467 |
+
crossfade_s, crossfade_db, slot_id=None):
|
| 1468 |
"""GPU-only HunyuanFoley regen — returns (new_wav, sr) for a single segment."""
|
| 1469 |
meta = json.loads(seg_meta_json)
|
| 1470 |
seg_idx = int(seg_idx)
|
|
|
|
| 1481 |
|
| 1482 |
set_global_seed(random.randint(0, 2**32 - 1))
|
| 1483 |
|
| 1484 |
+
ctx = _ctx_load("regen_hunyuan_gpu")
|
| 1485 |
seg_path = ctx.get("seg_path")
|
| 1486 |
assert seg_path, "[HunyuanFoley regen] seg_path not set — wrapper must pre-extract segment clip"
|
| 1487 |
|
|
|
|
| 1525 |
meta["silent_video"], seg_start, seg_dur,
|
| 1526 |
os.path.join(tmp_dir, "regen_seg.mp4"),
|
| 1527 |
)
|
| 1528 |
+
_ctx_store("regen_hunyuan_gpu", {
|
| 1529 |
"seg_path": seg_path,
|
| 1530 |
"text_feats_path": meta.get("text_feats_path", ""),
|
| 1531 |
})
|
|
|
|
| 1534 |
new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
|
| 1535 |
prompt, negative_prompt, seed_val,
|
| 1536 |
guidance_scale, num_steps, model_size,
|
| 1537 |
+
crossfade_s, crossfade_db, slot_id)
|
| 1538 |
|
| 1539 |
meta["sr"] = sr
|
| 1540 |
|
|
|
|
| 1651 |
meta["silent_video"], seg_start, seg_end - seg_start,
|
| 1652 |
os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
|
| 1653 |
)
|
| 1654 |
+
_ctx_store("regen_mmaudio_gpu", {"seg_path": seg_path})
|
| 1655 |
wav, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
|
| 1656 |
prompt, negative_prompt, seed_val,
|
| 1657 |
cfg_strength, num_steps,
|
| 1658 |
+
crossfade_s, crossfade_db, slot_id)
|
| 1659 |
return wav, src_sr
|
| 1660 |
|
| 1661 |
yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
|
|
|
|
| 1676 |
meta["silent_video"], seg_start, seg_end - seg_start,
|
| 1677 |
os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
|
| 1678 |
)
|
| 1679 |
+
_ctx_store("regen_hunyuan_gpu", {
|
| 1680 |
"seg_path": seg_path,
|
| 1681 |
"text_feats_path": meta.get("text_feats_path", ""),
|
| 1682 |
})
|
| 1683 |
wav, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
|
| 1684 |
prompt, negative_prompt, seed_val,
|
| 1685 |
guidance_scale, num_steps, model_size,
|
| 1686 |
+
crossfade_s, crossfade_db, slot_id)
|
| 1687 |
return wav, src_sr
|
| 1688 |
|
| 1689 |
yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
|