BoxOfColors Claude Sonnet 4.6 commited on
Commit
e152b28
Β·
1 Parent(s): 7592f82

fix: replace threading.local with caller-thread-id dict for ZeroGPU context passing

Browse files

ZeroGPU dispatches @spaces.GPU functions on its own worker thread pool, not
the Gradio handler thread. threading.local() values are invisible across
threads, causing AttributeError when the GPU worker reads the context.

Replace _tl (threading.local) with _CTX (dict keyed by caller thread ID)
+ _ctx_set/_ctx_get helpers. The caller writes context under (tid, key)
before invoking the GPU function; the GPU worker reads using the same tid
(ZeroGPU runs synchronously on behalf of the caller, so no concurrent write
occurs). Entries are popped on read to prevent unbounded growth.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +47 -28
app.py CHANGED
@@ -124,13 +124,32 @@ print(f"[startup] All downloads done in {time.perf_counter() - _t_dl_start:.1f}s
124
  # SHARED CONSTANTS / HELPERS #
125
  # ================================================================== #
126
 
127
- # Thread-local storage for CPU β†’ GPU context passing.
128
  # Replaces the fragile function-attribute pattern (_fn._cpu_ctx = {...}).
129
- # Each wrapper writes its context under a unique key before calling the
130
- # @spaces.GPU function; the GPU function reads it back. Using thread-local
131
- # storage means concurrent requests on different threads don't clobber
132
- # each other's context β€” the function-attribute approach was not thread-safe.
133
- _tl = threading.local()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  MAX_SLOTS = 8 # max parallel generation slots shown in UI
136
  MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≀ ~64 s at 8 s/seg)
@@ -788,8 +807,8 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
788
  from TARO.onset_util import extract_onset
789
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
790
 
791
- # Use pre-computed CPU results passed via thread-local storage
792
- ctx = _tl.taro_gen_ctx
793
  tmp_dir = ctx["tmp_dir"]
794
  silent_video = ctx["silent_video"]
795
  segments = ctx["segments"]
@@ -861,11 +880,11 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
861
  tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
862
  video_file, TARO_MODEL_DUR, crossfade_s)
863
 
864
- # Pass pre-computed CPU results to the GPU function via thread-local storage
865
- _tl.taro_gen_ctx = {
866
  "tmp_dir": tmp_dir, "silent_video": silent_video,
867
  "segments": segments, "total_dur_s": total_dur_s,
868
- }
869
 
870
  # ── GPU inference only ──
871
  results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
@@ -941,7 +960,7 @@ def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
941
 
942
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
943
 
944
- ctx = _tl.mmaudio_gen_ctx
945
  segments = ctx["segments"]
946
  seg_clip_paths = ctx["seg_clip_paths"]
947
 
@@ -1020,9 +1039,9 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
1020
  for i, (s, e) in enumerate(segments)
1021
  ]
1022
 
1023
- _tl.mmaudio_gen_ctx = {
1024
  "segments": segments, "seg_clip_paths": seg_clip_paths,
1025
- }
1026
 
1027
  # ── GPU inference only ──
1028
  results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
@@ -1090,7 +1109,7 @@ def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1090
 
1091
  model_dict, cfg = _load_hunyuan_model(device, model_size)
1092
 
1093
- ctx = _tl.hunyuan_gen_ctx
1094
  segments = ctx["segments"]
1095
  total_dur_s = ctx["total_dur_s"]
1096
  dummy_seg_path = ctx["dummy_seg_path"]
@@ -1174,10 +1193,10 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
1174
  for i, (s, e) in enumerate(segments)
1175
  ]
1176
 
1177
- _tl.hunyuan_gen_ctx = {
1178
  "segments": segments, "total_dur_s": total_dur_s,
1179
  "dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
1180
- }
1181
 
1182
  # ── GPU inference only ──
1183
  results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
@@ -1213,7 +1232,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
1213
 
1214
  def _preload_taro_regen_ctx(meta: dict) -> dict:
1215
  """Pre-load TARO CAVP/onset features on CPU for regen.
1216
- Returns a dict for _tl.taro_regen_ctx (thread-local storage)."""
1217
  cavp_path = meta.get("cavp_path", "")
1218
  onset_path = meta.get("onset_path", "")
1219
  ctx = {}
@@ -1225,7 +1244,7 @@ def _preload_taro_regen_ctx(meta: dict) -> dict:
1225
 
1226
  def _preload_hunyuan_regen_ctx(meta: dict, seg_path: str) -> dict:
1227
  """Pre-load HunyuanFoley text features + segment path on CPU for regen.
1228
- Returns a dict for _tl.hunyuan_regen_ctx (thread-local storage)."""
1229
  ctx = {"seg_path": seg_path}
1230
  text_feats_path = meta.get("text_feats_path", "")
1231
  if text_feats_path and os.path.exists(text_feats_path):
@@ -1316,7 +1335,7 @@ def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
1316
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
1317
 
1318
  # Use pre-loaded features from CPU wrapper (avoids np.load inside GPU window)
1319
- ctx = getattr(_tl, "taro_regen_ctx", {})
1320
  if "cavp" in ctx and "onset" in ctx:
1321
  print("[TARO regen] Using pre-loaded CAVP + onset features (CPU cache hit)")
1322
  cavp_feats = ctx["cavp"]
@@ -1354,7 +1373,7 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
1354
  seg_idx = int(seg_idx)
1355
 
1356
  # CPU: pre-load cached features so np.load doesn't happen inside GPU window
1357
- _tl.taro_regen_ctx = _preload_taro_regen_ctx(meta)
1358
 
1359
  # GPU: inference only
1360
  new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
@@ -1396,7 +1415,7 @@ def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1396
  sr = seq_cfg.sampling_rate
1397
 
1398
  # Use pre-extracted segment clip from the CPU wrapper
1399
- seg_path = getattr(_tl, "mmaudio_regen_ctx", {}).get("seg_path")
1400
  assert seg_path, "[MMAudio regen] seg_path not set β€” wrapper must pre-extract segment clip"
1401
 
1402
  rng = torch.Generator(device=device)
@@ -1438,7 +1457,7 @@ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
1438
  meta["silent_video"], seg_start, seg_dur,
1439
  os.path.join(tmp_dir, "regen_seg.mp4"),
1440
  )
1441
- _tl.mmaudio_regen_ctx = {"seg_path": seg_path}
1442
 
1443
  # GPU: inference only
1444
  new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
@@ -1488,7 +1507,7 @@ def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1488
  set_global_seed(random.randint(0, 2**32 - 1))
1489
 
1490
  # Use pre-extracted segment clip + text_feats from CPU wrapper
1491
- ctx = getattr(_tl, "hunyuan_regen_ctx", {})
1492
  seg_path = ctx.get("seg_path")
1493
  assert seg_path, "[HunyuanFoley regen] seg_path not set β€” wrapper must pre-extract segment clip"
1494
 
@@ -1531,7 +1550,7 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1531
  meta["silent_video"], seg_start, seg_dur,
1532
  os.path.join(tmp_dir, "regen_seg.mp4"),
1533
  )
1534
- _tl.hunyuan_regen_ctx = _preload_hunyuan_regen_ctx(meta, seg_path)
1535
 
1536
  # GPU: inference only
1537
  new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
@@ -1631,7 +1650,7 @@ def xregen_taro(seg_idx, state_json, slot_id,
1631
  meta = json.loads(state_json)
1632
 
1633
  def _run():
1634
- _tl.taro_regen_ctx = _preload_taro_regen_ctx(meta)
1635
  wav = _regen_taro_gpu(None, seg_idx, state_json,
1636
  seed_val, cfg_scale, num_steps, mode,
1637
  crossfade_s, crossfade_db, slot_id)
@@ -1654,7 +1673,7 @@ def xregen_mmaudio(seg_idx, state_json, slot_id,
1654
  meta["silent_video"], seg_start, seg_end - seg_start,
1655
  os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1656
  )
1657
- _tl.mmaudio_regen_ctx = {"seg_path": seg_path}
1658
  wav, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
1659
  prompt, negative_prompt, seed_val,
1660
  cfg_strength, num_steps,
@@ -1679,7 +1698,7 @@ def xregen_hunyuan(seg_idx, state_json, slot_id,
1679
  meta["silent_video"], seg_start, seg_end - seg_start,
1680
  os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1681
  )
1682
- _tl.hunyuan_regen_ctx = _preload_hunyuan_regen_ctx(meta, seg_path)
1683
  wav, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
1684
  prompt, negative_prompt, seed_val,
1685
  guidance_scale, num_steps, model_size,
 
124
  # SHARED CONSTANTS / HELPERS #
125
  # ================================================================== #
126
 
127
+ # Per-caller-thread context store for CPU β†’ GPU context passing.
128
  # Replaces the fragile function-attribute pattern (_fn._cpu_ctx = {...}).
129
+ #
130
+ # WHY NOT threading.local():
131
+ # ZeroGPU dispatches @spaces.GPU functions on its OWN worker thread, not
132
+ # the Gradio handler thread. threading.local() values are invisible across
133
+ # threads, so the GPU worker would always see an empty namespace.
134
+ #
135
+ # SOLUTION β€” a plain dict keyed by (caller_thread_id, context_name):
136
+ # The wrapper writes _CTX[(tid, key)] = value before calling the GPU fn.
137
+ # The GPU fn reads _CTX.get((tid, key)) β€” same tid because ZeroGPU runs
138
+ # the function synchronously on behalf of the calling thread (the caller
139
+ # blocks until the GPU task completes, so there is no concurrent write).
140
+ # Entries are deleted after the GPU fn reads them to avoid memory leaks.
141
+ _CTX: dict = {}
142
+ _CTX_LOCK = threading.Lock()
143
+
144
+ def _ctx_set(key: str, value) -> None:
145
+ tid = threading.get_ident()
146
+ with _CTX_LOCK:
147
+ _CTX[(tid, key)] = value
148
+
149
+ def _ctx_get(key: str, default=None):
150
+ tid = threading.get_ident()
151
+ with _CTX_LOCK:
152
+ return _CTX.pop((tid, key), default)
153
 
154
  MAX_SLOTS = 8 # max parallel generation slots shown in UI
155
  MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≀ ~64 s at 8 s/seg)
 
807
  from TARO.onset_util import extract_onset
808
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
809
 
810
+ # Use pre-computed CPU results passed via cross-thread context store
811
+ ctx = _ctx_get("taro_gen_ctx")
812
  tmp_dir = ctx["tmp_dir"]
813
  silent_video = ctx["silent_video"]
814
  segments = ctx["segments"]
 
880
  tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
881
  video_file, TARO_MODEL_DUR, crossfade_s)
882
 
883
+ # Pass pre-computed CPU results to the GPU function via cross-thread context store
884
+ _ctx_set("taro_gen_ctx", {
885
  "tmp_dir": tmp_dir, "silent_video": silent_video,
886
  "segments": segments, "total_dur_s": total_dur_s,
887
+ })
888
 
889
  # ── GPU inference only ──
890
  results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
 
960
 
961
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
962
 
963
+ ctx = _ctx_get("mmaudio_gen_ctx")
964
  segments = ctx["segments"]
965
  seg_clip_paths = ctx["seg_clip_paths"]
966
 
 
1039
  for i, (s, e) in enumerate(segments)
1040
  ]
1041
 
1042
+ _ctx_set("mmaudio_gen_ctx", {
1043
  "segments": segments, "seg_clip_paths": seg_clip_paths,
1044
+ })
1045
 
1046
  # ── GPU inference only ──
1047
  results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
 
1109
 
1110
  model_dict, cfg = _load_hunyuan_model(device, model_size)
1111
 
1112
+ ctx = _ctx_get("hunyuan_gen_ctx")
1113
  segments = ctx["segments"]
1114
  total_dur_s = ctx["total_dur_s"]
1115
  dummy_seg_path = ctx["dummy_seg_path"]
 
1193
  for i, (s, e) in enumerate(segments)
1194
  ]
1195
 
1196
+ _ctx_set("hunyuan_gen_ctx", {
1197
  "segments": segments, "total_dur_s": total_dur_s,
1198
  "dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
1199
+ })
1200
 
1201
  # ── GPU inference only ──
1202
  results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
 
1232
 
1233
  def _preload_taro_regen_ctx(meta: dict) -> dict:
1234
  """Pre-load TARO CAVP/onset features on CPU for regen.
1235
+ Returns a dict for _ctx_set("taro_regen_ctx", ...)."""
1236
  cavp_path = meta.get("cavp_path", "")
1237
  onset_path = meta.get("onset_path", "")
1238
  ctx = {}
 
1244
 
1245
  def _preload_hunyuan_regen_ctx(meta: dict, seg_path: str) -> dict:
1246
  """Pre-load HunyuanFoley text features + segment path on CPU for regen.
1247
+ Returns a dict for _ctx_set("hunyuan_regen_ctx", ...)."""
1248
  ctx = {"seg_path": seg_path}
1249
  text_feats_path = meta.get("text_feats_path", "")
1250
  if text_feats_path and os.path.exists(text_feats_path):
 
1335
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
1336
 
1337
  # Use pre-loaded features from CPU wrapper (avoids np.load inside GPU window)
1338
+ ctx = _ctx_get("taro_regen_ctx", {})
1339
  if "cavp" in ctx and "onset" in ctx:
1340
  print("[TARO regen] Using pre-loaded CAVP + onset features (CPU cache hit)")
1341
  cavp_feats = ctx["cavp"]
 
1373
  seg_idx = int(seg_idx)
1374
 
1375
  # CPU: pre-load cached features so np.load doesn't happen inside GPU window
1376
+ _ctx_set("taro_regen_ctx", _preload_taro_regen_ctx(meta))
1377
 
1378
  # GPU: inference only
1379
  new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
 
1415
  sr = seq_cfg.sampling_rate
1416
 
1417
  # Use pre-extracted segment clip from the CPU wrapper
1418
+ seg_path = _ctx_get("mmaudio_regen_ctx", {}).get("seg_path")
1419
  assert seg_path, "[MMAudio regen] seg_path not set β€” wrapper must pre-extract segment clip"
1420
 
1421
  rng = torch.Generator(device=device)
 
1457
  meta["silent_video"], seg_start, seg_dur,
1458
  os.path.join(tmp_dir, "regen_seg.mp4"),
1459
  )
1460
+ _ctx_set("mmaudio_regen_ctx", {"seg_path": seg_path})
1461
 
1462
  # GPU: inference only
1463
  new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
 
1507
  set_global_seed(random.randint(0, 2**32 - 1))
1508
 
1509
  # Use pre-extracted segment clip + text_feats from CPU wrapper
1510
+ ctx = _ctx_get("hunyuan_regen_ctx", {})
1511
  seg_path = ctx.get("seg_path")
1512
  assert seg_path, "[HunyuanFoley regen] seg_path not set β€” wrapper must pre-extract segment clip"
1513
 
 
1550
  meta["silent_video"], seg_start, seg_dur,
1551
  os.path.join(tmp_dir, "regen_seg.mp4"),
1552
  )
1553
+ _ctx_set("hunyuan_regen_ctx", _preload_hunyuan_regen_ctx(meta, seg_path))
1554
 
1555
  # GPU: inference only
1556
  new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
 
1650
  meta = json.loads(state_json)
1651
 
1652
  def _run():
1653
+ _ctx_set("taro_regen_ctx", _preload_taro_regen_ctx(meta))
1654
  wav = _regen_taro_gpu(None, seg_idx, state_json,
1655
  seed_val, cfg_scale, num_steps, mode,
1656
  crossfade_s, crossfade_db, slot_id)
 
1673
  meta["silent_video"], seg_start, seg_end - seg_start,
1674
  os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1675
  )
1676
+ _ctx_set("mmaudio_regen_ctx", {"seg_path": seg_path})
1677
  wav, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
1678
  prompt, negative_prompt, seed_val,
1679
  cfg_strength, num_steps,
 
1698
  meta["silent_video"], seg_start, seg_end - seg_start,
1699
  os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1700
  )
1701
+ _ctx_set("hunyuan_regen_ctx", _preload_hunyuan_regen_ctx(meta, seg_path))
1702
  wav, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
1703
  prompt, negative_prompt, seed_val,
1704
  guidance_scale, num_steps, model_size,