BoxOfColors Claude Opus 4.6 commited on
Commit
c06e566
·
1 Parent(s): 13cc4e6

refactor: consolidate duplicated code via MODEL_CONFIGS registry

Browse files

- Add MODEL_CONFIGS dict as single source of truth for per-model constants
(window_s, sr, secs_per_step, load_overhead, tab_prefix, regen_fn, label)
- Replace 6 nearly-identical duration estimators with 2 generic functions:
_estimate_gpu_duration() and _estimate_regen_duration()
- Replace 3 duplicated regen button factory loops (~35 lines each, 40%
duplication) with single _register_regen_handlers() function
- Fix import redundancy: consolidate threading/shutil imports at top of file,
remove duplicate `import threading` at line 310 and inline `import shutil`

Net effect: ~170 lines of duplicated boilerplate eliminated, all model-specific
behavior now parameterized through the registry. Future model additions only need
a new MODEL_CONFIGS entry + model-specific GPU/regen functions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +171 -167
app.py CHANGED
@@ -11,12 +11,13 @@ Supported models
11
  import os
12
  import sys
13
  import json
 
14
  import tempfile
15
  import random
16
  import threading
 
17
  from pathlib import Path
18
 
19
- import time
20
  import torch
21
  import numpy as np
22
  import torchaudio
@@ -118,7 +119,6 @@ _TEMP_DIRS_MAX = 10 # keep at most this many; older ones get cleaned up
118
 
119
  def _register_tmp_dir(tmp_dir: str) -> str:
120
  """Register a temp dir so it can be cleaned up when newer ones replace it."""
121
- import shutil
122
  _TEMP_DIRS.append(tmp_dir)
123
  while len(_TEMP_DIRS) > _TEMP_DIRS_MAX:
124
  old = _TEMP_DIRS.pop(0)
@@ -305,9 +305,73 @@ HUNYUAN_SECS_PER_STEP = 0.35 # measured 0.328s/step on H200 (8.3s video, 1 seg
305
  HUNYUAN_LOAD_OVERHEAD = 55 # ~55s to load the 10GB XXL model weights into GPU
306
  GPU_DURATION_CAP = 300 # hard cap per call — never reserve more than this
307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  _TARO_CACHE_MAXLEN = 16 # evict oldest entries beyond this limit
309
  _TARO_INFERENCE_CACHE: dict = {} # keyed by (video_file, seed, cfg, steps, mode, crossfade_s)
310
- import threading
311
  _TARO_CACHE_LOCK = threading.Lock()
312
 
313
 
@@ -320,16 +384,9 @@ def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: floa
320
 
321
  def _taro_duration(video_file, seed_val, cfg_scale, num_steps, mode,
322
  crossfade_s, crossfade_db, num_samples):
323
- """Pre-GPU callable — must match _run_taro's input order exactly."""
324
- try:
325
- total_s = get_video_duration(video_file)
326
- n_segs = len(_build_segments(total_s, TARO_MODEL_DUR, float(crossfade_s)))
327
- except Exception:
328
- n_segs = 1
329
- secs = int(num_samples) * n_segs * int(num_steps) * TARO_SECS_PER_STEP + TARO_LOAD_OVERHEAD
330
- result = min(GPU_DURATION_CAP, max(60, int(secs)))
331
- print(f"[duration] TARO: {int(num_samples)}samp × {n_segs}seg × {int(num_steps)}steps → {secs:.0f}s → capped {result}s")
332
- return result
333
 
334
 
335
  def _taro_infer_segment(
@@ -558,16 +615,9 @@ MMAUDIO_WINDOW = 8.0 # seconds — MMAudio's fixed generation window
558
 
559
  def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
560
  cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
561
- """Pre-GPU callable — must match _run_mmaudio's input order exactly."""
562
- try:
563
- total_s = get_video_duration(video_file)
564
- n_segs = len(_build_segments(total_s, MMAUDIO_WINDOW, float(crossfade_s)))
565
- except Exception:
566
- n_segs = 1
567
- secs = int(num_samples) * n_segs * int(num_steps) * MMAUDIO_SECS_PER_STEP + MMAUDIO_LOAD_OVERHEAD
568
- result = min(GPU_DURATION_CAP, max(60, int(secs)))
569
- print(f"[duration] MMAudio: {int(num_samples)}samp × {n_segs}seg × {int(num_steps)}steps → {secs:.0f}s → capped {result}s")
570
- return result
571
 
572
 
573
  @spaces.GPU(duration=_mmaudio_duration)
@@ -735,16 +785,9 @@ HUNYUAN_MAX_DUR = 15.0 # seconds
735
 
736
  def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
737
  guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
738
- """Pre-GPU callable — must match _run_hunyuan's input order exactly."""
739
- try:
740
- total_s = get_video_duration(video_file)
741
- n_segs = len(_build_segments(total_s, HUNYUAN_MAX_DUR, float(crossfade_s)))
742
- except Exception:
743
- n_segs = 1
744
- secs = int(num_samples) * n_segs * int(num_steps) * HUNYUAN_SECS_PER_STEP + HUNYUAN_LOAD_OVERHEAD
745
- result = min(GPU_DURATION_CAP, max(60, int(secs)))
746
- print(f"[duration] HunyuanFoley: {int(num_samples)}samp × {n_segs}seg × {int(num_steps)}steps → {secs:.0f}s → capped {result}s")
747
- return result
748
 
749
 
750
  @spaces.GPU(duration=_hunyuan_duration)
@@ -993,10 +1036,7 @@ def _splice_and_save(new_wav, seg_idx, meta, slot_id):
993
  def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
994
  seed_val, cfg_scale, num_steps, mode,
995
  crossfade_s, crossfade_db, slot_id=None):
996
- secs = int(num_steps) * TARO_SECS_PER_STEP + TARO_LOAD_OVERHEAD
997
- result = min(GPU_DURATION_CAP, max(60, int(secs)))
998
- print(f"[duration] TARO regen: 1 seg × {int(num_steps)} steps → {secs:.0f}s → capped {result}s")
999
- return result
1000
 
1001
 
1002
  @spaces.GPU(duration=_taro_regen_duration)
@@ -1067,10 +1107,7 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
1067
  def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
1068
  prompt, negative_prompt, seed_val,
1069
  cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id=None):
1070
- secs = int(num_steps) * MMAUDIO_SECS_PER_STEP + MMAUDIO_LOAD_OVERHEAD
1071
- result = min(GPU_DURATION_CAP, max(60, int(secs)))
1072
- print(f"[duration] MMAudio regen: 1 seg × {int(num_steps)} steps → {secs:.0f}s → capped {result}s")
1073
- return result
1074
 
1075
 
1076
  @spaces.GPU(duration=_mmaudio_regen_duration)
@@ -1169,10 +1206,7 @@ def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json,
1169
  prompt, negative_prompt, seed_val,
1170
  guidance_scale, num_steps, model_size,
1171
  crossfade_s, crossfade_db, slot_id=None):
1172
- secs = int(num_steps) * HUNYUAN_SECS_PER_STEP + HUNYUAN_LOAD_OVERHEAD
1173
- result = min(GPU_DURATION_CAP, max(60, int(secs)))
1174
- print(f"[duration] HunyuanFoley regen: 1 seg × {int(num_steps)} steps → {secs:.0f}s → capped {result}s")
1175
- return result
1176
 
1177
 
1178
  @spaces.GPU(duration=_hunyuan_regen_duration)
@@ -1268,10 +1302,82 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1268
  return video_path, audio_path, json.dumps(updated_meta), waveform_html
1269
 
1270
 
 
 
 
 
 
 
1271
  # ================================================================== #
1272
  # SHARED UI HELPERS #
1273
  # ================================================================== #
1274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1275
  def _pad_outputs(outputs: list) -> list:
1276
  """Flatten (video, audio, seg_meta) triples and pad to MAX_SLOTS * 3 with None.
1277
 
@@ -2073,52 +2179,14 @@ with gr.Blocks(title="Generate Audio for Video", css=_SLOT_CSS, js=_GLOBAL_JS) a
2073
  outputs=taro_slot_grps,
2074
  ))
2075
 
2076
- # Per-slot regen handlers for TARO.
2077
- # JS calls /gradio_api/queue/join directly with fn_index + data array:
2078
- # data = [seg_idx, state_json, video_path_or_null, seed, cfg, steps, mode, cf_dur, cf_db]
2079
- # fn_index is discovered at runtime from gradio_config.dependencies by api_name.
2080
- # The handlers are registered via a dummy gr.Button click so Gradio assigns them
2081
- # a stable fn_index and api_name.
2082
- taro_regen_btns = []
2083
- for _i in range(MAX_SLOTS):
2084
- _slot_id = f"taro_{_i}"
2085
- _btn = gr.Button(visible=False, elem_id=f"regen_btn_{_slot_id}")
2086
- taro_regen_btns.append(_btn)
2087
- print(f"[startup] registering regen handler for slot {_slot_id}")
2088
- def _make_taro_regen(_si, _sid):
2089
- def _do(seg_idx, state_json, video, seed, cfg, steps, mode, cf_dur, cf_db):
2090
- print(f"[regen TARO] slot={_sid} seg_idx={seg_idx} state_json_len={len(state_json) if state_json else 0}")
2091
- if not state_json:
2092
- print(f"[regen TARO] early-exit: state_json empty")
2093
- yield gr.update(), gr.update(); return
2094
- lock = _get_slot_lock(_sid)
2095
- with lock:
2096
- print(f"[regen TARO] slot={_sid} seg_idx={seg_idx} — lock acquired, showing spinner")
2097
- state = json.loads(state_json)
2098
- pending_html = _build_regen_pending_html(
2099
- state["segments"], int(seg_idx), _sid, ""
2100
- )
2101
- yield gr.update(), gr.update(value=pending_html)
2102
- print(f"[regen TARO] slot={_sid} seg_idx={seg_idx} — calling regen_taro_segment")
2103
- try:
2104
- vid, aud, new_meta_json, html = regen_taro_segment(
2105
- video, int(seg_idx), state_json,
2106
- seed, cfg, steps, mode, cf_dur, cf_db, _sid,
2107
- )
2108
- print(f"[regen TARO] slot={_sid} seg_idx={seg_idx} — done, vid={vid!r}")
2109
- except Exception as _e:
2110
- print(f"[regen TARO] slot={_sid} seg_idx={seg_idx} — ERROR: {_e}")
2111
- raise
2112
- yield gr.update(value=vid), gr.update(value=html)
2113
- return _do
2114
- _btn.click(
2115
- fn=_make_taro_regen(_i, _slot_id),
2116
- inputs=[taro_regen_seg, taro_regen_state,
2117
- taro_video, taro_seed, taro_cfg, taro_steps,
2118
- taro_mode, taro_cf_dur, taro_cf_db],
2119
- outputs=[taro_slot_vids[_i], taro_slot_waves[_i]],
2120
- api_name=f"regen_taro_{_i}",
2121
- )
2122
 
2123
  # ---------------------------------------------------------- #
2124
  # Tab 2 — MMAudio #
@@ -2167,44 +2235,12 @@ with gr.Blocks(title="Generate Audio for Video", css=_SLOT_CSS, js=_GLOBAL_JS) a
2167
  outputs=mma_slot_grps,
2168
  ))
2169
 
2170
- mma_regen_btns = []
2171
- for _i in range(MAX_SLOTS):
2172
- _slot_id = f"mma_{_i}"
2173
- _btn = gr.Button(visible=False, elem_id=f"regen_btn_{_slot_id}")
2174
- mma_regen_btns.append(_btn)
2175
- def _make_mma_regen(_si, _sid):
2176
- def _do(seg_idx, state_json, video, prompt, neg, seed, cfg, steps, cf_dur, cf_db):
2177
- print(f"[regen MMA] slot={_sid} seg_idx={seg_idx} state_json_len={len(state_json) if state_json else 0}")
2178
- if not state_json:
2179
- print(f"[regen MMA] early-exit: state_json empty")
2180
- yield gr.update(), gr.update(); return
2181
- lock = _get_slot_lock(_sid)
2182
- with lock:
2183
- state = json.loads(state_json)
2184
- pending_html = _build_regen_pending_html(
2185
- state["segments"], int(seg_idx), _sid, ""
2186
- )
2187
- yield gr.update(), gr.update(value=pending_html)
2188
- print(f"[regen MMA] slot={_sid} seg_idx={seg_idx} — calling regen_mmaudio_segment")
2189
- try:
2190
- vid, aud, new_meta_json, html = regen_mmaudio_segment(
2191
- video, int(seg_idx), state_json,
2192
- prompt, neg, seed, cfg, steps, cf_dur, cf_db, _sid,
2193
- )
2194
- print(f"[regen MMA] slot={_sid} seg_idx={seg_idx} — done, vid={vid!r}")
2195
- except Exception as _e:
2196
- print(f"[regen MMA] slot={_sid} seg_idx={seg_idx} — ERROR: {_e}")
2197
- raise
2198
- yield gr.update(value=vid), gr.update(value=html)
2199
- return _do
2200
- _btn.click(
2201
- fn=_make_mma_regen(_i, _slot_id),
2202
- inputs=[mma_regen_seg, mma_regen_state,
2203
- mma_video, mma_prompt, mma_neg, mma_seed,
2204
- mma_cfg, mma_steps, mma_cf_dur, mma_cf_db],
2205
- outputs=[mma_slot_vids[_i], mma_slot_waves[_i]],
2206
- api_name=f"regen_mma_{_i}",
2207
- )
2208
 
2209
  # ---------------------------------------------------------- #
2210
  # Tab 3 — HunyuanVideoFoley #
@@ -2254,44 +2290,12 @@ with gr.Blocks(title="Generate Audio for Video", css=_SLOT_CSS, js=_GLOBAL_JS) a
2254
  outputs=hf_slot_grps,
2255
  ))
2256
 
2257
- hf_regen_btns = []
2258
- for _i in range(MAX_SLOTS):
2259
- _slot_id = f"hf_{_i}"
2260
- _btn = gr.Button(visible=False, elem_id=f"regen_btn_{_slot_id}")
2261
- hf_regen_btns.append(_btn)
2262
- def _make_hf_regen(_si, _sid):
2263
- def _do(seg_idx, state_json, video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db):
2264
- print(f"[regen HF] slot={_sid} seg_idx={seg_idx} state_json_len={len(state_json) if state_json else 0}")
2265
- if not state_json:
2266
- print(f"[regen HF] early-exit: state_json empty")
2267
- yield gr.update(), gr.update(); return
2268
- lock = _get_slot_lock(_sid)
2269
- with lock:
2270
- state = json.loads(state_json)
2271
- pending_html = _build_regen_pending_html(
2272
- state["segments"], int(seg_idx), _sid, ""
2273
- )
2274
- yield gr.update(), gr.update(value=pending_html)
2275
- print(f"[regen HF] slot={_sid} seg_idx={seg_idx} — calling regen_hunyuan_segment")
2276
- try:
2277
- vid, aud, new_meta_json, html = regen_hunyuan_segment(
2278
- video, int(seg_idx), state_json,
2279
- prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, _sid,
2280
- )
2281
- print(f"[regen HF] slot={_sid} seg_idx={seg_idx} — done, vid={vid!r}")
2282
- except Exception as _e:
2283
- print(f"[regen HF] slot={_sid} seg_idx={seg_idx} — ERROR: {_e}")
2284
- raise
2285
- yield gr.update(value=vid), gr.update(value=html)
2286
- return _do
2287
- _btn.click(
2288
- fn=_make_hf_regen(_i, _slot_id),
2289
- inputs=[hf_regen_seg, hf_regen_state,
2290
- hf_video, hf_prompt, hf_neg, hf_seed,
2291
- hf_guidance, hf_steps, hf_size, hf_cf_dur, hf_cf_db],
2292
- outputs=[hf_slot_vids[_i], hf_slot_waves[_i]],
2293
- api_name=f"regen_hf_{_i}",
2294
- )
2295
 
2296
  # ---- Cross-tab video sync ----
2297
  _sync = lambda v: (gr.update(value=v), gr.update(value=v))
 
11
  import os
12
  import sys
13
  import json
14
+ import shutil
15
  import tempfile
16
  import random
17
  import threading
18
+ import time
19
  from pathlib import Path
20
 
 
21
  import torch
22
  import numpy as np
23
  import torchaudio
 
119
 
120
  def _register_tmp_dir(tmp_dir: str) -> str:
121
  """Register a temp dir so it can be cleaned up when newer ones replace it."""
 
122
  _TEMP_DIRS.append(tmp_dir)
123
  while len(_TEMP_DIRS) > _TEMP_DIRS_MAX:
124
  old = _TEMP_DIRS.pop(0)
 
305
  HUNYUAN_LOAD_OVERHEAD = 55 # ~55s to load the 10GB XXL model weights into GPU
306
  GPU_DURATION_CAP = 300 # hard cap per call — never reserve more than this
307
 
308
+ # ------------------------------------------------------------------ #
309
+ # Model configuration registry — single source of truth for per-model #
310
+ # constants used by duration estimation, segmentation, and UI. #
311
+ # ------------------------------------------------------------------ #
312
+ MODEL_CONFIGS = {
313
+ "taro": {
314
+ "window_s": TARO_MODEL_DUR, # 8.192 s
315
+ "sr": TARO_SR, # 16000
316
+ "secs_per_step": TARO_SECS_PER_STEP, # 0.05
317
+ "load_overhead": TARO_LOAD_OVERHEAD, # 15
318
+ "tab_prefix": "taro",
319
+ "regen_fn": None, # set after function definitions (avoids forward-ref)
320
+ "label": "TARO",
321
+ },
322
+ "mmaudio": {
323
+ "window_s": MMAUDIO_WINDOW, # 8.0 s
324
+ "sr": 44100,
325
+ "secs_per_step": MMAUDIO_SECS_PER_STEP, # 0.25
326
+ "load_overhead": MMAUDIO_LOAD_OVERHEAD, # 15
327
+ "tab_prefix": "mma",
328
+ "regen_fn": None,
329
+ "label": "MMAudio",
330
+ },
331
+ "hunyuan": {
332
+ "window_s": HUNYUAN_MAX_DUR, # 15.0 s
333
+ "sr": 48000,
334
+ "secs_per_step": HUNYUAN_SECS_PER_STEP, # 0.35
335
+ "load_overhead": HUNYUAN_LOAD_OVERHEAD, # 55
336
+ "tab_prefix": "hf",
337
+ "regen_fn": None,
338
+ "label": "HunyuanFoley",
339
+ },
340
+ }
341
+
342
+
343
+ def _estimate_gpu_duration(model_key: str, num_samples: int, num_steps: int,
344
+ total_dur_s: float = None, crossfade_s: float = 0,
345
+ video_file: str = None) -> int:
346
+ """Generic GPU duration estimator used by all models.
347
+
348
+ Computes: num_samples × n_segs × num_steps × secs_per_step + load_overhead
349
+ Clamped to [60, GPU_DURATION_CAP].
350
+ """
351
+ cfg = MODEL_CONFIGS[model_key]
352
+ try:
353
+ if total_dur_s is None:
354
+ total_dur_s = get_video_duration(video_file)
355
+ n_segs = len(_build_segments(total_dur_s, cfg["window_s"], float(crossfade_s)))
356
+ except Exception:
357
+ n_segs = 1
358
+ secs = int(num_samples) * n_segs * int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
359
+ result = min(GPU_DURATION_CAP, max(60, int(secs)))
360
+ print(f"[duration] {cfg['label']}: {int(num_samples)}samp × {n_segs}seg × "
361
+ f"{int(num_steps)}steps → {secs:.0f}s → capped {result}s")
362
+ return result
363
+
364
+
365
+ def _estimate_regen_duration(model_key: str, num_steps: int) -> int:
366
+ """Generic GPU duration estimator for single-segment regen."""
367
+ cfg = MODEL_CONFIGS[model_key]
368
+ secs = int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
369
+ result = min(GPU_DURATION_CAP, max(60, int(secs)))
370
+ print(f"[duration] {cfg['label']} regen: 1 seg × {int(num_steps)} steps → {secs:.0f}s → capped {result}s")
371
+ return result
372
+
373
  _TARO_CACHE_MAXLEN = 16 # evict oldest entries beyond this limit
374
  _TARO_INFERENCE_CACHE: dict = {} # keyed by (video_file, seed, cfg, steps, mode, crossfade_s)
 
375
  _TARO_CACHE_LOCK = threading.Lock()
376
 
377
 
 
384
 
385
  def _taro_duration(video_file, seed_val, cfg_scale, num_steps, mode,
386
  crossfade_s, crossfade_db, num_samples):
387
+ """Pre-GPU callable — must match _taro_gpu_infer's input order exactly."""
388
+ return _estimate_gpu_duration("taro", int(num_samples), int(num_steps),
389
+ video_file=video_file, crossfade_s=crossfade_s)
 
 
 
 
 
 
 
390
 
391
 
392
  def _taro_infer_segment(
 
615
 
616
  def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
617
  cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
618
+ """Pre-GPU callable — must match _mmaudio_gpu_infer's input order exactly."""
619
+ return _estimate_gpu_duration("mmaudio", int(num_samples), int(num_steps),
620
+ video_file=video_file, crossfade_s=crossfade_s)
 
 
 
 
 
 
 
621
 
622
 
623
  @spaces.GPU(duration=_mmaudio_duration)
 
785
 
786
  def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
787
  guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
788
+ """Pre-GPU callable — must match _hunyuan_gpu_infer's input order exactly."""
789
+ return _estimate_gpu_duration("hunyuan", int(num_samples), int(num_steps),
790
+ video_file=video_file, crossfade_s=crossfade_s)
 
 
 
 
 
 
 
791
 
792
 
793
  @spaces.GPU(duration=_hunyuan_duration)
 
1036
  def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
1037
  seed_val, cfg_scale, num_steps, mode,
1038
  crossfade_s, crossfade_db, slot_id=None):
1039
+ return _estimate_regen_duration("taro", int(num_steps))
 
 
 
1040
 
1041
 
1042
  @spaces.GPU(duration=_taro_regen_duration)
 
1107
  def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
1108
  prompt, negative_prompt, seed_val,
1109
  cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id=None):
1110
+ return _estimate_regen_duration("mmaudio", int(num_steps))
 
 
 
1111
 
1112
 
1113
  @spaces.GPU(duration=_mmaudio_regen_duration)
 
1206
  prompt, negative_prompt, seed_val,
1207
  guidance_scale, num_steps, model_size,
1208
  crossfade_s, crossfade_db, slot_id=None):
1209
+ return _estimate_regen_duration("hunyuan", int(num_steps))
 
 
 
1210
 
1211
 
1212
  @spaces.GPU(duration=_hunyuan_regen_duration)
 
1302
  return video_path, audio_path, json.dumps(updated_meta), waveform_html
1303
 
1304
 
1305
+ # Wire up regen_fn references now that the functions are defined
1306
+ MODEL_CONFIGS["taro"]["regen_fn"] = regen_taro_segment
1307
+ MODEL_CONFIGS["mmaudio"]["regen_fn"] = regen_mmaudio_segment
1308
+ MODEL_CONFIGS["hunyuan"]["regen_fn"] = regen_hunyuan_segment
1309
+
1310
+
1311
  # ================================================================== #
1312
  # SHARED UI HELPERS #
1313
  # ================================================================== #
1314
 
1315
+ def _register_regen_handlers(tab_prefix, model_key, regen_seg_tb, regen_state_tb,
1316
+ input_components, slot_vids, slot_waves):
1317
+ """Register per-slot regen button handlers for a model tab.
1318
+
1319
+ This replaces the three nearly-identical for-loops that previously existed
1320
+ for TARO, MMAudio, and HunyuanFoley tabs.
1321
+
1322
+ Args:
1323
+ tab_prefix: e.g. "taro", "mma", "hf"
1324
+ model_key: e.g. "taro", "mmaudio", "hunyuan"
1325
+ regen_seg_tb: gr.Textbox for seg_idx (render=False)
1326
+ regen_state_tb: gr.Textbox for state_json (render=False)
1327
+ input_components: list of Gradio input components (video, seed, etc.)
1328
+ — order must match regen_fn signature after (seg_idx, state_json, video)
1329
+ slot_vids: list of gr.Video components per slot
1330
+ slot_waves: list of gr.HTML components per slot
1331
+ Returns:
1332
+ list of hidden gr.Buttons (one per slot)
1333
+ """
1334
+ cfg = MODEL_CONFIGS[model_key]
1335
+ regen_fn = cfg["regen_fn"]
1336
+ label = cfg["label"]
1337
+ btns = []
1338
+ for _i in range(MAX_SLOTS):
1339
+ _slot_id = f"{tab_prefix}_{_i}"
1340
+ _btn = gr.Button(visible=False, elem_id=f"regen_btn_{_slot_id}")
1341
+ btns.append(_btn)
1342
+ print(f"[startup] registering regen handler for slot {_slot_id}")
1343
+
1344
+ def _make_regen(_si, _sid, _model_key, _label, _regen_fn):
1345
+ def _do(seg_idx, state_json, *args):
1346
+ print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} "
1347
+ f"state_json_len={len(state_json) if state_json else 0}")
1348
+ if not state_json:
1349
+ print(f"[regen {_label}] early-exit: state_json empty")
1350
+ yield gr.update(), gr.update()
1351
+ return
1352
+ lock = _get_slot_lock(_sid)
1353
+ with lock:
1354
+ state = json.loads(state_json)
1355
+ pending_html = _build_regen_pending_html(
1356
+ state["segments"], int(seg_idx), _sid, ""
1357
+ )
1358
+ yield gr.update(), gr.update(value=pending_html)
1359
+ print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} — calling regen")
1360
+ try:
1361
+ # args[0] = video, args[1:] = model-specific params
1362
+ vid, aud, new_meta_json, html = _regen_fn(
1363
+ args[0], int(seg_idx), state_json, *args[1:], _sid,
1364
+ )
1365
+ print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} — done, vid={vid!r}")
1366
+ except Exception as _e:
1367
+ print(f"[regen {_label}] slot={_sid} seg_idx={seg_idx} — ERROR: {_e}")
1368
+ raise
1369
+ yield gr.update(value=vid), gr.update(value=html)
1370
+ return _do
1371
+
1372
+ _btn.click(
1373
+ fn=_make_regen(_i, _slot_id, model_key, label, regen_fn),
1374
+ inputs=[regen_seg_tb, regen_state_tb] + input_components,
1375
+ outputs=[slot_vids[_i], slot_waves[_i]],
1376
+ api_name=f"regen_{tab_prefix}_{_i}",
1377
+ )
1378
+ return btns
1379
+
1380
+
1381
  def _pad_outputs(outputs: list) -> list:
1382
  """Flatten (video, audio, seg_meta) triples and pad to MAX_SLOTS * 3 with None.
1383
 
 
2179
  outputs=taro_slot_grps,
2180
  ))
2181
 
2182
+ # Per-slot regen handlers JS calls /gradio_api/queue/join with
2183
+ # fn_index (by api_name) + data=[seg_idx, state_json, video, ...params].
2184
+ taro_regen_btns = _register_regen_handlers(
2185
+ "taro", "taro", taro_regen_seg, taro_regen_state,
2186
+ [taro_video, taro_seed, taro_cfg, taro_steps,
2187
+ taro_mode, taro_cf_dur, taro_cf_db],
2188
+ taro_slot_vids, taro_slot_waves,
2189
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2190
 
2191
  # ---------------------------------------------------------- #
2192
  # Tab 2 — MMAudio #
 
2235
  outputs=mma_slot_grps,
2236
  ))
2237
 
2238
+ mma_regen_btns = _register_regen_handlers(
2239
+ "mma", "mmaudio", mma_regen_seg, mma_regen_state,
2240
+ [mma_video, mma_prompt, mma_neg, mma_seed,
2241
+ mma_cfg, mma_steps, mma_cf_dur, mma_cf_db],
2242
+ mma_slot_vids, mma_slot_waves,
2243
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2244
 
2245
  # ---------------------------------------------------------- #
2246
  # Tab 3 — HunyuanVideoFoley #
 
2290
  outputs=hf_slot_grps,
2291
  ))
2292
 
2293
+ hf_regen_btns = _register_regen_handlers(
2294
+ "hf", "hunyuan", hf_regen_seg, hf_regen_state,
2295
+ [hf_video, hf_prompt, hf_neg, hf_seed,
2296
+ hf_guidance, hf_steps, hf_size, hf_cf_dur, hf_cf_db],
2297
+ hf_slot_vids, hf_slot_waves,
2298
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2299
 
2300
  # ---- Cross-tab video sync ----
2301
  _sync = lambda v: (gr.update(value=v), gr.update(value=v))