Jack Wu commited on
Commit
6cf4573
Β·
1 Parent(s): 53f384c

Restructure app.py: multi-model support (TARO, MMAudio, HunyuanFoley)

Browse files

- Split into three tabbed UI sections, one per model
- Updated checkpoint repo/folder paths to JackIsNotInTheBox/Generate_Audio_for_Video_Checkpoints
with TARO/, MMAudio/, HunyuanFoley/ subfolders
- TARO: preserve exact infer.py invocation (CAVP+VideoOnsetNet+MMDiT+AudioLDM2 decoder)
with sliding-window segmentation; fix samplers tuple return indexing
- MMAudio: use official load_video()+generate() pipeline from gradio_demo.py;
override model paths to our HF checkpoint repo; large_44k_v2 variant
- HunyuanFoley: use official load_model()+feature_process()+denoise_process()
pipeline; batch inference for multiple samples; xl/xxl size selector
- Add model-specific optimal duration constants from source configs
- Shared slot UI helper; MAX_SLOTS=8 maintained across all tabs

TARO/README.md ADDED
File without changes
{cavp β†’ TARO/cavp}/cavp.yaml RENAMED
File without changes
{cavp β†’ TARO/cavp}/model/cavp_model.py RENAMED
File without changes
{cavp β†’ TARO/cavp}/model/cavp_modules.py RENAMED
File without changes
cavp_util.py β†’ TARO/cavp_util.py RENAMED
File without changes
dataset.py β†’ TARO/dataset.py RENAMED
File without changes
infer.py β†’ TARO/infer.py RENAMED
File without changes
loss.py β†’ TARO/loss.py RENAMED
File without changes
models.py β†’ TARO/models.py RENAMED
File without changes
onset_util.py β†’ TARO/onset_util.py RENAMED
File without changes
{preprocess β†’ TARO/preprocess}/extract_cavp.py RENAMED
File without changes
{preprocess β†’ TARO/preprocess}/extract_fbank.py RENAMED
File without changes
{preprocess β†’ TARO/preprocess}/extract_mel.py RENAMED
File without changes
{preprocess β†’ TARO/preprocess}/extract_onset.py RENAMED
File without changes
samplers.py β†’ TARO/samplers.py RENAMED
File without changes
train.py β†’ TARO/train.py RENAMED
@@ -18,10 +18,10 @@ from accelerate import Accelerator
18
  from accelerate.logging import get_logger
19
  from accelerate.utils import ProjectConfiguration, set_seed
20
 
21
- from models import MMDiT
22
- from loss import SILoss
23
 
24
- from dataset import audio_video_spec_fullset_Dataset_Train, collate_fn_taro
25
  from diffusers import AudioLDM2Pipeline
26
  import wandb
27
 
 
18
  from accelerate.logging import get_logger
19
  from accelerate.utils import ProjectConfiguration, set_seed
20
 
21
+ from TARO.models import MMDiT
22
+ from TARO.loss import SILoss
23
 
24
+ from TARO.dataset import audio_video_spec_fullset_Dataset_Train, collate_fn_taro
25
  from diffusers import AudioLDM2Pipeline
26
  import wandb
27
 
train.sh β†’ TARO/train.sh RENAMED
File without changes
app.py CHANGED
@@ -1,8 +1,24 @@
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import subprocess
3
  import sys
4
- from math import ceil, floor
 
 
 
5
 
 
 
 
6
  try:
7
  import mmcv
8
  print("mmcv already installed")
@@ -13,133 +29,169 @@ except ImportError:
13
 
14
  import torch
15
  import numpy as np
16
- import random
17
  import soundfile as sf
18
  import ffmpeg
19
- import tempfile
20
  import spaces
21
  import gradio as gr
22
  from huggingface_hub import hf_hub_download
23
 
24
- REPO_ID = "JackIsNotInTheBox/Taro_checkpoints"
25
- CACHE_DIR = "/tmp/taro_ckpts"
26
- os.makedirs(CACHE_DIR, exist_ok=True)
27
-
28
- print("Downloading checkpoints...")
29
- cavp_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="cavp_epoch66.ckpt", cache_dir=CACHE_DIR)
30
- onset_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="onset_model.ckpt", cache_dir=CACHE_DIR)
31
- taro_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="taro_ckpt.pt", cache_dir=CACHE_DIR)
32
- print("Checkpoints downloaded.")
33
-
34
- # Model constants
35
- SR = 16000
36
- TRUNCATE = 131072
37
- FPS = 4
38
- TRUNCATE_FRAME = int(FPS * TRUNCATE / SR) # 32 cavp frames per model window
39
- TRUNCATE_ONSET = 120 # onset frames per model window
40
- MODEL_DUR = TRUNCATE / SR # 8.192 s
41
- MAX_SLOTS = 8 # max sample output slots in UI
42
- SECS_PER_STEP = 2.5 # estimated seconds of GPU time per diffusion step
43
-
44
- # ------------------------------------------------------------------ #
45
- # Inference cache #
46
- # Key: (video_path, seed, cfg_scale, num_steps, mode, crossfade_s) #
47
- # Value: {"wavs": [...], "total_dur_s": float, #
48
- # "tmp_dir": str, "silent_video": str} #
49
- # ------------------------------------------------------------------ #
50
- _INFERENCE_CACHE = {}
51
 
 
 
 
52
 
53
- def set_global_seed(seed):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  np.random.seed(seed % (2**32))
55
  random.seed(seed)
56
  torch.manual_seed(seed)
57
  torch.cuda.manual_seed(seed)
58
  torch.backends.cudnn.deterministic = True
59
 
 
 
60
 
61
- def strip_audio_from_video(video_path, output_path):
62
- """Strip any existing audio from a video file, outputting a silent video."""
63
- (
64
- ffmpeg
65
- .input(video_path)
66
- .output(output_path, vcodec="libx264", an=None)
67
- .run(overwrite_output=True, quiet=True)
68
- )
69
-
70
-
71
- def get_video_duration(video_path):
72
- """Read video duration in seconds using ffprobe (no GPU needed)."""
73
  probe = ffmpeg.probe(video_path)
74
  return float(probe["format"]["duration"])
75
 
 
 
 
 
 
76
 
77
- def build_segments(total_dur_s, crossfade_s):
78
- """
79
- Build list of (seg_start_s, seg_end_s) segment windows.
80
-
81
- For videos <= MODEL_DUR: single segment [0, total_dur_s].
82
- For longer videos: advance by step_s = MODEL_DUR - crossfade_s each time.
83
- The LAST segment is always anchored at [total_dur_s - MODEL_DUR, total_dur_s]
84
- so it is a full-length window with no zero-padding, giving the best quality
85
- at the tail end of the video.
86
- """
87
- if total_dur_s <= MODEL_DUR:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  return [(0.0, total_dur_s)]
89
-
90
- step_s = MODEL_DUR - crossfade_s
91
- segments = []
92
- seg_start = 0.0
93
  while True:
94
- seg_end = seg_start + MODEL_DUR
95
- if seg_end >= total_dur_s:
96
- # Replace this segment with a full-length tail-anchored window
97
- seg_start = max(0.0, total_dur_s - MODEL_DUR)
98
  segments.append((seg_start, total_dur_s))
99
  break
100
- segments.append((seg_start, seg_start + MODEL_DUR))
101
  seg_start += step_s
102
-
103
  return segments
104
 
105
 
106
- def calc_max_samples(total_dur_s, num_steps, crossfade_s):
107
- """Estimate max samples that fit within the 600s ZeroGPU budget."""
108
- num_segments = len(build_segments(total_dur_s, crossfade_s))
109
- time_per_seg = num_steps * SECS_PER_STEP
110
- budget = 600.0
111
- max_s = floor(budget / (num_segments * time_per_seg))
112
  return max(1, min(max_s, MAX_SLOTS))
113
 
114
 
115
- def infer_segment(model, vae, vocoder, cavp_feats_full, onset_feats_full,
116
- seg_start_s, seg_end_s,
117
- device, weight_dtype,
118
- cfg_scale, num_steps, mode,
119
- latents_scale,
120
- euler_sampler, euler_maruyama_sampler):
121
- """Run one model inference pass. Returns wav trimmed to segment duration."""
 
 
 
122
  # CAVP features (4 fps)
123
- cavp_start = int(round(seg_start_s * FPS))
124
- cavp_slice = cavp_feats_full[cavp_start : cavp_start + TRUNCATE_FRAME]
125
- if cavp_slice.shape[0] < TRUNCATE_FRAME:
126
  pad = np.zeros(
127
- (TRUNCATE_FRAME - cavp_slice.shape[0],) + cavp_slice.shape[1:],
128
  dtype=cavp_slice.dtype,
129
  )
130
  cavp_slice = np.concatenate([cavp_slice, pad], axis=0)
131
- video_feats = torch.from_numpy(cavp_slice).unsqueeze(0).to(device).to(weight_dtype)
132
 
133
- # Onset features
134
- onset_fps = TRUNCATE_ONSET / MODEL_DUR
135
  onset_start = int(round(seg_start_s * onset_fps))
136
- onset_slice = onset_feats_full[onset_start : onset_start + TRUNCATE_ONSET]
137
- if onset_slice.shape[0] < TRUNCATE_ONSET:
138
- pad_len = TRUNCATE_ONSET - onset_slice.shape[0]
139
- onset_slice = np.pad(onset_slice, ((0, pad_len),), mode="constant", constant_values=0)
140
- onset_feats_t = torch.from_numpy(onset_slice).unsqueeze(0).to(device).to(weight_dtype)
 
 
 
 
 
 
141
 
142
- z = torch.randn(1, model.in_channels, 204, 16, device=device).to(weight_dtype)
143
  sampling_kwargs = dict(
144
  model=model,
145
  latents=z,
@@ -153,169 +205,119 @@ def infer_segment(model, vae, vocoder, cavp_feats_full, onset_feats_full,
153
  path_type="linear",
154
  )
155
  with torch.no_grad():
156
- if mode == "sde":
157
- samples = euler_maruyama_sampler(**sampling_kwargs)
158
- else:
159
- samples = euler_sampler(**sampling_kwargs)
160
 
 
161
  samples = vae.decode(samples / latents_scale).sample
162
  wav = vocoder(samples.squeeze().float()).detach().cpu().numpy()
163
- seg_samples = int(round((seg_end_s - seg_start_s) * SR))
164
  return wav[:seg_samples]
165
 
166
 
167
- def crossfade_join(wav_a, wav_b, crossfade_s, db_boost):
168
- """
169
- Join two wav arrays with a crossfade.
170
- Both signals are scaled by gain = 10^(db_boost/20) in the overlap region
171
- and summed, producing a +db_boost bump at the midpoint.
172
- """
173
- cf_samples = int(round(crossfade_s * SR))
174
- cf_samples = min(cf_samples, len(wav_a), len(wav_b))
175
- if cf_samples <= 0:
176
  return np.concatenate([wav_a, wav_b])
177
-
178
  gain = 10 ** (db_boost / 20.0)
179
- overlap = wav_a[-cf_samples:] * gain + wav_b[:cf_samples] * gain
180
-
181
- return np.concatenate([wav_a[:-cf_samples], overlap, wav_b[cf_samples:]])
182
-
183
-
184
- def stitch_wavs(wavs, crossfade_s, db_boost, total_dur_s):
185
- """Stitch segment wavs with crossfades and clip to total_dur_s."""
186
- if len(wavs) == 1:
187
- final_wav = wavs[0]
188
- else:
189
- final_wav = wavs[0]
190
- for nw in wavs[1:]:
191
- final_wav = crossfade_join(final_wav, nw, crossfade_s, db_boost)
192
- return final_wav[:int(round(total_dur_s * SR))]
193
-
194
-
195
- def mux_video_audio(silent_video, audio_path, output_path):
196
- input_v = ffmpeg.input(silent_video)
197
- input_a = ffmpeg.input(audio_path)
198
- (
199
- ffmpeg
200
- .output(input_v, input_a, output_path,
201
- vcodec="libx264", acodec="aac", strict="experimental")
202
- .run(overwrite_output=True, quiet=True)
203
- )
204
 
205
 
206
- # ------------------------------------------------------------------ #
207
- # UI helpers (no GPU) #
208
- # ------------------------------------------------------------------ #
209
-
210
- def on_video_upload(video_file, num_steps, crossfade_s):
211
- """Called when video is uploaded or sliders change. Updates samples slider."""
212
- if video_file is None:
213
- return gr.update(maximum=MAX_SLOTS, value=1)
214
- try:
215
- D = get_video_duration(video_file)
216
- max_s = calc_max_samples(D, int(num_steps), float(crossfade_s))
217
- except Exception:
218
- max_s = MAX_SLOTS
219
- return gr.update(maximum=max_s, value=min(1, max_s))
220
 
221
 
222
- def get_random_seed():
223
- return random.randint(0, 2**32 - 1)
224
-
225
-
226
- # ------------------------------------------------------------------ #
227
- # Main inference #
228
- # ------------------------------------------------------------------ #
229
-
230
  @spaces.GPU(duration=600)
231
- def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode,
232
- crossfade_s, crossfade_db, num_samples):
233
- global _INFERENCE_CACHE
 
234
 
235
  seed_val = int(seed_val)
236
  crossfade_s = float(crossfade_s)
237
  crossfade_db = float(crossfade_db)
238
  num_samples = int(num_samples)
239
-
240
  if seed_val < 0:
241
  seed_val = random.randint(0, 2**32 - 1)
242
 
243
- # Load models once (shared across all samples this call)
244
  torch.set_grad_enabled(False)
245
  device = "cuda" if torch.cuda.is_available() else "cpu"
246
  weight_dtype = torch.bfloat16
247
 
248
- from cavp_util import Extract_CAVP_Features
249
- from onset_util import VideoOnsetNet, extract_onset
250
- from models import MMDiT
251
- from samplers import euler_sampler, euler_maruyama_sampler
252
- from diffusers import AudioLDM2Pipeline
 
253
 
 
254
  extract_cavp = Extract_CAVP_Features(
255
- device=device, config_path="./cavp/cavp.yaml", ckpt_path=cavp_ckpt_path
 
 
256
  )
257
 
258
- state_dict = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
259
- new_state_dict = {}
260
- for key, value in state_dict.items():
261
- if "model.net.model" in key:
262
- new_key = key.replace("model.net.model", "net.model")
263
- elif "model.fc." in key:
264
- new_key = key.replace("model.fc", "fc")
265
- else:
266
- new_key = key
267
- new_state_dict[new_key] = value
268
- onset_model = VideoOnsetNet(False).to(device)
269
- onset_model.load_state_dict(new_state_dict)
270
  onset_model.eval()
271
 
 
 
 
272
  model = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
273
- ckpt = torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"]
274
- model.load_state_dict(ckpt)
275
- model.eval()
276
- model.to(weight_dtype)
277
-
278
- model_audioldm = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
279
- vae = model_audioldm.vae.to(device)
280
- vae.eval()
281
- vocoder = model_audioldm.vocoder.to(device)
282
-
283
  latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
284
 
285
- # Prepare silent video (shared across all samples)
286
  tmp_dir = tempfile.mkdtemp()
287
  silent_video = os.path.join(tmp_dir, "silent_input.mp4")
288
  strip_audio_from_video(video_file, silent_video)
289
 
290
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
291
- total_frames = cavp_feats.shape[0]
292
- total_dur_s = total_frames / FPS
293
- segments = build_segments(total_dur_s, crossfade_s)
294
-
295
- # ------------------------------------------------------------------ #
296
- # Generate N samples #
297
- # ------------------------------------------------------------------ #
298
- outputs = [] # list of (video_path, audio_path)
299
 
 
300
  for sample_idx in range(num_samples):
301
  sample_seed = seed_val + sample_idx
302
- cache_key = (video_file, sample_seed, float(cfg_scale),
303
- int(num_steps), mode, crossfade_s)
304
 
305
- if cache_key in _INFERENCE_CACHE:
306
- print(f"Sample {sample_idx+1}: cache hit, re-stitching.")
307
- cached = _INFERENCE_CACHE[cache_key]
308
- wavs = cached["wavs"]
309
  else:
310
  set_global_seed(sample_seed)
311
- onset_feats = extract_onset(
312
- silent_video, onset_model, tmp_path=tmp_dir, device=device
313
- )
314
-
315
  wavs = []
316
  for seg_start_s, seg_end_s in segments:
317
- print(f" Sample {sample_idx+1} | segment {seg_start_s:.2f}s – {seg_end_s:.2f}s")
318
- wav = infer_segment(
319
  model, vae, vocoder,
320
  cavp_feats, onset_feats,
321
  seg_start_s, seg_end_s,
@@ -325,117 +327,421 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode,
325
  euler_sampler, euler_maruyama_sampler,
326
  )
327
  wavs.append(wav)
 
328
 
329
- _INFERENCE_CACHE[cache_key] = {"wavs": wavs}
 
 
 
 
 
330
 
331
- # Stitch
332
- final_wav = stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s)
333
 
334
- audio_path = os.path.join(tmp_dir, f"output_{sample_idx}.wav")
335
- sf.write(audio_path, final_wav, SR)
336
 
337
- video_path = os.path.join(tmp_dir, f"output_{sample_idx}.mp4")
338
- mux_video_audio(silent_video, audio_path, video_path)
 
 
 
 
 
 
 
 
 
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  outputs.append((video_path, audio_path))
341
 
342
- # ------------------------------------------------------------------ #
343
- # Return flat list of (video, audio) pairs padded with None #
344
- # so Gradio output list length is always MAX_SLOTS * 2 #
345
- # ------------------------------------------------------------------ #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  result = []
347
  for i in range(MAX_SLOTS):
348
  if i < len(outputs):
349
- result.append(outputs[i][0]) # video
350
- result.append(outputs[i][1]) # audio
351
  else:
352
- result.append(None)
353
- result.append(None)
354
  return result
355
 
356
 
357
- # ------------------------------------------------------------------ #
358
- # Build gr.Blocks UI #
359
- # ------------------------------------------------------------------ #
 
 
 
 
 
 
 
360
 
361
- with gr.Blocks(title="TARO: Video-to-Audio Synthesis") as demo:
 
 
 
 
 
 
 
 
 
362
  gr.Markdown(
363
- "# TARO: Video-to-Audio Synthesis (ICCV 2025)\n"
364
- "Upload a video and generate synchronized audio. "
365
- "Optimal clip duration is 8.2s. Longer videos are automatically "
366
- "split into overlapping segments and stitched with a crossfade."
 
 
 
367
  )
368
 
369
- with gr.Row():
370
- with gr.Column():
371
- video_input = gr.Video(label="Input Video")
372
- seed_input = gr.Number(label="Seed", value=get_random_seed, precision=0)
373
- cfg_input = gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5)
374
- steps_input = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1)
375
- mode_input = gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde")
376
- cf_dur_input = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=8, value=2, step=0.1)
377
- cf_db_input = gr.Textbox(label="Crossfade Boost (dB)", value="3")
378
- samples_input = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS,
379
- value=1, step=1)
380
- run_btn = gr.Button("Generate", variant="primary")
381
-
382
- with gr.Column():
383
- # All MAX_SLOTS slots pre-built.
384
- # Slot 0 is always visible (shows loading progress during inference).
385
- # Slots 1-N become visible when user drags the Generations slider.
386
- slot_videos = []
387
- slot_audios = []
388
- slot_grps = []
389
- for i in range(MAX_SLOTS):
390
- with gr.Group(visible=(i == 0)) as grp:
391
- sv = gr.Video(label=f"Generation {i+1} β€” Video")
392
- sa = gr.Audio(label=f"Generation {i+1} β€” Audio")
393
- slot_grps.append(grp)
394
- slot_videos.append(sv)
395
- slot_audios.append(sa)
396
-
397
- # -------------------------------------------------------------- #
398
- # Events #
399
- # -------------------------------------------------------------- #
400
-
401
- # Update Generations slider max on video upload / steps / crossfade change
402
- def _update_samples_slider(video_file, num_steps, crossfade_s):
403
- return on_video_upload(video_file, num_steps, crossfade_s)
404
-
405
- for trigger in [video_input, steps_input, cf_dur_input]:
406
- trigger.change(
407
- fn=_update_samples_slider,
408
- inputs=[video_input, steps_input, cf_dur_input],
409
- outputs=[samples_input],
410
- )
 
 
411
 
412
- # Show/hide output slots instantly when Generations slider is dragged
413
- def _update_slot_visibility(num_samples):
414
- n = int(num_samples)
415
- return [gr.update(visible=(i < n)) for i in range(MAX_SLOTS)]
 
 
 
 
 
 
 
 
 
 
416
 
417
- samples_input.change(
418
- fn=_update_slot_visibility,
419
- inputs=[samples_input],
420
- outputs=slot_grps,
421
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
423
- # Main generate: calls inference then populates slots
424
- def _generate_and_update(video_file, seed_val, cfg_scale, num_steps, mode,
425
- crossfade_s, crossfade_db, num_samples):
426
- flat = generate_audio(video_file, seed_val, cfg_scale, num_steps, mode,
427
- crossfade_s, crossfade_db, num_samples)
428
- n = int(num_samples)
429
- grp_updates = [gr.update(visible=(i < n)) for i in range(MAX_SLOTS)]
430
- video_updates = [gr.update(value=flat[i * 2]) for i in range(MAX_SLOTS)]
431
- audio_updates = [gr.update(value=flat[i * 2 + 1]) for i in range(MAX_SLOTS)]
432
- return grp_updates + video_updates + audio_updates
433
-
434
- run_btn.click(
435
- fn=_generate_and_update,
436
- inputs=[video_input, seed_input, cfg_input, steps_input, mode_input,
437
- cf_dur_input, cf_db_input, samples_input],
438
- outputs=slot_grps + slot_videos + slot_audios,
439
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
 
441
  demo.queue().launch()
 
1
+ """
2
+ Generate Audio for Video β€” multi-model Gradio app.
3
+
4
+ Supported models
5
+ ----------------
6
+ TARO – video-conditioned diffusion via CAVP + onset features (16 kHz, 8.192 s window)
7
+ MMAudio – multimodal flow-matching with CLIP/Synchformer + text prompt (44 kHz, 8 s window)
8
+ HunyuanFoley – text-guided foley via SigLIP2 + Synchformer + CLAP (48 kHz, up to 15 s)
9
+ """
10
+
11
  import os
12
  import subprocess
13
  import sys
14
+ import tempfile
15
+ import random
16
+ from math import floor
17
+ from pathlib import Path
18
 
19
+ # ------------------------------------------------------------------ #
20
+ # mmcv bootstrap (needed by TARO's CAVP encoder) #
21
+ # ------------------------------------------------------------------ #
22
  try:
23
  import mmcv
24
  print("mmcv already installed")
 
29
 
30
  import torch
31
  import numpy as np
 
32
  import soundfile as sf
33
  import ffmpeg
 
34
  import spaces
35
  import gradio as gr
36
  from huggingface_hub import hf_hub_download
37
 
38
+ # ================================================================== #
39
+ # CHECKPOINT CONFIGURATION #
40
+ # ================================================================== #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ CKPT_REPO_ID = "JackIsNotInTheBox/Generate_Audio_for_Video_Checkpoints"
43
+ CACHE_DIR = "/tmp/model_ckpts"
44
+ os.makedirs(CACHE_DIR, exist_ok=True)
45
 
46
+ # ---- TARO checkpoints (in TARO/ subfolder of the HF repo) ----
47
+ print("Downloading TARO checkpoints…")
48
+ cavp_ckpt_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/cavp_epoch66.ckpt", cache_dir=CACHE_DIR)
49
+ onset_ckpt_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/onset_model.ckpt", cache_dir=CACHE_DIR)
50
+ taro_ckpt_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/taro_ckpt.pt", cache_dir=CACHE_DIR)
51
+ print("TARO checkpoints downloaded.")
52
+
53
+ # ---- MMAudio checkpoints (in MMAudio/ subfolder) ----
54
+ # MMAudio normally auto-downloads from its own HF repo, but we
55
+ # override the paths so it pulls from our consolidated repo instead.
56
+ MMAUDIO_WEIGHTS_DIR = Path(CACHE_DIR) / "MMAudio" / "weights"
57
+ MMAUDIO_EXT_DIR = Path(CACHE_DIR) / "MMAudio" / "ext_weights"
58
+ MMAUDIO_WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
59
+ MMAUDIO_EXT_DIR.mkdir(parents=True, exist_ok=True)
60
+
61
+ print("Downloading MMAudio checkpoints…")
62
+ mmaudio_model_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/mmaudio_large_44k_v2.pth", cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_WEIGHTS_DIR), local_dir_use_symlinks=False)
63
+ mmaudio_vae_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/v1-44.pth", cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False)
64
+ mmaudio_synchformer_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/synchformer_state_dict.pth", cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False)
65
+ print("MMAudio checkpoints downloaded.")
66
+
67
+ # ---- HunyuanVideoFoley checkpoints (in HunyuanFoley/ subfolder) ----
68
+ HUNYUAN_MODEL_DIR = Path(CACHE_DIR) / "HunyuanFoley"
69
+ HUNYUAN_MODEL_DIR.mkdir(parents=True, exist_ok=True)
70
+
71
+ print("Downloading HunyuanVideoFoley checkpoints…")
72
+ hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanFoley/hunyuanvideo_foley.pth", cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
73
+ hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanFoley/vae_128d_48k.pth", cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
74
+ hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanFoley/synchformer_state_dict.pth", cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
75
+ print("HunyuanVideoFoley checkpoints downloaded.")
76
+
77
+ # ================================================================== #
78
+ # SHARED CONSTANTS / HELPERS #
79
+ # ================================================================== #
80
+
81
+ MAX_SLOTS = 8 # max parallel generation slots shown in UI
82
+
83
+ def set_global_seed(seed: int):
84
  np.random.seed(seed % (2**32))
85
  random.seed(seed)
86
  torch.manual_seed(seed)
87
  torch.cuda.manual_seed(seed)
88
  torch.backends.cudnn.deterministic = True
89
 
90
+ def get_random_seed() -> int:
91
+ return random.randint(0, 2**32 - 1)
92
 
93
+ def get_video_duration(video_path: str) -> float:
94
+ """Return video duration in seconds (CPU only)."""
 
 
 
 
 
 
 
 
 
 
95
  probe = ffmpeg.probe(video_path)
96
  return float(probe["format"]["duration"])
97
 
98
+ def strip_audio_from_video(video_path: str, output_path: str):
99
+ """Write a silent copy of *video_path* to *output_path*."""
100
+ ffmpeg.input(video_path).output(output_path, vcodec="libx264", an=None).run(
101
+ overwrite_output=True, quiet=True
102
+ )
103
 
104
+ def mux_video_audio(silent_video: str, audio_path: str, output_path: str):
105
+ """Mux a silent video with an audio file into *output_path*."""
106
+ ffmpeg.output(
107
+ ffmpeg.input(silent_video),
108
+ ffmpeg.input(audio_path),
109
+ output_path,
110
+ vcodec="libx264", acodec="aac", strict="experimental",
111
+ ).run(overwrite_output=True, quiet=True)
112
+
113
+
114
+ # ================================================================== #
115
+ # TARO #
116
+ # ================================================================== #
117
+ # Constants sourced from TARO/infer.py and TARO/models.py:
118
+ # SR=16000, TRUNCATE=131072 β†’ 8.192 s window
119
+ # TRUNCATE_FRAME = 4 fps Γ— 131072/16000 = 32 CAVP frames per window
120
+ # TRUNCATE_ONSET = 120 onset frames per window
121
+ # latent shape: (1, 8, 204, 16) β€” fixed by MMDiT architecture
122
+ # latents_scale: [0.18215]*8 β€” AudioLDM2 VAE scale factor
123
+ # ================================================================== #
124
+
125
+ TARO_SR = 16000
126
+ TARO_TRUNCATE = 131072
127
+ TARO_FPS = 4
128
+ TARO_TRUNCATE_FRAME = int(TARO_FPS * TARO_TRUNCATE / TARO_SR) # 32
129
+ TARO_TRUNCATE_ONSET = 120
130
+ TARO_MODEL_DUR = TARO_TRUNCATE / TARO_SR # 8.192 s
131
+ TARO_SECS_PER_STEP = 2.5 # estimated GPU-seconds per diffusion step
132
+
133
+ _TARO_INFERENCE_CACHE: dict = {}
134
+
135
+
136
+ def _taro_build_segments(total_dur_s: float, crossfade_s: float) -> list:
137
+ """Sliding-window segmentation for videos longer than one TARO window."""
138
+ if total_dur_s <= TARO_MODEL_DUR:
139
  return [(0.0, total_dur_s)]
140
+ step_s = TARO_MODEL_DUR - crossfade_s
141
+ segments, seg_start = [], 0.0
 
 
142
  while True:
143
+ if seg_start + TARO_MODEL_DUR >= total_dur_s:
144
+ seg_start = max(0.0, total_dur_s - TARO_MODEL_DUR)
 
 
145
  segments.append((seg_start, total_dur_s))
146
  break
147
+ segments.append((seg_start, seg_start + TARO_MODEL_DUR))
148
  seg_start += step_s
 
149
  return segments
150
 
151
 
152
+ def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: float) -> int:
153
+ n_segs = len(_taro_build_segments(total_dur_s, crossfade_s))
154
+ time_per_seg = num_steps * TARO_SECS_PER_STEP
155
+ max_s = floor(600.0 / (n_segs * time_per_seg))
 
 
156
  return max(1, min(max_s, MAX_SLOTS))
157
 
158
 
159
+ def _taro_infer_segment(
160
+ model, vae, vocoder,
161
+ cavp_feats_full, onset_feats_full,
162
+ seg_start_s: float, seg_end_s: float,
163
+ device, weight_dtype,
164
+ cfg_scale: float, num_steps: int, mode: str,
165
+ latents_scale,
166
+ euler_sampler, euler_maruyama_sampler,
167
+ ) -> np.ndarray:
168
+ """Single-segment TARO inference. Returns wav array trimmed to segment length."""
169
  # CAVP features (4 fps)
170
+ cavp_start = int(round(seg_start_s * TARO_FPS))
171
+ cavp_slice = cavp_feats_full[cavp_start : cavp_start + TARO_TRUNCATE_FRAME]
172
+ if cavp_slice.shape[0] < TARO_TRUNCATE_FRAME:
173
  pad = np.zeros(
174
+ (TARO_TRUNCATE_FRAME - cavp_slice.shape[0],) + cavp_slice.shape[1:],
175
  dtype=cavp_slice.dtype,
176
  )
177
  cavp_slice = np.concatenate([cavp_slice, pad], axis=0)
178
+ video_feats = torch.from_numpy(cavp_slice).unsqueeze(0).to(device, weight_dtype)
179
 
180
+ # Onset features (onset_fps = TRUNCATE_ONSET / MODEL_DUR β‰ˆ 14.65 fps)
181
+ onset_fps = TARO_TRUNCATE_ONSET / TARO_MODEL_DUR
182
  onset_start = int(round(seg_start_s * onset_fps))
183
+ onset_slice = onset_feats_full[onset_start : onset_start + TARO_TRUNCATE_ONSET]
184
+ if onset_slice.shape[0] < TARO_TRUNCATE_ONSET:
185
+ onset_slice = np.pad(
186
+ onset_slice,
187
+ ((0, TARO_TRUNCATE_ONSET - onset_slice.shape[0]),),
188
+ mode="constant",
189
+ )
190
+ onset_feats_t = torch.from_numpy(onset_slice).unsqueeze(0).to(device, weight_dtype)
191
+
192
+ # Latent noise β€” shape matches MMDiT architecture (in_channels=8, 204Γ—16 spatial)
193
+ z = torch.randn(1, model.in_channels, 204, 16, device=device, dtype=weight_dtype)
194
 
 
195
  sampling_kwargs = dict(
196
  model=model,
197
  latents=z,
 
205
  path_type="linear",
206
  )
207
  with torch.no_grad():
208
+ samples = (euler_maruyama_sampler if mode == "sde" else euler_sampler)(**sampling_kwargs)
209
+ # samplers return (output_tensor, zs) β€” index [0] for the audio latent
210
+ if isinstance(samples, tuple):
211
+ samples = samples[0]
212
 
213
+ # Decode: AudioLDM2 VAE β†’ mel β†’ vocoder β†’ waveform
214
  samples = vae.decode(samples / latents_scale).sample
215
  wav = vocoder(samples.squeeze().float()).detach().cpu().numpy()
216
+ seg_samples = int(round((seg_end_s - seg_start_s) * TARO_SR))
217
  return wav[:seg_samples]
218
 
219
 
220
+ def _crossfade_join(wav_a: np.ndarray, wav_b: np.ndarray,
221
+ crossfade_s: float, db_boost: float) -> np.ndarray:
222
+ cf = int(round(crossfade_s * TARO_SR))
223
+ cf = min(cf, len(wav_a), len(wav_b))
224
+ if cf <= 0:
 
 
 
 
225
  return np.concatenate([wav_a, wav_b])
 
226
  gain = 10 ** (db_boost / 20.0)
227
+ overlap = wav_a[-cf:] * gain + wav_b[:cf] * gain
228
+ return np.concatenate([wav_a[:-cf], overlap, wav_b[cf:]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
 
231
+ def _stitch_wavs(wavs: list, crossfade_s: float, db_boost: float, total_dur_s: float) -> np.ndarray:
232
+ out = wavs[0]
233
+ for nw in wavs[1:]:
234
+ out = _crossfade_join(out, nw, crossfade_s, db_boost)
235
+ return out[:int(round(total_dur_s * TARO_SR))]
 
 
 
 
 
 
 
 
 
236
 
237
 
 
 
 
 
 
 
 
 
238
  @spaces.GPU(duration=600)
239
+ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
240
+ crossfade_s, crossfade_db, num_samples):
241
+ """TARO: video-conditioned diffusion, 16 kHz, 8.192 s sliding window."""
242
+ global _TARO_INFERENCE_CACHE
243
 
244
  seed_val = int(seed_val)
245
  crossfade_s = float(crossfade_s)
246
  crossfade_db = float(crossfade_db)
247
  num_samples = int(num_samples)
 
248
  if seed_val < 0:
249
  seed_val = random.randint(0, 2**32 - 1)
250
 
 
251
  torch.set_grad_enabled(False)
252
  device = "cuda" if torch.cuda.is_available() else "cpu"
253
  weight_dtype = torch.bfloat16
254
 
255
+ # Imports are inside the GPU context so the Space only pays for GPU time here
256
+ from TARO.cavp_util import Extract_CAVP_Features
257
+ from TARO.onset_util import VideoOnsetNet, extract_onset
258
+ from TARO.models import MMDiT
259
+ from TARO.samplers import euler_sampler, euler_maruyama_sampler
260
+ from diffusers import AudioLDM2Pipeline
261
 
262
+ # -- Load CAVP encoder (uses checkpoint from our HF repo) --
263
  extract_cavp = Extract_CAVP_Features(
264
+ device=device,
265
+ config_path="TARO/cavp/cavp.yaml",
266
+ ckpt_path=cavp_ckpt_path,
267
  )
268
 
269
+ # -- Load onset detection model --
270
+ # Key remapping matches the original TARO infer.py exactly
271
+ raw_sd = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
272
+ onset_sd = {}
273
+ for k, v in raw_sd.items():
274
+ if "model.net.model" in k:
275
+ k = k.replace("model.net.model", "net.model")
276
+ elif "model.fc." in k:
277
+ k = k.replace("model.fc", "fc")
278
+ onset_sd[k] = v
279
+ onset_model = VideoOnsetNet(pretrained=False).to(device)
280
+ onset_model.load_state_dict(onset_sd)
281
  onset_model.eval()
282
 
283
+ # -- Load TARO MMDiT --
284
+ # Architecture params match TARO/train.py: adm_in_channels=120 (onset dim),
285
+ # z_dims=[768] (CAVP dim), encoder_depth=4
286
  model = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
287
+ model.load_state_dict(torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"])
288
+ model.eval().to(weight_dtype)
289
+
290
+ # -- Load AudioLDM2 VAE + vocoder (decoder pipeline only) --
291
+ # TARO uses AudioLDM2's VAE and vocoder for decoding; no encoder needed at inference
292
+ audioldm2 = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
293
+ vae = audioldm2.vae.to(device).eval()
294
+ vocoder = audioldm2.vocoder.to(device)
 
 
295
  latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
296
 
297
+ # -- Prepare silent video (shared across all samples) --
298
  tmp_dir = tempfile.mkdtemp()
299
  silent_video = os.path.join(tmp_dir, "silent_input.mp4")
300
  strip_audio_from_video(video_file, silent_video)
301
 
302
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
303
+ total_dur_s = cavp_feats.shape[0] / TARO_FPS
304
+ segments = _taro_build_segments(total_dur_s, crossfade_s)
 
 
 
 
 
 
305
 
306
+ outputs = []
307
  for sample_idx in range(num_samples):
308
  sample_seed = seed_val + sample_idx
309
+ cache_key = (video_file, sample_seed, float(cfg_scale), int(num_steps), mode, crossfade_s)
 
310
 
311
+ if cache_key in _TARO_INFERENCE_CACHE:
312
+ print(f"[TARO] Sample {sample_idx+1}: cache hit.")
313
+ wavs = _TARO_INFERENCE_CACHE[cache_key]["wavs"]
 
314
  else:
315
  set_global_seed(sample_seed)
316
+ onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
 
 
 
317
  wavs = []
318
  for seg_start_s, seg_end_s in segments:
319
+ print(f"[TARO] Sample {sample_idx+1} | {seg_start_s:.2f}s – {seg_end_s:.2f}s")
320
+ wav = _taro_infer_segment(
321
  model, vae, vocoder,
322
  cavp_feats, onset_feats,
323
  seg_start_s, seg_end_s,
 
327
  euler_sampler, euler_maruyama_sampler,
328
  )
329
  wavs.append(wav)
330
+ _TARO_INFERENCE_CACHE[cache_key] = {"wavs": wavs}
331
 
332
+ final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s)
333
+ audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
334
+ sf.write(audio_path, final_wav, TARO_SR)
335
+ video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
336
+ mux_video_audio(silent_video, audio_path, video_path)
337
+ outputs.append((video_path, audio_path))
338
 
339
+ return _pad_outputs(outputs)
 
340
 
 
 
341
 
342
+ # ================================================================== #
343
+ # MMAudio #
344
+ # ================================================================== #
345
+ # Constants sourced from MMAudio/mmaudio/model/sequence_config.py:
346
+ # CONFIG_44K: duration=8.0 s, sampling_rate=44100
347
+ # CLIP encoder: 8 fps, 384Γ—384 px
348
+ # Synchformer: 25 fps, 224Γ—224 px
349
+ # Default variant: large_44k_v2
350
+ # MMAudio uses flow-matching (FlowMatching with euler inference).
351
+ # generate() handles all feature extraction + decoding internally.
352
+ # ================================================================== #
353
 
354
+ @spaces.GPU(duration=600)
355
+ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
356
+ cfg_strength, num_steps, duration, num_samples):
357
+ """MMAudio: flow-matching video-to-audio, 44.1 kHz, 8 s window, text-guided."""
358
+ import torchaudio
359
+ from mmaudio.eval_utils import all_model_cfg, generate, load_video, make_video
360
+ from mmaudio.model.flow_matching import FlowMatching
361
+ from mmaudio.model.networks import get_my_mmaudio
362
+ from mmaudio.model.utils.features_utils import FeaturesUtils
363
+
364
+ seed_val = int(seed_val)
365
+ num_samples = int(num_samples)
366
+ duration = float(duration)
367
+
368
+ device = "cuda" if torch.cuda.is_available() else "cpu"
369
+ dtype = torch.bfloat16
370
+
371
+ # Use large_44k_v2 variant; override paths to our consolidated HF checkpoint repo
372
+ model_cfg = all_model_cfg["large_44k_v2"]
373
+ # Patch checkpoint paths to our downloaded files
374
+ from pathlib import Path as _Path
375
+ model_cfg.model_path = _Path(mmaudio_model_path)
376
+ model_cfg.vae_path = _Path(mmaudio_vae_path)
377
+ model_cfg.synchformer_ckpt = _Path(mmaudio_synchformer_path)
378
+ # large_44k_v2 is 44k mode, no BigVGAN vocoder needed
379
+ model_cfg.bigvgan_16k_path = None
380
+ seq_cfg = model_cfg.seq_cfg # CONFIG_44K: 8 s, 44100 Hz
381
+
382
+ # Load network weights
383
+ net = get_my_mmaudio(model_cfg.model_name).to(device, dtype).eval()
384
+ net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True))
385
+
386
+ # Load feature utilities: CLIP (auto-downloaded from apple/DFN5B-CLIP-ViT-H-14-384),
387
+ # Synchformer (from our repo), VAE (from our repo), no BigVGAN for 44k mode
388
+ feature_utils = FeaturesUtils(
389
+ tod_vae_ckpt=str(model_cfg.vae_path),
390
+ synchformer_ckpt=str(model_cfg.synchformer_ckpt),
391
+ enable_conditions=True,
392
+ mode=model_cfg.mode, # "44k"
393
+ bigvgan_vocoder_ckpt=None,
394
+ need_vae_encoder=False,
395
+ ).to(device, dtype).eval()
396
+
397
+ tmp_dir = tempfile.mkdtemp()
398
+ outputs = []
399
+
400
+ for sample_idx in range(num_samples):
401
+ rng = torch.Generator(device=device)
402
+ if seed_val >= 0:
403
+ rng.manual_seed(seed_val + sample_idx)
404
+ else:
405
+ rng.seed()
406
+
407
+ fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
408
+
409
+ # load_video() resamples to 8 fps (CLIP) and 25 fps (Synchformer) on the fly
410
+ video_info = load_video(video_file, duration)
411
+ clip_frames = video_info.clip_frames.unsqueeze(0) # (1, T_clip, C, H, W)
412
+ sync_frames = video_info.sync_frames.unsqueeze(0) # (1, T_sync, C, H, W)
413
+ actual_dur = video_info.duration_sec
414
+
415
+ seq_cfg.duration = actual_dur
416
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
417
+
418
+ print(f"[MMAudio] Sample {sample_idx+1} | duration={actual_dur:.2f}s | prompt='{prompt}'")
419
+
420
+ audios = generate(
421
+ clip_frames,
422
+ sync_frames,
423
+ [prompt],
424
+ negative_text=[negative_prompt] if negative_prompt else None,
425
+ feature_utils=feature_utils,
426
+ net=net,
427
+ fm=fm,
428
+ rng=rng,
429
+ cfg_strength=float(cfg_strength),
430
+ )
431
+ audio = audios.float().cpu()[0] # (C, T)
432
+
433
+ audio_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.flac")
434
+ torchaudio.save(audio_path, audio, seq_cfg.sampling_rate)
435
+
436
+ video_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.mp4")
437
+ make_video(video_info, video_path, audio, sampling_rate=seq_cfg.sampling_rate)
438
  outputs.append((video_path, audio_path))
439
 
440
+ return _pad_outputs(outputs)
441
+
442
+
443
+ # ================================================================== #
444
+ # HunyuanVideoFoley #
445
+ # ================================================================== #
446
+ # Constants sourced from HunyuanVideo-Foley/hunyuanvideo_foley/constants.py
447
+ # and configs/hunyuanvideo-foley-xxl.yaml:
448
+ # sample_rate = 48000 Hz (from DAC VAE)
449
+ # audio_frame_rate = 50 (latent fps, xxl config)
450
+ # max video duration = 15 s
451
+ # SigLIP2 fps = 8, Synchformer fps = 25
452
+ # CLAP text encoder: laion/larger_clap_general (auto-downloaded from HF Hub)
453
+ # Default guidance_scale=4.5, num_inference_steps=50
454
+ # ================================================================== #
455
+
456
+ HUNYUAN_MAX_DUR = 15.0 # seconds
457
+
458
+ @spaces.GPU(duration=600)
459
+ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
460
+ guidance_scale, num_steps, model_size, num_samples):
461
+ """HunyuanVideoFoley: text-guided foley, 48 kHz, up to 15 s."""
462
+ import torchaudio
463
+ import sys as _sys
464
+ # Ensure HunyuanVideo-Foley package is importable
465
+ _hf_path = str(Path("HunyuanVideo-Foley").resolve())
466
+ if _hf_path not in _sys.path:
467
+ _sys.path.insert(0, _hf_path)
468
+
469
+ from hunyuanvideo_foley.utils.model_utils import load_model, denoise_process
470
+ from hunyuanvideo_foley.utils.feature_utils import feature_process
471
+ from hunyuanvideo_foley.utils.media_utils import merge_audio_video
472
+
473
+ seed_val = int(seed_val)
474
+ num_samples = int(num_samples)
475
+ if seed_val >= 0:
476
+ set_global_seed(seed_val)
477
+
478
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
479
+ model_size = model_size.lower() # "xl" or "xxl"
480
+
481
+ config_map = {
482
+ "xl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xl.yaml",
483
+ "xxl": "HunyuanVideo-Foley/configs/hunyuanvideo-foley-xxl.yaml",
484
+ }
485
+ config_path = config_map.get(model_size, config_map["xxl"])
486
+
487
+ print(f"[HunyuanFoley] Loading {model_size.upper()} model from {HUNYUAN_MODEL_DIR}")
488
+ # load_model() handles: HunyuanVideoFoley main model, DAC-VAE, SigLIP2, CLAP, Synchformer
489
+ # CLAP (laion/larger_clap_general) and SigLIP2 (google/siglip2-base-patch16-512) are
490
+ # downloaded from HuggingFace Hub automatically by load_model().
491
+ model_dict, cfg = load_model(
492
+ str(HUNYUAN_MODEL_DIR),
493
+ config_path,
494
+ device,
495
+ enable_offload=False,
496
+ model_size=model_size,
497
+ )
498
+
499
+ tmp_dir = tempfile.mkdtemp()
500
+ outputs = []
501
+
502
+ # feature_process() extracts SigLIP2 visual features + Synchformer sync features
503
+ # + CLAP text embeddings β€” exactly as in HunyuanVideo-Foley/gradio_app.py
504
+ visual_feats, text_feats, audio_len_in_s = feature_process(
505
+ video_file,
506
+ prompt if prompt else "",
507
+ model_dict,
508
+ cfg,
509
+ neg_prompt=negative_prompt if negative_prompt else None,
510
+ )
511
+ print(f"[HunyuanFoley] Audio length: {audio_len_in_s:.2f}s | generating {num_samples} sample(s)")
512
+
513
+ # denoise_process() runs the flow-matching diffusion loop and decodes with DAC-VAE
514
+ # batch_size=num_samples generates all samples in one pass
515
+ audio, sample_rate = denoise_process(
516
+ visual_feats,
517
+ text_feats,
518
+ audio_len_in_s,
519
+ model_dict,
520
+ cfg,
521
+ guidance_scale=float(guidance_scale),
522
+ num_inference_steps=int(num_steps),
523
+ batch_size=num_samples,
524
+ )
525
+ # audio shape: (batch, channels, samples)
526
+ for sample_idx in range(num_samples):
527
+ audio_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.wav")
528
+ torchaudio.save(audio_path, audio[sample_idx], sample_rate)
529
+ video_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.mp4")
530
+ merge_audio_video(audio_path, video_file, video_path)
531
+ outputs.append((video_path, audio_path))
532
+
533
+ return _pad_outputs(outputs)
534
+
535
+
536
+ # ================================================================== #
537
+ # SHARED UI HELPERS #
538
+ # ================================================================== #
539
+
540
+ def _pad_outputs(outputs: list) -> list:
541
+ """Flatten (video, audio) pairs and pad to MAX_SLOTS * 2 with None."""
542
  result = []
543
  for i in range(MAX_SLOTS):
544
  if i < len(outputs):
545
+ result.extend(outputs[i])
 
546
  else:
547
+ result.extend([None, None])
 
548
  return result
549
 
550
 
551
+ def _on_video_upload_taro(video_file, num_steps, crossfade_s):
552
+ if video_file is None:
553
+ return gr.update(maximum=MAX_SLOTS, value=1)
554
+ try:
555
+ D = get_video_duration(video_file)
556
+ max_s = _taro_calc_max_samples(D, int(num_steps), float(crossfade_s))
557
+ except Exception:
558
+ max_s = MAX_SLOTS
559
+ return gr.update(maximum=max_s, value=min(1, max_s))
560
+
561
 
562
+ def _update_slot_visibility(n):
563
+ n = int(n)
564
+ return [gr.update(visible=(i < n)) for i in range(MAX_SLOTS)]
565
+
566
+
567
+ # ================================================================== #
568
+ # GRADIO UI #
569
+ # ================================================================== #
570
+
571
+ with gr.Blocks(title="Video-to-Audio Generation") as demo:
572
  gr.Markdown(
573
+ "# Video-to-Audio Generation\n"
574
+ "Choose a model and upload a video to generate synchronized audio.\n\n"
575
+ "| Model | Sample rate | Optimal duration | Notes |\n"
576
+ "|-------|------------|-----------------|-------|\n"
577
+ "| **TARO** | 16 kHz | 8.2 s | Video-only, sliding window for longer clips |\n"
578
+ "| **MMAudio** | 44.1 kHz | 8 s | Text prompt supported |\n"
579
+ "| **HunyuanFoley** | 48 kHz | up to 15 s | Text-guided foley, highest fidelity |"
580
  )
581
 
582
+ with gr.Tabs():
583
+
584
+ # ---------------------------------------------------------- #
585
+ # Tab 1 β€” TARO #
586
+ # ---------------------------------------------------------- #
587
+ with gr.Tab("TARO"):
588
+ gr.Markdown(
589
+ "**TARO** β€” Video-conditioned diffusion (ICCV 2025). No text prompt needed. "
590
+ "8.192 s model window; longer videos are split into overlapping segments "
591
+ "and stitched with a crossfade."
592
+ )
593
+ with gr.Row():
594
+ with gr.Column():
595
+ taro_video = gr.Video(label="Input Video")
596
+ taro_seed = gr.Number(label="Seed (-1 = random)", value=get_random_seed, precision=0)
597
+ taro_cfg = gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5)
598
+ taro_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1)
599
+ taro_mode = gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde")
600
+ taro_cf_dur = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=8, value=2, step=0.1)
601
+ taro_cf_db = gr.Textbox(label="Crossfade Boost (dB)", value="3")
602
+ taro_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1)
603
+ taro_btn = gr.Button("Generate", variant="primary")
604
+
605
+ with gr.Column():
606
+ taro_slot_grps, taro_slot_vids, taro_slot_auds = [], [], []
607
+ for i in range(MAX_SLOTS):
608
+ with gr.Group(visible=(i == 0)) as g:
609
+ sv = gr.Video(label=f"Generation {i+1} β€” Video")
610
+ sa = gr.Audio(label=f"Generation {i+1} β€” Audio")
611
+ taro_slot_grps.append(g)
612
+ taro_slot_vids.append(sv)
613
+ taro_slot_auds.append(sa)
614
+
615
+ for trigger in [taro_video, taro_steps, taro_cf_dur]:
616
+ trigger.change(
617
+ fn=_on_video_upload_taro,
618
+ inputs=[taro_video, taro_steps, taro_cf_dur],
619
+ outputs=[taro_samples],
620
+ )
621
+ taro_samples.change(
622
+ fn=_update_slot_visibility,
623
+ inputs=[taro_samples],
624
+ outputs=taro_slot_grps,
625
+ )
626
 
627
+ def _run_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n):
628
+ flat = generate_taro(video, seed, cfg, steps, mode, cf_dur, cf_db, n)
629
+ n = int(n)
630
+ grp_upd = [gr.update(visible=(i < n)) for i in range(MAX_SLOTS)]
631
+ vid_upd = [gr.update(value=flat[i * 2]) for i in range(MAX_SLOTS)]
632
+ aud_upd = [gr.update(value=flat[i * 2 + 1]) for i in range(MAX_SLOTS)]
633
+ return grp_upd + vid_upd + aud_upd
634
+
635
+ taro_btn.click(
636
+ fn=_run_taro,
637
+ inputs=[taro_video, taro_seed, taro_cfg, taro_steps, taro_mode,
638
+ taro_cf_dur, taro_cf_db, taro_samples],
639
+ outputs=taro_slot_grps + taro_slot_vids + taro_slot_auds,
640
+ )
641
 
642
+ # ---------------------------------------------------------- #
643
+ # Tab 2 β€” MMAudio #
644
+ # ---------------------------------------------------------- #
645
+ with gr.Tab("MMAudio"):
646
+ gr.Markdown(
647
+ "**MMAudio** β€” Multimodal flow-matching (CVPR 2025). "
648
+ "Supports a text prompt for additional control. "
649
+ "Native window is 8 s at 44.1 kHz. "
650
+ "Duration slider lets you control how many seconds are processed."
651
+ )
652
+ with gr.Row():
653
+ with gr.Column():
654
+ mma_video = gr.Video(label="Input Video")
655
+ mma_prompt = gr.Textbox(label="Prompt", placeholder="e.g. footsteps on gravel")
656
+ mma_neg = gr.Textbox(label="Negative Prompt", placeholder="music, speech")
657
+ mma_seed = gr.Number(label="Seed (-1 = random)", value=get_random_seed, precision=0)
658
+ mma_cfg = gr.Slider(label="CFG Strength", minimum=1, maximum=10, value=4.5, step=0.5)
659
+ mma_steps = gr.Slider(label="Steps", minimum=10, maximum=50, value=25, step=1)
660
+ mma_dur = gr.Slider(label="Duration (s)", minimum=1, maximum=10, value=8, step=0.5)
661
+ mma_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1)
662
+ mma_btn = gr.Button("Generate", variant="primary")
663
+
664
+ with gr.Column():
665
+ mma_slot_grps, mma_slot_vids, mma_slot_auds = [], [], []
666
+ for i in range(MAX_SLOTS):
667
+ with gr.Group(visible=(i == 0)) as g:
668
+ sv = gr.Video(label=f"Generation {i+1} β€” Video")
669
+ sa = gr.Audio(label=f"Generation {i+1} β€” Audio")
670
+ mma_slot_grps.append(g)
671
+ mma_slot_vids.append(sv)
672
+ mma_slot_auds.append(sa)
673
+
674
+ mma_samples.change(
675
+ fn=_update_slot_visibility,
676
+ inputs=[mma_samples],
677
+ outputs=mma_slot_grps,
678
+ )
679
 
680
+ def _run_mmaudio(video, prompt, neg, seed, cfg, steps, dur, n):
681
+ flat = generate_mmaudio(video, prompt, neg, seed, cfg, steps, dur, n)
682
+ n = int(n)
683
+ grp_upd = [gr.update(visible=(i < n)) for i in range(MAX_SLOTS)]
684
+ vid_upd = [gr.update(value=flat[i * 2]) for i in range(MAX_SLOTS)]
685
+ aud_upd = [gr.update(value=flat[i * 2 + 1]) for i in range(MAX_SLOTS)]
686
+ return grp_upd + vid_upd + aud_upd
687
+
688
+ mma_btn.click(
689
+ fn=_run_mmaudio,
690
+ inputs=[mma_video, mma_prompt, mma_neg, mma_seed,
691
+ mma_cfg, mma_steps, mma_dur, mma_samples],
692
+ outputs=mma_slot_grps + mma_slot_vids + mma_slot_auds,
693
+ )
694
+
695
+ # ---------------------------------------------------------- #
696
+ # Tab 3 β€” HunyuanVideoFoley #
697
+ # ---------------------------------------------------------- #
698
+ with gr.Tab("HunyuanFoley"):
699
+ gr.Markdown(
700
+ "**HunyuanVideo-Foley** (Tencent Hunyuan). "
701
+ "Professional-grade text-guided foley at 48 kHz, up to 15 s. "
702
+ "Requires a text prompt describing the desired sound."
703
+ )
704
+ with gr.Row():
705
+ with gr.Column():
706
+ hf_video = gr.Video(label="Input Video")
707
+ hf_prompt = gr.Textbox(label="Prompt", placeholder="e.g. rain hitting a metal roof")
708
+ hf_neg = gr.Textbox(label="Negative Prompt", value="noisy, harsh")
709
+ hf_seed = gr.Number(label="Seed (-1 = random)", value=get_random_seed, precision=0)
710
+ hf_guidance = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, value=4.5, step=0.5)
711
+ hf_steps = gr.Slider(label="Steps", minimum=10, maximum=100, value=50, step=5)
712
+ hf_size = gr.Radio(label="Model Size", choices=["xl", "xxl"], value="xxl")
713
+ hf_samples = gr.Slider(label="Generations", minimum=1, maximum=MAX_SLOTS, value=1, step=1)
714
+ hf_btn = gr.Button("Generate", variant="primary")
715
+
716
+ with gr.Column():
717
+ hf_slot_grps, hf_slot_vids, hf_slot_auds = [], [], []
718
+ for i in range(MAX_SLOTS):
719
+ with gr.Group(visible=(i == 0)) as g:
720
+ sv = gr.Video(label=f"Generation {i+1} β€” Video")
721
+ sa = gr.Audio(label=f"Generation {i+1} β€” Audio")
722
+ hf_slot_grps.append(g)
723
+ hf_slot_vids.append(sv)
724
+ hf_slot_auds.append(sa)
725
+
726
+ hf_samples.change(
727
+ fn=_update_slot_visibility,
728
+ inputs=[hf_samples],
729
+ outputs=hf_slot_grps,
730
+ )
731
+
732
+ def _run_hunyuan(video, prompt, neg, seed, guidance, steps, size, n):
733
+ flat = generate_hunyuan(video, prompt, neg, seed, guidance, steps, size, n)
734
+ n = int(n)
735
+ grp_upd = [gr.update(visible=(i < n)) for i in range(MAX_SLOTS)]
736
+ vid_upd = [gr.update(value=flat[i * 2]) for i in range(MAX_SLOTS)]
737
+ aud_upd = [gr.update(value=flat[i * 2 + 1]) for i in range(MAX_SLOTS)]
738
+ return grp_upd + vid_upd + aud_upd
739
+
740
+ hf_btn.click(
741
+ fn=_run_hunyuan,
742
+ inputs=[hf_video, hf_prompt, hf_neg, hf_seed,
743
+ hf_guidance, hf_steps, hf_size, hf_samples],
744
+ outputs=hf_slot_grps + hf_slot_vids + hf_slot_auds,
745
+ )
746
 
747
  demo.queue().launch()