BoxOfColors Claude Sonnet 4.6 commited on
Commit
121a071
·
1 Parent(s): b60f330

fix: cross-model regen — fix stereo/mono mismatch and add pending spinner

Browse files

Two bugs fixed:

1. ValueError 'arrays must have same number of dimensions':
TARO outputs mono (T,) while MMAudio/Hunyuan output stereo (C, T).
_resample_to_slot_sr now takes a slot_wav_ref and matches channel
layout after resampling — stereo→mono averages channels, mono→stereo
duplicates the channel — so _cf_join always receives matching shapes.

2. No loading indicator on cross-model regen buttons:
xregen_* functions were plain return functions; the pending waveform
spinner only appeared via the Python yield in same-model regen. All
three xregen_* are now generators that yield the pending HTML
immediately before calling the GPU, matching the same-model behaviour.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +56 -22
app.py CHANGED
@@ -1324,35 +1324,59 @@ MODEL_CONFIGS["hunyuan"]["regen_fn"] = regen_hunyuan_segment
1324
  # (44.1 kHz) / Hunyuan (48 kHz) outputs can all be mixed freely. #
1325
  # ================================================================== #
1326
 
1327
- def _resample_to_slot_sr(wav: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
1328
- """Resample *wav* from src_sr to dst_sr using torchaudio.
1329
- Works for mono (T,) and stereo (C, T) numpy arrays."""
1330
- if src_sr == dst_sr:
1331
- return wav
1332
- stereo = wav.ndim == 2
1333
- t = torch.from_numpy(np.ascontiguousarray(wav))
1334
- if not stereo:
1335
- t = t.unsqueeze(0) # (1, T)
1336
- t = torchaudio.functional.resample(t.float(), src_sr, dst_sr)
1337
- if not stereo:
1338
- t = t.squeeze(0) # (T,)
1339
- return t.numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1340
 
1341
 
1342
  def xregen_taro(seg_idx, state_json, slot_id,
1343
  seed_val, cfg_scale, num_steps, mode,
1344
  crossfade_s, crossfade_db):
1345
  """Cross-model regen: run TARO inference and splice into *slot_id*."""
1346
- meta = json.loads(state_json)
1347
- slot_sr = int(meta["sr"])
 
 
 
 
 
 
1348
  new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
1349
  seed_val, cfg_scale, num_steps, mode,
1350
  crossfade_s, crossfade_db, slot_id)
1351
- new_wav = _resample_to_slot_sr(new_wav_raw, TARO_SR, slot_sr)
 
1352
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1353
- new_wav, int(seg_idx), meta, slot_id
1354
  )
1355
- return gr.update(value=video_path), gr.update(value=waveform_html)
1356
 
1357
 
1358
  def xregen_mmaudio(seg_idx, state_json, slot_id,
@@ -1365,6 +1389,10 @@ def xregen_mmaudio(seg_idx, state_json, slot_id,
1365
  seg_dur = seg_end - seg_start
1366
  slot_sr = int(meta["sr"])
1367
 
 
 
 
 
1368
  silent_video = meta["silent_video"]
1369
  tmp_dir = tempfile.mkdtemp()
1370
  seg_path = os.path.join(tmp_dir, "xregen_seg.mp4")
@@ -1377,11 +1405,12 @@ def xregen_mmaudio(seg_idx, state_json, slot_id,
1377
  prompt, negative_prompt, seed_val,
1378
  cfg_strength, num_steps,
1379
  crossfade_s, crossfade_db, slot_id)
1380
- new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr)
 
1381
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1382
  new_wav, seg_idx, meta, slot_id
1383
  )
1384
- return gr.update(value=video_path), gr.update(value=waveform_html)
1385
 
1386
 
1387
  def xregen_hunyuan(seg_idx, state_json, slot_id,
@@ -1395,6 +1424,10 @@ def xregen_hunyuan(seg_idx, state_json, slot_id,
1395
  seg_dur = seg_end - seg_start
1396
  slot_sr = int(meta["sr"])
1397
 
 
 
 
 
1398
  silent_video = meta["silent_video"]
1399
  tmp_dir = tempfile.mkdtemp()
1400
  seg_path = os.path.join(tmp_dir, "xregen_seg.mp4")
@@ -1407,11 +1440,12 @@ def xregen_hunyuan(seg_idx, state_json, slot_id,
1407
  prompt, negative_prompt, seed_val,
1408
  guidance_scale, num_steps, model_size,
1409
  crossfade_s, crossfade_db, slot_id)
1410
- new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr)
 
1411
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1412
  new_wav, seg_idx, meta, slot_id
1413
  )
1414
- return gr.update(value=video_path), gr.update(value=waveform_html)
1415
 
1416
 
1417
  # ================================================================== #
 
1324
  # (44.1 kHz) / Hunyuan (48 kHz) outputs can all be mixed freely. #
1325
  # ================================================================== #
1326
 
1327
+ def _resample_to_slot_sr(wav: np.ndarray, src_sr: int, dst_sr: int,
1328
+ slot_wav_ref: np.ndarray = None) -> np.ndarray:
1329
+ """Resample *wav* from src_sr to dst_sr using torchaudio, then match
1330
+ channel layout to *slot_wav_ref* (the first existing segment in the slot).
1331
+
1332
+ TARO is mono (T,), MMAudio/Hunyuan are stereo (C, T). Mixing them
1333
+ without normalisation causes a shape mismatch in _cf_join. Rules:
1334
+ stereo → mono : average channels
1335
+ mono → stereo: duplicate the single channel
1336
+ """
1337
+ # 1. Resample
1338
+ if src_sr != dst_sr:
1339
+ stereo_in = wav.ndim == 2
1340
+ t = torch.from_numpy(np.ascontiguousarray(wav))
1341
+ if not stereo_in:
1342
+ t = t.unsqueeze(0)
1343
+ t = torchaudio.functional.resample(t.float(), src_sr, dst_sr)
1344
+ if not stereo_in:
1345
+ t = t.squeeze(0)
1346
+ wav = t.numpy()
1347
+
1348
+ # 2. Match channel layout to the slot's existing segments
1349
+ if slot_wav_ref is not None:
1350
+ slot_stereo = slot_wav_ref.ndim == 2
1351
+ wav_stereo = wav.ndim == 2
1352
+ if slot_stereo and not wav_stereo:
1353
+ wav = np.stack([wav, wav], axis=0) # mono → stereo (C, T)
1354
+ elif not slot_stereo and wav_stereo:
1355
+ wav = wav.mean(axis=0) # stereo → mono (T,)
1356
+ return wav
1357
 
1358
 
1359
  def xregen_taro(seg_idx, state_json, slot_id,
1360
  seed_val, cfg_scale, num_steps, mode,
1361
  crossfade_s, crossfade_db):
1362
  """Cross-model regen: run TARO inference and splice into *slot_id*."""
1363
+ meta = json.loads(state_json)
1364
+ seg_idx = int(seg_idx)
1365
+ slot_sr = int(meta["sr"])
1366
+
1367
+ # Show pending waveform immediately
1368
+ pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
1369
+ yield gr.update(), gr.update(value=pending_html)
1370
+
1371
  new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
1372
  seed_val, cfg_scale, num_steps, mode,
1373
  crossfade_s, crossfade_db, slot_id)
1374
+ slot_wavs = _load_seg_wavs(meta["wav_paths"])
1375
+ new_wav = _resample_to_slot_sr(new_wav_raw, TARO_SR, slot_sr, slot_wavs[0])
1376
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1377
+ new_wav, seg_idx, meta, slot_id
1378
  )
1379
+ yield gr.update(value=video_path), gr.update(value=waveform_html)
1380
 
1381
 
1382
  def xregen_mmaudio(seg_idx, state_json, slot_id,
 
1389
  seg_dur = seg_end - seg_start
1390
  slot_sr = int(meta["sr"])
1391
 
1392
+ # Show pending waveform immediately
1393
+ pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
1394
+ yield gr.update(), gr.update(value=pending_html)
1395
+
1396
  silent_video = meta["silent_video"]
1397
  tmp_dir = tempfile.mkdtemp()
1398
  seg_path = os.path.join(tmp_dir, "xregen_seg.mp4")
 
1405
  prompt, negative_prompt, seed_val,
1406
  cfg_strength, num_steps,
1407
  crossfade_s, crossfade_db, slot_id)
1408
+ slot_wavs = _load_seg_wavs(meta["wav_paths"])
1409
+ new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr, slot_wavs[0])
1410
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1411
  new_wav, seg_idx, meta, slot_id
1412
  )
1413
+ yield gr.update(value=video_path), gr.update(value=waveform_html)
1414
 
1415
 
1416
  def xregen_hunyuan(seg_idx, state_json, slot_id,
 
1424
  seg_dur = seg_end - seg_start
1425
  slot_sr = int(meta["sr"])
1426
 
1427
+ # Show pending waveform immediately
1428
+ pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
1429
+ yield gr.update(), gr.update(value=pending_html)
1430
+
1431
  silent_video = meta["silent_video"]
1432
  tmp_dir = tempfile.mkdtemp()
1433
  seg_path = os.path.join(tmp_dir, "xregen_seg.mp4")
 
1440
  prompt, negative_prompt, seed_val,
1441
  guidance_scale, num_steps, model_size,
1442
  crossfade_s, crossfade_db, slot_id)
1443
+ slot_wavs = _load_seg_wavs(meta["wav_paths"])
1444
+ new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr, slot_wavs[0])
1445
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1446
  new_wav, seg_idx, meta, slot_id
1447
  )
1448
+ yield gr.update(value=video_path), gr.update(value=waveform_html)
1449
 
1450
 
1451
  # ================================================================== #