BoxOfColors Claude Sonnet 4.6 commited on
Commit
ac67bf3
Β·
1 Parent(s): 4c173d1

fix: use UUID-keyed global dict + ctx_key param for ZeroGPU context passing

Browse files

ctx_json arg was being set to None by ZeroGPU because the duration callable
signature didn't include it β€” ZeroGPU validates/forwards args against the
duration fn signature and silently drops any extra params not present there.

Fix: add ctx_key='' to ALL duration callables (not just GPU fns) so the
param survives ZeroGPU's arg-forwarding pipeline. Use a UUID-keyed global
dict (_GPU_CTX) instead of JSON-encoding the full context β€” the UUID is a
tiny hex string that round-trips safely through any arg marshalling, and the
global dict is readable from any thread (unlike threading.local).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +63 -35
app.py CHANGED
@@ -124,10 +124,35 @@ print(f"[startup] All downloads done in {time.perf_counter() - _t_dl_start:.1f}s
124
  # SHARED CONSTANTS / HELPERS #
125
  # ================================================================== #
126
 
127
- # CPU β†’ GPU context passing: each wrapper serialises pre-computed CPU data
128
- # into a JSON string and passes it as the last argument (ctx_json) to the
129
- # @spaces.GPU function. ZeroGPU forwards all arguments to the GPU worker
130
- # unchanged, so no shared state or thread-local tricks are needed.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  MAX_SLOTS = 8 # max parallel generation slots shown in UI
133
  MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≀ ~64 s at 8 s/seg)
@@ -552,7 +577,7 @@ def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: floa
552
 
553
 
554
  def _taro_duration(video_file, seed_val, cfg_scale, num_steps, mode,
555
- crossfade_s, crossfade_db, num_samples):
556
  """Pre-GPU callable β€” must match _taro_gpu_infer's input order exactly."""
557
  return _estimate_gpu_duration("taro", int(num_samples), int(num_steps),
558
  video_file=video_file, crossfade_s=crossfade_s)
@@ -769,7 +794,7 @@ def _cpu_preprocess(video_file: str, model_dur: float,
769
 
770
  @spaces.GPU(duration=_taro_duration)
771
  def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
772
- crossfade_s, crossfade_db, num_samples, ctx_json="{}"):
773
  """GPU-only TARO inference β€” model loading + feature extraction + diffusion.
774
  Returns list of (wavs_list, onset_feats) per sample."""
775
  seed_val = int(seed_val)
@@ -785,7 +810,7 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
785
  from TARO.onset_util import extract_onset
786
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
787
 
788
- ctx = json.loads(ctx_json)
789
  tmp_dir = ctx["tmp_dir"]
790
  silent_video = ctx["silent_video"]
791
  segments = ctx["segments"]
@@ -857,14 +882,14 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
857
  tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
858
  video_file, TARO_MODEL_DUR, crossfade_s)
859
 
860
- ctx_json = json.dumps({
861
  "tmp_dir": tmp_dir, "silent_video": silent_video,
862
  "segments": segments, "total_dur_s": total_dur_s,
863
  })
864
 
865
  # ── GPU inference only ──
866
  results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
867
- crossfade_s, crossfade_db, num_samples, ctx_json)
868
 
869
  # ── CPU post-processing (no GPU needed) ──
870
  # Upsample 16kHz β†’ 48kHz and normalise result tuples to (seg_wavs, ...)
@@ -913,7 +938,8 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
913
 
914
 
915
  def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
916
- cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
 
917
  """Pre-GPU callable β€” must match _mmaudio_gpu_infer's input order exactly."""
918
  return _estimate_gpu_duration("mmaudio", int(num_samples), int(num_steps),
919
  video_file=video_file, crossfade_s=crossfade_s)
@@ -922,7 +948,7 @@ def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
922
  @spaces.GPU(duration=_mmaudio_duration)
923
  def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
924
  cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples,
925
- ctx_json="{}"):
926
  """GPU-only MMAudio inference β€” model loading + flow-matching generation.
927
  Returns list of (seg_audios, sr) per sample."""
928
  _ensure_syspath("MMAudio")
@@ -937,7 +963,7 @@ def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
937
 
938
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
939
 
940
- ctx = json.loads(ctx_json)
941
  segments = ctx["segments"]
942
  seg_clip_paths = ctx["seg_clip_paths"]
943
 
@@ -1016,12 +1042,12 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
1016
  for i, (s, e) in enumerate(segments)
1017
  ]
1018
 
1019
- ctx_json = json.dumps({"segments": segments, "seg_clip_paths": seg_clip_paths})
1020
 
1021
  # ── GPU inference only ──
1022
  results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1023
  cfg_strength, num_steps, crossfade_s, crossfade_db,
1024
- num_samples, ctx_json)
1025
 
1026
  # ── CPU post-processing ──
1027
  # Resample 44100 β†’ 48000 and normalise tuples to (seg_wavs, ...)
@@ -1058,7 +1084,8 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
1058
 
1059
 
1060
  def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
1061
- guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
 
1062
  """Pre-GPU callable β€” must match _hunyuan_gpu_infer's input order exactly."""
1063
  return _estimate_gpu_duration("hunyuan", int(num_samples), int(num_steps),
1064
  video_file=video_file, crossfade_s=crossfade_s)
@@ -1067,7 +1094,7 @@ def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
1067
  @spaces.GPU(duration=_hunyuan_duration)
1068
  def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1069
  guidance_scale, num_steps, model_size, crossfade_s, crossfade_db,
1070
- num_samples, ctx_json="{}"):
1071
  """GPU-only HunyuanFoley inference β€” model loading + feature extraction + denoising.
1072
  Returns list of (seg_wavs, sr, text_feats) per sample."""
1073
  _ensure_syspath("HunyuanVideo-Foley")
@@ -1086,7 +1113,7 @@ def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1086
 
1087
  model_dict, cfg = _load_hunyuan_model(device, model_size)
1088
 
1089
- ctx = json.loads(ctx_json)
1090
  segments = ctx["segments"]
1091
  total_dur_s = ctx["total_dur_s"]
1092
  dummy_seg_path = ctx["dummy_seg_path"]
@@ -1170,7 +1197,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
1170
  for i, (s, e) in enumerate(segments)
1171
  ]
1172
 
1173
- ctx_json = json.dumps({
1174
  "segments": segments, "total_dur_s": total_dur_s,
1175
  "dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
1176
  })
@@ -1178,7 +1205,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
1178
  # ── GPU inference only ──
1179
  results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1180
  guidance_scale, num_steps, model_size,
1181
- crossfade_s, crossfade_db, num_samples, ctx_json)
1182
 
1183
  # ── CPU post-processing (no GPU needed) ──
1184
  def _hunyuan_extras(sample_idx, result, td):
@@ -1258,7 +1285,7 @@ def _splice_and_save(new_wav, seg_idx, meta, slot_id):
1258
 
1259
  def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
1260
  seed_val, cfg_scale, num_steps, mode,
1261
- crossfade_s, crossfade_db, slot_id=None):
1262
  # If cached CAVP/onset features exist, skip ~10s feature-extractor overhead
1263
  try:
1264
  meta = json.loads(seg_meta_json)
@@ -1278,7 +1305,7 @@ def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
1278
  @spaces.GPU(duration=_taro_regen_duration)
1279
  def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
1280
  seed_val, cfg_scale, num_steps, mode,
1281
- crossfade_s, crossfade_db, slot_id=None):
1282
  """GPU-only TARO regen β€” returns new_wav for a single segment."""
1283
  meta = json.loads(seg_meta_json)
1284
  seg_idx = int(seg_idx)
@@ -1344,7 +1371,8 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
1344
 
1345
  def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
1346
  prompt, negative_prompt, seed_val,
1347
- cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id=None):
 
1348
  return _estimate_regen_duration("mmaudio", int(num_steps))
1349
 
1350
 
@@ -1352,7 +1380,7 @@ def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
1352
  def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1353
  prompt, negative_prompt, seed_val,
1354
  cfg_strength, num_steps, crossfade_s, crossfade_db,
1355
- slot_id=None, ctx_json="{}"):
1356
  """GPU-only MMAudio regen β€” returns (new_wav, sr) for a single segment."""
1357
  meta = json.loads(seg_meta_json)
1358
  seg_idx = int(seg_idx)
@@ -1368,7 +1396,7 @@ def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1368
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
1369
  sr = seq_cfg.sampling_rate
1370
 
1371
- seg_path = json.loads(ctx_json).get("seg_path")
1372
  assert seg_path, "[MMAudio regen] seg_path not set β€” wrapper must pre-extract segment clip"
1373
 
1374
  rng = torch.Generator(device=device)
@@ -1410,13 +1438,13 @@ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
1410
  meta["silent_video"], seg_start, seg_dur,
1411
  os.path.join(tmp_dir, "regen_seg.mp4"),
1412
  )
1413
- ctx_json = json.dumps({"seg_path": seg_path})
1414
 
1415
  # GPU: inference only
1416
  new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1417
  prompt, negative_prompt, seed_val,
1418
  cfg_strength, num_steps, crossfade_s, crossfade_db,
1419
- slot_id, ctx_json)
1420
 
1421
  # Resample to 48kHz if needed (MMAudio outputs at 44100 Hz)
1422
  if sr != TARGET_SR:
@@ -1435,7 +1463,7 @@ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
1435
  def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json,
1436
  prompt, negative_prompt, seed_val,
1437
  guidance_scale, num_steps, model_size,
1438
- crossfade_s, crossfade_db, slot_id=None):
1439
  return _estimate_regen_duration("hunyuan", int(num_steps))
1440
 
1441
 
@@ -1443,7 +1471,7 @@ def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json,
1443
  def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1444
  prompt, negative_prompt, seed_val,
1445
  guidance_scale, num_steps, model_size,
1446
- crossfade_s, crossfade_db, slot_id=None, ctx_json="{}"):
1447
  """GPU-only HunyuanFoley regen β€” returns (new_wav, sr) for a single segment."""
1448
  meta = json.loads(seg_meta_json)
1449
  seg_idx = int(seg_idx)
@@ -1460,7 +1488,7 @@ def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1460
 
1461
  set_global_seed(random.randint(0, 2**32 - 1))
1462
 
1463
- ctx = json.loads(ctx_json)
1464
  seg_path = ctx.get("seg_path")
1465
  assert seg_path, "[HunyuanFoley regen] seg_path not set β€” wrapper must pre-extract segment clip"
1466
 
@@ -1504,7 +1532,7 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1504
  meta["silent_video"], seg_start, seg_dur,
1505
  os.path.join(tmp_dir, "regen_seg.mp4"),
1506
  )
1507
- ctx_json = json.dumps({
1508
  "seg_path": seg_path,
1509
  "text_feats_path": meta.get("text_feats_path", ""),
1510
  })
@@ -1513,7 +1541,7 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1513
  new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1514
  prompt, negative_prompt, seed_val,
1515
  guidance_scale, num_steps, model_size,
1516
- crossfade_s, crossfade_db, slot_id, ctx_json)
1517
 
1518
  meta["sr"] = sr
1519
 
@@ -1630,11 +1658,11 @@ def xregen_mmaudio(seg_idx, state_json, slot_id,
1630
  meta["silent_video"], seg_start, seg_end - seg_start,
1631
  os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1632
  )
1633
- ctx_json = json.dumps({"seg_path": seg_path})
1634
  wav, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
1635
  prompt, negative_prompt, seed_val,
1636
  cfg_strength, num_steps,
1637
- crossfade_s, crossfade_db, slot_id, ctx_json)
1638
  return wav, src_sr
1639
 
1640
  yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
@@ -1655,14 +1683,14 @@ def xregen_hunyuan(seg_idx, state_json, slot_id,
1655
  meta["silent_video"], seg_start, seg_end - seg_start,
1656
  os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1657
  )
1658
- ctx_json = json.dumps({
1659
  "seg_path": seg_path,
1660
  "text_feats_path": meta.get("text_feats_path", ""),
1661
  })
1662
  wav, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
1663
  prompt, negative_prompt, seed_val,
1664
  guidance_scale, num_steps, model_size,
1665
- crossfade_s, crossfade_db, slot_id, ctx_json)
1666
  return wav, src_sr
1667
 
1668
  yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
 
124
  # SHARED CONSTANTS / HELPERS #
125
  # ================================================================== #
126
 
127
+ # CPU β†’ GPU context passing via UUID-keyed global store.
128
+ #
129
+ # ZeroGPU dispatches @spaces.GPU functions on its own worker thread, so
130
+ # threading.local() doesn't work. Passing context as a function argument
131
+ # is the right idea, but ZeroGPU validates args against the *duration*
132
+ # callable's signature β€” any extra param not present in the duration fn
133
+ # gets dropped or set to None before the GPU fn runs.
134
+ #
135
+ # Solution: add ctx_key="" to BOTH the duration fn AND the GPU fn.
136
+ # The wrapper stores the context dict in _GPU_CTX[uuid] and passes the
137
+ # uuid string as ctx_key. The GPU fn does _GPU_CTX.pop(ctx_key).
138
+ # Since the dict is global (not thread-local), the GPU worker thread can
139
+ # read it regardless of which thread wrote it. The uuid ensures
140
+ # concurrent requests don't collide.
141
+ import uuid as _uuid_mod
142
+ _GPU_CTX: dict = {}
143
+ _GPU_CTX_LOCK = threading.Lock()
144
+
145
+ def _ctx_store(data: dict) -> str:
146
+ """Store *data* in the global context dict; return the UUID key."""
147
+ key = _uuid_mod.uuid4().hex
148
+ with _GPU_CTX_LOCK:
149
+ _GPU_CTX[key] = data
150
+ return key
151
+
152
+ def _ctx_load(key: str) -> dict:
153
+ """Pop and return the context dict for *key*."""
154
+ with _GPU_CTX_LOCK:
155
+ return _GPU_CTX.pop(key, {})
156
 
157
  MAX_SLOTS = 8 # max parallel generation slots shown in UI
158
  MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≀ ~64 s at 8 s/seg)
 
577
 
578
 
579
  def _taro_duration(video_file, seed_val, cfg_scale, num_steps, mode,
580
+ crossfade_s, crossfade_db, num_samples, ctx_key=""):
581
  """Pre-GPU callable β€” must match _taro_gpu_infer's input order exactly."""
582
  return _estimate_gpu_duration("taro", int(num_samples), int(num_steps),
583
  video_file=video_file, crossfade_s=crossfade_s)
 
794
 
795
  @spaces.GPU(duration=_taro_duration)
796
  def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
797
+ crossfade_s, crossfade_db, num_samples, ctx_key=""):
798
  """GPU-only TARO inference β€” model loading + feature extraction + diffusion.
799
  Returns list of (wavs_list, onset_feats) per sample."""
800
  seed_val = int(seed_val)
 
810
  from TARO.onset_util import extract_onset
811
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
812
 
813
+ ctx = _ctx_load(ctx_key)
814
  tmp_dir = ctx["tmp_dir"]
815
  silent_video = ctx["silent_video"]
816
  segments = ctx["segments"]
 
882
  tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
883
  video_file, TARO_MODEL_DUR, crossfade_s)
884
 
885
+ ctx_key = _ctx_store({
886
  "tmp_dir": tmp_dir, "silent_video": silent_video,
887
  "segments": segments, "total_dur_s": total_dur_s,
888
  })
889
 
890
  # ── GPU inference only ──
891
  results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
892
+ crossfade_s, crossfade_db, num_samples, ctx_key)
893
 
894
  # ── CPU post-processing (no GPU needed) ──
895
  # Upsample 16kHz β†’ 48kHz and normalise result tuples to (seg_wavs, ...)
 
938
 
939
 
940
  def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
941
+ cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples,
942
+ ctx_key=""):
943
  """Pre-GPU callable β€” must match _mmaudio_gpu_infer's input order exactly."""
944
  return _estimate_gpu_duration("mmaudio", int(num_samples), int(num_steps),
945
  video_file=video_file, crossfade_s=crossfade_s)
 
948
  @spaces.GPU(duration=_mmaudio_duration)
949
  def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
950
  cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples,
951
+ ctx_key=""):
952
  """GPU-only MMAudio inference β€” model loading + flow-matching generation.
953
  Returns list of (seg_audios, sr) per sample."""
954
  _ensure_syspath("MMAudio")
 
963
 
964
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
965
 
966
+ ctx = _ctx_load(ctx_key)
967
  segments = ctx["segments"]
968
  seg_clip_paths = ctx["seg_clip_paths"]
969
 
 
1042
  for i, (s, e) in enumerate(segments)
1043
  ]
1044
 
1045
+ ctx_key = _ctx_store({"segments": segments, "seg_clip_paths": seg_clip_paths})
1046
 
1047
  # ── GPU inference only ──
1048
  results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1049
  cfg_strength, num_steps, crossfade_s, crossfade_db,
1050
+ num_samples, ctx_key)
1051
 
1052
  # ── CPU post-processing ──
1053
  # Resample 44100 β†’ 48000 and normalise tuples to (seg_wavs, ...)
 
1084
 
1085
 
1086
  def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
1087
+ guidance_scale, num_steps, model_size, crossfade_s, crossfade_db,
1088
+ num_samples, ctx_key=""):
1089
  """Pre-GPU callable β€” must match _hunyuan_gpu_infer's input order exactly."""
1090
  return _estimate_gpu_duration("hunyuan", int(num_samples), int(num_steps),
1091
  video_file=video_file, crossfade_s=crossfade_s)
 
1094
  @spaces.GPU(duration=_hunyuan_duration)
1095
  def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1096
  guidance_scale, num_steps, model_size, crossfade_s, crossfade_db,
1097
+ num_samples, ctx_key=""):
1098
  """GPU-only HunyuanFoley inference β€” model loading + feature extraction + denoising.
1099
  Returns list of (seg_wavs, sr, text_feats) per sample."""
1100
  _ensure_syspath("HunyuanVideo-Foley")
 
1113
 
1114
  model_dict, cfg = _load_hunyuan_model(device, model_size)
1115
 
1116
+ ctx = _ctx_load(ctx_key)
1117
  segments = ctx["segments"]
1118
  total_dur_s = ctx["total_dur_s"]
1119
  dummy_seg_path = ctx["dummy_seg_path"]
 
1197
  for i, (s, e) in enumerate(segments)
1198
  ]
1199
 
1200
+ ctx_key = _ctx_store({
1201
  "segments": segments, "total_dur_s": total_dur_s,
1202
  "dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
1203
  })
 
1205
  # ── GPU inference only ──
1206
  results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1207
  guidance_scale, num_steps, model_size,
1208
+ crossfade_s, crossfade_db, num_samples, ctx_key)
1209
 
1210
  # ── CPU post-processing (no GPU needed) ──
1211
  def _hunyuan_extras(sample_idx, result, td):
 
1285
 
1286
  def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
1287
  seed_val, cfg_scale, num_steps, mode,
1288
+ crossfade_s, crossfade_db, slot_id=None, ctx_key=""):
1289
  # If cached CAVP/onset features exist, skip ~10s feature-extractor overhead
1290
  try:
1291
  meta = json.loads(seg_meta_json)
 
1305
  @spaces.GPU(duration=_taro_regen_duration)
1306
  def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
1307
  seed_val, cfg_scale, num_steps, mode,
1308
+ crossfade_s, crossfade_db, slot_id=None, ctx_key=""):
1309
  """GPU-only TARO regen β€” returns new_wav for a single segment."""
1310
  meta = json.loads(seg_meta_json)
1311
  seg_idx = int(seg_idx)
 
1371
 
1372
  def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
1373
  prompt, negative_prompt, seed_val,
1374
+ cfg_strength, num_steps, crossfade_s, crossfade_db,
1375
+ slot_id=None, ctx_key=""):
1376
  return _estimate_regen_duration("mmaudio", int(num_steps))
1377
 
1378
 
 
1380
  def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1381
  prompt, negative_prompt, seed_val,
1382
  cfg_strength, num_steps, crossfade_s, crossfade_db,
1383
+ slot_id=None, ctx_key=""):
1384
  """GPU-only MMAudio regen β€” returns (new_wav, sr) for a single segment."""
1385
  meta = json.loads(seg_meta_json)
1386
  seg_idx = int(seg_idx)
 
1396
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
1397
  sr = seq_cfg.sampling_rate
1398
 
1399
+ seg_path = _ctx_load(ctx_key).get("seg_path")
1400
  assert seg_path, "[MMAudio regen] seg_path not set β€” wrapper must pre-extract segment clip"
1401
 
1402
  rng = torch.Generator(device=device)
 
1438
  meta["silent_video"], seg_start, seg_dur,
1439
  os.path.join(tmp_dir, "regen_seg.mp4"),
1440
  )
1441
+ ctx_key = _ctx_store({"seg_path": seg_path})
1442
 
1443
  # GPU: inference only
1444
  new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1445
  prompt, negative_prompt, seed_val,
1446
  cfg_strength, num_steps, crossfade_s, crossfade_db,
1447
+ slot_id, ctx_key)
1448
 
1449
  # Resample to 48kHz if needed (MMAudio outputs at 44100 Hz)
1450
  if sr != TARGET_SR:
 
1463
  def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json,
1464
  prompt, negative_prompt, seed_val,
1465
  guidance_scale, num_steps, model_size,
1466
+ crossfade_s, crossfade_db, slot_id=None, ctx_key=""):
1467
  return _estimate_regen_duration("hunyuan", int(num_steps))
1468
 
1469
 
 
1471
  def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1472
  prompt, negative_prompt, seed_val,
1473
  guidance_scale, num_steps, model_size,
1474
+ crossfade_s, crossfade_db, slot_id=None, ctx_key=""):
1475
  """GPU-only HunyuanFoley regen β€” returns (new_wav, sr) for a single segment."""
1476
  meta = json.loads(seg_meta_json)
1477
  seg_idx = int(seg_idx)
 
1488
 
1489
  set_global_seed(random.randint(0, 2**32 - 1))
1490
 
1491
+ ctx = _ctx_load(ctx_key)
1492
  seg_path = ctx.get("seg_path")
1493
  assert seg_path, "[HunyuanFoley regen] seg_path not set β€” wrapper must pre-extract segment clip"
1494
 
 
1532
  meta["silent_video"], seg_start, seg_dur,
1533
  os.path.join(tmp_dir, "regen_seg.mp4"),
1534
  )
1535
+ ctx_key = _ctx_store({
1536
  "seg_path": seg_path,
1537
  "text_feats_path": meta.get("text_feats_path", ""),
1538
  })
 
1541
  new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1542
  prompt, negative_prompt, seed_val,
1543
  guidance_scale, num_steps, model_size,
1544
+ crossfade_s, crossfade_db, slot_id, ctx_key)
1545
 
1546
  meta["sr"] = sr
1547
 
 
1658
  meta["silent_video"], seg_start, seg_end - seg_start,
1659
  os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1660
  )
1661
+ ctx_key = _ctx_store({"seg_path": seg_path})
1662
  wav, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
1663
  prompt, negative_prompt, seed_val,
1664
  cfg_strength, num_steps,
1665
+ crossfade_s, crossfade_db, slot_id, ctx_key)
1666
  return wav, src_sr
1667
 
1668
  yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
 
1683
  meta["silent_video"], seg_start, seg_end - seg_start,
1684
  os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1685
  )
1686
+ ctx_key = _ctx_store({
1687
  "seg_path": seg_path,
1688
  "text_feats_path": meta.get("text_feats_path", ""),
1689
  })
1690
  wav, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
1691
  prompt, negative_prompt, seed_val,
1692
  guidance_scale, num_steps, model_size,
1693
+ crossfade_s, crossfade_db, slot_id, ctx_key)
1694
  return wav, src_sr
1695
 
1696
  yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)