BoxOfColors commited on
Commit
8375700
·
1 Parent(s): 9ef3cf6
Files changed (1) hide show
  1. app.py +138 -41
app.py CHANGED
@@ -14,6 +14,7 @@ import json
14
  import base64
15
  import tempfile
16
  import random
 
17
  from pathlib import Path
18
 
19
  import time
@@ -75,6 +76,20 @@ print("CLAP model pre-downloaded.")
75
  # ================================================================== #
76
 
77
  MAX_SLOTS = 8 # max parallel generation slots shown in UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  def set_global_seed(seed: int):
80
  np.random.seed(seed % (2**32))
@@ -1085,6 +1100,59 @@ def _pad_outputs(outputs: list) -> list:
1085
  # WaveSurfer waveform + segment marker HTML builder #
1086
  # ------------------------------------------------------------------ #
1087
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1088
  def _build_waveform_html(audio_path: str, segments: list, slot_id: str,
1089
  hidden_input_id: str) -> str:
1090
  """Return a self-contained HTML block with a Canvas waveform (display only),
@@ -1336,7 +1404,9 @@ def _build_waveform_html(audio_path: str, segments: list, slot_id: str,
1336
  def _make_output_slots(tab_prefix: str) -> tuple:
1337
  """Build MAX_SLOTS output groups for one tab.
1338
 
1339
- Each slot has: video, waveform HTML, hidden regen trigger textbox, seg state.
 
 
1340
  Returns (grps, vids, waveforms, regen_triggers, seg_states).
1341
  """
1342
  grps, vids, waveforms, regen_triggers, seg_states = [], [], [], [], []
@@ -1347,14 +1417,20 @@ def _make_output_slots(tab_prefix: str) -> tuple:
1347
  waveforms.append(gr.HTML(
1348
  value="<p style='color:#888;font-size:12px'>Generate audio to see waveform.</p>",
1349
  ))
1350
- # Hidden textbox: JS writes "<slot_id>|<seg_idx>" here to trigger regen
1351
  regen_triggers.append(gr.Textbox(
1352
  value="",
1353
  visible=False,
1354
  elem_id=f"regen_trigger_{slot_id}",
1355
  label=f"regen_trigger_{slot_id}",
1356
  ))
1357
- seg_states.append(gr.State(value=None))
 
 
 
 
 
 
1358
  grps.append(g)
1359
  return grps, vids, waveforms, regen_triggers, seg_states
1360
 
@@ -1381,12 +1457,13 @@ def _unpack_outputs(flat: list, n: int, tab_prefix: str) -> list:
1381
  hidden_el_id = f"regen_trigger_{slot_id}"
1382
  html = _build_waveform_html(aud_path, meta["segments"], slot_id, hidden_el_id)
1383
  wave_updates.append(gr.update(value=html))
1384
- state_updates.append(meta)
 
1385
  else:
1386
  wave_updates.append(gr.update(
1387
  value="<p style='color:#888;font-size:12px'>Generate audio to see waveform.</p>"
1388
  ))
1389
- state_updates.append(None)
1390
  return vid_updates + wave_updates + state_updates
1391
 
1392
 
@@ -1571,20 +1648,28 @@ with gr.Blocks(title="Generate Audio for Video", css=_SLOT_CSS, js=_GLOBAL_JS) a
1571
  for _i, _rtrig in enumerate(taro_slot_rtrigs):
1572
  _slot_id = f"taro_{_i}"
1573
  def _make_taro_regen(_si, _sid):
1574
- def _do(trigger_val, video, seed, cfg, steps, mode, cf_dur, cf_db, state):
1575
- if not trigger_val or not state:
1576
- return gr.update(), gr.update(), state, gr.update()
1577
  parts = trigger_val.split("|")
1578
  if len(parts) != 2 or parts[0] != _sid:
1579
- return gr.update(), gr.update(), state, gr.update()
1580
- seg_idx = int(parts[1])
1581
- meta_json = json.dumps(state)
1582
- vid, aud, new_meta_json, html = regen_taro_segment(
1583
- video, seg_idx, meta_json,
1584
- seed, cfg, steps, mode, cf_dur, cf_db, _sid,
1585
- )
1586
- new_meta = json.loads(new_meta_json)
1587
- return gr.update(value=vid), gr.update(value=html), new_meta, gr.update(value="")
 
 
 
 
 
 
 
 
1588
  return _do
1589
  _rtrig.change(
1590
  fn=_make_taro_regen(_i, _slot_id),
@@ -1645,20 +1730,26 @@ with gr.Blocks(title="Generate Audio for Video", css=_SLOT_CSS, js=_GLOBAL_JS) a
1645
  for _i, _rtrig in enumerate(mma_slot_rtrigs):
1646
  _slot_id = f"mma_{_i}"
1647
  def _make_mma_regen(_si, _sid):
1648
- def _do(trigger_val, video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, state):
1649
- if not trigger_val or not state:
1650
- return gr.update(), gr.update(), state, gr.update()
1651
  parts = trigger_val.split("|")
1652
  if len(parts) != 2 or parts[0] != _sid:
1653
- return gr.update(), gr.update(), state, gr.update()
1654
- seg_idx = int(parts[1])
1655
- meta_json = json.dumps(state)
1656
- vid, aud, new_meta_json, html = regen_mmaudio_segment(
1657
- video, seg_idx, meta_json,
1658
- prompt, neg, seed, cfg, steps, cf_dur, cf_db, _sid,
1659
- )
1660
- new_meta = json.loads(new_meta_json)
1661
- return gr.update(value=vid), gr.update(value=html), new_meta, gr.update(value="")
 
 
 
 
 
 
1662
  return _do
1663
  _rtrig.change(
1664
  fn=_make_mma_regen(_i, _slot_id),
@@ -1720,20 +1811,26 @@ with gr.Blocks(title="Generate Audio for Video", css=_SLOT_CSS, js=_GLOBAL_JS) a
1720
  for _i, _rtrig in enumerate(hf_slot_rtrigs):
1721
  _slot_id = f"hf_{_i}"
1722
  def _make_hf_regen(_si, _sid):
1723
- def _do(trigger_val, video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, state):
1724
- if not trigger_val or not state:
1725
- return gr.update(), gr.update(), state, gr.update()
1726
  parts = trigger_val.split("|")
1727
  if len(parts) != 2 or parts[0] != _sid:
1728
- return gr.update(), gr.update(), state, gr.update()
1729
- seg_idx = int(parts[1])
1730
- meta_json = json.dumps(state)
1731
- vid, aud, new_meta_json, html = regen_hunyuan_segment(
1732
- video, seg_idx, meta_json,
1733
- prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, _sid,
1734
- )
1735
- new_meta = json.loads(new_meta_json)
1736
- return gr.update(value=vid), gr.update(value=html), new_meta, gr.update(value="")
 
 
 
 
 
 
1737
  return _do
1738
  _rtrig.change(
1739
  fn=_make_hf_regen(_i, _slot_id),
 
14
  import base64
15
  import tempfile
16
  import random
17
+ import threading
18
  from pathlib import Path
19
 
20
  import time
 
76
  # ================================================================== #
77
 
78
  MAX_SLOTS = 8 # max parallel generation slots shown in UI
79
+ MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≤ ~64 s at 8 s/seg)
80
+
81
+ # Per-slot reentrant locks — prevent concurrent regens on the same slot from
82
+ # producing a race condition where the second regen reads stale state
83
+ # (the shared seg_state textbox hasn't been updated yet by the first regen).
84
+ # Locks are keyed by slot_id string (e.g. "taro_0", "mma_2").
85
+ _SLOT_LOCKS: dict = {}
86
+ _SLOT_LOCKS_MUTEX = threading.Lock()
87
+
88
+ def _get_slot_lock(slot_id: str) -> threading.Lock:
89
+ with _SLOT_LOCKS_MUTEX:
90
+ if slot_id not in _SLOT_LOCKS:
91
+ _SLOT_LOCKS[slot_id] = threading.Lock()
92
+ return _SLOT_LOCKS[slot_id]
93
 
94
  def set_global_seed(seed: int):
95
  np.random.seed(seed % (2**32))
 
1100
  # WaveSurfer waveform + segment marker HTML builder #
1101
  # ------------------------------------------------------------------ #
1102
 
1103
+ def _build_regen_pending_html(segments: list, regen_seg_idx: int, slot_id: str,
1104
+ hidden_input_id: str) -> str:
1105
+ """Return a waveform placeholder shown while a segment is being regenerated.
1106
+
1107
+ Renders a dark bar with the active segment highlighted in amber + a spinner.
1108
+ """
1109
+ segs_json = json.dumps(segments)
1110
+ seg_colors = ["rgba(100,180,255,0.25)", "rgba(255,160,100,0.25)",
1111
+ "rgba(120,220,140,0.25)", "rgba(220,120,220,0.25)",
1112
+ "rgba(255,220,80,0.25)", "rgba(80,220,220,0.25)",
1113
+ "rgba(255,100,100,0.25)", "rgba(180,255,180,0.25)"]
1114
+ active_color = "rgba(255,180,0,0.55)"
1115
+ duration = segments[-1][1] if segments else 1.0
1116
+
1117
+ seg_divs = ""
1118
+ for i, seg in enumerate(segments):
1119
+ left_pct = seg[0] / duration * 100
1120
+ width_pct = (seg[1] - seg[0]) / duration * 100
1121
+ color = active_color if i == regen_seg_idx else seg_colors[i % len(seg_colors)]
1122
+ extra = "border:2px solid #ffb300;animation:wf_pulse 0.8s ease-in-out infinite alternate;" if i == regen_seg_idx else ""
1123
+ seg_divs += (
1124
+ f'<div style="position:absolute;top:0;left:{left_pct:.2f}%;'
1125
+ f'width:{width_pct:.2f}%;height:100%;background:{color};{extra}">'
1126
+ f'<span style="color:rgba(255,255,255,0.7);font-size:10px;padding:2px 3px;">Seg {i+1}</span>'
1127
+ f'</div>'
1128
+ )
1129
+
1130
+ spinner = (
1131
+ '<div style="position:absolute;top:50%;left:50%;transform:translate(-50%,-50%);'
1132
+ 'display:flex;align-items:center;gap:6px;">'
1133
+ '<div style="width:14px;height:14px;border:2px solid #ffb300;'
1134
+ 'border-top-color:transparent;border-radius:50%;'
1135
+ 'animation:wf_spin 0.7s linear infinite;"></div>'
1136
+ f'<span style="color:#ffb300;font-size:12px;white-space:nowrap;">'
1137
+ f'Regenerating Seg {regen_seg_idx+1}…</span>'
1138
+ '</div>'
1139
+ )
1140
+
1141
+ return f"""
1142
+ <style>
1143
+ @keyframes wf_pulse {{from{{opacity:0.5}}to{{opacity:1}}}}
1144
+ @keyframes wf_spin {{to{{transform:rotate(360deg)}}}}
1145
+ </style>
1146
+ <div style="background:#1a1a1a;border-radius:8px;padding:10px;margin-top:6px;">
1147
+ <div style="position:relative;width:100%;height:80px;background:#1e1e2e;border-radius:4px;overflow:hidden;">
1148
+ {seg_divs}
1149
+ {spinner}
1150
+ </div>
1151
+ <div style="color:#888;font-size:11px;margin-top:6px;">Regenerating — please wait…</div>
1152
+ </div>
1153
+ """
1154
+
1155
+
1156
  def _build_waveform_html(audio_path: str, segments: list, slot_id: str,
1157
  hidden_input_id: str) -> str:
1158
  """Return a self-contained HTML block with a Canvas waveform (display only),
 
1404
  def _make_output_slots(tab_prefix: str) -> tuple:
1405
  """Build MAX_SLOTS output groups for one tab.
1406
 
1407
+ Each slot has: video, waveform HTML, hidden regen trigger textbox,
1408
+ hidden JSON state textbox (replaces gr.State to fix Gradio 5 SSR
1409
+ 'Too many arguments' caused by gr.State not counting in endpoint outputs).
1410
  Returns (grps, vids, waveforms, regen_triggers, seg_states).
1411
  """
1412
  grps, vids, waveforms, regen_triggers, seg_states = [], [], [], [], []
 
1417
  waveforms.append(gr.HTML(
1418
  value="<p style='color:#888;font-size:12px'>Generate audio to see waveform.</p>",
1419
  ))
1420
+ # Hidden textbox: JS writes "<slot_id>|<seg_idx>" to trigger regen
1421
  regen_triggers.append(gr.Textbox(
1422
  value="",
1423
  visible=False,
1424
  elem_id=f"regen_trigger_{slot_id}",
1425
  label=f"regen_trigger_{slot_id}",
1426
  ))
1427
+ # Hidden JSON textbox instead of gr.State — Gradio 5 SSR counts
1428
+ # gr.Textbox correctly in endpoint outputs but not gr.State.
1429
+ seg_states.append(gr.Textbox(
1430
+ value="",
1431
+ visible=False,
1432
+ label=f"seg_state_{slot_id}",
1433
+ ))
1434
  grps.append(g)
1435
  return grps, vids, waveforms, regen_triggers, seg_states
1436
 
 
1457
  hidden_el_id = f"regen_trigger_{slot_id}"
1458
  html = _build_waveform_html(aud_path, meta["segments"], slot_id, hidden_el_id)
1459
  wave_updates.append(gr.update(value=html))
1460
+ # Serialize meta to JSON string (seg_states are now gr.Textbox)
1461
+ state_updates.append(gr.update(value=json.dumps(meta)))
1462
  else:
1463
  wave_updates.append(gr.update(
1464
  value="<p style='color:#888;font-size:12px'>Generate audio to see waveform.</p>"
1465
  ))
1466
+ state_updates.append(gr.update(value=""))
1467
  return vid_updates + wave_updates + state_updates
1468
 
1469
 
 
1648
  for _i, _rtrig in enumerate(taro_slot_rtrigs):
1649
  _slot_id = f"taro_{_i}"
1650
  def _make_taro_regen(_si, _sid):
1651
+ def _do(trigger_val, video, seed, cfg, steps, mode, cf_dur, cf_db, state_json):
1652
+ if not trigger_val or not state_json:
1653
+ return gr.update(), gr.update(), gr.update(value=""), gr.update()
1654
  parts = trigger_val.split("|")
1655
  if len(parts) != 2 or parts[0] != _sid:
1656
+ return gr.update(), gr.update(), gr.update(value=""), gr.update()
1657
+ seg_idx = int(parts[1])
1658
+ # Acquire per-slot lock so concurrent regens on the same slot
1659
+ # don't read stale state (second regen waits for first to finish).
1660
+ lock = _get_slot_lock(_sid)
1661
+ with lock:
1662
+ state = json.loads(state_json)
1663
+ pending_html = _build_regen_pending_html(
1664
+ state["segments"], seg_idx, _sid,
1665
+ f"regen_trigger_{_sid}"
1666
+ )
1667
+ yield gr.update(), gr.update(value=pending_html), gr.update(value=state_json), gr.update()
1668
+ vid, aud, new_meta_json, html = regen_taro_segment(
1669
+ video, seg_idx, state_json,
1670
+ seed, cfg, steps, mode, cf_dur, cf_db, _sid,
1671
+ )
1672
+ yield gr.update(value=vid), gr.update(value=html), gr.update(value=new_meta_json), gr.update(value="")
1673
  return _do
1674
  _rtrig.change(
1675
  fn=_make_taro_regen(_i, _slot_id),
 
1730
  for _i, _rtrig in enumerate(mma_slot_rtrigs):
1731
  _slot_id = f"mma_{_i}"
1732
  def _make_mma_regen(_si, _sid):
1733
+ def _do(trigger_val, video, prompt, neg, seed, cfg, steps, cf_dur, cf_db, state_json):
1734
+ if not trigger_val or not state_json:
1735
+ return gr.update(), gr.update(), gr.update(value=""), gr.update()
1736
  parts = trigger_val.split("|")
1737
  if len(parts) != 2 or parts[0] != _sid:
1738
+ return gr.update(), gr.update(), gr.update(value=""), gr.update()
1739
+ seg_idx = int(parts[1])
1740
+ lock = _get_slot_lock(_sid)
1741
+ with lock:
1742
+ state = json.loads(state_json)
1743
+ pending_html = _build_regen_pending_html(
1744
+ state["segments"], seg_idx, _sid,
1745
+ f"regen_trigger_{_sid}"
1746
+ )
1747
+ yield gr.update(), gr.update(value=pending_html), gr.update(value=state_json), gr.update()
1748
+ vid, aud, new_meta_json, html = regen_mmaudio_segment(
1749
+ video, seg_idx, state_json,
1750
+ prompt, neg, seed, cfg, steps, cf_dur, cf_db, _sid,
1751
+ )
1752
+ yield gr.update(value=vid), gr.update(value=html), gr.update(value=new_meta_json), gr.update(value="")
1753
  return _do
1754
  _rtrig.change(
1755
  fn=_make_mma_regen(_i, _slot_id),
 
1811
  for _i, _rtrig in enumerate(hf_slot_rtrigs):
1812
  _slot_id = f"hf_{_i}"
1813
  def _make_hf_regen(_si, _sid):
1814
+ def _do(trigger_val, video, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, state_json):
1815
+ if not trigger_val or not state_json:
1816
+ return gr.update(), gr.update(), gr.update(value=""), gr.update()
1817
  parts = trigger_val.split("|")
1818
  if len(parts) != 2 or parts[0] != _sid:
1819
+ return gr.update(), gr.update(), gr.update(value=""), gr.update()
1820
+ seg_idx = int(parts[1])
1821
+ lock = _get_slot_lock(_sid)
1822
+ with lock:
1823
+ state = json.loads(state_json)
1824
+ pending_html = _build_regen_pending_html(
1825
+ state["segments"], seg_idx, _sid,
1826
+ f"regen_trigger_{_sid}"
1827
+ )
1828
+ yield gr.update(), gr.update(value=pending_html), gr.update(value=state_json), gr.update()
1829
+ vid, aud, new_meta_json, html = regen_hunyuan_segment(
1830
+ video, seg_idx, state_json,
1831
+ prompt, neg, seed, guidance, steps, size, cf_dur, cf_db, _sid,
1832
+ )
1833
+ yield gr.update(value=vid), gr.update(value=html), gr.update(value=new_meta_json), gr.update(value="")
1834
  return _do
1835
  _rtrig.change(
1836
  fn=_make_hf_regen(_i, _slot_id),