JackIsNotInTheBox commited on
Commit
160db86
·
1 Parent(s): 5806ea4

Multi-sample support, last-segment tail anchor fix, dynamic samples cap, gr.Blocks UI, duration=600

Browse files
Files changed (1) hide show
  1. app.py +317 -193
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import subprocess
3
  import sys
 
4
 
5
  try:
6
  import mmcv
@@ -30,13 +31,23 @@ onset_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="onset_model.ckpt",
30
  taro_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="taro_ckpt.pt", cache_dir=CACHE_DIR)
31
  print("Checkpoints downloaded.")
32
 
 
 
 
 
 
 
 
 
 
 
33
  # ------------------------------------------------------------------ #
34
- # Inference cache: keyed by (video_path, seed, cfg_scale, #
35
- # num_steps, mode, crossfade_s) #
36
- # Stores the raw per-segment wavs so that only the dB value can be #
37
- # changed without re-running the model. #
38
  # ------------------------------------------------------------------ #
39
- _INFERENCE_CACHE = {} # key -> {"wavs": [...], "sr": int}
40
 
41
 
42
  def set_global_seed(seed):
@@ -57,33 +68,74 @@ def strip_audio_from_video(video_path, output_path):
57
  )
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def infer_segment(model, vae, vocoder, cavp_feats_full, onset_feats_full,
61
  seg_start_s, seg_end_s,
62
- sr, fps, truncate_frame, truncate_onset, model_dur,
63
- latents_scale, device, weight_dtype,
64
  cfg_scale, num_steps, mode,
 
65
  euler_sampler, euler_maruyama_sampler):
66
- """
67
- Run one model inference pass for the video window starting at seg_start_s.
68
- Returns a numpy float32 wav array trimmed to (seg_end_s - seg_start_s).
69
- """
70
- # CAVP features at fps (4 fps)
71
- cavp_start = int(round(seg_start_s * fps))
72
- cavp_slice = cavp_feats_full[cavp_start : cavp_start + truncate_frame]
73
- if cavp_slice.shape[0] < truncate_frame:
74
  pad = np.zeros(
75
- (truncate_frame - cavp_slice.shape[0],) + cavp_slice.shape[1:],
76
  dtype=cavp_slice.dtype,
77
  )
78
  cavp_slice = np.concatenate([cavp_slice, pad], axis=0)
79
  video_feats = torch.from_numpy(cavp_slice).unsqueeze(0).to(device).to(weight_dtype)
80
 
81
- # Onset features at truncate_onset / model_dur frames per second
82
- onset_fps = truncate_onset / model_dur
83
  onset_start = int(round(seg_start_s * onset_fps))
84
- onset_slice = onset_feats_full[onset_start : onset_start + truncate_onset]
85
- if onset_slice.shape[0] < truncate_onset:
86
- pad_len = truncate_onset - onset_slice.shape[0]
87
  onset_slice = np.pad(onset_slice, ((0, pad_len),), mode="constant", constant_values=0)
88
  onset_feats_t = torch.from_numpy(onset_slice).unsqueeze(0).to(device).to(weight_dtype)
89
 
@@ -108,214 +160,286 @@ def infer_segment(model, vae, vocoder, cavp_feats_full, onset_feats_full,
108
 
109
  samples = vae.decode(samples / latents_scale).sample
110
  wav = vocoder(samples.squeeze().float()).detach().cpu().numpy()
111
-
112
- seg_samples = int(round((seg_end_s - seg_start_s) * sr))
113
  return wav[:seg_samples]
114
 
115
 
116
- def crossfade_join(wav_a, wav_b, crossfade_s, db_boost, sr):
117
  """
118
- Join wav_a and wav_b with a crossfade_s-second crossfade.
119
-
120
- db_boost controls the gain applied to both signals in the overlap region:
121
- gain = 10 ** (db_boost / 20)
122
- At +3 dB (gain ≈ 1.414), the two summed unity signals produce +3 dB at midpoint.
123
- At 0 dB (gain = 1.0), each signal is kept at full amplitude — same as +3 dB sum
124
- since both are 1.0. The parameter lets the user tune the blend level freely.
125
-
126
- The crossfade window is the last crossfade_s seconds of wav_a overlapping with
127
- the first crossfade_s seconds of wav_b. Both are scaled by gain and summed.
128
  """
129
- cf_samples = int(round(crossfade_s * sr))
130
-
131
- # Guard: if either wav is shorter than the crossfade window, shrink the window
132
- cf_samples = min(cf_samples, len(wav_a), len(wav_b))
133
  if cf_samples <= 0:
134
  return np.concatenate([wav_a, wav_b])
135
 
136
- gain = 10 ** (db_boost / 20.0)
137
-
138
- tail_a = wav_a[-cf_samples:] * gain
139
- head_b = wav_b[:cf_samples] * gain
140
- overlap = tail_a + head_b
141
 
142
- return np.concatenate([
143
- wav_a[:-cf_samples],
144
- overlap,
145
- wav_b[cf_samples:],
146
- ])
147
 
148
 
149
- def stitch_wavs(wavs, crossfade_s, db_boost, sr, total_dur_s):
150
- """Stitch a list of wav arrays using crossfade_join, then clip to total_dur_s."""
151
  if len(wavs) == 1:
152
  final_wav = wavs[0]
153
  else:
154
  final_wav = wavs[0]
155
- for next_wav in wavs[1:]:
156
- final_wav = crossfade_join(final_wav, next_wav, crossfade_s, db_boost, sr)
 
157
 
158
- target_samples = int(round(total_dur_s * sr))
159
- return final_wav[:target_samples]
160
 
 
 
 
 
 
 
 
 
 
161
 
162
- @spaces.GPU(duration=300)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode,
164
- crossfade_s, crossfade_db):
165
  global _INFERENCE_CACHE
166
 
167
  seed_val = int(seed_val)
168
  crossfade_s = float(crossfade_s)
169
  crossfade_db = float(crossfade_db)
 
170
 
171
  if seed_val < 0:
172
  seed_val = random.randint(0, 2**32 - 1)
173
 
174
- sr = 16000
175
- truncate = 131072
176
- fps = 4
177
- truncate_frame = int(fps * truncate / sr)
178
- truncate_onset = 120
179
- model_dur = truncate / sr # 8.192 s
180
- step_s = model_dur - crossfade_s
181
-
182
- # Cache key covers everything that affects segmentation and inference
183
- cache_key = (video_file, seed_val, float(cfg_scale), int(num_steps), mode,
184
- crossfade_s)
185
-
186
- if cache_key in _INFERENCE_CACHE:
187
- print("Cache hit — skipping inference, re-stitching with new dB value.")
188
- cached = _INFERENCE_CACHE[cache_key]
189
- wavs = cached["wavs"]
190
- total_dur_s = cached["total_dur_s"]
191
- tmp_dir = cached["tmp_dir"]
192
- silent_video = cached["silent_video"]
193
- else:
194
- set_global_seed(seed_val)
195
- torch.set_grad_enabled(False)
196
- device = "cuda" if torch.cuda.is_available() else "cpu"
197
- weight_dtype = torch.bfloat16
198
-
199
- from cavp_util import Extract_CAVP_Features
200
- from onset_util import VideoOnsetNet, extract_onset
201
- from models import MMDiT
202
- from samplers import euler_sampler, euler_maruyama_sampler
203
- from diffusers import AudioLDM2Pipeline
204
-
205
- extract_cavp = Extract_CAVP_Features(
206
- device=device, config_path="./cavp/cavp.yaml", ckpt_path=cavp_ckpt_path
207
- )
208
 
209
- state_dict = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
210
- new_state_dict = {}
211
- for key, value in state_dict.items():
212
- if "model.net.model" in key:
213
- new_key = key.replace("model.net.model", "net.model")
214
- elif "model.fc." in key:
215
- new_key = key.replace("model.fc", "fc")
216
- else:
217
- new_key = key
218
- new_state_dict[new_key] = value
219
- onset_model = VideoOnsetNet(False).to(device)
220
- onset_model.load_state_dict(new_state_dict)
221
- onset_model.eval()
222
-
223
- model = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
224
- ckpt = torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"]
225
- model.load_state_dict(ckpt)
226
- model.eval()
227
- model.to(weight_dtype)
228
-
229
- model_audioldm = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
230
- vae = model_audioldm.vae.to(device)
231
- vae.eval()
232
- vocoder = model_audioldm.vocoder.to(device)
233
-
234
- tmp_dir = tempfile.mkdtemp()
235
- silent_video = os.path.join(tmp_dir, "silent_input.mp4")
236
- strip_audio_from_video(video_file, silent_video)
237
-
238
- cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
239
- onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
240
-
241
- latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
242
-
243
- total_frames = cavp_feats.shape[0]
244
- total_dur_s = total_frames / fps
245
-
246
- # Build segment list
247
- segments = []
248
- seg_start = 0.0
249
- while True:
250
- seg_end = min(seg_start + model_dur, total_dur_s)
251
- segments.append((seg_start, seg_end))
252
- if seg_end >= total_dur_s:
253
- break
254
- seg_start += step_s
255
-
256
- # Run inference for every segment
257
- wavs = []
258
- for seg_start_s, seg_end_s in segments:
259
- print(f"Inferring segment {seg_start_s:.2f}s – {seg_end_s:.2f}s ...")
260
- wav = infer_segment(
261
- model, vae, vocoder,
262
- cavp_feats, onset_feats,
263
- seg_start_s, seg_end_s,
264
- sr, fps, truncate_frame, truncate_onset, model_dur,
265
- latents_scale, device, weight_dtype,
266
- cfg_scale, num_steps, mode,
267
- euler_sampler, euler_maruyama_sampler,
268
- )
269
- wavs.append(wav)
270
 
271
- # Store in cache
272
- _INFERENCE_CACHE[cache_key] = {
273
- "wavs": wavs,
274
- "total_dur_s": total_dur_s,
275
- "tmp_dir": tmp_dir,
276
- "silent_video": silent_video,
277
- }
278
 
279
- # Stitch with current crossfade params
280
- device = "cuda" if torch.cuda.is_available() else "cpu"
281
- final_wav = stitch_wavs(wavs, crossfade_s, crossfade_db, sr, total_dur_s)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
- audio_path = os.path.join(tmp_dir, "output.wav")
284
- sf.write(audio_path, final_wav, sr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
- output_video = os.path.join(tmp_dir, "output.mp4")
287
- input_v = ffmpeg.input(silent_video)
288
- input_a = ffmpeg.input(audio_path)
289
- (
290
- ffmpeg
291
- .output(input_v, input_a, output_video,
292
- vcodec="libx264", acodec="aac", strict="experimental")
293
- .run(overwrite_output=True, quiet=True)
294
- )
295
 
296
- return output_video, audio_path
 
 
297
 
 
 
 
 
 
 
 
298
 
299
- def get_random_seed():
300
- return random.randint(0, 2**32 - 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
- demo = gr.Interface(
304
- fn=generate_audio,
305
- inputs=[
306
- gr.Video(label="Input Video"),
307
- gr.Number(label="Seed", value=get_random_seed, precision=0),
308
- gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5),
309
- gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1),
310
- gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde"),
311
- gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=8, value=2, step=0.1),
312
- gr.Textbox(label="Crossfade Boost (dB)", value="3"),
313
- ],
314
- outputs=[
315
- gr.Video(label="Output Video with Audio"),
316
- gr.Audio(label="Generated Audio"),
317
- ],
318
- title="TARO: Video-to-Audio Synthesis (ICCV 2025)",
319
- description="Upload a video and generate synchronized audio using TARO. Optimal clip duration is 8.2s. Longer videos are automatically split into overlapping segments and stitched with a crossfade.",
320
- )
321
  demo.queue().launch()
 
1
  import os
2
  import subprocess
3
  import sys
4
+ from math import ceil, floor
5
 
6
  try:
7
  import mmcv
 
31
  taro_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="taro_ckpt.pt", cache_dir=CACHE_DIR)
32
  print("Checkpoints downloaded.")
33
 
34
+ # Model constants
35
+ SR = 16000
36
+ TRUNCATE = 131072
37
+ FPS = 4
38
+ TRUNCATE_FRAME = int(FPS * TRUNCATE / SR) # 32 cavp frames per model window
39
+ TRUNCATE_ONSET = 120 # onset frames per model window
40
+ MODEL_DUR = TRUNCATE / SR # 8.192 s
41
+ MAX_SLOTS = 8 # max sample output slots in UI
42
+ SECS_PER_STEP = 2.5 # estimated seconds of GPU time per diffusion step
43
+
44
  # ------------------------------------------------------------------ #
45
+ # Inference cache #
46
+ # Key: (video_path, seed, cfg_scale, num_steps, mode, crossfade_s) #
47
+ # Value: {"wavs": [...], "total_dur_s": float, #
48
+ # "tmp_dir": str, "silent_video": str} #
49
  # ------------------------------------------------------------------ #
50
+ _INFERENCE_CACHE = {}
51
 
52
 
53
  def set_global_seed(seed):
 
68
  )
69
 
70
 
71
+ def get_video_duration(video_path):
72
+ """Read video duration in seconds using ffprobe (no GPU needed)."""
73
+ probe = ffmpeg.probe(video_path)
74
+ return float(probe["format"]["duration"])
75
+
76
+
77
+ def build_segments(total_dur_s, crossfade_s):
78
+ """
79
+ Build list of (seg_start_s, seg_end_s) segment windows.
80
+
81
+ For videos <= MODEL_DUR: single segment [0, total_dur_s].
82
+ For longer videos: advance by step_s = MODEL_DUR - crossfade_s each time.
83
+ The LAST segment is always anchored at [total_dur_s - MODEL_DUR, total_dur_s]
84
+ so it is a full-length window with no zero-padding, giving the best quality
85
+ at the tail end of the video.
86
+ """
87
+ if total_dur_s <= MODEL_DUR:
88
+ return [(0.0, total_dur_s)]
89
+
90
+ step_s = MODEL_DUR - crossfade_s
91
+ segments = []
92
+ seg_start = 0.0
93
+ while True:
94
+ seg_end = seg_start + MODEL_DUR
95
+ if seg_end >= total_dur_s:
96
+ # Replace this segment with a full-length tail-anchored window
97
+ seg_start = max(0.0, total_dur_s - MODEL_DUR)
98
+ segments.append((seg_start, total_dur_s))
99
+ break
100
+ segments.append((seg_start, seg_start + MODEL_DUR))
101
+ seg_start += step_s
102
+
103
+ return segments
104
+
105
+
106
+ def calc_max_samples(total_dur_s, num_steps, crossfade_s):
107
+ """Estimate max samples that fit within the 600s ZeroGPU budget."""
108
+ num_segments = len(build_segments(total_dur_s, crossfade_s))
109
+ time_per_seg = num_steps * SECS_PER_STEP
110
+ budget = 600.0
111
+ max_s = floor(budget / (num_segments * time_per_seg))
112
+ return max(1, min(max_s, MAX_SLOTS))
113
+
114
+
115
  def infer_segment(model, vae, vocoder, cavp_feats_full, onset_feats_full,
116
  seg_start_s, seg_end_s,
117
+ device, weight_dtype,
 
118
  cfg_scale, num_steps, mode,
119
+ latents_scale,
120
  euler_sampler, euler_maruyama_sampler):
121
+ """Run one model inference pass. Returns wav trimmed to segment duration."""
122
+ # CAVP features (4 fps)
123
+ cavp_start = int(round(seg_start_s * FPS))
124
+ cavp_slice = cavp_feats_full[cavp_start : cavp_start + TRUNCATE_FRAME]
125
+ if cavp_slice.shape[0] < TRUNCATE_FRAME:
 
 
 
126
  pad = np.zeros(
127
+ (TRUNCATE_FRAME - cavp_slice.shape[0],) + cavp_slice.shape[1:],
128
  dtype=cavp_slice.dtype,
129
  )
130
  cavp_slice = np.concatenate([cavp_slice, pad], axis=0)
131
  video_feats = torch.from_numpy(cavp_slice).unsqueeze(0).to(device).to(weight_dtype)
132
 
133
+ # Onset features
134
+ onset_fps = TRUNCATE_ONSET / MODEL_DUR
135
  onset_start = int(round(seg_start_s * onset_fps))
136
+ onset_slice = onset_feats_full[onset_start : onset_start + TRUNCATE_ONSET]
137
+ if onset_slice.shape[0] < TRUNCATE_ONSET:
138
+ pad_len = TRUNCATE_ONSET - onset_slice.shape[0]
139
  onset_slice = np.pad(onset_slice, ((0, pad_len),), mode="constant", constant_values=0)
140
  onset_feats_t = torch.from_numpy(onset_slice).unsqueeze(0).to(device).to(weight_dtype)
141
 
 
160
 
161
  samples = vae.decode(samples / latents_scale).sample
162
  wav = vocoder(samples.squeeze().float()).detach().cpu().numpy()
163
+ seg_samples = int(round((seg_end_s - seg_start_s) * SR))
 
164
  return wav[:seg_samples]
165
 
166
 
167
+ def crossfade_join(wav_a, wav_b, crossfade_s, db_boost):
168
  """
169
+ Join two wav arrays with a crossfade.
170
+ Both signals are scaled by gain = 10^(db_boost/20) in the overlap region
171
+ and summed, producing a +db_boost bump at the midpoint.
 
 
 
 
 
 
 
172
  """
173
+ cf_samples = int(round(crossfade_s * SR))
174
+ cf_samples = min(cf_samples, len(wav_a), len(wav_b))
 
 
175
  if cf_samples <= 0:
176
  return np.concatenate([wav_a, wav_b])
177
 
178
+ gain = 10 ** (db_boost / 20.0)
179
+ overlap = wav_a[-cf_samples:] * gain + wav_b[:cf_samples] * gain
 
 
 
180
 
181
+ return np.concatenate([wav_a[:-cf_samples], overlap, wav_b[cf_samples:]])
 
 
 
 
182
 
183
 
184
+ def stitch_wavs(wavs, crossfade_s, db_boost, total_dur_s):
185
+ """Stitch segment wavs with crossfades and clip to total_dur_s."""
186
  if len(wavs) == 1:
187
  final_wav = wavs[0]
188
  else:
189
  final_wav = wavs[0]
190
+ for nw in wavs[1:]:
191
+ final_wav = crossfade_join(final_wav, nw, crossfade_s, db_boost)
192
+ return final_wav[:int(round(total_dur_s * SR))]
193
 
 
 
194
 
195
+ def mux_video_audio(silent_video, audio_path, output_path):
196
+ input_v = ffmpeg.input(silent_video)
197
+ input_a = ffmpeg.input(audio_path)
198
+ (
199
+ ffmpeg
200
+ .output(input_v, input_a, output_path,
201
+ vcodec="libx264", acodec="aac", strict="experimental")
202
+ .run(overwrite_output=True, quiet=True)
203
+ )
204
 
205
+
206
+ # ------------------------------------------------------------------ #
207
+ # UI helpers (no GPU) #
208
+ # ------------------------------------------------------------------ #
209
+
210
+ def on_video_upload(video_file, num_steps, crossfade_s):
211
+ """Called when video is uploaded or sliders change. Updates samples slider."""
212
+ if video_file is None:
213
+ return gr.update(maximum=MAX_SLOTS, value=1)
214
+ try:
215
+ D = get_video_duration(video_file)
216
+ max_s = calc_max_samples(D, int(num_steps), float(crossfade_s))
217
+ except Exception:
218
+ max_s = MAX_SLOTS
219
+ return gr.update(maximum=max_s, value=min(1, max_s))
220
+
221
+
222
+ def get_random_seed():
223
+ return random.randint(0, 2**32 - 1)
224
+
225
+
226
+ # ------------------------------------------------------------------ #
227
+ # Main inference #
228
+ # ------------------------------------------------------------------ #
229
+
230
+ @spaces.GPU(duration=600)
231
  def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode,
232
+ crossfade_s, crossfade_db, num_samples):
233
  global _INFERENCE_CACHE
234
 
235
  seed_val = int(seed_val)
236
  crossfade_s = float(crossfade_s)
237
  crossfade_db = float(crossfade_db)
238
+ num_samples = int(num_samples)
239
 
240
  if seed_val < 0:
241
  seed_val = random.randint(0, 2**32 - 1)
242
 
243
+ # Load models once (shared across all samples this call)
244
+ torch.set_grad_enabled(False)
245
+ device = "cuda" if torch.cuda.is_available() else "cpu"
246
+ weight_dtype = torch.bfloat16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
+ from cavp_util import Extract_CAVP_Features
249
+ from onset_util import VideoOnsetNet, extract_onset
250
+ from models import MMDiT
251
+ from samplers import euler_sampler, euler_maruyama_sampler
252
+ from diffusers import AudioLDM2Pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
+ extract_cavp = Extract_CAVP_Features(
255
+ device=device, config_path="./cavp/cavp.yaml", ckpt_path=cavp_ckpt_path
256
+ )
 
 
 
 
257
 
258
+ state_dict = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
259
+ new_state_dict = {}
260
+ for key, value in state_dict.items():
261
+ if "model.net.model" in key:
262
+ new_key = key.replace("model.net.model", "net.model")
263
+ elif "model.fc." in key:
264
+ new_key = key.replace("model.fc", "fc")
265
+ else:
266
+ new_key = key
267
+ new_state_dict[new_key] = value
268
+ onset_model = VideoOnsetNet(False).to(device)
269
+ onset_model.load_state_dict(new_state_dict)
270
+ onset_model.eval()
271
+
272
+ model = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
273
+ ckpt = torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"]
274
+ model.load_state_dict(ckpt)
275
+ model.eval()
276
+ model.to(weight_dtype)
277
+
278
+ model_audioldm = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
279
+ vae = model_audioldm.vae.to(device)
280
+ vae.eval()
281
+ vocoder = model_audioldm.vocoder.to(device)
282
+
283
+ latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
284
+
285
+ # Prepare silent video (shared across all samples)
286
+ tmp_dir = tempfile.mkdtemp()
287
+ silent_video = os.path.join(tmp_dir, "silent_input.mp4")
288
+ strip_audio_from_video(video_file, silent_video)
289
+
290
+ cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
291
+ total_frames = cavp_feats.shape[0]
292
+ total_dur_s = total_frames / FPS
293
+ segments = build_segments(total_dur_s, crossfade_s)
294
+
295
+ # ------------------------------------------------------------------ #
296
+ # Generate N samples #
297
+ # ------------------------------------------------------------------ #
298
+ outputs = [] # list of (video_path, audio_path)
299
+
300
+ for sample_idx in range(num_samples):
301
+ sample_seed = seed_val + sample_idx
302
+ cache_key = (video_file, sample_seed, float(cfg_scale),
303
+ int(num_steps), mode, crossfade_s)
304
+
305
+ if cache_key in _INFERENCE_CACHE:
306
+ print(f"Sample {sample_idx+1}: cache hit, re-stitching.")
307
+ cached = _INFERENCE_CACHE[cache_key]
308
+ wavs = cached["wavs"]
309
+ else:
310
+ set_global_seed(sample_seed)
311
+ onset_feats = extract_onset(
312
+ silent_video, onset_model, tmp_path=tmp_dir, device=device
313
+ )
314
 
315
+ wavs = []
316
+ for seg_start_s, seg_end_s in segments:
317
+ print(f" Sample {sample_idx+1} | segment {seg_start_s:.2f}s – {seg_end_s:.2f}s")
318
+ wav = infer_segment(
319
+ model, vae, vocoder,
320
+ cavp_feats, onset_feats,
321
+ seg_start_s, seg_end_s,
322
+ device, weight_dtype,
323
+ cfg_scale, num_steps, mode,
324
+ latents_scale,
325
+ euler_sampler, euler_maruyama_sampler,
326
+ )
327
+ wavs.append(wav)
328
+
329
+ _INFERENCE_CACHE[cache_key] = {"wavs": wavs}
330
+
331
+ # Stitch
332
+ final_wav = stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s)
333
+
334
+ audio_path = os.path.join(tmp_dir, f"output_{sample_idx}.wav")
335
+ sf.write(audio_path, final_wav, SR)
336
+
337
+ video_path = os.path.join(tmp_dir, f"output_{sample_idx}.mp4")
338
+ mux_video_audio(silent_video, audio_path, video_path)
339
+
340
+ outputs.append((video_path, audio_path))
341
+
342
+ # ------------------------------------------------------------------ #
343
+ # Return flat list of (video, audio) pairs padded with None #
344
+ # so Gradio output list length is always MAX_SLOTS * 2 #
345
+ # ------------------------------------------------------------------ #
346
+ result = []
347
+ for i in range(MAX_SLOTS):
348
+ if i < len(outputs):
349
+ result.append(outputs[i][0]) # video
350
+ result.append(outputs[i][1]) # audio
351
+ else:
352
+ result.append(None)
353
+ result.append(None)
354
+ return result
355
 
 
 
 
 
 
 
 
 
 
356
 
357
+ # ------------------------------------------------------------------ #
358
+ # Build gr.Blocks UI #
359
+ # ------------------------------------------------------------------ #
360
 
361
+ with gr.Blocks(title="TARO: Video-to-Audio Synthesis") as demo:
362
+ gr.Markdown(
363
+ "# TARO: Video-to-Audio Synthesis (ICCV 2025)\n"
364
+ "Upload a video and generate synchronized audio. "
365
+ "Optimal clip duration is 8.2s. Longer videos are automatically "
366
+ "split into overlapping segments and stitched with a crossfade."
367
+ )
368
 
369
+ with gr.Row():
370
+ with gr.Column():
371
+ video_input = gr.Video(label="Input Video")
372
+ seed_input = gr.Number(label="Seed", value=get_random_seed, precision=0)
373
+ cfg_input = gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5)
374
+ steps_input = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1)
375
+ mode_input = gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde")
376
+ cf_dur_input = gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=8, value=2, step=0.1)
377
+ cf_db_input = gr.Textbox(label="Crossfade Boost (dB)", value="3")
378
+ samples_input = gr.Slider(label="Number of Samples", minimum=1, maximum=MAX_SLOTS,
379
+ value=1, step=1)
380
+ run_btn = gr.Button("Generate", variant="primary")
381
+
382
+ with gr.Column():
383
+ # Pre-build MAX_SLOTS output slots; hide all initially
384
+ slot_videos = []
385
+ slot_audios = []
386
+ for i in range(MAX_SLOTS):
387
+ with gr.Group(visible=False) as grp:
388
+ sv = gr.Video(label=f"Sample {i+1} — Video")
389
+ sa = gr.Audio(label=f"Sample {i+1} — Audio")
390
+ slot_videos.append((grp, sv))
391
+ slot_audios.append((grp, sa))
392
+
393
+ # ------------------------------------------------------------------ #
394
+ # Events #
395
+ # ------------------------------------------------------------------ #
396
+
397
+ # Update samples slider max when video uploaded or relevant sliders change
398
+ def _update_samples_slider(video_file, num_steps, crossfade_s):
399
+ return on_video_upload(video_file, num_steps, crossfade_s)
400
+
401
+ for trigger in [video_input, steps_input, cf_dur_input]:
402
+ trigger.change(
403
+ fn=_update_samples_slider,
404
+ inputs=[video_input, steps_input, cf_dur_input],
405
+ outputs=[samples_input],
406
+ )
407
 
408
+ # Collect all output components (flat: grp_visible, video, audio per slot)
409
+ all_outputs = []
410
+ for grp, sv in slot_videos:
411
+ all_outputs.append(grp)
412
+ for _, sa in slot_audios:
413
+ all_outputs.append(sa)
414
+ # Actually build properly: interleaved group + video + audio
415
+ all_outputs = []
416
+ slot_video_comps = [sv for _, sv in slot_videos]
417
+ slot_audio_comps = [sa for _, sa in slot_audios]
418
+ slot_grp_comps = [grp for grp, _ in slot_videos]
419
+
420
+ def _generate_and_update(video_file, seed_val, cfg_scale, num_steps, mode,
421
+ crossfade_s, crossfade_db, num_samples):
422
+ flat = generate_audio(video_file, seed_val, cfg_scale, num_steps, mode,
423
+ crossfade_s, crossfade_db, num_samples)
424
+ num_samples = int(num_samples)
425
+ # flat = [vid0, aud0, vid1, aud1, ...]
426
+ grp_updates = []
427
+ video_updates = []
428
+ audio_updates = []
429
+ for i in range(MAX_SLOTS):
430
+ visible = i < num_samples
431
+ vid = flat[i * 2]
432
+ aud = flat[i * 2 + 1]
433
+ grp_updates.append(gr.update(visible=visible))
434
+ video_updates.append(gr.update(value=vid))
435
+ audio_updates.append(gr.update(value=aud))
436
+ return grp_updates + video_updates + audio_updates
437
+
438
+ run_btn.click(
439
+ fn=_generate_and_update,
440
+ inputs=[video_input, seed_input, cfg_input, steps_input, mode_input,
441
+ cf_dur_input, cf_db_input, samples_input],
442
+ outputs=slot_grp_comps + slot_video_comps + slot_audio_comps,
443
+ )
444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  demo.queue().launch()