JackIsNotInTheBox commited on
Commit
eaedb53
·
1 Parent(s): 0429f8a

Add crossfade duration/dB controls, inference caching, share=True

Browse files
Files changed (1) hide show
  1. app.py +172 -154
app.py CHANGED
@@ -30,6 +30,14 @@ onset_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="onset_model.ckpt",
30
  taro_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="taro_ckpt.pt", cache_dir=CACHE_DIR)
31
  print("Checkpoints downloaded.")
32
 
 
 
 
 
 
 
 
 
33
 
34
  def set_global_seed(seed):
35
  np.random.seed(seed % (2**32))
@@ -56,22 +64,22 @@ def infer_segment(model, vae, vocoder, cavp_feats_full, onset_feats_full,
56
  cfg_scale, num_steps, mode,
57
  euler_sampler, euler_maruyama_sampler):
58
  """
59
- Run one model inference pass for the video window [seg_start_s, seg_start_s + model_dur].
60
- Returns a numpy float32 wav array of exactly round(model_dur * sr) samples,
61
- trimmed to the actual segment length (seg_end_s - seg_start_s) when shorter.
62
  """
63
- # -- CAVP features: 4 fps --
64
  cavp_start = int(round(seg_start_s * fps))
65
- cavp_end = cavp_start + truncate_frame
66
- cavp_slice = cavp_feats_full[cavp_start:cavp_end]
67
- # pad if near end of video
68
  if cavp_slice.shape[0] < truncate_frame:
69
- pad = np.zeros((truncate_frame - cavp_slice.shape[0],) + cavp_slice.shape[1:], dtype=cavp_slice.dtype)
 
 
 
70
  cavp_slice = np.concatenate([cavp_slice, pad], axis=0)
71
  video_feats = torch.from_numpy(cavp_slice).unsqueeze(0).to(device).to(weight_dtype)
72
 
73
- # -- Onset features: truncate_onset frames per model_dur --
74
- onset_fps = truncate_onset / model_dur # frames per second of onset feats
75
  onset_start = int(round(seg_start_s * onset_fps))
76
  onset_slice = onset_feats_full[onset_start : onset_start + truncate_onset]
77
  if onset_slice.shape[0] < truncate_onset:
@@ -79,7 +87,6 @@ def infer_segment(model, vae, vocoder, cavp_feats_full, onset_feats_full,
79
  onset_slice = np.pad(onset_slice, ((0, pad_len),), mode="constant", constant_values=0)
80
  onset_feats_t = torch.from_numpy(onset_slice).unsqueeze(0).to(device).to(weight_dtype)
81
 
82
- # -- Diffusion --
83
  z = torch.randn(1, model.in_channels, 204, 16, device=device).to(weight_dtype)
84
  sampling_kwargs = dict(
85
  model=model,
@@ -102,171 +109,180 @@ def infer_segment(model, vae, vocoder, cavp_feats_full, onset_feats_full,
102
  samples = vae.decode(samples / latents_scale).sample
103
  wav = vocoder(samples.squeeze().float()).detach().cpu().numpy()
104
 
105
- # Trim to actual segment length
106
  seg_samples = int(round((seg_end_s - seg_start_s) * sr))
107
  return wav[:seg_samples]
108
 
109
 
110
- def crossfade_join(wav_a, wav_b, crossfade_s, sr):
111
  """
112
- Join wav_a and wav_b with a 2-second equal-power (+3 dB) crossfade.
113
-
114
- wav_a contains 1 s of 'extra' audio at its tail (the overlap region starts
115
- 1 s before its end). wav_b contains 1 s of 'extra' audio at its head.
116
- The crossfade window is crossfade_s wide; the midpoint sits at (crossfade_s/2)
117
- into the window, where each gain = sqrt(0.5) -3 dB ... wait, we want +3 dB
118
- at midpoint meaning both signals are at *full* amplitude there.
119
-
120
- Equal-power (sqrt) ramps: at the midpoint t=0.5 the fade-out = sqrt(0.5) and
121
- fade-in = sqrt(0.5), so combined power = 0.5+0.5 = 1.0 (+0 dB).
122
- For a +3 dB bump at midpoint we use *linear* ramps instead:
123
- fade_out = 1 - t, fade_in = t (t: 0->1 across window)
124
- At t=0.5: both = 0.5, sum = 1.0 amplitude = +6 dB power... that is not right.
125
-
126
- DaVinci Resolve "+3 dB" crossfade means the combined level at the midpoint
127
- is +3 dB above either source, which equals the behaviour where each signal
128
- is kept at full gain (1.0) across the entire overlap and the two are simply
129
- summed — then the overlap region has 6 dB of headroom risk, but the *perceived*
130
- loudness boost at the centre is +3 dB (sqrt(2) in amplitude).
131
-
132
- Implementation: keep both signals at unity gain in the crossfade window and
133
- sum them. Outside the window use the respective signal only.
134
  """
135
  cf_samples = int(round(crossfade_s * sr))
136
 
137
- # The crossfade sits at the junction: last cf_samples of wav_a overlap with
138
- # first cf_samples of wav_b.
139
- tail_a = wav_a[-cf_samples:] # 1s before end of a
140
- head_b = wav_b[:cf_samples] # 1s after start of b
141
- overlap = tail_a + head_b # +3 dB sum at centre (unity + unity)
142
 
143
- result = np.concatenate([
144
- wav_a[:-cf_samples], # body of a (before crossfade)
145
- overlap, # crossfade region
146
- wav_b[cf_samples:], # body of b (after crossfade)
 
 
 
 
 
 
147
  ])
148
- return result
149
 
150
 
151
- @spaces.GPU(duration=300)
152
- def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
153
- seed_val = int(seed_val)
154
- if seed_val < 0:
155
- seed_val = random.randint(0, 2**32 - 1)
156
- set_global_seed(seed_val)
157
- torch.set_grad_enabled(False)
158
- device = "cuda" if torch.cuda.is_available() else "cpu"
159
- weight_dtype = torch.bfloat16
 
 
160
 
161
- from cavp_util import Extract_CAVP_Features
162
- from onset_util import VideoOnsetNet, extract_onset
163
- from models import MMDiT
164
- from samplers import euler_sampler, euler_maruyama_sampler
165
- from diffusers import AudioLDM2Pipeline
166
 
167
- extract_cavp = Extract_CAVP_Features(
168
- device=device, config_path="./cavp/cavp.yaml", ckpt_path=cavp_ckpt_path
169
- )
 
170
 
171
- state_dict = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
172
- new_state_dict = {}
173
- for key, value in state_dict.items():
174
- if "model.net.model" in key:
175
- new_key = key.replace("model.net.model", "net.model")
176
- elif "model.fc." in key:
177
- new_key = key.replace("model.fc", "fc")
178
- else:
179
- new_key = key
180
- new_state_dict[new_key] = value
181
- onset_model = VideoOnsetNet(False).to(device)
182
- onset_model.load_state_dict(new_state_dict)
183
- onset_model.eval()
184
-
185
- model = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
186
- ckpt = torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"]
187
- model.load_state_dict(ckpt)
188
- model.eval()
189
- model.to(weight_dtype)
190
-
191
- model_audioldm = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
192
- vae = model_audioldm.vae.to(device)
193
- vae.eval()
194
- vocoder = model_audioldm.vocoder.to(device)
195
-
196
- tmp_dir = tempfile.mkdtemp()
197
- silent_video = os.path.join(tmp_dir, "silent_input.mp4")
198
- strip_audio_from_video(video_file, silent_video)
199
-
200
- cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
201
- onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
202
 
203
  sr = 16000
204
  truncate = 131072
205
  fps = 4
206
- truncate_frame = int(fps * truncate / sr) # 32 cavp frames per segment
207
- truncate_onset = 120 # onset frames per segment
208
- model_dur = truncate / sr # 8.192 s
209
- crossfade_s = 2.0 # 2-second crossfade window
210
- # Each segment starts (model_dur - crossfade_s) later than the previous,
211
- # so the tails overlap by crossfade_s giving 1 s of extra audio on each side.
212
- step_s = model_dur - crossfade_s # 6.192 s
213
-
214
- latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
215
-
216
- # Total video duration from cavp features
217
- total_frames = cavp_feats.shape[0]
218
- total_dur_s = total_frames / fps
219
-
220
- # ------------------------------------------------------------------ #
221
- # Build segment list: each entry is (seg_start_s, seg_end_s) #
222
- # seg_end_s is the actual content end (clipped to video length), #
223
- # but we always run the model for a full model_dur window. #
224
- # ------------------------------------------------------------------ #
225
- segments = []
226
- seg_start = 0.0
227
- while True:
228
- seg_end = min(seg_start + model_dur, total_dur_s)
229
- segments.append((seg_start, seg_end))
230
- if seg_end >= total_dur_s:
231
- break
232
- seg_start += step_s
233
-
234
- # ------------------------------------------------------------------ #
235
- # Run inference for every segment #
236
- # ------------------------------------------------------------------ #
237
- wavs = []
238
- for seg_start_s, seg_end_s in segments:
239
- print(f"Inferring segment {seg_start_s:.2f}s – {seg_end_s:.2f}s ...")
240
- wav = infer_segment(
241
- model, vae, vocoder,
242
- cavp_feats, onset_feats,
243
- seg_start_s, seg_end_s,
244
- sr, fps, truncate_frame, truncate_onset, model_dur,
245
- latents_scale, device, weight_dtype,
246
- cfg_scale, num_steps, mode,
247
- euler_sampler, euler_maruyama_sampler,
248
- )
249
- wavs.append(wav)
250
-
251
- # ------------------------------------------------------------------ #
252
- # Stitch with crossfades #
253
- # Single segment: no crossfade needed #
254
- # ------------------------------------------------------------------ #
255
- if len(wavs) == 1:
256
- final_wav = wavs[0]
257
  else:
258
- final_wav = wavs[0]
259
- for next_wav in wavs[1:]:
260
- final_wav = crossfade_join(final_wav, next_wav, crossfade_s, sr)
 
 
 
 
 
 
 
 
 
 
 
261
 
262
- # Clip to exact video duration
263
- target_samples = int(round(total_dur_s * sr))
264
- final_wav = final_wav[:target_samples]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
- audio_path = os.path.join(tmp_dir, "output.wav")
267
  sf.write(audio_path, final_wav, sr)
268
 
269
- # Mux original silent video (full length) with generated audio
270
  output_video = os.path.join(tmp_dir, "output.mp4")
271
  input_v = ffmpeg.input(silent_video)
272
  input_a = ffmpeg.input(audio_path)
@@ -292,12 +308,14 @@ demo = gr.Interface(
292
  gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5),
293
  gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1),
294
  gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde"),
 
 
295
  ],
296
  outputs=[
297
  gr.Video(label="Output Video with Audio"),
298
  gr.Audio(label="Generated Audio"),
299
  ],
300
  title="TARO: Video-to-Audio Synthesis (ICCV 2025)",
301
- description="Upload a video and generate synchronized audio using TARO. Optimal duration is 8.2s.",
302
  )
303
- demo.queue().launch()
 
30
  taro_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="taro_ckpt.pt", cache_dir=CACHE_DIR)
31
  print("Checkpoints downloaded.")
32
 
33
+ # ------------------------------------------------------------------ #
34
+ # Inference cache: keyed by (video_path, seed, cfg_scale, #
35
+ # num_steps, mode, crossfade_s) #
36
+ # Stores the raw per-segment wavs so that only the dB value can be #
37
+ # changed without re-running the model. #
38
+ # ------------------------------------------------------------------ #
39
+ _INFERENCE_CACHE = {} # key -> {"wavs": [...], "sr": int}
40
+
41
 
42
  def set_global_seed(seed):
43
  np.random.seed(seed % (2**32))
 
64
  cfg_scale, num_steps, mode,
65
  euler_sampler, euler_maruyama_sampler):
66
  """
67
+ Run one model inference pass for the video window starting at seg_start_s.
68
+ Returns a numpy float32 wav array trimmed to (seg_end_s - seg_start_s).
 
69
  """
70
+ # CAVP features at fps (4 fps)
71
  cavp_start = int(round(seg_start_s * fps))
72
+ cavp_slice = cavp_feats_full[cavp_start : cavp_start + truncate_frame]
 
 
73
  if cavp_slice.shape[0] < truncate_frame:
74
+ pad = np.zeros(
75
+ (truncate_frame - cavp_slice.shape[0],) + cavp_slice.shape[1:],
76
+ dtype=cavp_slice.dtype,
77
+ )
78
  cavp_slice = np.concatenate([cavp_slice, pad], axis=0)
79
  video_feats = torch.from_numpy(cavp_slice).unsqueeze(0).to(device).to(weight_dtype)
80
 
81
+ # Onset features at truncate_onset / model_dur frames per second
82
+ onset_fps = truncate_onset / model_dur
83
  onset_start = int(round(seg_start_s * onset_fps))
84
  onset_slice = onset_feats_full[onset_start : onset_start + truncate_onset]
85
  if onset_slice.shape[0] < truncate_onset:
 
87
  onset_slice = np.pad(onset_slice, ((0, pad_len),), mode="constant", constant_values=0)
88
  onset_feats_t = torch.from_numpy(onset_slice).unsqueeze(0).to(device).to(weight_dtype)
89
 
 
90
  z = torch.randn(1, model.in_channels, 204, 16, device=device).to(weight_dtype)
91
  sampling_kwargs = dict(
92
  model=model,
 
109
  samples = vae.decode(samples / latents_scale).sample
110
  wav = vocoder(samples.squeeze().float()).detach().cpu().numpy()
111
 
 
112
  seg_samples = int(round((seg_end_s - seg_start_s) * sr))
113
  return wav[:seg_samples]
114
 
115
 
116
+ def crossfade_join(wav_a, wav_b, crossfade_s, db_boost, sr):
117
  """
118
+ Join wav_a and wav_b with a crossfade_s-second crossfade.
119
+
120
+ db_boost controls the gain applied to both signals in the overlap region:
121
+ gain = 10 ** (db_boost / 20)
122
+ At +3 dB (gain 1.414), the two summed unity signals produce +3 dB at midpoint.
123
+ At 0 dB (gain = 1.0), each signal is kept at full amplitude — same as +3 dB sum
124
+ since both are 1.0. The parameter lets the user tune the blend level freely.
125
+
126
+ The crossfade window is the last crossfade_s seconds of wav_a overlapping with
127
+ the first crossfade_s seconds of wav_b. Both are scaled by gain and summed.
 
 
 
 
 
 
 
 
 
 
 
 
128
  """
129
  cf_samples = int(round(crossfade_s * sr))
130
 
131
+ # Guard: if either wav is shorter than the crossfade window, shrink the window
132
+ cf_samples = min(cf_samples, len(wav_a), len(wav_b))
133
+ if cf_samples <= 0:
134
+ return np.concatenate([wav_a, wav_b])
 
135
 
136
+ gain = 10 ** (db_boost / 20.0)
137
+
138
+ tail_a = wav_a[-cf_samples:] * gain
139
+ head_b = wav_b[:cf_samples] * gain
140
+ overlap = tail_a + head_b
141
+
142
+ return np.concatenate([
143
+ wav_a[:-cf_samples],
144
+ overlap,
145
+ wav_b[cf_samples:],
146
  ])
 
147
 
148
 
149
+ def stitch_wavs(wavs, crossfade_s, db_boost, sr, total_dur_s):
150
+ """Stitch a list of wav arrays using crossfade_join, then clip to total_dur_s."""
151
+ if len(wavs) == 1:
152
+ final_wav = wavs[0]
153
+ else:
154
+ final_wav = wavs[0]
155
+ for next_wav in wavs[1:]:
156
+ final_wav = crossfade_join(final_wav, next_wav, crossfade_s, db_boost, sr)
157
+
158
+ target_samples = int(round(total_dur_s * sr))
159
+ return final_wav[:target_samples]
160
 
 
 
 
 
 
161
 
162
+ @spaces.GPU(duration=300)
163
+ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode,
164
+ crossfade_s, crossfade_db):
165
+ global _INFERENCE_CACHE
166
 
167
+ seed_val = int(seed_val)
168
+ crossfade_s = float(crossfade_s)
169
+ crossfade_db = float(crossfade_db)
170
+
171
+ if seed_val < 0:
172
+ seed_val = random.randint(0, 2**32 - 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  sr = 16000
175
  truncate = 131072
176
  fps = 4
177
+ truncate_frame = int(fps * truncate / sr)
178
+ truncate_onset = 120
179
+ model_dur = truncate / sr # 8.192 s
180
+ step_s = model_dur - crossfade_s
181
+
182
+ # Cache key covers everything that affects segmentation and inference
183
+ cache_key = (video_file, seed_val, float(cfg_scale), int(num_steps), mode,
184
+ crossfade_s)
185
+
186
+ if cache_key in _INFERENCE_CACHE:
187
+ print("Cache hit skipping inference, re-stitching with new dB value.")
188
+ cached = _INFERENCE_CACHE[cache_key]
189
+ wavs = cached["wavs"]
190
+ total_dur_s = cached["total_dur_s"]
191
+ tmp_dir = cached["tmp_dir"]
192
+ silent_video = cached["silent_video"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  else:
194
+ set_global_seed(seed_val)
195
+ torch.set_grad_enabled(False)
196
+ device = "cuda" if torch.cuda.is_available() else "cpu"
197
+ weight_dtype = torch.bfloat16
198
+
199
+ from cavp_util import Extract_CAVP_Features
200
+ from onset_util import VideoOnsetNet, extract_onset
201
+ from models import MMDiT
202
+ from samplers import euler_sampler, euler_maruyama_sampler
203
+ from diffusers import AudioLDM2Pipeline
204
+
205
+ extract_cavp = Extract_CAVP_Features(
206
+ device=device, config_path="./cavp/cavp.yaml", ckpt_path=cavp_ckpt_path
207
+ )
208
 
209
+ state_dict = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
210
+ new_state_dict = {}
211
+ for key, value in state_dict.items():
212
+ if "model.net.model" in key:
213
+ new_key = key.replace("model.net.model", "net.model")
214
+ elif "model.fc." in key:
215
+ new_key = key.replace("model.fc", "fc")
216
+ else:
217
+ new_key = key
218
+ new_state_dict[new_key] = value
219
+ onset_model = VideoOnsetNet(False).to(device)
220
+ onset_model.load_state_dict(new_state_dict)
221
+ onset_model.eval()
222
+
223
+ model = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
224
+ ckpt = torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"]
225
+ model.load_state_dict(ckpt)
226
+ model.eval()
227
+ model.to(weight_dtype)
228
+
229
+ model_audioldm = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
230
+ vae = model_audioldm.vae.to(device)
231
+ vae.eval()
232
+ vocoder = model_audioldm.vocoder.to(device)
233
+
234
+ tmp_dir = tempfile.mkdtemp()
235
+ silent_video = os.path.join(tmp_dir, "silent_input.mp4")
236
+ strip_audio_from_video(video_file, silent_video)
237
+
238
+ cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
239
+ onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
240
+
241
+ latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
242
+
243
+ total_frames = cavp_feats.shape[0]
244
+ total_dur_s = total_frames / fps
245
+
246
+ # Build segment list
247
+ segments = []
248
+ seg_start = 0.0
249
+ while True:
250
+ seg_end = min(seg_start + model_dur, total_dur_s)
251
+ segments.append((seg_start, seg_end))
252
+ if seg_end >= total_dur_s:
253
+ break
254
+ seg_start += step_s
255
+
256
+ # Run inference for every segment
257
+ wavs = []
258
+ for seg_start_s, seg_end_s in segments:
259
+ print(f"Inferring segment {seg_start_s:.2f}s – {seg_end_s:.2f}s ...")
260
+ wav = infer_segment(
261
+ model, vae, vocoder,
262
+ cavp_feats, onset_feats,
263
+ seg_start_s, seg_end_s,
264
+ sr, fps, truncate_frame, truncate_onset, model_dur,
265
+ latents_scale, device, weight_dtype,
266
+ cfg_scale, num_steps, mode,
267
+ euler_sampler, euler_maruyama_sampler,
268
+ )
269
+ wavs.append(wav)
270
+
271
+ # Store in cache
272
+ _INFERENCE_CACHE[cache_key] = {
273
+ "wavs": wavs,
274
+ "total_dur_s": total_dur_s,
275
+ "tmp_dir": tmp_dir,
276
+ "silent_video": silent_video,
277
+ }
278
+
279
+ # Stitch with current crossfade params
280
+ device = "cuda" if torch.cuda.is_available() else "cpu"
281
+ final_wav = stitch_wavs(wavs, crossfade_s, crossfade_db, sr, total_dur_s)
282
 
283
+ audio_path = os.path.join(tmp_dir, "output.wav")
284
  sf.write(audio_path, final_wav, sr)
285
 
 
286
  output_video = os.path.join(tmp_dir, "output.mp4")
287
  input_v = ffmpeg.input(silent_video)
288
  input_a = ffmpeg.input(audio_path)
 
308
  gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5),
309
  gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1),
310
  gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde"),
311
+ gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=8, value=2, step=0.1),
312
+ gr.Textbox(label="Crossfade Boost (dB)", value="3"),
313
  ],
314
  outputs=[
315
  gr.Video(label="Output Video with Audio"),
316
  gr.Audio(label="Generated Audio"),
317
  ],
318
  title="TARO: Video-to-Audio Synthesis (ICCV 2025)",
319
+ description="Upload a video and generate synchronized audio using TARO. Optimal clip duration is 8.2s. Longer videos are automatically split into overlapping segments and stitched with a crossfade.",
320
  )
321
+ demo.queue().launch(share=True)