Spaces:
Running on Zero
refactor: consolidate duplicated code via MODEL_CONFIGS registry
Browse files- Add MODEL_CONFIGS dict as single source of truth for per-model constants
(window_s, sr, secs_per_step, load_overhead, tab_prefix, regen_fn, label)
- Replace 6 nearly-identical duration estimators with 2 generic functions:
_estimate_gpu_duration() and _estimate_regen_duration()
- Replace 3 duplicated regen button factory loops (~35 lines each, 40%
duplication) with single _register_regen_handlers() function
- Fix import redundancy: consolidate threading/shutil imports at top of file,
remove duplicate `import threading` at line 310 and inline `import shutil`
Net effect: ~170 lines of duplicated boilerplate eliminated, all model-specific
behavior now parameterized through the registry. Future model additions only need
a new MODEL_CONFIGS entry + model-specific GPU/regen functions.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
@@ -11,12 +11,13 @@ Supported models
|
|
| 11 |
import os
|
| 12 |
import sys
|
| 13 |
import json
|
|
|
|
| 14 |
import tempfile
|
| 15 |
import random
|
| 16 |
import threading
|
|
|
|
| 17 |
from pathlib import Path
|
| 18 |
|
| 19 |
-
import time
|
| 20 |
import torch
|
| 21 |
import numpy as np
|
| 22 |
import torchaudio
|
|
@@ -118,7 +119,6 @@ _TEMP_DIRS_MAX = 10 # keep at most this many; older ones get cleaned up
|
|
| 118 |
|
| 119 |
def _register_tmp_dir(tmp_dir: str) -> str:
|
| 120 |
"""Register a temp dir so it can be cleaned up when newer ones replace it."""
|
| 121 |
-
import shutil
|
| 122 |
_TEMP_DIRS.append(tmp_dir)
|
| 123 |
while len(_TEMP_DIRS) > _TEMP_DIRS_MAX:
|
| 124 |
old = _TEMP_DIRS.pop(0)
|
|
@@ -305,9 +305,73 @@ HUNYUAN_SECS_PER_STEP = 0.35 # measured 0.328s/step on H200 (8.3s video, 1 seg
|
|
| 305 |
HUNYUAN_LOAD_OVERHEAD = 55 # ~55s to load the 10GB XXL model weights into GPU
|
| 306 |
GPU_DURATION_CAP = 300 # hard cap per call — never reserve more than this
|
| 307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
_TARO_CACHE_MAXLEN = 16 # evict oldest entries beyond this limit
|
| 309 |
_TARO_INFERENCE_CACHE: dict = {} # keyed by (video_file, seed, cfg, steps, mode, crossfade_s)
|
| 310 |
-
import threading
|
| 311 |
_TARO_CACHE_LOCK = threading.Lock()
|
| 312 |
|
| 313 |
|
|
@@ -320,16 +384,9 @@ def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: floa
|
|
| 320 |
|
| 321 |
def _taro_duration(video_file, seed_val, cfg_scale, num_steps, mode,
|
| 322 |
crossfade_s, crossfade_db, num_samples):
|
| 323 |
-
"""Pre-GPU callable — must match
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
n_segs = len(_build_segments(total_s, TARO_MODEL_DUR, float(crossfade_s)))
|
| 327 |
-
except Exception:
|
| 328 |
-
n_segs = 1
|
| 329 |
-
secs = int(num_samples) * n_segs * int(num_steps) * TARO_SECS_PER_STEP + TARO_LOAD_OVERHEAD
|
| 330 |
-
result = min(GPU_DURATION_CAP, max(60, int(secs)))
|
| 331 |
-
print(f"[duration] TARO: {int(num_samples)}samp × {n_segs}seg × {int(num_steps)}steps → {secs:.0f}s → capped {result}s")
|
| 332 |
-
return result
|
| 333 |
|
| 334 |
|
| 335 |
def _taro_infer_segment(
|
|
@@ -558,16 +615,9 @@ MMAUDIO_WINDOW = 8.0 # seconds — MMAudio's fixed generation window
|
|
| 558 |
|
| 559 |
def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
|
| 560 |
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
|
| 561 |
-
"""Pre-GPU callable — must match
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
n_segs = len(_build_segments(total_s, MMAUDIO_WINDOW, float(crossfade_s)))
|
| 565 |
-
except Exception:
|
| 566 |
-
n_segs = 1
|
| 567 |
-
secs = int(num_samples) * n_segs * int(num_steps) * MMAUDIO_SECS_PER_STEP + MMAUDIO_LOAD_OVERHEAD
|
| 568 |
-
result = min(GPU_DURATION_CAP, max(60, int(secs)))
|
| 569 |
-
print(f"[duration] MMAudio: {int(num_samples)}samp × {n_segs}seg × {int(num_steps)}steps → {secs:.0f}s → capped {result}s")
|
| 570 |
-
return result
|
| 571 |
|
| 572 |
|
| 573 |
@spaces.GPU(duration=_mmaudio_duration)
|
|
@@ -735,16 +785,9 @@ HUNYUAN_MAX_DUR = 15.0 # seconds
|
|
| 735 |
|
| 736 |
def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
|
| 737 |
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
|
| 738 |
-
"""Pre-GPU callable — must match
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
n_segs = len(_build_segments(total_s, HUNYUAN_MAX_DUR, float(crossfade_s)))
|
| 742 |
-
except Exception:
|
| 743 |
-
n_segs = 1
|
| 744 |
-
secs = int(num_samples) * n_segs * int(num_steps) * HUNYUAN_SECS_PER_STEP + HUNYUAN_LOAD_OVERHEAD
|
| 745 |
-
result = min(GPU_DURATION_CAP, max(60, int(secs)))
|
| 746 |
-
print(f"[duration] HunyuanFoley: {int(num_samples)}samp × {n_segs}seg × {int(num_steps)}steps → {secs:.0f}s → capped {result}s")
|
| 747 |
-
return result
|
| 748 |
|
| 749 |
|
| 750 |
@spaces.GPU(duration=_hunyuan_duration)
|
|
@@ -993,10 +1036,7 @@ def _splice_and_save(new_wav, seg_idx, meta, slot_id):
|
|
| 993 |
def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
|
| 994 |
seed_val, cfg_scale, num_steps, mode,
|
| 995 |
crossfade_s, crossfade_db, slot_id=None):
|
| 996 |
-
|
| 997 |
-
result = min(GPU_DURATION_CAP, max(60, int(secs)))
|
| 998 |
-
print(f"[duration] TARO regen: 1 seg × {int(num_steps)} steps → {secs:.0f}s → capped {result}s")
|
| 999 |
-
return result
|
| 1000 |
|
| 1001 |
|
| 1002 |
@spaces.GPU(duration=_taro_regen_duration)
|
|
@@ -1067,10 +1107,7 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
|
|
| 1067 |
def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
|
| 1068 |
prompt, negative_prompt, seed_val,
|
| 1069 |
cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id=None):
|
| 1070 |
-
|
| 1071 |
-
result = min(GPU_DURATION_CAP, max(60, int(secs)))
|
| 1072 |
-
print(f"[duration] MMAudio regen: 1 seg × {int(num_steps)} steps → {secs:.0f}s → capped {result}s")
|
| 1073 |
-
return result
|
| 1074 |
|
| 1075 |
|
| 1076 |
@spaces.GPU(duration=_mmaudio_regen_duration)
|
|
@@ -1169,10 +1206,7 @@ def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json,
|
|
| 1169 |
prompt, negative_prompt, seed_val,
|
| 1170 |
guidance_scale, num_steps, model_size,
|
| 1171 |
crossfade_s, crossfade_db, slot_id=None):
|
| 1172 |
-
|
| 1173 |
-
result = min(GPU_DURATION_CAP, max(60, int(secs)))
|
| 1174 |
-
print(f"[duration] HunyuanFoley regen: 1 seg × {int(num_steps)} steps → {secs:.0f}s → capped {result}s")
|
| 1175 |
-
return result
|
| 1176 |
|
| 1177 |
|
| 1178 |
@spaces.GPU(duration=_hunyuan_regen_duration)
|
|
@@ -1268,10 +1302,82 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
|
|
| 1268 |
return video_path, audio_path, json.dumps(updated_meta), waveform_html
|
| 1269 |
|
| 1270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1271 |
# ================================================================== #
|
| 1272 |
# SHARED UI HELPERS #
|
| 1273 |
# ================================================================== #
|
| 1274 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1275 |
def _pad_outputs(outputs: list) -> list:
|
| 1276 |
"""Flatten (video, audio, seg_meta) triples and pad to MAX_SLOTS * 3 with None.
|
| 1277 |
|
|
@@ -2073,52 +2179,14 @@ with gr.Blocks(title="Generate Audio for Video", css=_SLOT_CSS, js=_GLOBAL_JS) a
|
|
| 2073 |
outputs=taro_slot_grps,
|
| 2074 |
))
|
| 2075 |
|
| 2076 |
-
# Per-slot regen handlers
|
| 2077 |
-
#
|
| 2078 |
-
|
| 2079 |
-
|
| 2080 |
-
|
| 2081 |
-
|
| 2082 |
-
|
| 2083 |
-
|
| 2084 |
-
_slot_id = f"taro_{_i}"
|
| 2085 |
-
_btn = gr.Button(visible=False, elem_id=f"regen_btn_{_slot_id}")
|
| 2086 |
-
taro_regen_btns.append(_btn)
|
| 2087 |
-
print(f"[startup] registering regen handler for slot {_slot_id}")
|
| 2088 |
-
def _make_taro_regen(_si, _sid):
|
| 2089 |
-
def _do(seg_idx, state_json, video, seed, cfg, steps, mode, cf_dur, cf_db):
|
| 2090 |
-
print(f"[regen TARO] slot={_sid} seg_idx={seg_idx} state_json_len={len(state_json) if state_json else 0}")
|
| 2091 |
-
if not state_json:
|
| 2092 |
-
print(f"[regen TARO] early-exit: state_json empty")
|
| 2093 |
-
yield gr.update(), gr.update(); return
|
| 2094 |
-
lock = _get_slot_lock(_sid)
|
| 2095 |
-
with lock:
|
| 2096 |
-
print(f"[regen TARO] slot={_sid} seg_idx={seg_idx} — lock acquired, showing spinner")
|
| 2097 |
-
state = json.loads(state_json)
|
| 2098 |
-
pending_html = _build_regen_pending_html(
|
| 2099 |
-
state["segments"], int(seg_idx), _sid, ""
|
| 2100 |
-
)
|
| 2101 |
-
yield gr.update(), gr.update(value=pending_html)
|
| 2102 |
-
print(f"[regen TARO] slot={_sid} seg_idx={seg_idx} — calling regen_taro_segment")
|
| 2103 |
-
try:
|
| 2104 |
-
vid, aud, new_meta_json, html = regen_taro_segment(
|
| 2105 |
-
video, int(seg_idx), state_json,
|
| 2106 |
-
seed, cfg, steps, mode, cf_dur, cf_db, _sid,
|
| 2107 |
-
)
|
| 2108 |
-
print(f"[regen TARO] slot={_sid} seg_idx={seg_idx} — done, vid={vid!r}")
|
| 2109 |
-
except Exception as _e:
|
| 2110 |
-
print(f"[regen TARO] slot={_sid} seg_idx={seg_idx} — ERROR: {_e}")
|
| 2111 |
-
raise
|
| 2112 |
-
yield gr.update(value=vid), gr.update(value=html)
|
| 2113 |
-
return _do
|
| 2114 |
-
_btn.click(
|
| 2115 |
-
fn=_make_taro_regen(_i, _slot_id),
|
| 2116 |
-
inputs=[taro_regen_seg, taro_regen_state,
|
| 2117 |
-
taro_video, taro_seed, taro_cfg, taro_steps,
|
| 2118 |
-
taro_mode, taro_cf_dur, taro_cf_db],
|
| 2119 |
-
outputs=[taro_slot_vids[_i], taro_slot_waves[_i]],
|
| 2120 |
-
api_name=f"regen_taro_{_i}",
|
| 2121 |
-
)
|
| 2122 |
|
| 2123 |
# ---------------------------------------------------------- #
|
| 2124 |
# Tab 2 — MMAudio #
|
|
@@ -2167,44 +2235,12 @@ with gr.Blocks(title="Generate Audio for Video", css=_SLOT_CSS, js=_GLOBAL_JS) a
|
|
| 2167 |
outputs=mma_slot_grps,
|
| 2168 |
))
|
| 2169 |
|
| 2170 |
-
mma_regen_btns =
|
| 2171 |
-
|
| 2172 |
-
|
| 2173 |
-
|
| 2174 |
-
|
| 2175 |
-
|
| 2176 |
-
def _do(seg_idx, state_json, video, prompt, neg, seed, cfg, steps, cf_dur, cf_db):
|
| 2177 |
-
print(f"[regen MMA] slot={_sid} seg_idx={seg_idx} state_json_len={len(state_json) if state_json else 0}")
|
| 2178 |
-
if not state_json:
|
| 2179 |
-
print(f"[regen MMA] early-exit: state_json empty")
|
| 2180 |
-
yield gr.update(), gr.update(); return
|
| 2181 |
-
lock = _get_slot_lock(_sid)
|
| 2182 |
-
with lock:
|
| 2183 |
-
state = json.loads(state_json)
|
| 2184 |
-
pending_html = _build_regen_pending_html(
|
| 2185 |
-
state["segments"], int(seg_idx), _sid, ""
|
| 2186 |
-
)
|
| 2187 |
-
yield gr.update(), gr.update(value=pending_html)
|
| 2188 |
-
print(f"[regen MMA] slot={_sid} seg_idx={seg_idx} — calling regen_mmaudio_segment")
|
| 2189 |
-
try:
|
| 2190 |
-
vid, aud, new_meta_json, html = regen_mmaudio_segment(
|
| 2191 |
-
video, int(seg_idx), state_json,
|
| 2192 |
-
prompt, neg, seed, cfg, steps, cf_dur, cf_db, _sid,
|
| 2193 |
-
)
|
| 2194 |
-
print(f"[regen MMA] slot={_sid} seg_idx={seg_idx} — done, vid={vid!r}")
|
| 2195 |
-
except Exception as _e:
|
| 2196 |
-
print(f"[regen MMA] slot={_sid} seg_idx={seg_idx} — ERROR: {_e}")
|
| 2197 |
-
raise
|
| 2198 |
-
yield gr.update(value=vid), gr.update(value=html)
|
| 2199 |
-
return _do
|
| 2200 |
-
_btn.click(
|
| 2201 |
-
fn=_make_mma_regen(_i, _slot_id),
|
| 2202 |
-
inputs=[mma_regen_seg, mma_regen_state,
|
| 2203 |
-
mma_video, mma_prompt, mma_neg, mma_seed,
|
| 2204 |
-
mma_cfg, mma_steps, mma_cf_dur, mma_cf_db],
|
| 2205 |
-
outputs=[mma_slot_vids[_i], mma_slot_waves[_i]],
|
| 2206 |
-
api_name=f"regen_mma_{_i}",
|
| 2207 |
-
)
|
| 2208 |
|
| 2209 |
# ---------------------------------------------------------- #
|
| 2210 |
# Tab 3 — HunyuanVideoFoley #
|
|
@@ -2254,44 +2290,12 @@ with gr.Blocks(title="Generate Audio for Video", css=_SLOT_CSS, js=_GLOBAL_JS) a
|
|
| 2254 |
outputs=hf_slot_grps,
|
| 2255 |
))
|
| 2256 |
|
| 2257 |
-
hf_regen_btns =
|
| 2258 |
-
|
| 2259 |
-
|
| 2260 |
-
|
| 2261 |
-
|
| 2262 |
-
|
| 2263 |
-
def _do(seg_idx, state_json, video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db):
|
| 2264 |
-
print(f"[regen HF] slot={_sid} seg_idx={seg_idx} state_json_len={len(state_json) if state_json else 0}")
|
| 2265 |
-
if not state_json:
|
| 2266 |
-
print(f"[regen HF] early-exit: state_json empty")
|
| 2267 |
-
yield gr.update(), gr.update(); return
|
| 2268 |
-
lock = _get_slot_lock(_sid)
|
| 2269 |
-
with lock:
|
| 2270 |
-
state = json.loads(state_json)
|
| 2271 |
-
pending_html = _build_regen_pending_html(
|
| 2272 |
-
state["segments"], int(seg_idx), _sid, ""
|
| 2273 |
-
)
|
| 2274 |
-
yield gr.update(), gr.update(value=pending_html)
|
| 2275 |
-
print(f"[regen HF] slot={_sid} seg_idx={seg_idx} — calling regen_hunyuan_segment")
|
| 2276 |
-
try:
|
| 2277 |
-
vid, aud, new_meta_json, html = regen_hunyuan_segment(
|
| 2278 |
-
video, int(seg_idx), state_json,
|
| 2279 |
-
prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, _sid,
|
| 2280 |
-
)
|
| 2281 |
-
print(f"[regen HF] slot={_sid} seg_idx={seg_idx} — done, vid={vid!r}")
|
| 2282 |
-
except Exception as _e:
|
| 2283 |
-
print(f"[regen HF] slot={_sid} seg_idx={seg_idx} — ERROR: {_e}")
|
| 2284 |
-
raise
|
| 2285 |
-
yield gr.update(value=vid), gr.update(value=html)
|
| 2286 |
-
return _do
|
| 2287 |
-
_btn.click(
|
| 2288 |
-
fn=_make_hf_regen(_i, _slot_id),
|
| 2289 |
-
inputs=[hf_regen_seg, hf_regen_state,
|
| 2290 |
-
hf_video, hf_prompt, hf_neg, hf_seed,
|
| 2291 |
-
hf_guidance, hf_steps, hf_size, hf_cf_dur, hf_cf_db],
|
| 2292 |
-
outputs=[hf_slot_vids[_i], hf_slot_waves[_i]],
|
| 2293 |
-
api_name=f"regen_hf_{_i}",
|
| 2294 |
-
)
|
| 2295 |
|
| 2296 |
# ---- Cross-tab video sync ----
|
| 2297 |
_sync = lambda v: (gr.update(value=v), gr.update(value=v))
|
|
|
|
| 11 |
import os
|
| 12 |
import sys
|
| 13 |
import json
|
| 14 |
+
import shutil
|
| 15 |
import tempfile
|
| 16 |
import random
|
| 17 |
import threading
|
| 18 |
+
import time
|
| 19 |
from pathlib import Path
|
| 20 |
|
|
|
|
| 21 |
import torch
|
| 22 |
import numpy as np
|
| 23 |
import torchaudio
|
|
|
|
| 119 |
|
| 120 |
def _register_tmp_dir(tmp_dir: str) -> str:
|
| 121 |
"""Register a temp dir so it can be cleaned up when newer ones replace it."""
|
|
|
|
| 122 |
_TEMP_DIRS.append(tmp_dir)
|
| 123 |
while len(_TEMP_DIRS) > _TEMP_DIRS_MAX:
|
| 124 |
old = _TEMP_DIRS.pop(0)
|
|
|
|
| 305 |
HUNYUAN_LOAD_OVERHEAD = 55 # ~55s to load the 10GB XXL model weights into GPU
|
| 306 |
GPU_DURATION_CAP = 300 # hard cap per call — never reserve more than this
|
| 307 |
|
| 308 |
+
# ------------------------------------------------------------------ #
|
| 309 |
+
# Model configuration registry — single source of truth for per-model #
|
| 310 |
+
# constants used by duration estimation, segmentation, and UI. #
|
| 311 |
+
# ------------------------------------------------------------------ #
|
| 312 |
+
MODEL_CONFIGS = {
|
| 313 |
+
"taro": {
|
| 314 |
+
"window_s": TARO_MODEL_DUR, # 8.192 s
|
| 315 |
+
"sr": TARO_SR, # 16000
|
| 316 |
+
"secs_per_step": TARO_SECS_PER_STEP, # 0.05
|
| 317 |
+
"load_overhead": TARO_LOAD_OVERHEAD, # 15
|
| 318 |
+
"tab_prefix": "taro",
|
| 319 |
+
"regen_fn": None, # set after function definitions (avoids forward-ref)
|
| 320 |
+
"label": "TARO",
|
| 321 |
+
},
|
| 322 |
+
"mmaudio": {
|
| 323 |
+
"window_s": MMAUDIO_WINDOW, # 8.0 s
|
| 324 |
+
"sr": 44100,
|
| 325 |
+
"secs_per_step": MMAUDIO_SECS_PER_STEP, # 0.25
|
| 326 |
+
"load_overhead": MMAUDIO_LOAD_OVERHEAD, # 15
|
| 327 |
+
"tab_prefix": "mma",
|
| 328 |
+
"regen_fn": None,
|
| 329 |
+
"label": "MMAudio",
|
| 330 |
+
},
|
| 331 |
+
"hunyuan": {
|
| 332 |
+
"window_s": HUNYUAN_MAX_DUR, # 15.0 s
|
| 333 |
+
"sr": 48000,
|
| 334 |
+
"secs_per_step": HUNYUAN_SECS_PER_STEP, # 0.35
|
| 335 |
+
"load_overhead": HUNYUAN_LOAD_OVERHEAD, # 55
|
| 336 |
+
"tab_prefix": "hf",
|
| 337 |
+
"regen_fn": None,
|
| 338 |
+
"label": "HunyuanFoley",
|
| 339 |
+
},
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def _estimate_gpu_duration(model_key: str, num_samples: int, num_steps: int,
|
| 344 |
+
total_dur_s: float = None, crossfade_s: float = 0,
|
| 345 |
+
video_file: str = None) -> int:
|
| 346 |
+
"""Generic GPU duration estimator used by all models.
|
| 347 |
+
|
| 348 |
+
Computes: num_samples × n_segs × num_steps × secs_per_step + load_overhead
|
| 349 |
+
Clamped to [60, GPU_DURATION_CAP].
|
| 350 |
+
"""
|
| 351 |
+
cfg = MODEL_CONFIGS[model_key]
|
| 352 |
+
try:
|
| 353 |
+
if total_dur_s is None:
|
| 354 |
+
total_dur_s = get_video_duration(video_file)
|
| 355 |
+
n_segs = len(_build_segments(total_dur_s, cfg["window_s"], float(crossfade_s)))
|
| 356 |
+
except Exception:
|
| 357 |
+
n_segs = 1
|
| 358 |
+
secs = int(num_samples) * n_segs * int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
|
| 359 |
+
result = min(GPU_DURATION_CAP, max(60, int(secs)))
|
| 360 |
+
print(f"[duration] {cfg['label']}: {int(num_samples)}samp × {n_segs}seg × "
|
| 361 |
+
f"{int(num_steps)}steps → {secs:.0f}s → capped {result}s")
|
| 362 |
+
return result
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def _estimate_regen_duration(model_key: str, num_steps: int) -> int:
|
| 366 |
+
"""Generic GPU duration estimator for single-segment regen."""
|
| 367 |
+
cfg = MODEL_CONFIGS[model_key]
|
| 368 |
+
secs = int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
|
| 369 |
+
result = min(GPU_DURATION_CAP, max(60, int(secs)))
|
| 370 |
+
print(f"[duration] {cfg['label']} regen: 1 seg × {int(num_steps)} steps → {secs:.0f}s → capped {result}s")
|
| 371 |
+
return result
|
| 372 |
+
|
| 373 |
_TARO_CACHE_MAXLEN = 16 # evict oldest entries beyond this limit
|
| 374 |
_TARO_INFERENCE_CACHE: dict = {} # keyed by (video_file, seed, cfg, steps, mode, crossfade_s)
|
|
|
|
| 375 |
_TARO_CACHE_LOCK = threading.Lock()
|
| 376 |
|
| 377 |
|
|
|
|
| 384 |
|
| 385 |
def _taro_duration(video_file, seed_val, cfg_scale, num_steps, mode,
|
| 386 |
crossfade_s, crossfade_db, num_samples):
|
| 387 |
+
"""Pre-GPU callable — must match _taro_gpu_infer's input order exactly."""
|
| 388 |
+
return _estimate_gpu_duration("taro", int(num_samples), int(num_steps),
|
| 389 |
+
video_file=video_file, crossfade_s=crossfade_s)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
|
| 391 |
|
| 392 |
def _taro_infer_segment(
|
|
|
|
| 615 |
|
| 616 |
def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
|
| 617 |
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
|
| 618 |
+
"""Pre-GPU callable — must match _mmaudio_gpu_infer's input order exactly."""
|
| 619 |
+
return _estimate_gpu_duration("mmaudio", int(num_samples), int(num_steps),
|
| 620 |
+
video_file=video_file, crossfade_s=crossfade_s)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
|
| 622 |
|
| 623 |
@spaces.GPU(duration=_mmaudio_duration)
|
|
|
|
| 785 |
|
| 786 |
def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
|
| 787 |
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
|
| 788 |
+
"""Pre-GPU callable — must match _hunyuan_gpu_infer's input order exactly."""
|
| 789 |
+
return _estimate_gpu_duration("hunyuan", int(num_samples), int(num_steps),
|
| 790 |
+
video_file=video_file, crossfade_s=crossfade_s)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 791 |
|
| 792 |
|
| 793 |
@spaces.GPU(duration=_hunyuan_duration)
|
|
|
|
| 1036 |
def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
|
| 1037 |
seed_val, cfg_scale, num_steps, mode,
|
| 1038 |
crossfade_s, crossfade_db, slot_id=None):
|
| 1039 |
+
return _estimate_regen_duration("taro", int(num_steps))
|
|
|
|
|
|
|
|
|
|
| 1040 |
|
| 1041 |
|
| 1042 |
@spaces.GPU(duration=_taro_regen_duration)
|
|
|
|
| 1107 |
def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
|
| 1108 |
prompt, negative_prompt, seed_val,
|
| 1109 |
cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id=None):
|
| 1110 |
+
return _estimate_regen_duration("mmaudio", int(num_steps))
|
|
|
|
|
|
|
|
|
|
| 1111 |
|
| 1112 |
|
| 1113 |
@spaces.GPU(duration=_mmaudio_regen_duration)
|
|
|
|
| 1206 |
prompt, negative_prompt, seed_val,
|
| 1207 |
guidance_scale, num_steps, model_size,
|
| 1208 |
crossfade_s, crossfade_db, slot_id=None):
|
| 1209 |
+
return _estimate_regen_duration("hunyuan", int(num_steps))
|
|
|
|
|
|
|
|
|
|
| 1210 |
|
| 1211 |
|
| 1212 |
@spaces.GPU(duration=_hunyuan_regen_duration)
|
|
|
|
| 1302 |
return video_path, audio_path, json.dumps(updated_meta), waveform_html
|
| 1303 |
|
| 1304 |
|
| 1305 |
+
# Wire up regen_fn references now that the functions are defined
|
| 1306 |
+
MODEL_CONFIGS["taro"]["regen_fn"] = regen_taro_segment
|
| 1307 |
+
MODEL_CONFIGS["mmaudio"]["regen_fn"] = regen_mmaudio_segment
|
| 1308 |
+
MODEL_CONFIGS["hunyuan"]["regen_fn"] = regen_hunyuan_segment
|
| 1309 |
+
|
| 1310 |
+
|
| 1311 |
# ================================================================== #
|
| 1312 |
# SHARED UI HELPERS #
|
| 1313 |
# ================================================================== #
|
| 1314 |
|
| 1315 |
+
def _register_regen_handlers(tab_prefix, model_key, regen_seg_tb, regen_state_tb,
|
| 1316 |
+
input_components, slot_vids, slot_waves):
|
| 1317 |
+
"""Register per-slot regen button handlers for a model tab.
|
| 1318 |
+
|
| 1319 |
+
This replaces the three nearly-identical for-loops that previously existed
|
| 1320 |
+
for TARO, MMAudio, and HunyuanFoley tabs.
|
| 1321 |
+
|
| 1322 |
+
Args:
|
| 1323 |
+
tab_prefix: e.g. "taro", "mma", "hf"
|
| 1324 |
+
model_key: e.g. "taro", "mmaudio", "hunyuan"
|
| 1325 |
+
regen_seg_tb: gr.Textbox for seg_idx (render=False)
|
| 1326 |
+
regen_state_tb: gr.Textbox for state_json (render=False)
|
| 1327 |
+
input_components: list of Gradio input components (video, seed, etc.)
|
| 1328 |
+
— order must match regen_fn signature after (seg_idx, state_json, video)
|
| 1329 |
+
slot_vids: list of gr.Video components per slot
|
| 1330 |
+
slot_waves: list of gr.HTML components per slot
|
| 1331 |
+
Returns:
|
| 1332 |
+
list of hidden gr.Buttons (one per slot)
|
| 1333 |
+
"""
|
| 1334 |
+
cfg = MODEL_CONFIGS[model_key]
|
| 1335 |
+
regen_fn = cfg["regen_fn"]
|
| 1336 |
+
label = cfg["label"]
|
| 1337 |
+
btns = []
|
| 1338 |
+
for _i in range(MAX_SLOTS):
|
| 1339 |
+
_slot_id = f"{tab_prefix}_{_i}"
|
| 1340 |
+
_btn = gr.Button(visible=False, elem_id=f"regen_btn_{_slot_id}")
|
| 1341 |
+
btns.append(_btn)
|
| 1342 |
+
print(f"[startup] registering regen handler for slot {_slot_id}")
|
| 1343 |
+
|
| 1344 |
+
def _make_regen(_si, _sid, _model_key, _label, _regen_fn):
|
| 1345 |
+
def _do(seg_idx, state_json, *args):
|
| 1346 |
+
print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} "
|
| 1347 |
+
f"state_json_len={len(state_json) if state_json else 0}")
|
| 1348 |
+
if not state_json:
|
| 1349 |
+
print(f"[regen {_label}] early-exit: state_json empty")
|
| 1350 |
+
yield gr.update(), gr.update()
|
| 1351 |
+
return
|
| 1352 |
+
lock = _get_slot_lock(_sid)
|
| 1353 |
+
with lock:
|
| 1354 |
+
state = json.loads(state_json)
|
| 1355 |
+
pending_html = _build_regen_pending_html(
|
| 1356 |
+
state["segments"], int(seg_idx), _sid, ""
|
| 1357 |
+
)
|
| 1358 |
+
yield gr.update(), gr.update(value=pending_html)
|
| 1359 |
+
print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} — calling regen")
|
| 1360 |
+
try:
|
| 1361 |
+
# args[0] = video, args[1:] = model-specific params
|
| 1362 |
+
vid, aud, new_meta_json, html = _regen_fn(
|
| 1363 |
+
args[0], int(seg_idx), state_json, *args[1:], _sid,
|
| 1364 |
+
)
|
| 1365 |
+
print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} — done, vid={vid!r}")
|
| 1366 |
+
except Exception as _e:
|
| 1367 |
+
print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} — ERROR: {_e}")
|
| 1368 |
+
raise
|
| 1369 |
+
yield gr.update(value=vid), gr.update(value=html)
|
| 1370 |
+
return _do
|
| 1371 |
+
|
| 1372 |
+
_btn.click(
|
| 1373 |
+
fn=_make_regen(_i, _slot_id, model_key, label, regen_fn),
|
| 1374 |
+
inputs=[regen_seg_tb, regen_state_tb] + input_components,
|
| 1375 |
+
outputs=[slot_vids[_i], slot_waves[_i]],
|
| 1376 |
+
api_name=f"regen_{tab_prefix}_{_i}",
|
| 1377 |
+
)
|
| 1378 |
+
return btns
|
| 1379 |
+
|
| 1380 |
+
|
| 1381 |
def _pad_outputs(outputs: list) -> list:
|
| 1382 |
"""Flatten (video, audio, seg_meta) triples and pad to MAX_SLOTS * 3 with None.
|
| 1383 |
|
|
|
|
| 2179 |
outputs=taro_slot_grps,
|
| 2180 |
))
|
| 2181 |
|
| 2182 |
+
# Per-slot regen handlers — JS calls /gradio_api/queue/join with
|
| 2183 |
+
# fn_index (by api_name) + data=[seg_idx, state_json, video, ...params].
|
| 2184 |
+
taro_regen_btns = _register_regen_handlers(
|
| 2185 |
+
"taro", "taro", taro_regen_seg, taro_regen_state,
|
| 2186 |
+
[taro_video, taro_seed, taro_cfg, taro_steps,
|
| 2187 |
+
taro_mode, taro_cf_dur, taro_cf_db],
|
| 2188 |
+
taro_slot_vids, taro_slot_waves,
|
| 2189 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2190 |
|
| 2191 |
# ---------------------------------------------------------- #
|
| 2192 |
# Tab 2 — MMAudio #
|
|
|
|
| 2235 |
outputs=mma_slot_grps,
|
| 2236 |
))
|
| 2237 |
|
| 2238 |
+
mma_regen_btns = _register_regen_handlers(
|
| 2239 |
+
"mma", "mmaudio", mma_regen_seg, mma_regen_state,
|
| 2240 |
+
[mma_video, mma_prompt, mma_neg, mma_seed,
|
| 2241 |
+
mma_cfg, mma_steps, mma_cf_dur, mma_cf_db],
|
| 2242 |
+
mma_slot_vids, mma_slot_waves,
|
| 2243 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2244 |
|
| 2245 |
# ---------------------------------------------------------- #
|
| 2246 |
# Tab 3 — HunyuanVideoFoley #
|
|
|
|
| 2290 |
outputs=hf_slot_grps,
|
| 2291 |
))
|
| 2292 |
|
| 2293 |
+
hf_regen_btns = _register_regen_handlers(
|
| 2294 |
+
"hf", "hunyuan", hf_regen_seg, hf_regen_state,
|
| 2295 |
+
[hf_video, hf_prompt, hf_neg, hf_seed,
|
| 2296 |
+
hf_guidance, hf_steps, hf_size, hf_cf_dur, hf_cf_db],
|
| 2297 |
+
hf_slot_vids, hf_slot_waves,
|
| 2298 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2299 |
|
| 2300 |
# ---- Cross-tab video sync ----
|
| 2301 |
_sync = lambda v: (gr.update(value=v), gr.update(value=v))
|