JackIsNotInTheBox commited on
Commit
0429f8a
·
1 Parent(s): d0e121d

Support videos longer than 8.2s via overlapping inference + 2s +3dB crossfade stitching

Browse files
Files changed (1) hide show
  1. app.py +162 -57
app.py CHANGED
@@ -49,6 +49,105 @@ def strip_audio_from_video(video_path, output_path):
49
  )
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  @spaces.GPU(duration=300)
53
  def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
54
  seed_val = int(seed_val)
@@ -101,74 +200,80 @@ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
101
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
102
  onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
103
 
104
- sr = 16000
105
- truncate = 131072
106
- fps = 4
107
- truncate_frame = int(fps * truncate / sr)
108
- truncate_onset = 120
 
 
 
 
 
109
 
110
  latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
111
 
112
- video_feats = torch.from_numpy(cavp_feats[:truncate_frame]).unsqueeze(0).to(device).to(weight_dtype)
113
-
114
- # Slice onset features and pad to truncate_onset if the video is shorter than expected
115
- onset_feats_sliced = onset_feats[:truncate_onset]
116
- actual_onset_len = onset_feats_sliced.shape[0]
117
- if actual_onset_len < truncate_onset:
118
- pad_len = truncate_onset - actual_onset_len
119
- onset_feats_sliced = np.pad(
120
- onset_feats_sliced,
121
- ((0, pad_len),),
122
- mode="constant",
123
- constant_values=0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  )
125
- onset_feats_t = torch.from_numpy(onset_feats_sliced).unsqueeze(0).to(device).to(weight_dtype)
126
-
127
- z = torch.randn(len(video_feats), model.in_channels, 204, 16, device=device).to(weight_dtype)
128
-
129
- sampling_kwargs = dict(
130
- model=model,
131
- latents=z,
132
- y=onset_feats_t,
133
- context=video_feats,
134
- num_steps=int(num_steps),
135
- heun=False,
136
- cfg_scale=float(cfg_scale),
137
- guidance_low=0.0,
138
- guidance_high=0.7,
139
- path_type="linear",
140
- )
141
-
142
- with torch.no_grad():
143
- if mode == "sde":
144
- samples = euler_maruyama_sampler(**sampling_kwargs)
145
- else:
146
- samples = euler_sampler(**sampling_kwargs)
147
-
148
- samples = vae.decode(samples / latents_scale).sample
149
-
150
- # Cast to float32 before vocoder (HiFi-GAN requires float32)
151
- wav_samples = vocoder(samples.squeeze().float()).detach().cpu().numpy()
152
 
153
  audio_path = os.path.join(tmp_dir, "output.wav")
154
- sf.write(audio_path, wav_samples, sr)
155
 
156
- duration = truncate / sr
157
- trimmed_video = os.path.join(tmp_dir, "trimmed.mp4")
158
  output_video = os.path.join(tmp_dir, "output.mp4")
159
-
160
- (
161
- ffmpeg
162
- .input(silent_video, ss=0, t=duration)
163
- .output(trimmed_video, vcodec="libx264", an=None)
164
- .run(overwrite_output=True, quiet=True)
165
- )
166
-
167
- input_v = ffmpeg.input(trimmed_video)
168
  input_a = ffmpeg.input(audio_path)
169
  (
170
  ffmpeg
171
- .output(input_v, input_a, output_video, vcodec="libx264", acodec="aac", strict="experimental")
 
172
  .run(overwrite_output=True, quiet=True)
173
  )
174
 
 
49
  )
50
 
51
 
52
+ def infer_segment(model, vae, vocoder, cavp_feats_full, onset_feats_full,
53
+ seg_start_s, seg_end_s,
54
+ sr, fps, truncate_frame, truncate_onset, model_dur,
55
+ latents_scale, device, weight_dtype,
56
+ cfg_scale, num_steps, mode,
57
+ euler_sampler, euler_maruyama_sampler):
58
+ """
59
+ Run one model inference pass for the video window [seg_start_s, seg_start_s + model_dur].
60
+ Returns a numpy float32 wav array of exactly round(model_dur * sr) samples,
61
+ trimmed to the actual segment length (seg_end_s - seg_start_s) when shorter.
62
+ """
63
+ # -- CAVP features: 4 fps --
64
+ cavp_start = int(round(seg_start_s * fps))
65
+ cavp_end = cavp_start + truncate_frame
66
+ cavp_slice = cavp_feats_full[cavp_start:cavp_end]
67
+ # pad if near end of video
68
+ if cavp_slice.shape[0] < truncate_frame:
69
+ pad = np.zeros((truncate_frame - cavp_slice.shape[0],) + cavp_slice.shape[1:], dtype=cavp_slice.dtype)
70
+ cavp_slice = np.concatenate([cavp_slice, pad], axis=0)
71
+ video_feats = torch.from_numpy(cavp_slice).unsqueeze(0).to(device).to(weight_dtype)
72
+
73
+ # -- Onset features: truncate_onset frames per model_dur --
74
+ onset_fps = truncate_onset / model_dur # frames per second of onset feats
75
+ onset_start = int(round(seg_start_s * onset_fps))
76
+ onset_slice = onset_feats_full[onset_start : onset_start + truncate_onset]
77
+ if onset_slice.shape[0] < truncate_onset:
78
+ pad_len = truncate_onset - onset_slice.shape[0]
79
+ onset_slice = np.pad(onset_slice, ((0, pad_len),), mode="constant", constant_values=0)
80
+ onset_feats_t = torch.from_numpy(onset_slice).unsqueeze(0).to(device).to(weight_dtype)
81
+
82
+ # -- Diffusion --
83
+ z = torch.randn(1, model.in_channels, 204, 16, device=device).to(weight_dtype)
84
+ sampling_kwargs = dict(
85
+ model=model,
86
+ latents=z,
87
+ y=onset_feats_t,
88
+ context=video_feats,
89
+ num_steps=int(num_steps),
90
+ heun=False,
91
+ cfg_scale=float(cfg_scale),
92
+ guidance_low=0.0,
93
+ guidance_high=0.7,
94
+ path_type="linear",
95
+ )
96
+ with torch.no_grad():
97
+ if mode == "sde":
98
+ samples = euler_maruyama_sampler(**sampling_kwargs)
99
+ else:
100
+ samples = euler_sampler(**sampling_kwargs)
101
+
102
+ samples = vae.decode(samples / latents_scale).sample
103
+ wav = vocoder(samples.squeeze().float()).detach().cpu().numpy()
104
+
105
+ # Trim to actual segment length
106
+ seg_samples = int(round((seg_end_s - seg_start_s) * sr))
107
+ return wav[:seg_samples]
108
+
109
+
110
+ def crossfade_join(wav_a, wav_b, crossfade_s, sr):
111
+ """
112
+ Join wav_a and wav_b with a 2-second equal-power (+3 dB) crossfade.
113
+
114
+ wav_a contains 1 s of 'extra' audio at its tail (the overlap region starts
115
+ 1 s before its end). wav_b contains 1 s of 'extra' audio at its head.
116
+ The crossfade window is crossfade_s wide; the midpoint sits at (crossfade_s/2)
117
+ into the window, where each gain = sqrt(0.5) ≈ -3 dB ... wait, we want +3 dB
118
+ at midpoint meaning both signals are at *full* amplitude there.
119
+
120
+ Equal-power (sqrt) ramps: at the midpoint t=0.5 the fade-out = sqrt(0.5) and
121
+ fade-in = sqrt(0.5), so combined power = 0.5+0.5 = 1.0 (+0 dB).
122
+ For a +3 dB bump at midpoint we use *linear* ramps instead:
123
+ fade_out = 1 - t, fade_in = t (t: 0->1 across window)
124
+ At t=0.5: both = 0.5, sum = 1.0 amplitude = +6 dB power... that is not right.
125
+
126
+ DaVinci Resolve "+3 dB" crossfade means the combined level at the midpoint
127
+ is +3 dB above either source, which equals the behaviour where each signal
128
+ is kept at full gain (1.0) across the entire overlap and the two are simply
129
+ summed — then the overlap region has 6 dB of headroom risk, but the *perceived*
130
+ loudness boost at the centre is +3 dB (sqrt(2) in amplitude).
131
+
132
+ Implementation: keep both signals at unity gain in the crossfade window and
133
+ sum them. Outside the window use the respective signal only.
134
+ """
135
+ cf_samples = int(round(crossfade_s * sr))
136
+
137
+ # The crossfade sits at the junction: last cf_samples of wav_a overlap with
138
+ # first cf_samples of wav_b.
139
+ tail_a = wav_a[-cf_samples:] # 1s before end of a
140
+ head_b = wav_b[:cf_samples] # 1s after start of b
141
+ overlap = tail_a + head_b # +3 dB sum at centre (unity + unity)
142
+
143
+ result = np.concatenate([
144
+ wav_a[:-cf_samples], # body of a (before crossfade)
145
+ overlap, # crossfade region
146
+ wav_b[cf_samples:], # body of b (after crossfade)
147
+ ])
148
+ return result
149
+
150
+
151
  @spaces.GPU(duration=300)
152
  def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
153
  seed_val = int(seed_val)
 
200
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
201
  onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
202
 
203
+ sr = 16000
204
+ truncate = 131072
205
+ fps = 4
206
+ truncate_frame = int(fps * truncate / sr) # 32 cavp frames per segment
207
+ truncate_onset = 120 # onset frames per segment
208
+ model_dur = truncate / sr # 8.192 s
209
+ crossfade_s = 2.0 # 2-second crossfade window
210
+ # Each segment starts (model_dur - crossfade_s) later than the previous,
211
+ # so the tails overlap by crossfade_s giving 1 s of extra audio on each side.
212
+ step_s = model_dur - crossfade_s # 6.192 s
213
 
214
  latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
215
 
216
+ # Total video duration from cavp features
217
+ total_frames = cavp_feats.shape[0]
218
+ total_dur_s = total_frames / fps
219
+
220
+ # ------------------------------------------------------------------ #
221
+ # Build segment list: each entry is (seg_start_s, seg_end_s) #
222
+ # seg_end_s is the actual content end (clipped to video length), #
223
+ # but we always run the model for a full model_dur window. #
224
+ # ------------------------------------------------------------------ #
225
+ segments = []
226
+ seg_start = 0.0
227
+ while True:
228
+ seg_end = min(seg_start + model_dur, total_dur_s)
229
+ segments.append((seg_start, seg_end))
230
+ if seg_end >= total_dur_s:
231
+ break
232
+ seg_start += step_s
233
+
234
+ # ------------------------------------------------------------------ #
235
+ # Run inference for every segment #
236
+ # ------------------------------------------------------------------ #
237
+ wavs = []
238
+ for seg_start_s, seg_end_s in segments:
239
+ print(f"Inferring segment {seg_start_s:.2f}s – {seg_end_s:.2f}s ...")
240
+ wav = infer_segment(
241
+ model, vae, vocoder,
242
+ cavp_feats, onset_feats,
243
+ seg_start_s, seg_end_s,
244
+ sr, fps, truncate_frame, truncate_onset, model_dur,
245
+ latents_scale, device, weight_dtype,
246
+ cfg_scale, num_steps, mode,
247
+ euler_sampler, euler_maruyama_sampler,
248
  )
249
+ wavs.append(wav)
250
+
251
+ # ------------------------------------------------------------------ #
252
+ # Stitch with crossfades #
253
+ # Single segment: no crossfade needed #
254
+ # ------------------------------------------------------------------ #
255
+ if len(wavs) == 1:
256
+ final_wav = wavs[0]
257
+ else:
258
+ final_wav = wavs[0]
259
+ for next_wav in wavs[1:]:
260
+ final_wav = crossfade_join(final_wav, next_wav, crossfade_s, sr)
261
+
262
+ # Clip to exact video duration
263
+ target_samples = int(round(total_dur_s * sr))
264
+ final_wav = final_wav[:target_samples]
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  audio_path = os.path.join(tmp_dir, "output.wav")
267
+ sf.write(audio_path, final_wav, sr)
268
 
269
+ # Mux original silent video (full length) with generated audio
 
270
  output_video = os.path.join(tmp_dir, "output.mp4")
271
+ input_v = ffmpeg.input(silent_video)
 
 
 
 
 
 
 
 
272
  input_a = ffmpeg.input(audio_path)
273
  (
274
  ffmpeg
275
+ .output(input_v, input_a, output_video,
276
+ vcodec="libx264", acodec="aac", strict="experimental")
277
  .run(overwrite_output=True, quiet=True)
278
  )
279